import torch import numpy as np from pathlib import Path from typing import Union, Optional, Iterator, Any, Sequence, cast, List import threading import logging logger = logging.getLogger(__name__) class ShardLoader: """ Memory-mapped shard loader for efficient streaming data access. Supports lazy loading of data shards without loading entire files into RAM. Uses np.ndarray for .npy files and lazy torch.load for .pt files. Includes background prefetching for next shard to minimize latency. Example: >>> loader = ShardLoader("data/shard_00.npy", device="cuda") >>> batch = loader.get_batch(start=0, size=32) >>> loader.prefetch_next("data/shard_01.npy") """ def __init__( self, shard_path: Union[str, Path], device: str = "cuda", prefetch: bool = True, ) -> None: """ Initialize shard loader with memory-mapped file access. Args: shard_path: Path to shard file (.pt or .npy format) device: Target device for tensor operations (default: "cuda") prefetch: Enable background prefetching for next shard (default: True) """ self.shard_path = Path(shard_path) self.device = torch.device(device) self.prefetch_enabled = prefetch assert self.shard_path.exists(), f"Shard file not found: {shard_path}" assert self.shard_path.suffix in [".pt", ".npy"], \ f"Unsupported format: {self.shard_path.suffix}. Use .pt or .npy" # Lazy-loaded data (only loaded when accessed) self._data: Optional[Union[np.ndarray, torch.Tensor, List]] = None self._data_loaded = False # Prefetch state self._prefetch_thread: Optional[threading.Thread] = None self._prefetch_data: Optional[Union[np.ndarray, torch.Tensor, List]] = None self._prefetch_path: Optional[Path] = None self._prefetch_lock = threading.Lock() logger.info(f"Initialized ShardLoader for {self.shard_path.name}") def _load_data(self) -> Union[np.ndarray, torch.Tensor, List]: """ Load shard data using memory-mapping for efficiency. Returns: Memory-mapped array (.npy), tensor (.pt), or list (for text data) """ if self.shard_path.suffix == ".npy": # Memory-mapped numpy array (zero-copy, lazy loading) mmap_data = cast(np.ndarray, np.load(self.shard_path, mmap_mode='r')) logger.debug(f"Memory-mapped .npy shard: {self.shard_path.name}, " f"shape={mmap_data.shape}, dtype={mmap_data.dtype}") return mmap_data elif self.shard_path.suffix == ".pt": # Load PyTorch tensor (lazy load, keep on CPU until needed) pt_data = cast(Union[torch.Tensor, List], torch.load(self.shard_path, map_location='cpu')) # Handle different data types if isinstance(pt_data, torch.Tensor): logger.debug(f"Loaded .pt tensor: {self.shard_path.name}, " f"shape={pt_data.shape}, dtype={pt_data.dtype}") elif isinstance(pt_data, list): logger.debug(f"Loaded .pt list: {self.shard_path.name}, " f"length={len(pt_data)}, type={type(pt_data[0]) if pt_data else 'empty'}") else: logger.debug(f"Loaded .pt shard: {self.shard_path.name}, " f"type={type(pt_data)}") return pt_data else: raise ValueError(f"Unsupported format: {self.shard_path.suffix}") @property def data(self) -> Union[np.ndarray, torch.Tensor, List]: """ Lazily load and return shard data. Returns: Shard data (memory-mapped, tensor, or list) """ if not self._data_loaded: self._data = self._load_data() self._data_loaded = True assert self._data is not None, "Data should be loaded" return self._data def get_batch( self, start: int, size: int, to_gpu: bool = True ) -> Union[torch.Tensor, List]: """ Extract a batch from the shard with optional GPU transfer. Args: start: Starting index in shard size: Batch size (number of samples) to_gpu: Transfer batch to GPU immediately (default: True) Returns: Batch tensor on specified device (for tensor/array data) or list (for text data) """ end = min(start + size, len(self.data)) # Extract batch slice (zero-copy for memmap) batch_data = self.data[start:end] # Handle different data types if isinstance(batch_data, list): # Return list as-is (text data) return batch_data elif isinstance(batch_data, np.ndarray): # Convert numpy to tensor batch = torch.from_numpy(np.array(batch_data)) else: # Already a tensor batch = batch_data # Transfer to GPU if requested if to_gpu and self.device.type == 'cuda': batch = batch.to(self.device, non_blocking=True) return batch def __len__(self) -> int: """ Get number of samples in shard. Returns: Number of samples (first dimension of data) """ return len(self.data) def prefetch_next(self, next_shard_path: Union[str, Path]) -> None: """ Prefetch next shard in background thread to minimize latency. Args: next_shard_path: Path to next shard file to prefetch """ if not self.prefetch_enabled: return next_path = Path(next_shard_path) # Wait for previous prefetch to complete if self._prefetch_thread is not None: self._prefetch_thread.join() # Start background prefetch self._prefetch_path = next_path self._prefetch_thread = threading.Thread( target=self._background_prefetch, daemon=True ) self._prefetch_thread.start() logger.debug(f"Started prefetch for {next_path.name}") def _background_prefetch(self) -> None: """ Background worker for prefetching next shard. """ try: if self._prefetch_path is None: return if self._prefetch_path.suffix == ".npy": data = np.load(self._prefetch_path, mmap_mode='r') else: data = torch.load(self._prefetch_path, map_location='cpu') with self._prefetch_lock: self._prefetch_data = data logger.debug(f"Prefetch completed: {self._prefetch_path.name}") except Exception as e: logger.error(f"Prefetch failed for {self._prefetch_path}: {e}") with self._prefetch_lock: self._prefetch_data = None def swap_to_prefetched(self) -> bool: """ Swap current shard data with prefetched data. Returns: True if swap successful, False if no prefetched data available """ if self._prefetch_thread is not None: self._prefetch_thread.join() with self._prefetch_lock: if self._prefetch_data is not None and self._prefetch_path is not None: self._data = self._prefetch_data self.shard_path = self._prefetch_path self._data_loaded = True # Clear prefetch state self._prefetch_data = None self._prefetch_path = None self._prefetch_thread = None logger.info(f"Swapped to prefetched shard: {self.shard_path.name}") return True return False def close(self) -> None: """ Clean up resources and wait for background threads. """ if self._prefetch_thread is not None: self._prefetch_thread.join() # Clear references to allow garbage collection self._data = None self._prefetch_data = None self._data_loaded = False logger.debug(f"Closed ShardLoader for {self.shard_path.name}") def __enter__(self) -> "ShardLoader": """Context manager entry.""" return self def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: """Context manager exit.""" self.close() class MultiShardLoader: """ Manages multiple shards with automatic cycling and prefetching. Designed for infinite streaming over a list of shard files. Automatically prefetches next shard while processing current one. Example: >>> shard_paths = ["shard_00.npy", "shard_01.npy", "shard_02.npy"] >>> loader = MultiShardLoader(shard_paths, device="cuda") >>> for batch in loader.iter_batches(batch_size=32, num_batches=1000): ... process(batch) """ def __init__( self, shard_paths: Sequence[Union[str, Path]], device: str = "cuda", prefetch: bool = True, cycle: bool = True, ) -> None: """ Initialize multi-shard loader. Args: shard_paths: List of paths to shard files device: Target device for tensors (default: "cuda") prefetch: Enable prefetching (default: True) cycle: Cycle through shards infinitely (default: True) """ assert len(shard_paths) > 0, "Must provide at least one shard path" self.shard_paths = [Path(p) for p in shard_paths] self.device = device self.prefetch_enabled = prefetch self.cycle = cycle # Verify all shard files exist for path in self.shard_paths: assert path.exists(), f"Shard file not found: {path}" self.current_shard_idx = 0 self.current_loader: Optional[ShardLoader] = None logger.info(f"Initialized MultiShardLoader with {len(self.shard_paths)} shards") def _load_shard(self, idx: int) -> ShardLoader: """ Load shard at given index. Args: idx: Index into shard_paths list Returns: ShardLoader for the shard """ return ShardLoader( self.shard_paths[idx], device=self.device, prefetch=self.prefetch_enabled ) def _next_shard_idx(self) -> Optional[int]: """ Get next shard index (with optional cycling). Returns: Next shard index, or None if at end and not cycling """ next_idx = self.current_shard_idx + 1 if next_idx >= len(self.shard_paths): if self.cycle: return 0 else: return None return next_idx def iter_batches( self, batch_size: int, num_batches: Optional[int] = None, to_gpu: bool = True, ) -> Iterator[Union[torch.Tensor, List]]: """ Iterate over batches from all shards. Args: batch_size: Number of samples per batch num_batches: Maximum number of batches (None = infinite) to_gpu: Transfer batches to GPU (default: True) Yields: Batches as tensors on specified device """ batches_yielded = 0 # Load first shard if self.current_loader is None: self.current_loader = self._load_shard(self.current_shard_idx) # Prefetch next shard next_idx = self._next_shard_idx() if next_idx is not None and self.prefetch_enabled: self.current_loader.prefetch_next(self.shard_paths[next_idx]) current_pos = 0 while True: # Check if we've yielded enough batches if num_batches is not None and batches_yielded >= num_batches: break # Check if current shard is exhausted if current_pos >= len(self.current_loader): # Move to next shard next_idx = self._next_shard_idx() if next_idx is None: # No more shards and not cycling break # Swap to prefetched shard or load new one if not self.current_loader.swap_to_prefetched(): self.current_loader.close() self.current_loader = self._load_shard(next_idx) self.current_shard_idx = next_idx current_pos = 0 # Prefetch next shard next_next_idx = self._next_shard_idx() if next_next_idx is not None and self.prefetch_enabled: self.current_loader.prefetch_next(self.shard_paths[next_next_idx]) # Get batch from current shard batch = self.current_loader.get_batch( start=current_pos, size=batch_size, to_gpu=to_gpu ) current_pos += batch_size batches_yielded += 1 yield batch def close(self) -> None: """ Clean up all resources. """ if self.current_loader is not None: self.current_loader.close() logger.info("Closed MultiShardLoader") def __enter__(self) -> "MultiShardLoader": """Context manager entry.""" return self def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: """Context manager exit.""" self.close()