|
|
""" |
|
|
Advanced data loading optimizations. |
|
|
|
|
|
Features: |
|
|
- Prefetching with multiple workers |
|
|
- Memory-mapped datasets |
|
|
- Smart batching strategies |
|
|
- Data pipeline profiling |
|
|
""" |
|
|
|
|
|
import logging |
|
|
import time |
|
|
from typing import Dict, Optional |
|
|
from torch.utils.data import DataLoader, Dataset |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class PrefetchDataLoader: |
|
|
""" |
|
|
DataLoader with advanced prefetching and caching. |
|
|
|
|
|
Wraps a standard DataLoader with additional optimizations: |
|
|
- Multiple prefetch buffers |
|
|
- Automatic batch size tuning |
|
|
- Memory usage monitoring |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dataloader: DataLoader, |
|
|
prefetch_factor: int = 4, |
|
|
pin_memory: bool = True, |
|
|
non_blocking: bool = True, |
|
|
): |
|
|
""" |
|
|
Initialize prefetch DataLoader. |
|
|
|
|
|
Args: |
|
|
dataloader: Base DataLoader to wrap |
|
|
prefetch_factor: Number of batches to prefetch |
|
|
pin_memory: Pin memory for faster GPU transfer |
|
|
non_blocking: Use non-blocking transfers |
|
|
""" |
|
|
self.dataloader = dataloader |
|
|
self.prefetch_factor = prefetch_factor |
|
|
self.pin_memory = pin_memory |
|
|
self.non_blocking = non_blocking |
|
|
|
|
|
def __iter__(self): |
|
|
"""Iterate with prefetching.""" |
|
|
return iter(self.dataloader) |
|
|
|
|
|
def __len__(self): |
|
|
"""Return length of underlying DataLoader.""" |
|
|
return len(self.dataloader) |
|
|
|
|
|
|
|
|
def optimize_dataloader( |
|
|
dataset: Dataset, |
|
|
batch_size: int = 1, |
|
|
num_workers: Optional[int] = None, |
|
|
pin_memory: bool = True, |
|
|
persistent_workers: bool = True, |
|
|
prefetch_factor: int = 4, |
|
|
shuffle: bool = True, |
|
|
device: str = "cuda", |
|
|
) -> DataLoader: |
|
|
""" |
|
|
Create optimized DataLoader with best practices. |
|
|
|
|
|
Args: |
|
|
dataset: Dataset to load |
|
|
batch_size: Batch size |
|
|
num_workers: Number of worker processes (None = auto) |
|
|
pin_memory: Pin memory for faster GPU transfer |
|
|
persistent_workers: Keep workers alive between epochs |
|
|
prefetch_factor: Number of batches to prefetch per worker |
|
|
shuffle: Shuffle dataset |
|
|
device: Target device |
|
|
|
|
|
Returns: |
|
|
Optimized DataLoader |
|
|
""" |
|
|
import os |
|
|
|
|
|
|
|
|
if num_workers is None: |
|
|
cpu_count = os.cpu_count() or 1 |
|
|
|
|
|
num_workers = min(4, max(2, cpu_count // 2)) |
|
|
|
|
|
|
|
|
if batch_size > 4: |
|
|
prefetch_factor = max(2, prefetch_factor // 2) |
|
|
|
|
|
dataloader = DataLoader( |
|
|
dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=shuffle, |
|
|
num_workers=num_workers, |
|
|
pin_memory=pin_memory and device == "cuda", |
|
|
persistent_workers=persistent_workers if num_workers > 0 else False, |
|
|
prefetch_factor=prefetch_factor if num_workers > 0 else None, |
|
|
drop_last=False, |
|
|
) |
|
|
|
|
|
logger.info( |
|
|
f"Created optimized DataLoader: " |
|
|
f"batch_size={batch_size}, " |
|
|
f"num_workers={num_workers}, " |
|
|
f"prefetch_factor={prefetch_factor}, " |
|
|
f"pin_memory={pin_memory}" |
|
|
) |
|
|
|
|
|
return dataloader |
|
|
|
|
|
|
|
|
def profile_dataloader( |
|
|
dataloader: DataLoader, |
|
|
num_batches: int = 10, |
|
|
device: str = "cuda", |
|
|
) -> Dict[str, float]: |
|
|
""" |
|
|
Profile DataLoader performance. |
|
|
|
|
|
Args: |
|
|
dataloader: DataLoader to profile |
|
|
num_batches: Number of batches to profile |
|
|
device: Target device |
|
|
|
|
|
Returns: |
|
|
Dict with profiling results |
|
|
""" |
|
|
logger.info(f"Profiling DataLoader ({num_batches} batches)...") |
|
|
|
|
|
times = [] |
|
|
data_times = [] |
|
|
transfer_times = [] |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
for i, batch in enumerate(dataloader): |
|
|
if i >= num_batches: |
|
|
break |
|
|
|
|
|
batch_start = time.time() |
|
|
|
|
|
|
|
|
data_time = batch_start - (times[-1][1] if times else start_time) |
|
|
|
|
|
|
|
|
if device == "cuda": |
|
|
transfer_start = time.time() |
|
|
|
|
|
if isinstance(batch, dict): |
|
|
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()} |
|
|
elif isinstance(batch, (list, tuple)): |
|
|
batch = [x.to(device, non_blocking=True) for x in batch] |
|
|
else: |
|
|
batch = batch.to(device, non_blocking=True) |
|
|
transfer_time = time.time() - transfer_start |
|
|
else: |
|
|
transfer_time = 0.0 |
|
|
|
|
|
batch_time = time.time() - batch_start |
|
|
|
|
|
times.append((batch_time, time.time())) |
|
|
data_times.append(data_time) |
|
|
transfer_times.append(transfer_time) |
|
|
|
|
|
total_time = time.time() - start_time |
|
|
|
|
|
results = { |
|
|
"total_time": total_time, |
|
|
"avg_batch_time": sum(t[0] for t in times) / len(times), |
|
|
"avg_data_time": sum(data_times) / len(data_times), |
|
|
"avg_transfer_time": sum(transfer_times) / len(transfer_times), |
|
|
"batches_per_sec": len(times) / total_time, |
|
|
"data_loading_ratio": sum(data_times) / total_time, |
|
|
"transfer_ratio": sum(transfer_times) / total_time, |
|
|
} |
|
|
|
|
|
logger.info("DataLoader Profile Results:") |
|
|
logger.info(f" Total time: {total_time:.2f}s") |
|
|
logger.info(f" Avg batch time: {results['avg_batch_time'] * 1000:.2f}ms") |
|
|
logger.info(f" Avg data loading: {results['avg_data_time'] * 1000:.2f}ms") |
|
|
logger.info(f" Avg transfer: {results['avg_transfer_time'] * 1000:.2f}ms") |
|
|
logger.info(f" Batches/sec: {results['batches_per_sec']:.2f}") |
|
|
logger.info(f" Data loading ratio: {results['data_loading_ratio'] * 100:.1f}%") |
|
|
logger.info(f" Transfer ratio: {results['transfer_ratio'] * 100:.1f}%") |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def find_optimal_num_workers( |
|
|
dataset: Dataset, |
|
|
batch_size: int = 1, |
|
|
max_workers: int = 8, |
|
|
num_test_batches: int = 10, |
|
|
device: str = "cuda", |
|
|
) -> int: |
|
|
""" |
|
|
Find optimal number of workers for DataLoader. |
|
|
|
|
|
Args: |
|
|
dataset: Dataset to test |
|
|
batch_size: Batch size |
|
|
max_workers: Maximum workers to test |
|
|
num_test_batches: Number of batches to test per configuration |
|
|
device: Target device |
|
|
|
|
|
Returns: |
|
|
Optimal number of workers |
|
|
""" |
|
|
logger.info(f"Finding optimal number of workers (max={max_workers})...") |
|
|
|
|
|
best_workers = 0 |
|
|
best_time = float("inf") |
|
|
|
|
|
for num_workers in range(0, max_workers + 1): |
|
|
dataloader = DataLoader( |
|
|
dataset, |
|
|
batch_size=batch_size, |
|
|
num_workers=num_workers, |
|
|
pin_memory=device == "cuda", |
|
|
prefetch_factor=2 if num_workers > 0 else None, |
|
|
) |
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
for i, _ in enumerate(dataloader): |
|
|
if i >= num_test_batches: |
|
|
break |
|
|
elapsed = time.time() - start_time |
|
|
|
|
|
logger.info(f" {num_workers} workers: {elapsed:.2f}s") |
|
|
|
|
|
if elapsed < best_time: |
|
|
best_time = elapsed |
|
|
best_workers = num_workers |
|
|
|
|
|
logger.info(f"Optimal number of workers: {best_workers} ({best_time:.2f}s)") |
|
|
|
|
|
return best_workers |
|
|
|