sunf / flow_matching /dataset.py
anhtunguyen98's picture
Upload folder using huggingface_hub
4698bfc verified
"""
Dataset for CosyVoice-style TTS flow matching.
Each training sample comes from a HuggingFace dataset (saved to disk) with:
file_path (str) – path to audio file
text (str) – transcription (unused here)
codes (list) – SEMACS VQ codes at 12.5 Hz
For each sample we produce:
mel (T_mel, n_mels) log-mel spectrogram at 24 kHz (target for flow matching)
codes (num_q, T_codes) VQ conditioning codes
wav_16k (T_wav,) raw 16 kHz waveform for on-the-fly speaker extraction
Speaker embeddings are extracted in the training loop on GPU (batched via CAM++ fbank β†’
campplus forward), so no precomputation step is needed.
Mel spectrogram parameters default to the Vocos 24 kHz configuration:
sample_rate = 24000
n_mels = 100
n_fft = 1024
hop_length = 256 β†’ frame rate β‰ˆ 93.75 Hz
f_max = 12000
"""
from typing import Dict, List, Union
import numpy as np
import torch
import torchaudio
from datasets import concatenate_datasets, load_from_disk
from torch.utils.data import Dataset
class TTSDataset(Dataset):
"""
VQ codes + mel spectrogram dataset. Speaker embeddings are NOT extracted
here β€” the raw 16 kHz waveform is returned so the training loop can run
batched CAM++ inference on GPU.
Args:
dataset_path HuggingFace dataset directory, or list of directories
sample_rate Target sample rate for mel (24000)
n_mels Mel bins (100 for Vocos 24 kHz)
n_fft FFT size
hop_length Hop size in samples
f_min/f_max Mel filterbank range
max_duration Clip audio longer than this (seconds)
max_wav_16k Max 16 kHz samples kept for speaker extraction (10 s default)
"""
def __init__(
self,
dataset_path: Union[str, List[str]],
sample_rate: int = 24000,
n_mels: int = 100,
n_fft: int = 1024,
hop_length: int = 256,
f_min: float = 0.0,
f_max: float = 12000.0,
max_duration: float = 30.0,
max_wav_16k: int = 160_000, # 10 s at 16 kHz
):
if isinstance(dataset_path, (list, tuple)):
datasets = [load_from_disk(p) for p in dataset_path]
self.dataset = concatenate_datasets(datasets)
print(f"[TTSDataset] Loaded {len(dataset_path)} datasets, "
f"{len(self.dataset):,} total samples")
else:
self.dataset = load_from_disk(dataset_path)
print(f"[TTSDataset] Loaded {len(self.dataset):,} samples from {dataset_path}")
self.sample_rate = sample_rate
self.hop_length = hop_length
self.max_frames = int(max_duration * sample_rate / hop_length)
self.max_wav_16k = max_wav_16k
# Must match Vocos `charactr/vocos-mel-24khz` MelSpectrogramFeatures:
# power=1 (amplitude, not power), default norm=None, default mel_scale="htk"
self.mel_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=n_fft,
hop_length=hop_length,
n_mels=n_mels,
f_min=f_min,
f_max=f_max,
power=1,
)
def __len__(self) -> int:
return len(self.dataset)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
# Iterative retry: on load failure, walk forward up to N times before giving up.
N = len(self)
for attempt in range(32):
i = (idx + attempt) % N
item = self.dataset[i]
file_path = item["file_path"]
try:
wav, sr = torchaudio.load(file_path)
break
except Exception as e:
if attempt == 0:
print(f"[TTSDataset] load failed at idx={i}: {file_path} ({type(e).__name__}: {e}); skipping")
continue
else:
raise RuntimeError(f"TTSDataset: 32 consecutive load failures starting at idx={idx}")
wav_mono = wav.mean(0) # (T,)
# ── 16 kHz waveform for speaker extraction (truncated to 10 s) ───────
wav_16k = torchaudio.functional.resample(wav_mono, sr, 16000) \
if sr != 16000 else wav_mono
wav_16k = wav_16k[: self.max_wav_16k]
# ── Mel spectrogram at target sample rate ────────────────────────────
wav_target = torchaudio.functional.resample(wav_mono, sr, self.sample_rate) \
if sr != self.sample_rate else wav_mono
mel = self.mel_transform(wav_target) # (n_mels, T_mel)
mel = torch.log(mel.clamp(min=1e-7)) # Vocos safe_log clip_val
mel = mel[:, : self.max_frames].T.contiguous() # (T_mel, n_mels)
# ── VQ codes ─────────────────────────────────────────────────────────
codes = np.array(item["codes"], dtype=np.int64)
if codes.ndim == 1:
codes = codes[np.newaxis, :] # (1, T_codes)
return {
"mel": mel, # (T_mel, n_mels)
"codes": torch.from_numpy(codes), # (num_q, T_codes)
"wav_16k": wav_16k, # (T_wav,) for speaker encoder
}
# ── Collate ───────────────────────────────────────────────────────────────────
def collate_fn(batch: List[Dict]) -> Dict[str, torch.Tensor]:
"""Pad variable-length sequences within a batch."""
mels = [item["mel"] for item in batch]
codes = [item["codes"] for item in batch]
wavs = [item["wav_16k"] for item in batch]
# Mel (T_mel, n_mels)
mel_lens = torch.tensor([m.shape[0] for m in mels])
max_mel = int(mel_lens.max())
n_mels = mels[0].shape[1]
padded_mels = torch.zeros(len(batch), max_mel, n_mels)
for i, m in enumerate(mels):
padded_mels[i, : m.shape[0]] = m
# Codes (num_q, T_codes)
num_q = codes[0].shape[0]
code_lens = torch.tensor([c.shape[1] for c in codes])
max_codes = int(code_lens.max())
padded_codes = torch.zeros(len(batch), num_q, max_codes, dtype=torch.long)
for i, c in enumerate(codes):
padded_codes[i, :, : c.shape[1]] = c
# wav_16k (T_wav,) β€” padded for batched fbank extraction
wav_lens = torch.tensor([w.shape[0] for w in wavs])
max_wav = int(wav_lens.max())
padded_wavs = torch.zeros(len(batch), max_wav)
for i, w in enumerate(wavs):
padded_wavs[i, : w.shape[0]] = w
return {
"mel": padded_mels, # (B, T_mel, n_mels)
"mel_lens": mel_lens, # (B,)
"codes": padded_codes, # (B, num_q, T_codes)
"code_lens": code_lens, # (B,)
"wav_16k": padded_wavs, # (B, T_wav)
"wav_lens": wav_lens, # (B,) valid samples in wav_16k
}