kauroy1994 commited on
Commit
4a261c9
·
verified ·
1 Parent(s): 8dcf11e

Delete training_script.py

Browse files
Files changed (1) hide show
  1. training_script.py +0 -509
training_script.py DELETED
@@ -1,509 +0,0 @@
1
- import os
2
- import time
3
- import math
4
- import pickle
5
- import random
6
- import json
7
- import numpy as np
8
- import torch
9
- import torch.nn as nn
10
- import torch.nn.functional as F
11
- import matplotlib.pyplot as plt
12
-
13
- # We use Hugging Face’s transformers only for pretrained weight loading and tokenizer.
14
- from transformers import GPT2LMHeadModel, GPT2Tokenizer
15
- from dataclasses import dataclass
16
-
17
- # ----------------------------
18
- # Helper: ALiBi slopes computation
19
- # ----------------------------
20
- def get_alibi_slopes(n_head):
21
- """Compute ALiBi slopes for each head.
22
- This implementation follows the approach used in several ALiBi implementations.
23
- """
24
- def get_slopes_power_of_2(n):
25
- start = 2 ** (-2 ** -(math.log2(n) - 3))
26
- ratio = start
27
- return [start * (ratio ** i) for i in range(n)]
28
- if math.log2(n_head).is_integer():
29
- slopes = get_slopes_power_of_2(n_head)
30
- else:
31
- closest_power_of_2 = 2 ** math.floor(math.log2(n_head))
32
- slopes = get_slopes_power_of_2(closest_power_of_2)
33
- extra_slopes = get_slopes_power_of_2(2 * closest_power_of_2)[0::2][: n_head - closest_power_of_2]
34
- slopes.extend(extra_slopes)
35
- return torch.tensor(slopes, dtype=torch.float32)
36
-
37
- # ----------------------------
38
- # Model Components
39
- # ----------------------------
40
-
41
- class LayerNorm(nn.Module):
42
- """LayerNorm with an optional bias."""
43
- def __init__(self, ndim, bias: bool):
44
- super().__init__()
45
- self.weight = nn.Parameter(torch.ones(ndim))
46
- self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
47
-
48
- def forward(self, input):
49
- return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
50
-
51
- class CausalSelfAttention(nn.Module):
52
- def __init__(self, config):
53
- super().__init__()
54
- assert config.n_embd % config.n_head == 0
55
- self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
56
- self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
57
- self.attn_dropout = nn.Dropout(config.dropout)
58
- self.resid_dropout = nn.Dropout(config.dropout)
59
- self.n_head = config.n_head
60
- self.n_embd = config.n_embd
61
- self.dropout = config.dropout
62
- self.use_rope = config.use_rope
63
- self.rope_base = config.rope_base
64
- # Existing APE support.
65
- self.use_ape = getattr(config, 'use_ape', False)
66
- # New: ALiBi support.
67
- self.use_alibi = getattr(config, 'use_alibi', False)
68
- if self.use_alibi and self.use_ape:
69
- raise ValueError("Cannot use both ALiBi and APE simultaneously.")
70
- # For APE, learn a parameter beta.
71
- if self.use_ape:
72
- self.beta = nn.Parameter(torch.tensor(1.0))
73
- # Use Flash Attention if available (but disable when APE is enabled).
74
- self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
75
- if (not self.flash) or self.use_ape:
76
- self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
77
- .view(1, 1, config.block_size, config.block_size))
78
-
79
- def forward(self, x, return_attn_entropy=False, aggregate_heads=False):
80
- """
81
- Args:
82
- x: Input tensor [B, T, C]
83
- return_attn_entropy (bool): If True, return attention entropy.
84
- aggregate_heads (bool): If True, average entropy across heads.
85
- Returns:
86
- y: Output tensor [B, T, C] or (y, entropy)
87
- """
88
- B, T, C = x.size()
89
- q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
90
- head_dim = C // self.n_head
91
-
92
- # Reshape to [B, n_head, T, head_dim]
93
- q = q.view(B, T, self.n_head, head_dim).transpose(1, 2)
94
- k = k.view(B, T, self.n_head, head_dim).transpose(1, 2)
95
- v = v.view(B, T, self.n_head, head_dim).transpose(1, 2)
96
-
97
- # Optionally, apply RoPE if enabled.
98
- if self.use_rope:
99
- hs = head_dim
100
- d = hs // 2
101
- if self.use_ape:
102
- theta = 1.0 / (self.rope_base ** (2 * torch.arange(0, d, dtype=x.dtype, device=x.device) / hs))
103
- else:
104
- theta = 1.0 / (self.rope_base ** (2 * torch.arange(0, d, dtype=x.dtype, device=x.device) / hs))
105
- t_pos = torch.arange(T, device=x.device, dtype=x.dtype)
106
- freqs = torch.outer(t_pos, theta)
107
- freqs_cos = torch.cos(freqs).unsqueeze(0).unsqueeze(0)
108
- freqs_sin = torch.sin(freqs).unsqueeze(0).unsqueeze(0)
109
- def apply_rope(tensor, cos, sin):
110
- tensor = tensor.reshape(*tensor.shape[:-1], -1, 2)
111
- x0 = tensor[..., 0]
112
- x1 = tensor[..., 1]
113
- x0_rot = x0 * cos - x1 * sin
114
- x1_rot = x0 * sin + x1 * cos
115
- return torch.stack([x0_rot, x1_rot], dim=-1).flatten(start_dim=-2)
116
- q = apply_rope(q, freqs_cos, freqs_sin)
117
- k = apply_rope(k, freqs_cos, freqs_sin)
118
-
119
- # Compute scaled dot-product attention scores.
120
- att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(head_dim))
121
-
122
- # --- Apply positional biases ---
123
- if self.use_alibi:
124
- slopes = get_alibi_slopes(self.n_head).to(x.device) # shape: (n_head,)
125
- rel_positions = torch.arange(T, device=x.device).unsqueeze(0) - torch.arange(T, device=x.device).unsqueeze(1)
126
- alibi_bias = slopes.view(1, self.n_head, 1, 1) * rel_positions.view(1, 1, T, T)
127
- att = att - alibi_bias
128
- elif self.use_ape:
129
- pos_ids = torch.arange(T, device=x.device)
130
- rel_dist = pos_ids.unsqueeze(0) - pos_ids.unsqueeze(1)
131
- abs_rel = rel_dist.abs().float()
132
- temp_matrix = 1.0 / (1.0 + abs_rel)
133
- bias_matrix = - self.beta * torch.log(1.0 + abs_rel)
134
- temp_matrix = temp_matrix.unsqueeze(0).unsqueeze(0)
135
- bias_matrix = bias_matrix.unsqueeze(0).unsqueeze(0)
136
- att = temp_matrix * att + bias_matrix
137
-
138
- p_att = F.softmax(att, dim=-1)
139
- entropy = -(p_att * torch.log(p_att + 1e-9)).sum(dim=-1) # [B, n_head, T, T]
140
-
141
- if self.flash and not self.use_ape:
142
- y = torch.nn.functional.scaled_dot_product_attention(
143
- q, k, v,
144
- attn_mask=None,
145
- dropout_p=self.dropout if self.training else 0,
146
- is_causal=True
147
- )
148
- else:
149
- if T > self.bias.size(-1):
150
- bias = torch.tril(torch.ones(T, T, device=x.device)).view(1, 1, T, T)
151
- else:
152
- bias = self.bias[:, :, :T, :T]
153
- att = att.masked_fill(bias == 0, float('-inf'))
154
- p_att = F.softmax(att, dim=-1)
155
- entropy = -(p_att * torch.log(p_att + 1e-9)).sum(dim=-1)
156
- att = self.attn_dropout(p_att)
157
- y = att @ v # [B, n_head, T, head_dim]
158
-
159
- y = y.transpose(1, 2).contiguous().view(B, T, C)
160
- y = self.resid_dropout(self.c_proj(y))
161
-
162
- if return_attn_entropy:
163
- if aggregate_heads:
164
- entropy = entropy.mean(dim=1) # [B, T, T]
165
- return y, entropy
166
- else:
167
- return y
168
-
169
- class MLP(nn.Module):
170
- def __init__(self, config):
171
- super().__init__()
172
- self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
173
- self.gelu = nn.GELU()
174
- self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
175
- self.dropout = nn.Dropout(config.dropout)
176
- def forward(self, x):
177
- x = self.c_fc(x)
178
- x = self.gelu(x)
179
- x = self.c_proj(x)
180
- x = self.dropout(x)
181
- return x
182
-
183
- class Block(nn.Module):
184
- def __init__(self, config):
185
- super().__init__()
186
- self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
187
- self.attn = CausalSelfAttention(config)
188
- self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
189
- self.mlp = MLP(config)
190
- def forward(self, x, return_attn_entropy=False, aggregate_heads=False):
191
- if return_attn_entropy:
192
- attn_output, entropy = self.attn(self.ln_1(x), return_attn_entropy=True, aggregate_heads=aggregate_heads)
193
- x = x + attn_output
194
- x = x + self.mlp(self.ln_2(x))
195
- return x, entropy
196
- else:
197
- attn_output = self.attn(self.ln_1(x), return_attn_entropy=False)
198
- x = x + attn_output
199
- x = x + self.mlp(self.ln_2(x))
200
- return x
201
-
202
- @dataclass
203
- class GPTConfig:
204
- block_size: int = 128
205
- vocab_size: int = 50304 # For GPT-2
206
- n_layer: int = 6
207
- n_head: int = 6
208
- n_embd: int = 384
209
- dropout: float = 0.0
210
- bias: bool = True
211
- use_rope: bool = True
212
- rope_base: float = 10000.0
213
- use_ape: bool = False
214
- lambda_temp: float = 0.1
215
- use_alibi: bool = False
216
-
217
- class GPT(nn.Module):
218
- def __init__(self, config):
219
- super().__init__()
220
- assert config.vocab_size is not None and config.block_size is not None
221
- self.config = config
222
- # If using ALiBi, disable RoPE.
223
- self.use_rope = config.use_rope and not config.use_alibi
224
- print(f"Using RoPE in GPT init: {self.use_rope}")
225
- self.transformer = nn.ModuleDict(dict(
226
- wte = nn.Embedding(config.vocab_size, config.n_embd),
227
- wpe = None if self.use_rope else nn.Embedding(config.block_size, config.n_embd),
228
- drop = nn.Dropout(config.dropout),
229
- h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
230
- ln_f = LayerNorm(config.n_embd, bias=config.bias),
231
- ))
232
- self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
233
- self.transformer.wte.weight = self.lm_head.weight
234
- self.apply(self._init_weights)
235
- for pn, p in self.named_parameters():
236
- if pn.endswith('c_proj.weight'):
237
- torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
238
- print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
239
- def get_num_params(self, non_embedding=True):
240
- n_params = sum(p.numel() for p in self.parameters())
241
- if non_embedding and (not self.use_rope) and (self.transformer.wpe is not None):
242
- n_params -= self.transformer.wpe.weight.numel()
243
- return n_params
244
- def _init_weights(self, module):
245
- if isinstance(module, nn.Linear):
246
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
247
- if module.bias is not None:
248
- torch.nn.init.zeros_(module.bias)
249
- elif isinstance(module, nn.Embedding):
250
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
251
- def forward(self, idx, targets=None, return_attn_entropy=False, aggregate_heads=False):
252
- device = idx.device
253
- b, t = idx.size()
254
- pos = torch.arange(0, t, dtype=torch.long, device=device)
255
- tok_emb = self.transformer.wte(idx)
256
- if self.use_rope or self.config.use_alibi:
257
- x = self.transformer.drop(tok_emb)
258
- else:
259
- pos_emb = self.transformer.wpe(pos) if self.transformer.wpe is not None else 0
260
- x = self.transformer.drop(tok_emb + pos_emb)
261
- attn_entropies = []
262
- for block in self.transformer.h:
263
- if return_attn_entropy:
264
- x, entropy = block(x, return_attn_entropy=True, aggregate_heads=aggregate_heads)
265
- attn_entropies.append(entropy)
266
- else:
267
- x = block(x)
268
- x = self.transformer.ln_f(x)
269
- if targets is not None:
270
- logits = self.lm_head(x)
271
- loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
272
- else:
273
- logits = self.lm_head(x[:, [-1], :])
274
- loss = None
275
- if return_attn_entropy:
276
- return logits, loss, attn_entropies
277
- else:
278
- return logits, loss
279
- @torch.no_grad()
280
- def generate_and_compute_perplexity(self, prompt, ground_truth, temperature=1.0, return_attn_entropy=False, aggregate_heads=False):
281
- if return_attn_entropy:
282
- _, _, attn_entropies = self(prompt, return_attn_entropy=True, aggregate_heads=aggregate_heads)
283
- per_layer_avgs = [entropy.mean().item() for entropy in attn_entropies]
284
- avg_entropy = np.mean(per_layer_avgs)
285
- else:
286
- avg_entropy = None
287
- total_loss = 0.0
288
- total_tokens = 0
289
- prompt_length = prompt.size(1)
290
- num_target_tokens = ground_truth.size(1) - prompt_length
291
- idx = prompt.clone()
292
- for i in range(num_target_tokens):
293
- logits, _ = self(idx)
294
- logits = logits[:, -1, :] / temperature
295
- target = ground_truth[:, prompt_length + i]
296
- loss = F.cross_entropy(logits, target, reduction='sum')
297
- total_loss += loss.item()
298
- total_tokens += target.numel()
299
- target_token = target.unsqueeze(1)
300
- idx = torch.cat((idx, target_token), dim=1)
301
- avg_neg_log_likelihood = total_loss / total_tokens if total_tokens > 0 else float('inf')
302
- perplexity = math.exp(avg_neg_log_likelihood)
303
- return idx, perplexity, avg_entropy
304
- @torch.no_grad()
305
- def generate_until_end(self, idx, temperature=1.0, top_k=None, max_new_tokens=1000):
306
- for i in range(max_new_tokens):
307
- idx_cond = idx
308
- logits, _ = self(idx_cond)
309
- logits = logits[:, -1, :] / temperature
310
- if top_k is not None:
311
- v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
312
- logits[logits < v[:, [-1]]] = -float('Inf')
313
- probs = F.softmax(logits, dim=-1)
314
- idx_next = torch.multinomial(probs, num_samples=1)
315
- idx = torch.cat((idx, idx_next), dim=1)
316
- if idx_next.item() == 50256:
317
- break
318
- return idx
319
-
320
- # ----------------------------
321
- # Utility Functions for Training & Evaluation
322
- # ----------------------------
323
-
324
- # Data Loader Functions
325
- train_data_path = "/data1/home/nitinvetcha/Topics in AI/Streamlined/COLM2025/train_tinystories.bin"
326
- val_data_path = "/data1/home/nitinvetcha/Topics in AI/Streamlined/COLM2025/val_tinystories.bin"
327
- def get_batch(split):
328
- data_path = train_data_path if split == 'train' else val_data_path
329
- data = np.memmap(data_path, dtype=np.uint16, mode='r')
330
- total_tokens = len(data)
331
- max_ix = max(1, total_tokens - gptconf.block_size)
332
- ix = torch.randint(0, max_ix, (batch_size,))
333
- X = torch.stack([torch.from_numpy(data[i:i+gptconf.block_size].astype(np.int64)) for i in ix])
334
- Y = torch.stack([torch.from_numpy(data[i+1:i+1+gptconf.block_size].astype(np.int64)) for i in ix])
335
- return X.to(device), Y.to(device)
336
-
337
- def evaluate_prompt_perplexity(model, token_file, prompt_length, num_trials, generation_params, device):
338
- tokens = np.fromfile(token_file, dtype=np.uint16)
339
- total_tokens = len(tokens)
340
- perplexities = []
341
- entropy_trials = []
342
- max_new_tokens = generation_params.get("max_new_tokens", 50)
343
- total_length = prompt_length + max_new_tokens
344
- for trial in range(num_trials):
345
- start_idx = random.randint(0, total_tokens - total_length)
346
- sequence_tokens = tokens[start_idx : start_idx + total_length]
347
- prompt_tokens = sequence_tokens[:prompt_length]
348
- ground_truth_tokens = sequence_tokens
349
- prompt_tensor = torch.tensor(prompt_tokens, dtype=torch.long).unsqueeze(0).to(device)
350
- ground_truth_tensor = torch.tensor(ground_truth_tokens, dtype=torch.long).unsqueeze(0).to(device)
351
- _, ppl, trial_entropy = model.generate_and_compute_perplexity(
352
- prompt_tensor, ground_truth_tensor,
353
- temperature=generation_params.get("temperature", 1.0),
354
- return_attn_entropy=True, aggregate_heads=True
355
- )
356
- perplexities.append(ppl)
357
- entropy_trials.append(trial_entropy)
358
- print(f"Trial {trial+1}/{num_trials} for prompt length {prompt_length}: Perplexity = {ppl:.2f}, Avg Entropy = {trial_entropy:.4f}")
359
- avg_ppl = np.mean(perplexities)
360
- avg_entropy = np.mean(entropy_trials)
361
- print(f"Prompt Length {prompt_length} - Avg Perplexity: {avg_ppl:.2f}, Avg Attention Entropy: {avg_entropy:.4f}\n")
362
- return avg_ppl, avg_entropy
363
-
364
- # ----------------------------
365
- # Training Loop
366
- # ----------------------------
367
- # Training hyperparameters
368
- batch_size = 12
369
- max_iters = 25001
370
- save_interval = 5000
371
- learning_rate = 6e-4
372
- weight_decay = 1e-1
373
- grad_clip = 1.0
374
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
375
- torch.manual_seed(1337)
376
-
377
- # Model configuration: adjust these flags as needed.
378
- model_args = dict(
379
- n_layer=6,
380
- n_head=6,
381
- n_embd=384,
382
- block_size=64, # You can change this as needed.
383
- bias=False,
384
- use_rope=True,
385
- use_ape=True, # Set to True if you want APE.
386
- use_alibi=False, # Set to True to use ALiBi.
387
- rope_base=10000.0,
388
- vocab_size=50304,
389
- dropout=0.0
390
- )
391
- gptconf = GPTConfig(**model_args)
392
- model = GPT(gptconf).to(device)
393
- model.train()
394
-
395
- optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
396
- iter_num = 0
397
- start_time = time.time()
398
- training_losses = []
399
- validation_losses = []
400
- save_iters = []
401
-
402
- # Build a flag string for naming: e.g. "rope_ape" or "alibi" etc.
403
- flag_parts = []
404
- if gptconf.use_rope:
405
- flag_parts.append("rope")
406
- if gptconf.use_ape:
407
- flag_parts.append("ape")
408
- if gptconf.use_alibi:
409
- flag_parts.append("alibi")
410
- flag_str = "_".join(flag_parts) if flag_parts else "none"
411
- weight_dir = f"weights_{flag_str}_{gptconf.block_size}"
412
- os.makedirs(weight_dir, exist_ok=True)
413
-
414
- while iter_num < max_iters:
415
- X_train, Y_train = get_batch('train')
416
- optimizer.zero_grad()
417
- logits, loss_train = model(X_train, Y_train)
418
- loss_train.backward()
419
- if grad_clip > 0:
420
- torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
421
- optimizer.step()
422
- training_losses.append(loss_train.item())
423
-
424
- model.eval()
425
- X_val, Y_val = get_batch('val')
426
- with torch.no_grad():
427
- logits_val, loss_val = model(X_val, Y_val)
428
- validation_losses.append(loss_val.item())
429
- model.train()
430
-
431
- if iter_num % 100 == 0:
432
- elapsed = time.time() - start_time
433
- print(f"Iter {iter_num:5d}: train loss = {loss_train.item():.4f}, val loss = {loss_val.item():.4f}, time/iter = {elapsed/(iter_num+1):.4f}s")
434
-
435
- if iter_num > 0 and iter_num % save_interval == 0:
436
- save_iters.append(iter_num)
437
- ckpt = {
438
- 'iter_num': iter_num,
439
- 'model_state_dict': model.state_dict(),
440
- 'optimizer_state_dict': optimizer.state_dict(),
441
- 'training_losses': training_losses,
442
- 'validation_losses': validation_losses,
443
- 'save_iters': save_iters,
444
- }
445
- ckpt_path = os.path.join(weight_dir, f"ckpt_{iter_num}.pt")
446
- torch.save(ckpt, ckpt_path)
447
- print(f"Checkpoint saved to {ckpt_path}")
448
-
449
- iter_num += 1
450
-
451
- print("Training complete.")
452
-
453
- plt.figure(figsize=(10, 6))
454
- plt.plot(range(len(training_losses)), training_losses, label="Training Loss")
455
- plt.plot(range(len(validation_losses)), validation_losses, label="Validation Loss", alpha=0.7)
456
- plt.xlabel("Iteration")
457
- plt.ylabel("Loss")
458
- plt.title("Training and Validation Loss per Iteration")
459
- plt.legend()
460
- plt.grid(True)
461
- plt.show()
462
-
463
- # ----------------------------
464
- # Perplexity & Entropy Evaluation
465
- # ----------------------------
466
-
467
- token_file = val_data_path # Use validation data for evaluation.
468
- prompt_lengths = [64, 128, 256, 512, 1024, 2048, 4096, 8192]
469
- num_trials = 5
470
- generation_params = {"temperature": 1.0, "max_new_tokens": 50}
471
-
472
- avg_perplexities = []
473
- avg_entropies = []
474
-
475
- for pl in prompt_lengths:
476
- print(f"Evaluating for prompt length: {pl}")
477
- avg_ppl, avg_entropy = evaluate_prompt_perplexity(model, token_file, pl, num_trials, generation_params, device)
478
- avg_perplexities.append(avg_ppl)
479
- avg_entropies.append(avg_entropy)
480
-
481
- results = {
482
- "prompt_lengths": prompt_lengths,
483
- "avg_perplexities": avg_perplexities,
484
- "avg_entropies": avg_entropies
485
- }
486
- results_filename = f"results_{flag_str}_{gptconf.block_size}.json"
487
- with open(results_filename, "w") as f:
488
- json.dump(results, f)
489
- print(f"Results saved to {results_filename}")
490
-
491
- plt.figure(figsize=(8, 6))
492
- plt.plot(prompt_lengths, avg_perplexities, marker='o')
493
- plt.xlabel("Prompt Length")
494
- plt.ylabel("Avg Generated Perplexity")
495
- plt.title("Avg Generated Perplexity vs Prompt Length")
496
- plt.grid(True)
497
- plt.xscale('log')
498
- plt.savefig(f"avg_generated_perplexity_{flag_str}_{gptconf.block_size}.png")
499
- plt.show()
500
-
501
- plt.figure(figsize=(8, 6))
502
- plt.plot(prompt_lengths, avg_entropies, marker='o', color='red')
503
- plt.xlabel("Prompt Length")
504
- plt.ylabel("Avg Attention Entropy")
505
- plt.title("Avg Attention Entropy vs Prompt Length\n(Averaged over Layers)")
506
- plt.grid(True)
507
- plt.xscale('log')
508
- plt.savefig(f"avg_attention_entropy_{flag_str}_{gptconf.block_size}.png")
509
- plt.show()