AbstractPhil commited on
Commit
7cc3d76
·
verified ·
1 Parent(s): 04ba8f5

Create prototype_55_geodesic_bank_multitest.py

Browse files
prototype_55_geodesic_bank_multitest.py ADDED
@@ -0,0 +1,924 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================================
2
+ # RAPID PROTOTYPE v2: Differentiation-Centered Alignment Bank
3
+ #
4
+ # The bank aligns to the DIFFERENTIATION between experts, not to any
5
+ # arbitrary target. The consensus CV, spectral profile, and pairwise
6
+ # statistics measured during alignment become the exact targets.
7
+ #
8
+ # The bank embodies the centerpoint of expert disagreement.
9
+ # ============================================================================
10
+
11
+ import gc
12
+ import math
13
+ import os
14
+ import time
15
+ import json
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from tqdm import tqdm
22
+
23
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
24
+
25
+ EXPERTS = [
26
+ ("google-bert/bert-base-uncased", "bert", 512),
27
+ ("answerdotai/ModernBERT-base", "modern", 512),
28
+ ]
29
+
30
+ print("=" * 65)
31
+ print("RAPID PROTOTYPE v2: Differentiation-Centered Bank")
32
+ print("=" * 65)
33
+ print(f" Device: {DEVICE}")
34
+
35
+
36
+ # ══════════════════════════════════════════════════════════════════
37
+ # STUDENT MODEL
38
+ # ══════════════════════════════════════════════════════════════════
39
+
40
+ class MiniStudent(nn.Module):
41
+ def __init__(self, vocab_size=30522, max_len=512, d_model=256,
42
+ n_heads=4, n_layers=4, d_ff=1024, output_dim=768,
43
+ dropout=0.1, pad_token_id=0):
44
+ super().__init__()
45
+ self.pad_token_id = pad_token_id
46
+ self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_token_id)
47
+ self.pos_emb = nn.Embedding(max_len, d_model)
48
+ self.emb_norm = nn.LayerNorm(d_model)
49
+ self.emb_drop = nn.Dropout(dropout)
50
+ encoder_layer = nn.TransformerEncoderLayer(
51
+ d_model=d_model, nhead=n_heads, dim_feedforward=d_ff,
52
+ dropout=dropout, activation="gelu", batch_first=True,
53
+ norm_first=True)
54
+ self.encoder = nn.TransformerEncoder(
55
+ encoder_layer, num_layers=n_layers, enable_nested_tensor=False)
56
+ self.output_proj = nn.Sequential(
57
+ nn.Linear(d_model, d_model), nn.GELU(),
58
+ nn.LayerNorm(d_model), nn.Linear(d_model, output_dim))
59
+
60
+ def forward(self, input_ids, attention_mask=None):
61
+ B, L = input_ids.shape
62
+ positions = torch.arange(L, device=input_ids.device).unsqueeze(0)
63
+ x = self.token_emb(input_ids) + self.pos_emb(positions)
64
+ x = self.emb_drop(self.emb_norm(x))
65
+ kpm = ~attention_mask.bool() if attention_mask is not None else (input_ids == self.pad_token_id)
66
+ x = self.encoder(x, src_key_padding_mask=kpm)
67
+ mask = attention_mask.unsqueeze(-1).float() if attention_mask is not None else (~kpm).unsqueeze(-1).float()
68
+ pooled = (x * mask).sum(1) / mask.sum(1).clamp(min=1)
69
+ return F.normalize(self.output_proj(pooled), dim=-1)
70
+
71
+
72
+ # ══════════════════════════════════════════════════════════════════
73
+ # ALIGNMENT BANK
74
+ # ══════════════════════════════════════════════════════════════════
75
+
76
+ class AlignmentBank(nn.Module):
77
+ """
78
+ Differentiation-centered geometric interface.
79
+
80
+ Aligns to the CENTERPOINT between experts — the consensus itself.
81
+ Stores per-expert rotation matrices (the differentiation structure)
82
+ and learned anchor landmarks (the consensus manifold topology).
83
+
84
+ The bank doesn't invent geometry. It mirrors the measured consensus.
85
+ Every loss term pulls toward measured consensus statistics.
86
+ """
87
+ def __init__(self, d_embed=768, n_experts=2, n_anchors=512, d_bank=128):
88
+ super().__init__()
89
+ self.d_embed = d_embed
90
+ self.n_experts = n_experts
91
+ self.n_anchors = n_anchors
92
+ self.d_bank = d_bank
93
+
94
+ # Per-expert rotation matrices (differentiation structure)
95
+ self.expert_rotations = nn.ParameterList([
96
+ nn.Parameter(torch.eye(d_embed)) for _ in range(n_experts)
97
+ ])
98
+
99
+ # Per-expert whiteners (captures variance structure per expert)
100
+ self.expert_whiteners = nn.ParameterList([
101
+ nn.Parameter(torch.eye(d_embed)) for _ in range(n_experts)
102
+ ])
103
+
104
+ # Per-expert means (centering offset per expert)
105
+ self.expert_means = nn.ParameterList([
106
+ nn.Parameter(torch.zeros(d_embed)) for _ in range(n_experts)
107
+ ])
108
+
109
+ # Anchor bank: consensus landmarks on the hypersphere
110
+ self.anchors = nn.Parameter(
111
+ F.normalize(torch.randn(n_anchors, d_embed), dim=-1))
112
+
113
+ # Project: expert_cos (n) + expert_mse (n) + cross (n*(n-1)/2) +
114
+ # disagreement_ratio (1) + norm_ratio (n) + anchor_cos (n_anchors)
115
+ n_cross = n_experts * (n_experts - 1) // 2
116
+ geo_dim = n_experts + n_experts + n_cross + 1 + n_experts + n_anchors
117
+ self.geo_proj = nn.Sequential(
118
+ nn.Linear(geo_dim, d_bank * 2),
119
+ nn.GELU(),
120
+ nn.LayerNorm(d_bank * 2),
121
+ nn.Linear(d_bank * 2, d_bank),
122
+ nn.LayerNorm(d_bank),
123
+ )
124
+
125
+ # Consensus statistics (set during init, used as exact targets)
126
+ self.register_buffer("target_cv", torch.tensor(0.12))
127
+ self.register_buffer("target_mean_cos", torch.tensor(0.0))
128
+ self.register_buffer("target_spectral", torch.zeros(50))
129
+ # Disagreement structure (measured once, preserved forever)
130
+ self.register_buffer("target_cross_cos_mean", torch.tensor(0.0))
131
+ self.register_buffer("target_cross_cos_std", torch.tensor(0.0))
132
+ self.register_buffer("target_disagreement_ratio", torch.tensor(0.0))
133
+
134
+ def init_from_procrustes(self, procrustes_results, expert_names,
135
+ consensus_embeddings=None,
136
+ consensus_stats=None):
137
+ """Initialize from consensus training artifacts."""
138
+ device = self.anchors.device
139
+ for i, name in enumerate(expert_names[:self.n_experts]):
140
+ info = procrustes_results[name]
141
+ self.expert_rotations[i].data = info["rotation"].float().to(device)
142
+ if "source_whitener" in info:
143
+ self.expert_whiteners[i].data = info["source_whitener"].float().to(device)
144
+ if "source_mean" in info:
145
+ self.expert_means[i].data = info["source_mean"].float().to(device)
146
+ print(f" Expert {i} ({name}): rotation + whitener + mean loaded, "
147
+ f"cos_after={info['cos_after']:.4f}")
148
+
149
+ if consensus_embeddings is not None:
150
+ n = min(self.n_anchors, consensus_embeddings.shape[0])
151
+ indices = torch.linspace(0, consensus_embeddings.shape[0] - 1, n).long()
152
+ self.anchors.data[:n] = F.normalize(
153
+ consensus_embeddings[indices].float(), dim=-1).to(device)
154
+ print(f" Anchors: {n} initialized from consensus embeddings")
155
+
156
+ if consensus_stats is not None:
157
+ self.target_cv.fill_(consensus_stats["cv"])
158
+ self.target_mean_cos.fill_(consensus_stats["mean_cos"])
159
+ if "spectral" in consensus_stats:
160
+ s = torch.tensor(consensus_stats["spectral"][:50], dtype=torch.float32)
161
+ self.target_spectral[:len(s)] = s.to(device)
162
+ print(f" Targets: CV={consensus_stats['cv']:.4f}, "
163
+ f"mean_cos={consensus_stats['mean_cos']:.4f}")
164
+
165
+ def forward(self, embedding):
166
+ B = embedding.shape[0]
167
+ emb = embedding.float()
168
+
169
+ # ── Per-expert projections (full whitened Procrustes) ──
170
+ # Chain: center → whiten → normalize → rotate
171
+ # This is EXACTLY what was computed during alignment.
172
+ # The rotation only makes geometric sense in whitened-normalized space.
173
+ expert_consistency = []
174
+ expert_recon = []
175
+ expert_projected = []
176
+ for i in range(self.n_experts):
177
+ R = self.expert_rotations[i]
178
+ W = self.expert_whiteners[i]
179
+ mu = self.expert_means[i]
180
+
181
+ # Forward: center → whiten → normalize → rotate
182
+ centered = emb - mu
183
+ whitened = centered @ W
184
+ whitened_n = F.normalize(whitened, dim=-1)
185
+ in_expert = whitened_n @ R.T # now in expert's whitened-normalized space
186
+
187
+ # Round-trip: rotate back (orthogonal, so R.T inverse = R)
188
+ back = in_expert @ R
189
+
190
+ # Consistency: round-trip should recover whitened_n exactly
191
+ cos = F.cosine_similarity(whitened_n, back, dim=-1)
192
+ recon = (whitened_n - back).pow(2).mean(dim=-1)
193
+
194
+ expert_consistency.append(cos)
195
+ expert_recon.append(recon)
196
+ expert_projected.append(in_expert)
197
+
198
+ expert_cos = torch.stack(expert_consistency, dim=-1) # (B, n_experts)
199
+ expert_mse = torch.stack(expert_recon, dim=-1) # (B, n_experts)
200
+
201
+ # ── Cross-expert differentiation ──
202
+ # How each expert's projection relates to every other expert's projection
203
+ # This IS the disagreement structure. Preserve it exactly.
204
+ cross_cos = []
205
+ for i in range(self.n_experts):
206
+ for j in range(i + 1, self.n_experts):
207
+ cc = F.cosine_similarity(
208
+ expert_projected[i], expert_projected[j], dim=-1)
209
+ cross_cos.append(cc)
210
+ cross_features = torch.stack(cross_cos, dim=-1) if cross_cos else torch.zeros(B, 0, device=emb.device)
211
+
212
+ # Per-sample disagreement: how much do experts disagree on THIS embedding?
213
+ # High disagreement = embedding is in contested territory
214
+ # Low disagreement = all experts agree (well-anchored)
215
+ per_sample_agreement = expert_cos.mean(dim=-1) # (B,) mean round-trip cos
216
+ per_sample_disagreement = expert_cos.std(dim=-1) # (B,) std across experts
217
+ # Ratio: how much agreement relative to disagreement
218
+ disagreement_ratio = per_sample_disagreement / (per_sample_agreement + 1e-8) # (B,)
219
+
220
+ # Expert projection norms before normalization (captures magnitude structure)
221
+ expert_norms = []
222
+ for i in range(self.n_experts):
223
+ R = self.expert_rotations[i]
224
+ W = self.expert_whiteners[i]
225
+ mu = self.expert_means[i]
226
+ centered = emb - mu
227
+ whitened = centered @ W
228
+ expert_norms.append(whitened.norm(dim=-1)) # (B,)
229
+ expert_norm_features = torch.stack(expert_norms, dim=-1) # (B, n_experts)
230
+ norm_ratio = expert_norm_features / (expert_norm_features.mean(dim=-1, keepdim=True) + 1e-8)
231
+
232
+ # ── Anchor distances ──
233
+ anchors_n = F.normalize(self.anchors, dim=-1)
234
+ anchor_cos = emb @ anchors_n.T # (B, n_anchors)
235
+
236
+ # ── Geometric context ──
237
+ # Full feature set: expert consistency + reconstruction + cross-expert +
238
+ # disagreement ratio + norm ratios + anchor distances
239
+ geo_input = torch.cat([
240
+ expert_cos, # (B, n_experts)
241
+ expert_mse, # (B, n_experts)
242
+ cross_features, # (B, n_cross)
243
+ disagreement_ratio.unsqueeze(-1), # (B, 1)
244
+ norm_ratio, # (B, n_experts)
245
+ anchor_cos, # (B, n_anchors)
246
+ ], dim=-1)
247
+ geo_context = self.geo_proj(geo_input)
248
+
249
+ enriched = torch.cat([embedding, geo_context], dim=-1)
250
+
251
+ # ── Losses + Diagnostics ──
252
+ aux = {}
253
+
254
+ # 1. Expert agreement: all experts should see embedding equally
255
+ expert_mean = expert_cos.mean(dim=-1, keepdim=True)
256
+ aux["expert_agreement"] = (expert_cos - expert_mean).pow(2).mean()
257
+
258
+ # 2. Rotation orthogonality
259
+ ortho_loss = 0.0
260
+ for i in range(self.n_experts):
261
+ R = self.expert_rotations[i]
262
+ RRT = R @ R.T
263
+ ortho_loss += (RRT - torch.eye(self.d_embed, device=R.device)).pow(2).mean()
264
+ aux["rotation_ortho"] = ortho_loss / self.n_experts
265
+
266
+ # 3. Anchor spread
267
+ anchor_sim = anchors_n @ anchors_n.T
268
+ anchor_sim.fill_diagonal_(0)
269
+ aux["anchor_spread"] = anchor_sim.pow(2).mean()
270
+
271
+ # 4. Anchor sharpness
272
+ anchor_probs = F.softmax(anchor_cos * 10, dim=-1)
273
+ entropy = -(anchor_probs * (anchor_probs + 1e-12).log()).sum(-1).mean()
274
+ aux["anchor_entropy"] = entropy
275
+
276
+ # 5. Cross-expert differentiation consistency
277
+ if cross_features.shape[1] > 0:
278
+ aux["cross_expert_var"] = cross_features.var(dim=0).mean()
279
+ else:
280
+ aux["cross_expert_var"] = torch.tensor(0.0, device=emb.device)
281
+
282
+ # 6. Disagreement preservation
283
+ # The distribution of disagreement should stay at the measured target
284
+ batch_cross_mean = cross_features.mean() if cross_features.shape[1] > 0 else torch.tensor(0.0, device=emb.device)
285
+ batch_cross_std = cross_features.std() if cross_features.shape[1] > 0 else torch.tensor(0.0, device=emb.device)
286
+ batch_disagree_ratio = disagreement_ratio.mean()
287
+ aux["disagree_preserve"] = (
288
+ (batch_cross_mean - self.target_cross_cos_mean).pow(2) +
289
+ (batch_cross_std - self.target_cross_cos_std).pow(2) +
290
+ (batch_disagree_ratio - self.target_disagreement_ratio).pow(2)
291
+ )
292
+
293
+ # 7. Bank CV
294
+ if B >= 10:
295
+ ctx_n = F.normalize(geo_context, dim=-1)
296
+ vols = []
297
+ for _ in range(32):
298
+ idx = torch.randperm(B, device=embedding.device)[:5]
299
+ pts = ctx_n[idx].unsqueeze(0)
300
+ diff = pts.unsqueeze(-2) - pts.unsqueeze(-3)
301
+ d2 = (diff * diff).sum(-1)
302
+ Bv, V, _ = d2.shape
303
+ cm = torch.zeros(Bv, V+1, V+1, device=d2.device, dtype=torch.float32)
304
+ cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
305
+ s = (-1.0)**V; f = math.factorial(V-1)
306
+ v2 = s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm)
307
+ vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12))
308
+ stacked = torch.stack(vols)
309
+ bank_cv = stacked.std() / (stacked.mean() + 1e-8)
310
+ aux["bank_cv"] = bank_cv
311
+ else:
312
+ aux["bank_cv"] = torch.tensor(0.0, device=embedding.device)
313
+
314
+ # 8. Emb CV
315
+ if B >= 10:
316
+ emb_n = F.normalize(emb, dim=-1)
317
+ vols = []
318
+ for _ in range(32):
319
+ idx = torch.randperm(B, device=embedding.device)[:5]
320
+ pts = emb_n[idx].unsqueeze(0)
321
+ diff = pts.unsqueeze(-2) - pts.unsqueeze(-3)
322
+ d2 = (diff * diff).sum(-1)
323
+ Bv, V, _ = d2.shape
324
+ cm = torch.zeros(Bv, V+1, V+1, device=d2.device, dtype=torch.float32)
325
+ cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
326
+ s = (-1.0)**V; f = math.factorial(V-1)
327
+ v2 = s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm)
328
+ vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12))
329
+ stacked = torch.stack(vols)
330
+ emb_cv = stacked.std() / (stacked.mean() + 1e-8)
331
+ aux["emb_cv"] = emb_cv
332
+ else:
333
+ aux["emb_cv"] = torch.tensor(0.0, device=embedding.device)
334
+
335
+ # Diagnostics
336
+ aux["expert_cos_mean"] = expert_cos.mean().item()
337
+ aux["expert_cos_std"] = expert_cos.std().item()
338
+ aux["anchor_max_cos"] = anchor_cos.max(dim=-1).values.mean().item()
339
+ aux["anchor_mean_cos"] = anchor_cos.mean().item()
340
+ if cross_features.shape[1] > 0:
341
+ aux["cross_expert_cos"] = cross_features.mean().item()
342
+ aux["cross_expert_cos_std"] = cross_features.std().item()
343
+ aux["disagreement_ratio"] = disagreement_ratio.mean().item()
344
+ aux["norm_ratio_spread"] = norm_ratio.std(dim=-1).mean().item()
345
+
346
+ return enriched, aux
347
+
348
+ def bank_loss(self, aux):
349
+ """All targets from measured consensus. Preserves disagreement structure."""
350
+ loss = (
351
+ 1.0 * aux["expert_agreement"] +
352
+ 1.0 * aux["rotation_ortho"] +
353
+ 0.5 * aux["anchor_spread"] +
354
+ 0.1 * aux["anchor_entropy"] +
355
+ 0.3 * aux["cross_expert_var"] +
356
+ 0.3 * (aux["bank_cv"] - self.target_cv).abs() +
357
+ 0.3 * (aux["emb_cv"] - self.target_cv).abs() +
358
+ 0.5 * aux["disagree_preserve"] # preserve the disagreement distribution
359
+ )
360
+ return loss
361
+
362
+ @torch.no_grad()
363
+ def calibrate_disagreement(self, embeddings):
364
+ """
365
+ Measure the initial disagreement structure from per-sample distribution.
366
+ Uses the full batch to capture the spread, not just the mean.
367
+ """
368
+ B = embeddings.shape[0]
369
+ emb = embeddings.float()
370
+
371
+ # Compute per-sample disagreement directly
372
+ per_sample_expert_cos = []
373
+ for i in range(self.n_experts):
374
+ R = self.expert_rotations[i]
375
+ W = self.expert_whiteners[i]
376
+ mu = self.expert_means[i]
377
+ centered = emb - mu
378
+ whitened = centered @ W
379
+ whitened_n = F.normalize(whitened, dim=-1)
380
+ in_expert = whitened_n @ R.T
381
+ back = in_expert @ R
382
+ cos = F.cosine_similarity(whitened_n, back, dim=-1)
383
+ per_sample_expert_cos.append(cos)
384
+
385
+ expert_cos = torch.stack(per_sample_expert_cos, dim=-1) # (B, n_experts)
386
+ per_sample_agreement = expert_cos.mean(dim=-1)
387
+ per_sample_disagreement = expert_cos.std(dim=-1)
388
+ per_sample_ratio = per_sample_disagreement / (per_sample_agreement + 1e-8)
389
+
390
+ # Cross-expert cosines
391
+ cross_vals = []
392
+ expert_projected = []
393
+ for i in range(self.n_experts):
394
+ R = self.expert_rotations[i]
395
+ W = self.expert_whiteners[i]
396
+ mu = self.expert_means[i]
397
+ centered = emb - mu
398
+ whitened = centered @ W
399
+ whitened_n = F.normalize(whitened, dim=-1)
400
+ expert_projected.append(whitened_n @ R.T)
401
+
402
+ for i in range(self.n_experts):
403
+ for j in range(i + 1, self.n_experts):
404
+ cc = F.cosine_similarity(expert_projected[i], expert_projected[j], dim=-1)
405
+ cross_vals.append(cc)
406
+
407
+ if cross_vals:
408
+ cross_all = torch.stack(cross_vals, dim=-1)
409
+ self.target_cross_cos_mean.fill_(cross_all.mean().item())
410
+ self.target_cross_cos_std.fill_(cross_all.std().item())
411
+
412
+ # Use MEDIAN of per-sample ratio (robust to outliers)
413
+ self.target_disagreement_ratio.fill_(per_sample_ratio.median().item())
414
+
415
+ print(f" Calibrated disagreement (n={B}):")
416
+ print(f" cross_cos: {self.target_cross_cos_mean.item():.4f} ± {self.target_cross_cos_std.item():.4f}")
417
+ print(f" disagree_ratio: median={self.target_disagreement_ratio.item():.6f} "
418
+ f"mean={per_sample_ratio.mean().item():.6f} "
419
+ f"std={per_sample_ratio.std().item():.6f}")
420
+ print(f" expert_cos: {expert_cos.mean().item():.4f} ± {expert_cos.std().item():.4f}")
421
+
422
+
423
+ # ══════════════════════════════════════════════════════════════════
424
+ # GEOMETRY
425
+ # ══════════════════════════════════════════════════════════════════
426
+
427
+ def infonce(a, b, temperature=0.07):
428
+ a = F.normalize(a, dim=-1)
429
+ b = F.normalize(b, dim=-1)
430
+ logits = (a @ b.T) / temperature
431
+ labels = torch.arange(logits.shape[0], device=logits.device)
432
+ loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2
433
+ with torch.no_grad():
434
+ acc = (logits.argmax(-1) == labels).float().mean().item()
435
+ return loss, acc
436
+
437
+ def cayley_menger_vol2(pts):
438
+ pts = pts.float()
439
+ diff = pts.unsqueeze(-2) - pts.unsqueeze(-3)
440
+ d2 = (diff * diff).sum(-1)
441
+ B, V, _ = d2.shape
442
+ cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32)
443
+ cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
444
+ s = (-1.0)**V; f = math.factorial(V-1)
445
+ return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm)
446
+
447
+ def cv_loss(emb, target=0.12, n_samples=16):
448
+ B = emb.shape[0]
449
+ if B < 5: return torch.tensor(0.0, device=emb.device)
450
+ vols = []
451
+ for _ in range(n_samples):
452
+ idx = torch.randperm(B, device=emb.device)[:5]
453
+ v2 = cayley_menger_vol2(emb[idx].unsqueeze(0))
454
+ vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12))
455
+ stacked = torch.stack(vols)
456
+ cv = stacked.std() / (stacked.mean() + 1e-8)
457
+ return (cv - target).abs()
458
+
459
+ def cv_metric(emb, n=200):
460
+ B = emb.shape[0]
461
+ if B < 5: return 0.0
462
+ vols = []
463
+ for _ in range(n):
464
+ idx = torch.randperm(B, device=emb.device)[:5]
465
+ v2 = cayley_menger_vol2(emb[idx].unsqueeze(0))
466
+ v = torch.sqrt(F.relu(v2[0]) + 1e-12).item()
467
+ if v > 0: vols.append(v)
468
+ if len(vols) < 10: return 0.0
469
+ a = np.array(vols)
470
+ return float(a.std() / (a.mean() + 1e-8))
471
+
472
+ def measure_consensus_stats(consensus_embs, n_check=2000):
473
+ """Measure exact geometric statistics of the consensus manifold."""
474
+ embs = consensus_embs[:n_check].float()
475
+ # CV
476
+ cv = cv_metric(embs.to(DEVICE))
477
+ # Pairwise cosine
478
+ sim = embs @ embs.T
479
+ mask = ~torch.eye(embs.shape[0], dtype=torch.bool)
480
+ pairwise = sim[mask]
481
+ mean_cos = pairwise.mean().item()
482
+ # Spectral
483
+ centered = embs - embs.mean(0, keepdim=True)
484
+ S = torch.linalg.svdvals(centered)
485
+ S_norm = (S / (S.sum() + 1e-8)).tolist()[:50]
486
+ # Eff dim
487
+ eff_dim = float((S.sum() ** 2) / (S.pow(2).sum() + 1e-12))
488
+
489
+ return {
490
+ "cv": cv,
491
+ "mean_cos": mean_cos,
492
+ "spectral": S_norm,
493
+ "eff_dim": eff_dim,
494
+ }
495
+
496
+
497
+ # ══════════════════════════════════════════════════════════════════
498
+ # EXTRACTION + ALIGNMENT
499
+ # ══════════════════════════════════════════════════════════════════
500
+
501
+ def symmetric_inv_sqrt(cov, eps=1e-6):
502
+ evals, evecs = torch.linalg.eigh(cov)
503
+ evals = torch.clamp(evals, min=eps)
504
+ return evecs @ torch.diag(evals.rsqrt()) @ evecs.T
505
+
506
+ def procrustes_align(source, target, n_align=5000):
507
+ N = min(n_align, source.shape[0], target.shape[0])
508
+ S = source[:N].float(); T = target[:N].float()
509
+ s_mean = S.mean(0, keepdim=True); t_mean = T.mean(0, keepdim=True)
510
+ Sc = S - s_mean; Tc = T - t_mean; N_s = Sc.shape[0]
511
+ cos_before = F.cosine_similarity(Sc, Tc, dim=-1).mean().item()
512
+ s_cov = (Sc.T @ Sc) / max(N_s - 1, 1)
513
+ t_cov = (Tc.T @ Tc) / max(N_s - 1, 1)
514
+ s_whiten = symmetric_inv_sqrt(s_cov)
515
+ t_whiten = symmetric_inv_sqrt(t_cov)
516
+ Sc_w = F.normalize(Sc @ s_whiten, dim=-1)
517
+ Tc_w = F.normalize(Tc @ t_whiten, dim=-1)
518
+ U, _, Vt = torch.linalg.svd(Tc_w.T @ Sc_w, full_matrices=False)
519
+ R = U @ Vt
520
+ cos_after = F.cosine_similarity(Sc_w @ R.T, Tc_w, dim=-1).mean().item()
521
+ return {
522
+ "rotation": R, "source_mean": s_mean.squeeze(0),
523
+ "source_whitener": s_whiten,
524
+ "target_unwhitener": torch.linalg.pinv(t_whiten),
525
+ "cos_before": cos_before, "cos_after": cos_after,
526
+ }
527
+
528
+ def apply_align(emb, a):
529
+ x = emb.float() - a["source_mean"]
530
+ x = x @ a["source_whitener"]; x = x @ a["rotation"].T
531
+ x = x @ a["target_unwhitener"]; return x
532
+
533
+
534
+ # ══════════════════════════════════════════════════════════════════
535
+ # MAIN
536
+ # ══════════════════════════════════════════════════════════════════
537
+
538
+ def run():
539
+ torch.manual_seed(42)
540
+ np.random.seed(42)
541
+ N_SAMPLES = 20000
542
+ MAX_LEN = 128
543
+ BATCH = 256
544
+
545
+ # ── Phase 0: Extract ──
546
+ print(f"\n{'='*65}")
547
+ print("PHASE 0: EXTRACTION")
548
+ print(f"{'='*65}")
549
+
550
+ from datasets import load_dataset
551
+ from transformers import AutoModel, AutoTokenizer
552
+
553
+ ds = load_dataset("CaptionEmporium/conceptual-captions-cc12m-llavanext",
554
+ split="train", streaming=True)
555
+ captions = []
556
+ for row in ds:
557
+ cap = row.get("caption_llava", "")
558
+ if isinstance(cap, str) and len(cap) > 50:
559
+ captions.append(cap)
560
+ if len(captions) >= N_SAMPLES:
561
+ break
562
+ print(f" Captions: {len(captions):,}")
563
+
564
+ embeds = {}
565
+ for model_name, short, max_len in EXPERTS:
566
+ print(f"\n Extracting: {short}...")
567
+ model = AutoModel.from_pretrained(model_name).to(DEVICE).eval()
568
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
569
+ all_emb = []
570
+ with torch.no_grad():
571
+ for i in tqdm(range(0, len(captions), 128), desc=f" {short}"):
572
+ batch = captions[i:i+128]
573
+ inputs = tokenizer(batch, max_length=max_len, padding=True,
574
+ truncation=True, return_tensors="pt").to(DEVICE)
575
+ out = model(**inputs)
576
+ m = inputs.attention_mask.unsqueeze(-1).float()
577
+ pooled = (out.last_hidden_state * m).sum(1) / m.sum(1).clamp(min=1)
578
+ all_emb.append(pooled.cpu())
579
+ embeds[short] = torch.cat(all_emb)
580
+ print(f" Shape: {embeds[short].shape}")
581
+ del model; gc.collect(); torch.cuda.empty_cache()
582
+
583
+ # ── Phase 0b: Align + Consensus + Measure ──
584
+ print(f"\n{'='*65}")
585
+ print("PHASE 0b: GENERALIZED PROCRUSTES ALIGNMENT (no reference bias)")
586
+ print(f"{'='*65}")
587
+
588
+ names = [s for _, s, _ in EXPERTS]
589
+
590
+ # Generalized Procrustes: iteratively align all to their mean
591
+ # No expert is the reference. The centerpoint emerges.
592
+ GPA_ITERS = 10
593
+ current = {name: embeds[name].float() for name in names}
594
+
595
+ for gpa_iter in range(GPA_ITERS):
596
+ # Compute mean shape
597
+ mean_shape = sum(current[n] for n in names) / len(names)
598
+
599
+ # Align each to mean
600
+ new_current = {}
601
+ total_delta = 0.0
602
+ for name in names:
603
+ info = procrustes_align(current[name], mean_shape)
604
+ new_current[name] = apply_align(current[name], info)
605
+ # Measure how much this iteration changed things
606
+ delta = (new_current[name] - current[name]).pow(2).mean().item()
607
+ total_delta += delta
608
+
609
+ current = new_current
610
+ if gpa_iter == 0 or (gpa_iter + 1) % 3 == 0 or total_delta < 1e-8:
611
+ print(f" GPA iter {gpa_iter+1}: delta={total_delta:.8f}")
612
+ if total_delta < 1e-8:
613
+ print(f" Converged at iteration {gpa_iter+1}")
614
+ break
615
+
616
+ # Final alignment: align each expert to the converged mean
617
+ mean_shape = sum(current[n] for n in names) / len(names)
618
+ procrustes_results = {}
619
+ aligned = {}
620
+ for name in names:
621
+ info = procrustes_align(embeds[name], mean_shape)
622
+ procrustes_results[name] = info
623
+ aligned[name] = apply_align(embeds[name], info)
624
+ cos = F.cosine_similarity(
625
+ aligned[name][:2000], mean_shape[:2000], dim=-1).mean().item()
626
+ print(f" {name:10s}: cos_after={info['cos_after']:.4f} cos_to_mean={cos:.4f}")
627
+
628
+ # Consensus = normalized centroid (now equidistant from all experts)
629
+ consensus = F.normalize(sum(aligned[n] for n in names) / len(names), dim=-1)
630
+ for name in names:
631
+ cos = F.cosine_similarity(consensus[:2000], aligned[name][:2000], dim=-1).mean().item()
632
+ print(f" cos(consensus, {name}): {cos:.4f}")
633
+
634
+ # Verify equidistance
635
+ expert_cos_to_consensus = []
636
+ for name in names:
637
+ c = F.cosine_similarity(consensus[:2000], aligned[name][:2000], dim=-1).mean().item()
638
+ expert_cos_to_consensus.append(c)
639
+ equidist_range = max(expert_cos_to_consensus) - min(expert_cos_to_consensus)
640
+ print(f" Equidistance range: {equidist_range:.4f} (should be near 0)")
641
+
642
+ # Measure EXACT consensus statistics
643
+ print(f"\n Measuring consensus statistics...")
644
+ consensus_stats = measure_consensus_stats(consensus)
645
+ print(f" CV: {consensus_stats['cv']:.4f}")
646
+ print(f" Mean cos: {consensus_stats['mean_cos']:.4f}")
647
+ print(f" Eff dim: {consensus_stats['eff_dim']:.1f}")
648
+ print(f" Spectral: [{', '.join(f'{s:.4f}' for s in consensus_stats['spectral'][:5])}...]")
649
+
650
+ del embeds, aligned
651
+ gc.collect(); torch.cuda.empty_cache()
652
+
653
+ # ── Phase 1: Train Student ──
654
+ print(f"\n{'='*65}")
655
+ print("PHASE 1: TRAIN STUDENT")
656
+ print(f"{'='*65}")
657
+
658
+ tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
659
+ tokens = tokenizer(captions, max_length=MAX_LEN, padding="max_length",
660
+ truncation=True, return_tensors="pt")
661
+ input_ids = tokens["input_ids"]
662
+ attention_mask = tokens["attention_mask"]
663
+
664
+ n_train = N_SAMPLES - 2000
665
+ train_ids = input_ids[:n_train].to(DEVICE)
666
+ train_mask = attention_mask[:n_train].to(DEVICE)
667
+ train_targets = consensus[:n_train].to(DEVICE)
668
+ val_ids = input_ids[n_train:].to(DEVICE)
669
+ val_mask = attention_mask[n_train:].to(DEVICE)
670
+ val_targets = consensus[n_train:].to(DEVICE)
671
+
672
+ student = MiniStudent(
673
+ vocab_size=tokenizer.vocab_size, max_len=MAX_LEN,
674
+ d_model=256, n_heads=4, n_layers=4, d_ff=1024,
675
+ output_dim=768, dropout=0.1, pad_token_id=tokenizer.pad_token_id
676
+ ).to(DEVICE)
677
+ n_params = sum(p.numel() for p in student.parameters())
678
+ print(f" Student: {n_params:,} params")
679
+ print(f" CV target: {consensus_stats['cv']:.4f}")
680
+
681
+ optimizer = torch.optim.AdamW(student.parameters(), lr=3e-4, weight_decay=0.01)
682
+
683
+ for epoch in range(5):
684
+ student.train()
685
+ perm = torch.randperm(n_train, device=DEVICE)
686
+ t_loss, t_acc, t_cos, n = 0, 0, 0, 0
687
+ t0 = time.time()
688
+ for i in range(0, n_train, BATCH):
689
+ idx = perm[i:i+BATCH]
690
+ if len(idx) < 8: continue
691
+ emb = student(train_ids[idx], train_mask[idx])
692
+ tgt = train_targets[idx]
693
+ l_nce, acc = infonce(emb, tgt)
694
+ l_mse = F.mse_loss(emb, tgt)
695
+ l_cv = cv_loss(emb, target=consensus_stats["cv"])
696
+ loss = l_nce + l_mse + 0.1 * l_cv
697
+ loss.backward()
698
+ torch.nn.utils.clip_grad_norm_(student.parameters(), 1.0)
699
+ optimizer.step(); optimizer.zero_grad(set_to_none=True)
700
+ with torch.no_grad():
701
+ cos = F.cosine_similarity(emb, tgt, dim=-1).mean().item()
702
+ t_loss += loss.item(); t_acc += acc; t_cos += cos; n += 1
703
+ elapsed = time.time() - t0; d = max(n, 1)
704
+ student.eval()
705
+ with torch.no_grad():
706
+ v_emb = student(val_ids, val_mask)
707
+ _, v_acc = infonce(v_emb[:1000], val_targets[:1000])
708
+ v_cos = F.cosine_similarity(v_emb, val_targets, dim=-1).mean().item()
709
+ v_cv = cv_metric(v_emb[:1000])
710
+ print(f" E{epoch+1}: {elapsed:.0f}s loss={t_loss/d:.4f} "
711
+ f"t_acc={t_acc/d:.3f} t_cos={t_cos/d:.3f} "
712
+ f"v_acc={v_acc:.3f} v_cos={v_cos:.3f} v_cv={v_cv:.3f}")
713
+
714
+ torch.save(student.state_dict(), "mini_student.pt")
715
+ print(f"\n Student saved. v_cos={v_cos:.3f}, v_cv={v_cv:.3f}")
716
+
717
+ # ── Phase 2: Train Alignment Bank ──
718
+ print(f"\n{'='*65}")
719
+ print("PHASE 2: TRAIN ALIGNMENT BANK (student frozen)")
720
+ print(f"{'='*65}")
721
+
722
+ student.eval()
723
+ for p in student.parameters():
724
+ p.requires_grad = False
725
+
726
+ print(" Pre-encoding through frozen student...")
727
+ with torch.no_grad():
728
+ all_embs = []
729
+ for i in range(0, n_train, 512):
730
+ j = min(i + 512, n_train)
731
+ emb = student(train_ids[i:j], train_mask[i:j])
732
+ all_embs.append(emb)
733
+ student_embs = torch.cat(all_embs)
734
+ val_student_embs = student(val_ids, val_mask)
735
+ print(f" Student embeddings: {student_embs.shape}")
736
+
737
+ bank = AlignmentBank(
738
+ d_embed=768, n_experts=len(EXPERTS),
739
+ n_anchors=512, d_bank=128
740
+ ).to(DEVICE)
741
+
742
+ bank.init_from_procrustes(procrustes_results, names,
743
+ consensus[:n_train], consensus_stats)
744
+ bank_params = sum(p.numel() for p in bank.parameters())
745
+ print(f" Bank: {bank_params:,} params")
746
+ print(f" Bank targets: CV={bank.target_cv.item():.4f}, "
747
+ f"mean_cos={bank.target_mean_cos.item():.4f}")
748
+
749
+ # Calibrate disagreement from initial state (before any training)
750
+ bank.calibrate_disagreement(student_embs[:2000])
751
+
752
+ bank_opt = torch.optim.AdamW(bank.parameters(), lr=1e-3, weight_decay=0.01)
753
+ BANK_EPOCHS = 20
754
+ BANK_BATCH = 256
755
+
756
+ for epoch in range(BANK_EPOCHS):
757
+ bank.train()
758
+ perm = torch.randperm(n_train, device=DEVICE)
759
+ total_loss = 0
760
+ stats = {"expert_agreement": 0, "rotation_ortho": 0,
761
+ "anchor_spread": 0, "bank_cv": 0, "emb_cv": 0,
762
+ "cross_expert_var": 0, "disagree_preserve": 0}
763
+ n = 0
764
+ t0 = time.time()
765
+ for i in range(0, n_train, BANK_BATCH):
766
+ idx = perm[i:i+BANK_BATCH]
767
+ if len(idx) < 16: continue
768
+ emb = student_embs[idx]
769
+ enriched, aux = bank(emb)
770
+ loss = bank.bank_loss(aux)
771
+ loss.backward()
772
+ torch.nn.utils.clip_grad_norm_(bank.parameters(), 1.0)
773
+ bank_opt.step(); bank_opt.zero_grad(set_to_none=True)
774
+ total_loss += loss.item()
775
+ for k in stats:
776
+ if k in aux:
777
+ v = aux[k]
778
+ stats[k] += v.item() if torch.is_tensor(v) else v
779
+ n += 1
780
+ elapsed = time.time() - t0; d = max(n, 1)
781
+
782
+ bank.eval()
783
+ with torch.no_grad():
784
+ v_enriched, v_aux = bank(val_student_embs)
785
+ v_loss = bank.bank_loss(v_aux).item()
786
+
787
+ print(f"\n E{epoch+1:2d}: {elapsed:.0f}s loss={total_loss/d:.4f} v_loss={v_loss:.4f}")
788
+ print(f" Geometry: b_cv={stats['bank_cv']/d:.4f} e_cv={stats['emb_cv']/d:.4f} "
789
+ f"spread={stats['anchor_spread']/d:.5f} a_max={v_aux['anchor_max_cos']:.3f}")
790
+ print(f" Experts: cos={v_aux['expert_cos_mean']:.3f}±{v_aux['expert_cos_std']:.3f} "
791
+ f"agr={stats['expert_agreement']/d:.6f} ortho={stats['rotation_ortho']/d:.6f}")
792
+ print(f" Disagree: x_cos={v_aux.get('cross_expert_cos', 0):.4f}±{v_aux.get('cross_expert_cos_std', 0):.4f} "
793
+ f"ratio={v_aux['disagreement_ratio']:.6f} "
794
+ f"preserve={stats['disagree_preserve']/d:.6f} "
795
+ f"norms={v_aux['norm_ratio_spread']:.4f}")
796
+
797
+ torch.save(bank.state_dict(), "alignment_bank.pt")
798
+
799
+ # ── Phase 3: Geometric Verification ──
800
+ print(f"\n{'='*65}")
801
+ print("PHASE 3: GEOMETRIC VERIFICATION")
802
+ print(f"{'='*65}")
803
+
804
+ bank.eval()
805
+ with torch.no_grad():
806
+ enriched_val, v_aux = bank(val_student_embs)
807
+ original_768 = enriched_val[:, :768]
808
+ geo_context = enriched_val[:, 768:]
809
+
810
+ passthrough_cos = F.cosine_similarity(
811
+ original_768[:100], val_student_embs[:100], dim=-1).mean().item()
812
+ geo_cv = cv_metric(F.normalize(geo_context[:1000], dim=-1))
813
+ S = torch.linalg.svdvals(
814
+ geo_context[:1000].float() - geo_context[:1000].float().mean(0))
815
+ geo_eff_dim = float((S.sum() ** 2) / (S.pow(2).sum() + 1e-12))
816
+
817
+ # Verify consensus stats are preserved
818
+ emb_cv = cv_metric(val_student_embs[:1000])
819
+
820
+ print(f" Passthrough: {passthrough_cos:.6f} (target: 1.000)")
821
+ print(f" Emb CV: {emb_cv:.4f} (consensus: {consensus_stats['cv']:.4f})")
822
+ print(f" Geo context CV: {geo_cv:.4f}")
823
+ print(f" Geo eff_dim: {geo_eff_dim:.1f} / {bank.d_bank}")
824
+ print(f" Expert cos: {v_aux['expert_cos_mean']:.3f} ± {v_aux['expert_cos_std']:.3f}")
825
+ print(f" Anchor max cos: {v_aux['anchor_max_cos']:.3f}")
826
+ print(f" Disagreement:")
827
+ print(f" Cross-expert: {v_aux.get('cross_expert_cos', 0):.4f} ± {v_aux.get('cross_expert_cos_std', 0):.4f}")
828
+ print(f" Ratio: {v_aux['disagreement_ratio']:.6f} (target: {bank.target_disagreement_ratio.item():.6f})")
829
+ print(f" Norm spread: {v_aux['norm_ratio_spread']:.4f}")
830
+
831
+ # ── Phase 4: Classifier Stability Test ──
832
+ print(f"\n{'='*65}")
833
+ print("PHASE 4: CLASSIFIER STABILITY TEST")
834
+ print(f"{'='*65}")
835
+
836
+ with torch.no_grad():
837
+ embs = val_student_embs[:1000]
838
+ sim = embs @ embs.T
839
+ sim.fill_diagonal_(-1)
840
+ n_pairs = 3000
841
+ idx_a = torch.randint(0, 1000, (n_pairs,))
842
+ idx_b = torch.randint(0, 1000, (n_pairs,))
843
+ pair_cos = sim[idx_a, idx_b]
844
+ sorted_cos, _ = pair_cos.sort()
845
+ t1 = sorted_cos[n_pairs // 3].item()
846
+ t2 = sorted_cos[2 * n_pairs // 3].item()
847
+ labels = torch.zeros(n_pairs, dtype=torch.long, device=DEVICE)
848
+ labels[pair_cos > t2] = 0
849
+ labels[(pair_cos <= t2) & (pair_cos > t1)] = 1
850
+ labels[pair_cos <= t1] = 2
851
+
852
+ enriched_a, aux_a = bank(embs[idx_a])
853
+ enriched_b, aux_b = bank(embs[idx_b])
854
+
855
+ # Build explicit geometric features per pair
856
+ # These are interpretable and hard to overfit
857
+ a_emb = embs[idx_a]; b_emb = embs[idx_b]
858
+ a_geo = enriched_a[:, 768:]; b_geo = enriched_b[:, 768:]
859
+
860
+ geo_explicit = torch.cat([
861
+ # Pair-level
862
+ F.cosine_similarity(a_emb, b_emb, dim=-1).unsqueeze(-1), # raw cosine
863
+ (a_emb - b_emb).pow(2).mean(dim=-1).unsqueeze(-1), # MSE
864
+ F.cosine_similarity(a_geo, b_geo, dim=-1).unsqueeze(-1), # geo context cosine
865
+ (a_geo - b_geo).pow(2).mean(dim=-1).unsqueeze(-1), # geo context MSE
866
+ # Per-sample bank diagnostics (already computed in forward)
867
+ torch.abs(a_emb - b_emb).mean(dim=-1).unsqueeze(-1), # L1 distance
868
+ (a_emb * b_emb).sum(dim=-1).unsqueeze(-1), # dot product
869
+ ], dim=-1) # (n_pairs, 6)
870
+
871
+ modes = {
872
+ "raw_768": torch.cat([a_emb, b_emb], dim=-1),
873
+ "raw+diff": torch.cat([a_emb, b_emb, torch.abs(a_emb - b_emb), a_emb * b_emb], dim=-1),
874
+ "bank_enriched": torch.cat([enriched_a, enriched_b], dim=-1),
875
+ "bank+diff": torch.cat([enriched_a, enriched_b,
876
+ torch.abs(enriched_a - enriched_b),
877
+ enriched_a * enriched_b], dim=-1),
878
+ "geo_explicit": geo_explicit,
879
+ }
880
+
881
+ print(f"\n {'Mode':<20} {'Dim':>6} {'Train':>7} {'Val':>7} {'Gap':>7}")
882
+ print(f" {'-'*50}")
883
+
884
+ for mode_name, features in modes.items():
885
+ feat_dim = features.shape[1]
886
+ clf = nn.Sequential(
887
+ nn.Linear(feat_dim, min(256, feat_dim)), nn.GELU(), nn.LayerNorm(min(256, feat_dim)),
888
+ nn.Dropout(0.1),
889
+ nn.Linear(min(256, feat_dim), 3)
890
+ ).to(DEVICE)
891
+ clf_opt = torch.optim.Adam(clf.parameters(), lr=1e-3)
892
+ n_clf_train = 2400
893
+ train_f = features[:n_clf_train].detach()
894
+ train_l = labels[:n_clf_train]
895
+ val_f = features[n_clf_train:].detach()
896
+ val_l = labels[n_clf_train:]
897
+ for e in range(30):
898
+ clf.train()
899
+ logits = clf(train_f)
900
+ loss = F.cross_entropy(logits, train_l)
901
+ loss.backward(); clf_opt.step(); clf_opt.zero_grad()
902
+ clf.eval()
903
+ with torch.no_grad():
904
+ v_acc = (clf(val_f).argmax(-1) == val_l).float().mean().item()
905
+ t_acc = (clf(train_f).argmax(-1) == train_l).float().mean().item()
906
+ print(f" {mode_name:<20} {feat_dim:>6} {t_acc:>7.3f} {v_acc:>7.3f} {t_acc-v_acc:>7.3f}")
907
+
908
+ print(f"\n{'='*65}")
909
+ print("SUMMARY")
910
+ print(f"{'='*65}")
911
+ print(f" Consensus CV: {consensus_stats['cv']:.4f}")
912
+ print(f" Consensus eff_dim:{consensus_stats['eff_dim']:.1f}")
913
+ print(f" Student v_cos: {v_cos:.3f}")
914
+ print(f" Student v_cv: {v_cv:.3f}")
915
+ print(f" Bank params: {bank_params:,}")
916
+ print(f" Bank geo_eff_dim: {geo_eff_dim:.1f}")
917
+ print(f" Bank geo_cv: {geo_cv:.4f}")
918
+ print(f"\n{'='*65}")
919
+ print("DONE")
920
+ print(f"{'='*65}")
921
+
922
+
923
+ if __name__ == "__main__":
924
+ run()