import torch from torch import nn import torch.nn.functional as F from tqdm import tqdm from einops import rearrange, repeat from .optical_flow import OpticalFlow from .point_tracking import PointTracker from dot.utils.torch import get_grid class DenseOpticalTracker(nn.Module): def __init__(self, height=512, width=512, tracker_config="configs/cotracker2_patch_4_wind_8.json", tracker_path="checkpoints/movi_f_cotracker2_patch_4_wind_8.pth", estimator_config="configs/raft_patch_8.json", estimator_path="checkpoints/cvo_raft_patch_8.pth", refiner_config="configs/raft_patch_4_alpha.json", refiner_path="checkpoints/movi_f_raft_patch_4_alpha.pth"): super().__init__() self.point_tracker = PointTracker(height, width, tracker_config, tracker_path, estimator_config, estimator_path) self.optical_flow_refiner = OpticalFlow(height, width, refiner_config, refiner_path) self.name = self.point_tracker.name + "_" + self.optical_flow_refiner.name self.resolution = [height, width] def forward(self, data, mode, **kwargs): if mode == "flow_from_last_to_first_frame": return self.get_flow_from_last_to_first_frame(data, **kwargs) elif mode == "tracks_for_queries": return self.get_tracks_for_queries(data, **kwargs) elif mode == "tracks_from_first_to_every_other_frame": return self.get_tracks_from_first_to_every_other_frame(data, **kwargs) elif mode == "tracks_from_every_cell_in_every_frame": return self.get_tracks_from_every_cell_in_every_frame(data, **kwargs) else: raise ValueError(f"Unknown mode {mode}") def get_flow_from_last_to_first_frame(self, data, **kwargs): B, T, C, h, w = data["video"].shape init = self.point_tracker(data, mode="tracks_at_motion_boundaries", **kwargs)["tracks"] init = torch.stack([init[..., 0] / (w - 1), init[..., 1] / (h - 1), init[..., 2]], dim=-1) data = { "src_frame": data["video"][:, -1], "tgt_frame": data["video"][:, 0], "src_points": init[:, -1], "tgt_points": init[:, 0] } pred = self.optical_flow_refiner(data, mode="flow_with_tracks_init", **kwargs) pred["src_points"] = data["src_points"] pred["tgt_points"] = data["tgt_points"] return pred def get_tracks_for_queries(self, data, **kwargs): time_steps = data["video"].size(1) query_points = data["query_points"] video = data["video"] S = query_points.size(1) B, T, C, h, w = video.shape H, W = self.resolution init = self.point_tracker(data, mode="tracks_at_motion_boundaries", **kwargs)["tracks"] init = torch.stack([init[..., 0] / (w - 1), init[..., 1] / (h - 1), init[..., 2]], dim=-1) if h != H or w != W: video = video.reshape(B * T, C, h, w) video = F.interpolate(video, size=(H, W), mode="bilinear") video = video.reshape(B, T, C, H, W) feats = self.optical_flow_refiner({"video": video}, mode="feats", **kwargs)["feats"] grid = get_grid(H, W, device=video.device) src_steps = [int(v) for v in torch.unique(query_points[..., 0])] tracks = torch.zeros(B, T, S, 3, device=video.device) for src_step in tqdm(src_steps, desc="Refine source step", leave=False): src_points = init[:, src_step] src_feats = feats[:, src_step] tracks_from_src = [] for tgt_step in tqdm(range(time_steps), desc="Refine target step", leave=False): if src_step == tgt_step: flow = torch.zeros(B, H, W, 2, device=video.device) alpha = torch.ones(B, H, W, device=video.device) else: tgt_points = init[:, tgt_step] tgt_feats = feats[:, tgt_step] data = { "src_feats": src_feats, "tgt_feats": tgt_feats, "src_points": src_points, "tgt_points": tgt_points } pred = self.optical_flow_refiner(data, mode="flow_with_tracks_init", **kwargs) flow, alpha = pred["flow"], pred["alpha"] flow[..., 0] = flow[..., 0] / (W - 1) flow[..., 1] = flow[..., 1] / (H - 1) tracks_from_src.append(torch.cat([flow + grid, alpha[..., None]], dim=-1)) tracks_from_src = torch.stack(tracks_from_src, dim=1) for b in range(B): cur = query_points[b, :, 0] == src_step if torch.any(cur): cur_points = query_points[b, cur] cur_x = cur_points[..., 2] / (w - 1) cur_y = cur_points[..., 1] / (h - 1) cur_tracks = dense_to_sparse_tracks(cur_x, cur_y, tracks_from_src[b], h, w) tracks[b, :, cur] = cur_tracks return {"tracks": tracks} def get_tracks_from_first_to_every_other_frame(self, data, return_flow=False, **kwargs): video = data["video"] B, T, C, h, w = video.shape H, W = self.resolution if h != H or w != W: video = video.reshape(B * T, C, h, w) video = F.interpolate(video, size=(H, W), mode="bilinear") video = video.reshape(B, T, C, H, W) init = self.point_tracker(data, mode="tracks_at_motion_boundaries", **kwargs)["tracks"] init = torch.stack([init[..., 0] / (w - 1), init[..., 1] / (h - 1), init[..., 2]], dim=-1) grid = get_grid(H, W, device=video.device) grid[..., 0] *= (W - 1) grid[..., 1] *= (H - 1) src_step = 0 src_points = init[:, src_step] src_frame = video[:, src_step] tracks = [] for tgt_step in tqdm(range(T), desc="Refine target step", leave=False): if src_step == tgt_step: flow = torch.zeros(B, H, W, 2, device=video.device) alpha = torch.ones(B, H, W, device=video.device) else: tgt_points = init[:, tgt_step] tgt_frame = video[:, tgt_step] data = { "src_frame": src_frame, "tgt_frame": tgt_frame, "src_points": src_points, "tgt_points": tgt_points } pred = self.optical_flow_refiner(data, mode="flow_with_tracks_init", **kwargs) flow, alpha = pred["flow"], pred["alpha"] if return_flow: tracks.append(torch.cat([flow, alpha[..., None]], dim=-1)) else: tracks.append(torch.cat([flow + grid, alpha[..., None]], dim=-1)) # flow means: 1->i pixel moving values, grid is the fisrt frame pixel ori cood, alpha is confidence tracks = torch.stack(tracks, dim=1) return {"tracks": tracks} def get_tracks_from_every_cell_in_every_frame(self, data, cell_size=1, cell_time_steps=20, **kwargs): video = data["video"] B, T, C, h, w = video.shape H, W = self.resolution ch, cw, ct = h // cell_size, w // cell_size, min(T, cell_time_steps) if h != H or w != W: video = video.reshape(B * T, C, h, w) video = F.interpolate(video, size=(H, W), mode="bilinear") video = video.reshape(B, T, C, H, W) init = self.point_tracker(data, mode="tracks_at_motion_boundaries", **kwargs)["tracks"] init = torch.stack([init[..., 0] / (w - 1), init[..., 1] / (h - 1), init[..., 2]], dim=-1) feats = self.optical_flow_refiner({"video": video}, mode="feats", **kwargs)["feats"] grid = get_grid(H, W, device=video.device) visited_cells = torch.zeros(B, T, ch, cw, device=video.device) src_steps = torch.linspace(0, T - 1, T // ct).long() tracks = [[] for _ in range(B)] for k, src_step in enumerate(tqdm(src_steps, desc="Refine source step", leave=False)): if visited_cells[:, src_step].all(): continue src_points = init[:, src_step] src_feats = feats[:, src_step] tracks_from_src = [] for tgt_step in tqdm(range(T), desc="Refine target step", leave=False): if src_step == tgt_step: flow = torch.zeros(B, H, W, 2, device=video.device) alpha = torch.ones(B, H, W, device=video.device) else: tgt_points = init[:, tgt_step] tgt_feats = feats[:, tgt_step] data = { "src_feats": src_feats, "tgt_feats": tgt_feats, "src_points": src_points, "tgt_points": tgt_points } pred = self.optical_flow_refiner(data, mode="flow_with_tracks_init", **kwargs) flow, alpha = pred["flow"], pred["alpha"] flow[..., 0] = flow[..., 0] / (W - 1) flow[..., 1] = flow[..., 1] / (H - 1) tracks_from_src.append(torch.cat([flow + grid, alpha[..., None]], dim=-1)) tracks_from_src = torch.stack(tracks_from_src, dim=1) for b in range(B): src_cell = visited_cells[b, src_step] if src_cell.all(): continue cur_y, cur_x = (1 - src_cell).nonzero(as_tuple=True) cur_x = (cur_x + 0.5) / cw cur_y = (cur_y + 0.5) / ch cur_tracks = dense_to_sparse_tracks(cur_x, cur_y, tracks_from_src[b], h, w) visited_cells[b] = update_visited(visited_cells[b], cur_tracks, h, w, ch, cw) tracks[b].append(cur_tracks) tracks = [torch.cat(t, dim=1) for t in tracks] return {"tracks": tracks} def dense_to_sparse_tracks(x, y, tracks, height, width): h, w = height, width T = tracks.size(0) grid = torch.stack([x, y], dim=-1) * 2 - 1 grid = repeat(grid, "s c -> t s r c", t=T, r=1) tracks = rearrange(tracks, "t h w c -> t c h w") tracks = F.grid_sample(tracks, grid, align_corners=True, mode="bilinear") tracks = rearrange(tracks[..., 0], "t c s -> t s c") tracks[..., 0] = tracks[..., 0] * (w - 1) tracks[..., 1] = tracks[..., 1] * (h - 1) tracks[..., 2] = (tracks[..., 2] > 0).float() return tracks def update_visited(visited_cells, tracks, height, width, cell_height, cell_width): T = tracks.size(0) h, w = height, width ch, cw = cell_height, cell_width for tgt_step in range(T): tgt_points = tracks[tgt_step] tgt_vis = tgt_points[:, 2] visited = tgt_points[tgt_vis.bool()] if len(visited) > 0: visited_x, visited_y = visited[:, 0], visited[:, 1] visited_x = (visited_x / (w - 1) * cw).floor().long() visited_y = (visited_y / (h - 1) * ch).floor().long() valid = (visited_x >= 0) & (visited_x < cw) & (visited_y >= 0) & (visited_y < ch) visited_x = visited_x[valid] visited_y = visited_y[valid] tgt_cell = visited_cells[tgt_step].view(-1) tgt_cell[visited_y * cw + visited_x] = 1. tgt_cell = tgt_cell.view_as(visited_cells[tgt_step]) visited_cells[tgt_step] = tgt_cell return visited_cells