| | |
| | |
| |
|
| | |
| | |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from typing import Tuple |
| |
|
| | from cotracker.models.core.cotracker.cotracker import CoTracker2 |
| | from cotracker.models.core.model_utils import get_points_on_a_grid |
| |
|
| |
|
| | class EvaluationPredictor(torch.nn.Module): |
| | def __init__( |
| | self, |
| | cotracker_model: CoTracker2, |
| | interp_shape: Tuple[int, int] = (384, 512), |
| | grid_size: int = 5, |
| | local_grid_size: int = 8, |
| | single_point: bool = True, |
| | n_iters: int = 6, |
| | ) -> None: |
| | super(EvaluationPredictor, self).__init__() |
| | self.grid_size = grid_size |
| | self.local_grid_size = local_grid_size |
| | self.single_point = single_point |
| | self.interp_shape = interp_shape |
| | self.n_iters = n_iters |
| |
|
| | self.model = cotracker_model |
| | self.model.eval() |
| |
|
| | def forward(self, video, queries): |
| | queries = queries.clone() |
| | B, T, C, H, W = video.shape |
| | B, N, D = queries.shape |
| |
|
| | assert D == 3 |
| |
|
| | video = video.reshape(B * T, C, H, W) |
| | video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True) |
| | video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1]) |
| |
|
| | device = video.device |
| |
|
| | queries[:, :, 1] *= (self.interp_shape[1] - 1) / (W - 1) |
| | queries[:, :, 2] *= (self.interp_shape[0] - 1) / (H - 1) |
| |
|
| | if self.single_point: |
| | traj_e = torch.zeros((B, T, N, 2), device=device) |
| | vis_e = torch.zeros((B, T, N), device=device) |
| | for pind in range((N)): |
| | query = queries[:, pind : pind + 1] |
| |
|
| | t = query[0, 0, 0].long() |
| |
|
| | traj_e_pind, vis_e_pind = self._process_one_point(video, query) |
| | traj_e[:, t:, pind : pind + 1] = traj_e_pind[:, :, :1] |
| | vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1] |
| | else: |
| | if self.grid_size > 0: |
| | xy = get_points_on_a_grid(self.grid_size, video.shape[3:]) |
| | xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) |
| | queries = torch.cat([queries, xy], dim=1) |
| |
|
| | traj_e, vis_e, __ = self.model( |
| | video=video, |
| | queries=queries, |
| | iters=self.n_iters, |
| | ) |
| |
|
| | traj_e[:, :, :, 0] *= (W - 1) / float(self.interp_shape[1] - 1) |
| | traj_e[:, :, :, 1] *= (H - 1) / float(self.interp_shape[0] - 1) |
| | return traj_e, vis_e |
| |
|
| | def _process_one_point(self, video, query): |
| | t = query[0, 0, 0].long() |
| |
|
| | device = query.device |
| | if self.local_grid_size > 0: |
| | xy_target = get_points_on_a_grid( |
| | self.local_grid_size, |
| | (50, 50), |
| | [query[0, 0, 2].item(), query[0, 0, 1].item()], |
| | ) |
| |
|
| | xy_target = torch.cat([torch.zeros_like(xy_target[:, :, :1]), xy_target], dim=2).to( |
| | device |
| | ) |
| | query = torch.cat([query, xy_target], dim=1) |
| |
|
| | if self.grid_size > 0: |
| | xy = get_points_on_a_grid(self.grid_size, video.shape[3:]) |
| | xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) |
| | query = torch.cat([query, xy], dim=1) |
| | |
| | query[0, 0, 0] = 0 |
| | traj_e_pind, vis_e_pind, __ = self.model( |
| | video=video[:, t:], queries=query, iters=self.n_iters |
| | ) |
| |
|
| | return traj_e_pind, vis_e_pind |
| |
|