# Cell 2 # === Capacity Head ============================================================ class CapacityHead(nn.Module): def __init__(self, in_dim, feat_dim, init_capacity=1.0): super().__init__() self._raw_capacity = nn.Parameter(torch.tensor(math.log(math.exp(init_capacity) - 1))) # GELU for cascade: smooth gradients needed for overflow propagation self.evidence_net = nn.Sequential( nn.Linear(in_dim, feat_dim), nn.GELU(), nn.Linear(feat_dim, 1)) self.feature_net = nn.Sequential( nn.Linear(in_dim, feat_dim), nn.GELU(), nn.Linear(feat_dim, feat_dim)) self.retain_gate = nn.Sequential( nn.Linear(feat_dim + 1, feat_dim), nn.Sigmoid()) self.overflow_gate = nn.Sequential( nn.Linear(feat_dim + 1, feat_dim), nn.Sigmoid()) @property def capacity(self): return F.softplus(self._raw_capacity) def forward(self, x): cap = self.capacity raw_ev = F.relu(self.evidence_net(x)) fill = torch.clamp(raw_ev / (cap + 1e-8), max=1.0) sat = torch.clamp((raw_ev - cap) / (cap + 1e-8), min=0.0) feat = self.feature_net(x) retained = self.retain_gate(torch.cat([feat, fill], -1)) * feat * fill overflow = self.overflow_gate(torch.cat([feat, sat], -1)) * feat * torch.clamp(sat, max=1.0) return fill, overflow, retained, cap, raw_ev # === Differentiation Gate ===================================================== class DifferentiationGate(nn.Module): """ Curvature direction analysis via occupancy field differentiation. Computes gradient and Laplacian of the 3D occupancy field to determine: - Curvature direction: convex (normals point outward) vs concave (inward) - Curvature alternation: where sign flips (saddle points, torus inner/outer) - Perturbation robustness: smoothed gradient features survive noise The key insight: a hemisphere and bowl occupy nearly identical voxels, but their occupancy gradients point in opposite directions relative to the center of mass. The Laplacian's sign distinguishes them. Outputs gate signals that modulate curvature features: - direction_gate: learned weighting based on gradient analysis - alternation_score: how much curvature sign varies spatially - directional_features: rich features encoding curvature orientation """ def __init__(self, embed_dim=64): super().__init__() # Fixed 3D differentiation kernels — fused into single conv # 4 output channels: [grad_x, grad_y, grad_z, laplacian] diff_kernels = torch.zeros(4, 1, 3, 3, 3) # Sobel X diff_kernels[0, 0, 0, 1, 1] = -1; diff_kernels[0, 0, 2, 1, 1] = 1 # Sobel Y diff_kernels[1, 0, 1, 0, 1] = -1; diff_kernels[1, 0, 1, 2, 1] = 1 # Sobel Z diff_kernels[2, 0, 1, 1, 0] = -1; diff_kernels[2, 0, 1, 1, 2] = 1 # Laplacian diff_kernels[3, 0, 1, 1, 1] = -6 diff_kernels[3, 0, 0, 1, 1] = 1; diff_kernels[3, 0, 2, 1, 1] = 1 diff_kernels[3, 0, 1, 0, 1] = 1; diff_kernels[3, 0, 1, 2, 1] = 1 diff_kernels[3, 0, 1, 1, 0] = 1; diff_kernels[3, 0, 1, 1, 2] = 1 self.register_buffer("diff_kernels", diff_kernels) # Precompute coordinate grid coords = torch.stack(torch.meshgrid( torch.arange(GS, dtype=torch.float32), torch.arange(GS, dtype=torch.float32), torch.arange(GS, dtype=torch.float32), indexing="ij"), dim=-1) # (5,5,5,3) self.register_buffer("coords", coords) # Process gradient-derived features # Per-voxel: gradient direction, Laplacian sign, centroid-relative direction # Summarized as histograms and statistics # Gradient direction relative to centroid: 3 histogram bins per axis # + Laplacian sign distribution: 3 values (frac_pos, frac_neg, frac_zero) # + Alternation score: 1 value # + Per-axis gradient asymmetry: 3 values # + Radial gradient profile: 5 bins raw_feat_dim = 3 + 3 + 1 + 3 + 5 # = 15 # Plus the 3D conv on the Laplacian field preserving spatial structure self.lap_conv = nn.Sequential( nn.Conv3d(1, 16, 3, padding=1), nn.GELU(), nn.Conv3d(16, 16, 3, padding=1), nn.GELU(), nn.AdaptiveAvgPool3d(2)) # -> (B, 16, 2, 2, 2) = 128 lap_conv_dim = 16 * 8 # 128 # Gradient magnitude 3D conv (encodes where boundaries are + direction) self.grad_conv = nn.Sequential( nn.Conv3d(3, 16, 3, padding=1), nn.GELU(), # 3-channel: dx, dy, dz nn.Conv3d(16, 16, 3, padding=1), nn.GELU(), nn.AdaptiveAvgPool3d(2)) # -> (B, 16, 2, 2, 2) = 128 grad_conv_dim = 16 * 8 # 128 total_feat_dim = raw_feat_dim + lap_conv_dim + grad_conv_dim # 15 + 128 + 128 = 271 # Direction gate: SwiGLU for sharp convex/concave gating self.direction_net = nn.Sequential( SwiGLU(total_feat_dim, embed_dim), nn.Linear(embed_dim, embed_dim), nn.Sigmoid()) # Directional features: SwiGLU for crisp direction encoding self.direction_feat_net = nn.Sequential( SwiGLU(total_feat_dim, embed_dim), nn.Linear(embed_dim, embed_dim)) def forward(self, grid): """ grid: (B, 5, 5, 5) binary occupancy Returns: direction_gate: (B, embed_dim) sigmoid gate for curvature features direction_feat: (B, embed_dim) additive directional features alternation_score: (B, 1) how much curvature alternates """ B = grid.shape[0] device = grid.device vox = grid.unsqueeze(1) # (B, 1, 5, 5, 5) # === Smooth occupancy before differentiation === # Binary voxels produce spike gradients. Light blur creates # a continuous field whose derivatives are geometrically meaningful. vox_smooth = F.avg_pool3d( F.pad(vox, (1,1,1,1,1,1), mode='replicate'), kernel_size=3, stride=1, padding=0) # (B, 1, 5, 5, 5) # === Compute gradients + Laplacian in single fused conv === diff = F.conv3d(vox_smooth, self.diff_kernels, padding=1) # (B, 4, 5, 5, 5) grad_field = diff[:, :3] # (B, 3, 5, 5, 5) — gx, gy, gz gx, gy, gz = diff[:, 0:1], diff[:, 1:2], diff[:, 2:3] lap = diff[:, 3:4] # (B, 1, 5, 5, 5) # === Centroid === flat_grid = grid.reshape(B, -1) # (B, 125) flat_coords = self.coords.reshape(-1, 3) # (125, 3) total_occ = flat_grid.sum(dim=-1, keepdim=True).clamp(min=1) # (B, 1) centroids = (flat_grid.unsqueeze(-1) * flat_coords.unsqueeze(0)).sum(dim=1) / total_occ # (B, 3) # === Gradient direction relative to centroid === grad_flat = grad_field.reshape(B, 3, -1).permute(0, 2, 1) # (B, 125, 3) diff_from_center = flat_coords.unsqueeze(0) - centroids.unsqueeze(1) # (B, 125, 3) diff_norm = diff_from_center / (diff_from_center.norm(dim=-1, keepdim=True) + 1e-8) dot_products = (grad_flat * diff_norm).sum(dim=-1) # (B, 125) grad_mag = grad_flat.norm(dim=-1) # (B, 125) active = (flat_grid > 0.5) & (grad_mag > 0.01) # Histogram of dot product signs (convex/concave/neutral fractions) n_active = active.float().sum(-1).clamp(min=1) frac_outward = ((dot_products > 0.1) & active).float().sum(-1) / n_active frac_inward = ((dot_products < -0.1) & active).float().sum(-1) / n_active frac_neutral = 1.0 - frac_outward - frac_inward direction_hist = torch.stack([frac_outward, frac_inward, frac_neutral], dim=-1) # (B, 3) # === Laplacian sign distribution (active voxels only) === lap_flat = lap.reshape(B, -1) # (B, 125) lap_active = flat_grid > 0.5 n_lap_active = lap_active.float().sum(-1).clamp(min=1) frac_pos_lap = ((lap_flat > 0.1) & lap_active).float().sum(-1) / n_lap_active frac_neg_lap = ((lap_flat < -0.1) & lap_active).float().sum(-1) / n_lap_active frac_zero_lap = 1.0 - frac_pos_lap - frac_neg_lap lap_hist = torch.stack([frac_pos_lap, frac_neg_lap, frac_zero_lap], dim=-1) # (B, 3) # === Alternation score (ACTIVE VOXELS ONLY) === # Only count sign flips between neighbor pairs where BOTH voxels are # near occupied regions. Otherwise empty space dilutes the signal. lap_3d = lap.squeeze(1) # (B, 5, 5, 5) # Boundary mask: dilate occupancy by 1 to include immediate neighbors boundary_mask = F.max_pool3d(vox, kernel_size=3, stride=1, padding=1).squeeze(1) # (B,5,5,5) # X-axis: both neighbors must be in boundary region bm_x = boundary_mask[:, 1:, :, :] * boundary_mask[:, :-1, :, :] # (B,4,5,5) flip_x = (torch.sign(lap_3d[:, 1:, :, :]) * torch.sign(lap_3d[:, :-1, :, :]) < 0).float() active_flips_x = (flip_x * bm_x).sum(dim=(1, 2, 3)) active_pairs_x = bm_x.sum(dim=(1, 2, 3)).clamp(min=1) bm_y = boundary_mask[:, :, 1:, :] * boundary_mask[:, :, :-1, :] flip_y = (torch.sign(lap_3d[:, :, 1:, :]) * torch.sign(lap_3d[:, :, :-1, :]) < 0).float() active_flips_y = (flip_y * bm_y).sum(dim=(1, 2, 3)) active_pairs_y = bm_y.sum(dim=(1, 2, 3)).clamp(min=1) bm_z = boundary_mask[:, :, :, 1:] * boundary_mask[:, :, :, :-1] flip_z = (torch.sign(lap_3d[:, :, :, 1:]) * torch.sign(lap_3d[:, :, :, :-1]) < 0).float() active_flips_z = (flip_z * bm_z).sum(dim=(1, 2, 3)) active_pairs_z = bm_z.sum(dim=(1, 2, 3)).clamp(min=1) alternation = ((active_flips_x / active_pairs_x + active_flips_y / active_pairs_y + active_flips_z / active_pairs_z) / 3.0).unsqueeze(-1) # (B, 1) # === Per-axis gradient asymmetry === # Asymmetry: mean gradient along each axis (nonzero = asymmetric curvature) gx_mean = (gx.squeeze(1) * grid).sum(dim=(1, 2, 3)) / total_occ.squeeze(-1) gy_mean = (gy.squeeze(1) * grid).sum(dim=(1, 2, 3)) / total_occ.squeeze(-1) gz_mean = (gz.squeeze(1) * grid).sum(dim=(1, 2, 3)) / total_occ.squeeze(-1) grad_asym = torch.stack([gx_mean, gy_mean, gz_mean], dim=-1) # (B, 3) # === Radial gradient profile === # How does gradient magnitude vary with distance from centroid? dists = diff_from_center.norm(dim=-1) # (B, 125) # Arithmetic binning (Inductor-safe, no bucketize) # nan_to_num prevents NaN→long producing garbage indices under BF16 bin_idx = torch.nan_to_num(dists * (5.0 / 3.5), nan=0.0).long().clamp(0, 4) active_mask = (flat_grid > 0.5) # (B, 125) radial_grad = torch.zeros(B, 5, device=device) # Scatter-add: accumulate grad_mag and counts per bin weighted_mag = grad_mag * active_mask.float() # zero out inactive one_hot = F.one_hot(bin_idx, 5).float() # (B, 125, 5) active_oh = one_hot * active_mask.float().unsqueeze(-1) # mask inactive counts = active_oh.sum(dim=1).clamp(min=1) # (B, 5) radial_grad = (weighted_mag.unsqueeze(-1) * active_oh).sum(dim=1) / counts # (B, 5) # === Conv on Laplacian field (spatial curvature map) === lap_feat = self.lap_conv(lap).reshape(B, -1) # (B, 128) # === Conv on gradient field (directional boundaries) === grad_feat = self.grad_conv(grad_field).reshape(B, -1) # (B, 128) # === Combine all === raw_feat = torch.cat([ direction_hist, # 3 lap_hist, # 3 alternation, # 1 grad_asym, # 3 radial_grad, # 5 ], dim=-1) # (B, 15) all_feat = torch.cat([raw_feat, lap_feat, grad_feat], dim=-1) # (B, 271) direction_gate = self.direction_net(all_feat) # (B, embed_dim) sigmoid direction_feat = self.direction_feat_net(all_feat) # (B, embed_dim) return direction_gate, direction_feat, alternation # === Deformation Augmentation ================================================= def deform_grid(grid, p_dropout=0.1, p_add=0.1, p_shift=0.15): """Fully vectorized voxel augmentation — zero CPU-GPU sync points.""" B = grid.shape[0] device = grid.device r = torch.rand(B, 3, device=device) out = grid.clone() # --- Voxel dropout (batched, no .any() sync) --- drop_sel = (r[:, 0] < p_dropout).view(B, 1, 1, 1) keep = torch.rand_like(out) > 0.15 out = torch.where(drop_sel, out * keep.float(), out) # --- Boundary addition (batched, no .any() sync) --- add_sel = (r[:, 1] < p_add).view(B, 1, 1, 1).float() dilated = F.max_pool3d(out.unsqueeze(1), kernel_size=3, stride=1, padding=1).squeeze(1) boundary = ((dilated > 0.5) & (out < 0.5)).float() add_noise = (torch.rand_like(out) < 0.3).float() out = (out + boundary * add_noise * add_sel).clamp(max=1.0) # --- Small translation (fully vectorized, no loops, no boolean indexing) --- shift_sel = (r[:, 2] < p_shift) # (B,) axes = torch.randint(3, (B,), device=device) dirs = torch.randint(0, 2, (B,), device=device) * 2 - 1 # Precompute all 6 shifted versions of full batch (cheap for 5x5x5) # Encode: idx = axis * 2 + (dir==1) → [0..5], 6 = no shift versions = [] for ax in range(3): for d in [-1, 1]: s = torch.roll(out, shifts=d, dims=ax + 1) # +1 for batch dim # Zero wrapped edge if d == 1: if ax == 0: s[:, 0, :, :] = 0 elif ax == 1: s[:, :, 0, :] = 0 else: s[:, :, :, 0] = 0 else: if ax == 0: s[:, -1, :, :] = 0 elif ax == 1: s[:, :, -1, :] = 0 else: s[:, :, :, -1] = 0 versions.append(s) versions.append(out) # index 6 = no shift (identity) stacked = torch.stack(versions, dim=0) # (7, B, 5, 5, 5) # Per-sample assignment: which version to pick assign = torch.where(shift_sel, axes * 2 + (dirs == 1).long(), torch.full_like(axes, 6)) # Gather: stacked[assign[b], b] for each b out = stacked[assign, torch.arange(B, device=device)] return out # === Curvature Head (axis-aware) ============================================== class CurvatureHead(nn.Module): """ Axis-aware curvature detection with differentiation gating. 1. Per-axis max projections -> 2D conv (keeps 2×2 spatial) 2. Radial occupancy profile from centroid 3. Axial symmetry + translation invariance scores 4. 3D conv with spatial preservation (2×2×2) 5. DifferentiationGate: gradient/Laplacian analysis for direction detection The DifferentiationGate modulates curvature features so that convex and concave shapes get distinct representations even when their occupancy patterns are nearly identical. """ def __init__(self, rigid_feat_dim, fill_dim, embed_dim): super().__init__() self.plane_conv = nn.Sequential( nn.Conv2d(1, 16, 3, padding=1), nn.GELU(), nn.Conv2d(16, 16, 3, padding=1), nn.GELU(), nn.AdaptiveAvgPool2d(2)) plane_feat_dim = 3 * 16 * 4 # 192 n_radial = 5 self.radial_net = nn.Sequential( nn.Linear(n_radial, 32), nn.GELU(), nn.Linear(32, 16)) radial_feat_dim = 16 symmetry_feat_dim = 6 self.voxel_conv = nn.Sequential( nn.Conv3d(1, 16, 3, padding=1), nn.GELU(), nn.Conv3d(16, 32, 3, padding=1), nn.GELU(), nn.AdaptiveAvgPool3d(2)) voxel3d_feat_dim = 32 * 8 # 256 # DifferentiationGate for curvature direction self.diff_gate = DifferentiationGate(embed_dim) # Pre-gate combine (without direction features) pre_gate_dim = (plane_feat_dim + radial_feat_dim + symmetry_feat_dim + voxel3d_feat_dim + rigid_feat_dim + fill_dim) # Pre-gate feature projection: SwiGLU for sharp geometric feature gating self.pre_gate_proj = nn.Sequential( SwiGLU(pre_gate_dim, embed_dim * 2), nn.Linear(embed_dim * 2, embed_dim)) # Post-gate: gated features + direction features + alternation + raw combine # = embed_dim (gated) + embed_dim (direction) + 1 (alternation) + pre_gate_dim post_gate_dim = embed_dim + embed_dim + 1 + pre_gate_dim # SwiGLU for all curvature decision heads: sharp geometric classification self.curved_head = nn.Sequential( SwiGLU(post_gate_dim, embed_dim), nn.Linear(embed_dim, 1), nn.Sigmoid()) self.curv_type_head = nn.Sequential( SwiGLU(post_gate_dim, embed_dim), nn.Linear(embed_dim, NUM_CURVATURES)) self.curv_features = nn.Sequential( SwiGLU(post_gate_dim, embed_dim * 2), nn.Linear(embed_dim * 2, embed_dim)) def forward(self, grid, rigid_retained, fill_ratios): B = grid.shape[0] proj_x = grid.max(dim=1).values proj_y = grid.max(dim=2).values proj_z = grid.max(dim=3).values # Batch all 3 projections through plane_conv in single pass projs_batched = torch.cat([ proj_x.unsqueeze(1), proj_y.unsqueeze(1), proj_z.unsqueeze(1) ], dim=0) # (3B, 1, 5, 5) plane_all = self.plane_conv(projs_batched).reshape(3, B, -1) # (3, B, 64) plane_feat = plane_all.permute(1, 0, 2).reshape(B, -1) # (B, 192) radial = self._radial_profile(grid) radial_feat = self.radial_net(radial) sym_feat = self._symmetry_features(proj_x, proj_y, proj_z) vox3d_feat = self.voxel_conv(grid.unsqueeze(1)).reshape(B, -1) # Raw curvature features (shape-aware but direction-blind) raw_combined = torch.cat([ plane_feat, radial_feat, sym_feat, vox3d_feat, rigid_retained, fill_ratios], dim=-1) # Project to gatable dimension pre_gate = self.pre_gate_proj(raw_combined) # (B, embed_dim) # Direction analysis dir_gate, dir_feat, alternation = self.diff_gate(grid) # Apply gate: direction-modulated curvature features gated = pre_gate * dir_gate # (B, embed_dim) — convex/concave differentiation # Full post-gate features combined = torch.cat([gated, dir_feat, alternation, raw_combined], dim=-1) is_curved = self.curved_head(combined) curv_logits = self.curv_type_head(combined) curv_feat = self.curv_features(combined) return is_curved, curv_logits, curv_feat, alternation def _radial_profile(self, grid): B = grid.shape[0] device = grid.device coords = torch.stack(torch.meshgrid( torch.arange(GS, device=device, dtype=torch.float32), torch.arange(GS, device=device, dtype=torch.float32), torch.arange(GS, device=device, dtype=torch.float32), indexing="ij"), dim=-1) flat_grid = grid.reshape(B, -1) flat_coords = coords.reshape(-1, 3) total_occ = flat_grid.sum(dim=-1, keepdim=True).clamp(min=1) centroids = (flat_grid.unsqueeze(-1) * flat_coords.unsqueeze(0)).sum(dim=1) / total_occ diffs = flat_coords.unsqueeze(0) - centroids.unsqueeze(1) dists = diffs.norm(dim=-1) # (B, 125) max_dist = 3.5 n_bins = 5 # Arithmetic binning (Inductor-safe, no bucketize) bin_idx = torch.nan_to_num(dists * (float(n_bins) / max_dist), nan=0.0).long().clamp(0, n_bins - 1) one_hot = F.one_hot(bin_idx, n_bins).float() # (B, 125, 5) weighted = flat_grid.unsqueeze(-1) * one_hot # (B, 125, 5) profile = weighted.sum(dim=1) / total_occ # (B, 5) return profile def _symmetry_features(self, proj_x, proj_y, proj_z): projs = torch.stack([proj_x, proj_y, proj_z], dim=1) # (B, 3, H, W) fh = torch.flip(projs, dims=[2]) fv = torch.flip(projs, dims=[3]) sym = 1.0 - ((projs - fh).abs().mean(dim=(2, 3)) + (projs - fv).abs().mean(dim=(2, 3))) / 2 # (B, 3) shift_diff = (projs[:, :, 1:, :] - projs[:, :, :-1, :]).abs().mean(dim=(2, 3)) # (B, 3) trans_inv = 1.0 - shift_diff # Interleave: [sym0, trans0, sym1, trans1, sym2, trans2] return torch.stack([sym[:, 0], trans_inv[:, 0], sym[:, 1], trans_inv[:, 1], sym[:, 2], trans_inv[:, 2]], dim=-1) # (B, 6) # === Confidence Computation ==================================================== def compute_confidence(logits): """ Compute real calibrated confidence metrics from logits. Returns dict with: max_prob: max(softmax(logits)) — calibrated top-class probability margin: top1_prob - top2_prob — disambiguation strength entropy: -sum(p * log(p)) — total uncertainty (lower = more confident) confidence: margin — primary confidence signal for gating """ probs = F.softmax(logits, dim=-1) max_prob, _ = probs.max(dim=-1) top2 = probs.topk(2, dim=-1).values margin = top2[:, 0] - top2[:, 1] # Entropy normalized to [0, 1] range log_probs = F.log_softmax(logits, dim=-1) entropy = -(probs * log_probs).sum(dim=-1) max_entropy = math.log(logits.shape[-1]) norm_entropy = entropy / max_entropy return { "max_prob": max_prob, "margin": margin, "entropy": norm_entropy, "confidence": margin, # primary signal } # === Rectified Flow Arbiter =================================================== class RectifiedFlowArbiter(nn.Module): """ Rectified flow matching for ambiguous classification refinement. Real flow matching requires a target endpoint to define the velocity field. We learn class prototypes in latent space as targets: for a sample of class c, the target is prototype[c]. The velocity field learns to transport the encoded feature z0 toward the correct prototype z1 in straight lines: v_target = z1 - z0 (rectified: straight path from source to target) loss = ||v_predicted - v_target||^2 (flow matching objective) At inference, the arbiter integrates the learned velocity field from z0, landing near the correct class prototype. Classification reads off the nearest prototype. Confidence gating: velocity magnitude is scaled by (1 - margin), so confident first-pass predictions receive minimal correction. """ def __init__(self, feat_dim, n_classes, n_steps=4, latent_dim=128, embed_dim=64): super().__init__() self.n_steps = n_steps self.n_classes = n_classes self.dt = 1.0 / n_steps self.latent_dim = latent_dim # Project features to latent space self.encode = nn.Sequential( nn.Linear(feat_dim, latent_dim * 2), nn.GELU(), nn.Linear(latent_dim * 2, latent_dim)) # Learnable class prototypes — target endpoints for flow self.prototypes = nn.Parameter(torch.randn(n_classes, latent_dim) * 0.05) # Timestep embedding self.time_embed = nn.Sequential( nn.Linear(16, embed_dim), nn.GELU(), nn.Linear(embed_dim, embed_dim)) # Confidence embedding self.conf_embed = nn.Sequential( nn.Linear(3, embed_dim), nn.GELU(), nn.Linear(embed_dim, embed_dim)) # Velocity network: predicts flow direction in latent space vel_in = latent_dim + embed_dim + embed_dim self.velocity = nn.Sequential( SwiGLU(vel_in, latent_dim), nn.Linear(latent_dim, latent_dim), SwiGLU(latent_dim, latent_dim), nn.Linear(latent_dim, latent_dim)) # Velocity gate: low confidence → full correction, high → minimal self.vel_gate = nn.Sequential( nn.Linear(embed_dim, latent_dim), nn.Sigmoid()) # Classification from latent: distance to prototypes + learned head self.classifier_head = nn.Sequential( SwiGLU(latent_dim + n_classes, 96), nn.Linear(96, n_classes)) # Learned confidence head for blending (differentiable, not topk) self.blend_head = nn.Sequential( nn.Linear(feat_dim, 64), nn.GELU(), nn.Linear(64, 1), nn.Sigmoid()) # Post-refinement confidence self.refined_confidence = nn.Sequential( SwiGLU(latent_dim, 32), nn.Linear(32, 1), nn.Sigmoid()) def _time_encoding(self, t, device): freqs = torch.exp(torch.linspace(0, -4, 8, device=device)) args = t.unsqueeze(-1) * freqs.unsqueeze(0) return torch.cat([args.sin(), args.cos()], dim=-1) def _proto_logits(self, z): """Classify by negative distance to prototypes.""" # (B, latent) vs (C, latent) → (B, C) distances dists = torch.cdist(z.unsqueeze(0), self.prototypes.unsqueeze(0)).squeeze(0) # Combine distance signal with learned head combined = torch.cat([z, -dists], dim=-1) # (B, latent + n_classes) return self.classifier_head(combined) def forward(self, features, initial_logits, labels=None): """ features: (B, feat_dim) initial_logits: (B, n_classes) labels: (B,) — only during training, for flow matching target Returns: refined_logits, refined_conf, initial_conf, trajectory_logits, flow_loss """ B = features.shape[0] device = features.device # Confidence from initial logits initial_conf = compute_confidence(initial_logits) conf_input = torch.stack([ initial_conf["max_prob"], initial_conf["margin"], initial_conf["entropy"]], dim=-1) conf_emb = self.conf_embed(conf_input) # Confidence-gated velocity magnitude gate = self.vel_gate(conf_emb) inv_conf = (1.0 - initial_conf["margin"]).unsqueeze(-1) adaptive_gate = gate * inv_conf # Encode to latent z0 = self.encode(features) # === Flow matching target === flow_loss = torch.tensor(0.0, device=device) if labels is not None: # Target: class prototype for each sample z1 = self.prototypes[labels] # (B, latent_dim) # Target velocity: straight path z0 → z1 v_target = z1 - z0 # (B, latent_dim) # Sample random timestep for flow matching training t_rand = torch.rand(B, device=device) t_emb = self.time_embed(self._time_encoding(t_rand, device)) # Interpolated position along straight path z_t = z0 + t_rand.unsqueeze(-1) * v_target # (B, latent_dim) # Predicted velocity at this point vel_input = torch.cat([z_t, t_emb, conf_emb], dim=-1) v_pred = self.velocity(vel_input) * adaptive_gate v_pred = v_pred.clamp(-20, 20) # Flow matching loss: predicted velocity should match target flow_loss = F.mse_loss(v_pred, v_target.clamp(-20, 20)) # === Inference: integrate velocity field === z = z0 trajectory_logits = [] for step in range(self.n_steps): t_val = torch.full((B,), step * self.dt, device=device) t_emb = self.time_embed(self._time_encoding(t_val, device)) vel_input = torch.cat([z, t_emb, conf_emb], dim=-1) v = self.velocity(vel_input) * adaptive_gate # Prevent BF16 divergence: clamp velocity magnitude v = v.clamp(-20, 20) z = z + self.dt * v trajectory_logits.append(self._proto_logits(z)) refined_logits = trajectory_logits[-1] refined_conf = self.refined_confidence(z) # Learned blend weight (differentiable, from initial features) blend_weight = self.blend_head(features) # (B, 1) return refined_logits, refined_conf, initial_conf, trajectory_logits, flow_loss, blend_weight # === Model ==================================================================== class GeometricShapeClassifier(nn.Module): def __init__(self, n_classes=NUM_CLASSES, embed_dim=64, n_tracers=5): super().__init__() self.n_tracers = n_tracers self.embed_dim = embed_dim self.voxel_embed = nn.Sequential( nn.Linear(4, embed_dim), nn.GELU(), nn.Linear(embed_dim, embed_dim)) coords = torch.stack(torch.meshgrid( torch.arange(GS, dtype=torch.float32), torch.arange(GS, dtype=torch.float32), torch.arange(GS, dtype=torch.float32), indexing="ij"), dim=-1) / (GS - 1) # (5,5,5,3) normalized self.register_buffer("pos_grid", coords) self.tracer_tokens = nn.Parameter(torch.randn(n_tracers, embed_dim) * 0.02) self.tracer_attn = nn.MultiheadAttention(embed_dim, num_heads=4, batch_first=True) self.tracer_gate = nn.Sequential(nn.Linear(embed_dim * 2, embed_dim), nn.Sigmoid()) self.tracer_interact = nn.Sequential( nn.Linear(embed_dim * 2, embed_dim), nn.GELU(), nn.Linear(embed_dim, embed_dim)) # SwiGLU for edge detection: sharp "edge present?" decision self.edge_head = nn.Sequential( SwiGLU(embed_dim * 2, 32), nn.Linear(32, 1)) # Precompute all C(n_tracers, 2) pair indices for vectorized interaction _pi, _pj = [], [] for i in range(n_tracers): for j in range(i + 1, n_tracers): _pi.append(i); _pj.append(j) self.register_buffer("_pair_i", torch.tensor(_pi, dtype=torch.long)) self.register_buffer("_pair_j", torch.tensor(_pj, dtype=torch.long)) self.n_pairs = len(_pi) pool_dim = embed_dim * n_tracers self.dim0 = CapacityHead(pool_dim, embed_dim, init_capacity=0.5) self.dim1 = CapacityHead(pool_dim + embed_dim, embed_dim, init_capacity=1.0) self.dim2 = CapacityHead(pool_dim + embed_dim, embed_dim, init_capacity=1.5) self.dim3 = CapacityHead(pool_dim + embed_dim, embed_dim, init_capacity=2.0) rigid_feat_dim = embed_dim * 4 self.curvature = CurvatureHead(rigid_feat_dim, fill_dim=4, embed_dim=embed_dim) class_in = pool_dim + 4 + rigid_feat_dim + embed_dim + 1 self.class_in = class_in # Store for arbiter self.classifier = nn.Sequential( nn.Linear(class_in, 256), nn.GELU(), nn.Dropout(0.1), nn.Linear(256, 128), nn.GELU(), nn.Linear(128, n_classes)) # SwiGLU for peak dimension: sharp "which dimension?" decision self.peak_head = nn.Sequential( SwiGLU(class_in, 32), nn.Linear(32, 4)) # Volume is continuous interpolation — keep GELU self.volume_head = nn.Sequential( nn.Linear(class_in, 64), nn.GELU(), nn.Linear(64, 1)) # SwiGLU for CM determinant sign: sharp geometric determinant self.cm_head = nn.Sequential( SwiGLU(class_in, 64), nn.Linear(64, 1), nn.Tanh()) # Rectified flow arbiter for ambiguous classification self.arbiter = RectifiedFlowArbiter( feat_dim=class_in, n_classes=n_classes, n_steps=4, latent_dim=128, embed_dim=embed_dim) def forward(self, grid, labels=None): B = grid.shape[0] occ = grid.reshape(B, GS**3, 1) pos = self.pos_grid.reshape(1, GS**3, 3).expand(B, -1, -1) voxel_emb = self.voxel_embed(torch.cat([occ, pos], dim=-1)) tracers = self.tracer_tokens.unsqueeze(0).expand(B, -1, -1) tracers, _ = self.tracer_attn(tracers, voxel_emb, voxel_emb) # Vectorized pair interaction: all C(5,2)=10 pairs at once left = tracers[:, self._pair_i] # (B, 10, embed_dim) right = tracers[:, self._pair_j] # (B, 10, embed_dim) pairs = torch.cat([left, right], dim=-1) # (B, 10, embed_dim*2) # Flatten to batch, run networks, reshape back flat_pairs = pairs.reshape(B * self.n_pairs, -1) gate = self.tracer_gate(flat_pairs).reshape(B, self.n_pairs, -1) interaction = self.tracer_interact(flat_pairs).reshape(B, self.n_pairs, -1) edge_lengths = self.edge_head(flat_pairs).reshape(B, self.n_pairs) # Scatter-add gated interactions back to both tracers in each pair gated = gate * interaction # (B, 10, embed_dim) tracer_out = tracers.clone() pi_exp = self._pair_i.view(1, self.n_pairs, 1).expand(B, -1, self.embed_dim) pj_exp = self._pair_j.view(1, self.n_pairs, 1).expand(B, -1, self.embed_dim) tracer_out.scatter_add_(1, pi_exp, gated) tracer_out.scatter_add_(1, pj_exp, gated) pooled = tracer_out.reshape(B, -1) fill0, ovf0, ret0, cap0, _ = self.dim0(pooled) fill1, ovf1, ret1, cap1, _ = self.dim1(torch.cat([pooled, ovf0], -1)) fill2, ovf2, ret2, cap2, _ = self.dim2(torch.cat([pooled, ovf1], -1)) fill3, ovf3, ret3, cap3, _ = self.dim3(torch.cat([pooled, ovf2], -1)) fill_ratios = torch.cat([fill0, fill1, fill2, fill3], dim=-1) rigid_retained = torch.cat([ret0, ret1, ret2, ret3], dim=-1) ovf_norms = torch.stack([ ovf0.norm(dim=-1), ovf1.norm(dim=-1), ovf2.norm(dim=-1), ovf3.norm(dim=-1)], dim=-1) is_curved, curv_logits, curv_feat, alternation = self.curvature(grid, rigid_retained, fill_ratios) full = torch.cat([pooled, fill_ratios, rigid_retained, curv_feat, is_curved], dim=-1) # === First pass classification === initial_logits = self.classifier(full) # === Rectified flow arbitration === refined_logits, refined_conf, initial_conf, trajectory_logits, flow_loss, blend_weight = \ self.arbiter(full, initial_logits, labels=labels) # === Blend: learned confidence head decides trust === # blend_weight is (B, 1) sigmoid output from learned head final_logits = blend_weight * initial_logits + (1.0 - blend_weight) * refined_logits return { # Classification "class_logits": final_logits, "initial_logits": initial_logits, "refined_logits": refined_logits, "trajectory_logits": trajectory_logits, # Flow matching "flow_loss": flow_loss, # Confidence "confidence": initial_conf["confidence"], "max_prob": initial_conf["max_prob"], "entropy": initial_conf["entropy"], "refined_confidence": refined_conf, "blend_weight": blend_weight.squeeze(-1), # Auxiliary heads "peak_logits": self.peak_head(full), "volume_pred": self.volume_head(full).squeeze(-1), "cm_pred": self.cm_head(full).squeeze(-1), "edge_lengths": edge_lengths, "fill_ratios": fill_ratios, "overflows": ovf_norms, "capacities": torch.stack([cap0, cap1, cap2, cap3]), "is_curved_pred": is_curved, "curv_type_logits": curv_logits, "alternation": alternation, # Pre-classifier features (for cross-contrast) "features": full, } # Quick sanity _m = GeometricShapeClassifier() print(f'GeometricShapeClassifier: {sum(p.numel() for p in _m.parameters()):,} params') del _m