phanerozoic's picture
Add tropical, compression, curvature, renormalization heads
fefb215 verified
"""Renormalization Depth: depth from the scale at which features stabilize.
Nearby objects fluctuate across fine RG steps (cofiber energy high at fine scales).
Distant objects are stable (cofiber energy concentrated at coarse scales).
Depth = learned weighted sum of per-scale cofiber energies. 5 parameters at N=5 scales.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class RenormalizationDepth(nn.Module):
"""Depth from renormalization group flow. 5 parameters at N=5 scales."""
name = "renormalization"
needs_intermediates = False
def __init__(self, feat_dim=768, n_scales=5, min_depth=0.001, max_depth=10.0):
super().__init__()
self.n_scales = n_scales
self.min_depth = min_depth
self.max_depth = max_depth
# One weight per scale: how much does fluctuation at this scale contribute to depth
self.scale_weights = nn.Parameter(torch.linspace(-1, 1, n_scales))
self.bias = nn.Parameter(torch.tensor(0.5))
def forward(self, spatial, inter=None):
B, C, H, W = spatial.shape
target_size = (H, W)
energies = []
residual = spatial
for _ in range(self.n_scales - 1):
h, w = residual.shape[2], residual.shape[3]
if h < 2 or w < 2:
energies.append(torch.zeros(B, 1, *target_size, device=spatial.device))
continue
omega = F.avg_pool2d(residual, 2)
sigma_omega = F.interpolate(omega, size=(h, w), mode="bilinear", align_corners=False)
cofib = residual - sigma_omega
energy = cofib.pow(2).sum(dim=1, keepdim=True)
energy = F.interpolate(energy, size=target_size, mode="bilinear", align_corners=False)
energies.append(energy)
residual = omega
# Coarsest residual energy
res_energy = residual.pow(2).sum(dim=1, keepdim=True)
res_energy = F.interpolate(res_energy, size=target_size, mode="bilinear", align_corners=False)
energies.append(res_energy)
# Weighted sum: depth = sigmoid(sum(w_k * energy_k) + bias) * depth_range
stacked = torch.stack(energies, dim=0) # (n_scales, B, 1, H, W)
weights = self.scale_weights.reshape(-1, 1, 1, 1, 1)
depth = (stacked * weights).sum(dim=0) + self.bias
depth = torch.sigmoid(depth) * (self.max_depth - self.min_depth) + self.min_depth
return depth