File size: 1,383 Bytes
465e160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1806f3
465e160
3b04946
 
 
 
465e160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig


class MakemoreConfig(PretrainedConfig):
  model_type = "makemore_mlp"

  def __init__(self, block_size=3, emb_dim=10, hidden_dim=200, vocab_size=27, **kwargs):
      super().__init__(**kwargs)
      self.block_size = block_size
      self.emb_dim = emb_dim
      self.hidden_dim = hidden_dim
      self.vocab_size = vocab_size


class MakemoreMLP(PreTrainedModel):
  config_class = MakemoreConfig
  _tied_weights_keys = []

  @property
  def all_tied_weights_keys(self):
      return {}

  def __init__(self, config):
      super().__init__(config)
      self.C  = nn.Embedding(config.vocab_size, config.emb_dim)
      self.W1 = nn.Linear(config.block_size * config.emb_dim, config.hidden_dim, bias=True)
      self.W2 = nn.Linear(config.hidden_dim, config.vocab_size, bias=True)

  def forward(self, input_ids, labels=None):
      # input_ids: (batch, block_size)
      emb    = self.C(input_ids)                            # (B, block_size, emb_dim)
      h      = torch.tanh(self.W1(emb.view(emb.size(0), -1)))  # (B, hidden_dim)
      logits = self.W2(h)                                   # (B, vocab_size)

      loss = None
      if labels is not None:
          loss = F.cross_entropy(logits, labels)

      return {"loss": loss, "logits": logits}