geolip-vit-dual-stream / run2 /dual_stream_modeling.py
AbstractPhil's picture
Rename dual_stream_modeling.py to run2/dual_stream_modeling.py
96998c5 verified
#!/usr/bin/env python3
"""
GeoLIP Dual-Stream ViT
=======================
Two parallel streams that cross-attend at bottlenecks:
Stream A (geometric): KSimplexChannel β†’ geometric features β†’ self-attn
Stream B (standard): learned projections β†’ self-attn
Architecture:
Shared encoder: patch_embed + pos_embed (no transformer blocks β€” raw patches)
β†’ Split into geo_stream and std_stream
β†’ 2Γ— DualStreamBlock (self-attn + cross-attn per stream)
β†’ Fuse: concat β†’ proj
β†’ 4Γ— FusedBlock (standard transformer)
β†’ Pool + InfoNCE + Constellation + Classifier
The geometric structure survives because it has its own stream for 2 blocks.
Cross-attention lets info flow without mixing representations.
Fused blocks merge the two with the geometric signal already established.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from itertools import combinations
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ══════════════════════════════════════════════════════════════════
# CAYLEY-MENGER + KSIMPLEX CHANNEL
# ══════════════════════════════════════════════════════════════════
class CMValidator(nn.Module):
"""Batch-friendly Cayley-Menger determinant."""
def __init__(self, k):
super().__init__()
self._k = k
self._nv = k + 1
pairs = list(combinations(range(self._nv), 2))
self._npairs = len(pairs)
self.register_buffer('_pi', torch.tensor([p[0] for p in pairs], dtype=torch.long))
self.register_buffer('_pj', torch.tensor([p[1] for p in pairs], dtype=torch.long))
sign = (-1.0) ** (k + 1)
fact = math.factorial(k)
self._prefactor = sign / ((2.0 ** k) * (fact ** 2))
def forward(self, verts):
gram = torch.einsum('...ve,...we->...vw', verts, verts)
norms = torch.diagonal(gram, dim1=-2, dim2=-1)
d2_mat = norms.unsqueeze(-1) + norms.unsqueeze(-2) - 2 * gram
d2_mat = F.relu(d2_mat)
d2_pairs = d2_mat[..., self._pi, self._pj]
shape = d2_mat.shape[:-2]
V = d2_mat.shape[-1]
cm = torch.zeros(*shape, V + 1, V + 1,
device=d2_mat.device, dtype=d2_mat.dtype)
cm[..., 0, 1:] = 1.0; cm[..., 1:, 0] = 1.0
cm[..., 1:, 1:] = d2_mat
vol2 = self._prefactor * torch.linalg.det(cm.float())
vol2 = vol2.to(d2_pairs.dtype)
return d2_pairs, vol2
class KSimplexChannel(nn.Module):
"""Per-position simplex encoder. k=4: 11 geometric features."""
BASE_DEFORM = 0.05
def __init__(self, k, in_dim, edim):
super().__init__()
self._k = k
self._nv = k + 1
self._edim = edim
self._cm = CMValidator(k)
self._out_dim = self._cm._npairs + 1 # 10 dΒ² + 1 volΒ² = 11
template = self._make_regular_simplex(k, edim)
self.register_buffer('_template', template)
self._to_deform = nn.Linear(in_dim, self._nv * edim)
self._norm = nn.LayerNorm(self._out_dim)
@staticmethod
def _make_regular_simplex(k, edim):
nv = k + 1
verts = torch.zeros(nv, edim)
for i in range(min(nv, edim)):
verts[i, i] = 1.0
if nv > edim:
for i in range(edim, nv):
v = torch.randn(edim)
verts[i] = v / (v.norm() + 1e-8)
verts = verts - verts.mean(dim=0, keepdim=True)
edge_len = (verts[0] - verts[1]).norm().clamp(min=1e-8)
verts = verts / edge_len
return verts
@property
def out_dim(self):
return self._out_dim
def forward(self, x):
deform = self._to_deform(x).unflatten(-1, (self._nv, self._edim))
verts = self._template + self.BASE_DEFORM * deform
d2, vol2 = self._cm(verts)
geo = torch.cat([d2, vol2.unsqueeze(-1)], dim=-1)
geo = self._norm(geo)
return geo, vol2
# ══════════════════════════════════════════════════════════════════
# CONSTELLATION + PATCHWORK
# ══════════════════════════════════════════════════════════════════
class Constellation(nn.Module):
def __init__(self, n_anchors, dim, anchor_drop=0.0):
super().__init__()
self.anchors = nn.Parameter(torch.randn(n_anchors, dim))
nn.init.normal_(self.anchors, 0, 1.0 / dim ** 0.5)
self.anchor_drop = anchor_drop
def triangulate(self, emb, training=False):
anchors = F.normalize(self.anchors, dim=-1)
if training and self.anchor_drop > 0:
mask = torch.rand(anchors.shape[0], device=anchors.device) > self.anchor_drop
if mask.sum() < 2:
mask[:2] = True
anchors = anchors[mask]
cos = emb @ anchors.T
tri = 1.0 - cos
_, nearest_local = cos.max(dim=-1)
full_idx = mask.nonzero(as_tuple=True)[0]
nearest = full_idx[nearest_local]
else:
cos = emb @ anchors.T
tri = 1.0 - cos
_, nearest = cos.max(dim=-1)
return tri, nearest
class Patchwork(nn.Module):
def __init__(self, n_anchors, n_comp, d_comp):
super().__init__()
self.n_comp = n_comp
self.d_comp = d_comp
asgn = torch.arange(n_anchors) % n_comp
self.register_buffer('asgn', asgn)
anchors_per = n_anchors // n_comp
self.comps = nn.ModuleList([nn.Sequential(
nn.Linear(anchors_per, d_comp * 2), nn.GELU(),
nn.Linear(d_comp * 2, d_comp), nn.LayerNorm(d_comp))
for _ in range(n_comp)])
def forward(self, tri):
return torch.cat([self.comps[k](tri[:, self.asgn == k])
for k in range(self.n_comp)], -1)
# ══════════════════════════════════════════════════════════════════
# EMBEDDING AUTOGRAD
# ══════════════════════════════════════════════════════════════════
class EmbeddingAutograd(torch.autograd.Function):
"""Geometric autograd: tangential projection + anchor separation."""
@staticmethod
def forward(ctx, x, embedding, anchors, tang, sep):
ctx.save_for_backward(embedding, anchors)
ctx.tang = tang; ctx.sep = sep
return x
@staticmethod
def backward(ctx, grad_output):
embedding, anchors = ctx.saved_tensors
emb_n = F.normalize(embedding.detach().float(), dim=-1)
anchors_n = F.normalize(anchors.detach().float(), dim=-1)
grad_f = grad_output.float()
radial = (grad_f * emb_n).sum(-1, keepdim=True) * emb_n
corrected = (grad_f - radial) + (1.0 - ctx.tang) * radial
if ctx.sep > 0:
cos_to = emb_n @ anchors_n.T
nearest = anchors_n[cos_to.argmax(dim=-1)]
toward = (corrected * nearest).sum(-1, keepdim=True)
corrected = corrected - ctx.sep * (toward > 0).float() * toward * nearest
return corrected.to(grad_output.dtype), None, None, None, None
# ══════════════════════════════════════════════════════════════════
# DUAL-STREAM BLOCKS
# ══════════════════════════════════════════════════════════════════
class DualStreamBlock(nn.Module):
"""
Two parallel streams with self-attention + cross-attention.
Geo stream: self_attn β†’ KSimplex β†’ cross_attn(q=geo, kv=std) β†’ FFN
Std stream: self_attn β†’ cross_attn(q=std, kv=geo) β†’ FFN
Cross-attention is the bottleneck where info flows between streams.
"""
def __init__(self, stream_dim, geo_dim, n_heads, ksimplex_k=4,
ksimplex_edim=8, dropout=0.1):
super().__init__()
self.stream_dim = stream_dim
self.geo_dim = geo_dim
# ── Geo stream ──
self.geo_norm1 = nn.LayerNorm(stream_dim)
self.geo_self_attn = nn.MultiheadAttention(
stream_dim, n_heads, dropout=dropout, batch_first=True)
self.geo_ksimplex = KSimplexChannel(
k=ksimplex_k, in_dim=stream_dim, edim=ksimplex_edim)
# Project geo features back to stream dim
self.geo_lift = nn.Sequential(
nn.Linear(self.geo_ksimplex.out_dim, stream_dim), nn.GELU())
self.geo_norm2 = nn.LayerNorm(stream_dim)
self.geo_cross_attn = nn.MultiheadAttention(
stream_dim, n_heads, dropout=dropout, batch_first=True)
self.geo_norm3 = nn.LayerNorm(stream_dim)
self.geo_ffn = nn.Sequential(
nn.Linear(stream_dim, stream_dim * 4), nn.GELU(),
nn.Dropout(dropout),
nn.Linear(stream_dim * 4, stream_dim), nn.Dropout(dropout))
# ── Std stream ──
self.std_norm1 = nn.LayerNorm(stream_dim)
self.std_self_attn = nn.MultiheadAttention(
stream_dim, n_heads, dropout=dropout, batch_first=True)
self.std_norm2 = nn.LayerNorm(stream_dim)
self.std_cross_attn = nn.MultiheadAttention(
stream_dim, n_heads, dropout=dropout, batch_first=True)
self.std_norm3 = nn.LayerNorm(stream_dim)
self.std_ffn = nn.Sequential(
nn.Linear(stream_dim, stream_dim * 4), nn.GELU(),
nn.Dropout(dropout),
nn.Linear(stream_dim * 4, stream_dim), nn.Dropout(dropout))
def forward(self, geo_stream, std_stream):
"""
geo_stream: (B, P, stream_dim)
std_stream: (B, P, stream_dim)
Returns: geo_stream, std_stream, geo_feats (B, P, 11), vol2 (B, P)
"""
B, P, _ = geo_stream.shape
# ── Geo: self-attention ──
h = self.geo_norm1(geo_stream)
h, _ = self.geo_self_attn(h, h, h, need_weights=False)
geo_stream = geo_stream + h
# ── Geo: KSimplex per patch ──
flat = geo_stream.reshape(B * P, -1)
geo_feats, vol2 = self.geo_ksimplex(flat)
geo_feats = geo_feats.reshape(B, P, -1) # (B, P, 11)
vol2 = vol2.reshape(B, P) # (B, P)
# Lift geo features and add as residual
geo_stream = geo_stream + self.geo_lift(geo_feats)
# ── Geo: cross-attend FROM std ──
h = self.geo_norm2(geo_stream)
std_ctx = self.std_norm2(std_stream)
h, _ = self.geo_cross_attn(h, std_ctx, std_ctx, need_weights=False)
geo_stream = geo_stream + h
# ── Geo: FFN ──
geo_stream = geo_stream + self.geo_ffn(self.geo_norm3(geo_stream))
# ── Std: self-attention ──
h = self.std_norm1(std_stream)
h, _ = self.std_self_attn(h, h, h, need_weights=False)
std_stream = std_stream + h
# ── Std: cross-attend FROM geo ──
h2 = self.std_norm2(std_stream)
geo_ctx = self.geo_norm2(geo_stream)
h2, _ = self.std_cross_attn(h2, geo_ctx, geo_ctx, need_weights=False)
std_stream = std_stream + h2
# ── Std: FFN ──
std_stream = std_stream + self.std_ffn(self.std_norm3(std_stream))
return geo_stream, std_stream, geo_feats, vol2
class FusedBlock(nn.Module):
"""Standard transformer block on the fused stream."""
def __init__(self, dim, n_heads, dropout=0.1):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(
dim, n_heads, dropout=dropout, batch_first=True)
self.norm2 = nn.LayerNorm(dim)
self.ffn = nn.Sequential(
nn.Linear(dim, dim * 4), nn.GELU(), nn.Dropout(dropout),
nn.Linear(dim * 4, dim), nn.Dropout(dropout))
def forward(self, x):
h = self.norm1(x)
h, _ = self.attn(h, h, h, need_weights=False)
x = x + h
x = x + self.ffn(self.norm2(x))
return x
# ══════════════════════════════════════════════════════════════════
# DUAL-STREAM VIT
# ══════════════════════════════════════════════════════════════════
class DualStreamViT(nn.Module):
"""
GeoLIP Dual-Stream Vision Transformer.
Architecture:
patch_embed + pos β†’ (B, 64, embed_dim)
β†’ geo_proj, std_proj β†’ two streams at stream_dim
β†’ 2Γ— DualStreamBlock (self-attn + cross-attn + KSimplex)
β†’ fuse: concat(geo, std) β†’ proj to fused_dim
β†’ 4Γ— FusedBlock
β†’ pool + constellation + InfoNCE + classifier
"""
def __init__(
self,
num_classes=10,
img_size=32,
patch_size=4,
embed_dim=384,
stream_dim=192,
fused_dim=256,
n_dual_blocks=2,
n_fused_blocks=4,
n_heads=8,
output_dim=128,
n_anchors=64,
n_comp=8,
d_comp=64,
anchor_drop=0.10,
cv_target=0.22,
ksimplex_k=4,
ksimplex_edim=8,
dropout=0.1,
infonce_temp=0.07,
infonce_weight=1.0,
bce_weight=1.0,
cm_weight=0.1,
cv_weight=0.01,
autograd_tang=0.5,
autograd_sep=0.1,
enable_autograd=True,
):
super().__init__()
self.num_classes = num_classes
self.num_patches = (img_size // patch_size) ** 2
self.stream_dim = stream_dim
self.fused_dim = fused_dim
self.output_dim = output_dim
self.cv_target = cv_target
self.infonce_temp = infonce_temp
self.infonce_weight = infonce_weight
self.bce_weight = bce_weight
self.cm_weight = cm_weight
self.cv_weight = cv_weight
self.autograd_tang = autograd_tang
self.autograd_sep = autograd_sep
self.enable_autograd = enable_autograd
# Save config for checkpoint
self.config = {k: v for k, v in locals().items()
if k != 'self' and not k.startswith('_')}
# ── Patch embedding ──
self.patch_embed = nn.Conv2d(
3, embed_dim, kernel_size=patch_size, stride=patch_size)
self.pos_embed = nn.Parameter(
torch.zeros(1, self.num_patches, embed_dim))
nn.init.trunc_normal_(self.pos_embed, std=0.02)
# ── Stream projections ──
self.geo_proj = nn.Sequential(
nn.Linear(embed_dim, stream_dim), nn.LayerNorm(stream_dim))
self.std_proj = nn.Sequential(
nn.Linear(embed_dim, stream_dim), nn.LayerNorm(stream_dim))
# ── Dual-stream blocks ──
geo_dim = 11 # KSimplex output
self.dual_blocks = nn.ModuleList([
DualStreamBlock(stream_dim, geo_dim, n_heads,
ksimplex_k, ksimplex_edim, dropout)
for _ in range(n_dual_blocks)])
# ── Fusion ──
self.fuse_proj = nn.Sequential(
nn.Linear(stream_dim * 2, fused_dim),
nn.LayerNorm(fused_dim), nn.GELU())
# ── Fused blocks ──
self.fused_blocks = nn.ModuleList([
FusedBlock(fused_dim, n_heads, dropout)
for _ in range(n_fused_blocks)])
self.fused_norm = nn.LayerNorm(fused_dim)
# ── Output projection ──
self.output_proj = nn.Sequential(
nn.Linear(fused_dim, output_dim),
nn.LayerNorm(output_dim))
# ── Constellation + Patchwork ──
self.constellation = Constellation(n_anchors, output_dim, anchor_drop)
self.patchwork = Patchwork(n_anchors, n_comp, d_comp)
pw_dim = n_comp * d_comp
# ── Classifier: patchwork + pooled emb ──
self.classifier = nn.Sequential(
nn.Linear(pw_dim + output_dim, pw_dim), nn.GELU(),
nn.LayerNorm(pw_dim), nn.Dropout(dropout),
nn.Linear(pw_dim, num_classes))
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x, targets=None, apply_autograd=True):
"""
Args:
x: (B, 3, H, W)
targets: (B,) class indices (optional, for loss)
Returns:
dict with logits, embedding, geo_feats, vol2, etc.
"""
output = {}
B = x.shape[0]
# ── Patch embedding ──
tokens = self.patch_embed(x).flatten(2).transpose(1, 2)
tokens = tokens + self.pos_embed
P = tokens.shape[1]
# ── Split into two streams ──
geo_stream = self.geo_proj(tokens) # (B, P, stream_dim)
std_stream = self.std_proj(tokens) # (B, P, stream_dim)
# ── Dual-stream blocks ──
all_geo_feats = []
all_vol2 = []
for block in self.dual_blocks:
geo_stream, std_stream, geo_feats, vol2 = block(
geo_stream, std_stream)
all_geo_feats.append(geo_feats)
all_vol2.append(vol2)
output['geo_feats'] = all_geo_feats[-1] # (B, P, 11) from last block
output['all_geo_feats'] = torch.stack(all_geo_feats) # (n_dual, B, P, 11)
output['vol2'] = torch.stack(all_vol2) # (n_dual, B, P)
# ── Fuse ──
fused = self.fuse_proj(
torch.cat([geo_stream, std_stream], dim=-1)) # (B, P, fused_dim)
# ── Fused blocks ──
for block in self.fused_blocks:
fused = block(fused)
fused = self.fused_norm(fused) # (B, P, fused_dim)
# ── Pool: mean over patches ──
pooled = fused.mean(dim=1) # (B, fused_dim)
# ── Output projection to sphere ──
emb = F.normalize(self.output_proj(pooled), dim=-1) # (B, output_dim)
# ── EmbeddingAutograd on the sphere embedding ──
# Corrects gradients flowing back through the std/fused path:
# tangential projection keeps updates on the sphere,
# separation pushes away from nearest anchor
if (apply_autograd and self.training and self.enable_autograd):
emb = EmbeddingAutograd.apply(
emb, emb, self.constellation.anchors,
self.autograd_tang, self.autograd_sep)
output['embedding'] = emb
# ── Geo-only embedding (for CV measurement) ──
# Pool geo stream separately β€” this is the geometric contribution
geo_pooled = geo_stream.mean(dim=1) # (B, stream_dim)
output['geo_pooled'] = geo_pooled
# ── Constellation triangulation ──
# Full tri for patchwork (no dropout β€” needs all anchors)
tri_full, nearest_full = self.constellation.triangulate(
emb, training=False)
pw = self.patchwork(tri_full)
output['triangulation'] = tri_full
# Dropout version for nearest anchor tracking (training regularization)
if self.training:
_, nearest = self.constellation.triangulate(emb, training=True)
else:
nearest = nearest_full
output['nearest'] = nearest
# ── Classifier ──
logits = self.classifier(torch.cat([pw, emb], dim=-1))
output['logits'] = logits
# ── Patch-level anchor tracking ──
# Project all patches to output space for tracking
with torch.no_grad():
patch_embs = F.normalize(
self.output_proj(fused.reshape(B * P, -1)), dim=-1)
patch_embs = patch_embs.reshape(B, P, -1)
anchors_n = F.normalize(self.constellation.anchors, dim=-1)
patch_cos = torch.einsum('bpd,ad->bpa', patch_embs, anchors_n)
output['patch_nearest'] = patch_cos.argmax(dim=-1)
output['patch_embs'] = patch_embs
return output
def compute_loss(self, output, targets, output_aug=None):
"""
Compute loss with InfoNCE between two augmented views.
Args:
output: dict from forward(view1)
targets: (B,) class indices
output_aug: dict from forward(view2) β€” optional, for InfoNCE
Returns:
loss, loss_dict
"""
loss_dict = {}
emb = output['embedding']
# ── BCE classification ──
one_hot = F.one_hot(targets, self.num_classes).float()
l_bce = F.binary_cross_entropy_with_logits(output['logits'], one_hot)
loss_dict['bce'] = l_bce
# ── InfoNCE between augmented views ──
if output_aug is not None:
emb_aug = output_aug['embedding']
# Each image's two views should be closest to each other
sim = emb @ emb_aug.T / self.infonce_temp
labels_nce = torch.arange(emb.shape[0], device=emb.device)
l_nce = F.cross_entropy(sim, labels_nce)
nce_acc = (sim.argmax(1) == labels_nce).float().mean()
loss_dict['nce'] = l_nce
loss_dict['nce_acc'] = nce_acc.item()
# ── CM validity (penalize negative volumes) ──
vol2 = output['vol2'] # (n_dual, B, P)
l_cm = F.relu(-vol2).mean()
loss_dict['cm'] = l_cm
loss_dict['cm_valid'] = (vol2 > 0).float().mean().item()
# ── CV loss (gentle push toward 0.20-0.23 band) ──
# Measured on the final embedding β€” gradient flows to
# constellation anchors and geo stream via backprop
l_cv = self._cv_loss_fast(emb, target=self.cv_target)
loss_dict['cv'] = l_cv
# ── Anchor spread (prevent anchor clustering) ──
anchors_n = F.normalize(self.constellation.anchors, dim=-1)
anchor_sim = anchors_n @ anchors_n.T
mask_a = ~torch.eye(anchors_n.shape[0], dtype=torch.bool,
device=anchors_n.device)
l_spread = F.relu(anchor_sim[mask_a] - 0.0).mean()
loss_dict['spread'] = l_spread
# ── Combine ──
loss = (l_bce * self.bce_weight
+ loss_dict.get('nce', 0.0) * self.infonce_weight
+ l_cm * self.cm_weight
+ l_cv * self.cv_weight
+ l_spread * 0.001)
loss_dict['total'] = loss
return loss, loss_dict
@staticmethod
def _cv_loss_fast(emb, target=0.22, n_samples=64, n_points=5):
"""Fast differentiable CV loss from random pentachora."""
B = emb.shape[0]
if B < n_points:
return torch.tensor(0.0, device=emb.device)
vols = []
for _ in range(n_samples):
idx = torch.randperm(min(B, 512), device=emb.device)[:n_points]
pts = emb[idx].unsqueeze(0) # (1, 5, D)
gram = torch.bmm(pts, pts.transpose(1, 2))
norms = torch.diagonal(gram, dim1=1, dim2=2)
d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram
d2 = F.relu(d2)
N = n_points
cm = torch.zeros(1, N + 1, N + 1,
device=emb.device, dtype=emb.dtype)
cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
k = N - 1
sign = (-1.0) ** (k + 1)
fact = math.factorial(k)
prefactor = sign / ((2.0 ** k) * (fact ** 2))
vol2 = prefactor * torch.linalg.det(cm.float())
if vol2[0].item() > 1e-20:
vols.append(vol2[0].to(emb.dtype).sqrt())
if len(vols) < 5:
return torch.tensor(0.0, device=emb.device)
vols_t = torch.stack(vols)
cv = vols_t.std() / (vols_t.mean() + 1e-8)
return (cv - target).pow(2)
# ══════════════════════════════════════════════════════════════════
# FACTORY
# ══════════════════════════════════════════════════════════════════
def create_dual_stream_vit(**kwargs):
return DualStreamViT(**kwargs)