import torch import torch.utils.data from pathlib import Path from typing import Union, Optional, Iterator, List import logging from .shard_loader import MultiShardLoader logger = logging.getLogger(__name__) class StreamingDataset(torch.utils.data.IterableDataset): """ Streaming dataset for efficient memory-mapped data loading. Implements torch.utils.data.IterableDataset for infinite streaming over shards. Supports distributed training with automatic worker partitioning. Memory-efficient: only loads active shard chunks into RAM. Design: - Uses MultiShardLoader for background prefetching and cycling - Partitions shards across DataLoader workers to avoid duplication - Cycles through shards infinitely for epoch-based training - GPU-first: transfers batches to GPU asynchronously Reference: DeepSeek V3 training infrastructure (2025) Example: >>> dataset = StreamingDataset( ... shard_paths=["shard_00.pt", "shard_01.pt"], ... batch_size=32, ... device="cuda" ... ) >>> loader = torch.utils.data.DataLoader( ... dataset, ... batch_size=None, # dataset handles batching ... num_workers=4, ... pin_memory=True, ... persistent_workers=True ... ) >>> for batch in loader: ... process(batch) # batch already on GPU """ def __init__( self, shard_paths: list[Union[str, Path]], batch_size: int = 32, device: str = "cuda", prefetch: bool = True, cycle: bool = True, max_samples_per_worker: Optional[int] = None, ) -> None: """ Initialize streaming dataset with shard paths. Args: shard_paths: List of paths to shard files (.pt or .npy) batch_size: Number of samples per batch (default: 32) device: Target device for batches (default: "cuda") prefetch: Enable background prefetching (default: True) cycle: Cycle through shards infinitely (default: True) max_samples_per_worker: Maximum samples per worker (None = infinite) """ assert len(shard_paths) > 0, "Must provide at least one shard path" assert batch_size > 0, f"batch_size must be positive, got {batch_size}" self.shard_paths = [Path(p) for p in shard_paths] self.batch_size = batch_size self.device = device self.prefetch = prefetch self.cycle = cycle self.max_samples_per_worker = max_samples_per_worker # Verify all shards exist for path in self.shard_paths: if not path.exists(): raise FileNotFoundError(f"Shard file not found: {path}") logger.info( f"Initialized StreamingDataset with {len(self.shard_paths)} shards, " f"batch_size={batch_size}, device={device}" ) def _get_worker_shards(self) -> list[Path]: """ Partition shards across DataLoader workers to avoid duplication. Each worker gets a disjoint subset of shards for parallel loading. Uses PyTorch's worker_info to determine partition. Returns: List of shard paths assigned to this worker """ worker_info = torch.utils.data.get_worker_info() if worker_info is None: # Single-process data loading return self.shard_paths # Multi-worker: partition shards by worker_id num_workers = worker_info.num_workers worker_id = worker_info.id # Assign shards to workers in round-robin fashion worker_shards = [ shard for i, shard in enumerate(self.shard_paths) if i % num_workers == worker_id ] logger.debug( f"Worker {worker_id}/{num_workers} assigned {len(worker_shards)} shards" ) return worker_shards def __iter__(self) -> Iterator[Union[torch.Tensor, List]]: """ Iterate over batches from assigned shards. Automatically partitions shards across workers for distributed loading. Cycles through shards infinitely (if cycle=True) for epoch-based training. Yields: Batches as tensors on specified device """ worker_shards = self._get_worker_shards() if len(worker_shards) == 0: logger.warning("No shards assigned to this worker") return # Calculate num_batches if max_samples_per_worker is set num_batches = None if self.max_samples_per_worker is not None: num_batches = self.max_samples_per_worker // self.batch_size # Create multi-shard loader for this worker's shards with MultiShardLoader( shard_paths=worker_shards, device=self.device, prefetch=self.prefetch, cycle=self.cycle, ) as loader: # Stream batches from loader for batch in loader.iter_batches( batch_size=self.batch_size, num_batches=num_batches, to_gpu=True, ): yield batch def __len__(self) -> int: """ Get approximate total number of batches. Note: This is an approximation based on shard file sizes. Actual length may vary slightly due to padding. Returns: Approximate number of batches across all shards """ # Load first shard to estimate samples per shard # (assumes all shards have similar size) try: from .shard_loader import ShardLoader first_loader = ShardLoader(self.shard_paths[0], device=self.device) samples_per_shard = len(first_loader) first_loader.close() total_samples = len(self.shard_paths) * samples_per_shard total_batches = total_samples // self.batch_size return total_batches except Exception as e: logger.warning(f"Failed to compute dataset length: {e}") return 0