|
|
""" |
|
|
Dynamic batch sizing utilities. |
|
|
|
|
|
Automatically adjusts batch size to maximize GPU utilization while avoiding OOM errors. |
|
|
""" |
|
|
|
|
|
import logging |
|
|
import torch |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class DynamicBatchSampler: |
|
|
""" |
|
|
Dynamic batch sampler that adjusts batch size based on GPU memory. |
|
|
|
|
|
Starts with small batch size and gradually increases if successful. |
|
|
Decreases if OOM occurs. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dataset, |
|
|
initial_batch_size: int = 1, |
|
|
max_batch_size: int = 8, |
|
|
min_batch_size: int = 1, |
|
|
increase_factor: float = 2.0, |
|
|
decrease_factor: float = 0.5, |
|
|
patience: int = 5, |
|
|
): |
|
|
""" |
|
|
Args: |
|
|
dataset: Dataset to sample from |
|
|
initial_batch_size: Starting batch size |
|
|
max_batch_size: Maximum batch size |
|
|
min_batch_size: Minimum batch size |
|
|
increase_factor: Factor to increase batch size |
|
|
decrease_factor: Factor to decrease batch size on OOM |
|
|
patience: Number of successful batches before increasing |
|
|
""" |
|
|
self.dataset = dataset |
|
|
self.current_batch_size = initial_batch_size |
|
|
self.max_batch_size = max_batch_size |
|
|
self.min_batch_size = min_batch_size |
|
|
self.increase_factor = increase_factor |
|
|
self.decrease_factor = decrease_factor |
|
|
self.patience = patience |
|
|
|
|
|
self.successful_batches = 0 |
|
|
self.total_batches = 0 |
|
|
|
|
|
logger.info( |
|
|
f"DynamicBatchSampler initialized: " |
|
|
f"initial={initial_batch_size}, " |
|
|
f"max={max_batch_size}, " |
|
|
f"min={min_batch_size}" |
|
|
) |
|
|
|
|
|
def get_batch_size(self) -> int: |
|
|
"""Get current batch size.""" |
|
|
return self.current_batch_size |
|
|
|
|
|
def on_success(self): |
|
|
"""Called after successful batch processing.""" |
|
|
self.successful_batches += 1 |
|
|
self.total_batches += 1 |
|
|
|
|
|
|
|
|
if self.successful_batches >= self.patience: |
|
|
new_batch_size = int(self.current_batch_size * self.increase_factor) |
|
|
if new_batch_size <= self.max_batch_size: |
|
|
old_size = self.current_batch_size |
|
|
self.current_batch_size = new_batch_size |
|
|
self.successful_batches = 0 |
|
|
logger.info(f"Batch size increased: {old_size} -> {self.current_batch_size}") |
|
|
|
|
|
def on_oom(self): |
|
|
"""Called when OOM error occurs.""" |
|
|
new_batch_size = int(self.current_batch_size * self.decrease_factor) |
|
|
new_batch_size = max(new_batch_size, self.min_batch_size) |
|
|
|
|
|
if new_batch_size < self.current_batch_size: |
|
|
old_size = self.current_batch_size |
|
|
self.current_batch_size = new_batch_size |
|
|
self.successful_batches = 0 |
|
|
logger.warning( |
|
|
f"OOM detected, batch size decreased: {old_size} -> {self.current_batch_size}" |
|
|
) |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
def get_stats(self) -> dict: |
|
|
"""Get sampler statistics.""" |
|
|
return { |
|
|
"current_batch_size": self.current_batch_size, |
|
|
"successful_batches": self.successful_batches, |
|
|
"total_batches": self.total_batches, |
|
|
"success_rate": self.successful_batches / max(self.total_batches, 1), |
|
|
} |
|
|
|
|
|
|
|
|
class AdaptiveDataLoader: |
|
|
""" |
|
|
DataLoader wrapper with dynamic batch sizing. |
|
|
|
|
|
Automatically adjusts batch size during training. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dataset, |
|
|
initial_batch_size: int = 1, |
|
|
max_batch_size: int = 8, |
|
|
**dataloader_kwargs, |
|
|
): |
|
|
""" |
|
|
Args: |
|
|
dataset: Dataset |
|
|
initial_batch_size: Starting batch size |
|
|
max_batch_size: Maximum batch size |
|
|
**dataloader_kwargs: Additional DataLoader arguments |
|
|
""" |
|
|
self.dataset = dataset |
|
|
self.initial_batch_size = initial_batch_size |
|
|
self.max_batch_size = max_batch_size |
|
|
self.dataloader_kwargs = dataloader_kwargs |
|
|
|
|
|
self.sampler = DynamicBatchSampler( |
|
|
dataset, |
|
|
initial_batch_size=initial_batch_size, |
|
|
max_batch_size=max_batch_size, |
|
|
) |
|
|
|
|
|
self.dataloader = None |
|
|
self._create_dataloader() |
|
|
|
|
|
def _create_dataloader(self): |
|
|
"""Create DataLoader with current batch size.""" |
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
self.dataloader = DataLoader( |
|
|
self.dataset, |
|
|
batch_size=self.sampler.get_batch_size(), |
|
|
**self.dataloader_kwargs, |
|
|
) |
|
|
|
|
|
def __iter__(self): |
|
|
"""Iterate over dataloader with error handling.""" |
|
|
iterator = iter(self.dataloader) |
|
|
|
|
|
while True: |
|
|
try: |
|
|
batch = next(iterator) |
|
|
yield batch |
|
|
self.sampler.on_success() |
|
|
except StopIteration: |
|
|
break |
|
|
except RuntimeError as e: |
|
|
if "out of memory" in str(e): |
|
|
self.sampler.on_oom() |
|
|
|
|
|
self._create_dataloader() |
|
|
iterator = iter(self.dataloader) |
|
|
|
|
|
try: |
|
|
batch = next(iterator) |
|
|
yield batch |
|
|
self.sampler.on_success() |
|
|
except StopIteration: |
|
|
break |
|
|
else: |
|
|
raise |
|
|
|
|
|
def __len__(self): |
|
|
"""Length of dataloader.""" |
|
|
return len(self.dataloader) |
|
|
|
|
|
def get_batch_size(self) -> int: |
|
|
"""Get current batch size.""" |
|
|
return self.sampler.get_batch_size() |
|
|
|
|
|
def get_stats(self) -> dict: |
|
|
"""Get statistics.""" |
|
|
return self.sampler.get_stats() |
|
|
|