|
|
|
|
|
|
|
|
|
|
|
|
| """
|
| Misc functions, including distributed helpers.
|
|
|
| Mostly copy-paste from torchvision references.
|
| """
|
|
|
| from dataclasses import dataclass
|
| from typing import List, Optional, Tuple, Union
|
|
|
| import torch
|
|
|
| from PIL import Image as PILImage
|
| from tensordict import tensorclass
|
|
|
|
|
| @tensorclass
|
| class BatchedVideoMetaData:
|
| """
|
| This class represents metadata about a batch of videos.
|
| Attributes:
|
| unique_objects_identifier: A tensor of shape Bx3 containing unique identifiers for each object in the batch. Index consists of (video_id, obj_id, frame_id)
|
| frame_orig_size: A tensor of shape Bx2 containing the original size of each frame in the batch.
|
| """
|
|
|
| unique_objects_identifier: torch.LongTensor
|
| frame_orig_size: torch.LongTensor
|
|
|
|
|
| @tensorclass
|
| class BatchedVideoDatapoint:
|
| """
|
| This class represents a batch of videos with associated annotations and metadata.
|
| Attributes:
|
| img_batch: A [TxBxCxHxW] tensor containing the image data for each frame in the batch, where T is the number of frames per video, and B is the number of videos in the batch.
|
| obj_to_frame_idx: A [TxOx2] tensor containing the image_batch index which the object belongs to. O is the number of objects in the batch.
|
| masks: A [TxOxHxW] tensor containing binary masks for each object in the batch.
|
| metadata: An instance of BatchedVideoMetaData containing metadata about the batch.
|
| dict_key: A string key used to identify the batch.
|
| """
|
|
|
| img_batch: torch.FloatTensor
|
| obj_to_frame_idx: torch.IntTensor
|
| masks: torch.BoolTensor
|
| metadata: BatchedVideoMetaData
|
|
|
| dict_key: str
|
|
|
| def pin_memory(self, device=None):
|
| return self.apply(torch.Tensor.pin_memory, device=device)
|
|
|
| @property
|
| def num_frames(self) -> int:
|
| """
|
| Returns the number of frames per video.
|
| """
|
| return self.batch_size[0]
|
|
|
| @property
|
| def num_videos(self) -> int:
|
| """
|
| Returns the number of videos in the batch.
|
| """
|
| return self.img_batch.shape[1]
|
|
|
| @property
|
| def flat_obj_to_img_idx(self) -> torch.IntTensor:
|
| """
|
| Returns a flattened tensor containing the object to img index.
|
| The flat index can be used to access a flattened img_batch of shape [(T*B)xCxHxW]
|
| """
|
| frame_idx, video_idx = self.obj_to_frame_idx.unbind(dim=-1)
|
| flat_idx = video_idx * self.num_frames + frame_idx
|
| return flat_idx
|
|
|
| @property
|
| def flat_img_batch(self) -> torch.FloatTensor:
|
| """
|
| Returns a flattened img_batch_tensor of shape [(B*T)xCxHxW]
|
| """
|
|
|
| return self.img_batch.transpose(0, 1).flatten(0, 1)
|
|
|
|
|
| @dataclass
|
| class Object:
|
|
|
| object_id: int
|
|
|
| frame_index: int
|
| segment: Union[torch.Tensor, dict]
|
|
|
|
|
| @dataclass
|
| class Frame:
|
| data: Union[torch.Tensor, PILImage.Image]
|
| objects: List[Object]
|
|
|
|
|
| @dataclass
|
| class VideoDatapoint:
|
| """Refers to an image/video and all its annotations"""
|
|
|
| frames: List[Frame]
|
| video_id: int
|
| size: Tuple[int, int]
|
|
|
|
|
| def collate_fn(
|
| batch: List[VideoDatapoint],
|
| dict_key,
|
| ) -> BatchedVideoDatapoint:
|
| """
|
| Args:
|
| batch: A list of VideoDatapoint instances.
|
| dict_key (str): A string key used to identify the batch.
|
| """
|
| img_batch = []
|
| for video in batch:
|
| img_batch += [torch.stack([frame.data for frame in video.frames], dim=0)]
|
|
|
| img_batch = torch.stack(img_batch, dim=0).permute((1, 0, 2, 3, 4))
|
| T = img_batch.shape[0]
|
|
|
| step_t_objects_identifier = [[] for _ in range(T)]
|
| step_t_frame_orig_size = [[] for _ in range(T)]
|
|
|
| step_t_masks = [[] for _ in range(T)]
|
| step_t_obj_to_frame_idx = [
|
| [] for _ in range(T)
|
| ]
|
|
|
| for video_idx, video in enumerate(batch):
|
| orig_video_id = video.video_id
|
| orig_frame_size = video.size
|
| for t, frame in enumerate(video.frames):
|
| objects = frame.objects
|
| for obj in objects:
|
| orig_obj_id = obj.object_id
|
| orig_frame_idx = obj.frame_index
|
| step_t_obj_to_frame_idx[t].append(
|
| torch.tensor([t, video_idx], dtype=torch.int)
|
| )
|
| step_t_masks[t].append(obj.segment.to(torch.bool))
|
| step_t_objects_identifier[t].append(
|
| torch.tensor([orig_video_id, orig_obj_id, orig_frame_idx])
|
| )
|
| step_t_frame_orig_size[t].append(torch.tensor(orig_frame_size))
|
|
|
| obj_to_frame_idx = torch.stack(
|
| [
|
| torch.stack(obj_to_frame_idx, dim=0)
|
| for obj_to_frame_idx in step_t_obj_to_frame_idx
|
| ],
|
| dim=0,
|
| )
|
| masks = torch.stack([torch.stack(masks, dim=0) for masks in step_t_masks], dim=0)
|
| objects_identifier = torch.stack(
|
| [torch.stack(id, dim=0) for id in step_t_objects_identifier], dim=0
|
| )
|
| frame_orig_size = torch.stack(
|
| [torch.stack(id, dim=0) for id in step_t_frame_orig_size], dim=0
|
| )
|
| return BatchedVideoDatapoint(
|
| img_batch=img_batch,
|
| obj_to_frame_idx=obj_to_frame_idx,
|
| masks=masks,
|
| metadata=BatchedVideoMetaData(
|
| unique_objects_identifier=objects_identifier,
|
| frame_orig_size=frame_orig_size,
|
| ),
|
| dict_key=dict_key,
|
| batch_size=[T],
|
| )
|
|
|