adityashisharma commited on
Commit
348cbf2
·
verified ·
1 Parent(s): 70ae50e

Upload tiny_gpt2.py

Browse files
Files changed (1) hide show
  1. model/tiny_gpt2.py +120 -0
model/tiny_gpt2.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Minimal GPT-2-ish decoder-only LM, written for clarity.
2
+ from dataclasses import dataclass
3
+ import math
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ @dataclass
8
+ class GPTConfig:
9
+ vocab_size: int = 16000
10
+ n_layer: int = 6
11
+ n_head: int = 6
12
+ n_embed: int = 384
13
+ block_size: int = 256
14
+ attn_pdrop: float = 0.0
15
+ resid_pdrop: float = 0.0
16
+
17
+ class CausalSelfAttention(nn.Module):
18
+ def __init__(self, cfg: GPTConfig):
19
+ super().__init__()
20
+ assert cfg.n_embed % cfg.n_head == 0
21
+ self.n_head = cfg.n_head
22
+ self.key = nn.Linear(cfg.n_embed, cfg.n_embed, bias=False)
23
+ self.query = nn.Linear(cfg.n_embed, cfg.n_embed, bias=False)
24
+ self.value = nn.Linear(cfg.n_embed, cfg.n_embed, bias=False)
25
+ self.proj = nn.Linear(cfg.n_embed, cfg.n_embed, bias=False)
26
+ self.attn_drop = nn.Dropout(cfg.attn_pdrop)
27
+ self.resid_drop = nn.Dropout(cfg.resid_pdrop)
28
+ self.register_buffer("mask",
29
+ torch.tril(torch.ones(cfg.block_size, cfg.block_size)).view(1,1,cfg.block_size,cfg.block_size)
30
+ )
31
+
32
+ def forward(self, x):
33
+ B,T,C = x.size()
34
+ H = self.n_head
35
+ k = self.key(x).view(B,T,H,C//H).transpose(1,2)
36
+ q = self.query(x).view(B,T,H,C//H).transpose(1,2)
37
+ v = self.value(x).view(B,T,H,C//H).transpose(1,2)
38
+ att = (q @ k.transpose(-2,-1)) / math.sqrt(k.size(-1))
39
+ att = att.masked_fill(self.mask[:,:,:T,:T]==0, float("-inf"))
40
+ att = torch.softmax(att, dim=-1)
41
+ att = self.attn_drop(att)
42
+ y = att @ v
43
+ y = y.transpose(1,2).contiguous().view(B,T,C)
44
+ y = self.resid_drop(self.proj(y))
45
+ return y
46
+
47
+ class Block(nn.Module):
48
+ def __init__(self, cfg: GPTConfig):
49
+ super().__init__()
50
+ self.ln1 = nn.LayerNorm(cfg.n_embed)
51
+ self.attn = CausalSelfAttention(cfg)
52
+ self.ln2 = nn.LayerNorm(cfg.n_embed)
53
+ self.mlp = nn.Sequential(
54
+ nn.Linear(cfg.n_embed, 4*cfg.n_embed),
55
+ nn.GELU(),
56
+ nn.Linear(4*cfg.n_embed, cfg.n_embed),
57
+ nn.Dropout(cfg.resid_pdrop),
58
+ )
59
+
60
+ def forward(self, x):
61
+ x = x + self.attn(self.ln1(x))
62
+ x = x + self.mlp(self.ln2(x))
63
+ return x
64
+
65
+ class TinyGPT2(nn.Module):
66
+ def __init__(self, cfg: GPTConfig):
67
+ super().__init__()
68
+ self.cfg = cfg
69
+ self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.n_embed)
70
+ self.pos_emb = nn.Embedding(cfg.block_size, cfg.n_embed)
71
+ self.drop = nn.Dropout(cfg.resid_pdrop)
72
+ self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layer)])
73
+ self.ln_f = nn.LayerNorm(cfg.n_embed)
74
+ self.head = nn.Linear(cfg.n_embed, cfg.vocab_size, bias=False)
75
+ self.apply(self._init_weights)
76
+
77
+ def _init_weights(self, module):
78
+ if isinstance(module, nn.Linear):
79
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
80
+ if module.bias is not None:
81
+ nn.init.zeros_(module.bias)
82
+ if isinstance(module, nn.Embedding):
83
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
84
+
85
+ @torch.no_grad()
86
+ def generate(self, idx, max_new_tokens=64, top_k=50, top_p=0.95, temperature=1.0):
87
+ self.eval()
88
+ for _ in range(max_new_tokens):
89
+ idx_cond = idx[:, -self.cfg.block_size:]
90
+ logits = self(idx_cond)[:, -1, :] / max(temperature, 1e-5)
91
+ logits = self._top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
92
+ probs = torch.softmax(logits, dim=-1)
93
+ next_id = torch.multinomial(probs, num_samples=1)
94
+ idx = torch.cat([idx, next_id], dim=1)
95
+ return idx
96
+
97
+ @staticmethod
98
+ def _top_k_top_p_filtering(logits, top_k=0, top_p=1.0):
99
+ if top_k and top_k > 0:
100
+ v, _ = torch.topk(logits, top_k)
101
+ logits[logits < v[:, [-1]]] = -float("inf")
102
+ if top_p < 1.0:
103
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
104
+ cumprobs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
105
+ idx = cumprobs > top_p
106
+ idx[..., 1:] = idx[..., :-1].clone()
107
+ idx[..., 0] = 0
108
+ sorted_logits[idx] = -float("inf")
109
+ logits.scatter_(1, sorted_indices, sorted_logits)
110
+ return logits
111
+
112
+ def forward(self, idx):
113
+ B,T = idx.size()
114
+ pos = torch.arange(0, T, device=idx.device).unsqueeze(0)
115
+ x = self.tok_emb(idx) + self.pos_emb(pos)
116
+ x = self.drop(x)
117
+ for block in self.blocks:
118
+ x = block(x)
119
+ x = self.ln_f(x)
120
+ return self.head(x)