| from transformers import PreTrainedModel | |
| import torch.nn as nn | |
| from .configuration_simple_model import SimpleNNConfig | |
| # Define the model class | |
| class SimpleNN(PreTrainedModel): | |
| config_class = SimpleNNConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.dense = nn.Linear(config.input_size, config.num_classes) | |
| def forward(self, x): | |
| return self.dense(x) |