|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
from PIL import Image, ImageFile |
|
|
|
|
|
from torch.utils.data import Dataset |
|
|
from .dataset_util import * |
|
|
|
|
|
Image.MAX_IMAGE_PIXELS = None |
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
|
|
|
|
|
|
class BaseDataset(Dataset): |
|
|
""" |
|
|
Base dataset class for VGGT and VGGSfM training. |
|
|
|
|
|
This abstract class handles common operations like image resizing, |
|
|
augmentation, and coordinate transformations. Concrete dataset |
|
|
implementations should inherit from this class. |
|
|
|
|
|
Attributes: |
|
|
img_size: Target image size (typically the width) |
|
|
patch_size: Size of patches for vit |
|
|
augs.scales: Scale range for data augmentation [min, max] |
|
|
rescale: Whether to rescale images |
|
|
rescale_aug: Whether to apply augmentation during rescaling |
|
|
landscape_check: Whether to handle landscape vs portrait orientation |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
common_conf, |
|
|
): |
|
|
""" |
|
|
Initialize the base dataset with common configuration. |
|
|
|
|
|
Args: |
|
|
common_conf: Configuration object with the following properties, shared by all datasets: |
|
|
- img_size: Default is 518 |
|
|
- patch_size: Default is 14 |
|
|
- augs.scales: Default is [0.8, 1.2] |
|
|
- rescale: Default is True |
|
|
- rescale_aug: Default is True |
|
|
- landscape_check: Default is True |
|
|
""" |
|
|
super().__init__() |
|
|
self.img_size = common_conf.img_size |
|
|
self.patch_size = common_conf.patch_size |
|
|
self.aug_scale = common_conf.augs.scales |
|
|
self.rescale = common_conf.rescale |
|
|
self.rescale_aug = common_conf.rescale_aug |
|
|
self.landscape_check = common_conf.landscape_check |
|
|
|
|
|
def __len__(self): |
|
|
return self.len_train |
|
|
|
|
|
def __getitem__(self, idx_N): |
|
|
""" |
|
|
Get an item from the dataset. |
|
|
|
|
|
Args: |
|
|
idx_N: Tuple containing (seq_index, img_per_seq, aspect_ratio) |
|
|
|
|
|
Returns: |
|
|
Dataset item as returned by get_data() |
|
|
""" |
|
|
seq_index, img_per_seq, aspect_ratio = idx_N |
|
|
return self.get_data( |
|
|
seq_index=seq_index, img_per_seq=img_per_seq, aspect_ratio=aspect_ratio |
|
|
) |
|
|
|
|
|
def get_data(self, seq_index=None, seq_name=None, ids=None, aspect_ratio=1.0): |
|
|
""" |
|
|
Abstract method to retrieve data for a given sequence. |
|
|
|
|
|
Args: |
|
|
seq_index (int, optional): Index of the sequence |
|
|
seq_name (str, optional): Name of the sequence |
|
|
ids (list, optional): List of frame IDs |
|
|
aspect_ratio (float, optional): Target aspect ratio. |
|
|
|
|
|
Returns: |
|
|
Dataset-specific data |
|
|
|
|
|
Raises: |
|
|
NotImplementedError: This method must be implemented by subclasses |
|
|
""" |
|
|
raise NotImplementedError( |
|
|
"This is an abstract method and should be implemented in the subclass, i.e., each dataset should implement its own get_data method." |
|
|
) |
|
|
|
|
|
def get_target_shape(self, aspect_ratio): |
|
|
""" |
|
|
Calculate the target shape based on the given aspect ratio. |
|
|
|
|
|
Args: |
|
|
aspect_ratio: Target aspect ratio |
|
|
|
|
|
Returns: |
|
|
numpy.ndarray: Target image shape [height, width] |
|
|
""" |
|
|
short_size = int(self.img_size * aspect_ratio) |
|
|
small_size = self.patch_size |
|
|
|
|
|
|
|
|
if short_size % small_size != 0: |
|
|
short_size = (short_size // small_size) * small_size |
|
|
|
|
|
image_shape = np.array([short_size, self.img_size]) |
|
|
return image_shape |
|
|
|
|
|
def process_one_image( |
|
|
self, |
|
|
image, |
|
|
depth_map, |
|
|
extri_opencv, |
|
|
intri_opencv, |
|
|
original_size, |
|
|
target_image_shape, |
|
|
track=None, |
|
|
filepath=None, |
|
|
safe_bound=4, |
|
|
): |
|
|
""" |
|
|
Process a single image and its associated data. |
|
|
|
|
|
This method handles image transformations, depth processing, and coordinate conversions. |
|
|
|
|
|
Args: |
|
|
image (numpy.ndarray): Input image array |
|
|
depth_map (numpy.ndarray): Depth map array |
|
|
extri_opencv (numpy.ndarray): Extrinsic camera matrix (OpenCV convention) |
|
|
intri_opencv (numpy.ndarray): Intrinsic camera matrix (OpenCV convention) |
|
|
original_size (numpy.ndarray): Original image size [height, width] |
|
|
target_image_shape (numpy.ndarray): Target image shape after processing |
|
|
track (numpy.ndarray, optional): Optional tracking information. Defaults to None. |
|
|
filepath (str, optional): Optional file path for debugging. Defaults to None. |
|
|
safe_bound (int, optional): Safety margin for cropping operations. Defaults to 4. |
|
|
|
|
|
Returns: |
|
|
tuple: ( |
|
|
image (numpy.ndarray): Processed image, |
|
|
depth_map (numpy.ndarray): Processed depth map, |
|
|
extri_opencv (numpy.ndarray): Updated extrinsic matrix, |
|
|
intri_opencv (numpy.ndarray): Updated intrinsic matrix, |
|
|
world_coords_points (numpy.ndarray): 3D points in world coordinates, |
|
|
cam_coords_points (numpy.ndarray): 3D points in camera coordinates, |
|
|
point_mask (numpy.ndarray): Boolean mask of valid points, |
|
|
track (numpy.ndarray, optional): Updated tracking information |
|
|
) |
|
|
""" |
|
|
|
|
|
image = np.copy(image) |
|
|
depth_map = np.copy(depth_map) |
|
|
extri_opencv = np.copy(extri_opencv) |
|
|
intri_opencv = np.copy(intri_opencv) |
|
|
if track is not None: |
|
|
track = np.copy(track) |
|
|
|
|
|
|
|
|
if self.training and self.aug_scale: |
|
|
random_h_scale, random_w_scale = np.random.uniform( |
|
|
self.aug_scale[0], self.aug_scale[1], 2 |
|
|
) |
|
|
|
|
|
random_h_scale = min(random_h_scale, 1.0) |
|
|
random_w_scale = min(random_w_scale, 1.0) |
|
|
aug_size = original_size * np.array([random_h_scale, random_w_scale]) |
|
|
aug_size = aug_size.astype(np.int32) |
|
|
else: |
|
|
aug_size = original_size |
|
|
|
|
|
|
|
|
image, depth_map, intri_opencv, track = crop_image_depth_and_intrinsic_by_pp( |
|
|
image, depth_map, intri_opencv, aug_size, track=track, filepath=filepath, |
|
|
) |
|
|
|
|
|
original_size = np.array(image.shape[:2]) |
|
|
target_shape = target_image_shape |
|
|
|
|
|
|
|
|
rotate_to_portrait = False |
|
|
if self.landscape_check: |
|
|
|
|
|
if original_size[0] > 1.25 * original_size[1]: |
|
|
if (target_image_shape[0] != target_image_shape[1]) and (np.random.rand() > 0.5): |
|
|
target_shape = np.array([target_image_shape[1], target_image_shape[0]]) |
|
|
rotate_to_portrait = True |
|
|
|
|
|
|
|
|
if self.rescale: |
|
|
image, depth_map, intri_opencv, track = resize_image_depth_and_intrinsic( |
|
|
image, depth_map, intri_opencv, target_shape, original_size, track=track, |
|
|
safe_bound=safe_bound, |
|
|
rescale_aug=self.rescale_aug |
|
|
) |
|
|
else: |
|
|
print("Not rescaling the images") |
|
|
|
|
|
|
|
|
image, depth_map, intri_opencv, track = crop_image_depth_and_intrinsic_by_pp( |
|
|
image, depth_map, intri_opencv, target_shape, track=track, filepath=filepath, strict=True, |
|
|
) |
|
|
|
|
|
|
|
|
if rotate_to_portrait: |
|
|
assert self.landscape_check |
|
|
clockwise = np.random.rand() > 0.5 |
|
|
image, depth_map, extri_opencv, intri_opencv, track = rotate_90_degrees( |
|
|
image, |
|
|
depth_map, |
|
|
extri_opencv, |
|
|
intri_opencv, |
|
|
clockwise=clockwise, |
|
|
track=track, |
|
|
) |
|
|
|
|
|
|
|
|
world_coords_points, cam_coords_points, point_mask = ( |
|
|
depth_to_world_coords_points(depth_map, extri_opencv, intri_opencv) |
|
|
) |
|
|
|
|
|
return ( |
|
|
image, |
|
|
depth_map, |
|
|
extri_opencv, |
|
|
intri_opencv, |
|
|
world_coords_points, |
|
|
cam_coords_points, |
|
|
point_mask, |
|
|
track, |
|
|
) |
|
|
|
|
|
def get_nearby_ids(self, ids, full_seq_num, expand_ratio=None, expand_range=None): |
|
|
""" |
|
|
TODO: add the function to sample the ids by pose similarity ranking. |
|
|
|
|
|
Sample a set of IDs from a sequence close to a given start index. |
|
|
|
|
|
You can specify the range either as a ratio of the number of input IDs |
|
|
or as a fixed integer window. |
|
|
|
|
|
|
|
|
Args: |
|
|
ids (list): Initial list of IDs. The first element is used as the anchor. |
|
|
full_seq_num (int): Total number of items in the full sequence. |
|
|
expand_ratio (float, optional): Factor by which the number of IDs expands |
|
|
around the start index. Default is 2.0 if neither expand_ratio nor |
|
|
expand_range is provided. |
|
|
expand_range (int, optional): Fixed number of items to expand around the |
|
|
start index. If provided, expand_ratio is ignored. |
|
|
|
|
|
Returns: |
|
|
numpy.ndarray: Array of sampled IDs, with the first element being the |
|
|
original start index. |
|
|
|
|
|
Examples: |
|
|
# Using expand_ratio (default behavior) |
|
|
# If ids=[100,101,102] and full_seq_num=200, with expand_ratio=2.0, |
|
|
# expand_range = int(3 * 2.0) = 6, so IDs sampled from [94...106] (if boundaries allow). |
|
|
|
|
|
# Using expand_range directly |
|
|
# If ids=[100,101,102] and full_seq_num=200, with expand_range=10, |
|
|
# IDs are sampled from [90...110] (if boundaries allow). |
|
|
|
|
|
Raises: |
|
|
ValueError: If no IDs are provided. |
|
|
""" |
|
|
if len(ids) == 0: |
|
|
raise ValueError("No IDs provided.") |
|
|
|
|
|
if expand_range is None and expand_ratio is None: |
|
|
expand_ratio = 2.0 |
|
|
|
|
|
total_ids = len(ids) |
|
|
start_idx = ids[0] |
|
|
|
|
|
|
|
|
if expand_range is None: |
|
|
|
|
|
expand_range = int(total_ids * expand_ratio) |
|
|
|
|
|
|
|
|
low_bound = max(0, start_idx - expand_range) |
|
|
high_bound = min(full_seq_num, start_idx + expand_range) |
|
|
|
|
|
|
|
|
valid_range = np.arange(low_bound, high_bound) |
|
|
|
|
|
|
|
|
sampled_ids = np.random.choice( |
|
|
valid_range, |
|
|
size=(total_ids - 1), |
|
|
replace=True, |
|
|
) |
|
|
|
|
|
|
|
|
result_ids = np.insert(sampled_ids, 0, start_idx) |
|
|
|
|
|
return result_ids |
|
|
|