| import torch |
| import numpy as np |
| from pathlib import Path |
| from typing import Union, Optional, Iterator, Any, Sequence, cast, List |
| import threading |
| import logging |
|
|
| logger = logging.getLogger(__name__) |
|
|
| class ShardLoader: |
| """ |
| Memory-mapped shard loader for efficient streaming data access. |
| |
| Supports lazy loading of data shards without loading entire files into RAM. |
| Uses np.ndarray for .npy files and lazy torch.load for .pt files. |
| Includes background prefetching for next shard to minimize latency. |
| |
| Example: |
| >>> loader = ShardLoader("data/shard_00.npy", device="cuda") |
| >>> batch = loader.get_batch(start=0, size=32) |
| >>> loader.prefetch_next("data/shard_01.npy") |
| """ |
|
|
| def __init__( |
| self, |
| shard_path: Union[str, Path], |
| device: str = "cuda", |
| prefetch: bool = True, |
| ) -> None: |
| """ |
| Initialize shard loader with memory-mapped file access. |
| |
| Args: |
| shard_path: Path to shard file (.pt or .npy format) |
| device: Target device for tensor operations (default: "cuda") |
| prefetch: Enable background prefetching for next shard (default: True) |
| """ |
| self.shard_path = Path(shard_path) |
| self.device = torch.device(device) |
| self.prefetch_enabled = prefetch |
|
|
| assert self.shard_path.exists(), f"Shard file not found: {shard_path}" |
| assert self.shard_path.suffix in [".pt", ".npy"], \ |
| f"Unsupported format: {self.shard_path.suffix}. Use .pt or .npy" |
|
|
| |
| self._data: Optional[Union[np.ndarray, torch.Tensor, List]] = None |
| self._data_loaded = False |
|
|
| |
| self._prefetch_thread: Optional[threading.Thread] = None |
| self._prefetch_data: Optional[Union[np.ndarray, torch.Tensor, List]] = None |
| self._prefetch_path: Optional[Path] = None |
| self._prefetch_lock = threading.Lock() |
|
|
| logger.info(f"Initialized ShardLoader for {self.shard_path.name}") |
|
|
| def _load_data(self) -> Union[np.ndarray, torch.Tensor, List]: |
| """ |
| Load shard data using memory-mapping for efficiency. |
| |
| Returns: |
| Memory-mapped array (.npy), tensor (.pt), or list (for text data) |
| """ |
| if self.shard_path.suffix == ".npy": |
| |
| mmap_data = cast(np.ndarray, np.load(self.shard_path, mmap_mode='r')) |
| logger.debug(f"Memory-mapped .npy shard: {self.shard_path.name}, " |
| f"shape={mmap_data.shape}, dtype={mmap_data.dtype}") |
| return mmap_data |
|
|
| elif self.shard_path.suffix == ".pt": |
| |
| pt_data = cast(Union[torch.Tensor, List], torch.load(self.shard_path, map_location='cpu')) |
|
|
| |
| if isinstance(pt_data, torch.Tensor): |
| logger.debug(f"Loaded .pt tensor: {self.shard_path.name}, " |
| f"shape={pt_data.shape}, dtype={pt_data.dtype}") |
| elif isinstance(pt_data, list): |
| logger.debug(f"Loaded .pt list: {self.shard_path.name}, " |
| f"length={len(pt_data)}, type={type(pt_data[0]) if pt_data else 'empty'}") |
| else: |
| logger.debug(f"Loaded .pt shard: {self.shard_path.name}, " |
| f"type={type(pt_data)}") |
| return pt_data |
|
|
| else: |
| raise ValueError(f"Unsupported format: {self.shard_path.suffix}") |
|
|
| @property |
| def data(self) -> Union[np.ndarray, torch.Tensor, List]: |
| """ |
| Lazily load and return shard data. |
| |
| Returns: |
| Shard data (memory-mapped, tensor, or list) |
| """ |
| if not self._data_loaded: |
| self._data = self._load_data() |
| self._data_loaded = True |
|
|
| assert self._data is not None, "Data should be loaded" |
| return self._data |
|
|
| def get_batch( |
| self, |
| start: int, |
| size: int, |
| to_gpu: bool = True |
| ) -> Union[torch.Tensor, List]: |
| """ |
| Extract a batch from the shard with optional GPU transfer. |
| |
| Args: |
| start: Starting index in shard |
| size: Batch size (number of samples) |
| to_gpu: Transfer batch to GPU immediately (default: True) |
| |
| Returns: |
| Batch tensor on specified device (for tensor/array data) or list (for text data) |
| """ |
| end = min(start + size, len(self.data)) |
|
|
| |
| batch_data = self.data[start:end] |
|
|
| |
| if isinstance(batch_data, list): |
| |
| return batch_data |
| elif isinstance(batch_data, np.ndarray): |
| |
| batch = torch.from_numpy(np.array(batch_data)) |
| else: |
| |
| batch = batch_data |
|
|
| |
| if to_gpu and self.device.type == 'cuda': |
| batch = batch.to(self.device, non_blocking=True) |
|
|
| return batch |
|
|
| def __len__(self) -> int: |
| """ |
| Get number of samples in shard. |
| |
| Returns: |
| Number of samples (first dimension of data) |
| """ |
| return len(self.data) |
|
|
| def prefetch_next(self, next_shard_path: Union[str, Path]) -> None: |
| """ |
| Prefetch next shard in background thread to minimize latency. |
| |
| Args: |
| next_shard_path: Path to next shard file to prefetch |
| """ |
| if not self.prefetch_enabled: |
| return |
|
|
| next_path = Path(next_shard_path) |
|
|
| |
| if self._prefetch_thread is not None: |
| self._prefetch_thread.join() |
|
|
| |
| self._prefetch_path = next_path |
| self._prefetch_thread = threading.Thread( |
| target=self._background_prefetch, |
| daemon=True |
| ) |
| self._prefetch_thread.start() |
|
|
| logger.debug(f"Started prefetch for {next_path.name}") |
|
|
| def _background_prefetch(self) -> None: |
| """ |
| Background worker for prefetching next shard. |
| """ |
| try: |
| if self._prefetch_path is None: |
| return |
|
|
| if self._prefetch_path.suffix == ".npy": |
| data = np.load(self._prefetch_path, mmap_mode='r') |
| else: |
| data = torch.load(self._prefetch_path, map_location='cpu') |
|
|
| with self._prefetch_lock: |
| self._prefetch_data = data |
|
|
| logger.debug(f"Prefetch completed: {self._prefetch_path.name}") |
|
|
| except Exception as e: |
| logger.error(f"Prefetch failed for {self._prefetch_path}: {e}") |
| with self._prefetch_lock: |
| self._prefetch_data = None |
|
|
| def swap_to_prefetched(self) -> bool: |
| """ |
| Swap current shard data with prefetched data. |
| |
| Returns: |
| True if swap successful, False if no prefetched data available |
| """ |
| if self._prefetch_thread is not None: |
| self._prefetch_thread.join() |
|
|
| with self._prefetch_lock: |
| if self._prefetch_data is not None and self._prefetch_path is not None: |
| self._data = self._prefetch_data |
| self.shard_path = self._prefetch_path |
| self._data_loaded = True |
|
|
| |
| self._prefetch_data = None |
| self._prefetch_path = None |
| self._prefetch_thread = None |
|
|
| logger.info(f"Swapped to prefetched shard: {self.shard_path.name}") |
| return True |
|
|
| return False |
|
|
| def close(self) -> None: |
| """ |
| Clean up resources and wait for background threads. |
| """ |
| if self._prefetch_thread is not None: |
| self._prefetch_thread.join() |
|
|
| |
| self._data = None |
| self._prefetch_data = None |
| self._data_loaded = False |
|
|
| logger.debug(f"Closed ShardLoader for {self.shard_path.name}") |
|
|
| def __enter__(self) -> "ShardLoader": |
| """Context manager entry.""" |
| return self |
|
|
| def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: |
| """Context manager exit.""" |
| self.close() |
|
|
| class MultiShardLoader: |
| """ |
| Manages multiple shards with automatic cycling and prefetching. |
| |
| Designed for infinite streaming over a list of shard files. |
| Automatically prefetches next shard while processing current one. |
| |
| Example: |
| >>> shard_paths = ["shard_00.npy", "shard_01.npy", "shard_02.npy"] |
| >>> loader = MultiShardLoader(shard_paths, device="cuda") |
| >>> for batch in loader.iter_batches(batch_size=32, num_batches=1000): |
| ... process(batch) |
| """ |
|
|
| def __init__( |
| self, |
| shard_paths: Sequence[Union[str, Path]], |
| device: str = "cuda", |
| prefetch: bool = True, |
| cycle: bool = True, |
| ) -> None: |
| """ |
| Initialize multi-shard loader. |
| |
| Args: |
| shard_paths: List of paths to shard files |
| device: Target device for tensors (default: "cuda") |
| prefetch: Enable prefetching (default: True) |
| cycle: Cycle through shards infinitely (default: True) |
| """ |
| assert len(shard_paths) > 0, "Must provide at least one shard path" |
|
|
| self.shard_paths = [Path(p) for p in shard_paths] |
| self.device = device |
| self.prefetch_enabled = prefetch |
| self.cycle = cycle |
|
|
| |
| for path in self.shard_paths: |
| assert path.exists(), f"Shard file not found: {path}" |
|
|
| self.current_shard_idx = 0 |
| self.current_loader: Optional[ShardLoader] = None |
|
|
| logger.info(f"Initialized MultiShardLoader with {len(self.shard_paths)} shards") |
|
|
| def _load_shard(self, idx: int) -> ShardLoader: |
| """ |
| Load shard at given index. |
| |
| Args: |
| idx: Index into shard_paths list |
| |
| Returns: |
| ShardLoader for the shard |
| """ |
| return ShardLoader( |
| self.shard_paths[idx], |
| device=self.device, |
| prefetch=self.prefetch_enabled |
| ) |
|
|
| def _next_shard_idx(self) -> Optional[int]: |
| """ |
| Get next shard index (with optional cycling). |
| |
| Returns: |
| Next shard index, or None if at end and not cycling |
| """ |
| next_idx = self.current_shard_idx + 1 |
|
|
| if next_idx >= len(self.shard_paths): |
| if self.cycle: |
| return 0 |
| else: |
| return None |
|
|
| return next_idx |
|
|
| def iter_batches( |
| self, |
| batch_size: int, |
| num_batches: Optional[int] = None, |
| to_gpu: bool = True, |
| ) -> Iterator[Union[torch.Tensor, List]]: |
| """ |
| Iterate over batches from all shards. |
| |
| Args: |
| batch_size: Number of samples per batch |
| num_batches: Maximum number of batches (None = infinite) |
| to_gpu: Transfer batches to GPU (default: True) |
| |
| Yields: |
| Batches as tensors on specified device |
| """ |
| batches_yielded = 0 |
|
|
| |
| if self.current_loader is None: |
| self.current_loader = self._load_shard(self.current_shard_idx) |
|
|
| |
| next_idx = self._next_shard_idx() |
| if next_idx is not None and self.prefetch_enabled: |
| self.current_loader.prefetch_next(self.shard_paths[next_idx]) |
|
|
| current_pos = 0 |
|
|
| while True: |
| |
| if num_batches is not None and batches_yielded >= num_batches: |
| break |
|
|
| |
| if current_pos >= len(self.current_loader): |
| |
| next_idx = self._next_shard_idx() |
|
|
| if next_idx is None: |
| |
| break |
|
|
| |
| if not self.current_loader.swap_to_prefetched(): |
| self.current_loader.close() |
| self.current_loader = self._load_shard(next_idx) |
|
|
| self.current_shard_idx = next_idx |
| current_pos = 0 |
|
|
| |
| next_next_idx = self._next_shard_idx() |
| if next_next_idx is not None and self.prefetch_enabled: |
| self.current_loader.prefetch_next(self.shard_paths[next_next_idx]) |
|
|
| |
| batch = self.current_loader.get_batch( |
| start=current_pos, |
| size=batch_size, |
| to_gpu=to_gpu |
| ) |
|
|
| current_pos += batch_size |
| batches_yielded += 1 |
|
|
| yield batch |
|
|
| def close(self) -> None: |
| """ |
| Clean up all resources. |
| """ |
| if self.current_loader is not None: |
| self.current_loader.close() |
|
|
| logger.info("Closed MultiShardLoader") |
|
|
| def __enter__(self) -> "MultiShardLoader": |
| """Context manager entry.""" |
| return self |
|
|
| def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: |
| """Context manager exit.""" |
| self.close() |
|
|