| """ |
| MelSVCDataset: training data for F5-SVC. |
| |
| Each sample returns: |
| target_mel (T_mel, 100) log-mel spectrogram of the full clip |
| ref_mel (T_mel, 100) first ref_frames unmasked, rest zeroed (F5-TTS style) |
| ppg (T_mel, 1280) Whisper PPG resampled to mel frame rate |
| hubert (T_mel, 256) HuBERT resampled to mel frame rate |
| f0 (T_mel, 1) log-F0 resampled to mel frame rate |
| spk (256,) speaker d-vector |
| ref_len int number of reference (non-zero) frames |
| |
| Expected directory layout (same as v1): |
| data_svc/ |
| audio/<speaker>/<id>.wav |
| whisper/<speaker>/<id>.ppg.npy |
| hubert/<speaker>/<id>.vec.npy |
| pitch/<speaker>/<id>.pit.npy |
| speaker/<speaker>/<id>.spk.npy |
| """ |
|
|
| from __future__ import annotations |
|
|
| import glob |
| import os |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset |
| import torchaudio |
| import torchaudio.transforms as T |
|
|
| |
| SAMPLE_RATE = 24_000 |
| HOP_LENGTH = 256 |
| N_FFT = 1_024 |
| WIN_LENGTH = 1_024 |
| N_MELS = 100 |
| F_MIN = 0 |
| F_MAX = None |
| MEL_FRAME_RATE = SAMPLE_RATE / HOP_LENGTH |
|
|
|
|
| def _build_mel_transform(sample_rate: int = SAMPLE_RATE) -> T.MelSpectrogram: |
| return T.MelSpectrogram( |
| sample_rate=sample_rate, |
| n_fft=N_FFT, |
| hop_length=HOP_LENGTH, |
| n_mels=N_MELS, |
| win_length=WIN_LENGTH, |
| f_min=F_MIN, |
| f_max=F_MAX, |
| power=1.0, |
| center=True, |
| |
| ) |
|
|
|
|
| def _resample_to(seq: torch.Tensor, target_len: int) -> torch.Tensor: |
| """(T, D) → (target_len, D) via linear interpolation.""" |
| if seq.shape[0] == target_len: |
| return seq |
| x = seq.unsqueeze(0).transpose(1, 2) |
| x = F.interpolate(x, size=target_len, mode="linear", align_corners=False) |
| return x.squeeze(0).transpose(0, 1) |
|
|
|
|
| class MelSVCDataset(Dataset): |
| def __init__( |
| self, |
| audio_dir: str = "./data_svc/audio", |
| ppg_dir: str = "./data_svc/whisper", |
| hubert_dir: str = "./data_svc/hubert", |
| f0_dir: str = "./data_svc/pitch", |
| spk_dir: str = "./data_svc/speaker", |
| packed_dir: str | None = None, |
| max_frames: int = 800, |
| ref_frames: int = 280, |
| sample_rate: int = SAMPLE_RATE, |
| strict: bool = True, |
| ): |
| self.max_frames = max_frames |
| self.ref_frames = ref_frames |
| self.sample_rate = sample_rate |
| self.packed = packed_dir is not None |
|
|
| if self.packed: |
| |
| pt_files = glob.glob(os.path.join(packed_dir, "**", "*.pt"), recursive=True) |
| if not pt_files: |
| raise RuntimeError(f"No .pt files found under {packed_dir}. " |
| f"Run: python prepare/preprocess_pack.py -w <wav_dir> -o {packed_dir}") |
| self.samples: list[dict] = [{"packed": p} for p in pt_files] |
| print(f"MelSVCDataset (packed): {len(self.samples)} samples from {packed_dir}") |
| else: |
| self.mel_tf = _build_mel_transform(sample_rate) |
| wav_files = ( |
| glob.glob(os.path.join(audio_dir, "**", "*.wav"), recursive=True) |
| + glob.glob(os.path.join(audio_dir, "**", "*.flac"), recursive=True) |
| ) |
| if not wav_files: |
| raise RuntimeError(f"No .wav/.flac files found under {audio_dir}") |
|
|
| self.samples = [] |
| skipped = 0 |
| for wav_path in wav_files: |
| file_id = os.path.splitext(os.path.basename(wav_path))[0] |
| spk_name = os.path.basename(os.path.dirname(wav_path)) |
|
|
| ppg_path = os.path.join(ppg_dir, spk_name, f"{file_id}.ppg.npy") |
| hbt_path = os.path.join(hubert_dir, spk_name, f"{file_id}.vec.npy") |
| f0_path = os.path.join(f0_dir, spk_name, f"{file_id}.pit.npy") |
| spk_path = os.path.join(spk_dir, spk_name, f"{file_id}.spk.npy") |
|
|
| if strict and any(not os.path.isfile(p) for p in [ppg_path, hbt_path, f0_path, spk_path]): |
| skipped += 1 |
| continue |
|
|
| self.samples.append(dict(wav=wav_path, ppg=ppg_path, hubert=hbt_path, |
| f0=f0_path, spk=spk_path)) |
|
|
| if not self.samples: |
| raise RuntimeError(f"No valid samples (skipped={skipped}). Check directory layout.") |
| print(f"MelSVCDataset: {len(self.samples)} samples (skipped={skipped})") |
|
|
| def __len__(self) -> int: |
| return len(self.samples) |
|
|
| def __getitem__(self, idx: int): |
| s = self.samples[idx] |
|
|
| if self.packed: |
| |
| data = torch.load(s["packed"], weights_only=True) |
| mel = data["mel"] |
| ppg = data["ppg"] |
| hubert = data["hubert"] |
| f0 = data["f0"].unsqueeze(-1) |
| spk = data["spk"] |
| else: |
| |
| wav, sr = torchaudio.load(s["wav"]) |
| if wav.shape[0] > 1: |
| wav = wav.mean(dim=0, keepdim=True) |
| if sr != self.sample_rate: |
| wav = torchaudio.functional.resample(wav, sr, self.sample_rate) |
| mel = torch.log(self.mel_tf(wav).clamp(min=1e-5)).squeeze(0).T |
|
|
| try: |
| ppg = torch.tensor(np.load(s["ppg"])).float() |
| except Exception: |
| ppg = torch.zeros(mel.shape[0], 1280) |
| try: |
| hubert = torch.tensor(np.load(s["hubert"])).float() |
| except Exception: |
| hubert = torch.zeros(mel.shape[0], 256) |
| try: |
| f0_raw = torch.tensor(np.load(s["f0"])).float() |
| f0 = torch.where(f0_raw > 0, |
| torch.log(f0_raw.clamp(min=1.0)), |
| torch.zeros_like(f0_raw)).unsqueeze(-1) |
| except Exception: |
| f0 = torch.zeros(mel.shape[0], 1) |
| try: |
| spk = torch.tensor(np.load(s["spk"])).float() |
| except Exception: |
| spk = torch.zeros(256) |
|
|
| |
| |
| |
| t_mel = mel.shape[0] |
| ppg = _resample_to(ppg, t_mel) |
| hubert = _resample_to(hubert, t_mel) |
| f0 = _resample_to(f0, t_mel) |
|
|
| |
| if t_mel > self.max_frames: |
| start = torch.randint(0, t_mel - self.max_frames, (1,)).item() |
| mel = mel[start: start + self.max_frames] |
| ppg = ppg[start: start + self.max_frames] |
| hubert = hubert[start: start + self.max_frames] |
| f0 = f0[start: start + self.max_frames] |
| t_mel = self.max_frames |
|
|
| |
| ref_len = min(self.ref_frames, t_mel) |
| ref_mel = torch.zeros_like(mel) |
| ref_mel[:ref_len] = mel[:ref_len] |
|
|
| return mel, ref_mel, ppg, hubert, f0, spk, ref_len |
|
|
|
|
| def collate_fn(batch): |
| mels, ref_mels, ppgs, huberts, f0s, spks, ref_lens = zip(*batch) |
|
|
| lengths = [m.shape[0] for m in mels] |
| max_len = max(lengths) |
| bsz = len(batch) |
|
|
| mel_padded = torch.zeros(bsz, max_len, N_MELS) |
| ref_padded = torch.zeros(bsz, max_len, N_MELS) |
| ppg_padded = torch.zeros(bsz, max_len, ppgs[0].shape[1]) |
| hbt_padded = torch.zeros(bsz, max_len, huberts[0].shape[1]) |
| f0_padded = torch.zeros(bsz, max_len, 1) |
| mask = torch.zeros(bsz, max_len, dtype=torch.bool) |
|
|
| for i, ln in enumerate(lengths): |
| mel_padded[i, :ln] = mels[i] |
| ref_padded[i, :ln] = ref_mels[i] |
| ppg_padded[i, :ln] = ppgs[i] |
| hbt_padded[i, :ln] = huberts[i] |
| f0_padded[i, :ln] = f0s[i] |
| mask[i, :ln] = True |
|
|
| return ( |
| mel_padded, ref_padded, |
| ppg_padded, hbt_padded, f0_padded, |
| torch.stack(spks), |
| mask, |
| torch.tensor(list(ref_lens), dtype=torch.long), |
| ) |
|
|