|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import List, Optional, Tuple, Union |
|
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
def ind_sel(target: torch.Tensor, ind: torch.Tensor, dim: int = 1): |
|
|
""" |
|
|
:param target: [... (can be k or 1), n > M, ...] |
|
|
:param ind: [... (k), M] |
|
|
:param dim: dim to apply index on |
|
|
:return: sel_target [... (k), M, ...] |
|
|
""" |
|
|
assert ( |
|
|
len(ind.shape) > dim |
|
|
), "Index must have the target dim, but get dim: %d, ind shape: %s" % (dim, str(ind.shape)) |
|
|
|
|
|
target = target.expand( |
|
|
*tuple( |
|
|
[ind.shape[k] if target.shape[k] == 1 else -1 for k in range(dim)] |
|
|
+ [ |
|
|
-1, |
|
|
] |
|
|
* (len(target.shape) - dim) |
|
|
) |
|
|
) |
|
|
|
|
|
ind_pad = ind |
|
|
|
|
|
if len(target.shape) > dim + 1: |
|
|
for _ in range(len(target.shape) - (dim + 1)): |
|
|
ind_pad = ind_pad.unsqueeze(-1) |
|
|
ind_pad = ind_pad.expand(*(-1,) * (dim + 1), *target.shape[(dim + 1) : :]) |
|
|
|
|
|
return torch.gather(target, dim=dim, index=ind_pad) |
|
|
|
|
|
|
|
|
def merge_final(vert_attr: torch.Tensor, weight: torch.Tensor, vert_assign: torch.Tensor): |
|
|
""" |
|
|
|
|
|
:param vert_attr: [n, d] or [b, n, d] color or feature of each vertex |
|
|
:param weight: [b(optional), w, h, M] weight of selected vertices |
|
|
:param vert_assign: [b(optional), w, h, M] selective index |
|
|
:return: |
|
|
""" |
|
|
target_dim = len(vert_assign.shape) - 1 |
|
|
if len(vert_attr.shape) == 2: |
|
|
assert vert_attr.shape[0] > vert_assign.max() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
new_shape = [1] * target_dim + list(vert_attr.shape) |
|
|
tensor = vert_attr.reshape(new_shape) |
|
|
sel_attr = ind_sel(tensor, vert_assign.type(torch.long), dim=target_dim) |
|
|
else: |
|
|
assert vert_attr.shape[1] > vert_assign.max() |
|
|
|
|
|
|
|
|
|
|
|
new_shape = [vert_attr.shape[0]] + [1] * (target_dim - 1) + list(vert_attr.shape[1:]) |
|
|
tensor = vert_attr.reshape(new_shape) |
|
|
sel_attr = ind_sel(tensor, vert_assign.type(torch.long), dim=target_dim) |
|
|
|
|
|
|
|
|
final_attr = torch.sum(sel_attr * weight.unsqueeze(-1), dim=-2) |
|
|
return final_attr |
|
|
|
|
|
|
|
|
def patch_motion( |
|
|
tracks: torch.FloatTensor, |
|
|
vid: torch.FloatTensor, |
|
|
temperature: float = 220.0, |
|
|
vae_divide: tuple = (4, 16), |
|
|
topk: int = 2, |
|
|
): |
|
|
with torch.no_grad(): |
|
|
_, T, H, W = vid.shape |
|
|
N = tracks.shape[2] |
|
|
_, tracks, visible = torch.split( |
|
|
tracks, [1, 2, 1], dim=-1 |
|
|
) |
|
|
tracks_n = tracks / torch.tensor([W / min(H, W), H / min(H, W)], device=tracks.device) |
|
|
tracks_n = tracks_n.clamp(-1, 1) |
|
|
visible = visible.clamp(0, 1) |
|
|
|
|
|
xx = torch.linspace(-W / min(H, W), W / min(H, W), W) |
|
|
yy = torch.linspace(-H / min(H, W), H / min(H, W), H) |
|
|
|
|
|
grid = torch.stack(torch.meshgrid(yy, xx, indexing="ij")[::-1], dim=-1).to( |
|
|
tracks.device |
|
|
) |
|
|
|
|
|
tracks_pad = tracks[:, 1:] |
|
|
visible_pad = visible[:, 1:] |
|
|
|
|
|
visible_align = visible_pad.view(T - 1, 4, *visible_pad.shape[2:]).sum(1) |
|
|
tracks_align = (tracks_pad * visible_pad).view(T - 1, 4, *tracks_pad.shape[2:]).sum( |
|
|
1 |
|
|
) / (visible_align + 1e-5) |
|
|
dist_ = ( |
|
|
(tracks_align[:, None, None] - grid[None, :, :, None]).pow(2).sum(-1) |
|
|
) |
|
|
weight = torch.exp(-dist_ * temperature) * visible_align.clamp(0, 1).view( |
|
|
T - 1, 1, 1, N |
|
|
) |
|
|
vert_weight, vert_index = torch.topk( |
|
|
weight, k=min(topk, weight.shape[-1]), dim=-1 |
|
|
) |
|
|
|
|
|
grid_mode = "bilinear" |
|
|
point_feature = torch.nn.functional.grid_sample( |
|
|
vid[vae_divide[0]:].permute(1, 0, 2, 3)[:1], |
|
|
tracks_n[:, :1].type(vid.dtype), |
|
|
mode=grid_mode, |
|
|
padding_mode="zeros", |
|
|
align_corners=False, |
|
|
) |
|
|
point_feature = point_feature.squeeze(0).squeeze(1).permute(1, 0) |
|
|
|
|
|
out_feature = merge_final(point_feature, vert_weight, vert_index).permute(3, 0, 1, 2) |
|
|
out_weight = vert_weight.sum(-1) |
|
|
|
|
|
|
|
|
mix_feature = out_feature + vid[vae_divide[0]:, 1:] * (1 - out_weight.clamp(0, 1)) |
|
|
|
|
|
out_feature_full = torch.cat([vid[vae_divide[0]:, :1], mix_feature], dim=1) |
|
|
out_mask_full = torch.cat([torch.ones_like(out_weight[:1]), out_weight], dim=0) |
|
|
return torch.cat([out_mask_full[None].expand(vae_divide[0], -1, -1, -1), out_feature_full], dim=0) |
|
|
|