chudai1019 commited on
Commit
ad4e7b1
·
verified ·
1 Parent(s): eae2184

Update model_transformer.py

Browse files
Files changed (1) hide show
  1. model_transformer.py +64 -24
model_transformer.py CHANGED
@@ -1,37 +1,77 @@
 
 
 
1
  import torch
2
  import torch.nn as nn
3
 
4
  class PositionalEncoding(nn.Module):
5
- def __init__(self, emb, max_len=2048):
6
  super().__init__()
7
- pe = torch.zeros(max_len, emb)
8
- pos = torch.arange(0, max_len).unsqueeze(1)
9
- div = torch.exp(torch.arange(0, emb, 2) * (-torch.log(torch.tensor(10000.0)) / emb))
10
- pe[:, 0::2] = torch.sin(pos * div)
11
- pe[:, 1::2] = torch.cos(pos * div)
12
- self.pe = pe.unsqueeze(0)
13
 
14
  def forward(self, x):
15
- return x + self.pe[:, :x.size(1), :].to(x.device)
 
 
16
 
17
  class TransformerLM(nn.Module):
18
- def __init__(self, vocab_size, emb=256, n_heads=4, n_layers=4):
 
 
 
19
  super().__init__()
20
- self.embed = nn.Embedding(vocab_size, emb)
21
- self.pos = PositionalEncoding(emb)
 
 
 
 
 
22
 
23
- encoder_layer = nn.TransformerEncoderLayer(
24
- d_model=emb,
25
- nhead=n_heads,
26
- dim_feedforward=512,
27
- batch_first=True
28
- )
29
 
30
- self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
31
- self.head = nn.Linear(emb, vocab_size)
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- def forward(self, x):
34
- x = self.embed(x)
35
- x = self.pos(x)
36
- x = self.transformer(x)
37
- return self.head(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model_transformer.py
2
+ # Requires: pip install torch
3
+ import math
4
  import torch
5
  import torch.nn as nn
6
 
7
  class PositionalEncoding(nn.Module):
8
+ def __init__(self, d_model, max_len=2048):
9
  super().__init__()
10
+ pe = torch.zeros(max_len, d_model)
11
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
12
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
13
+ pe[:, 0::2] = torch.sin(position * div_term)
14
+ pe[:, 1::2] = torch.cos(position * div_term)
15
+ self.register_buffer("pe", pe.unsqueeze(0)) # (1, max_len, d_model)
16
 
17
  def forward(self, x):
18
+ # x: (B, T, D)
19
+ L = x.size(1)
20
+ return x + self.pe[:, :L, :]
21
 
22
  class TransformerLM(nn.Module):
23
+ def __init__(self, vocab_size, d_model=384, nhead=8, num_layers=4, dim_feedforward=1536, dropout=0.1, pad_id=0):
24
+ """
25
+ d_model=384, num_layers=4 is a reasonable size for a ~10M-ish model depending on vocab.
26
+ """
27
  super().__init__()
28
+ self.pad_id = pad_id
29
+ self.tok_embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
30
+ self.pos_enc = PositionalEncoding(d_model)
31
+ encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, batch_first=True)
32
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
33
+ self.ln_f = nn.LayerNorm(d_model)
34
+ self.head = nn.Linear(d_model, vocab_size, bias=False)
35
 
36
+ # init
37
+ nn.init.normal_(self.tok_embedding.weight, mean=0.0, std=0.02)
38
+ nn.init.normal_(self.head.weight, mean=0.0, std=0.02)
 
 
 
39
 
40
+ def forward(self, input_ids):
41
+ """
42
+ input_ids: (B, T) LongTensor
43
+ returns logits: (B, T, V)
44
+ """
45
+ # create attention mask to prevent attending to pad tokens
46
+ x = self.tok_embedding(input_ids) # (B,T,D)
47
+ x = self.pos_enc(x)
48
+ # mask padding: transformer expects key_padding_mask bool of shape (B,T) True=pad
49
+ key_padding_mask = (input_ids == self.pad_id) # bool
50
+ x = self.transformer(x, src_key_padding_mask=key_padding_mask)
51
+ x = self.ln_f(x)
52
+ logits = self.head(x)
53
+ return logits
54
 
55
+ @torch.no_grad()
56
+ def generate(self, tokenizer, device, prompt, max_new_tokens=64, temperature=1.0, top_k=40):
57
+ """
58
+ Simple autoregressive generation using the model as an encoder-decoder LM:
59
+ We feed the entire sequence and sample the next token from last position.
60
+ This is simple and works for smaller models.
61
+ """
62
+ self.eval()
63
+ ids = tokenizer.encode(prompt)
64
+ ids = [i for i in ids if i is not None]
65
+ input_ids = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0) # (1, T)
66
+ for _ in range(max_new_tokens):
67
+ logits = self.forward(input_ids) # (1, T, V)
68
+ next_logits = logits[:, -1, :] / max(temperature, 1e-8)
69
+ if top_k is not None and top_k > 0:
70
+ topk_vals, topk_idx = torch.topk(next_logits, min(top_k, next_logits.size(-1)))
71
+ probs = torch.zeros_like(next_logits).scatter_(1, topk_idx, nn.functional.softmax(topk_vals, dim=-1))
72
+ else:
73
+ probs = nn.functional.softmax(next_logits, dim=-1)
74
+ next_id = torch.multinomial(probs, num_samples=1).item()
75
+ input_ids = torch.cat([input_ids, torch.tensor([[next_id]], device=device)], dim=1)
76
+ out_ids = input_ids.squeeze(0).tolist()
77
+ return tokenizer.decode(out_ids)