"""Fish Audio (OpenAudio S1-mini) inference engine. Loads the model once (cached), exposes a ZeroGPU-decorated ``synthesize()`` for a single utterance with optional zero-shot voice cloning, and ``generate_podcast()`` to stitch a multi-speaker script into one waveform. Heavy deps (torch / fish_speech) are imported lazily so this module can be imported on a CPU-only / local machine (Phase 1 development) without them installed. """ from __future__ import annotations import os import tempfile from dataclasses import dataclass from typing import List, Optional, Tuple import numpy as np # ----------------------------------------------------------------- ZeroGPU decorator try: import spaces # provided in HF Spaces runtime GPU = spaces.GPU except Exception: # local / non-Space: no-op decorator def GPU(*dargs, **dkwargs): def _wrap(fn): return fn # support both @GPU and @GPU(duration=...) if len(dargs) == 1 and callable(dargs[0]) and not dkwargs: return dargs[0] return _wrap TTS_MODEL_REPO = os.environ.get("TTS_MODEL_REPO", "fishaudio/openaudio-s1-mini") # Filenames inside the model repo — verify against the repo if it changes. DECODER_CHECKPOINT = os.environ.get("TTS_DECODER_CKPT", "codec.pth") DECODER_CONFIG = os.environ.get("TTS_DECODER_CONFIG", "modded_dac_vq") _ENGINE = None # cached TTSInferenceEngine _SAMPLE_RATE = 44100 class TTSModelAccessError(RuntimeError): """Raised when the configured TTS model cannot be downloaded from HF Hub.""" @dataclass class VoiceConfig: """Resolved voice for one speaker: a reference clip+text, or model default.""" ref_audio: Optional[str] = None ref_text: str = "" def is_available() -> bool: """True if the TTS stack can run (fish_speech + torch importable).""" try: import torch # noqa: F401 import fish_speech # noqa: F401 return True except Exception: return False def _patch_pyrootutils() -> None: """Make fish-speech importable when installed as a package (no source checkout). Several fish_speech modules call ``pyrootutils.setup_root(__file__, indicator='.project-root')`` at import time. That marker only exists in the source repo, so a pip-installed copy raises ``FileNotFoundError`` (and we can't write the marker into a root-owned site-packages at runtime). We wrap ``pyrootutils.setup_root`` — the exact attribute fish_speech calls — so the interception is guaranteed. (Patching ``find_root`` does not work: ``setup_root`` lives in the ``pyrootutils.pyrootutils`` submodule and resolves ``find_root`` from that submodule's own globals, not the package-level re-export.) On failure we fall back to the installed package's parent dir, which mirrors the repo layout (``/fish_speech/...``) closely enough for config resolution. """ import pyrootutils if getattr(pyrootutils.setup_root, "_podify_patched", False): return _orig_setup_root = pyrootutils.setup_root def _setup_root(*args, **kwargs): try: return _orig_setup_root(*args, **kwargs) except FileNotFoundError: import sys from pathlib import Path # fish_speech is a PEP 420 namespace package here, so __file__ is None; # locate its directory via __path__, falling back to the calling module's # path (setup_root's first arg). The project root is the dir *containing* # the fish_speech package, mirroring the repo's .project-root location. pkg_dir = None try: import fish_speech paths = list(getattr(fish_speech, "__path__", []) or []) if paths: pkg_dir = Path(paths[0]).resolve() elif getattr(fish_speech, "__file__", None): pkg_dir = Path(fish_speech.__file__).resolve().parent except Exception: pkg_dir = None if pkg_dir is None and args: sf = Path(str(args[0])).resolve() for p in [sf, *sf.parents]: if p.name == "fish_speech": pkg_dir = p break if pkg_dir is None: raise # nothing to fall back to — re-raise the original error root = pkg_dir.parent if kwargs.get("pythonpath", False) and str(root) not in sys.path: sys.path.insert(0, str(root)) if kwargs.get("project_root_env_var", True): os.environ["PROJECT_ROOT"] = str(root) return root _setup_root._podify_patched = True pyrootutils.setup_root = _setup_root def _load_engine(): """Build and cache the TTSInferenceEngine. Runs on the GPU worker.""" global _ENGINE, _SAMPLE_RATE if _ENGINE is not None: return _ENGINE import torch from huggingface_hub import snapshot_download _patch_pyrootutils() # must precede the fish_speech inference imports below from fish_speech.models.text2semantic.inference import launch_thread_safe_queue from fish_speech.models.dac.inference import load_model as load_decoder_model from fish_speech.inference_engine import TTSInferenceEngine device = "cuda" if torch.cuda.is_available() else "cpu" precision = torch.half if device == "cuda" else torch.float32 token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") try: checkpoint_dir = snapshot_download(repo_id=TTS_MODEL_REPO, token=token) except Exception as e: msg = str(e) if type(e).__name__ == "GatedRepoError" or "Cannot access gated repo" in msg or "403" in msg: access_url = ( "https://huggingface.co/fishaudio/s1-mini" if TTS_MODEL_REPO == "fishaudio/openaudio-s1-mini" else f"https://huggingface.co/{TTS_MODEL_REPO}" ) raise TTSModelAccessError( f"The TTS model '{TTS_MODEL_REPO}' is gated or not accessible with the current " f"Hugging Face token. Request access at {access_url}, then log in locally or set " "HF_TOKEN to a token with read access. You can also set TTS_MODEL_REPO to another " "compatible Fish Audio/OpenAudio checkpoint you can access." ) from e raise llama_queue = launch_thread_safe_queue( checkpoint_path=checkpoint_dir, device=device, precision=precision, compile=False, ) decoder_model = load_decoder_model( config_name=DECODER_CONFIG, checkpoint_path=os.path.join(checkpoint_dir, DECODER_CHECKPOINT), device=device, ) engine = TTSInferenceEngine( llama_queue=llama_queue, decoder_model=decoder_model, compile=False, precision=precision, ) try: _SAMPLE_RATE = int(decoder_model.sample_rate) except Exception: _SAMPLE_RATE = 44100 _ENGINE = engine return engine def _build_request(text: str, voice: VoiceConfig): from fish_speech.utils.schema import ServeTTSRequest, ServeReferenceAudio references = [] if voice.ref_audio and os.path.isfile(voice.ref_audio): with open(voice.ref_audio, "rb") as f: audio_bytes = f.read() references = [ServeReferenceAudio(audio=audio_bytes, text=voice.ref_text or "")] return ServeTTSRequest( text=text, references=references, reference_id=None, max_new_tokens=1024, chunk_length=200, top_p=0.8, repetition_penalty=1.1, temperature=0.8, format="wav", ) @GPU(duration=120) def synthesize(text: str, voice: VoiceConfig) -> Tuple[int, np.ndarray]: """Synthesize one utterance. Returns (sample_rate, float32 mono waveform).""" engine = _load_engine() request = _build_request(text, voice) audio_chunks: List[np.ndarray] = [] sample_rate = _SAMPLE_RATE for result in engine.inference(request): code = getattr(result, "code", None) if code == "final" and getattr(result, "audio", None) is not None: sample_rate, audio = result.audio audio_chunks.append(np.asarray(audio, dtype=np.float32).reshape(-1)) elif code == "error": raise RuntimeError(f"TTS inference error: {getattr(result, 'error', 'unknown')}") if not audio_chunks: raise RuntimeError("TTS produced no audio.") return int(sample_rate), np.concatenate(audio_chunks) @GPU(duration=300) def generate_podcast( lines: List[Tuple[str, str]], voice_map: dict, *, gap_seconds: float = 0.4, progress=None, ) -> Tuple[int, np.ndarray]: """Synthesize each (speaker, text) line and stitch into one waveform. ``voice_map`` maps speaker name -> VoiceConfig. The whole loop runs inside a single GPU allocation so the model is loaded once per podcast. """ engine = _load_engine() segments: List[np.ndarray] = [] sample_rate = _SAMPLE_RATE default_voice = VoiceConfig() total = len(lines) for i, (speaker, text) in enumerate(lines): if not text.strip(): continue if progress is not None: progress((i / max(total, 1)), desc=f"Voicing line {i + 1}/{total} ({speaker})") voice = voice_map.get(speaker, default_voice) request = _build_request(text, voice) for result in engine.inference(request): if getattr(result, "code", None) == "final" and getattr(result, "audio", None): sample_rate, audio = result.audio segments.append(np.asarray(audio, dtype=np.float32).reshape(-1)) if gap_seconds > 0: segments.append(np.zeros(int(sample_rate * gap_seconds), dtype=np.float32)) if not segments: raise RuntimeError("No audio was generated for this script.") return int(sample_rate), np.concatenate(segments) def write_wav(sample_rate: int, audio: np.ndarray) -> str: """Write a waveform to a temp .wav file and return its path (for download).""" import soundfile as sf path = tempfile.mktemp(suffix=".wav") sf.write(path, audio, sample_rate) return path