Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # | |
| # Modified from https://github.com/facebookresearch/vggt | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .modules import AttnBlock, CrossAttnBlock, ResidualBlock | |
| from .utils import bilinear_sampler | |
| class BasicEncoder(nn.Module): | |
| def __init__(self, input_dim=3, output_dim=128, stride=4): | |
| super(BasicEncoder, self).__init__() | |
| self.stride = stride | |
| self.norm_fn = "instance" | |
| self.in_planes = output_dim // 2 | |
| self.norm1 = nn.InstanceNorm2d(self.in_planes) | |
| self.norm2 = nn.InstanceNorm2d(output_dim * 2) | |
| self.conv1 = nn.Conv2d( | |
| input_dim, | |
| self.in_planes, | |
| kernel_size=7, | |
| stride=2, | |
| padding=3, | |
| padding_mode="zeros", | |
| ) | |
| self.relu1 = nn.ReLU(inplace=True) | |
| self.layer1 = self._make_layer(output_dim // 2, stride=1) | |
| self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2) | |
| self.layer3 = self._make_layer(output_dim, stride=2) | |
| self.layer4 = self._make_layer(output_dim, stride=2) | |
| self.conv2 = nn.Conv2d( | |
| output_dim * 3 + output_dim // 4, | |
| output_dim * 2, | |
| kernel_size=3, | |
| padding=1, | |
| padding_mode="zeros", | |
| ) | |
| self.relu2 = nn.ReLU(inplace=True) | |
| self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1) | |
| for m in self.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") | |
| elif isinstance(m, (nn.InstanceNorm2d)): | |
| if m.weight is not None: | |
| nn.init.constant_(m.weight, 1) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| def _make_layer(self, dim, stride=1): | |
| layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) | |
| layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) | |
| layers = (layer1, layer2) | |
| self.in_planes = dim | |
| return nn.Sequential(*layers) | |
| def forward(self, x): | |
| _, _, H, W = x.shape | |
| x = self.conv1(x) | |
| x = self.norm1(x) | |
| x = self.relu1(x) | |
| a = self.layer1(x) | |
| b = self.layer2(a) | |
| c = self.layer3(b) | |
| d = self.layer4(c) | |
| a = _bilinear_intepolate(a, self.stride, H, W) | |
| b = _bilinear_intepolate(b, self.stride, H, W) | |
| c = _bilinear_intepolate(c, self.stride, H, W) | |
| d = _bilinear_intepolate(d, self.stride, H, W) | |
| x = self.conv2(torch.cat([a, b, c, d], dim=1)) | |
| x = self.norm2(x) | |
| x = self.relu2(x) | |
| x = self.conv3(x) | |
| return x | |
| class ShallowEncoder(nn.Module): | |
| def __init__(self, input_dim=3, output_dim=32, stride=1, norm_fn="instance"): | |
| super(ShallowEncoder, self).__init__() | |
| self.stride = stride | |
| self.norm_fn = norm_fn | |
| self.in_planes = output_dim | |
| if self.norm_fn == "group": | |
| self.norm1 = nn.GroupNorm(num_groups=8, num_channels=self.in_planes) | |
| self.norm2 = nn.GroupNorm(num_groups=8, num_channels=output_dim * 2) | |
| elif self.norm_fn == "batch": | |
| self.norm1 = nn.BatchNorm2d(self.in_planes) | |
| self.norm2 = nn.BatchNorm2d(output_dim * 2) | |
| elif self.norm_fn == "instance": | |
| self.norm1 = nn.InstanceNorm2d(self.in_planes) | |
| self.norm2 = nn.InstanceNorm2d(output_dim * 2) | |
| elif self.norm_fn == "none": | |
| self.norm1 = nn.Sequential() | |
| self.conv1 = nn.Conv2d( | |
| input_dim, | |
| self.in_planes, | |
| kernel_size=3, | |
| stride=2, | |
| padding=1, | |
| padding_mode="zeros", | |
| ) | |
| self.relu1 = nn.ReLU(inplace=True) | |
| self.layer1 = self._make_layer(output_dim, stride=2) | |
| self.layer2 = self._make_layer(output_dim, stride=2) | |
| self.conv2 = nn.Conv2d(output_dim, output_dim, kernel_size=1) | |
| for m in self.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") | |
| elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): | |
| if m.weight is not None: | |
| nn.init.constant_(m.weight, 1) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| def _make_layer(self, dim, stride=1): | |
| self.in_planes = dim | |
| layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) | |
| return layer1 | |
| def forward(self, x): | |
| _, _, H, W = x.shape | |
| x = self.conv1(x) | |
| x = self.norm1(x) | |
| x = self.relu1(x) | |
| tmp = self.layer1(x) | |
| x = x + F.interpolate(tmp, (x.shape[-2:]), mode="bilinear", align_corners=True) | |
| tmp = self.layer2(tmp) | |
| x = x + F.interpolate(tmp, (x.shape[-2:]), mode="bilinear", align_corners=True) | |
| tmp = None | |
| x = self.conv2(x) + x | |
| x = F.interpolate( | |
| x, (H // self.stride, W // self.stride), mode="bilinear", align_corners=True | |
| ) | |
| return x | |
| def _bilinear_intepolate(x, stride, H, W): | |
| return F.interpolate( | |
| x, (H // stride, W // stride), mode="bilinear", align_corners=True | |
| ) | |
| class EfficientUpdateFormer(nn.Module): | |
| """ | |
| Transformer model that updates track estimates. | |
| """ | |
| def __init__( | |
| self, | |
| space_depth=6, | |
| time_depth=6, | |
| input_dim=320, | |
| hidden_size=384, | |
| num_heads=8, | |
| output_dim=130, | |
| mlp_ratio=4.0, | |
| add_space_attn=True, | |
| num_virtual_tracks=64, | |
| ): | |
| super().__init__() | |
| self.out_channels = 2 | |
| self.num_heads = num_heads | |
| self.hidden_size = hidden_size | |
| self.add_space_attn = add_space_attn | |
| self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True) | |
| self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True) | |
| self.num_virtual_tracks = num_virtual_tracks | |
| if self.add_space_attn: | |
| self.virual_tracks = nn.Parameter( | |
| torch.randn(1, num_virtual_tracks, 1, hidden_size) | |
| ) | |
| else: | |
| self.virual_tracks = None | |
| self.time_blocks = nn.ModuleList( | |
| [ | |
| AttnBlock( | |
| hidden_size, | |
| num_heads, | |
| mlp_ratio=mlp_ratio, | |
| attn_class=nn.MultiheadAttention, | |
| ) | |
| for _ in range(time_depth) | |
| ] | |
| ) | |
| if add_space_attn: | |
| self.space_virtual_blocks = nn.ModuleList( | |
| [ | |
| AttnBlock( | |
| hidden_size, | |
| num_heads, | |
| mlp_ratio=mlp_ratio, | |
| attn_class=nn.MultiheadAttention, | |
| ) | |
| for _ in range(space_depth) | |
| ] | |
| ) | |
| self.space_point2virtual_blocks = nn.ModuleList( | |
| [ | |
| CrossAttnBlock( | |
| hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio | |
| ) | |
| for _ in range(space_depth) | |
| ] | |
| ) | |
| self.space_virtual2point_blocks = nn.ModuleList( | |
| [ | |
| CrossAttnBlock( | |
| hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio | |
| ) | |
| for _ in range(space_depth) | |
| ] | |
| ) | |
| assert len(self.time_blocks) >= len(self.space_virtual2point_blocks) | |
| self.initialize_weights() | |
| def initialize_weights(self): | |
| def _basic_init(module): | |
| if isinstance(module, nn.Linear): | |
| torch.nn.init.xavier_uniform_(module.weight) | |
| if module.bias is not None: | |
| nn.init.constant_(module.bias, 0) | |
| def init_weights_vit_timm(module: nn.Module, name: str = ""): | |
| """ViT weight initialization, original timm impl (for reproducibility)""" | |
| if isinstance(module, nn.Linear): | |
| torch.nn.init.trunc_normal_(module.weight, std=0.02) | |
| if module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| def forward(self, input_tensor, mask=None): | |
| tokens = self.input_transform(input_tensor) | |
| init_tokens = tokens | |
| B, _, T, _ = tokens.shape | |
| if self.add_space_attn: | |
| virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1) | |
| tokens = torch.cat([tokens, virtual_tokens], dim=1) | |
| _, N, _, _ = tokens.shape | |
| j = 0 | |
| for i in range(len(self.time_blocks)): | |
| time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C | |
| time_tokens = self.time_blocks[i](time_tokens) | |
| tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C | |
| if self.add_space_attn and ( | |
| i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0 | |
| ): | |
| space_tokens = ( | |
| tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) | |
| ) # B N T C -> (B T) N C | |
| point_tokens = space_tokens[:, : N - self.num_virtual_tracks] | |
| virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :] | |
| virtual_tokens = self.space_virtual2point_blocks[j]( | |
| virtual_tokens, point_tokens, mask=mask | |
| ) | |
| virtual_tokens = self.space_virtual_blocks[j](virtual_tokens) | |
| point_tokens = self.space_point2virtual_blocks[j]( | |
| point_tokens, virtual_tokens, mask=mask | |
| ) | |
| space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1) | |
| tokens = space_tokens.view(B, T, N, -1).permute( | |
| 0, 2, 1, 3 | |
| ) # (B T) N C -> B N T C | |
| j += 1 | |
| if self.add_space_attn: | |
| tokens = tokens[:, : N - self.num_virtual_tracks] | |
| tokens = tokens + init_tokens | |
| flow = self.flow_head(tokens) | |
| return flow | |
| class CorrBlock: | |
| def __init__( | |
| self, | |
| fmaps, | |
| num_levels=4, | |
| radius=4, | |
| multiple_track_feats=False, | |
| padding_mode="zeros", | |
| ): | |
| B, S, C, H, W = fmaps.shape | |
| self.S, self.C, self.H, self.W = S, C, H, W | |
| self.padding_mode = padding_mode | |
| self.num_levels = num_levels | |
| self.radius = radius | |
| self.fmaps_pyramid = [] | |
| self.multiple_track_feats = multiple_track_feats | |
| self.fmaps_pyramid.append(fmaps) | |
| for i in range(self.num_levels - 1): | |
| fmaps_ = fmaps.reshape(B * S, C, H, W) | |
| fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2) | |
| _, _, H, W = fmaps_.shape | |
| fmaps = fmaps_.reshape(B, S, C, H, W) | |
| self.fmaps_pyramid.append(fmaps) | |
| def sample(self, coords): | |
| r = self.radius | |
| B, S, N, D = coords.shape | |
| assert D == 2 | |
| H, W = self.H, self.W | |
| out_pyramid = [] | |
| for i in range(self.num_levels): | |
| corrs = self.corrs_pyramid[i] # B, S, N, H, W | |
| *_, H, W = corrs.shape | |
| dx = torch.linspace(-r, r, 2 * r + 1) | |
| dy = torch.linspace(-r, r, 2 * r + 1) | |
| delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to( | |
| coords.device | |
| ) | |
| centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i | |
| delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) | |
| coords_lvl = centroid_lvl + delta_lvl | |
| corrs = bilinear_sampler( | |
| corrs.reshape(B * S * N, 1, H, W), | |
| coords_lvl, | |
| padding_mode=self.padding_mode, | |
| ) | |
| corrs = corrs.view(B, S, N, -1) | |
| out_pyramid.append(corrs) | |
| out = torch.cat(out_pyramid, dim=-1).contiguous() # B, S, N, LRR*2 | |
| return out | |
| def corr(self, targets): | |
| B, S, N, C = targets.shape | |
| if self.multiple_track_feats: | |
| targets_split = targets.split(C // self.num_levels, dim=-1) | |
| B, S, N, C = targets_split[0].shape | |
| assert C == self.C | |
| assert S == self.S | |
| fmap1 = targets | |
| self.corrs_pyramid = [] | |
| for i, fmaps in enumerate(self.fmaps_pyramid): | |
| *_, H, W = fmaps.shape | |
| fmap2s = fmaps.view(B, S, C, H * W) # B S C H W -> B S C (H W) | |
| if self.multiple_track_feats: | |
| fmap1 = targets_split[i] | |
| corrs = torch.matmul(fmap1, fmap2s) | |
| corrs = corrs.view(B, S, N, H, W) # B S N (H W) -> B S N H W | |
| corrs = corrs / torch.sqrt(torch.tensor(C).float()) | |
| self.corrs_pyramid.append(corrs) | |