AbstractPhil commited on
Commit
33f612a
·
verified ·
1 Parent(s): 8f6ef1c

Create trainer.py

Browse files
Files changed (1) hide show
  1. trainer.py +912 -0
trainer.py ADDED
@@ -0,0 +1,912 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #@title Geometric Autoregressive LM - Full Training with HF Upload + TensorBoard generated valid shakespere
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.utils.data import DataLoader, Dataset
6
+ from torch.utils.tensorboard import SummaryWriter
7
+ import math
8
+ from itertools import combinations
9
+ import time
10
+ import os
11
+ import json
12
+ from tqdm.auto import tqdm
13
+ from pathlib import Path
14
+
15
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
16
+ print(f"Device: {device}")
17
+
18
+ from geovocab2.shapes.factory.simplex_factory import SimplexFactory
19
+ from huggingface_hub import HfApi, create_repo, upload_folder
20
+ import tiktoken
21
+
22
+ # ============================================================================
23
+ # CONFIG
24
+ # ============================================================================
25
+
26
+ HF_REPO = "AbstractPhil/ksimplex-llm-prototype"
27
+ RUN_NAME = f"run_{int(time.time())}"
28
+ CHECKPOINT_DIR = Path(f"./checkpoints/{RUN_NAME}")
29
+ TENSORBOARD_DIR = Path(f"./runs/{RUN_NAME}")
30
+
31
+ CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
32
+ TENSORBOARD_DIR.mkdir(parents=True, exist_ok=True)
33
+
34
+ # ============================================================================
35
+ # CAYLEY-MENGER VALIDATOR
36
+ # ============================================================================
37
+
38
+ class CMValidator(nn.Module):
39
+ def __init__(self, k):
40
+ super().__init__()
41
+ self._k = k
42
+ self._nv = k + 1
43
+
44
+ pairs = list(combinations(range(self._nv), 2))
45
+ self._npairs = len(pairs)
46
+ self.register_buffer('_pi', torch.tensor([p[0] for p in pairs], dtype=torch.long))
47
+ self.register_buffer('_pj', torch.tensor([p[1] for p in pairs], dtype=torch.long))
48
+
49
+ sign = (-1.0) ** (k + 1)
50
+ fact = math.factorial(k)
51
+ self._prefactor = sign / ((2.0 ** k) * (fact ** 2))
52
+
53
+ def forward(self, verts):
54
+ gram = torch.einsum('...ve,...we->...vw', verts, verts)
55
+ norms = torch.diagonal(gram, dim1=-2, dim2=-1)
56
+ d2_mat = norms.unsqueeze(-1) + norms.unsqueeze(-2) - 2 * gram
57
+ d2_mat = F.relu(d2_mat)
58
+
59
+ d2_pairs = d2_mat[..., self._pi, self._pj]
60
+
61
+ shape = d2_mat.shape[:-2]
62
+ V = d2_mat.shape[-1]
63
+ cm = torch.zeros(*shape, V+1, V+1, device=d2_mat.device, dtype=d2_mat.dtype)
64
+ cm[..., 0, 1:] = 1.0
65
+ cm[..., 1:, 0] = 1.0
66
+ cm[..., 1:, 1:] = d2_mat
67
+
68
+ vol2 = self._prefactor * torch.linalg.det(cm)
69
+
70
+ return d2_pairs, vol2
71
+
72
+
73
+ # ============================================================================
74
+ # K-SIMPLEX CHANNEL ENCODER
75
+ # ============================================================================
76
+
77
+ class KSimplexChannel(nn.Module):
78
+ BASE_DEFORM = 0.05
79
+
80
+ def __init__(self, k, in_dim, edim, feat_dim):
81
+ super().__init__()
82
+ self._k = k
83
+ self._nv = k + 1
84
+ self._edim = edim
85
+ self._feat_dim = feat_dim
86
+
87
+ self._cm = CMValidator(k)
88
+ self._geo_dim = self._cm._npairs + 1
89
+
90
+ factory = SimplexFactory(k=k, embed_dim=edim, method="regular", scale=1.0)
91
+ self.register_buffer('_template', factory.build_torch(dtype=torch.float32))
92
+
93
+ self._to_coords = nn.Linear(in_dim, self._nv * edim)
94
+ self._to_feats = nn.Linear(in_dim, self._nv * feat_dim)
95
+
96
+ self._geo_gate = nn.Sequential(
97
+ nn.Linear(self._geo_dim, feat_dim),
98
+ nn.Sigmoid(),
99
+ )
100
+
101
+ self._out_dim = feat_dim + self._geo_dim
102
+
103
+ @property
104
+ def out_dim(self):
105
+ return self._out_dim
106
+
107
+ def forward(self, x):
108
+ coords = self._to_coords(x).unflatten(-1, (self._nv, self._edim))
109
+ verts = self._template + self.BASE_DEFORM * coords
110
+
111
+ vert_feats = self._to_feats(x).unflatten(-1, (self._nv, self._feat_dim))
112
+
113
+ d2, vol2 = self._cm(verts)
114
+ geo = torch.cat([d2, vol2.unsqueeze(-1)], dim=-1)
115
+
116
+ gate = self._geo_gate(geo)
117
+ validity = torch.sigmoid(vol2 * 1e6).unsqueeze(-1)
118
+
119
+ feat_agg = vert_feats.mean(dim=-2) * gate * validity
120
+
121
+ out = torch.cat([feat_agg, geo], dim=-1)
122
+
123
+ return out, vol2, d2.mean(dim=-1)
124
+
125
+
126
+ # ============================================================================
127
+ # TOKEN TO K-SIMPLEX CHANNELS
128
+ # ============================================================================
129
+
130
+ class TokenToKChannels(nn.Module):
131
+ def __init__(self, embed_dim, depth, edim, feat_dim, hidden=256):
132
+ super().__init__()
133
+ self._depth = depth
134
+
135
+ self._proj = nn.Sequential(
136
+ nn.Linear(embed_dim, hidden),
137
+ nn.LayerNorm(hidden),
138
+ nn.GELU(),
139
+ nn.Linear(hidden, hidden),
140
+ nn.LayerNorm(hidden),
141
+ nn.GELU(),
142
+ )
143
+
144
+ self._k_encoders = nn.ModuleList([
145
+ KSimplexChannel(k=k+1, in_dim=hidden, edim=edim, feat_dim=feat_dim)
146
+ for k in range(depth)
147
+ ])
148
+
149
+ self._k_out_dims = [enc.out_dim for enc in self._k_encoders]
150
+ self._max_out_dim = max(self._k_out_dims)
151
+
152
+ def forward(self, x):
153
+ h = self._proj(x)
154
+
155
+ out_list, vol2_list, d2_list = [], [], []
156
+
157
+ for enc in self._k_encoders:
158
+ out, vol2, d2_mean = enc(h)
159
+
160
+ pad_size = self._max_out_dim - out.shape[-1]
161
+ if pad_size > 0:
162
+ out = F.pad(out, (0, pad_size))
163
+
164
+ out_list.append(out)
165
+ vol2_list.append(vol2)
166
+ d2_list.append(d2_mean)
167
+
168
+ k_channels = torch.stack(out_list, dim=-2)
169
+ vol2 = torch.stack(vol2_list, dim=-1)
170
+ d2_mean = torch.stack(d2_list, dim=-1)
171
+
172
+ return k_channels, vol2, d2_mean
173
+
174
+
175
+ # ============================================================================
176
+ # K-CHANNEL CROSS-ATTENTION
177
+ # ============================================================================
178
+
179
+ class KChannelCrossAttention(nn.Module):
180
+ def __init__(self, depth, feat_dim, num_heads=4, dropout=0.1):
181
+ super().__init__()
182
+ self._depth = depth
183
+ self._feat_dim = feat_dim
184
+ self._num_heads = num_heads
185
+ self._head_dim = feat_dim // num_heads
186
+
187
+ self._norm_q = nn.LayerNorm(feat_dim)
188
+ self._norm_kv = nn.LayerNorm(feat_dim)
189
+
190
+ self._to_q = nn.Linear(feat_dim, feat_dim)
191
+ self._to_k = nn.Linear(feat_dim, feat_dim)
192
+ self._to_v = nn.Linear(feat_dim, feat_dim)
193
+ self._out = nn.Linear(feat_dim, feat_dim)
194
+ self._drop = nn.Dropout(dropout)
195
+
196
+ self._scale = self._head_dim ** -0.5
197
+
198
+ def forward(self, x):
199
+ B, T, K, F = x.shape
200
+
201
+ x_flat = x.view(B * T, K, F)
202
+
203
+ q = self._to_q(self._norm_q(x_flat))
204
+ k = self._to_k(self._norm_kv(x_flat))
205
+ v = self._to_v(self._norm_kv(x_flat))
206
+
207
+ q = q.view(-1, K, self._num_heads, self._head_dim).transpose(1, 2)
208
+ k = k.view(-1, K, self._num_heads, self._head_dim).transpose(1, 2)
209
+ v = v.view(-1, K, self._num_heads, self._head_dim).transpose(1, 2)
210
+
211
+ attn = (q @ k.transpose(-2, -1)) * self._scale
212
+ attn = attn.softmax(dim=-1)
213
+ attn = self._drop(attn)
214
+
215
+ out = (attn @ v).transpose(1, 2).reshape(B * T, K, F)
216
+ out = self._out(out)
217
+ out = self._drop(out)
218
+
219
+ return x + out.view(B, T, K, F)
220
+
221
+
222
+ # ============================================================================
223
+ # CAUSAL SEQUENCE ATTENTION
224
+ # ============================================================================
225
+
226
+ class CausalSequenceAttention(nn.Module):
227
+ def __init__(self, depth, feat_dim, num_heads=4, dropout=0.1, max_seq_len=2048):
228
+ super().__init__()
229
+ self._num_heads = num_heads
230
+
231
+ total_dim = depth * feat_dim
232
+ self._head_dim = total_dim // num_heads
233
+
234
+ self._norm = nn.LayerNorm(total_dim)
235
+ self._to_qkv = nn.Linear(total_dim, 3 * total_dim)
236
+ self._out = nn.Linear(total_dim, total_dim)
237
+ self._drop = nn.Dropout(dropout)
238
+
239
+ self._scale = self._head_dim ** -0.5
240
+
241
+ self.register_buffer(
242
+ '_causal_mask',
243
+ torch.tril(torch.ones(max_seq_len, max_seq_len)).bool()
244
+ )
245
+
246
+ def forward(self, x):
247
+ B, T, K, F = x.shape
248
+
249
+ x_flat = x.view(B, T, K * F)
250
+ x_norm = self._norm(x_flat)
251
+
252
+ qkv = self._to_qkv(x_norm).chunk(3, dim=-1)
253
+ q, k, v = [t.view(B, T, self._num_heads, self._head_dim).transpose(1, 2) for t in qkv]
254
+
255
+ attn = (q @ k.transpose(-2, -1)) * self._scale
256
+
257
+ mask = self._causal_mask[:T, :T]
258
+ attn = attn.masked_fill(~mask, float('-inf'))
259
+ attn = attn.softmax(dim=-1)
260
+ attn = self._drop(attn)
261
+
262
+ out = (attn @ v).transpose(1, 2).reshape(B, T, K * F)
263
+ out = self._out(out)
264
+ out = self._drop(out)
265
+
266
+ return x + out.view(B, T, K, F)
267
+
268
+
269
+ # ============================================================================
270
+ # TRANSFORMER BLOCK
271
+ # ============================================================================
272
+
273
+ class GeoBlock(nn.Module):
274
+ def __init__(self, depth, feat_dim, num_heads, mlp_ratio=4.0, dropout=0.1, max_seq_len=2048):
275
+ super().__init__()
276
+
277
+ self._k_attn = KChannelCrossAttention(depth, feat_dim, num_heads, dropout)
278
+ self._seq_attn = CausalSequenceAttention(depth, feat_dim, num_heads, dropout, max_seq_len)
279
+
280
+ total_dim = depth * feat_dim
281
+ self._norm = nn.LayerNorm(total_dim)
282
+ self._mlp = nn.Sequential(
283
+ nn.Linear(total_dim, int(total_dim * mlp_ratio)),
284
+ nn.GELU(),
285
+ nn.Dropout(dropout),
286
+ nn.Linear(int(total_dim * mlp_ratio), total_dim),
287
+ nn.Dropout(dropout),
288
+ )
289
+
290
+ def forward(self, x):
291
+ B, T, K, F = x.shape
292
+
293
+ x = self._k_attn(x)
294
+ x = self._seq_attn(x)
295
+
296
+ x_flat = x.view(B, T, K * F)
297
+ x_flat = x_flat + self._mlp(self._norm(x_flat))
298
+ x = x_flat.view(B, T, K, F)
299
+
300
+ return x
301
+
302
+
303
+ # ============================================================================
304
+ # GEOMETRIC LM
305
+ # ============================================================================
306
+
307
+ class GeometricLM(nn.Module):
308
+ def __init__(
309
+ self,
310
+ vocab_size,
311
+ max_seq_len=512,
312
+ embed_dim=256,
313
+ depth=4,
314
+ edim=16,
315
+ feat_dim=64,
316
+ hidden=256,
317
+ num_heads=8,
318
+ num_blocks=8,
319
+ dropout=0.1,
320
+ ):
321
+ super().__init__()
322
+
323
+ self._vocab_size = vocab_size
324
+ self._max_seq_len = max_seq_len
325
+ self._depth = depth
326
+ self._feat_dim = feat_dim
327
+
328
+ self._tok_embed = nn.Embedding(vocab_size, embed_dim)
329
+ self._pos_embed = nn.Embedding(max_seq_len, embed_dim)
330
+
331
+ self._tok_to_k = TokenToKChannels(embed_dim, depth, edim, feat_dim, hidden)
332
+ self._max_out_dim = self._tok_to_k._max_out_dim
333
+
334
+ self._proj = nn.Linear(self._max_out_dim, feat_dim)
335
+
336
+ self._blocks = nn.ModuleList([
337
+ GeoBlock(depth, feat_dim, num_heads, dropout=dropout, max_seq_len=max_seq_len)
338
+ for _ in range(num_blocks)
339
+ ])
340
+
341
+ total_dim = depth * feat_dim
342
+ self._norm = nn.LayerNorm(total_dim)
343
+ self._lm_head = nn.Linear(total_dim, vocab_size, bias=False)
344
+
345
+ self._config = {
346
+ 'vocab_size': vocab_size,
347
+ 'max_seq_len': max_seq_len,
348
+ 'embed_dim': embed_dim,
349
+ 'depth': depth,
350
+ 'edim': edim,
351
+ 'feat_dim': feat_dim,
352
+ 'hidden': hidden,
353
+ 'num_heads': num_heads,
354
+ 'num_blocks': num_blocks,
355
+ 'dropout': dropout,
356
+ 'total_dim': total_dim,
357
+ }
358
+
359
+ def forward(self, tokens):
360
+ B, T = tokens.shape
361
+
362
+ pos = torch.arange(T, device=tokens.device)
363
+ x = self._tok_embed(tokens) + self._pos_embed(pos)
364
+
365
+ k_channels, vol2, d2_mean = self._tok_to_k(x)
366
+ k_channels = self._proj(k_channels)
367
+
368
+ for blk in self._blocks:
369
+ k_channels = blk(k_channels)
370
+
371
+ out = k_channels.flatten(-2)
372
+ logits = self._lm_head(self._norm(out))
373
+
374
+ return logits, {'vol2': vol2, 'd2_mean': d2_mean}
375
+
376
+ @torch.no_grad()
377
+ def generate(self, prompt_tokens, max_new_tokens=100, temperature=1.0, top_k=50):
378
+ self.eval()
379
+ tokens = prompt_tokens.clone()
380
+
381
+ for _ in range(max_new_tokens):
382
+ ctx = tokens[:, -self._max_seq_len:]
383
+ logits, _ = self(ctx)
384
+ logits = logits[:, -1, :] / temperature
385
+
386
+ if top_k > 0:
387
+ v, _ = torch.topk(logits, top_k)
388
+ logits[logits < v[:, [-1]]] = float('-inf')
389
+
390
+ probs = F.softmax(logits, dim=-1)
391
+ next_tok = torch.multinomial(probs, num_samples=1)
392
+ tokens = torch.cat([tokens, next_tok], dim=1)
393
+
394
+ return tokens
395
+
396
+
397
+ # ============================================================================
398
+ # DATASET
399
+ # ============================================================================
400
+
401
+ class TokenizedDataset(Dataset):
402
+ def __init__(self, tokens, seq_len, stride=None):
403
+ self._tokens = tokens
404
+ self._seq_len = seq_len
405
+ self._stride = stride if stride else seq_len // 2 # 50% overlap max
406
+
407
+ def __len__(self):
408
+ return max(0, (len(self._tokens) - self._seq_len - 1) // self._stride)
409
+
410
+ def __getitem__(self, idx):
411
+ start = idx * self._stride
412
+ chunk = self._tokens[start:start + self._seq_len + 1]
413
+ x = torch.tensor(chunk[:-1], dtype=torch.long)
414
+ y = torch.tensor(chunk[1:], dtype=torch.long)
415
+ return x, y
416
+
417
+
418
+ # ============================================================================
419
+ # LOSS & METRICS
420
+ # ============================================================================
421
+
422
+ def lm_loss(logits, targets, info, ce_weight=1.0, validity_weight=0.1):
423
+ B, T, V = logits.shape
424
+ ce = F.cross_entropy(logits.view(B * T, V), targets.view(B * T))
425
+ validity = F.relu(-info['vol2']).mean()
426
+ total = ce_weight * ce + validity_weight * validity
427
+ return total, ce, validity
428
+
429
+
430
+ @torch.no_grad()
431
+ def compute_metrics(info, depth):
432
+ vol2 = info['vol2']
433
+ d2_mean = info['d2_mean']
434
+
435
+ m = {'valid_rate': (vol2 > 0).float().mean().item()}
436
+ for k in range(depth):
437
+ m[f'k{k+1}_valid'] = (vol2[..., k] > 0).float().mean().item()
438
+ m[f'k{k+1}_vol2'] = vol2[..., k].mean().item()
439
+ m[f'k{k+1}_d2'] = d2_mean[..., k].mean().item()
440
+ return m
441
+
442
+
443
+ # ============================================================================
444
+ # SANITY CHECK
445
+ # ============================================================================
446
+
447
+ @torch.no_grad()
448
+ def sanity_check(model, enc, device):
449
+ """Verify no information leak."""
450
+ print("\n" + "=" * 60)
451
+ print("SANITY CHECK")
452
+ print("=" * 60)
453
+
454
+ model.eval()
455
+
456
+ # Test 1: Random input should give high CE
457
+ random_tokens = torch.randint(0, 1000, (4, 256), device=device)
458
+ logits, _ = model(random_tokens)
459
+ random_targets = torch.randint(0, enc.n_vocab, (4, 256), device=device)
460
+ ce = F.cross_entropy(logits.view(-1, enc.n_vocab), random_targets.view(-1))
461
+
462
+ expected_ce = math.log(enc.n_vocab)
463
+ print(f"Test 1 - Random input:")
464
+ print(f" CE: {ce.item():.2f} (expected ~{expected_ce:.2f})")
465
+ print(f" PPL: {math.exp(min(ce.item(), 20)):.0f} (expected ~{enc.n_vocab})")
466
+
467
+ test1_pass = ce.item() > 8.0 # Should be close to ln(50257) ≈ 10.8
468
+ print(f" Status: {'✓ PASS' if test1_pass else '✗ FAIL'}")
469
+
470
+ # Test 2: Causal mask - early positions shouldn't depend on late tokens
471
+ tokens1 = torch.zeros(1, 256, dtype=torch.long, device=device)
472
+ tokens2 = torch.zeros(1, 256, dtype=torch.long, device=device)
473
+ tokens2[0, 128:] = 999 # Change later tokens
474
+
475
+ logits1, _ = model(tokens1)
476
+ logits2, _ = model(tokens2)
477
+
478
+ diff_early = (logits1[0, :128] - logits2[0, :128]).abs().max().item()
479
+ diff_late = (logits1[0, 128:] - logits2[0, 128:]).abs().max().item()
480
+
481
+ print(f"\nTest 2 - Causal mask:")
482
+ print(f" Early positions diff: {diff_early:.6f} (should be ~0)")
483
+ print(f" Late positions diff: {diff_late:.6f} (should be >0)")
484
+
485
+ test2_pass = diff_early < 1e-5 and diff_late > 1e-3
486
+ print(f" Status: {'✓ PASS' if test2_pass else '✗ FAIL'}")
487
+
488
+ # Test 3: Dataset sanity - x and y should be offset by 1
489
+ print(f"\nTest 3 - Dataset offset:")
490
+ test_tokens = list(range(100))
491
+ ds = TokenizedDataset(test_tokens, seq_len=10)
492
+ x, y = ds[0]
493
+ offset_correct = all(x[i] + 1 == y[i] for i in range(len(x)))
494
+ print(f" x: {x[:5].tolist()}...")
495
+ print(f" y: {y[:5].tolist()}...")
496
+ print(f" Offset correct: {'✓ PASS' if offset_correct else '✗ FAIL'}")
497
+
498
+ print("=" * 60)
499
+
500
+ all_pass = test1_pass and test2_pass and offset_correct
501
+ if not all_pass:
502
+ print("⚠️ WARNING: Some sanity checks failed!")
503
+ else:
504
+ print("✓ All sanity checks passed!")
505
+
506
+ print("=" * 60 + "\n")
507
+
508
+ model.train()
509
+ return all_pass
510
+
511
+
512
+ # ============================================================================
513
+ # GENERATION SAMPLING
514
+ # ============================================================================
515
+
516
+ PROMPTS = [
517
+ "ROMEO: ",
518
+ "JULIET: ",
519
+ "To be or not to be",
520
+ "The king ",
521
+ "Once upon a time",
522
+ "First Citizen:\n",
523
+ "What light through yonder",
524
+ "Friends, Romans, countrymen",
525
+ "Now is the winter of",
526
+ "All the world's a stage",
527
+ ]
528
+
529
+ @torch.no_grad()
530
+ def generate_samples(model, enc, device, epoch, writer=None):
531
+ """Generate samples from all prompts."""
532
+ model.eval()
533
+
534
+ samples = []
535
+ print(f"\n{'='*60}")
536
+ print(f"GENERATION SAMPLES - Epoch {epoch}")
537
+ print(f"{'='*60}")
538
+
539
+ for i, prompt in enumerate(PROMPTS):
540
+ prompt_tokens = torch.tensor([enc.encode(prompt)], device=device)
541
+
542
+ out_tokens = model.generate(
543
+ prompt_tokens,
544
+ max_new_tokens=100,
545
+ temperature=0.8,
546
+ top_k=50
547
+ )
548
+
549
+ generated = enc.decode(out_tokens[0].tolist())
550
+ samples.append({'prompt': prompt, 'generated': generated})
551
+
552
+ print(f"\n--- Prompt {i+1}: '{prompt.strip()}' ---")
553
+ print(generated[:300])
554
+ if len(generated) > 300:
555
+ print("...")
556
+
557
+ print(f"{'='*60}\n")
558
+
559
+ # Log to tensorboard
560
+ if writer:
561
+ sample_text = "\n\n".join([
562
+ f"**Prompt:** {s['prompt']}\n**Generated:**\n{s['generated'][:500]}"
563
+ for s in samples
564
+ ])
565
+ writer.add_text("samples/generated", sample_text, epoch)
566
+
567
+ model.train()
568
+ return samples
569
+
570
+
571
+ # ============================================================================
572
+ # CHECKPOINTING & HF UPLOAD
573
+ # ============================================================================
574
+
575
+ def save_checkpoint(model, optimizer, scheduler, epoch, config, metrics, checkpoint_dir):
576
+ """Save checkpoint locally."""
577
+ checkpoint = {
578
+ 'epoch': epoch,
579
+ 'model_state_dict': model._orig_mod.state_dict() if hasattr(model, '_orig_mod') else model.state_dict(),
580
+ 'optimizer_state_dict': optimizer.state_dict(),
581
+ 'scheduler_state_dict': scheduler.state_dict(),
582
+ 'config': config,
583
+ 'metrics': metrics,
584
+ }
585
+
586
+ path = checkpoint_dir / f"checkpoint_epoch_{epoch:03d}.pt"
587
+ torch.save(checkpoint, path)
588
+
589
+ # Also save latest
590
+ torch.save(checkpoint, checkpoint_dir / "checkpoint_latest.pt")
591
+
592
+ # Save config as JSON
593
+ with open(checkpoint_dir / "config.json", 'w') as f:
594
+ json.dump(config, f, indent=2)
595
+
596
+ print(f"Saved checkpoint: {path}")
597
+ return path
598
+
599
+
600
+ def upload_to_hf(checkpoint_dir, repo_id, epoch):
601
+ """Upload checkpoint directory to HuggingFace."""
602
+ try:
603
+ api = HfApi()
604
+
605
+ # Create repo if doesn't exist
606
+ try:
607
+ create_repo(repo_id, exist_ok=True, repo_type="model")
608
+ except Exception as e:
609
+ print(f"Repo creation note: {e}")
610
+
611
+ # Upload folder
612
+ api.upload_folder(
613
+ folder_path=str(checkpoint_dir),
614
+ repo_id=repo_id,
615
+ commit_message=f"Epoch {epoch} checkpoint",
616
+ )
617
+
618
+ print(f"Uploaded to HuggingFace: {repo_id}")
619
+ return True
620
+ except Exception as e:
621
+ print(f"HuggingFace upload failed: {e}")
622
+ return False
623
+
624
+
625
+ # ============================================================================
626
+ # TRAIN
627
+ # ============================================================================
628
+
629
+ def train():
630
+ import urllib.request
631
+
632
+ # TensorBoard
633
+ writer = SummaryWriter(log_dir=str(TENSORBOARD_DIR))
634
+ print(f"TensorBoard logs: {TENSORBOARD_DIR}")
635
+ print(f"Checkpoints: {CHECKPOINT_DIR}")
636
+ print(f"HuggingFace repo: {HF_REPO}")
637
+
638
+ # Data
639
+ data_path = './data/shakespeare.txt'
640
+ if not os.path.exists(data_path):
641
+ os.makedirs('./data', exist_ok=True)
642
+ url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
643
+ print("Downloading Shakespeare...")
644
+ urllib.request.urlretrieve(url, data_path)
645
+
646
+ with open(data_path, 'r') as f:
647
+ text = f.read()
648
+
649
+ print(f"Text length: {len(text):,} chars")
650
+
651
+ # Tokenizer
652
+ print("Loading tokenizer...")
653
+ enc = tiktoken.get_encoding("gpt2")
654
+
655
+ print("Tokenizing...")
656
+ tokens = enc.encode(text)
657
+ print(f"Token count: {len(tokens):,}")
658
+ print(f"Vocab size: {enc.n_vocab:,}")
659
+ print(f"Compression ratio: {len(text) / len(tokens):.2f}x")
660
+
661
+ # Split
662
+ seq_len = 256
663
+ split_idx = int(len(tokens) * 0.9)
664
+ train_tokens = tokens[:split_idx]
665
+ val_tokens = tokens[split_idx:]
666
+
667
+ train_ds = TokenizedDataset(train_tokens, seq_len)
668
+ val_ds = TokenizedDataset(val_tokens, seq_len)
669
+
670
+ batch_size = 12
671
+ train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True, persistent_workers=True)
672
+ val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, persistent_workers=True)
673
+
674
+ print(f"Train sequences: {len(train_ds):,} ({len(train_dl)} batches)")
675
+ print(f"Val sequences: {len(val_ds):,} ({len(val_dl)} batches)")
676
+
677
+ # Model config
678
+ model_config = {
679
+ 'vocab_size': enc.n_vocab,
680
+ 'max_seq_len': seq_len,
681
+ 'embed_dim': 384,
682
+ 'depth': 4,
683
+ 'edim': 16,
684
+ 'feat_dim': 96,
685
+ 'hidden': 384,
686
+ 'num_heads': 8,
687
+ 'num_blocks': 8,
688
+ 'dropout': 0.1,
689
+ }
690
+
691
+ # Training config
692
+ train_config = {
693
+ 'batch_size': batch_size,
694
+ 'seq_len': seq_len,
695
+ 'lr': 3e-4,
696
+ 'weight_decay': 0.1,
697
+ 'num_epochs': 14,
698
+ 'grad_clip': 1.0,
699
+ 'ce_weight': 1.0,
700
+ 'validity_weight': 0.1,
701
+ }
702
+
703
+ full_config = {
704
+ 'model': model_config,
705
+ 'training': train_config,
706
+ 'data': {
707
+ 'train_tokens': len(train_tokens),
708
+ 'val_tokens': len(val_tokens),
709
+ 'vocab_size': enc.n_vocab,
710
+ },
711
+ 'run_name': RUN_NAME,
712
+ }
713
+
714
+ # Save config
715
+ with open(CHECKPOINT_DIR / "config.json", 'w') as f:
716
+ json.dump(full_config, f, indent=2)
717
+
718
+ # Model
719
+ print("\nBuilding model...")
720
+ model = GeometricLM(**model_config).to(device)
721
+
722
+ print(f"\nConfig:")
723
+ for k, v in model._config.items():
724
+ print(f" {k}: {v}")
725
+
726
+ params = sum(p.numel() for p in model.parameters())
727
+ print(f" params: {params:,}")
728
+ full_config['model']['params'] = params
729
+
730
+ # Sanity check BEFORE compile
731
+ sanity_check(model, enc, device)
732
+
733
+ print("\nCompiling...")
734
+ #model = torch.compile(model, mode="reduce-overhead")
735
+
736
+ # Optimizer
737
+ opt = torch.optim.AdamW(
738
+ model.parameters(),
739
+ lr=train_config['lr'],
740
+ weight_decay=train_config['weight_decay']
741
+ )
742
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=train_config['num_epochs'])
743
+
744
+ # Log model graph
745
+ # writer.add_graph(model, torch.zeros(1, seq_len, dtype=torch.long, device=device))
746
+
747
+ best_val = float('inf')
748
+ best_ppl = float('inf')
749
+ global_step = 0
750
+
751
+ print("\nTraining...")
752
+ print("=" * 120)
753
+
754
+ epoch_pbar = tqdm(range(train_config['num_epochs']), desc="Epochs", position=0)
755
+
756
+ for ep in epoch_pbar:
757
+ epoch_start = time.time()
758
+
759
+ # ==================== TRAIN ====================
760
+ model.train()
761
+ ce_sum, val_sum, n = 0, 0, 0
762
+
763
+ train_pbar = tqdm(train_dl, desc=f"Train {ep+1}", leave=False, position=1)
764
+ for batch_idx, (x, y) in enumerate(train_pbar):
765
+ x, y = x.to(device), y.to(device)
766
+
767
+ opt.zero_grad()
768
+ logits, info = model(x)
769
+ loss, ce, val = lm_loss(
770
+ logits, y, info,
771
+ ce_weight=train_config['ce_weight'],
772
+ validity_weight=train_config['validity_weight']
773
+ )
774
+ loss.backward()
775
+ torch.nn.utils.clip_grad_norm_(model.parameters(), train_config['grad_clip'])
776
+ opt.step()
777
+
778
+ ce_sum += ce.item() * x.size(0)
779
+ val_sum += val.item() * x.size(0)
780
+ n += x.size(0)
781
+
782
+ # TensorBoard - batch level
783
+ if global_step % 100 == 0:
784
+ writer.add_scalar("train/ce_batch", ce.item(), global_step)
785
+ writer.add_scalar("train/ppl_batch", math.exp(min(ce.item(), 10)), global_step)
786
+ writer.add_scalar("train/validity_batch", val.item(), global_step)
787
+ writer.add_scalar("train/lr", sched.get_last_lr()[0], global_step)
788
+
789
+ global_step += 1
790
+
791
+ train_pbar.set_postfix({
792
+ 'CE': f'{ce.item():.3f}',
793
+ 'PPL': f'{math.exp(min(ce.item(), 10)):.1f}'
794
+ })
795
+
796
+ tr_ce = ce_sum / n
797
+ tr_ppl = math.exp(min(tr_ce, 10))
798
+ tr_val = val_sum / n
799
+
800
+ # ==================== VAL ====================
801
+ model.eval()
802
+ ce_sum, n = 0, 0
803
+ metrics_agg = []
804
+
805
+ val_pbar = tqdm(val_dl, desc=f"Val {ep+1}", leave=False, position=1)
806
+ with torch.no_grad():
807
+ for x, y in val_pbar:
808
+ x, y = x.to(device), y.to(device)
809
+ logits, info = model(x)
810
+ _, ce, _ = lm_loss(logits, y, info)
811
+ ce_sum += ce.item() * x.size(0)
812
+ n += x.size(0)
813
+ metrics_agg.append(compute_metrics(info, model._config['depth']))
814
+
815
+ val_pbar.set_postfix({
816
+ 'CE': f'{ce.item():.3f}',
817
+ 'PPL': f'{math.exp(min(ce.item(), 10)):.1f}'
818
+ })
819
+
820
+ va_ce = ce_sum / n
821
+ va_ppl = math.exp(min(va_ce, 10))
822
+
823
+ sched.step()
824
+
825
+ if va_ce < best_val:
826
+ best_val = va_ce
827
+ best_ppl = va_ppl
828
+
829
+ # Aggregate metrics
830
+ m = {k: sum(d[k] for d in metrics_agg) / len(metrics_agg) for k in metrics_agg[0]}
831
+
832
+ epoch_time = time.time() - epoch_start
833
+
834
+ # ==================== TENSORBOARD - EPOCH ====================
835
+ writer.add_scalar("epoch/train_ce", tr_ce, ep)
836
+ writer.add_scalar("epoch/train_ppl", tr_ppl, ep)
837
+ writer.add_scalar("epoch/val_ce", va_ce, ep)
838
+ writer.add_scalar("epoch/val_ppl", va_ppl, ep)
839
+ writer.add_scalar("epoch/best_ppl", best_ppl, ep)
840
+ writer.add_scalar("epoch/validity_loss", tr_val, ep)
841
+ writer.add_scalar("epoch/time", epoch_time, ep)
842
+
843
+ for k in range(model._config['depth']):
844
+ writer.add_scalar(f"geometry/k{k+1}_valid", m[f'k{k+1}_valid'], ep)
845
+ writer.add_scalar(f"geometry/k{k+1}_vol2", m[f'k{k+1}_vol2'], ep)
846
+ writer.add_scalar(f"geometry/k{k+1}_d2", m[f'k{k+1}_d2'], ep)
847
+
848
+ writer.add_scalar("geometry/valid_rate", m['valid_rate'], ep)
849
+
850
+ # ==================== LOGGING ====================
851
+ epoch_pbar.set_postfix({
852
+ 'TrPPL': f'{tr_ppl:.1f}',
853
+ 'VaPPL': f'{va_ppl:.1f}',
854
+ 'Best': f'{best_ppl:.1f}',
855
+ 'Valid': f"{m['valid_rate']:.0%}"
856
+ })
857
+
858
+ tqdm.write(
859
+ f"\nEp {ep+1:3d} | TrCE {tr_ce:.4f} | VaCE {va_ce:.4f} | "
860
+ f"TrPPL {tr_ppl:7.2f} | VaPPL {va_ppl:7.2f} | BestPPL {best_ppl:.2f} | "
861
+ f"Time {epoch_time:.1f}s"
862
+ )
863
+ tqdm.write(
864
+ f" | k1 {m['k1_valid']:5.1%} vol²={m['k1_vol2']:.2e} | "
865
+ f"k2 {m['k2_valid']:5.1%} vol²={m['k2_vol2']:.2e} | "
866
+ f"k3 {m['k3_valid']:5.1%} vol²={m['k3_vol2']:.2e} | "
867
+ f"k4 {m['k4_valid']:5.1%} vol²={m['k4_vol2']:.2e}"
868
+ )
869
+
870
+ # ==================== GENERATE SAMPLES ====================
871
+ if ep % 25 == 0 or ep == train_config['num_epochs'] - 1:
872
+ samples = generate_samples(model, enc, device, ep + 1, writer)
873
+
874
+ # Save samples to file
875
+ with open(CHECKPOINT_DIR / f"samples_epoch_{ep+1:03d}.json", 'w') as f:
876
+ json.dump(samples, f, indent=2)
877
+
878
+ # ==================== CHECKPOINT ====================
879
+ metrics = {
880
+ 'epoch': ep + 1,
881
+ 'train_ce': tr_ce,
882
+ 'train_ppl': tr_ppl,
883
+ 'val_ce': va_ce,
884
+ 'val_ppl': va_ppl,
885
+ 'best_ppl': best_ppl,
886
+ 'geometry': m,
887
+ }
888
+
889
+ if ep % 2 == 0 or ep == train_config['num_epochs'] - 1:
890
+ save_checkpoint(model, opt, sched, ep + 1, full_config, metrics, CHECKPOINT_DIR)
891
+
892
+
893
+ # ==================== HF UPLOAD ====================
894
+ if train_config['num_epochs'] - 1 == ep:
895
+ upload_to_hf(CHECKPOINT_DIR, HF_REPO, ep + 1)
896
+
897
+ # ==================== FINAL ====================
898
+ writer.close()
899
+
900
+ print("\n" + "=" * 120)
901
+ print(f"Training complete!")
902
+ print(f"Best val CE: {best_val:.4f}, PPL: {best_ppl:.2f}")
903
+ print(f"Checkpoints: {CHECKPOINT_DIR}")
904
+ print(f"TensorBoard: {TENSORBOARD_DIR}")
905
+ print(f"HuggingFace: https://huggingface.co/{HF_REPO}")
906
+ print("=" * 120)
907
+
908
+ return model, enc
909
+
910
+
911
+ if __name__ == "__main__":
912
+ model, tokenizer = train()