|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Callable, Optional |
|
|
|
|
|
from hydra.utils import instantiate |
|
|
import random |
|
|
import numpy as np |
|
|
from torch.utils.data import DataLoader, Dataset, DistributedSampler, IterableDataset, Sampler |
|
|
from abc import ABC, abstractmethod |
|
|
|
|
|
from .worker_fn import get_worker_init_fn |
|
|
|
|
|
class DynamicTorchDataset(ABC): |
|
|
def __init__( |
|
|
self, |
|
|
dataset: dict, |
|
|
common_config: dict, |
|
|
num_workers: int, |
|
|
shuffle: bool, |
|
|
pin_memory: bool, |
|
|
drop_last: bool = True, |
|
|
collate_fn: Optional[Callable] = None, |
|
|
worker_init_fn: Optional[Callable] = None, |
|
|
persistent_workers: bool = False, |
|
|
seed: int = 42, |
|
|
max_img_per_gpu: int = 48, |
|
|
) -> None: |
|
|
self.dataset_config = dataset |
|
|
self.common_config = common_config |
|
|
self.num_workers = num_workers |
|
|
self.shuffle = shuffle |
|
|
self.pin_memory = pin_memory |
|
|
self.drop_last = drop_last |
|
|
self.collate_fn = collate_fn |
|
|
self.worker_init_fn = worker_init_fn |
|
|
self.persistent_workers = persistent_workers |
|
|
self.seed = seed |
|
|
self.max_img_per_gpu = max_img_per_gpu |
|
|
|
|
|
|
|
|
self.dataset = instantiate(dataset, common_config=common_config, _recursive_=False) |
|
|
|
|
|
|
|
|
self.aspect_ratio_range = common_config.augs.aspects |
|
|
self.image_num_range = common_config.img_nums |
|
|
|
|
|
|
|
|
if len(self.aspect_ratio_range) != 2 or self.aspect_ratio_range[0] > self.aspect_ratio_range[1]: |
|
|
raise ValueError(f"aspect_ratio_range must be [min, max] with min <= max, got {self.aspect_ratio_range}") |
|
|
if len(self.image_num_range) != 2 or self.image_num_range[0] < 1 or self.image_num_range[0] > self.image_num_range[1]: |
|
|
raise ValueError(f"image_num_range must be [min, max] with 1 <= min <= max, got {self.image_num_range}") |
|
|
|
|
|
|
|
|
self.sampler = DynamicDistributedSampler(self.dataset, seed=seed, shuffle=shuffle) |
|
|
self.batch_sampler = DynamicBatchSampler( |
|
|
self.sampler, |
|
|
self.aspect_ratio_range, |
|
|
self.image_num_range, |
|
|
seed=seed, |
|
|
max_img_per_gpu=max_img_per_gpu |
|
|
) |
|
|
|
|
|
def get_loader(self, epoch): |
|
|
print("Building dynamic dataloader with epoch:", epoch) |
|
|
|
|
|
|
|
|
self.sampler.set_epoch(epoch) |
|
|
if hasattr(self.dataset, "epoch"): |
|
|
self.dataset.epoch = epoch |
|
|
if hasattr(self.dataset, "set_epoch"): |
|
|
self.dataset.set_epoch(epoch) |
|
|
|
|
|
|
|
|
return DataLoader( |
|
|
self.dataset, |
|
|
num_workers=self.num_workers, |
|
|
pin_memory=self.pin_memory, |
|
|
batch_sampler=self.batch_sampler, |
|
|
collate_fn=self.collate_fn, |
|
|
persistent_workers=self.persistent_workers, |
|
|
worker_init_fn=get_worker_init_fn( |
|
|
seed=self.seed, |
|
|
num_workers=self.num_workers, |
|
|
epoch=epoch, |
|
|
worker_init_fn=self.worker_init_fn, |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
class DynamicBatchSampler(Sampler): |
|
|
""" |
|
|
A custom batch sampler that dynamically adjusts batch size, aspect ratio, and image number |
|
|
for each sample. Batches within a sample share the same aspect ratio and image number. |
|
|
""" |
|
|
def __init__(self, |
|
|
sampler, |
|
|
aspect_ratio_range, |
|
|
image_num_range, |
|
|
epoch=0, |
|
|
seed=42, |
|
|
max_img_per_gpu=48): |
|
|
""" |
|
|
Initializes the dynamic batch sampler. |
|
|
|
|
|
Args: |
|
|
sampler: Instance of DynamicDistributedSampler. |
|
|
aspect_ratio_range: List containing [min_aspect_ratio, max_aspect_ratio]. |
|
|
image_num_range: List containing [min_images, max_images] per sample. |
|
|
epoch: Current epoch number. |
|
|
seed: Random seed for reproducibility. |
|
|
max_img_per_gpu: Maximum number of images to fit in GPU memory. |
|
|
""" |
|
|
self.sampler = sampler |
|
|
self.aspect_ratio_range = aspect_ratio_range |
|
|
self.image_num_range = image_num_range |
|
|
self.rng = random.Random() |
|
|
|
|
|
|
|
|
|
|
|
self.image_num_weights = {num_images: 1.0 for num_images in range(image_num_range[0], image_num_range[1]+1)} |
|
|
|
|
|
|
|
|
self.possible_nums = np.array([n for n in self.image_num_weights.keys() |
|
|
if self.image_num_range[0] <= n <= self.image_num_range[1]]) |
|
|
|
|
|
|
|
|
weights = [self.image_num_weights[n] for n in self.possible_nums] |
|
|
self.normalized_weights = np.array(weights) / sum(weights) |
|
|
|
|
|
|
|
|
self.max_img_per_gpu = max_img_per_gpu |
|
|
|
|
|
|
|
|
self.set_epoch(epoch + seed) |
|
|
|
|
|
def set_epoch(self, epoch): |
|
|
""" |
|
|
Sets the epoch for this sampler, affecting the random sequence. |
|
|
|
|
|
Args: |
|
|
epoch: The epoch number. |
|
|
""" |
|
|
self.sampler.set_epoch(epoch) |
|
|
self.epoch = epoch |
|
|
self.rng.seed(epoch * 100) |
|
|
|
|
|
def __iter__(self): |
|
|
""" |
|
|
Yields batches of samples with synchronized dynamic parameters. |
|
|
|
|
|
Returns: |
|
|
Iterator yielding batches of indices with associated parameters. |
|
|
""" |
|
|
sampler_iterator = iter(self.sampler) |
|
|
|
|
|
while True: |
|
|
try: |
|
|
|
|
|
random_image_num = int(np.random.choice(self.possible_nums, p=self.normalized_weights)) |
|
|
random_aspect_ratio = round(self.rng.uniform(self.aspect_ratio_range[0], self.aspect_ratio_range[1]), 2) |
|
|
|
|
|
|
|
|
self.sampler.update_parameters( |
|
|
aspect_ratio=random_aspect_ratio, |
|
|
image_num=random_image_num |
|
|
) |
|
|
|
|
|
|
|
|
batch_size = self.max_img_per_gpu / random_image_num |
|
|
batch_size = np.floor(batch_size).astype(int) |
|
|
batch_size = max(1, batch_size) |
|
|
|
|
|
|
|
|
current_batch = [] |
|
|
for _ in range(batch_size): |
|
|
try: |
|
|
item = next(sampler_iterator) |
|
|
current_batch.append(item) |
|
|
except StopIteration: |
|
|
break |
|
|
|
|
|
if not current_batch: |
|
|
break |
|
|
|
|
|
yield current_batch |
|
|
|
|
|
except StopIteration: |
|
|
break |
|
|
|
|
|
def __len__(self): |
|
|
|
|
|
return 1000000 |
|
|
|
|
|
|
|
|
class DynamicDistributedSampler(DistributedSampler): |
|
|
""" |
|
|
Extends PyTorch's DistributedSampler to include dynamic aspect_ratio and image_num |
|
|
parameters, which can be passed into the dataset's __getitem__ method. |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
dataset, |
|
|
num_replicas: Optional[int] = None, |
|
|
rank: Optional[int] = None, |
|
|
shuffle: bool = False, |
|
|
seed: int = 0, |
|
|
drop_last: bool = False, |
|
|
): |
|
|
super().__init__( |
|
|
dataset, |
|
|
num_replicas=num_replicas, |
|
|
rank=rank, |
|
|
shuffle=shuffle, |
|
|
seed=seed, |
|
|
drop_last=drop_last |
|
|
) |
|
|
self.aspect_ratio = None |
|
|
self.image_num = None |
|
|
|
|
|
def __iter__(self): |
|
|
""" |
|
|
Yields a sequence of (index, image_num, aspect_ratio). |
|
|
Relies on the parent class's logic for shuffling/distributing |
|
|
the indices across replicas, then attaches extra parameters. |
|
|
""" |
|
|
indices_iter = super().__iter__() |
|
|
|
|
|
for idx in indices_iter: |
|
|
yield (idx, self.image_num, self.aspect_ratio,) |
|
|
|
|
|
def update_parameters(self, aspect_ratio, image_num): |
|
|
""" |
|
|
Updates dynamic parameters for each new epoch or iteration. |
|
|
|
|
|
Args: |
|
|
aspect_ratio: The aspect ratio to set. |
|
|
image_num: The number of images to set. |
|
|
""" |
|
|
self.aspect_ratio = aspect_ratio |
|
|
self.image_num = image_num |
|
|
|