AbstractPhil commited on
Commit
3d7d199
·
verified ·
1 Parent(s): 4706821

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +793 -0
model.py ADDED
@@ -0,0 +1,793 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Cell 2
2
+ # === Capacity Head ============================================================
3
+
4
+ class CapacityHead(nn.Module):
5
+ def __init__(self, in_dim, feat_dim, init_capacity=1.0):
6
+ super().__init__()
7
+ self._raw_capacity = nn.Parameter(torch.tensor(math.log(math.exp(init_capacity) - 1)))
8
+ # GELU for cascade: smooth gradients needed for overflow propagation
9
+ self.evidence_net = nn.Sequential(
10
+ nn.Linear(in_dim, feat_dim), nn.GELU(), nn.Linear(feat_dim, 1))
11
+ self.feature_net = nn.Sequential(
12
+ nn.Linear(in_dim, feat_dim), nn.GELU(), nn.Linear(feat_dim, feat_dim))
13
+ self.retain_gate = nn.Sequential(
14
+ nn.Linear(feat_dim + 1, feat_dim), nn.Sigmoid())
15
+ self.overflow_gate = nn.Sequential(
16
+ nn.Linear(feat_dim + 1, feat_dim), nn.Sigmoid())
17
+
18
+ @property
19
+ def capacity(self):
20
+ return F.softplus(self._raw_capacity)
21
+
22
+ def forward(self, x):
23
+ cap = self.capacity
24
+ raw_ev = F.relu(self.evidence_net(x))
25
+ fill = torch.clamp(raw_ev / (cap + 1e-8), max=1.0)
26
+ sat = torch.clamp((raw_ev - cap) / (cap + 1e-8), min=0.0)
27
+ feat = self.feature_net(x)
28
+ retained = self.retain_gate(torch.cat([feat, fill], -1)) * feat * fill
29
+ overflow = self.overflow_gate(torch.cat([feat, sat], -1)) * feat * torch.clamp(sat, max=1.0)
30
+ return fill, overflow, retained, cap, raw_ev
31
+
32
+
33
+ # === Differentiation Gate =====================================================
34
+
35
+ class DifferentiationGate(nn.Module):
36
+ """
37
+ Curvature direction analysis via occupancy field differentiation.
38
+
39
+ Computes gradient and Laplacian of the 3D occupancy field to determine:
40
+ - Curvature direction: convex (normals point outward) vs concave (inward)
41
+ - Curvature alternation: where sign flips (saddle points, torus inner/outer)
42
+ - Perturbation robustness: smoothed gradient features survive noise
43
+
44
+ The key insight: a hemisphere and bowl occupy nearly identical voxels,
45
+ but their occupancy gradients point in opposite directions relative
46
+ to the center of mass. The Laplacian's sign distinguishes them.
47
+
48
+ Outputs gate signals that modulate curvature features:
49
+ - direction_gate: learned weighting based on gradient analysis
50
+ - alternation_score: how much curvature sign varies spatially
51
+ - directional_features: rich features encoding curvature orientation
52
+ """
53
+
54
+ def __init__(self, embed_dim=64):
55
+ super().__init__()
56
+
57
+ # Fixed 3D differentiation kernels — fused into single conv
58
+ # 4 output channels: [grad_x, grad_y, grad_z, laplacian]
59
+ diff_kernels = torch.zeros(4, 1, 3, 3, 3)
60
+ # Sobel X
61
+ diff_kernels[0, 0, 0, 1, 1] = -1; diff_kernels[0, 0, 2, 1, 1] = 1
62
+ # Sobel Y
63
+ diff_kernels[1, 0, 1, 0, 1] = -1; diff_kernels[1, 0, 1, 2, 1] = 1
64
+ # Sobel Z
65
+ diff_kernels[2, 0, 1, 1, 0] = -1; diff_kernels[2, 0, 1, 1, 2] = 1
66
+ # Laplacian
67
+ diff_kernels[3, 0, 1, 1, 1] = -6
68
+ diff_kernels[3, 0, 0, 1, 1] = 1; diff_kernels[3, 0, 2, 1, 1] = 1
69
+ diff_kernels[3, 0, 1, 0, 1] = 1; diff_kernels[3, 0, 1, 2, 1] = 1
70
+ diff_kernels[3, 0, 1, 1, 0] = 1; diff_kernels[3, 0, 1, 1, 2] = 1
71
+ self.register_buffer("diff_kernels", diff_kernels)
72
+
73
+ # Precompute coordinate grid
74
+ coords = torch.stack(torch.meshgrid(
75
+ torch.arange(GS, dtype=torch.float32),
76
+ torch.arange(GS, dtype=torch.float32),
77
+ torch.arange(GS, dtype=torch.float32),
78
+ indexing="ij"), dim=-1) # (5,5,5,3)
79
+ self.register_buffer("coords", coords)
80
+
81
+ # Process gradient-derived features
82
+ # Per-voxel: gradient direction, Laplacian sign, centroid-relative direction
83
+ # Summarized as histograms and statistics
84
+
85
+ # Gradient direction relative to centroid: 3 histogram bins per axis
86
+ # + Laplacian sign distribution: 3 values (frac_pos, frac_neg, frac_zero)
87
+ # + Alternation score: 1 value
88
+ # + Per-axis gradient asymmetry: 3 values
89
+ # + Radial gradient profile: 5 bins
90
+ raw_feat_dim = 3 + 3 + 1 + 3 + 5 # = 15
91
+ # Plus the 3D conv on the Laplacian field preserving spatial structure
92
+ self.lap_conv = nn.Sequential(
93
+ nn.Conv3d(1, 16, 3, padding=1), nn.GELU(),
94
+ nn.Conv3d(16, 16, 3, padding=1), nn.GELU(),
95
+ nn.AdaptiveAvgPool3d(2)) # -> (B, 16, 2, 2, 2) = 128
96
+ lap_conv_dim = 16 * 8 # 128
97
+
98
+ # Gradient magnitude 3D conv (encodes where boundaries are + direction)
99
+ self.grad_conv = nn.Sequential(
100
+ nn.Conv3d(3, 16, 3, padding=1), nn.GELU(), # 3-channel: dx, dy, dz
101
+ nn.Conv3d(16, 16, 3, padding=1), nn.GELU(),
102
+ nn.AdaptiveAvgPool3d(2)) # -> (B, 16, 2, 2, 2) = 128
103
+ grad_conv_dim = 16 * 8 # 128
104
+
105
+ total_feat_dim = raw_feat_dim + lap_conv_dim + grad_conv_dim # 15 + 128 + 128 = 271
106
+
107
+ # Direction gate: SwiGLU for sharp convex/concave gating
108
+ self.direction_net = nn.Sequential(
109
+ SwiGLU(total_feat_dim, embed_dim),
110
+ nn.Linear(embed_dim, embed_dim), nn.Sigmoid())
111
+
112
+ # Directional features: SwiGLU for crisp direction encoding
113
+ self.direction_feat_net = nn.Sequential(
114
+ SwiGLU(total_feat_dim, embed_dim),
115
+ nn.Linear(embed_dim, embed_dim))
116
+
117
+ def forward(self, grid):
118
+ """
119
+ grid: (B, 5, 5, 5) binary occupancy
120
+
121
+ Returns:
122
+ direction_gate: (B, embed_dim) sigmoid gate for curvature features
123
+ direction_feat: (B, embed_dim) additive directional features
124
+ alternation_score: (B, 1) how much curvature alternates
125
+ """
126
+ B = grid.shape[0]
127
+ device = grid.device
128
+ vox = grid.unsqueeze(1) # (B, 1, 5, 5, 5)
129
+
130
+ # === Smooth occupancy before differentiation ===
131
+ # Binary voxels produce spike gradients. Light blur creates
132
+ # a continuous field whose derivatives are geometrically meaningful.
133
+ vox_smooth = F.avg_pool3d(
134
+ F.pad(vox, (1,1,1,1,1,1), mode='replicate'),
135
+ kernel_size=3, stride=1, padding=0) # (B, 1, 5, 5, 5)
136
+
137
+ # === Compute gradients + Laplacian in single fused conv ===
138
+ diff = F.conv3d(vox_smooth, self.diff_kernels, padding=1) # (B, 4, 5, 5, 5)
139
+ grad_field = diff[:, :3] # (B, 3, 5, 5, 5) — gx, gy, gz
140
+ gx, gy, gz = diff[:, 0:1], diff[:, 1:2], diff[:, 2:3]
141
+ lap = diff[:, 3:4] # (B, 1, 5, 5, 5)
142
+
143
+ # === Centroid ===
144
+ flat_grid = grid.reshape(B, -1) # (B, 125)
145
+ flat_coords = self.coords.reshape(-1, 3) # (125, 3)
146
+ total_occ = flat_grid.sum(dim=-1, keepdim=True).clamp(min=1) # (B, 1)
147
+ centroids = (flat_grid.unsqueeze(-1) * flat_coords.unsqueeze(0)).sum(dim=1) / total_occ # (B, 3)
148
+
149
+ # === Gradient direction relative to centroid ===
150
+ grad_flat = grad_field.reshape(B, 3, -1).permute(0, 2, 1) # (B, 125, 3)
151
+ diff_from_center = flat_coords.unsqueeze(0) - centroids.unsqueeze(1) # (B, 125, 3)
152
+ diff_norm = diff_from_center / (diff_from_center.norm(dim=-1, keepdim=True) + 1e-8)
153
+ dot_products = (grad_flat * diff_norm).sum(dim=-1) # (B, 125)
154
+ grad_mag = grad_flat.norm(dim=-1) # (B, 125)
155
+ active = (flat_grid > 0.5) & (grad_mag > 0.01)
156
+
157
+ # Histogram of dot product signs (convex/concave/neutral fractions)
158
+ n_active = active.float().sum(-1).clamp(min=1)
159
+ frac_outward = ((dot_products > 0.1) & active).float().sum(-1) / n_active
160
+ frac_inward = ((dot_products < -0.1) & active).float().sum(-1) / n_active
161
+ frac_neutral = 1.0 - frac_outward - frac_inward
162
+ direction_hist = torch.stack([frac_outward, frac_inward, frac_neutral], dim=-1) # (B, 3)
163
+
164
+ # === Laplacian sign distribution (active voxels only) ===
165
+ lap_flat = lap.reshape(B, -1) # (B, 125)
166
+ lap_active = flat_grid > 0.5
167
+ n_lap_active = lap_active.float().sum(-1).clamp(min=1)
168
+ frac_pos_lap = ((lap_flat > 0.1) & lap_active).float().sum(-1) / n_lap_active
169
+ frac_neg_lap = ((lap_flat < -0.1) & lap_active).float().sum(-1) / n_lap_active
170
+ frac_zero_lap = 1.0 - frac_pos_lap - frac_neg_lap
171
+ lap_hist = torch.stack([frac_pos_lap, frac_neg_lap, frac_zero_lap], dim=-1) # (B, 3)
172
+
173
+ # === Alternation score (ACTIVE VOXELS ONLY) ===
174
+ # Only count sign flips between neighbor pairs where BOTH voxels are
175
+ # near occupied regions. Otherwise empty space dilutes the signal.
176
+ lap_3d = lap.squeeze(1) # (B, 5, 5, 5)
177
+ # Boundary mask: dilate occupancy by 1 to include immediate neighbors
178
+ boundary_mask = F.max_pool3d(vox, kernel_size=3, stride=1, padding=1).squeeze(1) # (B,5,5,5)
179
+
180
+ # X-axis: both neighbors must be in boundary region
181
+ bm_x = boundary_mask[:, 1:, :, :] * boundary_mask[:, :-1, :, :] # (B,4,5,5)
182
+ flip_x = (torch.sign(lap_3d[:, 1:, :, :]) * torch.sign(lap_3d[:, :-1, :, :]) < 0).float()
183
+ active_flips_x = (flip_x * bm_x).sum(dim=(1, 2, 3))
184
+ active_pairs_x = bm_x.sum(dim=(1, 2, 3)).clamp(min=1)
185
+
186
+ bm_y = boundary_mask[:, :, 1:, :] * boundary_mask[:, :, :-1, :]
187
+ flip_y = (torch.sign(lap_3d[:, :, 1:, :]) * torch.sign(lap_3d[:, :, :-1, :]) < 0).float()
188
+ active_flips_y = (flip_y * bm_y).sum(dim=(1, 2, 3))
189
+ active_pairs_y = bm_y.sum(dim=(1, 2, 3)).clamp(min=1)
190
+
191
+ bm_z = boundary_mask[:, :, :, 1:] * boundary_mask[:, :, :, :-1]
192
+ flip_z = (torch.sign(lap_3d[:, :, :, 1:]) * torch.sign(lap_3d[:, :, :, :-1]) < 0).float()
193
+ active_flips_z = (flip_z * bm_z).sum(dim=(1, 2, 3))
194
+ active_pairs_z = bm_z.sum(dim=(1, 2, 3)).clamp(min=1)
195
+
196
+ alternation = ((active_flips_x / active_pairs_x +
197
+ active_flips_y / active_pairs_y +
198
+ active_flips_z / active_pairs_z) / 3.0).unsqueeze(-1) # (B, 1)
199
+
200
+ # === Per-axis gradient asymmetry ===
201
+ # Asymmetry: mean gradient along each axis (nonzero = asymmetric curvature)
202
+ gx_mean = (gx.squeeze(1) * grid).sum(dim=(1, 2, 3)) / total_occ.squeeze(-1)
203
+ gy_mean = (gy.squeeze(1) * grid).sum(dim=(1, 2, 3)) / total_occ.squeeze(-1)
204
+ gz_mean = (gz.squeeze(1) * grid).sum(dim=(1, 2, 3)) / total_occ.squeeze(-1)
205
+ grad_asym = torch.stack([gx_mean, gy_mean, gz_mean], dim=-1) # (B, 3)
206
+
207
+ # === Radial gradient profile ===
208
+ # How does gradient magnitude vary with distance from centroid?
209
+ dists = diff_from_center.norm(dim=-1) # (B, 125)
210
+ # Arithmetic binning (Inductor-safe, no bucketize)
211
+ # nan_to_num prevents NaN→long producing garbage indices under BF16
212
+ bin_idx = torch.nan_to_num(dists * (5.0 / 3.5), nan=0.0).long().clamp(0, 4)
213
+ active_mask = (flat_grid > 0.5) # (B, 125)
214
+ radial_grad = torch.zeros(B, 5, device=device)
215
+ # Scatter-add: accumulate grad_mag and counts per bin
216
+ weighted_mag = grad_mag * active_mask.float() # zero out inactive
217
+ one_hot = F.one_hot(bin_idx, 5).float() # (B, 125, 5)
218
+ active_oh = one_hot * active_mask.float().unsqueeze(-1) # mask inactive
219
+ counts = active_oh.sum(dim=1).clamp(min=1) # (B, 5)
220
+ radial_grad = (weighted_mag.unsqueeze(-1) * active_oh).sum(dim=1) / counts
221
+ # (B, 5)
222
+
223
+ # === Conv on Laplacian field (spatial curvature map) ===
224
+ lap_feat = self.lap_conv(lap).reshape(B, -1) # (B, 128)
225
+
226
+ # === Conv on gradient field (directional boundaries) ===
227
+ grad_feat = self.grad_conv(grad_field).reshape(B, -1) # (B, 128)
228
+
229
+ # === Combine all ===
230
+ raw_feat = torch.cat([
231
+ direction_hist, # 3
232
+ lap_hist, # 3
233
+ alternation, # 1
234
+ grad_asym, # 3
235
+ radial_grad, # 5
236
+ ], dim=-1) # (B, 15)
237
+
238
+ all_feat = torch.cat([raw_feat, lap_feat, grad_feat], dim=-1) # (B, 271)
239
+
240
+ direction_gate = self.direction_net(all_feat) # (B, embed_dim) sigmoid
241
+ direction_feat = self.direction_feat_net(all_feat) # (B, embed_dim)
242
+
243
+ return direction_gate, direction_feat, alternation
244
+
245
+
246
+ # === Deformation Augmentation =================================================
247
+
248
+ def deform_grid(grid, p_dropout=0.1, p_add=0.1, p_shift=0.15):
249
+ """Fully vectorized voxel augmentation — zero CPU-GPU sync points."""
250
+ B = grid.shape[0]
251
+ device = grid.device
252
+ r = torch.rand(B, 3, device=device)
253
+ out = grid.clone()
254
+
255
+ # --- Voxel dropout (batched, no .any() sync) ---
256
+ drop_sel = (r[:, 0] < p_dropout).view(B, 1, 1, 1)
257
+ keep = torch.rand_like(out) > 0.15
258
+ out = torch.where(drop_sel, out * keep.float(), out)
259
+
260
+ # --- Boundary addition (batched, no .any() sync) ---
261
+ add_sel = (r[:, 1] < p_add).view(B, 1, 1, 1).float()
262
+ dilated = F.max_pool3d(out.unsqueeze(1), kernel_size=3, stride=1, padding=1).squeeze(1)
263
+ boundary = ((dilated > 0.5) & (out < 0.5)).float()
264
+ add_noise = (torch.rand_like(out) < 0.3).float()
265
+ out = (out + boundary * add_noise * add_sel).clamp(max=1.0)
266
+
267
+ # --- Small translation (fully vectorized, no loops, no boolean indexing) ---
268
+ shift_sel = (r[:, 2] < p_shift) # (B,)
269
+ axes = torch.randint(3, (B,), device=device)
270
+ dirs = torch.randint(0, 2, (B,), device=device) * 2 - 1
271
+
272
+ # Precompute all 6 shifted versions of full batch (cheap for 5x5x5)
273
+ # Encode: idx = axis * 2 + (dir==1) → [0..5], 6 = no shift
274
+ versions = []
275
+ for ax in range(3):
276
+ for d in [-1, 1]:
277
+ s = torch.roll(out, shifts=d, dims=ax + 1) # +1 for batch dim
278
+ # Zero wrapped edge
279
+ if d == 1:
280
+ if ax == 0: s[:, 0, :, :] = 0
281
+ elif ax == 1: s[:, :, 0, :] = 0
282
+ else: s[:, :, :, 0] = 0
283
+ else:
284
+ if ax == 0: s[:, -1, :, :] = 0
285
+ elif ax == 1: s[:, :, -1, :] = 0
286
+ else: s[:, :, :, -1] = 0
287
+ versions.append(s)
288
+ versions.append(out) # index 6 = no shift (identity)
289
+ stacked = torch.stack(versions, dim=0) # (7, B, 5, 5, 5)
290
+
291
+ # Per-sample assignment: which version to pick
292
+ assign = torch.where(shift_sel, axes * 2 + (dirs == 1).long(), torch.full_like(axes, 6))
293
+ # Gather: stacked[assign[b], b] for each b
294
+ out = stacked[assign, torch.arange(B, device=device)]
295
+
296
+ return out
297
+
298
+
299
+ # === Curvature Head (axis-aware) ==============================================
300
+
301
+ class CurvatureHead(nn.Module):
302
+ """
303
+ Axis-aware curvature detection with differentiation gating.
304
+
305
+ 1. Per-axis max projections -> 2D conv (keeps 2×2 spatial)
306
+ 2. Radial occupancy profile from centroid
307
+ 3. Axial symmetry + translation invariance scores
308
+ 4. 3D conv with spatial preservation (2×2×2)
309
+ 5. DifferentiationGate: gradient/Laplacian analysis for direction detection
310
+
311
+ The DifferentiationGate modulates curvature features so that
312
+ convex and concave shapes get distinct representations even when
313
+ their occupancy patterns are nearly identical.
314
+ """
315
+
316
+ def __init__(self, rigid_feat_dim, fill_dim, embed_dim):
317
+ super().__init__()
318
+
319
+ self.plane_conv = nn.Sequential(
320
+ nn.Conv2d(1, 16, 3, padding=1), nn.GELU(),
321
+ nn.Conv2d(16, 16, 3, padding=1), nn.GELU(),
322
+ nn.AdaptiveAvgPool2d(2))
323
+ plane_feat_dim = 3 * 16 * 4 # 192
324
+
325
+ n_radial = 5
326
+ self.radial_net = nn.Sequential(
327
+ nn.Linear(n_radial, 32), nn.GELU(), nn.Linear(32, 16))
328
+ radial_feat_dim = 16
329
+
330
+ symmetry_feat_dim = 6
331
+
332
+ self.voxel_conv = nn.Sequential(
333
+ nn.Conv3d(1, 16, 3, padding=1), nn.GELU(),
334
+ nn.Conv3d(16, 32, 3, padding=1), nn.GELU(),
335
+ nn.AdaptiveAvgPool3d(2))
336
+ voxel3d_feat_dim = 32 * 8 # 256
337
+
338
+ # DifferentiationGate for curvature direction
339
+ self.diff_gate = DifferentiationGate(embed_dim)
340
+
341
+ # Pre-gate combine (without direction features)
342
+ pre_gate_dim = (plane_feat_dim + radial_feat_dim + symmetry_feat_dim +
343
+ voxel3d_feat_dim + rigid_feat_dim + fill_dim)
344
+
345
+ # Pre-gate feature projection: SwiGLU for sharp geometric feature gating
346
+ self.pre_gate_proj = nn.Sequential(
347
+ SwiGLU(pre_gate_dim, embed_dim * 2),
348
+ nn.Linear(embed_dim * 2, embed_dim))
349
+
350
+ # Post-gate: gated features + direction features + alternation + raw combine
351
+ # = embed_dim (gated) + embed_dim (direction) + 1 (alternation) + pre_gate_dim
352
+ post_gate_dim = embed_dim + embed_dim + 1 + pre_gate_dim
353
+
354
+ # SwiGLU for all curvature decision heads: sharp geometric classification
355
+ self.curved_head = nn.Sequential(
356
+ SwiGLU(post_gate_dim, embed_dim),
357
+ nn.Linear(embed_dim, 1), nn.Sigmoid())
358
+ self.curv_type_head = nn.Sequential(
359
+ SwiGLU(post_gate_dim, embed_dim),
360
+ nn.Linear(embed_dim, NUM_CURVATURES))
361
+ self.curv_features = nn.Sequential(
362
+ SwiGLU(post_gate_dim, embed_dim * 2),
363
+ nn.Linear(embed_dim * 2, embed_dim))
364
+
365
+ def forward(self, grid, rigid_retained, fill_ratios):
366
+ B = grid.shape[0]
367
+
368
+ proj_x = grid.max(dim=1).values
369
+ proj_y = grid.max(dim=2).values
370
+ proj_z = grid.max(dim=3).values
371
+
372
+ # Batch all 3 projections through plane_conv in single pass
373
+ projs_batched = torch.cat([
374
+ proj_x.unsqueeze(1), proj_y.unsqueeze(1), proj_z.unsqueeze(1)
375
+ ], dim=0) # (3B, 1, 5, 5)
376
+ plane_all = self.plane_conv(projs_batched).reshape(3, B, -1) # (3, B, 64)
377
+ plane_feat = plane_all.permute(1, 0, 2).reshape(B, -1) # (B, 192)
378
+
379
+ radial = self._radial_profile(grid)
380
+ radial_feat = self.radial_net(radial)
381
+
382
+ sym_feat = self._symmetry_features(proj_x, proj_y, proj_z)
383
+
384
+ vox3d_feat = self.voxel_conv(grid.unsqueeze(1)).reshape(B, -1)
385
+
386
+ # Raw curvature features (shape-aware but direction-blind)
387
+ raw_combined = torch.cat([
388
+ plane_feat, radial_feat, sym_feat, vox3d_feat,
389
+ rigid_retained, fill_ratios], dim=-1)
390
+
391
+ # Project to gatable dimension
392
+ pre_gate = self.pre_gate_proj(raw_combined) # (B, embed_dim)
393
+
394
+ # Direction analysis
395
+ dir_gate, dir_feat, alternation = self.diff_gate(grid)
396
+
397
+ # Apply gate: direction-modulated curvature features
398
+ gated = pre_gate * dir_gate # (B, embed_dim) — convex/concave differentiation
399
+
400
+ # Full post-gate features
401
+ combined = torch.cat([gated, dir_feat, alternation, raw_combined], dim=-1)
402
+
403
+ is_curved = self.curved_head(combined)
404
+ curv_logits = self.curv_type_head(combined)
405
+ curv_feat = self.curv_features(combined)
406
+ return is_curved, curv_logits, curv_feat, alternation
407
+
408
+ def _radial_profile(self, grid):
409
+ B = grid.shape[0]
410
+ device = grid.device
411
+ coords = torch.stack(torch.meshgrid(
412
+ torch.arange(GS, device=device, dtype=torch.float32),
413
+ torch.arange(GS, device=device, dtype=torch.float32),
414
+ torch.arange(GS, device=device, dtype=torch.float32),
415
+ indexing="ij"), dim=-1)
416
+ flat_grid = grid.reshape(B, -1)
417
+ flat_coords = coords.reshape(-1, 3)
418
+ total_occ = flat_grid.sum(dim=-1, keepdim=True).clamp(min=1)
419
+ centroids = (flat_grid.unsqueeze(-1) * flat_coords.unsqueeze(0)).sum(dim=1) / total_occ
420
+ diffs = flat_coords.unsqueeze(0) - centroids.unsqueeze(1)
421
+ dists = diffs.norm(dim=-1) # (B, 125)
422
+ max_dist = 3.5
423
+ n_bins = 5
424
+ # Arithmetic binning (Inductor-safe, no bucketize)
425
+ bin_idx = torch.nan_to_num(dists * (float(n_bins) / max_dist), nan=0.0).long().clamp(0, n_bins - 1)
426
+ one_hot = F.one_hot(bin_idx, n_bins).float() # (B, 125, 5)
427
+ weighted = flat_grid.unsqueeze(-1) * one_hot # (B, 125, 5)
428
+ profile = weighted.sum(dim=1) / total_occ # (B, 5)
429
+ return profile
430
+
431
+ def _symmetry_features(self, proj_x, proj_y, proj_z):
432
+ projs = torch.stack([proj_x, proj_y, proj_z], dim=1) # (B, 3, H, W)
433
+ fh = torch.flip(projs, dims=[2])
434
+ fv = torch.flip(projs, dims=[3])
435
+ sym = 1.0 - ((projs - fh).abs().mean(dim=(2, 3)) +
436
+ (projs - fv).abs().mean(dim=(2, 3))) / 2 # (B, 3)
437
+ shift_diff = (projs[:, :, 1:, :] - projs[:, :, :-1, :]).abs().mean(dim=(2, 3)) # (B, 3)
438
+ trans_inv = 1.0 - shift_diff
439
+ # Interleave: [sym0, trans0, sym1, trans1, sym2, trans2]
440
+ return torch.stack([sym[:, 0], trans_inv[:, 0],
441
+ sym[:, 1], trans_inv[:, 1],
442
+ sym[:, 2], trans_inv[:, 2]], dim=-1) # (B, 6)
443
+
444
+
445
+ # === Confidence Computation ====================================================
446
+
447
+ def compute_confidence(logits):
448
+ """
449
+ Compute real calibrated confidence metrics from logits.
450
+
451
+ Returns dict with:
452
+ max_prob: max(softmax(logits)) — calibrated top-class probability
453
+ margin: top1_prob - top2_prob — disambiguation strength
454
+ entropy: -sum(p * log(p)) — total uncertainty (lower = more confident)
455
+ confidence: margin — primary confidence signal for gating
456
+ """
457
+ probs = F.softmax(logits, dim=-1)
458
+ max_prob, _ = probs.max(dim=-1)
459
+
460
+ top2 = probs.topk(2, dim=-1).values
461
+ margin = top2[:, 0] - top2[:, 1]
462
+
463
+ # Entropy normalized to [0, 1] range
464
+ log_probs = F.log_softmax(logits, dim=-1)
465
+ entropy = -(probs * log_probs).sum(dim=-1)
466
+ max_entropy = math.log(logits.shape[-1])
467
+ norm_entropy = entropy / max_entropy
468
+
469
+ return {
470
+ "max_prob": max_prob,
471
+ "margin": margin,
472
+ "entropy": norm_entropy,
473
+ "confidence": margin, # primary signal
474
+ }
475
+
476
+
477
+ # === Rectified Flow Arbiter ===================================================
478
+
479
+ class RectifiedFlowArbiter(nn.Module):
480
+ """
481
+ Rectified flow matching for ambiguous classification refinement.
482
+
483
+ Real flow matching requires a target endpoint to define the velocity field.
484
+ We learn class prototypes in latent space as targets: for a sample of class c,
485
+ the target is prototype[c]. The velocity field learns to transport the
486
+ encoded feature z0 toward the correct prototype z1 in straight lines:
487
+
488
+ v_target = z1 - z0 (rectified: straight path from source to target)
489
+ loss = ||v_predicted - v_target||^2 (flow matching objective)
490
+
491
+ At inference, the arbiter integrates the learned velocity field from z0,
492
+ landing near the correct class prototype. Classification reads off the
493
+ nearest prototype.
494
+
495
+ Confidence gating: velocity magnitude is scaled by (1 - margin), so
496
+ confident first-pass predictions receive minimal correction.
497
+ """
498
+
499
+ def __init__(self, feat_dim, n_classes, n_steps=4, latent_dim=128, embed_dim=64):
500
+ super().__init__()
501
+ self.n_steps = n_steps
502
+ self.n_classes = n_classes
503
+ self.dt = 1.0 / n_steps
504
+ self.latent_dim = latent_dim
505
+
506
+ # Project features to latent space
507
+ self.encode = nn.Sequential(
508
+ nn.Linear(feat_dim, latent_dim * 2), nn.GELU(),
509
+ nn.Linear(latent_dim * 2, latent_dim))
510
+
511
+ # Learnable class prototypes — target endpoints for flow
512
+ self.prototypes = nn.Parameter(torch.randn(n_classes, latent_dim) * 0.05)
513
+
514
+ # Timestep embedding
515
+ self.time_embed = nn.Sequential(
516
+ nn.Linear(16, embed_dim), nn.GELU(),
517
+ nn.Linear(embed_dim, embed_dim))
518
+
519
+ # Confidence embedding
520
+ self.conf_embed = nn.Sequential(
521
+ nn.Linear(3, embed_dim), nn.GELU(),
522
+ nn.Linear(embed_dim, embed_dim))
523
+
524
+ # Velocity network: predicts flow direction in latent space
525
+ vel_in = latent_dim + embed_dim + embed_dim
526
+ self.velocity = nn.Sequential(
527
+ SwiGLU(vel_in, latent_dim),
528
+ nn.Linear(latent_dim, latent_dim),
529
+ SwiGLU(latent_dim, latent_dim),
530
+ nn.Linear(latent_dim, latent_dim))
531
+
532
+ # Velocity gate: low confidence → full correction, high → minimal
533
+ self.vel_gate = nn.Sequential(
534
+ nn.Linear(embed_dim, latent_dim), nn.Sigmoid())
535
+
536
+ # Classification from latent: distance to prototypes + learned head
537
+ self.classifier_head = nn.Sequential(
538
+ SwiGLU(latent_dim + n_classes, 96),
539
+ nn.Linear(96, n_classes))
540
+
541
+ # Learned confidence head for blending (differentiable, not topk)
542
+ self.blend_head = nn.Sequential(
543
+ nn.Linear(feat_dim, 64), nn.GELU(),
544
+ nn.Linear(64, 1), nn.Sigmoid())
545
+
546
+ # Post-refinement confidence
547
+ self.refined_confidence = nn.Sequential(
548
+ SwiGLU(latent_dim, 32),
549
+ nn.Linear(32, 1), nn.Sigmoid())
550
+
551
+ def _time_encoding(self, t, device):
552
+ freqs = torch.exp(torch.linspace(0, -4, 8, device=device))
553
+ args = t.unsqueeze(-1) * freqs.unsqueeze(0)
554
+ return torch.cat([args.sin(), args.cos()], dim=-1)
555
+
556
+ def _proto_logits(self, z):
557
+ """Classify by negative distance to prototypes."""
558
+ # (B, latent) vs (C, latent) → (B, C) distances
559
+ dists = torch.cdist(z.unsqueeze(0), self.prototypes.unsqueeze(0)).squeeze(0)
560
+ # Combine distance signal with learned head
561
+ combined = torch.cat([z, -dists], dim=-1) # (B, latent + n_classes)
562
+ return self.classifier_head(combined)
563
+
564
+ def forward(self, features, initial_logits, labels=None):
565
+ """
566
+ features: (B, feat_dim)
567
+ initial_logits: (B, n_classes)
568
+ labels: (B,) — only during training, for flow matching target
569
+
570
+ Returns:
571
+ refined_logits, refined_conf, initial_conf, trajectory_logits, flow_loss
572
+ """
573
+ B = features.shape[0]
574
+ device = features.device
575
+
576
+ # Confidence from initial logits
577
+ initial_conf = compute_confidence(initial_logits)
578
+ conf_input = torch.stack([
579
+ initial_conf["max_prob"],
580
+ initial_conf["margin"],
581
+ initial_conf["entropy"]], dim=-1)
582
+ conf_emb = self.conf_embed(conf_input)
583
+
584
+ # Confidence-gated velocity magnitude
585
+ gate = self.vel_gate(conf_emb)
586
+ inv_conf = (1.0 - initial_conf["margin"]).unsqueeze(-1)
587
+ adaptive_gate = gate * inv_conf
588
+
589
+ # Encode to latent
590
+ z0 = self.encode(features)
591
+
592
+ # === Flow matching target ===
593
+ flow_loss = torch.tensor(0.0, device=device)
594
+ if labels is not None:
595
+ # Target: class prototype for each sample
596
+ z1 = self.prototypes[labels] # (B, latent_dim)
597
+ # Target velocity: straight path z0 → z1
598
+ v_target = z1 - z0 # (B, latent_dim)
599
+
600
+ # Sample random timestep for flow matching training
601
+ t_rand = torch.rand(B, device=device)
602
+ t_emb = self.time_embed(self._time_encoding(t_rand, device))
603
+
604
+ # Interpolated position along straight path
605
+ z_t = z0 + t_rand.unsqueeze(-1) * v_target # (B, latent_dim)
606
+
607
+ # Predicted velocity at this point
608
+ vel_input = torch.cat([z_t, t_emb, conf_emb], dim=-1)
609
+ v_pred = self.velocity(vel_input) * adaptive_gate
610
+ v_pred = v_pred.clamp(-20, 20)
611
+
612
+ # Flow matching loss: predicted velocity should match target
613
+ flow_loss = F.mse_loss(v_pred, v_target.clamp(-20, 20))
614
+
615
+ # === Inference: integrate velocity field ===
616
+ z = z0
617
+ trajectory_logits = []
618
+ for step in range(self.n_steps):
619
+ t_val = torch.full((B,), step * self.dt, device=device)
620
+ t_emb = self.time_embed(self._time_encoding(t_val, device))
621
+
622
+ vel_input = torch.cat([z, t_emb, conf_emb], dim=-1)
623
+ v = self.velocity(vel_input) * adaptive_gate
624
+ # Prevent BF16 divergence: clamp velocity magnitude
625
+ v = v.clamp(-20, 20)
626
+
627
+ z = z + self.dt * v
628
+ trajectory_logits.append(self._proto_logits(z))
629
+
630
+ refined_logits = trajectory_logits[-1]
631
+ refined_conf = self.refined_confidence(z)
632
+
633
+ # Learned blend weight (differentiable, from initial features)
634
+ blend_weight = self.blend_head(features) # (B, 1)
635
+
636
+ return refined_logits, refined_conf, initial_conf, trajectory_logits, flow_loss, blend_weight
637
+
638
+
639
+ # === Model ====================================================================
640
+
641
+ class GeometricShapeClassifier(nn.Module):
642
+ def __init__(self, n_classes=NUM_CLASSES, embed_dim=64, n_tracers=5):
643
+ super().__init__()
644
+ self.n_tracers = n_tracers
645
+ self.embed_dim = embed_dim
646
+
647
+ self.voxel_embed = nn.Sequential(
648
+ nn.Linear(4, embed_dim), nn.GELU(), nn.Linear(embed_dim, embed_dim))
649
+
650
+ coords = torch.stack(torch.meshgrid(
651
+ torch.arange(GS, dtype=torch.float32),
652
+ torch.arange(GS, dtype=torch.float32),
653
+ torch.arange(GS, dtype=torch.float32),
654
+ indexing="ij"), dim=-1) / (GS - 1) # (5,5,5,3) normalized
655
+ self.register_buffer("pos_grid", coords)
656
+
657
+ self.tracer_tokens = nn.Parameter(torch.randn(n_tracers, embed_dim) * 0.02)
658
+ self.tracer_attn = nn.MultiheadAttention(embed_dim, num_heads=4, batch_first=True)
659
+ self.tracer_gate = nn.Sequential(nn.Linear(embed_dim * 2, embed_dim), nn.Sigmoid())
660
+ self.tracer_interact = nn.Sequential(
661
+ nn.Linear(embed_dim * 2, embed_dim), nn.GELU(), nn.Linear(embed_dim, embed_dim))
662
+ # SwiGLU for edge detection: sharp "edge present?" decision
663
+ self.edge_head = nn.Sequential(
664
+ SwiGLU(embed_dim * 2, 32), nn.Linear(32, 1))
665
+
666
+ # Precompute all C(n_tracers, 2) pair indices for vectorized interaction
667
+ _pi, _pj = [], []
668
+ for i in range(n_tracers):
669
+ for j in range(i + 1, n_tracers):
670
+ _pi.append(i); _pj.append(j)
671
+ self.register_buffer("_pair_i", torch.tensor(_pi, dtype=torch.long))
672
+ self.register_buffer("_pair_j", torch.tensor(_pj, dtype=torch.long))
673
+ self.n_pairs = len(_pi)
674
+
675
+ pool_dim = embed_dim * n_tracers
676
+
677
+ self.dim0 = CapacityHead(pool_dim, embed_dim, init_capacity=0.5)
678
+ self.dim1 = CapacityHead(pool_dim + embed_dim, embed_dim, init_capacity=1.0)
679
+ self.dim2 = CapacityHead(pool_dim + embed_dim, embed_dim, init_capacity=1.5)
680
+ self.dim3 = CapacityHead(pool_dim + embed_dim, embed_dim, init_capacity=2.0)
681
+
682
+ rigid_feat_dim = embed_dim * 4
683
+ self.curvature = CurvatureHead(rigid_feat_dim, fill_dim=4, embed_dim=embed_dim)
684
+
685
+ class_in = pool_dim + 4 + rigid_feat_dim + embed_dim + 1
686
+ self.class_in = class_in # Store for arbiter
687
+ self.classifier = nn.Sequential(
688
+ nn.Linear(class_in, 256), nn.GELU(), nn.Dropout(0.1),
689
+ nn.Linear(256, 128), nn.GELU(), nn.Linear(128, n_classes))
690
+
691
+ # SwiGLU for peak dimension: sharp "which dimension?" decision
692
+ self.peak_head = nn.Sequential(
693
+ SwiGLU(class_in, 32), nn.Linear(32, 4))
694
+ # Volume is continuous interpolation — keep GELU
695
+ self.volume_head = nn.Sequential(
696
+ nn.Linear(class_in, 64), nn.GELU(), nn.Linear(64, 1))
697
+ # SwiGLU for CM determinant sign: sharp geometric determinant
698
+ self.cm_head = nn.Sequential(
699
+ SwiGLU(class_in, 64), nn.Linear(64, 1), nn.Tanh())
700
+
701
+ # Rectified flow arbiter for ambiguous classification
702
+ self.arbiter = RectifiedFlowArbiter(
703
+ feat_dim=class_in, n_classes=n_classes,
704
+ n_steps=4, latent_dim=128, embed_dim=embed_dim)
705
+
706
+ def forward(self, grid, labels=None):
707
+ B = grid.shape[0]
708
+ occ = grid.reshape(B, GS**3, 1)
709
+ pos = self.pos_grid.reshape(1, GS**3, 3).expand(B, -1, -1)
710
+ voxel_emb = self.voxel_embed(torch.cat([occ, pos], dim=-1))
711
+
712
+ tracers = self.tracer_tokens.unsqueeze(0).expand(B, -1, -1)
713
+ tracers, _ = self.tracer_attn(tracers, voxel_emb, voxel_emb)
714
+
715
+ # Vectorized pair interaction: all C(5,2)=10 pairs at once
716
+ left = tracers[:, self._pair_i] # (B, 10, embed_dim)
717
+ right = tracers[:, self._pair_j] # (B, 10, embed_dim)
718
+ pairs = torch.cat([left, right], dim=-1) # (B, 10, embed_dim*2)
719
+
720
+ # Flatten to batch, run networks, reshape back
721
+ flat_pairs = pairs.reshape(B * self.n_pairs, -1)
722
+ gate = self.tracer_gate(flat_pairs).reshape(B, self.n_pairs, -1)
723
+ interaction = self.tracer_interact(flat_pairs).reshape(B, self.n_pairs, -1)
724
+ edge_lengths = self.edge_head(flat_pairs).reshape(B, self.n_pairs)
725
+
726
+ # Scatter-add gated interactions back to both tracers in each pair
727
+ gated = gate * interaction # (B, 10, embed_dim)
728
+ tracer_out = tracers.clone()
729
+ pi_exp = self._pair_i.view(1, self.n_pairs, 1).expand(B, -1, self.embed_dim)
730
+ pj_exp = self._pair_j.view(1, self.n_pairs, 1).expand(B, -1, self.embed_dim)
731
+ tracer_out.scatter_add_(1, pi_exp, gated)
732
+ tracer_out.scatter_add_(1, pj_exp, gated)
733
+ pooled = tracer_out.reshape(B, -1)
734
+
735
+ fill0, ovf0, ret0, cap0, _ = self.dim0(pooled)
736
+ fill1, ovf1, ret1, cap1, _ = self.dim1(torch.cat([pooled, ovf0], -1))
737
+ fill2, ovf2, ret2, cap2, _ = self.dim2(torch.cat([pooled, ovf1], -1))
738
+ fill3, ovf3, ret3, cap3, _ = self.dim3(torch.cat([pooled, ovf2], -1))
739
+
740
+ fill_ratios = torch.cat([fill0, fill1, fill2, fill3], dim=-1)
741
+ rigid_retained = torch.cat([ret0, ret1, ret2, ret3], dim=-1)
742
+ ovf_norms = torch.stack([
743
+ ovf0.norm(dim=-1), ovf1.norm(dim=-1),
744
+ ovf2.norm(dim=-1), ovf3.norm(dim=-1)], dim=-1)
745
+
746
+ is_curved, curv_logits, curv_feat, alternation = self.curvature(grid, rigid_retained, fill_ratios)
747
+ full = torch.cat([pooled, fill_ratios, rigid_retained, curv_feat, is_curved], dim=-1)
748
+
749
+ # === First pass classification ===
750
+ initial_logits = self.classifier(full)
751
+
752
+ # === Rectified flow arbitration ===
753
+ refined_logits, refined_conf, initial_conf, trajectory_logits, flow_loss, blend_weight = \
754
+ self.arbiter(full, initial_logits, labels=labels)
755
+
756
+ # === Blend: learned confidence head decides trust ===
757
+ # blend_weight is (B, 1) sigmoid output from learned head
758
+ final_logits = blend_weight * initial_logits + (1.0 - blend_weight) * refined_logits
759
+
760
+ return {
761
+ # Classification
762
+ "class_logits": final_logits,
763
+ "initial_logits": initial_logits,
764
+ "refined_logits": refined_logits,
765
+ "trajectory_logits": trajectory_logits,
766
+ # Flow matching
767
+ "flow_loss": flow_loss,
768
+ # Confidence
769
+ "confidence": initial_conf["confidence"],
770
+ "max_prob": initial_conf["max_prob"],
771
+ "entropy": initial_conf["entropy"],
772
+ "refined_confidence": refined_conf,
773
+ "blend_weight": blend_weight.squeeze(-1),
774
+ # Auxiliary heads
775
+ "peak_logits": self.peak_head(full),
776
+ "volume_pred": self.volume_head(full).squeeze(-1),
777
+ "cm_pred": self.cm_head(full).squeeze(-1),
778
+ "edge_lengths": edge_lengths,
779
+ "fill_ratios": fill_ratios,
780
+ "overflows": ovf_norms,
781
+ "capacities": torch.stack([cap0, cap1, cap2, cap3]),
782
+ "is_curved_pred": is_curved,
783
+ "curv_type_logits": curv_logits,
784
+ "alternation": alternation,
785
+ # Pre-classifier features (for cross-contrast)
786
+ "features": full,
787
+ }
788
+
789
+
790
+ # Quick sanity
791
+ _m = GeometricShapeClassifier()
792
+ print(f'GeometricShapeClassifier: {sum(p.numel() for p in _m.parameters()):,} params')
793
+ del _m