Spaces:
Running on Zero
Running on Zero
| """Fish Audio (OpenAudio S1-mini) inference engine. | |
| Loads the model once (cached), exposes a ZeroGPU-decorated ``synthesize()`` for a single | |
| utterance with optional zero-shot voice cloning, and ``generate_podcast()`` to stitch a | |
| multi-speaker script into one waveform. | |
| Heavy deps (torch / fish_speech) are imported lazily so this module can be imported on a | |
| CPU-only / local machine (Phase 1 development) without them installed. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import tempfile | |
| from dataclasses import dataclass | |
| from typing import List, Optional, Tuple | |
| import numpy as np | |
| # ----------------------------------------------------------------- ZeroGPU decorator | |
| try: | |
| import spaces # provided in HF Spaces runtime | |
| GPU = spaces.GPU | |
| except Exception: # local / non-Space: no-op decorator | |
| def GPU(*dargs, **dkwargs): | |
| def _wrap(fn): | |
| return fn | |
| # support both @GPU and @GPU(duration=...) | |
| if len(dargs) == 1 and callable(dargs[0]) and not dkwargs: | |
| return dargs[0] | |
| return _wrap | |
| TTS_MODEL_REPO = os.environ.get("TTS_MODEL_REPO", "fishaudio/openaudio-s1-mini") | |
| # Filenames inside the model repo — verify against the repo if it changes. | |
| DECODER_CHECKPOINT = os.environ.get("TTS_DECODER_CKPT", "codec.pth") | |
| DECODER_CONFIG = os.environ.get("TTS_DECODER_CONFIG", "modded_dac_vq") | |
| _ENGINE = None # cached TTSInferenceEngine | |
| _SAMPLE_RATE = 44100 | |
| class TTSModelAccessError(RuntimeError): | |
| """Raised when the configured TTS model cannot be downloaded from HF Hub.""" | |
| class VoiceConfig: | |
| """Resolved voice for one speaker: a reference clip+text, or model default.""" | |
| ref_audio: Optional[str] = None | |
| ref_text: str = "" | |
| def is_available() -> bool: | |
| """True if the TTS stack can run (fish_speech + torch importable).""" | |
| try: | |
| import torch # noqa: F401 | |
| import fish_speech # noqa: F401 | |
| return True | |
| except Exception: | |
| return False | |
| def _patch_pyrootutils() -> None: | |
| """Make fish-speech importable when installed as a package (no source checkout). | |
| Several fish_speech modules call ``pyrootutils.setup_root(__file__, | |
| indicator='.project-root')`` at import time. That marker only exists in the source | |
| repo, so a pip-installed copy raises ``FileNotFoundError`` (and we can't write the | |
| marker into a root-owned site-packages at runtime). | |
| We wrap ``pyrootutils.setup_root`` — the exact attribute fish_speech calls — so the | |
| interception is guaranteed. (Patching ``find_root`` does not work: ``setup_root`` | |
| lives in the ``pyrootutils.pyrootutils`` submodule and resolves ``find_root`` from | |
| that submodule's own globals, not the package-level re-export.) On failure we fall | |
| back to the installed package's parent dir, which mirrors the repo layout | |
| (``<root>/fish_speech/...``) closely enough for config resolution. | |
| """ | |
| import pyrootutils | |
| if getattr(pyrootutils.setup_root, "_podify_patched", False): | |
| return | |
| _orig_setup_root = pyrootutils.setup_root | |
| def _setup_root(*args, **kwargs): | |
| try: | |
| return _orig_setup_root(*args, **kwargs) | |
| except FileNotFoundError: | |
| import sys | |
| from pathlib import Path | |
| # fish_speech is a PEP 420 namespace package here, so __file__ is None; | |
| # locate its directory via __path__, falling back to the calling module's | |
| # path (setup_root's first arg). The project root is the dir *containing* | |
| # the fish_speech package, mirroring the repo's .project-root location. | |
| pkg_dir = None | |
| try: | |
| import fish_speech | |
| paths = list(getattr(fish_speech, "__path__", []) or []) | |
| if paths: | |
| pkg_dir = Path(paths[0]).resolve() | |
| elif getattr(fish_speech, "__file__", None): | |
| pkg_dir = Path(fish_speech.__file__).resolve().parent | |
| except Exception: | |
| pkg_dir = None | |
| if pkg_dir is None and args: | |
| sf = Path(str(args[0])).resolve() | |
| for p in [sf, *sf.parents]: | |
| if p.name == "fish_speech": | |
| pkg_dir = p | |
| break | |
| if pkg_dir is None: | |
| raise # nothing to fall back to — re-raise the original error | |
| root = pkg_dir.parent | |
| if kwargs.get("pythonpath", False) and str(root) not in sys.path: | |
| sys.path.insert(0, str(root)) | |
| if kwargs.get("project_root_env_var", True): | |
| os.environ["PROJECT_ROOT"] = str(root) | |
| return root | |
| _setup_root._podify_patched = True | |
| pyrootutils.setup_root = _setup_root | |
| def _load_engine(): | |
| """Build and cache the TTSInferenceEngine. Runs on the GPU worker.""" | |
| global _ENGINE, _SAMPLE_RATE | |
| if _ENGINE is not None: | |
| return _ENGINE | |
| import torch | |
| from huggingface_hub import snapshot_download | |
| _patch_pyrootutils() # must precede the fish_speech inference imports below | |
| from fish_speech.models.text2semantic.inference import launch_thread_safe_queue | |
| from fish_speech.models.dac.inference import load_model as load_decoder_model | |
| from fish_speech.inference_engine import TTSInferenceEngine | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| precision = torch.half if device == "cuda" else torch.float32 | |
| token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") | |
| try: | |
| checkpoint_dir = snapshot_download(repo_id=TTS_MODEL_REPO, token=token) | |
| except Exception as e: | |
| msg = str(e) | |
| if type(e).__name__ == "GatedRepoError" or "Cannot access gated repo" in msg or "403" in msg: | |
| access_url = ( | |
| "https://huggingface.co/fishaudio/s1-mini" | |
| if TTS_MODEL_REPO == "fishaudio/openaudio-s1-mini" | |
| else f"https://huggingface.co/{TTS_MODEL_REPO}" | |
| ) | |
| raise TTSModelAccessError( | |
| f"The TTS model '{TTS_MODEL_REPO}' is gated or not accessible with the current " | |
| f"Hugging Face token. Request access at {access_url}, then log in locally or set " | |
| "HF_TOKEN to a token with read access. You can also set TTS_MODEL_REPO to another " | |
| "compatible Fish Audio/OpenAudio checkpoint you can access." | |
| ) from e | |
| raise | |
| llama_queue = launch_thread_safe_queue( | |
| checkpoint_path=checkpoint_dir, | |
| device=device, | |
| precision=precision, | |
| compile=False, | |
| ) | |
| decoder_model = load_decoder_model( | |
| config_name=DECODER_CONFIG, | |
| checkpoint_path=os.path.join(checkpoint_dir, DECODER_CHECKPOINT), | |
| device=device, | |
| ) | |
| engine = TTSInferenceEngine( | |
| llama_queue=llama_queue, | |
| decoder_model=decoder_model, | |
| compile=False, | |
| precision=precision, | |
| ) | |
| try: | |
| _SAMPLE_RATE = int(decoder_model.sample_rate) | |
| except Exception: | |
| _SAMPLE_RATE = 44100 | |
| _ENGINE = engine | |
| return engine | |
| def _build_request(text: str, voice: VoiceConfig): | |
| from fish_speech.utils.schema import ServeTTSRequest, ServeReferenceAudio | |
| references = [] | |
| if voice.ref_audio and os.path.isfile(voice.ref_audio): | |
| with open(voice.ref_audio, "rb") as f: | |
| audio_bytes = f.read() | |
| references = [ServeReferenceAudio(audio=audio_bytes, text=voice.ref_text or "")] | |
| return ServeTTSRequest( | |
| text=text, | |
| references=references, | |
| reference_id=None, | |
| max_new_tokens=1024, | |
| chunk_length=200, | |
| top_p=0.8, | |
| repetition_penalty=1.1, | |
| temperature=0.8, | |
| format="wav", | |
| ) | |
| def synthesize(text: str, voice: VoiceConfig) -> Tuple[int, np.ndarray]: | |
| """Synthesize one utterance. Returns (sample_rate, float32 mono waveform).""" | |
| engine = _load_engine() | |
| request = _build_request(text, voice) | |
| audio_chunks: List[np.ndarray] = [] | |
| sample_rate = _SAMPLE_RATE | |
| for result in engine.inference(request): | |
| code = getattr(result, "code", None) | |
| if code == "final" and getattr(result, "audio", None) is not None: | |
| sample_rate, audio = result.audio | |
| audio_chunks.append(np.asarray(audio, dtype=np.float32).reshape(-1)) | |
| elif code == "error": | |
| raise RuntimeError(f"TTS inference error: {getattr(result, 'error', 'unknown')}") | |
| if not audio_chunks: | |
| raise RuntimeError("TTS produced no audio.") | |
| return int(sample_rate), np.concatenate(audio_chunks) | |
| def generate_podcast( | |
| lines: List[Tuple[str, str]], | |
| voice_map: dict, | |
| *, | |
| gap_seconds: float = 0.4, | |
| progress=None, | |
| ) -> Tuple[int, np.ndarray]: | |
| """Synthesize each (speaker, text) line and stitch into one waveform. | |
| ``voice_map`` maps speaker name -> VoiceConfig. The whole loop runs inside a single | |
| GPU allocation so the model is loaded once per podcast. | |
| """ | |
| engine = _load_engine() | |
| segments: List[np.ndarray] = [] | |
| sample_rate = _SAMPLE_RATE | |
| default_voice = VoiceConfig() | |
| total = len(lines) | |
| for i, (speaker, text) in enumerate(lines): | |
| if not text.strip(): | |
| continue | |
| if progress is not None: | |
| progress((i / max(total, 1)), desc=f"Voicing line {i + 1}/{total} ({speaker})") | |
| voice = voice_map.get(speaker, default_voice) | |
| request = _build_request(text, voice) | |
| for result in engine.inference(request): | |
| if getattr(result, "code", None) == "final" and getattr(result, "audio", None): | |
| sample_rate, audio = result.audio | |
| segments.append(np.asarray(audio, dtype=np.float32).reshape(-1)) | |
| if gap_seconds > 0: | |
| segments.append(np.zeros(int(sample_rate * gap_seconds), dtype=np.float32)) | |
| if not segments: | |
| raise RuntimeError("No audio was generated for this script.") | |
| return int(sample_rate), np.concatenate(segments) | |
| def write_wav(sample_rate: int, audio: np.ndarray) -> str: | |
| """Write a waveform to a temp .wav file and return its path (for download).""" | |
| import soundfile as sf | |
| path = tempfile.mktemp(suffix=".wav") | |
| sf.write(path, audio, sample_rate) | |
| return path | |