ColamanAI's picture
Upload 169 files
b74998d verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
# Modified from https://github.com/facebookresearch/vggt
import torch
import torch.nn as nn
import torch.nn.functional as F
from .modules import AttnBlock, CrossAttnBlock, ResidualBlock
from .utils import bilinear_sampler
class BasicEncoder(nn.Module):
def __init__(self, input_dim=3, output_dim=128, stride=4):
super(BasicEncoder, self).__init__()
self.stride = stride
self.norm_fn = "instance"
self.in_planes = output_dim // 2
self.norm1 = nn.InstanceNorm2d(self.in_planes)
self.norm2 = nn.InstanceNorm2d(output_dim * 2)
self.conv1 = nn.Conv2d(
input_dim,
self.in_planes,
kernel_size=7,
stride=2,
padding=3,
padding_mode="zeros",
)
self.relu1 = nn.ReLU(inplace=True)
self.layer1 = self._make_layer(output_dim // 2, stride=1)
self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
self.layer3 = self._make_layer(output_dim, stride=2)
self.layer4 = self._make_layer(output_dim, stride=2)
self.conv2 = nn.Conv2d(
output_dim * 3 + output_dim // 4,
output_dim * 2,
kernel_size=3,
padding=1,
padding_mode="zeros",
)
self.relu2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.InstanceNorm2d)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
_, _, H, W = x.shape
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
a = self.layer1(x)
b = self.layer2(a)
c = self.layer3(b)
d = self.layer4(c)
a = _bilinear_intepolate(a, self.stride, H, W)
b = _bilinear_intepolate(b, self.stride, H, W)
c = _bilinear_intepolate(c, self.stride, H, W)
d = _bilinear_intepolate(d, self.stride, H, W)
x = self.conv2(torch.cat([a, b, c, d], dim=1))
x = self.norm2(x)
x = self.relu2(x)
x = self.conv3(x)
return x
class ShallowEncoder(nn.Module):
def __init__(self, input_dim=3, output_dim=32, stride=1, norm_fn="instance"):
super(ShallowEncoder, self).__init__()
self.stride = stride
self.norm_fn = norm_fn
self.in_planes = output_dim
if self.norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=self.in_planes)
self.norm2 = nn.GroupNorm(num_groups=8, num_channels=output_dim * 2)
elif self.norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(self.in_planes)
self.norm2 = nn.BatchNorm2d(output_dim * 2)
elif self.norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(self.in_planes)
self.norm2 = nn.InstanceNorm2d(output_dim * 2)
elif self.norm_fn == "none":
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(
input_dim,
self.in_planes,
kernel_size=3,
stride=2,
padding=1,
padding_mode="zeros",
)
self.relu1 = nn.ReLU(inplace=True)
self.layer1 = self._make_layer(output_dim, stride=2)
self.layer2 = self._make_layer(output_dim, stride=2)
self.conv2 = nn.Conv2d(output_dim, output_dim, kernel_size=1)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
self.in_planes = dim
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
return layer1
def forward(self, x):
_, _, H, W = x.shape
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
tmp = self.layer1(x)
x = x + F.interpolate(tmp, (x.shape[-2:]), mode="bilinear", align_corners=True)
tmp = self.layer2(tmp)
x = x + F.interpolate(tmp, (x.shape[-2:]), mode="bilinear", align_corners=True)
tmp = None
x = self.conv2(x) + x
x = F.interpolate(
x, (H // self.stride, W // self.stride), mode="bilinear", align_corners=True
)
return x
def _bilinear_intepolate(x, stride, H, W):
return F.interpolate(
x, (H // stride, W // stride), mode="bilinear", align_corners=True
)
class EfficientUpdateFormer(nn.Module):
"""
Transformer model that updates track estimates.
"""
def __init__(
self,
space_depth=6,
time_depth=6,
input_dim=320,
hidden_size=384,
num_heads=8,
output_dim=130,
mlp_ratio=4.0,
add_space_attn=True,
num_virtual_tracks=64,
):
super().__init__()
self.out_channels = 2
self.num_heads = num_heads
self.hidden_size = hidden_size
self.add_space_attn = add_space_attn
self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
self.num_virtual_tracks = num_virtual_tracks
if self.add_space_attn:
self.virual_tracks = nn.Parameter(
torch.randn(1, num_virtual_tracks, 1, hidden_size)
)
else:
self.virual_tracks = None
self.time_blocks = nn.ModuleList(
[
AttnBlock(
hidden_size,
num_heads,
mlp_ratio=mlp_ratio,
attn_class=nn.MultiheadAttention,
)
for _ in range(time_depth)
]
)
if add_space_attn:
self.space_virtual_blocks = nn.ModuleList(
[
AttnBlock(
hidden_size,
num_heads,
mlp_ratio=mlp_ratio,
attn_class=nn.MultiheadAttention,
)
for _ in range(space_depth)
]
)
self.space_point2virtual_blocks = nn.ModuleList(
[
CrossAttnBlock(
hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
)
for _ in range(space_depth)
]
)
self.space_virtual2point_blocks = nn.ModuleList(
[
CrossAttnBlock(
hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
)
for _ in range(space_depth)
]
)
assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
self.initialize_weights()
def initialize_weights(self):
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
def init_weights_vit_timm(module: nn.Module, name: str = ""):
"""ViT weight initialization, original timm impl (for reproducibility)"""
if isinstance(module, nn.Linear):
torch.nn.init.trunc_normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
def forward(self, input_tensor, mask=None):
tokens = self.input_transform(input_tensor)
init_tokens = tokens
B, _, T, _ = tokens.shape
if self.add_space_attn:
virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
tokens = torch.cat([tokens, virtual_tokens], dim=1)
_, N, _, _ = tokens.shape
j = 0
for i in range(len(self.time_blocks)):
time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
time_tokens = self.time_blocks[i](time_tokens)
tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
if self.add_space_attn and (
i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0
):
space_tokens = (
tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1)
) # B N T C -> (B T) N C
point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
virtual_tokens = self.space_virtual2point_blocks[j](
virtual_tokens, point_tokens, mask=mask
)
virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
point_tokens = self.space_point2virtual_blocks[j](
point_tokens, virtual_tokens, mask=mask
)
space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
tokens = space_tokens.view(B, T, N, -1).permute(
0, 2, 1, 3
) # (B T) N C -> B N T C
j += 1
if self.add_space_attn:
tokens = tokens[:, : N - self.num_virtual_tracks]
tokens = tokens + init_tokens
flow = self.flow_head(tokens)
return flow
class CorrBlock:
def __init__(
self,
fmaps,
num_levels=4,
radius=4,
multiple_track_feats=False,
padding_mode="zeros",
):
B, S, C, H, W = fmaps.shape
self.S, self.C, self.H, self.W = S, C, H, W
self.padding_mode = padding_mode
self.num_levels = num_levels
self.radius = radius
self.fmaps_pyramid = []
self.multiple_track_feats = multiple_track_feats
self.fmaps_pyramid.append(fmaps)
for i in range(self.num_levels - 1):
fmaps_ = fmaps.reshape(B * S, C, H, W)
fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
_, _, H, W = fmaps_.shape
fmaps = fmaps_.reshape(B, S, C, H, W)
self.fmaps_pyramid.append(fmaps)
def sample(self, coords):
r = self.radius
B, S, N, D = coords.shape
assert D == 2
H, W = self.H, self.W
out_pyramid = []
for i in range(self.num_levels):
corrs = self.corrs_pyramid[i] # B, S, N, H, W
*_, H, W = corrs.shape
dx = torch.linspace(-r, r, 2 * r + 1)
dy = torch.linspace(-r, r, 2 * r + 1)
delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(
coords.device
)
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
coords_lvl = centroid_lvl + delta_lvl
corrs = bilinear_sampler(
corrs.reshape(B * S * N, 1, H, W),
coords_lvl,
padding_mode=self.padding_mode,
)
corrs = corrs.view(B, S, N, -1)
out_pyramid.append(corrs)
out = torch.cat(out_pyramid, dim=-1).contiguous() # B, S, N, LRR*2
return out
def corr(self, targets):
B, S, N, C = targets.shape
if self.multiple_track_feats:
targets_split = targets.split(C // self.num_levels, dim=-1)
B, S, N, C = targets_split[0].shape
assert C == self.C
assert S == self.S
fmap1 = targets
self.corrs_pyramid = []
for i, fmaps in enumerate(self.fmaps_pyramid):
*_, H, W = fmaps.shape
fmap2s = fmaps.view(B, S, C, H * W) # B S C H W -> B S C (H W)
if self.multiple_track_feats:
fmap1 = targets_split[i]
corrs = torch.matmul(fmap1, fmap2s)
corrs = corrs.view(B, S, N, H, W) # B S N (H W) -> B S N H W
corrs = corrs / torch.sqrt(torch.tensor(C).float())
self.corrs_pyramid.append(corrs)