geolip-vit-tri-stream / modeling_tri_stream.py
AbstractPhil's picture
Update modeling_tri_stream.py
8d7f7cd verified
#!/usr/bin/env python3
"""
GeoLIP Tri-Stream ViT v8 β€” Geometric Arbitration (fixed)
==========================================================
v7β†’v8 changes:
1. Uniform hypersphere orthogonal init for GAL anchors + constellation
2. Gate init at 1/(2*n_blocks) β€” geometry enters immediately
3. InfoNCE on emb_b (Stream B survives through contrastive, not BCE)
4. InfoNCE weight on geo_emb raised β€” geo was starved
5. No residual scaling (per Phil)
6. GAL update interval + lr controlled from trainer
Three processing paths:
Stream A (CE loss): self-attn + FFN, standard cross-entropy
Stream B (BCE+NCE): self-attn + FFN, binary CE + InfoNCE
GAL (geometric): KSimplex features, accumulated over time,
provides cross-attention to shared anchors
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from itertools import combinations
# ══════════════════════════════════════════════════════════════════
# UNIFORM HYPERSPHERE INIT
# ══════════════════════════════════════════════════════════════════
def uniform_hypersphere_init(n, d):
"""
Generate n points with maximal spread on the d-dimensional unit sphere.
n <= d: orthogonal columns via QR decomposition (perfect spread).
n > d: QR orthogonal basis + iterative repulsion for the rest.
Returns: (n, d) tensor on the unit sphere.
"""
if n <= d:
# Perfect orthogonal set
M = torch.randn(d, n)
Q, _ = torch.linalg.qr(M)
return Q.T.contiguous() # (n, d), each row unit-norm & orthogonal
else:
# Start with d orthogonal vectors, fill remainder
M = torch.randn(d, d)
Q, _ = torch.linalg.qr(M)
basis = Q.T # (d, d)
extra = torch.randn(n - d, d)
extra = F.normalize(extra, dim=-1)
vecs = torch.cat([basis, extra], dim=0) # (n, d)
# Iterative repulsion β€” push points apart on sphere
for _ in range(200):
sim = vecs @ vecs.T
sim.fill_diagonal_(-2.0) # ignore self
# Find nearest neighbor for each point
nn_idx = sim.argmax(dim=1)
nn_vec = vecs[nn_idx]
# Repel from nearest neighbor
vecs = F.normalize(vecs - 0.05 * nn_vec, dim=-1)
return vecs
# ══════════════════════════════════════════════════════════════════
# CAYLEY-MENGER + KSIMPLEX (unchanged)
# ══════════════════════════════════════════════════════════════════
class CMValidator(nn.Module):
def __init__(self, k):
super().__init__()
self._k = k
self._nv = k + 1
pairs = list(combinations(range(self._nv), 2))
self._npairs = len(pairs)
self.register_buffer('_pi', torch.tensor([p[0] for p in pairs], dtype=torch.long))
self.register_buffer('_pj', torch.tensor([p[1] for p in pairs], dtype=torch.long))
sign = (-1.0) ** (k + 1)
fact = math.factorial(k)
self._prefactor = sign / ((2.0 ** k) * (fact ** 2))
def forward(self, verts):
gram = torch.einsum('...ve,...we->...vw', verts, verts)
norms = torch.diagonal(gram, dim1=-2, dim2=-1)
d2_mat = norms.unsqueeze(-1) + norms.unsqueeze(-2) - 2 * gram
d2_mat = F.relu(d2_mat)
d2_pairs = d2_mat[..., self._pi, self._pj]
shape = d2_mat.shape[:-2]
V = d2_mat.shape[-1]
cm = torch.zeros(*shape, V + 1, V + 1,
device=d2_mat.device, dtype=d2_mat.dtype)
cm[..., 0, 1:] = 1.0; cm[..., 1:, 0] = 1.0
cm[..., 1:, 1:] = d2_mat
vol2 = self._prefactor * torch.linalg.det(cm.float())
vol2 = vol2.to(d2_pairs.dtype)
return d2_pairs, vol2
class KSimplexChannel(nn.Module):
BASE_DEFORM = 0.05
def __init__(self, k, in_dim, edim):
super().__init__()
self._k = k; self._nv = k + 1; self._edim = edim
self._cm = CMValidator(k)
self._out_dim = self._cm._npairs + 1
template = self._make_regular_simplex(k, edim)
self.register_buffer('_template', template)
self._to_deform = nn.Linear(in_dim, self._nv * edim)
self._norm = nn.LayerNorm(self._out_dim)
@staticmethod
def _make_regular_simplex(k, edim):
nv = k + 1
verts = torch.zeros(nv, edim)
for i in range(min(nv, edim)):
verts[i, i] = 1.0
if nv > edim:
for i in range(edim, nv):
v = torch.randn(edim)
verts[i] = v / (v.norm() + 1e-8)
verts = verts - verts.mean(dim=0, keepdim=True)
edge_len = (verts[0] - verts[1]).norm().clamp(min=1e-8)
return verts / edge_len
@property
def out_dim(self):
return self._out_dim
def forward(self, x):
deform = self._to_deform(x).unflatten(-1, (self._nv, self._edim))
verts = self._template + self.BASE_DEFORM * deform
d2, vol2 = self._cm(verts)
geo = torch.cat([d2, vol2.unsqueeze(-1)], dim=-1)
return self._norm(geo), vol2
# ══════════════════════════════════════════════════════════════════
# CONSTELLATION + PATCHWORK
# ══════════════════════════════════════════════════════════════════
class Constellation(nn.Module):
def __init__(self, n_anchors, dim, anchor_drop=0.0):
super().__init__()
# ── v8: uniform hypersphere init ──
init_vecs = uniform_hypersphere_init(n_anchors, dim)
self.anchors = nn.Parameter(init_vecs)
self.anchor_drop = anchor_drop
# Diagnostic
with torch.no_grad():
an = F.normalize(init_vecs, dim=-1)
sim = an @ an.T
mask = ~torch.eye(n_anchors, dtype=torch.bool)
off = sim[mask]
print(f" βœ“ Constellation: {n_anchors}Γ—{dim} uniform hypersphere")
print(f" pairwise cos: mean={off.mean():.4f} max={off.max():.4f}")
def triangulate(self, emb, training=False):
anchors = F.normalize(self.anchors, dim=-1)
if training and self.anchor_drop > 0:
mask = torch.rand(anchors.shape[0], device=anchors.device) > self.anchor_drop
if mask.sum() < 2: mask[:2] = True
anchors = anchors[mask]
cos = emb @ anchors.T
tri = 1.0 - cos
_, nearest_local = cos.max(dim=-1)
nearest = mask.nonzero(as_tuple=True)[0][nearest_local]
else:
cos = emb @ anchors.T
tri = 1.0 - cos
_, nearest = cos.max(dim=-1)
return tri, nearest
class Patchwork(nn.Module):
def __init__(self, n_anchors, n_comp, d_comp):
super().__init__()
self.n_comp = n_comp; self.d_comp = d_comp
self.register_buffer('asgn', torch.arange(n_anchors) % n_comp)
anchors_per = n_anchors // n_comp
self.comps = nn.ModuleList([nn.Sequential(
nn.Linear(anchors_per, d_comp * 2), nn.GELU(),
nn.Linear(d_comp * 2, d_comp), nn.LayerNorm(d_comp))
for _ in range(n_comp)])
def forward(self, tri):
return torch.cat([self.comps[k](tri[:, self.asgn == k])
for k in range(self.n_comp)], -1)
# ══════════════════════════════════════════════════════════════════
# EMBEDDING AUTOGRAD (unchanged)
# ══════════════════════════════════════════════════════════════════
class EmbeddingAutograd(torch.autograd.Function):
@staticmethod
def forward(ctx, x, embedding, anchors, tang, sep):
ctx.save_for_backward(embedding, anchors)
ctx.tang = tang; ctx.sep = sep
return x
@staticmethod
def backward(ctx, grad_output):
embedding, anchors = ctx.saved_tensors
emb_n = F.normalize(embedding.detach().float(), dim=-1)
anchors_n = F.normalize(anchors.detach().float(), dim=-1)
grad_f = grad_output.float()
radial = (grad_f * emb_n).sum(-1, keepdim=True) * emb_n
corrected = (grad_f - radial) + (1.0 - ctx.tang) * radial
if ctx.sep > 0:
cos_to = emb_n @ anchors_n.T
nearest = anchors_n[cos_to.argmax(dim=-1)]
toward = (corrected * nearest).sum(-1, keepdim=True)
corrected = corrected - ctx.sep * (toward > 0).float() * toward * nearest
return corrected.to(grad_output.dtype), None, None, None, None
# ══════════════════════════════════════════════════════════════════
# PROCRUSTES ALIGNMENT (unchanged)
# ══════════════════════════════════════════════════════════════════
def procrustes_align(source, target, whiten=False):
source_c = source.float() - source.float().mean(0, keepdim=True)
target_c = target.float() - target.float().mean(0, keepdim=True)
if whiten:
source_c = source_c / (source_c.std(0, keepdim=True) + 1e-8)
target_c = target_c / (target_c.std(0, keepdim=True) + 1e-8)
M = (source_c.T @ target_c).float()
U, S, Vt = torch.linalg.svd(M)
d = torch.ones(U.shape[0], device=U.device, dtype=U.dtype)
d[-1] = torch.det(U @ Vt).sign()
R = U @ torch.diag(d) @ Vt
return R, S.sum().item()
# ══════════════════════════════════════════════════════════════════
# SIMPLEX BUFFER (unchanged)
# ══════════════════════════════════════════════════════════════════
class SimplexBuffer:
def __init__(self, dim, max_size=50000, device='cuda'):
self.dim = dim; self.max_size = max_size; self.device = device
self._feats = None; self._labels = None
def push(self, feats, labels):
feats = feats.detach().to(self.device)
labels = labels.detach().to(self.device)
if self._feats is None:
self._feats = feats; self._labels = labels
else:
self._feats = torch.cat([self._feats, feats], 0)[-self.max_size:]
self._labels = torch.cat([self._labels, labels], 0)[-self.max_size:]
@property
def size(self):
return 0 if self._feats is None else self._feats.shape[0]
def class_centroids(self, num_classes):
if self._feats is None or self.size < num_classes * 10:
return None
centroids = []
for c in range(num_classes):
mask = self._labels == c
if mask.sum() == 0: return None
centroids.append(self._feats[mask].mean(0))
return torch.stack(centroids)
# ══════════════════════════════════════════════════════════════════
# GAL β€” v8: uniform hypersphere anchors
# ══════════════════════════════════════════════════════════════════
class GAL(nn.Module):
def __init__(self, stream_dim, n_gal_anchors, n_heads,
ksimplex_k=4, ksimplex_edim=8, dropout=0.1):
super().__init__()
self.stream_dim = stream_dim
self.n_gal_anchors = n_gal_anchors
# ── v8: uniform hypersphere init for anchors ──
init_anchors = uniform_hypersphere_init(n_gal_anchors, stream_dim)
self.register_buffer('gal_anchors', init_anchors)
with torch.no_grad():
an = F.normalize(init_anchors, dim=-1)
sim = an @ an.T
mask = ~torch.eye(n_gal_anchors, dtype=torch.bool)
off = sim[mask]
print(f" βœ“ GAL anchors: {n_gal_anchors}Γ—{stream_dim} "
f"uniform hypersphere")
print(f" pairwise cos: mean={off.mean():.4f} "
f"max={off.max():.4f}")
self.ksimplex = KSimplexChannel(
k=ksimplex_k, in_dim=stream_dim, edim=ksimplex_edim)
self.geo_lift = nn.Sequential(
nn.Linear(self.ksimplex.out_dim, stream_dim), nn.GELU())
self.anchor_proj = nn.Sequential(
nn.Linear(stream_dim, stream_dim), nn.LayerNorm(stream_dim))
@torch.no_grad()
def rotate_anchors(self, rotation_matrix):
self.gal_anchors.copy_(
(self.gal_anchors @ rotation_matrix).contiguous())
def get_anchor_kv(self):
return self.anchor_proj(self.gal_anchors)
class GALBlock(nn.Module):
"""
Per-layer GAL injection with non-zero gate init.
v8: gates start at 1/(2*n_blocks) so geometry enters immediately.
"""
def __init__(self, stream_dim, n_gal_anchors, n_heads,
gate_init=0.055, dropout=0.1):
super().__init__()
self.cross_attn_a = nn.MultiheadAttention(
stream_dim, n_heads, dropout=dropout, batch_first=True)
self.cross_attn_b = nn.MultiheadAttention(
stream_dim, n_heads, dropout=dropout, batch_first=True)
self.norm_ga = nn.LayerNorm(stream_dim)
self.norm_gb = nn.LayerNorm(stream_dim)
self.lift_proj_a = nn.Linear(stream_dim, stream_dim)
self.lift_proj_b = nn.Linear(stream_dim, stream_dim)
# ── v8: init at small positive value, NOT zero ──
self.gate_a = nn.Parameter(torch.tensor(gate_init))
self.gate_b = nn.Parameter(torch.tensor(gate_init))
def forward(self, stream_a, stream_b, anchor_kv, geo_lifted):
B = stream_a.shape[0]
kv = anchor_kv.unsqueeze(0).expand(B, -1, -1)
qa = self.norm_ga(stream_a)
ha, _ = self.cross_attn_a(qa, kv, kv, need_weights=False)
qb = self.norm_gb(stream_b)
hb, _ = self.cross_attn_b(qb, kv, kv, need_weights=False)
stream_a = stream_a + self.gate_a * (ha + self.lift_proj_a(geo_lifted))
stream_b = stream_b + self.gate_b * (hb + self.lift_proj_b(geo_lifted))
return stream_a, stream_b
# ══════════════════════════════════════════════════════════════════
# TRI-STREAM BLOCK (unchanged structure)
# ══════════════════════════════════════════════════════════════════
class TriStreamBlock(nn.Module):
def __init__(self, stream_dim, n_gal_anchors, n_heads,
gate_init=0.055, dropout=0.1):
super().__init__()
# Stream A
self.norm_a1 = nn.LayerNorm(stream_dim)
self.attn_a = nn.MultiheadAttention(
stream_dim, n_heads, dropout=dropout, batch_first=True)
self.norm_a2 = nn.LayerNorm(stream_dim)
self.ffn_a = nn.Sequential(
nn.Linear(stream_dim, stream_dim * 4), nn.GELU(),
nn.Dropout(dropout),
nn.Linear(stream_dim * 4, stream_dim), nn.Dropout(dropout))
# Stream B
self.norm_b1 = nn.LayerNorm(stream_dim)
self.attn_b = nn.MultiheadAttention(
stream_dim, n_heads, dropout=dropout, batch_first=True)
self.norm_b2 = nn.LayerNorm(stream_dim)
self.ffn_b = nn.Sequential(
nn.Linear(stream_dim, stream_dim * 4), nn.GELU(),
nn.Dropout(dropout),
nn.Linear(stream_dim * 4, stream_dim), nn.Dropout(dropout))
# GAL block β€” v8: gate_init passed through
self.gal_block = GALBlock(
stream_dim, n_gal_anchors, n_heads,
gate_init=gate_init, dropout=dropout)
self.geo_combine_norm = nn.LayerNorm(stream_dim)
def forward(self, stream_a, stream_b, gal, anchor_kv):
B, P, D = stream_a.shape
# Stream A
h = self.norm_a1(stream_a)
h, _ = self.attn_a(h, h, h, need_weights=False)
stream_a = stream_a + h
stream_a = stream_a + self.ffn_a(self.norm_a2(stream_a))
# Stream B
h = self.norm_b1(stream_b)
h, _ = self.attn_b(h, h, h, need_weights=False)
stream_b = stream_b + h
stream_b = stream_b + self.ffn_b(self.norm_b2(stream_b))
# GAL
geo_input = self.geo_combine_norm(stream_a + stream_b)
flat = geo_input.reshape(B * P, D)
geo_feats, vol2 = gal.ksimplex(flat)
geo_feats = geo_feats.reshape(B, P, -1)
vol2 = vol2.reshape(B, P)
geo_lifted = gal.geo_lift(geo_feats)
stream_a, stream_b = self.gal_block(
stream_a, stream_b, anchor_kv, geo_lifted)
return stream_a, stream_b, geo_feats, vol2, geo_lifted
# ══════════════════════════════════════════════════════════════════
# TRI-STREAM VIT v8
# ══════════════════════════════════════════════════════════════════
class TriStreamViT(nn.Module):
def __init__(
self,
num_classes=10,
img_size=32,
patch_size=4,
embed_dim=384,
stream_dim=192,
n_blocks=9,
n_heads=8,
output_dim=256,
n_anchors=128,
n_gal_anchors=64,
n_comp=16,
d_comp=128,
anchor_drop=0.10,
cv_target=0.22,
ksimplex_k=4,
ksimplex_edim=8,
dropout=0.1,
infonce_temp=0.07,
infonce_weight=0.1,
bce_weight=1.0,
cm_weight=0.1,
cv_weight=0.1,
autograd_tang=1.0,
autograd_sep=0.1,
enable_autograd=True,
label_smoothing=0.1,
# ── v8: stream B + geo InfoNCE weights (separate) ──
stream_b_nce_weight=0.5,
geo_nce_weight=0.5,
):
super().__init__()
self.num_classes = num_classes
self.num_patches = (img_size // patch_size) ** 2
self.stream_dim = stream_dim
self.output_dim = output_dim
self.cv_target = cv_target
self.infonce_temp = infonce_temp
self.infonce_weight = infonce_weight
self.bce_weight = bce_weight
self.cm_weight = cm_weight
self.cv_weight = cv_weight
self.autograd_tang = autograd_tang
self.autograd_sep = autograd_sep
self.enable_autograd = enable_autograd
self.label_smoothing = label_smoothing
self.stream_b_nce_weight = stream_b_nce_weight
self.geo_nce_weight = geo_nce_weight
self.config = {k: v for k, v in locals().items()
if k != 'self' and not k.startswith('_')}
# ── v8: gate init from block count ──
gate_init = 1.0 / (2.0 * n_blocks) # ~0.055 for 9 blocks
print(f" Gate init: {gate_init:.4f} (1/(2Γ—{n_blocks}))")
# Shared patch embedding
self.patch_embed = nn.Conv2d(
3, embed_dim, kernel_size=patch_size, stride=patch_size)
self.pos_embed = nn.Parameter(
torch.randn(1, self.num_patches, embed_dim) * 0.02)
# Stream projections
self.proj_a = nn.Sequential(
nn.Linear(embed_dim, stream_dim), nn.LayerNorm(stream_dim))
self.proj_b = nn.Sequential(
nn.Linear(embed_dim, stream_dim), nn.LayerNorm(stream_dim))
# Shared GAL
self.gal = GAL(stream_dim, n_gal_anchors, n_heads,
ksimplex_k, ksimplex_edim, dropout)
# Tri-stream blocks β€” v8: pass gate_init
self.blocks = nn.ModuleList([
TriStreamBlock(stream_dim, n_gal_anchors, n_heads,
gate_init=gate_init, dropout=dropout)
for _ in range(n_blocks)])
# Output norms
self.norm_a = nn.LayerNorm(stream_dim)
self.norm_b = nn.LayerNorm(stream_dim)
# Sphere projections
self.proj_sphere_a = nn.Sequential(
nn.Linear(stream_dim, output_dim), nn.LayerNorm(output_dim))
self.proj_sphere_b = nn.Sequential(
nn.Linear(stream_dim, output_dim), nn.LayerNorm(output_dim))
self.proj_sphere_geo = nn.Sequential(
nn.Linear(stream_dim, output_dim), nn.LayerNorm(output_dim))
# Constellation + Patchwork (uniform hypersphere via Constellation)
self.constellation = Constellation(n_anchors, output_dim, anchor_drop)
self.patchwork = Patchwork(n_anchors, n_comp, d_comp)
pw_dim = n_comp * d_comp
# Classifiers
self.classifier_a = nn.Sequential(
nn.Linear(pw_dim + output_dim, pw_dim), nn.GELU(),
nn.LayerNorm(pw_dim), nn.Dropout(dropout),
nn.Linear(pw_dim, num_classes))
self.classifier_b = nn.Sequential(
nn.Linear(pw_dim + output_dim, pw_dim), nn.GELU(),
nn.LayerNorm(pw_dim), nn.Dropout(dropout),
nn.Linear(pw_dim, num_classes))
self.geo_classifier = nn.Sequential(
nn.Linear(output_dim, output_dim), nn.GELU(),
nn.Dropout(dropout),
nn.Linear(output_dim, num_classes))
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x, apply_autograd=True):
output = {}
B = x.shape[0]
# Patch embedding
tokens = self.patch_embed(x).flatten(2).transpose(1, 2)
tokens = tokens + self.pos_embed
P = tokens.shape[1]
# Split
stream_a = self.proj_a(tokens)
stream_b = self.proj_b(tokens)
# Anchor KV once
anchor_kv = self.gal.get_anchor_kv()
# Process through blocks
all_geo_feats = []
all_vol2 = []
geo_accum = torch.zeros_like(stream_a)
for block in self.blocks:
stream_a, stream_b, geo_feats, vol2, geo_lifted = block(
stream_a, stream_b, self.gal, anchor_kv)
all_geo_feats.append(geo_feats)
all_vol2.append(vol2)
geo_accum = geo_accum + geo_lifted
output['geo_feats'] = all_geo_feats[-1]
output['all_geo_feats'] = torch.stack(all_geo_feats)
output['vol2'] = torch.stack(all_vol2)
# Norms
stream_a = self.norm_a(stream_a)
stream_b = self.norm_b(stream_b)
# Pool
pool_a = stream_a.mean(dim=1)
pool_b = stream_b.mean(dim=1)
pool_geo = geo_accum.mean(dim=1)
# β†’ sphere
emb_a = F.normalize(self.proj_sphere_a(pool_a), dim=-1)
emb_b = F.normalize(self.proj_sphere_b(pool_b), dim=-1)
geo_emb = F.normalize(self.proj_sphere_geo(pool_geo), dim=-1)
# Combined
emb = F.normalize(emb_a + emb_b + geo_emb, dim=-1)
# EmbeddingAutograd
if apply_autograd and self.training and self.enable_autograd:
emb = EmbeddingAutograd.apply(
emb, emb, self.constellation.anchors,
self.autograd_tang, self.autograd_sep)
# ── v8: autograd on ALL three sub-embeddings ──
emb_b = EmbeddingAutograd.apply(
emb_b, emb_b, self.constellation.anchors,
self.autograd_tang, self.autograd_sep)
geo_emb = EmbeddingAutograd.apply(
geo_emb, geo_emb, self.constellation.anchors,
self.autograd_tang, self.autograd_sep)
output['embedding'] = emb
output['emb_a'] = emb_a
output['emb_b'] = emb_b
output['geo_emb'] = geo_emb
output['pool_geo'] = pool_geo
# Constellation + Patchwork
tri_full, nearest_full = self.constellation.triangulate(
emb, training=False)
pw = self.patchwork(tri_full)
output['triangulation'] = tri_full
if self.training:
_, nearest = self.constellation.triangulate(emb, training=True)
else:
nearest = nearest_full
output['nearest'] = nearest
# Classifiers
logits_a = self.classifier_a(torch.cat([pw, emb_a], dim=-1))
logits_b = self.classifier_b(torch.cat([pw, emb_b], dim=-1))
geo_logits = self.geo_classifier(geo_emb)
output['logits_a'] = logits_a
output['logits_b'] = logits_b
output['geo_logits'] = geo_logits
# Gate monitoring
gates_a = [b.gal_block.gate_a.item() for b in self.blocks]
gates_b = [b.gal_block.gate_b.item() for b in self.blocks]
output['gates_a'] = gates_a
output['gates_b'] = gates_b
return output
# ──────────────────────────────────────────────────────────
# PROCRUSTES ANCHOR UPDATE (unchanged)
# ──────────────────────────────────────────────────────────
@torch.no_grad()
def update_gal_anchors(self, simplex_buffer, lr=0.015, whiten=False):
with torch.amp.autocast("cuda", enabled=False):
centroids = simplex_buffer.class_centroids(self.num_classes)
if centroids is None:
return None
anchors = self.gal.gal_anchors.float()
centroid_n = F.normalize(centroids.float(), dim=-1)
anchor_n = F.normalize(anchors, dim=-1)
cos = centroid_n @ anchor_n.T
matched_idx = cos.argmax(dim=1)
matched_anchors = anchors[matched_idx]
R, score = procrustes_align(
matched_anchors, centroids.float(), whiten=whiten)
rotated = anchors @ R
new_anchors = F.normalize(
anchors + lr * (rotated - anchors), dim=-1)
self.gal.gal_anchors.copy_(
new_anchors.to(self.gal.gal_anchors.dtype))
return score
# ──────────────────────────────────────────────────────────
# LOSS β€” v8: InfoNCE on emb_b + stronger geo_emb signal
# ──────────────────────────────────────────────────────────
def compute_loss(self, output, targets, output_aug=None,
mastery_queue=None):
loss_dict = {}
emb = output['embedding']
emb_b = output['emb_b']
geo_emb = output['geo_emb']
B = emb.shape[0]
is_mastery = mastery_queue is not None and mastery_queue.active
# ── CE on Stream A ──
l_ce = F.cross_entropy(output['logits_a'], targets)
loss_dict['ce'] = l_ce
acc_a = (output['logits_a'].argmax(-1) == targets).float().mean().item()
loss_dict['acc_a'] = acc_a
# ── BCE on Stream B ──
one_hot = F.one_hot(targets, self.num_classes).float()
ls = self.label_smoothing
one_hot_smooth = one_hot * (1.0 - ls) + ls / self.num_classes if ls > 0 else one_hot
l_bce = F.binary_cross_entropy_with_logits(
output['logits_b'], one_hot_smooth)
loss_dict['bce'] = l_bce
acc_b = (output['logits_b'].argmax(-1) == targets).float().mean().item()
loss_dict['acc_b'] = acc_b
# ── Geo classifier BCE ──
l_geo_bce = F.binary_cross_entropy_with_logits(
output['geo_logits'], one_hot_smooth)
loss_dict['geo_bce'] = l_geo_bce
geo_acc = (output['geo_logits'].argmax(-1) == targets).float().mean().item()
loss_dict['geo_acc'] = geo_acc
# ── InfoNCE β€” v8: on combined, emb_b, AND geo_emb ──
nce_acc = 0.0
if output_aug is not None:
labels_nce = torch.arange(B, device=emb.device)
# Combined embedding InfoNCE
emb_aug = output_aug['embedding']
sim = emb @ emb_aug.T / self.infonce_temp
l_nce = F.cross_entropy(sim, labels_nce)
nce_acc = (sim.argmax(1) == labels_nce).float().mean().item()
loss_dict['nce'] = l_nce
loss_dict['nce_acc'] = nce_acc
# ── v8: Stream B InfoNCE (this is what keeps B alive) ──
emb_b_aug = output_aug.get('emb_b')
if emb_b_aug is not None:
sim_b = emb_b @ emb_b_aug.T / self.infonce_temp
l_nce_b = F.cross_entropy(sim_b, labels_nce)
nce_b_acc = (sim_b.argmax(1) == labels_nce).float().mean().item()
loss_dict['nce_b'] = l_nce_b
loss_dict['nce_b_acc'] = nce_b_acc
# ── v8: Geo InfoNCE (this is what feeds the geo path) ──
geo_emb_aug = output_aug.get('geo_emb')
if geo_emb_aug is not None:
sim_g = geo_emb @ geo_emb_aug.T / self.infonce_temp
l_geo_nce = F.cross_entropy(sim_g, labels_nce)
geo_nce_acc = (sim_g.argmax(1) == labels_nce).float().mean().item()
loss_dict['geo_nce'] = l_geo_nce
loss_dict['geo_nce_acc'] = geo_nce_acc
# ── Mastery (unchanged) ──
if is_mastery:
q_emb, q_labels = mastery_queue.get()
if q_emb is not None and q_emb.shape[0] >= B:
cross_sim = emb @ q_emb.T
same_mask = targets.unsqueeze(1) == q_labels.unsqueeze(0)
hn_sim = cross_sim.clone(); hn_sim[same_mask] = -1e9
hn_cos = hn_sim.max(dim=1).values
hp_sim = cross_sim.clone(); hp_sim[~same_mask] = 1e9
hp_cos = hp_sim.min(dim=1).values
valid = same_mask.any(1) & (~same_mask).any(1)
if valid.sum() > 0:
margin = mastery_queue.current_margin
l_mastery = F.relu(
hn_cos[valid] - hp_cos[valid] + margin).mean()
loss_dict['mastery'] = l_mastery
loss_dict['hard_neg_cos'] = hn_cos[valid].mean().item()
loss_dict['hard_pos_cos'] = hp_cos[valid].mean().item()
loss_dict['margin'] = margin
mastery_queue.push(emb.detach(), targets.detach())
# ── CM validity ──
vol2 = output['vol2']
l_cm = F.relu(-vol2).mean()
loss_dict['cm'] = l_cm
loss_dict['cm_valid'] = (vol2 > 0).float().mean().item()
# ── CV on combined + geo ──
l_cv_main = self._cv_loss_fast(emb, target=self.cv_target)
l_cv_geo = self._cv_loss_fast(geo_emb, target=self.cv_target)
l_cv = l_cv_main + l_cv_geo
loss_dict['cv'] = l_cv
loss_dict['cv_main'] = l_cv_main.item() if torch.is_tensor(l_cv_main) else l_cv_main
loss_dict['cv_geo'] = l_cv_geo.item() if torch.is_tensor(l_cv_geo) else l_cv_geo
# ── Anchor spread ──
anchors_n = F.normalize(self.constellation.anchors, dim=-1)
anchor_sim = anchors_n @ anchors_n.T
mask_a = ~torch.eye(anchors_n.shape[0], dtype=torch.bool,
device=anchors_n.device)
l_spread = F.relu(anchor_sim[mask_a] - 0.0).mean()
loss_dict['spread'] = l_spread
# ── Combine β€” v8: explicit weights for B and geo NCE ──
loss = (l_ce * self.bce_weight
+ l_bce * self.bce_weight
+ l_geo_bce * self.bce_weight
+ loss_dict.get('nce', 0.0) * self.infonce_weight
+ loss_dict.get('nce_b', 0.0) * self.stream_b_nce_weight
+ loss_dict.get('geo_nce', 0.0) * self.geo_nce_weight
+ loss_dict.get('mastery', 0.0) * self.bce_weight
+ l_cm * self.cm_weight
+ l_cv * self.cv_weight
+ l_spread * 0.001)
loss_dict['total'] = loss
return loss, loss_dict
@staticmethod
def _cv_loss_fast(emb, target=0.22, n_samples=64, n_points=5):
B = emb.shape[0]
if B < n_points:
return torch.tensor(0.0, device=emb.device)
vols = []
for _ in range(n_samples):
idx = torch.randperm(min(B, 512), device=emb.device)[:n_points]
pts = emb[idx].unsqueeze(0)
gram = torch.bmm(pts, pts.transpose(1, 2))
norms = torch.diagonal(gram, dim1=1, dim2=2)
d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram
d2 = F.relu(d2)
N = n_points
cm = torch.zeros(1, N + 1, N + 1,
device=emb.device, dtype=emb.dtype)
cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
k = N - 1
sign = (-1.0) ** (k + 1)
fact = math.factorial(k)
prefactor = sign / ((2.0 ** k) * (fact ** 2))
vol2 = prefactor * torch.linalg.det(cm.float())
if vol2[0].item() > 1e-20:
vols.append(vol2[0].to(emb.dtype).sqrt())
if len(vols) < 5:
return torch.tensor(0.0, device=emb.device)
vols_t = torch.stack(vols)
cv = vols_t.std() / (vols_t.mean() + 1e-8)
return (cv - target).pow(2)
# ══════════════════════════════════════════════════════════════════
# MASTERY QUEUE (unchanged)
# ══════════════════════════════════════════════════════════════════
class MasteryQueue:
def __init__(self, dim, min_size=1024, max_size=8192, initial_size=4096,
patience=50, device='cuda',
margin_start=0.1, margin_end=0.3, margin_warmup=5000,
resize_step=1024, resize_cooldown=5, overfit_threshold=3.0):
self.dim = dim
self.min_size = min_size; self.max_size = max_size
self._current_max = initial_size
self.patience = patience; self.device = device
self.active = False
self._embs = None; self._labels = None
self._perfect_count = 0; self._total_batches = 0
self._activated_at = None
self._margin_start = margin_start
self._margin_end = margin_end
self._margin_warmup = margin_warmup
self._mastery_steps = 0
self._resize_step = resize_step
self._resize_cooldown = resize_cooldown
self._overfit_threshold = overfit_threshold
self._epochs_since_resize = resize_cooldown
self._gap_history = []; self._gap_window = 5
self._resize_history = []
def check_activation(self, nce_acc):
self._total_batches += 1
if nce_acc >= 0.99:
self._perfect_count += 1
else:
self._perfect_count = 0
if not self.active and self._perfect_count >= self.patience:
self.active = True
self._activated_at = self._total_batches
print(f"\n β˜… MASTERY ACTIVATED at batch {self._total_batches} "
f"(nce_acc=1.0 for {self.patience} consecutive) "
f"queue={self._current_max}")
if self.active:
self._mastery_steps += 1
def update_size(self, train_acc, val_acc, epoch):
if not self.active: return
self._epochs_since_resize += 1
gap = train_acc - val_acc
self._gap_history.append((epoch, gap))
if self._epochs_since_resize < self._resize_cooldown: return
old_size = self._current_max; reason = None
if gap > self._overfit_threshold * 2:
self._current_max = min(self._current_max + self._resize_step, self.max_size)
reason = f"grow: gap={gap:.1f}%"
elif gap < self._overfit_threshold and gap > 0:
if len(self._gap_history) >= self._gap_window:
recent = [g for _, g in self._gap_history[-self._gap_window:]]
if all(0 < g < self._overfit_threshold for g in recent):
self._current_max = max(self._current_max - self._resize_step, self.min_size)
reason = f"shrink: stable gap={gap:.1f}%"
if reason is None and len(self._gap_history) >= self._gap_window:
drift = gap - self._gap_history[-self._gap_window][1]
if drift > self._overfit_threshold:
self._current_max = min(self._current_max + self._resize_step, self.max_size)
reason = f"drift: {drift:+.1f}%"
elif drift < -self._overfit_threshold and gap > 0:
self._current_max = max(self._current_max - self._resize_step, self.min_size)
reason = f"drift: {drift:+.1f}%"
if self._current_max != old_size:
d = "↑" if self._current_max > old_size else "↓"
print(f" βš™ Queue {d} {old_size}β†’{self._current_max} ({reason})")
self._epochs_since_resize = 0
self._resize_history.append((epoch, old_size, self._current_max, gap, reason))
if self._embs is not None and self._embs.shape[0] > self._current_max:
self._embs = self._embs[-self._current_max:]
self._labels = self._labels[-self._current_max:]
@property
def current_margin(self):
if not self.active: return self._margin_start
t = min(self._mastery_steps / max(self._margin_warmup, 1), 1.0)
return self._margin_start + t * (self._margin_end - self._margin_start)
def push(self, emb, labels):
emb = emb.detach().to(self.device)
labels = labels.detach().to(self.device)
if self._embs is None:
self._embs = emb; self._labels = labels
else:
self._embs = torch.cat([self._embs, emb], 0)[-self._current_max:]
self._labels = torch.cat([self._labels, labels], 0)[-self._current_max:]
def get(self):
if self._embs is None: return None, None
return self._embs, self._labels
@property
def size(self):
return 0 if self._embs is None else self._embs.shape[0]
def state_dict(self):
return {
'active': self.active, 'total_batches': self._total_batches,
'activated_at': self._activated_at,
'mastery_steps': self._mastery_steps,
'current_margin': self.current_margin,
'current_max': self._current_max,
'gap_history': self._gap_history[-20:],
'resize_history': self._resize_history,
}
# ══════════════════════════════════════════════════════════════════
# FACTORY
# ══════════════════════════════════════════════════════════════════
def create_tri_stream_vit(**kwargs):
return TriStreamViT(**kwargs)