File size: 1,383 Bytes
465e160 e1806f3 465e160 3b04946 465e160 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 | import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig
class MakemoreConfig(PretrainedConfig):
model_type = "makemore_mlp"
def __init__(self, block_size=3, emb_dim=10, hidden_dim=200, vocab_size=27, **kwargs):
super().__init__(**kwargs)
self.block_size = block_size
self.emb_dim = emb_dim
self.hidden_dim = hidden_dim
self.vocab_size = vocab_size
class MakemoreMLP(PreTrainedModel):
config_class = MakemoreConfig
_tied_weights_keys = []
@property
def all_tied_weights_keys(self):
return {}
def __init__(self, config):
super().__init__(config)
self.C = nn.Embedding(config.vocab_size, config.emb_dim)
self.W1 = nn.Linear(config.block_size * config.emb_dim, config.hidden_dim, bias=True)
self.W2 = nn.Linear(config.hidden_dim, config.vocab_size, bias=True)
def forward(self, input_ids, labels=None):
# input_ids: (batch, block_size)
emb = self.C(input_ids) # (B, block_size, emb_dim)
h = torch.tanh(self.W1(emb.view(emb.size(0), -1))) # (B, hidden_dim)
logits = self.W2(h) # (B, vocab_size)
loss = None
if labels is not None:
loss = F.cross_entropy(logits, labels)
return {"loss": loss, "logits": logits}
|