phanerozoic commited on
Commit
78ae7da
·
verified ·
1 Parent(s): a103957

Add novel heads: Optimal Transport (det), Info Bottleneck (seg), Harmonic Depth (depth)

Browse files
heads/__init__.py CHANGED
@@ -7,6 +7,7 @@ from .wavelet.head import Wavelet
7
  from .log_linear.head import LogLinear
8
  from .ordinal_regression.head import OrdinalRegression
9
  from .multiscale_gradient.head import MultiscaleGradient
 
10
 
11
  REGISTRY = {
12
  "linear_probe": LinearProbe,
@@ -16,6 +17,7 @@ REGISTRY = {
16
  "log_linear": LogLinear,
17
  "ordinal_regression": OrdinalRegression,
18
  "multiscale_gradient": MultiscaleGradient,
 
19
  }
20
 
21
  ALL_NAMES = list(REGISTRY.keys())
 
7
  from .log_linear.head import LogLinear
8
  from .ordinal_regression.head import OrdinalRegression
9
  from .multiscale_gradient.head import MultiscaleGradient
10
+ from .harmonic.head import HarmonicDepth
11
 
12
  REGISTRY = {
13
  "linear_probe": LinearProbe,
 
17
  "log_linear": LogLinear,
18
  "ordinal_regression": OrdinalRegression,
19
  "multiscale_gradient": MultiscaleGradient,
20
+ "harmonic": HarmonicDepth,
21
  }
22
 
23
  ALL_NAMES = list(REGISTRY.keys())
heads/harmonic/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .head import HarmonicDepth
heads/harmonic/head.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Harmonic Depth: cofiber edge detection + boundary depth prediction + PDE solve.
2
+
3
+ 770 parameters. Depth at edge locations is predicted by a single linear layer.
4
+ Depth at non-edge locations is solved via the discrete Laplace equation
5
+ (iterative Jacobi relaxation). The cofiber energy identifies edges — locations
6
+ where the feature changes across scales, which correspond to depth discontinuities.
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ def cofiber_energy(spatial):
15
+ """Compute per-location cofiber energy (L2 norm of high-freq residual)."""
16
+ omega = F.avg_pool2d(spatial, 2)
17
+ sigma_omega = F.interpolate(omega, size=spatial.shape[2:], mode="bilinear", align_corners=False)
18
+ cofib = spatial - sigma_omega
19
+ return cofib.pow(2).sum(dim=1, keepdim=True)
20
+
21
+
22
+ def jacobi_solve(depth, mask, n_iters=50):
23
+ """Solve discrete Laplace equation at non-masked locations.
24
+ At mask=1 (edge) locations, depth is fixed. At mask=0 locations,
25
+ depth = average of 4 neighbors. Iterated to convergence."""
26
+ for _ in range(n_iters):
27
+ # Pad with replicate boundary conditions
28
+ padded = F.pad(depth, (1, 1, 1, 1), mode="replicate")
29
+ avg = (padded[:, :, :-2, 1:-1] + padded[:, :, 2:, 1:-1] +
30
+ padded[:, :, 1:-1, :-2] + padded[:, :, 1:-1, 2:]) / 4
31
+ # Update only non-edge locations
32
+ depth = mask * depth + (1 - mask) * avg
33
+ return depth
34
+
35
+
36
+ class HarmonicDepth(nn.Module):
37
+ """770 parameters. Depth solved as a harmonic function with learned boundary values."""
38
+ name = "harmonic"
39
+ needs_intermediates = False
40
+
41
+ def __init__(self, feat_dim=768, min_depth=0.001, max_depth=10.0,
42
+ edge_percentile=0.2, n_iters=50):
43
+ super().__init__()
44
+ self.min_depth = min_depth
45
+ self.max_depth = max_depth
46
+ self.edge_percentile = edge_percentile
47
+ self.n_iters = n_iters
48
+ # Boundary depth predictor: single linear layer at edge locations
49
+ self.depth_proj = nn.Conv2d(feat_dim, 1, 1)
50
+
51
+ def forward(self, spatial, inter=None):
52
+ B, C, H, W = spatial.shape
53
+
54
+ # Cofiber energy identifies edges
55
+ energy = cofiber_energy(spatial)
56
+
57
+ # Threshold: top edge_percentile of locations are edges
58
+ flat_energy = energy.reshape(B, -1)
59
+ k = max(1, int(H * W * self.edge_percentile))
60
+ threshold = flat_energy.topk(k, dim=1).values[:, -1:].reshape(B, 1, 1, 1)
61
+ edge_mask = (energy >= threshold).float()
62
+
63
+ # Predict depth at edge locations
64
+ boundary_depth = self.depth_proj(spatial)
65
+ boundary_depth = boundary_depth.clamp(self.min_depth, self.max_depth)
66
+
67
+ # Initialize: boundary values at edges, mean depth elsewhere
68
+ mean_depth = (boundary_depth * edge_mask).sum(dim=(2, 3), keepdim=True) / edge_mask.sum(dim=(2, 3), keepdim=True).clamp(min=1)
69
+ depth = edge_mask * boundary_depth + (1 - edge_mask) * mean_depth
70
+
71
+ # Solve Laplace equation at non-edge locations
72
+ depth = jacobi_solve(depth, edge_mask, self.n_iters)
73
+
74
+ return depth.clamp(self.min_depth, self.max_depth)