| """ |
| Data loading and dataset management. |
| |
| Single Responsibility: Only handles data loading and dataset creation. |
| Open/Closed: Can extend with new dataset formats without modifying existing code. |
| |
| Optimizations: |
| - Lazy loading for memory efficiency |
| - Sequence length bucketing for reduced padding overhead |
| - LRU cache for batch files |
| """ |
|
|
| import gc |
| import random |
| import torch |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader, ConcatDataset, Sampler |
| from pathlib import Path |
| from typing import List, Dict, Any, Optional, Callable, Tuple, Iterator |
| from .utils import log |
|
|
|
|
| |
| |
| |
| class BucketBatchSampler(Sampler[List[int]]): |
| """ |
| Batch sampler that groups samples by sequence length into buckets. |
| |
| This reduces padding overhead by batching similar-length sequences together. |
| Based on TensorFlow's bucket_by_sequence_length concept. |
| |
| Benefits: |
| - Reduces wasted computation on padding tokens |
| - More consistent memory usage per batch |
| - Can improve training speed by 10-30% |
| |
| Args: |
| lengths: List of sequence lengths for each sample |
| batch_size: Number of samples per batch |
| bucket_boundaries: Length boundaries for buckets (auto-computed if None) |
| shuffle: Whether to shuffle within buckets |
| drop_last: Whether to drop incomplete batches |
| """ |
|
|
| def __init__( |
| self, |
| lengths: List[int], |
| batch_size: int, |
| bucket_boundaries: Optional[List[int]] = None, |
| shuffle: bool = True, |
| drop_last: bool = False, |
| ): |
| self.lengths = lengths |
| self.batch_size = batch_size |
| self.shuffle = shuffle |
| self.drop_last = drop_last |
|
|
| |
| if bucket_boundaries is None: |
| |
| sorted_lens = sorted(lengths) |
| n = len(sorted_lens) |
| bucket_boundaries = [ |
| sorted_lens[int(n * p)] for p in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] |
| ] |
| |
| bucket_boundaries = sorted(set(bucket_boundaries)) |
|
|
| self.bucket_boundaries = bucket_boundaries |
|
|
| |
| self.buckets: Dict[int, List[int]] = {i: [] for i in range(len(bucket_boundaries) + 1)} |
| for idx, length in enumerate(lengths): |
| bucket_id = self._get_bucket_id(length) |
| self.buckets[bucket_id].append(idx) |
|
|
| def _get_bucket_id(self, length: int) -> int: |
| """Find which bucket a length belongs to.""" |
| for i, boundary in enumerate(self.bucket_boundaries): |
| if length <= boundary: |
| return i |
| return len(self.bucket_boundaries) |
|
|
| def __iter__(self) -> Iterator[List[int]]: |
| """Generate batches from buckets.""" |
| |
| all_batches = [] |
|
|
| for bucket_id, indices in self.buckets.items(): |
| if not indices: |
| continue |
|
|
| |
| bucket_indices = indices.copy() |
| if self.shuffle: |
| random.shuffle(bucket_indices) |
|
|
| |
| for i in range(0, len(bucket_indices), self.batch_size): |
| batch = bucket_indices[i:i + self.batch_size] |
| if len(batch) == self.batch_size or not self.drop_last: |
| all_batches.append(batch) |
|
|
| |
| if self.shuffle: |
| random.shuffle(all_batches) |
|
|
| yield from all_batches |
|
|
| def __len__(self) -> int: |
| """Return total number of batches.""" |
| total = 0 |
| for indices in self.buckets.values(): |
| n_batches = len(indices) // self.batch_size |
| if not self.drop_last and len(indices) % self.batch_size != 0: |
| n_batches += 1 |
| total += n_batches |
| return total |
|
|
|
|
| |
| |
| |
| class LazyShardedDataset(Dataset): |
| """ |
| Memory-efficient dataset that loads batch files on-demand. |
| |
| Instead of loading all data into memory, maintains an index of |
| which sample is in which batch file, and loads batches as needed. |
| |
| Also stores approximate sequence lengths for bucketing optimization. |
| """ |
|
|
| def __init__( |
| self, |
| batch_files: List[Path], |
| tokenizer, |
| max_audio_len: int = 500, |
| max_seq_len: int = 2048, |
| cache_size: int = 3, |
| verbose: bool = True |
| ): |
| self.batch_files = batch_files |
| self.tokenizer = tokenizer |
| self.max_audio = max_audio_len * 5 |
| self.max_seq_len = max_seq_len |
| self.cache_size = cache_size |
| self.verbose = verbose |
|
|
| |
| self.index: List[Tuple[int, int]] = [] |
| self.batch_sizes: List[int] = [] |
| |
| self.sequence_lengths: List[int] = [] |
|
|
| if verbose: |
| log(f" Indexing {len(batch_files)} batch files...") |
|
|
| for batch_idx, bf in enumerate(batch_files): |
| |
| data = torch.load(bf, map_location="cpu", weights_only=False) |
| batch_size = len(data) |
| self.batch_sizes.append(batch_size) |
|
|
| for local_idx in range(batch_size): |
| self.index.append((batch_idx, local_idx)) |
| |
| item = data[local_idx] |
| snac_len = len(item.get("snac_tokens", [])) // 7 * 7 |
| self.sequence_lengths.append(snac_len) |
|
|
| del data |
| if (batch_idx + 1) % 100 == 0: |
| gc.collect() |
| if verbose: |
| log(f" Indexed {batch_idx+1}/{len(batch_files)} batches ({len(self.index):,} samples)") |
|
|
| if verbose: |
| log(f" Total indexed: {len(self.index):,} samples") |
|
|
| |
| self._cache: Dict[int, List[Dict]] = {} |
| self._cache_order: List[int] = [] |
|
|
| def get_sequence_lengths(self) -> List[int]: |
| """Return sequence lengths for bucketing.""" |
| return self.sequence_lengths |
|
|
| def __len__(self) -> int: |
| return len(self.index) |
|
|
| def _load_batch(self, batch_idx: int) -> List[Dict]: |
| """Load a batch file, using cache.""" |
| if batch_idx in self._cache: |
| |
| self._cache_order.remove(batch_idx) |
| self._cache_order.append(batch_idx) |
| return self._cache[batch_idx] |
|
|
| |
| data = torch.load(self.batch_files[batch_idx], map_location="cpu", weights_only=False) |
|
|
| |
| self._cache[batch_idx] = data |
| self._cache_order.append(batch_idx) |
|
|
| |
| while len(self._cache_order) > self.cache_size: |
| old_idx = self._cache_order.pop(0) |
| del self._cache[old_idx] |
| gc.collect() |
|
|
| return data |
|
|
| def __getitem__(self, idx: int) -> Dict[str, Any]: |
| batch_idx, local_idx = self.index[idx] |
| batch_data = self._load_batch(batch_idx) |
| item = batch_data[local_idx] |
|
|
| |
| whisper = item["whisper_features"][:self.max_audio] |
| text_tokens = self._get_text_tokens(item) |
| snac_list = self._get_snac_tokens(item) |
| word_alignments = item.get("word_alignments", None) |
| answer_text = item.get("answer", "") |
|
|
| return { |
| "whisper": whisper, |
| "text_tokens": text_tokens, |
| "snac_tokens": snac_list, |
| "word_alignments": word_alignments, |
| "answer_text": answer_text |
| } |
|
|
| def _get_text_tokens(self, item: Dict[str, Any]) -> List[int]: |
| if "text_tokens" in item and len(item["text_tokens"]) > 0: |
| tt = item["text_tokens"] |
| return tt.tolist() if hasattr(tt, 'tolist') else list(tt) |
| text = item.get("answer", item.get("text", "")) |
| if isinstance(text, str) and len(text) > 0: |
| return self.tokenizer.encode(text, add_special_tokens=False) |
| return [] |
|
|
| def _get_snac_tokens(self, item: Dict[str, Any]) -> List[int]: |
| snac = item["snac_tokens"] |
| snac_len = (len(snac) // 7) * 7 |
| snac = snac[:snac_len] if snac_len > 0 else snac[:7] |
| return snac.tolist() if hasattr(snac, 'tolist') else list(snac) |
|
|
|
|
| |
| |
| |
| class ShardedDatasetLoader: |
| """ |
| Load datasets from single files or sharded batch directories. |
| |
| Supports: |
| - Single .pt file |
| - Batch directory with batch_*.pt files |
| - Mixed (base file + batch files) |
| |
| Dependency Inversion: Uses abstract path interface, not concrete file operations. |
| """ |
|
|
| def __init__(self, verbose: bool = True): |
| self.verbose = verbose |
|
|
| def load(self, path: str) -> List[Dict[str, Any]]: |
| """Load dataset from path (file or directory).""" |
| path = Path(path) |
| samples = [] |
|
|
| |
| if path.name.endswith('.batches') and path.is_dir(): |
| samples = self._load_batches_dir(path) |
|
|
| |
| elif path.exists() and path.is_file(): |
| samples = self._load_file_with_batches(path) |
|
|
| |
| else: |
| batches_dir = Path(f"{path}.batches") |
| if batches_dir.exists() and batches_dir.is_dir(): |
| samples = self._load_batches_dir(batches_dir) |
| else: |
| raise FileNotFoundError(f"No dataset found at {path}") |
|
|
| return samples |
|
|
| def _load_batches_dir(self, batches_dir: Path) -> List[Dict[str, Any]]: |
| """Load all batch files from a directory.""" |
| batch_files = sorted(batches_dir.glob("batch_*.pt")) |
| if not batch_files: |
| raise FileNotFoundError(f"No batch files in {batches_dir}") |
|
|
| if self.verbose: |
| log(f" Loading {len(batch_files)} batch files from {batches_dir.name}/") |
|
|
| samples = [] |
| for i, bf in enumerate(batch_files): |
| items = torch.load(bf, map_location="cpu", weights_only=False) |
| samples.extend(items) |
| del items |
| if (i + 1) % 100 == 0: |
| gc.collect() |
| if self.verbose: |
| log(f" Loaded {i+1}/{len(batch_files)} batches ({len(samples):,} samples)") |
|
|
| return samples |
|
|
| def _load_file_with_batches(self, path: Path) -> List[Dict[str, Any]]: |
| """Load base file and any associated batch files.""" |
| |
| base = torch.load(path, map_location="cpu", weights_only=False, mmap=True) |
| samples = list(base) |
| del base |
| gc.collect() |
|
|
| if self.verbose: |
| log(f" Base file: {len(samples):,} samples") |
|
|
| |
| batches_dir = Path(f"{path}.batches") |
| if batches_dir.exists() and batches_dir.is_dir(): |
| batch_files = sorted(batches_dir.glob("batch_*.pt")) |
| if batch_files: |
| if self.verbose: |
| log(f" Found {len(batch_files)} batch files") |
| for i, bf in enumerate(batch_files): |
| items = torch.load(bf, map_location="cpu", weights_only=False) |
| samples.extend(items) |
| del items |
| if (i + 1) % 100 == 0: |
| gc.collect() |
| if self.verbose: |
| log(f" Loaded {i+1}/{len(batch_files)} batches ({len(samples):,} total)") |
|
|
| return samples |
|
|
|
|
| def load_sharded_dataset(path: str, verbose: bool = True) -> List[Dict[str, Any]]: |
| """Convenience function for loading sharded datasets.""" |
| loader = ShardedDatasetLoader(verbose=verbose) |
| return loader.load(path) |
|
|
|
|
| |
| |
| |
| class InterleavedDataset(Dataset): |
| """ |
| Dataset that prepares samples for interleaved training. |
| |
| Single Responsibility: Only handles sample access and preprocessing. |
| """ |
|
|
| def __init__( |
| self, |
| data: List[Dict[str, Any]], |
| tokenizer, |
| max_audio_len: int = 500, |
| max_seq_len: int = 2048 |
| ): |
| self.data = data |
| self.tokenizer = tokenizer |
| self.max_audio = max_audio_len * 5 |
| self.max_seq_len = max_seq_len |
|
|
| def __len__(self) -> int: |
| return len(self.data) |
|
|
| def __getitem__(self, idx: int) -> Dict[str, Any]: |
| item = self.data[idx] |
|
|
| |
| whisper = item["whisper_features"][:self.max_audio] |
|
|
| |
| text_tokens = self._get_text_tokens(item) |
|
|
| |
| snac_list = self._get_snac_tokens(item) |
|
|
| |
| word_alignments = item.get("word_alignments", None) |
| answer_text = item.get("answer", "") |
|
|
| return { |
| "whisper": whisper, |
| "text_tokens": text_tokens, |
| "snac_tokens": snac_list, |
| "word_alignments": word_alignments, |
| "answer_text": answer_text |
| } |
|
|
| def _get_text_tokens(self, item: Dict[str, Any]) -> List[int]: |
| """Extract or generate text tokens from item.""" |
| if "text_tokens" in item and len(item["text_tokens"]) > 0: |
| tt = item["text_tokens"] |
| return tt.tolist() if hasattr(tt, 'tolist') else list(tt) |
|
|
| text = item.get("answer", item.get("text", "")) |
| if isinstance(text, str) and len(text) > 0: |
| return self.tokenizer.encode(text, add_special_tokens=False) |
|
|
| return [] |
|
|
| def _get_snac_tokens(self, item: Dict[str, Any]) -> List[int]: |
| """Extract SNAC tokens, ensuring multiple of 7.""" |
| snac = item["snac_tokens"] |
| snac_len = (len(snac) // 7) * 7 |
| snac = snac[:snac_len] if snac_len > 0 else snac[:7] |
| return snac.tolist() if hasattr(snac, 'tolist') else list(snac) |
|
|
|
|
| |
| |
| |
| def collate_simple(batch: List[Dict[str, Any]]) -> Dict[str, Any]: |
| """ |
| Simple collate that pads whisper features. |
| Interleaving is done in training loop for correct text_ratio. |
| """ |
| max_w = max(b["whisper"].shape[0] for b in batch) |
| max_w = ((max_w + 4) // 5) * 5 |
|
|
| whisper_batch = [] |
| raw_data = [] |
|
|
| for b in batch: |
| w = b["whisper"] |
| w_pad = F.pad(w, (0, 0, 0, max_w - w.shape[0])) |
| whisper_batch.append(w_pad) |
| raw_data.append({ |
| "text_tokens": b["text_tokens"], |
| "snac_tokens": b["snac_tokens"], |
| "word_alignments": b.get("word_alignments"), |
| "answer_text": b.get("answer_text", "") |
| }) |
|
|
| return { |
| "whisper": torch.stack(whisper_batch), |
| "raw_data": raw_data |
| } |
|
|
|
|
| |
| |
| |
| class DataLoaderFactory: |
| """ |
| Factory for creating DataLoaders with optimal settings. |
| |
| Single Responsibility: Only handles DataLoader creation. |
| Open/Closed: Can extend multiprocessing strategies without modification. |
| |
| Optimizations: |
| - Sequence length bucketing for reduced padding |
| - Optimal worker count based on system resources |
| """ |
|
|
| @staticmethod |
| def get_optimal_workers() -> int: |
| """Calculate optimal num_workers based on system resources.""" |
| if not torch.cuda.is_available(): |
| return 0 |
|
|
| try: |
| import os |
| num_gpus = torch.cuda.device_count() |
| cpu_cores = os.cpu_count() or 4 |
| max_workers = max(1, cpu_cores // 2) |
|
|
| |
| ideal_workers = num_gpus * 2 |
| num_workers = min(ideal_workers, max_workers) |
|
|
| |
| try: |
| free_bytes, _ = torch.cuda.mem_get_info(0) |
| if free_bytes / 1024**3 < 5: |
| num_workers = min(num_workers, 1) |
| except Exception: |
| pass |
|
|
| return max(0, num_workers) |
| except Exception: |
| return 2 |
|
|
| @classmethod |
| def create( |
| cls, |
| dataset: Dataset, |
| batch_size: int, |
| shuffle: bool = True, |
| collate_fn: Callable = collate_simple, |
| verbose: bool = True, |
| use_bucketing: bool = True, |
| ) -> DataLoader: |
| """ |
| Create DataLoader with optimal settings. |
| |
| Args: |
| dataset: The dataset to load from |
| batch_size: Batch size |
| shuffle: Whether to shuffle data |
| collate_fn: Function to collate samples |
| verbose: Whether to log details |
| use_bucketing: Use sequence length bucketing (reduces padding overhead) |
| """ |
| optimal_workers = cls.get_optimal_workers() |
|
|
| |
| batch_sampler = None |
| if use_bucketing and shuffle: |
| try: |
| |
| lengths = cls._get_sequence_lengths(dataset) |
| if lengths and len(lengths) > batch_size * 10: |
| batch_sampler = BucketBatchSampler( |
| lengths=lengths, |
| batch_size=batch_size, |
| shuffle=True, |
| drop_last=False, |
| ) |
| if verbose: |
| log(f"[DataLoader] Using sequence length bucketing ({len(batch_sampler)} batches)") |
| except Exception as e: |
| if verbose: |
| log(f"[DataLoader] Bucketing failed: {e}, using standard batching") |
|
|
| |
| for mp_method in ['spawn', 'fork', None]: |
| try: |
| loader = cls._try_create_loader( |
| dataset, batch_size, shuffle, collate_fn, |
| optimal_workers, mp_method, batch_sampler |
| ) |
| if loader is not None: |
| if verbose and mp_method: |
| log(f"[DataLoader] Using '{mp_method}' with {optimal_workers} workers") |
| elif verbose: |
| log("[DataLoader] Using single-process mode") |
| return loader |
| except Exception as e: |
| if verbose: |
| log(f"[DataLoader] {mp_method} failed: {str(e)[:50]}...") |
| continue |
|
|
| |
| if verbose: |
| log("[DataLoader] Fallback to num_workers=0") |
| return DataLoader( |
| dataset, |
| batch_size=batch_size, |
| shuffle=shuffle, |
| collate_fn=collate_fn, |
| num_workers=0, |
| pin_memory=True |
| ) |
|
|
| @staticmethod |
| def _get_sequence_lengths(dataset: Dataset) -> Optional[List[int]]: |
| """Extract sequence lengths from dataset for bucketing.""" |
| |
| if hasattr(dataset, 'get_sequence_lengths'): |
| return dataset.get_sequence_lengths() |
|
|
| |
| if isinstance(dataset, ConcatDataset): |
| lengths = [] |
| for ds in dataset.datasets: |
| if hasattr(ds, 'get_sequence_lengths'): |
| lengths.extend(ds.get_sequence_lengths()) |
| else: |
| return None |
| return lengths |
|
|
| return None |
|
|
| @staticmethod |
| def _try_create_loader( |
| dataset: Dataset, |
| batch_size: int, |
| shuffle: bool, |
| collate_fn: Callable, |
| num_workers: int, |
| mp_method: Optional[str], |
| batch_sampler: Optional[Sampler] = None, |
| ) -> Optional[DataLoader]: |
| """Try to create DataLoader with given settings.""" |
| import multiprocessing |
|
|
| |
| common_kwargs = { |
| 'collate_fn': collate_fn, |
| 'pin_memory': True, |
| } |
|
|
| if batch_sampler is not None: |
| common_kwargs['batch_sampler'] = batch_sampler |
| else: |
| common_kwargs['batch_size'] = batch_size |
| common_kwargs['shuffle'] = shuffle |
|
|
| if mp_method and num_workers > 0: |
| mp_context = multiprocessing.get_context(mp_method) |
| loader = DataLoader( |
| dataset, |
| num_workers=num_workers, |
| multiprocessing_context=mp_context, |
| persistent_workers=True, |
| **common_kwargs |
| ) |
| |
| test_iter = iter(loader) |
| del test_iter |
| return loader |
| else: |
| return DataLoader( |
| dataset, |
| num_workers=0, |
| **common_kwargs |
| ) |
|
|
|
|
| def create_dataloader( |
| dataset: Dataset, |
| batch_size: int, |
| shuffle: bool = True, |
| verbose: bool = True, |
| use_bucketing: bool = True, |
| ) -> DataLoader: |
| """ |
| Convenience function for creating DataLoaders. |
| |
| Args: |
| dataset: Dataset to load from |
| batch_size: Batch size |
| shuffle: Whether to shuffle |
| verbose: Whether to log |
| use_bucketing: Use sequence length bucketing |
| """ |
| return DataLoaderFactory.create( |
| dataset, batch_size, shuffle, |
| collate_fn=collate_simple, verbose=verbose, |
| use_bucketing=use_bucketing, |
| ) |
|
|
|
|
| |
| |
| |
| def load_datasets( |
| paths: List[str], |
| tokenizer, |
| max_audio_len: int = 500, |
| max_seq_len: int = 2048, |
| verbose: bool = True, |
| lazy_loading: bool = True, |
| ) -> Dataset: |
| """ |
| Load and combine multiple datasets. |
| |
| Args: |
| paths: List of dataset paths (files or directories) |
| tokenizer: Tokenizer for text encoding |
| max_audio_len: Maximum audio length |
| max_seq_len: Maximum sequence length |
| verbose: Whether to log progress |
| lazy_loading: Use memory-efficient lazy loading for batch directories |
| |
| Returns: |
| Combined dataset |
| """ |
| if verbose: |
| log("\nLoading datasets (lazy loading enabled)..." if lazy_loading else "\nLoading datasets...") |
|
|
| all_datasets = [] |
|
|
| for path in paths: |
| path = Path(path) |
| try: |
| |
| batches_dir = None |
| if path.name.endswith('.batches') and path.is_dir(): |
| batches_dir = path |
| elif Path(f"{path}.batches").exists(): |
| batches_dir = Path(f"{path}.batches") |
| elif path.is_dir(): |
| batch_files = list(path.glob("batch_*.pt")) |
| if batch_files: |
| batches_dir = path |
|
|
| if batches_dir and lazy_loading: |
| |
| batch_files = sorted(batches_dir.glob("batch_*.pt")) |
| if batch_files: |
| dataset = LazyShardedDataset( |
| batch_files, |
| tokenizer, |
| max_audio_len=max_audio_len, |
| max_seq_len=max_seq_len, |
| cache_size=5, |
| verbose=verbose |
| ) |
| all_datasets.append(dataset) |
| if verbose: |
| log(f" {path.name}: {len(dataset):,} samples (lazy loading)") |
| else: |
| |
| loader = ShardedDatasetLoader(verbose=verbose) |
| data = loader.load(str(path)) |
| if data: |
| dataset = InterleavedDataset( |
| data, tokenizer, |
| max_audio_len=max_audio_len, |
| max_seq_len=max_seq_len |
| ) |
| all_datasets.append(dataset) |
| if verbose: |
| log(f" {path.name}: {len(data):,} samples") |
|
|
| except FileNotFoundError as e: |
| if verbose: |
| log(f" [WARN] {e}") |
|
|
| if not all_datasets: |
| raise ValueError("No datasets loaded!") |
|
|
| return ConcatDataset(all_datasets) if len(all_datasets) > 1 else all_datasets[0] |
|
|