| |
| """ |
| GeoLIP Tri-Stream ViT v8 β Geometric Arbitration (fixed) |
| ========================================================== |
| v7βv8 changes: |
| 1. Uniform hypersphere orthogonal init for GAL anchors + constellation |
| 2. Gate init at 1/(2*n_blocks) β geometry enters immediately |
| 3. InfoNCE on emb_b (Stream B survives through contrastive, not BCE) |
| 4. InfoNCE weight on geo_emb raised β geo was starved |
| 5. No residual scaling (per Phil) |
| 6. GAL update interval + lr controlled from trainer |
| |
| Three processing paths: |
| Stream A (CE loss): self-attn + FFN, standard cross-entropy |
| Stream B (BCE+NCE): self-attn + FFN, binary CE + InfoNCE |
| GAL (geometric): KSimplex features, accumulated over time, |
| provides cross-attention to shared anchors |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| from itertools import combinations |
|
|
|
|
| |
| |
| |
|
|
| def uniform_hypersphere_init(n, d): |
| """ |
| Generate n points with maximal spread on the d-dimensional unit sphere. |
| n <= d: orthogonal columns via QR decomposition (perfect spread). |
| n > d: QR orthogonal basis + iterative repulsion for the rest. |
| Returns: (n, d) tensor on the unit sphere. |
| """ |
| if n <= d: |
| |
| M = torch.randn(d, n) |
| Q, _ = torch.linalg.qr(M) |
| return Q.T.contiguous() |
| else: |
| |
| M = torch.randn(d, d) |
| Q, _ = torch.linalg.qr(M) |
| basis = Q.T |
|
|
| extra = torch.randn(n - d, d) |
| extra = F.normalize(extra, dim=-1) |
| vecs = torch.cat([basis, extra], dim=0) |
|
|
| |
| for _ in range(200): |
| sim = vecs @ vecs.T |
| sim.fill_diagonal_(-2.0) |
| |
| nn_idx = sim.argmax(dim=1) |
| nn_vec = vecs[nn_idx] |
| |
| vecs = F.normalize(vecs - 0.05 * nn_vec, dim=-1) |
|
|
| return vecs |
|
|
|
|
| |
| |
| |
|
|
| class CMValidator(nn.Module): |
| def __init__(self, k): |
| super().__init__() |
| self._k = k |
| self._nv = k + 1 |
| pairs = list(combinations(range(self._nv), 2)) |
| self._npairs = len(pairs) |
| self.register_buffer('_pi', torch.tensor([p[0] for p in pairs], dtype=torch.long)) |
| self.register_buffer('_pj', torch.tensor([p[1] for p in pairs], dtype=torch.long)) |
| sign = (-1.0) ** (k + 1) |
| fact = math.factorial(k) |
| self._prefactor = sign / ((2.0 ** k) * (fact ** 2)) |
|
|
| def forward(self, verts): |
| gram = torch.einsum('...ve,...we->...vw', verts, verts) |
| norms = torch.diagonal(gram, dim1=-2, dim2=-1) |
| d2_mat = norms.unsqueeze(-1) + norms.unsqueeze(-2) - 2 * gram |
| d2_mat = F.relu(d2_mat) |
| d2_pairs = d2_mat[..., self._pi, self._pj] |
| shape = d2_mat.shape[:-2] |
| V = d2_mat.shape[-1] |
| cm = torch.zeros(*shape, V + 1, V + 1, |
| device=d2_mat.device, dtype=d2_mat.dtype) |
| cm[..., 0, 1:] = 1.0; cm[..., 1:, 0] = 1.0 |
| cm[..., 1:, 1:] = d2_mat |
| vol2 = self._prefactor * torch.linalg.det(cm.float()) |
| vol2 = vol2.to(d2_pairs.dtype) |
| return d2_pairs, vol2 |
|
|
|
|
| class KSimplexChannel(nn.Module): |
| BASE_DEFORM = 0.05 |
|
|
| def __init__(self, k, in_dim, edim): |
| super().__init__() |
| self._k = k; self._nv = k + 1; self._edim = edim |
| self._cm = CMValidator(k) |
| self._out_dim = self._cm._npairs + 1 |
| template = self._make_regular_simplex(k, edim) |
| self.register_buffer('_template', template) |
| self._to_deform = nn.Linear(in_dim, self._nv * edim) |
| self._norm = nn.LayerNorm(self._out_dim) |
|
|
| @staticmethod |
| def _make_regular_simplex(k, edim): |
| nv = k + 1 |
| verts = torch.zeros(nv, edim) |
| for i in range(min(nv, edim)): |
| verts[i, i] = 1.0 |
| if nv > edim: |
| for i in range(edim, nv): |
| v = torch.randn(edim) |
| verts[i] = v / (v.norm() + 1e-8) |
| verts = verts - verts.mean(dim=0, keepdim=True) |
| edge_len = (verts[0] - verts[1]).norm().clamp(min=1e-8) |
| return verts / edge_len |
|
|
| @property |
| def out_dim(self): |
| return self._out_dim |
|
|
| def forward(self, x): |
| deform = self._to_deform(x).unflatten(-1, (self._nv, self._edim)) |
| verts = self._template + self.BASE_DEFORM * deform |
| d2, vol2 = self._cm(verts) |
| geo = torch.cat([d2, vol2.unsqueeze(-1)], dim=-1) |
| return self._norm(geo), vol2 |
|
|
|
|
| |
| |
| |
|
|
| class Constellation(nn.Module): |
| def __init__(self, n_anchors, dim, anchor_drop=0.0): |
| super().__init__() |
| |
| init_vecs = uniform_hypersphere_init(n_anchors, dim) |
| self.anchors = nn.Parameter(init_vecs) |
| self.anchor_drop = anchor_drop |
| |
| with torch.no_grad(): |
| an = F.normalize(init_vecs, dim=-1) |
| sim = an @ an.T |
| mask = ~torch.eye(n_anchors, dtype=torch.bool) |
| off = sim[mask] |
| print(f" β Constellation: {n_anchors}Γ{dim} uniform hypersphere") |
| print(f" pairwise cos: mean={off.mean():.4f} max={off.max():.4f}") |
|
|
| def triangulate(self, emb, training=False): |
| anchors = F.normalize(self.anchors, dim=-1) |
| if training and self.anchor_drop > 0: |
| mask = torch.rand(anchors.shape[0], device=anchors.device) > self.anchor_drop |
| if mask.sum() < 2: mask[:2] = True |
| anchors = anchors[mask] |
| cos = emb @ anchors.T |
| tri = 1.0 - cos |
| _, nearest_local = cos.max(dim=-1) |
| nearest = mask.nonzero(as_tuple=True)[0][nearest_local] |
| else: |
| cos = emb @ anchors.T |
| tri = 1.0 - cos |
| _, nearest = cos.max(dim=-1) |
| return tri, nearest |
|
|
|
|
| class Patchwork(nn.Module): |
| def __init__(self, n_anchors, n_comp, d_comp): |
| super().__init__() |
| self.n_comp = n_comp; self.d_comp = d_comp |
| self.register_buffer('asgn', torch.arange(n_anchors) % n_comp) |
| anchors_per = n_anchors // n_comp |
| self.comps = nn.ModuleList([nn.Sequential( |
| nn.Linear(anchors_per, d_comp * 2), nn.GELU(), |
| nn.Linear(d_comp * 2, d_comp), nn.LayerNorm(d_comp)) |
| for _ in range(n_comp)]) |
|
|
| def forward(self, tri): |
| return torch.cat([self.comps[k](tri[:, self.asgn == k]) |
| for k in range(self.n_comp)], -1) |
|
|
|
|
| |
| |
| |
|
|
| class EmbeddingAutograd(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x, embedding, anchors, tang, sep): |
| ctx.save_for_backward(embedding, anchors) |
| ctx.tang = tang; ctx.sep = sep |
| return x |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| embedding, anchors = ctx.saved_tensors |
| emb_n = F.normalize(embedding.detach().float(), dim=-1) |
| anchors_n = F.normalize(anchors.detach().float(), dim=-1) |
| grad_f = grad_output.float() |
| radial = (grad_f * emb_n).sum(-1, keepdim=True) * emb_n |
| corrected = (grad_f - radial) + (1.0 - ctx.tang) * radial |
| if ctx.sep > 0: |
| cos_to = emb_n @ anchors_n.T |
| nearest = anchors_n[cos_to.argmax(dim=-1)] |
| toward = (corrected * nearest).sum(-1, keepdim=True) |
| corrected = corrected - ctx.sep * (toward > 0).float() * toward * nearest |
| return corrected.to(grad_output.dtype), None, None, None, None |
|
|
|
|
| |
| |
| |
|
|
| def procrustes_align(source, target, whiten=False): |
| source_c = source.float() - source.float().mean(0, keepdim=True) |
| target_c = target.float() - target.float().mean(0, keepdim=True) |
| if whiten: |
| source_c = source_c / (source_c.std(0, keepdim=True) + 1e-8) |
| target_c = target_c / (target_c.std(0, keepdim=True) + 1e-8) |
| M = (source_c.T @ target_c).float() |
| U, S, Vt = torch.linalg.svd(M) |
| d = torch.ones(U.shape[0], device=U.device, dtype=U.dtype) |
| d[-1] = torch.det(U @ Vt).sign() |
| R = U @ torch.diag(d) @ Vt |
| return R, S.sum().item() |
|
|
|
|
| |
| |
| |
|
|
| class SimplexBuffer: |
| def __init__(self, dim, max_size=50000, device='cuda'): |
| self.dim = dim; self.max_size = max_size; self.device = device |
| self._feats = None; self._labels = None |
|
|
| def push(self, feats, labels): |
| feats = feats.detach().to(self.device) |
| labels = labels.detach().to(self.device) |
| if self._feats is None: |
| self._feats = feats; self._labels = labels |
| else: |
| self._feats = torch.cat([self._feats, feats], 0)[-self.max_size:] |
| self._labels = torch.cat([self._labels, labels], 0)[-self.max_size:] |
|
|
| @property |
| def size(self): |
| return 0 if self._feats is None else self._feats.shape[0] |
|
|
| def class_centroids(self, num_classes): |
| if self._feats is None or self.size < num_classes * 10: |
| return None |
| centroids = [] |
| for c in range(num_classes): |
| mask = self._labels == c |
| if mask.sum() == 0: return None |
| centroids.append(self._feats[mask].mean(0)) |
| return torch.stack(centroids) |
|
|
|
|
| |
| |
| |
|
|
| class GAL(nn.Module): |
| def __init__(self, stream_dim, n_gal_anchors, n_heads, |
| ksimplex_k=4, ksimplex_edim=8, dropout=0.1): |
| super().__init__() |
| self.stream_dim = stream_dim |
| self.n_gal_anchors = n_gal_anchors |
|
|
| |
| init_anchors = uniform_hypersphere_init(n_gal_anchors, stream_dim) |
| self.register_buffer('gal_anchors', init_anchors) |
| with torch.no_grad(): |
| an = F.normalize(init_anchors, dim=-1) |
| sim = an @ an.T |
| mask = ~torch.eye(n_gal_anchors, dtype=torch.bool) |
| off = sim[mask] |
| print(f" β GAL anchors: {n_gal_anchors}Γ{stream_dim} " |
| f"uniform hypersphere") |
| print(f" pairwise cos: mean={off.mean():.4f} " |
| f"max={off.max():.4f}") |
|
|
| self.ksimplex = KSimplexChannel( |
| k=ksimplex_k, in_dim=stream_dim, edim=ksimplex_edim) |
| self.geo_lift = nn.Sequential( |
| nn.Linear(self.ksimplex.out_dim, stream_dim), nn.GELU()) |
| self.anchor_proj = nn.Sequential( |
| nn.Linear(stream_dim, stream_dim), nn.LayerNorm(stream_dim)) |
|
|
| @torch.no_grad() |
| def rotate_anchors(self, rotation_matrix): |
| self.gal_anchors.copy_( |
| (self.gal_anchors @ rotation_matrix).contiguous()) |
|
|
| def get_anchor_kv(self): |
| return self.anchor_proj(self.gal_anchors) |
|
|
|
|
| class GALBlock(nn.Module): |
| """ |
| Per-layer GAL injection with non-zero gate init. |
| v8: gates start at 1/(2*n_blocks) so geometry enters immediately. |
| """ |
| def __init__(self, stream_dim, n_gal_anchors, n_heads, |
| gate_init=0.055, dropout=0.1): |
| super().__init__() |
|
|
| self.cross_attn_a = nn.MultiheadAttention( |
| stream_dim, n_heads, dropout=dropout, batch_first=True) |
| self.cross_attn_b = nn.MultiheadAttention( |
| stream_dim, n_heads, dropout=dropout, batch_first=True) |
|
|
| self.norm_ga = nn.LayerNorm(stream_dim) |
| self.norm_gb = nn.LayerNorm(stream_dim) |
|
|
| self.lift_proj_a = nn.Linear(stream_dim, stream_dim) |
| self.lift_proj_b = nn.Linear(stream_dim, stream_dim) |
|
|
| |
| self.gate_a = nn.Parameter(torch.tensor(gate_init)) |
| self.gate_b = nn.Parameter(torch.tensor(gate_init)) |
|
|
| def forward(self, stream_a, stream_b, anchor_kv, geo_lifted): |
| B = stream_a.shape[0] |
| kv = anchor_kv.unsqueeze(0).expand(B, -1, -1) |
|
|
| qa = self.norm_ga(stream_a) |
| ha, _ = self.cross_attn_a(qa, kv, kv, need_weights=False) |
|
|
| qb = self.norm_gb(stream_b) |
| hb, _ = self.cross_attn_b(qb, kv, kv, need_weights=False) |
|
|
| stream_a = stream_a + self.gate_a * (ha + self.lift_proj_a(geo_lifted)) |
| stream_b = stream_b + self.gate_b * (hb + self.lift_proj_b(geo_lifted)) |
|
|
| return stream_a, stream_b |
|
|
|
|
| |
| |
| |
|
|
| class TriStreamBlock(nn.Module): |
| def __init__(self, stream_dim, n_gal_anchors, n_heads, |
| gate_init=0.055, dropout=0.1): |
| super().__init__() |
|
|
| |
| self.norm_a1 = nn.LayerNorm(stream_dim) |
| self.attn_a = nn.MultiheadAttention( |
| stream_dim, n_heads, dropout=dropout, batch_first=True) |
| self.norm_a2 = nn.LayerNorm(stream_dim) |
| self.ffn_a = nn.Sequential( |
| nn.Linear(stream_dim, stream_dim * 4), nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(stream_dim * 4, stream_dim), nn.Dropout(dropout)) |
|
|
| |
| self.norm_b1 = nn.LayerNorm(stream_dim) |
| self.attn_b = nn.MultiheadAttention( |
| stream_dim, n_heads, dropout=dropout, batch_first=True) |
| self.norm_b2 = nn.LayerNorm(stream_dim) |
| self.ffn_b = nn.Sequential( |
| nn.Linear(stream_dim, stream_dim * 4), nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(stream_dim * 4, stream_dim), nn.Dropout(dropout)) |
|
|
| |
| self.gal_block = GALBlock( |
| stream_dim, n_gal_anchors, n_heads, |
| gate_init=gate_init, dropout=dropout) |
|
|
| self.geo_combine_norm = nn.LayerNorm(stream_dim) |
|
|
| def forward(self, stream_a, stream_b, gal, anchor_kv): |
| B, P, D = stream_a.shape |
|
|
| |
| h = self.norm_a1(stream_a) |
| h, _ = self.attn_a(h, h, h, need_weights=False) |
| stream_a = stream_a + h |
| stream_a = stream_a + self.ffn_a(self.norm_a2(stream_a)) |
|
|
| |
| h = self.norm_b1(stream_b) |
| h, _ = self.attn_b(h, h, h, need_weights=False) |
| stream_b = stream_b + h |
| stream_b = stream_b + self.ffn_b(self.norm_b2(stream_b)) |
|
|
| |
| geo_input = self.geo_combine_norm(stream_a + stream_b) |
| flat = geo_input.reshape(B * P, D) |
| geo_feats, vol2 = gal.ksimplex(flat) |
| geo_feats = geo_feats.reshape(B, P, -1) |
| vol2 = vol2.reshape(B, P) |
| geo_lifted = gal.geo_lift(geo_feats) |
|
|
| stream_a, stream_b = self.gal_block( |
| stream_a, stream_b, anchor_kv, geo_lifted) |
|
|
| return stream_a, stream_b, geo_feats, vol2, geo_lifted |
|
|
|
|
| |
| |
| |
|
|
| class TriStreamViT(nn.Module): |
| def __init__( |
| self, |
| num_classes=10, |
| img_size=32, |
| patch_size=4, |
| embed_dim=384, |
| stream_dim=192, |
| n_blocks=9, |
| n_heads=8, |
| output_dim=256, |
| n_anchors=128, |
| n_gal_anchors=64, |
| n_comp=16, |
| d_comp=128, |
| anchor_drop=0.10, |
| cv_target=0.22, |
| ksimplex_k=4, |
| ksimplex_edim=8, |
| dropout=0.1, |
| infonce_temp=0.07, |
| infonce_weight=0.1, |
| bce_weight=1.0, |
| cm_weight=0.1, |
| cv_weight=0.1, |
| autograd_tang=1.0, |
| autograd_sep=0.1, |
| enable_autograd=True, |
| label_smoothing=0.1, |
| |
| stream_b_nce_weight=0.5, |
| geo_nce_weight=0.5, |
| ): |
| super().__init__() |
| self.num_classes = num_classes |
| self.num_patches = (img_size // patch_size) ** 2 |
| self.stream_dim = stream_dim |
| self.output_dim = output_dim |
| self.cv_target = cv_target |
| self.infonce_temp = infonce_temp |
| self.infonce_weight = infonce_weight |
| self.bce_weight = bce_weight |
| self.cm_weight = cm_weight |
| self.cv_weight = cv_weight |
| self.autograd_tang = autograd_tang |
| self.autograd_sep = autograd_sep |
| self.enable_autograd = enable_autograd |
| self.label_smoothing = label_smoothing |
| self.stream_b_nce_weight = stream_b_nce_weight |
| self.geo_nce_weight = geo_nce_weight |
|
|
| self.config = {k: v for k, v in locals().items() |
| if k != 'self' and not k.startswith('_')} |
|
|
| |
| gate_init = 1.0 / (2.0 * n_blocks) |
| print(f" Gate init: {gate_init:.4f} (1/(2Γ{n_blocks}))") |
|
|
| |
| self.patch_embed = nn.Conv2d( |
| 3, embed_dim, kernel_size=patch_size, stride=patch_size) |
| self.pos_embed = nn.Parameter( |
| torch.randn(1, self.num_patches, embed_dim) * 0.02) |
|
|
| |
| self.proj_a = nn.Sequential( |
| nn.Linear(embed_dim, stream_dim), nn.LayerNorm(stream_dim)) |
| self.proj_b = nn.Sequential( |
| nn.Linear(embed_dim, stream_dim), nn.LayerNorm(stream_dim)) |
|
|
| |
| self.gal = GAL(stream_dim, n_gal_anchors, n_heads, |
| ksimplex_k, ksimplex_edim, dropout) |
|
|
| |
| self.blocks = nn.ModuleList([ |
| TriStreamBlock(stream_dim, n_gal_anchors, n_heads, |
| gate_init=gate_init, dropout=dropout) |
| for _ in range(n_blocks)]) |
|
|
| |
| self.norm_a = nn.LayerNorm(stream_dim) |
| self.norm_b = nn.LayerNorm(stream_dim) |
|
|
| |
| self.proj_sphere_a = nn.Sequential( |
| nn.Linear(stream_dim, output_dim), nn.LayerNorm(output_dim)) |
| self.proj_sphere_b = nn.Sequential( |
| nn.Linear(stream_dim, output_dim), nn.LayerNorm(output_dim)) |
| self.proj_sphere_geo = nn.Sequential( |
| nn.Linear(stream_dim, output_dim), nn.LayerNorm(output_dim)) |
|
|
| |
| self.constellation = Constellation(n_anchors, output_dim, anchor_drop) |
| self.patchwork = Patchwork(n_anchors, n_comp, d_comp) |
| pw_dim = n_comp * d_comp |
|
|
| |
| self.classifier_a = nn.Sequential( |
| nn.Linear(pw_dim + output_dim, pw_dim), nn.GELU(), |
| nn.LayerNorm(pw_dim), nn.Dropout(dropout), |
| nn.Linear(pw_dim, num_classes)) |
|
|
| self.classifier_b = nn.Sequential( |
| nn.Linear(pw_dim + output_dim, pw_dim), nn.GELU(), |
| nn.LayerNorm(pw_dim), nn.Dropout(dropout), |
| nn.Linear(pw_dim, num_classes)) |
|
|
| self.geo_classifier = nn.Sequential( |
| nn.Linear(output_dim, output_dim), nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(output_dim, num_classes)) |
|
|
| self._init_weights() |
|
|
| def _init_weights(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Linear): |
| nn.init.trunc_normal_(m.weight, std=0.02) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
| elif isinstance(m, nn.LayerNorm): |
| nn.init.ones_(m.weight) |
| nn.init.zeros_(m.bias) |
|
|
| def forward(self, x, apply_autograd=True): |
| output = {} |
| B = x.shape[0] |
|
|
| |
| tokens = self.patch_embed(x).flatten(2).transpose(1, 2) |
| tokens = tokens + self.pos_embed |
| P = tokens.shape[1] |
|
|
| |
| stream_a = self.proj_a(tokens) |
| stream_b = self.proj_b(tokens) |
|
|
| |
| anchor_kv = self.gal.get_anchor_kv() |
|
|
| |
| all_geo_feats = [] |
| all_vol2 = [] |
| geo_accum = torch.zeros_like(stream_a) |
|
|
| for block in self.blocks: |
| stream_a, stream_b, geo_feats, vol2, geo_lifted = block( |
| stream_a, stream_b, self.gal, anchor_kv) |
| all_geo_feats.append(geo_feats) |
| all_vol2.append(vol2) |
| geo_accum = geo_accum + geo_lifted |
|
|
| output['geo_feats'] = all_geo_feats[-1] |
| output['all_geo_feats'] = torch.stack(all_geo_feats) |
| output['vol2'] = torch.stack(all_vol2) |
|
|
| |
| stream_a = self.norm_a(stream_a) |
| stream_b = self.norm_b(stream_b) |
|
|
| |
| pool_a = stream_a.mean(dim=1) |
| pool_b = stream_b.mean(dim=1) |
| pool_geo = geo_accum.mean(dim=1) |
|
|
| |
| emb_a = F.normalize(self.proj_sphere_a(pool_a), dim=-1) |
| emb_b = F.normalize(self.proj_sphere_b(pool_b), dim=-1) |
| geo_emb = F.normalize(self.proj_sphere_geo(pool_geo), dim=-1) |
|
|
| |
| emb = F.normalize(emb_a + emb_b + geo_emb, dim=-1) |
|
|
| |
| if apply_autograd and self.training and self.enable_autograd: |
| emb = EmbeddingAutograd.apply( |
| emb, emb, self.constellation.anchors, |
| self.autograd_tang, self.autograd_sep) |
| |
| emb_b = EmbeddingAutograd.apply( |
| emb_b, emb_b, self.constellation.anchors, |
| self.autograd_tang, self.autograd_sep) |
| geo_emb = EmbeddingAutograd.apply( |
| geo_emb, geo_emb, self.constellation.anchors, |
| self.autograd_tang, self.autograd_sep) |
|
|
| output['embedding'] = emb |
| output['emb_a'] = emb_a |
| output['emb_b'] = emb_b |
| output['geo_emb'] = geo_emb |
| output['pool_geo'] = pool_geo |
|
|
| |
| tri_full, nearest_full = self.constellation.triangulate( |
| emb, training=False) |
| pw = self.patchwork(tri_full) |
| output['triangulation'] = tri_full |
|
|
| if self.training: |
| _, nearest = self.constellation.triangulate(emb, training=True) |
| else: |
| nearest = nearest_full |
| output['nearest'] = nearest |
|
|
| |
| logits_a = self.classifier_a(torch.cat([pw, emb_a], dim=-1)) |
| logits_b = self.classifier_b(torch.cat([pw, emb_b], dim=-1)) |
| geo_logits = self.geo_classifier(geo_emb) |
|
|
| output['logits_a'] = logits_a |
| output['logits_b'] = logits_b |
| output['geo_logits'] = geo_logits |
|
|
| |
| gates_a = [b.gal_block.gate_a.item() for b in self.blocks] |
| gates_b = [b.gal_block.gate_b.item() for b in self.blocks] |
| output['gates_a'] = gates_a |
| output['gates_b'] = gates_b |
|
|
| return output |
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def update_gal_anchors(self, simplex_buffer, lr=0.015, whiten=False): |
| with torch.amp.autocast("cuda", enabled=False): |
| centroids = simplex_buffer.class_centroids(self.num_classes) |
| if centroids is None: |
| return None |
|
|
| anchors = self.gal.gal_anchors.float() |
| centroid_n = F.normalize(centroids.float(), dim=-1) |
| anchor_n = F.normalize(anchors, dim=-1) |
| cos = centroid_n @ anchor_n.T |
| matched_idx = cos.argmax(dim=1) |
| matched_anchors = anchors[matched_idx] |
|
|
| R, score = procrustes_align( |
| matched_anchors, centroids.float(), whiten=whiten) |
|
|
| rotated = anchors @ R |
| new_anchors = F.normalize( |
| anchors + lr * (rotated - anchors), dim=-1) |
| self.gal.gal_anchors.copy_( |
| new_anchors.to(self.gal.gal_anchors.dtype)) |
|
|
| return score |
|
|
| |
| |
| |
|
|
| def compute_loss(self, output, targets, output_aug=None, |
| mastery_queue=None): |
| loss_dict = {} |
| emb = output['embedding'] |
| emb_b = output['emb_b'] |
| geo_emb = output['geo_emb'] |
| B = emb.shape[0] |
| is_mastery = mastery_queue is not None and mastery_queue.active |
|
|
| |
| l_ce = F.cross_entropy(output['logits_a'], targets) |
| loss_dict['ce'] = l_ce |
| acc_a = (output['logits_a'].argmax(-1) == targets).float().mean().item() |
| loss_dict['acc_a'] = acc_a |
|
|
| |
| one_hot = F.one_hot(targets, self.num_classes).float() |
| ls = self.label_smoothing |
| one_hot_smooth = one_hot * (1.0 - ls) + ls / self.num_classes if ls > 0 else one_hot |
| l_bce = F.binary_cross_entropy_with_logits( |
| output['logits_b'], one_hot_smooth) |
| loss_dict['bce'] = l_bce |
| acc_b = (output['logits_b'].argmax(-1) == targets).float().mean().item() |
| loss_dict['acc_b'] = acc_b |
|
|
| |
| l_geo_bce = F.binary_cross_entropy_with_logits( |
| output['geo_logits'], one_hot_smooth) |
| loss_dict['geo_bce'] = l_geo_bce |
| geo_acc = (output['geo_logits'].argmax(-1) == targets).float().mean().item() |
| loss_dict['geo_acc'] = geo_acc |
|
|
| |
| nce_acc = 0.0 |
| if output_aug is not None: |
| labels_nce = torch.arange(B, device=emb.device) |
|
|
| |
| emb_aug = output_aug['embedding'] |
| sim = emb @ emb_aug.T / self.infonce_temp |
| l_nce = F.cross_entropy(sim, labels_nce) |
| nce_acc = (sim.argmax(1) == labels_nce).float().mean().item() |
| loss_dict['nce'] = l_nce |
| loss_dict['nce_acc'] = nce_acc |
|
|
| |
| emb_b_aug = output_aug.get('emb_b') |
| if emb_b_aug is not None: |
| sim_b = emb_b @ emb_b_aug.T / self.infonce_temp |
| l_nce_b = F.cross_entropy(sim_b, labels_nce) |
| nce_b_acc = (sim_b.argmax(1) == labels_nce).float().mean().item() |
| loss_dict['nce_b'] = l_nce_b |
| loss_dict['nce_b_acc'] = nce_b_acc |
|
|
| |
| geo_emb_aug = output_aug.get('geo_emb') |
| if geo_emb_aug is not None: |
| sim_g = geo_emb @ geo_emb_aug.T / self.infonce_temp |
| l_geo_nce = F.cross_entropy(sim_g, labels_nce) |
| geo_nce_acc = (sim_g.argmax(1) == labels_nce).float().mean().item() |
| loss_dict['geo_nce'] = l_geo_nce |
| loss_dict['geo_nce_acc'] = geo_nce_acc |
|
|
| |
| if is_mastery: |
| q_emb, q_labels = mastery_queue.get() |
| if q_emb is not None and q_emb.shape[0] >= B: |
| cross_sim = emb @ q_emb.T |
| same_mask = targets.unsqueeze(1) == q_labels.unsqueeze(0) |
| hn_sim = cross_sim.clone(); hn_sim[same_mask] = -1e9 |
| hn_cos = hn_sim.max(dim=1).values |
| hp_sim = cross_sim.clone(); hp_sim[~same_mask] = 1e9 |
| hp_cos = hp_sim.min(dim=1).values |
| valid = same_mask.any(1) & (~same_mask).any(1) |
| if valid.sum() > 0: |
| margin = mastery_queue.current_margin |
| l_mastery = F.relu( |
| hn_cos[valid] - hp_cos[valid] + margin).mean() |
| loss_dict['mastery'] = l_mastery |
| loss_dict['hard_neg_cos'] = hn_cos[valid].mean().item() |
| loss_dict['hard_pos_cos'] = hp_cos[valid].mean().item() |
| loss_dict['margin'] = margin |
| mastery_queue.push(emb.detach(), targets.detach()) |
|
|
| |
| vol2 = output['vol2'] |
| l_cm = F.relu(-vol2).mean() |
| loss_dict['cm'] = l_cm |
| loss_dict['cm_valid'] = (vol2 > 0).float().mean().item() |
|
|
| |
| l_cv_main = self._cv_loss_fast(emb, target=self.cv_target) |
| l_cv_geo = self._cv_loss_fast(geo_emb, target=self.cv_target) |
| l_cv = l_cv_main + l_cv_geo |
| loss_dict['cv'] = l_cv |
| loss_dict['cv_main'] = l_cv_main.item() if torch.is_tensor(l_cv_main) else l_cv_main |
| loss_dict['cv_geo'] = l_cv_geo.item() if torch.is_tensor(l_cv_geo) else l_cv_geo |
|
|
| |
| anchors_n = F.normalize(self.constellation.anchors, dim=-1) |
| anchor_sim = anchors_n @ anchors_n.T |
| mask_a = ~torch.eye(anchors_n.shape[0], dtype=torch.bool, |
| device=anchors_n.device) |
| l_spread = F.relu(anchor_sim[mask_a] - 0.0).mean() |
| loss_dict['spread'] = l_spread |
|
|
| |
| loss = (l_ce * self.bce_weight |
| + l_bce * self.bce_weight |
| + l_geo_bce * self.bce_weight |
| + loss_dict.get('nce', 0.0) * self.infonce_weight |
| + loss_dict.get('nce_b', 0.0) * self.stream_b_nce_weight |
| + loss_dict.get('geo_nce', 0.0) * self.geo_nce_weight |
| + loss_dict.get('mastery', 0.0) * self.bce_weight |
| + l_cm * self.cm_weight |
| + l_cv * self.cv_weight |
| + l_spread * 0.001) |
|
|
| loss_dict['total'] = loss |
| return loss, loss_dict |
|
|
| @staticmethod |
| def _cv_loss_fast(emb, target=0.22, n_samples=64, n_points=5): |
| B = emb.shape[0] |
| if B < n_points: |
| return torch.tensor(0.0, device=emb.device) |
| vols = [] |
| for _ in range(n_samples): |
| idx = torch.randperm(min(B, 512), device=emb.device)[:n_points] |
| pts = emb[idx].unsqueeze(0) |
| gram = torch.bmm(pts, pts.transpose(1, 2)) |
| norms = torch.diagonal(gram, dim1=1, dim2=2) |
| d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram |
| d2 = F.relu(d2) |
| N = n_points |
| cm = torch.zeros(1, N + 1, N + 1, |
| device=emb.device, dtype=emb.dtype) |
| cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 |
| k = N - 1 |
| sign = (-1.0) ** (k + 1) |
| fact = math.factorial(k) |
| prefactor = sign / ((2.0 ** k) * (fact ** 2)) |
| vol2 = prefactor * torch.linalg.det(cm.float()) |
| if vol2[0].item() > 1e-20: |
| vols.append(vol2[0].to(emb.dtype).sqrt()) |
| if len(vols) < 5: |
| return torch.tensor(0.0, device=emb.device) |
| vols_t = torch.stack(vols) |
| cv = vols_t.std() / (vols_t.mean() + 1e-8) |
| return (cv - target).pow(2) |
|
|
|
|
| |
| |
| |
|
|
| class MasteryQueue: |
| def __init__(self, dim, min_size=1024, max_size=8192, initial_size=4096, |
| patience=50, device='cuda', |
| margin_start=0.1, margin_end=0.3, margin_warmup=5000, |
| resize_step=1024, resize_cooldown=5, overfit_threshold=3.0): |
| self.dim = dim |
| self.min_size = min_size; self.max_size = max_size |
| self._current_max = initial_size |
| self.patience = patience; self.device = device |
| self.active = False |
| self._embs = None; self._labels = None |
| self._perfect_count = 0; self._total_batches = 0 |
| self._activated_at = None |
| self._margin_start = margin_start |
| self._margin_end = margin_end |
| self._margin_warmup = margin_warmup |
| self._mastery_steps = 0 |
| self._resize_step = resize_step |
| self._resize_cooldown = resize_cooldown |
| self._overfit_threshold = overfit_threshold |
| self._epochs_since_resize = resize_cooldown |
| self._gap_history = []; self._gap_window = 5 |
| self._resize_history = [] |
|
|
| def check_activation(self, nce_acc): |
| self._total_batches += 1 |
| if nce_acc >= 0.99: |
| self._perfect_count += 1 |
| else: |
| self._perfect_count = 0 |
| if not self.active and self._perfect_count >= self.patience: |
| self.active = True |
| self._activated_at = self._total_batches |
| print(f"\n β
MASTERY ACTIVATED at batch {self._total_batches} " |
| f"(nce_acc=1.0 for {self.patience} consecutive) " |
| f"queue={self._current_max}") |
| if self.active: |
| self._mastery_steps += 1 |
|
|
| def update_size(self, train_acc, val_acc, epoch): |
| if not self.active: return |
| self._epochs_since_resize += 1 |
| gap = train_acc - val_acc |
| self._gap_history.append((epoch, gap)) |
| if self._epochs_since_resize < self._resize_cooldown: return |
| old_size = self._current_max; reason = None |
| if gap > self._overfit_threshold * 2: |
| self._current_max = min(self._current_max + self._resize_step, self.max_size) |
| reason = f"grow: gap={gap:.1f}%" |
| elif gap < self._overfit_threshold and gap > 0: |
| if len(self._gap_history) >= self._gap_window: |
| recent = [g for _, g in self._gap_history[-self._gap_window:]] |
| if all(0 < g < self._overfit_threshold for g in recent): |
| self._current_max = max(self._current_max - self._resize_step, self.min_size) |
| reason = f"shrink: stable gap={gap:.1f}%" |
| if reason is None and len(self._gap_history) >= self._gap_window: |
| drift = gap - self._gap_history[-self._gap_window][1] |
| if drift > self._overfit_threshold: |
| self._current_max = min(self._current_max + self._resize_step, self.max_size) |
| reason = f"drift: {drift:+.1f}%" |
| elif drift < -self._overfit_threshold and gap > 0: |
| self._current_max = max(self._current_max - self._resize_step, self.min_size) |
| reason = f"drift: {drift:+.1f}%" |
| if self._current_max != old_size: |
| d = "β" if self._current_max > old_size else "β" |
| print(f" β Queue {d} {old_size}β{self._current_max} ({reason})") |
| self._epochs_since_resize = 0 |
| self._resize_history.append((epoch, old_size, self._current_max, gap, reason)) |
| if self._embs is not None and self._embs.shape[0] > self._current_max: |
| self._embs = self._embs[-self._current_max:] |
| self._labels = self._labels[-self._current_max:] |
|
|
| @property |
| def current_margin(self): |
| if not self.active: return self._margin_start |
| t = min(self._mastery_steps / max(self._margin_warmup, 1), 1.0) |
| return self._margin_start + t * (self._margin_end - self._margin_start) |
|
|
| def push(self, emb, labels): |
| emb = emb.detach().to(self.device) |
| labels = labels.detach().to(self.device) |
| if self._embs is None: |
| self._embs = emb; self._labels = labels |
| else: |
| self._embs = torch.cat([self._embs, emb], 0)[-self._current_max:] |
| self._labels = torch.cat([self._labels, labels], 0)[-self._current_max:] |
|
|
| def get(self): |
| if self._embs is None: return None, None |
| return self._embs, self._labels |
|
|
| @property |
| def size(self): |
| return 0 if self._embs is None else self._embs.shape[0] |
|
|
| def state_dict(self): |
| return { |
| 'active': self.active, 'total_batches': self._total_batches, |
| 'activated_at': self._activated_at, |
| 'mastery_steps': self._mastery_steps, |
| 'current_margin': self.current_margin, |
| 'current_max': self._current_max, |
| 'gap_history': self._gap_history[-20:], |
| 'resize_history': self._resize_history, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def create_tri_stream_vit(**kwargs): |
| return TriStreamViT(**kwargs) |