from __future__ import annotations import json import wave from dataclasses import dataclass from pathlib import Path from typing import Any import numpy as np from safetensors.torch import load_file as load_safetensors_file from .constants import DEFAULT_ESPEAK_VOICE, DEFAULT_SAMPLE_RATE from .processor import PreparedInput, prepare_input from .vits import SynthesizerTrn def _repo_root() -> Path: return Path(__file__).resolve().parents[2] def _default_model_path() -> Path: safetensors_path = _repo_root() / "model.safetensors" if safetensors_path.exists(): return safetensors_path return _repo_root() / "model.ckpt" def _default_config_path() -> Path: return _repo_root() / "config.json" def _import_torch() -> Any: try: import torch except ImportError as exc: raise ImportError("torch is required for checkpoint inference") from exc return torch def load_release_config(config_path: str | Path) -> dict[str, Any]: with Path(config_path).open("r", encoding="utf-8") as config_file: return json.load(config_file) def audio_float_to_int16(audio: np.ndarray, max_wav_value: float = 32767.0) -> np.ndarray: audio = np.asarray(audio, dtype=np.float32) scale = max(0.01, float(np.max(np.abs(audio)))) if audio.size else 1.0 audio_norm = audio * (max_wav_value / scale) audio_norm = np.clip(audio_norm, -max_wav_value, max_wav_value) return audio_norm.astype(np.int16) def write_wave(path: str | Path, samples: np.ndarray, sample_rate: int) -> Path: path = Path(path) pcm = audio_float_to_int16(samples) with wave.open(str(path), "wb") as wav_file: wav_file.setnchannels(1) wav_file.setsampwidth(2) wav_file.setframerate(sample_rate) wav_file.writeframes(pcm.tobytes()) return path def _generator_kwargs_from_config(config: dict[str, Any]) -> dict[str, Any]: model = config.get("model", {}) return { "n_vocab": int(config["num_symbols"]), "spec_channels": int(model["filter_length"]) // 2 + 1, "segment_size": int(model["segment_size"]) // int(model["hop_length"]), "inter_channels": int(model["inter_channels"]), "hidden_channels": int(model["hidden_channels"]), "filter_channels": int(model["filter_channels"]), "n_heads": int(model["n_heads"]), "n_layers": int(model["n_layers"]), "kernel_size": int(model["kernel_size"]), "p_dropout": float(model["p_dropout"]), "resblock": model["resblock"], "resblock_kernel_sizes": tuple(model["resblock_kernel_sizes"]), "resblock_dilation_sizes": tuple(tuple(x) for x in model["resblock_dilation_sizes"]), "upsample_rates": tuple(model["upsample_rates"]), "upsample_initial_channel": int(model["upsample_initial_channel"]), "upsample_kernel_sizes": tuple(model["upsample_kernel_sizes"]), "n_speakers": int(config["num_speakers"]), "gin_channels": int(model["gin_channels"]), "use_sdp": bool(model.get("use_sdp", True)), } def _load_generator_state(model_path: Path, torch_module: Any) -> dict[str, Any]: if model_path.suffix == ".safetensors": return load_safetensors_file(str(model_path), device="cpu") checkpoint = torch_module.load(model_path, map_location="cpu", weights_only=False) state_dict = checkpoint["state_dict"] return { key[len("model_g.") :]: value for key, value in state_dict.items() if key.startswith("model_g.") } @dataclass(frozen=True) class GeneratedAudio: samples: np.ndarray sample_rate: int prepared_input: PreparedInput class WfloatGenerator: def __init__( self, checkpoint_path: str | Path | None = None, config_path: str | Path | None = None, device: str = "cpu", ) -> None: self.checkpoint_path = Path(checkpoint_path or _default_model_path()) self.config_path = Path(config_path or _default_config_path()) self.device = device if not self.checkpoint_path.exists(): raise FileNotFoundError( f"Checkpoint not found at {self.checkpoint_path}. " "Place a compatible multi-speaker checkpoint there or pass --checkpoint." ) if not self.config_path.exists(): raise FileNotFoundError(f"Config not found at {self.config_path}") self.config = load_release_config(self.config_path) self.sample_rate = int(self.config.get("audio", {}).get("sample_rate", DEFAULT_SAMPLE_RATE)) self.espeak_voice = self.config.get("espeak", {}).get("voice", DEFAULT_ESPEAK_VOICE) self.num_speakers = int(self.config.get("num_speakers", 1)) torch = _import_torch() self._torch = torch self._model = SynthesizerTrn(**_generator_kwargs_from_config(self.config)) state_dict = _load_generator_state(self.checkpoint_path, torch) self._model.load_state_dict(state_dict, strict=True) self._model.eval() with torch.no_grad(): self._model.dec.remove_weight_norm() self._model.to(self.device) self.num_speakers = int(getattr(self._model, "n_speakers", self.num_speakers)) configured_num_speakers = int(self.config.get("num_speakers", self.num_speakers)) if configured_num_speakers != self.num_speakers: raise ValueError( "Checkpoint/config mismatch: " f"config.json declares num_speakers={configured_num_speakers}, " f"but checkpoint reports num_speakers={self.num_speakers}." ) def generate( self, text: str, sid: int = 0, emotion: str = "neutral", intensity: float = 0.5, noise_scale: float | None = None, length_scale: float | None = None, noise_w: float | None = None, ) -> GeneratedAudio: if self.num_speakers <= 1: if sid not in (0, None): raise ValueError( f"Loaded checkpoint is single-speaker but sid={sid} was provided" ) sid_tensor = None else: sid_tensor = self._torch.LongTensor([int(sid)]).to(self.device) prepared = prepare_input( text=text, config=self.config, emotion=emotion, intensity=intensity, espeak_voice=self.espeak_voice, ) text_tensor = self._torch.LongTensor(prepared.token_ids).unsqueeze(0).to(self.device) text_lengths = self._torch.LongTensor([len(prepared.token_ids)]).to(self.device) inference = self.config.get("inference", {}) scales = [ float(inference.get("noise_scale", 0.667) if noise_scale is None else noise_scale), float(inference.get("length_scale", 1.0) if length_scale is None else length_scale), float(inference.get("noise_w", 0.8) if noise_w is None else noise_w), ] with self._torch.no_grad(): audio, *_ = self._model.infer( text_tensor, text_lengths, sid=sid_tensor, noise_scale=scales[0], length_scale=scales[1], noise_scale_w=scales[2], ) samples = audio.detach().cpu().numpy().squeeze().astype(np.float32) return GeneratedAudio( samples=samples, sample_rate=self.sample_rate, prepared_input=prepared, ) def load_generator( checkpoint_path: str | Path | None = None, config_path: str | Path | None = None, device: str = "cpu", ) -> WfloatGenerator: return WfloatGenerator( checkpoint_path=checkpoint_path, config_path=config_path, device=device, )