| from torch import nn
|
| import torch.nn.functional as F
|
| from einops import rearrange
|
|
|
| from .tapir_utils.tapir_model import TAPIR
|
|
|
| class Tapir(nn.Module):
|
| def __init__(self, args):
|
| super().__init__()
|
| self.model = TAPIR(pyramid_level=args.pyramid_level,
|
| softmax_temperature=args.softmax_temperature,
|
| extra_convs=args.extra_convs)
|
|
|
| def forward(self, video, queries, backward_tracking, cache_features=False):
|
|
|
| video = video * 2 - 1
|
| video = rearrange(video, "b t c h w -> b t h w c")
|
|
|
|
|
| queries = queries[..., [0, 2, 1]]
|
|
|
|
|
| outputs = self.model(video, queries, cache_features=cache_features)
|
| tracks, occlusions, expected_dist = outputs['tracks'], outputs['occlusion'], outputs['expected_dist']
|
|
|
|
|
| tracks = rearrange(tracks, "b s t c -> b t s c")
|
|
|
|
|
| visibles = (1 - F.sigmoid(occlusions)) * (1 - F.sigmoid(expected_dist)) > 0.5
|
| visibles = rearrange(visibles, "b s t -> b t s")
|
|
|
| return tracks, visibles |