from __future__ import annotations from dataclasses import dataclass from pathlib import Path import numpy as np import soundfile as sf import torch @dataclass class VocosFbankConfig: sampling_rate: int = 24000 n_mels: int = 100 n_fft: int = 1024 hop_length: int = 256 def compute_num_frames(num_samples: int, hop_length: int) -> int: return int((num_samples + hop_length // 2) // hop_length) class LocalVocosFbank: def __init__(self) -> None: self.config = VocosFbankConfig() self.window = torch.hann_window(self.config.n_fft) self.mel_basis = _create_mel_filterbank( sample_rate=self.config.sampling_rate, n_fft=self.config.n_fft, n_mels=self.config.n_mels, ) def extract(self, samples: torch.Tensor, sampling_rate: int) -> torch.Tensor: if sampling_rate != self.config.sampling_rate: raise ValueError( f"Mismatched sampling rate: expected {self.config.sampling_rate}, got {sampling_rate}" ) if samples.ndim == 1: samples = samples.unsqueeze(0) if samples.ndim != 2: raise ValueError(f"Expected waveform shape [C, T], got {tuple(samples.shape)}") if samples.shape[0] == 2: samples = samples.mean(dim=0, keepdim=True) stft = torch.stft( samples, n_fft=self.config.n_fft, hop_length=self.config.hop_length, win_length=self.config.n_fft, window=self.window.to(samples.device), center=True, pad_mode="reflect", return_complex=True, ) spec = stft.abs() mel = torch.matmul(self.mel_basis.to(samples.device).t(), spec).clamp(min=1e-7).log() mel = mel.reshape(-1, mel.shape[-1]).t() num_frames = compute_num_frames(samples.shape[1], self.config.hop_length) if mel.shape[0] > num_frames: mel = mel[:num_frames] elif mel.shape[0] < num_frames: mel = torch.nn.functional.pad( mel.unsqueeze(0), (0, 0, 0, num_frames - mel.shape[0]), mode="replicate", ).squeeze(0) return mel def _hz_to_mel(freq: torch.Tensor) -> torch.Tensor: return 2595.0 * torch.log10(1.0 + freq / 700.0) def _mel_to_hz(mels: torch.Tensor) -> torch.Tensor: return 700.0 * (torch.pow(10.0, mels / 2595.0) - 1.0) def _create_mel_filterbank(sample_rate: int, n_fft: int, n_mels: int) -> torch.Tensor: n_freqs = n_fft // 2 + 1 all_freqs = torch.linspace(0, sample_rate // 2, n_freqs) m_min = _hz_to_mel(torch.tensor(0.0)) m_max = _hz_to_mel(torch.tensor(float(sample_rate // 2))) m_pts = torch.linspace(m_min, m_max, n_mels + 2) f_pts = _mel_to_hz(m_pts) f_diff = f_pts[1:] - f_pts[:-1] slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1) down_slopes = -slopes[:, :-2] / f_diff[:-1] up_slopes = slopes[:, 2:] / f_diff[1:] fb = torch.maximum(torch.zeros(1), torch.minimum(down_slopes, up_slopes)) if (fb.max(dim=0).values == 0.0).any(): raise ValueError("Mel filterbank has empty filters") return fb def _resample_linear(wav: torch.Tensor, orig_freq: int, new_freq: int) -> torch.Tensor: if orig_freq == new_freq: return wav old_len = wav.shape[-1] new_len = max(1, int(round(old_len * new_freq / orig_freq))) old_pos = np.arange(old_len, dtype=np.float64) new_pos = np.linspace(0, old_len - 1, new_len, dtype=np.float64) channels = [] for channel in wav.cpu().numpy(): channels.append(np.interp(new_pos, old_pos, channel).astype(np.float32)) return torch.from_numpy(np.stack(channels, axis=0)) def load_prompt_wav(prompt_wav: str | Path, sampling_rate: int) -> torch.Tensor: wav_np, sr = sf.read(str(prompt_wav), always_2d=True, dtype="float32") wav = torch.from_numpy(wav_np.T.copy()) if sr != sampling_rate: wav = _resample_linear(wav, orig_freq=sr, new_freq=sampling_rate) return wav def rms_norm(wav: torch.Tensor, target_rms: float): wav_rms = torch.sqrt(torch.mean(torch.square(wav))) if wav_rms < target_rms: wav = wav * target_rms / wav_rms return wav, wav_rms def load_local_vocos(vocoder_dir: str | Path): from scripts.local_vocos import LocalVocos vocoder_dir = Path(vocoder_dir) vocoder = LocalVocos() try: state_dict = torch.load( str(vocoder_dir / "pytorch_model.bin"), weights_only=True, map_location="cpu", ) except TypeError: state_dict = torch.load(str(vocoder_dir / "pytorch_model.bin"), map_location="cpu") state_dict = { key: value for key, value in state_dict.items() if key.startswith(("backbone.", "head.")) } vocoder.load_state_dict(state_dict) return vocoder