|
|
|
|
|
|
|
|
|
|
|
class CapacityHead(nn.Module): |
|
|
def __init__(self, in_dim, feat_dim, init_capacity=1.0): |
|
|
super().__init__() |
|
|
self._raw_capacity = nn.Parameter(torch.tensor(math.log(math.exp(init_capacity) - 1))) |
|
|
|
|
|
self.evidence_net = nn.Sequential( |
|
|
nn.Linear(in_dim, feat_dim), nn.GELU(), nn.Linear(feat_dim, 1)) |
|
|
self.feature_net = nn.Sequential( |
|
|
nn.Linear(in_dim, feat_dim), nn.GELU(), nn.Linear(feat_dim, feat_dim)) |
|
|
self.retain_gate = nn.Sequential( |
|
|
nn.Linear(feat_dim + 1, feat_dim), nn.Sigmoid()) |
|
|
self.overflow_gate = nn.Sequential( |
|
|
nn.Linear(feat_dim + 1, feat_dim), nn.Sigmoid()) |
|
|
|
|
|
@property |
|
|
def capacity(self): |
|
|
return F.softplus(self._raw_capacity) |
|
|
|
|
|
def forward(self, x): |
|
|
cap = self.capacity |
|
|
raw_ev = F.relu(self.evidence_net(x)) |
|
|
fill = torch.clamp(raw_ev / (cap + 1e-8), max=1.0) |
|
|
sat = torch.clamp((raw_ev - cap) / (cap + 1e-8), min=0.0) |
|
|
feat = self.feature_net(x) |
|
|
retained = self.retain_gate(torch.cat([feat, fill], -1)) * feat * fill |
|
|
overflow = self.overflow_gate(torch.cat([feat, sat], -1)) * feat * torch.clamp(sat, max=1.0) |
|
|
return fill, overflow, retained, cap, raw_ev |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DifferentiationGate(nn.Module): |
|
|
""" |
|
|
Curvature direction analysis via occupancy field differentiation. |
|
|
|
|
|
Computes gradient and Laplacian of the 3D occupancy field to determine: |
|
|
- Curvature direction: convex (normals point outward) vs concave (inward) |
|
|
- Curvature alternation: where sign flips (saddle points, torus inner/outer) |
|
|
- Perturbation robustness: smoothed gradient features survive noise |
|
|
|
|
|
The key insight: a hemisphere and bowl occupy nearly identical voxels, |
|
|
but their occupancy gradients point in opposite directions relative |
|
|
to the center of mass. The Laplacian's sign distinguishes them. |
|
|
|
|
|
Outputs gate signals that modulate curvature features: |
|
|
- direction_gate: learned weighting based on gradient analysis |
|
|
- alternation_score: how much curvature sign varies spatially |
|
|
- directional_features: rich features encoding curvature orientation |
|
|
""" |
|
|
|
|
|
def __init__(self, embed_dim=64): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
|
|
|
diff_kernels = torch.zeros(4, 1, 3, 3, 3) |
|
|
|
|
|
diff_kernels[0, 0, 0, 1, 1] = -1; diff_kernels[0, 0, 2, 1, 1] = 1 |
|
|
|
|
|
diff_kernels[1, 0, 1, 0, 1] = -1; diff_kernels[1, 0, 1, 2, 1] = 1 |
|
|
|
|
|
diff_kernels[2, 0, 1, 1, 0] = -1; diff_kernels[2, 0, 1, 1, 2] = 1 |
|
|
|
|
|
diff_kernels[3, 0, 1, 1, 1] = -6 |
|
|
diff_kernels[3, 0, 0, 1, 1] = 1; diff_kernels[3, 0, 2, 1, 1] = 1 |
|
|
diff_kernels[3, 0, 1, 0, 1] = 1; diff_kernels[3, 0, 1, 2, 1] = 1 |
|
|
diff_kernels[3, 0, 1, 1, 0] = 1; diff_kernels[3, 0, 1, 1, 2] = 1 |
|
|
self.register_buffer("diff_kernels", diff_kernels) |
|
|
|
|
|
|
|
|
coords = torch.stack(torch.meshgrid( |
|
|
torch.arange(GS, dtype=torch.float32), |
|
|
torch.arange(GS, dtype=torch.float32), |
|
|
torch.arange(GS, dtype=torch.float32), |
|
|
indexing="ij"), dim=-1) |
|
|
self.register_buffer("coords", coords) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raw_feat_dim = 3 + 3 + 1 + 3 + 5 |
|
|
|
|
|
self.lap_conv = nn.Sequential( |
|
|
nn.Conv3d(1, 16, 3, padding=1), nn.GELU(), |
|
|
nn.Conv3d(16, 16, 3, padding=1), nn.GELU(), |
|
|
nn.AdaptiveAvgPool3d(2)) |
|
|
lap_conv_dim = 16 * 8 |
|
|
|
|
|
|
|
|
self.grad_conv = nn.Sequential( |
|
|
nn.Conv3d(3, 16, 3, padding=1), nn.GELU(), |
|
|
nn.Conv3d(16, 16, 3, padding=1), nn.GELU(), |
|
|
nn.AdaptiveAvgPool3d(2)) |
|
|
grad_conv_dim = 16 * 8 |
|
|
|
|
|
total_feat_dim = raw_feat_dim + lap_conv_dim + grad_conv_dim |
|
|
|
|
|
|
|
|
self.direction_net = nn.Sequential( |
|
|
SwiGLU(total_feat_dim, embed_dim), |
|
|
nn.Linear(embed_dim, embed_dim), nn.Sigmoid()) |
|
|
|
|
|
|
|
|
self.direction_feat_net = nn.Sequential( |
|
|
SwiGLU(total_feat_dim, embed_dim), |
|
|
nn.Linear(embed_dim, embed_dim)) |
|
|
|
|
|
def forward(self, grid): |
|
|
""" |
|
|
grid: (B, 5, 5, 5) binary occupancy |
|
|
|
|
|
Returns: |
|
|
direction_gate: (B, embed_dim) sigmoid gate for curvature features |
|
|
direction_feat: (B, embed_dim) additive directional features |
|
|
alternation_score: (B, 1) how much curvature alternates |
|
|
""" |
|
|
B = grid.shape[0] |
|
|
device = grid.device |
|
|
vox = grid.unsqueeze(1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vox_smooth = F.avg_pool3d( |
|
|
F.pad(vox, (1,1,1,1,1,1), mode='replicate'), |
|
|
kernel_size=3, stride=1, padding=0) |
|
|
|
|
|
|
|
|
diff = F.conv3d(vox_smooth, self.diff_kernels, padding=1) |
|
|
grad_field = diff[:, :3] |
|
|
gx, gy, gz = diff[:, 0:1], diff[:, 1:2], diff[:, 2:3] |
|
|
lap = diff[:, 3:4] |
|
|
|
|
|
|
|
|
flat_grid = grid.reshape(B, -1) |
|
|
flat_coords = self.coords.reshape(-1, 3) |
|
|
total_occ = flat_grid.sum(dim=-1, keepdim=True).clamp(min=1) |
|
|
centroids = (flat_grid.unsqueeze(-1) * flat_coords.unsqueeze(0)).sum(dim=1) / total_occ |
|
|
|
|
|
|
|
|
grad_flat = grad_field.reshape(B, 3, -1).permute(0, 2, 1) |
|
|
diff_from_center = flat_coords.unsqueeze(0) - centroids.unsqueeze(1) |
|
|
diff_norm = diff_from_center / (diff_from_center.norm(dim=-1, keepdim=True) + 1e-8) |
|
|
dot_products = (grad_flat * diff_norm).sum(dim=-1) |
|
|
grad_mag = grad_flat.norm(dim=-1) |
|
|
active = (flat_grid > 0.5) & (grad_mag > 0.01) |
|
|
|
|
|
|
|
|
n_active = active.float().sum(-1).clamp(min=1) |
|
|
frac_outward = ((dot_products > 0.1) & active).float().sum(-1) / n_active |
|
|
frac_inward = ((dot_products < -0.1) & active).float().sum(-1) / n_active |
|
|
frac_neutral = 1.0 - frac_outward - frac_inward |
|
|
direction_hist = torch.stack([frac_outward, frac_inward, frac_neutral], dim=-1) |
|
|
|
|
|
|
|
|
lap_flat = lap.reshape(B, -1) |
|
|
lap_active = flat_grid > 0.5 |
|
|
n_lap_active = lap_active.float().sum(-1).clamp(min=1) |
|
|
frac_pos_lap = ((lap_flat > 0.1) & lap_active).float().sum(-1) / n_lap_active |
|
|
frac_neg_lap = ((lap_flat < -0.1) & lap_active).float().sum(-1) / n_lap_active |
|
|
frac_zero_lap = 1.0 - frac_pos_lap - frac_neg_lap |
|
|
lap_hist = torch.stack([frac_pos_lap, frac_neg_lap, frac_zero_lap], dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lap_3d = lap.squeeze(1) |
|
|
|
|
|
boundary_mask = F.max_pool3d(vox, kernel_size=3, stride=1, padding=1).squeeze(1) |
|
|
|
|
|
|
|
|
bm_x = boundary_mask[:, 1:, :, :] * boundary_mask[:, :-1, :, :] |
|
|
flip_x = (torch.sign(lap_3d[:, 1:, :, :]) * torch.sign(lap_3d[:, :-1, :, :]) < 0).float() |
|
|
active_flips_x = (flip_x * bm_x).sum(dim=(1, 2, 3)) |
|
|
active_pairs_x = bm_x.sum(dim=(1, 2, 3)).clamp(min=1) |
|
|
|
|
|
bm_y = boundary_mask[:, :, 1:, :] * boundary_mask[:, :, :-1, :] |
|
|
flip_y = (torch.sign(lap_3d[:, :, 1:, :]) * torch.sign(lap_3d[:, :, :-1, :]) < 0).float() |
|
|
active_flips_y = (flip_y * bm_y).sum(dim=(1, 2, 3)) |
|
|
active_pairs_y = bm_y.sum(dim=(1, 2, 3)).clamp(min=1) |
|
|
|
|
|
bm_z = boundary_mask[:, :, :, 1:] * boundary_mask[:, :, :, :-1] |
|
|
flip_z = (torch.sign(lap_3d[:, :, :, 1:]) * torch.sign(lap_3d[:, :, :, :-1]) < 0).float() |
|
|
active_flips_z = (flip_z * bm_z).sum(dim=(1, 2, 3)) |
|
|
active_pairs_z = bm_z.sum(dim=(1, 2, 3)).clamp(min=1) |
|
|
|
|
|
alternation = ((active_flips_x / active_pairs_x + |
|
|
active_flips_y / active_pairs_y + |
|
|
active_flips_z / active_pairs_z) / 3.0).unsqueeze(-1) |
|
|
|
|
|
|
|
|
|
|
|
gx_mean = (gx.squeeze(1) * grid).sum(dim=(1, 2, 3)) / total_occ.squeeze(-1) |
|
|
gy_mean = (gy.squeeze(1) * grid).sum(dim=(1, 2, 3)) / total_occ.squeeze(-1) |
|
|
gz_mean = (gz.squeeze(1) * grid).sum(dim=(1, 2, 3)) / total_occ.squeeze(-1) |
|
|
grad_asym = torch.stack([gx_mean, gy_mean, gz_mean], dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
dists = diff_from_center.norm(dim=-1) |
|
|
|
|
|
|
|
|
bin_idx = torch.nan_to_num(dists * (5.0 / 3.5), nan=0.0).long().clamp(0, 4) |
|
|
active_mask = (flat_grid > 0.5) |
|
|
radial_grad = torch.zeros(B, 5, device=device) |
|
|
|
|
|
weighted_mag = grad_mag * active_mask.float() |
|
|
one_hot = F.one_hot(bin_idx, 5).float() |
|
|
active_oh = one_hot * active_mask.float().unsqueeze(-1) |
|
|
counts = active_oh.sum(dim=1).clamp(min=1) |
|
|
radial_grad = (weighted_mag.unsqueeze(-1) * active_oh).sum(dim=1) / counts |
|
|
|
|
|
|
|
|
|
|
|
lap_feat = self.lap_conv(lap).reshape(B, -1) |
|
|
|
|
|
|
|
|
grad_feat = self.grad_conv(grad_field).reshape(B, -1) |
|
|
|
|
|
|
|
|
raw_feat = torch.cat([ |
|
|
direction_hist, |
|
|
lap_hist, |
|
|
alternation, |
|
|
grad_asym, |
|
|
radial_grad, |
|
|
], dim=-1) |
|
|
|
|
|
all_feat = torch.cat([raw_feat, lap_feat, grad_feat], dim=-1) |
|
|
|
|
|
direction_gate = self.direction_net(all_feat) |
|
|
direction_feat = self.direction_feat_net(all_feat) |
|
|
|
|
|
return direction_gate, direction_feat, alternation |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def deform_grid(grid, p_dropout=0.1, p_add=0.1, p_shift=0.15): |
|
|
"""Fully vectorized voxel augmentation — zero CPU-GPU sync points.""" |
|
|
B = grid.shape[0] |
|
|
device = grid.device |
|
|
r = torch.rand(B, 3, device=device) |
|
|
out = grid.clone() |
|
|
|
|
|
|
|
|
drop_sel = (r[:, 0] < p_dropout).view(B, 1, 1, 1) |
|
|
keep = torch.rand_like(out) > 0.15 |
|
|
out = torch.where(drop_sel, out * keep.float(), out) |
|
|
|
|
|
|
|
|
add_sel = (r[:, 1] < p_add).view(B, 1, 1, 1).float() |
|
|
dilated = F.max_pool3d(out.unsqueeze(1), kernel_size=3, stride=1, padding=1).squeeze(1) |
|
|
boundary = ((dilated > 0.5) & (out < 0.5)).float() |
|
|
add_noise = (torch.rand_like(out) < 0.3).float() |
|
|
out = (out + boundary * add_noise * add_sel).clamp(max=1.0) |
|
|
|
|
|
|
|
|
shift_sel = (r[:, 2] < p_shift) |
|
|
axes = torch.randint(3, (B,), device=device) |
|
|
dirs = torch.randint(0, 2, (B,), device=device) * 2 - 1 |
|
|
|
|
|
|
|
|
|
|
|
versions = [] |
|
|
for ax in range(3): |
|
|
for d in [-1, 1]: |
|
|
s = torch.roll(out, shifts=d, dims=ax + 1) |
|
|
|
|
|
if d == 1: |
|
|
if ax == 0: s[:, 0, :, :] = 0 |
|
|
elif ax == 1: s[:, :, 0, :] = 0 |
|
|
else: s[:, :, :, 0] = 0 |
|
|
else: |
|
|
if ax == 0: s[:, -1, :, :] = 0 |
|
|
elif ax == 1: s[:, :, -1, :] = 0 |
|
|
else: s[:, :, :, -1] = 0 |
|
|
versions.append(s) |
|
|
versions.append(out) |
|
|
stacked = torch.stack(versions, dim=0) |
|
|
|
|
|
|
|
|
assign = torch.where(shift_sel, axes * 2 + (dirs == 1).long(), torch.full_like(axes, 6)) |
|
|
|
|
|
out = stacked[assign, torch.arange(B, device=device)] |
|
|
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CurvatureHead(nn.Module): |
|
|
""" |
|
|
Axis-aware curvature detection with differentiation gating. |
|
|
|
|
|
1. Per-axis max projections -> 2D conv (keeps 2×2 spatial) |
|
|
2. Radial occupancy profile from centroid |
|
|
3. Axial symmetry + translation invariance scores |
|
|
4. 3D conv with spatial preservation (2×2×2) |
|
|
5. DifferentiationGate: gradient/Laplacian analysis for direction detection |
|
|
|
|
|
The DifferentiationGate modulates curvature features so that |
|
|
convex and concave shapes get distinct representations even when |
|
|
their occupancy patterns are nearly identical. |
|
|
""" |
|
|
|
|
|
def __init__(self, rigid_feat_dim, fill_dim, embed_dim): |
|
|
super().__init__() |
|
|
|
|
|
self.plane_conv = nn.Sequential( |
|
|
nn.Conv2d(1, 16, 3, padding=1), nn.GELU(), |
|
|
nn.Conv2d(16, 16, 3, padding=1), nn.GELU(), |
|
|
nn.AdaptiveAvgPool2d(2)) |
|
|
plane_feat_dim = 3 * 16 * 4 |
|
|
|
|
|
n_radial = 5 |
|
|
self.radial_net = nn.Sequential( |
|
|
nn.Linear(n_radial, 32), nn.GELU(), nn.Linear(32, 16)) |
|
|
radial_feat_dim = 16 |
|
|
|
|
|
symmetry_feat_dim = 6 |
|
|
|
|
|
self.voxel_conv = nn.Sequential( |
|
|
nn.Conv3d(1, 16, 3, padding=1), nn.GELU(), |
|
|
nn.Conv3d(16, 32, 3, padding=1), nn.GELU(), |
|
|
nn.AdaptiveAvgPool3d(2)) |
|
|
voxel3d_feat_dim = 32 * 8 |
|
|
|
|
|
|
|
|
self.diff_gate = DifferentiationGate(embed_dim) |
|
|
|
|
|
|
|
|
pre_gate_dim = (plane_feat_dim + radial_feat_dim + symmetry_feat_dim + |
|
|
voxel3d_feat_dim + rigid_feat_dim + fill_dim) |
|
|
|
|
|
|
|
|
self.pre_gate_proj = nn.Sequential( |
|
|
SwiGLU(pre_gate_dim, embed_dim * 2), |
|
|
nn.Linear(embed_dim * 2, embed_dim)) |
|
|
|
|
|
|
|
|
|
|
|
post_gate_dim = embed_dim + embed_dim + 1 + pre_gate_dim |
|
|
|
|
|
|
|
|
self.curved_head = nn.Sequential( |
|
|
SwiGLU(post_gate_dim, embed_dim), |
|
|
nn.Linear(embed_dim, 1), nn.Sigmoid()) |
|
|
self.curv_type_head = nn.Sequential( |
|
|
SwiGLU(post_gate_dim, embed_dim), |
|
|
nn.Linear(embed_dim, NUM_CURVATURES)) |
|
|
self.curv_features = nn.Sequential( |
|
|
SwiGLU(post_gate_dim, embed_dim * 2), |
|
|
nn.Linear(embed_dim * 2, embed_dim)) |
|
|
|
|
|
def forward(self, grid, rigid_retained, fill_ratios): |
|
|
B = grid.shape[0] |
|
|
|
|
|
proj_x = grid.max(dim=1).values |
|
|
proj_y = grid.max(dim=2).values |
|
|
proj_z = grid.max(dim=3).values |
|
|
|
|
|
|
|
|
projs_batched = torch.cat([ |
|
|
proj_x.unsqueeze(1), proj_y.unsqueeze(1), proj_z.unsqueeze(1) |
|
|
], dim=0) |
|
|
plane_all = self.plane_conv(projs_batched).reshape(3, B, -1) |
|
|
plane_feat = plane_all.permute(1, 0, 2).reshape(B, -1) |
|
|
|
|
|
radial = self._radial_profile(grid) |
|
|
radial_feat = self.radial_net(radial) |
|
|
|
|
|
sym_feat = self._symmetry_features(proj_x, proj_y, proj_z) |
|
|
|
|
|
vox3d_feat = self.voxel_conv(grid.unsqueeze(1)).reshape(B, -1) |
|
|
|
|
|
|
|
|
raw_combined = torch.cat([ |
|
|
plane_feat, radial_feat, sym_feat, vox3d_feat, |
|
|
rigid_retained, fill_ratios], dim=-1) |
|
|
|
|
|
|
|
|
pre_gate = self.pre_gate_proj(raw_combined) |
|
|
|
|
|
|
|
|
dir_gate, dir_feat, alternation = self.diff_gate(grid) |
|
|
|
|
|
|
|
|
gated = pre_gate * dir_gate |
|
|
|
|
|
|
|
|
combined = torch.cat([gated, dir_feat, alternation, raw_combined], dim=-1) |
|
|
|
|
|
is_curved = self.curved_head(combined) |
|
|
curv_logits = self.curv_type_head(combined) |
|
|
curv_feat = self.curv_features(combined) |
|
|
return is_curved, curv_logits, curv_feat, alternation |
|
|
|
|
|
def _radial_profile(self, grid): |
|
|
B = grid.shape[0] |
|
|
device = grid.device |
|
|
coords = torch.stack(torch.meshgrid( |
|
|
torch.arange(GS, device=device, dtype=torch.float32), |
|
|
torch.arange(GS, device=device, dtype=torch.float32), |
|
|
torch.arange(GS, device=device, dtype=torch.float32), |
|
|
indexing="ij"), dim=-1) |
|
|
flat_grid = grid.reshape(B, -1) |
|
|
flat_coords = coords.reshape(-1, 3) |
|
|
total_occ = flat_grid.sum(dim=-1, keepdim=True).clamp(min=1) |
|
|
centroids = (flat_grid.unsqueeze(-1) * flat_coords.unsqueeze(0)).sum(dim=1) / total_occ |
|
|
diffs = flat_coords.unsqueeze(0) - centroids.unsqueeze(1) |
|
|
dists = diffs.norm(dim=-1) |
|
|
max_dist = 3.5 |
|
|
n_bins = 5 |
|
|
|
|
|
bin_idx = torch.nan_to_num(dists * (float(n_bins) / max_dist), nan=0.0).long().clamp(0, n_bins - 1) |
|
|
one_hot = F.one_hot(bin_idx, n_bins).float() |
|
|
weighted = flat_grid.unsqueeze(-1) * one_hot |
|
|
profile = weighted.sum(dim=1) / total_occ |
|
|
return profile |
|
|
|
|
|
def _symmetry_features(self, proj_x, proj_y, proj_z): |
|
|
projs = torch.stack([proj_x, proj_y, proj_z], dim=1) |
|
|
fh = torch.flip(projs, dims=[2]) |
|
|
fv = torch.flip(projs, dims=[3]) |
|
|
sym = 1.0 - ((projs - fh).abs().mean(dim=(2, 3)) + |
|
|
(projs - fv).abs().mean(dim=(2, 3))) / 2 |
|
|
shift_diff = (projs[:, :, 1:, :] - projs[:, :, :-1, :]).abs().mean(dim=(2, 3)) |
|
|
trans_inv = 1.0 - shift_diff |
|
|
|
|
|
return torch.stack([sym[:, 0], trans_inv[:, 0], |
|
|
sym[:, 1], trans_inv[:, 1], |
|
|
sym[:, 2], trans_inv[:, 2]], dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_confidence(logits): |
|
|
""" |
|
|
Compute real calibrated confidence metrics from logits. |
|
|
|
|
|
Returns dict with: |
|
|
max_prob: max(softmax(logits)) — calibrated top-class probability |
|
|
margin: top1_prob - top2_prob — disambiguation strength |
|
|
entropy: -sum(p * log(p)) — total uncertainty (lower = more confident) |
|
|
confidence: margin — primary confidence signal for gating |
|
|
""" |
|
|
probs = F.softmax(logits, dim=-1) |
|
|
max_prob, _ = probs.max(dim=-1) |
|
|
|
|
|
top2 = probs.topk(2, dim=-1).values |
|
|
margin = top2[:, 0] - top2[:, 1] |
|
|
|
|
|
|
|
|
log_probs = F.log_softmax(logits, dim=-1) |
|
|
entropy = -(probs * log_probs).sum(dim=-1) |
|
|
max_entropy = math.log(logits.shape[-1]) |
|
|
norm_entropy = entropy / max_entropy |
|
|
|
|
|
return { |
|
|
"max_prob": max_prob, |
|
|
"margin": margin, |
|
|
"entropy": norm_entropy, |
|
|
"confidence": margin, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RectifiedFlowArbiter(nn.Module): |
|
|
""" |
|
|
Rectified flow matching for ambiguous classification refinement. |
|
|
|
|
|
Real flow matching requires a target endpoint to define the velocity field. |
|
|
We learn class prototypes in latent space as targets: for a sample of class c, |
|
|
the target is prototype[c]. The velocity field learns to transport the |
|
|
encoded feature z0 toward the correct prototype z1 in straight lines: |
|
|
|
|
|
v_target = z1 - z0 (rectified: straight path from source to target) |
|
|
loss = ||v_predicted - v_target||^2 (flow matching objective) |
|
|
|
|
|
At inference, the arbiter integrates the learned velocity field from z0, |
|
|
landing near the correct class prototype. Classification reads off the |
|
|
nearest prototype. |
|
|
|
|
|
Confidence gating: velocity magnitude is scaled by (1 - margin), so |
|
|
confident first-pass predictions receive minimal correction. |
|
|
""" |
|
|
|
|
|
def __init__(self, feat_dim, n_classes, n_steps=4, latent_dim=128, embed_dim=64): |
|
|
super().__init__() |
|
|
self.n_steps = n_steps |
|
|
self.n_classes = n_classes |
|
|
self.dt = 1.0 / n_steps |
|
|
self.latent_dim = latent_dim |
|
|
|
|
|
|
|
|
self.encode = nn.Sequential( |
|
|
nn.Linear(feat_dim, latent_dim * 2), nn.GELU(), |
|
|
nn.Linear(latent_dim * 2, latent_dim)) |
|
|
|
|
|
|
|
|
self.prototypes = nn.Parameter(torch.randn(n_classes, latent_dim) * 0.05) |
|
|
|
|
|
|
|
|
self.time_embed = nn.Sequential( |
|
|
nn.Linear(16, embed_dim), nn.GELU(), |
|
|
nn.Linear(embed_dim, embed_dim)) |
|
|
|
|
|
|
|
|
self.conf_embed = nn.Sequential( |
|
|
nn.Linear(3, embed_dim), nn.GELU(), |
|
|
nn.Linear(embed_dim, embed_dim)) |
|
|
|
|
|
|
|
|
vel_in = latent_dim + embed_dim + embed_dim |
|
|
self.velocity = nn.Sequential( |
|
|
SwiGLU(vel_in, latent_dim), |
|
|
nn.Linear(latent_dim, latent_dim), |
|
|
SwiGLU(latent_dim, latent_dim), |
|
|
nn.Linear(latent_dim, latent_dim)) |
|
|
|
|
|
|
|
|
self.vel_gate = nn.Sequential( |
|
|
nn.Linear(embed_dim, latent_dim), nn.Sigmoid()) |
|
|
|
|
|
|
|
|
self.classifier_head = nn.Sequential( |
|
|
SwiGLU(latent_dim + n_classes, 96), |
|
|
nn.Linear(96, n_classes)) |
|
|
|
|
|
|
|
|
self.blend_head = nn.Sequential( |
|
|
nn.Linear(feat_dim, 64), nn.GELU(), |
|
|
nn.Linear(64, 1), nn.Sigmoid()) |
|
|
|
|
|
|
|
|
self.refined_confidence = nn.Sequential( |
|
|
SwiGLU(latent_dim, 32), |
|
|
nn.Linear(32, 1), nn.Sigmoid()) |
|
|
|
|
|
def _time_encoding(self, t, device): |
|
|
freqs = torch.exp(torch.linspace(0, -4, 8, device=device)) |
|
|
args = t.unsqueeze(-1) * freqs.unsqueeze(0) |
|
|
return torch.cat([args.sin(), args.cos()], dim=-1) |
|
|
|
|
|
def _proto_logits(self, z): |
|
|
"""Classify by negative distance to prototypes.""" |
|
|
|
|
|
dists = torch.cdist(z.unsqueeze(0), self.prototypes.unsqueeze(0)).squeeze(0) |
|
|
|
|
|
combined = torch.cat([z, -dists], dim=-1) |
|
|
return self.classifier_head(combined) |
|
|
|
|
|
def forward(self, features, initial_logits, labels=None): |
|
|
""" |
|
|
features: (B, feat_dim) |
|
|
initial_logits: (B, n_classes) |
|
|
labels: (B,) — only during training, for flow matching target |
|
|
|
|
|
Returns: |
|
|
refined_logits, refined_conf, initial_conf, trajectory_logits, flow_loss |
|
|
""" |
|
|
B = features.shape[0] |
|
|
device = features.device |
|
|
|
|
|
|
|
|
initial_conf = compute_confidence(initial_logits) |
|
|
conf_input = torch.stack([ |
|
|
initial_conf["max_prob"], |
|
|
initial_conf["margin"], |
|
|
initial_conf["entropy"]], dim=-1) |
|
|
conf_emb = self.conf_embed(conf_input) |
|
|
|
|
|
|
|
|
gate = self.vel_gate(conf_emb) |
|
|
inv_conf = (1.0 - initial_conf["margin"]).unsqueeze(-1) |
|
|
adaptive_gate = gate * inv_conf |
|
|
|
|
|
|
|
|
z0 = self.encode(features) |
|
|
|
|
|
|
|
|
flow_loss = torch.tensor(0.0, device=device) |
|
|
if labels is not None: |
|
|
|
|
|
z1 = self.prototypes[labels] |
|
|
|
|
|
v_target = z1 - z0 |
|
|
|
|
|
|
|
|
t_rand = torch.rand(B, device=device) |
|
|
t_emb = self.time_embed(self._time_encoding(t_rand, device)) |
|
|
|
|
|
|
|
|
z_t = z0 + t_rand.unsqueeze(-1) * v_target |
|
|
|
|
|
|
|
|
vel_input = torch.cat([z_t, t_emb, conf_emb], dim=-1) |
|
|
v_pred = self.velocity(vel_input) * adaptive_gate |
|
|
v_pred = v_pred.clamp(-20, 20) |
|
|
|
|
|
|
|
|
flow_loss = F.mse_loss(v_pred, v_target.clamp(-20, 20)) |
|
|
|
|
|
|
|
|
z = z0 |
|
|
trajectory_logits = [] |
|
|
for step in range(self.n_steps): |
|
|
t_val = torch.full((B,), step * self.dt, device=device) |
|
|
t_emb = self.time_embed(self._time_encoding(t_val, device)) |
|
|
|
|
|
vel_input = torch.cat([z, t_emb, conf_emb], dim=-1) |
|
|
v = self.velocity(vel_input) * adaptive_gate |
|
|
|
|
|
v = v.clamp(-20, 20) |
|
|
|
|
|
z = z + self.dt * v |
|
|
trajectory_logits.append(self._proto_logits(z)) |
|
|
|
|
|
refined_logits = trajectory_logits[-1] |
|
|
refined_conf = self.refined_confidence(z) |
|
|
|
|
|
|
|
|
blend_weight = self.blend_head(features) |
|
|
|
|
|
return refined_logits, refined_conf, initial_conf, trajectory_logits, flow_loss, blend_weight |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GeometricShapeClassifier(nn.Module): |
|
|
def __init__(self, n_classes=NUM_CLASSES, embed_dim=64, n_tracers=5): |
|
|
super().__init__() |
|
|
self.n_tracers = n_tracers |
|
|
self.embed_dim = embed_dim |
|
|
|
|
|
self.voxel_embed = nn.Sequential( |
|
|
nn.Linear(4, embed_dim), nn.GELU(), nn.Linear(embed_dim, embed_dim)) |
|
|
|
|
|
coords = torch.stack(torch.meshgrid( |
|
|
torch.arange(GS, dtype=torch.float32), |
|
|
torch.arange(GS, dtype=torch.float32), |
|
|
torch.arange(GS, dtype=torch.float32), |
|
|
indexing="ij"), dim=-1) / (GS - 1) |
|
|
self.register_buffer("pos_grid", coords) |
|
|
|
|
|
self.tracer_tokens = nn.Parameter(torch.randn(n_tracers, embed_dim) * 0.02) |
|
|
self.tracer_attn = nn.MultiheadAttention(embed_dim, num_heads=4, batch_first=True) |
|
|
self.tracer_gate = nn.Sequential(nn.Linear(embed_dim * 2, embed_dim), nn.Sigmoid()) |
|
|
self.tracer_interact = nn.Sequential( |
|
|
nn.Linear(embed_dim * 2, embed_dim), nn.GELU(), nn.Linear(embed_dim, embed_dim)) |
|
|
|
|
|
self.edge_head = nn.Sequential( |
|
|
SwiGLU(embed_dim * 2, 32), nn.Linear(32, 1)) |
|
|
|
|
|
|
|
|
_pi, _pj = [], [] |
|
|
for i in range(n_tracers): |
|
|
for j in range(i + 1, n_tracers): |
|
|
_pi.append(i); _pj.append(j) |
|
|
self.register_buffer("_pair_i", torch.tensor(_pi, dtype=torch.long)) |
|
|
self.register_buffer("_pair_j", torch.tensor(_pj, dtype=torch.long)) |
|
|
self.n_pairs = len(_pi) |
|
|
|
|
|
pool_dim = embed_dim * n_tracers |
|
|
|
|
|
self.dim0 = CapacityHead(pool_dim, embed_dim, init_capacity=0.5) |
|
|
self.dim1 = CapacityHead(pool_dim + embed_dim, embed_dim, init_capacity=1.0) |
|
|
self.dim2 = CapacityHead(pool_dim + embed_dim, embed_dim, init_capacity=1.5) |
|
|
self.dim3 = CapacityHead(pool_dim + embed_dim, embed_dim, init_capacity=2.0) |
|
|
|
|
|
rigid_feat_dim = embed_dim * 4 |
|
|
self.curvature = CurvatureHead(rigid_feat_dim, fill_dim=4, embed_dim=embed_dim) |
|
|
|
|
|
class_in = pool_dim + 4 + rigid_feat_dim + embed_dim + 1 |
|
|
self.class_in = class_in |
|
|
self.classifier = nn.Sequential( |
|
|
nn.Linear(class_in, 256), nn.GELU(), nn.Dropout(0.1), |
|
|
nn.Linear(256, 128), nn.GELU(), nn.Linear(128, n_classes)) |
|
|
|
|
|
|
|
|
self.peak_head = nn.Sequential( |
|
|
SwiGLU(class_in, 32), nn.Linear(32, 4)) |
|
|
|
|
|
self.volume_head = nn.Sequential( |
|
|
nn.Linear(class_in, 64), nn.GELU(), nn.Linear(64, 1)) |
|
|
|
|
|
self.cm_head = nn.Sequential( |
|
|
SwiGLU(class_in, 64), nn.Linear(64, 1), nn.Tanh()) |
|
|
|
|
|
|
|
|
self.arbiter = RectifiedFlowArbiter( |
|
|
feat_dim=class_in, n_classes=n_classes, |
|
|
n_steps=4, latent_dim=128, embed_dim=embed_dim) |
|
|
|
|
|
def forward(self, grid, labels=None): |
|
|
B = grid.shape[0] |
|
|
occ = grid.reshape(B, GS**3, 1) |
|
|
pos = self.pos_grid.reshape(1, GS**3, 3).expand(B, -1, -1) |
|
|
voxel_emb = self.voxel_embed(torch.cat([occ, pos], dim=-1)) |
|
|
|
|
|
tracers = self.tracer_tokens.unsqueeze(0).expand(B, -1, -1) |
|
|
tracers, _ = self.tracer_attn(tracers, voxel_emb, voxel_emb) |
|
|
|
|
|
|
|
|
left = tracers[:, self._pair_i] |
|
|
right = tracers[:, self._pair_j] |
|
|
pairs = torch.cat([left, right], dim=-1) |
|
|
|
|
|
|
|
|
flat_pairs = pairs.reshape(B * self.n_pairs, -1) |
|
|
gate = self.tracer_gate(flat_pairs).reshape(B, self.n_pairs, -1) |
|
|
interaction = self.tracer_interact(flat_pairs).reshape(B, self.n_pairs, -1) |
|
|
edge_lengths = self.edge_head(flat_pairs).reshape(B, self.n_pairs) |
|
|
|
|
|
|
|
|
gated = gate * interaction |
|
|
tracer_out = tracers.clone() |
|
|
pi_exp = self._pair_i.view(1, self.n_pairs, 1).expand(B, -1, self.embed_dim) |
|
|
pj_exp = self._pair_j.view(1, self.n_pairs, 1).expand(B, -1, self.embed_dim) |
|
|
tracer_out.scatter_add_(1, pi_exp, gated) |
|
|
tracer_out.scatter_add_(1, pj_exp, gated) |
|
|
pooled = tracer_out.reshape(B, -1) |
|
|
|
|
|
fill0, ovf0, ret0, cap0, _ = self.dim0(pooled) |
|
|
fill1, ovf1, ret1, cap1, _ = self.dim1(torch.cat([pooled, ovf0], -1)) |
|
|
fill2, ovf2, ret2, cap2, _ = self.dim2(torch.cat([pooled, ovf1], -1)) |
|
|
fill3, ovf3, ret3, cap3, _ = self.dim3(torch.cat([pooled, ovf2], -1)) |
|
|
|
|
|
fill_ratios = torch.cat([fill0, fill1, fill2, fill3], dim=-1) |
|
|
rigid_retained = torch.cat([ret0, ret1, ret2, ret3], dim=-1) |
|
|
ovf_norms = torch.stack([ |
|
|
ovf0.norm(dim=-1), ovf1.norm(dim=-1), |
|
|
ovf2.norm(dim=-1), ovf3.norm(dim=-1)], dim=-1) |
|
|
|
|
|
is_curved, curv_logits, curv_feat, alternation = self.curvature(grid, rigid_retained, fill_ratios) |
|
|
full = torch.cat([pooled, fill_ratios, rigid_retained, curv_feat, is_curved], dim=-1) |
|
|
|
|
|
|
|
|
initial_logits = self.classifier(full) |
|
|
|
|
|
|
|
|
refined_logits, refined_conf, initial_conf, trajectory_logits, flow_loss, blend_weight = \ |
|
|
self.arbiter(full, initial_logits, labels=labels) |
|
|
|
|
|
|
|
|
|
|
|
final_logits = blend_weight * initial_logits + (1.0 - blend_weight) * refined_logits |
|
|
|
|
|
return { |
|
|
|
|
|
"class_logits": final_logits, |
|
|
"initial_logits": initial_logits, |
|
|
"refined_logits": refined_logits, |
|
|
"trajectory_logits": trajectory_logits, |
|
|
|
|
|
"flow_loss": flow_loss, |
|
|
|
|
|
"confidence": initial_conf["confidence"], |
|
|
"max_prob": initial_conf["max_prob"], |
|
|
"entropy": initial_conf["entropy"], |
|
|
"refined_confidence": refined_conf, |
|
|
"blend_weight": blend_weight.squeeze(-1), |
|
|
|
|
|
"peak_logits": self.peak_head(full), |
|
|
"volume_pred": self.volume_head(full).squeeze(-1), |
|
|
"cm_pred": self.cm_head(full).squeeze(-1), |
|
|
"edge_lengths": edge_lengths, |
|
|
"fill_ratios": fill_ratios, |
|
|
"overflows": ovf_norms, |
|
|
"capacities": torch.stack([cap0, cap1, cap2, cap3]), |
|
|
"is_curved_pred": is_curved, |
|
|
"curv_type_logits": curv_logits, |
|
|
"alternation": alternation, |
|
|
|
|
|
"features": full, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
_m = GeometricShapeClassifier() |
|
|
print(f'GeometricShapeClassifier: {sum(p.numel() for p in _m.parameters()):,} params') |
|
|
del _m |