| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Dict, List, Optional, Tuple, Union |
| import numpy as np |
| import torch |
|
|
| def process_tracks(tracks_np: np.ndarray, frame_size: Tuple[int, int], quant_multi: int = 8, **kwargs): |
| |
| |
|
|
| tracks = torch.from_numpy(tracks_np).float() |
| |
| if tracks.shape[1] == 121: |
| tracks = torch.permute(tracks, (1, 0, 2, 3)) |
| |
| tracks, visibles = tracks[..., :2], tracks[..., 2:3] |
| short_edge = min(*frame_size) |
|
|
| tracks = tracks - torch.tensor([*frame_size]).type_as(tracks) / 2 |
| tracks = tracks / short_edge * 2 |
|
|
| visibles = visibles * 2 - 1 |
|
|
| trange = torch.linspace(-1, 1, tracks.shape[0]).view(-1, 1, 1, 1).expand(*visibles.shape) |
| |
| out_ = torch.cat([trange, tracks, visibles], dim=-1).view(121, -1, 4) |
| out_0 = out_[:1] |
| out_l = out_[1:] |
| out_l = torch.repeat_interleave(out_l, 2, dim=0)[1::3] |
| return torch.cat([out_0, out_l], dim=0) |
|
|