File size: 5,950 Bytes
7a87926 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
"""
Dynamic batch sizing utilities.
Automatically adjusts batch size to maximize GPU utilization while avoiding OOM errors.
"""
import logging
import torch
logger = logging.getLogger(__name__)
class DynamicBatchSampler:
"""
Dynamic batch sampler that adjusts batch size based on GPU memory.
Starts with small batch size and gradually increases if successful.
Decreases if OOM occurs.
"""
def __init__(
self,
dataset,
initial_batch_size: int = 1,
max_batch_size: int = 8,
min_batch_size: int = 1,
increase_factor: float = 2.0,
decrease_factor: float = 0.5,
patience: int = 5,
):
"""
Args:
dataset: Dataset to sample from
initial_batch_size: Starting batch size
max_batch_size: Maximum batch size
min_batch_size: Minimum batch size
increase_factor: Factor to increase batch size
decrease_factor: Factor to decrease batch size on OOM
patience: Number of successful batches before increasing
"""
self.dataset = dataset
self.current_batch_size = initial_batch_size
self.max_batch_size = max_batch_size
self.min_batch_size = min_batch_size
self.increase_factor = increase_factor
self.decrease_factor = decrease_factor
self.patience = patience
self.successful_batches = 0
self.total_batches = 0
logger.info(
f"DynamicBatchSampler initialized: "
f"initial={initial_batch_size}, "
f"max={max_batch_size}, "
f"min={min_batch_size}"
)
def get_batch_size(self) -> int:
"""Get current batch size."""
return self.current_batch_size
def on_success(self):
"""Called after successful batch processing."""
self.successful_batches += 1
self.total_batches += 1
# Increase batch size if we've had enough successes
if self.successful_batches >= self.patience:
new_batch_size = int(self.current_batch_size * self.increase_factor)
if new_batch_size <= self.max_batch_size:
old_size = self.current_batch_size
self.current_batch_size = new_batch_size
self.successful_batches = 0
logger.info(f"Batch size increased: {old_size} -> {self.current_batch_size}")
def on_oom(self):
"""Called when OOM error occurs."""
new_batch_size = int(self.current_batch_size * self.decrease_factor)
new_batch_size = max(new_batch_size, self.min_batch_size)
if new_batch_size < self.current_batch_size:
old_size = self.current_batch_size
self.current_batch_size = new_batch_size
self.successful_batches = 0
logger.warning(
f"OOM detected, batch size decreased: {old_size} -> {self.current_batch_size}"
)
# Clear cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
def get_stats(self) -> dict:
"""Get sampler statistics."""
return {
"current_batch_size": self.current_batch_size,
"successful_batches": self.successful_batches,
"total_batches": self.total_batches,
"success_rate": self.successful_batches / max(self.total_batches, 1),
}
class AdaptiveDataLoader:
"""
DataLoader wrapper with dynamic batch sizing.
Automatically adjusts batch size during training.
"""
def __init__(
self,
dataset,
initial_batch_size: int = 1,
max_batch_size: int = 8,
**dataloader_kwargs,
):
"""
Args:
dataset: Dataset
initial_batch_size: Starting batch size
max_batch_size: Maximum batch size
**dataloader_kwargs: Additional DataLoader arguments
"""
self.dataset = dataset
self.initial_batch_size = initial_batch_size
self.max_batch_size = max_batch_size
self.dataloader_kwargs = dataloader_kwargs
self.sampler = DynamicBatchSampler(
dataset,
initial_batch_size=initial_batch_size,
max_batch_size=max_batch_size,
)
self.dataloader = None
self._create_dataloader()
def _create_dataloader(self):
"""Create DataLoader with current batch size."""
from torch.utils.data import DataLoader
self.dataloader = DataLoader(
self.dataset,
batch_size=self.sampler.get_batch_size(),
**self.dataloader_kwargs,
)
def __iter__(self):
"""Iterate over dataloader with error handling."""
iterator = iter(self.dataloader)
while True:
try:
batch = next(iterator)
yield batch
self.sampler.on_success()
except StopIteration:
break
except RuntimeError as e:
if "out of memory" in str(e):
self.sampler.on_oom()
# Recreate dataloader with new batch size
self._create_dataloader()
iterator = iter(self.dataloader)
# Retry with smaller batch
try:
batch = next(iterator)
yield batch
self.sampler.on_success()
except StopIteration:
break
else:
raise
def __len__(self):
"""Length of dataloader."""
return len(self.dataloader)
def get_batch_size(self) -> int:
"""Get current batch size."""
return self.sampler.get_batch_size()
def get_stats(self) -> dict:
"""Get statistics."""
return self.sampler.get_stats()
|