import torch import torch.nn as nn from transformers import PreTrainedModel from .configuration_tinymodel import TinyModelConfig class TinyCore(nn.Module): """Your original TinyModel, but embedded here for convenience.""" def __init__(self, cfg: TinyModelConfig): super().__init__() self.linear1 = nn.Linear(cfg.input_size, cfg.hidden_size) self.activation = nn.ReLU() self.linear2 = nn.Linear(cfg.hidden_size, cfg.num_labels) self.softmax = nn.Softmax(dim=-1) def forward(self, x: torch.Tensor): x = self.linear1(x) x = self.activation(x) x = self.linear2(x) x = self.softmax(x) return x class TinyModel(PreTrainedModel): config_class = TinyModelConfig def __init__(self, config: TinyModelConfig): super().__init__(config) self.core = TinyCore(config) self.post_init() # Initializes weights if needed def forward(self, inputs: torch.Tensor, **kwargs): """ Expect inputs shape: (batch, config.input_size) """ return self.core(inputs) # (Optional) helper for logits-only def predict_proba(self, inputs: torch.Tensor): return self.forward(inputs)