| | |
| | |
| |
|
| | |
| | |
| |
|
| | import logging |
| | import random |
| | from copy import deepcopy |
| |
|
| | import numpy as np |
| |
|
| | import torch |
| | from iopath.common.file_io import g_pathmgr |
| | from PIL import Image as PILImage |
| | from torchvision.datasets.vision import VisionDataset |
| |
|
| | from training.dataset.vos_raw_dataset import VOSRawDataset |
| | from training.dataset.vos_sampler import VOSSampler |
| | from training.dataset.vos_segment_loader import JSONSegmentLoader |
| |
|
| | from training.utils.data_utils import Frame, Object, VideoDatapoint |
| |
|
| | MAX_RETRIES = 100 |
| |
|
| |
|
| | class VOSDataset(VisionDataset): |
| | def __init__( |
| | self, |
| | transforms, |
| | training: bool, |
| | video_dataset: VOSRawDataset, |
| | sampler: VOSSampler, |
| | multiplier: int, |
| | always_target=True, |
| | target_segments_available=True, |
| | ): |
| | self._transforms = transforms |
| | self.training = training |
| | self.video_dataset = video_dataset |
| | self.sampler = sampler |
| |
|
| | self.repeat_factors = torch.ones(len(self.video_dataset), dtype=torch.float32) |
| | self.repeat_factors *= multiplier |
| | print(f"Raw dataset length = {len(self.video_dataset)}") |
| |
|
| | self.curr_epoch = 0 |
| | self.always_target = always_target |
| | self.target_segments_available = target_segments_available |
| |
|
| | def _get_datapoint(self, idx): |
| |
|
| | for retry in range(MAX_RETRIES): |
| | try: |
| | if isinstance(idx, torch.Tensor): |
| | idx = idx.item() |
| | |
| | video, segment_loader = self.video_dataset.get_video(idx) |
| | |
| | sampled_frms_and_objs = self.sampler.sample( |
| | video, segment_loader, epoch=self.curr_epoch |
| | ) |
| | break |
| | except Exception as e: |
| | if self.training: |
| | logging.warning( |
| | f"Loading failed (id={idx}); Retry {retry} with exception: {e}" |
| | ) |
| | idx = random.randrange(0, len(self.video_dataset)) |
| | else: |
| | |
| | raise e |
| |
|
| | datapoint = self.construct(video, sampled_frms_and_objs, segment_loader) |
| | for transform in self._transforms: |
| | datapoint = transform(datapoint, epoch=self.curr_epoch) |
| | return datapoint |
| |
|
| | def construct(self, video, sampled_frms_and_objs, segment_loader): |
| | """ |
| | Constructs a VideoDatapoint sample to pass to transforms |
| | """ |
| | sampled_frames = sampled_frms_and_objs.frames |
| | sampled_object_ids = sampled_frms_and_objs.object_ids |
| |
|
| | images = [] |
| | rgb_images = load_images(sampled_frames) |
| | |
| | for frame_idx, frame in enumerate(sampled_frames): |
| | w, h = rgb_images[frame_idx].size |
| | images.append( |
| | Frame( |
| | data=rgb_images[frame_idx], |
| | objects=[], |
| | ) |
| | ) |
| | |
| | if isinstance(segment_loader, JSONSegmentLoader): |
| | segments = segment_loader.load( |
| | frame.frame_idx, obj_ids=sampled_object_ids |
| | ) |
| | else: |
| | segments = segment_loader.load(frame.frame_idx) |
| | for obj_id in sampled_object_ids: |
| | |
| | if obj_id in segments: |
| | assert ( |
| | segments[obj_id] is not None |
| | ), "None targets are not supported" |
| | |
| | segment = segments[obj_id].to(torch.uint8) |
| | else: |
| | |
| | if not self.always_target: |
| | continue |
| | segment = torch.zeros(h, w, dtype=torch.uint8) |
| |
|
| | images[frame_idx].objects.append( |
| | Object( |
| | object_id=obj_id, |
| | frame_index=frame.frame_idx, |
| | segment=segment, |
| | ) |
| | ) |
| | return VideoDatapoint( |
| | frames=images, |
| | video_id=video.video_id, |
| | size=(h, w), |
| | ) |
| |
|
| | def __getitem__(self, idx): |
| | return self._get_datapoint(idx) |
| |
|
| | def __len__(self): |
| | return len(self.video_dataset) |
| |
|
| |
|
| | def load_images(frames): |
| | all_images = [] |
| | cache = {} |
| | for frame in frames: |
| | if frame.data is None: |
| | |
| | path = frame.image_path |
| | if path in cache: |
| | all_images.append(deepcopy(all_images[cache[path]])) |
| | continue |
| | with g_pathmgr.open(path, "rb") as fopen: |
| | all_images.append(PILImage.open(fopen).convert("RGB")) |
| | cache[path] = len(all_images) - 1 |
| | else: |
| | |
| | |
| | all_images.append(tensor_2_PIL(frame.data)) |
| |
|
| | return all_images |
| |
|
| |
|
| | def tensor_2_PIL(data: torch.Tensor) -> PILImage.Image: |
| | data = data.cpu().numpy().transpose((1, 2, 0)) * 255.0 |
| | data = data.astype(np.uint8) |
| | return PILImage.fromarray(data) |
| |
|