File size: 7,873 Bytes
f71ac1d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 | """Reference View Sampling.
These Classes sample reference views from a dataset that contains videos.
This is usually used when a model needs multiple samples of a video during
training.
"""
from __future__ import annotations
from abc import abstractmethod
from typing import Callable, List
import numpy as np
from torch.utils.data import Dataset
from .const import CommonKeys as K
from .datasets.base import VideoDataset
from .typing import DictData
SortingFunc = Callable[[DictData, list[DictData]], List[DictData]]
def sort_key_first(
cur_sample: DictData, ref_data: list[DictData]
) -> list[DictData]:
"""Sort views as key first."""
return [cur_sample, *ref_data]
def sort_temporal(
cur_sample: DictData, ref_data: list[DictData]
) -> list[DictData]:
"""Sort views temporally."""
return sorted([cur_sample, *ref_data], key=lambda x: x[K.frame_ids])
class ReferenceViewSampler:
"""Base reference view sampler."""
def __init__(self, num_ref_samples: int) -> None:
"""Creates an instance of the class.
Args:
num_ref_samples (int): Number of reference views to sample.
"""
self.num_ref_samples = num_ref_samples
@abstractmethod
def __call__(
self,
key_dataset_index: int,
indices_in_video: list[int],
frame_ids: list[int],
) -> list[int]:
"""Sample num_ref_samples reference view indices.
Args:
key_index (int): Index of key view in the video.
indices_in_video (list[int]): All dataset indices in the video.
frame_ids (list[int]): Frame ids of all views in the video.
Returns:
list[int]: dataset indices of reference views.
"""
raise NotImplementedError
class SequentialViewSampler(ReferenceViewSampler):
"""Sequential View Sampler."""
def __call__(
self,
key_dataset_index: int,
indices_in_video: list[int],
frame_ids: list[int],
) -> list[int]:
"""Sample sequential reference views."""
assert len(frame_ids) >= self.num_ref_samples + 1
key_index = indices_in_video.index(key_dataset_index)
right = key_index + 1 + self.num_ref_samples
if right <= len(indices_in_video):
ref_dataset_indices = indices_in_video[key_index + 1 : right]
else:
left = key_index - (right - len(indices_in_video))
ref_dataset_indices = (
indices_in_video[left:key_index]
+ indices_in_video[key_index + 1 :]
)
return ref_dataset_indices
class UniformViewSampler(ReferenceViewSampler):
"""View Sampler that chooses reference views uniform at random."""
def __init__(self, scope: int, num_ref_samples: int) -> None:
"""Creates an instance of the class.
Args:
scope (int): Define scope of neighborhood to key view to sample
from.
num_ref_samples (int): Number of reference views to sample.
"""
super().__init__(num_ref_samples)
if scope != 0 and scope < num_ref_samples // 2:
raise ValueError("Scope must be higher than num_ref_imgs / 2.")
self.scope = scope
def _get_valid_indices(
self, key_index: int, indices_in_video: list[int], frame_ids: list[int]
) -> list[int]:
"""Get valid indices in video."""
key_fid = frame_ids[key_index]
min_fid = max(0, key_fid - self.scope)
max_fid = min(key_fid + self.scope, frame_ids[-1])
return [
ind
for i, ind in enumerate(indices_in_video)
if min_fid <= frame_ids[i] <= max_fid and i != key_index
]
def __call__(
self,
key_dataset_index: int,
indices_in_video: list[int],
frame_ids: list[int],
) -> list[int]:
"""Uniformly sample reference views."""
if self.scope > 0:
key_index = indices_in_video.index(key_dataset_index)
valid_indices = self._get_valid_indices(
key_index, indices_in_video, frame_ids
)
if len(valid_indices) > 0:
assert len(valid_indices) >= self.num_ref_samples
return np.random.choice(
valid_indices, self.num_ref_samples, replace=False
).tolist()
return [key_dataset_index] * self.num_ref_samples
class MultiViewDataset(Dataset[list[DictData]]):
"""Dataset that samples reference views from a video dataset."""
def __init__(
self,
dataset: VideoDataset,
sampler: ReferenceViewSampler,
sort_fn: SortingFunc = sort_key_first,
num_retry: int = 3,
match_key: str = K.boxes2d_track_ids,
skip_nomatch_samples: bool = False,
) -> None:
"""Creates an instance of the class.
Args:
dataset (Dataset): Video dataset to sample from.
sampler (ReferenceViewSampler): Sampler that samples reference
views.
sort_fn (SortingFunc, optional): Function that sorts key and
reference views. Defaults to sort_key_first.
num_retry (int, optional): Number of retries if no match is found.
Defaults to 3.
match_key (str, optional): Key to match reference views with key
view. Defaults to K.boxes2d_track_ids.
skip_nomatch_samples (bool, optional): Whether to skip samples
where no match is found. Defaults to False.
"""
self.dataset = dataset
self.sampler = sampler
self.sort_fn = sort_fn
self.num_retry = num_retry
self.match_key = match_key
self.skip_nomatch_samples = skip_nomatch_samples
def has_matches(
self, key_data: DictData, ref_data: list[DictData]
) -> bool:
"""Check if key / ref data have matches."""
key_target = key_data[self.match_key]
for ref_view in ref_data:
ref_target = ref_view[self.match_key]
match = np.equal(
np.expand_dims(key_target, axis=1), ref_target[None]
)
if match.any():
return True
return False # pragma: no cover
def __len__(self) -> int:
"""Get length of dataset."""
return len(self.dataset)
def get_ref_data(self, ref_indices: list[int]) -> list[DictData]:
"""Get reference data from dataset."""
ref_data = []
for ref_index in ref_indices:
ref_sample = self.dataset[ref_index]
ref_sample["keyframes"] = False
ref_data.append(ref_sample)
assert self.sampler.num_ref_samples == len(ref_data)
return ref_data
def __getitem__(self, index: int) -> list[DictData]:
"""Get item from dataset."""
cur_sample = self.dataset[index]
cur_sample["keyframes"] = True
indices_in_video = self.dataset.video_mapping["video_to_indices"][
cur_sample[K.sequence_names]
]
frame_ids = self.dataset.video_mapping["video_to_frame_ids"][
cur_sample[K.sequence_names]
]
if self.sampler.num_ref_samples > 0:
for _ in range(self.num_retry):
ref_indices = self.sampler(index, indices_in_video, frame_ids)
ref_data = self.get_ref_data(ref_indices)
if self.skip_nomatch_samples and not (
self.has_matches(cur_sample, ref_data)
):
continue
return self.sort_fn(cur_sample, ref_data)
ref_indices = [index] * self.sampler.num_ref_samples
ref_data = self.get_ref_data(ref_indices)
return [cur_sample, *ref_data]
return [cur_sample]
|