ZipVoice.AXERA / scripts /local_audio.py
HY-2012's picture
Add ZipVoice_distill model
a17e6b8 verified
Raw
History Blame Contribute Delete
4.91 kB
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import soundfile as sf
import torch
@dataclass
class VocosFbankConfig:
sampling_rate: int = 24000
n_mels: int = 100
n_fft: int = 1024
hop_length: int = 256
def compute_num_frames(num_samples: int, hop_length: int) -> int:
return int((num_samples + hop_length // 2) // hop_length)
class LocalVocosFbank:
def __init__(self) -> None:
self.config = VocosFbankConfig()
self.window = torch.hann_window(self.config.n_fft)
self.mel_basis = _create_mel_filterbank(
sample_rate=self.config.sampling_rate,
n_fft=self.config.n_fft,
n_mels=self.config.n_mels,
)
def extract(self, samples: torch.Tensor, sampling_rate: int) -> torch.Tensor:
if sampling_rate != self.config.sampling_rate:
raise ValueError(
f"Mismatched sampling rate: expected {self.config.sampling_rate}, got {sampling_rate}"
)
if samples.ndim == 1:
samples = samples.unsqueeze(0)
if samples.ndim != 2:
raise ValueError(f"Expected waveform shape [C, T], got {tuple(samples.shape)}")
if samples.shape[0] == 2:
samples = samples.mean(dim=0, keepdim=True)
stft = torch.stft(
samples,
n_fft=self.config.n_fft,
hop_length=self.config.hop_length,
win_length=self.config.n_fft,
window=self.window.to(samples.device),
center=True,
pad_mode="reflect",
return_complex=True,
)
spec = stft.abs()
mel = torch.matmul(self.mel_basis.to(samples.device).t(), spec).clamp(min=1e-7).log()
mel = mel.reshape(-1, mel.shape[-1]).t()
num_frames = compute_num_frames(samples.shape[1], self.config.hop_length)
if mel.shape[0] > num_frames:
mel = mel[:num_frames]
elif mel.shape[0] < num_frames:
mel = torch.nn.functional.pad(
mel.unsqueeze(0),
(0, 0, 0, num_frames - mel.shape[0]),
mode="replicate",
).squeeze(0)
return mel
def _hz_to_mel(freq: torch.Tensor) -> torch.Tensor:
return 2595.0 * torch.log10(1.0 + freq / 700.0)
def _mel_to_hz(mels: torch.Tensor) -> torch.Tensor:
return 700.0 * (torch.pow(10.0, mels / 2595.0) - 1.0)
def _create_mel_filterbank(sample_rate: int, n_fft: int, n_mels: int) -> torch.Tensor:
n_freqs = n_fft // 2 + 1
all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
m_min = _hz_to_mel(torch.tensor(0.0))
m_max = _hz_to_mel(torch.tensor(float(sample_rate // 2)))
m_pts = torch.linspace(m_min, m_max, n_mels + 2)
f_pts = _mel_to_hz(m_pts)
f_diff = f_pts[1:] - f_pts[:-1]
slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1)
down_slopes = -slopes[:, :-2] / f_diff[:-1]
up_slopes = slopes[:, 2:] / f_diff[1:]
fb = torch.maximum(torch.zeros(1), torch.minimum(down_slopes, up_slopes))
if (fb.max(dim=0).values == 0.0).any():
raise ValueError("Mel filterbank has empty filters")
return fb
def _resample_linear(wav: torch.Tensor, orig_freq: int, new_freq: int) -> torch.Tensor:
if orig_freq == new_freq:
return wav
old_len = wav.shape[-1]
new_len = max(1, int(round(old_len * new_freq / orig_freq)))
old_pos = np.arange(old_len, dtype=np.float64)
new_pos = np.linspace(0, old_len - 1, new_len, dtype=np.float64)
channels = []
for channel in wav.cpu().numpy():
channels.append(np.interp(new_pos, old_pos, channel).astype(np.float32))
return torch.from_numpy(np.stack(channels, axis=0))
def load_prompt_wav(prompt_wav: str | Path, sampling_rate: int) -> torch.Tensor:
wav_np, sr = sf.read(str(prompt_wav), always_2d=True, dtype="float32")
wav = torch.from_numpy(wav_np.T.copy())
if sr != sampling_rate:
wav = _resample_linear(wav, orig_freq=sr, new_freq=sampling_rate)
return wav
def rms_norm(wav: torch.Tensor, target_rms: float):
wav_rms = torch.sqrt(torch.mean(torch.square(wav)))
if wav_rms < target_rms:
wav = wav * target_rms / wav_rms
return wav, wav_rms
def load_local_vocos(vocoder_dir: str | Path):
from scripts.local_vocos import LocalVocos
vocoder_dir = Path(vocoder_dir)
vocoder = LocalVocos()
try:
state_dict = torch.load(
str(vocoder_dir / "pytorch_model.bin"),
weights_only=True,
map_location="cpu",
)
except TypeError:
state_dict = torch.load(str(vocoder_dir / "pytorch_model.bin"), map_location="cpu")
state_dict = {
key: value
for key, value in state_dict.items()
if key.startswith(("backbone.", "head."))
}
vocoder.load_state_dict(state_dict)
return vocoder