File size: 4,906 Bytes
ea47387 a17e6b8 ea47387 a17e6b8 ea47387 a17e6b8 ea47387 a17e6b8 ea47387 a17e6b8 ea47387 a17e6b8 ea47387 a17e6b8 ea47387 a17e6b8 ea47387 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 | from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import soundfile as sf
import torch
@dataclass
class VocosFbankConfig:
sampling_rate: int = 24000
n_mels: int = 100
n_fft: int = 1024
hop_length: int = 256
def compute_num_frames(num_samples: int, hop_length: int) -> int:
return int((num_samples + hop_length // 2) // hop_length)
class LocalVocosFbank:
def __init__(self) -> None:
self.config = VocosFbankConfig()
self.window = torch.hann_window(self.config.n_fft)
self.mel_basis = _create_mel_filterbank(
sample_rate=self.config.sampling_rate,
n_fft=self.config.n_fft,
n_mels=self.config.n_mels,
)
def extract(self, samples: torch.Tensor, sampling_rate: int) -> torch.Tensor:
if sampling_rate != self.config.sampling_rate:
raise ValueError(
f"Mismatched sampling rate: expected {self.config.sampling_rate}, got {sampling_rate}"
)
if samples.ndim == 1:
samples = samples.unsqueeze(0)
if samples.ndim != 2:
raise ValueError(f"Expected waveform shape [C, T], got {tuple(samples.shape)}")
if samples.shape[0] == 2:
samples = samples.mean(dim=0, keepdim=True)
stft = torch.stft(
samples,
n_fft=self.config.n_fft,
hop_length=self.config.hop_length,
win_length=self.config.n_fft,
window=self.window.to(samples.device),
center=True,
pad_mode="reflect",
return_complex=True,
)
spec = stft.abs()
mel = torch.matmul(self.mel_basis.to(samples.device).t(), spec).clamp(min=1e-7).log()
mel = mel.reshape(-1, mel.shape[-1]).t()
num_frames = compute_num_frames(samples.shape[1], self.config.hop_length)
if mel.shape[0] > num_frames:
mel = mel[:num_frames]
elif mel.shape[0] < num_frames:
mel = torch.nn.functional.pad(
mel.unsqueeze(0),
(0, 0, 0, num_frames - mel.shape[0]),
mode="replicate",
).squeeze(0)
return mel
def _hz_to_mel(freq: torch.Tensor) -> torch.Tensor:
return 2595.0 * torch.log10(1.0 + freq / 700.0)
def _mel_to_hz(mels: torch.Tensor) -> torch.Tensor:
return 700.0 * (torch.pow(10.0, mels / 2595.0) - 1.0)
def _create_mel_filterbank(sample_rate: int, n_fft: int, n_mels: int) -> torch.Tensor:
n_freqs = n_fft // 2 + 1
all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
m_min = _hz_to_mel(torch.tensor(0.0))
m_max = _hz_to_mel(torch.tensor(float(sample_rate // 2)))
m_pts = torch.linspace(m_min, m_max, n_mels + 2)
f_pts = _mel_to_hz(m_pts)
f_diff = f_pts[1:] - f_pts[:-1]
slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1)
down_slopes = -slopes[:, :-2] / f_diff[:-1]
up_slopes = slopes[:, 2:] / f_diff[1:]
fb = torch.maximum(torch.zeros(1), torch.minimum(down_slopes, up_slopes))
if (fb.max(dim=0).values == 0.0).any():
raise ValueError("Mel filterbank has empty filters")
return fb
def _resample_linear(wav: torch.Tensor, orig_freq: int, new_freq: int) -> torch.Tensor:
if orig_freq == new_freq:
return wav
old_len = wav.shape[-1]
new_len = max(1, int(round(old_len * new_freq / orig_freq)))
old_pos = np.arange(old_len, dtype=np.float64)
new_pos = np.linspace(0, old_len - 1, new_len, dtype=np.float64)
channels = []
for channel in wav.cpu().numpy():
channels.append(np.interp(new_pos, old_pos, channel).astype(np.float32))
return torch.from_numpy(np.stack(channels, axis=0))
def load_prompt_wav(prompt_wav: str | Path, sampling_rate: int) -> torch.Tensor:
wav_np, sr = sf.read(str(prompt_wav), always_2d=True, dtype="float32")
wav = torch.from_numpy(wav_np.T.copy())
if sr != sampling_rate:
wav = _resample_linear(wav, orig_freq=sr, new_freq=sampling_rate)
return wav
def rms_norm(wav: torch.Tensor, target_rms: float):
wav_rms = torch.sqrt(torch.mean(torch.square(wav)))
if wav_rms < target_rms:
wav = wav * target_rms / wav_rms
return wav, wav_rms
def load_local_vocos(vocoder_dir: str | Path):
from scripts.local_vocos import LocalVocos
vocoder_dir = Path(vocoder_dir)
vocoder = LocalVocos()
try:
state_dict = torch.load(
str(vocoder_dir / "pytorch_model.bin"),
weights_only=True,
map_location="cpu",
)
except TypeError:
state_dict = torch.load(str(vocoder_dir / "pytorch_model.bin"), map_location="cpu")
state_dict = {
key: value
for key, value in state_dict.items()
if key.startswith(("backbone.", "head."))
}
vocoder.load_state_dict(state_dict)
return vocoder
|