AbstractPhil commited on
Commit
654b110
Β·
verified Β·
1 Parent(s): 7a468da

Create rapid_prototype_trainer.py

Browse files
Files changed (1) hide show
  1. rapid_prototype_trainer.py +670 -0
rapid_prototype_trainer.py ADDED
@@ -0,0 +1,670 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================================
2
+ # RAPID PROTOTYPE: 2-Expert Consensus + Alignment Bank
3
+ #
4
+ # Fast iteration cycle:
5
+ # Phase 1: Train student on 2-BERT consensus (20K captions, ~2 epochs)
6
+ # Phase 2: Freeze student, train alignment bank on its output
7
+ # Phase 3: Verify bank preserves geometry
8
+ # Phase 4: Snap a tiny classifier on bank output, check stability
9
+ # ============================================================================
10
+
11
+ import gc
12
+ import math
13
+ import os
14
+ import time
15
+ import json
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from tqdm import tqdm
22
+
23
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
24
+
25
+ EXPERTS = [
26
+ ("google-bert/bert-base-uncased", "bert", 512),
27
+ ("answerdotai/ModernBERT-base", "modern", 512),
28
+ ]
29
+
30
+ print("=" * 65)
31
+ print("RAPID PROTOTYPE: 2-Expert Consensus + Alignment Bank")
32
+ print("=" * 65)
33
+ print(f" Device: {DEVICE}")
34
+
35
+
36
+ # ══════════════════════════════════════════════════════════════════
37
+ # STUDENT MODEL
38
+ # ══════════════════════════════════════════════════════════════════
39
+
40
+ class MiniStudent(nn.Module):
41
+ def __init__(self, vocab_size=30522, max_len=512, d_model=256,
42
+ n_heads=4, n_layers=4, d_ff=1024, output_dim=768,
43
+ dropout=0.1, pad_token_id=0):
44
+ super().__init__()
45
+ self.pad_token_id = pad_token_id
46
+ self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_token_id)
47
+ self.pos_emb = nn.Embedding(max_len, d_model)
48
+ self.emb_norm = nn.LayerNorm(d_model)
49
+ self.emb_drop = nn.Dropout(dropout)
50
+ encoder_layer = nn.TransformerEncoderLayer(
51
+ d_model=d_model, nhead=n_heads, dim_feedforward=d_ff,
52
+ dropout=dropout, activation="gelu", batch_first=True,
53
+ norm_first=True)
54
+ self.encoder = nn.TransformerEncoder(
55
+ encoder_layer, num_layers=n_layers, enable_nested_tensor=False)
56
+ self.output_proj = nn.Sequential(
57
+ nn.Linear(d_model, d_model), nn.GELU(),
58
+ nn.LayerNorm(d_model), nn.Linear(d_model, output_dim))
59
+
60
+ def forward(self, input_ids, attention_mask=None):
61
+ B, L = input_ids.shape
62
+ positions = torch.arange(L, device=input_ids.device).unsqueeze(0)
63
+ x = self.token_emb(input_ids) + self.pos_emb(positions)
64
+ x = self.emb_drop(self.emb_norm(x))
65
+ kpm = ~attention_mask.bool() if attention_mask is not None else (input_ids == self.pad_token_id)
66
+ x = self.encoder(x, src_key_padding_mask=kpm)
67
+ mask = attention_mask.unsqueeze(-1).float() if attention_mask is not None else (~kpm).unsqueeze(-1).float()
68
+ pooled = (x * mask).sum(1) / mask.sum(1).clamp(min=1)
69
+ return F.normalize(self.output_proj(pooled), dim=-1)
70
+
71
+
72
+ # ══════════════════════════════════════════════════════════════════
73
+ # ALIGNMENT BANK
74
+ # ══════════════════════════════════════════════════════════════════
75
+
76
+ class AlignmentBank(nn.Module):
77
+ """
78
+ Geometric interface layer. Learns to annotate student embeddings
79
+ with per-expert alignment context and anchor distances.
80
+
81
+ Trained on frozen student output. Provides geometric memory of
82
+ the expert consensus for downstream heads.
83
+ """
84
+ def __init__(self, d_embed=768, n_experts=2, n_anchors=128, d_bank=64):
85
+ super().__init__()
86
+ self.d_embed = d_embed
87
+ self.n_experts = n_experts
88
+ self.n_anchors = n_anchors
89
+ self.d_bank = d_bank
90
+
91
+ # Per-expert rotation matrices (initialized from Procrustes)
92
+ self.expert_rotations = nn.ParameterList([
93
+ nn.Parameter(torch.eye(d_embed)) for _ in range(n_experts)
94
+ ])
95
+
96
+ # Per-expert bias (mean offset in each expert's space)
97
+ self.expert_means = nn.ParameterList([
98
+ nn.Parameter(torch.zeros(d_embed)) for _ in range(n_experts)
99
+ ])
100
+
101
+ # Anchor bank: learned consensus landmarks
102
+ self.anchors = nn.Parameter(
103
+ F.normalize(torch.randn(n_anchors, d_embed), dim=-1))
104
+
105
+ # Project geometric features into compact context
106
+ # Input: n_experts (consistency) + n_anchors (distances) + n_experts (reconstruction quality)
107
+ geo_dim = n_experts + n_anchors + n_experts
108
+ self.geo_proj = nn.Sequential(
109
+ nn.Linear(geo_dim, d_bank * 2),
110
+ nn.GELU(),
111
+ nn.LayerNorm(d_bank * 2),
112
+ nn.Linear(d_bank * 2, d_bank),
113
+ nn.LayerNorm(d_bank),
114
+ )
115
+
116
+ def init_from_procrustes(self, procrustes_results, expert_names,
117
+ consensus_embeddings=None):
118
+ """Initialize from consensus training artifacts."""
119
+ device = self.anchors.device
120
+ for i, name in enumerate(expert_names[:self.n_experts]):
121
+ info = procrustes_results[name]
122
+ self.expert_rotations[i].data = info["rotation"].float().to(device)
123
+ self.expert_means[i].data = info["source_mean"].float().to(device)
124
+ print(f" Expert {i} ({name}): rotation loaded, cos_after={info['cos_after']:.4f}")
125
+
126
+ if consensus_embeddings is not None:
127
+ n = min(self.n_anchors, consensus_embeddings.shape[0])
128
+ indices = torch.linspace(0, consensus_embeddings.shape[0] - 1, n).long()
129
+ self.anchors.data[:n] = F.normalize(
130
+ consensus_embeddings[indices].float(), dim=-1).to(device)
131
+ print(f" Anchors: {n} initialized from consensus embeddings")
132
+
133
+ def forward(self, embedding):
134
+ """
135
+ Annotate embedding with geometric context.
136
+
137
+ Args:
138
+ embedding: (B, 768) L2-normalized
139
+
140
+ Returns:
141
+ enriched: (B, 768 + d_bank)
142
+ aux: dict with geometric losses and diagnostics
143
+ """
144
+ B = embedding.shape[0]
145
+ emb = embedding.float()
146
+
147
+ # Per-expert: rotate into expert space, measure reconstruction quality
148
+ expert_consistency = [] # cosine between original and round-tripped
149
+ expert_recon = [] # MSE of round-trip
150
+ for i in range(self.n_experts):
151
+ R = self.expert_rotations[i]
152
+ # Forward rotation: consensus β†’ expert space
153
+ in_expert = emb @ R
154
+ # Backward rotation: expert space β†’ consensus
155
+ round_trip = in_expert @ R.T
156
+ # How well does round-trip recover original?
157
+ cos = F.cosine_similarity(emb, round_trip, dim=-1) # (B,)
158
+ recon = (emb - round_trip).pow(2).mean(dim=-1) # (B,)
159
+ expert_consistency.append(cos)
160
+ expert_recon.append(recon)
161
+
162
+ expert_cos = torch.stack(expert_consistency, dim=-1) # (B, n_experts)
163
+ expert_mse = torch.stack(expert_recon, dim=-1) # (B, n_experts)
164
+
165
+ # Anchor distances
166
+ anchors_n = F.normalize(self.anchors, dim=-1)
167
+ anchor_cos = emb @ anchors_n.T # (B, n_anchors)
168
+
169
+ # Geometric context vector
170
+ geo_input = torch.cat([expert_cos, anchor_cos, expert_mse], dim=-1)
171
+ geo_context = self.geo_proj(geo_input) # (B, d_bank)
172
+
173
+ # Enriched output
174
+ enriched = torch.cat([embedding, geo_context], dim=-1)
175
+
176
+ # ── Geometric losses ──
177
+ aux = {}
178
+
179
+ # 1. Expert agreement: all experts should see the embedding similarly
180
+ expert_mean = expert_cos.mean(dim=-1, keepdim=True)
181
+ aux["expert_agreement"] = (expert_cos - expert_mean).pow(2).mean()
182
+
183
+ # 2. Rotation orthogonality: rotations should stay orthogonal
184
+ ortho_loss = 0.0
185
+ for i in range(self.n_experts):
186
+ R = self.expert_rotations[i]
187
+ RRT = R @ R.T
188
+ ortho_loss += (RRT - torch.eye(self.d_embed, device=R.device)).pow(2).mean()
189
+ aux["rotation_ortho"] = ortho_loss / self.n_experts
190
+
191
+ # 3. Anchor spread: anchors should be well-distributed
192
+ anchor_sim = anchors_n @ anchors_n.T
193
+ anchor_sim.fill_diagonal_(0)
194
+ aux["anchor_spread"] = anchor_sim.pow(2).mean()
195
+
196
+ # 4. Anchor sharpness: each embedding should have clear nearest anchors
197
+ anchor_probs = F.softmax(anchor_cos * 10, dim=-1)
198
+ entropy = -(anchor_probs * (anchor_probs + 1e-12).log()).sum(-1).mean()
199
+ aux["anchor_entropy"] = entropy
200
+
201
+ # 5. Pentachoron CV of enriched space (sample from geo_context)
202
+ if B >= 10:
203
+ ctx_n = F.normalize(geo_context, dim=-1)
204
+ vols = []
205
+ for _ in range(32):
206
+ idx = torch.randperm(B, device=embedding.device)[:5]
207
+ pts = ctx_n[idx].unsqueeze(0)
208
+ diff = pts.unsqueeze(-2) - pts.unsqueeze(-3)
209
+ d2 = (diff * diff).sum(-1)
210
+ Bv, V, _ = d2.shape
211
+ cm = torch.zeros(Bv, V+1, V+1, device=d2.device, dtype=torch.float32)
212
+ cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
213
+ s = (-1.0)**V; f = math.factorial(V-1)
214
+ v2 = s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm)
215
+ vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12))
216
+ stacked = torch.stack(vols)
217
+ bank_cv = stacked.std() / (stacked.mean() + 1e-8)
218
+ aux["bank_cv"] = bank_cv
219
+ else:
220
+ aux["bank_cv"] = torch.tensor(0.0, device=embedding.device)
221
+
222
+ # Summary diagnostics
223
+ aux["expert_cos_mean"] = expert_cos.mean().item()
224
+ aux["expert_cos_std"] = expert_cos.std().item()
225
+ aux["anchor_max_cos"] = anchor_cos.max(dim=-1).values.mean().item()
226
+ aux["anchor_mean_cos"] = anchor_cos.mean().item()
227
+
228
+ return enriched, aux
229
+
230
+ def bank_loss(self, aux, cv_target=0.15):
231
+ """Combined bank training loss."""
232
+ loss = (1.0 * aux["expert_agreement"] +
233
+ 1.0 * aux["rotation_ortho"] +
234
+ 0.5 * aux["anchor_spread"] +
235
+ 0.1 * aux["anchor_entropy"] +
236
+ 0.3 * (aux["bank_cv"] - cv_target).abs())
237
+ return loss
238
+
239
+
240
+ # ══════════════════════════════════════════════════════════════════
241
+ # GEOMETRY
242
+ # ══════════════════════════════════════════════════════════════════
243
+
244
+ def infonce(a, b, temperature=0.07):
245
+ a = F.normalize(a, dim=-1)
246
+ b = F.normalize(b, dim=-1)
247
+ logits = (a @ b.T) / temperature
248
+ labels = torch.arange(logits.shape[0], device=logits.device)
249
+ loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2
250
+ with torch.no_grad():
251
+ acc = (logits.argmax(-1) == labels).float().mean().item()
252
+ return loss, acc
253
+
254
+ def cayley_menger_vol2(pts):
255
+ pts = pts.float()
256
+ diff = pts.unsqueeze(-2) - pts.unsqueeze(-3)
257
+ d2 = (diff * diff).sum(-1)
258
+ B, V, _ = d2.shape
259
+ cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32)
260
+ cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
261
+ s = (-1.0)**V; f = math.factorial(V-1)
262
+ return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm)
263
+
264
+ def cv_loss(emb, target=0.12, n_samples=16):
265
+ B = emb.shape[0]
266
+ if B < 5: return torch.tensor(0.0, device=emb.device)
267
+ vols = []
268
+ for _ in range(n_samples):
269
+ idx = torch.randperm(B, device=emb.device)[:5]
270
+ v2 = cayley_menger_vol2(emb[idx].unsqueeze(0))
271
+ vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12))
272
+ stacked = torch.stack(vols)
273
+ cv = stacked.std() / (stacked.mean() + 1e-8)
274
+ return (cv - target).abs()
275
+
276
+ def cv_metric(emb, n=200):
277
+ B = emb.shape[0]
278
+ if B < 5: return 0.0
279
+ vols = []
280
+ for _ in range(n):
281
+ idx = torch.randperm(B, device=emb.device)[:5]
282
+ v2 = cayley_menger_vol2(emb[idx].unsqueeze(0))
283
+ v = torch.sqrt(F.relu(v2[0]) + 1e-12).item()
284
+ if v > 0: vols.append(v)
285
+ if len(vols) < 10: return 0.0
286
+ a = np.array(vols)
287
+ return float(a.std() / (a.mean() + 1e-8))
288
+
289
+
290
+ # ══════════════════════════════════════════════════════════════════
291
+ # EXTRACTION + ALIGNMENT
292
+ # ══════════════════════════════════════════════════════════════════
293
+
294
+ def symmetric_inv_sqrt(cov, eps=1e-6):
295
+ evals, evecs = torch.linalg.eigh(cov)
296
+ evals = torch.clamp(evals, min=eps)
297
+ return evecs @ torch.diag(evals.rsqrt()) @ evecs.T
298
+
299
+ def procrustes_align(source, target, n_align=5000):
300
+ N = min(n_align, source.shape[0], target.shape[0])
301
+ S = source[:N].float()
302
+ T = target[:N].float()
303
+ s_mean = S.mean(0, keepdim=True)
304
+ t_mean = T.mean(0, keepdim=True)
305
+ Sc = S - s_mean; Tc = T - t_mean
306
+ N_s = Sc.shape[0]
307
+ cos_before = F.cosine_similarity(Sc, Tc, dim=-1).mean().item()
308
+ s_cov = (Sc.T @ Sc) / max(N_s - 1, 1)
309
+ t_cov = (Tc.T @ Tc) / max(N_s - 1, 1)
310
+ s_whiten = symmetric_inv_sqrt(s_cov)
311
+ t_whiten = symmetric_inv_sqrt(t_cov)
312
+ Sc_w = F.normalize(Sc @ s_whiten, dim=-1)
313
+ Tc_w = F.normalize(Tc @ t_whiten, dim=-1)
314
+ U, _, Vt = torch.linalg.svd(Tc_w.T @ Sc_w, full_matrices=False)
315
+ R = U @ Vt
316
+ cos_after = F.cosine_similarity(Sc_w @ R.T, Tc_w, dim=-1).mean().item()
317
+ return {
318
+ "rotation": R, "source_mean": s_mean.squeeze(0),
319
+ "source_whitener": s_whiten,
320
+ "target_unwhitener": torch.linalg.pinv(t_whiten),
321
+ "cos_before": cos_before, "cos_after": cos_after,
322
+ }
323
+
324
+ def apply_align(emb, a):
325
+ x = emb.float() - a["source_mean"]
326
+ x = x @ a["source_whitener"]
327
+ x = x @ a["rotation"].T
328
+ x = x @ a["target_unwhitener"]
329
+ return x
330
+
331
+
332
+ # ══════════════════════════════════════════════════════════════════
333
+ # MAIN
334
+ # ══════════════════════════════════════════════════════════════════
335
+
336
+ def run():
337
+ torch.manual_seed(42)
338
+ np.random.seed(42)
339
+ N_SAMPLES = 20000
340
+ MAX_LEN = 128
341
+ BATCH = 256
342
+
343
+ # ── Phase 0: Extract ──
344
+ print(f"\n{'='*65}")
345
+ print("PHASE 0: EXTRACTION")
346
+ print(f"{'='*65}")
347
+
348
+ from datasets import load_dataset
349
+ from transformers import AutoModel, AutoTokenizer
350
+
351
+ ds = load_dataset("CaptionEmporium/conceptual-captions-cc12m-llavanext",
352
+ split="train", streaming=True)
353
+ captions = []
354
+ for row in ds:
355
+ cap = row.get("caption_llava", "")
356
+ if isinstance(cap, str) and len(cap) > 50:
357
+ captions.append(cap)
358
+ if len(captions) >= N_SAMPLES:
359
+ break
360
+ print(f" Captions: {len(captions):,}")
361
+
362
+ embeds = {}
363
+ for model_name, short, max_len in EXPERTS:
364
+ print(f"\n Extracting: {short}...")
365
+ model = AutoModel.from_pretrained(model_name).to(DEVICE).eval()
366
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
367
+ all_emb = []
368
+ with torch.no_grad():
369
+ for i in tqdm(range(0, len(captions), 128), desc=f" {short}"):
370
+ batch = captions[i:i+128]
371
+ inputs = tokenizer(batch, max_length=max_len, padding=True,
372
+ truncation=True, return_tensors="pt").to(DEVICE)
373
+ out = model(**inputs)
374
+ m = inputs.attention_mask.unsqueeze(-1).float()
375
+ pooled = (out.last_hidden_state * m).sum(1) / m.sum(1).clamp(min=1)
376
+ all_emb.append(pooled.cpu())
377
+ embeds[short] = torch.cat(all_emb)
378
+ print(f" Shape: {embeds[short].shape}")
379
+ del model; gc.collect(); torch.cuda.empty_cache()
380
+
381
+ # ── Phase 0b: Align + Consensus ──
382
+ print(f"\n{'='*65}")
383
+ print("PHASE 0b: PROCRUSTES ALIGNMENT")
384
+ print(f"{'='*65}")
385
+
386
+ ref = "bert"
387
+ names = [s for _, s, _ in EXPERTS]
388
+ procrustes_results = {}
389
+ aligned = {}
390
+ for name in names:
391
+ info = procrustes_align(embeds[name], embeds[ref])
392
+ procrustes_results[name] = info
393
+ aligned[name] = apply_align(embeds[name], info)
394
+ print(f" {name:10s}: cos {info['cos_before']:.4f} β†’ {info['cos_after']:.4f}")
395
+
396
+ consensus = F.normalize(sum(aligned[n] for n in names) / len(names), dim=-1)
397
+ print(f" Consensus: {consensus.shape}")
398
+ for name in names:
399
+ cos = F.cosine_similarity(consensus[:2000], aligned[name][:2000], dim=-1).mean().item()
400
+ print(f" cos(consensus, {name}): {cos:.4f}")
401
+
402
+ consensus_cv = cv_metric(consensus[:2000].to(DEVICE))
403
+ print(f" Consensus CV: {consensus_cv:.4f}")
404
+
405
+ del embeds, aligned
406
+ gc.collect(); torch.cuda.empty_cache()
407
+
408
+ # ── Phase 1: Train Student ──
409
+ print(f"\n{'='*65}")
410
+ print("PHASE 1: TRAIN STUDENT (2 experts, 20K captions)")
411
+ print(f"{'='*65}")
412
+
413
+ tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
414
+ tokens = tokenizer(captions, max_length=MAX_LEN, padding="max_length",
415
+ truncation=True, return_tensors="pt")
416
+ input_ids = tokens["input_ids"]
417
+ attention_mask = tokens["attention_mask"]
418
+
419
+ n_train = N_SAMPLES - 2000
420
+ train_ids = input_ids[:n_train].to(DEVICE)
421
+ train_mask = attention_mask[:n_train].to(DEVICE)
422
+ train_targets = consensus[:n_train].to(DEVICE)
423
+ val_ids = input_ids[n_train:].to(DEVICE)
424
+ val_mask = attention_mask[n_train:].to(DEVICE)
425
+ val_targets = consensus[n_train:].to(DEVICE)
426
+
427
+ student = MiniStudent(
428
+ vocab_size=tokenizer.vocab_size, max_len=MAX_LEN,
429
+ d_model=256, n_heads=4, n_layers=4, d_ff=1024,
430
+ output_dim=768, dropout=0.1, pad_token_id=tokenizer.pad_token_id
431
+ ).to(DEVICE)
432
+ n_params = sum(p.numel() for p in student.parameters())
433
+ print(f" Student: {n_params:,} params")
434
+
435
+ optimizer = torch.optim.AdamW(student.parameters(), lr=3e-4, weight_decay=0.01)
436
+
437
+ for epoch in range(5):
438
+ student.train()
439
+ perm = torch.randperm(n_train, device=DEVICE)
440
+ t_loss, t_acc, t_cos, n = 0, 0, 0, 0
441
+ t0 = time.time()
442
+
443
+ for i in range(0, n_train, BATCH):
444
+ idx = perm[i:i+BATCH]
445
+ if len(idx) < 8: continue
446
+ emb = student(train_ids[idx], train_mask[idx])
447
+ tgt = train_targets[idx]
448
+ l_nce, acc = infonce(emb, tgt)
449
+ l_mse = F.mse_loss(emb, tgt)
450
+ l_cv = cv_loss(emb, target=consensus_cv)
451
+ loss = l_nce + l_mse + 0.1 * l_cv
452
+ loss.backward()
453
+ torch.nn.utils.clip_grad_norm_(student.parameters(), 1.0)
454
+ optimizer.step(); optimizer.zero_grad(set_to_none=True)
455
+ with torch.no_grad():
456
+ cos = F.cosine_similarity(emb, tgt, dim=-1).mean().item()
457
+ t_loss += loss.item(); t_acc += acc; t_cos += cos; n += 1
458
+
459
+ elapsed = time.time() - t0
460
+ d = max(n, 1)
461
+ student.eval()
462
+ with torch.no_grad():
463
+ v_emb = student(val_ids, val_mask)
464
+ _, v_acc = infonce(v_emb[:1000], val_targets[:1000])
465
+ v_cos = F.cosine_similarity(v_emb, val_targets, dim=-1).mean().item()
466
+ v_cv = cv_metric(v_emb[:1000])
467
+
468
+ print(f" E{epoch+1}: {elapsed:.0f}s loss={t_loss/d:.4f} "
469
+ f"t_acc={t_acc/d:.3f} t_cos={t_cos/d:.3f} "
470
+ f"v_acc={v_acc:.3f} v_cos={v_cos:.3f} v_cv={v_cv:.3f}")
471
+
472
+ # Save student
473
+ torch.save(student.state_dict(), "mini_student.pt")
474
+ print(f"\n Student saved. v_cos={v_cos:.3f}, v_cv={v_cv:.3f}")
475
+
476
+ # ── Phase 2: Train Alignment Bank ──
477
+ print(f"\n{'='*65}")
478
+ print("PHASE 2: TRAIN ALIGNMENT BANK (student frozen)")
479
+ print(f"{'='*65}")
480
+
481
+ # Freeze student
482
+ student.eval()
483
+ for p in student.parameters():
484
+ p.requires_grad = False
485
+
486
+ # Pre-encode everything through frozen student
487
+ print(" Pre-encoding through frozen student...")
488
+ with torch.no_grad():
489
+ all_embs = []
490
+ for i in range(0, n_train, 512):
491
+ j = min(i + 512, n_train)
492
+ emb = student(train_ids[i:j], train_mask[i:j])
493
+ all_embs.append(emb)
494
+ student_embs = torch.cat(all_embs) # (n_train, 768)
495
+ val_student_embs = student(val_ids, val_mask)
496
+
497
+ print(f" Student embeddings: {student_embs.shape}")
498
+
499
+ # Build bank
500
+ bank = AlignmentBank(
501
+ d_embed=768, n_experts=len(EXPERTS),
502
+ n_anchors=128, d_bank=64
503
+ ).to(DEVICE)
504
+
505
+ bank.init_from_procrustes(procrustes_results, names, consensus[:n_train])
506
+ bank_params = sum(p.numel() for p in bank.parameters())
507
+ print(f" Bank: {bank_params:,} params")
508
+
509
+ bank_opt = torch.optim.AdamW(bank.parameters(), lr=1e-3, weight_decay=0.01)
510
+ BANK_EPOCHS = 20
511
+ BANK_BATCH = 256
512
+
513
+ for epoch in range(BANK_EPOCHS):
514
+ bank.train()
515
+ perm = torch.randperm(n_train, device=DEVICE)
516
+ total_loss = 0
517
+ stats = {"expert_agreement": 0, "rotation_ortho": 0,
518
+ "anchor_spread": 0, "bank_cv": 0}
519
+ n = 0
520
+ t0 = time.time()
521
+
522
+ for i in range(0, n_train, BANK_BATCH):
523
+ idx = perm[i:i+BANK_BATCH]
524
+ if len(idx) < 16: continue
525
+
526
+ emb = student_embs[idx]
527
+ enriched, aux = bank(emb)
528
+ loss = bank.bank_loss(aux, cv_target=consensus_cv + 0.02)
529
+
530
+ loss.backward()
531
+ torch.nn.utils.clip_grad_norm_(bank.parameters(), 1.0)
532
+ bank_opt.step(); bank_opt.zero_grad(set_to_none=True)
533
+
534
+ total_loss += loss.item()
535
+ for k in stats:
536
+ if k in aux:
537
+ v = aux[k]
538
+ stats[k] += v.item() if torch.is_tensor(v) else v
539
+ n += 1
540
+
541
+ elapsed = time.time() - t0
542
+ d = max(n, 1)
543
+
544
+ # Validation
545
+ bank.eval()
546
+ with torch.no_grad():
547
+ v_enriched, v_aux = bank(val_student_embs)
548
+ v_loss = bank.bank_loss(v_aux, cv_target=consensus_cv + 0.02).item()
549
+
550
+ print(f" E{epoch+1:2d}: {elapsed:.0f}s loss={total_loss/d:.4f} "
551
+ f"v_loss={v_loss:.4f} "
552
+ f"expert_agr={stats['expert_agreement']/d:.5f} "
553
+ f"ortho={stats['rotation_ortho']/d:.5f} "
554
+ f"spread={stats['anchor_spread']/d:.5f} "
555
+ f"cv={stats['bank_cv']/d:.4f} "
556
+ f"anchor_max={v_aux['anchor_max_cos']:.3f} "
557
+ f"expert_cos={v_aux['expert_cos_mean']:.3f}Β±{v_aux['expert_cos_std']:.3f}")
558
+
559
+ torch.save(bank.state_dict(), "alignment_bank.pt")
560
+
561
+ # ── Phase 3: Verify Geometry ──
562
+ print(f"\n{'='*65}")
563
+ print("PHASE 3: GEOMETRIC VERIFICATION")
564
+ print(f"{'='*65}")
565
+
566
+ bank.eval()
567
+ with torch.no_grad():
568
+ # Check that enriched embeddings preserve original structure
569
+ enriched_val, _ = bank(val_student_embs)
570
+ original_768 = enriched_val[:, :768] # first 768 dims = original embedding
571
+ geo_context = enriched_val[:, 768:] # last d_bank dims = geometric annotation
572
+
573
+ # Original embedding should be unchanged (passthrough)
574
+ passthrough_cos = F.cosine_similarity(
575
+ original_768[:100], val_student_embs[:100], dim=-1).mean().item()
576
+
577
+ # Geometric context should be informative
578
+ geo_cv = cv_metric(F.normalize(geo_context[:1000], dim=-1))
579
+ geo_eff_dim = torch.linalg.svdvals(
580
+ geo_context[:1000].float() - geo_context[:1000].float().mean(0)).pow(2)
581
+ geo_eff_dim = (geo_eff_dim.sum() ** 2) / (geo_eff_dim.pow(2).sum() + 1e-12)
582
+
583
+ print(f" Passthrough integrity: {passthrough_cos:.6f} (should be ~1.000)")
584
+ print(f" Geo context CV: {geo_cv:.4f}")
585
+ print(f" Geo context eff_dim: {geo_eff_dim:.1f}")
586
+ print(f" Geo context shape: {geo_context.shape}")
587
+
588
+ # ── Phase 4: Quick Classifier Test ──
589
+ print(f"\n{'='*65}")
590
+ print("PHASE 4: CLASSIFIER STABILITY TEST")
591
+ print(f"{'='*65}")
592
+
593
+ # Create synthetic 3-class task from similarity structure
594
+ # Class 0: high consensus cosine pairs (similar)
595
+ # Class 1: medium consensus cosine pairs
596
+ # Class 2: low consensus cosine pairs (different)
597
+ with torch.no_grad():
598
+ # Generate synthetic labels from embedding distances
599
+ embs = val_student_embs[:1000]
600
+ sim = embs @ embs.T
601
+ sim.fill_diagonal_(-1) # exclude self
602
+
603
+ # Random pairs
604
+ n_pairs = 3000
605
+ idx_a = torch.randint(0, 1000, (n_pairs,))
606
+ idx_b = torch.randint(0, 1000, (n_pairs,))
607
+ pair_cos = sim[idx_a, idx_b]
608
+
609
+ # Assign labels by cosine terciles
610
+ sorted_cos, _ = pair_cos.sort()
611
+ t1 = sorted_cos[n_pairs // 3].item()
612
+ t2 = sorted_cos[2 * n_pairs // 3].item()
613
+ labels = torch.zeros(n_pairs, dtype=torch.long, device=DEVICE)
614
+ labels[pair_cos > t2] = 0 # similar
615
+ labels[(pair_cos <= t2) & (pair_cos > t1)] = 1 # medium
616
+ labels[pair_cos <= t1] = 2 # different
617
+
618
+ # Get enriched representations
619
+ enriched_a, _ = bank(embs[idx_a])
620
+ enriched_b, _ = bank(embs[idx_b])
621
+
622
+ # Train tiny classifier: with bank vs without bank
623
+ for mode in ["with_bank", "without_bank"]:
624
+ if mode == "with_bank":
625
+ feat_dim = (768 + 64) * 2 # enriched
626
+ features = torch.cat([enriched_a, enriched_b], dim=-1)
627
+ else:
628
+ feat_dim = 768 * 2 # raw
629
+ features = torch.cat([embs[idx_a], embs[idx_b]], dim=-1)
630
+
631
+ clf = nn.Sequential(
632
+ nn.Linear(feat_dim, 128), nn.GELU(),
633
+ nn.Linear(128, 3)
634
+ ).to(DEVICE)
635
+
636
+ clf_opt = torch.optim.Adam(clf.parameters(), lr=1e-3)
637
+ n_clf_train = 2400
638
+ train_f = features[:n_clf_train].detach()
639
+ train_l = labels[:n_clf_train]
640
+ val_f = features[n_clf_train:].detach()
641
+ val_l = labels[n_clf_train:]
642
+
643
+ for e in range(20):
644
+ clf.train()
645
+ logits = clf(train_f)
646
+ loss = F.cross_entropy(logits, train_l)
647
+ loss.backward()
648
+ clf_opt.step(); clf_opt.zero_grad()
649
+
650
+ clf.eval()
651
+ with torch.no_grad():
652
+ val_logits = clf(val_f)
653
+ val_acc = (val_logits.argmax(-1) == val_l).float().mean().item()
654
+ train_logits = clf(train_f)
655
+ train_acc = (train_logits.argmax(-1) == train_l).float().mean().item()
656
+
657
+ print(f" {mode:15s}: train_acc={train_acc:.3f} val_acc={val_acc:.3f} "
658
+ f"gap={train_acc-val_acc:.3f}")
659
+
660
+ print(f"\n{'='*65}")
661
+ print("DONE")
662
+ print(f"{'='*65}")
663
+ print(f"\n Student: mini_student.pt")
664
+ print(f" Bank: alignment_bank.pt")
665
+ print(f" Consensus CV: {consensus_cv:.4f}")
666
+ print(f" Student v_cos: {v_cos:.3f}")
667
+
668
+
669
+ if __name__ == "__main__":
670
+ run()