#!/usr/bin/env python3 """ GeoLIP Tri-Stream ViT v8 — Geometric Arbitration (fixed) ========================================================== v7→v8 changes: 1. Uniform hypersphere orthogonal init for GAL anchors + constellation 2. Gate init at 1/(2*n_blocks) — geometry enters immediately 3. InfoNCE on emb_b (Stream B survives through contrastive, not BCE) 4. InfoNCE weight on geo_emb raised — geo was starved 5. No residual scaling (per Phil) 6. GAL update interval + lr controlled from trainer Three processing paths: Stream A (CE loss): self-attn + FFN, standard cross-entropy Stream B (BCE+NCE): self-attn + FFN, binary CE + InfoNCE GAL (geometric): KSimplex features, accumulated over time, provides cross-attention to shared anchors """ import torch import torch.nn as nn import torch.nn.functional as F import math from itertools import combinations # ══════════════════════════════════════════════════════════════════ # UNIFORM HYPERSPHERE INIT # ══════════════════════════════════════════════════════════════════ def uniform_hypersphere_init(n, d): """ Generate n points with maximal spread on the d-dimensional unit sphere. n <= d: orthogonal columns via QR decomposition (perfect spread). n > d: QR orthogonal basis + iterative repulsion for the rest. Returns: (n, d) tensor on the unit sphere. """ if n <= d: # Perfect orthogonal set M = torch.randn(d, n) Q, _ = torch.linalg.qr(M) return Q.T.contiguous() # (n, d), each row unit-norm & orthogonal else: # Start with d orthogonal vectors, fill remainder M = torch.randn(d, d) Q, _ = torch.linalg.qr(M) basis = Q.T # (d, d) extra = torch.randn(n - d, d) extra = F.normalize(extra, dim=-1) vecs = torch.cat([basis, extra], dim=0) # (n, d) # Iterative repulsion — push points apart on sphere for _ in range(200): sim = vecs @ vecs.T sim.fill_diagonal_(-2.0) # ignore self # Find nearest neighbor for each point nn_idx = sim.argmax(dim=1) nn_vec = vecs[nn_idx] # Repel from nearest neighbor vecs = F.normalize(vecs - 0.05 * nn_vec, dim=-1) return vecs # ══════════════════════════════════════════════════════════════════ # CAYLEY-MENGER + KSIMPLEX (unchanged) # ══════════════════════════════════════════════════════════════════ class CMValidator(nn.Module): def __init__(self, k): super().__init__() self._k = k self._nv = k + 1 pairs = list(combinations(range(self._nv), 2)) self._npairs = len(pairs) self.register_buffer('_pi', torch.tensor([p[0] for p in pairs], dtype=torch.long)) self.register_buffer('_pj', torch.tensor([p[1] for p in pairs], dtype=torch.long)) sign = (-1.0) ** (k + 1) fact = math.factorial(k) self._prefactor = sign / ((2.0 ** k) * (fact ** 2)) def forward(self, verts): gram = torch.einsum('...ve,...we->...vw', verts, verts) norms = torch.diagonal(gram, dim1=-2, dim2=-1) d2_mat = norms.unsqueeze(-1) + norms.unsqueeze(-2) - 2 * gram d2_mat = F.relu(d2_mat) d2_pairs = d2_mat[..., self._pi, self._pj] shape = d2_mat.shape[:-2] V = d2_mat.shape[-1] cm = torch.zeros(*shape, V + 1, V + 1, device=d2_mat.device, dtype=d2_mat.dtype) cm[..., 0, 1:] = 1.0; cm[..., 1:, 0] = 1.0 cm[..., 1:, 1:] = d2_mat vol2 = self._prefactor * torch.linalg.det(cm.float()) vol2 = vol2.to(d2_pairs.dtype) return d2_pairs, vol2 class KSimplexChannel(nn.Module): BASE_DEFORM = 0.05 def __init__(self, k, in_dim, edim): super().__init__() self._k = k; self._nv = k + 1; self._edim = edim self._cm = CMValidator(k) self._out_dim = self._cm._npairs + 1 template = self._make_regular_simplex(k, edim) self.register_buffer('_template', template) self._to_deform = nn.Linear(in_dim, self._nv * edim) self._norm = nn.LayerNorm(self._out_dim) @staticmethod def _make_regular_simplex(k, edim): nv = k + 1 verts = torch.zeros(nv, edim) for i in range(min(nv, edim)): verts[i, i] = 1.0 if nv > edim: for i in range(edim, nv): v = torch.randn(edim) verts[i] = v / (v.norm() + 1e-8) verts = verts - verts.mean(dim=0, keepdim=True) edge_len = (verts[0] - verts[1]).norm().clamp(min=1e-8) return verts / edge_len @property def out_dim(self): return self._out_dim def forward(self, x): deform = self._to_deform(x).unflatten(-1, (self._nv, self._edim)) verts = self._template + self.BASE_DEFORM * deform d2, vol2 = self._cm(verts) geo = torch.cat([d2, vol2.unsqueeze(-1)], dim=-1) return self._norm(geo), vol2 # ══════════════════════════════════════════════════════════════════ # CONSTELLATION + PATCHWORK # ══════════════════════════════════════════════════════════════════ class Constellation(nn.Module): def __init__(self, n_anchors, dim, anchor_drop=0.0): super().__init__() # ── v8: uniform hypersphere init ── init_vecs = uniform_hypersphere_init(n_anchors, dim) self.anchors = nn.Parameter(init_vecs) self.anchor_drop = anchor_drop # Diagnostic with torch.no_grad(): an = F.normalize(init_vecs, dim=-1) sim = an @ an.T mask = ~torch.eye(n_anchors, dtype=torch.bool) off = sim[mask] print(f" ✓ Constellation: {n_anchors}×{dim} uniform hypersphere") print(f" pairwise cos: mean={off.mean():.4f} max={off.max():.4f}") def triangulate(self, emb, training=False): anchors = F.normalize(self.anchors, dim=-1) if training and self.anchor_drop > 0: mask = torch.rand(anchors.shape[0], device=anchors.device) > self.anchor_drop if mask.sum() < 2: mask[:2] = True anchors = anchors[mask] cos = emb @ anchors.T tri = 1.0 - cos _, nearest_local = cos.max(dim=-1) nearest = mask.nonzero(as_tuple=True)[0][nearest_local] else: cos = emb @ anchors.T tri = 1.0 - cos _, nearest = cos.max(dim=-1) return tri, nearest class Patchwork(nn.Module): def __init__(self, n_anchors, n_comp, d_comp): super().__init__() self.n_comp = n_comp; self.d_comp = d_comp self.register_buffer('asgn', torch.arange(n_anchors) % n_comp) anchors_per = n_anchors // n_comp self.comps = nn.ModuleList([nn.Sequential( nn.Linear(anchors_per, d_comp * 2), nn.GELU(), nn.Linear(d_comp * 2, d_comp), nn.LayerNorm(d_comp)) for _ in range(n_comp)]) def forward(self, tri): return torch.cat([self.comps[k](tri[:, self.asgn == k]) for k in range(self.n_comp)], -1) # ══════════════════════════════════════════════════════════════════ # EMBEDDING AUTOGRAD (unchanged) # ══════════════════════════════════════════════════════════════════ class EmbeddingAutograd(torch.autograd.Function): @staticmethod def forward(ctx, x, embedding, anchors, tang, sep): ctx.save_for_backward(embedding, anchors) ctx.tang = tang; ctx.sep = sep return x @staticmethod def backward(ctx, grad_output): embedding, anchors = ctx.saved_tensors emb_n = F.normalize(embedding.detach().float(), dim=-1) anchors_n = F.normalize(anchors.detach().float(), dim=-1) grad_f = grad_output.float() radial = (grad_f * emb_n).sum(-1, keepdim=True) * emb_n corrected = (grad_f - radial) + (1.0 - ctx.tang) * radial if ctx.sep > 0: cos_to = emb_n @ anchors_n.T nearest = anchors_n[cos_to.argmax(dim=-1)] toward = (corrected * nearest).sum(-1, keepdim=True) corrected = corrected - ctx.sep * (toward > 0).float() * toward * nearest return corrected.to(grad_output.dtype), None, None, None, None # ══════════════════════════════════════════════════════════════════ # PROCRUSTES ALIGNMENT (unchanged) # ══════════════════════════════════════════════════════════════════ def procrustes_align(source, target, whiten=False): source_c = source.float() - source.float().mean(0, keepdim=True) target_c = target.float() - target.float().mean(0, keepdim=True) if whiten: source_c = source_c / (source_c.std(0, keepdim=True) + 1e-8) target_c = target_c / (target_c.std(0, keepdim=True) + 1e-8) M = (source_c.T @ target_c).float() U, S, Vt = torch.linalg.svd(M) d = torch.ones(U.shape[0], device=U.device, dtype=U.dtype) d[-1] = torch.det(U @ Vt).sign() R = U @ torch.diag(d) @ Vt return R, S.sum().item() # ══════════════════════════════════════════════════════════════════ # SIMPLEX BUFFER (unchanged) # ══════════════════════════════════════════════════════════════════ class SimplexBuffer: def __init__(self, dim, max_size=50000, device='cuda'): self.dim = dim; self.max_size = max_size; self.device = device self._feats = None; self._labels = None def push(self, feats, labels): feats = feats.detach().to(self.device) labels = labels.detach().to(self.device) if self._feats is None: self._feats = feats; self._labels = labels else: self._feats = torch.cat([self._feats, feats], 0)[-self.max_size:] self._labels = torch.cat([self._labels, labels], 0)[-self.max_size:] @property def size(self): return 0 if self._feats is None else self._feats.shape[0] def class_centroids(self, num_classes): if self._feats is None or self.size < num_classes * 10: return None centroids = [] for c in range(num_classes): mask = self._labels == c if mask.sum() == 0: return None centroids.append(self._feats[mask].mean(0)) return torch.stack(centroids) # ══════════════════════════════════════════════════════════════════ # GAL — v8: uniform hypersphere anchors # ══════════════════════════════════════════════════════════════════ class GAL(nn.Module): def __init__(self, stream_dim, n_gal_anchors, n_heads, ksimplex_k=4, ksimplex_edim=8, dropout=0.1): super().__init__() self.stream_dim = stream_dim self.n_gal_anchors = n_gal_anchors # ── v8: uniform hypersphere init for anchors ── init_anchors = uniform_hypersphere_init(n_gal_anchors, stream_dim) self.register_buffer('gal_anchors', init_anchors) with torch.no_grad(): an = F.normalize(init_anchors, dim=-1) sim = an @ an.T mask = ~torch.eye(n_gal_anchors, dtype=torch.bool) off = sim[mask] print(f" ✓ GAL anchors: {n_gal_anchors}×{stream_dim} " f"uniform hypersphere") print(f" pairwise cos: mean={off.mean():.4f} " f"max={off.max():.4f}") self.ksimplex = KSimplexChannel( k=ksimplex_k, in_dim=stream_dim, edim=ksimplex_edim) self.geo_lift = nn.Sequential( nn.Linear(self.ksimplex.out_dim, stream_dim), nn.GELU()) self.anchor_proj = nn.Sequential( nn.Linear(stream_dim, stream_dim), nn.LayerNorm(stream_dim)) @torch.no_grad() def rotate_anchors(self, rotation_matrix): self.gal_anchors.copy_( (self.gal_anchors @ rotation_matrix).contiguous()) def get_anchor_kv(self): return self.anchor_proj(self.gal_anchors) class GALBlock(nn.Module): """ Per-layer GAL injection with non-zero gate init. v8: gates start at 1/(2*n_blocks) so geometry enters immediately. """ def __init__(self, stream_dim, n_gal_anchors, n_heads, gate_init=0.055, dropout=0.1): super().__init__() self.cross_attn_a = nn.MultiheadAttention( stream_dim, n_heads, dropout=dropout, batch_first=True) self.cross_attn_b = nn.MultiheadAttention( stream_dim, n_heads, dropout=dropout, batch_first=True) self.norm_ga = nn.LayerNorm(stream_dim) self.norm_gb = nn.LayerNorm(stream_dim) self.lift_proj_a = nn.Linear(stream_dim, stream_dim) self.lift_proj_b = nn.Linear(stream_dim, stream_dim) # ── v8: init at small positive value, NOT zero ── self.gate_a = nn.Parameter(torch.tensor(gate_init)) self.gate_b = nn.Parameter(torch.tensor(gate_init)) def forward(self, stream_a, stream_b, anchor_kv, geo_lifted): B = stream_a.shape[0] kv = anchor_kv.unsqueeze(0).expand(B, -1, -1) qa = self.norm_ga(stream_a) ha, _ = self.cross_attn_a(qa, kv, kv, need_weights=False) qb = self.norm_gb(stream_b) hb, _ = self.cross_attn_b(qb, kv, kv, need_weights=False) stream_a = stream_a + self.gate_a * (ha + self.lift_proj_a(geo_lifted)) stream_b = stream_b + self.gate_b * (hb + self.lift_proj_b(geo_lifted)) return stream_a, stream_b # ══════════════════════════════════════════════════════════════════ # TRI-STREAM BLOCK (unchanged structure) # ══════════════════════════════════════════════════════════════════ class TriStreamBlock(nn.Module): def __init__(self, stream_dim, n_gal_anchors, n_heads, gate_init=0.055, dropout=0.1): super().__init__() # Stream A self.norm_a1 = nn.LayerNorm(stream_dim) self.attn_a = nn.MultiheadAttention( stream_dim, n_heads, dropout=dropout, batch_first=True) self.norm_a2 = nn.LayerNorm(stream_dim) self.ffn_a = nn.Sequential( nn.Linear(stream_dim, stream_dim * 4), nn.GELU(), nn.Dropout(dropout), nn.Linear(stream_dim * 4, stream_dim), nn.Dropout(dropout)) # Stream B self.norm_b1 = nn.LayerNorm(stream_dim) self.attn_b = nn.MultiheadAttention( stream_dim, n_heads, dropout=dropout, batch_first=True) self.norm_b2 = nn.LayerNorm(stream_dim) self.ffn_b = nn.Sequential( nn.Linear(stream_dim, stream_dim * 4), nn.GELU(), nn.Dropout(dropout), nn.Linear(stream_dim * 4, stream_dim), nn.Dropout(dropout)) # GAL block — v8: gate_init passed through self.gal_block = GALBlock( stream_dim, n_gal_anchors, n_heads, gate_init=gate_init, dropout=dropout) self.geo_combine_norm = nn.LayerNorm(stream_dim) def forward(self, stream_a, stream_b, gal, anchor_kv): B, P, D = stream_a.shape # Stream A h = self.norm_a1(stream_a) h, _ = self.attn_a(h, h, h, need_weights=False) stream_a = stream_a + h stream_a = stream_a + self.ffn_a(self.norm_a2(stream_a)) # Stream B h = self.norm_b1(stream_b) h, _ = self.attn_b(h, h, h, need_weights=False) stream_b = stream_b + h stream_b = stream_b + self.ffn_b(self.norm_b2(stream_b)) # GAL geo_input = self.geo_combine_norm(stream_a + stream_b) flat = geo_input.reshape(B * P, D) geo_feats, vol2 = gal.ksimplex(flat) geo_feats = geo_feats.reshape(B, P, -1) vol2 = vol2.reshape(B, P) geo_lifted = gal.geo_lift(geo_feats) stream_a, stream_b = self.gal_block( stream_a, stream_b, anchor_kv, geo_lifted) return stream_a, stream_b, geo_feats, vol2, geo_lifted # ══════════════════════════════════════════════════════════════════ # TRI-STREAM VIT v8 # ══════════════════════════════════════════════════════════════════ class TriStreamViT(nn.Module): def __init__( self, num_classes=10, img_size=32, patch_size=4, embed_dim=384, stream_dim=192, n_blocks=9, n_heads=8, output_dim=256, n_anchors=128, n_gal_anchors=64, n_comp=16, d_comp=128, anchor_drop=0.10, cv_target=0.22, ksimplex_k=4, ksimplex_edim=8, dropout=0.1, infonce_temp=0.07, infonce_weight=0.1, bce_weight=1.0, cm_weight=0.1, cv_weight=0.1, autograd_tang=1.0, autograd_sep=0.1, enable_autograd=True, label_smoothing=0.1, # ── v8: stream B + geo InfoNCE weights (separate) ── stream_b_nce_weight=0.5, geo_nce_weight=0.5, ): super().__init__() self.num_classes = num_classes self.num_patches = (img_size // patch_size) ** 2 self.stream_dim = stream_dim self.output_dim = output_dim self.cv_target = cv_target self.infonce_temp = infonce_temp self.infonce_weight = infonce_weight self.bce_weight = bce_weight self.cm_weight = cm_weight self.cv_weight = cv_weight self.autograd_tang = autograd_tang self.autograd_sep = autograd_sep self.enable_autograd = enable_autograd self.label_smoothing = label_smoothing self.stream_b_nce_weight = stream_b_nce_weight self.geo_nce_weight = geo_nce_weight self.config = {k: v for k, v in locals().items() if k != 'self' and not k.startswith('_')} # ── v8: gate init from block count ── gate_init = 1.0 / (2.0 * n_blocks) # ~0.055 for 9 blocks print(f" Gate init: {gate_init:.4f} (1/(2×{n_blocks}))") # Shared patch embedding self.patch_embed = nn.Conv2d( 3, embed_dim, kernel_size=patch_size, stride=patch_size) self.pos_embed = nn.Parameter( torch.randn(1, self.num_patches, embed_dim) * 0.02) # Stream projections self.proj_a = nn.Sequential( nn.Linear(embed_dim, stream_dim), nn.LayerNorm(stream_dim)) self.proj_b = nn.Sequential( nn.Linear(embed_dim, stream_dim), nn.LayerNorm(stream_dim)) # Shared GAL self.gal = GAL(stream_dim, n_gal_anchors, n_heads, ksimplex_k, ksimplex_edim, dropout) # Tri-stream blocks — v8: pass gate_init self.blocks = nn.ModuleList([ TriStreamBlock(stream_dim, n_gal_anchors, n_heads, gate_init=gate_init, dropout=dropout) for _ in range(n_blocks)]) # Output norms self.norm_a = nn.LayerNorm(stream_dim) self.norm_b = nn.LayerNorm(stream_dim) # Sphere projections self.proj_sphere_a = nn.Sequential( nn.Linear(stream_dim, output_dim), nn.LayerNorm(output_dim)) self.proj_sphere_b = nn.Sequential( nn.Linear(stream_dim, output_dim), nn.LayerNorm(output_dim)) self.proj_sphere_geo = nn.Sequential( nn.Linear(stream_dim, output_dim), nn.LayerNorm(output_dim)) # Constellation + Patchwork (uniform hypersphere via Constellation) self.constellation = Constellation(n_anchors, output_dim, anchor_drop) self.patchwork = Patchwork(n_anchors, n_comp, d_comp) pw_dim = n_comp * d_comp # Classifiers self.classifier_a = nn.Sequential( nn.Linear(pw_dim + output_dim, pw_dim), nn.GELU(), nn.LayerNorm(pw_dim), nn.Dropout(dropout), nn.Linear(pw_dim, num_classes)) self.classifier_b = nn.Sequential( nn.Linear(pw_dim + output_dim, pw_dim), nn.GELU(), nn.LayerNorm(pw_dim), nn.Dropout(dropout), nn.Linear(pw_dim, num_classes)) self.geo_classifier = nn.Sequential( nn.Linear(output_dim, output_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(output_dim, num_classes)) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.LayerNorm): nn.init.ones_(m.weight) nn.init.zeros_(m.bias) def forward(self, x, apply_autograd=True): output = {} B = x.shape[0] # Patch embedding tokens = self.patch_embed(x).flatten(2).transpose(1, 2) tokens = tokens + self.pos_embed P = tokens.shape[1] # Split stream_a = self.proj_a(tokens) stream_b = self.proj_b(tokens) # Anchor KV once anchor_kv = self.gal.get_anchor_kv() # Process through blocks all_geo_feats = [] all_vol2 = [] geo_accum = torch.zeros_like(stream_a) for block in self.blocks: stream_a, stream_b, geo_feats, vol2, geo_lifted = block( stream_a, stream_b, self.gal, anchor_kv) all_geo_feats.append(geo_feats) all_vol2.append(vol2) geo_accum = geo_accum + geo_lifted output['geo_feats'] = all_geo_feats[-1] output['all_geo_feats'] = torch.stack(all_geo_feats) output['vol2'] = torch.stack(all_vol2) # Norms stream_a = self.norm_a(stream_a) stream_b = self.norm_b(stream_b) # Pool pool_a = stream_a.mean(dim=1) pool_b = stream_b.mean(dim=1) pool_geo = geo_accum.mean(dim=1) # → sphere emb_a = F.normalize(self.proj_sphere_a(pool_a), dim=-1) emb_b = F.normalize(self.proj_sphere_b(pool_b), dim=-1) geo_emb = F.normalize(self.proj_sphere_geo(pool_geo), dim=-1) # Combined emb = F.normalize(emb_a + emb_b + geo_emb, dim=-1) # EmbeddingAutograd if apply_autograd and self.training and self.enable_autograd: emb = EmbeddingAutograd.apply( emb, emb, self.constellation.anchors, self.autograd_tang, self.autograd_sep) # ── v8: autograd on ALL three sub-embeddings ── emb_b = EmbeddingAutograd.apply( emb_b, emb_b, self.constellation.anchors, self.autograd_tang, self.autograd_sep) geo_emb = EmbeddingAutograd.apply( geo_emb, geo_emb, self.constellation.anchors, self.autograd_tang, self.autograd_sep) output['embedding'] = emb output['emb_a'] = emb_a output['emb_b'] = emb_b output['geo_emb'] = geo_emb output['pool_geo'] = pool_geo # Constellation + Patchwork tri_full, nearest_full = self.constellation.triangulate( emb, training=False) pw = self.patchwork(tri_full) output['triangulation'] = tri_full if self.training: _, nearest = self.constellation.triangulate(emb, training=True) else: nearest = nearest_full output['nearest'] = nearest # Classifiers logits_a = self.classifier_a(torch.cat([pw, emb_a], dim=-1)) logits_b = self.classifier_b(torch.cat([pw, emb_b], dim=-1)) geo_logits = self.geo_classifier(geo_emb) output['logits_a'] = logits_a output['logits_b'] = logits_b output['geo_logits'] = geo_logits # Gate monitoring gates_a = [b.gal_block.gate_a.item() for b in self.blocks] gates_b = [b.gal_block.gate_b.item() for b in self.blocks] output['gates_a'] = gates_a output['gates_b'] = gates_b return output # ────────────────────────────────────────────────────────── # PROCRUSTES ANCHOR UPDATE (unchanged) # ────────────────────────────────────────────────────────── @torch.no_grad() def update_gal_anchors(self, simplex_buffer, lr=0.015, whiten=False): with torch.amp.autocast("cuda", enabled=False): centroids = simplex_buffer.class_centroids(self.num_classes) if centroids is None: return None anchors = self.gal.gal_anchors.float() centroid_n = F.normalize(centroids.float(), dim=-1) anchor_n = F.normalize(anchors, dim=-1) cos = centroid_n @ anchor_n.T matched_idx = cos.argmax(dim=1) matched_anchors = anchors[matched_idx] R, score = procrustes_align( matched_anchors, centroids.float(), whiten=whiten) rotated = anchors @ R new_anchors = F.normalize( anchors + lr * (rotated - anchors), dim=-1) self.gal.gal_anchors.copy_( new_anchors.to(self.gal.gal_anchors.dtype)) return score # ────────────────────────────────────────────────────────── # LOSS — v8: InfoNCE on emb_b + stronger geo_emb signal # ────────────────────────────────────────────────────────── def compute_loss(self, output, targets, output_aug=None, mastery_queue=None): loss_dict = {} emb = output['embedding'] emb_b = output['emb_b'] geo_emb = output['geo_emb'] B = emb.shape[0] is_mastery = mastery_queue is not None and mastery_queue.active # ── CE on Stream A ── l_ce = F.cross_entropy(output['logits_a'], targets) loss_dict['ce'] = l_ce acc_a = (output['logits_a'].argmax(-1) == targets).float().mean().item() loss_dict['acc_a'] = acc_a # ── BCE on Stream B ── one_hot = F.one_hot(targets, self.num_classes).float() ls = self.label_smoothing one_hot_smooth = one_hot * (1.0 - ls) + ls / self.num_classes if ls > 0 else one_hot l_bce = F.binary_cross_entropy_with_logits( output['logits_b'], one_hot_smooth) loss_dict['bce'] = l_bce acc_b = (output['logits_b'].argmax(-1) == targets).float().mean().item() loss_dict['acc_b'] = acc_b # ── Geo classifier BCE ── l_geo_bce = F.binary_cross_entropy_with_logits( output['geo_logits'], one_hot_smooth) loss_dict['geo_bce'] = l_geo_bce geo_acc = (output['geo_logits'].argmax(-1) == targets).float().mean().item() loss_dict['geo_acc'] = geo_acc # ── InfoNCE — v8: on combined, emb_b, AND geo_emb ── nce_acc = 0.0 if output_aug is not None: labels_nce = torch.arange(B, device=emb.device) # Combined embedding InfoNCE emb_aug = output_aug['embedding'] sim = emb @ emb_aug.T / self.infonce_temp l_nce = F.cross_entropy(sim, labels_nce) nce_acc = (sim.argmax(1) == labels_nce).float().mean().item() loss_dict['nce'] = l_nce loss_dict['nce_acc'] = nce_acc # ── v8: Stream B InfoNCE (this is what keeps B alive) ── emb_b_aug = output_aug.get('emb_b') if emb_b_aug is not None: sim_b = emb_b @ emb_b_aug.T / self.infonce_temp l_nce_b = F.cross_entropy(sim_b, labels_nce) nce_b_acc = (sim_b.argmax(1) == labels_nce).float().mean().item() loss_dict['nce_b'] = l_nce_b loss_dict['nce_b_acc'] = nce_b_acc # ── v8: Geo InfoNCE (this is what feeds the geo path) ── geo_emb_aug = output_aug.get('geo_emb') if geo_emb_aug is not None: sim_g = geo_emb @ geo_emb_aug.T / self.infonce_temp l_geo_nce = F.cross_entropy(sim_g, labels_nce) geo_nce_acc = (sim_g.argmax(1) == labels_nce).float().mean().item() loss_dict['geo_nce'] = l_geo_nce loss_dict['geo_nce_acc'] = geo_nce_acc # ── Mastery (unchanged) ── if is_mastery: q_emb, q_labels = mastery_queue.get() if q_emb is not None and q_emb.shape[0] >= B: cross_sim = emb @ q_emb.T same_mask = targets.unsqueeze(1) == q_labels.unsqueeze(0) hn_sim = cross_sim.clone(); hn_sim[same_mask] = -1e9 hn_cos = hn_sim.max(dim=1).values hp_sim = cross_sim.clone(); hp_sim[~same_mask] = 1e9 hp_cos = hp_sim.min(dim=1).values valid = same_mask.any(1) & (~same_mask).any(1) if valid.sum() > 0: margin = mastery_queue.current_margin l_mastery = F.relu( hn_cos[valid] - hp_cos[valid] + margin).mean() loss_dict['mastery'] = l_mastery loss_dict['hard_neg_cos'] = hn_cos[valid].mean().item() loss_dict['hard_pos_cos'] = hp_cos[valid].mean().item() loss_dict['margin'] = margin mastery_queue.push(emb.detach(), targets.detach()) # ── CM validity ── vol2 = output['vol2'] l_cm = F.relu(-vol2).mean() loss_dict['cm'] = l_cm loss_dict['cm_valid'] = (vol2 > 0).float().mean().item() # ── CV on combined + geo ── l_cv_main = self._cv_loss_fast(emb, target=self.cv_target) l_cv_geo = self._cv_loss_fast(geo_emb, target=self.cv_target) l_cv = l_cv_main + l_cv_geo loss_dict['cv'] = l_cv loss_dict['cv_main'] = l_cv_main.item() if torch.is_tensor(l_cv_main) else l_cv_main loss_dict['cv_geo'] = l_cv_geo.item() if torch.is_tensor(l_cv_geo) else l_cv_geo # ── Anchor spread ── anchors_n = F.normalize(self.constellation.anchors, dim=-1) anchor_sim = anchors_n @ anchors_n.T mask_a = ~torch.eye(anchors_n.shape[0], dtype=torch.bool, device=anchors_n.device) l_spread = F.relu(anchor_sim[mask_a] - 0.0).mean() loss_dict['spread'] = l_spread # ── Combine — v8: explicit weights for B and geo NCE ── loss = (l_ce * self.bce_weight + l_bce * self.bce_weight + l_geo_bce * self.bce_weight + loss_dict.get('nce', 0.0) * self.infonce_weight + loss_dict.get('nce_b', 0.0) * self.stream_b_nce_weight + loss_dict.get('geo_nce', 0.0) * self.geo_nce_weight + loss_dict.get('mastery', 0.0) * self.bce_weight + l_cm * self.cm_weight + l_cv * self.cv_weight + l_spread * 0.001) loss_dict['total'] = loss return loss, loss_dict @staticmethod def _cv_loss_fast(emb, target=0.22, n_samples=64, n_points=5): B = emb.shape[0] if B < n_points: return torch.tensor(0.0, device=emb.device) vols = [] for _ in range(n_samples): idx = torch.randperm(min(B, 512), device=emb.device)[:n_points] pts = emb[idx].unsqueeze(0) gram = torch.bmm(pts, pts.transpose(1, 2)) norms = torch.diagonal(gram, dim1=1, dim2=2) d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram d2 = F.relu(d2) N = n_points cm = torch.zeros(1, N + 1, N + 1, device=emb.device, dtype=emb.dtype) cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 k = N - 1 sign = (-1.0) ** (k + 1) fact = math.factorial(k) prefactor = sign / ((2.0 ** k) * (fact ** 2)) vol2 = prefactor * torch.linalg.det(cm.float()) if vol2[0].item() > 1e-20: vols.append(vol2[0].to(emb.dtype).sqrt()) if len(vols) < 5: return torch.tensor(0.0, device=emb.device) vols_t = torch.stack(vols) cv = vols_t.std() / (vols_t.mean() + 1e-8) return (cv - target).pow(2) # ══════════════════════════════════════════════════════════════════ # MASTERY QUEUE (unchanged) # ══════════════════════════════════════════════════════════════════ class MasteryQueue: def __init__(self, dim, min_size=1024, max_size=8192, initial_size=4096, patience=50, device='cuda', margin_start=0.1, margin_end=0.3, margin_warmup=5000, resize_step=1024, resize_cooldown=5, overfit_threshold=3.0): self.dim = dim self.min_size = min_size; self.max_size = max_size self._current_max = initial_size self.patience = patience; self.device = device self.active = False self._embs = None; self._labels = None self._perfect_count = 0; self._total_batches = 0 self._activated_at = None self._margin_start = margin_start self._margin_end = margin_end self._margin_warmup = margin_warmup self._mastery_steps = 0 self._resize_step = resize_step self._resize_cooldown = resize_cooldown self._overfit_threshold = overfit_threshold self._epochs_since_resize = resize_cooldown self._gap_history = []; self._gap_window = 5 self._resize_history = [] def check_activation(self, nce_acc): self._total_batches += 1 if nce_acc >= 0.99: self._perfect_count += 1 else: self._perfect_count = 0 if not self.active and self._perfect_count >= self.patience: self.active = True self._activated_at = self._total_batches print(f"\n ★ MASTERY ACTIVATED at batch {self._total_batches} " f"(nce_acc=1.0 for {self.patience} consecutive) " f"queue={self._current_max}") if self.active: self._mastery_steps += 1 def update_size(self, train_acc, val_acc, epoch): if not self.active: return self._epochs_since_resize += 1 gap = train_acc - val_acc self._gap_history.append((epoch, gap)) if self._epochs_since_resize < self._resize_cooldown: return old_size = self._current_max; reason = None if gap > self._overfit_threshold * 2: self._current_max = min(self._current_max + self._resize_step, self.max_size) reason = f"grow: gap={gap:.1f}%" elif gap < self._overfit_threshold and gap > 0: if len(self._gap_history) >= self._gap_window: recent = [g for _, g in self._gap_history[-self._gap_window:]] if all(0 < g < self._overfit_threshold for g in recent): self._current_max = max(self._current_max - self._resize_step, self.min_size) reason = f"shrink: stable gap={gap:.1f}%" if reason is None and len(self._gap_history) >= self._gap_window: drift = gap - self._gap_history[-self._gap_window][1] if drift > self._overfit_threshold: self._current_max = min(self._current_max + self._resize_step, self.max_size) reason = f"drift: {drift:+.1f}%" elif drift < -self._overfit_threshold and gap > 0: self._current_max = max(self._current_max - self._resize_step, self.min_size) reason = f"drift: {drift:+.1f}%" if self._current_max != old_size: d = "↑" if self._current_max > old_size else "↓" print(f" ⚙ Queue {d} {old_size}→{self._current_max} ({reason})") self._epochs_since_resize = 0 self._resize_history.append((epoch, old_size, self._current_max, gap, reason)) if self._embs is not None and self._embs.shape[0] > self._current_max: self._embs = self._embs[-self._current_max:] self._labels = self._labels[-self._current_max:] @property def current_margin(self): if not self.active: return self._margin_start t = min(self._mastery_steps / max(self._margin_warmup, 1), 1.0) return self._margin_start + t * (self._margin_end - self._margin_start) def push(self, emb, labels): emb = emb.detach().to(self.device) labels = labels.detach().to(self.device) if self._embs is None: self._embs = emb; self._labels = labels else: self._embs = torch.cat([self._embs, emb], 0)[-self._current_max:] self._labels = torch.cat([self._labels, labels], 0)[-self._current_max:] def get(self): if self._embs is None: return None, None return self._embs, self._labels @property def size(self): return 0 if self._embs is None else self._embs.shape[0] def state_dict(self): return { 'active': self.active, 'total_batches': self._total_batches, 'activated_at': self._activated_at, 'mastery_steps': self._mastery_steps, 'current_margin': self.current_margin, 'current_max': self._current_max, 'gap_history': self._gap_history[-20:], 'resize_history': self._resize_history, } # ══════════════════════════════════════════════════════════════════ # FACTORY # ══════════════════════════════════════════════════════════════════ def create_tri_stream_vit(**kwargs): return TriStreamViT(**kwargs)