File size: 6,680 Bytes
fc77df5 0a9dfed fc77df5 0a9dfed fc77df5 9c5a6e7 3b000e9 0a9dfed dccec3c 0a9dfed fc77df5 3e9ae46 3b000e9 9c5a6e7 dccec3c fc77df5 0a9dfed fc77df5 9c5a6e7 fc77df5 0a9dfed fc77df5 0a9dfed fc77df5 9c5a6e7 fc77df5 9c5a6e7 fc77df5 9c5a6e7 3e9ae46 fc77df5 dccec3c 0a9dfed fc77df5 176ee90 3b000e9 fc77df5 0a9dfed fc77df5 0a9dfed dccec3c 0a9dfed fc77df5 0a9dfed fc77df5 0a9dfed fc77df5 dccec3c fc77df5 dccec3c fc77df5 dccec3c fc77df5 0a9dfed fc77df5 0a9dfed fc77df5 0a9dfed fc77df5 0a9dfed fc77df5 0a9dfed fc77df5 0a9dfed fc77df5 0a9dfed fc77df5 0a9dfed | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 | from __future__ import annotations
import asyncio
import json
import os
import uuid
from typing import Any, Dict, Optional
import numpy as np
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import JSONResponse
from fastrtc import Stream, ReplyOnPause, get_stt_model, get_tts_model
from .gemini_text import (
gemini_chat_turn,
get_session,
deliver_function_result,
)
app = FastAPI()
# ----------------------------
# FastRTC Voice Chat (VAD + STT + TTS)
# ----------------------------
# These are CPU-friendly, but still heavy on Spaces. Keep them global.
STT_MODEL_NAME = os.getenv("FASTRTC_STT_MODEL", "moonshine/tiny")
TTS_MODEL_NAME = os.getenv("FASTRTC_TTS_MODEL", "kokoro")
stt = get_stt_model(model=STT_MODEL_NAME)
tts = get_tts_model(model=TTS_MODEL_NAME)
def _voice_reply_fn(audio: tuple[int, np.ndarray]):
"""
Called when the user pauses (VAD). Returns streamed audio frames (TTS).
"""
# audio is (sample_rate, int16 mono ndarray)
# FastRTC STT expects "audio" in the same tuple form per docs examples.
user_text = stt.stt(audio).strip()
if not user_text:
return
# For voice sessions we create a synthetic session_id (not Scratch ws session)
# because FastRTC’s ReplyOnPause fn signature doesn’t expose the RTC session id.
# This keeps a stable conversation state per-process, but not per-user.
#
# If you need per-user memory for voice, we can switch to a stateful StreamHandler later.
voice_session_id = "voice-global"
async def run():
# No tool bounce for voice by default (still supported via same session registry if you want)
async def noop_emit(_evt: dict):
return
text = await gemini_chat_turn(
session_id=voice_session_id,
user_text=user_text,
emit_event=noop_emit,
model=os.getenv("GEMINI_TEXT_MODEL", "gemini-2.0-flash"),
)
return text
text = asyncio.get_event_loop().run_until_complete(run())
# Stream TTS back
for chunk in tts.stream_tts_sync(text):
# chunk is already an audio frame compatible with FastRTC
yield chunk
voice_stream = Stream(
modality="audio",
mode="send-receive",
handler=ReplyOnPause(_voice_reply_fn),
)
# Mount FastRTC endpoints (WebRTC + WebSocket) under /rtc
voice_stream.mount(app, path="/rtc")
# ----------------------------
# Scratch-friendly WebSocket API (text + function calling)
# ----------------------------
@app.get("/")
async def root():
return JSONResponse(
{
"ok": True,
"service": "salexai-api",
"ws": "/ws",
"fastrtc": "/rtc",
"notes": [
"Use /ws for Scratch JSON chat + function calling.",
"Use /rtc for FastRTC voice chat endpoints (VAD/STT/TTS handled by FastRTC).",
],
}
)
@app.websocket("/ws")
async def ws_endpoint(ws: WebSocket):
await ws.accept()
session_id: Optional[str] = None
async def emit(evt: dict):
await ws.send_text(json.dumps(evt))
try:
while True:
raw = await ws.receive_text()
msg = json.loads(raw) if raw else {}
mtype = msg.get("type")
if mtype == "connect":
session_id = msg.get("session_id") or str(uuid.uuid4())
get_session(session_id) # ensure exists
await emit({"type": "ready", "session_id": session_id})
continue
if not session_id:
await emit({"type": "error", "message": "Not connected. Send {type:'connect'} first."})
continue
# -------- function registry --------
if mtype == "add_function":
name = str(msg.get("name") or "").strip()
schema = msg.get("schema") or {}
if not name:
await emit({"type": "error", "message": "add_function missing name"})
continue
s = get_session(session_id)
s.functions[name] = schema
await emit({"type": "function_added", "name": name})
continue
if mtype == "remove_function":
name = str(msg.get("name") or "").strip()
s = get_session(session_id)
if name in s.functions:
s.functions.pop(name, None)
await emit({"type": "function_removed", "name": name})
else:
await emit({"type": "warning", "message": f"Function not found: {name}"})
continue
if mtype == "list_functions":
s = get_session(session_id)
await emit({"type": "functions", "items": list(s.functions.keys())})
continue
# Client returns tool results
if mtype == "function_result":
call_id = msg.get("call_id")
result = msg.get("result")
if not call_id:
await emit({"type": "error", "message": "function_result missing call_id"})
continue
ok = deliver_function_result(session_id, call_id, result)
if not ok:
await emit({"type": "warning", "message": f"No pending call_id: {call_id}"})
else:
await emit({"type": "function_result_ack", "call_id": call_id})
continue
# -------- chat --------
if mtype == "send":
text = str(msg.get("text") or "")
if not text.strip():
await emit({"type": "error", "message": "Empty text"})
continue
try:
assistant_text = await gemini_chat_turn(
session_id=session_id,
user_text=text,
emit_event=emit, # this is where tool calls get emitted
model=os.getenv("GEMINI_TEXT_MODEL", "gemini-2.0-flash"),
)
await emit({"type": "assistant", "text": assistant_text})
except Exception as e:
await emit({"type": "error", "message": f"Gemini error: {e}"})
continue
await emit({"type": "error", "message": f"Unknown type: {mtype}"})
except WebSocketDisconnect:
return
except Exception as e:
try:
await emit({"type": "error", "message": f"WS crashed: {e}"})
except Exception:
pass
|