import torch import torch.nn as nn from transformers import PreTrainedModel, PretrainedConfig class SimpleConfig(PretrainedConfig): model_type = "simple-model" def __init__(self, vocab_size=100, hidden_size=32, num_labels=2, **kwargs): super().__init__(**kwargs) self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_labels = num_labels class SimpleModel(PreTrainedModel): config_class = SimpleConfig def __init__(self, config): super().__init__(config) self.embedding = nn.Embedding(config.vocab_size, config.hidden_size) self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.post_init() # important for HF weight init def forward(self, input_ids): x = self.embedding(input_ids) x = x.mean(dim=1) # simple pooling logits = self.classifier(x) return {"logits": logits}