""" Superposition Patch Classifier - Two-Tier Gated Transformer ============================================================= Colab Cell 2 of 3 - depends on Cell 1 (generator.py) namespace. Architecture: voxels → patch_embed → e₀ Stage 0 (local gates): From raw embeddings, no attention e₀ → local_dim_head → dim_soft ─┐ e₀ → local_curv_head → curv_soft ─┤ LOCAL_GATE_DIM = 11 e₀ → local_bound_head → bound_soft ─┤ e₀ → local_axis_head → axis_soft ─┘→ local_gates (detached) Stage 1 (bootstrap): Attention sees local gates proj([e₀, local_gates]) → bootstrap_block × N → h Stage 1.5 (structural gates): From h, after cross-patch context h → struct_topo_head → topo_soft ─┐ h → struct_neighbor_head → neighbor_soft ─┤ STRUCTURAL_GATE_DIM = 6 h → struct_role_head → role_soft ─┘→ structural_gates (detached) Stage 2 (geometric routing): Both gate tiers (h, local_gates, structural_gates) → geometric_block × N → h' Stage 3 (classification): Gated shape heads [h', local_gates, structural_gates] → shape_heads """ import math import torch import torch.nn as nn import torch.nn.functional as F # Cell 1 provides: all constants including LOCAL_GATE_DIM, STRUCTURAL_GATE_DIM, TOTAL_GATE_DIM # === Patch Embedding ========================================================== class PatchEmbedding3D(nn.Module): def __init__(self, patch_dim=64): super().__init__() self.proj = nn.Linear(PATCH_VOL, patch_dim) pz = torch.arange(MACRO_Z).float() / MACRO_Z py = torch.arange(MACRO_Y).float() / MACRO_Y px = torch.arange(MACRO_X).float() / MACRO_X pos = torch.stack(torch.meshgrid(pz, py, px, indexing='ij'), dim=-1).reshape(MACRO_N, 3) self.register_buffer('pos_embed', pos) self.pos_proj = nn.Linear(3, patch_dim) def forward(self, x): B = x.shape[0] patches = x.view(B, MACRO_Z, PATCH_Z, MACRO_Y, PATCH_Y, MACRO_X, PATCH_X) patches = patches.permute(0, 1, 3, 5, 2, 4, 6).contiguous().view(B, MACRO_N, PATCH_VOL) return self.proj(patches) + self.pos_proj(self.pos_embed) # === Standard Transformer Block =============================================== class TransformerBlock(nn.Module): def __init__(self, dim, n_heads, dropout=0.1): super().__init__() self.attn = nn.MultiheadAttention(dim, n_heads, dropout=dropout, batch_first=True) self.ff = nn.Sequential( nn.Linear(dim, dim * 4), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim * 4, dim), nn.Dropout(dropout) ) self.ln1, self.ln2 = nn.LayerNorm(dim), nn.LayerNorm(dim) def forward(self, x): x = x + self.attn(self.ln1(x), self.ln1(x), self.ln1(x))[0] return x + self.ff(self.ln2(x)) # === Geometric Gated Attention ================================================ class GatedGeometricAttention(nn.Module): """ Multi-head attention with two-tier gate modulation. Q, K see both local and structural gates. V modulated by combined gate vector. Per-head compatibility bias from gate interactions. """ def __init__(self, embed_dim, gate_dim, n_heads, dropout=0.1): super().__init__() self.embed_dim = embed_dim self.n_heads = n_heads self.head_dim = embed_dim // n_heads # Q, K from [h, all_gates] self.q_proj = nn.Linear(embed_dim + gate_dim, embed_dim) self.k_proj = nn.Linear(embed_dim + gate_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) # Per-head gate compatibility self.gate_q = nn.Linear(gate_dim, n_heads) self.gate_k = nn.Linear(gate_dim, n_heads) # Value modulation by gates self.v_gate = nn.Sequential(nn.Linear(gate_dim, embed_dim), nn.Sigmoid()) self.out_proj = nn.Linear(embed_dim, embed_dim) self.attn_drop = nn.Dropout(dropout) self.scale = math.sqrt(self.head_dim) def forward(self, h, gate_features): B, N, _ = h.shape hg = torch.cat([h, gate_features], dim=-1) Q = self.q_proj(hg).view(B, N, self.n_heads, self.head_dim).transpose(1, 2) K = self.k_proj(hg).view(B, N, self.n_heads, self.head_dim).transpose(1, 2) V = self.v_proj(h) V = (V * self.v_gate(gate_features)).view(B, N, self.n_heads, self.head_dim).transpose(1, 2) content_scores = (Q @ K.transpose(-2, -1)) / self.scale gq = self.gate_q(gate_features) gk = self.gate_k(gate_features) compat = torch.einsum('bih,bjh->bhij', gq, gk) attn = F.softmax(content_scores + compat, dim=-1) attn = self.attn_drop(attn) out = (attn @ V).transpose(1, 2).reshape(B, N, self.embed_dim) return self.out_proj(out) class GeometricTransformerBlock(nn.Module): def __init__(self, embed_dim, gate_dim, n_heads, dropout=0.1, ff_mult=4): super().__init__() self.ln1 = nn.LayerNorm(embed_dim) self.attn = GatedGeometricAttention(embed_dim, gate_dim, n_heads, dropout) self.ln2 = nn.LayerNorm(embed_dim) self.ff = nn.Sequential( nn.Linear(embed_dim, embed_dim * ff_mult), nn.GELU(), nn.Dropout(dropout), nn.Linear(embed_dim * ff_mult, embed_dim), nn.Dropout(dropout) ) def forward(self, h, gate_features): h = h + self.attn(self.ln1(h), gate_features) h = h + self.ff(self.ln2(h)) return h # === Main Classifier ========================================================== class SuperpositionPatchClassifier(nn.Module): """ Two-tier gated transformer for multi-shape superposition. Tier 1 (local): Gates from raw patch embeddings — what IS in this patch Tier 2 (structural): Gates from post-attention h — what ROLE this patch plays Both tiers feed into geometric attention and classification. """ def __init__(self, embed_dim=128, patch_dim=64, n_bootstrap=2, n_geometric=2, n_heads=4, dropout=0.1): super().__init__() self.embed_dim = embed_dim # Patch embedding self.patch_embed = PatchEmbedding3D(patch_dim) # === Stage 0: Local encoder + gate heads (pre-attention) === # Shared MLP gives local heads enough capacity to extract # dims/curvature/boundary from 32 voxels without cross-patch info local_hidden = patch_dim * 2 # 128 self.local_encoder = nn.Sequential( nn.Linear(patch_dim, local_hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(local_hidden, local_hidden), nn.GELU(), nn.Dropout(dropout), ) self.local_dim_head = nn.Linear(local_hidden, NUM_LOCAL_DIMS) self.local_curv_head = nn.Linear(local_hidden, NUM_LOCAL_CURVS) self.local_bound_head = nn.Linear(local_hidden, NUM_LOCAL_BOUNDARY) self.local_axis_head = nn.Linear(local_hidden, NUM_LOCAL_AXES) # Project [embedding, local_gates] → embed_dim for bootstrap self.proj = nn.Linear(patch_dim + LOCAL_GATE_DIM, embed_dim) # === Stage 1: Bootstrap blocks (attention with local gate context) === self.bootstrap_blocks = nn.ModuleList([ TransformerBlock(embed_dim, n_heads, dropout) for _ in range(n_bootstrap) ]) # === Stage 1.5: Structural gate heads (from h, post-attention) === self.struct_topo_head = nn.Linear(embed_dim, NUM_STRUCT_TOPO) self.struct_neighbor_head = nn.Linear(embed_dim, NUM_STRUCT_NEIGHBOR) self.struct_role_head = nn.Linear(embed_dim, NUM_STRUCT_ROLE) # === Stage 2: Geometric gated blocks (see both gate tiers) === self.geometric_blocks = nn.ModuleList([ GeometricTransformerBlock(embed_dim, TOTAL_GATE_DIM, n_heads, dropout) for _ in range(n_geometric) ]) # === Stage 3: Gated classification === gated_dim = embed_dim + TOTAL_GATE_DIM self.patch_shape_head = nn.Sequential( nn.Linear(gated_dim, embed_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(embed_dim, NUM_CLASSES) ) self.global_pool = nn.Sequential( nn.Linear(gated_dim, embed_dim), nn.GELU(), nn.Linear(embed_dim, embed_dim) ) self.global_gate_head = nn.Linear(embed_dim, NUM_GATES) self.global_shape_head = nn.Linear(embed_dim, NUM_CLASSES) def forward(self, x): # === Raw patch embedding === e = self.patch_embed(x) # (B, 64, patch_dim) # === Stage 0: Local gates from raw embedding via local encoder === e_local = self.local_encoder(e) # (B, 64, local_hidden) local_dim_logits = self.local_dim_head(e_local) local_curv_logits = self.local_curv_head(e_local) local_bound_logits = self.local_bound_head(e_local) local_axis_logits = self.local_axis_head(e_local) local_gates = torch.cat([ F.softmax(local_dim_logits, dim=-1), F.softmax(local_curv_logits, dim=-1), torch.sigmoid(local_bound_logits), torch.sigmoid(local_axis_logits), ], dim=-1) # (B, 64, 11) # === Stage 1: Bootstrap with local gate context === h = self.proj(torch.cat([e, local_gates], dim=-1)) for blk in self.bootstrap_blocks: h = blk(h) # === Stage 1.5: Structural gates from h (after cross-patch context) === struct_topo_logits = self.struct_topo_head(h) struct_neighbor_logits = self.struct_neighbor_head(h) struct_role_logits = self.struct_role_head(h) structural_gates = torch.cat([ F.softmax(struct_topo_logits, dim=-1), torch.sigmoid(struct_neighbor_logits), F.softmax(struct_role_logits, dim=-1), ], dim=-1) # (B, 64, 6) # === Combined gate vector === all_gates = torch.cat([local_gates, structural_gates], dim=-1) # (B, 64, 17) # === Stage 2: Geometric gated transformer === for blk in self.geometric_blocks: h = blk(h, all_gates) # === Stage 3: Classification from gated representations === h_gated = torch.cat([h, all_gates], dim=-1) shape_logits = self.patch_shape_head(h_gated) g = self.global_pool(h_gated.mean(dim=1)) return { # Local gate predictions (Stage 0) "local_dim_logits": local_dim_logits, "local_curv_logits": local_curv_logits, "local_bound_logits": local_bound_logits, "local_axis_logits": local_axis_logits, # Structural gate predictions (Stage 1.5) "struct_topo_logits": struct_topo_logits, "struct_neighbor_logits": struct_neighbor_logits, "struct_role_logits": struct_role_logits, # Shape predictions (Stage 3) "patch_shape_logits": shape_logits, "patch_features": h, "global_features": g, "global_gates": self.global_gate_head(g), "global_shapes": self.global_shape_head(g), } # === Loss ===================================================================== class SuperpositionLoss(nn.Module): def __init__(self, local_weight=1.0, struct_weight=1.0, shape_weight=1.0, global_weight=0.5): super().__init__() self.lw, self.sw, self.shw, self.gw = local_weight, struct_weight, shape_weight, global_weight def forward(self, outputs, targets): occ_mask = targets["patch_occupancy"] > 0.01 n_occ = occ_mask.sum().clamp(min=1) # --- Local gate losses --- dim_loss = F.cross_entropy( outputs["local_dim_logits"].view(-1, NUM_LOCAL_DIMS), targets["patch_dims"].clamp(0, NUM_LOCAL_DIMS - 1).view(-1), reduction='none').view_as(occ_mask) curv_loss = F.cross_entropy( outputs["local_curv_logits"].view(-1, NUM_LOCAL_CURVS), targets["patch_curvature"].clamp(0, NUM_LOCAL_CURVS - 1).view(-1), reduction='none').view_as(occ_mask) bound_loss = F.binary_cross_entropy_with_logits( outputs["local_bound_logits"].squeeze(-1), targets["patch_boundary"], reduction='none') axis_loss = F.binary_cross_entropy_with_logits( outputs["local_axis_logits"], targets["patch_axis_active"], reduction='none').mean(dim=-1) local_loss = ((dim_loss + curv_loss + bound_loss + axis_loss) * occ_mask.float()).sum() / n_occ # --- Structural gate losses --- topo_loss = F.cross_entropy( outputs["struct_topo_logits"].view(-1, NUM_STRUCT_TOPO), targets["patch_topology"].clamp(0, NUM_STRUCT_TOPO - 1).view(-1), reduction='none').view_as(occ_mask) neighbor_loss = F.mse_loss( torch.sigmoid(outputs["struct_neighbor_logits"].squeeze(-1)), targets["patch_neighbor_count"], reduction='none') role_loss = F.cross_entropy( outputs["struct_role_logits"].view(-1, NUM_STRUCT_ROLE), targets["patch_surface_role"].clamp(0, NUM_STRUCT_ROLE - 1).view(-1), reduction='none').view_as(occ_mask) struct_loss = ((topo_loss + neighbor_loss + role_loss) * occ_mask.float()).sum() / n_occ # --- Shape losses --- shape_loss = F.binary_cross_entropy_with_logits( outputs["patch_shape_logits"], targets["patch_shape_membership"], reduction='none').mean(dim=-1) shape_loss = (shape_loss * occ_mask.float()).sum() / n_occ # --- Global losses --- global_gate_loss = F.binary_cross_entropy_with_logits(outputs["global_gates"], targets["global_gates"]) global_shape_loss = F.binary_cross_entropy_with_logits(outputs["global_shapes"], targets["global_shapes"]) global_loss = global_gate_loss + global_shape_loss total = self.lw * local_loss + self.sw * struct_loss + self.shw * shape_loss + self.gw * global_loss return { "total": total, "local": local_loss, "struct": struct_loss, "shape": shape_loss, "global": global_loss, } print("✓ Model ready (Two-Tier Gated Transformer)")