# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import numpy as np from PIL import Image, ImageFile from torch.utils.data import Dataset from .dataset_util import * Image.MAX_IMAGE_PIXELS = None ImageFile.LOAD_TRUNCATED_IMAGES = True class BaseDataset(Dataset): """ Base dataset class for VGGT and VGGSfM training. This abstract class handles common operations like image resizing, augmentation, and coordinate transformations. Concrete dataset implementations should inherit from this class. Attributes: img_size: Target image size (typically the width) patch_size: Size of patches for vit augs.scales: Scale range for data augmentation [min, max] rescale: Whether to rescale images rescale_aug: Whether to apply augmentation during rescaling landscape_check: Whether to handle landscape vs portrait orientation """ def __init__( self, common_conf, ): """ Initialize the base dataset with common configuration. Args: common_conf: Configuration object with the following properties, shared by all datasets: - img_size: Default is 518 - patch_size: Default is 14 - augs.scales: Default is [0.8, 1.2] - rescale: Default is True - rescale_aug: Default is True - landscape_check: Default is True """ super().__init__() self.img_size = common_conf.img_size self.patch_size = common_conf.patch_size self.aug_scale = common_conf.augs.scales self.rescale = common_conf.rescale self.rescale_aug = common_conf.rescale_aug self.landscape_check = common_conf.landscape_check def __len__(self): return self.len_train def __getitem__(self, idx_N): """ Get an item from the dataset. Args: idx_N: Tuple containing (seq_index, img_per_seq, aspect_ratio) Returns: Dataset item as returned by get_data() """ seq_index, img_per_seq, aspect_ratio = idx_N return self.get_data( seq_index=seq_index, img_per_seq=img_per_seq, aspect_ratio=aspect_ratio ) def get_data(self, seq_index=None, seq_name=None, ids=None, aspect_ratio=1.0): """ Abstract method to retrieve data for a given sequence. Args: seq_index (int, optional): Index of the sequence seq_name (str, optional): Name of the sequence ids (list, optional): List of frame IDs aspect_ratio (float, optional): Target aspect ratio. Returns: Dataset-specific data Raises: NotImplementedError: This method must be implemented by subclasses """ raise NotImplementedError( "This is an abstract method and should be implemented in the subclass, i.e., each dataset should implement its own get_data method." ) def get_target_shape(self, aspect_ratio): """ Calculate the target shape based on the given aspect ratio. Args: aspect_ratio: Target aspect ratio Returns: numpy.ndarray: Target image shape [height, width] """ short_size = int(self.img_size * aspect_ratio) small_size = self.patch_size # ensure the input shape is friendly to vision transformer if short_size % small_size != 0: short_size = (short_size // small_size) * small_size image_shape = np.array([short_size, self.img_size]) return image_shape def process_one_image( self, image, depth_map, extri_opencv, intri_opencv, original_size, target_image_shape, track=None, filepath=None, safe_bound=4, ): """ Process a single image and its associated data. This method handles image transformations, depth processing, and coordinate conversions. Args: image (numpy.ndarray): Input image array depth_map (numpy.ndarray): Depth map array extri_opencv (numpy.ndarray): Extrinsic camera matrix (OpenCV convention) intri_opencv (numpy.ndarray): Intrinsic camera matrix (OpenCV convention) original_size (numpy.ndarray): Original image size [height, width] target_image_shape (numpy.ndarray): Target image shape after processing track (numpy.ndarray, optional): Optional tracking information. Defaults to None. filepath (str, optional): Optional file path for debugging. Defaults to None. safe_bound (int, optional): Safety margin for cropping operations. Defaults to 4. Returns: tuple: ( image (numpy.ndarray): Processed image, depth_map (numpy.ndarray): Processed depth map, extri_opencv (numpy.ndarray): Updated extrinsic matrix, intri_opencv (numpy.ndarray): Updated intrinsic matrix, world_coords_points (numpy.ndarray): 3D points in world coordinates, cam_coords_points (numpy.ndarray): 3D points in camera coordinates, point_mask (numpy.ndarray): Boolean mask of valid points, track (numpy.ndarray, optional): Updated tracking information ) """ # Make copies to avoid in-place operations affecting original data image = np.copy(image) depth_map = np.copy(depth_map) extri_opencv = np.copy(extri_opencv) intri_opencv = np.copy(intri_opencv) if track is not None: track = np.copy(track) # Apply random scale augmentation during training if enabled if self.training and self.aug_scale: random_h_scale, random_w_scale = np.random.uniform( self.aug_scale[0], self.aug_scale[1], 2 ) # Avoid random padding by capping at 1.0 random_h_scale = min(random_h_scale, 1.0) random_w_scale = min(random_w_scale, 1.0) aug_size = original_size * np.array([random_h_scale, random_w_scale]) aug_size = aug_size.astype(np.int32) else: aug_size = original_size # Move principal point to the image center and crop if necessary image, depth_map, intri_opencv, track = crop_image_depth_and_intrinsic_by_pp( image, depth_map, intri_opencv, aug_size, track=track, filepath=filepath, ) original_size = np.array(image.shape[:2]) # update original_size target_shape = target_image_shape # Handle landscape vs. portrait orientation rotate_to_portrait = False if self.landscape_check: # Switch between landscape and portrait if necessary if original_size[0] > 1.25 * original_size[1]: if (target_image_shape[0] != target_image_shape[1]) and (np.random.rand() > 0.5): target_shape = np.array([target_image_shape[1], target_image_shape[0]]) rotate_to_portrait = True # Resize images and update intrinsics if self.rescale: image, depth_map, intri_opencv, track = resize_image_depth_and_intrinsic( image, depth_map, intri_opencv, target_shape, original_size, track=track, safe_bound=safe_bound, rescale_aug=self.rescale_aug ) else: print("Not rescaling the images") # Ensure final crop to target shape image, depth_map, intri_opencv, track = crop_image_depth_and_intrinsic_by_pp( image, depth_map, intri_opencv, target_shape, track=track, filepath=filepath, strict=True, ) # Apply 90-degree rotation if needed if rotate_to_portrait: assert self.landscape_check clockwise = np.random.rand() > 0.5 image, depth_map, extri_opencv, intri_opencv, track = rotate_90_degrees( image, depth_map, extri_opencv, intri_opencv, clockwise=clockwise, track=track, ) # Convert depth to world and camera coordinates world_coords_points, cam_coords_points, point_mask = ( depth_to_world_coords_points(depth_map, extri_opencv, intri_opencv) ) return ( image, depth_map, extri_opencv, intri_opencv, world_coords_points, cam_coords_points, point_mask, track, ) def get_nearby_ids(self, ids, full_seq_num, expand_ratio=None, expand_range=None): """ TODO: add the function to sample the ids by pose similarity ranking. Sample a set of IDs from a sequence close to a given start index. You can specify the range either as a ratio of the number of input IDs or as a fixed integer window. Args: ids (list): Initial list of IDs. The first element is used as the anchor. full_seq_num (int): Total number of items in the full sequence. expand_ratio (float, optional): Factor by which the number of IDs expands around the start index. Default is 2.0 if neither expand_ratio nor expand_range is provided. expand_range (int, optional): Fixed number of items to expand around the start index. If provided, expand_ratio is ignored. Returns: numpy.ndarray: Array of sampled IDs, with the first element being the original start index. Examples: # Using expand_ratio (default behavior) # If ids=[100,101,102] and full_seq_num=200, with expand_ratio=2.0, # expand_range = int(3 * 2.0) = 6, so IDs sampled from [94...106] (if boundaries allow). # Using expand_range directly # If ids=[100,101,102] and full_seq_num=200, with expand_range=10, # IDs are sampled from [90...110] (if boundaries allow). Raises: ValueError: If no IDs are provided. """ if len(ids) == 0: raise ValueError("No IDs provided.") if expand_range is None and expand_ratio is None: expand_ratio = 2.0 # Default behavior total_ids = len(ids) start_idx = ids[0] # Determine the actual expand_range if expand_range is None: # Use ratio to determine range expand_range = int(total_ids * expand_ratio) # Calculate valid boundaries low_bound = max(0, start_idx - expand_range) high_bound = min(full_seq_num, start_idx + expand_range) # Create the valid range of indices valid_range = np.arange(low_bound, high_bound) # Sample 'total_ids - 1' items, because we already have the start_idx sampled_ids = np.random.choice( valid_range, size=(total_ids - 1), replace=True, # we accept the situation that some sampled ids are the same ) # Insert the start_idx at the beginning result_ids = np.insert(sampled_ids, 0, start_idx) return result_ids