grid-geometric-multishape / cell2_model_v10.py
AbstractPhil's picture
Create cell2_model_v10.py
7e37f31 verified
"""
Superposition Patch Classifier - Two-Tier Gated Transformer
=============================================================
Colab Cell 2 of 3 - depends on Cell 1 (generator.py) namespace.
Architecture:
voxels β†’ patch_embed β†’ eβ‚€
Stage 0 (local gates): From raw embeddings, no attention
eβ‚€ β†’ local_dim_head β†’ dim_soft ─┐
eβ‚€ β†’ local_curv_head β†’ curv_soft ── LOCAL_GATE_DIM = 11
eβ‚€ β†’ local_bound_head β†’ bound_soft ──
eβ‚€ β†’ local_axis_head β†’ axis_soft β”€β”˜β†’ local_gates (detached)
Stage 1 (bootstrap): Attention sees local gates
proj([eβ‚€, local_gates]) β†’ bootstrap_block Γ— N β†’ h
Stage 1.5 (structural gates): From h, after cross-patch context
h β†’ struct_topo_head β†’ topo_soft ─┐
h β†’ struct_neighbor_head β†’ neighbor_soft ── STRUCTURAL_GATE_DIM = 6
h β†’ struct_role_head β†’ role_soft β”€β”˜β†’ structural_gates (detached)
Stage 2 (geometric routing): Both gate tiers
(h, local_gates, structural_gates) β†’ geometric_block Γ— N β†’ h'
Stage 3 (classification): Gated shape heads
[h', local_gates, structural_gates] β†’ shape_heads
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
# Cell 1 provides: all constants including LOCAL_GATE_DIM, STRUCTURAL_GATE_DIM, TOTAL_GATE_DIM
# === Patch Embedding ==========================================================
class PatchEmbedding3D(nn.Module):
def __init__(self, patch_dim=64):
super().__init__()
self.proj = nn.Linear(PATCH_VOL, patch_dim)
pz = torch.arange(MACRO_Z).float() / MACRO_Z
py = torch.arange(MACRO_Y).float() / MACRO_Y
px = torch.arange(MACRO_X).float() / MACRO_X
pos = torch.stack(torch.meshgrid(pz, py, px, indexing='ij'), dim=-1).reshape(MACRO_N, 3)
self.register_buffer('pos_embed', pos)
self.pos_proj = nn.Linear(3, patch_dim)
def forward(self, x):
B = x.shape[0]
patches = x.view(B, MACRO_Z, PATCH_Z, MACRO_Y, PATCH_Y, MACRO_X, PATCH_X)
patches = patches.permute(0, 1, 3, 5, 2, 4, 6).contiguous().view(B, MACRO_N, PATCH_VOL)
return self.proj(patches) + self.pos_proj(self.pos_embed)
# === Standard Transformer Block ===============================================
class TransformerBlock(nn.Module):
def __init__(self, dim, n_heads, dropout=0.1):
super().__init__()
self.attn = nn.MultiheadAttention(dim, n_heads, dropout=dropout, batch_first=True)
self.ff = nn.Sequential(
nn.Linear(dim, dim * 4), nn.GELU(), nn.Dropout(dropout),
nn.Linear(dim * 4, dim), nn.Dropout(dropout)
)
self.ln1, self.ln2 = nn.LayerNorm(dim), nn.LayerNorm(dim)
def forward(self, x):
x = x + self.attn(self.ln1(x), self.ln1(x), self.ln1(x))[0]
return x + self.ff(self.ln2(x))
# === Geometric Gated Attention ================================================
class GatedGeometricAttention(nn.Module):
"""
Multi-head attention with two-tier gate modulation.
Q, K see both local and structural gates.
V modulated by combined gate vector.
Per-head compatibility bias from gate interactions.
"""
def __init__(self, embed_dim, gate_dim, n_heads, dropout=0.1):
super().__init__()
self.embed_dim = embed_dim
self.n_heads = n_heads
self.head_dim = embed_dim // n_heads
# Q, K from [h, all_gates]
self.q_proj = nn.Linear(embed_dim + gate_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim + gate_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
# Per-head gate compatibility
self.gate_q = nn.Linear(gate_dim, n_heads)
self.gate_k = nn.Linear(gate_dim, n_heads)
# Value modulation by gates
self.v_gate = nn.Sequential(nn.Linear(gate_dim, embed_dim), nn.Sigmoid())
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.attn_drop = nn.Dropout(dropout)
self.scale = math.sqrt(self.head_dim)
def forward(self, h, gate_features):
B, N, _ = h.shape
hg = torch.cat([h, gate_features], dim=-1)
Q = self.q_proj(hg).view(B, N, self.n_heads, self.head_dim).transpose(1, 2)
K = self.k_proj(hg).view(B, N, self.n_heads, self.head_dim).transpose(1, 2)
V = self.v_proj(h)
V = (V * self.v_gate(gate_features)).view(B, N, self.n_heads, self.head_dim).transpose(1, 2)
content_scores = (Q @ K.transpose(-2, -1)) / self.scale
gq = self.gate_q(gate_features)
gk = self.gate_k(gate_features)
compat = torch.einsum('bih,bjh->bhij', gq, gk)
attn = F.softmax(content_scores + compat, dim=-1)
attn = self.attn_drop(attn)
out = (attn @ V).transpose(1, 2).reshape(B, N, self.embed_dim)
return self.out_proj(out)
class GeometricTransformerBlock(nn.Module):
def __init__(self, embed_dim, gate_dim, n_heads, dropout=0.1, ff_mult=4):
super().__init__()
self.ln1 = nn.LayerNorm(embed_dim)
self.attn = GatedGeometricAttention(embed_dim, gate_dim, n_heads, dropout)
self.ln2 = nn.LayerNorm(embed_dim)
self.ff = nn.Sequential(
nn.Linear(embed_dim, embed_dim * ff_mult), nn.GELU(), nn.Dropout(dropout),
nn.Linear(embed_dim * ff_mult, embed_dim), nn.Dropout(dropout)
)
def forward(self, h, gate_features):
h = h + self.attn(self.ln1(h), gate_features)
h = h + self.ff(self.ln2(h))
return h
# === Main Classifier ==========================================================
class SuperpositionPatchClassifier(nn.Module):
"""
Two-tier gated transformer for multi-shape superposition.
Tier 1 (local): Gates from raw patch embeddings β€” what IS in this patch
Tier 2 (structural): Gates from post-attention h β€” what ROLE this patch plays
Both tiers feed into geometric attention and classification.
"""
def __init__(self, embed_dim=128, patch_dim=64, n_bootstrap=2, n_geometric=2,
n_heads=4, dropout=0.1):
super().__init__()
self.embed_dim = embed_dim
# Patch embedding
self.patch_embed = PatchEmbedding3D(patch_dim)
# === Stage 0: Local encoder + gate heads (pre-attention) ===
# Shared MLP gives local heads enough capacity to extract
# dims/curvature/boundary from 32 voxels without cross-patch info
local_hidden = patch_dim * 2 # 128
self.local_encoder = nn.Sequential(
nn.Linear(patch_dim, local_hidden), nn.GELU(), nn.Dropout(dropout),
nn.Linear(local_hidden, local_hidden), nn.GELU(), nn.Dropout(dropout),
)
self.local_dim_head = nn.Linear(local_hidden, NUM_LOCAL_DIMS)
self.local_curv_head = nn.Linear(local_hidden, NUM_LOCAL_CURVS)
self.local_bound_head = nn.Linear(local_hidden, NUM_LOCAL_BOUNDARY)
self.local_axis_head = nn.Linear(local_hidden, NUM_LOCAL_AXES)
# Project [embedding, local_gates] β†’ embed_dim for bootstrap
self.proj = nn.Linear(patch_dim + LOCAL_GATE_DIM, embed_dim)
# === Stage 1: Bootstrap blocks (attention with local gate context) ===
self.bootstrap_blocks = nn.ModuleList([
TransformerBlock(embed_dim, n_heads, dropout)
for _ in range(n_bootstrap)
])
# === Stage 1.5: Structural gate heads (from h, post-attention) ===
self.struct_topo_head = nn.Linear(embed_dim, NUM_STRUCT_TOPO)
self.struct_neighbor_head = nn.Linear(embed_dim, NUM_STRUCT_NEIGHBOR)
self.struct_role_head = nn.Linear(embed_dim, NUM_STRUCT_ROLE)
# === Stage 2: Geometric gated blocks (see both gate tiers) ===
self.geometric_blocks = nn.ModuleList([
GeometricTransformerBlock(embed_dim, TOTAL_GATE_DIM, n_heads, dropout)
for _ in range(n_geometric)
])
# === Stage 3: Gated classification ===
gated_dim = embed_dim + TOTAL_GATE_DIM
self.patch_shape_head = nn.Sequential(
nn.Linear(gated_dim, embed_dim), nn.GELU(), nn.Dropout(dropout),
nn.Linear(embed_dim, NUM_CLASSES)
)
self.global_pool = nn.Sequential(
nn.Linear(gated_dim, embed_dim), nn.GELU(),
nn.Linear(embed_dim, embed_dim)
)
self.global_gate_head = nn.Linear(embed_dim, NUM_GATES)
self.global_shape_head = nn.Linear(embed_dim, NUM_CLASSES)
def forward(self, x):
# === Raw patch embedding ===
e = self.patch_embed(x) # (B, 64, patch_dim)
# === Stage 0: Local gates from raw embedding via local encoder ===
e_local = self.local_encoder(e) # (B, 64, local_hidden)
local_dim_logits = self.local_dim_head(e_local)
local_curv_logits = self.local_curv_head(e_local)
local_bound_logits = self.local_bound_head(e_local)
local_axis_logits = self.local_axis_head(e_local)
local_gates = torch.cat([
F.softmax(local_dim_logits, dim=-1),
F.softmax(local_curv_logits, dim=-1),
torch.sigmoid(local_bound_logits),
torch.sigmoid(local_axis_logits),
], dim=-1) # (B, 64, 11)
# === Stage 1: Bootstrap with local gate context ===
h = self.proj(torch.cat([e, local_gates], dim=-1))
for blk in self.bootstrap_blocks:
h = blk(h)
# === Stage 1.5: Structural gates from h (after cross-patch context) ===
struct_topo_logits = self.struct_topo_head(h)
struct_neighbor_logits = self.struct_neighbor_head(h)
struct_role_logits = self.struct_role_head(h)
structural_gates = torch.cat([
F.softmax(struct_topo_logits, dim=-1),
torch.sigmoid(struct_neighbor_logits),
F.softmax(struct_role_logits, dim=-1),
], dim=-1) # (B, 64, 6)
# === Combined gate vector ===
all_gates = torch.cat([local_gates, structural_gates], dim=-1) # (B, 64, 17)
# === Stage 2: Geometric gated transformer ===
for blk in self.geometric_blocks:
h = blk(h, all_gates)
# === Stage 3: Classification from gated representations ===
h_gated = torch.cat([h, all_gates], dim=-1)
shape_logits = self.patch_shape_head(h_gated)
g = self.global_pool(h_gated.mean(dim=1))
return {
# Local gate predictions (Stage 0)
"local_dim_logits": local_dim_logits,
"local_curv_logits": local_curv_logits,
"local_bound_logits": local_bound_logits,
"local_axis_logits": local_axis_logits,
# Structural gate predictions (Stage 1.5)
"struct_topo_logits": struct_topo_logits,
"struct_neighbor_logits": struct_neighbor_logits,
"struct_role_logits": struct_role_logits,
# Shape predictions (Stage 3)
"patch_shape_logits": shape_logits,
"patch_features": h,
"global_features": g,
"global_gates": self.global_gate_head(g),
"global_shapes": self.global_shape_head(g),
}
# === Loss =====================================================================
class SuperpositionLoss(nn.Module):
def __init__(self, local_weight=1.0, struct_weight=1.0, shape_weight=1.0, global_weight=0.5):
super().__init__()
self.lw, self.sw, self.shw, self.gw = local_weight, struct_weight, shape_weight, global_weight
def forward(self, outputs, targets):
occ_mask = targets["patch_occupancy"] > 0.01
n_occ = occ_mask.sum().clamp(min=1)
# --- Local gate losses ---
dim_loss = F.cross_entropy(
outputs["local_dim_logits"].view(-1, NUM_LOCAL_DIMS),
targets["patch_dims"].clamp(0, NUM_LOCAL_DIMS - 1).view(-1),
reduction='none').view_as(occ_mask)
curv_loss = F.cross_entropy(
outputs["local_curv_logits"].view(-1, NUM_LOCAL_CURVS),
targets["patch_curvature"].clamp(0, NUM_LOCAL_CURVS - 1).view(-1),
reduction='none').view_as(occ_mask)
bound_loss = F.binary_cross_entropy_with_logits(
outputs["local_bound_logits"].squeeze(-1),
targets["patch_boundary"],
reduction='none')
axis_loss = F.binary_cross_entropy_with_logits(
outputs["local_axis_logits"],
targets["patch_axis_active"],
reduction='none').mean(dim=-1)
local_loss = ((dim_loss + curv_loss + bound_loss + axis_loss) * occ_mask.float()).sum() / n_occ
# --- Structural gate losses ---
topo_loss = F.cross_entropy(
outputs["struct_topo_logits"].view(-1, NUM_STRUCT_TOPO),
targets["patch_topology"].clamp(0, NUM_STRUCT_TOPO - 1).view(-1),
reduction='none').view_as(occ_mask)
neighbor_loss = F.mse_loss(
torch.sigmoid(outputs["struct_neighbor_logits"].squeeze(-1)),
targets["patch_neighbor_count"],
reduction='none')
role_loss = F.cross_entropy(
outputs["struct_role_logits"].view(-1, NUM_STRUCT_ROLE),
targets["patch_surface_role"].clamp(0, NUM_STRUCT_ROLE - 1).view(-1),
reduction='none').view_as(occ_mask)
struct_loss = ((topo_loss + neighbor_loss + role_loss) * occ_mask.float()).sum() / n_occ
# --- Shape losses ---
shape_loss = F.binary_cross_entropy_with_logits(
outputs["patch_shape_logits"],
targets["patch_shape_membership"],
reduction='none').mean(dim=-1)
shape_loss = (shape_loss * occ_mask.float()).sum() / n_occ
# --- Global losses ---
global_gate_loss = F.binary_cross_entropy_with_logits(outputs["global_gates"], targets["global_gates"])
global_shape_loss = F.binary_cross_entropy_with_logits(outputs["global_shapes"], targets["global_shapes"])
global_loss = global_gate_loss + global_shape_loss
total = self.lw * local_loss + self.sw * struct_loss + self.shw * shape_loss + self.gw * global_loss
return {
"total": total,
"local": local_loss,
"struct": struct_loss,
"shape": shape_loss,
"global": global_loss,
}
print("βœ“ Model ready (Two-Tier Gated Transformer)")