Add tropical, compression, curvature, renormalization heads
Browse files- heads/__init__.py +2 -0
- heads/compression/__init__.py +1 -0
- heads/curvature/__init__.py +1 -0
- heads/renormalization/__init__.py +1 -0
- heads/renormalization/head.py +56 -0
- heads/tropical/__init__.py +1 -0
heads/__init__.py
CHANGED
|
@@ -8,6 +8,7 @@ from .log_linear.head import LogLinear
|
|
| 8 |
from .ordinal_regression.head import OrdinalRegression
|
| 9 |
from .multiscale_gradient.head import MultiscaleGradient
|
| 10 |
from .harmonic.head import HarmonicDepth
|
|
|
|
| 11 |
|
| 12 |
REGISTRY = {
|
| 13 |
"linear_probe": LinearProbe,
|
|
@@ -18,6 +19,7 @@ REGISTRY = {
|
|
| 18 |
"ordinal_regression": OrdinalRegression,
|
| 19 |
"multiscale_gradient": MultiscaleGradient,
|
| 20 |
"harmonic": HarmonicDepth,
|
|
|
|
| 21 |
}
|
| 22 |
|
| 23 |
ALL_NAMES = list(REGISTRY.keys())
|
|
|
|
| 8 |
from .ordinal_regression.head import OrdinalRegression
|
| 9 |
from .multiscale_gradient.head import MultiscaleGradient
|
| 10 |
from .harmonic.head import HarmonicDepth
|
| 11 |
+
from .renormalization.head import RenormalizationDepth
|
| 12 |
|
| 13 |
REGISTRY = {
|
| 14 |
"linear_probe": LinearProbe,
|
|
|
|
| 19 |
"ordinal_regression": OrdinalRegression,
|
| 20 |
"multiscale_gradient": MultiscaleGradient,
|
| 21 |
"harmonic": HarmonicDepth,
|
| 22 |
+
"renormalization": RenormalizationDepth,
|
| 23 |
}
|
| 24 |
|
| 25 |
ALL_NAMES = list(REGISTRY.keys())
|
heads/compression/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .head import *
|
heads/curvature/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .head import *
|
heads/renormalization/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .head import *
|
heads/renormalization/head.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Renormalization Depth: depth from the scale at which features stabilize.
|
| 2 |
+
|
| 3 |
+
Nearby objects fluctuate across fine RG steps (cofiber energy high at fine scales).
|
| 4 |
+
Distant objects are stable (cofiber energy concentrated at coarse scales).
|
| 5 |
+
Depth = learned weighted sum of per-scale cofiber energies. 5 parameters at N=5 scales.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class RenormalizationDepth(nn.Module):
|
| 14 |
+
"""Depth from renormalization group flow. 5 parameters at N=5 scales."""
|
| 15 |
+
name = "renormalization"
|
| 16 |
+
needs_intermediates = False
|
| 17 |
+
|
| 18 |
+
def __init__(self, feat_dim=768, n_scales=5, min_depth=0.001, max_depth=10.0):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.n_scales = n_scales
|
| 21 |
+
self.min_depth = min_depth
|
| 22 |
+
self.max_depth = max_depth
|
| 23 |
+
# One weight per scale: how much does fluctuation at this scale contribute to depth
|
| 24 |
+
self.scale_weights = nn.Parameter(torch.linspace(-1, 1, n_scales))
|
| 25 |
+
self.bias = nn.Parameter(torch.tensor(0.5))
|
| 26 |
+
|
| 27 |
+
def forward(self, spatial, inter=None):
|
| 28 |
+
B, C, H, W = spatial.shape
|
| 29 |
+
target_size = (H, W)
|
| 30 |
+
energies = []
|
| 31 |
+
residual = spatial
|
| 32 |
+
|
| 33 |
+
for _ in range(self.n_scales - 1):
|
| 34 |
+
h, w = residual.shape[2], residual.shape[3]
|
| 35 |
+
if h < 2 or w < 2:
|
| 36 |
+
energies.append(torch.zeros(B, 1, *target_size, device=spatial.device))
|
| 37 |
+
continue
|
| 38 |
+
omega = F.avg_pool2d(residual, 2)
|
| 39 |
+
sigma_omega = F.interpolate(omega, size=(h, w), mode="bilinear", align_corners=False)
|
| 40 |
+
cofib = residual - sigma_omega
|
| 41 |
+
energy = cofib.pow(2).sum(dim=1, keepdim=True)
|
| 42 |
+
energy = F.interpolate(energy, size=target_size, mode="bilinear", align_corners=False)
|
| 43 |
+
energies.append(energy)
|
| 44 |
+
residual = omega
|
| 45 |
+
|
| 46 |
+
# Coarsest residual energy
|
| 47 |
+
res_energy = residual.pow(2).sum(dim=1, keepdim=True)
|
| 48 |
+
res_energy = F.interpolate(res_energy, size=target_size, mode="bilinear", align_corners=False)
|
| 49 |
+
energies.append(res_energy)
|
| 50 |
+
|
| 51 |
+
# Weighted sum: depth = sigmoid(sum(w_k * energy_k) + bias) * depth_range
|
| 52 |
+
stacked = torch.stack(energies, dim=0) # (n_scales, B, 1, H, W)
|
| 53 |
+
weights = self.scale_weights.reshape(-1, 1, 1, 1, 1)
|
| 54 |
+
depth = (stacked * weights).sum(dim=0) + self.bias
|
| 55 |
+
depth = torch.sigmoid(depth) * (self.max_depth - self.min_depth) + self.min_depth
|
| 56 |
+
return depth
|
heads/tropical/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .head import *
|