minerTTS / miner.py
aiseosae's picture
Upload miner.py with huggingface_hub
485e837 verified
from __future__ import annotations
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeout
from pathlib import Path
from typing import Any
import numpy as np
VOCENCE_CONFIG = "vocence_config.yaml"
QWEN_ANCHOR = "config.json"
WARMUP_SECONDS = 180.0
def _load_yaml(path: Path) -> dict[str, Any]:
if not path.is_file():
return {}
from yaml import safe_load
with path.open("r", encoding="utf-8") as fh:
return safe_load(fh) or {}
def _select_device(prefer_cuda: bool):
import torch
has_cuda = torch.cuda.is_available()
device = "cuda:0" if (prefer_cuda and has_cuda) else "cpu"
return device, torch, has_cuda
def _select_dtype(torch_mod, want_bf16: bool, has_cuda: bool):
return torch_mod.bfloat16 if (want_bf16 and has_cuda) else torch_mod.float32
def _build_qwen(model_name: str, device: str, dtype: Any, attn: str):
from qwen_tts import Qwen3TTSModel
return Qwen3TTSModel.from_pretrained(
pretrained_model_name_or_path=model_name,
device_map=device,
dtype=dtype,
attn_implementation=attn,
)
def _attn_order(prefer_flash: bool) -> tuple[str, ...]:
return ("flash_attention_2", "sdpa") if prefer_flash else ("sdpa",)
def _mono_pcm(arr: Any) -> np.ndarray:
wave = np.asarray(arr, dtype=np.float32)
return wave.mean(axis=1) if wave.ndim > 1 else wave
def _settings(snapshot: Path) -> dict[str, Any]:
raw = _load_yaml(snapshot / VOCENCE_CONFIG)
rt = raw.get("runtime") or {}
gen = raw.get("generation") or {}
lim = raw.get("limits") or {}
return {
"model_name": str(raw["model_name"]),
"language": str(lim.get("default_language") or rt.get("default_language") or "English"),
"sample_rate": int(gen.get("sample_rate", 24000)),
"cap_instruct": int(lim.get("max_instruction_chars", 600)),
"cap_text": int(lim.get("max_text_chars", 2000)),
"prefer_cuda": str(rt.get("device_preference", "cuda")).lower() == "cuda",
"prefer_bf16": str(rt.get("dtype", "bfloat16")).lower() == "bfloat16",
"prefer_flash": bool(rt.get("use_flash_attention_2", False)),
}
class Miner:
def __init__(self, path_hf_repo: Path) -> None:
snapshot = Path(path_hf_repo).resolve()
if not (snapshot / QWEN_ANCHOR).is_file():
raise FileNotFoundError(f"snapshot missing {QWEN_ANCHOR}: {snapshot}")
self.snapshot = snapshot
self.cfg = _settings(snapshot)
model_name = self.cfg["model_name"]
device, torch_mod, has_cuda = _select_device(self.cfg["prefer_cuda"])
dtype = _select_dtype(torch_mod, self.cfg["prefer_bf16"], has_cuda)
last_err: BaseException | None = None
engine = None
for attn in _attn_order(self.cfg["prefer_flash"]):
try:
engine = _build_qwen(model_name, device, dtype, attn)
tag = "bf16" if self.cfg["prefer_bf16"] and has_cuda else "fp32"
print(f"[Miner] qwen3-tts ready: model={model_name} device={device} dtype={tag} attn={attn}")
break
except Exception as exc:
last_err = exc
if engine is None:
raise RuntimeError(f"qwen3-tts load failed: {last_err!r}")
self.engine = engine
def __repr__(self) -> str:
return f"<Miner model={self.cfg['model_name']!r} lang={self.cfg['language']!r}>"
def warmup(self) -> None:
instruct = (
"An adult female with an American accent, speaking at a normal pace "
"in a mid-range pitch with a neutral emotional tone."
)
with ThreadPoolExecutor(max_workers=1) as pool:
future = pool.submit(self.generate_wav, instruct, "Warmup phrase for inference.")
try:
future.result(timeout=WARMUP_SECONDS)
except FutureTimeout:
raise RuntimeError(f"Miner warmup exceeded {WARMUP_SECONDS}s")
def generate_wav(self, instruction: str, text: str) -> tuple[np.ndarray, int]:
"""Synthesize mono float32 PCM.
Vocence requires `instruction` and `text` to be passed verbatim to the model.
Do not rewrite, enrich, or reformat either string.
"""
cap_i = self.cfg["cap_instruct"]
cap_t = self.cfg["cap_text"]
instruct = instruction[:cap_i] if cap_i > 0 else instruction
body = text[:cap_t] if cap_t > 0 else text
wavs, sr = self.engine.generate_voice_design(
text=body,
instruct=instruct,
language=self.cfg["language"],
)
if not wavs or wavs[0] is None:
raise ValueError("qwen3-tts returned no audio")
return _mono_pcm(wavs[0]), int(sr)