| | """ |
| | Patch Cross-Attention Shape Classifier β VAE-Matched (8Γ16Γ16) |
| | ================================================================ |
| | Replaces Conv3d backbone with v11-style decomposition + cross-attention. |
| | |
| | Input: (B, 8, 16, 16) binary voxel grid |
| | β Decompose into patches (macro grid) |
| | β Shared patch encoder (MLP + handcrafted) |
| | β Positional embedding |
| | β Cross-attention layers (patches attend to each other) |
| | β Pool β Classify |
| | |
| | Patch scheme: 2Γ4Γ4 patches β 4Γ4Γ4 macro grid (64 patches, 32 voxels each) |
| | - Preserves aspect ratio at macro level |
| | - 32 voxels per patch = tractable for shared MLP |
| | - 64 patches = reasonable sequence length for attention |
| | """ |
| |
|
| | import math |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | |
| | GZ = 8 |
| | GY = 16 |
| | GX = 16 |
| | GRID_SHAPE = (GZ, GY, GX) |
| | GRID_VOLUME = GZ * GY * GX |
| |
|
| | |
| | PATCH_Z = 2 |
| | PATCH_Y = 4 |
| | PATCH_X = 4 |
| | PATCH_VOL = PATCH_Z * PATCH_Y * PATCH_X |
| |
|
| | MACRO_Z = GZ // PATCH_Z |
| | MACRO_Y = GY // PATCH_Y |
| | MACRO_X = GX // PATCH_X |
| | MACRO_N = MACRO_Z * MACRO_Y * MACRO_X |
| |
|
| | |
| | NUM_CLASSES = 38 |
| | NUM_CURVATURES = 8 |
| |
|
| | CLASS_NAMES = [ |
| | "point", "line_x", "line_y", "line_z", "line_diag", |
| | "cross", "l_shape", "collinear", |
| | "triangle_xy", "triangle_xz", "triangle_3d", |
| | "square_xy", "square_xz", "rectangle", "coplanar", "plane", |
| | "tetrahedron", "pyramid", "pentachoron", |
| | "cube", "cuboid", "triangular_prism", "octahedron", |
| | "arc", "helix", "circle", "ellipse", "disc", |
| | "sphere", "hemisphere", "cylinder", "cone", "capsule", |
| | "torus", "shell", "tube", "bowl", "saddle", |
| | ] |
| |
|
| | CURVATURE_NAMES = ["none", "convex", "concave", "cylindrical", |
| | "conical", "toroidal", "hyperbolic", "helical"] |
| |
|
| |
|
| | |
| |
|
| | class SwiGLU(nn.Module): |
| | def __init__(self, in_dim, out_dim): |
| | super().__init__() |
| | self.w1 = nn.Linear(in_dim, out_dim) |
| | self.w2 = nn.Linear(in_dim, out_dim) |
| |
|
| | def forward(self, x): |
| | return self.w1(x) * F.silu(self.w2(x)) |
| |
|
| |
|
| | |
| |
|
| | class PatchEncoder(nn.Module): |
| | """ |
| | Shared encoder for each 2Γ4Γ4 local patch. |
| | Input: (M, 2, 4, 4) binary grids where M = B * 64 |
| | Output: (M, patch_feat_dim) feature vectors |
| | """ |
| |
|
| | def __init__(self, patch_feat_dim=96): |
| | super().__init__() |
| |
|
| | |
| | self.mlp = nn.Sequential( |
| | nn.Linear(PATCH_VOL, 256), nn.GELU(), |
| | nn.Linear(256, 128), nn.GELU(), |
| | nn.Linear(128, patch_feat_dim)) |
| |
|
| | |
| | |
| | n_hand = 7 |
| | self.combine = nn.Sequential( |
| | nn.Linear(patch_feat_dim + n_hand, patch_feat_dim), nn.GELU(), |
| | nn.Linear(patch_feat_dim, patch_feat_dim)) |
| |
|
| | def forward(self, patches): |
| | """patches: (M, 2, 4, 4)""" |
| | M = patches.shape[0] |
| | flat = patches.reshape(M, -1) |
| |
|
| | learned = self.mlp(flat) |
| |
|
| | |
| | occ = flat.mean(dim=-1, keepdim=True) |
| |
|
| | ax_z = patches.mean(dim=(2, 3)).std(dim=1, keepdim=True) |
| | ax_y = patches.mean(dim=(1, 3)).std(dim=1, keepdim=True) |
| | ax_x = patches.mean(dim=(1, 2)).std(dim=1, keepdim=True) |
| |
|
| | |
| | padded = F.pad(patches.unsqueeze(1), (1,1,1,1,1,1), mode='constant', value=0) |
| | neighbors = F.avg_pool3d(padded, kernel_size=3, stride=1, padding=0) |
| | neighbors = neighbors.squeeze(1) |
| | surface = ((neighbors < 1.0) & (patches > 0.5)).float().sum(dim=(1,2,3)) |
| | total = flat.sum(dim=-1).clamp(min=1) |
| | surf_ratio = (surface / total).unsqueeze(-1) |
| |
|
| | |
| | z_spread = (patches.sum(dim=(2, 3)) > 0).float().mean(dim=1, keepdim=True) |
| | yx_spread = (patches.sum(dim=1) > 0).float().mean(dim=(1, 2)).unsqueeze(-1) |
| |
|
| | hand = torch.cat([occ, ax_z, ax_y, ax_x, surf_ratio, z_spread, yx_spread], dim=-1) |
| |
|
| | return self.combine(torch.cat([learned, hand], dim=-1)) |
| |
|
| |
|
| | |
| |
|
| | class CrossAttentionBlock(nn.Module): |
| | """ |
| | Pre-norm transformer block: LN β MHA β residual β LN β FFN β residual. |
| | Patches cross-attend to each other (self-attention over patch sequence). |
| | """ |
| |
|
| | def __init__(self, embed_dim, num_heads=8, ff_mult=2, dropout=0.05): |
| | super().__init__() |
| | self.ln1 = nn.LayerNorm(embed_dim) |
| | self.attn = nn.MultiheadAttention( |
| | embed_dim, num_heads=num_heads, batch_first=True, dropout=dropout) |
| | self.ln2 = nn.LayerNorm(embed_dim) |
| | self.ff = nn.Sequential( |
| | nn.Linear(embed_dim, embed_dim * ff_mult), nn.GELU(), |
| | nn.Linear(embed_dim * ff_mult, embed_dim), |
| | nn.Dropout(dropout)) |
| |
|
| | def forward(self, x): |
| | |
| | normed = self.ln1(x) |
| | attn_out, _ = self.attn(normed, normed, normed) |
| | x = x + attn_out |
| | x = x + self.ff(self.ln2(x)) |
| | return x |
| |
|
| |
|
| | |
| |
|
| | class PatchCrossAttentionClassifier(nn.Module): |
| | """ |
| | 8Γ16Γ16 β patch decomposition β shared encoder β cross-attention β classify. |
| | |
| | Architecture: |
| | 1. Decompose (B, 8, 16, 16) into (B, 64, 2, 4, 4) patches |
| | 2. Shared PatchEncoder β (B, 64, patch_feat_dim) |
| | 3. Project + add 3D positional embedding β (B, 64, embed_dim) |
| | 4. N cross-attention layers |
| | 5. Global pool β classify |
| | |
| | ~2-3M params depending on config. |
| | """ |
| |
|
| | def __init__(self, n_classes=NUM_CLASSES, embed_dim=128, patch_feat_dim=96, |
| | n_layers=3, n_heads=8, dropout=0.05): |
| | super().__init__() |
| | self.embed_dim = embed_dim |
| | self.patch_feat_dim = patch_feat_dim |
| |
|
| | |
| | self.patch_encoder = PatchEncoder(patch_feat_dim) |
| |
|
| | |
| | patch_in = patch_feat_dim + 1 + 3 |
| | self.patch_proj = nn.Sequential( |
| | nn.Linear(patch_in, embed_dim), nn.GELU(), |
| | nn.Linear(embed_dim, embed_dim)) |
| |
|
| | |
| | self.pos_embed = nn.Parameter(torch.randn(1, MACRO_N, embed_dim) * 0.02) |
| |
|
| | |
| | self.layers = nn.ModuleList([ |
| | CrossAttentionBlock(embed_dim, n_heads, ff_mult=2, dropout=dropout) |
| | for _ in range(n_layers) |
| | ]) |
| |
|
| | |
| | self.final_ln = nn.LayerNorm(embed_dim) |
| |
|
| | |
| | n_global = 11 |
| | self.global_proj = nn.Sequential( |
| | nn.Linear(n_global, 64), nn.GELU(), |
| | nn.Linear(64, 64)) |
| |
|
| | |
| | class_in = embed_dim + 64 |
| | self.class_in = class_in |
| | self.classifier = nn.Sequential( |
| | nn.Linear(class_in, 256), nn.GELU(), nn.Dropout(0.1), |
| | nn.Linear(256, 128), nn.GELU(), |
| | nn.Linear(128, n_classes)) |
| |
|
| | |
| | self.dim_head = nn.Sequential( |
| | nn.Linear(class_in, 64), nn.GELU(), nn.Linear(64, 4)) |
| | self.curved_head = nn.Sequential( |
| | nn.Linear(class_in, 64), nn.GELU(), nn.Linear(64, 1)) |
| | self.curv_type_head = nn.Sequential( |
| | nn.Linear(class_in, 64), nn.GELU(), nn.Linear(64, NUM_CURVATURES)) |
| |
|
| | |
| | coords = torch.stack(torch.meshgrid( |
| | torch.arange(MACRO_Z, dtype=torch.float32) / max(MACRO_Z - 1, 1), |
| | torch.arange(MACRO_Y, dtype=torch.float32) / max(MACRO_Y - 1, 1), |
| | torch.arange(MACRO_X, dtype=torch.float32) / max(MACRO_X - 1, 1), |
| | indexing="ij"), dim=-1) |
| | self.register_buffer("macro_pos", coords.reshape(1, MACRO_N, 3)) |
| |
|
| | def _decompose_patches(self, grid): |
| | """ |
| | (B, 8, 16, 16) β (B*64, 2, 4, 4) |
| | |
| | Reshape into (B, 4, 2, 4, 4, 4, 4) then permute/flatten. |
| | Z: 8 = 4 macro Γ 2 local |
| | Y: 16 = 4 macro Γ 4 local |
| | X: 16 = 4 macro Γ 4 local |
| | """ |
| | B = grid.shape[0] |
| | |
| | x = grid.reshape(B, MACRO_Z, PATCH_Z, MACRO_Y, PATCH_Y, MACRO_X, PATCH_X) |
| | |
| | x = x.permute(0, 1, 3, 5, 2, 4, 6).contiguous() |
| | |
| | return x.reshape(B * MACRO_N, PATCH_Z, PATCH_Y, PATCH_X) |
| |
|
| | def _global_features(self, grid): |
| | """Extract global geometric statistics from (B, 8, 16, 16) grid.""" |
| | B = grid.shape[0] |
| | flat = grid.reshape(B, -1) |
| |
|
| | occ = flat.mean(dim=-1, keepdim=True) |
| |
|
| | ax_z = grid.mean(dim=(2, 3)).std(dim=1, keepdim=True) |
| | ax_y = grid.mean(dim=(1, 3)).std(dim=1, keepdim=True) |
| | ax_x = grid.mean(dim=(1, 2)).std(dim=1, keepdim=True) |
| |
|
| | |
| | padded = F.pad(grid.unsqueeze(1), (1,1,1,1,1,1), mode='constant', value=0) |
| | neighbors = F.avg_pool3d(padded, kernel_size=3, stride=1, padding=0) |
| | neighbors = neighbors.squeeze(1) |
| | surface = ((neighbors < 1.0) & (grid > 0.5)).float().sum(dim=(1,2,3)) |
| | total = flat.sum(dim=-1).clamp(min=1) |
| | surf_ratio = (surface / total).unsqueeze(-1) |
| |
|
| | |
| | proj_z = grid.max(dim=1).values |
| | proj_y = grid.max(dim=2).values |
| | proj_x = grid.max(dim=3).values |
| |
|
| | sym_z = 1.0 - (proj_z - torch.flip(proj_z, [1, 2])).abs().mean(dim=(1, 2)) |
| | sym_y = 1.0 - (proj_y - torch.flip(proj_y, [1, 2])).abs().mean(dim=(1, 2)) |
| | sym_x = 1.0 - (proj_x - torch.flip(proj_x, [1, 2])).abs().mean(dim=(1, 2)) |
| | sym = torch.stack([sym_z, sym_y, sym_x], dim=-1) |
| |
|
| | |
| | z_extent = (grid.sum(dim=(2, 3)) > 0).float().sum(dim=1, keepdim=True) / GZ |
| | y_extent = (grid.sum(dim=(1, 3)) > 0).float().sum(dim=1, keepdim=True) / GY |
| | x_extent = (grid.sum(dim=(1, 2)) > 0).float().sum(dim=1, keepdim=True) / GX |
| | extent = torch.cat([z_extent, y_extent, x_extent], dim=-1) |
| |
|
| | return torch.cat([occ, ax_z, ax_y, ax_x, surf_ratio, sym, extent], dim=-1) |
| |
|
| | def forward(self, grid, labels=None): |
| | """ |
| | grid: (B, 8, 16, 16) binary voxel grid |
| | """ |
| | B = grid.shape[0] |
| |
|
| | |
| | global_feat = self.global_proj(self._global_features(grid)) |
| |
|
| | |
| | patches = self._decompose_patches(grid) |
| | patch_feats = self.patch_encoder(patches) |
| | patch_feats = patch_feats.reshape(B, MACRO_N, self.patch_feat_dim) |
| |
|
| | |
| | patch_occ = patches.reshape(B, MACRO_N, PATCH_VOL).mean(dim=-1, keepdim=True) |
| |
|
| | |
| | pos = self.macro_pos.expand(B, -1, -1) |
| | patch_input = torch.cat([patch_feats, patch_occ, pos], dim=-1) |
| | x = self.patch_proj(patch_input) |
| |
|
| | |
| | x = x + self.pos_embed |
| |
|
| | |
| | for layer in self.layers: |
| | x = layer(x) |
| |
|
| | x = self.final_ln(x) |
| |
|
| | |
| | pooled = x.mean(dim=1) |
| |
|
| | |
| | feat = torch.cat([pooled, global_feat], dim=-1) |
| |
|
| | |
| | class_logits = self.classifier(feat) |
| | dim_logits = self.dim_head(feat) |
| | is_curved = self.curved_head(feat) |
| | curv_logits = self.curv_type_head(feat) |
| |
|
| | return { |
| | "class_logits": class_logits, |
| | "dim_logits": dim_logits, |
| | "is_curved_pred": is_curved, |
| | "curv_type_logits": curv_logits, |
| | "features": feat, |
| | } |
| |
|
| |
|
| | |
| |
|
| | def compute_confidence(logits): |
| | probs = F.softmax(logits, dim=-1) |
| | max_prob, _ = probs.max(dim=-1) |
| | top2 = probs.topk(2, dim=-1).values |
| | margin = top2[:, 0] - top2[:, 1] |
| | log_probs = F.log_softmax(logits, dim=-1) |
| | entropy = -(probs * log_probs).sum(dim=-1) |
| | max_entropy = math.log(logits.shape[-1]) |
| | return {"max_prob": max_prob, "margin": margin, |
| | "entropy": entropy / max_entropy, "confidence": margin} |
| |
|
| |
|
| | |
| | if __name__ == "__main__": |
| | _m = PatchCrossAttentionClassifier() |
| | _n = sum(p.numel() for p in _m.parameters()) |
| | print(f'PatchCrossAttentionClassifier: {_n:,} params') |
| | print(f' Patches: {MACRO_Z}Γ{MACRO_Y}Γ{MACRO_X} = {MACRO_N} patches of {PATCH_Z}Γ{PATCH_Y}Γ{PATCH_X}') |
| | _dummy = torch.zeros(2, GZ, GY, GX) |
| | with torch.no_grad(): |
| | _out = _m(_dummy) |
| | print(f' class_logits: {_out["class_logits"].shape}') |
| | print(f' features: {_out["features"].shape}') |
| | print(f' class_in: {_m.class_in}') |
| | del _m, _dummy, _out |