Spaces:
Running on Zero
Running on Zero
| """ | |
| Multi-DataLoader for combining multiple dataloaders with weighted batch sizes. | |
| This module provides a way to train on multiple datasets simultaneously by | |
| creating separate dataloaders for each dataset and combining their batches. | |
| Batch sizes directly control the sampling weight for each dataset. | |
| """ | |
| import torch | |
| from typing import List, Optional | |
| from torch.utils.data import DataLoader | |
| def collate_objaverse_batch(batch): | |
| """Collate function for Objaverse datasets (SimplifiedViewpointDataset). | |
| Converts tuple format to dict format compatible with CombinedDataset. | |
| Args: | |
| batch: List of tuples (image, rotation, translation, relative_rotation, spherical_angular) | |
| Returns: | |
| Dict with batched tensors and dataset_type | |
| """ | |
| images = torch.stack([item[0] for item in batch]) | |
| rotations = torch.stack([item[1] for item in batch]) | |
| translations = torch.stack([item[2] for item in batch]) | |
| relative_rotations = torch.stack([item[3] for item in batch]) | |
| spherical_angular = torch.stack([item[4] for item in batch]) | |
| return { | |
| 'image': images, | |
| 'rotation': rotations, | |
| 'translation': translations, | |
| 'relative_rotation': relative_rotations, | |
| 'spherical_angular': spherical_angular, | |
| 'dataset_type': torch.tensor([0] * len(batch), dtype=torch.long), # 0 = objaverse | |
| } | |
| def collate_compass_batch(batch): | |
| """Collate function for Compass datasets (CompassDataset). | |
| Converts tuple format to dict format compatible with CombinedDataset. | |
| Args: | |
| batch: List of tuples (image, azimuth, category) | |
| Note: category is ignored (3rd element not used) | |
| Returns: | |
| Dict with batched tensors and dataset_type | |
| """ | |
| images = torch.stack([item[0] for item in batch]) | |
| azimuths = torch.stack([item[1] for item in batch]) | |
| # Note: item[2] is category but we don't use it | |
| # Build spherical_angular with only azimuth (rest are zeros) | |
| # Format: [sin(az), cos(az), sin(el), cos(el), norm_radius, norm_yaw, norm_pitch] | |
| sin_az = torch.sin(azimuths) | |
| cos_az = torch.cos(azimuths) | |
| zeros = torch.zeros_like(azimuths) | |
| spherical_angular = torch.stack([ | |
| sin_az, cos_az, zeros, zeros, zeros, zeros, zeros | |
| ], dim=1) # (B, 7) | |
| return { | |
| 'image': images, | |
| 'spherical_angular': spherical_angular, | |
| 'dataset_type': torch.tensor([1] * len(batch), dtype=torch.long), # 1 = compass | |
| } | |
| def combine_batches(batches: List[dict]) -> dict: | |
| """Combine multiple batched dicts into a single batch. | |
| Args: | |
| batches: List of batch dicts from different dataloaders | |
| Returns: | |
| Combined batch dict with concatenated tensors | |
| """ | |
| if len(batches) == 1: | |
| return batches[0] | |
| combined = {} | |
| # Get all keys | |
| all_keys = set() | |
| for batch in batches: | |
| all_keys.update(batch.keys()) | |
| for key in all_keys: | |
| values = [batch[key] for batch in batches if key in batch] | |
| # Concatenate tensors, extend lists | |
| if torch.is_tensor(values[0]): | |
| combined[key] = torch.cat(values, dim=0) | |
| elif isinstance(values[0], list): | |
| combined[key] = sum(values, []) # Flatten lists | |
| else: | |
| combined[key] = values | |
| return combined | |
| class MultiDataLoader: | |
| """Combines multiple dataloaders with specified batch sizes. | |
| This dataloader iterates multiple dataloaders in parallel, combining their | |
| batches into a single batch. The epoch length is determined by the longest | |
| dataloader, and shorter dataloaders are restarted when exhausted (cyclic). | |
| Args: | |
| dataloaders: List of DataLoader instances to combine | |
| batch_sizes: List of batch sizes for each dataloader (must sum to total batch size) | |
| collate_fn: Optional collate function to combine batches from different dataloaders. | |
| If None, batches are returned as a list. | |
| Example: | |
| >>> dl1 = DataLoader(dataset1, batch_size=28, num_workers=4, shuffle=True) | |
| >>> dl2 = DataLoader(dataset2, batch_size=4, num_workers=4, shuffle=True) | |
| >>> multi_dl = MultiDataLoader([dl1, dl2], batch_sizes=[28, 4]) | |
| >>> for batch in multi_dl: | |
| ... # batch contains 32 samples (28 from dl1 + 4 from dl2) | |
| ... outputs = model(batch) | |
| """ | |
| def __init__( | |
| self, | |
| dataloaders: List[DataLoader], | |
| batch_sizes: List[int], | |
| collate_fn: Optional[callable] = None, | |
| ): | |
| assert len(dataloaders) > 0, "Must provide at least one dataloader" | |
| assert len(dataloaders) == len(batch_sizes), \ | |
| f"Number of dataloaders ({len(dataloaders)}) must match number of batch_sizes ({len(batch_sizes)})" | |
| # Verify batch sizes match dataloader configurations | |
| for i, (dl, expected_bs) in enumerate(zip(dataloaders, batch_sizes)): | |
| actual_bs = dl.batch_size | |
| if actual_bs != expected_bs: | |
| print(f"Warning: DataLoader {i} has batch_size={actual_bs} but expected {expected_bs}") | |
| self.dataloaders = dataloaders | |
| self.batch_sizes = batch_sizes | |
| self.collate_fn = collate_fn | |
| # Calculate total batch size | |
| self.total_batch_size = sum(batch_sizes) | |
| # Calculate epoch length (longest dataloader) | |
| self._length = max(len(dl) for dl in dataloaders) | |
| print(f"MultiDataLoader initialized with {len(dataloaders)} dataloaders:") | |
| for i, (dl, bs) in enumerate(zip(dataloaders, batch_sizes)): | |
| print(f" [{i}] batch_size={bs}, length={len(dl)} iterations") | |
| print(f"Total batch size: {self.total_batch_size}") | |
| print(f"Epoch length: {self._length} iterations (determined by longest dataloader)") | |
| def __len__(self) -> int: | |
| """Return the number of iterations per epoch (length of longest dataloader).""" | |
| return self._length | |
| def __iter__(self): | |
| """Iterate through all dataloaders in parallel, combining their batches.""" | |
| # Create iterators for all dataloaders | |
| iterators = [iter(dl) for dl in self.dataloaders] | |
| # Iterate for the length of the longest dataloader | |
| for iteration_idx in range(len(self)): | |
| batches = [] | |
| # Get one batch from each dataloader | |
| for i, iterator in enumerate(iterators): | |
| try: | |
| batch = next(iterator) | |
| except StopIteration: | |
| # This dataloader is exhausted, restart it (cyclic behavior) | |
| iterators[i] = iter(self.dataloaders[i]) | |
| batch = next(iterators[i]) | |
| batches.append(batch) | |
| # Combine batches | |
| if self.collate_fn is not None: | |
| # Use custom collate function | |
| combined_batch = self.collate_fn(batches) | |
| elif len(batches) > 1: | |
| # Default: return batches as-is (no combination) | |
| # This is for the default PyTorch collate which returns lists/tuples | |
| combined_batch = batches | |
| else: | |
| # Single batch, return as-is | |
| combined_batch = batches[0] | |
| yield combined_batch | |
| def test_multi_dataloader(): | |
| """Test MultiDataLoader with dummy datasets.""" | |
| import torch | |
| from torch.utils.data import TensorDataset, DataLoader | |
| print("Testing MultiDataLoader...") | |
| # Create dummy datasets with different sizes | |
| data1 = torch.randn(100, 3, 224, 224) | |
| labels1 = torch.randint(0, 10, (100,)) | |
| dataset1 = TensorDataset(data1, labels1) | |
| data2 = torch.randn(30, 3, 224, 224) | |
| labels2 = torch.randint(0, 10, (30,)) | |
| dataset2 = TensorDataset(data2, labels2) | |
| # Create dataloaders with different batch sizes | |
| dl1 = DataLoader(dataset1, batch_size=8, shuffle=True) | |
| dl2 = DataLoader(dataset2, batch_size=2, shuffle=True) | |
| print(f"\nDataLoader 1: {len(dataset1)} samples, batch_size=8 → {len(dl1)} iterations") | |
| print(f"DataLoader 2: {len(dataset2)} samples, batch_size=2 → {len(dl2)} iterations") | |
| # Create multi-dataloader | |
| multi_dl = MultiDataLoader([dl1, dl2], batch_sizes=[8, 2]) | |
| print(f"\nMultiDataLoader length: {len(multi_dl)} iterations") | |
| print(f"Expected samples per iteration: 10 (8+2)") | |
| # Test iteration | |
| print("\nTesting first 5 iterations:") | |
| for i, batches in enumerate(multi_dl): | |
| if i >= 5: | |
| break | |
| batch1, batch2 = batches | |
| img1, lbl1 = batch1 | |
| img2, lbl2 = batch2 | |
| print(f" Iteration {i}: batch1={img1.shape}, batch2={img2.shape}") | |
| # Verify batch sizes | |
| assert img1.shape[0] == 8, f"Expected batch size 8, got {img1.shape[0]}" | |
| assert img2.shape[0] == 2, f"Expected batch size 2, got {img2.shape[0]}" | |
| # Test full epoch | |
| print("\nTesting full epoch...") | |
| total_iterations = 0 | |
| for _ in multi_dl: | |
| total_iterations += 1 | |
| print(f"Total iterations in one epoch: {total_iterations}") | |
| assert total_iterations == len(multi_dl), \ | |
| f"Expected {len(multi_dl)} iterations, got {total_iterations}" | |
| print("\nTest passed!") | |
| if __name__ == "__main__": | |
| test_multi_dataloader() | |