File size: 5,361 Bytes
ef296aa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | from tqdm import tqdm
import torch
from torch import nn
from .optical_flow import OpticalFlow
from .shelf import CoTracker, CoTracker2, Tapir
from dot.utils.io import read_config
from dot.utils.torch import sample_points, sample_mask_points, get_grid
class PointTracker(nn.Module):
def __init__(self, height, width, tracker_config, tracker_path, estimator_config, estimator_path):
super().__init__()
model_args = read_config(tracker_config)
model_dict = {
"cotracker": CoTracker,
"cotracker2": CoTracker2,
"tapir": Tapir,
"bootstapir": Tapir
}
self.name = model_args.name
self.model = model_dict[model_args.name](model_args)
if tracker_path is not None:
device = next(self.model.parameters()).device
self.model.load_state_dict(torch.load(tracker_path, map_location=device), strict=False)
self.optical_flow_estimator = OpticalFlow(height, width, estimator_config, estimator_path)
def forward(self, data, mode, **kwargs):
if mode == "tracks_at_motion_boundaries":
return self.get_tracks_at_motion_boundaries(data, **kwargs)
elif mode == "flow_from_last_to_first_frame":
return self.get_flow_from_last_to_first_frame(data, **kwargs)
else:
raise ValueError(f"Unknown mode {mode}")
def get_tracks_at_motion_boundaries(self, data, num_tracks=8192, sim_tracks=2048, sample_mode="all", **kwargs):
video = data["video"]
N, S = num_tracks, sim_tracks
B, T, _, H, W = video.shape
assert N % S == 0
# Define sampling strategy
if sample_mode == "all":
samples_per_step = [S // T for _ in range(T)]
samples_per_step[0] += S - sum(samples_per_step)
backward_tracking = True
flip = False
elif sample_mode == "first":
samples_per_step = [0 for _ in range(T)]
samples_per_step[0] += S
backward_tracking = False
flip = False
elif sample_mode == "last":
samples_per_step = [0 for _ in range(T)]
samples_per_step[0] += S
backward_tracking = False
flip = True
else:
raise ValueError(f"Unknown sample mode {sample_mode}")
if flip:
video = video.flip(dims=[1])
# Track batches of points
tracks = []
motion_boundaries = {}
cache_features = True
for _ in tqdm(range(N // S), desc="Track batch of points", leave=False):
src_points = []
for src_step, src_samples in enumerate(samples_per_step):
if src_samples == 0:
continue
if not src_step in motion_boundaries:
tgt_step = src_step - 1 if src_step > 0 else src_step + 1
data = {"src_frame": video[:, src_step], "tgt_frame": video[:, tgt_step]}
pred = self.optical_flow_estimator(data, mode="motion_boundaries", **kwargs)
motion_boundaries[src_step] = pred["motion_boundaries"]
src_boundaries = motion_boundaries[src_step]
src_points.append(sample_points(src_step, src_boundaries, src_samples))
src_points = torch.cat(src_points, dim=1)
traj, vis = self.model(video, src_points, backward_tracking, cache_features)
tracks.append(torch.cat([traj, vis[..., None]], dim=-1))
cache_features = False
tracks = torch.cat(tracks, dim=2)
if flip:
tracks = tracks.flip(dims=[1])
return {"tracks": tracks}
def get_flow_from_last_to_first_frame(self, data, sim_tracks=2048, **kwargs):
video = data["video"]
video = video.flip(dims=[1])
src_step = 0 # We have flipped video over temporal axis so src_step is 0
B, T, C, H, W = video.shape
S = sim_tracks
backward_tracking = False
cache_features = True
flow = get_grid(H, W, shape=[B]).cuda()
flow[..., 0] = flow[..., 0] * (W - 1)
flow[..., 1] = flow[..., 1] * (H - 1)
alpha = torch.zeros(B, H, W).cuda()
mask = torch.ones(H, W)
pbar = tqdm(total=H * W // S, desc="Track batch of points", leave=False)
while torch.any(mask):
points, (i, j) = sample_mask_points(src_step, mask, S)
idx = i * W + j
points = points.cuda()[None].expand(B, -1, -1)
traj, vis = self.model(video, points, backward_tracking, cache_features)
traj = traj[:, -1]
vis = vis[:, -1].float()
# Update mask
mask = mask.view(-1)
mask[idx] = 0
mask = mask.view(H, W)
# Update flow
flow = flow.view(B, -1, 2)
flow[:, idx] = traj - flow[:, idx]
flow = flow.view(B, H, W, 2)
# Update alpha
alpha = alpha.view(B, -1)
alpha[:, idx] = vis
alpha = alpha.view(B, H, W)
cache_features = False
pbar.update(1)
pbar.close()
return {"flow": flow, "alpha": alpha}
|