podify / tts /engine.py
jayaspjacob's picture
pr/2 (#2)
b3ec634
Raw
History Blame Contribute Delete
10.3 kB
"""Fish Audio (OpenAudio S1-mini) inference engine.
Loads the model once (cached), exposes a ZeroGPU-decorated ``synthesize()`` for a single
utterance with optional zero-shot voice cloning, and ``generate_podcast()`` to stitch a
multi-speaker script into one waveform.
Heavy deps (torch / fish_speech) are imported lazily so this module can be imported on a
CPU-only / local machine (Phase 1 development) without them installed.
"""
from __future__ import annotations
import os
import tempfile
from dataclasses import dataclass
from typing import List, Optional, Tuple
import numpy as np
# ----------------------------------------------------------------- ZeroGPU decorator
try:
import spaces # provided in HF Spaces runtime
GPU = spaces.GPU
except Exception: # local / non-Space: no-op decorator
def GPU(*dargs, **dkwargs):
def _wrap(fn):
return fn
# support both @GPU and @GPU(duration=...)
if len(dargs) == 1 and callable(dargs[0]) and not dkwargs:
return dargs[0]
return _wrap
TTS_MODEL_REPO = os.environ.get("TTS_MODEL_REPO", "fishaudio/openaudio-s1-mini")
# Filenames inside the model repo — verify against the repo if it changes.
DECODER_CHECKPOINT = os.environ.get("TTS_DECODER_CKPT", "codec.pth")
DECODER_CONFIG = os.environ.get("TTS_DECODER_CONFIG", "modded_dac_vq")
_ENGINE = None # cached TTSInferenceEngine
_SAMPLE_RATE = 44100
class TTSModelAccessError(RuntimeError):
"""Raised when the configured TTS model cannot be downloaded from HF Hub."""
@dataclass
class VoiceConfig:
"""Resolved voice for one speaker: a reference clip+text, or model default."""
ref_audio: Optional[str] = None
ref_text: str = ""
def is_available() -> bool:
"""True if the TTS stack can run (fish_speech + torch importable)."""
try:
import torch # noqa: F401
import fish_speech # noqa: F401
return True
except Exception:
return False
def _patch_pyrootutils() -> None:
"""Make fish-speech importable when installed as a package (no source checkout).
Several fish_speech modules call ``pyrootutils.setup_root(__file__,
indicator='.project-root')`` at import time. That marker only exists in the source
repo, so a pip-installed copy raises ``FileNotFoundError`` (and we can't write the
marker into a root-owned site-packages at runtime).
We wrap ``pyrootutils.setup_root`` — the exact attribute fish_speech calls — so the
interception is guaranteed. (Patching ``find_root`` does not work: ``setup_root``
lives in the ``pyrootutils.pyrootutils`` submodule and resolves ``find_root`` from
that submodule's own globals, not the package-level re-export.) On failure we fall
back to the installed package's parent dir, which mirrors the repo layout
(``<root>/fish_speech/...``) closely enough for config resolution.
"""
import pyrootutils
if getattr(pyrootutils.setup_root, "_podify_patched", False):
return
_orig_setup_root = pyrootutils.setup_root
def _setup_root(*args, **kwargs):
try:
return _orig_setup_root(*args, **kwargs)
except FileNotFoundError:
import sys
from pathlib import Path
# fish_speech is a PEP 420 namespace package here, so __file__ is None;
# locate its directory via __path__, falling back to the calling module's
# path (setup_root's first arg). The project root is the dir *containing*
# the fish_speech package, mirroring the repo's .project-root location.
pkg_dir = None
try:
import fish_speech
paths = list(getattr(fish_speech, "__path__", []) or [])
if paths:
pkg_dir = Path(paths[0]).resolve()
elif getattr(fish_speech, "__file__", None):
pkg_dir = Path(fish_speech.__file__).resolve().parent
except Exception:
pkg_dir = None
if pkg_dir is None and args:
sf = Path(str(args[0])).resolve()
for p in [sf, *sf.parents]:
if p.name == "fish_speech":
pkg_dir = p
break
if pkg_dir is None:
raise # nothing to fall back to — re-raise the original error
root = pkg_dir.parent
if kwargs.get("pythonpath", False) and str(root) not in sys.path:
sys.path.insert(0, str(root))
if kwargs.get("project_root_env_var", True):
os.environ["PROJECT_ROOT"] = str(root)
return root
_setup_root._podify_patched = True
pyrootutils.setup_root = _setup_root
def _load_engine():
"""Build and cache the TTSInferenceEngine. Runs on the GPU worker."""
global _ENGINE, _SAMPLE_RATE
if _ENGINE is not None:
return _ENGINE
import torch
from huggingface_hub import snapshot_download
_patch_pyrootutils() # must precede the fish_speech inference imports below
from fish_speech.models.text2semantic.inference import launch_thread_safe_queue
from fish_speech.models.dac.inference import load_model as load_decoder_model
from fish_speech.inference_engine import TTSInferenceEngine
device = "cuda" if torch.cuda.is_available() else "cpu"
precision = torch.half if device == "cuda" else torch.float32
token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
try:
checkpoint_dir = snapshot_download(repo_id=TTS_MODEL_REPO, token=token)
except Exception as e:
msg = str(e)
if type(e).__name__ == "GatedRepoError" or "Cannot access gated repo" in msg or "403" in msg:
access_url = (
"https://huggingface.co/fishaudio/s1-mini"
if TTS_MODEL_REPO == "fishaudio/openaudio-s1-mini"
else f"https://huggingface.co/{TTS_MODEL_REPO}"
)
raise TTSModelAccessError(
f"The TTS model '{TTS_MODEL_REPO}' is gated or not accessible with the current "
f"Hugging Face token. Request access at {access_url}, then log in locally or set "
"HF_TOKEN to a token with read access. You can also set TTS_MODEL_REPO to another "
"compatible Fish Audio/OpenAudio checkpoint you can access."
) from e
raise
llama_queue = launch_thread_safe_queue(
checkpoint_path=checkpoint_dir,
device=device,
precision=precision,
compile=False,
)
decoder_model = load_decoder_model(
config_name=DECODER_CONFIG,
checkpoint_path=os.path.join(checkpoint_dir, DECODER_CHECKPOINT),
device=device,
)
engine = TTSInferenceEngine(
llama_queue=llama_queue,
decoder_model=decoder_model,
compile=False,
precision=precision,
)
try:
_SAMPLE_RATE = int(decoder_model.sample_rate)
except Exception:
_SAMPLE_RATE = 44100
_ENGINE = engine
return engine
def _build_request(text: str, voice: VoiceConfig):
from fish_speech.utils.schema import ServeTTSRequest, ServeReferenceAudio
references = []
if voice.ref_audio and os.path.isfile(voice.ref_audio):
with open(voice.ref_audio, "rb") as f:
audio_bytes = f.read()
references = [ServeReferenceAudio(audio=audio_bytes, text=voice.ref_text or "")]
return ServeTTSRequest(
text=text,
references=references,
reference_id=None,
max_new_tokens=1024,
chunk_length=200,
top_p=0.8,
repetition_penalty=1.1,
temperature=0.8,
format="wav",
)
@GPU(duration=120)
def synthesize(text: str, voice: VoiceConfig) -> Tuple[int, np.ndarray]:
"""Synthesize one utterance. Returns (sample_rate, float32 mono waveform)."""
engine = _load_engine()
request = _build_request(text, voice)
audio_chunks: List[np.ndarray] = []
sample_rate = _SAMPLE_RATE
for result in engine.inference(request):
code = getattr(result, "code", None)
if code == "final" and getattr(result, "audio", None) is not None:
sample_rate, audio = result.audio
audio_chunks.append(np.asarray(audio, dtype=np.float32).reshape(-1))
elif code == "error":
raise RuntimeError(f"TTS inference error: {getattr(result, 'error', 'unknown')}")
if not audio_chunks:
raise RuntimeError("TTS produced no audio.")
return int(sample_rate), np.concatenate(audio_chunks)
@GPU(duration=300)
def generate_podcast(
lines: List[Tuple[str, str]],
voice_map: dict,
*,
gap_seconds: float = 0.4,
progress=None,
) -> Tuple[int, np.ndarray]:
"""Synthesize each (speaker, text) line and stitch into one waveform.
``voice_map`` maps speaker name -> VoiceConfig. The whole loop runs inside a single
GPU allocation so the model is loaded once per podcast.
"""
engine = _load_engine()
segments: List[np.ndarray] = []
sample_rate = _SAMPLE_RATE
default_voice = VoiceConfig()
total = len(lines)
for i, (speaker, text) in enumerate(lines):
if not text.strip():
continue
if progress is not None:
progress((i / max(total, 1)), desc=f"Voicing line {i + 1}/{total} ({speaker})")
voice = voice_map.get(speaker, default_voice)
request = _build_request(text, voice)
for result in engine.inference(request):
if getattr(result, "code", None) == "final" and getattr(result, "audio", None):
sample_rate, audio = result.audio
segments.append(np.asarray(audio, dtype=np.float32).reshape(-1))
if gap_seconds > 0:
segments.append(np.zeros(int(sample_rate * gap_seconds), dtype=np.float32))
if not segments:
raise RuntimeError("No audio was generated for this script.")
return int(sample_rate), np.concatenate(segments)
def write_wav(sample_rate: int, audio: np.ndarray) -> str:
"""Write a waveform to a temp .wav file and return its path (for download)."""
import soundfile as sf
path = tempfile.mktemp(suffix=".wav")
sf.write(path, audio, sample_rate)
return path