sem-v6-training / src /sem_v6 /data /shard_loader.py
icarus112's picture
Upload folder using huggingface_hub
518db7a verified
import torch
import numpy as np
from pathlib import Path
from typing import Union, Optional, Iterator, Any, Sequence, cast, List
import threading
import logging
logger = logging.getLogger(__name__)
class ShardLoader:
"""
Memory-mapped shard loader for efficient streaming data access.
Supports lazy loading of data shards without loading entire files into RAM.
Uses np.ndarray for .npy files and lazy torch.load for .pt files.
Includes background prefetching for next shard to minimize latency.
Example:
>>> loader = ShardLoader("data/shard_00.npy", device="cuda")
>>> batch = loader.get_batch(start=0, size=32)
>>> loader.prefetch_next("data/shard_01.npy")
"""
def __init__(
self,
shard_path: Union[str, Path],
device: str = "cuda",
prefetch: bool = True,
) -> None:
"""
Initialize shard loader with memory-mapped file access.
Args:
shard_path: Path to shard file (.pt or .npy format)
device: Target device for tensor operations (default: "cuda")
prefetch: Enable background prefetching for next shard (default: True)
"""
self.shard_path = Path(shard_path)
self.device = torch.device(device)
self.prefetch_enabled = prefetch
assert self.shard_path.exists(), f"Shard file not found: {shard_path}"
assert self.shard_path.suffix in [".pt", ".npy"], \
f"Unsupported format: {self.shard_path.suffix}. Use .pt or .npy"
# Lazy-loaded data (only loaded when accessed)
self._data: Optional[Union[np.ndarray, torch.Tensor, List]] = None
self._data_loaded = False
# Prefetch state
self._prefetch_thread: Optional[threading.Thread] = None
self._prefetch_data: Optional[Union[np.ndarray, torch.Tensor, List]] = None
self._prefetch_path: Optional[Path] = None
self._prefetch_lock = threading.Lock()
logger.info(f"Initialized ShardLoader for {self.shard_path.name}")
def _load_data(self) -> Union[np.ndarray, torch.Tensor, List]:
"""
Load shard data using memory-mapping for efficiency.
Returns:
Memory-mapped array (.npy), tensor (.pt), or list (for text data)
"""
if self.shard_path.suffix == ".npy":
# Memory-mapped numpy array (zero-copy, lazy loading)
mmap_data = cast(np.ndarray, np.load(self.shard_path, mmap_mode='r'))
logger.debug(f"Memory-mapped .npy shard: {self.shard_path.name}, "
f"shape={mmap_data.shape}, dtype={mmap_data.dtype}")
return mmap_data
elif self.shard_path.suffix == ".pt":
# Load PyTorch tensor (lazy load, keep on CPU until needed)
pt_data = cast(Union[torch.Tensor, List], torch.load(self.shard_path, map_location='cpu'))
# Handle different data types
if isinstance(pt_data, torch.Tensor):
logger.debug(f"Loaded .pt tensor: {self.shard_path.name}, "
f"shape={pt_data.shape}, dtype={pt_data.dtype}")
elif isinstance(pt_data, list):
logger.debug(f"Loaded .pt list: {self.shard_path.name}, "
f"length={len(pt_data)}, type={type(pt_data[0]) if pt_data else 'empty'}")
else:
logger.debug(f"Loaded .pt shard: {self.shard_path.name}, "
f"type={type(pt_data)}")
return pt_data
else:
raise ValueError(f"Unsupported format: {self.shard_path.suffix}")
@property
def data(self) -> Union[np.ndarray, torch.Tensor, List]:
"""
Lazily load and return shard data.
Returns:
Shard data (memory-mapped, tensor, or list)
"""
if not self._data_loaded:
self._data = self._load_data()
self._data_loaded = True
assert self._data is not None, "Data should be loaded"
return self._data
def get_batch(
self,
start: int,
size: int,
to_gpu: bool = True
) -> Union[torch.Tensor, List]:
"""
Extract a batch from the shard with optional GPU transfer.
Args:
start: Starting index in shard
size: Batch size (number of samples)
to_gpu: Transfer batch to GPU immediately (default: True)
Returns:
Batch tensor on specified device (for tensor/array data) or list (for text data)
"""
end = min(start + size, len(self.data))
# Extract batch slice (zero-copy for memmap)
batch_data = self.data[start:end]
# Handle different data types
if isinstance(batch_data, list):
# Return list as-is (text data)
return batch_data
elif isinstance(batch_data, np.ndarray):
# Convert numpy to tensor
batch = torch.from_numpy(np.array(batch_data))
else:
# Already a tensor
batch = batch_data
# Transfer to GPU if requested
if to_gpu and self.device.type == 'cuda':
batch = batch.to(self.device, non_blocking=True)
return batch
def __len__(self) -> int:
"""
Get number of samples in shard.
Returns:
Number of samples (first dimension of data)
"""
return len(self.data)
def prefetch_next(self, next_shard_path: Union[str, Path]) -> None:
"""
Prefetch next shard in background thread to minimize latency.
Args:
next_shard_path: Path to next shard file to prefetch
"""
if not self.prefetch_enabled:
return
next_path = Path(next_shard_path)
# Wait for previous prefetch to complete
if self._prefetch_thread is not None:
self._prefetch_thread.join()
# Start background prefetch
self._prefetch_path = next_path
self._prefetch_thread = threading.Thread(
target=self._background_prefetch,
daemon=True
)
self._prefetch_thread.start()
logger.debug(f"Started prefetch for {next_path.name}")
def _background_prefetch(self) -> None:
"""
Background worker for prefetching next shard.
"""
try:
if self._prefetch_path is None:
return
if self._prefetch_path.suffix == ".npy":
data = np.load(self._prefetch_path, mmap_mode='r')
else:
data = torch.load(self._prefetch_path, map_location='cpu')
with self._prefetch_lock:
self._prefetch_data = data
logger.debug(f"Prefetch completed: {self._prefetch_path.name}")
except Exception as e:
logger.error(f"Prefetch failed for {self._prefetch_path}: {e}")
with self._prefetch_lock:
self._prefetch_data = None
def swap_to_prefetched(self) -> bool:
"""
Swap current shard data with prefetched data.
Returns:
True if swap successful, False if no prefetched data available
"""
if self._prefetch_thread is not None:
self._prefetch_thread.join()
with self._prefetch_lock:
if self._prefetch_data is not None and self._prefetch_path is not None:
self._data = self._prefetch_data
self.shard_path = self._prefetch_path
self._data_loaded = True
# Clear prefetch state
self._prefetch_data = None
self._prefetch_path = None
self._prefetch_thread = None
logger.info(f"Swapped to prefetched shard: {self.shard_path.name}")
return True
return False
def close(self) -> None:
"""
Clean up resources and wait for background threads.
"""
if self._prefetch_thread is not None:
self._prefetch_thread.join()
# Clear references to allow garbage collection
self._data = None
self._prefetch_data = None
self._data_loaded = False
logger.debug(f"Closed ShardLoader for {self.shard_path.name}")
def __enter__(self) -> "ShardLoader":
"""Context manager entry."""
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""Context manager exit."""
self.close()
class MultiShardLoader:
"""
Manages multiple shards with automatic cycling and prefetching.
Designed for infinite streaming over a list of shard files.
Automatically prefetches next shard while processing current one.
Example:
>>> shard_paths = ["shard_00.npy", "shard_01.npy", "shard_02.npy"]
>>> loader = MultiShardLoader(shard_paths, device="cuda")
>>> for batch in loader.iter_batches(batch_size=32, num_batches=1000):
... process(batch)
"""
def __init__(
self,
shard_paths: Sequence[Union[str, Path]],
device: str = "cuda",
prefetch: bool = True,
cycle: bool = True,
) -> None:
"""
Initialize multi-shard loader.
Args:
shard_paths: List of paths to shard files
device: Target device for tensors (default: "cuda")
prefetch: Enable prefetching (default: True)
cycle: Cycle through shards infinitely (default: True)
"""
assert len(shard_paths) > 0, "Must provide at least one shard path"
self.shard_paths = [Path(p) for p in shard_paths]
self.device = device
self.prefetch_enabled = prefetch
self.cycle = cycle
# Verify all shard files exist
for path in self.shard_paths:
assert path.exists(), f"Shard file not found: {path}"
self.current_shard_idx = 0
self.current_loader: Optional[ShardLoader] = None
logger.info(f"Initialized MultiShardLoader with {len(self.shard_paths)} shards")
def _load_shard(self, idx: int) -> ShardLoader:
"""
Load shard at given index.
Args:
idx: Index into shard_paths list
Returns:
ShardLoader for the shard
"""
return ShardLoader(
self.shard_paths[idx],
device=self.device,
prefetch=self.prefetch_enabled
)
def _next_shard_idx(self) -> Optional[int]:
"""
Get next shard index (with optional cycling).
Returns:
Next shard index, or None if at end and not cycling
"""
next_idx = self.current_shard_idx + 1
if next_idx >= len(self.shard_paths):
if self.cycle:
return 0
else:
return None
return next_idx
def iter_batches(
self,
batch_size: int,
num_batches: Optional[int] = None,
to_gpu: bool = True,
) -> Iterator[Union[torch.Tensor, List]]:
"""
Iterate over batches from all shards.
Args:
batch_size: Number of samples per batch
num_batches: Maximum number of batches (None = infinite)
to_gpu: Transfer batches to GPU (default: True)
Yields:
Batches as tensors on specified device
"""
batches_yielded = 0
# Load first shard
if self.current_loader is None:
self.current_loader = self._load_shard(self.current_shard_idx)
# Prefetch next shard
next_idx = self._next_shard_idx()
if next_idx is not None and self.prefetch_enabled:
self.current_loader.prefetch_next(self.shard_paths[next_idx])
current_pos = 0
while True:
# Check if we've yielded enough batches
if num_batches is not None and batches_yielded >= num_batches:
break
# Check if current shard is exhausted
if current_pos >= len(self.current_loader):
# Move to next shard
next_idx = self._next_shard_idx()
if next_idx is None:
# No more shards and not cycling
break
# Swap to prefetched shard or load new one
if not self.current_loader.swap_to_prefetched():
self.current_loader.close()
self.current_loader = self._load_shard(next_idx)
self.current_shard_idx = next_idx
current_pos = 0
# Prefetch next shard
next_next_idx = self._next_shard_idx()
if next_next_idx is not None and self.prefetch_enabled:
self.current_loader.prefetch_next(self.shard_paths[next_next_idx])
# Get batch from current shard
batch = self.current_loader.get_batch(
start=current_pos,
size=batch_size,
to_gpu=to_gpu
)
current_pos += batch_size
batches_yielded += 1
yield batch
def close(self) -> None:
"""
Clean up all resources.
"""
if self.current_loader is not None:
self.current_loader.close()
logger.info("Closed MultiShardLoader")
def __enter__(self) -> "MultiShardLoader":
"""Context manager entry."""
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""Context manager exit."""
self.close()