|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import PreTrainedModel |
|
|
from .configuration_tinymodel import TinyModelConfig |
|
|
|
|
|
class TinyCore(nn.Module): |
|
|
"""Your original TinyModel, but embedded here for convenience.""" |
|
|
def __init__(self, cfg: TinyModelConfig): |
|
|
super().__init__() |
|
|
self.linear1 = nn.Linear(cfg.input_size, cfg.hidden_size) |
|
|
self.activation = nn.ReLU() |
|
|
self.linear2 = nn.Linear(cfg.hidden_size, cfg.num_labels) |
|
|
self.softmax = nn.Softmax(dim=-1) |
|
|
|
|
|
def forward(self, x: torch.Tensor): |
|
|
x = self.linear1(x) |
|
|
x = self.activation(x) |
|
|
x = self.linear2(x) |
|
|
x = self.softmax(x) |
|
|
return x |
|
|
|
|
|
class TinyModel(PreTrainedModel): |
|
|
config_class = TinyModelConfig |
|
|
|
|
|
def __init__(self, config: TinyModelConfig): |
|
|
super().__init__(config) |
|
|
self.core = TinyCore(config) |
|
|
self.post_init() |
|
|
|
|
|
def forward(self, inputs: torch.Tensor, **kwargs): |
|
|
""" |
|
|
Expect inputs shape: (batch, config.input_size) |
|
|
""" |
|
|
return self.core(inputs) |
|
|
|
|
|
|
|
|
def predict_proba(self, inputs: torch.Tensor): |
|
|
return self.forward(inputs) |
|
|
|