File size: 1,228 Bytes
db32fc7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from .configuration_tinymodel import TinyModelConfig
class TinyCore(nn.Module):
"""Your original TinyModel, but embedded here for convenience."""
def __init__(self, cfg: TinyModelConfig):
super().__init__()
self.linear1 = nn.Linear(cfg.input_size, cfg.hidden_size)
self.activation = nn.ReLU()
self.linear2 = nn.Linear(cfg.hidden_size, cfg.num_labels)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x: torch.Tensor):
x = self.linear1(x)
x = self.activation(x)
x = self.linear2(x)
x = self.softmax(x)
return x
class TinyModel(PreTrainedModel):
config_class = TinyModelConfig
def __init__(self, config: TinyModelConfig):
super().__init__(config)
self.core = TinyCore(config)
self.post_init() # Initializes weights if needed
def forward(self, inputs: torch.Tensor, **kwargs):
"""
Expect inputs shape: (batch, config.input_size)
"""
return self.core(inputs)
# (Optional) helper for logits-only
def predict_proba(self, inputs: torch.Tensor):
return self.forward(inputs)
|