Create model.py
Browse files
model.py
ADDED
|
@@ -0,0 +1,793 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Cell 2
|
| 2 |
+
# === Capacity Head ============================================================
|
| 3 |
+
|
| 4 |
+
class CapacityHead(nn.Module):
|
| 5 |
+
def __init__(self, in_dim, feat_dim, init_capacity=1.0):
|
| 6 |
+
super().__init__()
|
| 7 |
+
self._raw_capacity = nn.Parameter(torch.tensor(math.log(math.exp(init_capacity) - 1)))
|
| 8 |
+
# GELU for cascade: smooth gradients needed for overflow propagation
|
| 9 |
+
self.evidence_net = nn.Sequential(
|
| 10 |
+
nn.Linear(in_dim, feat_dim), nn.GELU(), nn.Linear(feat_dim, 1))
|
| 11 |
+
self.feature_net = nn.Sequential(
|
| 12 |
+
nn.Linear(in_dim, feat_dim), nn.GELU(), nn.Linear(feat_dim, feat_dim))
|
| 13 |
+
self.retain_gate = nn.Sequential(
|
| 14 |
+
nn.Linear(feat_dim + 1, feat_dim), nn.Sigmoid())
|
| 15 |
+
self.overflow_gate = nn.Sequential(
|
| 16 |
+
nn.Linear(feat_dim + 1, feat_dim), nn.Sigmoid())
|
| 17 |
+
|
| 18 |
+
@property
|
| 19 |
+
def capacity(self):
|
| 20 |
+
return F.softplus(self._raw_capacity)
|
| 21 |
+
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
cap = self.capacity
|
| 24 |
+
raw_ev = F.relu(self.evidence_net(x))
|
| 25 |
+
fill = torch.clamp(raw_ev / (cap + 1e-8), max=1.0)
|
| 26 |
+
sat = torch.clamp((raw_ev - cap) / (cap + 1e-8), min=0.0)
|
| 27 |
+
feat = self.feature_net(x)
|
| 28 |
+
retained = self.retain_gate(torch.cat([feat, fill], -1)) * feat * fill
|
| 29 |
+
overflow = self.overflow_gate(torch.cat([feat, sat], -1)) * feat * torch.clamp(sat, max=1.0)
|
| 30 |
+
return fill, overflow, retained, cap, raw_ev
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# === Differentiation Gate =====================================================
|
| 34 |
+
|
| 35 |
+
class DifferentiationGate(nn.Module):
|
| 36 |
+
"""
|
| 37 |
+
Curvature direction analysis via occupancy field differentiation.
|
| 38 |
+
|
| 39 |
+
Computes gradient and Laplacian of the 3D occupancy field to determine:
|
| 40 |
+
- Curvature direction: convex (normals point outward) vs concave (inward)
|
| 41 |
+
- Curvature alternation: where sign flips (saddle points, torus inner/outer)
|
| 42 |
+
- Perturbation robustness: smoothed gradient features survive noise
|
| 43 |
+
|
| 44 |
+
The key insight: a hemisphere and bowl occupy nearly identical voxels,
|
| 45 |
+
but their occupancy gradients point in opposite directions relative
|
| 46 |
+
to the center of mass. The Laplacian's sign distinguishes them.
|
| 47 |
+
|
| 48 |
+
Outputs gate signals that modulate curvature features:
|
| 49 |
+
- direction_gate: learned weighting based on gradient analysis
|
| 50 |
+
- alternation_score: how much curvature sign varies spatially
|
| 51 |
+
- directional_features: rich features encoding curvature orientation
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(self, embed_dim=64):
|
| 55 |
+
super().__init__()
|
| 56 |
+
|
| 57 |
+
# Fixed 3D differentiation kernels — fused into single conv
|
| 58 |
+
# 4 output channels: [grad_x, grad_y, grad_z, laplacian]
|
| 59 |
+
diff_kernels = torch.zeros(4, 1, 3, 3, 3)
|
| 60 |
+
# Sobel X
|
| 61 |
+
diff_kernels[0, 0, 0, 1, 1] = -1; diff_kernels[0, 0, 2, 1, 1] = 1
|
| 62 |
+
# Sobel Y
|
| 63 |
+
diff_kernels[1, 0, 1, 0, 1] = -1; diff_kernels[1, 0, 1, 2, 1] = 1
|
| 64 |
+
# Sobel Z
|
| 65 |
+
diff_kernels[2, 0, 1, 1, 0] = -1; diff_kernels[2, 0, 1, 1, 2] = 1
|
| 66 |
+
# Laplacian
|
| 67 |
+
diff_kernels[3, 0, 1, 1, 1] = -6
|
| 68 |
+
diff_kernels[3, 0, 0, 1, 1] = 1; diff_kernels[3, 0, 2, 1, 1] = 1
|
| 69 |
+
diff_kernels[3, 0, 1, 0, 1] = 1; diff_kernels[3, 0, 1, 2, 1] = 1
|
| 70 |
+
diff_kernels[3, 0, 1, 1, 0] = 1; diff_kernels[3, 0, 1, 1, 2] = 1
|
| 71 |
+
self.register_buffer("diff_kernels", diff_kernels)
|
| 72 |
+
|
| 73 |
+
# Precompute coordinate grid
|
| 74 |
+
coords = torch.stack(torch.meshgrid(
|
| 75 |
+
torch.arange(GS, dtype=torch.float32),
|
| 76 |
+
torch.arange(GS, dtype=torch.float32),
|
| 77 |
+
torch.arange(GS, dtype=torch.float32),
|
| 78 |
+
indexing="ij"), dim=-1) # (5,5,5,3)
|
| 79 |
+
self.register_buffer("coords", coords)
|
| 80 |
+
|
| 81 |
+
# Process gradient-derived features
|
| 82 |
+
# Per-voxel: gradient direction, Laplacian sign, centroid-relative direction
|
| 83 |
+
# Summarized as histograms and statistics
|
| 84 |
+
|
| 85 |
+
# Gradient direction relative to centroid: 3 histogram bins per axis
|
| 86 |
+
# + Laplacian sign distribution: 3 values (frac_pos, frac_neg, frac_zero)
|
| 87 |
+
# + Alternation score: 1 value
|
| 88 |
+
# + Per-axis gradient asymmetry: 3 values
|
| 89 |
+
# + Radial gradient profile: 5 bins
|
| 90 |
+
raw_feat_dim = 3 + 3 + 1 + 3 + 5 # = 15
|
| 91 |
+
# Plus the 3D conv on the Laplacian field preserving spatial structure
|
| 92 |
+
self.lap_conv = nn.Sequential(
|
| 93 |
+
nn.Conv3d(1, 16, 3, padding=1), nn.GELU(),
|
| 94 |
+
nn.Conv3d(16, 16, 3, padding=1), nn.GELU(),
|
| 95 |
+
nn.AdaptiveAvgPool3d(2)) # -> (B, 16, 2, 2, 2) = 128
|
| 96 |
+
lap_conv_dim = 16 * 8 # 128
|
| 97 |
+
|
| 98 |
+
# Gradient magnitude 3D conv (encodes where boundaries are + direction)
|
| 99 |
+
self.grad_conv = nn.Sequential(
|
| 100 |
+
nn.Conv3d(3, 16, 3, padding=1), nn.GELU(), # 3-channel: dx, dy, dz
|
| 101 |
+
nn.Conv3d(16, 16, 3, padding=1), nn.GELU(),
|
| 102 |
+
nn.AdaptiveAvgPool3d(2)) # -> (B, 16, 2, 2, 2) = 128
|
| 103 |
+
grad_conv_dim = 16 * 8 # 128
|
| 104 |
+
|
| 105 |
+
total_feat_dim = raw_feat_dim + lap_conv_dim + grad_conv_dim # 15 + 128 + 128 = 271
|
| 106 |
+
|
| 107 |
+
# Direction gate: SwiGLU for sharp convex/concave gating
|
| 108 |
+
self.direction_net = nn.Sequential(
|
| 109 |
+
SwiGLU(total_feat_dim, embed_dim),
|
| 110 |
+
nn.Linear(embed_dim, embed_dim), nn.Sigmoid())
|
| 111 |
+
|
| 112 |
+
# Directional features: SwiGLU for crisp direction encoding
|
| 113 |
+
self.direction_feat_net = nn.Sequential(
|
| 114 |
+
SwiGLU(total_feat_dim, embed_dim),
|
| 115 |
+
nn.Linear(embed_dim, embed_dim))
|
| 116 |
+
|
| 117 |
+
def forward(self, grid):
|
| 118 |
+
"""
|
| 119 |
+
grid: (B, 5, 5, 5) binary occupancy
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
direction_gate: (B, embed_dim) sigmoid gate for curvature features
|
| 123 |
+
direction_feat: (B, embed_dim) additive directional features
|
| 124 |
+
alternation_score: (B, 1) how much curvature alternates
|
| 125 |
+
"""
|
| 126 |
+
B = grid.shape[0]
|
| 127 |
+
device = grid.device
|
| 128 |
+
vox = grid.unsqueeze(1) # (B, 1, 5, 5, 5)
|
| 129 |
+
|
| 130 |
+
# === Smooth occupancy before differentiation ===
|
| 131 |
+
# Binary voxels produce spike gradients. Light blur creates
|
| 132 |
+
# a continuous field whose derivatives are geometrically meaningful.
|
| 133 |
+
vox_smooth = F.avg_pool3d(
|
| 134 |
+
F.pad(vox, (1,1,1,1,1,1), mode='replicate'),
|
| 135 |
+
kernel_size=3, stride=1, padding=0) # (B, 1, 5, 5, 5)
|
| 136 |
+
|
| 137 |
+
# === Compute gradients + Laplacian in single fused conv ===
|
| 138 |
+
diff = F.conv3d(vox_smooth, self.diff_kernels, padding=1) # (B, 4, 5, 5, 5)
|
| 139 |
+
grad_field = diff[:, :3] # (B, 3, 5, 5, 5) — gx, gy, gz
|
| 140 |
+
gx, gy, gz = diff[:, 0:1], diff[:, 1:2], diff[:, 2:3]
|
| 141 |
+
lap = diff[:, 3:4] # (B, 1, 5, 5, 5)
|
| 142 |
+
|
| 143 |
+
# === Centroid ===
|
| 144 |
+
flat_grid = grid.reshape(B, -1) # (B, 125)
|
| 145 |
+
flat_coords = self.coords.reshape(-1, 3) # (125, 3)
|
| 146 |
+
total_occ = flat_grid.sum(dim=-1, keepdim=True).clamp(min=1) # (B, 1)
|
| 147 |
+
centroids = (flat_grid.unsqueeze(-1) * flat_coords.unsqueeze(0)).sum(dim=1) / total_occ # (B, 3)
|
| 148 |
+
|
| 149 |
+
# === Gradient direction relative to centroid ===
|
| 150 |
+
grad_flat = grad_field.reshape(B, 3, -1).permute(0, 2, 1) # (B, 125, 3)
|
| 151 |
+
diff_from_center = flat_coords.unsqueeze(0) - centroids.unsqueeze(1) # (B, 125, 3)
|
| 152 |
+
diff_norm = diff_from_center / (diff_from_center.norm(dim=-1, keepdim=True) + 1e-8)
|
| 153 |
+
dot_products = (grad_flat * diff_norm).sum(dim=-1) # (B, 125)
|
| 154 |
+
grad_mag = grad_flat.norm(dim=-1) # (B, 125)
|
| 155 |
+
active = (flat_grid > 0.5) & (grad_mag > 0.01)
|
| 156 |
+
|
| 157 |
+
# Histogram of dot product signs (convex/concave/neutral fractions)
|
| 158 |
+
n_active = active.float().sum(-1).clamp(min=1)
|
| 159 |
+
frac_outward = ((dot_products > 0.1) & active).float().sum(-1) / n_active
|
| 160 |
+
frac_inward = ((dot_products < -0.1) & active).float().sum(-1) / n_active
|
| 161 |
+
frac_neutral = 1.0 - frac_outward - frac_inward
|
| 162 |
+
direction_hist = torch.stack([frac_outward, frac_inward, frac_neutral], dim=-1) # (B, 3)
|
| 163 |
+
|
| 164 |
+
# === Laplacian sign distribution (active voxels only) ===
|
| 165 |
+
lap_flat = lap.reshape(B, -1) # (B, 125)
|
| 166 |
+
lap_active = flat_grid > 0.5
|
| 167 |
+
n_lap_active = lap_active.float().sum(-1).clamp(min=1)
|
| 168 |
+
frac_pos_lap = ((lap_flat > 0.1) & lap_active).float().sum(-1) / n_lap_active
|
| 169 |
+
frac_neg_lap = ((lap_flat < -0.1) & lap_active).float().sum(-1) / n_lap_active
|
| 170 |
+
frac_zero_lap = 1.0 - frac_pos_lap - frac_neg_lap
|
| 171 |
+
lap_hist = torch.stack([frac_pos_lap, frac_neg_lap, frac_zero_lap], dim=-1) # (B, 3)
|
| 172 |
+
|
| 173 |
+
# === Alternation score (ACTIVE VOXELS ONLY) ===
|
| 174 |
+
# Only count sign flips between neighbor pairs where BOTH voxels are
|
| 175 |
+
# near occupied regions. Otherwise empty space dilutes the signal.
|
| 176 |
+
lap_3d = lap.squeeze(1) # (B, 5, 5, 5)
|
| 177 |
+
# Boundary mask: dilate occupancy by 1 to include immediate neighbors
|
| 178 |
+
boundary_mask = F.max_pool3d(vox, kernel_size=3, stride=1, padding=1).squeeze(1) # (B,5,5,5)
|
| 179 |
+
|
| 180 |
+
# X-axis: both neighbors must be in boundary region
|
| 181 |
+
bm_x = boundary_mask[:, 1:, :, :] * boundary_mask[:, :-1, :, :] # (B,4,5,5)
|
| 182 |
+
flip_x = (torch.sign(lap_3d[:, 1:, :, :]) * torch.sign(lap_3d[:, :-1, :, :]) < 0).float()
|
| 183 |
+
active_flips_x = (flip_x * bm_x).sum(dim=(1, 2, 3))
|
| 184 |
+
active_pairs_x = bm_x.sum(dim=(1, 2, 3)).clamp(min=1)
|
| 185 |
+
|
| 186 |
+
bm_y = boundary_mask[:, :, 1:, :] * boundary_mask[:, :, :-1, :]
|
| 187 |
+
flip_y = (torch.sign(lap_3d[:, :, 1:, :]) * torch.sign(lap_3d[:, :, :-1, :]) < 0).float()
|
| 188 |
+
active_flips_y = (flip_y * bm_y).sum(dim=(1, 2, 3))
|
| 189 |
+
active_pairs_y = bm_y.sum(dim=(1, 2, 3)).clamp(min=1)
|
| 190 |
+
|
| 191 |
+
bm_z = boundary_mask[:, :, :, 1:] * boundary_mask[:, :, :, :-1]
|
| 192 |
+
flip_z = (torch.sign(lap_3d[:, :, :, 1:]) * torch.sign(lap_3d[:, :, :, :-1]) < 0).float()
|
| 193 |
+
active_flips_z = (flip_z * bm_z).sum(dim=(1, 2, 3))
|
| 194 |
+
active_pairs_z = bm_z.sum(dim=(1, 2, 3)).clamp(min=1)
|
| 195 |
+
|
| 196 |
+
alternation = ((active_flips_x / active_pairs_x +
|
| 197 |
+
active_flips_y / active_pairs_y +
|
| 198 |
+
active_flips_z / active_pairs_z) / 3.0).unsqueeze(-1) # (B, 1)
|
| 199 |
+
|
| 200 |
+
# === Per-axis gradient asymmetry ===
|
| 201 |
+
# Asymmetry: mean gradient along each axis (nonzero = asymmetric curvature)
|
| 202 |
+
gx_mean = (gx.squeeze(1) * grid).sum(dim=(1, 2, 3)) / total_occ.squeeze(-1)
|
| 203 |
+
gy_mean = (gy.squeeze(1) * grid).sum(dim=(1, 2, 3)) / total_occ.squeeze(-1)
|
| 204 |
+
gz_mean = (gz.squeeze(1) * grid).sum(dim=(1, 2, 3)) / total_occ.squeeze(-1)
|
| 205 |
+
grad_asym = torch.stack([gx_mean, gy_mean, gz_mean], dim=-1) # (B, 3)
|
| 206 |
+
|
| 207 |
+
# === Radial gradient profile ===
|
| 208 |
+
# How does gradient magnitude vary with distance from centroid?
|
| 209 |
+
dists = diff_from_center.norm(dim=-1) # (B, 125)
|
| 210 |
+
# Arithmetic binning (Inductor-safe, no bucketize)
|
| 211 |
+
# nan_to_num prevents NaN→long producing garbage indices under BF16
|
| 212 |
+
bin_idx = torch.nan_to_num(dists * (5.0 / 3.5), nan=0.0).long().clamp(0, 4)
|
| 213 |
+
active_mask = (flat_grid > 0.5) # (B, 125)
|
| 214 |
+
radial_grad = torch.zeros(B, 5, device=device)
|
| 215 |
+
# Scatter-add: accumulate grad_mag and counts per bin
|
| 216 |
+
weighted_mag = grad_mag * active_mask.float() # zero out inactive
|
| 217 |
+
one_hot = F.one_hot(bin_idx, 5).float() # (B, 125, 5)
|
| 218 |
+
active_oh = one_hot * active_mask.float().unsqueeze(-1) # mask inactive
|
| 219 |
+
counts = active_oh.sum(dim=1).clamp(min=1) # (B, 5)
|
| 220 |
+
radial_grad = (weighted_mag.unsqueeze(-1) * active_oh).sum(dim=1) / counts
|
| 221 |
+
# (B, 5)
|
| 222 |
+
|
| 223 |
+
# === Conv on Laplacian field (spatial curvature map) ===
|
| 224 |
+
lap_feat = self.lap_conv(lap).reshape(B, -1) # (B, 128)
|
| 225 |
+
|
| 226 |
+
# === Conv on gradient field (directional boundaries) ===
|
| 227 |
+
grad_feat = self.grad_conv(grad_field).reshape(B, -1) # (B, 128)
|
| 228 |
+
|
| 229 |
+
# === Combine all ===
|
| 230 |
+
raw_feat = torch.cat([
|
| 231 |
+
direction_hist, # 3
|
| 232 |
+
lap_hist, # 3
|
| 233 |
+
alternation, # 1
|
| 234 |
+
grad_asym, # 3
|
| 235 |
+
radial_grad, # 5
|
| 236 |
+
], dim=-1) # (B, 15)
|
| 237 |
+
|
| 238 |
+
all_feat = torch.cat([raw_feat, lap_feat, grad_feat], dim=-1) # (B, 271)
|
| 239 |
+
|
| 240 |
+
direction_gate = self.direction_net(all_feat) # (B, embed_dim) sigmoid
|
| 241 |
+
direction_feat = self.direction_feat_net(all_feat) # (B, embed_dim)
|
| 242 |
+
|
| 243 |
+
return direction_gate, direction_feat, alternation
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
# === Deformation Augmentation =================================================
|
| 247 |
+
|
| 248 |
+
def deform_grid(grid, p_dropout=0.1, p_add=0.1, p_shift=0.15):
|
| 249 |
+
"""Fully vectorized voxel augmentation — zero CPU-GPU sync points."""
|
| 250 |
+
B = grid.shape[0]
|
| 251 |
+
device = grid.device
|
| 252 |
+
r = torch.rand(B, 3, device=device)
|
| 253 |
+
out = grid.clone()
|
| 254 |
+
|
| 255 |
+
# --- Voxel dropout (batched, no .any() sync) ---
|
| 256 |
+
drop_sel = (r[:, 0] < p_dropout).view(B, 1, 1, 1)
|
| 257 |
+
keep = torch.rand_like(out) > 0.15
|
| 258 |
+
out = torch.where(drop_sel, out * keep.float(), out)
|
| 259 |
+
|
| 260 |
+
# --- Boundary addition (batched, no .any() sync) ---
|
| 261 |
+
add_sel = (r[:, 1] < p_add).view(B, 1, 1, 1).float()
|
| 262 |
+
dilated = F.max_pool3d(out.unsqueeze(1), kernel_size=3, stride=1, padding=1).squeeze(1)
|
| 263 |
+
boundary = ((dilated > 0.5) & (out < 0.5)).float()
|
| 264 |
+
add_noise = (torch.rand_like(out) < 0.3).float()
|
| 265 |
+
out = (out + boundary * add_noise * add_sel).clamp(max=1.0)
|
| 266 |
+
|
| 267 |
+
# --- Small translation (fully vectorized, no loops, no boolean indexing) ---
|
| 268 |
+
shift_sel = (r[:, 2] < p_shift) # (B,)
|
| 269 |
+
axes = torch.randint(3, (B,), device=device)
|
| 270 |
+
dirs = torch.randint(0, 2, (B,), device=device) * 2 - 1
|
| 271 |
+
|
| 272 |
+
# Precompute all 6 shifted versions of full batch (cheap for 5x5x5)
|
| 273 |
+
# Encode: idx = axis * 2 + (dir==1) → [0..5], 6 = no shift
|
| 274 |
+
versions = []
|
| 275 |
+
for ax in range(3):
|
| 276 |
+
for d in [-1, 1]:
|
| 277 |
+
s = torch.roll(out, shifts=d, dims=ax + 1) # +1 for batch dim
|
| 278 |
+
# Zero wrapped edge
|
| 279 |
+
if d == 1:
|
| 280 |
+
if ax == 0: s[:, 0, :, :] = 0
|
| 281 |
+
elif ax == 1: s[:, :, 0, :] = 0
|
| 282 |
+
else: s[:, :, :, 0] = 0
|
| 283 |
+
else:
|
| 284 |
+
if ax == 0: s[:, -1, :, :] = 0
|
| 285 |
+
elif ax == 1: s[:, :, -1, :] = 0
|
| 286 |
+
else: s[:, :, :, -1] = 0
|
| 287 |
+
versions.append(s)
|
| 288 |
+
versions.append(out) # index 6 = no shift (identity)
|
| 289 |
+
stacked = torch.stack(versions, dim=0) # (7, B, 5, 5, 5)
|
| 290 |
+
|
| 291 |
+
# Per-sample assignment: which version to pick
|
| 292 |
+
assign = torch.where(shift_sel, axes * 2 + (dirs == 1).long(), torch.full_like(axes, 6))
|
| 293 |
+
# Gather: stacked[assign[b], b] for each b
|
| 294 |
+
out = stacked[assign, torch.arange(B, device=device)]
|
| 295 |
+
|
| 296 |
+
return out
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
# === Curvature Head (axis-aware) ==============================================
|
| 300 |
+
|
| 301 |
+
class CurvatureHead(nn.Module):
|
| 302 |
+
"""
|
| 303 |
+
Axis-aware curvature detection with differentiation gating.
|
| 304 |
+
|
| 305 |
+
1. Per-axis max projections -> 2D conv (keeps 2×2 spatial)
|
| 306 |
+
2. Radial occupancy profile from centroid
|
| 307 |
+
3. Axial symmetry + translation invariance scores
|
| 308 |
+
4. 3D conv with spatial preservation (2×2×2)
|
| 309 |
+
5. DifferentiationGate: gradient/Laplacian analysis for direction detection
|
| 310 |
+
|
| 311 |
+
The DifferentiationGate modulates curvature features so that
|
| 312 |
+
convex and concave shapes get distinct representations even when
|
| 313 |
+
their occupancy patterns are nearly identical.
|
| 314 |
+
"""
|
| 315 |
+
|
| 316 |
+
def __init__(self, rigid_feat_dim, fill_dim, embed_dim):
|
| 317 |
+
super().__init__()
|
| 318 |
+
|
| 319 |
+
self.plane_conv = nn.Sequential(
|
| 320 |
+
nn.Conv2d(1, 16, 3, padding=1), nn.GELU(),
|
| 321 |
+
nn.Conv2d(16, 16, 3, padding=1), nn.GELU(),
|
| 322 |
+
nn.AdaptiveAvgPool2d(2))
|
| 323 |
+
plane_feat_dim = 3 * 16 * 4 # 192
|
| 324 |
+
|
| 325 |
+
n_radial = 5
|
| 326 |
+
self.radial_net = nn.Sequential(
|
| 327 |
+
nn.Linear(n_radial, 32), nn.GELU(), nn.Linear(32, 16))
|
| 328 |
+
radial_feat_dim = 16
|
| 329 |
+
|
| 330 |
+
symmetry_feat_dim = 6
|
| 331 |
+
|
| 332 |
+
self.voxel_conv = nn.Sequential(
|
| 333 |
+
nn.Conv3d(1, 16, 3, padding=1), nn.GELU(),
|
| 334 |
+
nn.Conv3d(16, 32, 3, padding=1), nn.GELU(),
|
| 335 |
+
nn.AdaptiveAvgPool3d(2))
|
| 336 |
+
voxel3d_feat_dim = 32 * 8 # 256
|
| 337 |
+
|
| 338 |
+
# DifferentiationGate for curvature direction
|
| 339 |
+
self.diff_gate = DifferentiationGate(embed_dim)
|
| 340 |
+
|
| 341 |
+
# Pre-gate combine (without direction features)
|
| 342 |
+
pre_gate_dim = (plane_feat_dim + radial_feat_dim + symmetry_feat_dim +
|
| 343 |
+
voxel3d_feat_dim + rigid_feat_dim + fill_dim)
|
| 344 |
+
|
| 345 |
+
# Pre-gate feature projection: SwiGLU for sharp geometric feature gating
|
| 346 |
+
self.pre_gate_proj = nn.Sequential(
|
| 347 |
+
SwiGLU(pre_gate_dim, embed_dim * 2),
|
| 348 |
+
nn.Linear(embed_dim * 2, embed_dim))
|
| 349 |
+
|
| 350 |
+
# Post-gate: gated features + direction features + alternation + raw combine
|
| 351 |
+
# = embed_dim (gated) + embed_dim (direction) + 1 (alternation) + pre_gate_dim
|
| 352 |
+
post_gate_dim = embed_dim + embed_dim + 1 + pre_gate_dim
|
| 353 |
+
|
| 354 |
+
# SwiGLU for all curvature decision heads: sharp geometric classification
|
| 355 |
+
self.curved_head = nn.Sequential(
|
| 356 |
+
SwiGLU(post_gate_dim, embed_dim),
|
| 357 |
+
nn.Linear(embed_dim, 1), nn.Sigmoid())
|
| 358 |
+
self.curv_type_head = nn.Sequential(
|
| 359 |
+
SwiGLU(post_gate_dim, embed_dim),
|
| 360 |
+
nn.Linear(embed_dim, NUM_CURVATURES))
|
| 361 |
+
self.curv_features = nn.Sequential(
|
| 362 |
+
SwiGLU(post_gate_dim, embed_dim * 2),
|
| 363 |
+
nn.Linear(embed_dim * 2, embed_dim))
|
| 364 |
+
|
| 365 |
+
def forward(self, grid, rigid_retained, fill_ratios):
|
| 366 |
+
B = grid.shape[0]
|
| 367 |
+
|
| 368 |
+
proj_x = grid.max(dim=1).values
|
| 369 |
+
proj_y = grid.max(dim=2).values
|
| 370 |
+
proj_z = grid.max(dim=3).values
|
| 371 |
+
|
| 372 |
+
# Batch all 3 projections through plane_conv in single pass
|
| 373 |
+
projs_batched = torch.cat([
|
| 374 |
+
proj_x.unsqueeze(1), proj_y.unsqueeze(1), proj_z.unsqueeze(1)
|
| 375 |
+
], dim=0) # (3B, 1, 5, 5)
|
| 376 |
+
plane_all = self.plane_conv(projs_batched).reshape(3, B, -1) # (3, B, 64)
|
| 377 |
+
plane_feat = plane_all.permute(1, 0, 2).reshape(B, -1) # (B, 192)
|
| 378 |
+
|
| 379 |
+
radial = self._radial_profile(grid)
|
| 380 |
+
radial_feat = self.radial_net(radial)
|
| 381 |
+
|
| 382 |
+
sym_feat = self._symmetry_features(proj_x, proj_y, proj_z)
|
| 383 |
+
|
| 384 |
+
vox3d_feat = self.voxel_conv(grid.unsqueeze(1)).reshape(B, -1)
|
| 385 |
+
|
| 386 |
+
# Raw curvature features (shape-aware but direction-blind)
|
| 387 |
+
raw_combined = torch.cat([
|
| 388 |
+
plane_feat, radial_feat, sym_feat, vox3d_feat,
|
| 389 |
+
rigid_retained, fill_ratios], dim=-1)
|
| 390 |
+
|
| 391 |
+
# Project to gatable dimension
|
| 392 |
+
pre_gate = self.pre_gate_proj(raw_combined) # (B, embed_dim)
|
| 393 |
+
|
| 394 |
+
# Direction analysis
|
| 395 |
+
dir_gate, dir_feat, alternation = self.diff_gate(grid)
|
| 396 |
+
|
| 397 |
+
# Apply gate: direction-modulated curvature features
|
| 398 |
+
gated = pre_gate * dir_gate # (B, embed_dim) — convex/concave differentiation
|
| 399 |
+
|
| 400 |
+
# Full post-gate features
|
| 401 |
+
combined = torch.cat([gated, dir_feat, alternation, raw_combined], dim=-1)
|
| 402 |
+
|
| 403 |
+
is_curved = self.curved_head(combined)
|
| 404 |
+
curv_logits = self.curv_type_head(combined)
|
| 405 |
+
curv_feat = self.curv_features(combined)
|
| 406 |
+
return is_curved, curv_logits, curv_feat, alternation
|
| 407 |
+
|
| 408 |
+
def _radial_profile(self, grid):
|
| 409 |
+
B = grid.shape[0]
|
| 410 |
+
device = grid.device
|
| 411 |
+
coords = torch.stack(torch.meshgrid(
|
| 412 |
+
torch.arange(GS, device=device, dtype=torch.float32),
|
| 413 |
+
torch.arange(GS, device=device, dtype=torch.float32),
|
| 414 |
+
torch.arange(GS, device=device, dtype=torch.float32),
|
| 415 |
+
indexing="ij"), dim=-1)
|
| 416 |
+
flat_grid = grid.reshape(B, -1)
|
| 417 |
+
flat_coords = coords.reshape(-1, 3)
|
| 418 |
+
total_occ = flat_grid.sum(dim=-1, keepdim=True).clamp(min=1)
|
| 419 |
+
centroids = (flat_grid.unsqueeze(-1) * flat_coords.unsqueeze(0)).sum(dim=1) / total_occ
|
| 420 |
+
diffs = flat_coords.unsqueeze(0) - centroids.unsqueeze(1)
|
| 421 |
+
dists = diffs.norm(dim=-1) # (B, 125)
|
| 422 |
+
max_dist = 3.5
|
| 423 |
+
n_bins = 5
|
| 424 |
+
# Arithmetic binning (Inductor-safe, no bucketize)
|
| 425 |
+
bin_idx = torch.nan_to_num(dists * (float(n_bins) / max_dist), nan=0.0).long().clamp(0, n_bins - 1)
|
| 426 |
+
one_hot = F.one_hot(bin_idx, n_bins).float() # (B, 125, 5)
|
| 427 |
+
weighted = flat_grid.unsqueeze(-1) * one_hot # (B, 125, 5)
|
| 428 |
+
profile = weighted.sum(dim=1) / total_occ # (B, 5)
|
| 429 |
+
return profile
|
| 430 |
+
|
| 431 |
+
def _symmetry_features(self, proj_x, proj_y, proj_z):
|
| 432 |
+
projs = torch.stack([proj_x, proj_y, proj_z], dim=1) # (B, 3, H, W)
|
| 433 |
+
fh = torch.flip(projs, dims=[2])
|
| 434 |
+
fv = torch.flip(projs, dims=[3])
|
| 435 |
+
sym = 1.0 - ((projs - fh).abs().mean(dim=(2, 3)) +
|
| 436 |
+
(projs - fv).abs().mean(dim=(2, 3))) / 2 # (B, 3)
|
| 437 |
+
shift_diff = (projs[:, :, 1:, :] - projs[:, :, :-1, :]).abs().mean(dim=(2, 3)) # (B, 3)
|
| 438 |
+
trans_inv = 1.0 - shift_diff
|
| 439 |
+
# Interleave: [sym0, trans0, sym1, trans1, sym2, trans2]
|
| 440 |
+
return torch.stack([sym[:, 0], trans_inv[:, 0],
|
| 441 |
+
sym[:, 1], trans_inv[:, 1],
|
| 442 |
+
sym[:, 2], trans_inv[:, 2]], dim=-1) # (B, 6)
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
# === Confidence Computation ====================================================
|
| 446 |
+
|
| 447 |
+
def compute_confidence(logits):
|
| 448 |
+
"""
|
| 449 |
+
Compute real calibrated confidence metrics from logits.
|
| 450 |
+
|
| 451 |
+
Returns dict with:
|
| 452 |
+
max_prob: max(softmax(logits)) — calibrated top-class probability
|
| 453 |
+
margin: top1_prob - top2_prob — disambiguation strength
|
| 454 |
+
entropy: -sum(p * log(p)) — total uncertainty (lower = more confident)
|
| 455 |
+
confidence: margin — primary confidence signal for gating
|
| 456 |
+
"""
|
| 457 |
+
probs = F.softmax(logits, dim=-1)
|
| 458 |
+
max_prob, _ = probs.max(dim=-1)
|
| 459 |
+
|
| 460 |
+
top2 = probs.topk(2, dim=-1).values
|
| 461 |
+
margin = top2[:, 0] - top2[:, 1]
|
| 462 |
+
|
| 463 |
+
# Entropy normalized to [0, 1] range
|
| 464 |
+
log_probs = F.log_softmax(logits, dim=-1)
|
| 465 |
+
entropy = -(probs * log_probs).sum(dim=-1)
|
| 466 |
+
max_entropy = math.log(logits.shape[-1])
|
| 467 |
+
norm_entropy = entropy / max_entropy
|
| 468 |
+
|
| 469 |
+
return {
|
| 470 |
+
"max_prob": max_prob,
|
| 471 |
+
"margin": margin,
|
| 472 |
+
"entropy": norm_entropy,
|
| 473 |
+
"confidence": margin, # primary signal
|
| 474 |
+
}
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
# === Rectified Flow Arbiter ===================================================
|
| 478 |
+
|
| 479 |
+
class RectifiedFlowArbiter(nn.Module):
|
| 480 |
+
"""
|
| 481 |
+
Rectified flow matching for ambiguous classification refinement.
|
| 482 |
+
|
| 483 |
+
Real flow matching requires a target endpoint to define the velocity field.
|
| 484 |
+
We learn class prototypes in latent space as targets: for a sample of class c,
|
| 485 |
+
the target is prototype[c]. The velocity field learns to transport the
|
| 486 |
+
encoded feature z0 toward the correct prototype z1 in straight lines:
|
| 487 |
+
|
| 488 |
+
v_target = z1 - z0 (rectified: straight path from source to target)
|
| 489 |
+
loss = ||v_predicted - v_target||^2 (flow matching objective)
|
| 490 |
+
|
| 491 |
+
At inference, the arbiter integrates the learned velocity field from z0,
|
| 492 |
+
landing near the correct class prototype. Classification reads off the
|
| 493 |
+
nearest prototype.
|
| 494 |
+
|
| 495 |
+
Confidence gating: velocity magnitude is scaled by (1 - margin), so
|
| 496 |
+
confident first-pass predictions receive minimal correction.
|
| 497 |
+
"""
|
| 498 |
+
|
| 499 |
+
def __init__(self, feat_dim, n_classes, n_steps=4, latent_dim=128, embed_dim=64):
|
| 500 |
+
super().__init__()
|
| 501 |
+
self.n_steps = n_steps
|
| 502 |
+
self.n_classes = n_classes
|
| 503 |
+
self.dt = 1.0 / n_steps
|
| 504 |
+
self.latent_dim = latent_dim
|
| 505 |
+
|
| 506 |
+
# Project features to latent space
|
| 507 |
+
self.encode = nn.Sequential(
|
| 508 |
+
nn.Linear(feat_dim, latent_dim * 2), nn.GELU(),
|
| 509 |
+
nn.Linear(latent_dim * 2, latent_dim))
|
| 510 |
+
|
| 511 |
+
# Learnable class prototypes — target endpoints for flow
|
| 512 |
+
self.prototypes = nn.Parameter(torch.randn(n_classes, latent_dim) * 0.05)
|
| 513 |
+
|
| 514 |
+
# Timestep embedding
|
| 515 |
+
self.time_embed = nn.Sequential(
|
| 516 |
+
nn.Linear(16, embed_dim), nn.GELU(),
|
| 517 |
+
nn.Linear(embed_dim, embed_dim))
|
| 518 |
+
|
| 519 |
+
# Confidence embedding
|
| 520 |
+
self.conf_embed = nn.Sequential(
|
| 521 |
+
nn.Linear(3, embed_dim), nn.GELU(),
|
| 522 |
+
nn.Linear(embed_dim, embed_dim))
|
| 523 |
+
|
| 524 |
+
# Velocity network: predicts flow direction in latent space
|
| 525 |
+
vel_in = latent_dim + embed_dim + embed_dim
|
| 526 |
+
self.velocity = nn.Sequential(
|
| 527 |
+
SwiGLU(vel_in, latent_dim),
|
| 528 |
+
nn.Linear(latent_dim, latent_dim),
|
| 529 |
+
SwiGLU(latent_dim, latent_dim),
|
| 530 |
+
nn.Linear(latent_dim, latent_dim))
|
| 531 |
+
|
| 532 |
+
# Velocity gate: low confidence → full correction, high → minimal
|
| 533 |
+
self.vel_gate = nn.Sequential(
|
| 534 |
+
nn.Linear(embed_dim, latent_dim), nn.Sigmoid())
|
| 535 |
+
|
| 536 |
+
# Classification from latent: distance to prototypes + learned head
|
| 537 |
+
self.classifier_head = nn.Sequential(
|
| 538 |
+
SwiGLU(latent_dim + n_classes, 96),
|
| 539 |
+
nn.Linear(96, n_classes))
|
| 540 |
+
|
| 541 |
+
# Learned confidence head for blending (differentiable, not topk)
|
| 542 |
+
self.blend_head = nn.Sequential(
|
| 543 |
+
nn.Linear(feat_dim, 64), nn.GELU(),
|
| 544 |
+
nn.Linear(64, 1), nn.Sigmoid())
|
| 545 |
+
|
| 546 |
+
# Post-refinement confidence
|
| 547 |
+
self.refined_confidence = nn.Sequential(
|
| 548 |
+
SwiGLU(latent_dim, 32),
|
| 549 |
+
nn.Linear(32, 1), nn.Sigmoid())
|
| 550 |
+
|
| 551 |
+
def _time_encoding(self, t, device):
|
| 552 |
+
freqs = torch.exp(torch.linspace(0, -4, 8, device=device))
|
| 553 |
+
args = t.unsqueeze(-1) * freqs.unsqueeze(0)
|
| 554 |
+
return torch.cat([args.sin(), args.cos()], dim=-1)
|
| 555 |
+
|
| 556 |
+
def _proto_logits(self, z):
|
| 557 |
+
"""Classify by negative distance to prototypes."""
|
| 558 |
+
# (B, latent) vs (C, latent) → (B, C) distances
|
| 559 |
+
dists = torch.cdist(z.unsqueeze(0), self.prototypes.unsqueeze(0)).squeeze(0)
|
| 560 |
+
# Combine distance signal with learned head
|
| 561 |
+
combined = torch.cat([z, -dists], dim=-1) # (B, latent + n_classes)
|
| 562 |
+
return self.classifier_head(combined)
|
| 563 |
+
|
| 564 |
+
def forward(self, features, initial_logits, labels=None):
|
| 565 |
+
"""
|
| 566 |
+
features: (B, feat_dim)
|
| 567 |
+
initial_logits: (B, n_classes)
|
| 568 |
+
labels: (B,) — only during training, for flow matching target
|
| 569 |
+
|
| 570 |
+
Returns:
|
| 571 |
+
refined_logits, refined_conf, initial_conf, trajectory_logits, flow_loss
|
| 572 |
+
"""
|
| 573 |
+
B = features.shape[0]
|
| 574 |
+
device = features.device
|
| 575 |
+
|
| 576 |
+
# Confidence from initial logits
|
| 577 |
+
initial_conf = compute_confidence(initial_logits)
|
| 578 |
+
conf_input = torch.stack([
|
| 579 |
+
initial_conf["max_prob"],
|
| 580 |
+
initial_conf["margin"],
|
| 581 |
+
initial_conf["entropy"]], dim=-1)
|
| 582 |
+
conf_emb = self.conf_embed(conf_input)
|
| 583 |
+
|
| 584 |
+
# Confidence-gated velocity magnitude
|
| 585 |
+
gate = self.vel_gate(conf_emb)
|
| 586 |
+
inv_conf = (1.0 - initial_conf["margin"]).unsqueeze(-1)
|
| 587 |
+
adaptive_gate = gate * inv_conf
|
| 588 |
+
|
| 589 |
+
# Encode to latent
|
| 590 |
+
z0 = self.encode(features)
|
| 591 |
+
|
| 592 |
+
# === Flow matching target ===
|
| 593 |
+
flow_loss = torch.tensor(0.0, device=device)
|
| 594 |
+
if labels is not None:
|
| 595 |
+
# Target: class prototype for each sample
|
| 596 |
+
z1 = self.prototypes[labels] # (B, latent_dim)
|
| 597 |
+
# Target velocity: straight path z0 → z1
|
| 598 |
+
v_target = z1 - z0 # (B, latent_dim)
|
| 599 |
+
|
| 600 |
+
# Sample random timestep for flow matching training
|
| 601 |
+
t_rand = torch.rand(B, device=device)
|
| 602 |
+
t_emb = self.time_embed(self._time_encoding(t_rand, device))
|
| 603 |
+
|
| 604 |
+
# Interpolated position along straight path
|
| 605 |
+
z_t = z0 + t_rand.unsqueeze(-1) * v_target # (B, latent_dim)
|
| 606 |
+
|
| 607 |
+
# Predicted velocity at this point
|
| 608 |
+
vel_input = torch.cat([z_t, t_emb, conf_emb], dim=-1)
|
| 609 |
+
v_pred = self.velocity(vel_input) * adaptive_gate
|
| 610 |
+
v_pred = v_pred.clamp(-20, 20)
|
| 611 |
+
|
| 612 |
+
# Flow matching loss: predicted velocity should match target
|
| 613 |
+
flow_loss = F.mse_loss(v_pred, v_target.clamp(-20, 20))
|
| 614 |
+
|
| 615 |
+
# === Inference: integrate velocity field ===
|
| 616 |
+
z = z0
|
| 617 |
+
trajectory_logits = []
|
| 618 |
+
for step in range(self.n_steps):
|
| 619 |
+
t_val = torch.full((B,), step * self.dt, device=device)
|
| 620 |
+
t_emb = self.time_embed(self._time_encoding(t_val, device))
|
| 621 |
+
|
| 622 |
+
vel_input = torch.cat([z, t_emb, conf_emb], dim=-1)
|
| 623 |
+
v = self.velocity(vel_input) * adaptive_gate
|
| 624 |
+
# Prevent BF16 divergence: clamp velocity magnitude
|
| 625 |
+
v = v.clamp(-20, 20)
|
| 626 |
+
|
| 627 |
+
z = z + self.dt * v
|
| 628 |
+
trajectory_logits.append(self._proto_logits(z))
|
| 629 |
+
|
| 630 |
+
refined_logits = trajectory_logits[-1]
|
| 631 |
+
refined_conf = self.refined_confidence(z)
|
| 632 |
+
|
| 633 |
+
# Learned blend weight (differentiable, from initial features)
|
| 634 |
+
blend_weight = self.blend_head(features) # (B, 1)
|
| 635 |
+
|
| 636 |
+
return refined_logits, refined_conf, initial_conf, trajectory_logits, flow_loss, blend_weight
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
# === Model ====================================================================
|
| 640 |
+
|
| 641 |
+
class GeometricShapeClassifier(nn.Module):
|
| 642 |
+
def __init__(self, n_classes=NUM_CLASSES, embed_dim=64, n_tracers=5):
|
| 643 |
+
super().__init__()
|
| 644 |
+
self.n_tracers = n_tracers
|
| 645 |
+
self.embed_dim = embed_dim
|
| 646 |
+
|
| 647 |
+
self.voxel_embed = nn.Sequential(
|
| 648 |
+
nn.Linear(4, embed_dim), nn.GELU(), nn.Linear(embed_dim, embed_dim))
|
| 649 |
+
|
| 650 |
+
coords = torch.stack(torch.meshgrid(
|
| 651 |
+
torch.arange(GS, dtype=torch.float32),
|
| 652 |
+
torch.arange(GS, dtype=torch.float32),
|
| 653 |
+
torch.arange(GS, dtype=torch.float32),
|
| 654 |
+
indexing="ij"), dim=-1) / (GS - 1) # (5,5,5,3) normalized
|
| 655 |
+
self.register_buffer("pos_grid", coords)
|
| 656 |
+
|
| 657 |
+
self.tracer_tokens = nn.Parameter(torch.randn(n_tracers, embed_dim) * 0.02)
|
| 658 |
+
self.tracer_attn = nn.MultiheadAttention(embed_dim, num_heads=4, batch_first=True)
|
| 659 |
+
self.tracer_gate = nn.Sequential(nn.Linear(embed_dim * 2, embed_dim), nn.Sigmoid())
|
| 660 |
+
self.tracer_interact = nn.Sequential(
|
| 661 |
+
nn.Linear(embed_dim * 2, embed_dim), nn.GELU(), nn.Linear(embed_dim, embed_dim))
|
| 662 |
+
# SwiGLU for edge detection: sharp "edge present?" decision
|
| 663 |
+
self.edge_head = nn.Sequential(
|
| 664 |
+
SwiGLU(embed_dim * 2, 32), nn.Linear(32, 1))
|
| 665 |
+
|
| 666 |
+
# Precompute all C(n_tracers, 2) pair indices for vectorized interaction
|
| 667 |
+
_pi, _pj = [], []
|
| 668 |
+
for i in range(n_tracers):
|
| 669 |
+
for j in range(i + 1, n_tracers):
|
| 670 |
+
_pi.append(i); _pj.append(j)
|
| 671 |
+
self.register_buffer("_pair_i", torch.tensor(_pi, dtype=torch.long))
|
| 672 |
+
self.register_buffer("_pair_j", torch.tensor(_pj, dtype=torch.long))
|
| 673 |
+
self.n_pairs = len(_pi)
|
| 674 |
+
|
| 675 |
+
pool_dim = embed_dim * n_tracers
|
| 676 |
+
|
| 677 |
+
self.dim0 = CapacityHead(pool_dim, embed_dim, init_capacity=0.5)
|
| 678 |
+
self.dim1 = CapacityHead(pool_dim + embed_dim, embed_dim, init_capacity=1.0)
|
| 679 |
+
self.dim2 = CapacityHead(pool_dim + embed_dim, embed_dim, init_capacity=1.5)
|
| 680 |
+
self.dim3 = CapacityHead(pool_dim + embed_dim, embed_dim, init_capacity=2.0)
|
| 681 |
+
|
| 682 |
+
rigid_feat_dim = embed_dim * 4
|
| 683 |
+
self.curvature = CurvatureHead(rigid_feat_dim, fill_dim=4, embed_dim=embed_dim)
|
| 684 |
+
|
| 685 |
+
class_in = pool_dim + 4 + rigid_feat_dim + embed_dim + 1
|
| 686 |
+
self.class_in = class_in # Store for arbiter
|
| 687 |
+
self.classifier = nn.Sequential(
|
| 688 |
+
nn.Linear(class_in, 256), nn.GELU(), nn.Dropout(0.1),
|
| 689 |
+
nn.Linear(256, 128), nn.GELU(), nn.Linear(128, n_classes))
|
| 690 |
+
|
| 691 |
+
# SwiGLU for peak dimension: sharp "which dimension?" decision
|
| 692 |
+
self.peak_head = nn.Sequential(
|
| 693 |
+
SwiGLU(class_in, 32), nn.Linear(32, 4))
|
| 694 |
+
# Volume is continuous interpolation — keep GELU
|
| 695 |
+
self.volume_head = nn.Sequential(
|
| 696 |
+
nn.Linear(class_in, 64), nn.GELU(), nn.Linear(64, 1))
|
| 697 |
+
# SwiGLU for CM determinant sign: sharp geometric determinant
|
| 698 |
+
self.cm_head = nn.Sequential(
|
| 699 |
+
SwiGLU(class_in, 64), nn.Linear(64, 1), nn.Tanh())
|
| 700 |
+
|
| 701 |
+
# Rectified flow arbiter for ambiguous classification
|
| 702 |
+
self.arbiter = RectifiedFlowArbiter(
|
| 703 |
+
feat_dim=class_in, n_classes=n_classes,
|
| 704 |
+
n_steps=4, latent_dim=128, embed_dim=embed_dim)
|
| 705 |
+
|
| 706 |
+
def forward(self, grid, labels=None):
|
| 707 |
+
B = grid.shape[0]
|
| 708 |
+
occ = grid.reshape(B, GS**3, 1)
|
| 709 |
+
pos = self.pos_grid.reshape(1, GS**3, 3).expand(B, -1, -1)
|
| 710 |
+
voxel_emb = self.voxel_embed(torch.cat([occ, pos], dim=-1))
|
| 711 |
+
|
| 712 |
+
tracers = self.tracer_tokens.unsqueeze(0).expand(B, -1, -1)
|
| 713 |
+
tracers, _ = self.tracer_attn(tracers, voxel_emb, voxel_emb)
|
| 714 |
+
|
| 715 |
+
# Vectorized pair interaction: all C(5,2)=10 pairs at once
|
| 716 |
+
left = tracers[:, self._pair_i] # (B, 10, embed_dim)
|
| 717 |
+
right = tracers[:, self._pair_j] # (B, 10, embed_dim)
|
| 718 |
+
pairs = torch.cat([left, right], dim=-1) # (B, 10, embed_dim*2)
|
| 719 |
+
|
| 720 |
+
# Flatten to batch, run networks, reshape back
|
| 721 |
+
flat_pairs = pairs.reshape(B * self.n_pairs, -1)
|
| 722 |
+
gate = self.tracer_gate(flat_pairs).reshape(B, self.n_pairs, -1)
|
| 723 |
+
interaction = self.tracer_interact(flat_pairs).reshape(B, self.n_pairs, -1)
|
| 724 |
+
edge_lengths = self.edge_head(flat_pairs).reshape(B, self.n_pairs)
|
| 725 |
+
|
| 726 |
+
# Scatter-add gated interactions back to both tracers in each pair
|
| 727 |
+
gated = gate * interaction # (B, 10, embed_dim)
|
| 728 |
+
tracer_out = tracers.clone()
|
| 729 |
+
pi_exp = self._pair_i.view(1, self.n_pairs, 1).expand(B, -1, self.embed_dim)
|
| 730 |
+
pj_exp = self._pair_j.view(1, self.n_pairs, 1).expand(B, -1, self.embed_dim)
|
| 731 |
+
tracer_out.scatter_add_(1, pi_exp, gated)
|
| 732 |
+
tracer_out.scatter_add_(1, pj_exp, gated)
|
| 733 |
+
pooled = tracer_out.reshape(B, -1)
|
| 734 |
+
|
| 735 |
+
fill0, ovf0, ret0, cap0, _ = self.dim0(pooled)
|
| 736 |
+
fill1, ovf1, ret1, cap1, _ = self.dim1(torch.cat([pooled, ovf0], -1))
|
| 737 |
+
fill2, ovf2, ret2, cap2, _ = self.dim2(torch.cat([pooled, ovf1], -1))
|
| 738 |
+
fill3, ovf3, ret3, cap3, _ = self.dim3(torch.cat([pooled, ovf2], -1))
|
| 739 |
+
|
| 740 |
+
fill_ratios = torch.cat([fill0, fill1, fill2, fill3], dim=-1)
|
| 741 |
+
rigid_retained = torch.cat([ret0, ret1, ret2, ret3], dim=-1)
|
| 742 |
+
ovf_norms = torch.stack([
|
| 743 |
+
ovf0.norm(dim=-1), ovf1.norm(dim=-1),
|
| 744 |
+
ovf2.norm(dim=-1), ovf3.norm(dim=-1)], dim=-1)
|
| 745 |
+
|
| 746 |
+
is_curved, curv_logits, curv_feat, alternation = self.curvature(grid, rigid_retained, fill_ratios)
|
| 747 |
+
full = torch.cat([pooled, fill_ratios, rigid_retained, curv_feat, is_curved], dim=-1)
|
| 748 |
+
|
| 749 |
+
# === First pass classification ===
|
| 750 |
+
initial_logits = self.classifier(full)
|
| 751 |
+
|
| 752 |
+
# === Rectified flow arbitration ===
|
| 753 |
+
refined_logits, refined_conf, initial_conf, trajectory_logits, flow_loss, blend_weight = \
|
| 754 |
+
self.arbiter(full, initial_logits, labels=labels)
|
| 755 |
+
|
| 756 |
+
# === Blend: learned confidence head decides trust ===
|
| 757 |
+
# blend_weight is (B, 1) sigmoid output from learned head
|
| 758 |
+
final_logits = blend_weight * initial_logits + (1.0 - blend_weight) * refined_logits
|
| 759 |
+
|
| 760 |
+
return {
|
| 761 |
+
# Classification
|
| 762 |
+
"class_logits": final_logits,
|
| 763 |
+
"initial_logits": initial_logits,
|
| 764 |
+
"refined_logits": refined_logits,
|
| 765 |
+
"trajectory_logits": trajectory_logits,
|
| 766 |
+
# Flow matching
|
| 767 |
+
"flow_loss": flow_loss,
|
| 768 |
+
# Confidence
|
| 769 |
+
"confidence": initial_conf["confidence"],
|
| 770 |
+
"max_prob": initial_conf["max_prob"],
|
| 771 |
+
"entropy": initial_conf["entropy"],
|
| 772 |
+
"refined_confidence": refined_conf,
|
| 773 |
+
"blend_weight": blend_weight.squeeze(-1),
|
| 774 |
+
# Auxiliary heads
|
| 775 |
+
"peak_logits": self.peak_head(full),
|
| 776 |
+
"volume_pred": self.volume_head(full).squeeze(-1),
|
| 777 |
+
"cm_pred": self.cm_head(full).squeeze(-1),
|
| 778 |
+
"edge_lengths": edge_lengths,
|
| 779 |
+
"fill_ratios": fill_ratios,
|
| 780 |
+
"overflows": ovf_norms,
|
| 781 |
+
"capacities": torch.stack([cap0, cap1, cap2, cap3]),
|
| 782 |
+
"is_curved_pred": is_curved,
|
| 783 |
+
"curv_type_logits": curv_logits,
|
| 784 |
+
"alternation": alternation,
|
| 785 |
+
# Pre-classifier features (for cross-contrast)
|
| 786 |
+
"features": full,
|
| 787 |
+
}
|
| 788 |
+
|
| 789 |
+
|
| 790 |
+
# Quick sanity
|
| 791 |
+
_m = GeometricShapeClassifier()
|
| 792 |
+
print(f'GeometricShapeClassifier: {sum(p.numel() for p in _m.parameters()):,} params')
|
| 793 |
+
del _m
|