sidharthg commited on
Commit
c60dfcc
·
verified ·
1 Parent(s): c4d4771

Upload gpt2_train_per.py

Browse files
Files changed (1) hide show
  1. scripts/gpt2_train_per.py +366 -1
scripts/gpt2_train_per.py CHANGED
@@ -1 +1,366 @@
1
- # This file is intentionally left blank.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is intentionally left blank.# Solving for residual std scaling issue
2
+ import os
3
+ import math
4
+ import time
5
+ import inspect
6
+ from dataclasses import dataclass
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.nn import functional as F
10
+ from torch.nn.utils import clip_grad_norm_
11
+ from torch.utils.checkpoint import checkpoint # Moved this import to the top
12
+
13
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64"
14
+
15
+ class CausalSelfAttention(nn.Module):
16
+
17
+ def __init__(self, config):
18
+ super().__init__()
19
+ assert config.n_embd % config.n_head == 0
20
+ # key, query, value projections for all heads, but in a batch
21
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
22
+ # output projection
23
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd)
24
+ self.c_proj.NANGPT_SCALE_INIT = 1
25
+ # regularization
26
+ self.n_head = config.n_head
27
+ self.n_embd = config.n_embd
28
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))
29
+
30
+ def forward(self, x):
31
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
32
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
33
+ # nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
34
+ # e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer
35
+ qkv = self.c_attn(x)
36
+ q, k, v = qkv.split(self.n_embd, dim=2)
37
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
38
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
39
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
40
+
41
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
42
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
43
+ att = F.softmax(att, dim=-1)
44
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs)
45
+
46
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
47
+ # output projection
48
+ y = self.c_proj(y)
49
+ return y
50
+
51
+
52
+ class MLP(nn.Module):
53
+
54
+ def __init__(self, config):
55
+ super().__init__()
56
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
57
+ self.gelu = nn.GELU(approximate='tanh')
58
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
59
+ self.c_proj.NANOGPT_SCALE_INIT = 1
60
+
61
+ def forward(self, x):
62
+ x = self.c_fc(x)
63
+ x = self.gelu(x)
64
+ x = self.c_proj(x)
65
+ return x
66
+
67
+ class Block(nn.Module):
68
+
69
+ def __init__(self, config):
70
+ super().__init__()
71
+ self.ln_1 = nn.LayerNorm(config.n_embd)
72
+ self.attn = CausalSelfAttention(config)
73
+ self.ln_2 = nn.LayerNorm(config.n_embd)
74
+ self.mlp = MLP(config)
75
+
76
+ # In forward of Block:
77
+ def forward(self, x):
78
+ def _forward_block(x):
79
+ x = x + self.attn(self.ln_1(x))
80
+ x = x + self.mlp(self.ln_2(x))
81
+ return x
82
+ return checkpoint(_forward_block, x)
83
+
84
+
85
+ @dataclass
86
+ class GPTConfig:
87
+ block_size: int = 1024 # max sequence length
88
+ vocab_size: int = 50257 # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token
89
+ n_layer: int = 6 # number of layers (reduced from 12)
90
+ n_head: int = 6 # number of heads (reduced from 12)
91
+ n_embd: int = 384 # embedding dimension (reduced from 768)
92
+
93
+
94
+ class GPT(nn.Module):
95
+
96
+ def __init__(self, config):
97
+ super().__init__()
98
+ self.config = config
99
+
100
+ self.transformer = nn.ModuleDict(dict(
101
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
102
+ wpe = nn.Embedding(config.block_size, config.n_embd),
103
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
104
+ ln_f = nn.LayerNorm(config.n_embd),
105
+ ))
106
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
107
+
108
+ # weight sharing
109
+ self.transformer.wte.weight = self.lm_head.weight
110
+
111
+ # weight initialization
112
+ self.apply(self._init_weights)
113
+
114
+ def _init_weights(self, module):
115
+ if isinstance(module, nn.Linear):
116
+ std = 0.02
117
+ if hasattr(module, 'NANGPT_SCALE_INIT'):
118
+ std *= (2 * self.config.n_layer) ** -0.5
119
+ torch.nn.init.normal_(module.weight, mean = 0.0, std = std)
120
+ if module.bias is not None:
121
+ torch.nn.init.zeros_(module.bias)
122
+ elif isinstance(module, nn.Embedding):
123
+ torch.nn.init.normal_(module.weight, mean=0.0, std = 0.02)
124
+
125
+
126
+
127
+ def forward(self, idx, targets=None):
128
+ # idx is of shape (B, T)
129
+ B, T = idx.size()
130
+ assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
131
+ # forward the token and posisition embeddings
132
+ pos = torch.arange(0, T, dtype=torch.long, device=idx.device) # shape (T)
133
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (T, n_embd)
134
+ tok_emb = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd)
135
+ x = tok_emb + pos_emb
136
+ # forward the blocks of the transformer
137
+ for block in self.transformer.h:
138
+ x = block(x)
139
+ # forward the final layernorm and the classifier
140
+ x = self.transformer.ln_f(x)
141
+ logits = self.lm_head(x) # (B, T, vocab_size)
142
+ loss = None
143
+ if targets is not None:
144
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
145
+ return logits, loss
146
+
147
+ @classmethod
148
+ def from_pretrained(cls, model_type):
149
+ """Loads pretrained GPT-2 model weights from huggingface"""
150
+ assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
151
+ from transformers import GPT2LMHeadModel
152
+ print("loading weights from pretrained gpt: %s" % model_type)
153
+
154
+ # n_layer, n_head and n_embd are determined from model_type
155
+ config_args = {
156
+ 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
157
+ 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
158
+ 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
159
+ 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
160
+ }[model_type]
161
+ config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
162
+ config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
163
+ # create a from-scratch initialized minGPT model
164
+ config = GPTConfig(**config_args)
165
+ model = GPT(config)
166
+ sd = model.state_dict()
167
+ sd_keys = sd.keys()
168
+ sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
169
+
170
+ # init a huggingface/transformers model
171
+ model_hf = GPT2LMHeadModel.from_pretrained(model_type)
172
+ sd_hf = model_hf.state_dict()
173
+
174
+ # copy while ensuring all of the parameters are aligned and match in names and shapes
175
+ sd_keys_hf = sd_hf.keys()
176
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
177
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
178
+ transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
179
+ # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
180
+ # this means that we have to transpose these weights when we import them
181
+ assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
182
+ for k in sd_keys_hf:
183
+ if any(k.endswith(w) for w in transposed):
184
+ # special treatment for the Conv1D weights we need to transpose
185
+ assert sd_hf[k].shape[::-1] == sd[k].shape
186
+ with torch.no_grad():
187
+ sd[k].copy_(sd_hf[k].t())
188
+ else:
189
+ # vanilla copy over the other parameters
190
+ assert sd_hf[k].shape == sd[k].shape
191
+ with torch.no_grad():
192
+ sd[k].copy_(sd_hf[k])
193
+
194
+ return model
195
+
196
+ # Device setup same as before
197
+ device = 'cpu'
198
+ if torch.cuda.is_available():
199
+ device = 'cuda'
200
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
201
+ device = "mps"
202
+ print(f"using device: {device}")
203
+
204
+ # Seed for reproducibility
205
+ torch.manual_seed(42)
206
+ if torch.cuda.is_available():
207
+ torch.cuda.manual_seed(42)
208
+
209
+ # Hyperparameters
210
+ B, T = 8,128 # batch size and sequence length (8192 tokens per batch)
211
+ max_iters = 2000
212
+ warmup_iters = 200
213
+ base_lr = 3e-4
214
+ final_lr = 1e-5
215
+ grad_clip = 1.0
216
+ patience = 20 # early stopping patience
217
+ num_val_batches = 10
218
+ accum_steps = 4 # effectively batch 32 by accumulation
219
+
220
+ import tiktoken
221
+
222
+ class DataLoaderLite:
223
+ def __init__(self, B, T):
224
+ self.B = B
225
+ self.T = T
226
+
227
+ # at init load tokens from disk and store them in memory
228
+ with open('input.txt', 'r') as f:
229
+ text = f.read()
230
+ enc = tiktoken.get_encoding('gpt2')
231
+ tokens = enc.encode(text)
232
+ self.tokens = torch.tensor(tokens)
233
+ print(f'loaded {len(self.tokens)} tokens')
234
+ print(f'1 epoch = {len(self.tokens) // (B * T)} batches')
235
+
236
+ # state
237
+ self.current_position = 0
238
+
239
+ def next_batch(self):
240
+ B, T = self.B, self.T
241
+ buf = self.tokens[self.current_position: self.current_position + B * T + 1]
242
+ x = (buf[:-1]).view(B, T) # inputs
243
+ y = (buf[1:]).view(B, T) # targets
244
+ # advance the position in the tensor
245
+ self.current_position += B*T
246
+ # if loading the next batch would be out of bounds, reset
247
+ if self.current_position + (B * T + 1) > len(self.tokens):
248
+ self.current_position = 0
249
+ return x, y
250
+ # Load full tokens
251
+ with open('input.txt', 'r') as f:
252
+ text = f.read()
253
+ enc = tiktoken.get_encoding('gpt2')
254
+ tokens = torch.tensor(enc.encode(text))
255
+
256
+ # Simple 90/10 train-val split to avoid data leakage
257
+ num_train_tokens = int(0.9 * len(tokens))
258
+ train_tokens = tokens[:num_train_tokens]
259
+ val_tokens = tokens[num_train_tokens:]
260
+
261
+ # Create data loaders pointing to split tokens
262
+ train_loader = DataLoaderLite(B, T)
263
+ train_loader.tokens = train_tokens
264
+ train_loader.current_position = 0
265
+
266
+ val_loader = DataLoaderLite(B, T)
267
+ val_loader.tokens = val_tokens
268
+ val_loader.current_position = 0
269
+
270
+ # Clear CUDA cache before model initialization
271
+ if torch.cuda.is_available():
272
+ torch.cuda.empty_cache()
273
+
274
+ # Initialize model
275
+ model = GPT(GPTConfig())
276
+ model.to(device)
277
+
278
+ # Optimizer
279
+ optimizer = torch.optim.AdamW(model.parameters(), lr=base_lr)
280
+
281
+ # Learning rate schedule: linear warmup + cosine decay
282
+ def get_lr(step):
283
+ if step < warmup_iters:
284
+ return base_lr * step / warmup_iters
285
+ progress = (step - warmup_iters) / (max_iters - warmup_iters)
286
+ return final_lr + 0.5 * (base_lr - final_lr) * (1 + math.cos(math.pi * progress))
287
+
288
+ best_val_loss = float('inf')
289
+ no_improve_steps = 0
290
+
291
+ model.train()
292
+
293
+ from torch.amp import GradScaler, autocast
294
+ scaler = GradScaler('cuda')
295
+
296
+ for step in range(max_iters):
297
+ optimizer.zero_grad()
298
+ for _ in range(accum_steps):
299
+ x, y = train_loader.next_batch()
300
+ x, y = x.to(device), y.to(device)
301
+
302
+ with autocast('cuda'):
303
+ logits, loss = model(x, y)
304
+ loss = loss / accum_steps # scale loss
305
+
306
+ scaler.scale(loss).backward()
307
+
308
+ # Gradient clipping and optimizer step
309
+ scaler.unscale_(optimizer)
310
+ clip_grad_norm_(model.parameters(), grad_clip)
311
+ scaler.step(optimizer)
312
+ scaler.update()
313
+
314
+ torch.cuda.empty_cache()
315
+
316
+ # Validation and logs every N steps (adjust for accum steps)
317
+ if step % (100 // accum_steps) == 0 or step == max_iters - 1:
318
+ model.eval()
319
+ val_losses = []
320
+ with torch.no_grad():
321
+ for _ in range(num_val_batches):
322
+ xv, yv = val_loader.next_batch()
323
+ xv, yv = xv.to(device), yv.to(device)
324
+ _, val_loss = model(xv, yv)
325
+ val_losses.append(val_loss.item())
326
+ avg_val_loss = sum(val_losses) / len(val_losses)
327
+
328
+ lr = get_lr(step) # Assign the learning rate
329
+ print(f"Step {step}: train loss {loss.item():.5f}, val loss {avg_val_loss:.5f}, lr {lr:.6f}")
330
+
331
+ # Early stopping and checkpoint saving
332
+ if avg_val_loss < best_val_loss:
333
+ best_val_loss = avg_val_loss
334
+ no_improve_steps = 0
335
+ torch.save(model.state_dict(), 'best_model.pt')
336
+ print("Checkpoint saved.")
337
+ else:
338
+ no_improve_steps += 1
339
+ if no_improve_steps >= patience:
340
+ print("Early stopping triggered.")
341
+ break
342
+ model.train()
343
+
344
+ # Load best model for sampling/generation
345
+ model.load_state_dict(torch.load('best_model.pt'))
346
+ model.eval()
347
+
348
+ # Sampling/generation code (unchanged from original)
349
+ num_return_sequences = 5
350
+ max_length = 30
351
+ x = val_loader.next_batch()[0][:num_return_sequences].to(device) # start from some validation tokens
352
+
353
+ while x.size(1) < max_length:
354
+ with torch.no_grad():
355
+ logits = model(x)[0]
356
+ logits = logits[:, -1, :]
357
+ probs = F.softmax(logits, dim=-1)
358
+ topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
359
+ ix = torch.multinomial(topk_probs, 1)
360
+ xcol = torch.gather(topk_indices, -1, ix)
361
+ x = torch.cat((x, xcol), dim=1)
362
+
363
+ for i in range(num_return_sequences):
364
+ tokens = x[i, :max_length].tolist()
365
+ decoded = enc.decode(tokens)
366
+ print(">", decoded)