_vggt / training /data /dynamic_dataloader.py
CgvKodai's picture
Upload folder using huggingface_hub
66003a2 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Callable, Optional
from hydra.utils import instantiate
import random
import numpy as np
from torch.utils.data import DataLoader, Dataset, DistributedSampler, IterableDataset, Sampler
from abc import ABC, abstractmethod
from .worker_fn import get_worker_init_fn
class DynamicTorchDataset(ABC):
def __init__(
self,
dataset: dict,
common_config: dict,
num_workers: int,
shuffle: bool,
pin_memory: bool,
drop_last: bool = True,
collate_fn: Optional[Callable] = None,
worker_init_fn: Optional[Callable] = None,
persistent_workers: bool = False,
seed: int = 42,
max_img_per_gpu: int = 48,
) -> None:
self.dataset_config = dataset
self.common_config = common_config
self.num_workers = num_workers
self.shuffle = shuffle
self.pin_memory = pin_memory
self.drop_last = drop_last
self.collate_fn = collate_fn
self.worker_init_fn = worker_init_fn
self.persistent_workers = persistent_workers
self.seed = seed
self.max_img_per_gpu = max_img_per_gpu
# Instantiate the dataset
self.dataset = instantiate(dataset, common_config=common_config, _recursive_=False)
# Extract aspect ratio and image number ranges from the configuration
self.aspect_ratio_range = common_config.augs.aspects # e.g., [0.5, 1.0]
self.image_num_range = common_config.img_nums # e.g., [2, 24]
# Validate the aspect ratio and image number ranges
if len(self.aspect_ratio_range) != 2 or self.aspect_ratio_range[0] > self.aspect_ratio_range[1]:
raise ValueError(f"aspect_ratio_range must be [min, max] with min <= max, got {self.aspect_ratio_range}")
if len(self.image_num_range) != 2 or self.image_num_range[0] < 1 or self.image_num_range[0] > self.image_num_range[1]:
raise ValueError(f"image_num_range must be [min, max] with 1 <= min <= max, got {self.image_num_range}")
# Create samplers
self.sampler = DynamicDistributedSampler(self.dataset, seed=seed, shuffle=shuffle)
self.batch_sampler = DynamicBatchSampler(
self.sampler,
self.aspect_ratio_range,
self.image_num_range,
seed=seed,
max_img_per_gpu=max_img_per_gpu
)
def get_loader(self, epoch):
print("Building dynamic dataloader with epoch:", epoch)
# Set the epoch for the sampler
self.sampler.set_epoch(epoch)
if hasattr(self.dataset, "epoch"):
self.dataset.epoch = epoch
if hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(epoch)
# Create and return the dataloader
return DataLoader(
self.dataset,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
batch_sampler=self.batch_sampler,
collate_fn=self.collate_fn,
persistent_workers=self.persistent_workers,
worker_init_fn=get_worker_init_fn(
seed=self.seed,
num_workers=self.num_workers,
epoch=epoch,
worker_init_fn=self.worker_init_fn,
),
)
class DynamicBatchSampler(Sampler):
"""
A custom batch sampler that dynamically adjusts batch size, aspect ratio, and image number
for each sample. Batches within a sample share the same aspect ratio and image number.
"""
def __init__(self,
sampler,
aspect_ratio_range,
image_num_range,
epoch=0,
seed=42,
max_img_per_gpu=48):
"""
Initializes the dynamic batch sampler.
Args:
sampler: Instance of DynamicDistributedSampler.
aspect_ratio_range: List containing [min_aspect_ratio, max_aspect_ratio].
image_num_range: List containing [min_images, max_images] per sample.
epoch: Current epoch number.
seed: Random seed for reproducibility.
max_img_per_gpu: Maximum number of images to fit in GPU memory.
"""
self.sampler = sampler
self.aspect_ratio_range = aspect_ratio_range
self.image_num_range = image_num_range
self.rng = random.Random()
# Uniformly sample from the range of possible image numbers
# For any image number, the weight is 1.0 (uniform sampling). You can set any different weights here.
self.image_num_weights = {num_images: 1.0 for num_images in range(image_num_range[0], image_num_range[1]+1)}
# Possible image numbers, e.g., [2, 3, 4, ..., 24]
self.possible_nums = np.array([n for n in self.image_num_weights.keys()
if self.image_num_range[0] <= n <= self.image_num_range[1]])
# Normalize weights for sampling
weights = [self.image_num_weights[n] for n in self.possible_nums]
self.normalized_weights = np.array(weights) / sum(weights)
# Maximum image number per GPU
self.max_img_per_gpu = max_img_per_gpu
# Set the epoch for the sampler
self.set_epoch(epoch + seed)
def set_epoch(self, epoch):
"""
Sets the epoch for this sampler, affecting the random sequence.
Args:
epoch: The epoch number.
"""
self.sampler.set_epoch(epoch)
self.epoch = epoch
self.rng.seed(epoch * 100)
def __iter__(self):
"""
Yields batches of samples with synchronized dynamic parameters.
Returns:
Iterator yielding batches of indices with associated parameters.
"""
sampler_iterator = iter(self.sampler)
while True:
try:
# Sample random image number and aspect ratio
random_image_num = int(np.random.choice(self.possible_nums, p=self.normalized_weights))
random_aspect_ratio = round(self.rng.uniform(self.aspect_ratio_range[0], self.aspect_ratio_range[1]), 2)
# Update sampler parameters
self.sampler.update_parameters(
aspect_ratio=random_aspect_ratio,
image_num=random_image_num
)
# Calculate batch size based on max images per GPU and current image number
batch_size = self.max_img_per_gpu / random_image_num
batch_size = np.floor(batch_size).astype(int)
batch_size = max(1, batch_size) # Ensure batch size is at least 1
# Collect samples for the current batch
current_batch = []
for _ in range(batch_size):
try:
item = next(sampler_iterator) # item is (idx, aspect_ratio, image_num)
current_batch.append(item)
except StopIteration:
break # No more samples
if not current_batch:
break # No more data to yield
yield current_batch
except StopIteration:
break # End of sampler's iterator
def __len__(self):
# Return a large dummy length
return 1000000
class DynamicDistributedSampler(DistributedSampler):
"""
Extends PyTorch's DistributedSampler to include dynamic aspect_ratio and image_num
parameters, which can be passed into the dataset's __getitem__ method.
"""
def __init__(
self,
dataset,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = False,
seed: int = 0,
drop_last: bool = False,
):
super().__init__(
dataset,
num_replicas=num_replicas,
rank=rank,
shuffle=shuffle,
seed=seed,
drop_last=drop_last
)
self.aspect_ratio = None
self.image_num = None
def __iter__(self):
"""
Yields a sequence of (index, image_num, aspect_ratio).
Relies on the parent class's logic for shuffling/distributing
the indices across replicas, then attaches extra parameters.
"""
indices_iter = super().__iter__()
for idx in indices_iter:
yield (idx, self.image_num, self.aspect_ratio,)
def update_parameters(self, aspect_ratio, image_num):
"""
Updates dynamic parameters for each new epoch or iteration.
Args:
aspect_ratio: The aspect ratio to set.
image_num: The number of images to set.
"""
self.aspect_ratio = aspect_ratio
self.image_num = image_num