|
|
|
|
|
"""
|
|
|
Optimized Data Loader for Training
|
|
|
|
|
|
This module provides an optimized data loader with prefetching, caching,
|
|
|
and efficient batch processing to improve training performance.
|
|
|
|
|
|
Author: Louis Chua Bean Chong
|
|
|
License: GPLv3
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from torch.utils.data import DataLoader, Dataset, Sampler
|
|
|
from typing import Optional, List, Tuple, Dict, Any
|
|
|
import numpy as np
|
|
|
import threading
|
|
|
import queue
|
|
|
import time
|
|
|
from collections import deque
|
|
|
import psutil
|
|
|
import os
|
|
|
|
|
|
|
|
|
class OptimizedDataset(Dataset):
|
|
|
"""
|
|
|
Optimized dataset with caching and memory management.
|
|
|
|
|
|
This dataset provides efficient data loading with optional caching
|
|
|
and memory management to improve training performance.
|
|
|
"""
|
|
|
|
|
|
def __init__(self,
|
|
|
data: torch.Tensor,
|
|
|
targets: torch.Tensor,
|
|
|
cache_size: Optional[int] = None,
|
|
|
pin_memory: bool = True):
|
|
|
"""
|
|
|
Initialize optimized dataset.
|
|
|
|
|
|
Args:
|
|
|
data: Input data tensor
|
|
|
targets: Target tensor
|
|
|
cache_size: Number of samples to cache in memory
|
|
|
pin_memory: Whether to pin memory for faster GPU transfer
|
|
|
"""
|
|
|
self.data = data
|
|
|
self.targets = targets
|
|
|
self.cache_size = cache_size
|
|
|
self.pin_memory = pin_memory
|
|
|
|
|
|
|
|
|
self.cache = {}
|
|
|
self.cache_hits = 0
|
|
|
self.cache_misses = 0
|
|
|
|
|
|
if cache_size and cache_size > 0:
|
|
|
print(f"Initializing cache with {cache_size} samples")
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.data)
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
|
|
if self.cache_size and idx in self.cache:
|
|
|
self.cache_hits += 1
|
|
|
return self.cache[idx]
|
|
|
|
|
|
self.cache_misses += 1
|
|
|
|
|
|
|
|
|
sample_data = self.data[idx]
|
|
|
sample_target = self.targets[idx]
|
|
|
|
|
|
|
|
|
if self.pin_memory and torch.cuda.is_available():
|
|
|
sample_data = sample_data.pin_memory()
|
|
|
sample_target = sample_target.pin_memory()
|
|
|
|
|
|
|
|
|
if self.cache_size and len(self.cache) < self.cache_size:
|
|
|
self.cache[idx] = (sample_data, sample_target)
|
|
|
|
|
|
return sample_data, sample_target
|
|
|
|
|
|
def get_cache_stats(self) -> Dict[str, Any]:
|
|
|
"""Get cache statistics."""
|
|
|
total_requests = self.cache_hits + self.cache_misses
|
|
|
hit_rate = self.cache_hits / total_requests if total_requests > 0 else 0
|
|
|
|
|
|
return {
|
|
|
"cache_hits": self.cache_hits,
|
|
|
"cache_misses": self.cache_misses,
|
|
|
"hit_rate": hit_rate,
|
|
|
"cache_size": len(self.cache),
|
|
|
"max_cache_size": self.cache_size
|
|
|
}
|
|
|
|
|
|
|
|
|
class PrefetchDataLoader:
|
|
|
"""
|
|
|
Data loader with prefetching for improved performance.
|
|
|
|
|
|
This data loader uses background threads to prefetch data,
|
|
|
reducing the time spent waiting for data during training.
|
|
|
"""
|
|
|
|
|
|
def __init__(self,
|
|
|
dataset: Dataset,
|
|
|
batch_size: int = 32,
|
|
|
num_workers: int = 4,
|
|
|
prefetch_factor: int = 2,
|
|
|
pin_memory: bool = True,
|
|
|
shuffle: bool = True,
|
|
|
drop_last: bool = False):
|
|
|
"""
|
|
|
Initialize prefetch data loader.
|
|
|
|
|
|
Args:
|
|
|
dataset: Dataset to load
|
|
|
batch_size: Batch size
|
|
|
num_workers: Number of worker processes
|
|
|
prefetch_factor: Number of batches to prefetch
|
|
|
pin_memory: Whether to pin memory
|
|
|
shuffle: Whether to shuffle data
|
|
|
drop_last: Whether to drop incomplete batches
|
|
|
"""
|
|
|
self.dataset = dataset
|
|
|
self.batch_size = batch_size
|
|
|
self.num_workers = num_workers
|
|
|
self.prefetch_factor = prefetch_factor
|
|
|
self.pin_memory = pin_memory
|
|
|
self.shuffle = shuffle
|
|
|
self.drop_last = drop_last
|
|
|
|
|
|
|
|
|
self.data_loader = DataLoader(
|
|
|
dataset=dataset,
|
|
|
batch_size=batch_size,
|
|
|
shuffle=shuffle,
|
|
|
num_workers=num_workers,
|
|
|
pin_memory=pin_memory,
|
|
|
drop_last=drop_last,
|
|
|
persistent_workers=True if num_workers > 0 else False
|
|
|
)
|
|
|
|
|
|
|
|
|
self.prefetch_queue = queue.Queue(maxsize=prefetch_factor)
|
|
|
self.prefetch_thread = None
|
|
|
self.stop_prefetch = False
|
|
|
|
|
|
|
|
|
self._start_prefetch()
|
|
|
|
|
|
print(f"PrefetchDataLoader initialized with {num_workers} workers")
|
|
|
|
|
|
def _start_prefetch(self):
|
|
|
"""Start prefetching thread."""
|
|
|
if self.prefetch_factor > 0:
|
|
|
self.prefetch_thread = threading.Thread(target=self._prefetch_worker)
|
|
|
self.prefetch_thread.daemon = True
|
|
|
self.prefetch_thread.start()
|
|
|
|
|
|
def _prefetch_worker(self):
|
|
|
"""Worker thread for prefetching data."""
|
|
|
try:
|
|
|
for batch in self.data_loader:
|
|
|
if self.stop_prefetch:
|
|
|
break
|
|
|
|
|
|
|
|
|
self.prefetch_queue.put(batch, block=True)
|
|
|
except Exception as e:
|
|
|
print(f"Prefetch worker error: {e}")
|
|
|
|
|
|
def __iter__(self):
|
|
|
"""Iterate over prefetched batches."""
|
|
|
return self
|
|
|
|
|
|
def __next__(self):
|
|
|
"""Get next batch from prefetch queue."""
|
|
|
if self.stop_prefetch:
|
|
|
raise StopIteration
|
|
|
|
|
|
try:
|
|
|
|
|
|
batch = self.prefetch_queue.get(timeout=1.0)
|
|
|
return batch
|
|
|
except queue.Empty:
|
|
|
|
|
|
return next(self.data_loader.__iter__())
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.data_loader)
|
|
|
|
|
|
def stop(self):
|
|
|
"""Stop prefetching."""
|
|
|
self.stop_prefetch = True
|
|
|
if self.prefetch_thread:
|
|
|
self.prefetch_thread.join()
|
|
|
|
|
|
|
|
|
class DynamicBatchSampler(Sampler):
|
|
|
"""
|
|
|
Dynamic batch sampler that adjusts batch size based on memory availability.
|
|
|
|
|
|
This sampler monitors system memory and adjusts batch sizes dynamically
|
|
|
to optimize memory usage and training performance.
|
|
|
"""
|
|
|
|
|
|
def __init__(self,
|
|
|
dataset_size: int,
|
|
|
base_batch_size: int = 32,
|
|
|
max_batch_size: int = 128,
|
|
|
memory_threshold: float = 0.8,
|
|
|
adjustment_factor: float = 1.2):
|
|
|
"""
|
|
|
Initialize dynamic batch sampler.
|
|
|
|
|
|
Args:
|
|
|
dataset_size: Size of the dataset
|
|
|
base_batch_size: Base batch size
|
|
|
max_batch_size: Maximum batch size
|
|
|
memory_threshold: Memory usage threshold for adjustment
|
|
|
adjustment_factor: Factor for batch size adjustment
|
|
|
"""
|
|
|
self.dataset_size = dataset_size
|
|
|
self.base_batch_size = base_batch_size
|
|
|
self.max_batch_size = max_batch_size
|
|
|
self.memory_threshold = memory_threshold
|
|
|
self.adjustment_factor = adjustment_factor
|
|
|
|
|
|
self.current_batch_size = base_batch_size
|
|
|
self.batch_history = deque(maxlen=10)
|
|
|
|
|
|
print(f"DynamicBatchSampler initialized with base batch size: {base_batch_size}")
|
|
|
|
|
|
def _get_memory_usage(self) -> float:
|
|
|
"""Get current memory usage as a fraction."""
|
|
|
memory = psutil.virtual_memory()
|
|
|
return memory.percent / 100.0
|
|
|
|
|
|
def _adjust_batch_size(self):
|
|
|
"""Adjust batch size based on memory usage."""
|
|
|
memory_usage = self._get_memory_usage()
|
|
|
|
|
|
if memory_usage > self.memory_threshold:
|
|
|
|
|
|
self.current_batch_size = max(
|
|
|
self.base_batch_size,
|
|
|
int(self.current_batch_size / self.adjustment_factor)
|
|
|
)
|
|
|
else:
|
|
|
|
|
|
self.current_batch_size = min(
|
|
|
self.max_batch_size,
|
|
|
int(self.current_batch_size * self.adjustment_factor)
|
|
|
)
|
|
|
|
|
|
self.batch_history.append(self.current_batch_size)
|
|
|
|
|
|
def __iter__(self):
|
|
|
"""Generate batch indices."""
|
|
|
indices = list(range(self.dataset_size))
|
|
|
|
|
|
|
|
|
np.random.shuffle(indices)
|
|
|
|
|
|
|
|
|
for i in range(0, len(indices), self.current_batch_size):
|
|
|
batch_indices = indices[i:i + self.current_batch_size]
|
|
|
|
|
|
|
|
|
self._adjust_batch_size()
|
|
|
|
|
|
yield batch_indices
|
|
|
|
|
|
def __len__(self):
|
|
|
return (self.dataset_size + self.current_batch_size - 1) // self.current_batch_size
|
|
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
|
"""Get sampler statistics."""
|
|
|
return {
|
|
|
"current_batch_size": self.current_batch_size,
|
|
|
"base_batch_size": self.base_batch_size,
|
|
|
"max_batch_size": self.max_batch_size,
|
|
|
"memory_usage": self._get_memory_usage(),
|
|
|
"batch_history": list(self.batch_history)
|
|
|
}
|
|
|
|
|
|
|
|
|
class OptimizedDataLoader:
|
|
|
"""
|
|
|
High-performance data loader with multiple optimizations.
|
|
|
|
|
|
This data loader combines multiple optimization techniques:
|
|
|
- Prefetching with background threads
|
|
|
- Dynamic batch sizing
|
|
|
- Memory pinning
|
|
|
- Caching
|
|
|
- Efficient memory management
|
|
|
"""
|
|
|
|
|
|
def __init__(self,
|
|
|
dataset: Dataset,
|
|
|
batch_size: int = 32,
|
|
|
num_workers: int = 4,
|
|
|
prefetch_factor: int = 2,
|
|
|
pin_memory: bool = True,
|
|
|
shuffle: bool = True,
|
|
|
drop_last: bool = False,
|
|
|
use_dynamic_batching: bool = True,
|
|
|
cache_size: Optional[int] = None):
|
|
|
"""
|
|
|
Initialize optimized data loader.
|
|
|
|
|
|
Args:
|
|
|
dataset: Dataset to load
|
|
|
batch_size: Base batch size
|
|
|
num_workers: Number of worker processes
|
|
|
prefetch_factor: Number of batches to prefetch
|
|
|
pin_memory: Whether to pin memory
|
|
|
shuffle: Whether to shuffle data
|
|
|
drop_last: Whether to drop incomplete batches
|
|
|
use_dynamic_batching: Whether to use dynamic batch sizing
|
|
|
cache_size: Number of samples to cache
|
|
|
"""
|
|
|
self.dataset = dataset
|
|
|
self.batch_size = batch_size
|
|
|
self.num_workers = num_workers
|
|
|
self.prefetch_factor = prefetch_factor
|
|
|
self.pin_memory = pin_memory
|
|
|
self.shuffle = shuffle
|
|
|
self.drop_last = drop_last
|
|
|
self.use_dynamic_batching = use_dynamic_batching
|
|
|
self.cache_size = cache_size
|
|
|
|
|
|
|
|
|
if cache_size and cache_size > 0:
|
|
|
self.dataset = OptimizedDataset(
|
|
|
dataset.data if hasattr(dataset, 'data') else dataset,
|
|
|
dataset.targets if hasattr(dataset, 'targets') else None,
|
|
|
cache_size=cache_size,
|
|
|
pin_memory=pin_memory
|
|
|
)
|
|
|
|
|
|
|
|
|
if use_dynamic_batching:
|
|
|
self.sampler = DynamicBatchSampler(
|
|
|
dataset_size=len(self.dataset),
|
|
|
base_batch_size=batch_size,
|
|
|
max_batch_size=batch_size * 4
|
|
|
)
|
|
|
else:
|
|
|
self.sampler = None
|
|
|
|
|
|
|
|
|
self.data_loader = DataLoader(
|
|
|
dataset=self.dataset,
|
|
|
batch_size=batch_size,
|
|
|
sampler=self.sampler,
|
|
|
shuffle=shuffle if not use_dynamic_batching else False,
|
|
|
num_workers=num_workers,
|
|
|
pin_memory=pin_memory,
|
|
|
drop_last=drop_last,
|
|
|
persistent_workers=True if num_workers > 0 else False
|
|
|
)
|
|
|
|
|
|
|
|
|
self.prefetch_loader = PrefetchDataLoader(
|
|
|
dataset=self.dataset,
|
|
|
batch_size=batch_size,
|
|
|
num_workers=num_workers,
|
|
|
prefetch_factor=prefetch_factor,
|
|
|
pin_memory=pin_memory,
|
|
|
shuffle=shuffle,
|
|
|
drop_last=drop_last
|
|
|
)
|
|
|
|
|
|
print(f"OptimizedDataLoader initialized with {num_workers} workers")
|
|
|
|
|
|
def __iter__(self):
|
|
|
"""Iterate over batches."""
|
|
|
return iter(self.prefetch_loader)
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.data_loader)
|
|
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
|
"""Get loader statistics."""
|
|
|
stats = {
|
|
|
"batch_size": self.batch_size,
|
|
|
"num_workers": self.num_workers,
|
|
|
"prefetch_factor": self.prefetch_factor,
|
|
|
"cache_enabled": self.cache_size is not None,
|
|
|
"dynamic_batching": self.use_dynamic_batching
|
|
|
}
|
|
|
|
|
|
if hasattr(self.dataset, 'get_cache_stats'):
|
|
|
stats.update(self.dataset.get_cache_stats())
|
|
|
|
|
|
if self.sampler:
|
|
|
stats.update(self.sampler.get_stats())
|
|
|
|
|
|
return stats
|
|
|
|
|
|
def stop(self):
|
|
|
"""Stop the data loader."""
|
|
|
self.prefetch_loader.stop()
|
|
|
|
|
|
|
|
|
def create_optimized_loader(dataset: Dataset,
|
|
|
batch_size: int = 32,
|
|
|
num_workers: Optional[int] = None,
|
|
|
**kwargs) -> OptimizedDataLoader:
|
|
|
"""
|
|
|
Create an optimized data loader with automatic configuration.
|
|
|
|
|
|
Args:
|
|
|
dataset: Dataset to load
|
|
|
batch_size: Batch size
|
|
|
num_workers: Number of workers (auto-detect if None)
|
|
|
**kwargs: Additional arguments
|
|
|
|
|
|
Returns:
|
|
|
OptimizedDataLoader: Configured data loader
|
|
|
"""
|
|
|
if num_workers is None:
|
|
|
|
|
|
num_workers = min(4, os.cpu_count() or 1)
|
|
|
|
|
|
return OptimizedDataLoader(
|
|
|
dataset=dataset,
|
|
|
batch_size=batch_size,
|
|
|
num_workers=num_workers,
|
|
|
**kwargs
|
|
|
)
|
|
|
|