Add novel heads: Optimal Transport (det), Info Bottleneck (seg), Harmonic Depth (depth)
Browse files- heads/__init__.py +2 -0
- heads/harmonic/__init__.py +1 -0
- heads/harmonic/head.py +74 -0
heads/__init__.py
CHANGED
|
@@ -7,6 +7,7 @@ from .wavelet.head import Wavelet
|
|
| 7 |
from .log_linear.head import LogLinear
|
| 8 |
from .ordinal_regression.head import OrdinalRegression
|
| 9 |
from .multiscale_gradient.head import MultiscaleGradient
|
|
|
|
| 10 |
|
| 11 |
REGISTRY = {
|
| 12 |
"linear_probe": LinearProbe,
|
|
@@ -16,6 +17,7 @@ REGISTRY = {
|
|
| 16 |
"log_linear": LogLinear,
|
| 17 |
"ordinal_regression": OrdinalRegression,
|
| 18 |
"multiscale_gradient": MultiscaleGradient,
|
|
|
|
| 19 |
}
|
| 20 |
|
| 21 |
ALL_NAMES = list(REGISTRY.keys())
|
|
|
|
| 7 |
from .log_linear.head import LogLinear
|
| 8 |
from .ordinal_regression.head import OrdinalRegression
|
| 9 |
from .multiscale_gradient.head import MultiscaleGradient
|
| 10 |
+
from .harmonic.head import HarmonicDepth
|
| 11 |
|
| 12 |
REGISTRY = {
|
| 13 |
"linear_probe": LinearProbe,
|
|
|
|
| 17 |
"log_linear": LogLinear,
|
| 18 |
"ordinal_regression": OrdinalRegression,
|
| 19 |
"multiscale_gradient": MultiscaleGradient,
|
| 20 |
+
"harmonic": HarmonicDepth,
|
| 21 |
}
|
| 22 |
|
| 23 |
ALL_NAMES = list(REGISTRY.keys())
|
heads/harmonic/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .head import HarmonicDepth
|
heads/harmonic/head.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Harmonic Depth: cofiber edge detection + boundary depth prediction + PDE solve.
|
| 2 |
+
|
| 3 |
+
770 parameters. Depth at edge locations is predicted by a single linear layer.
|
| 4 |
+
Depth at non-edge locations is solved via the discrete Laplace equation
|
| 5 |
+
(iterative Jacobi relaxation). The cofiber energy identifies edges — locations
|
| 6 |
+
where the feature changes across scales, which correspond to depth discontinuities.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def cofiber_energy(spatial):
|
| 15 |
+
"""Compute per-location cofiber energy (L2 norm of high-freq residual)."""
|
| 16 |
+
omega = F.avg_pool2d(spatial, 2)
|
| 17 |
+
sigma_omega = F.interpolate(omega, size=spatial.shape[2:], mode="bilinear", align_corners=False)
|
| 18 |
+
cofib = spatial - sigma_omega
|
| 19 |
+
return cofib.pow(2).sum(dim=1, keepdim=True)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def jacobi_solve(depth, mask, n_iters=50):
|
| 23 |
+
"""Solve discrete Laplace equation at non-masked locations.
|
| 24 |
+
At mask=1 (edge) locations, depth is fixed. At mask=0 locations,
|
| 25 |
+
depth = average of 4 neighbors. Iterated to convergence."""
|
| 26 |
+
for _ in range(n_iters):
|
| 27 |
+
# Pad with replicate boundary conditions
|
| 28 |
+
padded = F.pad(depth, (1, 1, 1, 1), mode="replicate")
|
| 29 |
+
avg = (padded[:, :, :-2, 1:-1] + padded[:, :, 2:, 1:-1] +
|
| 30 |
+
padded[:, :, 1:-1, :-2] + padded[:, :, 1:-1, 2:]) / 4
|
| 31 |
+
# Update only non-edge locations
|
| 32 |
+
depth = mask * depth + (1 - mask) * avg
|
| 33 |
+
return depth
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class HarmonicDepth(nn.Module):
|
| 37 |
+
"""770 parameters. Depth solved as a harmonic function with learned boundary values."""
|
| 38 |
+
name = "harmonic"
|
| 39 |
+
needs_intermediates = False
|
| 40 |
+
|
| 41 |
+
def __init__(self, feat_dim=768, min_depth=0.001, max_depth=10.0,
|
| 42 |
+
edge_percentile=0.2, n_iters=50):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.min_depth = min_depth
|
| 45 |
+
self.max_depth = max_depth
|
| 46 |
+
self.edge_percentile = edge_percentile
|
| 47 |
+
self.n_iters = n_iters
|
| 48 |
+
# Boundary depth predictor: single linear layer at edge locations
|
| 49 |
+
self.depth_proj = nn.Conv2d(feat_dim, 1, 1)
|
| 50 |
+
|
| 51 |
+
def forward(self, spatial, inter=None):
|
| 52 |
+
B, C, H, W = spatial.shape
|
| 53 |
+
|
| 54 |
+
# Cofiber energy identifies edges
|
| 55 |
+
energy = cofiber_energy(spatial)
|
| 56 |
+
|
| 57 |
+
# Threshold: top edge_percentile of locations are edges
|
| 58 |
+
flat_energy = energy.reshape(B, -1)
|
| 59 |
+
k = max(1, int(H * W * self.edge_percentile))
|
| 60 |
+
threshold = flat_energy.topk(k, dim=1).values[:, -1:].reshape(B, 1, 1, 1)
|
| 61 |
+
edge_mask = (energy >= threshold).float()
|
| 62 |
+
|
| 63 |
+
# Predict depth at edge locations
|
| 64 |
+
boundary_depth = self.depth_proj(spatial)
|
| 65 |
+
boundary_depth = boundary_depth.clamp(self.min_depth, self.max_depth)
|
| 66 |
+
|
| 67 |
+
# Initialize: boundary values at edges, mean depth elsewhere
|
| 68 |
+
mean_depth = (boundary_depth * edge_mask).sum(dim=(2, 3), keepdim=True) / edge_mask.sum(dim=(2, 3), keepdim=True).clamp(min=1)
|
| 69 |
+
depth = edge_mask * boundary_depth + (1 - edge_mask) * mean_depth
|
| 70 |
+
|
| 71 |
+
# Solve Laplace equation at non-edge locations
|
| 72 |
+
depth = jacobi_solve(depth, edge_mask, self.n_iters)
|
| 73 |
+
|
| 74 |
+
return depth.clamp(self.min_depth, self.max_depth)
|