| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import PreTrainedModel, PretrainedConfig |
|
|
|
|
| class MakemoreConfig(PretrainedConfig): |
| model_type = "makemore_mlp" |
|
|
| def __init__(self, block_size=3, emb_dim=10, hidden_dim=200, vocab_size=27, **kwargs): |
| super().__init__(**kwargs) |
| self.block_size = block_size |
| self.emb_dim = emb_dim |
| self.hidden_dim = hidden_dim |
| self.vocab_size = vocab_size |
|
|
|
|
| class MakemoreMLP(PreTrainedModel): |
| config_class = MakemoreConfig |
| _tied_weights_keys = [] |
|
|
| @property |
| def all_tied_weights_keys(self): |
| return {} |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.C = nn.Embedding(config.vocab_size, config.emb_dim) |
| self.W1 = nn.Linear(config.block_size * config.emb_dim, config.hidden_dim, bias=True) |
| self.W2 = nn.Linear(config.hidden_dim, config.vocab_size, bias=True) |
|
|
| def forward(self, input_ids, labels=None): |
| |
| emb = self.C(input_ids) |
| h = torch.tanh(self.W1(emb.view(emb.size(0), -1))) |
| logits = self.W2(h) |
|
|
| loss = None |
| if labels is not None: |
| loss = F.cross_entropy(logits, labels) |
|
|
| return {"loss": loss, "logits": logits} |
|
|