|
|
from tqdm import tqdm
|
|
|
import torch
|
|
|
from torch import nn
|
|
|
|
|
|
from .optical_flow import OpticalFlow
|
|
|
from .shelf import CoTracker, CoTracker2, Tapir
|
|
|
from dot.utils.io import read_config
|
|
|
from dot.utils.torch import sample_points, sample_mask_points, get_grid
|
|
|
|
|
|
|
|
|
class PointTracker(nn.Module):
|
|
|
def __init__(self, height, width, tracker_config, tracker_path, estimator_config, estimator_path):
|
|
|
super().__init__()
|
|
|
model_args = read_config(tracker_config)
|
|
|
model_dict = {
|
|
|
"cotracker": CoTracker,
|
|
|
"cotracker2": CoTracker2,
|
|
|
"tapir": Tapir,
|
|
|
"bootstapir": Tapir
|
|
|
}
|
|
|
self.name = model_args.name
|
|
|
self.model = model_dict[model_args.name](model_args)
|
|
|
if tracker_path is not None:
|
|
|
device = next(self.model.parameters()).device
|
|
|
self.model.load_state_dict(torch.load(tracker_path, map_location=device), strict=False)
|
|
|
self.optical_flow_estimator = OpticalFlow(height, width, estimator_config, estimator_path)
|
|
|
|
|
|
def forward(self, data, mode, **kwargs):
|
|
|
if mode == "tracks_at_motion_boundaries":
|
|
|
return self.get_tracks_at_motion_boundaries(data, **kwargs)
|
|
|
elif mode == "flow_from_last_to_first_frame":
|
|
|
return self.get_flow_from_last_to_first_frame(data, **kwargs)
|
|
|
else:
|
|
|
raise ValueError(f"Unknown mode {mode}")
|
|
|
|
|
|
def get_tracks_at_motion_boundaries(self, data, num_tracks=8192, sim_tracks=2048, sample_mode="all", **kwargs):
|
|
|
video = data["video"]
|
|
|
N, S = num_tracks, sim_tracks
|
|
|
B, T, _, H, W = video.shape
|
|
|
assert N % S == 0
|
|
|
|
|
|
|
|
|
if sample_mode == "all":
|
|
|
samples_per_step = [S // T for _ in range(T)]
|
|
|
samples_per_step[0] += S - sum(samples_per_step)
|
|
|
backward_tracking = True
|
|
|
flip = False
|
|
|
elif sample_mode == "first":
|
|
|
samples_per_step = [0 for _ in range(T)]
|
|
|
samples_per_step[0] += S
|
|
|
backward_tracking = False
|
|
|
flip = False
|
|
|
elif sample_mode == "last":
|
|
|
samples_per_step = [0 for _ in range(T)]
|
|
|
samples_per_step[0] += S
|
|
|
backward_tracking = False
|
|
|
flip = True
|
|
|
else:
|
|
|
raise ValueError(f"Unknown sample mode {sample_mode}")
|
|
|
|
|
|
if flip:
|
|
|
video = video.flip(dims=[1])
|
|
|
|
|
|
|
|
|
tracks = []
|
|
|
motion_boundaries = {}
|
|
|
cache_features = True
|
|
|
for _ in tqdm(range(N // S), desc="Track batch of points", leave=False):
|
|
|
src_points = []
|
|
|
for src_step, src_samples in enumerate(samples_per_step):
|
|
|
if src_samples == 0:
|
|
|
continue
|
|
|
if not src_step in motion_boundaries:
|
|
|
tgt_step = src_step - 1 if src_step > 0 else src_step + 1
|
|
|
data = {"src_frame": video[:, src_step], "tgt_frame": video[:, tgt_step]}
|
|
|
pred = self.optical_flow_estimator(data, mode="motion_boundaries", **kwargs)
|
|
|
motion_boundaries[src_step] = pred["motion_boundaries"]
|
|
|
src_boundaries = motion_boundaries[src_step]
|
|
|
src_points.append(sample_points(src_step, src_boundaries, src_samples))
|
|
|
src_points = torch.cat(src_points, dim=1)
|
|
|
traj, vis = self.model(video, src_points, backward_tracking, cache_features)
|
|
|
tracks.append(torch.cat([traj, vis[..., None]], dim=-1))
|
|
|
cache_features = False
|
|
|
tracks = torch.cat(tracks, dim=2)
|
|
|
|
|
|
if flip:
|
|
|
tracks = tracks.flip(dims=[1])
|
|
|
|
|
|
return {"tracks": tracks}
|
|
|
|
|
|
def get_flow_from_last_to_first_frame(self, data, sim_tracks=2048, **kwargs):
|
|
|
video = data["video"]
|
|
|
video = video.flip(dims=[1])
|
|
|
src_step = 0
|
|
|
B, T, C, H, W = video.shape
|
|
|
S = sim_tracks
|
|
|
backward_tracking = False
|
|
|
cache_features = True
|
|
|
flow = get_grid(H, W, shape=[B]).cuda()
|
|
|
flow[..., 0] = flow[..., 0] * (W - 1)
|
|
|
flow[..., 1] = flow[..., 1] * (H - 1)
|
|
|
alpha = torch.zeros(B, H, W).cuda()
|
|
|
mask = torch.ones(H, W)
|
|
|
pbar = tqdm(total=H * W // S, desc="Track batch of points", leave=False)
|
|
|
while torch.any(mask):
|
|
|
points, (i, j) = sample_mask_points(src_step, mask, S)
|
|
|
idx = i * W + j
|
|
|
points = points.cuda()[None].expand(B, -1, -1)
|
|
|
|
|
|
traj, vis = self.model(video, points, backward_tracking, cache_features)
|
|
|
traj = traj[:, -1]
|
|
|
vis = vis[:, -1].float()
|
|
|
|
|
|
|
|
|
mask = mask.view(-1)
|
|
|
mask[idx] = 0
|
|
|
mask = mask.view(H, W)
|
|
|
|
|
|
|
|
|
flow = flow.view(B, -1, 2)
|
|
|
flow[:, idx] = traj - flow[:, idx]
|
|
|
flow = flow.view(B, H, W, 2)
|
|
|
|
|
|
|
|
|
alpha = alpha.view(B, -1)
|
|
|
alpha[:, idx] = vis
|
|
|
alpha = alpha.view(B, H, W)
|
|
|
|
|
|
cache_features = False
|
|
|
pbar.update(1)
|
|
|
pbar.close()
|
|
|
return {"flow": flow, "alpha": alpha}
|
|
|
|