| import torch | |
| import torch.nn as nn | |
| from transformers import PreTrainedModel, PretrainedConfig | |
| class SimpleConfig(PretrainedConfig): | |
| model_type = "simple-model" | |
| def __init__(self, vocab_size=100, hidden_size=32, num_labels=2, **kwargs): | |
| super().__init__(**kwargs) | |
| self.vocab_size = vocab_size | |
| self.hidden_size = hidden_size | |
| self.num_labels = num_labels | |
| class SimpleModel(PreTrainedModel): | |
| config_class = SimpleConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.embedding = nn.Embedding(config.vocab_size, config.hidden_size) | |
| self.classifier = nn.Linear(config.hidden_size, config.num_labels) | |
| self.post_init() # important for HF weight init | |
| def forward(self, input_ids): | |
| x = self.embedding(input_ids) | |
| x = x.mean(dim=1) # simple pooling | |
| logits = self.classifier(x) | |
| return {"logits": logits} | |