AbstractPhil commited on
Commit
38cce53
Β·
verified Β·
1 Parent(s): 5ae1c5f

Create trainer.py

Browse files
Files changed (1) hide show
  1. trainer.py +755 -0
trainer.py ADDED
@@ -0,0 +1,755 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================================
2
+ # TRAINER: MEMORY-CLIP-SEQ β€” Sequence Reconstruction
3
+ #
4
+ # Extends the v2 trainer with:
5
+ # - Teacher full sequence capture (ModernBERT last_hidden_state)
6
+ # - Sequence reconstruction loss (reconstructed 77 vs teacher projected 77)
7
+ # - Two-phase training:
8
+ # Phase 1: freeze v1 memory weights, train only seq head
9
+ # Phase 2: unfreeze all, joint fine-tune
10
+ # - v1 checkpoint loading at startup
11
+ #
12
+ # Core training loop (InfoNCE + Procrustes + CV) UNCHANGED from v2.
13
+ # Sequence loss is ADDED alongside existing losses.
14
+ # ============================================================================
15
+
16
+ import gc
17
+ import math
18
+ import os
19
+ import json
20
+ import time
21
+ from dataclasses import dataclass, asdict
22
+
23
+ import numpy as np
24
+ import torch
25
+ import torch.nn as nn
26
+ import torch.nn.functional as F
27
+ from torch.utils.tensorboard import SummaryWriter
28
+ from tqdm import tqdm
29
+ from safetensors.torch import save_file as safetensors_save, load_file
30
+
31
+
32
+ # ══════════════════════════════════════════════════════════════════
33
+ # CONFIG
34
+ # ══════════════════════════════════════════════════════════════════
35
+
36
+ @dataclass
37
+ class TrainSeqConfig:
38
+ # Data
39
+ max_train_samples: int = 50000
40
+ max_val_samples: int = 2000
41
+ min_caption_length: int = 100
42
+
43
+ # Training β€” phase 1 (seq head only)
44
+ phase1_epochs: int = 5
45
+ phase1_lr_seq: float = 2e-3
46
+ phase1_lr_proj: float = 1e-3
47
+
48
+ # Training β€” phase 2 (joint fine-tune)
49
+ phase2_epochs: int = 5
50
+ phase2_lr_bank: float = 5e-4 # reduced from v2's 2e-3
51
+ phase2_lr_output: float = 2e-4 # reduced
52
+ phase2_lr_proj: float = 5e-4
53
+ phase2_lr_seq: float = 1e-3
54
+
55
+ # Shared
56
+ batch_size: int = 64
57
+ min_lr: float = 1e-6
58
+ weight_decay: float = 0.01
59
+ grad_clip: float = 1.0
60
+ warmup_steps: int = 200
61
+
62
+ # Loss weights β€” existing (unchanged from v2)
63
+ modern_weight: float = 1.0
64
+ procrustes_weight: float = 0.3
65
+ cv_weight: float = 0.05
66
+ temperature: float = 0.07
67
+
68
+ # Loss weights β€” sequence (NEW)
69
+ sequence_weight: float = 1.0 # MSE between reconstructed and teacher seq
70
+ sequence_cosine_weight: float = 0.5 # per-position cosine similarity
71
+
72
+ # Teacher
73
+ modern_max_len: int = 4096
74
+ procrustes_n_samples: int = 300
75
+
76
+ # v1 checkpoint β€” local path or HuggingFace URL
77
+ v1_checkpoint: str = ""
78
+ v1_repo_id: str = "AbstractPhil/geolip-clip-vit-large-patch14-ctx576"
79
+ v1_filename: str = "model.safetensors"
80
+
81
+ # Logging
82
+ checkpoint_dir: str = "/home/claude/memory_clip_seq_checkpoints"
83
+ tensorboard_dir: str = "/home/claude/memory_clip_seq_tb"
84
+ metrics_file: str = "/home/claude/memory_clip_seq_checkpoints/metrics.json"
85
+ log_every: int = 20
86
+ eval_every: int = 200
87
+
88
+
89
+ TCFG = TrainSeqConfig()
90
+
91
+
92
+ # ══════════════════════════════════════════════════════════════════
93
+ # GEOMETRIC UTILITIES β€” IDENTICAL to v2
94
+ # ══════════════════════════════════════════════════════════════════
95
+
96
+ def cayley_menger_vol2(pts):
97
+ with torch.amp.autocast("cuda", enabled=False):
98
+ pts = pts.float()
99
+ diff = pts.unsqueeze(-2) - pts.unsqueeze(-3)
100
+ d2 = (diff * diff).sum(-1)
101
+ B, V, _ = d2.shape
102
+ cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32)
103
+ cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
104
+ s = (-1.0)**V; f = math.factorial(V-1)
105
+ return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm)
106
+
107
+ def pentachoron_cv(embeddings, n_samples=16):
108
+ B = embeddings.shape[0]
109
+ if B < 5:
110
+ return torch.tensor(0.0, device=embeddings.device)
111
+ vols = []
112
+ for _ in range(n_samples):
113
+ idx = torch.randperm(B, device=embeddings.device)[:5]
114
+ v2 = cayley_menger_vol2(embeddings[idx].unsqueeze(0))
115
+ vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12))
116
+ stacked = torch.stack(vols)
117
+ return stacked.std() / (stacked.mean() + 1e-8)
118
+
119
+ def procrustes_alignment_loss(emb_a, emb_b):
120
+ with torch.amp.autocast("cuda", enabled=False):
121
+ A = F.normalize(emb_a.float(), dim=-1)
122
+ B_e = F.normalize(emb_b.float(), dim=-1)
123
+ A = A - A.mean(0, keepdim=True)
124
+ B_e = B_e - B_e.mean(0, keepdim=True)
125
+ S = torch.linalg.svdvals(A.T @ B_e)
126
+ N, D = A.shape
127
+ return 1.0 - S.sum() / (math.sqrt(N) * D)
128
+
129
+
130
+ # ════════════════════════════════════════════���═════════════════════
131
+ # LOSSES β€” v2 existing + NEW sequence loss
132
+ # ══════════════════════════════════════════════════════════════════
133
+
134
+ def infonce_loss(emb_a, emb_b, temperature=0.07):
135
+ a = F.normalize(emb_a, dim=-1)
136
+ b = F.normalize(emb_b, dim=-1)
137
+ logits = (a @ b.T) / temperature
138
+ labels = torch.arange(logits.shape[0], device=logits.device)
139
+ loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2
140
+ with torch.no_grad():
141
+ acc = (logits.argmax(-1) == labels).float().mean().item()
142
+ top5 = logits.topk(min(5, logits.shape[1]), dim=-1).indices
143
+ acc5 = (top5 == labels.unsqueeze(-1)).any(-1).float().mean().item()
144
+ return loss, acc, acc5
145
+
146
+
147
+ def batch_cv_loss(all_anchors, n_reals, cv_target=0.20):
148
+ device = all_anchors.device
149
+ B = all_anchors.shape[0]
150
+ total_loss = torch.tensor(0.0, device=device)
151
+ total_cv = 0.0; n_valid = 0
152
+ per_sample_cv = []
153
+ for b in range(B):
154
+ n = n_reals[b].item() if isinstance(n_reals[b], torch.Tensor) else n_reals[b]
155
+ if n < 5:
156
+ continue
157
+ cv_val = pentachoron_cv(all_anchors[b, :n], n_samples=16)
158
+ total_loss = total_loss + (cv_val - cv_target).abs()
159
+ total_cv += cv_val.item()
160
+ per_sample_cv.append(cv_val.item())
161
+ n_valid += 1
162
+ stats = {
163
+ "cv_raw": total_cv / max(n_valid, 1),
164
+ "cv_std": float(np.std(per_sample_cv)) if per_sample_cv else 0.0,
165
+ "cv_n_valid": n_valid,
166
+ }
167
+ return total_loss / max(n_valid, 1), stats
168
+
169
+
170
+ def sequence_reconstruction_loss(pred_seq, target_seq):
171
+ """
172
+ pred_seq: (B, 77, 768) β€” reconstructed sequence
173
+ target_seq: (B, 77, 768) β€” teacher projected sequence
174
+
175
+ Returns:
176
+ mse_loss: mean squared error
177
+ cos_loss: 1 - mean per-position cosine similarity
178
+ mean_cos: scalar metric (not differentiable)
179
+ """
180
+ mse = F.mse_loss(pred_seq, target_seq)
181
+
182
+ # Per-position cosine similarity
183
+ pred_norm = F.normalize(pred_seq, dim=-1)
184
+ tgt_norm = F.normalize(target_seq, dim=-1)
185
+ cos_sim = (pred_norm * tgt_norm).sum(-1) # (B, 77)
186
+ cos_loss = 1.0 - cos_sim.mean()
187
+
188
+ with torch.no_grad():
189
+ mean_cos = cos_sim.mean().item()
190
+
191
+ return mse, cos_loss, mean_cos
192
+
193
+
194
+ # ══════════════════════════════════════════════════════════════════
195
+ # TEACHER β€” returns BOTH pooled AND full sequence
196
+ # ══════════════════════════════════════════════════════════════════
197
+
198
+ @torch.no_grad()
199
+ def teacher_forward(model, tokenizer, texts, device, max_len):
200
+ """Returns pooled (B, 1024) from ModernBERT. Sequence target comes from CLIP."""
201
+ inputs = tokenizer(texts, max_length=max_len, padding=True,
202
+ truncation=True, return_tensors="pt").to(device)
203
+ out = model(**inputs)
204
+ mask = inputs.attention_mask.unsqueeze(-1).float()
205
+ pooled = (out.last_hidden_state * mask).sum(1) / mask.sum(1).clamp(min=1)
206
+ return pooled
207
+
208
+
209
+ # ══════════════════════════════════════════════════════════════════
210
+ # PARAM GROUPS
211
+ # ══════════════════════════════════════════════════════════════════
212
+
213
+ def make_param_groups_phase1(model):
214
+ """Phase 1: Only train sequence head + teacher seq projector."""
215
+ seq_params = []
216
+ for name, param in model.named_parameters():
217
+ param.requires_grad = False # freeze everything first
218
+ for name, param in model.named_parameters():
219
+ if "sequence_reconstructor" in name:
220
+ param.requires_grad = True
221
+ seq_params.append(param)
222
+ # Also keep proj_modern trainable (it's the pooled projector)
223
+ for name, param in model.named_parameters():
224
+ if "proj_modern" in name:
225
+ param.requires_grad = True
226
+
227
+ proj_params = [p for n, p in model.named_parameters()
228
+ if "proj_modern" in n and p.requires_grad]
229
+
230
+ groups = [
231
+ {"params": seq_params, "lr": TCFG.phase1_lr_seq, "name": "seq_head",
232
+ "weight_decay": TCFG.weight_decay},
233
+ {"params": proj_params, "lr": TCFG.phase1_lr_proj, "name": "proj",
234
+ "weight_decay": TCFG.weight_decay},
235
+ ]
236
+ for g in groups:
237
+ n = sum(p.numel() for p in g["params"])
238
+ print(f" {g['name']:12s}: {n:>10,} params @ lr={g['lr']}")
239
+ return groups
240
+
241
+
242
+ def make_param_groups_phase2(model):
243
+ """Phase 2: Unfreeze everything, differential LRs."""
244
+ # Unfreeze all trainable (non-CLIP) params
245
+ for name, param in model.named_parameters():
246
+ if "clip_text" not in name and "_clip_text" not in name:
247
+ param.requires_grad = True
248
+
249
+ bank_names = {"bank.", "clip_cross_attn", "clip_cross_norms",
250
+ "clip_cross_ffns", "clip_cross_ffn_norms"}
251
+ seq_names = {"sequence_reconstructor"}
252
+ proj_names = {"proj_modern"}
253
+
254
+ bank_p, seq_p, proj_p, output_p = [], [], [], []
255
+ for name, param in model.named_parameters():
256
+ if not param.requires_grad:
257
+ continue
258
+ if any(name.startswith(s) or s in name for s in seq_names):
259
+ seq_p.append(param)
260
+ elif any(name.startswith(s) or s in name for s in proj_names):
261
+ proj_p.append(param)
262
+ elif any(name.startswith(s) or s in name for s in bank_names):
263
+ bank_p.append(param)
264
+ else:
265
+ output_p.append(param)
266
+
267
+ groups = [
268
+ {"params": bank_p, "lr": TCFG.phase2_lr_bank, "name": "bank",
269
+ "weight_decay": TCFG.weight_decay},
270
+ {"params": seq_p, "lr": TCFG.phase2_lr_seq, "name": "seq_head",
271
+ "weight_decay": TCFG.weight_decay},
272
+ {"params": proj_p, "lr": TCFG.phase2_lr_proj, "name": "proj",
273
+ "weight_decay": TCFG.weight_decay},
274
+ {"params": output_p, "lr": TCFG.phase2_lr_output, "name": "output",
275
+ "weight_decay": TCFG.weight_decay},
276
+ ]
277
+ for g in groups:
278
+ n = sum(p.numel() for p in g["params"])
279
+ print(f" {g['name']:12s}: {n:>10,} params @ lr={g['lr']}")
280
+ return groups
281
+
282
+
283
+ # ══════════════════════════════════════════════════════════════════
284
+ # PROCRUSTES INIT β€” IDENTICAL to v2
285
+ # ══════════════════════════════════════════════════════════════════
286
+
287
+ @torch.no_grad()
288
+ def compute_and_init_procrustes(student_model, modern_model, modern_tok,
289
+ captions, device):
290
+ print(f"\n Computing static Procrustes on {len(captions)} captions...")
291
+ student_embs, modern_embs = [], []
292
+ clip_tok = student_model.clip_tokenizer
293
+ for i in range(0, len(captions), 16):
294
+ batch = captions[i:i+16]
295
+ tokens = clip_tok(batch, max_length=77, padding=True,
296
+ truncation=True, return_tensors="pt").to(device)
297
+ clip_out = student_model.clip_text(
298
+ input_ids=tokens.input_ids,
299
+ attention_mask=tokens.attention_mask,
300
+ output_hidden_states=False)
301
+ student_embs.append(clip_out.pooler_output.cpu())
302
+ pooled = teacher_forward(modern_model, modern_tok, batch,
303
+ device, TCFG.modern_max_len)
304
+ modern_embs.append(pooled.cpu())
305
+ student_all = torch.cat(student_embs)
306
+ modern_all = torch.cat(modern_embs)
307
+ print(f" Student: {student_all.shape}, Teacher: {modern_all.shape}")
308
+ X = student_all.float(); Y = modern_all.float()
309
+ mu_x, mu_y = X.mean(0), Y.mean(0)
310
+ Xc, Yc = X - mu_x, Y - mu_y
311
+ if Xc.shape[1] < Yc.shape[1]:
312
+ pad = torch.zeros(Xc.shape[0], Yc.shape[1] - Xc.shape[1])
313
+ Xc = torch.cat([Xc, pad], dim=1)
314
+ mu_x = torch.cat([mu_x, torch.zeros(Yc.shape[1] - mu_x.shape[0])])
315
+ U, S, Vt = torch.linalg.svd(Xc.T @ Yc)
316
+ R = (U @ Vt).T
317
+ cos_before = F.cosine_similarity(Xc, Yc, dim=-1).mean()
318
+ cos_after = F.cosine_similarity((Xc @ R.T), Yc, dim=-1).mean()
319
+ print(f" Procrustes: cos {cos_before:.4f} β†’ {cos_after:.4f}")
320
+ # Init the pooled projector if it has init_from_procrustes
321
+ if hasattr(student_model.proj_modern, 'init_from_procrustes'):
322
+ student_model.proj_modern.init_from_procrustes(R, mu_x, mu_y)
323
+ return {"cos_before": cos_before.item(), "cos_after": cos_after.item()}
324
+
325
+
326
+ # ══════════════════════════════════════════════════════════════════
327
+ # V1 WEIGHT LOADING
328
+ # ══════════════════════════════════════════════════════════════════
329
+
330
+ def load_v1_weights(model, device):
331
+ """Load v1 memory system weights into the expanded seq model.
332
+ Tries local path first, then downloads from HuggingFace."""
333
+ checkpoint_path = TCFG.v1_checkpoint
334
+
335
+ # Try local path first
336
+ if checkpoint_path and os.path.exists(checkpoint_path):
337
+ print(f" Loading v1 weights (local): {checkpoint_path}")
338
+ else:
339
+ # Download from HuggingFace
340
+ from huggingface_hub import hf_hub_download
341
+ print(f" Downloading v1 weights from {TCFG.v1_repo_id}/{TCFG.v1_filename}...")
342
+ checkpoint_path = hf_hub_download(
343
+ repo_id=TCFG.v1_repo_id,
344
+ filename=TCFG.v1_filename)
345
+ print(f" Downloaded to: {checkpoint_path}")
346
+
347
+ state = load_file(checkpoint_path, device=str(device))
348
+ missing, unexpected = model.load_state_dict(state, strict=False)
349
+
350
+ n_loaded = len(state) - len(unexpected)
351
+ print(f" Loaded: {n_loaded} tensors from v1")
352
+ print(f" Missing (new modules): {len(missing)}")
353
+ if missing:
354
+ new_module_keys = [k for k in missing if "sequence_reconstructor" in k
355
+ ]
356
+ other_missing = [k for k in missing if k not in new_module_keys]
357
+ print(f" Seq head (expected new): {len(new_module_keys)}")
358
+ if other_missing:
359
+ print(f" Other (check!): {other_missing[:5]}")
360
+ if unexpected:
361
+ print(f" Unexpected (v1 buffers, ignorable): {len(unexpected)}")
362
+ return True
363
+
364
+
365
+ # ══════════════════════════════════════════════════════════════════
366
+ # TRAINING
367
+ # ══════════════════════════════════════════════════════════════════
368
+
369
+ def train_phase(model, modern_model, modern_tok, train_captions, val_captions,
370
+ param_groups, n_epochs, phase_name, writer, all_metrics,
371
+ global_step=0):
372
+ """
373
+ Single training phase. Used for both phase 1 and phase 2.
374
+ """
375
+ device = next(model.parameters()).device
376
+ optimizer = torch.optim.AdamW(param_groups)
377
+ all_params = [p for g in param_groups for p in g["params"]]
378
+
379
+ n_batches = len(train_captions) // TCFG.batch_size
380
+ total_steps = n_batches * n_epochs
381
+ scheduler = torch.optim.lr_scheduler.SequentialLR(
382
+ optimizer,
383
+ [torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.01,
384
+ total_iters=TCFG.warmup_steps),
385
+ torch.optim.lr_scheduler.CosineAnnealingLR(
386
+ optimizer, T_max=max(total_steps, 1), eta_min=TCFG.min_lr)],
387
+ milestones=[TCFG.warmup_steps])
388
+
389
+ scaler = torch.amp.GradScaler()
390
+ clip_tokenizer = model.clip_tokenizer
391
+ best_val_loss = float("inf")
392
+
393
+ print(f"\n {phase_name}: {sum(p.numel() for p in all_params):,} trainable params")
394
+ print(f" {n_batches} batches/epoch Γ— {TCFG.batch_size}")
395
+
396
+ # segment_text is in notebook namespace from architecture cell
397
+ _segment_text = segment_text
398
+
399
+ for epoch in range(n_epochs):
400
+ model.train()
401
+ perm = np.random.permutation(len(train_captions))
402
+ losses = {"total": 0, "modern": 0, "procrustes": 0, "cv": 0,
403
+ "seq_mse": 0, "seq_cos": 0}
404
+ metrics = {"modern_acc": 0, "modern_acc5": 0,
405
+ "cv_raw": 0, "seq_cos_sim": 0, "n_segments_avg": 0}
406
+ n = 0
407
+ t0 = time.time()
408
+
409
+ pbar = tqdm(range(0, len(train_captions), TCFG.batch_size),
410
+ desc=f"{phase_name} E{epoch+1}/{n_epochs}", unit="batch")
411
+
412
+ for batch_start in pbar:
413
+ idx = perm[batch_start:batch_start + TCFG.batch_size]
414
+ if len(idx) < 2:
415
+ continue
416
+ batch_captions = [train_captions[i] for i in idx]
417
+ B = len(batch_captions)
418
+
419
+ # ── Teacher: ModernBERT pooled (sequence target comes from CLIP) ──
420
+ with torch.no_grad():
421
+ with torch.amp.autocast("cuda"):
422
+ modern_cls = teacher_forward(
423
+ modern_model, modern_tok, batch_captions,
424
+ device, TCFG.modern_max_len)
425
+
426
+ # ── Student: segment-by-segment processing ──
427
+ state = model.init_state(B, device)
428
+ all_segments = [_segment_text(cap, clip_tokenizer,
429
+ model.config.max_content_tokens,
430
+ model.config.segment_overlap,
431
+ model.config.max_segments)
432
+ for cap in batch_captions]
433
+ max_segs = max(len(s) for s in all_segments)
434
+ n_segs = [len(s) for s in all_segments]
435
+
436
+ for seg_k in range(max_segs):
437
+ batch_ids, batch_masks = [], []
438
+ for b in range(B):
439
+ if seg_k < len(all_segments[b]):
440
+ batch_ids.append(all_segments[b][seg_k]["input_ids"])
441
+ batch_masks.append(all_segments[b][seg_k]["attention_mask"])
442
+ else:
443
+ batch_ids.append(torch.zeros(77, dtype=torch.long))
444
+ batch_masks.append(torch.zeros(77, dtype=torch.long))
445
+ ids = torch.stack(batch_ids).to(device)
446
+ masks = torch.stack(batch_masks).to(device)
447
+ with torch.amp.autocast("cuda"):
448
+ fused_output, state = model.forward_segment(ids, masks, state)
449
+
450
+ student_cls = fused_output # pooled output from last segment
451
+
452
+ # Bank anchors for CV loss β€” accumulated in state during segment processing
453
+ bank_anchors = state["bank"]["anchors"] # (B, N_written, 768)
454
+ # Pad to max_segs for batch CV computation
455
+ all_anchors = torch.zeros(B, max_segs, model.config.anchor_dim, device=device)
456
+ n_written = min(bank_anchors.shape[1], max_segs)
457
+ all_anchors[:, :n_written] = bank_anchors[:, :n_written]
458
+
459
+ # ── Existing losses (UNCHANGED from v2) ──
460
+ with torch.amp.autocast("cuda"):
461
+ proj_m = model.proj_modern(student_cls)
462
+ l_modern, acc_m, acc5_m = infonce_loss(
463
+ proj_m, modern_cls, TCFG.temperature)
464
+ l_procrustes = procrustes_alignment_loss(
465
+ student_cls, modern_cls[:, :model.config.clip_hidden])
466
+ n_reals_t = torch.tensor(n_segs, device=device)
467
+ l_cv, cv_stats = batch_cv_loss(
468
+ all_anchors, n_reals_t, model.config.cv_target)
469
+
470
+ # ── NEW: Sequence reconstruction loss ──
471
+ # Target: CLIP's own last_hidden_state on the truncated caption.
472
+ # This is what the UNet was trained on β€” the reconstructor must
473
+ # produce sequences in CLIP's distribution.
474
+ with torch.no_grad():
475
+ clip_inputs = clip_tokenizer(
476
+ batch_captions, max_length=77, padding="max_length",
477
+ truncation=True, return_tensors="pt").to(device)
478
+ with torch.amp.autocast("cuda"):
479
+ clip_target_out = model.clip_text(
480
+ input_ids=clip_inputs.input_ids,
481
+ attention_mask=clip_inputs.attention_mask,
482
+ output_hidden_states=False, return_dict=True)
483
+ clip_target_seq = clip_target_out.last_hidden_state # (B, 77, 768)
484
+
485
+ with torch.amp.autocast("cuda"):
486
+ # Reconstruct sequence from memory state
487
+ recon_seq = model.reconstruct_sequence(state) # (B, 77, 768)
488
+
489
+ l_seq_mse, l_seq_cos, seq_cos_metric = sequence_reconstruction_loss(
490
+ recon_seq, clip_target_seq.detach())
491
+
492
+ # ── Combined loss ──
493
+ with torch.amp.autocast("cuda"):
494
+ loss = (TCFG.modern_weight * l_modern +
495
+ TCFG.procrustes_weight * l_procrustes +
496
+ TCFG.cv_weight * l_cv +
497
+ TCFG.sequence_weight * l_seq_mse +
498
+ TCFG.sequence_cosine_weight * l_seq_cos)
499
+
500
+ scaler.scale(loss).backward()
501
+ scaler.unscale_(optimizer)
502
+ torch.nn.utils.clip_grad_norm_(all_params, TCFG.grad_clip)
503
+ scaler.step(optimizer)
504
+ scaler.update()
505
+ optimizer.zero_grad(set_to_none=True)
506
+ scheduler.step()
507
+ global_step += 1
508
+
509
+ # ── Metrics ──
510
+ losses["total"] += loss.item()
511
+ losses["modern"] += l_modern.item()
512
+ losses["procrustes"] += l_procrustes.item()
513
+ losses["cv"] += l_cv.item()
514
+ losses["seq_mse"] += l_seq_mse.item()
515
+ losses["seq_cos"] += l_seq_cos.item()
516
+ metrics["modern_acc"] += acc_m
517
+ metrics["modern_acc5"] += acc5_m
518
+ metrics["cv_raw"] += cv_stats.get("cv_raw", 0)
519
+ metrics["seq_cos_sim"] += seq_cos_metric
520
+ metrics["n_segments_avg"] += np.mean(n_segs)
521
+ n += 1
522
+
523
+ d = max(n, 1)
524
+ pbar.set_postfix(
525
+ loss=f"{losses['total']/d:.3f}",
526
+ m_acc=f"{metrics['modern_acc']/d:.3f}",
527
+ s_cos=f"{metrics['seq_cos_sim']/d:.3f}",
528
+ cv=f"{metrics['cv_raw']/d:.3f}")
529
+
530
+ # Tensorboard
531
+ if global_step % TCFG.log_every == 0:
532
+ writer.add_scalar(f"{phase_name}/loss", losses["total"]/d, global_step)
533
+ writer.add_scalar(f"{phase_name}/modern_loss", losses["modern"]/d, global_step)
534
+ writer.add_scalar(f"{phase_name}/seq_mse", losses["seq_mse"]/d, global_step)
535
+ writer.add_scalar(f"{phase_name}/seq_cos_loss", losses["seq_cos"]/d, global_step)
536
+ writer.add_scalar(f"{phase_name}/seq_cos_sim", metrics["seq_cos_sim"]/d, global_step)
537
+ writer.add_scalar(f"{phase_name}/m_acc", metrics["modern_acc"]/d, global_step)
538
+ writer.add_scalar(f"{phase_name}/cv_raw", metrics["cv_raw"]/d, global_step)
539
+ all_metrics["steps"].append({
540
+ "step": global_step, "phase": phase_name,
541
+ "epoch": epoch + 1,
542
+ "loss": losses["total"]/d,
543
+ "m_acc": metrics["modern_acc"]/d,
544
+ "seq_cos": metrics["seq_cos_sim"]/d,
545
+ "cv_raw": metrics["cv_raw"]/d,
546
+ })
547
+
548
+ pbar.close()
549
+ elapsed = time.time() - t0
550
+ d = max(n, 1)
551
+
552
+ epoch_summary = {
553
+ "phase": phase_name, "epoch": epoch + 1,
554
+ "elapsed_s": elapsed,
555
+ "loss": losses["total"]/d,
556
+ "modern_loss": losses["modern"]/d,
557
+ "seq_mse": losses["seq_mse"]/d,
558
+ "seq_cos_loss": losses["seq_cos"]/d,
559
+ "m_acc": metrics["modern_acc"]/d,
560
+ "m_acc5": metrics["modern_acc5"]/d,
561
+ "seq_cos_sim": metrics["seq_cos_sim"]/d,
562
+ "cv_raw": metrics["cv_raw"]/d,
563
+ "global_step": global_step,
564
+ }
565
+
566
+ all_metrics["epochs"].append(epoch_summary)
567
+
568
+ print(f"\n {phase_name} E{epoch+1}: {elapsed:.0f}s "
569
+ f"loss={epoch_summary['loss']:.4f} "
570
+ f"m_acc={epoch_summary['m_acc']:.3f} "
571
+ f"seq_cos={epoch_summary['seq_cos_sim']:.3f} "
572
+ f"cv={epoch_summary['cv_raw']:.3f}")
573
+
574
+ # Save
575
+ save_checkpoint(model, optimizer, epoch + 1, global_step, phase_name,
576
+ os.path.join(TCFG.checkpoint_dir, f"{phase_name}_e{epoch+1:02d}"))
577
+
578
+ with open(TCFG.metrics_file, "w") as f:
579
+ json.dump(all_metrics, f, indent=2, default=str)
580
+
581
+ return global_step
582
+
583
+
584
+ def save_checkpoint(model, optimizer, epoch, global_step, phase, path):
585
+ os.makedirs(path, exist_ok=True)
586
+ state = {}
587
+ for name, param in model.named_parameters():
588
+ if param.requires_grad:
589
+ state[name] = param.data.contiguous().cpu()
590
+ for name, buf in model.named_buffers():
591
+ state[f"buffer.{name}"] = buf.contiguous().cpu()
592
+ safetensors_save(state, os.path.join(path, "memory_system.safetensors"))
593
+ torch.save({"optimizer": optimizer.state_dict() if optimizer else {},
594
+ "epoch": epoch,
595
+ "global_step": global_step, "phase": phase},
596
+ os.path.join(path, "training_state.pt"))
597
+ model_cfg = model.config.to_dict() if hasattr(model.config, 'to_dict') else {}
598
+ config_data = {"model": model_cfg, "training": asdict(TCFG)}
599
+ with open(os.path.join(path, "config.json"), "w") as f:
600
+ json.dump(config_data, f, indent=2, default=str)
601
+
602
+
603
+ # ══════════════════════════════════════════════════════════════════
604
+ # DATA
605
+ # ══════════════════════════════════════════════════════════════════
606
+
607
+ def load_long_captions(max_train, max_val, min_length=100):
608
+ from datasets import load_dataset
609
+ print(f" Loading CaptionEmporium/conceptual-captions-cc12m-llavanext...")
610
+ ds = load_dataset("CaptionEmporium/conceptual-captions-cc12m-llavanext",
611
+ split="train", streaming=True)
612
+ captions = []
613
+ for row in ds:
614
+ cap = row.get("caption_llava", "")
615
+ if isinstance(cap, str) and len(cap) > min_length:
616
+ captions.append(cap)
617
+ if len(captions) >= max_train + max_val:
618
+ break
619
+ train_caps = captions[:max_train]
620
+ val_caps = captions[max_train:max_train + max_val]
621
+ print(f" Train: {len(train_caps)}, Val: {len(val_caps)}")
622
+ return train_caps, val_caps
623
+
624
+
625
+ # ══════════════════════════════════════════════════════════════════
626
+ # MAIN
627
+ # ══════════════════════════════════════════════════════════════════
628
+
629
+ def main():
630
+ print("=" * 70)
631
+ print("TRAINING: MEMORY-CLIP-SEQ (SEQUENCE RECONSTRUCTION)")
632
+ print("=" * 70)
633
+
634
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
635
+ print(f" Device: {device}")
636
+ if torch.cuda.is_available():
637
+ print(f" GPU: {torch.cuda.get_device_name()}")
638
+ print(f" VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
639
+
640
+ # ── Model (classes loaded from architecture cell) ──
641
+ config = MemoryCLIPSeqConfig()
642
+ model = MemoryCLIPSeqModel(config).to(device)
643
+
644
+ # Trigger CLIP lazy load
645
+ _ = model.clip_text
646
+ _ = model.clip_tokenizer
647
+
648
+ # ── Load v1 weights ──
649
+ load_v1_weights(model, device)
650
+ print(" v1 memory system weights loaded")
651
+
652
+ # ── Teacher ──
653
+ from transformers import AutoModel, AutoTokenizer
654
+ print(f"\n Loading ModernBERT-large...")
655
+ modern_model = AutoModel.from_pretrained(
656
+ config.teacher_model, torch_dtype=torch.float16).to(device)
657
+ modern_model.eval()
658
+ for p in modern_model.parameters():
659
+ p.requires_grad = False
660
+ modern_tok = AutoTokenizer.from_pretrained(config.teacher_model)
661
+ print(f" {sum(p.numel() for p in modern_model.parameters()):,} params (frozen)")
662
+
663
+ # ── Data ──
664
+ train_captions, val_captions = load_long_captions(
665
+ TCFG.max_train_samples, TCFG.max_val_samples, TCFG.min_caption_length)
666
+
667
+ from transformers import CLIPTokenizer
668
+ tok_temp = CLIPTokenizer.from_pretrained(config.clip_model)
669
+ lengths = [len(tok_temp.encode(c)) for c in train_captions[:500]]
670
+ print(f" Caption tokens (sample 500): mean={np.mean(lengths):.0f} "
671
+ f"median={np.median(lengths):.0f} max={max(lengths)} "
672
+ f">77: {sum(1 for l in lengths if l > 77)/len(lengths):.1%}")
673
+ del tok_temp
674
+
675
+ # ── Procrustes init ──
676
+ compute_and_init_procrustes(
677
+ model, modern_model, modern_tok,
678
+ train_captions[:TCFG.procrustes_n_samples], device)
679
+
680
+ # ── Setup logging ──
681
+ os.makedirs(TCFG.checkpoint_dir, exist_ok=True)
682
+ os.makedirs(TCFG.tensorboard_dir, exist_ok=True)
683
+ writer = SummaryWriter(log_dir=TCFG.tensorboard_dir)
684
+ all_metrics = {
685
+ "config": {**{k: v for k, v in config.to_dict().items()
686
+ if not k.startswith("_")},
687
+ **asdict(TCFG)},
688
+ "epochs": [], "steps": [],
689
+ }
690
+
691
+ global_step = 0
692
+
693
+ # ═══════════════════════════════════════════════════
694
+ # PHASE 1: Train sequence head only (v1 weights frozen)
695
+ # ═══════════════════════════════════════════════════
696
+ print(f"\n{'='*70}")
697
+ print(f"PHASE 1: Sequence head training ({TCFG.phase1_epochs} epochs)")
698
+ print(f" v1 memory system: FROZEN")
699
+ print(f" Sequence reconstructor: TRAINING")
700
+ print(f"{'='*70}")
701
+
702
+ phase1_groups = make_param_groups_phase1(model)
703
+ global_step = train_phase(
704
+ model, modern_model, modern_tok,
705
+ train_captions, val_captions,
706
+ phase1_groups, TCFG.phase1_epochs,
707
+ "phase1", writer, all_metrics, global_step)
708
+
709
+ save_checkpoint(model, None, TCFG.phase1_epochs, global_step, "phase1",
710
+ os.path.join(TCFG.checkpoint_dir, "phase1_final"))
711
+
712
+ # ═══════════════════════════════════════════════════
713
+ # PHASE 2: Joint fine-tune (everything unfrozen)
714
+ # ═══════════════════════════════════════════════════
715
+ print(f"\n{'='*70}")
716
+ print(f"PHASE 2: Joint fine-tune ({TCFG.phase2_epochs} epochs)")
717
+ print(f" All trainable modules: TRAINING")
718
+ print(f" v1 components: reduced LR")
719
+ print(f"{'='*70}")
720
+
721
+ phase2_groups = make_param_groups_phase2(model)
722
+ global_step = train_phase(
723
+ model, modern_model, modern_tok,
724
+ train_captions, val_captions,
725
+ phase2_groups, TCFG.phase2_epochs,
726
+ "phase2", writer, all_metrics, global_step)
727
+
728
+ # ── Final save ──
729
+ save_checkpoint(model, None, TCFG.phase1_epochs + TCFG.phase2_epochs,
730
+ global_step, "final",
731
+ os.path.join(TCFG.checkpoint_dir, "final"))
732
+
733
+ all_metrics["final"] = {
734
+ "total_steps": global_step,
735
+ "final_m_acc": all_metrics["epochs"][-1]["m_acc"],
736
+ "final_seq_cos": all_metrics["epochs"][-1]["seq_cos_sim"],
737
+ "final_cv": all_metrics["epochs"][-1]["cv_raw"],
738
+ }
739
+ with open(TCFG.metrics_file, "w") as f:
740
+ json.dump(all_metrics, f, indent=2, default=str)
741
+
742
+ writer.flush()
743
+ writer.close()
744
+
745
+ print(f"\n{'='*70}")
746
+ print(f"FINAL:")
747
+ final = all_metrics["epochs"][-1]
748
+ print(f" m_acc: {final['m_acc']:.4f}")
749
+ print(f" seq_cos: {final['seq_cos_sim']:.4f}")
750
+ print(f" CV: {final['cv_raw']:.4f}")
751
+ print(f"{'='*70}")
752
+
753
+
754
+ if __name__ == "__main__":
755
+ main()