SFM_Inference_Demo / util /size_aware_batching.py
Anirudh Bhalekar
added models and util folder
a3f0d6c
"""
Size-aware batching utilities for variable-sized seismic images
"""
import torch
from torch.utils.data import DataLoader, Sampler
import numpy as np
from collections import defaultdict
import random
class SizeAwareSampler(Sampler):
"""
Groups samples by size and creates batches with images of the same size
"""
def __init__(self, dataset, batch_size, get_size_fn=None):
"""
Args:
dataset: PyTorch dataset
batch_size: batch size for each size group
get_size_fn: function that takes dataset index and returns (height, width)
If None, will try to infer from dataset
"""
self.dataset = dataset
self.batch_size = batch_size
self.get_size_fn = get_size_fn
# Group indices by size
self.size_groups = self._group_by_size()
# Create batches
self.batches = self._create_batches()
def _group_by_size(self):
"""Group dataset indices by image size"""
size_groups = defaultdict(list)
for idx in range(len(self.dataset)):
if self.get_size_fn:
size = self.get_size_fn(idx)
else:
# Try to get size from dataset item
sample = self.dataset[idx]
if isinstance(sample, (tuple, list)):
# Assume first element is the image tensor
img_tensor = sample[0]
else:
img_tensor = sample
# Get size from tensor shape (assuming shape is [C, H, W] or [H, W])
if len(img_tensor.shape) == 3:
size = (img_tensor.shape[1], img_tensor.shape[2]) # H, W
elif len(img_tensor.shape) == 2:
size = (img_tensor.shape[0], img_tensor.shape[1]) # H, W
else:
raise ValueError(f"Unexpected tensor shape: {img_tensor.shape}")
size_groups[size].append(idx)
return size_groups
def _create_batches(self, random_size = True):
"""Create batches from size groups"""
batches = []
for size, indices in self.size_groups.items():
# Shuffle indices within each size group
random.shuffle(indices)
# Create batches of the specified size
for i in range(0, len(indices), self.batch_size):
batch = indices[i:i + self.batch_size]
batches.append(batch)
return batches
def __iter__(self):
# Shuffle the order of batches
random.shuffle(self.batches)
for batch in self.batches:
yield batch
def __len__(self):
return len(self.batches)
class FixedSizeSampler(Sampler):
"""
Sampler for datasets where you know the exact 3 size categories
More efficient than SizeAwareSampler when sizes are known
"""
def __init__(self, dataset, batch_size, size_categories):
"""
Args:
dataset: PyTorch dataset
batch_size: batch size for each size category
size_categories: list of (height, width) tuples for the 3 categories
e.g., [(601, 200), (200, 255), (601, 255)]
"""
self.dataset = dataset
self.batch_size = batch_size
self.size_categories = size_categories
# Map indices to size categories
self.size_to_indices = {size: [] for size in size_categories}
self._categorize_indices()
# Create batches
self.batches = self._create_batches()
def _categorize_indices(self):
"""Categorize dataset indices by their size"""
for idx in range(len(self.dataset)):
sample = self.dataset[idx]
if isinstance(sample, (tuple, list)):
img_tensor = sample[0]
else:
img_tensor = sample
# Get size from tensor
if len(img_tensor.shape) == 3:
size = (img_tensor.shape[1], img_tensor.shape[2])
elif len(img_tensor.shape) == 2:
size = (img_tensor.shape[0], img_tensor.shape[1])
else:
raise ValueError(f"Unexpected tensor shape: {img_tensor.shape}")
# Find matching category
if size in self.size_categories:
self.size_to_indices[size].append(idx)
else:
# Find closest size category (optional)
closest_size = min(self.size_categories,
key=lambda cat: abs(cat[0] - size[0]) + abs(cat[1] - size[1]))
print(f"Warning: Size {size} not in categories, assigning to {closest_size}")
self.size_to_indices[closest_size].append(idx)
def _create_batches(self, random_size = True):
"""Create batches from size categories"""
batches = []
for size, indices in self.size_to_indices.items():
if not indices:
continue
# Shuffle indices within each size category
random.shuffle(indices)
# Create batches
for i in range(0, len(indices), self.batch_size):
batch = indices[i:i + self.batch_size]
batches.append(batch)
return batches
def __iter__(self):
# Shuffle the order of batches across all size categories
random.shuffle(self.batches)
for batch in self.batches:
yield batch
def __len__(self):
return len(self.batches)
def get_size_distribution(self):
"""Get the distribution of samples across size categories"""
distribution = {}
for size, indices in self.size_to_indices.items():
distribution[size] = len(indices)
return distribution
def create_size_aware_dataloader(dataset, batch_size=8, size_categories=None,
num_workers=4, pin_memory=True, **kwargs):
"""
Create a DataLoader that batches samples by size
Args:
dataset: PyTorch dataset
batch_size: batch size for each size group
size_categories: list of (height, width) tuples for known size categories
If None, will auto-detect sizes
num_workers: number of worker processes
pin_memory: whether to pin memory
**kwargs: additional arguments for DataLoader
Returns:
DataLoader with size-aware batching
"""
if size_categories:
sampler = FixedSizeSampler(dataset, batch_size, size_categories)
else:
sampler = SizeAwareSampler(dataset, batch_size)
# Remove batch_size from kwargs since we're using a custom sampler
kwargs.pop('batch_size', None)
kwargs.pop('shuffle', None) # Sampler handles shuffling
return DataLoader(
dataset,
batch_sampler=sampler,
num_workers=num_workers,
pin_memory=pin_memory,
**kwargs
)
# Custom collate function for same-size batches (no padding needed)
def same_size_collate_fn(batch):
"""
Collate function for batches where all items have the same size
No padding required since all images in batch are same size
"""
if isinstance(batch[0], (tuple, list)):
# Assuming (image, target) pairs
images, targets = zip(*batch)
return torch.stack(images), torch.stack(targets)
else:
# Just images
return torch.stack(batch)
# Utility function to check batch sizes
def validate_batch_sizes(dataloader, num_batches_to_check=5):
"""
Validate that all images in each batch have the same size
"""
print("Validating batch sizes...")
for i, batch in enumerate(dataloader):
if i >= num_batches_to_check:
break
if isinstance(batch, (tuple, list)):
images = batch[0]
else:
images = batch
batch_size = images.shape[0]
height = images.shape[2]
width = images.shape[3]
print(f"Batch {i}: {batch_size} images of size {height}x{width}")
# Verify all images in batch have same size
for j in range(batch_size):
img_h, img_w = images[j].shape[1], images[j].shape[2]
if img_h != height or img_w != width:
print(f" WARNING: Image {j} has different size {img_h}x{img_w}")
print("Validation complete!")