phanerozoic commited on
Commit
fefb215
·
verified ·
1 Parent(s): 78ae7da

Add tropical, compression, curvature, renormalization heads

Browse files
heads/__init__.py CHANGED
@@ -8,6 +8,7 @@ from .log_linear.head import LogLinear
8
  from .ordinal_regression.head import OrdinalRegression
9
  from .multiscale_gradient.head import MultiscaleGradient
10
  from .harmonic.head import HarmonicDepth
 
11
 
12
  REGISTRY = {
13
  "linear_probe": LinearProbe,
@@ -18,6 +19,7 @@ REGISTRY = {
18
  "ordinal_regression": OrdinalRegression,
19
  "multiscale_gradient": MultiscaleGradient,
20
  "harmonic": HarmonicDepth,
 
21
  }
22
 
23
  ALL_NAMES = list(REGISTRY.keys())
 
8
  from .ordinal_regression.head import OrdinalRegression
9
  from .multiscale_gradient.head import MultiscaleGradient
10
  from .harmonic.head import HarmonicDepth
11
+ from .renormalization.head import RenormalizationDepth
12
 
13
  REGISTRY = {
14
  "linear_probe": LinearProbe,
 
19
  "ordinal_regression": OrdinalRegression,
20
  "multiscale_gradient": MultiscaleGradient,
21
  "harmonic": HarmonicDepth,
22
+ "renormalization": RenormalizationDepth,
23
  }
24
 
25
  ALL_NAMES = list(REGISTRY.keys())
heads/compression/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .head import *
heads/curvature/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .head import *
heads/renormalization/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .head import *
heads/renormalization/head.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Renormalization Depth: depth from the scale at which features stabilize.
2
+
3
+ Nearby objects fluctuate across fine RG steps (cofiber energy high at fine scales).
4
+ Distant objects are stable (cofiber energy concentrated at coarse scales).
5
+ Depth = learned weighted sum of per-scale cofiber energies. 5 parameters at N=5 scales.
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ class RenormalizationDepth(nn.Module):
14
+ """Depth from renormalization group flow. 5 parameters at N=5 scales."""
15
+ name = "renormalization"
16
+ needs_intermediates = False
17
+
18
+ def __init__(self, feat_dim=768, n_scales=5, min_depth=0.001, max_depth=10.0):
19
+ super().__init__()
20
+ self.n_scales = n_scales
21
+ self.min_depth = min_depth
22
+ self.max_depth = max_depth
23
+ # One weight per scale: how much does fluctuation at this scale contribute to depth
24
+ self.scale_weights = nn.Parameter(torch.linspace(-1, 1, n_scales))
25
+ self.bias = nn.Parameter(torch.tensor(0.5))
26
+
27
+ def forward(self, spatial, inter=None):
28
+ B, C, H, W = spatial.shape
29
+ target_size = (H, W)
30
+ energies = []
31
+ residual = spatial
32
+
33
+ for _ in range(self.n_scales - 1):
34
+ h, w = residual.shape[2], residual.shape[3]
35
+ if h < 2 or w < 2:
36
+ energies.append(torch.zeros(B, 1, *target_size, device=spatial.device))
37
+ continue
38
+ omega = F.avg_pool2d(residual, 2)
39
+ sigma_omega = F.interpolate(omega, size=(h, w), mode="bilinear", align_corners=False)
40
+ cofib = residual - sigma_omega
41
+ energy = cofib.pow(2).sum(dim=1, keepdim=True)
42
+ energy = F.interpolate(energy, size=target_size, mode="bilinear", align_corners=False)
43
+ energies.append(energy)
44
+ residual = omega
45
+
46
+ # Coarsest residual energy
47
+ res_energy = residual.pow(2).sum(dim=1, keepdim=True)
48
+ res_energy = F.interpolate(res_energy, size=target_size, mode="bilinear", align_corners=False)
49
+ energies.append(res_energy)
50
+
51
+ # Weighted sum: depth = sigmoid(sum(w_k * energy_k) + bias) * depth_range
52
+ stacked = torch.stack(energies, dim=0) # (n_scales, B, 1, H, W)
53
+ weights = self.scale_weights.reshape(-1, 1, 1, 1, 1)
54
+ depth = (stacked * weights).sum(dim=0) + self.bias
55
+ depth = torch.sigmoid(depth) * (self.max_depth - self.min_depth) + self.min_depth
56
+ return depth
heads/tropical/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .head import *