| """ |
| Dataset for CosyVoice-style TTS flow matching. |
| |
| Each training sample comes from a HuggingFace dataset (saved to disk) with: |
| file_path (str) β path to audio file |
| text (str) β transcription (unused here) |
| codes (list) β SEMACS VQ codes at 12.5 Hz |
| |
| For each sample we produce: |
| mel (T_mel, n_mels) log-mel spectrogram at 24 kHz (target for flow matching) |
| codes (num_q, T_codes) VQ conditioning codes |
| wav_16k (T_wav,) raw 16 kHz waveform for on-the-fly speaker extraction |
| |
| Speaker embeddings are extracted in the training loop on GPU (batched via CAM++ fbank β |
| campplus forward), so no precomputation step is needed. |
| |
| Mel spectrogram parameters default to the Vocos 24 kHz configuration: |
| sample_rate = 24000 |
| n_mels = 100 |
| n_fft = 1024 |
| hop_length = 256 β frame rate β 93.75 Hz |
| f_max = 12000 |
| """ |
|
|
| from typing import Dict, List, Union |
|
|
| import numpy as np |
| import torch |
| import torchaudio |
| from datasets import concatenate_datasets, load_from_disk |
| from torch.utils.data import Dataset |
|
|
|
|
| class TTSDataset(Dataset): |
| """ |
| VQ codes + mel spectrogram dataset. Speaker embeddings are NOT extracted |
| here β the raw 16 kHz waveform is returned so the training loop can run |
| batched CAM++ inference on GPU. |
| |
| Args: |
| dataset_path HuggingFace dataset directory, or list of directories |
| sample_rate Target sample rate for mel (24000) |
| n_mels Mel bins (100 for Vocos 24 kHz) |
| n_fft FFT size |
| hop_length Hop size in samples |
| f_min/f_max Mel filterbank range |
| max_duration Clip audio longer than this (seconds) |
| max_wav_16k Max 16 kHz samples kept for speaker extraction (10 s default) |
| """ |
|
|
| def __init__( |
| self, |
| dataset_path: Union[str, List[str]], |
| sample_rate: int = 24000, |
| n_mels: int = 100, |
| n_fft: int = 1024, |
| hop_length: int = 256, |
| f_min: float = 0.0, |
| f_max: float = 12000.0, |
| max_duration: float = 30.0, |
| max_wav_16k: int = 160_000, |
| ): |
| if isinstance(dataset_path, (list, tuple)): |
| datasets = [load_from_disk(p) for p in dataset_path] |
| self.dataset = concatenate_datasets(datasets) |
| print(f"[TTSDataset] Loaded {len(dataset_path)} datasets, " |
| f"{len(self.dataset):,} total samples") |
| else: |
| self.dataset = load_from_disk(dataset_path) |
| print(f"[TTSDataset] Loaded {len(self.dataset):,} samples from {dataset_path}") |
|
|
| self.sample_rate = sample_rate |
| self.hop_length = hop_length |
| self.max_frames = int(max_duration * sample_rate / hop_length) |
| self.max_wav_16k = max_wav_16k |
|
|
| |
| |
| self.mel_transform = torchaudio.transforms.MelSpectrogram( |
| sample_rate=sample_rate, |
| n_fft=n_fft, |
| hop_length=hop_length, |
| n_mels=n_mels, |
| f_min=f_min, |
| f_max=f_max, |
| power=1, |
| ) |
|
|
| def __len__(self) -> int: |
| return len(self.dataset) |
|
|
| def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
| |
| N = len(self) |
| for attempt in range(32): |
| i = (idx + attempt) % N |
| item = self.dataset[i] |
| file_path = item["file_path"] |
| try: |
| wav, sr = torchaudio.load(file_path) |
| break |
| except Exception as e: |
| if attempt == 0: |
| print(f"[TTSDataset] load failed at idx={i}: {file_path} ({type(e).__name__}: {e}); skipping") |
| continue |
| else: |
| raise RuntimeError(f"TTSDataset: 32 consecutive load failures starting at idx={idx}") |
|
|
| wav_mono = wav.mean(0) |
|
|
| |
| wav_16k = torchaudio.functional.resample(wav_mono, sr, 16000) \ |
| if sr != 16000 else wav_mono |
| wav_16k = wav_16k[: self.max_wav_16k] |
|
|
| |
| wav_target = torchaudio.functional.resample(wav_mono, sr, self.sample_rate) \ |
| if sr != self.sample_rate else wav_mono |
| mel = self.mel_transform(wav_target) |
| mel = torch.log(mel.clamp(min=1e-7)) |
| mel = mel[:, : self.max_frames].T.contiguous() |
|
|
| |
| codes = np.array(item["codes"], dtype=np.int64) |
| if codes.ndim == 1: |
| codes = codes[np.newaxis, :] |
|
|
| return { |
| "mel": mel, |
| "codes": torch.from_numpy(codes), |
| "wav_16k": wav_16k, |
| } |
|
|
|
|
| |
|
|
| def collate_fn(batch: List[Dict]) -> Dict[str, torch.Tensor]: |
| """Pad variable-length sequences within a batch.""" |
| mels = [item["mel"] for item in batch] |
| codes = [item["codes"] for item in batch] |
| wavs = [item["wav_16k"] for item in batch] |
|
|
| |
| mel_lens = torch.tensor([m.shape[0] for m in mels]) |
| max_mel = int(mel_lens.max()) |
| n_mels = mels[0].shape[1] |
| padded_mels = torch.zeros(len(batch), max_mel, n_mels) |
| for i, m in enumerate(mels): |
| padded_mels[i, : m.shape[0]] = m |
|
|
| |
| num_q = codes[0].shape[0] |
| code_lens = torch.tensor([c.shape[1] for c in codes]) |
| max_codes = int(code_lens.max()) |
| padded_codes = torch.zeros(len(batch), num_q, max_codes, dtype=torch.long) |
| for i, c in enumerate(codes): |
| padded_codes[i, :, : c.shape[1]] = c |
|
|
| |
| wav_lens = torch.tensor([w.shape[0] for w in wavs]) |
| max_wav = int(wav_lens.max()) |
| padded_wavs = torch.zeros(len(batch), max_wav) |
| for i, w in enumerate(wavs): |
| padded_wavs[i, : w.shape[0]] = w |
|
|
| return { |
| "mel": padded_mels, |
| "mel_lens": mel_lens, |
| "codes": padded_codes, |
| "code_lens": code_lens, |
| "wav_16k": padded_wavs, |
| "wav_lens": wav_lens, |
| } |
|
|