tts_engine_v2 / miner.py
arwin0727's picture
Upload miner.py with huggingface_hub
8960e0b verified
from __future__ import annotations
"""Vocence TTS miner: HF snapshot (weights + ``fish-speech/``) then Fish inference.
After ``snapshot_download``, code lives at ``<snapshot>/fish-speech/``. Missing
PyPI packages are installed one at a time via ``pip install <module>`` from
``ModuleNotFoundError`` (avoids full ``pyproject.toml`` install pulling ``pyaudio``,
which needs system ``portaudio`` headers). Set ``VOCENCE_SKIP_FISH_SPEECH_PIP=1`` to
disable. Override path with ``fish_speech.repo_root`` or ``FISH_SPEECH_ROOT``.
"""
import io
import logging
import os
import sys
import wave
from pathlib import Path
from typing import Any, Mapping
import numpy as np
REPO = Path(__file__).resolve().parent
_VOCENCE_YAML = "vocence_config.yaml"
_MAX_AUDIO_SEC = 30
_OUT_SR = 24000
_OMEGA_RESOLVER_PATCHED: bool = False
_OrigOmegaRegister: Any = None
def _patch_omegaconf_register_new_resolver() -> None:
"""Retry ``register_new_resolver`` with ``replace=True`` if Hydra/lightning registered ``eval`` first."""
global _OMEGA_RESOLVER_PATCHED, _OrigOmegaRegister
if _OMEGA_RESOLVER_PATCHED:
return
try:
from omegaconf import OmegaConf
except ImportError:
return
if _OrigOmegaRegister is None:
_OrigOmegaRegister = OmegaConf.register_new_resolver
def _patched(name, resolver, *args, **kwargs):
kw = dict(kwargs)
try:
return _OrigOmegaRegister(name, resolver, *args, **kw)
except ValueError as exc:
if "already registered" not in str(exc).lower():
raise
kw["replace"] = True
return _OrigOmegaRegister(name, resolver, *args, **kw)
OmegaConf.register_new_resolver = _patched # type: ignore[method-assign]
_OMEGA_RESOLVER_PATCHED = True
def _read_yaml(repo: Path) -> dict[str, Any]:
from yaml import safe_load
p = repo / _VOCENCE_YAML
if not p.is_file():
return {}
with p.open(encoding="utf-8") as f:
d = safe_load(f)
return dict(d) if isinstance(d, Mapping) else {}
def _hf_token() -> str | None:
t = (os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") or "").strip()
return t or None
def _weights_dir(repo: Path, repo_id: str) -> Path:
safe = repo_id.replace("/", "__").replace(":", "_")
return (repo / "_vocence_hf_weights" / safe).resolve()
def download_hub(repo: Path, repo_id: str, revision: str | None) -> Path:
from huggingface_hub import snapshot_download
dest = _weights_dir(repo, repo_id)
dest.mkdir(parents=True, exist_ok=True)
logging.getLogger(__name__).info("Downloading %s → %s", repo_id, dest)
snapshot_download(repo_id=repo_id, revision=revision, local_dir=str(dest), token=_hf_token())
if not (dest / "codec.pth").is_file():
raise FileNotFoundError(f"missing codec.pth under {dest}")
return dest
def _purge_tools_modules() -> None:
for m in list(sys.modules):
if m == "tools" or m.startswith("tools.") or m == "fish_speech" or m.startswith("fish_speech."):
del sys.modules[m]
# Top-level import name -> PyPI distribution (wrong names break installs, e.g. ``hydra`` vs ``hydra-core``).
_PIP_ALIASES: dict[str, str] = {
"PIL": "Pillow",
"yaml": "PyYAML",
"sklearn": "scikit-learn",
"hydra": "hydra-core",
"pytorch_lightning": "lightning",
}
def _pip_install_module(mod: str) -> None:
"""``pip install`` the PyPI package that provides import name ``mod`` (top-level segment)."""
import subprocess
head = (mod or "").strip().split(".")[0]
if not head:
raise ValueError("empty module name")
if head in ("fish_speech", "tools"):
raise RuntimeError(f"refusing to pip install std project name {head!r}")
std = getattr(sys, "stdlib_module_names", None)
if std is not None and head in std:
raise RuntimeError(f"refusing to pip install stdlib name {head!r}")
pkg = _PIP_ALIASES.get(head, head)
log = logging.getLogger(__name__)
cmd = [sys.executable, "-m", "pip", "install", pkg]
log.info("Running: %s", " ".join(cmd))
r = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
if r.returncode != 0:
msg = (r.stderr or r.stdout or "").strip() or f"exit {r.returncode}"
raise RuntimeError(f"pip install {pkg!r} failed: {msg}")
def _tools_already_importable() -> bool:
import importlib
_patch_omegaconf_register_new_resolver()
try:
importlib.import_module("tools.server.model_manager")
return True
except Exception:
_purge_tools_modules()
return False
def _ensure_fish_speech(miner_repo: Path, model_root: Path, fs: Mapping[str, Any]) -> None:
"""Use ``<model_root>/fish-speech`` on ``sys.path`` (or ``repo_root`` / ``FISH_SPEECH_ROOT``)."""
global _OMEGA_RESOLVER_PATCHED
import importlib
log = logging.getLogger(__name__)
if _tools_already_importable():
return
roots: list[Path] = [(model_root / "fish-speech").resolve()]
raw = (fs.get("repo_root") or os.environ.get("FISH_SPEECH_ROOT") or "").strip()
if raw:
p = Path(raw).expanduser()
roots.append(p.resolve() if p.is_absolute() else (miner_repo / p).resolve())
mm = Path("tools") / "server" / "model_manager.py"
skip_pip = os.environ.get("VOCENCE_SKIP_FISH_SPEECH_PIP", "").strip().lower() in ("1", "true", "yes")
max_rounds = int(os.environ.get("VOCENCE_FISH_SPEECH_PIP_MAX_ROUNDS", "60"))
for code_root in roots:
if not (code_root / mm).is_file():
continue
s = str(code_root.resolve())
if s not in sys.path:
sys.path.insert(0, s)
last_err: BaseException | None = None
for _ in range(max_rounds):
_patch_omegaconf_register_new_resolver()
try:
importlib.import_module("tools.server.model_manager")
return
except ModuleNotFoundError as e:
last_err = e
_purge_tools_modules()
mod = e.name
if skip_pip or mod is None:
try:
sys.path.remove(s)
except ValueError:
pass
raise ImportError(
f"{code_root}: missing {mod!r}. Install deps or unset VOCENCE_SKIP_FISH_SPEECH_PIP."
) from e
head = mod.split(".")[0]
pkg = _PIP_ALIASES.get(head, head)
log.warning("Missing Python module %r — pip install %r …", mod, pkg)
if head in ("fish_speech", "tools"):
try:
sys.path.remove(s)
except ValueError:
pass
raise ImportError(
f"{code_root}: project import {mod!r} failed (broken tree or path?)."
) from e
try:
_pip_install_module(mod)
except Exception as pip_e:
try:
sys.path.remove(s)
except ValueError:
pass
raise ImportError(f"{code_root}: could not install missing {mod!r}: {pip_e}") from pip_e
if s not in sys.path:
sys.path.insert(0, s)
continue
except Exception as e:
msg_l = str(e).lower()
if "already registered" in msg_l and "resolver" in msg_l:
log.warning("OmegaConf resolver clash (%s); clearing ``eval`` and retrying …", e)
last_err = e
_purge_tools_modules()
try:
from omegaconf import OmegaConf
cr = getattr(OmegaConf, "clear_resolver", None)
if callable(cr):
cr("eval")
if _OrigOmegaRegister is not None:
OmegaConf.register_new_resolver = _OrigOmegaRegister # type: ignore[method-assign]
except Exception:
pass
_OMEGA_RESOLVER_PATCHED = False
_patch_omegaconf_register_new_resolver()
if s not in sys.path:
sys.path.insert(0, s)
continue
last_err = e
_purge_tools_modules()
try:
sys.path.remove(s)
except ValueError:
pass
raise ImportError(
f"{code_root}: import failed after resolving modules (not a simple missing PyPI dep): {e}"
) from e
try:
sys.path.remove(s)
except ValueError:
pass
raise ImportError(
f"{code_root}: exceeded {max_rounds} pip rounds (last error: {last_err}). "
"Install fish-speech deps manually or raise VOCENCE_FISH_SPEECH_PIP_MAX_ROUNDS."
) from last_err
raise FileNotFoundError(
f"Missing {roots[0] / mm}. HF repo should include a fish-speech/ tree next to codec.pth, "
f"or set fish_speech.repo_root in {_VOCENCE_YAML} / FISH_SPEECH_ROOT."
)
def load_tts_inference_engine(
*,
llama_checkpoint_path: str,
decoder_checkpoint_path: str,
decoder_config_name: str = "modded_dac_vq",
device: str = "cuda",
half: bool = False,
compile_model: bool = False,
) -> Any:
from tools.server.model_manager import ModelManager
m = ModelManager(
mode="tts",
device=device,
half=half,
compile=compile_model,
llama_checkpoint_path=llama_checkpoint_path,
decoder_checkpoint_path=decoder_checkpoint_path,
decoder_config_name=decoder_config_name,
)
return m.tts_inference_engine
def synthesize_wav(
engine: Any,
*,
text: str,
reference_audio_path: str | None = None,
reference_text: str | None = None,
max_new_tokens: int = 1024,
chunk_length: int = 200,
top_p: float = 0.8,
repetition_penalty: float = 1.1,
temperature: float = 0.8,
seed: int | None = None,
) -> tuple[int, np.ndarray]:
from fish_speech.utils.schema import ServeReferenceAudio, ServeTTSRequest
if bool(reference_audio_path) ^ bool(reference_text):
raise ValueError("reference_audio_path and reference_text must be both set or both omitted")
refs: list[ServeReferenceAudio] = []
if reference_audio_path:
rp = Path(reference_audio_path)
if not rp.is_file():
raise FileNotFoundError(rp)
refs = [ServeReferenceAudio(audio=rp.read_bytes(), text=reference_text or "")]
req = ServeTTSRequest(
text=text,
references=refs,
reference_id=None,
max_new_tokens=max_new_tokens,
chunk_length=chunk_length,
top_p=top_p,
repetition_penalty=repetition_penalty,
temperature=temperature,
format="wav",
streaming=False,
seed=seed,
)
sr: int | None = None
audio: np.ndarray | None = None
for result in engine.inference(req):
if result.code == "error":
raise RuntimeError(str(result.error or "inference error"))
if result.code == "final" and result.audio is not None:
sr, audio = result.audio
break
if sr is None or audio is None:
raise RuntimeError("no audio")
arr = np.asarray(audio, dtype=np.float32)
if arr.ndim > 1:
arr = np.mean(arr, axis=-1).astype(np.float32)
return int(sr), arr
def _resample(w: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
if orig_sr == target_sr:
return np.asarray(w, dtype=np.float32)
import librosa
y = np.asarray(w, dtype=np.float32)
if y.ndim > 1:
y = np.mean(y, axis=-1).astype(np.float32)
return librosa.resample(y, orig_sr=orig_sr, target_sr=target_sr).astype(np.float32)
def _wav_bytes(w: np.ndarray, sample_rate: int) -> bytes:
w = np.clip(np.asarray(w, dtype=np.float32), -1.0, 1.0)
s16 = (w * 32767.0).astype(np.int16)
buf = io.BytesIO()
with wave.open(buf, "wb") as wv:
wv.setnchannels(1)
wv.setsampwidth(2)
wv.setframerate(sample_rate)
wv.writeframes(s16.tobytes())
return buf.getvalue()
def _resolve_ckpt(raw: str | None, *, model_root: Path, miner_repo: Path) -> Path | None:
if not raw or not str(raw).strip():
return None
s = str(raw).strip()
p = Path(s).expanduser()
if p.is_absolute():
return p.resolve()
for base in (model_root, miner_repo):
c = (base / s).resolve()
if c.exists():
return c
return (miner_repo / s).resolve()
def _llama_and_decoder(model_root: Path, miner_repo: Path, fs: Mapping[str, Any]) -> tuple[str, str]:
lr = (fs.get("llama_checkpoint_path") or os.environ.get("FISH_SPEECH_LLAMA_PATH") or "").strip()
dr = (fs.get("decoder_checkpoint_path") or os.environ.get("FISH_SPEECH_DECODER_PATH") or "").strip()
lp, dp = _resolve_ckpt(lr or None, model_root=model_root, miner_repo=miner_repo), _resolve_ckpt(
dr or None, model_root=model_root, miner_repo=miner_repo
)
if lp is not None and dp is not None:
return str(lp), str(dp)
if lp is not None and dp is None:
cand = sorted(Path(lp).rglob("codec.pth"), key=lambda x: len(x.parts))
if not cand:
raise FileNotFoundError(f"no codec.pth under {lp}")
return str(lp), str(cand[0])
if dp is not None and lp is None:
return str(Path(dp).parent), str(dp)
c = model_root / "codec.pth"
if c.is_file():
p = c.parent
return str(p), str(c)
m = sorted(model_root.rglob("codec.pth"), key=lambda x: len(x.parts))
if not m:
raise FileNotFoundError(f"no codec.pth under {model_root}")
x = m[0]
return str(x.parent), str(x)
def _prompt(instruction: str, text: str) -> str:
s = instruction.strip()
tags = "".join(f"[{p.strip()}]" for p in s.split("|") if p.strip()) if s else ""
body = text.strip()
if not tags:
return body
return f"{tags} {body}" if body else tags
class Miner:
def __init__(self, miner_repo: Path) -> None:
if not logging.root.handlers:
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s")
self._repo = Path(miner_repo).resolve()
cfg = _read_yaml(self._repo)
lim = cfg.get("limits") or {}
self._cap_t = int(lim.get("max_text_chars", 2000))
self._cap_i = int(lim.get("max_instruction_chars", 600))
gen = cfg.get("generation") or {}
self._out_sr = int(gen.get("sample_rate", _OUT_SR))
if self._out_sr != _OUT_SR:
raise ValueError(f"generation.sample_rate must be {_OUT_SR} in {_VOCENCE_YAML}")
fs = cfg.get("fish_speech") or {}
rt = cfg.get("runtime") or {}
log = logging.getLogger(__name__)
hub = (rt.get("hub_model_id") or rt.get("model_id") or "").strip()
rev = str(rt.get("model_revision") or rt.get("hub_revision") or os.environ.get("VOCENCE_MODEL_REVISION") or "").strip() or None
model_root = download_hub(self._repo, hub, rev) if hub else self._repo
_ensure_fish_speech(self._repo, model_root, fs)
llama_p, dec_p = _llama_and_decoder(model_root, self._repo, fs)
if not Path(dec_p).is_file():
raise FileNotFoundError(f"decoder not a file: {dec_p}")
if not Path(llama_p).exists():
raise FileNotFoundError(f"llama path missing: {llama_p}")
dev = str(fs.get("device") or rt.get("device_preference") or os.environ.get("FISH_SPEECH_DEVICE") or "cuda")
self._engine = load_tts_inference_engine(
llama_checkpoint_path=llama_p,
decoder_checkpoint_path=dec_p,
decoder_config_name=str(fs.get("decoder_config_name", "modded_dac_vq")),
device=dev,
half=bool(fs.get("half", False)),
compile_model=bool(fs.get("compile", False)),
)
self._tok = int(fs.get("max_new_tokens", 1024))
self._chunk = int(fs.get("chunk_length", 200))
self._top_p = float(fs.get("top_p", 0.8))
self._rep = float(fs.get("repetition_penalty", 1.1))
self._temp = float(fs.get("temperature", 0.8))
se = fs.get("seed")
self._seed = int(se) if se is not None else None
self._adapter = str(rt.get("adapter", "finetuned-tts"))
log.info("Miner ready (hub=%s, llama=%s)", hub or "local", llama_p)
def generate_wav(self, instruction: str, text: str) -> tuple[np.ndarray, int]:
t = text[: self._cap_t] if self._cap_t else text
ins = instruction[: self._cap_i] if self._cap_i else instruction
sr, wav = synthesize_wav(
self._engine,
text=_prompt(ins, t),
max_new_tokens=self._tok,
chunk_length=self._chunk,
top_p=self._top_p,
repetition_penalty=self._rep,
temperature=self._temp,
seed=self._seed,
)
return _resample(wav, int(sr), self._out_sr), self._out_sr
_engine: Miner | None = None
_err: str | None = None
_sr: int = _OUT_SR
_adapter: str = "finetuned-tts"
def _run_dev_server() -> None:
from contextlib import asynccontextmanager
import uvicorn
from fastapi import Body, FastAPI, HTTPException, status
from fastapi.responses import Response
from pydantic import BaseModel
if not logging.root.handlers:
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s")
@asynccontextmanager
async def lifespan(_: Any):
global _engine, _err, _sr, _adapter
cfg = _read_yaml(REPO)
_sr = int((cfg.get("generation") or {}).get("sample_rate", _OUT_SR))
_adapter = str((cfg.get("runtime") or {}).get("adapter", "finetuned-tts"))
try:
_engine = Miner(REPO)
_err = None
except Exception as e:
_engine = None
_err = f"{type(e).__name__}: {e}"
logging.getLogger(__name__).exception("Miner startup failed")
yield
_engine = None
class Health(BaseModel):
status: str
model_loaded: bool
sample_rate: int
adapter: str
error: str | None = None
app = FastAPI(title="Vocence TTS", lifespan=lifespan)
@app.get("/health", response_model=Health)
async def health() -> Health:
ok = _engine is not None
return Health(
status="healthy" if ok else "unhealthy",
model_loaded=ok,
sample_rate=_sr,
adapter=_adapter,
error=None if ok else _err,
)
lim = _read_yaml(REPO).get("limits") or {}
mx_t, mx_i = int(lim.get("max_text_chars", 2000)), int(lim.get("max_instruction_chars", 600))
@app.post("/speak", response_class=Response, response_model=None)
async def speak(
text: str = Body(..., min_length=1, max_length=mx_t, embed=True),
instruction: str = Body(..., min_length=1, max_length=mx_i, embed=True),
) -> Response:
if _engine is None:
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=_err or "engine not loaded")
w, sr = _engine.generate_wav(instruction, text)
w = np.asarray(w)
if w.ndim != 1 or w.size == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="invalid waveform")
d = float(w.shape[0]) / float(sr)
if d <= 0 or d > _MAX_AUDIO_SEC:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="invalid duration")
return Response(content=_wav_bytes(w, int(sr)), media_type="audio/wav")
uvicorn.run(app, host=os.environ.get("HOST", "0.0.0.0"), port=int(os.environ.get("PORT", "8765")))
if __name__ == "__main__":
_run_dev_server()