| """Renormalization Depth: depth from the scale at which features stabilize. |
| |
| Nearby objects fluctuate across fine RG steps (cofiber energy high at fine scales). |
| Distant objects are stable (cofiber energy concentrated at coarse scales). |
| Depth = learned weighted sum of per-scale cofiber energies. 5 parameters at N=5 scales. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class RenormalizationDepth(nn.Module): |
| """Depth from renormalization group flow. 5 parameters at N=5 scales.""" |
| name = "renormalization" |
| needs_intermediates = False |
|
|
| def __init__(self, feat_dim=768, n_scales=5, min_depth=0.001, max_depth=10.0): |
| super().__init__() |
| self.n_scales = n_scales |
| self.min_depth = min_depth |
| self.max_depth = max_depth |
| |
| self.scale_weights = nn.Parameter(torch.linspace(-1, 1, n_scales)) |
| self.bias = nn.Parameter(torch.tensor(0.5)) |
|
|
| def forward(self, spatial, inter=None): |
| B, C, H, W = spatial.shape |
| target_size = (H, W) |
| energies = [] |
| residual = spatial |
|
|
| for _ in range(self.n_scales - 1): |
| h, w = residual.shape[2], residual.shape[3] |
| if h < 2 or w < 2: |
| energies.append(torch.zeros(B, 1, *target_size, device=spatial.device)) |
| continue |
| omega = F.avg_pool2d(residual, 2) |
| sigma_omega = F.interpolate(omega, size=(h, w), mode="bilinear", align_corners=False) |
| cofib = residual - sigma_omega |
| energy = cofib.pow(2).sum(dim=1, keepdim=True) |
| energy = F.interpolate(energy, size=target_size, mode="bilinear", align_corners=False) |
| energies.append(energy) |
| residual = omega |
|
|
| |
| res_energy = residual.pow(2).sum(dim=1, keepdim=True) |
| res_energy = F.interpolate(res_energy, size=target_size, mode="bilinear", align_corners=False) |
| energies.append(res_energy) |
|
|
| |
| stacked = torch.stack(energies, dim=0) |
| weights = self.scale_weights.reshape(-1, 1, 1, 1, 1) |
| depth = (stacked * weights).sum(dim=0) + self.bias |
| depth = torch.sigmoid(depth) * (self.max_depth - self.min_depth) + self.min_depth |
| return depth |
|
|