""" Patch Cross-Attention Shape Classifier — VAE-Matched (8×16×16) ================================================================ Replaces Conv3d backbone with v11-style decomposition + cross-attention. Input: (B, 8, 16, 16) binary voxel grid → Decompose into patches (macro grid) → Shared patch encoder (MLP + handcrafted) → Positional embedding → Cross-attention layers (patches attend to each other) → Pool → Classify Patch scheme: 2×4×4 patches → 4×4×4 macro grid (64 patches, 32 voxels each) - Preserves aspect ratio at macro level - 32 voxels per patch = tractable for shared MLP - 64 patches = reasonable sequence length for attention """ import math import torch import torch.nn as nn import torch.nn.functional as F # === Grid Constants =========================================================== GZ = 8 GY = 16 GX = 16 GRID_SHAPE = (GZ, GY, GX) GRID_VOLUME = GZ * GY * GX # 2048 # Patch decomposition PATCH_Z = 2 PATCH_Y = 4 PATCH_X = 4 PATCH_VOL = PATCH_Z * PATCH_Y * PATCH_X # 32 MACRO_Z = GZ // PATCH_Z # 4 MACRO_Y = GY // PATCH_Y # 4 MACRO_X = GX // PATCH_X # 4 MACRO_N = MACRO_Z * MACRO_Y * MACRO_X # 64 # Shape classes NUM_CLASSES = 38 NUM_CURVATURES = 8 CLASS_NAMES = [ "point", "line_x", "line_y", "line_z", "line_diag", "cross", "l_shape", "collinear", "triangle_xy", "triangle_xz", "triangle_3d", "square_xy", "square_xz", "rectangle", "coplanar", "plane", "tetrahedron", "pyramid", "pentachoron", "cube", "cuboid", "triangular_prism", "octahedron", "arc", "helix", "circle", "ellipse", "disc", "sphere", "hemisphere", "cylinder", "cone", "capsule", "torus", "shell", "tube", "bowl", "saddle", ] CURVATURE_NAMES = ["none", "convex", "concave", "cylindrical", "conical", "toroidal", "hyperbolic", "helical"] # === SwiGLU =================================================================== class SwiGLU(nn.Module): def __init__(self, in_dim, out_dim): super().__init__() self.w1 = nn.Linear(in_dim, out_dim) self.w2 = nn.Linear(in_dim, out_dim) def forward(self, x): return self.w1(x) * F.silu(self.w2(x)) # === Patch Encoder ============================================================ class PatchEncoder(nn.Module): """ Shared encoder for each 2×4×4 local patch. Input: (M, 2, 4, 4) binary grids where M = B * 64 Output: (M, patch_feat_dim) feature vectors """ def __init__(self, patch_feat_dim=96): super().__init__() # Learned features from raw voxels self.mlp = nn.Sequential( nn.Linear(PATCH_VOL, 256), nn.GELU(), nn.Linear(256, 128), nn.GELU(), nn.Linear(128, patch_feat_dim)) # Handcrafted: occupancy(1) + 3 axis std(3) + surface ratio(1) # + z_spread(1) + yx_spread(1) = 7 n_hand = 7 self.combine = nn.Sequential( nn.Linear(patch_feat_dim + n_hand, patch_feat_dim), nn.GELU(), nn.Linear(patch_feat_dim, patch_feat_dim)) def forward(self, patches): """patches: (M, 2, 4, 4)""" M = patches.shape[0] flat = patches.reshape(M, -1) learned = self.mlp(flat) # Handcrafted features occ = flat.mean(dim=-1, keepdim=True) ax_z = patches.mean(dim=(2, 3)).std(dim=1, keepdim=True) ax_y = patches.mean(dim=(1, 3)).std(dim=1, keepdim=True) ax_x = patches.mean(dim=(1, 2)).std(dim=1, keepdim=True) # Surface ratio padded = F.pad(patches.unsqueeze(1), (1,1,1,1,1,1), mode='constant', value=0) neighbors = F.avg_pool3d(padded, kernel_size=3, stride=1, padding=0) neighbors = neighbors.squeeze(1) surface = ((neighbors < 1.0) & (patches > 0.5)).float().sum(dim=(1,2,3)) total = flat.sum(dim=-1).clamp(min=1) surf_ratio = (surface / total).unsqueeze(-1) # Spread: how much of the z vs yx space is used z_spread = (patches.sum(dim=(2, 3)) > 0).float().mean(dim=1, keepdim=True) yx_spread = (patches.sum(dim=1) > 0).float().mean(dim=(1, 2)).unsqueeze(-1) hand = torch.cat([occ, ax_z, ax_y, ax_x, surf_ratio, z_spread, yx_spread], dim=-1) return self.combine(torch.cat([learned, hand], dim=-1)) # === Cross-Attention Block ==================================================== class CrossAttentionBlock(nn.Module): """ Pre-norm transformer block: LN → MHA → residual → LN → FFN → residual. Patches cross-attend to each other (self-attention over patch sequence). """ def __init__(self, embed_dim, num_heads=8, ff_mult=2, dropout=0.05): super().__init__() self.ln1 = nn.LayerNorm(embed_dim) self.attn = nn.MultiheadAttention( embed_dim, num_heads=num_heads, batch_first=True, dropout=dropout) self.ln2 = nn.LayerNorm(embed_dim) self.ff = nn.Sequential( nn.Linear(embed_dim, embed_dim * ff_mult), nn.GELU(), nn.Linear(embed_dim * ff_mult, embed_dim), nn.Dropout(dropout)) def forward(self, x): # Self-attention (each patch attends to all patches) normed = self.ln1(x) attn_out, _ = self.attn(normed, normed, normed) x = x + attn_out x = x + self.ff(self.ln2(x)) return x # === Main Classifier ========================================================== class PatchCrossAttentionClassifier(nn.Module): """ 8×16×16 → patch decomposition → shared encoder → cross-attention → classify. Architecture: 1. Decompose (B, 8, 16, 16) into (B, 64, 2, 4, 4) patches 2. Shared PatchEncoder → (B, 64, patch_feat_dim) 3. Project + add 3D positional embedding → (B, 64, embed_dim) 4. N cross-attention layers 5. Global pool → classify ~2-3M params depending on config. """ def __init__(self, n_classes=NUM_CLASSES, embed_dim=128, patch_feat_dim=96, n_layers=3, n_heads=8, dropout=0.05): super().__init__() self.embed_dim = embed_dim self.patch_feat_dim = patch_feat_dim # Shared patch encoder self.patch_encoder = PatchEncoder(patch_feat_dim) # Project patch features + occupancy + position → embed_dim patch_in = patch_feat_dim + 1 + 3 # feat + occ + 3D pos self.patch_proj = nn.Sequential( nn.Linear(patch_in, embed_dim), nn.GELU(), nn.Linear(embed_dim, embed_dim)) # Learnable 3D positional embedding for macro grid self.pos_embed = nn.Parameter(torch.randn(1, MACRO_N, embed_dim) * 0.02) # Cross-attention layers self.layers = nn.ModuleList([ CrossAttentionBlock(embed_dim, n_heads, ff_mult=2, dropout=dropout) for _ in range(n_layers) ]) # Final norm before pooling self.final_ln = nn.LayerNorm(embed_dim) # Global features: occupancy stats from full grid n_global = 11 # same as VAEShapeClassifier handcrafted self.global_proj = nn.Sequential( nn.Linear(n_global, 64), nn.GELU(), nn.Linear(64, 64)) # Classification class_in = embed_dim + 64 # pooled attention + global features self.class_in = class_in self.classifier = nn.Sequential( nn.Linear(class_in, 256), nn.GELU(), nn.Dropout(0.1), nn.Linear(256, 128), nn.GELU(), nn.Linear(128, n_classes)) # Auxiliary heads self.dim_head = nn.Sequential( nn.Linear(class_in, 64), nn.GELU(), nn.Linear(64, 4)) self.curved_head = nn.Sequential( nn.Linear(class_in, 64), nn.GELU(), nn.Linear(64, 1)) self.curv_type_head = nn.Sequential( nn.Linear(class_in, 64), nn.GELU(), nn.Linear(64, NUM_CURVATURES)) # Precompute macro grid positions (normalized) coords = torch.stack(torch.meshgrid( torch.arange(MACRO_Z, dtype=torch.float32) / max(MACRO_Z - 1, 1), torch.arange(MACRO_Y, dtype=torch.float32) / max(MACRO_Y - 1, 1), torch.arange(MACRO_X, dtype=torch.float32) / max(MACRO_X - 1, 1), indexing="ij"), dim=-1) self.register_buffer("macro_pos", coords.reshape(1, MACRO_N, 3)) def _decompose_patches(self, grid): """ (B, 8, 16, 16) → (B*64, 2, 4, 4) Reshape into (B, 4, 2, 4, 4, 4, 4) then permute/flatten. Z: 8 = 4 macro × 2 local Y: 16 = 4 macro × 4 local X: 16 = 4 macro × 4 local """ B = grid.shape[0] # (B, 8, 16, 16) → (B, MZ, PZ, MY, PY, MX, PX) x = grid.reshape(B, MACRO_Z, PATCH_Z, MACRO_Y, PATCH_Y, MACRO_X, PATCH_X) # → (B, MZ, MY, MX, PZ, PY, PX) x = x.permute(0, 1, 3, 5, 2, 4, 6).contiguous() # → (B*64, 2, 4, 4) return x.reshape(B * MACRO_N, PATCH_Z, PATCH_Y, PATCH_X) def _global_features(self, grid): """Extract global geometric statistics from (B, 8, 16, 16) grid.""" B = grid.shape[0] flat = grid.reshape(B, -1) occ = flat.mean(dim=-1, keepdim=True) ax_z = grid.mean(dim=(2, 3)).std(dim=1, keepdim=True) ax_y = grid.mean(dim=(1, 3)).std(dim=1, keepdim=True) ax_x = grid.mean(dim=(1, 2)).std(dim=1, keepdim=True) # Surface ratio padded = F.pad(grid.unsqueeze(1), (1,1,1,1,1,1), mode='constant', value=0) neighbors = F.avg_pool3d(padded, kernel_size=3, stride=1, padding=0) neighbors = neighbors.squeeze(1) surface = ((neighbors < 1.0) & (grid > 0.5)).float().sum(dim=(1,2,3)) total = flat.sum(dim=-1).clamp(min=1) surf_ratio = (surface / total).unsqueeze(-1) # Axis projection symmetry proj_z = grid.max(dim=1).values proj_y = grid.max(dim=2).values proj_x = grid.max(dim=3).values sym_z = 1.0 - (proj_z - torch.flip(proj_z, [1, 2])).abs().mean(dim=(1, 2)) sym_y = 1.0 - (proj_y - torch.flip(proj_y, [1, 2])).abs().mean(dim=(1, 2)) sym_x = 1.0 - (proj_x - torch.flip(proj_x, [1, 2])).abs().mean(dim=(1, 2)) sym = torch.stack([sym_z, sym_y, sym_x], dim=-1) # Spatial extent z_extent = (grid.sum(dim=(2, 3)) > 0).float().sum(dim=1, keepdim=True) / GZ y_extent = (grid.sum(dim=(1, 3)) > 0).float().sum(dim=1, keepdim=True) / GY x_extent = (grid.sum(dim=(1, 2)) > 0).float().sum(dim=1, keepdim=True) / GX extent = torch.cat([z_extent, y_extent, x_extent], dim=-1) return torch.cat([occ, ax_z, ax_y, ax_x, surf_ratio, sym, extent], dim=-1) def forward(self, grid, labels=None): """ grid: (B, 8, 16, 16) binary voxel grid """ B = grid.shape[0] # === Global features === global_feat = self.global_proj(self._global_features(grid)) # === Patch decomposition + encoding === patches = self._decompose_patches(grid) # (B*64, 2, 4, 4) patch_feats = self.patch_encoder(patches) # (B*64, patch_feat_dim) patch_feats = patch_feats.reshape(B, MACRO_N, self.patch_feat_dim) # Per-patch occupancy patch_occ = patches.reshape(B, MACRO_N, PATCH_VOL).mean(dim=-1, keepdim=True) # Combine: features + occupancy + position pos = self.macro_pos.expand(B, -1, -1) patch_input = torch.cat([patch_feats, patch_occ, pos], dim=-1) x = self.patch_proj(patch_input) # Add learnable positional embedding x = x + self.pos_embed # === Cross-attention layers === for layer in self.layers: x = layer(x) x = self.final_ln(x) # === Pool: mean over patches === pooled = x.mean(dim=1) # (B, embed_dim) # === Combine with global features === feat = torch.cat([pooled, global_feat], dim=-1) # (B, class_in) # === Classification === class_logits = self.classifier(feat) dim_logits = self.dim_head(feat) is_curved = self.curved_head(feat) curv_logits = self.curv_type_head(feat) return { "class_logits": class_logits, "dim_logits": dim_logits, "is_curved_pred": is_curved, "curv_type_logits": curv_logits, "features": feat, } # === Confidence =============================================================== def compute_confidence(logits): probs = F.softmax(logits, dim=-1) max_prob, _ = probs.max(dim=-1) top2 = probs.topk(2, dim=-1).values margin = top2[:, 0] - top2[:, 1] log_probs = F.log_softmax(logits, dim=-1) entropy = -(probs * log_probs).sum(dim=-1) max_entropy = math.log(logits.shape[-1]) return {"max_prob": max_prob, "margin": margin, "entropy": entropy / max_entropy, "confidence": margin} # === Sanity check ============================================================= if __name__ == "__main__": _m = PatchCrossAttentionClassifier() _n = sum(p.numel() for p in _m.parameters()) print(f'PatchCrossAttentionClassifier: {_n:,} params') print(f' Patches: {MACRO_Z}×{MACRO_Y}×{MACRO_X} = {MACRO_N} patches of {PATCH_Z}×{PATCH_Y}×{PATCH_X}') _dummy = torch.zeros(2, GZ, GY, GX) with torch.no_grad(): _out = _m(_dummy) print(f' class_logits: {_out["class_logits"].shape}') print(f' features: {_out["features"].shape}') print(f' class_in: {_m.class_in}') del _m, _dummy, _out