File size: 10,335 Bytes
623fe6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3ec634
 
 
 
623fe6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310b223
 
 
 
 
 
57f0925
 
 
 
 
 
 
310b223
 
 
 
57f0925
310b223
 
57f0925
310b223
57f0925
310b223
57f0925
310b223
57f0925
310b223
 
f3cc6d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57f0925
 
 
 
 
310b223
57f0925
 
310b223
 
623fe6a
 
 
 
 
 
 
 
310b223
 
 
623fe6a
 
 
 
 
 
 
b3ec634
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
623fe6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
"""Fish Audio (OpenAudio S1-mini) inference engine.

Loads the model once (cached), exposes a ZeroGPU-decorated ``synthesize()`` for a single
utterance with optional zero-shot voice cloning, and ``generate_podcast()`` to stitch a
multi-speaker script into one waveform.

Heavy deps (torch / fish_speech) are imported lazily so this module can be imported on a
CPU-only / local machine (Phase 1 development) without them installed.
"""

from __future__ import annotations

import os
import tempfile
from dataclasses import dataclass
from typing import List, Optional, Tuple

import numpy as np

# ----------------------------------------------------------------- ZeroGPU decorator
try:
    import spaces  # provided in HF Spaces runtime

    GPU = spaces.GPU
except Exception:  # local / non-Space: no-op decorator

    def GPU(*dargs, **dkwargs):
        def _wrap(fn):
            return fn

        # support both @GPU and @GPU(duration=...)
        if len(dargs) == 1 and callable(dargs[0]) and not dkwargs:
            return dargs[0]
        return _wrap


TTS_MODEL_REPO = os.environ.get("TTS_MODEL_REPO", "fishaudio/openaudio-s1-mini")
# Filenames inside the model repo — verify against the repo if it changes.
DECODER_CHECKPOINT = os.environ.get("TTS_DECODER_CKPT", "codec.pth")
DECODER_CONFIG = os.environ.get("TTS_DECODER_CONFIG", "modded_dac_vq")

_ENGINE = None  # cached TTSInferenceEngine
_SAMPLE_RATE = 44100


class TTSModelAccessError(RuntimeError):
    """Raised when the configured TTS model cannot be downloaded from HF Hub."""


@dataclass
class VoiceConfig:
    """Resolved voice for one speaker: a reference clip+text, or model default."""

    ref_audio: Optional[str] = None
    ref_text: str = ""


def is_available() -> bool:
    """True if the TTS stack can run (fish_speech + torch importable)."""
    try:
        import torch  # noqa: F401
        import fish_speech  # noqa: F401

        return True
    except Exception:
        return False


def _patch_pyrootutils() -> None:
    """Make fish-speech importable when installed as a package (no source checkout).

    Several fish_speech modules call ``pyrootutils.setup_root(__file__,
    indicator='.project-root')`` at import time. That marker only exists in the source
    repo, so a pip-installed copy raises ``FileNotFoundError`` (and we can't write the
    marker into a root-owned site-packages at runtime).

    We wrap ``pyrootutils.setup_root`` — the exact attribute fish_speech calls — so the
    interception is guaranteed. (Patching ``find_root`` does not work: ``setup_root``
    lives in the ``pyrootutils.pyrootutils`` submodule and resolves ``find_root`` from
    that submodule's own globals, not the package-level re-export.) On failure we fall
    back to the installed package's parent dir, which mirrors the repo layout
    (``<root>/fish_speech/...``) closely enough for config resolution.
    """
    import pyrootutils

    if getattr(pyrootutils.setup_root, "_podify_patched", False):
        return

    _orig_setup_root = pyrootutils.setup_root

    def _setup_root(*args, **kwargs):
        try:
            return _orig_setup_root(*args, **kwargs)
        except FileNotFoundError:
            import sys
            from pathlib import Path

            # fish_speech is a PEP 420 namespace package here, so __file__ is None;
            # locate its directory via __path__, falling back to the calling module's
            # path (setup_root's first arg). The project root is the dir *containing*
            # the fish_speech package, mirroring the repo's .project-root location.
            pkg_dir = None
            try:
                import fish_speech

                paths = list(getattr(fish_speech, "__path__", []) or [])
                if paths:
                    pkg_dir = Path(paths[0]).resolve()
                elif getattr(fish_speech, "__file__", None):
                    pkg_dir = Path(fish_speech.__file__).resolve().parent
            except Exception:
                pkg_dir = None

            if pkg_dir is None and args:
                sf = Path(str(args[0])).resolve()
                for p in [sf, *sf.parents]:
                    if p.name == "fish_speech":
                        pkg_dir = p
                        break

            if pkg_dir is None:
                raise  # nothing to fall back to — re-raise the original error

            root = pkg_dir.parent
            if kwargs.get("pythonpath", False) and str(root) not in sys.path:
                sys.path.insert(0, str(root))
            if kwargs.get("project_root_env_var", True):
                os.environ["PROJECT_ROOT"] = str(root)
            return root

    _setup_root._podify_patched = True
    pyrootutils.setup_root = _setup_root


def _load_engine():
    """Build and cache the TTSInferenceEngine. Runs on the GPU worker."""
    global _ENGINE, _SAMPLE_RATE
    if _ENGINE is not None:
        return _ENGINE

    import torch
    from huggingface_hub import snapshot_download

    _patch_pyrootutils()  # must precede the fish_speech inference imports below

    from fish_speech.models.text2semantic.inference import launch_thread_safe_queue
    from fish_speech.models.dac.inference import load_model as load_decoder_model
    from fish_speech.inference_engine import TTSInferenceEngine

    device = "cuda" if torch.cuda.is_available() else "cpu"
    precision = torch.half if device == "cuda" else torch.float32

    token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
    try:
        checkpoint_dir = snapshot_download(repo_id=TTS_MODEL_REPO, token=token)
    except Exception as e:
        msg = str(e)
        if type(e).__name__ == "GatedRepoError" or "Cannot access gated repo" in msg or "403" in msg:
            access_url = (
                "https://huggingface.co/fishaudio/s1-mini"
                if TTS_MODEL_REPO == "fishaudio/openaudio-s1-mini"
                else f"https://huggingface.co/{TTS_MODEL_REPO}"
            )
            raise TTSModelAccessError(
                f"The TTS model '{TTS_MODEL_REPO}' is gated or not accessible with the current "
                f"Hugging Face token. Request access at {access_url}, then log in locally or set "
                "HF_TOKEN to a token with read access. You can also set TTS_MODEL_REPO to another "
                "compatible Fish Audio/OpenAudio checkpoint you can access."
            ) from e
        raise

    llama_queue = launch_thread_safe_queue(
        checkpoint_path=checkpoint_dir,
        device=device,
        precision=precision,
        compile=False,
    )
    decoder_model = load_decoder_model(
        config_name=DECODER_CONFIG,
        checkpoint_path=os.path.join(checkpoint_dir, DECODER_CHECKPOINT),
        device=device,
    )

    engine = TTSInferenceEngine(
        llama_queue=llama_queue,
        decoder_model=decoder_model,
        compile=False,
        precision=precision,
    )
    try:
        _SAMPLE_RATE = int(decoder_model.sample_rate)
    except Exception:
        _SAMPLE_RATE = 44100
    _ENGINE = engine
    return engine


def _build_request(text: str, voice: VoiceConfig):
    from fish_speech.utils.schema import ServeTTSRequest, ServeReferenceAudio

    references = []
    if voice.ref_audio and os.path.isfile(voice.ref_audio):
        with open(voice.ref_audio, "rb") as f:
            audio_bytes = f.read()
        references = [ServeReferenceAudio(audio=audio_bytes, text=voice.ref_text or "")]

    return ServeTTSRequest(
        text=text,
        references=references,
        reference_id=None,
        max_new_tokens=1024,
        chunk_length=200,
        top_p=0.8,
        repetition_penalty=1.1,
        temperature=0.8,
        format="wav",
    )


@GPU(duration=120)
def synthesize(text: str, voice: VoiceConfig) -> Tuple[int, np.ndarray]:
    """Synthesize one utterance. Returns (sample_rate, float32 mono waveform)."""
    engine = _load_engine()
    request = _build_request(text, voice)

    audio_chunks: List[np.ndarray] = []
    sample_rate = _SAMPLE_RATE
    for result in engine.inference(request):
        code = getattr(result, "code", None)
        if code == "final" and getattr(result, "audio", None) is not None:
            sample_rate, audio = result.audio
            audio_chunks.append(np.asarray(audio, dtype=np.float32).reshape(-1))
        elif code == "error":
            raise RuntimeError(f"TTS inference error: {getattr(result, 'error', 'unknown')}")

    if not audio_chunks:
        raise RuntimeError("TTS produced no audio.")
    return int(sample_rate), np.concatenate(audio_chunks)


@GPU(duration=300)
def generate_podcast(
    lines: List[Tuple[str, str]],
    voice_map: dict,
    *,
    gap_seconds: float = 0.4,
    progress=None,
) -> Tuple[int, np.ndarray]:
    """Synthesize each (speaker, text) line and stitch into one waveform.

    ``voice_map`` maps speaker name -> VoiceConfig. The whole loop runs inside a single
    GPU allocation so the model is loaded once per podcast.
    """
    engine = _load_engine()
    segments: List[np.ndarray] = []
    sample_rate = _SAMPLE_RATE
    default_voice = VoiceConfig()

    total = len(lines)
    for i, (speaker, text) in enumerate(lines):
        if not text.strip():
            continue
        if progress is not None:
            progress((i / max(total, 1)), desc=f"Voicing line {i + 1}/{total} ({speaker})")
        voice = voice_map.get(speaker, default_voice)
        request = _build_request(text, voice)
        for result in engine.inference(request):
            if getattr(result, "code", None) == "final" and getattr(result, "audio", None):
                sample_rate, audio = result.audio
                segments.append(np.asarray(audio, dtype=np.float32).reshape(-1))
        if gap_seconds > 0:
            segments.append(np.zeros(int(sample_rate * gap_seconds), dtype=np.float32))

    if not segments:
        raise RuntimeError("No audio was generated for this script.")
    return int(sample_rate), np.concatenate(segments)


def write_wav(sample_rate: int, audio: np.ndarray) -> str:
    """Write a waveform to a temp .wav file and return its path (for download)."""
    import soundfile as sf

    path = tempfile.mktemp(suffix=".wav")
    sf.write(path, audio, sample_rate)
    return path