""" Size-aware batching utilities for variable-sized seismic images """ import torch from torch.utils.data import DataLoader, Sampler import numpy as np from collections import defaultdict import random class SizeAwareSampler(Sampler): """ Groups samples by size and creates batches with images of the same size """ def __init__(self, dataset, batch_size, get_size_fn=None): """ Args: dataset: PyTorch dataset batch_size: batch size for each size group get_size_fn: function that takes dataset index and returns (height, width) If None, will try to infer from dataset """ self.dataset = dataset self.batch_size = batch_size self.get_size_fn = get_size_fn # Group indices by size self.size_groups = self._group_by_size() # Create batches self.batches = self._create_batches() def _group_by_size(self): """Group dataset indices by image size""" size_groups = defaultdict(list) for idx in range(len(self.dataset)): if self.get_size_fn: size = self.get_size_fn(idx) else: # Try to get size from dataset item sample = self.dataset[idx] if isinstance(sample, (tuple, list)): # Assume first element is the image tensor img_tensor = sample[0] else: img_tensor = sample # Get size from tensor shape (assuming shape is [C, H, W] or [H, W]) if len(img_tensor.shape) == 3: size = (img_tensor.shape[1], img_tensor.shape[2]) # H, W elif len(img_tensor.shape) == 2: size = (img_tensor.shape[0], img_tensor.shape[1]) # H, W else: raise ValueError(f"Unexpected tensor shape: {img_tensor.shape}") size_groups[size].append(idx) return size_groups def _create_batches(self, random_size = True): """Create batches from size groups""" batches = [] for size, indices in self.size_groups.items(): # Shuffle indices within each size group random.shuffle(indices) # Create batches of the specified size for i in range(0, len(indices), self.batch_size): batch = indices[i:i + self.batch_size] batches.append(batch) return batches def __iter__(self): # Shuffle the order of batches random.shuffle(self.batches) for batch in self.batches: yield batch def __len__(self): return len(self.batches) class FixedSizeSampler(Sampler): """ Sampler for datasets where you know the exact 3 size categories More efficient than SizeAwareSampler when sizes are known """ def __init__(self, dataset, batch_size, size_categories): """ Args: dataset: PyTorch dataset batch_size: batch size for each size category size_categories: list of (height, width) tuples for the 3 categories e.g., [(601, 200), (200, 255), (601, 255)] """ self.dataset = dataset self.batch_size = batch_size self.size_categories = size_categories # Map indices to size categories self.size_to_indices = {size: [] for size in size_categories} self._categorize_indices() # Create batches self.batches = self._create_batches() def _categorize_indices(self): """Categorize dataset indices by their size""" for idx in range(len(self.dataset)): sample = self.dataset[idx] if isinstance(sample, (tuple, list)): img_tensor = sample[0] else: img_tensor = sample # Get size from tensor if len(img_tensor.shape) == 3: size = (img_tensor.shape[1], img_tensor.shape[2]) elif len(img_tensor.shape) == 2: size = (img_tensor.shape[0], img_tensor.shape[1]) else: raise ValueError(f"Unexpected tensor shape: {img_tensor.shape}") # Find matching category if size in self.size_categories: self.size_to_indices[size].append(idx) else: # Find closest size category (optional) closest_size = min(self.size_categories, key=lambda cat: abs(cat[0] - size[0]) + abs(cat[1] - size[1])) print(f"Warning: Size {size} not in categories, assigning to {closest_size}") self.size_to_indices[closest_size].append(idx) def _create_batches(self, random_size = True): """Create batches from size categories""" batches = [] for size, indices in self.size_to_indices.items(): if not indices: continue # Shuffle indices within each size category random.shuffle(indices) # Create batches for i in range(0, len(indices), self.batch_size): batch = indices[i:i + self.batch_size] batches.append(batch) return batches def __iter__(self): # Shuffle the order of batches across all size categories random.shuffle(self.batches) for batch in self.batches: yield batch def __len__(self): return len(self.batches) def get_size_distribution(self): """Get the distribution of samples across size categories""" distribution = {} for size, indices in self.size_to_indices.items(): distribution[size] = len(indices) return distribution def create_size_aware_dataloader(dataset, batch_size=8, size_categories=None, num_workers=4, pin_memory=True, **kwargs): """ Create a DataLoader that batches samples by size Args: dataset: PyTorch dataset batch_size: batch size for each size group size_categories: list of (height, width) tuples for known size categories If None, will auto-detect sizes num_workers: number of worker processes pin_memory: whether to pin memory **kwargs: additional arguments for DataLoader Returns: DataLoader with size-aware batching """ if size_categories: sampler = FixedSizeSampler(dataset, batch_size, size_categories) else: sampler = SizeAwareSampler(dataset, batch_size) # Remove batch_size from kwargs since we're using a custom sampler kwargs.pop('batch_size', None) kwargs.pop('shuffle', None) # Sampler handles shuffling return DataLoader( dataset, batch_sampler=sampler, num_workers=num_workers, pin_memory=pin_memory, **kwargs ) # Custom collate function for same-size batches (no padding needed) def same_size_collate_fn(batch): """ Collate function for batches where all items have the same size No padding required since all images in batch are same size """ if isinstance(batch[0], (tuple, list)): # Assuming (image, target) pairs images, targets = zip(*batch) return torch.stack(images), torch.stack(targets) else: # Just images return torch.stack(batch) # Utility function to check batch sizes def validate_batch_sizes(dataloader, num_batches_to_check=5): """ Validate that all images in each batch have the same size """ print("Validating batch sizes...") for i, batch in enumerate(dataloader): if i >= num_batches_to_check: break if isinstance(batch, (tuple, list)): images = batch[0] else: images = batch batch_size = images.shape[0] height = images.shape[2] width = images.shape[3] print(f"Batch {i}: {batch_size} images of size {height}x{width}") # Verify all images in batch have same size for j in range(batch_size): img_h, img_w = images[j].shape[1], images[j].shape[2] if img_h != height or img_w != width: print(f" WARNING: Image {j} has different size {img_h}x{img_w}") print("Validation complete!")