omini-model / training /data.py
marcos
feat: Refactor training with SOLID principles and add optimizations
e20f447
Raw
History Blame Contribute Delete
25.7 kB
"""
Data loading and dataset management.
Single Responsibility: Only handles data loading and dataset creation.
Open/Closed: Can extend with new dataset formats without modifying existing code.
Optimizations:
- Lazy loading for memory efficiency
- Sequence length bucketing for reduced padding overhead
- LRU cache for batch files
"""
import gc
import random
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, ConcatDataset, Sampler
from pathlib import Path
from typing import List, Dict, Any, Optional, Callable, Tuple, Iterator
from .utils import log
# ============================================================
# Sequence Length Bucketing (Reduces padding overhead)
# ============================================================
class BucketBatchSampler(Sampler[List[int]]):
"""
Batch sampler that groups samples by sequence length into buckets.
This reduces padding overhead by batching similar-length sequences together.
Based on TensorFlow's bucket_by_sequence_length concept.
Benefits:
- Reduces wasted computation on padding tokens
- More consistent memory usage per batch
- Can improve training speed by 10-30%
Args:
lengths: List of sequence lengths for each sample
batch_size: Number of samples per batch
bucket_boundaries: Length boundaries for buckets (auto-computed if None)
shuffle: Whether to shuffle within buckets
drop_last: Whether to drop incomplete batches
"""
def __init__(
self,
lengths: List[int],
batch_size: int,
bucket_boundaries: Optional[List[int]] = None,
shuffle: bool = True,
drop_last: bool = False,
):
self.lengths = lengths
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
# Auto-compute bucket boundaries if not provided
if bucket_boundaries is None:
# Create ~10 buckets based on length distribution
sorted_lens = sorted(lengths)
n = len(sorted_lens)
bucket_boundaries = [
sorted_lens[int(n * p)] for p in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
]
# Remove duplicates and sort
bucket_boundaries = sorted(set(bucket_boundaries))
self.bucket_boundaries = bucket_boundaries
# Assign samples to buckets
self.buckets: Dict[int, List[int]] = {i: [] for i in range(len(bucket_boundaries) + 1)}
for idx, length in enumerate(lengths):
bucket_id = self._get_bucket_id(length)
self.buckets[bucket_id].append(idx)
def _get_bucket_id(self, length: int) -> int:
"""Find which bucket a length belongs to."""
for i, boundary in enumerate(self.bucket_boundaries):
if length <= boundary:
return i
return len(self.bucket_boundaries)
def __iter__(self) -> Iterator[List[int]]:
"""Generate batches from buckets."""
# Collect all batches from all buckets
all_batches = []
for bucket_id, indices in self.buckets.items():
if not indices:
continue
# Shuffle within bucket
bucket_indices = indices.copy()
if self.shuffle:
random.shuffle(bucket_indices)
# Create batches
for i in range(0, len(bucket_indices), self.batch_size):
batch = bucket_indices[i:i + self.batch_size]
if len(batch) == self.batch_size or not self.drop_last:
all_batches.append(batch)
# Shuffle batches across buckets
if self.shuffle:
random.shuffle(all_batches)
yield from all_batches
def __len__(self) -> int:
"""Return total number of batches."""
total = 0
for indices in self.buckets.values():
n_batches = len(indices) // self.batch_size
if not self.drop_last and len(indices) % self.batch_size != 0:
n_batches += 1
total += n_batches
return total
# ============================================================
# Lazy Sharded Dataset (Memory Efficient)
# ============================================================
class LazyShardedDataset(Dataset):
"""
Memory-efficient dataset that loads batch files on-demand.
Instead of loading all data into memory, maintains an index of
which sample is in which batch file, and loads batches as needed.
Also stores approximate sequence lengths for bucketing optimization.
"""
def __init__(
self,
batch_files: List[Path],
tokenizer,
max_audio_len: int = 500,
max_seq_len: int = 2048,
cache_size: int = 3, # Number of batches to keep in memory
verbose: bool = True
):
self.batch_files = batch_files
self.tokenizer = tokenizer
self.max_audio = max_audio_len * 5
self.max_seq_len = max_seq_len
self.cache_size = cache_size
self.verbose = verbose
# Build index: sample_idx -> (batch_idx, local_idx)
self.index: List[Tuple[int, int]] = []
self.batch_sizes: List[int] = []
# Store sequence lengths for bucketing
self.sequence_lengths: List[int] = []
if verbose:
log(f" Indexing {len(batch_files)} batch files...")
for batch_idx, bf in enumerate(batch_files):
# Quick load to get size and lengths
data = torch.load(bf, map_location="cpu", weights_only=False)
batch_size = len(data)
self.batch_sizes.append(batch_size)
for local_idx in range(batch_size):
self.index.append((batch_idx, local_idx))
# Estimate sequence length (SNAC tokens dominate)
item = data[local_idx]
snac_len = len(item.get("snac_tokens", [])) // 7 * 7
self.sequence_lengths.append(snac_len)
del data
if (batch_idx + 1) % 100 == 0:
gc.collect()
if verbose:
log(f" Indexed {batch_idx+1}/{len(batch_files)} batches ({len(self.index):,} samples)")
if verbose:
log(f" Total indexed: {len(self.index):,} samples")
# LRU cache for loaded batches
self._cache: Dict[int, List[Dict]] = {}
self._cache_order: List[int] = []
def get_sequence_lengths(self) -> List[int]:
"""Return sequence lengths for bucketing."""
return self.sequence_lengths
def __len__(self) -> int:
return len(self.index)
def _load_batch(self, batch_idx: int) -> List[Dict]:
"""Load a batch file, using cache."""
if batch_idx in self._cache:
# Move to end of cache order (most recently used)
self._cache_order.remove(batch_idx)
self._cache_order.append(batch_idx)
return self._cache[batch_idx]
# Load from disk
data = torch.load(self.batch_files[batch_idx], map_location="cpu", weights_only=False)
# Add to cache
self._cache[batch_idx] = data
self._cache_order.append(batch_idx)
# Evict old batches if cache is full
while len(self._cache_order) > self.cache_size:
old_idx = self._cache_order.pop(0)
del self._cache[old_idx]
gc.collect()
return data
def __getitem__(self, idx: int) -> Dict[str, Any]:
batch_idx, local_idx = self.index[idx]
batch_data = self._load_batch(batch_idx)
item = batch_data[local_idx]
# Process item (same as InterleavedDataset)
whisper = item["whisper_features"][:self.max_audio]
text_tokens = self._get_text_tokens(item)
snac_list = self._get_snac_tokens(item)
word_alignments = item.get("word_alignments", None)
answer_text = item.get("answer", "")
return {
"whisper": whisper,
"text_tokens": text_tokens,
"snac_tokens": snac_list,
"word_alignments": word_alignments,
"answer_text": answer_text
}
def _get_text_tokens(self, item: Dict[str, Any]) -> List[int]:
if "text_tokens" in item and len(item["text_tokens"]) > 0:
tt = item["text_tokens"]
return tt.tolist() if hasattr(tt, 'tolist') else list(tt)
text = item.get("answer", item.get("text", ""))
if isinstance(text, str) and len(text) > 0:
return self.tokenizer.encode(text, add_special_tokens=False)
return []
def _get_snac_tokens(self, item: Dict[str, Any]) -> List[int]:
snac = item["snac_tokens"]
snac_len = (len(snac) // 7) * 7
snac = snac[:snac_len] if snac_len > 0 else snac[:7]
return snac.tolist() if hasattr(snac, 'tolist') else list(snac)
# ============================================================
# Sharded Dataset Loading (Industry Standard)
# ============================================================
class ShardedDatasetLoader:
"""
Load datasets from single files or sharded batch directories.
Supports:
- Single .pt file
- Batch directory with batch_*.pt files
- Mixed (base file + batch files)
Dependency Inversion: Uses abstract path interface, not concrete file operations.
"""
def __init__(self, verbose: bool = True):
self.verbose = verbose
def load(self, path: str) -> List[Dict[str, Any]]:
"""Load dataset from path (file or directory)."""
path = Path(path)
samples = []
# Case 1: Explicit batches directory
if path.name.endswith('.batches') and path.is_dir():
samples = self._load_batches_dir(path)
# Case 2: Single file (possibly with batches)
elif path.exists() and path.is_file():
samples = self._load_file_with_batches(path)
# Case 3: Only batches directory exists
else:
batches_dir = Path(f"{path}.batches")
if batches_dir.exists() and batches_dir.is_dir():
samples = self._load_batches_dir(batches_dir)
else:
raise FileNotFoundError(f"No dataset found at {path}")
return samples
def _load_batches_dir(self, batches_dir: Path) -> List[Dict[str, Any]]:
"""Load all batch files from a directory."""
batch_files = sorted(batches_dir.glob("batch_*.pt"))
if not batch_files:
raise FileNotFoundError(f"No batch files in {batches_dir}")
if self.verbose:
log(f" Loading {len(batch_files)} batch files from {batches_dir.name}/")
samples = []
for i, bf in enumerate(batch_files):
items = torch.load(bf, map_location="cpu", weights_only=False)
samples.extend(items)
del items
if (i + 1) % 100 == 0:
gc.collect()
if self.verbose:
log(f" Loaded {i+1}/{len(batch_files)} batches ({len(samples):,} samples)")
return samples
def _load_file_with_batches(self, path: Path) -> List[Dict[str, Any]]:
"""Load base file and any associated batch files."""
# Load base file
base = torch.load(path, map_location="cpu", weights_only=False, mmap=True)
samples = list(base)
del base
gc.collect()
if self.verbose:
log(f" Base file: {len(samples):,} samples")
# Check for batch files
batches_dir = Path(f"{path}.batches")
if batches_dir.exists() and batches_dir.is_dir():
batch_files = sorted(batches_dir.glob("batch_*.pt"))
if batch_files:
if self.verbose:
log(f" Found {len(batch_files)} batch files")
for i, bf in enumerate(batch_files):
items = torch.load(bf, map_location="cpu", weights_only=False)
samples.extend(items)
del items
if (i + 1) % 100 == 0:
gc.collect()
if self.verbose:
log(f" Loaded {i+1}/{len(batch_files)} batches ({len(samples):,} total)")
return samples
def load_sharded_dataset(path: str, verbose: bool = True) -> List[Dict[str, Any]]:
"""Convenience function for loading sharded datasets."""
loader = ShardedDatasetLoader(verbose=verbose)
return loader.load(path)
# ============================================================
# Dataset Classes
# ============================================================
class InterleavedDataset(Dataset):
"""
Dataset that prepares samples for interleaved training.
Single Responsibility: Only handles sample access and preprocessing.
"""
def __init__(
self,
data: List[Dict[str, Any]],
tokenizer,
max_audio_len: int = 500,
max_seq_len: int = 2048
):
self.data = data
self.tokenizer = tokenizer
self.max_audio = max_audio_len * 5 # Account for downsampling
self.max_seq_len = max_seq_len
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, idx: int) -> Dict[str, Any]:
item = self.data[idx]
# Whisper features (truncate if needed)
whisper = item["whisper_features"][:self.max_audio]
# Text tokens - use pre-computed if available
text_tokens = self._get_text_tokens(item)
# SNAC tokens (ensure multiple of 7)
snac_list = self._get_snac_tokens(item)
# Optional fields
word_alignments = item.get("word_alignments", None)
answer_text = item.get("answer", "")
return {
"whisper": whisper,
"text_tokens": text_tokens,
"snac_tokens": snac_list,
"word_alignments": word_alignments,
"answer_text": answer_text
}
def _get_text_tokens(self, item: Dict[str, Any]) -> List[int]:
"""Extract or generate text tokens from item."""
if "text_tokens" in item and len(item["text_tokens"]) > 0:
tt = item["text_tokens"]
return tt.tolist() if hasattr(tt, 'tolist') else list(tt)
text = item.get("answer", item.get("text", ""))
if isinstance(text, str) and len(text) > 0:
return self.tokenizer.encode(text, add_special_tokens=False)
return []
def _get_snac_tokens(self, item: Dict[str, Any]) -> List[int]:
"""Extract SNAC tokens, ensuring multiple of 7."""
snac = item["snac_tokens"]
snac_len = (len(snac) // 7) * 7
snac = snac[:snac_len] if snac_len > 0 else snac[:7]
return snac.tolist() if hasattr(snac, 'tolist') else list(snac)
# ============================================================
# Collate Functions
# ============================================================
def collate_simple(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Simple collate that pads whisper features.
Interleaving is done in training loop for correct text_ratio.
"""
max_w = max(b["whisper"].shape[0] for b in batch)
max_w = ((max_w + 4) // 5) * 5 # Align to downsample factor
whisper_batch = []
raw_data = []
for b in batch:
w = b["whisper"]
w_pad = F.pad(w, (0, 0, 0, max_w - w.shape[0]))
whisper_batch.append(w_pad)
raw_data.append({
"text_tokens": b["text_tokens"],
"snac_tokens": b["snac_tokens"],
"word_alignments": b.get("word_alignments"),
"answer_text": b.get("answer_text", "")
})
return {
"whisper": torch.stack(whisper_batch),
"raw_data": raw_data
}
# ============================================================
# DataLoader Factory
# ============================================================
class DataLoaderFactory:
"""
Factory for creating DataLoaders with optimal settings.
Single Responsibility: Only handles DataLoader creation.
Open/Closed: Can extend multiprocessing strategies without modification.
Optimizations:
- Sequence length bucketing for reduced padding
- Optimal worker count based on system resources
"""
@staticmethod
def get_optimal_workers() -> int:
"""Calculate optimal num_workers based on system resources."""
if not torch.cuda.is_available():
return 0
try:
import os
num_gpus = torch.cuda.device_count()
cpu_cores = os.cpu_count() or 4
max_workers = max(1, cpu_cores // 2)
# 2 workers per GPU, capped by CPU
ideal_workers = num_gpus * 2
num_workers = min(ideal_workers, max_workers)
# Check VRAM pressure
try:
free_bytes, _ = torch.cuda.mem_get_info(0)
if free_bytes / 1024**3 < 5:
num_workers = min(num_workers, 1)
except Exception:
pass
return max(0, num_workers)
except Exception:
return 2
@classmethod
def create(
cls,
dataset: Dataset,
batch_size: int,
shuffle: bool = True,
collate_fn: Callable = collate_simple,
verbose: bool = True,
use_bucketing: bool = True,
) -> DataLoader:
"""
Create DataLoader with optimal settings.
Args:
dataset: The dataset to load from
batch_size: Batch size
shuffle: Whether to shuffle data
collate_fn: Function to collate samples
verbose: Whether to log details
use_bucketing: Use sequence length bucketing (reduces padding overhead)
"""
optimal_workers = cls.get_optimal_workers()
# Try to use sequence length bucketing
batch_sampler = None
if use_bucketing and shuffle:
try:
# Get sequence lengths from dataset
lengths = cls._get_sequence_lengths(dataset)
if lengths and len(lengths) > batch_size * 10: # Only bucket if enough samples
batch_sampler = BucketBatchSampler(
lengths=lengths,
batch_size=batch_size,
shuffle=True,
drop_last=False,
)
if verbose:
log(f"[DataLoader] Using sequence length bucketing ({len(batch_sampler)} batches)")
except Exception as e:
if verbose:
log(f"[DataLoader] Bucketing failed: {e}, using standard batching")
# Try different multiprocessing methods
for mp_method in ['spawn', 'fork', None]:
try:
loader = cls._try_create_loader(
dataset, batch_size, shuffle, collate_fn,
optimal_workers, mp_method, batch_sampler
)
if loader is not None:
if verbose and mp_method:
log(f"[DataLoader] Using '{mp_method}' with {optimal_workers} workers")
elif verbose:
log("[DataLoader] Using single-process mode")
return loader
except Exception as e:
if verbose:
log(f"[DataLoader] {mp_method} failed: {str(e)[:50]}...")
continue
# Final fallback
if verbose:
log("[DataLoader] Fallback to num_workers=0")
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
collate_fn=collate_fn,
num_workers=0,
pin_memory=True
)
@staticmethod
def _get_sequence_lengths(dataset: Dataset) -> Optional[List[int]]:
"""Extract sequence lengths from dataset for bucketing."""
# Check if dataset has get_sequence_lengths method
if hasattr(dataset, 'get_sequence_lengths'):
return dataset.get_sequence_lengths()
# For ConcatDataset, try to combine lengths from components
if isinstance(dataset, ConcatDataset):
lengths = []
for ds in dataset.datasets:
if hasattr(ds, 'get_sequence_lengths'):
lengths.extend(ds.get_sequence_lengths())
else:
return None # Can't get lengths for all, skip bucketing
return lengths
return None
@staticmethod
def _try_create_loader(
dataset: Dataset,
batch_size: int,
shuffle: bool,
collate_fn: Callable,
num_workers: int,
mp_method: Optional[str],
batch_sampler: Optional[Sampler] = None,
) -> Optional[DataLoader]:
"""Try to create DataLoader with given settings."""
import multiprocessing
# When using batch_sampler, don't set batch_size, shuffle, or sampler
common_kwargs = {
'collate_fn': collate_fn,
'pin_memory': True,
}
if batch_sampler is not None:
common_kwargs['batch_sampler'] = batch_sampler
else:
common_kwargs['batch_size'] = batch_size
common_kwargs['shuffle'] = shuffle
if mp_method and num_workers > 0:
mp_context = multiprocessing.get_context(mp_method)
loader = DataLoader(
dataset,
num_workers=num_workers,
multiprocessing_context=mp_context,
persistent_workers=True,
**common_kwargs
)
# Test if it works
test_iter = iter(loader)
del test_iter
return loader
else:
return DataLoader(
dataset,
num_workers=0,
**common_kwargs
)
def create_dataloader(
dataset: Dataset,
batch_size: int,
shuffle: bool = True,
verbose: bool = True,
use_bucketing: bool = True,
) -> DataLoader:
"""
Convenience function for creating DataLoaders.
Args:
dataset: Dataset to load from
batch_size: Batch size
shuffle: Whether to shuffle
verbose: Whether to log
use_bucketing: Use sequence length bucketing
"""
return DataLoaderFactory.create(
dataset, batch_size, shuffle,
collate_fn=collate_simple, verbose=verbose,
use_bucketing=use_bucketing,
)
# ============================================================
# Dataset Loading Pipeline
# ============================================================
def load_datasets(
paths: List[str],
tokenizer,
max_audio_len: int = 500,
max_seq_len: int = 2048,
verbose: bool = True,
lazy_loading: bool = True, # Use memory-efficient lazy loading
) -> Dataset:
"""
Load and combine multiple datasets.
Args:
paths: List of dataset paths (files or directories)
tokenizer: Tokenizer for text encoding
max_audio_len: Maximum audio length
max_seq_len: Maximum sequence length
verbose: Whether to log progress
lazy_loading: Use memory-efficient lazy loading for batch directories
Returns:
Combined dataset
"""
if verbose:
log("\nLoading datasets (lazy loading enabled)..." if lazy_loading else "\nLoading datasets...")
all_datasets = []
for path in paths:
path = Path(path)
try:
# Check if this is a batch directory (use lazy loading)
batches_dir = None
if path.name.endswith('.batches') and path.is_dir():
batches_dir = path
elif Path(f"{path}.batches").exists():
batches_dir = Path(f"{path}.batches")
elif path.is_dir():
batch_files = list(path.glob("batch_*.pt"))
if batch_files:
batches_dir = path
if batches_dir and lazy_loading:
# Use memory-efficient lazy loading
batch_files = sorted(batches_dir.glob("batch_*.pt"))
if batch_files:
dataset = LazyShardedDataset(
batch_files,
tokenizer,
max_audio_len=max_audio_len,
max_seq_len=max_seq_len,
cache_size=5, # Keep 5 batches in memory
verbose=verbose
)
all_datasets.append(dataset)
if verbose:
log(f" {path.name}: {len(dataset):,} samples (lazy loading)")
else:
# Fall back to full loading for single files
loader = ShardedDatasetLoader(verbose=verbose)
data = loader.load(str(path))
if data:
dataset = InterleavedDataset(
data, tokenizer,
max_audio_len=max_audio_len,
max_seq_len=max_seq_len
)
all_datasets.append(dataset)
if verbose:
log(f" {path.name}: {len(data):,} samples")
except FileNotFoundError as e:
if verbose:
log(f" [WARN] {e}")
if not all_datasets:
raise ValueError("No datasets loaded!")
return ConcatDataset(all_datasets) if len(all_datasets) > 1 else all_datasets[0]