Spaces:
Build error
Build error
| """ | |
| Size-aware batching utilities for variable-sized seismic images | |
| """ | |
| import torch | |
| from torch.utils.data import DataLoader, Sampler | |
| import numpy as np | |
| from collections import defaultdict | |
| import random | |
| class SizeAwareSampler(Sampler): | |
| """ | |
| Groups samples by size and creates batches with images of the same size | |
| """ | |
| def __init__(self, dataset, batch_size, get_size_fn=None): | |
| """ | |
| Args: | |
| dataset: PyTorch dataset | |
| batch_size: batch size for each size group | |
| get_size_fn: function that takes dataset index and returns (height, width) | |
| If None, will try to infer from dataset | |
| """ | |
| self.dataset = dataset | |
| self.batch_size = batch_size | |
| self.get_size_fn = get_size_fn | |
| # Group indices by size | |
| self.size_groups = self._group_by_size() | |
| # Create batches | |
| self.batches = self._create_batches() | |
| def _group_by_size(self): | |
| """Group dataset indices by image size""" | |
| size_groups = defaultdict(list) | |
| for idx in range(len(self.dataset)): | |
| if self.get_size_fn: | |
| size = self.get_size_fn(idx) | |
| else: | |
| # Try to get size from dataset item | |
| sample = self.dataset[idx] | |
| if isinstance(sample, (tuple, list)): | |
| # Assume first element is the image tensor | |
| img_tensor = sample[0] | |
| else: | |
| img_tensor = sample | |
| # Get size from tensor shape (assuming shape is [C, H, W] or [H, W]) | |
| if len(img_tensor.shape) == 3: | |
| size = (img_tensor.shape[1], img_tensor.shape[2]) # H, W | |
| elif len(img_tensor.shape) == 2: | |
| size = (img_tensor.shape[0], img_tensor.shape[1]) # H, W | |
| else: | |
| raise ValueError(f"Unexpected tensor shape: {img_tensor.shape}") | |
| size_groups[size].append(idx) | |
| return size_groups | |
| def _create_batches(self, random_size = True): | |
| """Create batches from size groups""" | |
| batches = [] | |
| for size, indices in self.size_groups.items(): | |
| # Shuffle indices within each size group | |
| random.shuffle(indices) | |
| # Create batches of the specified size | |
| for i in range(0, len(indices), self.batch_size): | |
| batch = indices[i:i + self.batch_size] | |
| batches.append(batch) | |
| return batches | |
| def __iter__(self): | |
| # Shuffle the order of batches | |
| random.shuffle(self.batches) | |
| for batch in self.batches: | |
| yield batch | |
| def __len__(self): | |
| return len(self.batches) | |
| class FixedSizeSampler(Sampler): | |
| """ | |
| Sampler for datasets where you know the exact 3 size categories | |
| More efficient than SizeAwareSampler when sizes are known | |
| """ | |
| def __init__(self, dataset, batch_size, size_categories): | |
| """ | |
| Args: | |
| dataset: PyTorch dataset | |
| batch_size: batch size for each size category | |
| size_categories: list of (height, width) tuples for the 3 categories | |
| e.g., [(601, 200), (200, 255), (601, 255)] | |
| """ | |
| self.dataset = dataset | |
| self.batch_size = batch_size | |
| self.size_categories = size_categories | |
| # Map indices to size categories | |
| self.size_to_indices = {size: [] for size in size_categories} | |
| self._categorize_indices() | |
| # Create batches | |
| self.batches = self._create_batches() | |
| def _categorize_indices(self): | |
| """Categorize dataset indices by their size""" | |
| for idx in range(len(self.dataset)): | |
| sample = self.dataset[idx] | |
| if isinstance(sample, (tuple, list)): | |
| img_tensor = sample[0] | |
| else: | |
| img_tensor = sample | |
| # Get size from tensor | |
| if len(img_tensor.shape) == 3: | |
| size = (img_tensor.shape[1], img_tensor.shape[2]) | |
| elif len(img_tensor.shape) == 2: | |
| size = (img_tensor.shape[0], img_tensor.shape[1]) | |
| else: | |
| raise ValueError(f"Unexpected tensor shape: {img_tensor.shape}") | |
| # Find matching category | |
| if size in self.size_categories: | |
| self.size_to_indices[size].append(idx) | |
| else: | |
| # Find closest size category (optional) | |
| closest_size = min(self.size_categories, | |
| key=lambda cat: abs(cat[0] - size[0]) + abs(cat[1] - size[1])) | |
| print(f"Warning: Size {size} not in categories, assigning to {closest_size}") | |
| self.size_to_indices[closest_size].append(idx) | |
| def _create_batches(self, random_size = True): | |
| """Create batches from size categories""" | |
| batches = [] | |
| for size, indices in self.size_to_indices.items(): | |
| if not indices: | |
| continue | |
| # Shuffle indices within each size category | |
| random.shuffle(indices) | |
| # Create batches | |
| for i in range(0, len(indices), self.batch_size): | |
| batch = indices[i:i + self.batch_size] | |
| batches.append(batch) | |
| return batches | |
| def __iter__(self): | |
| # Shuffle the order of batches across all size categories | |
| random.shuffle(self.batches) | |
| for batch in self.batches: | |
| yield batch | |
| def __len__(self): | |
| return len(self.batches) | |
| def get_size_distribution(self): | |
| """Get the distribution of samples across size categories""" | |
| distribution = {} | |
| for size, indices in self.size_to_indices.items(): | |
| distribution[size] = len(indices) | |
| return distribution | |
| def create_size_aware_dataloader(dataset, batch_size=8, size_categories=None, | |
| num_workers=4, pin_memory=True, **kwargs): | |
| """ | |
| Create a DataLoader that batches samples by size | |
| Args: | |
| dataset: PyTorch dataset | |
| batch_size: batch size for each size group | |
| size_categories: list of (height, width) tuples for known size categories | |
| If None, will auto-detect sizes | |
| num_workers: number of worker processes | |
| pin_memory: whether to pin memory | |
| **kwargs: additional arguments for DataLoader | |
| Returns: | |
| DataLoader with size-aware batching | |
| """ | |
| if size_categories: | |
| sampler = FixedSizeSampler(dataset, batch_size, size_categories) | |
| else: | |
| sampler = SizeAwareSampler(dataset, batch_size) | |
| # Remove batch_size from kwargs since we're using a custom sampler | |
| kwargs.pop('batch_size', None) | |
| kwargs.pop('shuffle', None) # Sampler handles shuffling | |
| return DataLoader( | |
| dataset, | |
| batch_sampler=sampler, | |
| num_workers=num_workers, | |
| pin_memory=pin_memory, | |
| **kwargs | |
| ) | |
| # Custom collate function for same-size batches (no padding needed) | |
| def same_size_collate_fn(batch): | |
| """ | |
| Collate function for batches where all items have the same size | |
| No padding required since all images in batch are same size | |
| """ | |
| if isinstance(batch[0], (tuple, list)): | |
| # Assuming (image, target) pairs | |
| images, targets = zip(*batch) | |
| return torch.stack(images), torch.stack(targets) | |
| else: | |
| # Just images | |
| return torch.stack(batch) | |
| # Utility function to check batch sizes | |
| def validate_batch_sizes(dataloader, num_batches_to_check=5): | |
| """ | |
| Validate that all images in each batch have the same size | |
| """ | |
| print("Validating batch sizes...") | |
| for i, batch in enumerate(dataloader): | |
| if i >= num_batches_to_check: | |
| break | |
| if isinstance(batch, (tuple, list)): | |
| images = batch[0] | |
| else: | |
| images = batch | |
| batch_size = images.shape[0] | |
| height = images.shape[2] | |
| width = images.shape[3] | |
| print(f"Batch {i}: {batch_size} images of size {height}x{width}") | |
| # Verify all images in batch have same size | |
| for j in range(batch_size): | |
| img_h, img_w = images[j].shape[1], images[j].shape[2] | |
| if img_h != height or img_w != width: | |
| print(f" WARNING: Image {j} has different size {img_h}x{img_w}") | |
| print("Validation complete!") | |