| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | import numpy as np |
| | import torch |
| | from typing import Callable, Iterable, Optional |
| | from torch.utils.data import DataLoader, Dataset, DistributedSampler, IterableDataset, Sampler, BatchSampler |
| | import random |
| |
|
| | def custom_collate_fn(batch): |
| | """ |
| | Custom collate function to handle variable batch sizes |
| | |
| | Args: |
| | batch: A list where each element could be either: |
| | - A single tuple (idx, num_images, ...) |
| | - A list of tuples [(idx1, num_images1, ...), (idx2, num_images2, ...)] |
| | """ |
| | |
| | breakpoint() |
| | if isinstance(batch[0], list): |
| | |
| | flattened = [] |
| | for item in batch: |
| | flattened.extend(item) |
| | batch = flattened |
| | |
| | |
| | return torch.utils.data.default_collate(batch) |
| |
|
| | class BatchedRandomSampler: |
| | """Random sampling under a constraint: each sample in the batch has the same feature, |
| | which is chosen randomly from a known pool of 'features' for each batch. |
| | |
| | For instance, the 'feature' could be the image aspect-ratio. |
| | |
| | The index returned is a tuple (sample_idx, feat_idx). |
| | This sampler ensures that each series of `batch_size` indices has the same `feat_idx`. |
| | """ |
| |
|
| | def __init__( |
| | self, dataset, batch_size, num_context_views, min_patch_num=20, max_patch_num=32, world_size=1, rank=0, drop_last=True |
| | ): |
| | self.batch_size = batch_size |
| | self.num_context_views = num_context_views |
| |
|
| | self.len_dataset = N = len(dataset) |
| | self.total_size = round_by(N, batch_size * world_size) if drop_last else N |
| | self.min_patch_num = min_patch_num |
| | self.max_patch_num = max_patch_num |
| | assert ( |
| | world_size == 1 or drop_last |
| | ), "must drop the last batch in distributed mode" |
| |
|
| | |
| | self.world_size = world_size |
| | self.rank = rank |
| | self.epoch = None |
| |
|
| | def __len__(self): |
| |
|
| |
|
| | return self.total_size // self.world_size |
| |
|
| | def set_epoch(self, epoch): |
| | self.epoch = epoch |
| |
|
| | def __iter__(self): |
| | |
| | if self.epoch is None: |
| | assert ( |
| | self.world_size == 1 and self.rank == 0 |
| | ), "use set_epoch() if distributed mode is used" |
| | seed = int(torch.empty((), dtype=torch.int64).random_().item()) |
| | else: |
| | seed = self.epoch + 777 |
| | rng = np.random.default_rng(seed=seed) |
| | |
| | |
| | sample_idxs = np.arange(self.total_size) |
| | rng.shuffle(sample_idxs) |
| | |
| | |
| | n_batches = (self.total_size + self.batch_size - 1) // self.batch_size |
| | num_imgs = rng.integers(low=2, high=self.num_context_views, size=n_batches) |
| | |
| | num_imgs = np.broadcast_to(num_imgs[:, None], (n_batches, self.batch_size)) |
| | num_imgs = num_imgs.ravel()[: self.total_size] |
| |
|
| | |
| | idxs = np.c_[sample_idxs, num_imgs] |
| |
|
| | |
| | |
| | size_per_proc = self.batch_size * ( |
| | (self.total_size + self.world_size * self.batch_size - 1) |
| | // (self.world_size * self.batch_size) |
| | ) |
| | idxs = idxs[self.rank * size_per_proc : (self.rank + 1) * size_per_proc] |
| |
|
| | yield from (tuple(idx) for idx in idxs) |
| |
|
| | class DynamicBatchSampler(Sampler): |
| | """ |
| | A custom batch sampler that dynamically adjusts batch size, aspect ratio, and image number |
| | for each sample. Batches within a sample share the same aspect ratio and image number. |
| | """ |
| | def __init__(self, |
| | sampler, |
| | image_num_range, |
| | h_range, |
| | epoch=0, |
| | seed=42, |
| | max_img_per_gpu=48): |
| | """ |
| | Initializes the dynamic batch sampler. |
| | |
| | Args: |
| | sampler: Instance of DynamicDistributedSampler. |
| | aspect_ratio_range: List containing [min_aspect_ratio, max_aspect_ratio]. |
| | image_num_range: List containing [min_images, max_images] per sample. |
| | epoch: Current epoch number. |
| | seed: Random seed for reproducibility. |
| | max_img_per_gpu: Maximum number of images to fit in GPU memory. |
| | """ |
| | self.sampler = sampler |
| | self.image_num_range = image_num_range |
| | self.h_range = h_range |
| | self.rng = random.Random() |
| | |
| | |
| | |
| | self.image_num_weights = {num_images: float(num_images**2) for num_images in range(image_num_range[0], image_num_range[1]+1)} |
| |
|
| | |
| | self.possible_nums = np.array([n for n in self.image_num_weights.keys() |
| | if self.image_num_range[0] <= n <= self.image_num_range[1]]) |
| | |
| | |
| | weights = [self.image_num_weights[n] for n in self.possible_nums] |
| | self.normalized_weights = np.array(weights) / sum(weights) |
| |
|
| | |
| | self.max_img_per_gpu = max_img_per_gpu |
| |
|
| | |
| | self.set_epoch(epoch + seed) |
| |
|
| | def set_epoch(self, epoch): |
| | """ |
| | Sets the epoch for this sampler, affecting the random sequence. |
| | |
| | Args: |
| | epoch: The epoch number. |
| | """ |
| | self.sampler.set_epoch(epoch) |
| | self.epoch = epoch |
| | self.rng.seed(epoch * 100) |
| |
|
| | def __iter__(self): |
| | """ |
| | Yields batches of samples with synchronized dynamic parameters. |
| | |
| | Returns: |
| | Iterator yielding batches of indices with associated parameters. |
| | """ |
| | sampler_iterator = iter(self.sampler) |
| |
|
| | while True: |
| | try: |
| | |
| | random_image_num = int(np.random.choice(self.possible_nums, p=self.normalized_weights)) |
| | random_ps_h = np.random.randint(low=(self.h_range[0] // 14), high=(self.h_range[1] // 14)+1) |
| |
|
| | |
| | self.sampler.update_parameters( |
| | image_num=random_image_num, |
| | ps_h=random_ps_h |
| | ) |
| | |
| | |
| | batch_size = self.max_img_per_gpu / random_image_num |
| | batch_size = np.floor(batch_size).astype(int) |
| | batch_size = max(1, batch_size) |
| |
|
| | |
| | current_batch = [] |
| | for _ in range(batch_size): |
| | try: |
| | item = next(sampler_iterator) |
| | current_batch.append(item) |
| | except StopIteration: |
| | break |
| |
|
| | if not current_batch: |
| | break |
| |
|
| | yield current_batch |
| |
|
| | except StopIteration: |
| | break |
| |
|
| | def __len__(self): |
| | |
| | return 1000000 |
| |
|
| |
|
| | class DynamicDistributedSampler(DistributedSampler): |
| | """ |
| | Extends PyTorch's DistributedSampler to include dynamic aspect_ratio and image_num |
| | parameters, which can be passed into the dataset's __getitem__ method. |
| | """ |
| | def __init__( |
| | self, |
| | dataset, |
| | num_replicas: Optional[int] = None, |
| | rank: Optional[int] = None, |
| | shuffle: bool = False, |
| | seed: int = 0, |
| | drop_last: bool = False, |
| | ): |
| | super().__init__( |
| | dataset, |
| | num_replicas=num_replicas, |
| | rank=rank, |
| | shuffle=shuffle, |
| | seed=seed, |
| | drop_last=drop_last |
| | ) |
| | self.image_num = None |
| | self.ps_h = None |
| |
|
| | def __iter__(self): |
| | """ |
| | Yields a sequence of (index, image_num, aspect_ratio). |
| | Relies on the parent class's logic for shuffling/distributing |
| | the indices across replicas, then attaches extra parameters. |
| | """ |
| | indices_iter = super().__iter__() |
| |
|
| | for idx in indices_iter: |
| | yield (idx, self.image_num, self.ps_h, ) |
| |
|
| | def update_parameters(self, image_num, ps_h): |
| | """ |
| | Updates dynamic parameters for each new epoch or iteration. |
| | |
| | Args: |
| | aspect_ratio: The aspect ratio to set. |
| | image_num: The number of images to set. |
| | """ |
| | self.image_num = image_num |
| | self.ps_h = ps_h |
| |
|
| | class MixedBatchSampler(BatchSampler): |
| | """Sample one batch from a selected dataset with given probability. |
| | Compatible with datasets at different resolution |
| | """ |
| |
|
| | def __init__( |
| | self, src_dataset_ls, batch_size, num_context_views, world_size=1, rank=0, prob=None, sampler=None, generator=None |
| | ): |
| | self.base_sampler = None |
| | self.batch_size = batch_size |
| | self.num_context_views = num_context_views |
| | self.world_size = world_size |
| | self.rank = rank |
| | self.drop_last = True |
| | self.generator = generator |
| |
|
| | self.src_dataset_ls = src_dataset_ls |
| | self.n_dataset = len(self.src_dataset_ls) |
| | |
| | |
| | self.dataset_length = [len(ds) for ds in self.src_dataset_ls] |
| | self.cum_dataset_length = [ |
| | sum(self.dataset_length[:i]) for i in range(self.n_dataset) |
| | ] |
| | |
| | |
| | self.src_batch_samplers = [] |
| | for ds in self.src_dataset_ls: |
| | sampler = DynamicDistributedSampler(ds, num_replicas=self.world_size, rank=self.rank, seed=42, shuffle=True) |
| | sampler.set_epoch(0) |
| |
|
| | if hasattr(ds, "epoch"): |
| | ds.epoch = 0 |
| | if hasattr(ds, "set_epoch"): |
| | ds.set_epoch(0) |
| | batch_sampler = DynamicBatchSampler( |
| | sampler, |
| | [2, ds.cfg.view_sampler.num_context_views], |
| | ds.cfg.input_image_shape, |
| | seed=42, |
| | max_img_per_gpu=ds.cfg.view_sampler.max_img_per_gpu |
| | ) |
| | self.src_batch_samplers.append(batch_sampler) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | print("Setting epoch for all underlying BatchedRandomSamplers") |
| | |
| | |
| | self.raw_batches = [ |
| | list(bs) for bs in self.src_batch_samplers |
| | ] |
| | self.n_batches = [len(b) for b in self.raw_batches] |
| | self.n_total_batch = sum(self.n_batches) |
| | |
| | |
| | if prob is None: |
| | |
| | self.prob = torch.tensor(self.n_batches) / self.n_total_batch |
| | else: |
| | self.prob = torch.as_tensor(prob) |
| | |
| | def __iter__(self): |
| | """Yields batches of indices in the format of (sample_idx, feat_idx) tuples, |
| | where indices correspond to ConcatDataset of src_dataset_ls |
| | """ |
| | for _ in range(self.n_total_batch): |
| | idx_ds = torch.multinomial( |
| | self.prob, 1, replacement=True, generator=self.generator |
| | ).item() |
| | |
| | if 0 == len(self.raw_batches[idx_ds]): |
| | self.raw_batches[idx_ds] = list(self.src_batch_samplers[idx_ds]) |
| | |
| | |
| | batch_raw = self.raw_batches[idx_ds].pop() |
| |
|
| | |
| | shift = self.cum_dataset_length[idx_ds] |
| | processed_batch = [] |
| |
|
| | for item in batch_raw: |
| | |
| | processed_item = (item[0] + shift, item[1], item[2]) |
| | processed_batch.append(processed_item) |
| | yield processed_batch |
| | |
| | def set_epoch(self, epoch): |
| | """Set epoch for all underlying BatchedRandomSamplers""" |
| | for sampler in self.src_batch_samplers: |
| | sampler.set_epoch(epoch) |
| | |
| | self.raw_batches = [list(bs) for bs in self.src_batch_samplers] |
| |
|
| | def __len__(self): |
| | return self.n_total_batch |
| |
|
| | def round_by(total, multiple, up=False): |
| | if up: |
| | total = total + multiple - 1 |
| | return (total // multiple) * multiple |