SalexAI commited on
Commit
0a9dfed
·
verified ·
1 Parent(s): ce52252

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +180 -15
app/main.py CHANGED
@@ -1,33 +1,198 @@
1
- from fastapi import FastAPI
2
- from fastapi.responses import JSONResponse
 
 
 
 
3
 
4
  import numpy as np
5
- from fastrtc import Stream, ReplyOnPause
 
 
 
 
 
 
 
 
 
6
 
7
  app = FastAPI()
8
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- # Simple echo handler to verify your stream works end-to-end
11
- def echo(audio: tuple[int, np.ndarray]):
12
- # audio is (sample_rate, int16 numpy array)
13
- yield audio
 
 
 
 
 
 
 
 
 
14
 
15
 
16
  stream = Stream(
17
- handler=ReplyOnPause(echo), # VAD-ish turn-taking
18
  modality="audio",
19
  mode="send-receive",
 
20
  )
21
 
22
- # Mount FastRTC endpoints onto this FastAPI app
23
  stream.mount(app)
24
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
 
 
 
26
  @app.get("/")
27
  async def root():
28
- return JSONResponse(
29
- {
30
- "ok": True,
31
- "message": "FastRTC mounted. Use the mounted endpoints for WebRTC/WebSocket.",
32
- }
33
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import base64
3
+ import json
4
+ import os
5
+ import uuid
6
+ from typing import AsyncGenerator, Literal, Optional
7
 
8
  import numpy as np
9
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
10
+ from fastapi.responses import JSONResponse, StreamingResponse
11
+ from dotenv import load_dotenv
12
+
13
+ from fastrtc import AdditionalOutputs, AsyncStreamHandler, Stream, wait_for_item
14
+
15
+ # ---- Gemini (optional for later; right now we keep your echo handler working) ----
16
+ # You can plug Gemini back in once bridge works.
17
+
18
+ load_dotenv()
19
 
20
  app = FastAPI()
21
 
22
+ # ---------------------------
23
+ # Minimal VAD echo handler (server is already booting with this)
24
+ # ---------------------------
25
+ class EchoHandler(AsyncStreamHandler):
26
+ def __init__(self, expected_layout: Literal["mono"] = "mono", output_sample_rate: int = 24000):
27
+ super().__init__(expected_layout=expected_layout, output_sample_rate=output_sample_rate, input_sample_rate=16000)
28
+ self.out_q: asyncio.Queue[tuple[int, np.ndarray] | AdditionalOutputs] = asyncio.Queue()
29
+
30
+ def copy(self):
31
+ return EchoHandler()
32
 
33
+ async def receive(self, frame: tuple[int, np.ndarray]) -> None:
34
+ sr, audio = frame
35
+ audio = np.asarray(audio)
36
+ if audio.ndim == 2:
37
+ audio = audio.squeeze()
38
+ if audio.dtype != np.int16:
39
+ audio = audio.astype(np.int16)
40
+
41
+ # Echo back immediately as "audio"
42
+ self.out_q.put_nowait((sr, audio.reshape(1, -1)))
43
+
44
+ async def emit(self):
45
+ return await wait_for_item(self.out_q)
46
 
47
 
48
  stream = Stream(
49
+ handler=EchoHandler(),
50
  modality="audio",
51
  mode="send-receive",
52
+ additional_inputs=["voice_name"], # placeholder for later
53
  )
54
 
 
55
  stream.mount(app)
56
 
57
+ # ---------------------------
58
+ # Helpers
59
+ # ---------------------------
60
+ def b64_to_int16(b64: str) -> np.ndarray:
61
+ raw = base64.b64decode(b64)
62
+ return np.frombuffer(raw, dtype=np.int16)
63
+
64
+ def int16_to_b64(audio: np.ndarray) -> str:
65
+ if audio.dtype != np.int16:
66
+ audio = audio.astype(np.int16)
67
+ return base64.b64encode(audio.tobytes()).decode("utf-8")
68
+
69
 
70
+ # ---------------------------
71
+ # Basic endpoints
72
+ # ---------------------------
73
  @app.get("/")
74
  async def root():
75
+ return {"ok": True, "message": "FastRTC mounted. Use the mounted endpoints for WebRTC/WebSocket."}
76
+
77
+ @app.get("/health")
78
+ async def health():
79
+ return {"ok": True}
80
+
81
+ @app.get("/webrtc/new")
82
+ async def webrtc_new():
83
+ """
84
+ Mint a webrtc_id to use with /outputs or /ws bridge.
85
+ """
86
+ webrtc_id = str(uuid.uuid4())
87
+ # Initialize internal connection state so output_stream has something to bind to later
88
+ # (FastRTC will create it lazily when first used, but we create a stable id for the client.)
89
+ return {"webrtc_id": webrtc_id}
90
+
91
+ @app.get("/outputs")
92
+ async def outputs(webrtc_id: str):
93
+ async def event_stream():
94
+ async for out in stream.output_stream(webrtc_id):
95
+ payload = json.dumps(out.args[0] if out.args else None)
96
+ yield f"event: output\ndata: {payload}\n\n"
97
+ return StreamingResponse(event_stream(), media_type="text/event-stream")
98
+
99
+
100
+ # ---------------------------
101
+ # Scratch-friendly WebSocket bridge
102
+ # ---------------------------
103
+ @app.websocket("/ws")
104
+ async def ws_bridge(ws: WebSocket):
105
+ await ws.accept()
106
+
107
+ webrtc_id: Optional[str] = None
108
+ out_task: Optional[asyncio.Task] = None
109
+
110
+ async def send_outputs_loop():
111
+ # Stream AdditionalOutputs + audio coming out of FastRTC
112
+ try:
113
+ async for item in stream.output_stream(webrtc_id):
114
+ # item is AdditionalOutputs; forward as JSON
115
+ msg = item.args[0] if item.args else None
116
+ await ws.send_text(json.dumps({"type": "output", "data": msg}))
117
+ except Exception:
118
+ pass
119
+
120
+ async def send_audio_loop():
121
+ # Also poll the "audio" output if your handler emits raw audio tuples.
122
+ # FastRTC output_stream yields AdditionalOutputs only.
123
+ # So for audio we use stream.fetch_output(...) style by calling internal generator:
124
+ try:
125
+ async for out in stream.stream_output(webrtc_id):
126
+ # out can be (sr, np.ndarray) or AdditionalOutputs
127
+ if isinstance(out, AdditionalOutputs):
128
+ continue
129
+ sr, audio = out
130
+ audio = np.asarray(audio)
131
+ if audio.ndim == 2:
132
+ audio = audio.squeeze()
133
+ if audio.dtype != np.int16:
134
+ audio = audio.astype(np.int16)
135
+ await ws.send_text(json.dumps({
136
+ "type": "audio_delta",
137
+ "rate": int(sr),
138
+ "data": int16_to_b64(audio)
139
+ }))
140
+ except Exception:
141
+ pass
142
+
143
+ try:
144
+ while True:
145
+ raw = await ws.receive_text()
146
+ msg = json.loads(raw)
147
+ t = msg.get("type")
148
+
149
+ if t == "connect":
150
+ # create or use provided webrtc_id
151
+ webrtc_id = msg.get("webrtc_id") or str(uuid.uuid4())
152
+
153
+ # optionally set voice / other inputs (stored for handler)
154
+ voice = msg.get("voice") or "Puck"
155
+ try:
156
+ await stream.set_input(webrtc_id, voice)
157
+ except Exception:
158
+ # if set_input isn't supported in your exact FastRTC build, ignore
159
+ pass
160
+
161
+ # start output loops once
162
+ if out_task is None:
163
+ out_task = asyncio.gather(send_audio_loop(), send_outputs_loop())
164
+
165
+ await ws.send_text(json.dumps({"type": "ready", "webrtc_id": webrtc_id}))
166
+ continue
167
+
168
+ if t == "audio":
169
+ if not webrtc_id:
170
+ await ws.send_text(json.dumps({"type": "error", "message": "Not connected. Send {type:'connect'} first."}))
171
+ continue
172
+
173
+ b64 = msg.get("data")
174
+ rate = int(msg.get("rate") or 16000)
175
+
176
+ if not isinstance(b64, str) or not b64:
177
+ continue
178
+
179
+ audio = b64_to_int16(b64)
180
+
181
+ # FastRTC expects (sample_rate, np.ndarray)
182
+ await stream.send_input(webrtc_id, (rate, audio.reshape(1, -1)))
183
+ continue
184
+
185
+ if t == "close":
186
+ await ws.close()
187
+ return
188
+
189
+ await ws.send_text(json.dumps({"type": "error", "message": f"Unknown type: {t}"}))
190
+
191
+ except WebSocketDisconnect:
192
+ pass
193
+ finally:
194
+ try:
195
+ if out_task:
196
+ out_task.cancel()
197
+ except Exception:
198
+ pass