| | import warnings
|
| | import torch
|
| |
|
| | try:
|
| | from dot.utils import torch3d
|
| | except ModuleNotFoundError:
|
| | torch3d = None
|
| |
|
| | if torch3d:
|
| | TORCH3D_AVAILABLE = True
|
| | else:
|
| | TORCH3D_AVAILABLE = False
|
| |
|
| |
|
| | def interpolate(src_points, tgt_points, grid, version="torch3d"):
|
| | B, S, _ = src_points.shape
|
| | H, W, _ = grid.shape
|
| |
|
| |
|
| | grid = grid.view(1, H * W, 2).expand(B, -1, -1)
|
| | src_pos, src_alpha = src_points[..., :2], src_points[..., 2]
|
| | if version == "torch" or (version == "torch3d" and not TORCH3D_AVAILABLE):
|
| | if version == "torch3d":
|
| | warnings.warn(
|
| | "Torch3D is not available. For optimal speed and memory consumption, consider setting it up.",
|
| | stacklevel=2,
|
| | )
|
| | dis = (grid ** 2).sum(-1)[:, None] + (src_pos ** 2).sum(-1)[:, :, None] - 2 * src_pos @ grid.permute(0, 2, 1)
|
| | dis[src_alpha == 0] = float('inf')
|
| | _, idx = dis.min(dim=1)
|
| | idx = idx.view(B, H * W, 1)
|
| | elif version == "torch3d":
|
| | src_pos_packed = src_pos[src_alpha.bool()]
|
| | tgt_points_packed = tgt_points[src_alpha.bool()]
|
| | lengths = src_alpha.sum(dim=1).long()
|
| | max_length = int(lengths.max())
|
| | cum_lengths = lengths.cumsum(dim=0)
|
| | cum_lengths = torch.cat([torch.zeros_like(cum_lengths[:1]), cum_lengths[:-1]])
|
| | src_pos = torch3d.packed_to_padded(src_pos_packed, cum_lengths, max_length)
|
| | tgt_points = torch3d.packed_to_padded(tgt_points_packed, cum_lengths, max_length)
|
| | _, idx, _ = torch3d.knn_points(grid, src_pos, lengths2=lengths, return_nn=False)
|
| | idx = idx.view(B, H * W, 1)
|
| |
|
| |
|
| | tgt_pos, tgt_alpha = tgt_points[..., :2], tgt_points[..., 2]
|
| | flow = tgt_pos - src_pos
|
| | flow = torch.cat([flow, tgt_alpha[..., None]], dim=-1)
|
| | flow = flow.gather(dim=1, index=idx.expand(-1, -1, flow.size(-1)))
|
| | flow = flow.view(B, H, W, -1)
|
| | flow, alpha = flow[..., :2], flow[..., 2]
|
| | flow[..., 0] = flow[..., 0] * (W - 1)
|
| | flow[..., 1] = flow[..., 1] * (H - 1)
|
| | return flow, alpha |