cfm_svc / svc_data /mel_svc_dataset.py
Hector Li
Initial commit for Hugging Face
df93d13
"""
MelSVCDataset: training data for F5-SVC.
Each sample returns:
target_mel (T_mel, 100) log-mel spectrogram of the full clip
ref_mel (T_mel, 100) first ref_frames unmasked, rest zeroed (F5-TTS style)
ppg (T_mel, 1280) Whisper PPG resampled to mel frame rate
hubert (T_mel, 256) HuBERT resampled to mel frame rate
f0 (T_mel, 1) log-F0 resampled to mel frame rate
spk (256,) speaker d-vector
ref_len int number of reference (non-zero) frames
Expected directory layout (same as v1):
data_svc/
audio/<speaker>/<id>.wav
whisper/<speaker>/<id>.ppg.npy
hubert/<speaker>/<id>.vec.npy
pitch/<speaker>/<id>.pit.npy
speaker/<speaker>/<id>.spk.npy
"""
from __future__ import annotations
import glob
import os
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
import torchaudio
import torchaudio.transforms as T
# Mel parameters — must match Vocos (charactr/vocos-mel-24khz) exactly
SAMPLE_RATE = 24_000
HOP_LENGTH = 256
N_FFT = 1_024
WIN_LENGTH = 1_024
N_MELS = 100
F_MIN = 0
F_MAX = None # = sr/2 = 12 kHz (Vocos default)
MEL_FRAME_RATE = SAMPLE_RATE / HOP_LENGTH # 93.75 Hz
def _build_mel_transform(sample_rate: int = SAMPLE_RATE) -> T.MelSpectrogram:
return T.MelSpectrogram(
sample_rate=sample_rate,
n_fft=N_FFT,
hop_length=HOP_LENGTH,
n_mels=N_MELS,
win_length=WIN_LENGTH,
f_min=F_MIN,
f_max=F_MAX,
power=1.0, # amplitude (matches Vocos)
center=True, # matches Vocos
# No norm, no mel_scale — use PyTorch defaults (htk) to match Vocos
)
def _resample_to(seq: torch.Tensor, target_len: int) -> torch.Tensor:
"""(T, D) → (target_len, D) via linear interpolation."""
if seq.shape[0] == target_len:
return seq
x = seq.unsqueeze(0).transpose(1, 2)
x = F.interpolate(x, size=target_len, mode="linear", align_corners=False)
return x.squeeze(0).transpose(0, 1)
class MelSVCDataset(Dataset):
def __init__(
self,
audio_dir: str = "./data_svc/audio",
ppg_dir: str = "./data_svc/whisper",
hubert_dir: str = "./data_svc/hubert",
f0_dir: str = "./data_svc/pitch",
spk_dir: str = "./data_svc/speaker",
packed_dir: str | None = None, # if set, load from packed .pt files (NFS-efficient)
max_frames: int = 800, # ~8.5 sec at 93.75 Hz
ref_frames: int = 280, # ~3 sec reference region
sample_rate: int = SAMPLE_RATE,
strict: bool = True,
):
self.max_frames = max_frames
self.ref_frames = ref_frames
self.sample_rate = sample_rate
self.packed = packed_dir is not None
if self.packed:
# One .pt file per sample — single NFS read, mel already computed
pt_files = glob.glob(os.path.join(packed_dir, "**", "*.pt"), recursive=True)
if not pt_files:
raise RuntimeError(f"No .pt files found under {packed_dir}. "
f"Run: python prepare/preprocess_pack.py -w <wav_dir> -o {packed_dir}")
self.samples: list[dict] = [{"packed": p} for p in pt_files]
print(f"MelSVCDataset (packed): {len(self.samples)} samples from {packed_dir}")
else:
self.mel_tf = _build_mel_transform(sample_rate)
wav_files = (
glob.glob(os.path.join(audio_dir, "**", "*.wav"), recursive=True)
+ glob.glob(os.path.join(audio_dir, "**", "*.flac"), recursive=True)
)
if not wav_files:
raise RuntimeError(f"No .wav/.flac files found under {audio_dir}")
self.samples = []
skipped = 0
for wav_path in wav_files:
file_id = os.path.splitext(os.path.basename(wav_path))[0]
spk_name = os.path.basename(os.path.dirname(wav_path))
ppg_path = os.path.join(ppg_dir, spk_name, f"{file_id}.ppg.npy")
hbt_path = os.path.join(hubert_dir, spk_name, f"{file_id}.vec.npy")
f0_path = os.path.join(f0_dir, spk_name, f"{file_id}.pit.npy")
spk_path = os.path.join(spk_dir, spk_name, f"{file_id}.spk.npy")
if strict and any(not os.path.isfile(p) for p in [ppg_path, hbt_path, f0_path, spk_path]):
skipped += 1
continue
self.samples.append(dict(wav=wav_path, ppg=ppg_path, hubert=hbt_path,
f0=f0_path, spk=spk_path))
if not self.samples:
raise RuntimeError(f"No valid samples (skipped={skipped}). Check directory layout.")
print(f"MelSVCDataset: {len(self.samples)} samples (skipped={skipped})")
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int):
s = self.samples[idx]
if self.packed:
# --- Packed path: single load, mel already computed ---
data = torch.load(s["packed"], weights_only=True)
mel = data["mel"] # (T_mel, N_MELS)
ppg = data["ppg"] # (T_feat, 1280)
hubert = data["hubert"] # (T_feat, 256)
f0 = data["f0"].unsqueeze(-1) # (T_feat, 1)
spk = data["spk"] # (256,)
else:
# --- Unpacked path: load wav + 4 npy files, compute mel ---
wav, sr = torchaudio.load(s["wav"])
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True)
if sr != self.sample_rate:
wav = torchaudio.functional.resample(wav, sr, self.sample_rate)
mel = torch.log(self.mel_tf(wav).clamp(min=1e-5)).squeeze(0).T # (T_mel, N_MELS)
try:
ppg = torch.tensor(np.load(s["ppg"])).float()
except Exception:
ppg = torch.zeros(mel.shape[0], 1280)
try:
hubert = torch.tensor(np.load(s["hubert"])).float()
except Exception:
hubert = torch.zeros(mel.shape[0], 256)
try:
f0_raw = torch.tensor(np.load(s["f0"])).float()
f0 = torch.where(f0_raw > 0,
torch.log(f0_raw.clamp(min=1.0)),
torch.zeros_like(f0_raw)).unsqueeze(-1)
except Exception:
f0 = torch.zeros(mel.shape[0], 1)
try:
spk = torch.tensor(np.load(s["spk"])).float()
except Exception:
spk = torch.zeros(256)
# Resample features from their native rate (50 Hz) to mel frame rate
# (93.75 Hz) BEFORE cropping, so crop indices are consistent across
# mel and all features.
t_mel = mel.shape[0]
ppg = _resample_to(ppg, t_mel)
hubert = _resample_to(hubert, t_mel)
f0 = _resample_to(f0, t_mel)
# Random crop to max_frames (all tensors are now at mel frame rate)
if t_mel > self.max_frames:
start = torch.randint(0, t_mel - self.max_frames, (1,)).item()
mel = mel[start: start + self.max_frames]
ppg = ppg[start: start + self.max_frames]
hubert = hubert[start: start + self.max_frames]
f0 = f0[start: start + self.max_frames]
t_mel = self.max_frames
# Reference region (F5-TTS inpainting convention)
ref_len = min(self.ref_frames, t_mel)
ref_mel = torch.zeros_like(mel)
ref_mel[:ref_len] = mel[:ref_len]
return mel, ref_mel, ppg, hubert, f0, spk, ref_len
def collate_fn(batch):
mels, ref_mels, ppgs, huberts, f0s, spks, ref_lens = zip(*batch)
lengths = [m.shape[0] for m in mels]
max_len = max(lengths)
bsz = len(batch)
mel_padded = torch.zeros(bsz, max_len, N_MELS)
ref_padded = torch.zeros(bsz, max_len, N_MELS)
ppg_padded = torch.zeros(bsz, max_len, ppgs[0].shape[1])
hbt_padded = torch.zeros(bsz, max_len, huberts[0].shape[1])
f0_padded = torch.zeros(bsz, max_len, 1)
mask = torch.zeros(bsz, max_len, dtype=torch.bool)
for i, ln in enumerate(lengths):
mel_padded[i, :ln] = mels[i]
ref_padded[i, :ln] = ref_mels[i]
ppg_padded[i, :ln] = ppgs[i]
hbt_padded[i, :ln] = huberts[i]
f0_padded[i, :ln] = f0s[i]
mask[i, :ln] = True
return (
mel_padded, ref_padded,
ppg_padded, hbt_padded, f0_padded,
torch.stack(spks),
mask,
torch.tensor(list(ref_lens), dtype=torch.long),
)