tts_engine_v1 / miner.py
arwin0727's picture
Upload miner.py with huggingface_hub
08a3ea9 verified
from __future__ import annotations
import io
import logging
import re
import threading
import wave
from functools import cached_property
from pathlib import Path
from types import SimpleNamespace
from typing import Any
import numpy as np
_VOCENCE_YAML = "vocence_config.yaml"
_MAX_AUDIO_SEC = 120
def _load_yaml(path: Path) -> dict[str, Any]:
if not path.is_file():
return {}
from yaml import safe_load
with path.open("r", encoding="utf-8") as fh:
return safe_load(fh) or {}
def _normalize_instruction_for_qwen(instruction: str) -> str:
out = re.sub(r"American\s+accent", "US accent", instruction, flags=re.IGNORECASE)
return re.sub(r"British\s+accent", "UK accent", out, flags=re.IGNORECASE)
def _accent_tag_from_instruction(instruction: str) -> str | None:
lower = instruction.lower()
if "american accent" in lower:
return "[American accent]"
if "british accent" in lower:
return "[British accent]"
return None
def _build_adapter_text(instruction: str, text: str) -> str:
tag = _accent_tag_from_instruction(instruction)
body = text.strip() or "Hello."
return f"{tag} {body}".strip() if tag else body
def _wav_bytes(audio: np.ndarray, sample_rate: int) -> bytes:
audio = np.clip(np.asarray(audio, dtype=np.float32), -1.0, 1.0)
s16 = (audio * 32767.0).astype(np.int16)
buf = io.BytesIO()
with wave.open(buf, "wb") as wv:
wv.setnchannels(1)
wv.setsampwidth(2)
wv.setframerate(int(sample_rate))
wv.writeframes(s16.tobytes())
return buf.getvalue()
def _resample(w: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
if orig_sr == target_sr:
return np.asarray(w, dtype=np.float32)
import librosa
y = np.asarray(w, dtype=np.float32)
if y.ndim > 1:
y = np.mean(y, axis=-1).astype(np.float32)
return librosa.resample(y, orig_sr=orig_sr, target_sr=target_sr).astype(np.float32)
def _ensure_fish_speech_project_root() -> None:
"""fish_speech calls pyrootutils.setup_root() on import; needs a .project-root marker."""
import os
import sys
seen: set[Path] = set()
def _mark(root: Path) -> None:
for candidate in (root.resolve(), root.resolve().parent):
if candidate in seen:
continue
seen.add(candidate)
(candidate / ".project-root").touch(exist_ok=True)
for env_root in os.environ.get("FISH_SPEECH_ROOT", "").split(os.pathsep):
env_root = env_root.strip()
if env_root:
p = Path(env_root)
_mark(p / "fish_speech" if (p / "fish_speech").is_dir() else p)
for entry in sys.path:
if not entry:
continue
base = Path(entry)
pkg = base / "fish_speech"
if pkg.is_dir() and (pkg / "models").is_dir():
_mark(pkg)
for repo_root in (Path("/app/fish-speech"),):
if repo_root.is_dir():
_mark(repo_root)
if (repo_root / "fish_speech").is_dir():
_mark(repo_root / "fish_speech")
class _Adapter:
def __init__(self, model_dir: Path, *, use_half: bool = False, use_compile: bool = False) -> None:
self.root = model_dir.resolve()
if not (self.root / "config.json").is_file():
raise FileNotFoundError(f"config.json not present in {self.root}")
if not (self.root / "codec.pth").is_file():
raise FileNotFoundError("codec.pth not present in adapter dir")
self._use_half = use_half
self._use_compile = use_compile
_ = self.engine
@cached_property
def engine(self) -> Any:
_ensure_fish_speech_project_root()
import torch
from fish_speech.inference_engine import TTSInferenceEngine
from fish_speech.models.dac.inference import load_model as load_decoder_model
from fish_speech.models.text2semantic.inference import launch_thread_safe_queue
device = "cuda:0" if torch.cuda.is_available() else "cpu"
precision = torch.half if self._use_half else torch.bfloat16
llama_queue = launch_thread_safe_queue(
checkpoint_path=str(self.root),
device=device,
precision=precision,
compile=self._use_compile,
)
decoder = load_decoder_model(
config_name="modded_dac_vq",
checkpoint_path=str(self.root / "codec.pth"),
device=device,
)
eng = TTSInferenceEngine(
llama_queue=llama_queue,
decoder_model=decoder,
precision=precision,
compile=self._use_compile,
)
logging.getLogger(__name__).info("[Adapter] ready device=%s", device)
return eng
def generate_wav(
self,
instruction: str,
text: str,
reference_audio: bytes,
reference_text: str,
) -> tuple[np.ndarray, int]:
from fish_speech.utils.schema import ServeReferenceAudio, ServeTTSRequest
fish_text = _build_adapter_text(instruction, text)
req = ServeTTSRequest(
text=fish_text,
references=[ServeReferenceAudio(audio=reference_audio, text=reference_text)],
reference_id=None,
max_new_tokens=1024,
chunk_length=200,
format="wav",
streaming=False,
)
sample_rate = 44100
wave: np.ndarray | None = None
for result in self.engine.inference(req):
if result.code == "error":
raise RuntimeError(str(result.error))
if result.code == "final" and result.audio is not None:
if isinstance(result.audio, tuple):
sample_rate, wave = int(result.audio[0]), np.asarray(
result.audio[1], dtype=np.float32
)
else:
wave = np.asarray(result.audio, dtype=np.float32)
if wave is None or wave.size == 0:
raise ValueError("fish adapter produced no audio")
if wave.ndim > 1:
wave = wave.mean(axis=1)
return wave.astype(np.float32), sample_rate
class QwenMiner:
REPO_SENTINEL = "config.json"
SETTINGS_FILE = "vocence_config.yaml"
WARMUP_TIMEOUT = 180.0
def __init__(self, path_hf_repo: Path | str) -> None:
self.root = Path(path_hf_repo).resolve()
if not (self.root / self.REPO_SENTINEL).is_file():
raise FileNotFoundError(f"{self.REPO_SENTINEL} not present in {self.root}")
_ = self.settings
_ = self.model
def __repr__(self) -> str:
return f"<QwenMiner root={self.root.name} language={self.settings.language!r}>"
@cached_property
def settings(self) -> SimpleNamespace:
raw = self._load_yaml(self.root / self.SETTINGS_FILE)
rt = raw.get("runtime") or {}
gen = raw.get("generation") or {}
lim = raw.get("limits") or {}
return SimpleNamespace(
language=str(lim.get("default_language") or rt.get("default_language") or "English"),
sample_rate=int(gen.get("sample_rate", 24000)),
max_instruction_chars=int(lim.get("max_instruction_chars", 600)),
max_text_chars=int(lim.get("max_text_chars", 2000)),
prefer_cuda=str(rt.get("device_preference", "cuda")).lower() == "cuda",
prefer_bf16=str(rt.get("dtype", "bfloat16")).lower() == "bfloat16",
prefer_flash=bool(rt.get("use_flash_attention_2", False)),
)
@cached_property
def model(self) -> Any:
return self._instantiate_engine()
def warmup(self) -> None:
outcome: dict[str, Any] = {"done": False, "err": None}
def _trial() -> None:
try:
self.generate_wav(instruction="Neutral voice.", text="Warming up.")
outcome["done"] = True
except Exception as exc:
outcome["err"] = repr(exc)
worker = threading.Thread(target=_trial, daemon=True)
worker.start()
worker.join(timeout=self.WARMUP_TIMEOUT)
if not outcome["done"]:
raise RuntimeError(
f"warmup did not complete within {self.WARMUP_TIMEOUT}s: "
f"{outcome['err'] or 'no completion signal'}"
)
def generate_wav(self, instruction: str, text: str) -> tuple[np.ndarray, int]:
s = self.settings
prompt = instruction[: s.max_instruction_chars] if s.max_instruction_chars > 0 else instruction
prompt = _normalize_instruction_for_qwen(prompt)
body = text[: s.max_text_chars] if s.max_text_chars > 0 else text
wavs, sample_rate = self.model.generate_voice_design(
text=body,
instruct=prompt,
language=s.language,
)
if not wavs or wavs[0] is None:
raise ValueError("qwen3-tts produced no audio")
wave = np.asarray(wavs[0], dtype=np.float32)
if wave.ndim > 1:
wave = wave.mean(axis=1)
return wave, int(sample_rate)
def _instantiate_engine(self) -> Any:
import torch
from qwen_tts import Qwen3TTSModel
s = self.settings
cuda_ready = bool(torch.cuda.is_available())
device_map = "cuda:0" if (s.prefer_cuda and cuda_ready) else "cpu"
torch_dtype = torch.bfloat16 if (s.prefer_bf16 and cuda_ready) else torch.float32
attempts = ("flash_attention_2", "sdpa") if s.prefer_flash else ("sdpa",)
model_name = str(self.root)
last_failure: BaseException | None = None
for attn in attempts:
try:
engine = Qwen3TTSModel.from_pretrained(
model_name,
device_map=device_map,
dtype=torch_dtype,
attn_implementation=attn,
)
logging.getLogger(__name__).info(
"[QwenMiner] ready device=%s attn=%s", device_map, attn
)
return engine
except Exception as exc:
last_failure = exc
raise RuntimeError(f"qwen3-tts failed to load :: {last_failure!r}")
@staticmethod
def _load_yaml(path: Path) -> dict[str, Any]:
return _load_yaml(path)
class Miner:
REPO_SENTINEL = "config.json"
WARMUP_TIMEOUT = 180.0
def __init__(
self,
path_hf_repo: Path | str,
adapter_subdir: str | None = None,
output_sample_rate: int | None = None,
adapter_compile: bool = False,
) -> None:
self.repo_root = Path(path_hf_repo).resolve()
if not (self.repo_root / self.REPO_SENTINEL).is_file():
raise FileNotFoundError(f"{self.REPO_SENTINEL} not present in {self.repo_root}")
raw = _load_yaml(self.repo_root / _VOCENCE_YAML)
rt = raw.get("runtime") or {}
self.adapter_subdir = str(adapter_subdir or rt.get("adapter") or "adapter").strip()
self.adapter_dir = (self.repo_root / self.adapter_subdir).resolve()
if not (self.adapter_dir / self.REPO_SENTINEL).is_file():
raise FileNotFoundError(
f"adapter checkpoint missing: {self.adapter_dir} "
f"(expected {self.adapter_subdir}/ under the HF repo)"
)
gen = raw.get("generation") or {}
lim = raw.get("limits") or {}
self._out_sr = int(
output_sample_rate
if output_sample_rate is not None
else gen.get("sample_rate", 24000)
)
self._max_instruction_chars = int(lim.get("max_instruction_chars", 600))
self._max_text_chars = int(lim.get("max_text_chars", 2000))
self._qwen = QwenMiner(self.repo_root)
self._fish = _Adapter(
self.adapter_dir,
use_half=bool(rt.get("use_half", False)),
use_compile=adapter_compile or bool(rt.get("use_compile", False)),
)
_ = self._qwen.model
_ = self._fish.engine
def __repr__(self) -> str:
return (
f"<Miner repo={self.repo_root.name!r} "
f"adapter={self.adapter_dir.name!r} out_sr={self._out_sr}>"
)
@cached_property
def settings(self) -> SimpleNamespace:
return self._qwen.settings
@property
def output_sample_rate(self) -> int:
return self._out_sr
def warmup(self) -> None:
outcome: dict[str, Any] = {"done": False, "err": None}
def _trial() -> None:
try:
self.generate_wav(instruction="Neutral tone.", text="Warming up.")
outcome["done"] = True
except Exception as exc:
outcome["err"] = repr(exc)
worker = threading.Thread(target=_trial, daemon=True)
worker.start()
worker.join(timeout=self.WARMUP_TIMEOUT)
if not outcome["done"]:
raise RuntimeError(
f"warmup did not complete within {self.WARMUP_TIMEOUT}s: "
f"{outcome['err'] or 'no completion signal'}"
)
def generate_wav(self, instruction: str, text: str) -> tuple[np.ndarray, int]:
log = logging.getLogger(__name__)
instruction = (
instruction[: self._max_instruction_chars]
if self._max_instruction_chars > 0
else instruction
)
text = text[: self._max_text_chars] if self._max_text_chars > 0 else text
ref_np, ref_sr = self._qwen.generate_wav(instruction=instruction, text=text)
ref_bytes = _wav_bytes(ref_np, ref_sr)
duration = float(len(ref_np)) / max(1, ref_sr)
if duration <= 0 or duration > _MAX_AUDIO_SEC:
raise ValueError(f"invalid reference duration: {duration:.2f}s")
audio_np, fish_sr = self._fish.generate_wav(
instruction=instruction,
text=text,
reference_audio=ref_bytes,
reference_text=text,
)
out = _resample(np.asarray(audio_np, dtype=np.float32), int(fish_sr), self._out_sr)
return out, self._out_sr