phanerozoic's picture
7 depth head candidates with shared losses/utils and registry
a103957 verified
"""Wavelet: Haar decomposition + per-subband depth bin prediction."""
import torch
import torch.nn as nn
import torch.nn.functional as F
class Wavelet(nn.Module):
name = "wavelet"
needs_intermediates = False
def __init__(self, feat_dim=768, n_bins=256, min_depth=0.001, max_depth=10.0, n_scales=3):
super().__init__()
self.n_scales = n_scales
self.n_bins = n_bins
self.min_depth = min_depth
self.max_depth = max_depth
self.heads = nn.ModuleList([nn.Conv2d(feat_dim, n_bins, 1) for _ in range(n_scales)])
@staticmethod
def haar_down(x):
h, w = x.shape[2], x.shape[3]
x = x[:, :, :h - h % 2, :w - w % 2]
return (x[:, :, 0::2, 0::2] + x[:, :, 0::2, 1::2] +
x[:, :, 1::2, 0::2] + x[:, :, 1::2, 1::2]) / 4
def forward(self, spatial, inter=None):
target_size = spatial.shape[2:]
f = spatial
logits = None
for i in range(self.n_scales):
out = self.heads[i](f)
out = F.interpolate(out, size=target_size, mode="bilinear", align_corners=False)
logits = out if logits is None else logits + out
if i < self.n_scales - 1:
f = self.haar_down(f)
dist = torch.relu(logits) + 0.1
dist = dist / dist.sum(dim=1, keepdim=True)
bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=spatial.device)
return torch.einsum("bkhw,k->bhw", dist, bins).unsqueeze(1)