""" Data loading and dataset management. Single Responsibility: Only handles data loading and dataset creation. Open/Closed: Can extend with new dataset formats without modifying existing code. Optimizations: - Lazy loading for memory efficiency - Sequence length bucketing for reduced padding overhead - LRU cache for batch files """ import gc import random import torch import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader, ConcatDataset, Sampler from pathlib import Path from typing import List, Dict, Any, Optional, Callable, Tuple, Iterator from .utils import log # ============================================================ # Sequence Length Bucketing (Reduces padding overhead) # ============================================================ class BucketBatchSampler(Sampler[List[int]]): """ Batch sampler that groups samples by sequence length into buckets. This reduces padding overhead by batching similar-length sequences together. Based on TensorFlow's bucket_by_sequence_length concept. Benefits: - Reduces wasted computation on padding tokens - More consistent memory usage per batch - Can improve training speed by 10-30% Args: lengths: List of sequence lengths for each sample batch_size: Number of samples per batch bucket_boundaries: Length boundaries for buckets (auto-computed if None) shuffle: Whether to shuffle within buckets drop_last: Whether to drop incomplete batches """ def __init__( self, lengths: List[int], batch_size: int, bucket_boundaries: Optional[List[int]] = None, shuffle: bool = True, drop_last: bool = False, ): self.lengths = lengths self.batch_size = batch_size self.shuffle = shuffle self.drop_last = drop_last # Auto-compute bucket boundaries if not provided if bucket_boundaries is None: # Create ~10 buckets based on length distribution sorted_lens = sorted(lengths) n = len(sorted_lens) bucket_boundaries = [ sorted_lens[int(n * p)] for p in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] ] # Remove duplicates and sort bucket_boundaries = sorted(set(bucket_boundaries)) self.bucket_boundaries = bucket_boundaries # Assign samples to buckets self.buckets: Dict[int, List[int]] = {i: [] for i in range(len(bucket_boundaries) + 1)} for idx, length in enumerate(lengths): bucket_id = self._get_bucket_id(length) self.buckets[bucket_id].append(idx) def _get_bucket_id(self, length: int) -> int: """Find which bucket a length belongs to.""" for i, boundary in enumerate(self.bucket_boundaries): if length <= boundary: return i return len(self.bucket_boundaries) def __iter__(self) -> Iterator[List[int]]: """Generate batches from buckets.""" # Collect all batches from all buckets all_batches = [] for bucket_id, indices in self.buckets.items(): if not indices: continue # Shuffle within bucket bucket_indices = indices.copy() if self.shuffle: random.shuffle(bucket_indices) # Create batches for i in range(0, len(bucket_indices), self.batch_size): batch = bucket_indices[i:i + self.batch_size] if len(batch) == self.batch_size or not self.drop_last: all_batches.append(batch) # Shuffle batches across buckets if self.shuffle: random.shuffle(all_batches) yield from all_batches def __len__(self) -> int: """Return total number of batches.""" total = 0 for indices in self.buckets.values(): n_batches = len(indices) // self.batch_size if not self.drop_last and len(indices) % self.batch_size != 0: n_batches += 1 total += n_batches return total # ============================================================ # Lazy Sharded Dataset (Memory Efficient) # ============================================================ class LazyShardedDataset(Dataset): """ Memory-efficient dataset that loads batch files on-demand. Instead of loading all data into memory, maintains an index of which sample is in which batch file, and loads batches as needed. Also stores approximate sequence lengths for bucketing optimization. """ def __init__( self, batch_files: List[Path], tokenizer, max_audio_len: int = 500, max_seq_len: int = 2048, cache_size: int = 3, # Number of batches to keep in memory verbose: bool = True ): self.batch_files = batch_files self.tokenizer = tokenizer self.max_audio = max_audio_len * 5 self.max_seq_len = max_seq_len self.cache_size = cache_size self.verbose = verbose # Build index: sample_idx -> (batch_idx, local_idx) self.index: List[Tuple[int, int]] = [] self.batch_sizes: List[int] = [] # Store sequence lengths for bucketing self.sequence_lengths: List[int] = [] if verbose: log(f" Indexing {len(batch_files)} batch files...") for batch_idx, bf in enumerate(batch_files): # Quick load to get size and lengths data = torch.load(bf, map_location="cpu", weights_only=False) batch_size = len(data) self.batch_sizes.append(batch_size) for local_idx in range(batch_size): self.index.append((batch_idx, local_idx)) # Estimate sequence length (SNAC tokens dominate) item = data[local_idx] snac_len = len(item.get("snac_tokens", [])) // 7 * 7 self.sequence_lengths.append(snac_len) del data if (batch_idx + 1) % 100 == 0: gc.collect() if verbose: log(f" Indexed {batch_idx+1}/{len(batch_files)} batches ({len(self.index):,} samples)") if verbose: log(f" Total indexed: {len(self.index):,} samples") # LRU cache for loaded batches self._cache: Dict[int, List[Dict]] = {} self._cache_order: List[int] = [] def get_sequence_lengths(self) -> List[int]: """Return sequence lengths for bucketing.""" return self.sequence_lengths def __len__(self) -> int: return len(self.index) def _load_batch(self, batch_idx: int) -> List[Dict]: """Load a batch file, using cache.""" if batch_idx in self._cache: # Move to end of cache order (most recently used) self._cache_order.remove(batch_idx) self._cache_order.append(batch_idx) return self._cache[batch_idx] # Load from disk data = torch.load(self.batch_files[batch_idx], map_location="cpu", weights_only=False) # Add to cache self._cache[batch_idx] = data self._cache_order.append(batch_idx) # Evict old batches if cache is full while len(self._cache_order) > self.cache_size: old_idx = self._cache_order.pop(0) del self._cache[old_idx] gc.collect() return data def __getitem__(self, idx: int) -> Dict[str, Any]: batch_idx, local_idx = self.index[idx] batch_data = self._load_batch(batch_idx) item = batch_data[local_idx] # Process item (same as InterleavedDataset) whisper = item["whisper_features"][:self.max_audio] text_tokens = self._get_text_tokens(item) snac_list = self._get_snac_tokens(item) word_alignments = item.get("word_alignments", None) answer_text = item.get("answer", "") return { "whisper": whisper, "text_tokens": text_tokens, "snac_tokens": snac_list, "word_alignments": word_alignments, "answer_text": answer_text } def _get_text_tokens(self, item: Dict[str, Any]) -> List[int]: if "text_tokens" in item and len(item["text_tokens"]) > 0: tt = item["text_tokens"] return tt.tolist() if hasattr(tt, 'tolist') else list(tt) text = item.get("answer", item.get("text", "")) if isinstance(text, str) and len(text) > 0: return self.tokenizer.encode(text, add_special_tokens=False) return [] def _get_snac_tokens(self, item: Dict[str, Any]) -> List[int]: snac = item["snac_tokens"] snac_len = (len(snac) // 7) * 7 snac = snac[:snac_len] if snac_len > 0 else snac[:7] return snac.tolist() if hasattr(snac, 'tolist') else list(snac) # ============================================================ # Sharded Dataset Loading (Industry Standard) # ============================================================ class ShardedDatasetLoader: """ Load datasets from single files or sharded batch directories. Supports: - Single .pt file - Batch directory with batch_*.pt files - Mixed (base file + batch files) Dependency Inversion: Uses abstract path interface, not concrete file operations. """ def __init__(self, verbose: bool = True): self.verbose = verbose def load(self, path: str) -> List[Dict[str, Any]]: """Load dataset from path (file or directory).""" path = Path(path) samples = [] # Case 1: Explicit batches directory if path.name.endswith('.batches') and path.is_dir(): samples = self._load_batches_dir(path) # Case 2: Single file (possibly with batches) elif path.exists() and path.is_file(): samples = self._load_file_with_batches(path) # Case 3: Only batches directory exists else: batches_dir = Path(f"{path}.batches") if batches_dir.exists() and batches_dir.is_dir(): samples = self._load_batches_dir(batches_dir) else: raise FileNotFoundError(f"No dataset found at {path}") return samples def _load_batches_dir(self, batches_dir: Path) -> List[Dict[str, Any]]: """Load all batch files from a directory.""" batch_files = sorted(batches_dir.glob("batch_*.pt")) if not batch_files: raise FileNotFoundError(f"No batch files in {batches_dir}") if self.verbose: log(f" Loading {len(batch_files)} batch files from {batches_dir.name}/") samples = [] for i, bf in enumerate(batch_files): items = torch.load(bf, map_location="cpu", weights_only=False) samples.extend(items) del items if (i + 1) % 100 == 0: gc.collect() if self.verbose: log(f" Loaded {i+1}/{len(batch_files)} batches ({len(samples):,} samples)") return samples def _load_file_with_batches(self, path: Path) -> List[Dict[str, Any]]: """Load base file and any associated batch files.""" # Load base file base = torch.load(path, map_location="cpu", weights_only=False, mmap=True) samples = list(base) del base gc.collect() if self.verbose: log(f" Base file: {len(samples):,} samples") # Check for batch files batches_dir = Path(f"{path}.batches") if batches_dir.exists() and batches_dir.is_dir(): batch_files = sorted(batches_dir.glob("batch_*.pt")) if batch_files: if self.verbose: log(f" Found {len(batch_files)} batch files") for i, bf in enumerate(batch_files): items = torch.load(bf, map_location="cpu", weights_only=False) samples.extend(items) del items if (i + 1) % 100 == 0: gc.collect() if self.verbose: log(f" Loaded {i+1}/{len(batch_files)} batches ({len(samples):,} total)") return samples def load_sharded_dataset(path: str, verbose: bool = True) -> List[Dict[str, Any]]: """Convenience function for loading sharded datasets.""" loader = ShardedDatasetLoader(verbose=verbose) return loader.load(path) # ============================================================ # Dataset Classes # ============================================================ class InterleavedDataset(Dataset): """ Dataset that prepares samples for interleaved training. Single Responsibility: Only handles sample access and preprocessing. """ def __init__( self, data: List[Dict[str, Any]], tokenizer, max_audio_len: int = 500, max_seq_len: int = 2048 ): self.data = data self.tokenizer = tokenizer self.max_audio = max_audio_len * 5 # Account for downsampling self.max_seq_len = max_seq_len def __len__(self) -> int: return len(self.data) def __getitem__(self, idx: int) -> Dict[str, Any]: item = self.data[idx] # Whisper features (truncate if needed) whisper = item["whisper_features"][:self.max_audio] # Text tokens - use pre-computed if available text_tokens = self._get_text_tokens(item) # SNAC tokens (ensure multiple of 7) snac_list = self._get_snac_tokens(item) # Optional fields word_alignments = item.get("word_alignments", None) answer_text = item.get("answer", "") return { "whisper": whisper, "text_tokens": text_tokens, "snac_tokens": snac_list, "word_alignments": word_alignments, "answer_text": answer_text } def _get_text_tokens(self, item: Dict[str, Any]) -> List[int]: """Extract or generate text tokens from item.""" if "text_tokens" in item and len(item["text_tokens"]) > 0: tt = item["text_tokens"] return tt.tolist() if hasattr(tt, 'tolist') else list(tt) text = item.get("answer", item.get("text", "")) if isinstance(text, str) and len(text) > 0: return self.tokenizer.encode(text, add_special_tokens=False) return [] def _get_snac_tokens(self, item: Dict[str, Any]) -> List[int]: """Extract SNAC tokens, ensuring multiple of 7.""" snac = item["snac_tokens"] snac_len = (len(snac) // 7) * 7 snac = snac[:snac_len] if snac_len > 0 else snac[:7] return snac.tolist() if hasattr(snac, 'tolist') else list(snac) # ============================================================ # Collate Functions # ============================================================ def collate_simple(batch: List[Dict[str, Any]]) -> Dict[str, Any]: """ Simple collate that pads whisper features. Interleaving is done in training loop for correct text_ratio. """ max_w = max(b["whisper"].shape[0] for b in batch) max_w = ((max_w + 4) // 5) * 5 # Align to downsample factor whisper_batch = [] raw_data = [] for b in batch: w = b["whisper"] w_pad = F.pad(w, (0, 0, 0, max_w - w.shape[0])) whisper_batch.append(w_pad) raw_data.append({ "text_tokens": b["text_tokens"], "snac_tokens": b["snac_tokens"], "word_alignments": b.get("word_alignments"), "answer_text": b.get("answer_text", "") }) return { "whisper": torch.stack(whisper_batch), "raw_data": raw_data } # ============================================================ # DataLoader Factory # ============================================================ class DataLoaderFactory: """ Factory for creating DataLoaders with optimal settings. Single Responsibility: Only handles DataLoader creation. Open/Closed: Can extend multiprocessing strategies without modification. Optimizations: - Sequence length bucketing for reduced padding - Optimal worker count based on system resources """ @staticmethod def get_optimal_workers() -> int: """Calculate optimal num_workers based on system resources.""" if not torch.cuda.is_available(): return 0 try: import os num_gpus = torch.cuda.device_count() cpu_cores = os.cpu_count() or 4 max_workers = max(1, cpu_cores // 2) # 2 workers per GPU, capped by CPU ideal_workers = num_gpus * 2 num_workers = min(ideal_workers, max_workers) # Check VRAM pressure try: free_bytes, _ = torch.cuda.mem_get_info(0) if free_bytes / 1024**3 < 5: num_workers = min(num_workers, 1) except Exception: pass return max(0, num_workers) except Exception: return 2 @classmethod def create( cls, dataset: Dataset, batch_size: int, shuffle: bool = True, collate_fn: Callable = collate_simple, verbose: bool = True, use_bucketing: bool = True, ) -> DataLoader: """ Create DataLoader with optimal settings. Args: dataset: The dataset to load from batch_size: Batch size shuffle: Whether to shuffle data collate_fn: Function to collate samples verbose: Whether to log details use_bucketing: Use sequence length bucketing (reduces padding overhead) """ optimal_workers = cls.get_optimal_workers() # Try to use sequence length bucketing batch_sampler = None if use_bucketing and shuffle: try: # Get sequence lengths from dataset lengths = cls._get_sequence_lengths(dataset) if lengths and len(lengths) > batch_size * 10: # Only bucket if enough samples batch_sampler = BucketBatchSampler( lengths=lengths, batch_size=batch_size, shuffle=True, drop_last=False, ) if verbose: log(f"[DataLoader] Using sequence length bucketing ({len(batch_sampler)} batches)") except Exception as e: if verbose: log(f"[DataLoader] Bucketing failed: {e}, using standard batching") # Try different multiprocessing methods for mp_method in ['spawn', 'fork', None]: try: loader = cls._try_create_loader( dataset, batch_size, shuffle, collate_fn, optimal_workers, mp_method, batch_sampler ) if loader is not None: if verbose and mp_method: log(f"[DataLoader] Using '{mp_method}' with {optimal_workers} workers") elif verbose: log("[DataLoader] Using single-process mode") return loader except Exception as e: if verbose: log(f"[DataLoader] {mp_method} failed: {str(e)[:50]}...") continue # Final fallback if verbose: log("[DataLoader] Fallback to num_workers=0") return DataLoader( dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn, num_workers=0, pin_memory=True ) @staticmethod def _get_sequence_lengths(dataset: Dataset) -> Optional[List[int]]: """Extract sequence lengths from dataset for bucketing.""" # Check if dataset has get_sequence_lengths method if hasattr(dataset, 'get_sequence_lengths'): return dataset.get_sequence_lengths() # For ConcatDataset, try to combine lengths from components if isinstance(dataset, ConcatDataset): lengths = [] for ds in dataset.datasets: if hasattr(ds, 'get_sequence_lengths'): lengths.extend(ds.get_sequence_lengths()) else: return None # Can't get lengths for all, skip bucketing return lengths return None @staticmethod def _try_create_loader( dataset: Dataset, batch_size: int, shuffle: bool, collate_fn: Callable, num_workers: int, mp_method: Optional[str], batch_sampler: Optional[Sampler] = None, ) -> Optional[DataLoader]: """Try to create DataLoader with given settings.""" import multiprocessing # When using batch_sampler, don't set batch_size, shuffle, or sampler common_kwargs = { 'collate_fn': collate_fn, 'pin_memory': True, } if batch_sampler is not None: common_kwargs['batch_sampler'] = batch_sampler else: common_kwargs['batch_size'] = batch_size common_kwargs['shuffle'] = shuffle if mp_method and num_workers > 0: mp_context = multiprocessing.get_context(mp_method) loader = DataLoader( dataset, num_workers=num_workers, multiprocessing_context=mp_context, persistent_workers=True, **common_kwargs ) # Test if it works test_iter = iter(loader) del test_iter return loader else: return DataLoader( dataset, num_workers=0, **common_kwargs ) def create_dataloader( dataset: Dataset, batch_size: int, shuffle: bool = True, verbose: bool = True, use_bucketing: bool = True, ) -> DataLoader: """ Convenience function for creating DataLoaders. Args: dataset: Dataset to load from batch_size: Batch size shuffle: Whether to shuffle verbose: Whether to log use_bucketing: Use sequence length bucketing """ return DataLoaderFactory.create( dataset, batch_size, shuffle, collate_fn=collate_simple, verbose=verbose, use_bucketing=use_bucketing, ) # ============================================================ # Dataset Loading Pipeline # ============================================================ def load_datasets( paths: List[str], tokenizer, max_audio_len: int = 500, max_seq_len: int = 2048, verbose: bool = True, lazy_loading: bool = True, # Use memory-efficient lazy loading ) -> Dataset: """ Load and combine multiple datasets. Args: paths: List of dataset paths (files or directories) tokenizer: Tokenizer for text encoding max_audio_len: Maximum audio length max_seq_len: Maximum sequence length verbose: Whether to log progress lazy_loading: Use memory-efficient lazy loading for batch directories Returns: Combined dataset """ if verbose: log("\nLoading datasets (lazy loading enabled)..." if lazy_loading else "\nLoading datasets...") all_datasets = [] for path in paths: path = Path(path) try: # Check if this is a batch directory (use lazy loading) batches_dir = None if path.name.endswith('.batches') and path.is_dir(): batches_dir = path elif Path(f"{path}.batches").exists(): batches_dir = Path(f"{path}.batches") elif path.is_dir(): batch_files = list(path.glob("batch_*.pt")) if batch_files: batches_dir = path if batches_dir and lazy_loading: # Use memory-efficient lazy loading batch_files = sorted(batches_dir.glob("batch_*.pt")) if batch_files: dataset = LazyShardedDataset( batch_files, tokenizer, max_audio_len=max_audio_len, max_seq_len=max_seq_len, cache_size=5, # Keep 5 batches in memory verbose=verbose ) all_datasets.append(dataset) if verbose: log(f" {path.name}: {len(dataset):,} samples (lazy loading)") else: # Fall back to full loading for single files loader = ShardedDatasetLoader(verbose=verbose) data = loader.load(str(path)) if data: dataset = InterleavedDataset( data, tokenizer, max_audio_len=max_audio_len, max_seq_len=max_seq_len ) all_datasets.append(dataset) if verbose: log(f" {path.name}: {len(data):,} samples") except FileNotFoundError as e: if verbose: log(f" [WARN] {e}") if not all_datasets: raise ValueError("No datasets loaded!") return ConcatDataset(all_datasets) if len(all_datasets) > 1 else all_datasets[0]