AbstractPhil commited on
Commit
3d51f39
Β·
verified Β·
1 Parent(s): ebd9fd5

Create distillation_trainer_v3.py

Browse files
Files changed (1) hide show
  1. distillation_trainer_v3.py +718 -0
distillation_trainer_v3.py ADDED
@@ -0,0 +1,718 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================================
2
+ # DEEP BERT v3 β€” TRAINER
3
+ #
4
+ # Teacher-distilled training. Frozen long-context experts teach the memory
5
+ # system what correct recall looks like.
6
+ #
7
+ # Colab cells:
8
+ # Cell 1: deep_bert_v3.py (architecture)
9
+ # Cell 2: this file (training)
10
+ #
11
+ # Flow per document:
12
+ # 1. ModernBERT (8192 ctx) β†’ teacher_cls (frozen, no grad)
13
+ # 2. Longformer (4096 ctx) β†’ teacher_cls_2 (frozen, no grad)
14
+ # 3. BERT + memory (16Γ—480) β†’ student_cls (memory trains)
15
+ # 4. Loss: student should match teachers
16
+ # ============================================================================
17
+
18
+ import gc
19
+ import json
20
+ import math
21
+ import os
22
+ import time
23
+ from dataclasses import dataclass
24
+ from typing import Any, Dict, List, Optional, Tuple
25
+
26
+ import numpy as np
27
+ import torch
28
+ import torch.nn as nn
29
+ import torch.nn.functional as F
30
+ from torch.utils.data import Dataset, DataLoader
31
+ from torch.utils.tensorboard import SummaryWriter
32
+ from safetensors.torch import save_file as safetensors_save
33
+ from datasets import load_dataset
34
+ from transformers import AutoModel, AutoTokenizer, BertTokenizer
35
+ from tqdm import tqdm
36
+
37
+
38
+ # ══════════════════════════════════════════════════════════════════
39
+ # CONFIG
40
+ # ══════════════════════════════════════════════════════════════════
41
+
42
+ @dataclass
43
+ class TrainConfig:
44
+ # Data
45
+ max_documents: int = 50000
46
+ max_val_documents: int = 500
47
+ segment_length: int = 480
48
+ segment_overlap: int = 64
49
+ target_chain_segments: int = 16
50
+ max_segments: int = 16
51
+ min_segments: int = 6
52
+
53
+ # Teachers
54
+ modern_bert_model: str = "answerdotai/ModernBERT-large"
55
+ longformer_model: str = "allenai/longformer-large-4096"
56
+ modern_max_len: int = 8192
57
+ longformer_max_len: int = 4096
58
+ procrustes_n_samples: int = 500 # docs for static pre-alignment
59
+
60
+ # Training
61
+ epochs: int = 10
62
+ batch_size: int = 4
63
+ lr_bank: float = 2e-3
64
+ lr_output: float = 5e-4
65
+ lr_proj: float = 1e-3
66
+ min_lr: float = 1e-6
67
+ weight_decay: float = 0.01
68
+ grad_clip: float = 1.0
69
+ warmup_steps: int = 300
70
+ tbptt_segments: int = 0 # 0 = no truncation (clean bank, safe now)
71
+
72
+ # Loss weights
73
+ modern_weight: float = 1.0
74
+ longformer_weight: float = 0.5
75
+ cv_weight: float = 0.05
76
+ temperature: float = 0.07
77
+
78
+ # Logging
79
+ checkpoint_dir: str = "/home/claude/deep_bert_v3_checkpoints"
80
+ tensorboard_dir: str = "/home/claude/deep_bert_v3_tb"
81
+ log_every: int = 20
82
+ eval_every: int = 200
83
+ save_every_epoch: bool = True
84
+
85
+
86
+ TCFG = TrainConfig()
87
+
88
+
89
+ # ══════════════════════════════════════════════════════════════════
90
+ # DATA PIPELINE β€” raw text + student segments
91
+ # ══════════════════════════════════════════════════════════════════
92
+
93
+ def load_wikitext_documents(split, max_docs):
94
+ """Load WikiText-103, return list of raw text documents."""
95
+ print(f" Loading wikitext-103 ({split})...")
96
+ ds = load_dataset("wikitext", "wikitext-103-raw-v1", split=split)
97
+ documents = []
98
+ current_doc = []
99
+ for row in ds:
100
+ text = row.get("text", "").strip()
101
+ if not text:
102
+ if current_doc:
103
+ full = " ".join(current_doc)
104
+ if len(full) > 100:
105
+ documents.append(full)
106
+ current_doc = []
107
+ continue
108
+ if text.startswith("= ") and not text.startswith("= = "):
109
+ if current_doc:
110
+ full = " ".join(current_doc)
111
+ if len(full) > 100:
112
+ documents.append(full)
113
+ current_doc = [text]
114
+ else:
115
+ current_doc.append(text)
116
+ if current_doc:
117
+ full = " ".join(current_doc)
118
+ if len(full) > 100:
119
+ documents.append(full)
120
+ print(f" {len(documents)} documents")
121
+ return documents[:max_docs]
122
+
123
+
124
+ def build_chains_with_text(raw_docs, bert_tokenizer):
125
+ """Build student segment chains AND track raw text for teacher tokenization."""
126
+ stride = TCFG.segment_length - TCFG.segment_overlap
127
+ sep_id = bert_tokenizer.sep_token_id
128
+
129
+ all_ids, all_masks, all_n_reals, all_texts = [], [], [], []
130
+ doc_idx = 0
131
+
132
+ while doc_idx < len(raw_docs):
133
+ target_tokens = TCFG.target_chain_segments * stride
134
+ current_ids = []
135
+ chain_docs = []
136
+
137
+ while len(current_ids) < target_tokens and doc_idx < len(raw_docs):
138
+ if current_ids:
139
+ current_ids.append(sep_id)
140
+ ids = bert_tokenizer.encode(raw_docs[doc_idx], add_special_tokens=False)
141
+ if len(ids) > 50:
142
+ current_ids.extend(ids)
143
+ chain_docs.append(doc_idx)
144
+ doc_idx += 1
145
+
146
+ if len(current_ids) < TCFG.min_segments * stride:
147
+ continue
148
+
149
+ # Build segments
150
+ seg_ids_list, seg_masks_list = [], []
151
+ pos = 0
152
+ while pos < len(current_ids) and len(seg_ids_list) < TCFG.max_segments:
153
+ end = min(pos + TCFG.segment_length, len(current_ids))
154
+ seg = current_ids[pos:end]
155
+ pad = TCFG.segment_length - len(seg)
156
+ if pad > 0:
157
+ ids_t = torch.tensor(seg + [0] * pad, dtype=torch.int32)
158
+ mask_t = torch.cat([torch.ones(len(seg), dtype=torch.int8),
159
+ torch.zeros(pad, dtype=torch.int8)])
160
+ else:
161
+ ids_t = torch.tensor(seg[:TCFG.segment_length], dtype=torch.int32)
162
+ mask_t = torch.ones(TCFG.segment_length, dtype=torch.int8)
163
+ seg_ids_list.append(ids_t)
164
+ seg_masks_list.append(mask_t)
165
+ if end >= len(current_ids):
166
+ break
167
+ pos += stride
168
+
169
+ n_real = len(seg_ids_list)
170
+ if n_real < TCFG.min_segments:
171
+ continue
172
+ while len(seg_ids_list) < TCFG.max_segments:
173
+ seg_ids_list.append(torch.zeros(TCFG.segment_length, dtype=torch.int32))
174
+ seg_masks_list.append(torch.zeros(TCFG.segment_length, dtype=torch.int8))
175
+
176
+ all_ids.append(torch.stack(seg_ids_list))
177
+ all_masks.append(torch.stack(seg_masks_list))
178
+ all_n_reals.append(n_real)
179
+ # Raw text for teachers
180
+ all_texts.append(" ".join(raw_docs[i] for i in chain_docs))
181
+
182
+ print(f" {len(all_n_reals)} chains, segs: "
183
+ f"min={min(all_n_reals)}, max={max(all_n_reals)}, "
184
+ f"mean={np.mean(all_n_reals):.1f}")
185
+ return (torch.stack(all_ids), torch.stack(all_masks),
186
+ torch.tensor(all_n_reals, dtype=torch.long), all_texts)
187
+
188
+
189
+ class ChainDataset(Dataset):
190
+ def __init__(self, ids, masks, n_reals, texts):
191
+ self.ids, self.masks, self.n_reals = ids, masks, n_reals
192
+ self.texts = texts
193
+
194
+ def __len__(self):
195
+ return len(self.n_reals)
196
+
197
+ def __getitem__(self, i):
198
+ return self.ids[i], self.masks[i], self.n_reals[i], self.texts[i]
199
+
200
+
201
+ def chain_collate(batch):
202
+ ids, masks, n_reals, texts = zip(*batch)
203
+ return (torch.stack(ids), torch.stack(masks),
204
+ torch.tensor(n_reals, dtype=torch.long), list(texts))
205
+
206
+
207
+ # ══════════════════════════════════════════════════════════════════
208
+ # GEOMETRIC UTILITIES
209
+ # ══════════════════════════════════════════════════════════════════
210
+
211
+ def cayley_menger_vol2(pts):
212
+ with torch.amp.autocast("cuda", enabled=False):
213
+ pts = pts.float()
214
+ diff = pts.unsqueeze(-2) - pts.unsqueeze(-3)
215
+ d2 = (diff * diff).sum(-1)
216
+ B, V, _ = d2.shape
217
+ cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32)
218
+ cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
219
+ s = (-1.0)**V; f = math.factorial(V-1)
220
+ return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm)
221
+
222
+
223
+ def pentachoron_cv(embeddings, n_samples=16):
224
+ B = embeddings.shape[0]
225
+ if B < 5:
226
+ return torch.tensor(0.0, device=embeddings.device)
227
+ vols = []
228
+ for _ in range(n_samples):
229
+ idx = torch.randperm(B, device=embeddings.device)[:5]
230
+ v2 = cayley_menger_vol2(embeddings[idx].unsqueeze(0))
231
+ vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12))
232
+ stacked = torch.stack(vols)
233
+ return stacked.std() / (stacked.mean() + 1e-8)
234
+
235
+
236
+ # ══════════════════════════════════════════════════════════════════
237
+ # TEACHER UTILITIES
238
+ # ══════════════════════════════════════════════════════════════════
239
+
240
+ def mean_pool(hidden_states, attention_mask):
241
+ mask = attention_mask.unsqueeze(-1).float()
242
+ return (hidden_states * mask).sum(1) / mask.sum(1).clamp(min=1)
243
+
244
+
245
+ @torch.no_grad()
246
+ def teacher_forward_modern(model, tokenizer, texts, device, max_len):
247
+ """ModernBERT forward: standard attention, mean-pool."""
248
+ inputs = tokenizer(texts, max_length=max_len, padding=True,
249
+ truncation=True, return_tensors="pt").to(device)
250
+ out = model(**inputs)
251
+ return mean_pool(out.last_hidden_state, inputs.attention_mask)
252
+
253
+
254
+ @torch.no_grad()
255
+ def teacher_forward_longformer(model, tokenizer, texts, device, max_len):
256
+ """Longformer forward: CLS gets global attention."""
257
+ inputs = tokenizer(texts, max_length=max_len, padding=True,
258
+ truncation=True, return_tensors="pt").to(device)
259
+ # Global attention on CLS token
260
+ global_attn = torch.zeros_like(inputs.input_ids)
261
+ global_attn[:, 0] = 1
262
+ out = model(input_ids=inputs.input_ids,
263
+ attention_mask=inputs.attention_mask,
264
+ global_attention_mask=global_attn)
265
+ return out.last_hidden_state[:, 0] # CLS with global attention
266
+
267
+
268
+ # ══════════════════════════════════════════════════════════════════
269
+ # LOSSES
270
+ # ══════════════════════════════════════════════════════════════════
271
+
272
+ def distillation_loss(student_emb, teacher_emb, temperature=0.07):
273
+ """InfoNCE: student[i] should be closest to teacher[i] in the batch."""
274
+ s = F.normalize(student_emb, dim=-1)
275
+ t = F.normalize(teacher_emb, dim=-1)
276
+ logits = (s @ t.T) / temperature
277
+ labels = torch.arange(logits.shape[0], device=logits.device)
278
+ loss = F.cross_entropy(logits, labels)
279
+ with torch.no_grad():
280
+ acc = (logits.argmax(-1) == labels).float().mean().item()
281
+ return loss, acc
282
+
283
+
284
+ def batch_cv_loss(all_anchors, n_reals, cv_target=0.20):
285
+ device = all_anchors.device
286
+ B = all_anchors.shape[0]
287
+ total_loss = torch.tensor(0.0, device=device)
288
+ total_cv = 0.0; n_valid = 0
289
+ for b in range(B):
290
+ n = n_reals[b].item()
291
+ if n < 5:
292
+ continue
293
+ cv_val = pentachoron_cv(all_anchors[b, :n], n_samples=16)
294
+ total_loss = total_loss + (cv_val - cv_target).abs()
295
+ total_cv += cv_val.item()
296
+ n_valid += 1
297
+ stats = {"cv_raw": total_cv / max(n_valid, 1)}
298
+ if n_valid == 0:
299
+ return total_loss, stats
300
+ return total_loss / n_valid, stats
301
+
302
+
303
+ # ══════════════════════════════════════════════════════════════════
304
+ # PARAM GROUPS
305
+ # ══════════════════════════════════════════════════════════════════
306
+
307
+ def make_param_groups(model):
308
+ bank_names = {"bank.depth_compressor", "bank.temporal_proj",
309
+ "bank.cross_attn", "bank.cross_norms",
310
+ "bank.cross_ffns", "bank.ffn_norms"}
311
+ proj_names = {"proj_modern", "proj_longformer"}
312
+
313
+ bank_p, proj_p, output_p = [], [], []
314
+ for name, param in model.named_parameters():
315
+ if not param.requires_grad:
316
+ continue
317
+ if any(name.startswith(p) for p in proj_names):
318
+ proj_p.append(param)
319
+ elif any(name.startswith(p) for p in bank_names):
320
+ bank_p.append(param)
321
+ else:
322
+ output_p.append(param)
323
+
324
+ groups = [
325
+ {"params": bank_p, "lr": TCFG.lr_bank, "name": "bank"},
326
+ {"params": proj_p, "lr": TCFG.lr_proj, "name": "proj"},
327
+ {"params": output_p, "lr": TCFG.lr_output, "name": "output"},
328
+ ]
329
+ for g in groups:
330
+ g["weight_decay"] = TCFG.weight_decay
331
+ n = sum(p.numel() for p in g["params"])
332
+ print(f" {g['name']:8s}: {n:>10,} params @ lr={g['lr']}")
333
+ return groups
334
+
335
+
336
+ # ══════════════════════════════════════════════════════════════════
337
+ # STATIC PROCRUSTES PRE-ALIGNMENT
338
+ # ══════════════════════════════════════════════════════════════════
339
+
340
+ @torch.no_grad()
341
+ def compute_and_init_procrustes(student_model, modern_model, modern_tok,
342
+ long_model, long_tok, bert_tok,
343
+ texts, device):
344
+ """
345
+ Feed N texts through BERT (CLS) and each teacher (mean-pool/CLS).
346
+ Compute Procrustes rotation, initialize projectors.
347
+ """
348
+ print(f"\n Computing static Procrustes on {len(texts)} texts...")
349
+ student_embs, modern_embs, long_embs = [], [], []
350
+
351
+ for i in range(0, len(texts), 16):
352
+ batch = texts[i:i+16]
353
+
354
+ # Student: just BERT CLS (no memory, single segment)
355
+ bert_inputs = bert_tok(batch, max_length=480, padding=True,
356
+ truncation=True, return_tensors="pt").to(device)
357
+ bert_out = student_model.bert(
358
+ input_ids=bert_inputs.input_ids,
359
+ attention_mask=bert_inputs.attention_mask,
360
+ return_dict=True)
361
+ student_embs.append(bert_out.last_hidden_state[:, 0].cpu())
362
+
363
+ # ModernBERT
364
+ modern_embs.append(
365
+ teacher_forward_modern(modern_model, modern_tok, batch,
366
+ device, TCFG.modern_max_len).cpu())
367
+
368
+ # Longformer
369
+ long_embs.append(
370
+ teacher_forward_longformer(long_model, long_tok, batch,
371
+ device, TCFG.longformer_max_len).cpu())
372
+
373
+ student_all = torch.cat(student_embs)
374
+ modern_all = torch.cat(modern_embs)
375
+ long_all = torch.cat(long_embs)
376
+
377
+ # Procrustes: student β†’ ModernBERT
378
+ print(" ModernBERT alignment:")
379
+ R_m, mu_s_m, mu_t_m = compute_static_procrustes(student_all, modern_all)
380
+ student_model.proj_modern.init_from_procrustes(R_m, mu_s_m, mu_t_m)
381
+
382
+ # Procrustes: student β†’ Longformer
383
+ print(" Longformer alignment:")
384
+ R_l, mu_s_l, mu_t_l = compute_static_procrustes(student_all, long_all)
385
+ student_model.proj_longformer.init_from_procrustes(R_l, mu_s_l, mu_t_l)
386
+
387
+
388
+ # ══════════════════════════════════════════════════════════════════
389
+ # TRAINING
390
+ # ══════════════════════════════════════════════════════════════════
391
+
392
+ def train(model, modern_model, modern_tok, long_model, long_tok,
393
+ train_loader, val_loader=None):
394
+ device = next(model.parameters()).device
395
+ os.makedirs(TCFG.checkpoint_dir, exist_ok=True)
396
+ os.makedirs(TCFG.tensorboard_dir, exist_ok=True)
397
+ writer = SummaryWriter(log_dir=TCFG.tensorboard_dir)
398
+
399
+ param_groups = make_param_groups(model)
400
+ optimizer = torch.optim.AdamW(param_groups)
401
+ all_params = [p for g in param_groups for p in g["params"]]
402
+
403
+ total_steps = len(train_loader) * TCFG.epochs
404
+ scheduler = torch.optim.lr_scheduler.SequentialLR(
405
+ optimizer,
406
+ [torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.01,
407
+ total_iters=TCFG.warmup_steps),
408
+ torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max(total_steps, 1),
409
+ eta_min=TCFG.min_lr)],
410
+ milestones=[TCFG.warmup_steps])
411
+
412
+ scaler = torch.amp.GradScaler()
413
+ global_step = 0
414
+ best_val_loss = float("inf")
415
+
416
+ print(f"\n Training: {sum(p.numel() for p in all_params):,} params")
417
+ print(f" {len(train_loader)} batches/epoch Γ— {TCFG.batch_size} chains")
418
+ print(f" Losses: modern({TCFG.modern_weight}) + long({TCFG.longformer_weight}) "
419
+ f"+ cv({TCFG.cv_weight})")
420
+
421
+ for epoch in range(TCFG.epochs):
422
+ model.train()
423
+ losses = {"total": 0, "modern": 0, "longformer": 0, "cv": 0}
424
+ metrics = {"modern_acc": 0, "long_acc": 0, "cv_raw": 0}
425
+ n_batches = 0
426
+ t0 = time.time()
427
+
428
+ pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{TCFG.epochs}", unit="batch")
429
+
430
+ for student_ids, student_masks, n_reals, raw_texts in pbar:
431
+ B = n_reals.shape[0]
432
+
433
+ # ── Teacher forwards (frozen, no grad) ──
434
+ with torch.no_grad():
435
+ with torch.amp.autocast("cuda"):
436
+ modern_cls = teacher_forward_modern(
437
+ modern_model, modern_tok, raw_texts,
438
+ device, TCFG.modern_max_len)
439
+ long_cls = teacher_forward_longformer(
440
+ long_model, long_tok, raw_texts,
441
+ device, TCFG.longformer_max_len)
442
+
443
+ # ── Student forward (memory system trains) ──
444
+ state = model.init_state(B, device)
445
+ all_anchors = torch.zeros(B, TCFG.max_segments, model.config.anchor_dim,
446
+ device=device)
447
+
448
+ for seg_k in range(TCFG.max_segments):
449
+ if TCFG.tbptt_segments > 0 and seg_k > 0 and seg_k % TCFG.tbptt_segments == 0:
450
+ state = DeepBertV3.detach_state(state)
451
+ all_anchors = all_anchors.detach()
452
+
453
+ ids = student_ids[:, seg_k].to(device).long()
454
+ mask = student_masks[:, seg_k].to(device).long()
455
+
456
+ with torch.amp.autocast("cuda"):
457
+ outputs, state = model(ids, mask, state)
458
+ all_anchors[:, seg_k] = outputs["live_anchor"]
459
+
460
+ # Student output: fused (CLS + memory delta) from last real segment
461
+ student_cls = outputs["memory_output"]
462
+
463
+ # ── Project into teacher spaces ──
464
+ with torch.amp.autocast("cuda"):
465
+ proj_m = model.proj_modern(student_cls)
466
+ proj_l = model.proj_longformer(student_cls)
467
+
468
+ # ── Distillation losses ──
469
+ l_modern, acc_m = distillation_loss(
470
+ proj_m, modern_cls, TCFG.temperature)
471
+ l_long, acc_l = distillation_loss(
472
+ proj_l, long_cls, TCFG.temperature)
473
+
474
+ # ── CV on live anchors ──
475
+ l_cv, cv_stats = batch_cv_loss(
476
+ all_anchors, n_reals.to(device), model.config.cv_target)
477
+
478
+ loss = (TCFG.modern_weight * l_modern +
479
+ TCFG.longformer_weight * l_long +
480
+ TCFG.cv_weight * l_cv)
481
+
482
+ scaler.scale(loss).backward()
483
+ scaler.unscale_(optimizer)
484
+ torch.nn.utils.clip_grad_norm_(all_params, TCFG.grad_clip)
485
+ scaler.step(optimizer)
486
+ scaler.update()
487
+ optimizer.zero_grad(set_to_none=True)
488
+ scheduler.step()
489
+ global_step += 1
490
+
491
+ losses["total"] += loss.item()
492
+ losses["modern"] += l_modern.item()
493
+ losses["longformer"] += l_long.item()
494
+ losses["cv"] += l_cv.item()
495
+ metrics["modern_acc"] += acc_m
496
+ metrics["long_acc"] += acc_l
497
+ metrics["cv_raw"] += cv_stats.get("cv_raw", 0)
498
+ n_batches += 1
499
+
500
+ n = max(n_batches, 1)
501
+ pbar.set_postfix(
502
+ loss=f"{losses['total']/n:.3f}",
503
+ m_acc=f"{metrics['modern_acc']/n:.3f}",
504
+ l_acc=f"{metrics['long_acc']/n:.3f}",
505
+ cv=f"{metrics['cv_raw']/n:.3f}")
506
+
507
+ if global_step % TCFG.log_every == 0:
508
+ writer.add_scalar("train/loss", losses["total"] / n, global_step)
509
+ writer.add_scalar("train/modern_acc", metrics["modern_acc"] / n, global_step)
510
+ writer.add_scalar("train/long_acc", metrics["long_acc"] / n, global_step)
511
+ writer.add_scalar("train/cv_raw", metrics["cv_raw"] / n, global_step)
512
+ for k in ["modern", "longformer", "cv"]:
513
+ writer.add_scalar(f"train/{k}_loss", losses[k] / n, global_step)
514
+
515
+ if val_loader and global_step % TCFG.eval_every == 0:
516
+ vl = evaluate(model, modern_model, modern_tok,
517
+ long_model, long_tok, val_loader, writer, global_step)
518
+ if vl < best_val_loss:
519
+ best_val_loss = vl
520
+ save_checkpoint(model, optimizer, epoch, global_step,
521
+ os.path.join(TCFG.checkpoint_dir, "best"))
522
+ model.train()
523
+
524
+ pbar.close()
525
+ elapsed = time.time() - t0
526
+ n = max(n_batches, 1)
527
+ print(f"\n Epoch {epoch+1}: {n_batches * TCFG.batch_size / elapsed:.1f} chains/s "
528
+ f"loss={losses['total']/n:.4f} "
529
+ f"m_acc={metrics['modern_acc']/n:.3f} "
530
+ f"l_acc={metrics['long_acc']/n:.3f} "
531
+ f"cv={metrics['cv_raw']/n:.3f}")
532
+
533
+ if TCFG.save_every_epoch:
534
+ save_checkpoint(model, optimizer, epoch + 1, global_step,
535
+ os.path.join(TCFG.checkpoint_dir, f"epoch_{epoch+1:03d}"))
536
+
537
+ save_checkpoint(model, optimizer, TCFG.epochs, global_step,
538
+ os.path.join(TCFG.checkpoint_dir, "final"))
539
+ writer.flush()
540
+ writer.close()
541
+
542
+
543
+ # ══════════════════════════════════════════════════════════════════
544
+ # EVAL
545
+ # ══════════════════════════════════════════════════════════════════
546
+
547
+ @torch.no_grad()
548
+ def evaluate(model, modern_model, modern_tok, long_model, long_tok,
549
+ val_loader, writer=None, global_step=0):
550
+ model.eval()
551
+ device = next(model.parameters()).device
552
+ total = {"loss": 0, "modern_acc": 0, "long_acc": 0, "cv_raw": 0}
553
+ n = 0
554
+
555
+ for student_ids, student_masks, n_reals, raw_texts in tqdm(val_loader, desc="Eval", leave=False):
556
+ B = n_reals.shape[0]
557
+
558
+ with torch.amp.autocast("cuda"):
559
+ modern_cls = teacher_forward_modern(
560
+ modern_model, modern_tok, raw_texts, device, TCFG.modern_max_len)
561
+ long_cls = teacher_forward_longformer(
562
+ long_model, long_tok, raw_texts, device, TCFG.longformer_max_len)
563
+
564
+ state = model.init_state(B, device)
565
+ all_anc = torch.zeros(B, TCFG.max_segments, model.config.anchor_dim, device=device)
566
+ for seg_k in range(TCFG.max_segments):
567
+ with torch.amp.autocast("cuda"):
568
+ out, state = model(student_ids[:, seg_k].to(device).long(),
569
+ student_masks[:, seg_k].to(device).long(), state)
570
+ all_anc[:, seg_k] = out["live_anchor"]
571
+
572
+ with torch.amp.autocast("cuda"):
573
+ student_cls = out["memory_output"]
574
+ l_m, acc_m = distillation_loss(
575
+ model.proj_modern(student_cls), modern_cls, TCFG.temperature)
576
+ l_l, acc_l = distillation_loss(
577
+ model.proj_longformer(student_cls), long_cls, TCFG.temperature)
578
+ l_cv, cv_s = batch_cv_loss(all_anc, n_reals.to(device), 0.20)
579
+
580
+ total["loss"] += (TCFG.modern_weight * l_m.item() +
581
+ TCFG.longformer_weight * l_l.item() +
582
+ TCFG.cv_weight * l_cv.item())
583
+ total["modern_acc"] += acc_m
584
+ total["long_acc"] += acc_l
585
+ total["cv_raw"] += cv_s.get("cv_raw", 0)
586
+ n += 1
587
+
588
+ d = max(n, 1)
589
+ print(f" Val: loss={total['loss']/d:.4f} "
590
+ f"m_acc={total['modern_acc']/d:.3f} "
591
+ f"l_acc={total['long_acc']/d:.3f} "
592
+ f"cv={total['cv_raw']/d:.3f}")
593
+ if writer:
594
+ for k, v in total.items():
595
+ writer.add_scalar(f"val/{k}", v / d, global_step)
596
+ return total["loss"] / d
597
+
598
+
599
+ # ══════════════════════════════════════════════════════════════════
600
+ # CHECKPOINT
601
+ # ══════════════════════════════════════════════════════════════════
602
+
603
+ def save_checkpoint(model, optimizer, epoch, global_step, path):
604
+ os.makedirs(path, exist_ok=True)
605
+ state = {}
606
+ for name, param in model.named_parameters():
607
+ if param.requires_grad:
608
+ state[name] = param.data.contiguous().cpu()
609
+ for name, buf in model.named_buffers():
610
+ state[f"buffer.{name}"] = buf.contiguous().cpu()
611
+ safetensors_save(state, os.path.join(path, "memory_system.safetensors"))
612
+ torch.save({"optimizer": optimizer.state_dict(), "epoch": epoch,
613
+ "global_step": global_step}, os.path.join(path, "training_state.pt"))
614
+ import dataclasses
615
+ with open(os.path.join(path, "config.json"), "w") as f:
616
+ json.dump({"model": dataclasses.asdict(model.config),
617
+ "training": dataclasses.asdict(TCFG)}, f, indent=2, default=str)
618
+
619
+
620
+ # ══════════════════════════════════════════════════════════════════
621
+ # MAIN
622
+ # ══════════════════════════════════════════════════════════════════
623
+
624
+ def main():
625
+ print("=" * 70)
626
+ print("DEEP BERT v3 β€” TEACHER-DISTILLED GEOMETRIC MEMORY")
627
+ print("=" * 70)
628
+
629
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
630
+ print(f" Device: {device}")
631
+ if torch.cuda.is_available():
632
+ print(f" GPU: {torch.cuda.get_device_name()}")
633
+ print(f" VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
634
+
635
+ # ── Load student model ──
636
+ print(f"\n{'='*70}")
637
+ print("LOADING MODELS")
638
+ print(f"{'='*70}")
639
+
640
+ config = DeepBertV3Config()
641
+ model = DeepBertV3.from_pretrained(config).to(device)
642
+ bert_tokenizer = BertTokenizer.from_pretrained(config.bert_model)
643
+
644
+ # ── Load teachers (frozen) ──
645
+ print(f"\n Loading ModernBERT-large...")
646
+ modern_model = AutoModel.from_pretrained(TCFG.modern_bert_model,
647
+ torch_dtype=torch.float16).to(device)
648
+ modern_model.eval()
649
+ for p in modern_model.parameters():
650
+ p.requires_grad = False
651
+ modern_tok = AutoTokenizer.from_pretrained(TCFG.modern_bert_model)
652
+ print(f" {sum(p.numel() for p in modern_model.parameters()):,} params (frozen)")
653
+
654
+ print(f"\n Loading Longformer-large...")
655
+ long_model = AutoModel.from_pretrained(TCFG.longformer_model,
656
+ torch_dtype=torch.float16).to(device)
657
+ long_model.eval()
658
+ for p in long_model.parameters():
659
+ p.requires_grad = False
660
+ long_tok = AutoTokenizer.from_pretrained(TCFG.longformer_model)
661
+ print(f" {sum(p.numel() for p in long_model.parameters()):,} params (frozen)")
662
+
663
+ # ── Data ──
664
+ print(f"\n{'='*70}")
665
+ print("DATA")
666
+ print(f"{'='*70}")
667
+
668
+ train_docs = load_wikitext_documents("train", TCFG.max_documents)
669
+ train_ids, train_masks, train_nr, train_texts = build_chains_with_text(
670
+ train_docs, bert_tokenizer)
671
+
672
+ val_docs = load_wikitext_documents("validation", TCFG.max_val_documents)
673
+ val_ids, val_masks, val_nr, val_texts = build_chains_with_text(
674
+ val_docs, bert_tokenizer)
675
+
676
+ train_ds = ChainDataset(train_ids, train_masks, train_nr, train_texts)
677
+ val_ds = ChainDataset(val_ids, val_masks, val_nr, val_texts)
678
+
679
+ train_loader = DataLoader(train_ds, batch_size=TCFG.batch_size, shuffle=True,
680
+ num_workers=0, pin_memory=True, drop_last=True,
681
+ collate_fn=chain_collate)
682
+ val_loader = DataLoader(val_ds, batch_size=TCFG.batch_size, shuffle=False,
683
+ num_workers=0, pin_memory=True,
684
+ collate_fn=chain_collate)
685
+
686
+ print(f"\n Train: {len(train_ds)} chains β†’ {len(train_loader)} batches")
687
+ print(f" Val: {len(val_ds)} chains β†’ {len(val_loader)} batches")
688
+
689
+ # ── Static Procrustes pre-alignment ──
690
+ print(f"\n{'='*70}")
691
+ print("PROCRUSTES PRE-ALIGNMENT")
692
+ print(f"{'='*70}")
693
+
694
+ # Use first N train docs for alignment
695
+ align_texts = train_texts[:TCFG.procrustes_n_samples]
696
+ compute_and_init_procrustes(
697
+ model, modern_model, modern_tok, long_model, long_tok,
698
+ bert_tokenizer, align_texts, device)
699
+
700
+ # ── Train ──
701
+ print(f"\n{'='*70}")
702
+ print("TRAINING")
703
+ print(f"{'='*70}")
704
+
705
+ train(model, modern_model, modern_tok, long_model, long_tok,
706
+ train_loader, val_loader)
707
+
708
+ # ── Final eval ──
709
+ print(f"\n{'='*70}")
710
+ print("FINAL EVALUATION")
711
+ print(f"{'='*70}")
712
+
713
+ evaluate(model, modern_model, modern_tok, long_model, long_tok, val_loader)
714
+ print("\nDone.")
715
+
716
+
717
+ if __name__ == "__main__":
718
+ main()