| | from tqdm import tqdm
|
| | import torch
|
| | from torch import nn
|
| |
|
| | from .optical_flow import OpticalFlow
|
| | from .shelf import CoTracker, CoTracker2, Tapir
|
| | from dot.utils.io import read_config
|
| | from dot.utils.torch import sample_points, sample_mask_points, get_grid
|
| |
|
| |
|
| | class PointTracker(nn.Module):
|
| | def __init__(self, height, width, tracker_config, tracker_path, estimator_config, estimator_path):
|
| | super().__init__()
|
| | model_args = read_config(tracker_config)
|
| | model_dict = {
|
| | "cotracker": CoTracker,
|
| | "cotracker2": CoTracker2,
|
| | "tapir": Tapir,
|
| | "bootstapir": Tapir
|
| | }
|
| | self.name = model_args.name
|
| | self.model = model_dict[model_args.name](model_args)
|
| | if tracker_path is not None:
|
| | device = next(self.model.parameters()).device
|
| | self.model.load_state_dict(torch.load(tracker_path, map_location=device), strict=False)
|
| | self.optical_flow_estimator = OpticalFlow(height, width, estimator_config, estimator_path)
|
| |
|
| | def forward(self, data, mode, **kwargs):
|
| | if mode == "tracks_at_motion_boundaries":
|
| | return self.get_tracks_at_motion_boundaries(data, **kwargs)
|
| | elif mode == "flow_from_last_to_first_frame":
|
| | return self.get_flow_from_last_to_first_frame(data, **kwargs)
|
| | else:
|
| | raise ValueError(f"Unknown mode {mode}")
|
| |
|
| | def get_tracks_at_motion_boundaries(self, data, num_tracks=8192, sim_tracks=2048, sample_mode="all", **kwargs):
|
| | video = data["video"]
|
| | N, S = num_tracks, sim_tracks
|
| | B, T, _, H, W = video.shape
|
| | assert N % S == 0
|
| |
|
| |
|
| | if sample_mode == "all":
|
| | samples_per_step = [S // T for _ in range(T)]
|
| | samples_per_step[0] += S - sum(samples_per_step)
|
| | backward_tracking = True
|
| | flip = False
|
| | elif sample_mode == "first":
|
| | samples_per_step = [0 for _ in range(T)]
|
| | samples_per_step[0] += S
|
| | backward_tracking = False
|
| | flip = False
|
| | elif sample_mode == "last":
|
| | samples_per_step = [0 for _ in range(T)]
|
| | samples_per_step[0] += S
|
| | backward_tracking = False
|
| | flip = True
|
| | else:
|
| | raise ValueError(f"Unknown sample mode {sample_mode}")
|
| |
|
| | if flip:
|
| | video = video.flip(dims=[1])
|
| |
|
| |
|
| | tracks = []
|
| | motion_boundaries = {}
|
| | cache_features = True
|
| | for _ in tqdm(range(N // S), desc="Track batch of points", leave=False):
|
| | src_points = []
|
| | for src_step, src_samples in enumerate(samples_per_step):
|
| | if src_samples == 0:
|
| | continue
|
| | if not src_step in motion_boundaries:
|
| | tgt_step = src_step - 1 if src_step > 0 else src_step + 1
|
| | data = {"src_frame": video[:, src_step], "tgt_frame": video[:, tgt_step]}
|
| | pred = self.optical_flow_estimator(data, mode="motion_boundaries", **kwargs)
|
| | motion_boundaries[src_step] = pred["motion_boundaries"]
|
| | src_boundaries = motion_boundaries[src_step]
|
| | src_points.append(sample_points(src_step, src_boundaries, src_samples))
|
| | src_points = torch.cat(src_points, dim=1)
|
| | traj, vis = self.model(video, src_points, backward_tracking, cache_features)
|
| | tracks.append(torch.cat([traj, vis[..., None]], dim=-1))
|
| | cache_features = False
|
| | tracks = torch.cat(tracks, dim=2)
|
| |
|
| | if flip:
|
| | tracks = tracks.flip(dims=[1])
|
| |
|
| | return {"tracks": tracks}
|
| |
|
| | def get_flow_from_last_to_first_frame(self, data, sim_tracks=2048, **kwargs):
|
| | video = data["video"]
|
| | video = video.flip(dims=[1])
|
| | src_step = 0
|
| | B, T, C, H, W = video.shape
|
| | S = sim_tracks
|
| | backward_tracking = False
|
| | cache_features = True
|
| | flow = get_grid(H, W, shape=[B]).cuda()
|
| | flow[..., 0] = flow[..., 0] * (W - 1)
|
| | flow[..., 1] = flow[..., 1] * (H - 1)
|
| | alpha = torch.zeros(B, H, W).cuda()
|
| | mask = torch.ones(H, W)
|
| | pbar = tqdm(total=H * W // S, desc="Track batch of points", leave=False)
|
| | while torch.any(mask):
|
| | points, (i, j) = sample_mask_points(src_step, mask, S)
|
| | idx = i * W + j
|
| | points = points.cuda()[None].expand(B, -1, -1)
|
| |
|
| | traj, vis = self.model(video, points, backward_tracking, cache_features)
|
| | traj = traj[:, -1]
|
| | vis = vis[:, -1].float()
|
| |
|
| |
|
| | mask = mask.view(-1)
|
| | mask[idx] = 0
|
| | mask = mask.view(H, W)
|
| |
|
| |
|
| | flow = flow.view(B, -1, 2)
|
| | flow[:, idx] = traj - flow[:, idx]
|
| | flow = flow.view(B, H, W, 2)
|
| |
|
| |
|
| | alpha = alpha.view(B, -1)
|
| | alpha[:, idx] = vis
|
| | alpha = alpha.view(B, H, W)
|
| |
|
| | cache_features = False
|
| | pbar.update(1)
|
| | pbar.close()
|
| | return {"flow": flow, "alpha": alpha}
|
| |
|