| """Wavelet: Haar decomposition + per-subband depth bin prediction.""" |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class Wavelet(nn.Module): |
| name = "wavelet" |
| needs_intermediates = False |
|
|
| def __init__(self, feat_dim=768, n_bins=256, min_depth=0.001, max_depth=10.0, n_scales=3): |
| super().__init__() |
| self.n_scales = n_scales |
| self.n_bins = n_bins |
| self.min_depth = min_depth |
| self.max_depth = max_depth |
| self.heads = nn.ModuleList([nn.Conv2d(feat_dim, n_bins, 1) for _ in range(n_scales)]) |
|
|
| @staticmethod |
| def haar_down(x): |
| h, w = x.shape[2], x.shape[3] |
| x = x[:, :, :h - h % 2, :w - w % 2] |
| return (x[:, :, 0::2, 0::2] + x[:, :, 0::2, 1::2] + |
| x[:, :, 1::2, 0::2] + x[:, :, 1::2, 1::2]) / 4 |
|
|
| def forward(self, spatial, inter=None): |
| target_size = spatial.shape[2:] |
| f = spatial |
| logits = None |
| for i in range(self.n_scales): |
| out = self.heads[i](f) |
| out = F.interpolate(out, size=target_size, mode="bilinear", align_corners=False) |
| logits = out if logits is None else logits + out |
| if i < self.n_scales - 1: |
| f = self.haar_down(f) |
| dist = torch.relu(logits) + 0.1 |
| dist = dist / dist.sum(dim=1, keepdim=True) |
| bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=spatial.device) |
| return torch.einsum("bkhw,k->bhw", dist, bins).unsqueeze(1) |
|
|