mcrimi commited on
Commit
8873ab6
·
verified ·
1 Parent(s): 4b7659f

Upload gpt.py

Browse files
Files changed (1) hide show
  1. gpt.py +198 -0
gpt.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+ from dataclasses import dataclass
5
+
6
+
7
+ @dataclass
8
+ class GPTConfig:
9
+ vocab_size: int
10
+ block_size: int
11
+ n_embd: int
12
+ n_head: int
13
+ n_layer: int
14
+ dropout: float = 0.0
15
+ device: str = "cpu" # "cpu" or "cuda" or "mps"
16
+
17
+
18
+ class Head(nn.Module):
19
+ """one head of self-attention"""
20
+
21
+ def __init__(self, config, head_size):
22
+ super().__init__()
23
+ self.key = nn.Linear(config.n_embd, head_size, bias=False)
24
+ self.query = nn.Linear(config.n_embd, head_size, bias=False)
25
+ self.value = nn.Linear(config.n_embd, head_size, bias=False)
26
+ self.register_buffer(
27
+ "tril", torch.tril(torch.ones(config.block_size, config.block_size))
28
+ )
29
+ self.dropout = nn.Dropout(config.dropout)
30
+ self.config = config
31
+
32
+ def forward(self, x):
33
+ B, T, C = x.shape
34
+ k = self.key(x) # (B,T,hs)
35
+ q = self.query(x) # (B,T,hs)
36
+ # compute attention scores ("affinities")
37
+ wei = q @ k.transpose(-2, -1) * C**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
38
+ wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf")) # (B, T, T)
39
+ wei = F.softmax(wei, dim=-1) # (B, T, T)
40
+ wei = self.dropout(wei)
41
+ v = self.value(x) # (B,T,hs)
42
+ out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
43
+ return out
44
+
45
+
46
+ class MultiHeadAttention(nn.Module):
47
+ """multiple heads of self-attention in parallel"""
48
+
49
+ def __init__(self, config, head_size):
50
+ super().__init__()
51
+ self.heads = nn.ModuleList(
52
+ [Head(config, head_size) for _ in range(config.n_head)]
53
+ )
54
+ self.proj = nn.Linear(config.n_embd, config.n_embd)
55
+ self.dropout = nn.Dropout(config.dropout)
56
+
57
+ def forward(self, x):
58
+ out = torch.cat([h(x) for h in self.heads], dim=-1)
59
+ out = self.dropout(self.proj(out))
60
+ return out
61
+
62
+
63
+ class FeedFoward(nn.Module):
64
+ """a simple linear layer followed by a non-linearity"""
65
+
66
+ def __init__(self, config):
67
+ super().__init__()
68
+ self.net = nn.Sequential(
69
+ nn.Linear(config.n_embd, 4 * config.n_embd),
70
+ nn.ReLU(),
71
+ nn.Linear(4 * config.n_embd, config.n_embd),
72
+ nn.Dropout(config.dropout),
73
+ )
74
+
75
+ def forward(self, x):
76
+ return self.net(x)
77
+
78
+
79
+ class Block(nn.Module):
80
+ """Transformer block: communication followed by computation"""
81
+
82
+ def __init__(self, config):
83
+ super().__init__()
84
+ head_size = config.n_embd // config.n_head
85
+ self.sa = MultiHeadAttention(config, head_size)
86
+ self.ffwd = FeedFoward(config)
87
+ self.ln1 = nn.LayerNorm(config.n_embd)
88
+ self.ln2 = nn.LayerNorm(config.n_embd)
89
+
90
+ def forward(self, x):
91
+ x = x + self.sa(self.ln1(x))
92
+ x = x + self.ffwd(self.ln2(x))
93
+ return x
94
+
95
+
96
+ class GPT(nn.Module):
97
+ def __init__(self, config):
98
+ super().__init__()
99
+ self.config = config
100
+ # each token directly reads off the logits for the next token from a lookup table
101
+ self.token_embedding_table = nn.Embedding(config.vocab_size, config.n_embd)
102
+ self.position_embedding_table = nn.Embedding(config.block_size, config.n_embd)
103
+ self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
104
+ self.ln_f = nn.LayerNorm(config.n_embd) # final layer norm
105
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
106
+
107
+ def forward(self, idx, targets=None):
108
+ B, T = idx.shape
109
+
110
+ # idx and targets are both (B,T) tensor of integers
111
+ tok_emb = self.token_embedding_table(idx) # (B,T,C)
112
+ pos_emb = self.position_embedding_table(
113
+ torch.arange(T, device=idx.device)
114
+ ) # (T,C)
115
+ x = tok_emb + pos_emb # (B,T,C)
116
+ x = self.blocks(x) # (B,T,C)
117
+ x = self.ln_f(x) # (B,T,C)
118
+ logits = self.lm_head(x) # (B,T,vocab_size)
119
+
120
+ if targets is None:
121
+ loss = None
122
+ else:
123
+ B, T, C = logits.shape
124
+ logits = logits.view(B * T, C)
125
+ targets = targets.view(B * T)
126
+ loss = F.cross_entropy(logits, targets)
127
+
128
+ return logits, loss
129
+
130
+ def generate(self, idx, max_new_tokens, stop_token_id=None):
131
+ for _ in range(max_new_tokens):
132
+ # crop context to the last block_size tokens
133
+ idx_cond = idx[:, -self.config.block_size :]
134
+
135
+ # get the predictions
136
+ logits, _ = self(idx_cond)
137
+
138
+ # get probabilities
139
+ logits = logits[:, -1, :]
140
+ probs = F.softmax(logits, dim=-1)
141
+
142
+ _, idx_next = torch.topk(probs, k=1, dim=-1)
143
+
144
+ # append to the sequence and keep going
145
+ idx = torch.cat((idx, idx_next), dim=1)
146
+
147
+ # stopping rule to avoid unnecesary inference
148
+ if stop_token_id is not None and idx_next.item() == stop_token_id:
149
+ # We hit '$', so we stop inference
150
+ return idx
151
+ # -------------------------------
152
+
153
+ return idx
154
+
155
+ def train_step(self, optimizer, idx, target_idx, importance_weight=1.0):
156
+ """
157
+ Single training step for RL correction.
158
+ idx: (B, T) tensor of context inputs
159
+ target_idx: (B, 1) tensor (or scalar tensor) of the target token to predict
160
+ importance_weight: float multiplier for the loss
161
+ """
162
+ self.train()
163
+ optimizer.zero_grad()
164
+
165
+ # 1. Forward Pass
166
+ # We only care about the last token prediction for the loss
167
+ # The input 'idx' should be the full context up to the target
168
+
169
+ logits, _ = self(idx)
170
+
171
+ # Get the logits for the VERY LAST token (the one we are trying to predict)
172
+ # logits shape: (B, T, V) -> we want (B, -1, V)
173
+ last_token_logits = logits[:, -1, :] # Shape: (B, VocabSize)
174
+
175
+ # 2. Loss Calculation
176
+ # target_idx should be (B) or (1)
177
+ if target_idx.dim() == 2:
178
+ target_idx = target_idx.squeeze(-1)
179
+
180
+ loss = F.cross_entropy(last_token_logits, target_idx, reduction="none")
181
+
182
+ # Apply importance weight
183
+ weighted_loss = loss * importance_weight
184
+ final_loss = weighted_loss.mean()
185
+
186
+ # 3. Update
187
+ final_loss.backward()
188
+
189
+ # Clip gradients to prevent explosion during online updates
190
+ torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=0.5)
191
+
192
+ optimizer.step()
193
+
194
+ # Return probs for visualization
195
+ with torch.no_grad():
196
+ probs = F.softmax(last_token_logits, dim=-1)
197
+
198
+ return final_loss.item(), probs