ZipVoice.AXERA / scripts /common_infer.py
HY-2012's picture
First commit
ea47387 verified
Raw
History Blame Contribute Delete
2.18 kB
from __future__ import annotations
from pathlib import Path
import numpy as np
def load_tokenizer(repo_dir: str | Path):
from scripts.local_tokenizer import LocalEmiliaTokenizer
token_file = Path(repo_dir) / "resources" / "zipvoice_hf" / "zipvoice" / "tokens.txt"
if not token_file.is_file():
raise FileNotFoundError(f"tokens.txt not found: {token_file}")
return LocalEmiliaTokenizer(token_file=str(token_file))
def extract_prompt_features(
prompt_wav: str | Path,
repo_dir: str | Path,
sampling_rate: int,
feat_scale: float,
target_rms: float,
):
from scripts.local_audio import LocalVocosFbank, load_prompt_wav, rms_norm
import torch
wav = load_prompt_wav(prompt_wav, sampling_rate=sampling_rate)
wav, prompt_rms = rms_norm(wav, target_rms)
extractor = LocalVocosFbank()
features = extractor.extract(wav, sampling_rate=sampling_rate)
if not isinstance(features, torch.Tensor):
features = torch.from_numpy(features)
features = features.unsqueeze(0) * feat_scale
return features.cpu().numpy().astype(np.float32), float(prompt_rms)
def load_vocoder(repo_dir: str | Path):
from scripts.local_audio import load_local_vocos
import torch
vocoder_dir = Path(repo_dir) / "resources" / "vocos-mel-24khz"
if not (vocoder_dir / "config.yaml").is_file() or not (
vocoder_dir / "pytorch_model.bin"
).is_file():
raise FileNotFoundError(f"Local Vocos files not found in {vocoder_dir}")
vocoder = load_local_vocos(vocoder_dir)
vocoder = vocoder.to(torch.device("cpu"))
vocoder.eval()
return vocoder
def vocoder_decode_loaded(
vocoder,
features: np.ndarray,
feat_scale: float,
target_rms: float,
prompt_rms: float,
) -> np.ndarray:
from scripts.local_audio import rms_norm
import torch
feat_tensor = torch.from_numpy(features).float().permute(0, 2, 1) / feat_scale
with torch.no_grad():
wav = vocoder.decode(feat_tensor).squeeze(1).clamp(-1, 1)
wav = rms_norm(wav, target_rms)[0]
if prompt_rms < target_rms:
wav = wav * prompt_rms / target_rms
return wav.squeeze().cpu().numpy()