Nitin2004 commited on
Commit
f881e8c
·
verified ·
1 Parent(s): 5626bda

Upload full model_ADPB folder

Browse files
Files changed (3) hide show
  1. APE.py +509 -0
  2. config.json +15 -0
  3. pytorch_model.bin +3 -0
APE.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import math
4
+ import pickle
5
+ import random
6
+ import json
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import matplotlib.pyplot as plt
12
+
13
+ # We use Hugging Face’s transformers only for pretrained weight loading and tokenizer.
14
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
15
+ from dataclasses import dataclass
16
+
17
+ # ----------------------------
18
+ # Helper: ALiBi slopes computation
19
+ # ----------------------------
20
+ def get_alibi_slopes(n_head):
21
+ """Compute ALiBi slopes for each head.
22
+ This implementation follows the approach used in several ALiBi implementations.
23
+ """
24
+ def get_slopes_power_of_2(n):
25
+ start = 2 ** (-2 ** -(math.log2(n) - 3))
26
+ ratio = start
27
+ return [start * (ratio ** i) for i in range(n)]
28
+ if math.log2(n_head).is_integer():
29
+ slopes = get_slopes_power_of_2(n_head)
30
+ else:
31
+ closest_power_of_2 = 2 ** math.floor(math.log2(n_head))
32
+ slopes = get_slopes_power_of_2(closest_power_of_2)
33
+ extra_slopes = get_slopes_power_of_2(2 * closest_power_of_2)[0::2][: n_head - closest_power_of_2]
34
+ slopes.extend(extra_slopes)
35
+ return torch.tensor(slopes, dtype=torch.float32)
36
+
37
+ # ----------------------------
38
+ # Model Components
39
+ # ----------------------------
40
+
41
+ class LayerNorm(nn.Module):
42
+ """LayerNorm with an optional bias."""
43
+ def __init__(self, ndim, bias: bool):
44
+ super().__init__()
45
+ self.weight = nn.Parameter(torch.ones(ndim))
46
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
47
+
48
+ def forward(self, input):
49
+ return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
50
+
51
+ class CausalSelfAttention(nn.Module):
52
+ def __init__(self, config):
53
+ super().__init__()
54
+ assert config.n_embd % config.n_head == 0
55
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
56
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
57
+ self.attn_dropout = nn.Dropout(config.dropout)
58
+ self.resid_dropout = nn.Dropout(config.dropout)
59
+ self.n_head = config.n_head
60
+ self.n_embd = config.n_embd
61
+ self.dropout = config.dropout
62
+ self.use_rope = config.use_rope
63
+ self.rope_base = config.rope_base
64
+ # Existing APE support.
65
+ self.use_ape = getattr(config, 'use_ape', False)
66
+ # New: ALiBi support.
67
+ self.use_alibi = getattr(config, 'use_alibi', False)
68
+ if self.use_alibi and self.use_ape:
69
+ raise ValueError("Cannot use both ALiBi and APE simultaneously.")
70
+ # For APE, learn a parameter beta.
71
+ if self.use_ape:
72
+ self.beta = nn.Parameter(torch.tensor(1.0))
73
+ # Use Flash Attention if available (but disable when APE is enabled).
74
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
75
+ if (not self.flash) or self.use_ape:
76
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
77
+ .view(1, 1, config.block_size, config.block_size))
78
+
79
+ def forward(self, x, return_attn_entropy=False, aggregate_heads=False):
80
+ """
81
+ Args:
82
+ x: Input tensor [B, T, C]
83
+ return_attn_entropy (bool): If True, return attention entropy.
84
+ aggregate_heads (bool): If True, average entropy across heads.
85
+ Returns:
86
+ y: Output tensor [B, T, C] or (y, entropy)
87
+ """
88
+ B, T, C = x.size()
89
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
90
+ head_dim = C // self.n_head
91
+
92
+ # Reshape to [B, n_head, T, head_dim]
93
+ q = q.view(B, T, self.n_head, head_dim).transpose(1, 2)
94
+ k = k.view(B, T, self.n_head, head_dim).transpose(1, 2)
95
+ v = v.view(B, T, self.n_head, head_dim).transpose(1, 2)
96
+
97
+ # Optionally, apply RoPE if enabled.
98
+ if self.use_rope:
99
+ hs = head_dim
100
+ d = hs // 2
101
+ if self.use_ape:
102
+ theta = 1.0 / (self.rope_base ** (2 * torch.arange(0, d, dtype=x.dtype, device=x.device) / hs))
103
+ else:
104
+ theta = 1.0 / (self.rope_base ** (2 * torch.arange(0, d, dtype=x.dtype, device=x.device) / hs))
105
+ t_pos = torch.arange(T, device=x.device, dtype=x.dtype)
106
+ freqs = torch.outer(t_pos, theta)
107
+ freqs_cos = torch.cos(freqs).unsqueeze(0).unsqueeze(0)
108
+ freqs_sin = torch.sin(freqs).unsqueeze(0).unsqueeze(0)
109
+ def apply_rope(tensor, cos, sin):
110
+ tensor = tensor.reshape(*tensor.shape[:-1], -1, 2)
111
+ x0 = tensor[..., 0]
112
+ x1 = tensor[..., 1]
113
+ x0_rot = x0 * cos - x1 * sin
114
+ x1_rot = x0 * sin + x1 * cos
115
+ return torch.stack([x0_rot, x1_rot], dim=-1).flatten(start_dim=-2)
116
+ q = apply_rope(q, freqs_cos, freqs_sin)
117
+ k = apply_rope(k, freqs_cos, freqs_sin)
118
+
119
+ # Compute scaled dot-product attention scores.
120
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(head_dim))
121
+
122
+ # --- Apply positional biases ---
123
+ if self.use_alibi:
124
+ slopes = get_alibi_slopes(self.n_head).to(x.device) # shape: (n_head,)
125
+ rel_positions = torch.arange(T, device=x.device).unsqueeze(0) - torch.arange(T, device=x.device).unsqueeze(1)
126
+ alibi_bias = slopes.view(1, self.n_head, 1, 1) * rel_positions.view(1, 1, T, T)
127
+ att = att - alibi_bias
128
+ elif self.use_ape:
129
+ pos_ids = torch.arange(T, device=x.device)
130
+ rel_dist = pos_ids.unsqueeze(0) - pos_ids.unsqueeze(1)
131
+ abs_rel = rel_dist.abs().float()
132
+ temp_matrix = 1.0 / (1.0 + abs_rel)
133
+ bias_matrix = - self.beta * torch.log(1.0 + abs_rel)
134
+ temp_matrix = temp_matrix.unsqueeze(0).unsqueeze(0)
135
+ bias_matrix = bias_matrix.unsqueeze(0).unsqueeze(0)
136
+ att = temp_matrix * att + bias_matrix
137
+
138
+ p_att = F.softmax(att, dim=-1)
139
+ entropy = -(p_att * torch.log(p_att + 1e-9)).sum(dim=-1) # [B, n_head, T, T]
140
+
141
+ if self.flash and not self.use_ape:
142
+ y = torch.nn.functional.scaled_dot_product_attention(
143
+ q, k, v,
144
+ attn_mask=None,
145
+ dropout_p=self.dropout if self.training else 0,
146
+ is_causal=True
147
+ )
148
+ else:
149
+ if T > self.bias.size(-1):
150
+ bias = torch.tril(torch.ones(T, T, device=x.device)).view(1, 1, T, T)
151
+ else:
152
+ bias = self.bias[:, :, :T, :T]
153
+ att = att.masked_fill(bias == 0, float('-inf'))
154
+ p_att = F.softmax(att, dim=-1)
155
+ entropy = -(p_att * torch.log(p_att + 1e-9)).sum(dim=-1)
156
+ att = self.attn_dropout(p_att)
157
+ y = att @ v # [B, n_head, T, head_dim]
158
+
159
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
160
+ y = self.resid_dropout(self.c_proj(y))
161
+
162
+ if return_attn_entropy:
163
+ if aggregate_heads:
164
+ entropy = entropy.mean(dim=1) # [B, T, T]
165
+ return y, entropy
166
+ else:
167
+ return y
168
+
169
+ class MLP(nn.Module):
170
+ def __init__(self, config):
171
+ super().__init__()
172
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
173
+ self.gelu = nn.GELU()
174
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
175
+ self.dropout = nn.Dropout(config.dropout)
176
+ def forward(self, x):
177
+ x = self.c_fc(x)
178
+ x = self.gelu(x)
179
+ x = self.c_proj(x)
180
+ x = self.dropout(x)
181
+ return x
182
+
183
+ class Block(nn.Module):
184
+ def __init__(self, config):
185
+ super().__init__()
186
+ self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
187
+ self.attn = CausalSelfAttention(config)
188
+ self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
189
+ self.mlp = MLP(config)
190
+ def forward(self, x, return_attn_entropy=False, aggregate_heads=False):
191
+ if return_attn_entropy:
192
+ attn_output, entropy = self.attn(self.ln_1(x), return_attn_entropy=True, aggregate_heads=aggregate_heads)
193
+ x = x + attn_output
194
+ x = x + self.mlp(self.ln_2(x))
195
+ return x, entropy
196
+ else:
197
+ attn_output = self.attn(self.ln_1(x), return_attn_entropy=False)
198
+ x = x + attn_output
199
+ x = x + self.mlp(self.ln_2(x))
200
+ return x
201
+
202
+ @dataclass
203
+ class GPTConfig:
204
+ block_size: int = 128
205
+ vocab_size: int = 50304 # For GPT-2
206
+ n_layer: int = 6
207
+ n_head: int = 6
208
+ n_embd: int = 384
209
+ dropout: float = 0.0
210
+ bias: bool = True
211
+ use_rope: bool = True
212
+ rope_base: float = 10000.0
213
+ use_ape: bool = False
214
+ lambda_temp: float = 0.1
215
+ use_alibi: bool = False
216
+
217
+ class GPT(nn.Module):
218
+ def __init__(self, config):
219
+ super().__init__()
220
+ assert config.vocab_size is not None and config.block_size is not None
221
+ self.config = config
222
+ # If using ALiBi, disable RoPE.
223
+ self.use_rope = config.use_rope and not config.use_alibi
224
+ print(f"Using RoPE in GPT init: {self.use_rope}")
225
+ self.transformer = nn.ModuleDict(dict(
226
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
227
+ wpe = None if self.use_rope else nn.Embedding(config.block_size, config.n_embd),
228
+ drop = nn.Dropout(config.dropout),
229
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
230
+ ln_f = LayerNorm(config.n_embd, bias=config.bias),
231
+ ))
232
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
233
+ self.transformer.wte.weight = self.lm_head.weight
234
+ self.apply(self._init_weights)
235
+ for pn, p in self.named_parameters():
236
+ if pn.endswith('c_proj.weight'):
237
+ torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
238
+ print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
239
+ def get_num_params(self, non_embedding=True):
240
+ n_params = sum(p.numel() for p in self.parameters())
241
+ if non_embedding and (not self.use_rope) and (self.transformer.wpe is not None):
242
+ n_params -= self.transformer.wpe.weight.numel()
243
+ return n_params
244
+ def _init_weights(self, module):
245
+ if isinstance(module, nn.Linear):
246
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
247
+ if module.bias is not None:
248
+ torch.nn.init.zeros_(module.bias)
249
+ elif isinstance(module, nn.Embedding):
250
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
251
+ def forward(self, idx, targets=None, return_attn_entropy=False, aggregate_heads=False):
252
+ device = idx.device
253
+ b, t = idx.size()
254
+ pos = torch.arange(0, t, dtype=torch.long, device=device)
255
+ tok_emb = self.transformer.wte(idx)
256
+ if self.use_rope or self.config.use_alibi:
257
+ x = self.transformer.drop(tok_emb)
258
+ else:
259
+ pos_emb = self.transformer.wpe(pos) if self.transformer.wpe is not None else 0
260
+ x = self.transformer.drop(tok_emb + pos_emb)
261
+ attn_entropies = []
262
+ for block in self.transformer.h:
263
+ if return_attn_entropy:
264
+ x, entropy = block(x, return_attn_entropy=True, aggregate_heads=aggregate_heads)
265
+ attn_entropies.append(entropy)
266
+ else:
267
+ x = block(x)
268
+ x = self.transformer.ln_f(x)
269
+ if targets is not None:
270
+ logits = self.lm_head(x)
271
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
272
+ else:
273
+ logits = self.lm_head(x[:, [-1], :])
274
+ loss = None
275
+ if return_attn_entropy:
276
+ return logits, loss, attn_entropies
277
+ else:
278
+ return logits, loss
279
+ @torch.no_grad()
280
+ def generate_and_compute_perplexity(self, prompt, ground_truth, temperature=1.0, return_attn_entropy=False, aggregate_heads=False):
281
+ if return_attn_entropy:
282
+ _, _, attn_entropies = self(prompt, return_attn_entropy=True, aggregate_heads=aggregate_heads)
283
+ per_layer_avgs = [entropy.mean().item() for entropy in attn_entropies]
284
+ avg_entropy = np.mean(per_layer_avgs)
285
+ else:
286
+ avg_entropy = None
287
+ total_loss = 0.0
288
+ total_tokens = 0
289
+ prompt_length = prompt.size(1)
290
+ num_target_tokens = ground_truth.size(1) - prompt_length
291
+ idx = prompt.clone()
292
+ for i in range(num_target_tokens):
293
+ logits, _ = self(idx)
294
+ logits = logits[:, -1, :] / temperature
295
+ target = ground_truth[:, prompt_length + i]
296
+ loss = F.cross_entropy(logits, target, reduction='sum')
297
+ total_loss += loss.item()
298
+ total_tokens += target.numel()
299
+ target_token = target.unsqueeze(1)
300
+ idx = torch.cat((idx, target_token), dim=1)
301
+ avg_neg_log_likelihood = total_loss / total_tokens if total_tokens > 0 else float('inf')
302
+ perplexity = math.exp(avg_neg_log_likelihood)
303
+ return idx, perplexity, avg_entropy
304
+ @torch.no_grad()
305
+ def generate_until_end(self, idx, temperature=1.0, top_k=None, max_new_tokens=1000):
306
+ for i in range(max_new_tokens):
307
+ idx_cond = idx
308
+ logits, _ = self(idx_cond)
309
+ logits = logits[:, -1, :] / temperature
310
+ if top_k is not None:
311
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
312
+ logits[logits < v[:, [-1]]] = -float('Inf')
313
+ probs = F.softmax(logits, dim=-1)
314
+ idx_next = torch.multinomial(probs, num_samples=1)
315
+ idx = torch.cat((idx, idx_next), dim=1)
316
+ if idx_next.item() == 50256:
317
+ break
318
+ return idx
319
+
320
+ # ----------------------------
321
+ # Utility Functions for Training & Evaluation
322
+ # ----------------------------
323
+
324
+ # Data Loader Functions
325
+ train_data_path = "/data1/home/nitinvetcha/Topics in AI/Streamlined/COLM2025/train_tinystories.bin"
326
+ val_data_path = "/data1/home/nitinvetcha/Topics in AI/Streamlined/COLM2025/val_tinystories.bin"
327
+ def get_batch(split):
328
+ data_path = train_data_path if split == 'train' else val_data_path
329
+ data = np.memmap(data_path, dtype=np.uint16, mode='r')
330
+ total_tokens = len(data)
331
+ max_ix = max(1, total_tokens - gptconf.block_size)
332
+ ix = torch.randint(0, max_ix, (batch_size,))
333
+ X = torch.stack([torch.from_numpy(data[i:i+gptconf.block_size].astype(np.int64)) for i in ix])
334
+ Y = torch.stack([torch.from_numpy(data[i+1:i+1+gptconf.block_size].astype(np.int64)) for i in ix])
335
+ return X.to(device), Y.to(device)
336
+
337
+ def evaluate_prompt_perplexity(model, token_file, prompt_length, num_trials, generation_params, device):
338
+ tokens = np.fromfile(token_file, dtype=np.uint16)
339
+ total_tokens = len(tokens)
340
+ perplexities = []
341
+ entropy_trials = []
342
+ max_new_tokens = generation_params.get("max_new_tokens", 50)
343
+ total_length = prompt_length + max_new_tokens
344
+ for trial in range(num_trials):
345
+ start_idx = random.randint(0, total_tokens - total_length)
346
+ sequence_tokens = tokens[start_idx : start_idx + total_length]
347
+ prompt_tokens = sequence_tokens[:prompt_length]
348
+ ground_truth_tokens = sequence_tokens
349
+ prompt_tensor = torch.tensor(prompt_tokens, dtype=torch.long).unsqueeze(0).to(device)
350
+ ground_truth_tensor = torch.tensor(ground_truth_tokens, dtype=torch.long).unsqueeze(0).to(device)
351
+ _, ppl, trial_entropy = model.generate_and_compute_perplexity(
352
+ prompt_tensor, ground_truth_tensor,
353
+ temperature=generation_params.get("temperature", 1.0),
354
+ return_attn_entropy=True, aggregate_heads=True
355
+ )
356
+ perplexities.append(ppl)
357
+ entropy_trials.append(trial_entropy)
358
+ print(f"Trial {trial+1}/{num_trials} for prompt length {prompt_length}: Perplexity = {ppl:.2f}, Avg Entropy = {trial_entropy:.4f}")
359
+ avg_ppl = np.mean(perplexities)
360
+ avg_entropy = np.mean(entropy_trials)
361
+ print(f"Prompt Length {prompt_length} - Avg Perplexity: {avg_ppl:.2f}, Avg Attention Entropy: {avg_entropy:.4f}\n")
362
+ return avg_ppl, avg_entropy
363
+
364
+ # ----------------------------
365
+ # Training Loop
366
+ # ----------------------------
367
+ # Training hyperparameters
368
+ batch_size = 12
369
+ max_iters = 25001
370
+ save_interval = 5000
371
+ learning_rate = 6e-4
372
+ weight_decay = 1e-1
373
+ grad_clip = 1.0
374
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
375
+ torch.manual_seed(1337)
376
+
377
+ # Model configuration: adjust these flags as needed.
378
+ model_args = dict(
379
+ n_layer=6,
380
+ n_head=6,
381
+ n_embd=384,
382
+ block_size=64, # You can change this as needed.
383
+ bias=False,
384
+ use_rope=True,
385
+ use_ape=True, # Set to True if you want APE.
386
+ use_alibi=False, # Set to True to use ALiBi.
387
+ rope_base=10000.0,
388
+ vocab_size=50304,
389
+ dropout=0.0
390
+ )
391
+ gptconf = GPTConfig(**model_args)
392
+ model = GPT(gptconf).to(device)
393
+ model.train()
394
+
395
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
396
+ iter_num = 0
397
+ start_time = time.time()
398
+ training_losses = []
399
+ validation_losses = []
400
+ save_iters = []
401
+
402
+ # Build a flag string for naming: e.g. "rope_ape" or "alibi" etc.
403
+ flag_parts = []
404
+ if gptconf.use_rope:
405
+ flag_parts.append("rope")
406
+ if gptconf.use_ape:
407
+ flag_parts.append("ape")
408
+ if gptconf.use_alibi:
409
+ flag_parts.append("alibi")
410
+ flag_str = "_".join(flag_parts) if flag_parts else "none"
411
+ weight_dir = f"weights_{flag_str}_{gptconf.block_size}"
412
+ os.makedirs(weight_dir, exist_ok=True)
413
+
414
+ while iter_num < max_iters:
415
+ X_train, Y_train = get_batch('train')
416
+ optimizer.zero_grad()
417
+ logits, loss_train = model(X_train, Y_train)
418
+ loss_train.backward()
419
+ if grad_clip > 0:
420
+ torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
421
+ optimizer.step()
422
+ training_losses.append(loss_train.item())
423
+
424
+ model.eval()
425
+ X_val, Y_val = get_batch('val')
426
+ with torch.no_grad():
427
+ logits_val, loss_val = model(X_val, Y_val)
428
+ validation_losses.append(loss_val.item())
429
+ model.train()
430
+
431
+ if iter_num % 100 == 0:
432
+ elapsed = time.time() - start_time
433
+ print(f"Iter {iter_num:5d}: train loss = {loss_train.item():.4f}, val loss = {loss_val.item():.4f}, time/iter = {elapsed/(iter_num+1):.4f}s")
434
+
435
+ if iter_num > 0 and iter_num % save_interval == 0:
436
+ save_iters.append(iter_num)
437
+ ckpt = {
438
+ 'iter_num': iter_num,
439
+ 'model_state_dict': model.state_dict(),
440
+ 'optimizer_state_dict': optimizer.state_dict(),
441
+ 'training_losses': training_losses,
442
+ 'validation_losses': validation_losses,
443
+ 'save_iters': save_iters,
444
+ }
445
+ ckpt_path = os.path.join(weight_dir, f"ckpt_{iter_num}.pt")
446
+ torch.save(ckpt, ckpt_path)
447
+ print(f"Checkpoint saved to {ckpt_path}")
448
+
449
+ iter_num += 1
450
+
451
+ print("Training complete.")
452
+
453
+ plt.figure(figsize=(10, 6))
454
+ plt.plot(range(len(training_losses)), training_losses, label="Training Loss")
455
+ plt.plot(range(len(validation_losses)), validation_losses, label="Validation Loss", alpha=0.7)
456
+ plt.xlabel("Iteration")
457
+ plt.ylabel("Loss")
458
+ plt.title("Training and Validation Loss per Iteration")
459
+ plt.legend()
460
+ plt.grid(True)
461
+ plt.show()
462
+
463
+ # ----------------------------
464
+ # Perplexity & Entropy Evaluation
465
+ # ----------------------------
466
+
467
+ token_file = val_data_path # Use validation data for evaluation.
468
+ prompt_lengths = [64, 128, 256, 512, 1024, 2048, 4096, 8192]
469
+ num_trials = 5
470
+ generation_params = {"temperature": 1.0, "max_new_tokens": 50}
471
+
472
+ avg_perplexities = []
473
+ avg_entropies = []
474
+
475
+ for pl in prompt_lengths:
476
+ print(f"Evaluating for prompt length: {pl}")
477
+ avg_ppl, avg_entropy = evaluate_prompt_perplexity(model, token_file, pl, num_trials, generation_params, device)
478
+ avg_perplexities.append(avg_ppl)
479
+ avg_entropies.append(avg_entropy)
480
+
481
+ results = {
482
+ "prompt_lengths": prompt_lengths,
483
+ "avg_perplexities": avg_perplexities,
484
+ "avg_entropies": avg_entropies
485
+ }
486
+ results_filename = f"results_{flag_str}_{gptconf.block_size}.json"
487
+ with open(results_filename, "w") as f:
488
+ json.dump(results, f)
489
+ print(f"Results saved to {results_filename}")
490
+
491
+ plt.figure(figsize=(8, 6))
492
+ plt.plot(prompt_lengths, avg_perplexities, marker='o')
493
+ plt.xlabel("Prompt Length")
494
+ plt.ylabel("Avg Generated Perplexity")
495
+ plt.title("Avg Generated Perplexity vs Prompt Length")
496
+ plt.grid(True)
497
+ plt.xscale('log')
498
+ plt.savefig(f"avg_generated_perplexity_{flag_str}_{gptconf.block_size}.png")
499
+ plt.show()
500
+
501
+ plt.figure(figsize=(8, 6))
502
+ plt.plot(prompt_lengths, avg_entropies, marker='o', color='red')
503
+ plt.xlabel("Prompt Length")
504
+ plt.ylabel("Avg Attention Entropy")
505
+ plt.title("Avg Attention Entropy vs Prompt Length\n(Averaged over Layers)")
506
+ plt.grid(True)
507
+ plt.xscale('log')
508
+ plt.savefig(f"avg_attention_entropy_{flag_str}_{gptconf.block_size}.png")
509
+ plt.show()
config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "custom",
3
+ "architectures": ["APE"],
4
+ "bias" : "False",
5
+ "use_rope" : "True",
6
+ "use_ape" : "True",
7
+ "use_alibi" : "False",
8
+ "n_layer": 6,
9
+ "n_head": 6,
10
+ "n_embd": 384,
11
+ "block_size": 64,
12
+ "vocab_size": 50304,
13
+ "rope_base": 10000.0,
14
+ "dropout" : 0
15
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6ed529ea1054b5d58bc32ab9f3a1ff524cc3ade9788d909260aefcf144ab4c40
3
+ size 359869364