""" Advanced data loading optimizations. Features: - Prefetching with multiple workers - Memory-mapped datasets - Smart batching strategies - Data pipeline profiling """ import logging import time from typing import Dict, Optional from torch.utils.data import DataLoader, Dataset logger = logging.getLogger(__name__) class PrefetchDataLoader: """ DataLoader with advanced prefetching and caching. Wraps a standard DataLoader with additional optimizations: - Multiple prefetch buffers - Automatic batch size tuning - Memory usage monitoring """ def __init__( self, dataloader: DataLoader, prefetch_factor: int = 4, pin_memory: bool = True, non_blocking: bool = True, ): """ Initialize prefetch DataLoader. Args: dataloader: Base DataLoader to wrap prefetch_factor: Number of batches to prefetch pin_memory: Pin memory for faster GPU transfer non_blocking: Use non-blocking transfers """ self.dataloader = dataloader self.prefetch_factor = prefetch_factor self.pin_memory = pin_memory self.non_blocking = non_blocking def __iter__(self): """Iterate with prefetching.""" return iter(self.dataloader) def __len__(self): """Return length of underlying DataLoader.""" return len(self.dataloader) def optimize_dataloader( dataset: Dataset, batch_size: int = 1, num_workers: Optional[int] = None, pin_memory: bool = True, persistent_workers: bool = True, prefetch_factor: int = 4, shuffle: bool = True, device: str = "cuda", ) -> DataLoader: """ Create optimized DataLoader with best practices. Args: dataset: Dataset to load batch_size: Batch size num_workers: Number of worker processes (None = auto) pin_memory: Pin memory for faster GPU transfer persistent_workers: Keep workers alive between epochs prefetch_factor: Number of batches to prefetch per worker shuffle: Shuffle dataset device: Target device Returns: Optimized DataLoader """ import os # Auto-detect optimal number of workers if num_workers is None: cpu_count = os.cpu_count() or 1 # Use 2-4 workers, but not more than CPU count num_workers = min(4, max(2, cpu_count // 2)) # Adjust prefetch factor based on batch size if batch_size > 4: prefetch_factor = max(2, prefetch_factor // 2) dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory and device == "cuda", persistent_workers=persistent_workers if num_workers > 0 else False, prefetch_factor=prefetch_factor if num_workers > 0 else None, drop_last=False, ) logger.info( f"Created optimized DataLoader: " f"batch_size={batch_size}, " f"num_workers={num_workers}, " f"prefetch_factor={prefetch_factor}, " f"pin_memory={pin_memory}" ) return dataloader def profile_dataloader( dataloader: DataLoader, num_batches: int = 10, device: str = "cuda", ) -> Dict[str, float]: """ Profile DataLoader performance. Args: dataloader: DataLoader to profile num_batches: Number of batches to profile device: Target device Returns: Dict with profiling results """ logger.info(f"Profiling DataLoader ({num_batches} batches)...") times = [] data_times = [] transfer_times = [] start_time = time.time() for i, batch in enumerate(dataloader): if i >= num_batches: break batch_start = time.time() # Measure data loading time data_time = batch_start - (times[-1][1] if times else start_time) # Measure transfer time if device == "cuda": transfer_start = time.time() # Move batch to device if isinstance(batch, dict): batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()} elif isinstance(batch, (list, tuple)): batch = [x.to(device, non_blocking=True) for x in batch] else: batch = batch.to(device, non_blocking=True) transfer_time = time.time() - transfer_start else: transfer_time = 0.0 batch_time = time.time() - batch_start times.append((batch_time, time.time())) data_times.append(data_time) transfer_times.append(transfer_time) total_time = time.time() - start_time results = { "total_time": total_time, "avg_batch_time": sum(t[0] for t in times) / len(times), "avg_data_time": sum(data_times) / len(data_times), "avg_transfer_time": sum(transfer_times) / len(transfer_times), "batches_per_sec": len(times) / total_time, "data_loading_ratio": sum(data_times) / total_time, "transfer_ratio": sum(transfer_times) / total_time, } logger.info("DataLoader Profile Results:") logger.info(f" Total time: {total_time:.2f}s") logger.info(f" Avg batch time: {results['avg_batch_time'] * 1000:.2f}ms") logger.info(f" Avg data loading: {results['avg_data_time'] * 1000:.2f}ms") logger.info(f" Avg transfer: {results['avg_transfer_time'] * 1000:.2f}ms") logger.info(f" Batches/sec: {results['batches_per_sec']:.2f}") logger.info(f" Data loading ratio: {results['data_loading_ratio'] * 100:.1f}%") logger.info(f" Transfer ratio: {results['transfer_ratio'] * 100:.1f}%") return results def find_optimal_num_workers( dataset: Dataset, batch_size: int = 1, max_workers: int = 8, num_test_batches: int = 10, device: str = "cuda", ) -> int: """ Find optimal number of workers for DataLoader. Args: dataset: Dataset to test batch_size: Batch size max_workers: Maximum workers to test num_test_batches: Number of batches to test per configuration device: Target device Returns: Optimal number of workers """ logger.info(f"Finding optimal number of workers (max={max_workers})...") best_workers = 0 best_time = float("inf") for num_workers in range(0, max_workers + 1): dataloader = DataLoader( dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=device == "cuda", prefetch_factor=2 if num_workers > 0 else None, ) # Profile start_time = time.time() for i, _ in enumerate(dataloader): if i >= num_test_batches: break elapsed = time.time() - start_time logger.info(f" {num_workers} workers: {elapsed:.2f}s") if elapsed < best_time: best_time = elapsed best_workers = num_workers logger.info(f"Optimal number of workers: {best_workers} ({best_time:.2f}s)") return best_workers