SalexAI commited on
Commit
fc77df5
·
verified ·
1 Parent(s): ed9acde

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +145 -106
app/main.py CHANGED
@@ -1,161 +1,200 @@
 
 
1
  import asyncio
2
- import base64
3
  import json
 
4
  import uuid
5
- from typing import Optional, Literal
6
 
7
  import numpy as np
8
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
9
  from fastapi.responses import JSONResponse
10
 
11
- from fastrtc import Stream, AsyncStreamHandler, wait_for_item, AdditionalOutputs
 
 
 
 
 
 
12
 
13
  app = FastAPI()
14
 
15
 
16
- # ---------------------------
17
- # A tiny headless audio handler (echo) to validate the pipe.
18
- # Swap this out later for Gemini / other realtime models.
19
- # ---------------------------
20
- class EchoHandler(AsyncStreamHandler):
21
- def __init__(self, expected_layout: Literal["mono"] = "mono", output_sample_rate: int = 24000):
22
- super().__init__(expected_layout=expected_layout, output_sample_rate=output_sample_rate, input_sample_rate=16000)
23
- self.out_q: asyncio.Queue[tuple[int, np.ndarray] | AdditionalOutputs] = asyncio.Queue()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- def copy(self):
26
- return EchoHandler()
 
 
27
 
28
- async def receive(self, frame: tuple[int, np.ndarray]) -> None:
29
- sr, audio = frame
30
- audio = np.asarray(audio)
31
- if audio.ndim == 2:
32
- audio = audio.squeeze()
33
- if audio.dtype != np.int16:
34
- audio = audio.astype(np.int16)
35
 
36
- # Echo straight back
37
- self.out_q.put_nowait((sr, audio.reshape(1, -1)))
38
 
39
- async def emit(self):
40
- return await wait_for_item(self.out_q)
 
 
41
 
42
 
43
- # IMPORTANT: no additional_inputs here (strings crash in 0.0.34)
44
- stream = Stream(
45
- handler=EchoHandler(),
46
  modality="audio",
47
  mode="send-receive",
 
48
  )
49
 
50
- # This mounts FastRTC’s internal routes; we’re also adding /ws below.
51
- stream.mount(app)
52
-
53
 
54
- # ---------------------------
55
- # Helpers
56
- # ---------------------------
57
- def b64_to_int16(b64: str) -> np.ndarray:
58
- raw = base64.b64decode(b64)
59
- return np.frombuffer(raw, dtype=np.int16)
60
-
61
- def int16_to_b64(audio: np.ndarray) -> str:
62
- if audio.dtype != np.int16:
63
- audio = audio.astype(np.int16)
64
- return base64.b64encode(audio.tobytes()).decode("utf-8")
65
 
 
 
 
66
 
67
  @app.get("/")
68
  async def root():
69
- return {"ok": True, "message": "FastRTC mounted. Headless mode. Use /ws for Scratch."}
70
-
71
- @app.get("/health")
72
- async def health():
73
- return {"ok": True}
 
 
 
 
 
 
 
74
 
75
 
76
- # ---------------------------
77
- # Scratch-friendly WS bridge (no WebRTC needed client-side)
78
- # ---------------------------
79
  @app.websocket("/ws")
80
- async def ws_bridge(ws: WebSocket):
81
  await ws.accept()
82
 
83
  session_id: Optional[str] = None
84
- pump_task: Optional[asyncio.Task] = None
85
 
86
- async def pump_outputs():
87
- """
88
- Pull audio outputs from FastRTC and forward to client.
89
- NOTE: FastRTC 0.0.34 uses fetch_output() polling style.
90
- """
91
- try:
92
- while True:
93
- out = await stream.fetch_output(session_id)
94
- if out is None:
95
- await asyncio.sleep(0.01)
96
- continue
97
-
98
- if isinstance(out, AdditionalOutputs):
99
- payload = out.args[0] if out.args else None
100
- await ws.send_text(json.dumps({"type": "output", "data": payload}))
101
- continue
102
-
103
- sr, audio = out
104
- audio = np.asarray(audio)
105
- if audio.ndim == 2:
106
- audio = audio.squeeze()
107
- if audio.dtype != np.int16:
108
- audio = audio.astype(np.int16)
109
-
110
- await ws.send_text(json.dumps({
111
- "type": "audio_delta",
112
- "rate": int(sr),
113
- "data": int16_to_b64(audio),
114
- }))
115
- except Exception:
116
- return
117
 
118
  try:
119
  while True:
120
  raw = await ws.receive_text()
121
- msg = json.loads(raw)
122
- t = msg.get("type")
 
123
 
124
- if t == "connect":
125
  session_id = msg.get("session_id") or str(uuid.uuid4())
 
 
 
 
 
 
 
126
 
127
- # start pump once
128
- if pump_task is None:
129
- pump_task = asyncio.create_task(pump_outputs())
130
 
131
- await ws.send_text(json.dumps({"type": "ready", "session_id": session_id}))
 
 
 
 
 
 
 
 
132
  continue
133
 
134
- if t == "audio":
135
- if not session_id:
136
- await ws.send_text(json.dumps({"type": "error", "message": "Send {type:'connect'} first."}))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  continue
 
 
 
 
 
 
 
 
138
 
139
- b64 = msg.get("data")
140
- rate = int(msg.get("rate") or 16000)
141
- if not isinstance(b64, str) or not b64:
 
142
  continue
143
 
144
- audio = b64_to_int16(b64)
145
- await stream.send_input(session_id, (rate, audio.reshape(1, -1)))
 
 
 
 
 
 
 
 
146
  continue
147
 
148
- if t == "close":
149
- await ws.close()
150
- return
151
-
152
- await ws.send_text(json.dumps({"type": "error", "message": f"Unknown type: {t}"}))
153
 
154
  except WebSocketDisconnect:
155
- pass
156
- finally:
157
  try:
158
- if pump_task:
159
- pump_task.cancel()
160
  except Exception:
161
  pass
 
1
+ from __future__ import annotations
2
+
3
  import asyncio
 
4
  import json
5
+ import os
6
  import uuid
7
+ from typing import Any, Dict, Optional
8
 
9
  import numpy as np
10
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
11
  from fastapi.responses import JSONResponse
12
 
13
+ from fastrtc import Stream, ReplyOnPause, get_stt_model, get_tts_model
14
+
15
+ from .gemini_text import (
16
+ gemini_chat_turn,
17
+ get_session,
18
+ deliver_function_result,
19
+ )
20
 
21
  app = FastAPI()
22
 
23
 
24
+ # ----------------------------
25
+ # FastRTC Voice Chat (VAD + STT + TTS)
26
+ # ----------------------------
27
+
28
+ # These are CPU-friendly, but still heavy on Spaces. Keep them global.
29
+ STT_MODEL_NAME = os.getenv("FASTRTC_STT_MODEL", "moonshine/tiny")
30
+ TTS_MODEL_NAME = os.getenv("FASTRTC_TTS_MODEL", "kokoro")
31
+
32
+ stt = get_stt_model(model=STT_MODEL_NAME)
33
+ tts = get_tts_model(model=TTS_MODEL_NAME)
34
+
35
+
36
+ def _voice_reply_fn(audio: tuple[int, np.ndarray]):
37
+ """
38
+ Called when the user pauses (VAD). Returns streamed audio frames (TTS).
39
+ """
40
+ # audio is (sample_rate, int16 mono ndarray)
41
+ # FastRTC STT expects "audio" in the same tuple form per docs examples.
42
+ user_text = stt.stt(audio).strip()
43
+ if not user_text:
44
+ return
45
+
46
+ # For voice sessions we create a synthetic session_id (not Scratch ws session)
47
+ # because FastRTC’s ReplyOnPause fn signature doesn’t expose the RTC session id.
48
+ # This keeps a stable conversation state per-process, but not per-user.
49
+ #
50
+ # If you need per-user memory for voice, we can switch to a stateful StreamHandler later.
51
+ voice_session_id = "voice-global"
52
 
53
+ async def run():
54
+ # No tool bounce for voice by default (still supported via same session registry if you want)
55
+ async def noop_emit(_evt: dict):
56
+ return
57
 
58
+ text = await gemini_chat_turn(
59
+ session_id=voice_session_id,
60
+ user_text=user_text,
61
+ emit_event=noop_emit,
62
+ model=os.getenv("GEMINI_TEXT_MODEL", "gemini-2.0-flash"),
63
+ )
64
+ return text
65
 
66
+ text = asyncio.get_event_loop().run_until_complete(run())
 
67
 
68
+ # Stream TTS back
69
+ for chunk in tts.stream_tts_sync(text):
70
+ # chunk is already an audio frame compatible with FastRTC
71
+ yield chunk
72
 
73
 
74
+ voice_stream = Stream(
 
 
75
  modality="audio",
76
  mode="send-receive",
77
+ handler=ReplyOnPause(_voice_reply_fn),
78
  )
79
 
80
+ # Mount FastRTC endpoints (WebRTC + WebSocket) under /rtc
81
+ voice_stream.mount(app, path="/rtc")
 
82
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
+ # ----------------------------
85
+ # Scratch-friendly WebSocket API (text + function calling)
86
+ # ----------------------------
87
 
88
  @app.get("/")
89
  async def root():
90
+ return JSONResponse(
91
+ {
92
+ "ok": True,
93
+ "service": "salexai-api",
94
+ "ws": "/ws",
95
+ "fastrtc": "/rtc",
96
+ "notes": [
97
+ "Use /ws for Scratch JSON chat + function calling.",
98
+ "Use /rtc for FastRTC voice chat endpoints (VAD/STT/TTS handled by FastRTC).",
99
+ ],
100
+ }
101
+ )
102
 
103
 
 
 
 
104
  @app.websocket("/ws")
105
+ async def ws_endpoint(ws: WebSocket):
106
  await ws.accept()
107
 
108
  session_id: Optional[str] = None
 
109
 
110
+ async def emit(evt: dict):
111
+ await ws.send_text(json.dumps(evt))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  try:
114
  while True:
115
  raw = await ws.receive_text()
116
+ msg = json.loads(raw) if raw else {}
117
+
118
+ mtype = msg.get("type")
119
 
120
+ if mtype == "connect":
121
  session_id = msg.get("session_id") or str(uuid.uuid4())
122
+ get_session(session_id) # ensure exists
123
+ await emit({"type": "ready", "session_id": session_id})
124
+ continue
125
+
126
+ if not session_id:
127
+ await emit({"type": "error", "message": "Not connected. Send {type:'connect'} first."})
128
+ continue
129
 
130
+ # -------- function registry --------
 
 
131
 
132
+ if mtype == "add_function":
133
+ name = str(msg.get("name") or "").strip()
134
+ schema = msg.get("schema") or {}
135
+ if not name:
136
+ await emit({"type": "error", "message": "add_function missing name"})
137
+ continue
138
+ s = get_session(session_id)
139
+ s.functions[name] = schema
140
+ await emit({"type": "function_added", "name": name})
141
  continue
142
 
143
+ if mtype == "remove_function":
144
+ name = str(msg.get("name") or "").strip()
145
+ s = get_session(session_id)
146
+ if name in s.functions:
147
+ s.functions.pop(name, None)
148
+ await emit({"type": "function_removed", "name": name})
149
+ else:
150
+ await emit({"type": "warning", "message": f"Function not found: {name}"})
151
+ continue
152
+
153
+ if mtype == "list_functions":
154
+ s = get_session(session_id)
155
+ await emit({"type": "functions", "items": list(s.functions.keys())})
156
+ continue
157
+
158
+ # Client returns tool results
159
+ if mtype == "function_result":
160
+ call_id = msg.get("call_id")
161
+ result = msg.get("result")
162
+ if not call_id:
163
+ await emit({"type": "error", "message": "function_result missing call_id"})
164
  continue
165
+ ok = deliver_function_result(session_id, call_id, result)
166
+ if not ok:
167
+ await emit({"type": "warning", "message": f"No pending call_id: {call_id}"})
168
+ else:
169
+ await emit({"type": "function_result_ack", "call_id": call_id})
170
+ continue
171
+
172
+ # -------- chat --------
173
 
174
+ if mtype == "send":
175
+ text = str(msg.get("text") or "")
176
+ if not text.strip():
177
+ await emit({"type": "error", "message": "Empty text"})
178
  continue
179
 
180
+ try:
181
+ assistant_text = await gemini_chat_turn(
182
+ session_id=session_id,
183
+ user_text=text,
184
+ emit_event=emit, # this is where tool calls get emitted
185
+ model=os.getenv("GEMINI_TEXT_MODEL", "gemini-2.0-flash"),
186
+ )
187
+ await emit({"type": "assistant", "text": assistant_text})
188
+ except Exception as e:
189
+ await emit({"type": "error", "message": f"Gemini error: {e}"})
190
  continue
191
 
192
+ await emit({"type": "error", "message": f"Unknown type: {mtype}"})
 
 
 
 
193
 
194
  except WebSocketDisconnect:
195
+ return
196
+ except Exception as e:
197
  try:
198
+ await emit({"type": "error", "message": f"WS crashed: {e}"})
 
199
  except Exception:
200
  pass