wfloat-tts / src /wfloat_tts /infer.py
mitchsayre's picture
Init
f71bc95
from __future__ import annotations
import json
import wave
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import numpy as np
from safetensors.torch import load_file as load_safetensors_file
from .constants import DEFAULT_ESPEAK_VOICE, DEFAULT_SAMPLE_RATE
from .processor import PreparedInput, prepare_input
from .vits import SynthesizerTrn
def _repo_root() -> Path:
return Path(__file__).resolve().parents[2]
def _default_model_path() -> Path:
safetensors_path = _repo_root() / "model.safetensors"
if safetensors_path.exists():
return safetensors_path
return _repo_root() / "model.ckpt"
def _default_config_path() -> Path:
return _repo_root() / "config.json"
def _import_torch() -> Any:
try:
import torch
except ImportError as exc:
raise ImportError("torch is required for checkpoint inference") from exc
return torch
def load_release_config(config_path: str | Path) -> dict[str, Any]:
with Path(config_path).open("r", encoding="utf-8") as config_file:
return json.load(config_file)
def audio_float_to_int16(audio: np.ndarray, max_wav_value: float = 32767.0) -> np.ndarray:
audio = np.asarray(audio, dtype=np.float32)
scale = max(0.01, float(np.max(np.abs(audio)))) if audio.size else 1.0
audio_norm = audio * (max_wav_value / scale)
audio_norm = np.clip(audio_norm, -max_wav_value, max_wav_value)
return audio_norm.astype(np.int16)
def write_wave(path: str | Path, samples: np.ndarray, sample_rate: int) -> Path:
path = Path(path)
pcm = audio_float_to_int16(samples)
with wave.open(str(path), "wb") as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2)
wav_file.setframerate(sample_rate)
wav_file.writeframes(pcm.tobytes())
return path
def _generator_kwargs_from_config(config: dict[str, Any]) -> dict[str, Any]:
model = config.get("model", {})
return {
"n_vocab": int(config["num_symbols"]),
"spec_channels": int(model["filter_length"]) // 2 + 1,
"segment_size": int(model["segment_size"]) // int(model["hop_length"]),
"inter_channels": int(model["inter_channels"]),
"hidden_channels": int(model["hidden_channels"]),
"filter_channels": int(model["filter_channels"]),
"n_heads": int(model["n_heads"]),
"n_layers": int(model["n_layers"]),
"kernel_size": int(model["kernel_size"]),
"p_dropout": float(model["p_dropout"]),
"resblock": model["resblock"],
"resblock_kernel_sizes": tuple(model["resblock_kernel_sizes"]),
"resblock_dilation_sizes": tuple(tuple(x) for x in model["resblock_dilation_sizes"]),
"upsample_rates": tuple(model["upsample_rates"]),
"upsample_initial_channel": int(model["upsample_initial_channel"]),
"upsample_kernel_sizes": tuple(model["upsample_kernel_sizes"]),
"n_speakers": int(config["num_speakers"]),
"gin_channels": int(model["gin_channels"]),
"use_sdp": bool(model.get("use_sdp", True)),
}
def _load_generator_state(model_path: Path, torch_module: Any) -> dict[str, Any]:
if model_path.suffix == ".safetensors":
return load_safetensors_file(str(model_path), device="cpu")
checkpoint = torch_module.load(model_path, map_location="cpu", weights_only=False)
state_dict = checkpoint["state_dict"]
return {
key[len("model_g.") :]: value
for key, value in state_dict.items()
if key.startswith("model_g.")
}
@dataclass(frozen=True)
class GeneratedAudio:
samples: np.ndarray
sample_rate: int
prepared_input: PreparedInput
class WfloatGenerator:
def __init__(
self,
checkpoint_path: str | Path | None = None,
config_path: str | Path | None = None,
device: str = "cpu",
) -> None:
self.checkpoint_path = Path(checkpoint_path or _default_model_path())
self.config_path = Path(config_path or _default_config_path())
self.device = device
if not self.checkpoint_path.exists():
raise FileNotFoundError(
f"Checkpoint not found at {self.checkpoint_path}. "
"Place a compatible multi-speaker checkpoint there or pass --checkpoint."
)
if not self.config_path.exists():
raise FileNotFoundError(f"Config not found at {self.config_path}")
self.config = load_release_config(self.config_path)
self.sample_rate = int(self.config.get("audio", {}).get("sample_rate", DEFAULT_SAMPLE_RATE))
self.espeak_voice = self.config.get("espeak", {}).get("voice", DEFAULT_ESPEAK_VOICE)
self.num_speakers = int(self.config.get("num_speakers", 1))
torch = _import_torch()
self._torch = torch
self._model = SynthesizerTrn(**_generator_kwargs_from_config(self.config))
state_dict = _load_generator_state(self.checkpoint_path, torch)
self._model.load_state_dict(state_dict, strict=True)
self._model.eval()
with torch.no_grad():
self._model.dec.remove_weight_norm()
self._model.to(self.device)
self.num_speakers = int(getattr(self._model, "n_speakers", self.num_speakers))
configured_num_speakers = int(self.config.get("num_speakers", self.num_speakers))
if configured_num_speakers != self.num_speakers:
raise ValueError(
"Checkpoint/config mismatch: "
f"config.json declares num_speakers={configured_num_speakers}, "
f"but checkpoint reports num_speakers={self.num_speakers}."
)
def generate(
self,
text: str,
sid: int = 0,
emotion: str = "neutral",
intensity: float = 0.5,
noise_scale: float | None = None,
length_scale: float | None = None,
noise_w: float | None = None,
) -> GeneratedAudio:
if self.num_speakers <= 1:
if sid not in (0, None):
raise ValueError(
f"Loaded checkpoint is single-speaker but sid={sid} was provided"
)
sid_tensor = None
else:
sid_tensor = self._torch.LongTensor([int(sid)]).to(self.device)
prepared = prepare_input(
text=text,
config=self.config,
emotion=emotion,
intensity=intensity,
espeak_voice=self.espeak_voice,
)
text_tensor = self._torch.LongTensor(prepared.token_ids).unsqueeze(0).to(self.device)
text_lengths = self._torch.LongTensor([len(prepared.token_ids)]).to(self.device)
inference = self.config.get("inference", {})
scales = [
float(inference.get("noise_scale", 0.667) if noise_scale is None else noise_scale),
float(inference.get("length_scale", 1.0) if length_scale is None else length_scale),
float(inference.get("noise_w", 0.8) if noise_w is None else noise_w),
]
with self._torch.no_grad():
audio, *_ = self._model.infer(
text_tensor,
text_lengths,
sid=sid_tensor,
noise_scale=scales[0],
length_scale=scales[1],
noise_scale_w=scales[2],
)
samples = audio.detach().cpu().numpy().squeeze().astype(np.float32)
return GeneratedAudio(
samples=samples,
sample_rate=self.sample_rate,
prepared_input=prepared,
)
def load_generator(
checkpoint_path: str | Path | None = None,
config_path: str | Path | None = None,
device: str = "cpu",
) -> WfloatGenerator:
return WfloatGenerator(
checkpoint_path=checkpoint_path,
config_path=config_path,
device=device,
)