AbstractPhil commited on
Commit
162ef56
Β·
verified Β·
1 Parent(s): 1c6fbd2

bug fixed, queue didn't proc requires retrain

Browse files
Files changed (1) hide show
  1. trainer.py +427 -1053
trainer.py CHANGED
@@ -1,1083 +1,457 @@
1
  #!/usr/bin/env python3
2
  """
3
- GeoLIP Dual-Stream ViT β€” Full Bidirectional, Decoupled Gradients
4
- ==================================================================
5
- Two parallel streams that cross-attend at EVERY layer:
6
- Stream A (geometric): KSimplexChannel β†’ geometric features β†’ self-attn
7
- Stream B (standard): learned projections β†’ self-attn
8
-
9
- Architecture (two gradient paths):
10
-
11
- GEOMETRIC PATH (InfoNCE + CV + CM shape dual blocks):
12
- patch_embed β†’ split β†’ geo_stream, std_stream
13
- β†’ NΓ— DualStreamBlock (self-attn + cross-attn + KSimplex)
14
- β†’ pool BOTH β†’ geo_emb, std_emb, emb on S^d
15
- β†’ InfoNCE, CV loss, CM validity, mastery, autograd
16
-
17
- CLASSIFICATION PATH (BCE shapes cross blocks + classifier):
18
- dual block outputs [DETACHED β€” gradient wall]
19
- β†’ NΓ— CrossBlock (bidirectional cross-attn)
20
- β†’ pool BOTH β†’ class projections β†’ S^d
21
- β†’ constellation + patchwork + classifier β†’ BCE
22
-
23
- The dual blocks form geometry shaped ONLY by contrastive + geometric forces.
24
- The cross blocks learn to READ the geometry for classification.
25
- BCE cannot corrupt the geometric formation.
26
  """
27
 
28
  import torch
29
  import torch.nn as nn
30
  import torch.nn.functional as F
31
- import math
32
- from itertools import combinations
 
 
 
33
 
34
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
35
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  # ══════════════════════════════════════════════════════════════════
38
- # CAYLEY-MENGER + KSIMPLEX CHANNEL
39
  # ══════════════════════════════════════════════════════════════════
40
 
41
- class CMValidator(nn.Module):
42
- """Batch-friendly Cayley-Menger determinant."""
43
- def __init__(self, k):
44
- super().__init__()
45
- self._k = k
46
- self._nv = k + 1
47
- pairs = list(combinations(range(self._nv), 2))
48
- self._npairs = len(pairs)
49
- self.register_buffer('_pi', torch.tensor([p[0] for p in pairs], dtype=torch.long))
50
- self.register_buffer('_pj', torch.tensor([p[1] for p in pairs], dtype=torch.long))
51
- sign = (-1.0) ** (k + 1)
52
- fact = math.factorial(k)
53
- self._prefactor = sign / ((2.0 ** k) * (fact ** 2))
54
-
55
- def forward(self, verts):
56
- gram = torch.einsum('...ve,...we->...vw', verts, verts)
57
- norms = torch.diagonal(gram, dim1=-2, dim2=-1)
58
- d2_mat = norms.unsqueeze(-1) + norms.unsqueeze(-2) - 2 * gram
59
- d2_mat = F.relu(d2_mat)
60
- d2_pairs = d2_mat[..., self._pi, self._pj]
61
- shape = d2_mat.shape[:-2]
62
- V = d2_mat.shape[-1]
63
- cm = torch.zeros(*shape, V + 1, V + 1,
64
- device=d2_mat.device, dtype=d2_mat.dtype)
65
- cm[..., 0, 1:] = 1.0; cm[..., 1:, 0] = 1.0
66
- cm[..., 1:, 1:] = d2_mat
67
- vol2 = self._prefactor * torch.linalg.det(cm.float())
68
- vol2 = vol2.to(d2_pairs.dtype)
69
- return d2_pairs, vol2
70
-
71
-
72
- class KSimplexChannel(nn.Module):
73
- """Per-position simplex encoder. k=4: 11 geometric features."""
74
- BASE_DEFORM = 0.05
75
-
76
- def __init__(self, k, in_dim, edim):
77
- super().__init__()
78
- self._k = k
79
- self._nv = k + 1
80
- self._edim = edim
81
- self._cm = CMValidator(k)
82
- self._out_dim = self._cm._npairs + 1 # 10 dΒ² + 1 volΒ² = 11
83
- template = self._make_regular_simplex(k, edim)
84
- self.register_buffer('_template', template)
85
- self._to_deform = nn.Linear(in_dim, self._nv * edim)
86
- self._norm = nn.LayerNorm(self._out_dim)
87
-
88
- @staticmethod
89
- def _make_regular_simplex(k, edim):
90
- nv = k + 1
91
- verts = torch.zeros(nv, edim)
92
- for i in range(min(nv, edim)):
93
- verts[i, i] = 1.0
94
- if nv > edim:
95
- for i in range(edim, nv):
96
- v = torch.randn(edim)
97
- verts[i] = v / (v.norm() + 1e-8)
98
- verts = verts - verts.mean(dim=0, keepdim=True)
99
- edge_len = (verts[0] - verts[1]).norm().clamp(min=1e-8)
100
- verts = verts / edge_len
101
- return verts
102
-
103
- @property
104
- def out_dim(self):
105
- return self._out_dim
106
-
107
- def forward(self, x):
108
- deform = self._to_deform(x).unflatten(-1, (self._nv, self._edim))
109
- verts = self._template + self.BASE_DEFORM * deform
110
- d2, vol2 = self._cm(verts)
111
- geo = torch.cat([d2, vol2.unsqueeze(-1)], dim=-1)
112
- geo = self._norm(geo)
113
- return geo, vol2
114
-
115
-
116
- # ══════════════════════════════════════════════════════════════════
117
- # CONSTELLATION + PATCHWORK
118
- # ══════════════════════════════════════════════════════════════════
119
-
120
- class Constellation(nn.Module):
121
- def __init__(self, n_anchors, dim, anchor_drop=0.0):
122
- super().__init__()
123
- self.anchors = nn.Parameter(torch.randn(n_anchors, dim))
124
- nn.init.normal_(self.anchors, 0, 1.0 / dim ** 0.5)
125
- self.anchor_drop = anchor_drop
126
-
127
- def triangulate(self, emb, training=False):
128
- anchors = F.normalize(self.anchors, dim=-1)
129
- if training and self.anchor_drop > 0:
130
- mask = torch.rand(anchors.shape[0], device=anchors.device) > self.anchor_drop
131
- if mask.sum() < 2:
132
- mask[:2] = True
133
- anchors = anchors[mask]
134
- cos = emb @ anchors.T
135
- tri = 1.0 - cos
136
- _, nearest_local = cos.max(dim=-1)
137
- full_idx = mask.nonzero(as_tuple=True)[0]
138
- nearest = full_idx[nearest_local]
139
- else:
140
- cos = emb @ anchors.T
141
- tri = 1.0 - cos
142
- _, nearest = cos.max(dim=-1)
143
- return tri, nearest
144
-
145
-
146
- class Patchwork(nn.Module):
147
- def __init__(self, n_anchors, n_comp, d_comp):
148
- super().__init__()
149
- self.n_comp = n_comp
150
- self.d_comp = d_comp
151
- asgn = torch.arange(n_anchors) % n_comp
152
- self.register_buffer('asgn', asgn)
153
- anchors_per = n_anchors // n_comp
154
- self.comps = nn.ModuleList([nn.Sequential(
155
- nn.Linear(anchors_per, d_comp * 2), nn.GELU(),
156
- nn.Linear(d_comp * 2, d_comp), nn.LayerNorm(d_comp))
157
- for _ in range(n_comp)])
158
-
159
- def forward(self, tri):
160
- return torch.cat([self.comps[k](tri[:, self.asgn == k])
161
- for k in range(self.n_comp)], -1)
162
-
163
 
164
  # ══════════════════════════════════════════════════════════════════
165
- # EMBEDDING AUTOGRAD
166
  # ══════════════════════════════════════════════════════════════════
167
 
168
- class EmbeddingAutograd(torch.autograd.Function):
169
- """Geometric autograd: tangential projection + anchor separation."""
170
- @staticmethod
171
- def forward(ctx, x, embedding, anchors, tang, sep):
172
- ctx.save_for_backward(embedding, anchors)
173
- ctx.tang = tang; ctx.sep = sep
174
- return x
175
-
176
- @staticmethod
177
- def backward(ctx, grad_output):
178
- embedding, anchors = ctx.saved_tensors
179
- emb_n = F.normalize(embedding.detach().float(), dim=-1)
180
- anchors_n = F.normalize(anchors.detach().float(), dim=-1)
181
- grad_f = grad_output.float()
182
- radial = (grad_f * emb_n).sum(-1, keepdim=True) * emb_n
183
- corrected = (grad_f - radial) + (1.0 - ctx.tang) * radial
184
- if ctx.sep > 0:
185
- cos_to = emb_n @ anchors_n.T
186
- nearest = anchors_n[cos_to.argmax(dim=-1)]
187
- toward = (corrected * nearest).sum(-1, keepdim=True)
188
- corrected = corrected - ctx.sep * (toward > 0).float() * toward * nearest
189
- return corrected.to(grad_output.dtype), None, None, None, None
190
-
191
-
192
- class DisagreementCache(nn.Module):
193
- """
194
- Rolling multi-batch embedding cache for hard negative mining.
195
-
196
- Once standard InfoNCE saturates (acc=1.0), the "second guess" clause
197
- activates: find the hardest negatives across cached batches and
198
- force the model to refine those boundary regions.
199
-
200
- The cache stores (embedding, label) pairs from recent batches.
201
- Each step, we find near-miss pairs β€” different images that are
202
- closest on the sphere β€” and compute a harder contrastive loss
203
- that forces the model to differentiate them.
204
- """
205
- def __init__(self, dim, max_size=4096):
206
- super().__init__()
207
- self.dim = dim
208
- self.max_size = max_size
209
- self.register_buffer('emb_cache', torch.zeros(0, dim))
210
- self.register_buffer('label_cache', torch.zeros(0, dtype=torch.long))
211
- self.active = False # activated when nce_acc hits 1.0
212
-
213
- @torch.no_grad()
214
- def update(self, emb, labels):
215
- """Push new batch into cache, evict oldest if full."""
216
- self.emb_cache = torch.cat(
217
- [self.emb_cache, emb.detach()], dim=0)[-self.max_size:]
218
- self.label_cache = torch.cat(
219
- [self.label_cache, labels.detach()], dim=0)[-self.max_size:]
220
-
221
- def compute_second_guess(self, emb, labels, temp=0.04, n_hard=16):
222
- """
223
- Hard negative InfoNCE from cached disagreements.
224
-
225
- For each embedding in the current batch:
226
- 1. Find the n_hard closest embeddings in cache from DIFFERENT classes
227
- 2. Find the closest embedding from the SAME class in cache
228
- 3. InfoNCE: pull toward same-class, push from hard negatives
229
-
230
- temp is lower than standard InfoNCE (0.04 vs 0.07) β€” sharper
231
- discrimination at the boundaries.
232
-
233
- Returns: loss, disagreement_stats
234
- """
235
- if self.emb_cache.shape[0] < 256:
236
- return torch.tensor(0.0, device=emb.device), {}
237
-
238
- B = emb.shape[0]
239
- # Similarity to all cached embeddings
240
- sim = emb @ self.emb_cache.T # (B, cache_size)
241
-
242
- # Masks: same class vs different class
243
- same_mask = labels.unsqueeze(1) == self.label_cache.unsqueeze(0) # (B, C)
244
- diff_mask = ~same_mask
245
-
246
- # For each sample: find hardest negatives (closest different-class)
247
- neg_sim = sim.clone()
248
- neg_sim[same_mask] = -2.0 # mask out same class
249
- hard_neg_vals, hard_neg_idx = neg_sim.topk(n_hard, dim=1) # (B, n_hard)
250
-
251
- # For each sample: find the positive (closest same-class in cache)
252
- pos_sim = sim.clone()
253
- pos_sim[diff_mask] = -2.0 # mask out different class
254
- pos_vals, _ = pos_sim.max(dim=1, keepdim=True) # (B, 1)
255
-
256
- # Check if any sample has no same-class in cache
257
- has_pos = same_mask.any(dim=1) # (B,)
258
- if not has_pos.all():
259
- # Only compute on samples that have positives
260
- valid = has_pos
261
- if valid.sum() < 2:
262
- return torch.tensor(0.0, device=emb.device), {}
263
- pos_vals = pos_vals[valid]
264
- hard_neg_vals = hard_neg_vals[valid]
265
-
266
- # InfoNCE: logit for positive vs n_hard negatives
267
- # (B, 1 + n_hard) where column 0 is positive
268
- logits = torch.cat([pos_vals, hard_neg_vals], dim=1) / temp
269
- target = torch.zeros(logits.shape[0], dtype=torch.long,
270
- device=logits.device)
271
- l_second = F.cross_entropy(logits, target)
272
- second_acc = (logits.argmax(1) == 0).float().mean()
273
-
274
- # Disagreement stats: how hard are the negatives?
275
- margin = pos_vals.squeeze(-1) - hard_neg_vals[:, 0] # (B,)
276
- stats = {
277
- 'second_acc': second_acc.item(),
278
- 'margin_mean': margin.mean().item(),
279
- 'margin_min': margin.min().item(),
280
- 'hardest_neg_cos': hard_neg_vals[:, 0].mean().item(),
281
- }
282
-
283
- return l_second, stats
284
-
285
 
286
  # ══════════════════════════════════════════════════════════════════
287
- # DUAL-STREAM BLOCKS
288
  # ══════════════════════════════════════════════════════════════════
289
 
290
- class DualStreamBlock(nn.Module):
291
- """
292
- Two parallel streams with self-attention + cross-attention.
293
-
294
- Geo stream: self_attn β†’ KSimplex β†’ cross_attn(q=geo, kv=std) β†’ FFN
295
- Std stream: self_attn β†’ cross_attn(q=std, kv=geo) β†’ FFN
296
-
297
- Cross-attention is the bottleneck where info flows between streams.
298
- """
299
- def __init__(self, stream_dim, geo_dim, n_heads, ksimplex_k=4,
300
- ksimplex_edim=8, dropout=0.1):
301
- super().__init__()
302
- self.stream_dim = stream_dim
303
- self.geo_dim = geo_dim
304
-
305
- # ── Geo stream ──
306
- self.geo_norm1 = nn.LayerNorm(stream_dim)
307
- self.geo_self_attn = nn.MultiheadAttention(
308
- stream_dim, n_heads, dropout=dropout, batch_first=True)
309
- self.geo_ksimplex = KSimplexChannel(
310
- k=ksimplex_k, in_dim=stream_dim, edim=ksimplex_edim)
311
- # Project geo features back to stream dim
312
- self.geo_lift = nn.Sequential(
313
- nn.Linear(self.geo_ksimplex.out_dim, stream_dim), nn.GELU())
314
- self.geo_norm2 = nn.LayerNorm(stream_dim)
315
- self.geo_cross_attn = nn.MultiheadAttention(
316
- stream_dim, n_heads, dropout=dropout, batch_first=True)
317
- self.geo_norm3 = nn.LayerNorm(stream_dim)
318
- self.geo_ffn = nn.Sequential(
319
- nn.Linear(stream_dim, stream_dim * 4), nn.GELU(),
320
- nn.Dropout(dropout),
321
- nn.Linear(stream_dim * 4, stream_dim), nn.Dropout(dropout))
322
-
323
- # ── Std stream ──
324
- self.std_norm1 = nn.LayerNorm(stream_dim)
325
- self.std_self_attn = nn.MultiheadAttention(
326
- stream_dim, n_heads, dropout=dropout, batch_first=True)
327
- self.std_norm2 = nn.LayerNorm(stream_dim)
328
- self.std_cross_attn = nn.MultiheadAttention(
329
- stream_dim, n_heads, dropout=dropout, batch_first=True)
330
- self.std_norm3 = nn.LayerNorm(stream_dim)
331
- self.std_ffn = nn.Sequential(
332
- nn.Linear(stream_dim, stream_dim * 4), nn.GELU(),
333
- nn.Dropout(dropout),
334
- nn.Linear(stream_dim * 4, stream_dim), nn.Dropout(dropout))
335
-
336
- def forward(self, geo_stream, std_stream):
337
- """
338
- geo_stream: (B, P, stream_dim)
339
- std_stream: (B, P, stream_dim)
340
- Returns: geo_stream, std_stream, geo_feats (B, P, 11), vol2 (B, P)
341
- """
342
- B, P, _ = geo_stream.shape
343
-
344
- # ── Geo: self-attention ──
345
- h = self.geo_norm1(geo_stream)
346
- h, _ = self.geo_self_attn(h, h, h, need_weights=False)
347
- geo_stream = geo_stream + h
348
-
349
- # ── Geo: KSimplex per patch ──
350
- flat = geo_stream.reshape(B * P, -1)
351
- geo_feats, vol2 = self.geo_ksimplex(flat)
352
- geo_feats = geo_feats.reshape(B, P, -1) # (B, P, 11)
353
- vol2 = vol2.reshape(B, P) # (B, P)
354
- # Lift geo features and add as residual
355
- geo_stream = geo_stream + self.geo_lift(geo_feats)
356
-
357
- # ── Geo: cross-attend FROM std ──
358
- h = self.geo_norm2(geo_stream)
359
- std_ctx = self.std_norm2(std_stream)
360
- h, _ = self.geo_cross_attn(h, std_ctx, std_ctx, need_weights=False)
361
- geo_stream = geo_stream + h
362
-
363
- # ── Geo: FFN ──
364
- geo_stream = geo_stream + self.geo_ffn(self.geo_norm3(geo_stream))
365
-
366
- # ── Std: self-attention ──
367
- h = self.std_norm1(std_stream)
368
- h, _ = self.std_self_attn(h, h, h, need_weights=False)
369
- std_stream = std_stream + h
370
-
371
- # ── Std: cross-attend FROM geo ──
372
- h2 = self.std_norm2(std_stream)
373
- geo_ctx = self.geo_norm2(geo_stream)
374
- h2, _ = self.std_cross_attn(h2, geo_ctx, geo_ctx, need_weights=False)
375
- std_stream = std_stream + h2
376
-
377
- # ── Std: FFN ──
378
- std_stream = std_stream + self.std_ffn(self.std_norm3(std_stream))
379
-
380
- return geo_stream, std_stream, geo_feats, vol2
381
-
382
-
383
- class CrossBlock(nn.Module):
384
- """
385
- Bidirectional cross-attention block β€” both streams preserved.
386
-
387
- No fusion. No concatenation. Each stream self-attends, then
388
- cross-attends to the other. Both streams maintain their identity.
389
- The geometric rocks stay rocks.
390
- """
391
- def __init__(self, stream_dim, n_heads, dropout=0.1):
392
- super().__init__()
393
-
394
- # ── Geo path ──
395
- self.geo_norm1 = nn.LayerNorm(stream_dim)
396
- self.geo_self_attn = nn.MultiheadAttention(
397
- stream_dim, n_heads, dropout=dropout, batch_first=True)
398
- self.geo_norm2 = nn.LayerNorm(stream_dim)
399
- self.geo_cross_attn = nn.MultiheadAttention(
400
- stream_dim, n_heads, dropout=dropout, batch_first=True)
401
- self.geo_norm3 = nn.LayerNorm(stream_dim)
402
- self.geo_ffn = nn.Sequential(
403
- nn.Linear(stream_dim, stream_dim * 4), nn.GELU(),
404
- nn.Dropout(dropout),
405
- nn.Linear(stream_dim * 4, stream_dim), nn.Dropout(dropout))
406
-
407
- # ── Std path ──
408
- self.std_norm1 = nn.LayerNorm(stream_dim)
409
- self.std_self_attn = nn.MultiheadAttention(
410
- stream_dim, n_heads, dropout=dropout, batch_first=True)
411
- self.std_norm2 = nn.LayerNorm(stream_dim)
412
- self.std_cross_attn = nn.MultiheadAttention(
413
- stream_dim, n_heads, dropout=dropout, batch_first=True)
414
- self.std_norm3 = nn.LayerNorm(stream_dim)
415
- self.std_ffn = nn.Sequential(
416
- nn.Linear(stream_dim, stream_dim * 4), nn.GELU(),
417
- nn.Dropout(dropout),
418
- nn.Linear(stream_dim * 4, stream_dim), nn.Dropout(dropout))
419
-
420
- def forward(self, geo_stream, std_stream):
421
- # ── Geo: self-attend ──
422
- h = self.geo_norm1(geo_stream)
423
- h, _ = self.geo_self_attn(h, h, h, need_weights=False)
424
- geo_stream = geo_stream + h
425
-
426
- # ── Std: self-attend ──
427
- h = self.std_norm1(std_stream)
428
- h, _ = self.std_self_attn(h, h, h, need_weights=False)
429
- std_stream = std_stream + h
430
-
431
- # ── Bidirectional cross-attention ──
432
- # Geo attends to std
433
- g = self.geo_norm2(geo_stream)
434
- s = self.std_norm2(std_stream)
435
- g_cross, _ = self.geo_cross_attn(g, s, s, need_weights=False)
436
- # Std attends to geo
437
- s_cross, _ = self.std_cross_attn(s, g, g, need_weights=False)
438
-
439
- geo_stream = geo_stream + g_cross
440
- std_stream = std_stream + s_cross
441
-
442
- # ── FFN ──
443
- geo_stream = geo_stream + self.geo_ffn(self.geo_norm3(geo_stream))
444
- std_stream = std_stream + self.std_ffn(self.std_norm3(std_stream))
445
-
446
- return geo_stream, std_stream
447
-
448
-
449
- # ══════════════════════════════════════════════════════════════════
450
- # DUAL-STREAM VIT
451
- # ══════════════════════════════════════════════════════════════════
452
-
453
- class DualStreamViT(nn.Module):
454
- """
455
- GeoLIP Dual-Stream ViT β€” Decoupled Geometric + Classification Paths.
456
-
457
- Geometric path (InfoNCE/CV/CM β†’ dual blocks):
458
- patch_embed + pos β†’ split β†’ geo_stream, std_stream
459
- β†’ NΓ— DualStreamBlock (KSimplex + cross-attn)
460
- β†’ pool β†’ geo_emb, std_emb, emb on S^d
461
-
462
- Classification path (BCE β†’ cross blocks + classifier):
463
- dual block outputs.detach() [gradient wall]
464
- β†’ NΓ— CrossBlock (bidirectional cross-attn)
465
- β†’ pool β†’ class projections β†’ patchwork + classifier
466
-
467
- BCE cannot reach the dual blocks. The geometry forms under
468
- pure contrastive + geometric pressure. The cross blocks learn
469
- to read the geometry for classification without corrupting it.
470
- """
471
- def __init__(
472
- self,
473
- num_classes=10,
474
- img_size=32,
475
- patch_size=4,
476
- embed_dim=384,
477
- stream_dim=192,
478
- fused_dim=256,
479
- n_dual_blocks=2,
480
- n_fused_blocks=4,
481
- n_heads=8,
482
- output_dim=128,
483
- n_anchors=64,
484
- n_comp=8,
485
- d_comp=64,
486
- anchor_drop=0.10,
487
- cv_target=0.22,
488
- ksimplex_k=4,
489
- ksimplex_edim=8,
490
- dropout=0.1,
491
- infonce_temp=0.07,
492
- infonce_weight=1.0,
493
- bce_weight=1.0,
494
- cm_weight=0.1,
495
- cv_weight=0.01,
496
- autograd_tang=0.5,
497
- autograd_sep=0.1,
498
- enable_autograd=True,
499
- label_smoothing=0.1,
500
- second_guess_weight=0.5,
501
- second_guess_temp=0.04,
502
- second_guess_n_hard=16,
503
- cache_size=4096,
504
- ):
505
- super().__init__()
506
- self.num_classes = num_classes
507
- self.num_patches = (img_size // patch_size) ** 2
508
- self.stream_dim = stream_dim
509
- self.fused_dim = fused_dim # kept for config compat, not used in forward
510
- self.output_dim = output_dim
511
- self.cv_target = cv_target
512
- self.infonce_temp = infonce_temp
513
- self.infonce_weight = infonce_weight
514
- self.bce_weight = bce_weight
515
- self.cm_weight = cm_weight
516
- self.cv_weight = cv_weight
517
- self.autograd_tang = autograd_tang
518
- self.autograd_sep = autograd_sep
519
- self.enable_autograd = enable_autograd
520
- self.label_smoothing = label_smoothing
521
- self.second_guess_weight = second_guess_weight
522
- self.second_guess_temp = second_guess_temp
523
- self.second_guess_n_hard = second_guess_n_hard
524
-
525
- # Save config for checkpoint
526
- self.config = {k: v for k, v in locals().items()
527
- if k != 'self' and not k.startswith('_')}
528
-
529
- # ── Patch embedding ──
530
- self.patch_embed = nn.Conv2d(
531
- 3, embed_dim, kernel_size=patch_size, stride=patch_size)
532
- self.pos_embed = nn.Parameter(
533
- torch.zeros(1, self.num_patches, embed_dim))
534
- nn.init.trunc_normal_(self.pos_embed, std=0.02)
535
-
536
- # ── Stream projections ──
537
- self.geo_proj = nn.Sequential(
538
- nn.Linear(embed_dim, stream_dim), nn.LayerNorm(stream_dim))
539
- self.std_proj = nn.Sequential(
540
- nn.Linear(embed_dim, stream_dim), nn.LayerNorm(stream_dim))
541
-
542
- # ── Dual-stream blocks ──
543
- geo_dim = 11 # KSimplex output
544
- self.dual_blocks = nn.ModuleList([
545
- DualStreamBlock(stream_dim, geo_dim, n_heads,
546
- ksimplex_k, ksimplex_edim, dropout)
547
- for _ in range(n_dual_blocks)])
548
-
549
- # ── Cross-attention blocks (both streams preserved, bidirectional) ──
550
- self.cross_blocks = nn.ModuleList([
551
- CrossBlock(stream_dim, n_heads, dropout)
552
- for _ in range(n_fused_blocks)])
553
- self.geo_norm = nn.LayerNorm(stream_dim)
554
- self.std_norm = nn.LayerNorm(stream_dim)
555
-
556
- # ── Output projections: GEOMETRIC path (InfoNCE/CV/CM train these) ──
557
- self.output_proj = nn.Sequential(
558
- nn.Linear(stream_dim, output_dim),
559
- nn.LayerNorm(output_dim))
560
- self.geo_output_proj = nn.Sequential(
561
- nn.Linear(stream_dim, output_dim),
562
- nn.LayerNorm(output_dim))
563
-
564
- # ── Output projections: CLASSIFICATION path (BCE trains these) ──
565
- self.class_output_proj = nn.Sequential(
566
- nn.Linear(stream_dim, output_dim),
567
- nn.LayerNorm(output_dim))
568
- self.class_geo_output_proj = nn.Sequential(
569
- nn.Linear(stream_dim, output_dim),
570
- nn.LayerNorm(output_dim))
571
-
572
- # ── Constellation + Patchwork (on classification embeddings) ──
573
- self.constellation = Constellation(n_anchors, output_dim, anchor_drop)
574
- self.patchwork = Patchwork(n_anchors, n_comp, d_comp)
575
- pw_dim = n_comp * d_comp
576
-
577
- # ── Classifier: patchwork + class_geo_emb + class_std_emb ──
578
- self.classifier = nn.Sequential(
579
- nn.Linear(pw_dim + output_dim * 2, pw_dim), nn.GELU(),
580
- nn.LayerNorm(pw_dim), nn.Dropout(dropout),
581
- nn.Linear(pw_dim, num_classes))
582
 
583
- # ── Geo classifier: probe on geo_emb (detached β€” pure measurement) ──
584
- self.geo_classifier = nn.Sequential(
585
- nn.Linear(output_dim, output_dim), nn.GELU(),
586
- nn.Dropout(dropout),
587
- nn.Linear(output_dim, num_classes))
588
-
589
- self._init_weights()
590
-
591
- def _init_weights(self):
592
- for m in self.modules():
593
- if isinstance(m, nn.Linear):
594
- nn.init.trunc_normal_(m.weight, std=0.02)
595
- if m.bias is not None:
596
- nn.init.zeros_(m.bias)
597
- elif isinstance(m, nn.LayerNorm):
598
- nn.init.ones_(m.weight)
599
- nn.init.zeros_(m.bias)
600
-
601
- def forward(self, x, targets=None, apply_autograd=True):
602
- """
603
- Args:
604
- x: (B, 3, H, W)
605
- targets: (B,) class indices (optional, for loss)
606
- Returns:
607
- dict with logits, embedding, geo_feats, vol2, etc.
608
- """
609
- output = {}
610
- B = x.shape[0]
611
-
612
- # ── Patch embedding ──
613
- tokens = self.patch_embed(x).flatten(2).transpose(1, 2)
614
- tokens = tokens + self.pos_embed
615
- P = tokens.shape[1]
616
-
617
- # ── Split into two streams ──
618
- geo_stream = self.geo_proj(tokens) # (B, P, stream_dim)
619
- std_stream = self.std_proj(tokens) # (B, P, stream_dim)
620
-
621
- # ── Dual-stream blocks ──
622
- all_geo_feats = []
623
- all_vol2 = []
624
- for block in self.dual_blocks:
625
- geo_stream, std_stream, geo_feats, vol2 = block(
626
- geo_stream, std_stream)
627
- all_geo_feats.append(geo_feats)
628
- all_vol2.append(vol2)
629
-
630
- output['geo_feats'] = all_geo_feats[-1]
631
- output['all_geo_feats'] = torch.stack(all_geo_feats)
632
- output['vol2'] = torch.stack(all_vol2)
633
-
634
- # ════════════════════════════════════════════════════════
635
- # PATH A: GEOMETRIC (direct from dual blocks β†’ sphere)
636
- # InfoNCE + CV + CM + autograd shape these.
637
- # Gradients flow freely back into dual blocks.
638
- # This IS the geometric representation.
639
- # ════════════════════════════════════════════════════════
640
-
641
- geo_pooled = geo_stream.mean(dim=1)
642
- std_pooled = std_stream.mean(dim=1)
643
-
644
- geo_emb = F.normalize(self.geo_output_proj(geo_pooled), dim=-1)
645
- std_emb = F.normalize(self.output_proj(std_pooled), dim=-1)
646
- emb = F.normalize(geo_emb + std_emb, dim=-1)
647
-
648
- if (apply_autograd and self.training and self.enable_autograd):
649
- emb = EmbeddingAutograd.apply(
650
- emb, emb, self.constellation.anchors,
651
- self.autograd_tang, self.autograd_sep)
652
- geo_emb = EmbeddingAutograd.apply(
653
- geo_emb, geo_emb, self.constellation.anchors,
654
- self.autograd_tang, self.autograd_sep)
655
- std_emb = EmbeddingAutograd.apply(
656
- std_emb, std_emb, self.constellation.anchors,
657
- self.autograd_tang, self.autograd_sep)
658
-
659
- output['embedding'] = emb # for InfoNCE, CV, mastery
660
- output['geo_emb'] = geo_emb # for CV (geo), geo_div
661
- output['std_emb'] = std_emb
662
-
663
- # ════════════════════════════════════════════════════════
664
- # PATH B: CLASSIFICATION (through cross blocks, DETACHED)
665
- # BCE shapes cross blocks + classifier.
666
- # Gradient wall at detach β€” dual blocks never see BCE.
667
- # Cross blocks learn to READ the geometry, not WRITE it.
668
- # ════════════════════════════════════════════════════════
669
-
670
- geo_cross = geo_stream.detach() # ← gradient wall
671
- std_cross = std_stream.detach() # ← gradient wall
672
-
673
- for block in self.cross_blocks:
674
- geo_cross, std_cross = block(geo_cross, std_cross)
675
- geo_cross = self.geo_norm(geo_cross)
676
- std_cross = self.std_norm(std_cross)
677
-
678
- geo_class = F.normalize(
679
- self.class_geo_output_proj(geo_cross.mean(dim=1)), dim=-1)
680
- std_class = F.normalize(
681
- self.class_output_proj(std_cross.mean(dim=1)), dim=-1)
682
- emb_class = F.normalize(geo_class + std_class, dim=-1)
683
-
684
- output['emb_class'] = emb_class
685
- output['geo_class'] = geo_class
686
- output['std_class'] = std_class
687
-
688
- # Constellation + patchwork on classification embedding
689
- tri_full, nearest_full = self.constellation.triangulate(
690
- emb_class, training=False)
691
- pw = self.patchwork(tri_full)
692
- output['triangulation'] = tri_full
693
-
694
- if self.training:
695
- _, nearest = self.constellation.triangulate(emb_class, training=True)
696
- else:
697
- nearest = nearest_full
698
- output['nearest'] = nearest
699
-
700
- # Classifier reads classification-path embeddings
701
- logits = self.classifier(
702
- torch.cat([pw, geo_class, std_class], dim=-1))
703
- output['logits'] = logits
704
-
705
- # Geo classifier: probe on GEOMETRIC geo_emb (detached β€” pure measurement)
706
- geo_logits = self.geo_classifier(geo_emb.detach())
707
- output['geo_logits'] = geo_logits
708
-
709
- # ── Patch-level anchor tracking (no grad, uses geometric path) ──
710
  with torch.no_grad():
711
- geo_patch_embs = F.normalize(
712
- self.geo_output_proj(geo_stream.reshape(B * P, -1)), dim=-1)
713
- std_patch_embs = F.normalize(
714
- self.output_proj(std_stream.reshape(B * P, -1)), dim=-1)
715
- patch_embs = F.normalize(
716
- geo_patch_embs + std_patch_embs, dim=-1).reshape(B, P, -1)
717
- anchors_n = F.normalize(self.constellation.anchors, dim=-1)
718
- patch_cos = torch.einsum('bpd,ad->bpa', patch_embs, anchors_n)
719
- output['patch_nearest'] = patch_cos.argmax(dim=-1)
720
- output['patch_embs'] = patch_embs
721
-
722
- return output
723
-
724
- def compute_loss(self, output, targets, output_aug=None,
725
- mastery_queue=None):
726
- """
727
- Decoupled loss: geometric and classification gradients separated.
728
-
729
- GEOMETRIC PATH (trains dual blocks + geo projections):
730
- InfoNCE, CV, CM, geo_div, autograd, mastery
731
- Uses output['embedding'], output['geo_emb']
732
-
733
- CLASSIFICATION PATH (trains cross blocks + classifier):
734
- BCE on output['logits'] (from detached streams through cross blocks)
735
- Gradient wall at dual block boundary.
736
-
737
- GEO PROBE (trains only geo_classifier head):
738
- BCE on output['geo_logits'] (from geo_emb.detach())
739
- Pure measurement β€” does not shape any representation.
740
- """
741
- loss_dict = {}
742
- emb = output['embedding']
743
- B = emb.shape[0]
744
- is_mastery = mastery_queue is not None and mastery_queue.active
745
-
746
- # ── BCE classification (always primary, with label smoothing) ──
747
- one_hot = F.one_hot(targets, self.num_classes).float()
748
- # Label smoothing: 1.0 β†’ 0.9, 0.0 β†’ 0.1/(C-1)
749
- ls = self.label_smoothing
750
- if ls > 0:
751
- one_hot = one_hot * (1.0 - ls) + ls / self.num_classes
752
- l_bce = F.binary_cross_entropy_with_logits(output['logits'], one_hot)
753
- loss_dict['bce'] = l_bce
754
-
755
- # ── Geo classifier BCE (same smoothing) ──
756
- geo_logits = output.get('geo_logits')
757
- if geo_logits is not None:
758
- l_geo_bce = F.binary_cross_entropy_with_logits(geo_logits, one_hot)
759
- loss_dict['geo_bce'] = l_geo_bce
760
- geo_preds = geo_logits.argmax(-1)
761
- loss_dict['geo_acc'] = (geo_preds == targets).float().mean().item()
762
-
763
- # ── Geo diversity (prevent intra-class collapse) ──
764
- # Penalizes same-class geo embeddings from being too similar
765
- geo_emb = output.get('geo_emb')
766
- if geo_emb is not None and B > 4:
767
- geo_sim = geo_emb @ geo_emb.T # (B, B)
768
- same_class = targets.unsqueeze(0) == targets.unsqueeze(1)
769
- diag = torch.eye(B, dtype=torch.bool, device=emb.device)
770
- same_not_self = same_class & ~diag
771
- if same_not_self.any():
772
- # Penalize same-class cos > 0.8 (should have SOME variation)
773
- same_cos = geo_sim[same_not_self]
774
- l_geo_div = F.relu(same_cos - 0.8).mean()
775
- loss_dict['geo_div'] = l_geo_div
776
-
777
- # ── InfoNCE: ALWAYS active at full weight ──
778
- # The bidirectional cross-attention preserves structure;
779
- # InfoNCE maintains the spreading force at all times.
780
- nce_acc = 0.0
781
- if output_aug is not None:
782
- emb_aug = output_aug['embedding']
783
- sim = emb @ emb_aug.T / self.infonce_temp
784
- labels_nce = torch.arange(B, device=emb.device)
785
- l_nce = F.cross_entropy(sim, labels_nce)
786
- nce_acc = (sim.argmax(1) == labels_nce).float().mean().item()
787
- loss_dict['nce'] = l_nce
788
- loss_dict['nce_acc'] = nce_acc
789
-
790
- # ── Mastery clause (progressive margin) ──
791
- if is_mastery:
792
- q_emb, q_labels = mastery_queue.get()
793
- if q_emb is not None and q_emb.shape[0] >= B:
794
- cross_sim = emb @ q_emb.T # (B, Q)
795
-
796
- same_class_mask = targets.unsqueeze(1) == q_labels.unsqueeze(0)
797
- hard_neg_sim = cross_sim.clone()
798
- hard_neg_sim[same_class_mask] = -1e9
799
- hard_neg_cos = hard_neg_sim.max(dim=1).values
800
-
801
- hard_pos_sim = cross_sim.clone()
802
- hard_pos_sim[~same_class_mask] = 1e9
803
- hard_pos_cos = hard_pos_sim.min(dim=1).values
804
-
805
- has_same = same_class_mask.any(dim=1)
806
- has_diff = (~same_class_mask).any(dim=1)
807
- valid = has_same & has_diff
808
-
809
- if valid.sum() > 0:
810
- # Progressive margin: grows as hard_pos improves
811
- margin = mastery_queue.current_margin
812
- l_mastery = F.relu(
813
- hard_neg_cos[valid] - hard_pos_cos[valid] + margin
814
- ).mean()
815
- loss_dict['mastery'] = l_mastery
816
- loss_dict['hard_neg_cos'] = hard_neg_cos[valid].mean().item()
817
- loss_dict['hard_pos_cos'] = hard_pos_cos[valid].mean().item()
818
- loss_dict['margin'] = margin
819
-
820
- mastery_queue.push(emb.detach(), targets.detach())
821
-
822
- # ── CM validity ──
823
- vol2 = output['vol2']
824
- l_cm = F.relu(-vol2).mean()
825
- loss_dict['cm'] = l_cm
826
- loss_dict['cm_valid'] = (vol2 > 0).float().mean().item()
827
-
828
- # ── CV loss on BOTH streams ──
829
- l_cv_fused = self._cv_loss_fast(emb, target=self.cv_target)
830
- geo_emb = output.get('geo_emb')
831
- if geo_emb is not None:
832
- l_cv_geo = self._cv_loss_fast(geo_emb, target=self.cv_target)
833
- else:
834
- l_cv_geo = torch.tensor(0.0, device=emb.device)
835
- l_cv = l_cv_fused + l_cv_geo
836
- loss_dict['cv'] = l_cv
837
- loss_dict['cv_fused'] = l_cv_fused.item() if torch.is_tensor(l_cv_fused) else l_cv_fused
838
- loss_dict['cv_geo'] = l_cv_geo.item() if torch.is_tensor(l_cv_geo) else l_cv_geo
839
-
840
- # ── Anchor CV (dedicated, separate from embedding CV) ──
841
- anchors_n = F.normalize(self.constellation.anchors, dim=-1)
842
- l_anchor_cv = self._cv_loss_fast(anchors_n, target=self.cv_target)
843
- loss_dict['anchor_cv'] = l_anchor_cv
844
-
845
- # ── Anchor spread (prevent clustering, lighter than before) ──
846
- anchor_sim = anchors_n @ anchors_n.T
847
- mask_a = ~torch.eye(anchors_n.shape[0], dtype=torch.bool,
848
- device=anchors_n.device)
849
- l_spread = F.relu(anchor_sim[mask_a] - 0.0).mean()
850
- loss_dict['spread'] = l_spread
851
-
852
- # ── Combine ──
853
- loss = (l_bce * self.bce_weight
854
- + loss_dict.get('geo_bce', 0.0) * 0.3
855
- + loss_dict.get('geo_div', 0.0) * 0.5
856
- + loss_dict.get('nce', 0.0) * self.infonce_weight
857
- + loss_dict.get('mastery', 0.0) * self.bce_weight
858
- + l_cm * self.cm_weight
859
- + l_cv * self.cv_weight
860
- + l_anchor_cv * self.cv_weight * 0.5
861
- + l_spread * 0.001)
862
-
863
- loss_dict['total'] = loss
864
- return loss, loss_dict
865
-
866
- @staticmethod
867
- def _cv_loss_fast(emb, target=0.22, n_samples=64, n_points=5):
868
- """Fast differentiable CV loss from random pentachora."""
869
- B = emb.shape[0]
870
- if B < n_points:
871
- return torch.tensor(0.0, device=emb.device)
872
  vols = []
873
- for _ in range(n_samples):
874
- idx = torch.randperm(min(B, 512), device=emb.device)[:n_points]
875
- pts = emb[idx].unsqueeze(0) # (1, 5, D)
876
  gram = torch.bmm(pts, pts.transpose(1, 2))
877
  norms = torch.diagonal(gram, dim1=1, dim2=2)
878
  d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram
879
  d2 = F.relu(d2)
880
- N = n_points
881
- cm = torch.zeros(1, N + 1, N + 1,
882
- device=emb.device, dtype=emb.dtype)
883
  cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
884
- k = N - 1
885
- sign = (-1.0) ** (k + 1)
886
- fact = math.factorial(k)
887
- prefactor = sign / ((2.0 ** k) * (fact ** 2))
888
- vol2 = prefactor * torch.linalg.det(cm.float())
889
- if vol2[0].item() > 1e-20:
890
- vols.append(vol2[0].to(emb.dtype).sqrt())
891
- if len(vols) < 5:
892
- return torch.tensor(0.0, device=emb.device)
893
- vols_t = torch.stack(vols)
894
- cv = vols_t.std() / (vols_t.mean() + 1e-8)
895
- return (cv - target).pow(2)
896
-
897
-
898
- # ══════════════════════════════════════════════════════════════════
899
- # MASTERY QUEUE β€” Progressive cross-batch hard contrastive
900
- # ══════════════════════════════════════════════════════════════════
901
-
902
- class MasteryQueue:
903
- """
904
- Cross-batch embedding cache with adaptive queue sizing.
905
-
906
- Activation: when nce_acc >= 0.99 for `patience` consecutive batches.
907
- Progressive margin: ramps from margin_start β†’ margin_end over margin_warmup.
908
-
909
- Adaptive queue sizing (call update_size each epoch):
910
- Dual trigger with cooldown:
911
- 1. ABSOLUTE: gap > 3Γ— threshold β†’ grow (strongly overfitting)
912
- gap < 1Γ— threshold β†’ shrink (well-balanced)
913
- 2. DRIFT: gap grew > threshold over 5-epoch window β†’ grow
914
- gap shrank > threshold over 5-epoch window β†’ shrink
915
- Cooldown: no resize for `resize_cooldown` epochs after each change.
916
- """
917
- def __init__(self, dim, min_size=1024, max_size=8192, initial_size=4096,
918
- patience=50, device='cuda',
919
- margin_start=0.1, margin_end=0.3, margin_warmup=5000,
920
- resize_step=1024, resize_cooldown=5, overfit_threshold=3.0):
921
- self.dim = dim
922
- self.min_size = min_size
923
- self.max_size = max_size
924
- self._current_max = initial_size
925
- self.patience = patience
926
- self.device = device
927
- self.active = False
928
-
929
- # Queue storage
930
- self._embs = None
931
- self._labels = None
932
-
933
- # Activation tracking
934
- self._perfect_count = 0
935
- self._total_batches = 0
936
- self._activated_at = None
937
-
938
- # Progressive margin
939
- self._margin_start = margin_start
940
- self._margin_end = margin_end
941
- self._margin_warmup = margin_warmup
942
- self._mastery_steps = 0
943
-
944
- # Adaptive sizing
945
- self._resize_step = resize_step
946
- self._resize_cooldown = resize_cooldown
947
- self._overfit_threshold = overfit_threshold
948
- self._epochs_since_resize = resize_cooldown # allow first resize
949
- self._gap_history = [] # rolling window of (epoch, gap) pairs
950
- self._gap_window = 5 # look back this many epochs for drift
951
- self._resize_history = []
952
-
953
- def check_activation(self, nce_acc):
954
- """Call each batch. Activates when nce_acc >= 0.99 for patience steps."""
955
- self._total_batches += 1
956
- if nce_acc >= 0.99:
957
- self._perfect_count += 1
958
  else:
959
- self._perfect_count = 0
960
-
961
- if not self.active and self._perfect_count >= self.patience:
962
- self.active = True
963
- self._activated_at = self._total_batches
964
- print(f"\n β˜… MASTERY ACTIVATED at batch {self._total_batches} "
965
- f"(nce_acc=1.0 for {self.patience} consecutive) "
966
- f"[InfoNCE stays ON, margin {self._margin_start}β†’{self._margin_end}]"
967
- f" queue={self._current_max}")
968
-
969
- if self.active:
970
- self._mastery_steps += 1
971
-
972
- def update_size(self, train_acc, val_acc, epoch):
973
- """
974
- Adjusts queue size based on overfit gap. Dual trigger:
975
-
976
- 1. ABSOLUTE: gap > threshold β†’ grow queue
977
- gap < threshold/2 β†’ shrink queue
978
- 2. DRIFT: gap grew > threshold over rolling window β†’ grow queue
979
- gap shrank > threshold over rolling window β†’ shrink queue
980
-
981
- Cooldown prevents oscillation: no resize for `resize_cooldown` epochs.
982
- """
983
- if not self.active:
984
- return
985
-
986
- self._epochs_since_resize += 1
987
- gap = train_acc - val_acc
988
- self._gap_history.append((epoch, gap))
989
-
990
- if self._epochs_since_resize < self._resize_cooldown:
991
- return
992
-
993
- old_size = self._current_max
994
- reason = None
995
-
996
- # ── Trigger 1: Absolute gap ──
997
- if gap > self._overfit_threshold * 3:
998
- # Gap > 9% (3Γ— threshold) β€” strongly overfitting, grow queue
999
- self._current_max = min(
1000
- self._current_max + self._resize_step, self.max_size)
1001
- reason = f"abs gap={gap:.1f}%"
1002
- elif gap < self._overfit_threshold:
1003
- # Gap < 3% β€” underfitting or well-balanced, shrink for sharper signal
1004
- self._current_max = max(
1005
- self._current_max - self._resize_step, self.min_size)
1006
- reason = f"abs gap={gap:.1f}%"
1007
-
1008
- # ── Trigger 2: Drift over rolling window ──
1009
- if reason is None and len(self._gap_history) >= self._gap_window:
1010
- window_start = self._gap_history[-self._gap_window][1]
1011
- drift = gap - window_start
1012
- if drift > self._overfit_threshold:
1013
- # Gap grew by threshold over window β€” overfitting accelerating
1014
- self._current_max = min(
1015
- self._current_max + self._resize_step, self.max_size)
1016
- reason = f"drift={drift:+.1f}% over {self._gap_window}ep"
1017
- elif drift < -self._overfit_threshold:
1018
- # Gap shrank by threshold over window β€” can tighten
1019
- self._current_max = max(
1020
- self._current_max - self._resize_step, self.min_size)
1021
- reason = f"drift={drift:+.1f}% over {self._gap_window}ep"
1022
-
1023
- if self._current_max != old_size:
1024
- direction = "↑" if self._current_max > old_size else "↓"
1025
- print(f" βš™ Queue {direction} {old_size}β†’{self._current_max} "
1026
- f"({reason})")
1027
- self._epochs_since_resize = 0
1028
- self._resize_history.append(
1029
- (epoch, old_size, self._current_max, gap, reason))
1030
-
1031
- # Trim queue if it shrunk
1032
- if self._embs is not None and self._embs.shape[0] > self._current_max:
1033
- self._embs = self._embs[-self._current_max:]
1034
- self._labels = self._labels[-self._current_max:]
1035
-
1036
- @property
1037
- def current_margin(self):
1038
- if not self.active:
1039
- return self._margin_start
1040
- t = min(self._mastery_steps / max(self._margin_warmup, 1), 1.0)
1041
- return self._margin_start + t * (self._margin_end - self._margin_start)
1042
-
1043
- def push(self, emb, labels):
1044
- """Add batch to queue. FIFO eviction at current_max."""
1045
- emb = emb.detach().to(self.device)
1046
- labels = labels.detach().to(self.device)
1047
-
1048
- if self._embs is None:
1049
- self._embs = emb
1050
- self._labels = labels
1051
- else:
1052
- self._embs = torch.cat([self._embs, emb], 0)[-self._current_max:]
1053
- self._labels = torch.cat([self._labels, labels], 0)[-self._current_max:]
1054
-
1055
- def get(self):
1056
- if self._embs is None:
1057
- return None, None
1058
- return self._embs, self._labels
1059
-
1060
- @property
1061
- def size(self):
1062
- return 0 if self._embs is None else self._embs.shape[0]
1063
-
1064
- def state_dict(self):
1065
- return {
1066
- 'active': self.active,
1067
- 'perfect_count': self._perfect_count,
1068
- 'total_batches': self._total_batches,
1069
- 'activated_at': self._activated_at,
1070
- 'mastery_steps': self._mastery_steps,
1071
- 'current_margin': self.current_margin,
1072
- 'current_max': self._current_max,
1073
- 'gap_history': self._gap_history[-20:], # last 20 entries
1074
- 'resize_history': self._resize_history,
1075
- }
1076
-
1077
-
1078
- # ══════════════════════════════════════════════════════════════════
1079
- # FACTORY
1080
- # ══════════════════════════════════════════════════════════════════
1081
-
1082
- def create_dual_stream_vit(**kwargs):
1083
- return DualStreamViT(**kwargs)
 
1
  #!/usr/bin/env python3
2
  """
3
+ CIFAR-10 β€” Dual-Stream GeoLIP ViT β€” Experiment 6
4
+ ==================================================
5
+ Full bidirectional. 3Γ— DualBlock + 6Γ— CrossBlock.
6
+ Wider sphere: 256-d embeddings, 128 anchors, 16Γ—128 patchwork.
7
+ Adaptive mastery queue: grows/shrinks based on overfit gap with cooldown.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  """
9
 
10
  import torch
11
  import torch.nn as nn
12
  import torch.nn.functional as F
13
+ import os, time
14
+ import numpy as np
15
+ from tqdm import tqdm
16
+ from torchvision import datasets, transforms
17
+ from torch.utils.tensorboard import SummaryWriter
18
 
19
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
20
+ torch.backends.cuda.matmul.allow_tf32 = True
21
+ torch.backends.cudnn.allow_tf32 = True
22
+
23
+ # ── Architecture ──
24
+ NUM_CLASSES = 10
25
+ IMG_SIZE = 32
26
+ PATCH_SIZE = 4
27
+ EMBED_DIM = 384
28
+ STREAM_DIM = 192
29
+ FUSED_DIM = 256 # unused in bidirectional, kept for config compat
30
+ N_DUAL_BLOCKS = 3 # ↑ from 2 β€” more geometric processing
31
+ N_CROSS_BLOCKS = 6 # ↑ from 4 β€” deeper bidirectional cooperation
32
+ N_HEADS = 8
33
+ OUTPUT_DIM = 256 # ↑ from 128 β€” wider hypersphere
34
+ N_ANCHORS = 128 # ↑ from 64 β€” denser navigation frame
35
+ N_COMP = 16 # ↑ from 8 β€” more patchwork compartments
36
+ D_COMP = 128 # ↑ from 64 β€” richer per-anchor representation
37
+ ANCHOR_DROP = 0.10
38
+ CV_TARGET = 0.22
39
+
40
+ # ── Loss weights ──
41
+ CV_WEIGHT = 0.1
42
+ ENABLE_AUTOGRAD = True
43
+ AUTOGRAD_TANG = 1.0
44
+ AUTOGRAD_SEP = 0.1
45
+ LABEL_SMOOTHING = 0.1
46
+ INFONCE_WEIGHT = 0.1
47
+ BCE_WEIGHT = 1.0
48
+ CM_WEIGHT = 0.1
49
+ INFONCE_TEMP = 0.07
50
+
51
+ # ── Mastery queue ──
52
+ MASTERY_PATIENCE = 50
53
+ MASTERY_MARGIN_START = 0.1
54
+ MASTERY_MARGIN_END = 0.3
55
+ MASTERY_MARGIN_WARMUP = 5000
56
+ MASTERY_MIN_SIZE = 1024
57
+ MASTERY_MAX_SIZE = 16384
58
+ MASTERY_INITIAL_SIZE = 4096
59
+ MASTERY_RESIZE_STEP = 2048
60
+ MASTERY_RESIZE_COOLDOWN = 5 # epochs between resizes
61
+ MASTERY_OVERFIT_THRESH = 3.0 # abs trigger at 3Γ—, drift trigger at 1Γ—
62
+
63
+ # ── Training ──
64
+ BATCH = 1024
65
+ EPOCHS = 100
66
+ LR = 3e-4
67
+ WARMUP = 5
68
+ GRAD_CLIP = 1.0
69
+
70
+ # No warm start
71
+ V1_CKPT = ""
72
+
73
+ print("=" * 60)
74
+ print("CIFAR-10 β€” Dual-Stream GeoLIP ViT β€” EXP 6")
75
+ print(f" From scratch, {EPOCHS} epochs, lr={LR}")
76
+ print(f" Architecture: {N_DUAL_BLOCKS}Γ— DualBlock + {N_CROSS_BLOCKS}Γ— CrossBlock")
77
+ print(f" Sphere: {OUTPUT_DIM}-d emb, {N_ANCHORS} anchors, "
78
+ f"{N_COMP}Γ—{D_COMP} patchwork")
79
+ print(f" InfoNCE={INFONCE_WEIGHT} β€” ALWAYS ON")
80
+ print(f" CV={CV_WEIGHT}, autograd tang={AUTOGRAD_TANG}")
81
+ print(f" Mastery: patience={MASTERY_PATIENCE}, "
82
+ f"margin {MASTERY_MARGIN_START}β†’{MASTERY_MARGIN_END}, "
83
+ f"queue {MASTERY_INITIAL_SIZE} [{MASTERY_MIN_SIZE}–{MASTERY_MAX_SIZE}]")
84
+ print(f" Queue resize: step={MASTERY_RESIZE_STEP}, "
85
+ f"cooldown={MASTERY_RESIZE_COOLDOWN}ep, "
86
+ f"abs>{MASTERY_OVERFIT_THRESH*3:.0f}%/drift>{MASTERY_OVERFIT_THRESH:.0f}%")
87
+ print(f" Device: {DEVICE}")
88
+ print("=" * 60)
89
 
90
  # ══════════════════════════════════════════════════════════════════
91
+ # DATA
92
  # ══════════════════════════════════════════════════════════════════
93
 
94
+ CIFAR_MEAN = (0.4914, 0.4822, 0.4465)
95
+ CIFAR_STD = (0.2470, 0.2435, 0.2616)
96
+
97
+ train_transform = transforms.Compose([
98
+ transforms.RandomCrop(32, padding=4),
99
+ transforms.RandomHorizontalFlip(),
100
+ transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
101
+ transforms.RandomGrayscale(p=0.2),
102
+ transforms.ToTensor(),
103
+ transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
104
+ ])
105
+ val_transform = transforms.Compose([
106
+ transforms.ToTensor(),
107
+ transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
108
+ ])
109
+
110
+
111
+ class TwoViewDataset(torch.utils.data.Dataset):
112
+ def __init__(self, base_ds, transform):
113
+ self.base = base_ds
114
+ self.transform = transform
115
+
116
+ def __len__(self):
117
+ return len(self.base)
118
+
119
+ def __getitem__(self, idx):
120
+ img, label = self.base.data[idx], self.base.targets[idx]
121
+ from PIL import Image
122
+ img = Image.fromarray(img)
123
+ return self.transform(img), self.transform(img), label
124
+
125
+
126
+ raw_train = datasets.CIFAR10(root='./data', train=True, download=True)
127
+ train_ds = TwoViewDataset(raw_train, train_transform)
128
+ val_ds = datasets.CIFAR10(root='./data', train=False,
129
+ download=True, transform=val_transform)
130
+
131
+ train_loader = torch.utils.data.DataLoader(
132
+ train_ds, batch_size=BATCH, shuffle=True,
133
+ num_workers=2, pin_memory=True, drop_last=True)
134
+ val_loader = torch.utils.data.DataLoader(
135
+ val_ds, batch_size=BATCH, shuffle=False,
136
+ num_workers=2, pin_memory=True)
137
+
138
+ CIFAR_CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer',
139
+ 'dog', 'frog', 'horse', 'ship', 'truck']
140
+ print(f" Train: {len(train_ds):,} (two views) Val: {len(val_ds):,}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  # ══════════════════════════════════════════════════════════════════
143
+ # BUILD MODEL + WARM START
144
  # ══════════════════════════════════════════════════════════════════
145
 
146
+ print(f"\n Building model...")
147
+ model = create_dual_stream_vit(
148
+ num_classes=NUM_CLASSES, img_size=IMG_SIZE, patch_size=PATCH_SIZE,
149
+ embed_dim=EMBED_DIM, stream_dim=STREAM_DIM, fused_dim=FUSED_DIM,
150
+ n_dual_blocks=N_DUAL_BLOCKS, n_fused_blocks=N_CROSS_BLOCKS,
151
+ n_heads=N_HEADS, output_dim=OUTPUT_DIM,
152
+ n_anchors=N_ANCHORS, n_comp=N_COMP, d_comp=D_COMP,
153
+ anchor_drop=ANCHOR_DROP, cv_target=CV_TARGET,
154
+ dropout=0.1, infonce_temp=INFONCE_TEMP,
155
+ infonce_weight=INFONCE_WEIGHT, bce_weight=BCE_WEIGHT,
156
+ cm_weight=CM_WEIGHT, cv_weight=CV_WEIGHT,
157
+ autograd_tang=AUTOGRAD_TANG, autograd_sep=AUTOGRAD_SEP,
158
+ enable_autograd=ENABLE_AUTOGRAD,
159
+ label_smoothing=LABEL_SMOOTHING,
160
+ ).to(DEVICE)
161
+
162
+ # Optional warm start
163
+ if V1_CKPT and os.path.exists(V1_CKPT):
164
+ ckpt = torch.load(V1_CKPT, map_location="cpu", weights_only=False)
165
+ model.load_state_dict(ckpt["state_dict"], strict=False)
166
+ print(f" βœ“ Loaded weights: epoch {ckpt['epoch']}, "
167
+ f"val_acc {ckpt['val_acc']:.1f}%")
168
+ else:
169
+ print(f" Training from scratch")
170
+
171
+ n_params = sum(p.numel() for p in model.parameters())
172
+
173
+ # Param groups: geo params get separate tracking
174
+ geo_names = {'geo_proj', 'dual_blocks', 'constellation', 'patchwork'}
175
+ geo_params, std_params = [], []
176
+ for name, param in model.named_parameters():
177
+ if not param.requires_grad:
178
+ continue
179
+ if any(gn in name for gn in geo_names):
180
+ geo_params.append(param)
181
+ else:
182
+ std_params.append(param)
183
+
184
+ n_geo = sum(p.numel() for p in geo_params)
185
+ n_std = sum(p.numel() for p in std_params)
186
+ print(f" Parameters: {n_params:,}")
187
+ print(f" Geo route: {n_geo:,} ({100*n_geo/n_params:.1f}%)")
188
+ print(f" Std route: {n_std:,} ({100*n_std/n_params:.1f}%)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
  # ══════════════════════════════════════════════════════════════════
191
+ # TRAINING
192
  # ══════════════════════════════════════════════════════════════════
193
 
194
+ print(f"\n{'='*60}")
195
+ print(f"TRAINING β€” {EPOCHS} epochs, lr={LR}, batch={BATCH}")
196
+ print(f" Architecture: {N_DUAL_BLOCKS}Γ— DualBlock + {N_CROSS_BLOCKS}Γ— CrossBlock")
197
+ print(f" Sphere: {OUTPUT_DIM}-d, {N_ANCHORS} anchors, {N_COMP}Γ—{D_COMP} patchwork")
198
+ print(f" CV={CV_WEIGHT}, autograd={'ON' if ENABLE_AUTOGRAD else 'OFF'} "
199
+ f"(tang={AUTOGRAD_TANG})")
200
+ print(f" Mastery: patience={MASTERY_PATIENCE}, "
201
+ f"margin {MASTERY_MARGIN_START}β†’{MASTERY_MARGIN_END}, "
202
+ f"queue {MASTERY_INITIAL_SIZE} adaptive [{MASTERY_MIN_SIZE}–{MASTERY_MAX_SIZE}]")
203
+ print(f" InfoNCE={INFONCE_WEIGHT}, Geo cls=0.3, Geo div=0.5, LS={LABEL_SMOOTHING}")
204
+ print(f" Optimizer: AdamW (wd=0.01)")
205
+ print(f"{'='*60}")
206
+
207
+ optimizer = torch.optim.AdamW([
208
+ {'params': geo_params, 'lr': LR},
209
+ {'params': std_params, 'lr': LR},
210
+ ], lr=LR, weight_decay=0.01)
211
+
212
+ total_steps = len(train_loader) * EPOCHS
213
+ warmup_steps = len(train_loader) * WARMUP
214
+ scheduler = torch.optim.lr_scheduler.SequentialLR(
215
+ optimizer,
216
+ [torch.optim.lr_scheduler.LinearLR(
217
+ optimizer, start_factor=0.01, total_iters=warmup_steps),
218
+ torch.optim.lr_scheduler.CosineAnnealingLR(
219
+ optimizer, T_max=max(total_steps - warmup_steps, 1), eta_min=1e-6)],
220
+ milestones=[warmup_steps])
221
+
222
+ scaler = torch.amp.GradScaler("cuda")
223
+ os.makedirs("checkpoints", exist_ok=True)
224
+ writer = SummaryWriter("runs/cifar10_dual_stream_v6_wide")
225
+ best_acc = 0.0
226
+ gs = 0
227
+
228
+ # Mastery queue with adaptive sizing
229
+ mastery = MasteryQueue(
230
+ dim=OUTPUT_DIM,
231
+ min_size=MASTERY_MIN_SIZE,
232
+ max_size=MASTERY_MAX_SIZE,
233
+ initial_size=MASTERY_INITIAL_SIZE,
234
+ patience=MASTERY_PATIENCE,
235
+ device=DEVICE,
236
+ margin_start=MASTERY_MARGIN_START,
237
+ margin_end=MASTERY_MARGIN_END,
238
+ margin_warmup=MASTERY_MARGIN_WARMUP,
239
+ resize_step=MASTERY_RESIZE_STEP,
240
+ resize_cooldown=MASTERY_RESIZE_COOLDOWN,
241
+ overfit_threshold=MASTERY_OVERFIT_THRESH,
242
+ )
243
+
244
+ for epoch in range(EPOCHS):
245
+ model.train()
246
+ t0 = time.time()
247
+
248
+ acc_dict = {"loss": 0, "bce": 0, "geo_bce": 0, "geo_acc": 0, "geo_div": 0,
249
+ "nce": 0, "nce_acc": 0,
250
+ "cm": 0, "cm_valid": 0, "cv": 0, "cv_fused": 0, "cv_geo": 0,
251
+ "anchor_cv": 0, "spread": 0,
252
+ "mastery": 0, "hard_neg": 0, "hard_pos": 0, "margin": 0,
253
+ "correct": 0, "total": 0, "n": 0}
254
+
255
+ pbar = tqdm(train_loader, desc=f"E{epoch+1:3d}/{EPOCHS}", unit="batch")
256
+ for v1, v2, labels in pbar:
257
+ v1 = v1.to(DEVICE, non_blocking=True)
258
+ v2 = v2.to(DEVICE, non_blocking=True)
259
+ labels = labels.to(DEVICE, non_blocking=True)
260
+
261
+ with torch.amp.autocast("cuda", dtype=torch.bfloat16):
262
+ out1 = model(v1, targets=labels)
263
+ out2 = model(v2, targets=labels)
264
+ loss, ld = model.compute_loss(
265
+ out1, labels, output_aug=out2, mastery_queue=mastery)
266
+
267
+ # Check mastery activation
268
+ mastery.check_activation(ld.get('nce_acc', 0))
269
+
270
+ scaler.scale(loss).backward()
271
+ scaler.unscale_(optimizer)
272
+ torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
273
+ scaler.step(optimizer)
274
+ scaler.update()
275
+ optimizer.zero_grad(set_to_none=True)
276
+ scheduler.step()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  with torch.no_grad():
279
+ preds = out1['logits'].argmax(dim=-1)
280
+ acc_dict["correct"] += (preds == labels).sum().item()
281
+ acc_dict["total"] += labels.shape[0]
282
+
283
+ acc_dict["loss"] += loss.item()
284
+ for k in ["bce", "geo_bce", "geo_div", "nce", "cm", "cv", "spread", "mastery", "anchor_cv"]:
285
+ v = ld.get(k, 0)
286
+ acc_dict[k] += v.item() if torch.is_tensor(v) else v
287
+ acc_dict["nce_acc"] += ld.get("nce_acc", 0)
288
+ acc_dict["cm_valid"] += ld.get("cm_valid", 0)
289
+ acc_dict["hard_neg"] += ld.get("hard_neg_cos", 0)
290
+ acc_dict["hard_pos"] += ld.get("hard_pos_cos", 0)
291
+ acc_dict["cv_fused"] += ld.get("cv_fused", 0)
292
+ acc_dict["cv_geo"] += ld.get("cv_geo", 0)
293
+ acc_dict["geo_acc"] += ld.get("geo_acc", 0)
294
+ acc_dict["margin"] += ld.get("margin", 0)
295
+ acc_dict["n"] += 1; gs += 1
296
+
297
+ if gs % 20 == 0:
298
+ writer.add_scalar("step/loss", loss.item(), gs)
299
+ writer.add_scalar("step/geo_acc", ld.get("geo_acc", 0), gs)
300
+ if mastery.active:
301
+ writer.add_scalar("step/mastery",
302
+ ld.get("mastery", torch.tensor(0)).item()
303
+ if torch.is_tensor(ld.get("mastery", 0))
304
+ else ld.get("mastery", 0), gs)
305
+ writer.add_scalar("step/margin", mastery.current_margin, gs)
306
+
307
+ if acc_dict["n"] % 10 == 0:
308
+ d = acc_dict["n"]
309
+ train_acc = 100 * acc_dict["correct"] / acc_dict["total"]
310
+ cvf = acc_dict["cv_fused"] / d
311
+ cvg = acc_dict["cv_geo"] / d
312
+ cmv = acc_dict["cm_valid"] / d
313
+ mst = acc_dict["mastery"] / d
314
+ ga = 100 * acc_dict["geo_acc"] / d
315
+ stage = "M" if mastery.active else "S1"
316
+ pbar.set_postfix(
317
+ loss=f"{acc_dict['loss']/d:.4f}",
318
+ acc=f"{train_acc:.1f}%",
319
+ ga=f"{ga:.0f}%",
320
+ cvf=f"{cvf:.4f}",
321
+ mst=f"{mst:.3f}",
322
+ mrg=f"{mastery.current_margin:.2f}",
323
+ stg=stage,
324
+ ordered=True)
325
+
326
+ elapsed = time.time() - t0
327
+ d = max(acc_dict["n"], 1)
328
+ train_acc = 100 * acc_dict["correct"] / acc_dict["total"]
329
+
330
+ writer.add_scalar("epoch/train_loss", acc_dict["loss"] / d, epoch + 1)
331
+ writer.add_scalar("epoch/train_acc", train_acc, epoch + 1)
332
+ writer.add_scalar("epoch/geo_acc", 100 * acc_dict["geo_acc"] / d, epoch + 1)
333
+ writer.add_scalar("epoch/geo_div", acc_dict["geo_div"] / d, epoch + 1)
334
+ writer.add_scalar("epoch/nce_acc", acc_dict["nce_acc"] / d, epoch + 1)
335
+ writer.add_scalar("epoch/cv_loss", acc_dict["cv"] / d, epoch + 1)
336
+ writer.add_scalar("epoch/cv_fused", acc_dict["cv_fused"] / d, epoch + 1)
337
+ writer.add_scalar("epoch/cv_geo", acc_dict["cv_geo"] / d, epoch + 1)
338
+ writer.add_scalar("epoch/anchor_cv", acc_dict["anchor_cv"] / d, epoch + 1)
339
+ writer.add_scalar("epoch/cm_valid", acc_dict["cm_valid"] / d, epoch + 1)
340
+ writer.add_scalar("epoch/margin", mastery.current_margin, epoch + 1)
341
+
342
+ # ── Validation ──
343
+ model.eval()
344
+ val_correct, val_total, val_loss_sum, val_n = 0, 0, 0, 0
345
+ val_geo_correct = 0
346
+ all_embs = []
347
+
348
+ with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16):
349
+ for images, labels_v in val_loader:
350
+ images = images.to(DEVICE, non_blocking=True)
351
+ labels_v = labels_v.to(DEVICE, non_blocking=True)
352
+ out = model(images, apply_autograd=False)
353
+ preds = out['logits'].argmax(dim=-1)
354
+ val_correct += (preds == labels_v).sum().item()
355
+ if 'geo_logits' in out:
356
+ geo_preds = out['geo_logits'].argmax(dim=-1)
357
+ val_geo_correct += (geo_preds == labels_v).sum().item()
358
+ val_total += labels_v.shape[0]
359
+ one_hot = F.one_hot(labels_v, NUM_CLASSES).float()
360
+ loss_v = F.binary_cross_entropy_with_logits(out['logits'], one_hot)
361
+ val_loss_sum += loss_v.item()
362
+ val_n += 1
363
+ all_embs.append(out['embedding'].float().cpu())
364
+
365
+ val_acc = 100 * val_correct / val_total
366
+ val_geo_acc = 100 * val_geo_correct / val_total
367
+ val_loss = val_loss_sum / max(val_n, 1)
368
+
369
+ # Quick CV check on val embeddings
370
+ embs = torch.cat(all_embs)
371
+ with torch.no_grad():
372
+ sample = embs[:2000].to(DEVICE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
  vols = []
374
+ for _ in range(200):
375
+ idx = torch.randperm(2000)[:5]
376
+ pts = sample[idx].unsqueeze(0).float()
377
  gram = torch.bmm(pts, pts.transpose(1, 2))
378
  norms = torch.diagonal(gram, dim1=1, dim2=2)
379
  d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram
380
  d2 = F.relu(d2)
381
+ cm = torch.zeros(1, 6, 6, device=DEVICE, dtype=torch.float32)
 
 
382
  cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
383
+ v2 = -torch.linalg.det(cm) / 9216
384
+ if v2[0].item() > 1e-20:
385
+ vols.append(v2[0].sqrt())
386
+ if len(vols) > 10:
387
+ vols_t = torch.stack(vols)
388
+ v_cv = (vols_t.std() / (vols_t.mean() + 1e-8)).item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
  else:
390
+ v_cv = 0.0
391
+
392
+ # Anchor utilization
393
+ with torch.no_grad():
394
+ _, v_np = model.constellation.triangulate(
395
+ embs[:2000].to(DEVICE), training=False)
396
+ n_active = v_np.cpu().unique().numel()
397
+
398
+ writer.add_scalar("epoch/val_acc", val_acc, epoch + 1)
399
+ writer.add_scalar("epoch/val_geo_acc", val_geo_acc, epoch + 1)
400
+ writer.add_scalar("epoch/val_cv", v_cv, epoch + 1)
401
+ writer.add_scalar("epoch/val_anchors", n_active, epoch + 1)
402
+ writer.add_scalar("epoch/queue_max", mastery._current_max, epoch + 1)
403
+ writer.add_scalar("epoch/queue_size", mastery.size, epoch + 1)
404
+
405
+ # ── Adaptive mastery queue resize ──
406
+ mastery.update_size(train_acc, val_acc, epoch + 1)
407
+
408
+ mk = ""
409
+ if val_acc > best_acc:
410
+ best_acc = val_acc
411
+ torch.save({
412
+ "state_dict": model.state_dict(),
413
+ "config": model.config,
414
+ "epoch": epoch + 1,
415
+ "val_acc": val_acc,
416
+ "val_geo_acc": val_geo_acc,
417
+ "val_loss": val_loss,
418
+ "val_cv": v_cv,
419
+ "mastery": mastery.state_dict(),
420
+ }, "checkpoints/dual_stream_v6_best.pt")
421
+ mk = " β˜…"
422
+
423
+ if (epoch + 1) % 10 == 0:
424
+ torch.save({
425
+ "state_dict": model.state_dict(),
426
+ "config": model.config,
427
+ "epoch": epoch + 1,
428
+ "val_acc": val_acc,
429
+ "optimizer": optimizer.state_dict(),
430
+ }, f"checkpoints/dual_stream_v6_e{epoch+1:03d}.pt")
431
+
432
+ cv_m = acc_dict["cv"] / d
433
+ cvf = acc_dict["cv_fused"] / d
434
+ cvg = acc_dict["cv_geo"] / d
435
+ nce_a = acc_dict["nce_acc"] / d
436
+ cmv = acc_dict["cm_valid"] / d
437
+ mst_m = acc_dict["mastery"] / d
438
+ hn = acc_dict["hard_neg"] / d if mastery.active else 0
439
+ hp = acc_dict["hard_pos"] / d if mastery.active else 0
440
+ ga = 100 * acc_dict["geo_acc"] / d
441
+ gd = acc_dict["geo_div"] / d
442
+ mrg = mastery.current_margin
443
+ stage = "MASTERY" if mastery.active else "stage1"
444
+ print(f" E{epoch+1:3d}: train={train_acc:.1f}% val={val_acc:.1f}% "
445
+ f"geo={ga:.0f}/{val_geo_acc:.0f}% "
446
+ f"loss={acc_dict['loss']/d:.4f}/{val_loss:.4f} "
447
+ f"cv={v_cv:.4f}(f={cvf:.5f} g={cvg:.5f}) "
448
+ f"gd={gd:.4f} cm={cmv:.0%} anch={n_active}/{N_ANCHORS} "
449
+ f"[{stage}] mst={mst_m:.3f} mrg={mrg:.2f} "
450
+ f"hn={hn:.3f} hp={hp:.3f} "
451
+ f"q={mastery.size}/{mastery._current_max} ({elapsed:.0f}s){mk}")
452
+
453
+ writer.close()
454
+ print(f"\n Best val accuracy: {best_acc:.1f}%")
455
+ print(f"\n{'='*60}")
456
+ print("DONE")
457
+ print(f"{'='*60}")