File size: 14,931 Bytes
3afa0cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
"""
OpenAI-compatible FastAPI server for MOSS-TTS longform (1.5B / 7B) backend.

Exposes the same ``POST /v1/audio/speech`` endpoint shape as the realtime
server so clients can switch backends by changing the base URL only.

Environment variables
---------------------
MOSS_TTS_LONGFORM_MODEL_PATH   HF repo or local path for the backbone model
                                (default: OpenMOSS-Team/MOSS-TTS-Local-Transformer)
MOSS_TTS_CODEC_MODEL_PATH      HF repo or local path for the audio tokenizer
                                (default: OpenMOSS-Team/MOSS-Audio-Tokenizer)
MOSS_TTS_DEVICE                PyTorch device (default: cuda:0)
MOSS_TTS_ATTN_IMPLEMENTATION   sdpa | flash_attention_2 | eager | auto
                                (default: auto)
MOSS_TTS_TORCH_DTYPE           bfloat16 | float16 | float32 | auto
                                (default: auto → bfloat16 when CUDA present)
MOSS_TTS_VOICE_DIR             Directory that holds voice-prompt WAV/MP3 files
                                named after OpenAI voice IDs (default: built-in
                                audio/ next to openai_api.py in moss_tts_realtime)
MOSS_TTS_MAX_NEW_TOKENS        Max generation tokens (default: 4096)
                                and upper bound for the per-request heuristic cap
MOSS_TTS_TEMPERATURE           Audio sampling temperature (default: 1.0)
MOSS_TTS_TOP_P                 Audio top-p (default: 0.95)
MOSS_TTS_TOP_K                 Audio top-k (default: 50)
MOSS_TTS_REPETITION_PENALTY    Audio repetition penalty (default: 1.1)
MOSS_TTS_WARMUP_ON_START       true/1/yes → run a short warmup (default: true)
MOSS_TTS_MAX_CONCURRENT        Max simultaneous synthesis requests (default: 1)
MOSS_TTS_HOST                  Bind host (default: 0.0.0.0)
MOSS_TTS_PORT                  Bind port (default: 8013)
MOSS_TTS_LOG_LEVEL             Logging verbosity (default: INFO)
"""

from __future__ import annotations

import io
import logging
import os
import sys
import threading
import time
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Literal

import numpy as np
from fastapi import FastAPI, HTTPException, Response
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, ConfigDict, Field

# Make sure the project root is importable when the script is run directly.
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
if str(_PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(_PROJECT_ROOT))

from runner.adapters.longform_native import LongformNativeAdapter

log = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Configuration from environment
# ---------------------------------------------------------------------------

DEFAULT_MODEL_PATH = os.getenv(
    "MOSS_TTS_LONGFORM_MODEL_PATH", "OpenMOSS-Team/MOSS-TTS-Local-Transformer"
)
DEFAULT_CODEC_PATH = os.getenv(
    "MOSS_TTS_CODEC_MODEL_PATH", "OpenMOSS-Team/MOSS-Audio-Tokenizer"
)
DEFAULT_DEVICE = os.getenv("MOSS_TTS_DEVICE", "cuda:0")
DEFAULT_ATTN = os.getenv("MOSS_TTS_ATTN_IMPLEMENTATION", "auto")
DEFAULT_DTYPE = os.getenv("MOSS_TTS_TORCH_DTYPE", "auto")
DEFAULT_MAX_NEW_TOKENS = int(os.getenv("MOSS_TTS_MAX_NEW_TOKENS", "4096"))
DEFAULT_TEMPERATURE = float(os.getenv("MOSS_TTS_TEMPERATURE", "1.0"))
DEFAULT_TOP_P = float(os.getenv("MOSS_TTS_TOP_P", "0.95"))
DEFAULT_TOP_K = int(os.getenv("MOSS_TTS_TOP_K", "50"))
DEFAULT_REPETITION_PENALTY = float(os.getenv("MOSS_TTS_REPETITION_PENALTY", "1.1"))
WARMUP_ON_START = os.getenv("MOSS_TTS_WARMUP_ON_START", "true").lower() in ("true", "1", "yes")
MAX_CONCURRENT = max(1, int(os.getenv("MOSS_TTS_MAX_CONCURRENT", "1")))

# Directory that contains per-voice reference audio files (optional).
# Falls back to the audio/ folder next to the realtime openai_api.py.
_DEFAULT_VOICE_DIR = Path(__file__).resolve().parent.parent / "moss_tts_realtime" / "audio"
VOICE_DIR = Path(os.getenv("MOSS_TTS_VOICE_DIR", str(_DEFAULT_VOICE_DIR)))

_SUPPORTED_MODELS = {
    "tts-1": DEFAULT_MODEL_PATH,
    "tts-1-hd": DEFAULT_MODEL_PATH,
    "moss-tts-longform": DEFAULT_MODEL_PATH,
    "moss-tts-delay": DEFAULT_MODEL_PATH,
}

_VOICE_PRESETS: dict[str, Path | None] = {
    "alloy":   VOICE_DIR / "prompt_audio.mp3",
    "echo":    VOICE_DIR / "prompt_audio1.mp3",
    "fable":   VOICE_DIR / "prompt_audio.mp3",
    "nova":    VOICE_DIR / "prompt_audio1.mp3",
    "onyx":    VOICE_DIR / "prompt_audio.mp3",
    "shimmer": VOICE_DIR / "prompt_audio1.mp3",
    "default": None,
}

_generation_semaphore = threading.BoundedSemaphore(MAX_CONCURRENT)

# Single global adapter instance, loaded once at startup.
_adapter: LongformNativeAdapter | None = None


# ---------------------------------------------------------------------------
# Request / response models
# ---------------------------------------------------------------------------

class OpenAISpeechRequest(BaseModel):
    model_config = ConfigDict(extra="ignore")

    model: str = Field(default="tts-1")
    input: str = Field(..., min_length=1, max_length=8192)
    voice: str = Field(default="alloy")
    response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field(default="mp3")
    speed: float = Field(default=1.0, ge=0.25, le=4.0)  # speed is accepted but ignored for longform


class VoiceInfo(BaseModel):
    id: str
    name: str
    description: str | None = None


# ---------------------------------------------------------------------------
# Audio helpers (shared with realtime server)
# ---------------------------------------------------------------------------

def _content_type(audio_format: str) -> str:
    return {
        "mp3": "audio/mpeg",
        "opus": "audio/opus",
        "aac": "audio/aac",
        "flac": "audio/flac",
        "wav": "audio/wav",
        "pcm": "audio/pcm",
    }[audio_format]


def _wav_bytes(audio: np.ndarray, sample_rate: int) -> bytes:
    import wave

    audio = np.asarray(audio, dtype=np.float32).reshape(-1)
    audio = np.clip(audio, -1.0, 1.0)
    audio_i16 = (audio * 32767.0).astype(np.int16)
    buf = io.BytesIO()
    with wave.open(buf, "wb") as wf:
        wf.setnchannels(1)
        wf.setsampwidth(2)
        wf.setframerate(sample_rate)
        wf.writeframes(audio_i16.tobytes())
    return buf.getvalue()


def _pcm_bytes(audio: np.ndarray) -> bytes:
    audio = np.asarray(audio, dtype=np.float32).reshape(-1)
    return (np.clip(audio, -1.0, 1.0) * 32767.0).astype(np.int16).tobytes()


def _encode_audio(audio: np.ndarray, sample_rate: int, response_format: str) -> bytes:
    if response_format == "wav":
        return _wav_bytes(audio, sample_rate)
    if response_format == "pcm":
        return _pcm_bytes(audio)
    try:
        from pydub import AudioSegment
    except ImportError as exc:
        raise RuntimeError(
            f"Compressed output ('{response_format}') requires pydub: {exc}"
        ) from exc

    wav_b = _wav_bytes(audio, sample_rate)
    seg = AudioSegment.from_wav(io.BytesIO(wav_b))
    out = io.BytesIO()
    kwargs = {
        "mp3":  {"format": "mp3",  "bitrate": "192k"},
        "opus": {"format": "opus", "bitrate": "128k"},
        "aac":  {"format": "adts", "bitrate": "192k"},
        "flac": {"format": "flac"},
    }[response_format]
    fmt = kwargs.pop("format")
    seg.export(out, format=fmt, **kwargs)
    return out.getvalue()


# ---------------------------------------------------------------------------
# Voice resolution
# ---------------------------------------------------------------------------

def _voice_reference_path(voice: str) -> str | None:
    """Return the filesystem path to the reference audio for *voice*, or None."""
    normalized = voice.strip().lower()
    if not normalized:
        raise HTTPException(status_code=400, detail="voice is required")

    if normalized in _VOICE_PRESETS:
        p = _VOICE_PRESETS[normalized]
        if p is None:
            return None  # "default" preset → no reference audio
        if not p.exists():
            log.warning("Bundled voice prompt missing: %s – using no reference.", p)
            return None
        return str(p.resolve())

    # Allow callers to pass an absolute or relative path directly.
    candidate = Path(voice).expanduser()
    if candidate.is_file():
        return str(candidate.resolve())

    raise HTTPException(
        status_code=400,
        detail=(
            f"Unsupported voice '{voice}'.  "
            f"Available voices: {', '.join(sorted(_VOICE_PRESETS))}"
        ),
    )


def _estimate_max_new_tokens(text: str) -> int:
    """Estimate a practical generation cap from prompt length.

    The local-transformer backend does not always emit EOS promptly, so a fixed
    4096-token cap causes short prompts to run for minutes. Approximate speech
    length from word count and clamp it by the environment-configured ceiling.
    """
    words = max(1, len(text.split()))
    estimated = words * 6 + 64
    return max(128, min(DEFAULT_MAX_NEW_TOKENS, estimated))


# ---------------------------------------------------------------------------
# Core synthesis
# ---------------------------------------------------------------------------

def _synthesize(payload: OpenAISpeechRequest) -> tuple[bytes, dict[str, float]]:
    """Run synthesis and return ``(encoded_audio_bytes, metrics)``."""
    assert _adapter is not None, "Adapter not initialised"  # guaranteed by lifespan

    if payload.model not in _SUPPORTED_MODELS:
        raise HTTPException(
            status_code=400,
            detail=(
                f"Unsupported model '{payload.model}'.  "
                f"Supported: {', '.join(sorted(_SUPPORTED_MODELS))}"
            ),
        )

    reference_path = _voice_reference_path(payload.voice)

    t0 = time.perf_counter()

    acquired = _generation_semaphore.acquire(timeout=120)
    if not acquired:
        raise HTTPException(
            status_code=503,
            detail="Server busy – all generation slots occupied.  Retry shortly.",
        )

    try:
        t_gen_start = time.perf_counter()
        waveform, sample_rate = _adapter.synthesize(
            text=payload.input,
            reference_audio=reference_path,
            max_new_tokens=_estimate_max_new_tokens(payload.input),
            audio_temperature=DEFAULT_TEMPERATURE,
            audio_top_p=DEFAULT_TOP_P,
            audio_top_k=DEFAULT_TOP_K,
            audio_repetition_penalty=DEFAULT_REPETITION_PENALTY,
        )
        t_gen_end = time.perf_counter()
    finally:
        _generation_semaphore.release()

    t_encode_start = time.perf_counter()
    encoded = _encode_audio(waveform, sample_rate, payload.response_format)
    t_encode_end = time.perf_counter()

    audio_seconds = float(waveform.size) / sample_rate
    gen_seconds = t_gen_end - t_gen_start
    total_seconds = t_encode_end - t0

    metrics = {
        "model_generation_seconds": gen_seconds,
        "audio_emit_seconds": t_encode_end - t_encode_start,
        "total_seconds": total_seconds,
        "audio_seconds": audio_seconds,
        "rtf": gen_seconds / max(audio_seconds, 1e-9),
        "ttfb_ms": (t_gen_end - t0) * 1000.0,
    }

    log.info(
        "synthesize: %.1f s audio in %.1f s  (RTF=%.3f)",
        audio_seconds,
        gen_seconds,
        metrics["rtf"],
    )
    return encoded, metrics


# ---------------------------------------------------------------------------
# FastAPI app
# ---------------------------------------------------------------------------

@asynccontextmanager
async def lifespan(app: FastAPI):
    global _adapter

    _adapter = LongformNativeAdapter(
        model_path=DEFAULT_MODEL_PATH,
        device=DEFAULT_DEVICE,
        attn_implementation=DEFAULT_ATTN,
        codec_path=DEFAULT_CODEC_PATH,
        torch_dtype=DEFAULT_DTYPE,
    )
    _adapter.load()

    if WARMUP_ON_START:
        import asyncio
        loop = asyncio.get_event_loop()
        await loop.run_in_executor(None, _adapter.warmup)

    yield

    _adapter = None


app = FastAPI(
    title="MOSS-TTS Longform",
    description="OpenAI-compatible TTS API backed by the MOSS-TTS 1.5B / 7B PyTorch model.",
    version="1.0.0",
    lifespan=lifespan,
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.get("/")
async def root():
    return {
        "service": "moss-tts-longform",
        "status": "ok",
        "model": DEFAULT_MODEL_PATH,
    }


@app.get("/health")
async def health():
    if _adapter is None or _adapter._model is None:
        raise HTTPException(status_code=503, detail="Backend not ready.")
    return {"status": "ok", "backend": "longform-native"}


@app.get("/v1/models")
async def list_models():
    return {
        "object": "list",
        "data": [
            {"id": mid, "object": "model", "owned_by": "OpenMOSS-Team"}
            for mid in sorted(_SUPPORTED_MODELS)
        ],
    }


@app.get("/v1/voices")
async def list_voices():
    voices = [
        VoiceInfo(id=v, name=v.capitalize())
        for v in sorted(_VOICE_PRESETS)
    ]
    return {"object": "list", "data": [v.model_dump() for v in voices]}


@app.post("/v1/audio/speech")
async def create_speech(payload: OpenAISpeechRequest):
    import asyncio

    loop = asyncio.get_event_loop()
    try:
        encoded, metrics = await loop.run_in_executor(None, _synthesize, payload)
    except HTTPException:
        raise
    except Exception as exc:
        log.exception("Synthesis failed")
        raise HTTPException(status_code=500, detail=str(exc)) from exc

    headers = {
        "Content-Disposition": f"attachment; filename=speech.{payload.response_format}",
        "X-MOSS-TTFB-MS": f"{metrics['ttfb_ms']:.1f}",
        "X-MOSS-RTF": f"{metrics['rtf']:.4f}",
        "X-MOSS-AUDIO-SECONDS": f"{metrics['audio_seconds']:.4f}",
        "X-MOSS-STAGE-MODEL-MS": f"{metrics['model_generation_seconds'] * 1000:.1f}",
        "X-MOSS-STAGE-EMIT-MS": f"{metrics['audio_emit_seconds'] * 1000:.1f}",
        "X-MOSS-STAGE-TOTAL-MS": f"{metrics['total_seconds'] * 1000:.1f}",
    }

    return Response(
        content=encoded,
        media_type=_content_type(payload.response_format),
        headers=headers,
    )


# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------

def main() -> None:
    import uvicorn

    logging.basicConfig(
        level=os.getenv("MOSS_TTS_LOG_LEVEL", "INFO").upper(),
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    )
    uvicorn.run(
        app,
        host=os.getenv("MOSS_TTS_HOST", "0.0.0.0"),
        port=int(os.getenv("MOSS_TTS_PORT", "8013")),
        log_level=os.getenv("MOSS_TTS_LOG_LEVEL", "info").lower(),
    )


if __name__ == "__main__":
    main()