| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from pathlib import Path |
|
|
| import numpy as np |
| import soundfile as sf |
| import torch |
|
|
|
|
| @dataclass |
| class VocosFbankConfig: |
| sampling_rate: int = 24000 |
| n_mels: int = 100 |
| n_fft: int = 1024 |
| hop_length: int = 256 |
|
|
|
|
| def compute_num_frames(num_samples: int, hop_length: int) -> int: |
| return int((num_samples + hop_length // 2) // hop_length) |
|
|
|
|
| class LocalVocosFbank: |
| def __init__(self) -> None: |
| self.config = VocosFbankConfig() |
| self.window = torch.hann_window(self.config.n_fft) |
| self.mel_basis = _create_mel_filterbank( |
| sample_rate=self.config.sampling_rate, |
| n_fft=self.config.n_fft, |
| n_mels=self.config.n_mels, |
| ) |
|
|
| def extract(self, samples: torch.Tensor, sampling_rate: int) -> torch.Tensor: |
| if sampling_rate != self.config.sampling_rate: |
| raise ValueError( |
| f"Mismatched sampling rate: expected {self.config.sampling_rate}, got {sampling_rate}" |
| ) |
| if samples.ndim == 1: |
| samples = samples.unsqueeze(0) |
| if samples.ndim != 2: |
| raise ValueError(f"Expected waveform shape [C, T], got {tuple(samples.shape)}") |
| if samples.shape[0] == 2: |
| samples = samples.mean(dim=0, keepdim=True) |
|
|
| stft = torch.stft( |
| samples, |
| n_fft=self.config.n_fft, |
| hop_length=self.config.hop_length, |
| win_length=self.config.n_fft, |
| window=self.window.to(samples.device), |
| center=True, |
| pad_mode="reflect", |
| return_complex=True, |
| ) |
| spec = stft.abs() |
| mel = torch.matmul(self.mel_basis.to(samples.device).t(), spec).clamp(min=1e-7).log() |
| mel = mel.reshape(-1, mel.shape[-1]).t() |
| num_frames = compute_num_frames(samples.shape[1], self.config.hop_length) |
|
|
| if mel.shape[0] > num_frames: |
| mel = mel[:num_frames] |
| elif mel.shape[0] < num_frames: |
| mel = torch.nn.functional.pad( |
| mel.unsqueeze(0), |
| (0, 0, 0, num_frames - mel.shape[0]), |
| mode="replicate", |
| ).squeeze(0) |
| return mel |
|
|
|
|
| def _hz_to_mel(freq: torch.Tensor) -> torch.Tensor: |
| return 2595.0 * torch.log10(1.0 + freq / 700.0) |
|
|
|
|
| def _mel_to_hz(mels: torch.Tensor) -> torch.Tensor: |
| return 700.0 * (torch.pow(10.0, mels / 2595.0) - 1.0) |
|
|
|
|
| def _create_mel_filterbank(sample_rate: int, n_fft: int, n_mels: int) -> torch.Tensor: |
| n_freqs = n_fft // 2 + 1 |
| all_freqs = torch.linspace(0, sample_rate // 2, n_freqs) |
| m_min = _hz_to_mel(torch.tensor(0.0)) |
| m_max = _hz_to_mel(torch.tensor(float(sample_rate // 2))) |
| m_pts = torch.linspace(m_min, m_max, n_mels + 2) |
| f_pts = _mel_to_hz(m_pts) |
|
|
| f_diff = f_pts[1:] - f_pts[:-1] |
| slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1) |
| down_slopes = -slopes[:, :-2] / f_diff[:-1] |
| up_slopes = slopes[:, 2:] / f_diff[1:] |
| fb = torch.maximum(torch.zeros(1), torch.minimum(down_slopes, up_slopes)) |
|
|
| if (fb.max(dim=0).values == 0.0).any(): |
| raise ValueError("Mel filterbank has empty filters") |
| return fb |
|
|
|
|
| def _resample_linear(wav: torch.Tensor, orig_freq: int, new_freq: int) -> torch.Tensor: |
| if orig_freq == new_freq: |
| return wav |
| old_len = wav.shape[-1] |
| new_len = max(1, int(round(old_len * new_freq / orig_freq))) |
| old_pos = np.arange(old_len, dtype=np.float64) |
| new_pos = np.linspace(0, old_len - 1, new_len, dtype=np.float64) |
| channels = [] |
| for channel in wav.cpu().numpy(): |
| channels.append(np.interp(new_pos, old_pos, channel).astype(np.float32)) |
| return torch.from_numpy(np.stack(channels, axis=0)) |
|
|
|
|
| def load_prompt_wav(prompt_wav: str | Path, sampling_rate: int) -> torch.Tensor: |
| wav_np, sr = sf.read(str(prompt_wav), always_2d=True, dtype="float32") |
| wav = torch.from_numpy(wav_np.T.copy()) |
| if sr != sampling_rate: |
| wav = _resample_linear(wav, orig_freq=sr, new_freq=sampling_rate) |
| return wav |
|
|
|
|
| def rms_norm(wav: torch.Tensor, target_rms: float): |
| wav_rms = torch.sqrt(torch.mean(torch.square(wav))) |
| if wav_rms < target_rms: |
| wav = wav * target_rms / wav_rms |
| return wav, wav_rms |
|
|
|
|
| def load_local_vocos(vocoder_dir: str | Path): |
| from scripts.local_vocos import LocalVocos |
|
|
| vocoder_dir = Path(vocoder_dir) |
| vocoder = LocalVocos() |
| try: |
| state_dict = torch.load( |
| str(vocoder_dir / "pytorch_model.bin"), |
| weights_only=True, |
| map_location="cpu", |
| ) |
| except TypeError: |
| state_dict = torch.load(str(vocoder_dir / "pytorch_model.bin"), map_location="cpu") |
| state_dict = { |
| key: value |
| for key, value in state_dict.items() |
| if key.startswith(("backbone.", "head.")) |
| } |
| vocoder.load_state_dict(state_dict) |
| return vocoder |
|
|