|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from abc import ABC |
|
|
|
|
|
from hydra.utils import instantiate |
|
|
import torch |
|
|
import random |
|
|
import numpy as np |
|
|
from torch.utils.data import Dataset |
|
|
from torch.utils.data import ConcatDataset |
|
|
import bisect |
|
|
from .dataset_util import * |
|
|
from .track_util import * |
|
|
from .augmentation import get_image_augmentation |
|
|
|
|
|
|
|
|
class ComposedDataset(Dataset, ABC): |
|
|
""" |
|
|
Composes multiple base datasets and applies common configurations. |
|
|
|
|
|
This dataset provides a flexible way to combine multiple base datasets while |
|
|
applying shared augmentations, track generation, and other processing steps. |
|
|
It handles image normalization, tensor conversion, and other preparations |
|
|
needed for training computer vision models with sequences of images. |
|
|
""" |
|
|
def __init__(self, dataset_configs: dict, common_config: dict, **kwargs): |
|
|
""" |
|
|
Initializes the ComposedDataset. |
|
|
|
|
|
Args: |
|
|
dataset_configs (dict): List of Hydra configurations for base datasets. |
|
|
common_config (dict): Shared configurations (augs, tracks, mode, etc.). |
|
|
**kwargs: Additional arguments (unused). |
|
|
""" |
|
|
base_dataset_list = [] |
|
|
|
|
|
|
|
|
for baseset_dict in dataset_configs: |
|
|
baseset = instantiate(baseset_dict, common_conf=common_config) |
|
|
base_dataset_list.append(baseset) |
|
|
|
|
|
|
|
|
self.base_dataset = TupleConcatDataset(base_dataset_list, common_config) |
|
|
|
|
|
|
|
|
|
|
|
self.cojitter = common_config.augs.cojitter |
|
|
|
|
|
self.cojitter_ratio = common_config.augs.cojitter_ratio |
|
|
|
|
|
self.image_aug = get_image_augmentation( |
|
|
color_jitter=common_config.augs.color_jitter, |
|
|
gray_scale=common_config.augs.gray_scale, |
|
|
gau_blur=common_config.augs.gau_blur, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.fixed_num_images = common_config.fix_img_num |
|
|
|
|
|
self.fixed_aspect_ratio = common_config.fix_aspect_ratio |
|
|
|
|
|
|
|
|
|
|
|
self.load_track = common_config.load_track |
|
|
|
|
|
self.track_num = common_config.track_num |
|
|
|
|
|
|
|
|
|
|
|
self.training = common_config.training |
|
|
self.common_config = common_config |
|
|
|
|
|
self.total_samples = len(self.base_dataset) |
|
|
|
|
|
def __len__(self): |
|
|
"""Returns the total number of sequences in the dataset.""" |
|
|
return self.total_samples |
|
|
|
|
|
|
|
|
def __getitem__(self, idx_tuple): |
|
|
""" |
|
|
Retrieves a data sample (sequence) from the dataset. |
|
|
|
|
|
Loads raw data, converts to PyTorch tensors, applies augmentations, |
|
|
and prepares tracks if enabled. |
|
|
|
|
|
Args: |
|
|
idx_tuple (tuple): a tuple of (seq_idx, num_images, aspect_ratio) |
|
|
|
|
|
Returns: |
|
|
dict: A dictionary containing the sequence data (images, poses, tracks, etc.). |
|
|
""" |
|
|
|
|
|
if self.fixed_num_images > 0: |
|
|
seq_idx = idx_tuple[0] if isinstance(idx_tuple, tuple) else idx_tuple |
|
|
idx_tuple = (seq_idx, self.fixed_num_images, self.fixed_aspect_ratio) |
|
|
|
|
|
|
|
|
batch = self.base_dataset[idx_tuple] |
|
|
seq_name = batch["seq_name"] |
|
|
|
|
|
|
|
|
|
|
|
images = torch.from_numpy(np.stack(batch["images"]).astype(np.float32)).contiguous() |
|
|
|
|
|
images = images.permute(0,3,1,2).to(torch.get_default_dtype()).div(255) |
|
|
|
|
|
|
|
|
depths = torch.from_numpy(np.stack(batch["depths"]).astype(np.float32)) |
|
|
extrinsics = torch.from_numpy(np.stack(batch["extrinsics"]).astype(np.float32)) |
|
|
intrinsics = torch.from_numpy(np.stack(batch["intrinsics"]).astype(np.float32)) |
|
|
cam_points = torch.from_numpy(np.stack(batch["cam_points"]).astype(np.float32)) |
|
|
world_points = torch.from_numpy(np.stack(batch["world_points"]).astype(np.float32)) |
|
|
point_masks = torch.from_numpy(np.stack(batch["point_masks"])) |
|
|
ids = torch.from_numpy(batch["ids"]) |
|
|
|
|
|
|
|
|
|
|
|
if self.training and self.image_aug is not None: |
|
|
if self.cojitter and random.random() > self.cojitter_ratio: |
|
|
|
|
|
images = self.image_aug(images) |
|
|
else: |
|
|
|
|
|
for aug_img_idx in range(len(images)): |
|
|
images[aug_img_idx] = self.image_aug(images[aug_img_idx]) |
|
|
|
|
|
|
|
|
|
|
|
sample = { |
|
|
"seq_name": seq_name, |
|
|
"ids": ids, |
|
|
"images": images, |
|
|
"depths": depths, |
|
|
"extrinsics": extrinsics, |
|
|
"intrinsics": intrinsics, |
|
|
"cam_points": cam_points, |
|
|
"world_points": world_points, |
|
|
"point_masks": point_masks, |
|
|
} |
|
|
|
|
|
|
|
|
if self.load_track: |
|
|
if batch["tracks"] is not None: |
|
|
|
|
|
tracks = torch.from_numpy(np.stack(batch["tracks"]).astype(np.float32)) |
|
|
track_vis_mask = torch.from_numpy(np.stack(batch["track_masks"]).astype(bool)) |
|
|
|
|
|
|
|
|
valid_indices = torch.where(track_vis_mask[0])[0] |
|
|
if len(valid_indices) >= self.track_num: |
|
|
|
|
|
sampled_indices = valid_indices[torch.randperm(len(valid_indices))][:self.track_num] |
|
|
else: |
|
|
|
|
|
sampled_indices = valid_indices[torch.randint(0, len(valid_indices), |
|
|
(self.track_num,), |
|
|
dtype=torch.int64, |
|
|
device=valid_indices.device)] |
|
|
|
|
|
|
|
|
tracks = tracks[:, sampled_indices, :] |
|
|
track_vis_mask = track_vis_mask[:, sampled_indices] |
|
|
track_positive_mask = torch.ones(track_vis_mask.shape[1]).bool() |
|
|
|
|
|
else: |
|
|
|
|
|
|
|
|
tracks, track_vis_mask, track_positive_mask = build_tracks_by_depth( |
|
|
extrinsics, intrinsics, world_points, depths, point_masks, images, |
|
|
target_track_num=self.track_num, seq_name=seq_name |
|
|
) |
|
|
|
|
|
|
|
|
sample["tracks"] = tracks |
|
|
sample["track_vis_mask"] = track_vis_mask |
|
|
sample["track_positive_mask"] = track_positive_mask |
|
|
|
|
|
return sample |
|
|
|
|
|
|
|
|
class TupleConcatDataset(ConcatDataset): |
|
|
""" |
|
|
A custom ConcatDataset that supports indexing with a tuple. |
|
|
|
|
|
Standard PyTorch ConcatDataset only accepts an integer index. This class extends |
|
|
that functionality to allow passing a tuple like (sample_idx, num_images, aspect_ratio), |
|
|
where the first element is used to determine which sample to fetch, and the full |
|
|
tuple is passed down to the selected dataset's __getitem__ method. |
|
|
|
|
|
It also supports an option to randomly sample across all datasets, ignoring the |
|
|
provided index. This is useful during training when shuffling the entire dataset |
|
|
might cause memory issues due to duplicating dictionaries. If doing this, you can |
|
|
set pytorch's dataloader shuffle to False. |
|
|
""" |
|
|
def __init__(self, datasets, common_config): |
|
|
""" |
|
|
Initialize the TupleConcatDataset. |
|
|
|
|
|
Args: |
|
|
datasets (iterable): An iterable of PyTorch Dataset objects to concatenate. |
|
|
common_config (dict): Common configuration dict, used to check for random sampling. |
|
|
""" |
|
|
super().__init__(datasets) |
|
|
|
|
|
|
|
|
self.inside_random = common_config.inside_random |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
""" |
|
|
Retrieves an item using either an integer index or a tuple index. |
|
|
|
|
|
Args: |
|
|
idx (int or tuple): The index. If tuple, the first element is the sequence |
|
|
index across the concatenated datasets, and the rest are |
|
|
passed down. If int, it's treated as the sequence index. |
|
|
|
|
|
Returns: |
|
|
The item returned by the underlying dataset's __getitem__ method. |
|
|
|
|
|
Raises: |
|
|
ValueError: If the index is out of range or the tuple doesn't have exactly 3 elements. |
|
|
""" |
|
|
idx_tuple = None |
|
|
if isinstance(idx, tuple): |
|
|
idx_tuple = idx |
|
|
idx = idx_tuple[0] |
|
|
|
|
|
|
|
|
if self.inside_random: |
|
|
total_len = self.cumulative_sizes[-1] |
|
|
idx = random.randint(0, total_len - 1) |
|
|
|
|
|
|
|
|
if idx < 0: |
|
|
if -idx > len(self): |
|
|
raise ValueError( |
|
|
"absolute value of index should not exceed dataset length" |
|
|
) |
|
|
idx = len(self) + idx |
|
|
|
|
|
|
|
|
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) |
|
|
if dataset_idx == 0: |
|
|
sample_idx = idx |
|
|
else: |
|
|
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] |
|
|
|
|
|
|
|
|
if len(idx_tuple) == 3: |
|
|
idx_tuple = (sample_idx,) + idx_tuple[1:] |
|
|
else: |
|
|
raise ValueError("Tuple index must have exactly three elements") |
|
|
|
|
|
|
|
|
return self.datasets[dataset_idx][idx_tuple] |
|
|
|