modeling_simple / modeling_simple.py
jonl521's picture
Upload folder using huggingface_hub
2216aae verified
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig
class SimpleConfig(PretrainedConfig):
model_type = "simple-model"
def __init__(self, vocab_size=100, hidden_size=32, num_labels=2, **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_labels = num_labels
class SimpleModel(PreTrainedModel):
config_class = SimpleConfig
def __init__(self, config):
super().__init__(config)
self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.post_init() # important for HF weight init
def forward(self, input_ids):
x = self.embedding(input_ids)
x = x.mean(dim=1) # simple pooling
logits = self.classifier(x)
return {"logits": logits}