| from tqdm import tqdm
|
| import torch
|
| from torch import nn
|
|
|
| from .optical_flow import OpticalFlow
|
| from .shelf import CoTracker, CoTracker2, Tapir
|
| from dot.utils.io import read_config
|
| from dot.utils.torch import sample_points, sample_mask_points, get_grid
|
|
|
|
|
| class PointTracker(nn.Module):
|
| def __init__(self, height, width, tracker_config, tracker_path, estimator_config, estimator_path):
|
| super().__init__()
|
| model_args = read_config(tracker_config)
|
| model_dict = {
|
| "cotracker": CoTracker,
|
| "cotracker2": CoTracker2,
|
| "tapir": Tapir,
|
| "bootstapir": Tapir
|
| }
|
| self.name = model_args.name
|
| self.model = model_dict[model_args.name](model_args)
|
| if tracker_path is not None:
|
| device = next(self.model.parameters()).device
|
| self.model.load_state_dict(torch.load(tracker_path, map_location=device), strict=False)
|
| self.optical_flow_estimator = OpticalFlow(height, width, estimator_config, estimator_path)
|
|
|
| def forward(self, data, mode, **kwargs):
|
| if mode == "tracks_at_motion_boundaries":
|
| return self.get_tracks_at_motion_boundaries(data, **kwargs)
|
| elif mode == "flow_from_last_to_first_frame":
|
| return self.get_flow_from_last_to_first_frame(data, **kwargs)
|
| else:
|
| raise ValueError(f"Unknown mode {mode}")
|
|
|
| def get_tracks_at_motion_boundaries(self, data, num_tracks=8192, sim_tracks=2048, sample_mode="all", **kwargs):
|
| video = data["video"]
|
| N, S = num_tracks, sim_tracks
|
| B, T, _, H, W = video.shape
|
| assert N % S == 0
|
|
|
|
|
| if sample_mode == "all":
|
| samples_per_step = [S // T for _ in range(T)]
|
| samples_per_step[0] += S - sum(samples_per_step)
|
| backward_tracking = True
|
| flip = False
|
| elif sample_mode == "first":
|
| samples_per_step = [0 for _ in range(T)]
|
| samples_per_step[0] += S
|
| backward_tracking = False
|
| flip = False
|
| elif sample_mode == "last":
|
| samples_per_step = [0 for _ in range(T)]
|
| samples_per_step[0] += S
|
| backward_tracking = False
|
| flip = True
|
| else:
|
| raise ValueError(f"Unknown sample mode {sample_mode}")
|
|
|
| if flip:
|
| video = video.flip(dims=[1])
|
|
|
|
|
| tracks = []
|
| motion_boundaries = {}
|
| cache_features = True
|
| for _ in tqdm(range(N // S), desc="Track batch of points", leave=False):
|
| src_points = []
|
| for src_step, src_samples in enumerate(samples_per_step):
|
| if src_samples == 0:
|
| continue
|
| if not src_step in motion_boundaries:
|
| tgt_step = src_step - 1 if src_step > 0 else src_step + 1
|
| data = {"src_frame": video[:, src_step], "tgt_frame": video[:, tgt_step]}
|
| pred = self.optical_flow_estimator(data, mode="motion_boundaries", **kwargs)
|
| motion_boundaries[src_step] = pred["motion_boundaries"]
|
| src_boundaries = motion_boundaries[src_step]
|
| src_points.append(sample_points(src_step, src_boundaries, src_samples))
|
| src_points = torch.cat(src_points, dim=1)
|
| traj, vis = self.model(video, src_points, backward_tracking, cache_features)
|
| tracks.append(torch.cat([traj, vis[..., None]], dim=-1))
|
| cache_features = False
|
| tracks = torch.cat(tracks, dim=2)
|
|
|
| if flip:
|
| tracks = tracks.flip(dims=[1])
|
|
|
| return {"tracks": tracks}
|
|
|
| def get_flow_from_last_to_first_frame(self, data, sim_tracks=2048, **kwargs):
|
| video = data["video"]
|
| video = video.flip(dims=[1])
|
| src_step = 0
|
| B, T, C, H, W = video.shape
|
| S = sim_tracks
|
| backward_tracking = False
|
| cache_features = True
|
| flow = get_grid(H, W, shape=[B]).cuda()
|
| flow[..., 0] = flow[..., 0] * (W - 1)
|
| flow[..., 1] = flow[..., 1] * (H - 1)
|
| alpha = torch.zeros(B, H, W).cuda()
|
| mask = torch.ones(H, W)
|
| pbar = tqdm(total=H * W // S, desc="Track batch of points", leave=False)
|
| while torch.any(mask):
|
| points, (i, j) = sample_mask_points(src_step, mask, S)
|
| idx = i * W + j
|
| points = points.cuda()[None].expand(B, -1, -1)
|
|
|
| traj, vis = self.model(video, points, backward_tracking, cache_features)
|
| traj = traj[:, -1]
|
| vis = vis[:, -1].float()
|
|
|
|
|
| mask = mask.view(-1)
|
| mask[idx] = 0
|
| mask = mask.view(H, W)
|
|
|
|
|
| flow = flow.view(B, -1, 2)
|
| flow[:, idx] = traj - flow[:, idx]
|
| flow = flow.view(B, H, W, 2)
|
|
|
|
|
| alpha = alpha.view(B, -1)
|
| alpha[:, idx] = vis
|
| alpha = alpha.view(B, H, W)
|
|
|
| cache_features = False
|
| pbar.update(1)
|
| pbar.close()
|
| return {"flow": flow, "alpha": alpha}
|
|
|