| import torch |
| import torch.utils.data |
| from pathlib import Path |
| from typing import Union, Optional, Iterator, List |
| import logging |
|
|
| from .shard_loader import MultiShardLoader |
|
|
| logger = logging.getLogger(__name__) |
|
|
| class StreamingDataset(torch.utils.data.IterableDataset): |
| """ |
| Streaming dataset for efficient memory-mapped data loading. |
| |
| Implements torch.utils.data.IterableDataset for infinite streaming over shards. |
| Supports distributed training with automatic worker partitioning. |
| Memory-efficient: only loads active shard chunks into RAM. |
| |
| Design: |
| - Uses MultiShardLoader for background prefetching and cycling |
| - Partitions shards across DataLoader workers to avoid duplication |
| - Cycles through shards infinitely for epoch-based training |
| - GPU-first: transfers batches to GPU asynchronously |
| |
| Reference: DeepSeek V3 training infrastructure (2025) |
| |
| Example: |
| >>> dataset = StreamingDataset( |
| ... shard_paths=["shard_00.pt", "shard_01.pt"], |
| ... batch_size=32, |
| ... device="cuda" |
| ... ) |
| >>> loader = torch.utils.data.DataLoader( |
| ... dataset, |
| ... batch_size=None, # dataset handles batching |
| ... num_workers=4, |
| ... pin_memory=True, |
| ... persistent_workers=True |
| ... ) |
| >>> for batch in loader: |
| ... process(batch) # batch already on GPU |
| """ |
|
|
| def __init__( |
| self, |
| shard_paths: list[Union[str, Path]], |
| batch_size: int = 32, |
| device: str = "cuda", |
| prefetch: bool = True, |
| cycle: bool = True, |
| max_samples_per_worker: Optional[int] = None, |
| ) -> None: |
| """ |
| Initialize streaming dataset with shard paths. |
| |
| Args: |
| shard_paths: List of paths to shard files (.pt or .npy) |
| batch_size: Number of samples per batch (default: 32) |
| device: Target device for batches (default: "cuda") |
| prefetch: Enable background prefetching (default: True) |
| cycle: Cycle through shards infinitely (default: True) |
| max_samples_per_worker: Maximum samples per worker (None = infinite) |
| """ |
| assert len(shard_paths) > 0, "Must provide at least one shard path" |
| assert batch_size > 0, f"batch_size must be positive, got {batch_size}" |
|
|
| self.shard_paths = [Path(p) for p in shard_paths] |
| self.batch_size = batch_size |
| self.device = device |
| self.prefetch = prefetch |
| self.cycle = cycle |
| self.max_samples_per_worker = max_samples_per_worker |
|
|
| |
| for path in self.shard_paths: |
| if not path.exists(): |
| raise FileNotFoundError(f"Shard file not found: {path}") |
|
|
| logger.info( |
| f"Initialized StreamingDataset with {len(self.shard_paths)} shards, " |
| f"batch_size={batch_size}, device={device}" |
| ) |
|
|
| def _get_worker_shards(self) -> list[Path]: |
| """ |
| Partition shards across DataLoader workers to avoid duplication. |
| |
| Each worker gets a disjoint subset of shards for parallel loading. |
| Uses PyTorch's worker_info to determine partition. |
| |
| Returns: |
| List of shard paths assigned to this worker |
| """ |
| worker_info = torch.utils.data.get_worker_info() |
|
|
| if worker_info is None: |
| |
| return self.shard_paths |
|
|
| |
| num_workers = worker_info.num_workers |
| worker_id = worker_info.id |
|
|
| |
| worker_shards = [ |
| shard for i, shard in enumerate(self.shard_paths) |
| if i % num_workers == worker_id |
| ] |
|
|
| logger.debug( |
| f"Worker {worker_id}/{num_workers} assigned {len(worker_shards)} shards" |
| ) |
|
|
| return worker_shards |
|
|
| def __iter__(self) -> Iterator[Union[torch.Tensor, List]]: |
| """ |
| Iterate over batches from assigned shards. |
| |
| Automatically partitions shards across workers for distributed loading. |
| Cycles through shards infinitely (if cycle=True) for epoch-based training. |
| |
| Yields: |
| Batches as tensors on specified device |
| """ |
| worker_shards = self._get_worker_shards() |
|
|
| if len(worker_shards) == 0: |
| logger.warning("No shards assigned to this worker") |
| return |
|
|
| |
| num_batches = None |
| if self.max_samples_per_worker is not None: |
| num_batches = self.max_samples_per_worker // self.batch_size |
|
|
| |
| with MultiShardLoader( |
| shard_paths=worker_shards, |
| device=self.device, |
| prefetch=self.prefetch, |
| cycle=self.cycle, |
| ) as loader: |
| |
| for batch in loader.iter_batches( |
| batch_size=self.batch_size, |
| num_batches=num_batches, |
| to_gpu=True, |
| ): |
| yield batch |
|
|
| def __len__(self) -> int: |
| """ |
| Get approximate total number of batches. |
| |
| Note: This is an approximation based on shard file sizes. |
| Actual length may vary slightly due to padding. |
| |
| Returns: |
| Approximate number of batches across all shards |
| """ |
| |
| |
| try: |
| from .shard_loader import ShardLoader |
|
|
| first_loader = ShardLoader(self.shard_paths[0], device=self.device) |
| samples_per_shard = len(first_loader) |
| first_loader.close() |
|
|
| total_samples = len(self.shard_paths) * samples_per_shard |
| total_batches = total_samples // self.batch_size |
|
|
| return total_batches |
|
|
| except Exception as e: |
| logger.warning(f"Failed to compute dataset length: {e}") |
| return 0 |
|
|