|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from functools import partial |
|
|
from typing import Callable |
|
|
import collections |
|
|
from torch import Tensor |
|
|
from itertools import repeat |
|
|
|
|
|
from .utils import bilinear_sampler |
|
|
|
|
|
from .modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock |
|
|
|
|
|
|
|
|
class BasicEncoder(nn.Module): |
|
|
def __init__(self, input_dim=3, output_dim=128, stride=4): |
|
|
super(BasicEncoder, self).__init__() |
|
|
|
|
|
self.stride = stride |
|
|
self.norm_fn = "instance" |
|
|
self.in_planes = output_dim // 2 |
|
|
|
|
|
self.norm1 = nn.InstanceNorm2d(self.in_planes) |
|
|
self.norm2 = nn.InstanceNorm2d(output_dim * 2) |
|
|
|
|
|
self.conv1 = nn.Conv2d(input_dim, self.in_planes, kernel_size=7, stride=2, padding=3, padding_mode="zeros") |
|
|
self.relu1 = nn.ReLU(inplace=True) |
|
|
self.layer1 = self._make_layer(output_dim // 2, stride=1) |
|
|
self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2) |
|
|
self.layer3 = self._make_layer(output_dim, stride=2) |
|
|
self.layer4 = self._make_layer(output_dim, stride=2) |
|
|
|
|
|
self.conv2 = nn.Conv2d( |
|
|
output_dim * 3 + output_dim // 4, output_dim * 2, kernel_size=3, padding=1, padding_mode="zeros" |
|
|
) |
|
|
self.relu2 = nn.ReLU(inplace=True) |
|
|
self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1) |
|
|
|
|
|
for m in self.modules(): |
|
|
if isinstance(m, nn.Conv2d): |
|
|
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") |
|
|
elif isinstance(m, (nn.InstanceNorm2d)): |
|
|
if m.weight is not None: |
|
|
nn.init.constant_(m.weight, 1) |
|
|
if m.bias is not None: |
|
|
nn.init.constant_(m.bias, 0) |
|
|
|
|
|
def _make_layer(self, dim, stride=1): |
|
|
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) |
|
|
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) |
|
|
layers = (layer1, layer2) |
|
|
|
|
|
self.in_planes = dim |
|
|
return nn.Sequential(*layers) |
|
|
|
|
|
def forward(self, x): |
|
|
_, _, H, W = x.shape |
|
|
|
|
|
x = self.conv1(x) |
|
|
x = self.norm1(x) |
|
|
x = self.relu1(x) |
|
|
|
|
|
a = self.layer1(x) |
|
|
b = self.layer2(a) |
|
|
c = self.layer3(b) |
|
|
d = self.layer4(c) |
|
|
|
|
|
a = _bilinear_intepolate(a, self.stride, H, W) |
|
|
b = _bilinear_intepolate(b, self.stride, H, W) |
|
|
c = _bilinear_intepolate(c, self.stride, H, W) |
|
|
d = _bilinear_intepolate(d, self.stride, H, W) |
|
|
|
|
|
x = self.conv2(torch.cat([a, b, c, d], dim=1)) |
|
|
x = self.norm2(x) |
|
|
x = self.relu2(x) |
|
|
x = self.conv3(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class ShallowEncoder(nn.Module): |
|
|
def __init__(self, input_dim=3, output_dim=32, stride=1, norm_fn="instance"): |
|
|
super(ShallowEncoder, self).__init__() |
|
|
self.stride = stride |
|
|
self.norm_fn = norm_fn |
|
|
self.in_planes = output_dim |
|
|
|
|
|
if self.norm_fn == "group": |
|
|
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=self.in_planes) |
|
|
self.norm2 = nn.GroupNorm(num_groups=8, num_channels=output_dim * 2) |
|
|
elif self.norm_fn == "batch": |
|
|
self.norm1 = nn.BatchNorm2d(self.in_planes) |
|
|
self.norm2 = nn.BatchNorm2d(output_dim * 2) |
|
|
elif self.norm_fn == "instance": |
|
|
self.norm1 = nn.InstanceNorm2d(self.in_planes) |
|
|
self.norm2 = nn.InstanceNorm2d(output_dim * 2) |
|
|
elif self.norm_fn == "none": |
|
|
self.norm1 = nn.Sequential() |
|
|
|
|
|
self.conv1 = nn.Conv2d(input_dim, self.in_planes, kernel_size=3, stride=2, padding=1, padding_mode="zeros") |
|
|
self.relu1 = nn.ReLU(inplace=True) |
|
|
|
|
|
self.layer1 = self._make_layer(output_dim, stride=2) |
|
|
|
|
|
self.layer2 = self._make_layer(output_dim, stride=2) |
|
|
self.conv2 = nn.Conv2d(output_dim, output_dim, kernel_size=1) |
|
|
|
|
|
for m in self.modules(): |
|
|
if isinstance(m, nn.Conv2d): |
|
|
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") |
|
|
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): |
|
|
if m.weight is not None: |
|
|
nn.init.constant_(m.weight, 1) |
|
|
if m.bias is not None: |
|
|
nn.init.constant_(m.bias, 0) |
|
|
|
|
|
def _make_layer(self, dim, stride=1): |
|
|
self.in_planes = dim |
|
|
|
|
|
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) |
|
|
return layer1 |
|
|
|
|
|
def forward(self, x): |
|
|
_, _, H, W = x.shape |
|
|
|
|
|
x = self.conv1(x) |
|
|
x = self.norm1(x) |
|
|
x = self.relu1(x) |
|
|
|
|
|
tmp = self.layer1(x) |
|
|
x = x + F.interpolate(tmp, (x.shape[-2:]), mode="bilinear", align_corners=True) |
|
|
tmp = self.layer2(tmp) |
|
|
x = x + F.interpolate(tmp, (x.shape[-2:]), mode="bilinear", align_corners=True) |
|
|
tmp = None |
|
|
x = self.conv2(x) + x |
|
|
|
|
|
x = F.interpolate(x, (H // self.stride, W // self.stride), mode="bilinear", align_corners=True) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
def _bilinear_intepolate(x, stride, H, W): |
|
|
return F.interpolate(x, (H // stride, W // stride), mode="bilinear", align_corners=True) |
|
|
|
|
|
|
|
|
class EfficientUpdateFormer(nn.Module): |
|
|
""" |
|
|
Transformer model that updates track estimates. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
space_depth=6, |
|
|
time_depth=6, |
|
|
input_dim=320, |
|
|
hidden_size=384, |
|
|
num_heads=8, |
|
|
output_dim=130, |
|
|
mlp_ratio=4.0, |
|
|
add_space_attn=True, |
|
|
num_virtual_tracks=64, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.out_channels = 2 |
|
|
self.num_heads = num_heads |
|
|
self.hidden_size = hidden_size |
|
|
self.add_space_attn = add_space_attn |
|
|
self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True) |
|
|
self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True) |
|
|
self.num_virtual_tracks = num_virtual_tracks |
|
|
|
|
|
if self.add_space_attn: |
|
|
self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size)) |
|
|
else: |
|
|
self.virual_tracks = None |
|
|
|
|
|
self.time_blocks = nn.ModuleList( |
|
|
[ |
|
|
AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention) |
|
|
for _ in range(time_depth) |
|
|
] |
|
|
) |
|
|
|
|
|
if add_space_attn: |
|
|
self.space_virtual_blocks = nn.ModuleList( |
|
|
[ |
|
|
AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention) |
|
|
for _ in range(space_depth) |
|
|
] |
|
|
) |
|
|
self.space_point2virtual_blocks = nn.ModuleList( |
|
|
[CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)] |
|
|
) |
|
|
self.space_virtual2point_blocks = nn.ModuleList( |
|
|
[CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)] |
|
|
) |
|
|
assert len(self.time_blocks) >= len(self.space_virtual2point_blocks) |
|
|
self.initialize_weights() |
|
|
|
|
|
def initialize_weights(self): |
|
|
def _basic_init(module): |
|
|
if isinstance(module, nn.Linear): |
|
|
torch.nn.init.xavier_uniform_(module.weight) |
|
|
if module.bias is not None: |
|
|
nn.init.constant_(module.bias, 0) |
|
|
|
|
|
def init_weights_vit_timm(module: nn.Module, name: str = ""): |
|
|
"""ViT weight initialization, original timm impl (for reproducibility)""" |
|
|
if isinstance(module, nn.Linear): |
|
|
trunc_normal_(module.weight, std=0.02) |
|
|
if module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
|
|
|
def forward(self, input_tensor, mask=None): |
|
|
tokens = self.input_transform(input_tensor) |
|
|
|
|
|
init_tokens = tokens |
|
|
|
|
|
B, _, T, _ = tokens.shape |
|
|
|
|
|
if self.add_space_attn: |
|
|
virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1) |
|
|
tokens = torch.cat([tokens, virtual_tokens], dim=1) |
|
|
|
|
|
_, N, _, _ = tokens.shape |
|
|
|
|
|
j = 0 |
|
|
for i in range(len(self.time_blocks)): |
|
|
time_tokens = tokens.contiguous().view(B * N, T, -1) |
|
|
time_tokens = self.time_blocks[i](time_tokens) |
|
|
|
|
|
tokens = time_tokens.view(B, N, T, -1) |
|
|
if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0): |
|
|
space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) |
|
|
point_tokens = space_tokens[:, : N - self.num_virtual_tracks] |
|
|
virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :] |
|
|
|
|
|
virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask) |
|
|
virtual_tokens = self.space_virtual_blocks[j](virtual_tokens) |
|
|
point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask) |
|
|
space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1) |
|
|
tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) |
|
|
j += 1 |
|
|
|
|
|
if self.add_space_attn: |
|
|
tokens = tokens[:, : N - self.num_virtual_tracks] |
|
|
|
|
|
tokens = tokens + init_tokens |
|
|
|
|
|
flow = self.flow_head(tokens) |
|
|
return flow |
|
|
|
|
|
|
|
|
class CorrBlock: |
|
|
def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode="zeros"): |
|
|
B, S, C, H, W = fmaps.shape |
|
|
self.S, self.C, self.H, self.W = S, C, H, W |
|
|
self.padding_mode = padding_mode |
|
|
self.num_levels = num_levels |
|
|
self.radius = radius |
|
|
self.fmaps_pyramid = [] |
|
|
self.multiple_track_feats = multiple_track_feats |
|
|
|
|
|
self.fmaps_pyramid.append(fmaps) |
|
|
for i in range(self.num_levels - 1): |
|
|
fmaps_ = fmaps.reshape(B * S, C, H, W) |
|
|
fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2) |
|
|
_, _, H, W = fmaps_.shape |
|
|
fmaps = fmaps_.reshape(B, S, C, H, W) |
|
|
self.fmaps_pyramid.append(fmaps) |
|
|
|
|
|
def sample(self, coords): |
|
|
r = self.radius |
|
|
B, S, N, D = coords.shape |
|
|
assert D == 2 |
|
|
|
|
|
H, W = self.H, self.W |
|
|
out_pyramid = [] |
|
|
for i in range(self.num_levels): |
|
|
corrs = self.corrs_pyramid[i] |
|
|
*_, H, W = corrs.shape |
|
|
|
|
|
dx = torch.linspace(-r, r, 2 * r + 1) |
|
|
dy = torch.linspace(-r, r, 2 * r + 1) |
|
|
delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(coords.device) |
|
|
|
|
|
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i |
|
|
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) |
|
|
coords_lvl = centroid_lvl + delta_lvl |
|
|
|
|
|
corrs = bilinear_sampler(corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode) |
|
|
corrs = corrs.view(B, S, N, -1) |
|
|
|
|
|
out_pyramid.append(corrs) |
|
|
|
|
|
out = torch.cat(out_pyramid, dim=-1).contiguous() |
|
|
return out |
|
|
|
|
|
def corr(self, targets): |
|
|
B, S, N, C = targets.shape |
|
|
if self.multiple_track_feats: |
|
|
targets_split = targets.split(C // self.num_levels, dim=-1) |
|
|
B, S, N, C = targets_split[0].shape |
|
|
|
|
|
assert C == self.C |
|
|
assert S == self.S |
|
|
|
|
|
fmap1 = targets |
|
|
|
|
|
self.corrs_pyramid = [] |
|
|
for i, fmaps in enumerate(self.fmaps_pyramid): |
|
|
*_, H, W = fmaps.shape |
|
|
fmap2s = fmaps.view(B, S, C, H * W) |
|
|
if self.multiple_track_feats: |
|
|
fmap1 = targets_split[i] |
|
|
corrs = torch.matmul(fmap1, fmap2s) |
|
|
corrs = corrs.view(B, S, N, H, W) |
|
|
corrs = corrs / torch.sqrt(torch.tensor(C).float()) |
|
|
self.corrs_pyramid.append(corrs) |
|
|
|