""" Dynamic batch sizing utilities. Automatically adjusts batch size to maximize GPU utilization while avoiding OOM errors. """ import logging import torch logger = logging.getLogger(__name__) class DynamicBatchSampler: """ Dynamic batch sampler that adjusts batch size based on GPU memory. Starts with small batch size and gradually increases if successful. Decreases if OOM occurs. """ def __init__( self, dataset, initial_batch_size: int = 1, max_batch_size: int = 8, min_batch_size: int = 1, increase_factor: float = 2.0, decrease_factor: float = 0.5, patience: int = 5, ): """ Args: dataset: Dataset to sample from initial_batch_size: Starting batch size max_batch_size: Maximum batch size min_batch_size: Minimum batch size increase_factor: Factor to increase batch size decrease_factor: Factor to decrease batch size on OOM patience: Number of successful batches before increasing """ self.dataset = dataset self.current_batch_size = initial_batch_size self.max_batch_size = max_batch_size self.min_batch_size = min_batch_size self.increase_factor = increase_factor self.decrease_factor = decrease_factor self.patience = patience self.successful_batches = 0 self.total_batches = 0 logger.info( f"DynamicBatchSampler initialized: " f"initial={initial_batch_size}, " f"max={max_batch_size}, " f"min={min_batch_size}" ) def get_batch_size(self) -> int: """Get current batch size.""" return self.current_batch_size def on_success(self): """Called after successful batch processing.""" self.successful_batches += 1 self.total_batches += 1 # Increase batch size if we've had enough successes if self.successful_batches >= self.patience: new_batch_size = int(self.current_batch_size * self.increase_factor) if new_batch_size <= self.max_batch_size: old_size = self.current_batch_size self.current_batch_size = new_batch_size self.successful_batches = 0 logger.info(f"Batch size increased: {old_size} -> {self.current_batch_size}") def on_oom(self): """Called when OOM error occurs.""" new_batch_size = int(self.current_batch_size * self.decrease_factor) new_batch_size = max(new_batch_size, self.min_batch_size) if new_batch_size < self.current_batch_size: old_size = self.current_batch_size self.current_batch_size = new_batch_size self.successful_batches = 0 logger.warning( f"OOM detected, batch size decreased: {old_size} -> {self.current_batch_size}" ) # Clear cache if torch.cuda.is_available(): torch.cuda.empty_cache() def get_stats(self) -> dict: """Get sampler statistics.""" return { "current_batch_size": self.current_batch_size, "successful_batches": self.successful_batches, "total_batches": self.total_batches, "success_rate": self.successful_batches / max(self.total_batches, 1), } class AdaptiveDataLoader: """ DataLoader wrapper with dynamic batch sizing. Automatically adjusts batch size during training. """ def __init__( self, dataset, initial_batch_size: int = 1, max_batch_size: int = 8, **dataloader_kwargs, ): """ Args: dataset: Dataset initial_batch_size: Starting batch size max_batch_size: Maximum batch size **dataloader_kwargs: Additional DataLoader arguments """ self.dataset = dataset self.initial_batch_size = initial_batch_size self.max_batch_size = max_batch_size self.dataloader_kwargs = dataloader_kwargs self.sampler = DynamicBatchSampler( dataset, initial_batch_size=initial_batch_size, max_batch_size=max_batch_size, ) self.dataloader = None self._create_dataloader() def _create_dataloader(self): """Create DataLoader with current batch size.""" from torch.utils.data import DataLoader self.dataloader = DataLoader( self.dataset, batch_size=self.sampler.get_batch_size(), **self.dataloader_kwargs, ) def __iter__(self): """Iterate over dataloader with error handling.""" iterator = iter(self.dataloader) while True: try: batch = next(iterator) yield batch self.sampler.on_success() except StopIteration: break except RuntimeError as e: if "out of memory" in str(e): self.sampler.on_oom() # Recreate dataloader with new batch size self._create_dataloader() iterator = iter(self.dataloader) # Retry with smaller batch try: batch = next(iterator) yield batch self.sampler.on_success() except StopIteration: break else: raise def __len__(self): """Length of dataloader.""" return len(self.dataloader) def get_batch_size(self) -> int: """Get current batch size.""" return self.sampler.get_batch_size() def get_stats(self) -> dict: """Get statistics.""" return self.sampler.get_stats()