sem-v6-training / src /sem_v6 /data /streaming_dataset.py
icarus112's picture
Upload folder using huggingface_hub
518db7a verified
import torch
import torch.utils.data
from pathlib import Path
from typing import Union, Optional, Iterator, List
import logging
from .shard_loader import MultiShardLoader
logger = logging.getLogger(__name__)
class StreamingDataset(torch.utils.data.IterableDataset):
"""
Streaming dataset for efficient memory-mapped data loading.
Implements torch.utils.data.IterableDataset for infinite streaming over shards.
Supports distributed training with automatic worker partitioning.
Memory-efficient: only loads active shard chunks into RAM.
Design:
- Uses MultiShardLoader for background prefetching and cycling
- Partitions shards across DataLoader workers to avoid duplication
- Cycles through shards infinitely for epoch-based training
- GPU-first: transfers batches to GPU asynchronously
Reference: DeepSeek V3 training infrastructure (2025)
Example:
>>> dataset = StreamingDataset(
... shard_paths=["shard_00.pt", "shard_01.pt"],
... batch_size=32,
... device="cuda"
... )
>>> loader = torch.utils.data.DataLoader(
... dataset,
... batch_size=None, # dataset handles batching
... num_workers=4,
... pin_memory=True,
... persistent_workers=True
... )
>>> for batch in loader:
... process(batch) # batch already on GPU
"""
def __init__(
self,
shard_paths: list[Union[str, Path]],
batch_size: int = 32,
device: str = "cuda",
prefetch: bool = True,
cycle: bool = True,
max_samples_per_worker: Optional[int] = None,
) -> None:
"""
Initialize streaming dataset with shard paths.
Args:
shard_paths: List of paths to shard files (.pt or .npy)
batch_size: Number of samples per batch (default: 32)
device: Target device for batches (default: "cuda")
prefetch: Enable background prefetching (default: True)
cycle: Cycle through shards infinitely (default: True)
max_samples_per_worker: Maximum samples per worker (None = infinite)
"""
assert len(shard_paths) > 0, "Must provide at least one shard path"
assert batch_size > 0, f"batch_size must be positive, got {batch_size}"
self.shard_paths = [Path(p) for p in shard_paths]
self.batch_size = batch_size
self.device = device
self.prefetch = prefetch
self.cycle = cycle
self.max_samples_per_worker = max_samples_per_worker
# Verify all shards exist
for path in self.shard_paths:
if not path.exists():
raise FileNotFoundError(f"Shard file not found: {path}")
logger.info(
f"Initialized StreamingDataset with {len(self.shard_paths)} shards, "
f"batch_size={batch_size}, device={device}"
)
def _get_worker_shards(self) -> list[Path]:
"""
Partition shards across DataLoader workers to avoid duplication.
Each worker gets a disjoint subset of shards for parallel loading.
Uses PyTorch's worker_info to determine partition.
Returns:
List of shard paths assigned to this worker
"""
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
# Single-process data loading
return self.shard_paths
# Multi-worker: partition shards by worker_id
num_workers = worker_info.num_workers
worker_id = worker_info.id
# Assign shards to workers in round-robin fashion
worker_shards = [
shard for i, shard in enumerate(self.shard_paths)
if i % num_workers == worker_id
]
logger.debug(
f"Worker {worker_id}/{num_workers} assigned {len(worker_shards)} shards"
)
return worker_shards
def __iter__(self) -> Iterator[Union[torch.Tensor, List]]:
"""
Iterate over batches from assigned shards.
Automatically partitions shards across workers for distributed loading.
Cycles through shards infinitely (if cycle=True) for epoch-based training.
Yields:
Batches as tensors on specified device
"""
worker_shards = self._get_worker_shards()
if len(worker_shards) == 0:
logger.warning("No shards assigned to this worker")
return
# Calculate num_batches if max_samples_per_worker is set
num_batches = None
if self.max_samples_per_worker is not None:
num_batches = self.max_samples_per_worker // self.batch_size
# Create multi-shard loader for this worker's shards
with MultiShardLoader(
shard_paths=worker_shards,
device=self.device,
prefetch=self.prefetch,
cycle=self.cycle,
) as loader:
# Stream batches from loader
for batch in loader.iter_batches(
batch_size=self.batch_size,
num_batches=num_batches,
to_gpu=True,
):
yield batch
def __len__(self) -> int:
"""
Get approximate total number of batches.
Note: This is an approximation based on shard file sizes.
Actual length may vary slightly due to padding.
Returns:
Approximate number of batches across all shards
"""
# Load first shard to estimate samples per shard
# (assumes all shards have similar size)
try:
from .shard_loader import ShardLoader
first_loader = ShardLoader(self.shard_paths[0], device=self.device)
samples_per_shard = len(first_loader)
first_loader.close()
total_samples = len(self.shard_paths) * samples_per_shard
total_batches = total_samples // self.batch_size
return total_batches
except Exception as e:
logger.warning(f"Failed to compute dataset length: {e}")
return 0