File size: 927 Bytes
2216aae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig
class SimpleConfig(PretrainedConfig):
model_type = "simple-model"
def __init__(self, vocab_size=100, hidden_size=32, num_labels=2, **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_labels = num_labels
class SimpleModel(PreTrainedModel):
config_class = SimpleConfig
def __init__(self, config):
super().__init__(config)
self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.post_init() # important for HF weight init
def forward(self, input_ids):
x = self.embedding(input_ids)
x = x.mean(dim=1) # simple pooling
logits = self.classifier(x)
return {"logits": logits}
|