Spaces:
Running
Running
| """ | |
| Chatterbox Turbo TTS -- FastAPI Server | |
| ====================================== | |
| Production-ready API with true real-time MP3 streaming, | |
| in-memory voice cloning, and fully non-blocking inference. | |
| Endpoints: | |
| GET /health -> health check + optional warmup | |
| GET /info -> model info, supported tags, parameters | |
| POST /tts -> full audio response (WAV/MP3/FLAC) | |
| POST /tts/stream -> chunked MP3 streaming (MediaSource-ready) | |
| POST /tts/true-stream -> alias for /tts/stream (Kokoro compat) | |
| POST /tts/stop/{stream_id}-> cancel a specific active stream | |
| POST /tts/stop -> cancel ALL active streams | |
| POST /v1/audio/speech -> OpenAI-compatible streaming | |
| """ | |
| import asyncio | |
| import io | |
| import json | |
| import logging | |
| import queue as stdlib_queue | |
| import threading | |
| import time | |
| import urllib.error | |
| import urllib.parse | |
| import urllib.request | |
| import uuid | |
| from concurrent.futures import ThreadPoolExecutor | |
| from typing import Generator, Optional | |
| import numpy as np | |
| import soundfile as sf | |
| from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile | |
| from fastapi.responses import Response, StreamingResponse | |
| from contextlib import asynccontextmanager | |
| from config import Config | |
| from chatterbox_wrapper import ChatterboxWrapper, GenerationCancelled, VoiceProfile | |
| import text_processor | |
| # ββ Logging βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s β %(levelname)-7s β %(name)s β %(message)s", | |
| datefmt="%H:%M:%S", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # ββ Thread pool for CPU-bound inference βββββββββββββββββββββββββββ | |
| tts_executor = ThreadPoolExecutor(max_workers=Config.MAX_WORKERS) | |
| # ββ Lifespan ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def lifespan(app: FastAPI): | |
| try: | |
| wrapper = ChatterboxWrapper() | |
| app.state.wrapper = wrapper | |
| logger.info("β Model loaded, server ready") | |
| except Exception as e: | |
| logger.error(f"β Model loading failed: {e}") | |
| raise | |
| yield | |
| tts_executor.shutdown(wait=False) | |
| app = FastAPI( | |
| title="Chatterbox Turbo TTS API", | |
| version="1.0.0", | |
| docs_url="/docs", | |
| lifespan=lifespan, | |
| ) | |
| # ββ CORS Middleware βββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def cors_middleware(request: Request, call_next): | |
| origin = request.headers.get("origin") | |
| # Preflight | |
| if request.method == "OPTIONS" and origin in Config.ALLOWED_ORIGINS: | |
| return Response( | |
| status_code=200, | |
| headers={ | |
| "Access-Control-Allow-Origin": origin, | |
| "Access-Control-Allow-Methods": "*", | |
| "Access-Control-Allow-Headers": "*", | |
| "Access-Control-Allow-Credentials": "true", | |
| }, | |
| ) | |
| if not origin or origin in Config.ALLOWED_ORIGINS: | |
| response = await call_next(request) | |
| if origin: | |
| response.headers["Access-Control-Allow-Origin"] = origin | |
| response.headers["Access-Control-Allow-Credentials"] = "true" | |
| response.headers["Access-Control-Allow-Methods"] = "*" | |
| response.headers["Access-Control-Allow-Headers"] = "*" | |
| response.headers["Access-Control-Expose-Headers"] = "X-Stream-Id" | |
| return response | |
| logger.warning(f"π« Blocked origin: {origin}") | |
| return Response(status_code=403, content="Forbidden: Origin not allowed") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Helper: resolve voice from optional upload | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def _resolve_voice( | |
| voice_ref: Optional[UploadFile], | |
| voice_name: str, | |
| wrapper: ChatterboxWrapper, | |
| ) -> VoiceProfile: | |
| """Return a VoiceProfile from uploaded audio, built-in voice name, or default.""" | |
| # 1) If a file was uploaded, encode it (highest priority) | |
| if voice_ref is not None and voice_ref.filename: | |
| audio_bytes = await voice_ref.read() | |
| if len(audio_bytes) > Config.MAX_VOICE_UPLOAD_BYTES: | |
| raise HTTPException(status_code=413, detail="Voice file too large (max 10 MB)") | |
| if len(audio_bytes) == 0: | |
| raise HTTPException(status_code=400, detail="Empty voice file") | |
| loop = asyncio.get_running_loop() | |
| try: | |
| return await loop.run_in_executor( | |
| tts_executor, wrapper.encode_voice_from_bytes, audio_bytes | |
| ) | |
| except ValueError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| except Exception as e: | |
| logger.error(f"Voice encoding failed: {e}") | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Could not process voice file: {str(e)}. " | |
| f"Supported formats: WAV, MP3, MPEG, M4A, OGG, FLAC, WebM." | |
| ) | |
| # 2) Resolve by built-in voice name (returns cached profile β no encoding) | |
| try: | |
| return wrapper.get_builtin_voice(voice_name) | |
| except (ValueError, KeyError) as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Helper: encode numpy audio to bytes in given format | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _encode_audio(audio: np.ndarray, fmt: str = "wav") -> tuple[bytes, str]: | |
| buf = io.BytesIO() | |
| fmt_lower = fmt.lower() | |
| if fmt_lower == "mp3": | |
| sf.write(buf, audio, Config.SAMPLE_RATE, format="mp3") | |
| media = "audio/mpeg" | |
| elif fmt_lower == "flac": | |
| sf.write(buf, audio, Config.SAMPLE_RATE, format="flac") | |
| media = "audio/flac" | |
| else: | |
| sf.write(buf, audio, Config.SAMPLE_RATE, format="wav") | |
| media = "audio/wav" | |
| return buf.getvalue(), media | |
| def _encode_mp3_chunk(audio: np.ndarray) -> bytes: | |
| """Encode one numpy chunk to MP3 bytes (same encoder path as current server).""" | |
| data, _ = _encode_audio(audio, fmt="mp3") | |
| return data | |
| def _build_helper_endpoint(base_url: str, path: str) -> str: | |
| return f"{base_url.rstrip('/')}{path}" | |
| def _internal_headers() -> dict[str, str]: | |
| headers = {"Content-Type": "application/json", "Accept": "audio/mpeg"} | |
| if Config.INTERNAL_SHARED_SECRET: | |
| headers["X-Internal-Secret"] = Config.INTERNAL_SHARED_SECRET | |
| return headers | |
| def _helper_request_chunk( | |
| helper_base_url: str, | |
| payload: dict, | |
| timeout_sec: float, | |
| ) -> bytes: | |
| url = _build_helper_endpoint(helper_base_url, "/internal/chunk/synthesize") | |
| body = json.dumps(payload).encode("utf-8") | |
| req = urllib.request.Request( | |
| url=url, | |
| data=body, | |
| headers=_internal_headers(), | |
| method="POST", | |
| ) | |
| with urllib.request.urlopen(req, timeout=timeout_sec) as resp: | |
| return resp.read() | |
| def _helper_register_voice( | |
| helper_base_url: str, | |
| stream_id: str, | |
| audio_bytes: bytes, | |
| timeout_sec: float, | |
| ) -> str: | |
| """Register reference voice on helper once, return voice_key for chunk calls.""" | |
| query = urllib.parse.urlencode({"stream_id": stream_id}) | |
| url = _build_helper_endpoint(helper_base_url, f"/internal/voice/register?{query}") | |
| headers = {"Content-Type": "application/octet-stream", "Accept": "application/json"} | |
| if Config.INTERNAL_SHARED_SECRET: | |
| headers["X-Internal-Secret"] = Config.INTERNAL_SHARED_SECRET | |
| req = urllib.request.Request( | |
| url=url, | |
| data=audio_bytes, | |
| headers=headers, | |
| method="POST", | |
| ) | |
| with urllib.request.urlopen(req, timeout=timeout_sec) as resp: | |
| data = json.loads(resp.read().decode("utf-8")) | |
| voice_key = (data.get("voice_key") or "").strip() | |
| if not voice_key: | |
| raise RuntimeError("helper voice registration returned no voice_key") | |
| return voice_key | |
| def _helper_cancel_stream(helper_base_url: str, stream_id: str): | |
| """Best-effort cancellation signal to helper.""" | |
| try: | |
| url = _build_helper_endpoint(helper_base_url, f"/internal/chunk/cancel/{stream_id}") | |
| req = urllib.request.Request( | |
| url=url, | |
| data=b"", | |
| headers=_internal_headers(), | |
| method="POST", | |
| ) | |
| with urllib.request.urlopen(req, timeout=3.0): | |
| pass | |
| except Exception: | |
| pass | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Endpoints | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def health(warm_up: bool = False): | |
| wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None) | |
| status = { | |
| "status": "healthy" if wrapper else "loading", | |
| "model_loaded": wrapper is not None, | |
| "model_dtype": Config.MODEL_DTYPE, | |
| "streaming_supported": True, | |
| "voice_cache_entries": wrapper._voice_cache.size if wrapper else 0, | |
| } | |
| if warm_up and wrapper: | |
| try: | |
| loop = asyncio.get_running_loop() | |
| await loop.run_in_executor(tts_executor, wrapper.warmup) | |
| status["warm_up"] = "success" | |
| except Exception as e: | |
| status["warm_up"] = f"failed: {e}" | |
| return status | |
| async def info(): | |
| return { | |
| "model": Config.MODEL_ID, | |
| "dtype": Config.MODEL_DTYPE, | |
| "sample_rate": Config.SAMPLE_RATE, | |
| "paralinguistic_tags": list(Config.PARALINGUISTIC_TAGS), | |
| "tag_usage": "Insert tags directly in text, e.g. 'That is so funny! [laugh] Anywayβ¦'", | |
| "parameters": { | |
| "max_new_tokens": {"default": Config.MAX_NEW_TOKENS, "range": "64β2048"}, | |
| "repetition_penalty": {"default": Config.REPETITION_PENALTY, "range": "1.0β2.0"}, | |
| }, | |
| "voice_cloning": { | |
| "description": "Upload 3β30s reference WAV/MP3 as 'voice_ref' field", | |
| "max_upload_mb": Config.MAX_VOICE_UPLOAD_BYTES // (1024 * 1024), | |
| }, | |
| "parallel_mode": { | |
| "enabled": Config.ENABLE_PARALLEL_MODE, | |
| "helper_configured": bool(Config.HELPER_BASE_URL), | |
| "helper_base_url": Config.HELPER_BASE_URL or None, | |
| "supports_voice_ref": True, | |
| }, | |
| } | |
| async def list_voices(): | |
| """Return all built-in voices available for selection.""" | |
| wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None) | |
| if not wrapper: | |
| raise HTTPException(503, "Model not loaded") | |
| return { | |
| "default": wrapper.default_voice_name, | |
| "voices": wrapper.list_builtin_voices(), | |
| } | |
| # ββ POST /tts βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def text_to_speech( | |
| text: str = Form(...), | |
| voice_ref: Optional[UploadFile] = File(None), | |
| voice_name: str = Form("default"), | |
| output_format: str = Form("wav"), | |
| max_new_tokens: int = Form(Config.MAX_NEW_TOKENS), | |
| repetition_penalty: float = Form(Config.REPETITION_PENALTY), | |
| ): | |
| """Generate complete audio for the given text.""" | |
| wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None) | |
| if not wrapper: | |
| raise HTTPException(503, "Model not loaded") | |
| if not text or not text.strip(): | |
| raise HTTPException(400, "Text is required") | |
| voice = await _resolve_voice(voice_ref, voice_name, wrapper) | |
| loop = asyncio.get_running_loop() | |
| try: | |
| audio = await loop.run_in_executor( | |
| tts_executor, | |
| wrapper.generate_speech, | |
| text, voice, max_new_tokens, repetition_penalty, | |
| ) | |
| except ValueError as e: | |
| raise HTTPException(400, str(e)) | |
| except Exception as e: | |
| logger.error(f"TTS error: {e}") | |
| raise HTTPException(500, "Internal server error") | |
| data, media_type = _encode_audio(audio, output_format) | |
| return Response( | |
| content=data, | |
| media_type=media_type, | |
| headers={"Content-Disposition": f"attachment; filename=tts_output.{output_format}"}, | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Active Stream Registry (for cancellation) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _active_streams: dict[str, threading.Event] = {} | |
| _internal_cancelled_streams: set[str] = set() | |
| _internal_cancel_lock = threading.Lock() | |
| _internal_stream_voice_keys: dict[str, set[str]] = {} | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Pipeline Streaming Generator | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _pipeline_stream_generator( | |
| wrapper: ChatterboxWrapper, | |
| text: str, | |
| voice: VoiceProfile, | |
| max_new_tokens: int, | |
| repetition_penalty: float, | |
| stream_id: str, | |
| ) -> Generator[bytes, None, None]: | |
| """Two-stage producer-consumer pipeline for minimal inter-chunk gaps. | |
| Architecture: | |
| Producer thread (heavyweight, ~80% CPU): | |
| ONNX token generation β audio decoding β raw numpy arrays β queue | |
| Consumer (this generator, lightweight, ~20% CPU): | |
| queue β MP3 encode β yield to HTTP response | |
| Why this helps: | |
| - ONNX model runs CONTINUOUSLY without waiting for MP3 encode or HTTP | |
| - MP3 encoding (libsndfile, C code) releases GIL β true parallelism | |
| - ONNX inference (C++ code) also releases GIL β both run simultaneously | |
| - Queue(maxsize=2) lets producer stay 1-2 chunks ahead | |
| Cancellation: | |
| - cancel_event checked between chunks + every 25 autoregressive steps | |
| - Client disconnect triggers GeneratorExit β finally sets cancel | |
| - /tts/stop endpoint sets cancel externally | |
| """ | |
| cancel_event = threading.Event() | |
| _active_streams[stream_id] = cancel_event | |
| # Raw audio buffer: producer puts numpy arrays, consumer takes them | |
| audio_buffer: stdlib_queue.Queue = stdlib_queue.Queue(maxsize=2) | |
| def _producer(): | |
| """Heavyweight worker: runs ONNX model continuously.""" | |
| try: | |
| for audio_chunk in wrapper.stream_speech( | |
| text, voice, | |
| max_new_tokens=max_new_tokens, | |
| repetition_penalty=repetition_penalty, | |
| is_cancelled=cancel_event.is_set, | |
| ): | |
| if cancel_event.is_set(): | |
| break | |
| while not cancel_event.is_set(): | |
| try: | |
| audio_buffer.put(audio_chunk, timeout=0.1) | |
| break | |
| except stdlib_queue.Full: | |
| continue | |
| except GenerationCancelled: | |
| logger.info(f"[{stream_id}] Generation cancelled") | |
| except Exception as e: | |
| while not cancel_event.is_set(): | |
| try: | |
| audio_buffer.put(e, timeout=0.1) | |
| break | |
| except stdlib_queue.Full: | |
| continue | |
| finally: | |
| while not cancel_event.is_set(): | |
| try: | |
| audio_buffer.put(None, timeout=0.1) | |
| break | |
| except stdlib_queue.Full: | |
| continue | |
| producer = threading.Thread(target=_producer, daemon=True) | |
| producer.start() | |
| try: | |
| # Consumer: lightweight MP3 encoding + yield | |
| while True: | |
| item = audio_buffer.get() | |
| if item is None: | |
| break | |
| if isinstance(item, Exception): | |
| logger.error(f"[{stream_id}] Stream error: {item}") | |
| break | |
| if cancel_event.is_set(): | |
| break | |
| # MP3 encode (C code, releases GIL, runs parallel with next ONNX step) | |
| buf = io.BytesIO() | |
| sf.write(buf, item, Config.SAMPLE_RATE, format="mp3") | |
| yield buf.getvalue() | |
| finally: | |
| # Cleanup: signal producer to stop + deregister | |
| cancel_event.set() | |
| _active_streams.pop(stream_id, None) | |
| def _parallel_odd_even_stream_generator( | |
| wrapper: ChatterboxWrapper, | |
| text: str, | |
| local_voice: VoiceProfile, | |
| helper_voice_bytes: Optional[bytes], | |
| max_new_tokens: int, | |
| repetition_penalty: float, | |
| stream_id: str, | |
| helper_base_url: str, | |
| ) -> Generator[bytes, None, None]: | |
| """Additive odd/even split streamer (primary handles odd, helper handles even).""" | |
| cancel_event = threading.Event() | |
| _active_streams[stream_id] = cancel_event | |
| clean_text = text_processor.sanitize(text.strip()[: Config.MAX_TEXT_LENGTH]) | |
| chunks = text_processor.split_for_streaming(clean_text) | |
| total_chunks = len(chunks) | |
| if total_chunks == 0: | |
| _active_streams.pop(stream_id, None) | |
| return | |
| lock = threading.Lock() | |
| cond = threading.Condition(lock) | |
| ready: dict[int, bytes] = {} | |
| first_error: Optional[Exception] = None | |
| workers_done = 0 | |
| def _publish(idx: int, data: bytes): | |
| with cond: | |
| ready[idx] = data | |
| cond.notify_all() | |
| def _set_error(err: Exception): | |
| nonlocal first_error | |
| with cond: | |
| if first_error is None: | |
| first_error = err | |
| cond.notify_all() | |
| def _worker_done(): | |
| nonlocal workers_done | |
| with cond: | |
| workers_done += 1 | |
| cond.notify_all() | |
| def _synth_local(chunk_text: str) -> bytes: | |
| audio = wrapper.generate_speech( | |
| chunk_text, | |
| local_voice, | |
| max_new_tokens=max_new_tokens, | |
| repetition_penalty=repetition_penalty, | |
| ) | |
| return _encode_mp3_chunk(audio) | |
| def _odd_worker(): | |
| try: | |
| for idx in range(0, total_chunks, 2): | |
| if cancel_event.is_set(): | |
| break | |
| data = _synth_local(chunks[idx]) | |
| _publish(idx, data) | |
| except Exception as e: | |
| _set_error(e) | |
| finally: | |
| _worker_done() | |
| def _even_worker(): | |
| helper_available = True | |
| helper_voice_key: Optional[str] = None | |
| try: | |
| if helper_voice_bytes: | |
| attempts = 2 if Config.HELPER_RETRY_ONCE else 1 | |
| last_err: Optional[Exception] = None | |
| for _ in range(attempts): | |
| try: | |
| helper_voice_key = _helper_register_voice( | |
| helper_base_url=helper_base_url, | |
| stream_id=stream_id, | |
| audio_bytes=helper_voice_bytes, | |
| timeout_sec=max(1.0, Config.HELPER_TIMEOUT_SEC), | |
| ) | |
| last_err = None | |
| break | |
| except Exception as reg_err: | |
| last_err = reg_err | |
| continue | |
| if last_err is not None: | |
| helper_available = False | |
| logger.warning( | |
| f"[{stream_id}] Helper voice registration failed; " | |
| "falling back to local synthesis for even chunks" | |
| ) | |
| for idx in range(1, total_chunks, 2): | |
| if cancel_event.is_set(): | |
| break | |
| if helper_available: | |
| payload = { | |
| "stream_id": stream_id, | |
| "chunk_index": idx, | |
| "text": chunks[idx], | |
| "max_new_tokens": max_new_tokens, | |
| "repetition_penalty": repetition_penalty, | |
| "output_format": "mp3", | |
| } | |
| if helper_voice_key: | |
| payload["voice_key"] = helper_voice_key | |
| attempts = 2 if Config.HELPER_RETRY_ONCE else 1 | |
| last_err: Optional[Exception] = None | |
| for _ in range(attempts): | |
| try: | |
| helper_data = _helper_request_chunk( | |
| helper_base_url=helper_base_url, | |
| payload=payload, | |
| timeout_sec=max(1.0, Config.HELPER_TIMEOUT_SEC), | |
| ) | |
| _publish(idx, helper_data) | |
| last_err = None | |
| break | |
| except Exception as helper_err: | |
| last_err = helper_err | |
| continue | |
| if last_err is None: | |
| continue | |
| helper_available = False | |
| logger.warning( | |
| f"[{stream_id}] Helper failed at chunk {idx}; " | |
| "falling back to local synthesis for remaining even chunks" | |
| ) | |
| # Local fallback for even chunks | |
| data = _synth_local(chunks[idx]) | |
| _publish(idx, data) | |
| except Exception as e: | |
| _set_error(e) | |
| finally: | |
| _worker_done() | |
| odd_thread = threading.Thread(target=_odd_worker, daemon=True) | |
| even_thread = threading.Thread(target=_even_worker, daemon=True) | |
| odd_thread.start() | |
| even_thread.start() | |
| next_idx = 0 | |
| try: | |
| while next_idx < total_chunks: | |
| with cond: | |
| while ( | |
| next_idx not in ready | |
| and first_error is None | |
| and not cancel_event.is_set() | |
| and workers_done < 2 | |
| ): | |
| cond.wait(timeout=0.1) | |
| if cancel_event.is_set(): | |
| break | |
| if next_idx in ready: | |
| data = ready.pop(next_idx) | |
| elif first_error is not None: | |
| logger.error(f"[{stream_id}] Parallel stream error: {first_error}") | |
| break | |
| elif workers_done >= 2: | |
| logger.error( | |
| f"[{stream_id}] Parallel stream ended with missing chunk index {next_idx}" | |
| ) | |
| break | |
| else: | |
| continue | |
| yield data | |
| next_idx += 1 | |
| finally: | |
| cancel_event.set() | |
| _helper_cancel_stream(helper_base_url, stream_id) | |
| odd_thread.join(timeout=1.0) | |
| even_thread.join(timeout=1.0) | |
| _active_streams.pop(stream_id, None) | |
| # ββ POST /tts/stream & /tts/true-stream ββββββββββββββββββββββββββ | |
| async def stream_text_to_speech( | |
| text: str = Form(...), | |
| voice_ref: Optional[UploadFile] = File(None), | |
| voice_name: str = Form("default"), | |
| max_new_tokens: int = Form(Config.MAX_NEW_TOKENS), | |
| repetition_penalty: float = Form(Config.REPETITION_PENALTY), | |
| ): | |
| """True real-time streaming: yields MP3 chunks as each sentence finishes. | |
| Response includes X-Stream-Id header for cancellation via /tts/stop. | |
| Compatible with frontend's MediaSource + ReadableStream pattern. | |
| """ | |
| wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None) | |
| if not wrapper: | |
| raise HTTPException(503, "Model not loaded") | |
| if not text or not text.strip(): | |
| raise HTTPException(400, "Text is required") | |
| voice = await _resolve_voice(voice_ref, voice_name, wrapper) | |
| stream_id = uuid.uuid4().hex[:12] | |
| return StreamingResponse( | |
| _pipeline_stream_generator( | |
| wrapper, text, voice, max_new_tokens, repetition_penalty, stream_id, | |
| ), | |
| media_type="audio/mpeg", | |
| headers={ | |
| "Content-Disposition": "attachment; filename=tts_stream.mp3", | |
| "Transfer-Encoding": "chunked", | |
| "X-Stream-Id": stream_id, | |
| "X-Streaming-Type": "true-realtime", | |
| "Cache-Control": "no-cache", | |
| }, | |
| ) | |
| async def parallel_stream_text_to_speech( | |
| text: str = Form(...), | |
| voice_ref: Optional[UploadFile] = File(None), | |
| voice_name: str = Form("default"), | |
| max_new_tokens: int = Form(Config.MAX_NEW_TOKENS), | |
| repetition_penalty: float = Form(Config.REPETITION_PENALTY), | |
| helper_url: Optional[str] = Form(None), | |
| ): | |
| """Additive odd/even split stream mode (primary + helper).""" | |
| wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None) | |
| if not wrapper: | |
| raise HTTPException(503, "Model not loaded") | |
| if not Config.ENABLE_PARALLEL_MODE: | |
| raise HTTPException(503, "Parallel mode is disabled") | |
| if not text or not text.strip(): | |
| raise HTTPException(400, "Text is required") | |
| local_voice: VoiceProfile = wrapper.default_voice | |
| helper_voice_bytes: Optional[bytes] = None | |
| if voice_ref is not None and voice_ref.filename: | |
| helper_voice_bytes = await voice_ref.read() | |
| if len(helper_voice_bytes) > Config.MAX_VOICE_UPLOAD_BYTES: | |
| raise HTTPException(status_code=413, detail="Voice file too large (max 10 MB)") | |
| if len(helper_voice_bytes) == 0: | |
| raise HTTPException(status_code=400, detail="Empty voice file") | |
| loop = asyncio.get_running_loop() | |
| try: | |
| local_voice = await loop.run_in_executor( | |
| tts_executor, wrapper.encode_voice_from_bytes, helper_voice_bytes | |
| ) | |
| except Exception as e: | |
| logger.error(f"Parallel voice encoding failed: {e}") | |
| raise HTTPException(400, "Could not process voice file for parallel mode") | |
| else: | |
| # Built-in voice selected by name β resolve locally and prepare | |
| # bytes for helper registration so helpers cache the same hash. | |
| try: | |
| selected_voice_id = wrapper.resolve_voice_id(voice_name) | |
| local_voice = wrapper.get_builtin_voice(selected_voice_id) | |
| except ValueError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| # Only send bytes to helper if a non-default voice was selected, | |
| # because the helper's own default is already loaded. | |
| if selected_voice_id != wrapper.default_voice_name: | |
| helper_voice_bytes = wrapper.get_builtin_voice_bytes(selected_voice_id) | |
| if not helper_voice_bytes: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Selected voice '{voice_name}' is unavailable for helper registration", | |
| ) | |
| resolved_helper = (helper_url or Config.HELPER_BASE_URL).strip() | |
| if not resolved_helper: | |
| raise HTTPException( | |
| 400, | |
| "Helper URL not configured. Set CB_HELPER_BASE_URL or pass helper_url.", | |
| ) | |
| stream_id = uuid.uuid4().hex[:12] | |
| return StreamingResponse( | |
| _parallel_odd_even_stream_generator( | |
| wrapper=wrapper, | |
| text=text, | |
| local_voice=local_voice, | |
| helper_voice_bytes=helper_voice_bytes, | |
| max_new_tokens=max_new_tokens, | |
| repetition_penalty=repetition_penalty, | |
| stream_id=stream_id, | |
| helper_base_url=resolved_helper, | |
| ), | |
| media_type="audio/mpeg", | |
| headers={ | |
| "Content-Disposition": "attachment; filename=tts_parallel_stream.mp3", | |
| "Transfer-Encoding": "chunked", | |
| "X-Stream-Id": stream_id, | |
| "X-Streaming-Type": "parallel-odd-even", | |
| "Cache-Control": "no-cache", | |
| }, | |
| ) | |
| # ββ JSON body variant (Kokoro/OpenAI compatibility) βββββββββββββββ | |
| from pydantic import BaseModel, Field | |
| class InternalChunkRequest(BaseModel): | |
| stream_id: str = Field(..., min_length=1, max_length=64) | |
| chunk_index: int = Field(..., ge=0) | |
| text: str = Field(..., min_length=1, max_length=10000) | |
| max_new_tokens: int = Field(default=Config.MAX_NEW_TOKENS, ge=64, le=2048) | |
| repetition_penalty: float = Field(default=Config.REPETITION_PENALTY, ge=1.0, le=2.0) | |
| output_format: str = Field(default="mp3") | |
| voice_key: Optional[str] = Field(default=None, min_length=1, max_length=64) | |
| class TTSJsonRequest(BaseModel): | |
| text: str = Field(..., min_length=1, max_length=50000) | |
| voice: str = Field(default="default") | |
| speed: float = Field(default=1.0, ge=0.5, le=2.0) # reserved for future use | |
| max_new_tokens: int = Field(default=Config.MAX_NEW_TOKENS, ge=64, le=2048) | |
| repetition_penalty: float = Field(default=Config.REPETITION_PENALTY, ge=1.0, le=2.0) | |
| async def internal_voice_register(http_request: Request): | |
| """Register voice once for a stream; returns reusable voice_key.""" | |
| if Config.INTERNAL_SHARED_SECRET: | |
| provided = http_request.headers.get("X-Internal-Secret", "") | |
| if provided != Config.INTERNAL_SHARED_SECRET: | |
| raise HTTPException(403, "Forbidden") | |
| wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None) | |
| if not wrapper: | |
| raise HTTPException(503, "Model not loaded") | |
| audio_bytes = await http_request.body() | |
| if len(audio_bytes) > Config.MAX_VOICE_UPLOAD_BYTES: | |
| raise HTTPException(status_code=413, detail="Voice file too large (max 10 MB)") | |
| if len(audio_bytes) == 0: | |
| raise HTTPException(status_code=400, detail="Empty voice file") | |
| loop = asyncio.get_running_loop() | |
| try: | |
| voice = await loop.run_in_executor( | |
| tts_executor, wrapper.encode_voice_from_bytes, audio_bytes | |
| ) | |
| except Exception as e: | |
| logger.error(f"[internal] voice register failed: {e}") | |
| raise HTTPException(400, "Voice registration failed") | |
| voice_key = (voice.audio_hash or "").strip() | |
| if not voice_key: | |
| raise HTTPException(500, "Voice key unavailable") | |
| stream_id = (http_request.query_params.get("stream_id") or "").strip() | |
| if stream_id: | |
| with _internal_cancel_lock: | |
| keys = _internal_stream_voice_keys.setdefault(stream_id, set()) | |
| keys.add(voice_key) | |
| return {"status": "registered", "voice_key": voice_key} | |
| async def internal_chunk_synthesize( | |
| request: InternalChunkRequest, | |
| http_request: Request, | |
| ): | |
| """Internal endpoint used by primary/helper parallel routing.""" | |
| if Config.INTERNAL_SHARED_SECRET: | |
| provided = http_request.headers.get("X-Internal-Secret", "") | |
| if provided != Config.INTERNAL_SHARED_SECRET: | |
| raise HTTPException(403, "Forbidden") | |
| with _internal_cancel_lock: | |
| if request.stream_id in _internal_cancelled_streams: | |
| raise HTTPException(409, "Stream already cancelled") | |
| wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None) | |
| if not wrapper: | |
| raise HTTPException(503, "Model not loaded") | |
| voice_profile = wrapper.default_voice | |
| if request.voice_key: | |
| cached_voice = wrapper._voice_cache.get(request.voice_key) | |
| if cached_voice is None: | |
| # Built-in voices are permanent in wrapper registry even if TTL cache entry expired. | |
| cached_voice = wrapper.get_builtin_voice_by_hash(request.voice_key) | |
| if cached_voice is None: | |
| raise HTTPException(409, "Voice key expired or not found") | |
| voice_profile = cached_voice | |
| loop = asyncio.get_running_loop() | |
| try: | |
| audio = await loop.run_in_executor( | |
| tts_executor, | |
| wrapper.generate_speech, | |
| request.text, | |
| voice_profile, | |
| request.max_new_tokens, | |
| request.repetition_penalty, | |
| ) | |
| except Exception as e: | |
| logger.error(f"[internal] chunk {request.chunk_index} failed: {e}") | |
| raise HTTPException(500, "Chunk synthesis failed") | |
| fmt = (request.output_format or "mp3").lower() | |
| if fmt not in {"mp3", "wav", "flac"}: | |
| fmt = "mp3" | |
| data, media_type = _encode_audio(audio, fmt=fmt) | |
| return Response( | |
| content=data, | |
| media_type=media_type, | |
| headers={ | |
| "X-Stream-Id": request.stream_id, | |
| "X-Chunk-Index": str(request.chunk_index), | |
| }, | |
| ) | |
| async def internal_chunk_cancel(stream_id: str, http_request: Request): | |
| if Config.INTERNAL_SHARED_SECRET: | |
| provided = http_request.headers.get("X-Internal-Secret", "") | |
| if provided != Config.INTERNAL_SHARED_SECRET: | |
| raise HTTPException(403, "Forbidden") | |
| with _internal_cancel_lock: | |
| _internal_cancelled_streams.add(stream_id) | |
| _internal_stream_voice_keys.pop(stream_id, None) | |
| return {"status": "cancelled", "stream_id": stream_id} | |
| async def openai_compatible_tts(request: TTSJsonRequest): | |
| """OpenAI-compatible streaming endpoint (JSON body, no file upload). | |
| Uses built-in voice selection via `voice`. For voice cloning, use /tts/stream with FormData. | |
| """ | |
| wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None) | |
| if not wrapper: | |
| raise HTTPException(503, "Model not loaded") | |
| try: | |
| selected_voice = wrapper.get_builtin_voice(request.voice) | |
| except ValueError as e: | |
| raise HTTPException(400, str(e)) | |
| stream_id = uuid.uuid4().hex[:12] | |
| return StreamingResponse( | |
| _pipeline_stream_generator( | |
| wrapper, request.text, selected_voice, | |
| request.max_new_tokens, request.repetition_penalty, stream_id, | |
| ), | |
| media_type="audio/mpeg", | |
| headers={ | |
| "Transfer-Encoding": "chunked", | |
| "X-Stream-Id": stream_id, | |
| "Cache-Control": "no-cache", | |
| }, | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Stop / Cancel Endpoint | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def stop_stream(stream_id: str): | |
| """Stop an active TTS stream by its ID (from X-Stream-Id header). | |
| Cancels the ONNX generation loop mid-token, freeing CPU immediately. | |
| """ | |
| event = _active_streams.get(stream_id) | |
| if event: | |
| event.set() | |
| logger.info(f"Stream {stream_id} cancelled by client") | |
| return {"status": "stopped", "stream_id": stream_id} | |
| return {"status": "not_found", "stream_id": stream_id} | |
| async def stop_all_streams(): | |
| """Emergency stop: cancel ALL active TTS streams.""" | |
| count = len(_active_streams) | |
| for sid, event in list(_active_streams.items()): | |
| event.set() | |
| _active_streams.clear() | |
| logger.info(f"Stopped all streams ({count} active)") | |
| return {"status": "stopped_all", "count": count} | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Entrypoint | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host=Config.HOST, port=Config.PORT) | |