| |
| |
|
|
| |
| |
|
|
| import random |
| from dataclasses import dataclass |
| from typing import List |
|
|
| from training.dataset.vos_segment_loader import LazySegments |
|
|
| MAX_RETRIES = 1000 |
|
|
|
|
| @dataclass |
| class SampledFramesAndObjects: |
| frames: List[int] |
| object_ids: List[int] |
|
|
|
|
| class VOSSampler: |
| def __init__(self, sort_frames=True): |
| |
| self.sort_frames = sort_frames |
|
|
| def sample(self, video): |
| raise NotImplementedError() |
|
|
|
|
| class RandomUniformSampler(VOSSampler): |
| def __init__( |
| self, |
| num_frames, |
| max_num_objects, |
| reverse_time_prob=0.0, |
| ): |
| self.num_frames = num_frames |
| self.max_num_objects = max_num_objects |
| self.reverse_time_prob = reverse_time_prob |
|
|
| def sample(self, video, segment_loader, epoch=None): |
|
|
| for retry in range(MAX_RETRIES): |
| if len(video.frames) < self.num_frames: |
| raise Exception( |
| f"Cannot sample {self.num_frames} frames from video {video.video_name} as it only has {len(video.frames)} annotated frames." |
| ) |
| start = random.randrange(0, len(video.frames) - self.num_frames + 1) |
| frames = [video.frames[start + step] for step in range(self.num_frames)] |
| if random.uniform(0, 1) < self.reverse_time_prob: |
| |
| frames = frames[::-1] |
|
|
| |
| visible_object_ids = [] |
| loaded_segms = segment_loader.load(frames[0].frame_idx) |
| if isinstance(loaded_segms, LazySegments): |
| |
| visible_object_ids = list(loaded_segms.keys()) |
| else: |
| for object_id, segment in segment_loader.load( |
| frames[0].frame_idx |
| ).items(): |
| if segment.sum(): |
| visible_object_ids.append(object_id) |
|
|
| |
| if len(visible_object_ids) > 0: |
| break |
| if retry >= MAX_RETRIES - 1: |
| raise Exception("No visible objects") |
|
|
| object_ids = random.sample( |
| visible_object_ids, |
| min(len(visible_object_ids), self.max_num_objects), |
| ) |
| return SampledFramesAndObjects(frames=frames, object_ids=object_ids) |
|
|
|
|
| class EvalSampler(VOSSampler): |
| """ |
| VOS Sampler for evaluation: sampling all the frames and all the objects in a video |
| """ |
|
|
| def __init__( |
| self, |
| ): |
| super().__init__() |
|
|
| def sample(self, video, segment_loader, epoch=None): |
| """ |
| Sampling all the frames and all the objects |
| """ |
| if self.sort_frames: |
| |
| frames = sorted(video.frames, key=lambda x: x.frame_idx) |
| else: |
| |
| frames = video.frames |
| object_ids = segment_loader.load(frames[0].frame_idx).keys() |
| if len(object_ids) == 0: |
| raise Exception("First frame of the video has no objects") |
|
|
| return SampledFramesAndObjects(frames=frames, object_ids=object_ids) |
|
|