""" Dataset for CosyVoice-style TTS flow matching. Each training sample comes from a HuggingFace dataset (saved to disk) with: file_path (str) – path to audio file text (str) – transcription (unused here) codes (list) – SEMACS VQ codes at 12.5 Hz For each sample we produce: mel (T_mel, n_mels) log-mel spectrogram at 24 kHz (target for flow matching) codes (num_q, T_codes) VQ conditioning codes wav_16k (T_wav,) raw 16 kHz waveform for on-the-fly speaker extraction Speaker embeddings are extracted in the training loop on GPU (batched via CAM++ fbank → campplus forward), so no precomputation step is needed. Mel spectrogram parameters default to the Vocos 24 kHz configuration: sample_rate = 24000 n_mels = 100 n_fft = 1024 hop_length = 256 → frame rate ≈ 93.75 Hz f_max = 12000 """ from typing import Dict, List, Union import numpy as np import torch import torchaudio from datasets import concatenate_datasets, load_from_disk from torch.utils.data import Dataset class TTSDataset(Dataset): """ VQ codes + mel spectrogram dataset. Speaker embeddings are NOT extracted here — the raw 16 kHz waveform is returned so the training loop can run batched CAM++ inference on GPU. Args: dataset_path HuggingFace dataset directory, or list of directories sample_rate Target sample rate for mel (24000) n_mels Mel bins (100 for Vocos 24 kHz) n_fft FFT size hop_length Hop size in samples f_min/f_max Mel filterbank range max_duration Clip audio longer than this (seconds) max_wav_16k Max 16 kHz samples kept for speaker extraction (10 s default) """ def __init__( self, dataset_path: Union[str, List[str]], sample_rate: int = 24000, n_mels: int = 100, n_fft: int = 1024, hop_length: int = 256, f_min: float = 0.0, f_max: float = 12000.0, max_duration: float = 30.0, max_wav_16k: int = 160_000, # 10 s at 16 kHz ): if isinstance(dataset_path, (list, tuple)): datasets = [load_from_disk(p) for p in dataset_path] self.dataset = concatenate_datasets(datasets) print(f"[TTSDataset] Loaded {len(dataset_path)} datasets, " f"{len(self.dataset):,} total samples") else: self.dataset = load_from_disk(dataset_path) print(f"[TTSDataset] Loaded {len(self.dataset):,} samples from {dataset_path}") self.sample_rate = sample_rate self.hop_length = hop_length self.max_frames = int(max_duration * sample_rate / hop_length) self.max_wav_16k = max_wav_16k # Must match Vocos `charactr/vocos-mel-24khz` MelSpectrogramFeatures: # power=1 (amplitude, not power), default norm=None, default mel_scale="htk" self.mel_transform = torchaudio.transforms.MelSpectrogram( sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, f_min=f_min, f_max=f_max, power=1, ) def __len__(self) -> int: return len(self.dataset) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: # Iterative retry: on load failure, walk forward up to N times before giving up. N = len(self) for attempt in range(32): i = (idx + attempt) % N item = self.dataset[i] file_path = item["file_path"] try: wav, sr = torchaudio.load(file_path) break except Exception as e: if attempt == 0: print(f"[TTSDataset] load failed at idx={i}: {file_path} ({type(e).__name__}: {e}); skipping") continue else: raise RuntimeError(f"TTSDataset: 32 consecutive load failures starting at idx={idx}") wav_mono = wav.mean(0) # (T,) # ── 16 kHz waveform for speaker extraction (truncated to 10 s) ─────── wav_16k = torchaudio.functional.resample(wav_mono, sr, 16000) \ if sr != 16000 else wav_mono wav_16k = wav_16k[: self.max_wav_16k] # ── Mel spectrogram at target sample rate ──────────────────────────── wav_target = torchaudio.functional.resample(wav_mono, sr, self.sample_rate) \ if sr != self.sample_rate else wav_mono mel = self.mel_transform(wav_target) # (n_mels, T_mel) mel = torch.log(mel.clamp(min=1e-7)) # Vocos safe_log clip_val mel = mel[:, : self.max_frames].T.contiguous() # (T_mel, n_mels) # ── VQ codes ───────────────────────────────────────────────────────── codes = np.array(item["codes"], dtype=np.int64) if codes.ndim == 1: codes = codes[np.newaxis, :] # (1, T_codes) return { "mel": mel, # (T_mel, n_mels) "codes": torch.from_numpy(codes), # (num_q, T_codes) "wav_16k": wav_16k, # (T_wav,) for speaker encoder } # ── Collate ─────────────────────────────────────────────────────────────────── def collate_fn(batch: List[Dict]) -> Dict[str, torch.Tensor]: """Pad variable-length sequences within a batch.""" mels = [item["mel"] for item in batch] codes = [item["codes"] for item in batch] wavs = [item["wav_16k"] for item in batch] # Mel (T_mel, n_mels) mel_lens = torch.tensor([m.shape[0] for m in mels]) max_mel = int(mel_lens.max()) n_mels = mels[0].shape[1] padded_mels = torch.zeros(len(batch), max_mel, n_mels) for i, m in enumerate(mels): padded_mels[i, : m.shape[0]] = m # Codes (num_q, T_codes) num_q = codes[0].shape[0] code_lens = torch.tensor([c.shape[1] for c in codes]) max_codes = int(code_lens.max()) padded_codes = torch.zeros(len(batch), num_q, max_codes, dtype=torch.long) for i, c in enumerate(codes): padded_codes[i, :, : c.shape[1]] = c # wav_16k (T_wav,) — padded for batched fbank extraction wav_lens = torch.tensor([w.shape[0] for w in wavs]) max_wav = int(wav_lens.max()) padded_wavs = torch.zeros(len(batch), max_wav) for i, w in enumerate(wavs): padded_wavs[i, : w.shape[0]] = w return { "mel": padded_mels, # (B, T_mel, n_mels) "mel_lens": mel_lens, # (B,) "codes": padded_codes, # (B, num_q, T_codes) "code_lens": code_lens, # (B,) "wav_16k": padded_wavs, # (B, T_wav) "wav_lens": wav_lens, # (B,) valid samples in wav_16k }