| |
| """ |
| GeoLIP Dual-Stream ViT |
| ======================= |
| Two parallel streams that cross-attend at bottlenecks: |
| Stream A (geometric): KSimplexChannel β geometric features β self-attn |
| Stream B (standard): learned projections β self-attn |
| |
| Architecture: |
| Shared encoder: patch_embed + pos_embed (no transformer blocks β raw patches) |
| β Split into geo_stream and std_stream |
| β 2Γ DualStreamBlock (self-attn + cross-attn per stream) |
| β Fuse: concat β proj |
| β 4Γ FusedBlock (standard transformer) |
| β Pool + InfoNCE + Constellation + Classifier |
| |
| The geometric structure survives because it has its own stream for 2 blocks. |
| Cross-attention lets info flow without mixing representations. |
| Fused blocks merge the two with the geometric signal already established. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| from itertools import combinations |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
| |
| |
| |
|
|
| class CMValidator(nn.Module): |
| """Batch-friendly Cayley-Menger determinant.""" |
| def __init__(self, k): |
| super().__init__() |
| self._k = k |
| self._nv = k + 1 |
| pairs = list(combinations(range(self._nv), 2)) |
| self._npairs = len(pairs) |
| self.register_buffer('_pi', torch.tensor([p[0] for p in pairs], dtype=torch.long)) |
| self.register_buffer('_pj', torch.tensor([p[1] for p in pairs], dtype=torch.long)) |
| sign = (-1.0) ** (k + 1) |
| fact = math.factorial(k) |
| self._prefactor = sign / ((2.0 ** k) * (fact ** 2)) |
|
|
| def forward(self, verts): |
| gram = torch.einsum('...ve,...we->...vw', verts, verts) |
| norms = torch.diagonal(gram, dim1=-2, dim2=-1) |
| d2_mat = norms.unsqueeze(-1) + norms.unsqueeze(-2) - 2 * gram |
| d2_mat = F.relu(d2_mat) |
| d2_pairs = d2_mat[..., self._pi, self._pj] |
| shape = d2_mat.shape[:-2] |
| V = d2_mat.shape[-1] |
| cm = torch.zeros(*shape, V + 1, V + 1, |
| device=d2_mat.device, dtype=d2_mat.dtype) |
| cm[..., 0, 1:] = 1.0; cm[..., 1:, 0] = 1.0 |
| cm[..., 1:, 1:] = d2_mat |
| vol2 = self._prefactor * torch.linalg.det(cm.float()) |
| vol2 = vol2.to(d2_pairs.dtype) |
| return d2_pairs, vol2 |
|
|
|
|
| class KSimplexChannel(nn.Module): |
| """Per-position simplex encoder. k=4: 11 geometric features.""" |
| BASE_DEFORM = 0.05 |
|
|
| def __init__(self, k, in_dim, edim): |
| super().__init__() |
| self._k = k |
| self._nv = k + 1 |
| self._edim = edim |
| self._cm = CMValidator(k) |
| self._out_dim = self._cm._npairs + 1 |
| template = self._make_regular_simplex(k, edim) |
| self.register_buffer('_template', template) |
| self._to_deform = nn.Linear(in_dim, self._nv * edim) |
| self._norm = nn.LayerNorm(self._out_dim) |
|
|
| @staticmethod |
| def _make_regular_simplex(k, edim): |
| nv = k + 1 |
| verts = torch.zeros(nv, edim) |
| for i in range(min(nv, edim)): |
| verts[i, i] = 1.0 |
| if nv > edim: |
| for i in range(edim, nv): |
| v = torch.randn(edim) |
| verts[i] = v / (v.norm() + 1e-8) |
| verts = verts - verts.mean(dim=0, keepdim=True) |
| edge_len = (verts[0] - verts[1]).norm().clamp(min=1e-8) |
| verts = verts / edge_len |
| return verts |
|
|
| @property |
| def out_dim(self): |
| return self._out_dim |
|
|
| def forward(self, x): |
| deform = self._to_deform(x).unflatten(-1, (self._nv, self._edim)) |
| verts = self._template + self.BASE_DEFORM * deform |
| d2, vol2 = self._cm(verts) |
| geo = torch.cat([d2, vol2.unsqueeze(-1)], dim=-1) |
| geo = self._norm(geo) |
| return geo, vol2 |
|
|
|
|
| |
| |
| |
|
|
| class Constellation(nn.Module): |
| def __init__(self, n_anchors, dim, anchor_drop=0.0): |
| super().__init__() |
| self.anchors = nn.Parameter(torch.randn(n_anchors, dim)) |
| nn.init.normal_(self.anchors, 0, 1.0 / dim ** 0.5) |
| self.anchor_drop = anchor_drop |
|
|
| def triangulate(self, emb, training=False): |
| anchors = F.normalize(self.anchors, dim=-1) |
| if training and self.anchor_drop > 0: |
| mask = torch.rand(anchors.shape[0], device=anchors.device) > self.anchor_drop |
| if mask.sum() < 2: |
| mask[:2] = True |
| anchors = anchors[mask] |
| cos = emb @ anchors.T |
| tri = 1.0 - cos |
| _, nearest_local = cos.max(dim=-1) |
| full_idx = mask.nonzero(as_tuple=True)[0] |
| nearest = full_idx[nearest_local] |
| else: |
| cos = emb @ anchors.T |
| tri = 1.0 - cos |
| _, nearest = cos.max(dim=-1) |
| return tri, nearest |
|
|
|
|
| class Patchwork(nn.Module): |
| def __init__(self, n_anchors, n_comp, d_comp): |
| super().__init__() |
| self.n_comp = n_comp |
| self.d_comp = d_comp |
| asgn = torch.arange(n_anchors) % n_comp |
| self.register_buffer('asgn', asgn) |
| anchors_per = n_anchors // n_comp |
| self.comps = nn.ModuleList([nn.Sequential( |
| nn.Linear(anchors_per, d_comp * 2), nn.GELU(), |
| nn.Linear(d_comp * 2, d_comp), nn.LayerNorm(d_comp)) |
| for _ in range(n_comp)]) |
|
|
| def forward(self, tri): |
| return torch.cat([self.comps[k](tri[:, self.asgn == k]) |
| for k in range(self.n_comp)], -1) |
|
|
|
|
| |
| |
| |
|
|
| class EmbeddingAutograd(torch.autograd.Function): |
| """Geometric autograd: tangential projection + anchor separation.""" |
| @staticmethod |
| def forward(ctx, x, embedding, anchors, tang, sep): |
| ctx.save_for_backward(embedding, anchors) |
| ctx.tang = tang; ctx.sep = sep |
| return x |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| embedding, anchors = ctx.saved_tensors |
| emb_n = F.normalize(embedding.detach().float(), dim=-1) |
| anchors_n = F.normalize(anchors.detach().float(), dim=-1) |
| grad_f = grad_output.float() |
| radial = (grad_f * emb_n).sum(-1, keepdim=True) * emb_n |
| corrected = (grad_f - radial) + (1.0 - ctx.tang) * radial |
| if ctx.sep > 0: |
| cos_to = emb_n @ anchors_n.T |
| nearest = anchors_n[cos_to.argmax(dim=-1)] |
| toward = (corrected * nearest).sum(-1, keepdim=True) |
| corrected = corrected - ctx.sep * (toward > 0).float() * toward * nearest |
| return corrected.to(grad_output.dtype), None, None, None, None |
|
|
|
|
| |
| |
| |
|
|
| class DualStreamBlock(nn.Module): |
| """ |
| Two parallel streams with self-attention + cross-attention. |
| |
| Geo stream: self_attn β KSimplex β cross_attn(q=geo, kv=std) β FFN |
| Std stream: self_attn β cross_attn(q=std, kv=geo) β FFN |
| |
| Cross-attention is the bottleneck where info flows between streams. |
| """ |
| def __init__(self, stream_dim, geo_dim, n_heads, ksimplex_k=4, |
| ksimplex_edim=8, dropout=0.1): |
| super().__init__() |
| self.stream_dim = stream_dim |
| self.geo_dim = geo_dim |
|
|
| |
| self.geo_norm1 = nn.LayerNorm(stream_dim) |
| self.geo_self_attn = nn.MultiheadAttention( |
| stream_dim, n_heads, dropout=dropout, batch_first=True) |
| self.geo_ksimplex = KSimplexChannel( |
| k=ksimplex_k, in_dim=stream_dim, edim=ksimplex_edim) |
| |
| self.geo_lift = nn.Sequential( |
| nn.Linear(self.geo_ksimplex.out_dim, stream_dim), nn.GELU()) |
| self.geo_norm2 = nn.LayerNorm(stream_dim) |
| self.geo_cross_attn = nn.MultiheadAttention( |
| stream_dim, n_heads, dropout=dropout, batch_first=True) |
| self.geo_norm3 = nn.LayerNorm(stream_dim) |
| self.geo_ffn = nn.Sequential( |
| nn.Linear(stream_dim, stream_dim * 4), nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(stream_dim * 4, stream_dim), nn.Dropout(dropout)) |
|
|
| |
| self.std_norm1 = nn.LayerNorm(stream_dim) |
| self.std_self_attn = nn.MultiheadAttention( |
| stream_dim, n_heads, dropout=dropout, batch_first=True) |
| self.std_norm2 = nn.LayerNorm(stream_dim) |
| self.std_cross_attn = nn.MultiheadAttention( |
| stream_dim, n_heads, dropout=dropout, batch_first=True) |
| self.std_norm3 = nn.LayerNorm(stream_dim) |
| self.std_ffn = nn.Sequential( |
| nn.Linear(stream_dim, stream_dim * 4), nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(stream_dim * 4, stream_dim), nn.Dropout(dropout)) |
|
|
| def forward(self, geo_stream, std_stream): |
| """ |
| geo_stream: (B, P, stream_dim) |
| std_stream: (B, P, stream_dim) |
| Returns: geo_stream, std_stream, geo_feats (B, P, 11), vol2 (B, P) |
| """ |
| B, P, _ = geo_stream.shape |
|
|
| |
| h = self.geo_norm1(geo_stream) |
| h, _ = self.geo_self_attn(h, h, h, need_weights=False) |
| geo_stream = geo_stream + h |
|
|
| |
| flat = geo_stream.reshape(B * P, -1) |
| geo_feats, vol2 = self.geo_ksimplex(flat) |
| geo_feats = geo_feats.reshape(B, P, -1) |
| vol2 = vol2.reshape(B, P) |
| |
| geo_stream = geo_stream + self.geo_lift(geo_feats) |
|
|
| |
| h = self.geo_norm2(geo_stream) |
| std_ctx = self.std_norm2(std_stream) |
| h, _ = self.geo_cross_attn(h, std_ctx, std_ctx, need_weights=False) |
| geo_stream = geo_stream + h |
|
|
| |
| geo_stream = geo_stream + self.geo_ffn(self.geo_norm3(geo_stream)) |
|
|
| |
| h = self.std_norm1(std_stream) |
| h, _ = self.std_self_attn(h, h, h, need_weights=False) |
| std_stream = std_stream + h |
|
|
| |
| h2 = self.std_norm2(std_stream) |
| geo_ctx = self.geo_norm2(geo_stream) |
| h2, _ = self.std_cross_attn(h2, geo_ctx, geo_ctx, need_weights=False) |
| std_stream = std_stream + h2 |
|
|
| |
| std_stream = std_stream + self.std_ffn(self.std_norm3(std_stream)) |
|
|
| return geo_stream, std_stream, geo_feats, vol2 |
|
|
|
|
| class FusedBlock(nn.Module): |
| """Standard transformer block on the fused stream.""" |
| def __init__(self, dim, n_heads, dropout=0.1): |
| super().__init__() |
| self.norm1 = nn.LayerNorm(dim) |
| self.attn = nn.MultiheadAttention( |
| dim, n_heads, dropout=dropout, batch_first=True) |
| self.norm2 = nn.LayerNorm(dim) |
| self.ffn = nn.Sequential( |
| nn.Linear(dim, dim * 4), nn.GELU(), nn.Dropout(dropout), |
| nn.Linear(dim * 4, dim), nn.Dropout(dropout)) |
|
|
| def forward(self, x): |
| h = self.norm1(x) |
| h, _ = self.attn(h, h, h, need_weights=False) |
| x = x + h |
| x = x + self.ffn(self.norm2(x)) |
| return x |
|
|
|
|
| |
| |
| |
|
|
| class DualStreamViT(nn.Module): |
| """ |
| GeoLIP Dual-Stream Vision Transformer. |
| |
| Architecture: |
| patch_embed + pos β (B, 64, embed_dim) |
| β geo_proj, std_proj β two streams at stream_dim |
| β 2Γ DualStreamBlock (self-attn + cross-attn + KSimplex) |
| β fuse: concat(geo, std) β proj to fused_dim |
| β 4Γ FusedBlock |
| β pool + constellation + InfoNCE + classifier |
| """ |
| def __init__( |
| self, |
| num_classes=10, |
| img_size=32, |
| patch_size=4, |
| embed_dim=384, |
| stream_dim=192, |
| fused_dim=256, |
| n_dual_blocks=2, |
| n_fused_blocks=4, |
| n_heads=8, |
| output_dim=128, |
| n_anchors=64, |
| n_comp=8, |
| d_comp=64, |
| anchor_drop=0.10, |
| cv_target=0.22, |
| ksimplex_k=4, |
| ksimplex_edim=8, |
| dropout=0.1, |
| infonce_temp=0.07, |
| infonce_weight=1.0, |
| bce_weight=1.0, |
| cm_weight=0.1, |
| cv_weight=0.01, |
| autograd_tang=0.5, |
| autograd_sep=0.1, |
| enable_autograd=True, |
| ): |
| super().__init__() |
| self.num_classes = num_classes |
| self.num_patches = (img_size // patch_size) ** 2 |
| self.stream_dim = stream_dim |
| self.fused_dim = fused_dim |
| self.output_dim = output_dim |
| self.cv_target = cv_target |
| self.infonce_temp = infonce_temp |
| self.infonce_weight = infonce_weight |
| self.bce_weight = bce_weight |
| self.cm_weight = cm_weight |
| self.cv_weight = cv_weight |
| self.autograd_tang = autograd_tang |
| self.autograd_sep = autograd_sep |
| self.enable_autograd = enable_autograd |
|
|
| |
| self.config = {k: v for k, v in locals().items() |
| if k != 'self' and not k.startswith('_')} |
|
|
| |
| self.patch_embed = nn.Conv2d( |
| 3, embed_dim, kernel_size=patch_size, stride=patch_size) |
| self.pos_embed = nn.Parameter( |
| torch.zeros(1, self.num_patches, embed_dim)) |
| nn.init.trunc_normal_(self.pos_embed, std=0.02) |
|
|
| |
| self.geo_proj = nn.Sequential( |
| nn.Linear(embed_dim, stream_dim), nn.LayerNorm(stream_dim)) |
| self.std_proj = nn.Sequential( |
| nn.Linear(embed_dim, stream_dim), nn.LayerNorm(stream_dim)) |
|
|
| |
| geo_dim = 11 |
| self.dual_blocks = nn.ModuleList([ |
| DualStreamBlock(stream_dim, geo_dim, n_heads, |
| ksimplex_k, ksimplex_edim, dropout) |
| for _ in range(n_dual_blocks)]) |
|
|
| |
| self.fuse_proj = nn.Sequential( |
| nn.Linear(stream_dim * 2, fused_dim), |
| nn.LayerNorm(fused_dim), nn.GELU()) |
|
|
| |
| self.fused_blocks = nn.ModuleList([ |
| FusedBlock(fused_dim, n_heads, dropout) |
| for _ in range(n_fused_blocks)]) |
| self.fused_norm = nn.LayerNorm(fused_dim) |
|
|
| |
| self.output_proj = nn.Sequential( |
| nn.Linear(fused_dim, output_dim), |
| nn.LayerNorm(output_dim)) |
|
|
| |
| self.constellation = Constellation(n_anchors, output_dim, anchor_drop) |
| self.patchwork = Patchwork(n_anchors, n_comp, d_comp) |
| pw_dim = n_comp * d_comp |
|
|
| |
| self.classifier = nn.Sequential( |
| nn.Linear(pw_dim + output_dim, pw_dim), nn.GELU(), |
| nn.LayerNorm(pw_dim), nn.Dropout(dropout), |
| nn.Linear(pw_dim, num_classes)) |
|
|
| self._init_weights() |
|
|
| def _init_weights(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Linear): |
| nn.init.trunc_normal_(m.weight, std=0.02) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
| elif isinstance(m, nn.LayerNorm): |
| nn.init.ones_(m.weight) |
| nn.init.zeros_(m.bias) |
|
|
| def forward(self, x, targets=None, apply_autograd=True): |
| """ |
| Args: |
| x: (B, 3, H, W) |
| targets: (B,) class indices (optional, for loss) |
| Returns: |
| dict with logits, embedding, geo_feats, vol2, etc. |
| """ |
| output = {} |
| B = x.shape[0] |
|
|
| |
| tokens = self.patch_embed(x).flatten(2).transpose(1, 2) |
| tokens = tokens + self.pos_embed |
| P = tokens.shape[1] |
|
|
| |
| geo_stream = self.geo_proj(tokens) |
| std_stream = self.std_proj(tokens) |
|
|
| |
| all_geo_feats = [] |
| all_vol2 = [] |
| for block in self.dual_blocks: |
| geo_stream, std_stream, geo_feats, vol2 = block( |
| geo_stream, std_stream) |
| all_geo_feats.append(geo_feats) |
| all_vol2.append(vol2) |
|
|
| output['geo_feats'] = all_geo_feats[-1] |
| output['all_geo_feats'] = torch.stack(all_geo_feats) |
| output['vol2'] = torch.stack(all_vol2) |
|
|
| |
| fused = self.fuse_proj( |
| torch.cat([geo_stream, std_stream], dim=-1)) |
|
|
| |
| for block in self.fused_blocks: |
| fused = block(fused) |
| fused = self.fused_norm(fused) |
|
|
| |
| pooled = fused.mean(dim=1) |
|
|
| |
| emb = F.normalize(self.output_proj(pooled), dim=-1) |
|
|
| |
| |
| |
| |
| if (apply_autograd and self.training and self.enable_autograd): |
| emb = EmbeddingAutograd.apply( |
| emb, emb, self.constellation.anchors, |
| self.autograd_tang, self.autograd_sep) |
|
|
| output['embedding'] = emb |
|
|
| |
| |
| geo_pooled = geo_stream.mean(dim=1) |
| output['geo_pooled'] = geo_pooled |
|
|
| |
| |
| tri_full, nearest_full = self.constellation.triangulate( |
| emb, training=False) |
| pw = self.patchwork(tri_full) |
| output['triangulation'] = tri_full |
|
|
| |
| if self.training: |
| _, nearest = self.constellation.triangulate(emb, training=True) |
| else: |
| nearest = nearest_full |
| output['nearest'] = nearest |
|
|
| |
| logits = self.classifier(torch.cat([pw, emb], dim=-1)) |
| output['logits'] = logits |
|
|
| |
| |
| with torch.no_grad(): |
| patch_embs = F.normalize( |
| self.output_proj(fused.reshape(B * P, -1)), dim=-1) |
| patch_embs = patch_embs.reshape(B, P, -1) |
| anchors_n = F.normalize(self.constellation.anchors, dim=-1) |
| patch_cos = torch.einsum('bpd,ad->bpa', patch_embs, anchors_n) |
| output['patch_nearest'] = patch_cos.argmax(dim=-1) |
| output['patch_embs'] = patch_embs |
|
|
| return output |
|
|
| def compute_loss(self, output, targets, output_aug=None): |
| """ |
| Compute loss with InfoNCE between two augmented views. |
| |
| Args: |
| output: dict from forward(view1) |
| targets: (B,) class indices |
| output_aug: dict from forward(view2) β optional, for InfoNCE |
| Returns: |
| loss, loss_dict |
| """ |
| loss_dict = {} |
| emb = output['embedding'] |
|
|
| |
| one_hot = F.one_hot(targets, self.num_classes).float() |
| l_bce = F.binary_cross_entropy_with_logits(output['logits'], one_hot) |
| loss_dict['bce'] = l_bce |
|
|
| |
| if output_aug is not None: |
| emb_aug = output_aug['embedding'] |
| |
| sim = emb @ emb_aug.T / self.infonce_temp |
| labels_nce = torch.arange(emb.shape[0], device=emb.device) |
| l_nce = F.cross_entropy(sim, labels_nce) |
| nce_acc = (sim.argmax(1) == labels_nce).float().mean() |
| loss_dict['nce'] = l_nce |
| loss_dict['nce_acc'] = nce_acc.item() |
|
|
| |
| vol2 = output['vol2'] |
| l_cm = F.relu(-vol2).mean() |
| loss_dict['cm'] = l_cm |
| loss_dict['cm_valid'] = (vol2 > 0).float().mean().item() |
|
|
| |
| |
| |
| l_cv = self._cv_loss_fast(emb, target=self.cv_target) |
| loss_dict['cv'] = l_cv |
|
|
| |
| anchors_n = F.normalize(self.constellation.anchors, dim=-1) |
| anchor_sim = anchors_n @ anchors_n.T |
| mask_a = ~torch.eye(anchors_n.shape[0], dtype=torch.bool, |
| device=anchors_n.device) |
| l_spread = F.relu(anchor_sim[mask_a] - 0.0).mean() |
| loss_dict['spread'] = l_spread |
|
|
| |
| loss = (l_bce * self.bce_weight |
| + loss_dict.get('nce', 0.0) * self.infonce_weight |
| + l_cm * self.cm_weight |
| + l_cv * self.cv_weight |
| + l_spread * 0.001) |
|
|
| loss_dict['total'] = loss |
| return loss, loss_dict |
|
|
| @staticmethod |
| def _cv_loss_fast(emb, target=0.22, n_samples=64, n_points=5): |
| """Fast differentiable CV loss from random pentachora.""" |
| B = emb.shape[0] |
| if B < n_points: |
| return torch.tensor(0.0, device=emb.device) |
| vols = [] |
| for _ in range(n_samples): |
| idx = torch.randperm(min(B, 512), device=emb.device)[:n_points] |
| pts = emb[idx].unsqueeze(0) |
| gram = torch.bmm(pts, pts.transpose(1, 2)) |
| norms = torch.diagonal(gram, dim1=1, dim2=2) |
| d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram |
| d2 = F.relu(d2) |
| N = n_points |
| cm = torch.zeros(1, N + 1, N + 1, |
| device=emb.device, dtype=emb.dtype) |
| cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 |
| k = N - 1 |
| sign = (-1.0) ** (k + 1) |
| fact = math.factorial(k) |
| prefactor = sign / ((2.0 ** k) * (fact ** 2)) |
| vol2 = prefactor * torch.linalg.det(cm.float()) |
| if vol2[0].item() > 1e-20: |
| vols.append(vol2[0].to(emb.dtype).sqrt()) |
| if len(vols) < 5: |
| return torch.tensor(0.0, device=emb.device) |
| vols_t = torch.stack(vols) |
| cv = vols_t.std() / (vols_t.mean() + 1e-8) |
| return (cv - target).pow(2) |
|
|
|
|
| |
| |
| |
|
|
| def create_dual_stream_vit(**kwargs): |
| return DualStreamViT(**kwargs) |