Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,621 Bytes
7b25808 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
import random
from copy import deepcopy
import numpy as np
import torch
from iopath.common.file_io import g_pathmgr
from PIL import Image as PILImage
from torchvision.datasets.vision import VisionDataset
from training.dataset.vos_raw_dataset import VOSRawDataset
from training.dataset.vos_sampler import VOSSampler
from training.dataset.vos_segment_loader import JSONSegmentLoader
from training.utils.data_utils import Frame, Object, VideoDatapoint
MAX_RETRIES = 10010
class VOSDataset(VisionDataset):
def __init__(
self,
transforms,
training: bool,
video_dataset: VOSRawDataset,
sampler: VOSSampler,
multiplier: int,
always_target=True,
target_segments_available=True,
):
self._transforms = transforms
self.training = training
self.video_dataset = video_dataset
self.sampler = sampler
self.repeat_factors = torch.ones(len(self.video_dataset), dtype=torch.float32)
self.repeat_factors *= multiplier
print(f"Raw dataset length = {len(self.video_dataset)}")
self.curr_epoch = 0 # Used in case data loader behavior changes across epochs
self.always_target = always_target
self.target_segments_available = target_segments_available
def _get_datapoint(self, idx):
for retry in range(MAX_RETRIES):
try:
if isinstance(idx, torch.Tensor):
idx = idx.item()
# sample a video
video, segment_loader = self.video_dataset.get_video(idx)
# sample frames and object indices to be used in a datapoint
sampled_frms_and_objs = self.sampler.sample(
video, segment_loader, epoch=self.curr_epoch
)
# print(f'sampled_frms_and_objs: {sampled_frms_and_objs}')
break # Succesfully loaded video
except Exception as e:
if self.training:
logging.warning(
f"Loading failed (id={idx}); Retry {retry} with exception: {e}"
)
idx = random.randrange(0, len(self.video_dataset))
else:
# Shouldn't fail to load a val video
raise e
datapoint = self.construct(video, sampled_frms_and_objs, segment_loader)
for transform in self._transforms:
datapoint = transform(datapoint, epoch=self.curr_epoch)
return datapoint
def construct(self, video, sampled_frms_and_objs, segment_loader):
"""
Constructs a VideoDatapoint sample to pass to transforms
"""
sampled_frames = sampled_frms_and_objs.frames
sampled_object_ids = sampled_frms_and_objs.object_ids
images = []
rgb_images = load_images(sampled_frames)
# Iterate over the sampled frames and store their rgb data and object data (bbox, segment)
for frame_idx, frame in enumerate(sampled_frames):
w, h = rgb_images[frame_idx].size
images.append(
Frame(
data=rgb_images[frame_idx],
objects=[],
)
)
# We load the gt segments associated with the current frame
if isinstance(segment_loader, JSONSegmentLoader):
segments = segment_loader.load(
frame.frame_idx, obj_ids=sampled_object_ids
)
else:
segments = segment_loader.load(frame.frame_idx)
for obj_id in sampled_object_ids:
# Extract the segment
if obj_id in segments:
assert (
segments[obj_id] is not None
), "None targets are not supported"
# segment is uint8 and remains uint8 throughout the transforms
# segment = segments[obj_id].to(torch.uint8)
segment = segments[obj_id].to(torch.uint8) # change this to allow mask-granularity pair
else:
# There is no target, we either use a zero mask target or drop this object
if not self.always_target:
continue
segment = torch.zeros(h, w, dtype=torch.uint8)
# prefer sampler-provided interval-sampled granularity if available
if getattr(sampled_frms_and_objs, 'sampled_granularities', None) is not None:
per_frame = sampled_frms_and_objs.sampled_granularities.get(frame.frame_idx, None)
if per_frame is not None and obj_id in per_frame:
gran_val = per_frame[obj_id]
else:
gran_val = segments.get_granularity(key=obj_id)
else:
gran_val = segments.get_granularity(key=obj_id)
images[frame_idx].objects.append(
Object(
object_id=obj_id,
frame_index=frame.frame_idx,
segment=segment,
granularity=gran_val
)
)
return VideoDatapoint(
frames=images,
video_id=video.video_id,
size=(h, w),
)
def __getitem__(self, idx):
return self._get_datapoint(idx)
def __len__(self):
return len(self.video_dataset)
def load_images(frames):
all_images = []
cache = {}
for frame in frames:
if frame.data is None:
# Load the frame rgb data from file
path = frame.image_path
if path in cache:
all_images.append(deepcopy(all_images[cache[path]]))
continue
with g_pathmgr.open(path, "rb") as fopen:
all_images.append(PILImage.open(fopen).convert("RGB"))
cache[path] = len(all_images) - 1
else:
# The frame rgb data has already been loaded
# Convert it to a PILImage
all_images.append(tensor_2_PIL(frame.data))
return all_images
def tensor_2_PIL(data: torch.Tensor) -> PILImage.Image:
data = data.cpu().numpy().transpose((1, 2, 0)) * 255.0
data = data.astype(np.uint8)
return PILImage.fromarray(data)
|