3d_model / ylff /utils /data_loading_utils.py
Azan
Clean deployment build (Squashed)
7a87926
"""
Advanced data loading optimizations.
Features:
- Prefetching with multiple workers
- Memory-mapped datasets
- Smart batching strategies
- Data pipeline profiling
"""
import logging
import time
from typing import Dict, Optional
from torch.utils.data import DataLoader, Dataset
logger = logging.getLogger(__name__)
class PrefetchDataLoader:
"""
DataLoader with advanced prefetching and caching.
Wraps a standard DataLoader with additional optimizations:
- Multiple prefetch buffers
- Automatic batch size tuning
- Memory usage monitoring
"""
def __init__(
self,
dataloader: DataLoader,
prefetch_factor: int = 4,
pin_memory: bool = True,
non_blocking: bool = True,
):
"""
Initialize prefetch DataLoader.
Args:
dataloader: Base DataLoader to wrap
prefetch_factor: Number of batches to prefetch
pin_memory: Pin memory for faster GPU transfer
non_blocking: Use non-blocking transfers
"""
self.dataloader = dataloader
self.prefetch_factor = prefetch_factor
self.pin_memory = pin_memory
self.non_blocking = non_blocking
def __iter__(self):
"""Iterate with prefetching."""
return iter(self.dataloader)
def __len__(self):
"""Return length of underlying DataLoader."""
return len(self.dataloader)
def optimize_dataloader(
dataset: Dataset,
batch_size: int = 1,
num_workers: Optional[int] = None,
pin_memory: bool = True,
persistent_workers: bool = True,
prefetch_factor: int = 4,
shuffle: bool = True,
device: str = "cuda",
) -> DataLoader:
"""
Create optimized DataLoader with best practices.
Args:
dataset: Dataset to load
batch_size: Batch size
num_workers: Number of worker processes (None = auto)
pin_memory: Pin memory for faster GPU transfer
persistent_workers: Keep workers alive between epochs
prefetch_factor: Number of batches to prefetch per worker
shuffle: Shuffle dataset
device: Target device
Returns:
Optimized DataLoader
"""
import os
# Auto-detect optimal number of workers
if num_workers is None:
cpu_count = os.cpu_count() or 1
# Use 2-4 workers, but not more than CPU count
num_workers = min(4, max(2, cpu_count // 2))
# Adjust prefetch factor based on batch size
if batch_size > 4:
prefetch_factor = max(2, prefetch_factor // 2)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=pin_memory and device == "cuda",
persistent_workers=persistent_workers if num_workers > 0 else False,
prefetch_factor=prefetch_factor if num_workers > 0 else None,
drop_last=False,
)
logger.info(
f"Created optimized DataLoader: "
f"batch_size={batch_size}, "
f"num_workers={num_workers}, "
f"prefetch_factor={prefetch_factor}, "
f"pin_memory={pin_memory}"
)
return dataloader
def profile_dataloader(
dataloader: DataLoader,
num_batches: int = 10,
device: str = "cuda",
) -> Dict[str, float]:
"""
Profile DataLoader performance.
Args:
dataloader: DataLoader to profile
num_batches: Number of batches to profile
device: Target device
Returns:
Dict with profiling results
"""
logger.info(f"Profiling DataLoader ({num_batches} batches)...")
times = []
data_times = []
transfer_times = []
start_time = time.time()
for i, batch in enumerate(dataloader):
if i >= num_batches:
break
batch_start = time.time()
# Measure data loading time
data_time = batch_start - (times[-1][1] if times else start_time)
# Measure transfer time
if device == "cuda":
transfer_start = time.time()
# Move batch to device
if isinstance(batch, dict):
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
elif isinstance(batch, (list, tuple)):
batch = [x.to(device, non_blocking=True) for x in batch]
else:
batch = batch.to(device, non_blocking=True)
transfer_time = time.time() - transfer_start
else:
transfer_time = 0.0
batch_time = time.time() - batch_start
times.append((batch_time, time.time()))
data_times.append(data_time)
transfer_times.append(transfer_time)
total_time = time.time() - start_time
results = {
"total_time": total_time,
"avg_batch_time": sum(t[0] for t in times) / len(times),
"avg_data_time": sum(data_times) / len(data_times),
"avg_transfer_time": sum(transfer_times) / len(transfer_times),
"batches_per_sec": len(times) / total_time,
"data_loading_ratio": sum(data_times) / total_time,
"transfer_ratio": sum(transfer_times) / total_time,
}
logger.info("DataLoader Profile Results:")
logger.info(f" Total time: {total_time:.2f}s")
logger.info(f" Avg batch time: {results['avg_batch_time'] * 1000:.2f}ms")
logger.info(f" Avg data loading: {results['avg_data_time'] * 1000:.2f}ms")
logger.info(f" Avg transfer: {results['avg_transfer_time'] * 1000:.2f}ms")
logger.info(f" Batches/sec: {results['batches_per_sec']:.2f}")
logger.info(f" Data loading ratio: {results['data_loading_ratio'] * 100:.1f}%")
logger.info(f" Transfer ratio: {results['transfer_ratio'] * 100:.1f}%")
return results
def find_optimal_num_workers(
dataset: Dataset,
batch_size: int = 1,
max_workers: int = 8,
num_test_batches: int = 10,
device: str = "cuda",
) -> int:
"""
Find optimal number of workers for DataLoader.
Args:
dataset: Dataset to test
batch_size: Batch size
max_workers: Maximum workers to test
num_test_batches: Number of batches to test per configuration
device: Target device
Returns:
Optimal number of workers
"""
logger.info(f"Finding optimal number of workers (max={max_workers})...")
best_workers = 0
best_time = float("inf")
for num_workers in range(0, max_workers + 1):
dataloader = DataLoader(
dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=device == "cuda",
prefetch_factor=2 if num_workers > 0 else None,
)
# Profile
start_time = time.time()
for i, _ in enumerate(dataloader):
if i >= num_test_batches:
break
elapsed = time.time() - start_time
logger.info(f" {num_workers} workers: {elapsed:.2f}s")
if elapsed < best_time:
best_time = elapsed
best_workers = num_workers
logger.info(f"Optimal number of workers: {best_workers} ({best_time:.2f}s)")
return best_workers