| import asyncio |
| import logging |
| import os |
| import signal |
| import subprocess |
| from contextlib import asynccontextmanager |
| from typing import Optional |
|
|
| from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect |
| from fastapi.responses import JSONResponse |
| import websockets |
| from websockets.exceptions import ConnectionClosed |
|
|
| logging.basicConfig( |
| level=os.getenv("LOG_LEVEL", "INFO").upper(), |
| format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", |
| ) |
| logger = logging.getLogger("s2s-endpoint") |
|
|
| PORT = int(os.getenv("PORT", "7860")) |
|
|
| INTERNAL_WS_HOST = os.getenv("INTERNAL_WS_HOST", "127.0.0.1") |
| INTERNAL_WS_PORT = int(os.getenv("INTERNAL_WS_PORT", "9000")) |
| INTERNAL_WS_URL = f"ws://{INTERNAL_WS_HOST}:{INTERNAL_WS_PORT}" |
|
|
| S2S_REPO_DIR = os.getenv("S2S_REPO_DIR", "/opt/speech-to-speech") |
|
|
| |
| DEVICE = os.getenv("DEVICE", "cuda").strip() |
| LANGUAGE = os.getenv("LANGUAGE", "en").strip() |
| CHAT_SIZE = os.getenv("CHAT_SIZE", "10").strip() |
|
|
| STT = os.getenv("STT", "parakeet-tdt").strip() |
| LLM = os.getenv("LLM", "open_api").strip() |
| TTS = os.getenv("TTS", "qwen3").strip() |
|
|
| |
| ENABLE_LIVE_TRANSCRIPTION = os.getenv("ENABLE_LIVE_TRANSCRIPTION", "1").strip().lower() in {"1", "true", "yes"} |
| LIVE_TRANSCRIPTION_UPDATE_INTERVAL = os.getenv("LIVE_TRANSCRIPTION_UPDATE_INTERVAL", "").strip() |
|
|
| |
| STT_COMPILE_MODE = os.getenv("STT_COMPILE_MODE", "").strip() |
|
|
| |
| OPEN_API_MODEL_NAME = os.getenv("OPEN_API_MODEL_NAME", "Qwen/Qwen3.5-9B:together").strip() |
| OPEN_API_BASE_URL = os.getenv("OPEN_API_BASE_URL", "https://router.huggingface.co/v1").strip() |
| OPEN_API_API_KEY = os.getenv("OPEN_API_API_KEY", "").strip() or os.getenv("HF_TOKEN", "").strip() |
| OPEN_API_STREAM = os.getenv("OPEN_API_STREAM", "1").strip().lower() in {"1", "true", "yes"} |
| OPEN_API_INIT_CHAT_PROMPT = os.getenv("OPEN_API_INIT_CHAT_PROMPT", "").strip() |
| OPEN_API_IMAGE_PATHS = os.getenv("OPEN_API_IMAGE_PATHS", "").strip() |
|
|
| |
| EXTRA_S2S_ARGS = os.getenv("EXTRA_S2S_ARGS", "").strip() |
|
|
| pipeline_process: Optional[subprocess.Popen] = None |
|
|
|
|
| def _add_bool_flag(cmd: list[str], enabled: bool, flag: str) -> None: |
| if enabled: |
| cmd.append(flag) |
|
|
|
|
| def _add_str_flag(cmd: list[str], value: str, flag: str) -> None: |
| if value: |
| cmd.extend([flag, value]) |
|
|
|
|
| def build_s2s_command() -> list[str]: |
| cmd = [ |
| "uv", |
| "run", |
| "--directory", |
| S2S_REPO_DIR, |
| "python", |
| "s2s_pipeline.py", |
| "--mode", |
| "websocket", |
| "--ws_host", |
| INTERNAL_WS_HOST, |
| "--ws_port", |
| str(INTERNAL_WS_PORT), |
| "--device", |
| DEVICE, |
| "--language", |
| LANGUAGE, |
| "--chat_size", |
| CHAT_SIZE, |
| "--stt", |
| STT, |
| "--llm", |
| LLM, |
| "--tts", |
| TTS, |
| ] |
|
|
| |
| _add_bool_flag(cmd, ENABLE_LIVE_TRANSCRIPTION, "--enable_live_transcription") |
| _add_str_flag(cmd, LIVE_TRANSCRIPTION_UPDATE_INTERVAL, "--live_transcription_update_interval") |
|
|
| |
| if STT_COMPILE_MODE and STT in {"whisper", "faster-whisper"}: |
| cmd.extend(["--stt_compile_mode", STT_COMPILE_MODE]) |
|
|
| |
| if LLM == "open_api": |
| _add_str_flag(cmd, OPEN_API_MODEL_NAME, "--open_api_model_name") |
| _add_str_flag(cmd, OPEN_API_BASE_URL, "--open_api_base_url") |
| _add_str_flag(cmd, OPEN_API_API_KEY, "--open_api_api_key") |
| _add_bool_flag(cmd, OPEN_API_STREAM, "--open_api_stream") |
| _add_str_flag(cmd, OPEN_API_INIT_CHAT_PROMPT, "--open_api_init_chat_prompt") |
|
|
| |
| _add_str_flag(cmd, OPEN_API_IMAGE_PATHS, "--open_api_image_paths") |
|
|
| if EXTRA_S2S_ARGS: |
| cmd.extend(EXTRA_S2S_ARGS.split()) |
|
|
| return cmd |
|
|
|
|
| async def wait_for_internal_ws(timeout_s: float = 900.0) -> None: |
| start = asyncio.get_event_loop().time() |
| last_error = None |
|
|
| while True: |
| if pipeline_process is not None and pipeline_process.poll() is not None: |
| raise RuntimeError( |
| f"speech-to-speech process exited early with code {pipeline_process.returncode}" |
| ) |
|
|
| try: |
| async with websockets.connect( |
| INTERNAL_WS_URL, |
| open_timeout=5, |
| ping_interval=None, |
| max_size=None, |
| ): |
| logger.info("Internal speech-to-speech websocket is ready at %s", INTERNAL_WS_URL) |
| return |
| except Exception as exc: |
| last_error = exc |
|
|
| if asyncio.get_event_loop().time() - start > timeout_s: |
| raise RuntimeError( |
| f"Timed out waiting for internal websocket server at {INTERNAL_WS_URL}. " |
| f"Last error: {last_error}" |
| ) |
|
|
| await asyncio.sleep(2.0) |
|
|
|
|
| def start_pipeline() -> None: |
| global pipeline_process |
|
|
| if pipeline_process is not None and pipeline_process.poll() is None: |
| logger.info("speech-to-speech process already running") |
| return |
|
|
| cmd = build_s2s_command() |
| logger.info("Starting speech-to-speech subprocess:\n%s", " ".join(cmd)) |
|
|
| env = os.environ.copy() |
|
|
| pipeline_process = subprocess.Popen( |
| cmd, |
| cwd=S2S_REPO_DIR, |
| env=env, |
| stdout=None, |
| stderr=None, |
| preexec_fn=os.setsid if os.name != "nt" else None, |
| ) |
|
|
|
|
| def stop_pipeline() -> None: |
| global pipeline_process |
|
|
| if pipeline_process is None: |
| return |
|
|
| if pipeline_process.poll() is not None: |
| logger.info("speech-to-speech process already stopped") |
| return |
|
|
| logger.info("Stopping speech-to-speech subprocess") |
|
|
| try: |
| if os.name != "nt": |
| os.killpg(os.getpgid(pipeline_process.pid), signal.SIGTERM) |
| else: |
| pipeline_process.terminate() |
| pipeline_process.wait(timeout=20) |
| except Exception: |
| logger.exception("Graceful shutdown failed, killing subprocess") |
| try: |
| if os.name != "nt": |
| os.killpg(os.getpgid(pipeline_process.pid), signal.SIGKILL) |
| else: |
| pipeline_process.kill() |
| except Exception: |
| logger.exception("Failed to kill subprocess") |
| finally: |
| pipeline_process = None |
|
|
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| start_pipeline() |
| try: |
| yield |
| finally: |
| stop_pipeline() |
|
|
|
|
| app = FastAPI(lifespan=lifespan) |
|
|
|
|
| @app.get("/") |
| async def root(): |
| return { |
| "message": "s2s endpoint is up", |
| "health": "/health", |
| "websocket": "/ws", |
| "internal_ws": INTERNAL_WS_URL, |
| "config": { |
| "stt": STT, |
| "llm": LLM, |
| "tts": TTS, |
| "device": DEVICE, |
| "language": LANGUAGE, |
| }, |
| } |
|
|
|
|
| @app.get("/health") |
| async def health(): |
| if pipeline_process is None: |
| raise HTTPException(status_code=503, detail="speech-to-speech process not started") |
|
|
| if pipeline_process.poll() is not None: |
| raise HTTPException( |
| status_code=503, |
| detail=f"speech-to-speech process exited with code {pipeline_process.returncode}", |
| ) |
|
|
| try: |
| await asyncio.wait_for(wait_for_internal_ws(timeout_s=5), timeout=6) |
| except Exception as exc: |
| raise HTTPException(status_code=503, detail=f"internal websocket not ready: {exc}") from exc |
|
|
| return JSONResponse( |
| { |
| "status": "ok", |
| "internal_ws": INTERNAL_WS_URL, |
| "stt": STT, |
| "llm": LLM, |
| "tts": TTS, |
| } |
| ) |
|
|
|
|
| @app.websocket("/ws") |
| async def websocket_proxy(client_ws: WebSocket): |
| await client_ws.accept() |
| logger.info("Client websocket connected") |
|
|
| try: |
| async with websockets.connect( |
| INTERNAL_WS_URL, |
| open_timeout=30, |
| ping_interval=20, |
| ping_timeout=20, |
| max_size=None, |
| ) as upstream_ws: |
|
|
| async def client_to_upstream(): |
| while True: |
| message = await client_ws.receive() |
|
|
| if message["type"] == "websocket.disconnect": |
| raise WebSocketDisconnect() |
|
|
| if "bytes" in message and message["bytes"] is not None: |
| await upstream_ws.send(message["bytes"]) |
| elif "text" in message and message["text"] is not None: |
| await upstream_ws.send(message["text"]) |
|
|
| async def upstream_to_client(): |
| while True: |
| msg = await upstream_ws.recv() |
| if isinstance(msg, bytes): |
| await client_ws.send_bytes(msg) |
| else: |
| await client_ws.send_text(msg) |
|
|
| await asyncio.gather(client_to_upstream(), upstream_to_client()) |
|
|
| except WebSocketDisconnect: |
| logger.info("Client websocket disconnected") |
| except ConnectionClosed: |
| logger.info("Upstream websocket disconnected") |
| try: |
| await client_ws.close() |
| except Exception: |
| pass |
| except Exception: |
| logger.exception("Websocket proxy failed") |
| try: |
| await client_ws.close(code=1011, reason="Proxy failure") |
| except Exception: |
| pass |
|
|