phanerozoic commited on
Commit
a103957
·
verified ·
1 Parent(s): 4f1f9bf

7 depth head candidates with shared losses/utils and registry

Browse files
heads/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Depth head registry."""
2
+
3
+ from .linear_probe.head import LinearProbe
4
+ from .cofiber_linear.head import CofiberLinear
5
+ from .cofiber_threshold.head import CofiberThreshold
6
+ from .wavelet.head import Wavelet
7
+ from .log_linear.head import LogLinear
8
+ from .ordinal_regression.head import OrdinalRegression
9
+ from .multiscale_gradient.head import MultiscaleGradient
10
+
11
+ REGISTRY = {
12
+ "linear_probe": LinearProbe,
13
+ "cofiber_linear": CofiberLinear,
14
+ "cofiber_threshold": CofiberThreshold,
15
+ "wavelet": Wavelet,
16
+ "log_linear": LogLinear,
17
+ "ordinal_regression": OrdinalRegression,
18
+ "multiscale_gradient": MultiscaleGradient,
19
+ }
20
+
21
+ ALL_NAMES = list(REGISTRY.keys())
22
+
23
+ def get_head(name):
24
+ if name not in REGISTRY:
25
+ raise ValueError(f"Unknown head: {name}. Available: {ALL_NAMES}")
26
+ return REGISTRY[name]()
heads/cofiber_linear/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .head import *
heads/cofiber_linear/head.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Cofiber Linear: analytic decomposition + shared depth bin prediction per scale."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def cofiber_decompose(f, n_scales):
9
+ cofibers = []
10
+ residual = f
11
+ for _ in range(n_scales - 1):
12
+ omega = F.avg_pool2d(residual, 2)
13
+ sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False)
14
+ cofibers.append(residual - sigma_omega)
15
+ residual = omega
16
+ cofibers.append(residual)
17
+ return cofibers
18
+
19
+
20
+ class CofiberLinear(nn.Module):
21
+ name = "cofiber_linear"
22
+ needs_intermediates = False
23
+
24
+ def __init__(self, feat_dim=768, n_bins=256, min_depth=0.001, max_depth=10.0, n_scales=3):
25
+ super().__init__()
26
+ self.n_scales = n_scales
27
+ self.n_bins = n_bins
28
+ self.min_depth = min_depth
29
+ self.max_depth = max_depth
30
+ self.conv = nn.Conv2d(feat_dim, n_bins, 1)
31
+
32
+ def forward(self, spatial, inter=None):
33
+ cofibers = cofiber_decompose(spatial, self.n_scales)
34
+ target_size = spatial.shape[2:]
35
+ logits = None
36
+ for cof in cofibers:
37
+ out = self.conv(cof)
38
+ out = F.interpolate(out, size=target_size, mode="bilinear", align_corners=False)
39
+ logits = out if logits is None else logits + out
40
+ dist = torch.relu(logits) + 0.1
41
+ dist = dist / dist.sum(dim=1, keepdim=True)
42
+ bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=spatial.device)
43
+ return torch.einsum("bkhw,k->bhw", dist, bins).unsqueeze(1)
heads/cofiber_threshold/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .head import *
heads/cofiber_threshold/head.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Cofiber Threshold: analytic decomposition + per-scale LayerNorm + prototype depth prediction."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def cofiber_decompose(f, n_scales):
9
+ cofibers = []
10
+ residual = f
11
+ for _ in range(n_scales - 1):
12
+ omega = F.avg_pool2d(residual, 2)
13
+ sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False)
14
+ cofibers.append(residual - sigma_omega)
15
+ residual = omega
16
+ cofibers.append(residual)
17
+ return cofibers
18
+
19
+
20
+ class CofiberThreshold(nn.Module):
21
+ name = "cofiber_threshold"
22
+ needs_intermediates = False
23
+
24
+ def __init__(self, feat_dim=768, n_bins=256, min_depth=0.001, max_depth=10.0, n_scales=3):
25
+ super().__init__()
26
+ self.n_scales = n_scales
27
+ self.n_bins = n_bins
28
+ self.min_depth = min_depth
29
+ self.max_depth = max_depth
30
+ self.scale_norms = nn.ModuleList([nn.LayerNorm(feat_dim) for _ in range(n_scales)])
31
+ self.weight = nn.Parameter(torch.randn(n_bins, feat_dim) * 0.01)
32
+ self.bias = nn.Parameter(torch.zeros(n_bins))
33
+
34
+ def forward(self, spatial, inter=None):
35
+ cofibers = cofiber_decompose(spatial, self.n_scales)
36
+ target_size = spatial.shape[2:]
37
+ logits = None
38
+ for i, cof in enumerate(cofibers):
39
+ B, C, H, W = cof.shape
40
+ f = self.scale_norms[i](cof.permute(0, 2, 3, 1).reshape(-1, C))
41
+ out = (f @ self.weight.T + self.bias).reshape(B, H, W, -1).permute(0, 3, 1, 2)
42
+ out = F.interpolate(out, size=target_size, mode="bilinear", align_corners=False)
43
+ logits = out if logits is None else logits + out
44
+ dist = torch.relu(logits) + 0.1
45
+ dist = dist / dist.sum(dim=1, keepdim=True)
46
+ bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=spatial.device)
47
+ return torch.einsum("bkhw,k->bhw", dist, bins).unsqueeze(1)
heads/linear_probe/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .head import *
heads/linear_probe/head.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Linear probe: BatchNorm + 1x1 conv -> 256 depth bins. The EUPE paper baseline."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class LinearProbe(nn.Module):
8
+ name = "linear_probe"
9
+ needs_intermediates = False
10
+
11
+ def __init__(self, feat_dim=768, n_bins=256, min_depth=0.001, max_depth=10.0):
12
+ super().__init__()
13
+ self.bn = nn.BatchNorm2d(feat_dim)
14
+ self.conv = nn.Conv2d(feat_dim, n_bins, 1)
15
+ self.n_bins = n_bins
16
+ self.min_depth = min_depth
17
+ self.max_depth = max_depth
18
+
19
+ def forward(self, spatial, inter=None):
20
+ logits = self.conv(self.bn(spatial))
21
+ dist = torch.relu(logits) + 0.1
22
+ dist = dist / dist.sum(dim=1, keepdim=True)
23
+ bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=spatial.device)
24
+ return torch.einsum("bkhw,k->bhw", dist, bins).unsqueeze(1)
heads/log_linear/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .head import *
heads/log_linear/head.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Log-Linear: predict log-depth with a single linear layer. 769 parameters."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class LogLinear(nn.Module):
8
+ name = "log_linear"
9
+ needs_intermediates = False
10
+
11
+ def __init__(self, feat_dim=768, min_depth=0.001, max_depth=10.0):
12
+ super().__init__()
13
+ self.conv = nn.Conv2d(feat_dim, 1, 1)
14
+ self.min_depth = min_depth
15
+ self.max_depth = max_depth
16
+
17
+ def forward(self, spatial, inter=None):
18
+ log_depth = self.conv(spatial)
19
+ return log_depth.exp().clamp(self.min_depth, self.max_depth)
heads/multiscale_gradient/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .head import *
heads/multiscale_gradient/head.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Multi-scale Gradient: predict depth gradients per cofiber scale, integrate for absolute depth.
2
+ Cofiber features are inherently edge-like; predicting gradients aligns with the feature structure."""
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+
9
+ def cofiber_decompose(f, n_scales):
10
+ cofibers = []
11
+ residual = f
12
+ for _ in range(n_scales - 1):
13
+ omega = F.avg_pool2d(residual, 2)
14
+ sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False)
15
+ cofibers.append(residual - sigma_omega)
16
+ residual = omega
17
+ cofibers.append(residual)
18
+ return cofibers
19
+
20
+
21
+ class MultiscaleGradient(nn.Module):
22
+ name = "multiscale_gradient"
23
+ needs_intermediates = False
24
+
25
+ def __init__(self, feat_dim=768, n_scales=3, min_depth=0.001, max_depth=10.0):
26
+ super().__init__()
27
+ self.n_scales = n_scales
28
+ self.min_depth = min_depth
29
+ self.max_depth = max_depth
30
+ # Per-scale: predict dx and dy gradients
31
+ self.grad_heads = nn.ModuleList([nn.Conv2d(feat_dim, 2, 1) for _ in range(n_scales)])
32
+ # Base depth from coarsest residual
33
+ self.base_head = nn.Conv2d(feat_dim, 1, 1)
34
+
35
+ def forward(self, spatial, inter=None):
36
+ cofibers = cofiber_decompose(spatial, self.n_scales)
37
+ target_size = spatial.shape[2:]
38
+
39
+ # Base depth from coarsest scale
40
+ base = self.base_head(cofibers[-1])
41
+ base = F.interpolate(base, size=target_size, mode="bilinear", align_corners=False)
42
+ depth = base
43
+
44
+ # Add integrated gradients from each finer scale
45
+ for i in range(self.n_scales - 1):
46
+ grads = self.grad_heads[i](cofibers[i])
47
+ grads = F.interpolate(grads, size=target_size, mode="bilinear", align_corners=False)
48
+ dx, dy = grads[:, 0:1], grads[:, 1:2]
49
+ depth = depth + dx.cumsum(dim=3) + dy.cumsum(dim=2)
50
+
51
+ return depth.clamp(self.min_depth, self.max_depth)
heads/ordinal_regression/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .head import *
heads/ordinal_regression/head.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Ordinal Regression: predict 'is this pixel deeper than threshold t?' for K thresholds.
2
+ Each threshold is a 768->1 linear classifier. Depth = sum of positive predictions * bin width."""
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+
8
+ class OrdinalRegression(nn.Module):
9
+ name = "ordinal_regression"
10
+ needs_intermediates = False
11
+
12
+ def __init__(self, feat_dim=768, n_thresholds=64, min_depth=0.001, max_depth=10.0):
13
+ super().__init__()
14
+ self.conv = nn.Conv2d(feat_dim, n_thresholds, 1)
15
+ self.n_thresholds = n_thresholds
16
+ self.min_depth = min_depth
17
+ self.max_depth = max_depth
18
+
19
+ def forward(self, spatial, inter=None):
20
+ logits = self.conv(spatial)
21
+ probs = torch.sigmoid(logits)
22
+ bin_width = (self.max_depth - self.min_depth) / self.n_thresholds
23
+ depth = self.min_depth + probs.sum(dim=1, keepdim=True) * bin_width
24
+ return depth
heads/wavelet/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .head import *
heads/wavelet/head.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Wavelet: Haar decomposition + per-subband depth bin prediction."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class Wavelet(nn.Module):
9
+ name = "wavelet"
10
+ needs_intermediates = False
11
+
12
+ def __init__(self, feat_dim=768, n_bins=256, min_depth=0.001, max_depth=10.0, n_scales=3):
13
+ super().__init__()
14
+ self.n_scales = n_scales
15
+ self.n_bins = n_bins
16
+ self.min_depth = min_depth
17
+ self.max_depth = max_depth
18
+ self.heads = nn.ModuleList([nn.Conv2d(feat_dim, n_bins, 1) for _ in range(n_scales)])
19
+
20
+ @staticmethod
21
+ def haar_down(x):
22
+ h, w = x.shape[2], x.shape[3]
23
+ x = x[:, :, :h - h % 2, :w - w % 2]
24
+ return (x[:, :, 0::2, 0::2] + x[:, :, 0::2, 1::2] +
25
+ x[:, :, 1::2, 0::2] + x[:, :, 1::2, 1::2]) / 4
26
+
27
+ def forward(self, spatial, inter=None):
28
+ target_size = spatial.shape[2:]
29
+ f = spatial
30
+ logits = None
31
+ for i in range(self.n_scales):
32
+ out = self.heads[i](f)
33
+ out = F.interpolate(out, size=target_size, mode="bilinear", align_corners=False)
34
+ logits = out if logits is None else logits + out
35
+ if i < self.n_scales - 1:
36
+ f = self.haar_down(f)
37
+ dist = torch.relu(logits) + 0.1
38
+ dist = dist / dist.sum(dim=1, keepdim=True)
39
+ bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=spatial.device)
40
+ return torch.einsum("bkhw,k->bhw", dist, bins).unsqueeze(1)
losses/__init__.py ADDED
File without changes
losses/depth.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Depth losses."""
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+
7
+ def silog_loss(pred, target, mask=None, variance_focus=0.85):
8
+ """Scale-invariant logarithmic loss."""
9
+ pred = pred.flatten(1)
10
+ target = target.flatten(1)
11
+ if mask is not None:
12
+ mask = mask.flatten(1).bool()
13
+ pred = pred[mask]
14
+ target = target[mask]
15
+ else:
16
+ pred = pred.reshape(-1)
17
+ target = target.reshape(-1)
18
+ pred = pred.clamp(min=1e-6)
19
+ target = target.clamp(min=1e-6)
20
+ d = torch.log(pred) - torch.log(target)
21
+ return torch.sqrt((d ** 2).mean() - variance_focus * (d.mean() ** 2) + 1e-8)
22
+
23
+
24
+ def l1_depth_loss(pred, target, mask=None):
25
+ """Simple L1 loss on depth values."""
26
+ if mask is not None:
27
+ return F.l1_loss(pred[mask], target[mask])
28
+ return F.l1_loss(pred, target)
utils/__init__.py ADDED
File without changes