| from __future__ import annotations |
|
|
| import asyncio |
| import json |
| import os |
| import uuid |
| from typing import Any, Dict, Optional |
|
|
| import numpy as np |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect |
| from fastapi.responses import JSONResponse |
|
|
| from fastrtc import Stream, ReplyOnPause, get_stt_model, get_tts_model |
|
|
| from .gemini_text import ( |
| gemini_chat_turn, |
| get_session, |
| deliver_function_result, |
| ) |
|
|
| app = FastAPI() |
|
|
|
|
| |
| |
| |
|
|
| |
| STT_MODEL_NAME = os.getenv("FASTRTC_STT_MODEL", "moonshine/tiny") |
| TTS_MODEL_NAME = os.getenv("FASTRTC_TTS_MODEL", "kokoro") |
|
|
| stt = get_stt_model(model=STT_MODEL_NAME) |
| tts = get_tts_model(model=TTS_MODEL_NAME) |
|
|
|
|
| def _voice_reply_fn(audio: tuple[int, np.ndarray]): |
| """ |
| Called when the user pauses (VAD). Returns streamed audio frames (TTS). |
| """ |
| |
| |
| user_text = stt.stt(audio).strip() |
| if not user_text: |
| return |
|
|
| |
| |
| |
| |
| |
| voice_session_id = "voice-global" |
|
|
| async def run(): |
| |
| async def noop_emit(_evt: dict): |
| return |
|
|
| text = await gemini_chat_turn( |
| session_id=voice_session_id, |
| user_text=user_text, |
| emit_event=noop_emit, |
| model=os.getenv("GEMINI_TEXT_MODEL", "gemini-2.0-flash"), |
| ) |
| return text |
|
|
| text = asyncio.get_event_loop().run_until_complete(run()) |
|
|
| |
| for chunk in tts.stream_tts_sync(text): |
| |
| yield chunk |
|
|
|
|
| voice_stream = Stream( |
| modality="audio", |
| mode="send-receive", |
| handler=ReplyOnPause(_voice_reply_fn), |
| ) |
|
|
| |
| voice_stream.mount(app, path="/rtc") |
|
|
|
|
| |
| |
| |
|
|
| @app.get("/") |
| async def root(): |
| return JSONResponse( |
| { |
| "ok": True, |
| "service": "salexai-api", |
| "ws": "/ws", |
| "fastrtc": "/rtc", |
| "notes": [ |
| "Use /ws for Scratch JSON chat + function calling.", |
| "Use /rtc for FastRTC voice chat endpoints (VAD/STT/TTS handled by FastRTC).", |
| ], |
| } |
| ) |
|
|
|
|
| @app.websocket("/ws") |
| async def ws_endpoint(ws: WebSocket): |
| await ws.accept() |
|
|
| session_id: Optional[str] = None |
|
|
| async def emit(evt: dict): |
| await ws.send_text(json.dumps(evt)) |
|
|
| try: |
| while True: |
| raw = await ws.receive_text() |
| msg = json.loads(raw) if raw else {} |
|
|
| mtype = msg.get("type") |
|
|
| if mtype == "connect": |
| session_id = msg.get("session_id") or str(uuid.uuid4()) |
| get_session(session_id) |
| await emit({"type": "ready", "session_id": session_id}) |
| continue |
|
|
| if not session_id: |
| await emit({"type": "error", "message": "Not connected. Send {type:'connect'} first."}) |
| continue |
|
|
| |
|
|
| if mtype == "add_function": |
| name = str(msg.get("name") or "").strip() |
| schema = msg.get("schema") or {} |
| if not name: |
| await emit({"type": "error", "message": "add_function missing name"}) |
| continue |
| s = get_session(session_id) |
| s.functions[name] = schema |
| await emit({"type": "function_added", "name": name}) |
| continue |
|
|
| if mtype == "remove_function": |
| name = str(msg.get("name") or "").strip() |
| s = get_session(session_id) |
| if name in s.functions: |
| s.functions.pop(name, None) |
| await emit({"type": "function_removed", "name": name}) |
| else: |
| await emit({"type": "warning", "message": f"Function not found: {name}"}) |
| continue |
|
|
| if mtype == "list_functions": |
| s = get_session(session_id) |
| await emit({"type": "functions", "items": list(s.functions.keys())}) |
| continue |
|
|
| |
| if mtype == "function_result": |
| call_id = msg.get("call_id") |
| result = msg.get("result") |
| if not call_id: |
| await emit({"type": "error", "message": "function_result missing call_id"}) |
| continue |
| ok = deliver_function_result(session_id, call_id, result) |
| if not ok: |
| await emit({"type": "warning", "message": f"No pending call_id: {call_id}"}) |
| else: |
| await emit({"type": "function_result_ack", "call_id": call_id}) |
| continue |
|
|
| |
|
|
| if mtype == "send": |
| text = str(msg.get("text") or "") |
| if not text.strip(): |
| await emit({"type": "error", "message": "Empty text"}) |
| continue |
|
|
| try: |
| assistant_text = await gemini_chat_turn( |
| session_id=session_id, |
| user_text=text, |
| emit_event=emit, |
| model=os.getenv("GEMINI_TEXT_MODEL", "gemini-2.0-flash"), |
| ) |
| await emit({"type": "assistant", "text": assistant_text}) |
| except Exception as e: |
| await emit({"type": "error", "message": f"Gemini error: {e}"}) |
| continue |
|
|
| await emit({"type": "error", "message": f"Unknown type: {mtype}"}) |
|
|
| except WebSocketDisconnect: |
| return |
| except Exception as e: |
| try: |
| await emit({"type": "error", "message": f"WS crashed: {e}"}) |
| except Exception: |
| pass |
|
|