3d_model / ylff /utils /dynamic_batch.py
Azan
Clean deployment build (Squashed)
7a87926
"""
Dynamic batch sizing utilities.
Automatically adjusts batch size to maximize GPU utilization while avoiding OOM errors.
"""
import logging
import torch
logger = logging.getLogger(__name__)
class DynamicBatchSampler:
"""
Dynamic batch sampler that adjusts batch size based on GPU memory.
Starts with small batch size and gradually increases if successful.
Decreases if OOM occurs.
"""
def __init__(
self,
dataset,
initial_batch_size: int = 1,
max_batch_size: int = 8,
min_batch_size: int = 1,
increase_factor: float = 2.0,
decrease_factor: float = 0.5,
patience: int = 5,
):
"""
Args:
dataset: Dataset to sample from
initial_batch_size: Starting batch size
max_batch_size: Maximum batch size
min_batch_size: Minimum batch size
increase_factor: Factor to increase batch size
decrease_factor: Factor to decrease batch size on OOM
patience: Number of successful batches before increasing
"""
self.dataset = dataset
self.current_batch_size = initial_batch_size
self.max_batch_size = max_batch_size
self.min_batch_size = min_batch_size
self.increase_factor = increase_factor
self.decrease_factor = decrease_factor
self.patience = patience
self.successful_batches = 0
self.total_batches = 0
logger.info(
f"DynamicBatchSampler initialized: "
f"initial={initial_batch_size}, "
f"max={max_batch_size}, "
f"min={min_batch_size}"
)
def get_batch_size(self) -> int:
"""Get current batch size."""
return self.current_batch_size
def on_success(self):
"""Called after successful batch processing."""
self.successful_batches += 1
self.total_batches += 1
# Increase batch size if we've had enough successes
if self.successful_batches >= self.patience:
new_batch_size = int(self.current_batch_size * self.increase_factor)
if new_batch_size <= self.max_batch_size:
old_size = self.current_batch_size
self.current_batch_size = new_batch_size
self.successful_batches = 0
logger.info(f"Batch size increased: {old_size} -> {self.current_batch_size}")
def on_oom(self):
"""Called when OOM error occurs."""
new_batch_size = int(self.current_batch_size * self.decrease_factor)
new_batch_size = max(new_batch_size, self.min_batch_size)
if new_batch_size < self.current_batch_size:
old_size = self.current_batch_size
self.current_batch_size = new_batch_size
self.successful_batches = 0
logger.warning(
f"OOM detected, batch size decreased: {old_size} -> {self.current_batch_size}"
)
# Clear cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
def get_stats(self) -> dict:
"""Get sampler statistics."""
return {
"current_batch_size": self.current_batch_size,
"successful_batches": self.successful_batches,
"total_batches": self.total_batches,
"success_rate": self.successful_batches / max(self.total_batches, 1),
}
class AdaptiveDataLoader:
"""
DataLoader wrapper with dynamic batch sizing.
Automatically adjusts batch size during training.
"""
def __init__(
self,
dataset,
initial_batch_size: int = 1,
max_batch_size: int = 8,
**dataloader_kwargs,
):
"""
Args:
dataset: Dataset
initial_batch_size: Starting batch size
max_batch_size: Maximum batch size
**dataloader_kwargs: Additional DataLoader arguments
"""
self.dataset = dataset
self.initial_batch_size = initial_batch_size
self.max_batch_size = max_batch_size
self.dataloader_kwargs = dataloader_kwargs
self.sampler = DynamicBatchSampler(
dataset,
initial_batch_size=initial_batch_size,
max_batch_size=max_batch_size,
)
self.dataloader = None
self._create_dataloader()
def _create_dataloader(self):
"""Create DataLoader with current batch size."""
from torch.utils.data import DataLoader
self.dataloader = DataLoader(
self.dataset,
batch_size=self.sampler.get_batch_size(),
**self.dataloader_kwargs,
)
def __iter__(self):
"""Iterate over dataloader with error handling."""
iterator = iter(self.dataloader)
while True:
try:
batch = next(iterator)
yield batch
self.sampler.on_success()
except StopIteration:
break
except RuntimeError as e:
if "out of memory" in str(e):
self.sampler.on_oom()
# Recreate dataloader with new batch size
self._create_dataloader()
iterator = iter(self.dataloader)
# Retry with smaller batch
try:
batch = next(iterator)
yield batch
self.sampler.on_success()
except StopIteration:
break
else:
raise
def __len__(self):
"""Length of dataloader."""
return len(self.dataloader)
def get_batch_size(self) -> int:
"""Get current batch size."""
return self.sampler.get_batch_size()
def get_stats(self) -> dict:
"""Get statistics."""
return self.sampler.get_stats()