| | |
| | |
| |
|
| | |
| | |
| |
|
| |
|
| | import torch |
| | import dataclasses |
| | import torch.nn.functional as F |
| | from dataclasses import dataclass |
| | from typing import Any, Optional |
| |
|
| |
|
| | @dataclass(eq=False) |
| | class CoTrackerData: |
| | """ |
| | Dataclass for storing video tracks data. |
| | """ |
| |
|
| | video: torch.Tensor |
| | trajectory: torch.Tensor |
| | visibility: torch.Tensor |
| | |
| | valid: Optional[torch.Tensor] = None |
| | segmentation: Optional[torch.Tensor] = None |
| | seq_name: Optional[str] = None |
| | query_points: Optional[torch.Tensor] = None |
| |
|
| |
|
| | def collate_fn(batch): |
| | """ |
| | Collate function for video tracks data. |
| | """ |
| | video = torch.stack([b.video for b in batch], dim=0) |
| | trajectory = torch.stack([b.trajectory for b in batch], dim=0) |
| | visibility = torch.stack([b.visibility for b in batch], dim=0) |
| | query_points = segmentation = None |
| | if batch[0].query_points is not None: |
| | query_points = torch.stack([b.query_points for b in batch], dim=0) |
| | if batch[0].segmentation is not None: |
| | segmentation = torch.stack([b.segmentation for b in batch], dim=0) |
| | seq_name = [b.seq_name for b in batch] |
| |
|
| | return CoTrackerData( |
| | video=video, |
| | trajectory=trajectory, |
| | visibility=visibility, |
| | segmentation=segmentation, |
| | seq_name=seq_name, |
| | query_points=query_points, |
| | ) |
| |
|
| |
|
| | def collate_fn_train(batch): |
| | """ |
| | Collate function for video tracks data during training. |
| | """ |
| | gotit = [gotit for _, gotit in batch] |
| | video = torch.stack([b.video for b, _ in batch], dim=0) |
| | trajectory = torch.stack([b.trajectory for b, _ in batch], dim=0) |
| | visibility = torch.stack([b.visibility for b, _ in batch], dim=0) |
| | valid = torch.stack([b.valid for b, _ in batch], dim=0) |
| | seq_name = [b.seq_name for b, _ in batch] |
| | return ( |
| | CoTrackerData( |
| | video=video, |
| | trajectory=trajectory, |
| | visibility=visibility, |
| | valid=valid, |
| | seq_name=seq_name, |
| | ), |
| | gotit, |
| | ) |
| |
|
| |
|
| | def try_to_cuda(t: Any) -> Any: |
| | """ |
| | Try to move the input variable `t` to a cuda device. |
| | |
| | Args: |
| | t: Input. |
| | |
| | Returns: |
| | t_cuda: `t` moved to a cuda device, if supported. |
| | """ |
| | try: |
| | t = t.float().cuda() |
| | except AttributeError: |
| | pass |
| | return t |
| |
|
| |
|
| | def dataclass_to_cuda_(obj): |
| | """ |
| | Move all contents of a dataclass to cuda inplace if supported. |
| | |
| | Args: |
| | batch: Input dataclass. |
| | |
| | Returns: |
| | batch_cuda: `batch` moved to a cuda device, if supported. |
| | """ |
| | for f in dataclasses.fields(obj): |
| | setattr(obj, f.name, try_to_cuda(getattr(obj, f.name))) |
| | return obj |
| |
|