Image-to-Video
MotionPro / data /dot_single_video /dot /models /dense_optical_tracking.py
zzwustc's picture
Upload folder using huggingface_hub
ef296aa verified
raw
history blame
11.9 kB
import torch
from torch import nn
import torch.nn.functional as F
from tqdm import tqdm
from einops import rearrange, repeat
from .optical_flow import OpticalFlow
from .point_tracking import PointTracker
from dot.utils.torch import get_grid
class DenseOpticalTracker(nn.Module):
def __init__(self,
height=512,
width=512,
tracker_config="configs/cotracker2_patch_4_wind_8.json",
tracker_path="checkpoints/movi_f_cotracker2_patch_4_wind_8.pth",
estimator_config="configs/raft_patch_8.json",
estimator_path="checkpoints/cvo_raft_patch_8.pth",
refiner_config="configs/raft_patch_4_alpha.json",
refiner_path="checkpoints/movi_f_raft_patch_4_alpha.pth"):
super().__init__()
self.point_tracker = PointTracker(height, width, tracker_config, tracker_path, estimator_config, estimator_path)
self.optical_flow_refiner = OpticalFlow(height, width, refiner_config, refiner_path)
self.name = self.point_tracker.name + "_" + self.optical_flow_refiner.name
self.resolution = [height, width]
def forward(self, data, mode, **kwargs):
if mode == "flow_from_last_to_first_frame":
return self.get_flow_from_last_to_first_frame(data, **kwargs)
elif mode == "tracks_for_queries":
return self.get_tracks_for_queries(data, **kwargs)
elif mode == "tracks_from_first_to_every_other_frame":
return self.get_tracks_from_first_to_every_other_frame(data, **kwargs)
elif mode == "tracks_from_every_cell_in_every_frame":
return self.get_tracks_from_every_cell_in_every_frame(data, **kwargs)
else:
raise ValueError(f"Unknown mode {mode}")
def get_flow_from_last_to_first_frame(self, data, **kwargs):
B, T, C, h, w = data["video"].shape
init = self.point_tracker(data, mode="tracks_at_motion_boundaries", **kwargs)["tracks"]
init = torch.stack([init[..., 0] / (w - 1), init[..., 1] / (h - 1), init[..., 2]], dim=-1)
data = {
"src_frame": data["video"][:, -1],
"tgt_frame": data["video"][:, 0],
"src_points": init[:, -1],
"tgt_points": init[:, 0]
}
pred = self.optical_flow_refiner(data, mode="flow_with_tracks_init", **kwargs)
pred["src_points"] = data["src_points"]
pred["tgt_points"] = data["tgt_points"]
return pred
def get_tracks_for_queries(self, data, **kwargs):
time_steps = data["video"].size(1)
query_points = data["query_points"]
video = data["video"]
S = query_points.size(1)
B, T, C, h, w = video.shape
H, W = self.resolution
init = self.point_tracker(data, mode="tracks_at_motion_boundaries", **kwargs)["tracks"]
init = torch.stack([init[..., 0] / (w - 1), init[..., 1] / (h - 1), init[..., 2]], dim=-1)
if h != H or w != W:
video = video.reshape(B * T, C, h, w)
video = F.interpolate(video, size=(H, W), mode="bilinear")
video = video.reshape(B, T, C, H, W)
feats = self.optical_flow_refiner({"video": video}, mode="feats", **kwargs)["feats"]
grid = get_grid(H, W, device=video.device)
src_steps = [int(v) for v in torch.unique(query_points[..., 0])]
tracks = torch.zeros(B, T, S, 3, device=video.device)
for src_step in tqdm(src_steps, desc="Refine source step", leave=False):
src_points = init[:, src_step]
src_feats = feats[:, src_step]
tracks_from_src = []
for tgt_step in tqdm(range(time_steps), desc="Refine target step", leave=False):
if src_step == tgt_step:
flow = torch.zeros(B, H, W, 2, device=video.device)
alpha = torch.ones(B, H, W, device=video.device)
else:
tgt_points = init[:, tgt_step]
tgt_feats = feats[:, tgt_step]
data = {
"src_feats": src_feats,
"tgt_feats": tgt_feats,
"src_points": src_points,
"tgt_points": tgt_points
}
pred = self.optical_flow_refiner(data, mode="flow_with_tracks_init", **kwargs)
flow, alpha = pred["flow"], pred["alpha"]
flow[..., 0] = flow[..., 0] / (W - 1)
flow[..., 1] = flow[..., 1] / (H - 1)
tracks_from_src.append(torch.cat([flow + grid, alpha[..., None]], dim=-1))
tracks_from_src = torch.stack(tracks_from_src, dim=1)
for b in range(B):
cur = query_points[b, :, 0] == src_step
if torch.any(cur):
cur_points = query_points[b, cur]
cur_x = cur_points[..., 2] / (w - 1)
cur_y = cur_points[..., 1] / (h - 1)
cur_tracks = dense_to_sparse_tracks(cur_x, cur_y, tracks_from_src[b], h, w)
tracks[b, :, cur] = cur_tracks
return {"tracks": tracks}
def get_tracks_from_first_to_every_other_frame(self, data, return_flow=False, **kwargs):
video = data["video"]
B, T, C, h, w = video.shape
H, W = self.resolution
if h != H or w != W:
video = video.reshape(B * T, C, h, w)
video = F.interpolate(video, size=(H, W), mode="bilinear")
video = video.reshape(B, T, C, H, W)
init = self.point_tracker(data, mode="tracks_at_motion_boundaries", **kwargs)["tracks"]
init = torch.stack([init[..., 0] / (w - 1), init[..., 1] / (h - 1), init[..., 2]], dim=-1)
grid = get_grid(H, W, device=video.device)
grid[..., 0] *= (W - 1)
grid[..., 1] *= (H - 1)
src_step = 0
src_points = init[:, src_step]
src_frame = video[:, src_step]
tracks = []
for tgt_step in tqdm(range(T), desc="Refine target step", leave=False):
if src_step == tgt_step:
flow = torch.zeros(B, H, W, 2, device=video.device)
alpha = torch.ones(B, H, W, device=video.device)
else:
tgt_points = init[:, tgt_step]
tgt_frame = video[:, tgt_step]
data = {
"src_frame": src_frame,
"tgt_frame": tgt_frame,
"src_points": src_points,
"tgt_points": tgt_points
}
pred = self.optical_flow_refiner(data, mode="flow_with_tracks_init", **kwargs)
flow, alpha = pred["flow"], pred["alpha"]
if return_flow:
tracks.append(torch.cat([flow, alpha[..., None]], dim=-1))
else:
tracks.append(torch.cat([flow + grid, alpha[..., None]], dim=-1)) # flow means: 1->i pixel moving values, grid is the fisrt frame pixel ori cood, alpha is confidence
tracks = torch.stack(tracks, dim=1)
return {"tracks": tracks}
def get_tracks_from_every_cell_in_every_frame(self, data, cell_size=1, cell_time_steps=20, **kwargs):
video = data["video"]
B, T, C, h, w = video.shape
H, W = self.resolution
ch, cw, ct = h // cell_size, w // cell_size, min(T, cell_time_steps)
if h != H or w != W:
video = video.reshape(B * T, C, h, w)
video = F.interpolate(video, size=(H, W), mode="bilinear")
video = video.reshape(B, T, C, H, W)
init = self.point_tracker(data, mode="tracks_at_motion_boundaries", **kwargs)["tracks"]
init = torch.stack([init[..., 0] / (w - 1), init[..., 1] / (h - 1), init[..., 2]], dim=-1)
feats = self.optical_flow_refiner({"video": video}, mode="feats", **kwargs)["feats"]
grid = get_grid(H, W, device=video.device)
visited_cells = torch.zeros(B, T, ch, cw, device=video.device)
src_steps = torch.linspace(0, T - 1, T // ct).long()
tracks = [[] for _ in range(B)]
for k, src_step in enumerate(tqdm(src_steps, desc="Refine source step", leave=False)):
if visited_cells[:, src_step].all():
continue
src_points = init[:, src_step]
src_feats = feats[:, src_step]
tracks_from_src = []
for tgt_step in tqdm(range(T), desc="Refine target step", leave=False):
if src_step == tgt_step:
flow = torch.zeros(B, H, W, 2, device=video.device)
alpha = torch.ones(B, H, W, device=video.device)
else:
tgt_points = init[:, tgt_step]
tgt_feats = feats[:, tgt_step]
data = {
"src_feats": src_feats,
"tgt_feats": tgt_feats,
"src_points": src_points,
"tgt_points": tgt_points
}
pred = self.optical_flow_refiner(data, mode="flow_with_tracks_init", **kwargs)
flow, alpha = pred["flow"], pred["alpha"]
flow[..., 0] = flow[..., 0] / (W - 1)
flow[..., 1] = flow[..., 1] / (H - 1)
tracks_from_src.append(torch.cat([flow + grid, alpha[..., None]], dim=-1))
tracks_from_src = torch.stack(tracks_from_src, dim=1)
for b in range(B):
src_cell = visited_cells[b, src_step]
if src_cell.all():
continue
cur_y, cur_x = (1 - src_cell).nonzero(as_tuple=True)
cur_x = (cur_x + 0.5) / cw
cur_y = (cur_y + 0.5) / ch
cur_tracks = dense_to_sparse_tracks(cur_x, cur_y, tracks_from_src[b], h, w)
visited_cells[b] = update_visited(visited_cells[b], cur_tracks, h, w, ch, cw)
tracks[b].append(cur_tracks)
tracks = [torch.cat(t, dim=1) for t in tracks]
return {"tracks": tracks}
def dense_to_sparse_tracks(x, y, tracks, height, width):
h, w = height, width
T = tracks.size(0)
grid = torch.stack([x, y], dim=-1) * 2 - 1
grid = repeat(grid, "s c -> t s r c", t=T, r=1)
tracks = rearrange(tracks, "t h w c -> t c h w")
tracks = F.grid_sample(tracks, grid, align_corners=True, mode="bilinear")
tracks = rearrange(tracks[..., 0], "t c s -> t s c")
tracks[..., 0] = tracks[..., 0] * (w - 1)
tracks[..., 1] = tracks[..., 1] * (h - 1)
tracks[..., 2] = (tracks[..., 2] > 0).float()
return tracks
def update_visited(visited_cells, tracks, height, width, cell_height, cell_width):
T = tracks.size(0)
h, w = height, width
ch, cw = cell_height, cell_width
for tgt_step in range(T):
tgt_points = tracks[tgt_step]
tgt_vis = tgt_points[:, 2]
visited = tgt_points[tgt_vis.bool()]
if len(visited) > 0:
visited_x, visited_y = visited[:, 0], visited[:, 1]
visited_x = (visited_x / (w - 1) * cw).floor().long()
visited_y = (visited_y / (h - 1) * ch).floor().long()
valid = (visited_x >= 0) & (visited_x < cw) & (visited_y >= 0) & (visited_y < ch)
visited_x = visited_x[valid]
visited_y = visited_y[valid]
tgt_cell = visited_cells[tgt_step].view(-1)
tgt_cell[visited_y * cw + visited_x] = 1.
tgt_cell = tgt_cell.view_as(visited_cells[tgt_step])
visited_cells[tgt_step] = tgt_cell
return visited_cells