SalexAI commited on
Commit
9e04ed9
·
verified ·
1 Parent(s): 6a84946

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +34 -136
app/main.py CHANGED
@@ -1,5 +1,4 @@
1
  import asyncio
2
- import base64
3
  import json
4
  import os
5
  from typing import AsyncGenerator, Literal
@@ -9,91 +8,36 @@ from dotenv import load_dotenv
9
  from fastapi import FastAPI
10
  from fastapi.responses import StreamingResponse
11
 
12
- from fastrtc import (
13
- AdditionalOutputs,
14
- AsyncStreamHandler,
15
- Stream,
16
- wait_for_item,
17
- )
18
 
19
- from google import genai
20
- from google.genai.types import (
21
- LiveConnectConfig,
22
- PrebuiltVoiceConfig,
23
- SpeechConfig,
24
- VoiceConfig,
25
- )
26
 
27
- load_dotenv()
 
 
28
 
29
- # ---------------------------
30
- # Config (env vars)
31
- # ---------------------------
32
- # Put this in your HF Space "Secrets":
33
- # GEMINI_API_KEY = "..."
34
  GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")
35
-
36
- # Gemini realtime model (this is the one FastRTC uses in their Gemini demo Space)
37
- # You can change this later to another Live-capable model.
38
  GEMINI_LIVE_MODEL = os.getenv("GEMINI_LIVE_MODEL", "gemini-2.0-flash-exp")
39
-
40
- # Voice name (FastRTC Gemini demo uses "Puck" by default)
41
  DEFAULT_VOICE = os.getenv("GEMINI_VOICE", "Puck")
42
 
43
- # Sample rates
44
  OUTPUT_SAMPLE_RATE = int(os.getenv("OUTPUT_SAMPLE_RATE", "24000"))
45
- INPUT_SAMPLE_RATE = int(os.getenv("INPUT_SAMPLE_RATE", "16000")) # matches the demo Space
46
-
47
-
48
- def _encode_pcm16_mono_to_b64(data: np.ndarray) -> str:
49
- """
50
- Encodes int16 mono PCM to base64 for any custom debug endpoints.
51
- """
52
- if data.dtype != np.int16:
53
- data = data.astype(np.int16)
54
- return base64.b64encode(data.tobytes()).decode("utf-8")
55
 
56
 
57
  class GeminiLiveAudioHandler(AsyncStreamHandler):
58
- """
59
- FastRTC AsyncStreamHandler that connects to Gemini Live and streams AUDIO back.
60
-
61
- This is adapted from the official FastRTC Gemini demo Space code. :contentReference[oaicite:5]{index=5}
62
- """
63
-
64
- def __init__(
65
- self,
66
- expected_layout: Literal["mono"] = "mono",
67
- output_sample_rate: int = OUTPUT_SAMPLE_RATE,
68
- ) -> None:
69
- super().__init__(
70
- expected_layout=expected_layout,
71
- output_sample_rate=output_sample_rate,
72
- input_sample_rate=INPUT_SAMPLE_RATE,
73
- )
74
-
75
  self.input_queue: asyncio.Queue[bytes] = asyncio.Queue()
76
  self.output_queue: asyncio.Queue[tuple[int, np.ndarray] | AdditionalOutputs] = asyncio.Queue()
77
  self.quit = asyncio.Event()
78
 
79
  def copy(self) -> "GeminiLiveAudioHandler":
80
- # FastRTC uses .copy() to clone per-connection handlers
81
- return GeminiLiveAudioHandler(
82
- expected_layout="mono",
83
- output_sample_rate=self.output_sample_rate,
84
- )
85
 
86
  async def start_up(self) -> None:
87
- """
88
- Connect to Gemini Live, then continuously:
89
- - read user audio from self.stream()
90
- - receive model audio chunks and push them to output_queue
91
- """
92
- # Optional: allow per-connection overrides via "additional_inputs"
93
- # We wait for args to be set (FastRTC API docs show wait_for_args usage). :contentReference[oaicite:6]{index=6}
94
  await self.wait_for_args()
95
- # latest_args includes metadata at [0]; any custom inputs start at [1]
96
- # We'll accept: voice_name (str) as the single custom arg, fallback to DEFAULT_VOICE.
97
  voice_name = DEFAULT_VOICE
98
  try:
99
  if len(self.latest_args) >= 2 and isinstance(self.latest_args[1], str) and self.latest_args[1].strip():
@@ -101,21 +45,14 @@ class GeminiLiveAudioHandler(AsyncStreamHandler):
101
  except Exception:
102
  pass
103
 
104
- api_key = GEMINI_API_KEY
105
- if not api_key:
106
- # Fail early with a helpful message in the client.
107
- await self.output_queue.put(
108
- AdditionalOutputs({"type": "error", "message": "Missing GEMINI_API_KEY env var on the server."})
109
- )
110
  return
111
 
112
- client = genai.Client(
113
- api_key=api_key,
114
- http_options={"api_version": "v1alpha"}, # matches FastRTC Gemini demo Space :contentReference[oaicite:7]{index=7}
115
- )
116
 
117
  config = LiveConnectConfig(
118
- response_modalities=["AUDIO"], # AUDIO-only mode :contentReference[oaicite:8]{index=8}
119
  speech_config=SpeechConfig(
120
  voice_config=VoiceConfig(
121
  prebuilt_voice_config=PrebuiltVoiceConfig(voice_name=voice_name)
@@ -124,104 +61,65 @@ class GeminiLiveAudioHandler(AsyncStreamHandler):
124
  )
125
 
126
  async with client.aio.live.connect(model=GEMINI_LIVE_MODEL, config=config) as session:
127
- # session.start_stream takes an async generator of bytes
128
  async for audio in session.start_stream(stream=self._stream_pcm(), mime_type="audio/pcm"):
129
- if audio.data:
130
- # Gemini returns pcm16 bytes; convert to int16 array
131
  arr = np.frombuffer(audio.data, dtype=np.int16)
132
- # FastRTC expects (sample_rate, np.ndarray) shaped like (1, n) or (n,) depending on handler usage.
133
  self.output_queue.put_nowait((self.output_sample_rate, arr.reshape(1, -1)))
134
 
135
  async def _stream_pcm(self) -> AsyncGenerator[bytes, None]:
136
- """
137
- Provides PCM bytes to Gemini Live continuously.
138
- """
139
  while not self.quit.is_set():
140
  try:
141
  chunk = await asyncio.wait_for(self.input_queue.get(), timeout=0.1)
142
  yield chunk
143
- except (asyncio.TimeoutError, TimeoutError):
144
  pass
145
 
146
  async def receive(self, frame: tuple[int, np.ndarray]) -> None:
147
- """
148
- Called by FastRTC as audio frames arrive from the client.
149
- """
150
  _, audio = frame
151
- # Expect mono, int16-ish. Convert safely.
152
  audio = np.asarray(audio)
153
  if audio.ndim == 2:
154
  audio = audio.squeeze()
155
  if audio.dtype != np.int16:
156
  audio = audio.astype(np.int16)
157
-
158
- # Push raw PCM16 bytes to Gemini stream
159
  self.input_queue.put_nowait(audio.tobytes())
160
 
161
- async def emit(self) -> tuple[int, np.ndarray] | AdditionalOutputs | None:
162
- """
163
- Called by FastRTC to get the next outbound chunk (audio or structured outputs).
164
- """
165
  return await wait_for_item(self.output_queue)
166
 
167
  async def shutdown(self) -> None:
168
  self.quit.set()
169
 
170
 
171
- # ---------------------------
172
- # FastRTC Stream + FastAPI
173
- # ---------------------------
174
-
175
- # We expose one additional input: voice name
176
- # Clients can set it via Stream.set_input(...) patterns described in the FastRTC API docs. :contentReference[oaicite:9]{index=9}
177
  stream = Stream(
178
  handler=GeminiLiveAudioHandler(),
179
  modality="audio",
180
  mode="send-receive",
181
- additional_inputs=[
182
- # Keep it simple: one string
183
- # (FastRTC examples often use Gradio components here; in API mode we’ll set via set_input)
184
- # We still define it so handler.wait_for_args() has something to wait on.
185
- "voice_name"
186
- ],
187
  )
188
 
189
  app = FastAPI()
190
-
191
- # Mount FastRTC endpoints onto FastAPI (this is the core feature). :contentReference[oaicite:10]{index=10}
192
  stream.mount(app)
193
 
194
 
195
- # ---------------------------
196
- # Optional: server-side outputs stream (SSE)
197
- # Works well for Scratch/JS clients that want text/meta without WebRTC.
198
- # FastRTC docs show using stream.output_stream(webrtc_id). :contentReference[oaicite:11]{index=11}
199
- # The talk-to-openai Space uses the same approach. :contentReference[oaicite:12]{index=12}
200
- # ---------------------------
 
 
 
 
 
 
 
 
 
201
  @app.get("/outputs")
202
  async def outputs(webrtc_id: str):
203
  async def event_stream():
204
  async for out in stream.output_stream(webrtc_id):
205
- # out is an AdditionalOutputs instance
206
- # Serialize it as SSE "output" events
207
  payload = json.dumps(out.args[0] if out.args else None)
208
  yield f"event: output\ndata: {payload}\n\n"
209
-
210
  return StreamingResponse(event_stream(), media_type="text/event-stream")
211
-
212
-
213
- @app.get("/health")
214
- async def health():
215
- return {
216
- "ok": True,
217
- "provider": "gemini_live_audio",
218
- "model": GEMINI_LIVE_MODEL,
219
- "output_sample_rate": OUTPUT_SAMPLE_RATE,
220
- "input_sample_rate": INPUT_SAMPLE_RATE,
221
- }
222
-
223
-
224
- if __name__ == "__main__":
225
- import uvicorn
226
-
227
- uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "7860")))
 
1
  import asyncio
 
2
  import json
3
  import os
4
  from typing import AsyncGenerator, Literal
 
8
  from fastapi import FastAPI
9
  from fastapi.responses import StreamingResponse
10
 
11
+ load_dotenv()
 
 
 
 
 
12
 
13
+ # Import gradio first so if it fails, it fails loudly & early
14
+ import gradio as gr # noqa
 
 
 
 
 
15
 
16
+ from fastrtc import AdditionalOutputs, AsyncStreamHandler, Stream, wait_for_item
17
+ from google import genai
18
+ from google.genai.types import LiveConnectConfig, PrebuiltVoiceConfig, SpeechConfig, VoiceConfig
19
 
 
 
 
 
 
20
  GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")
 
 
 
21
  GEMINI_LIVE_MODEL = os.getenv("GEMINI_LIVE_MODEL", "gemini-2.0-flash-exp")
 
 
22
  DEFAULT_VOICE = os.getenv("GEMINI_VOICE", "Puck")
23
 
 
24
  OUTPUT_SAMPLE_RATE = int(os.getenv("OUTPUT_SAMPLE_RATE", "24000"))
25
+ INPUT_SAMPLE_RATE = int(os.getenv("INPUT_SAMPLE_RATE", "16000"))
 
 
 
 
 
 
 
 
 
26
 
27
 
28
  class GeminiLiveAudioHandler(AsyncStreamHandler):
29
+ def __init__(self, expected_layout: Literal["mono"] = "mono", output_sample_rate: int = OUTPUT_SAMPLE_RATE):
30
+ super().__init__(expected_layout=expected_layout, output_sample_rate=output_sample_rate, input_sample_rate=INPUT_SAMPLE_RATE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  self.input_queue: asyncio.Queue[bytes] = asyncio.Queue()
32
  self.output_queue: asyncio.Queue[tuple[int, np.ndarray] | AdditionalOutputs] = asyncio.Queue()
33
  self.quit = asyncio.Event()
34
 
35
  def copy(self) -> "GeminiLiveAudioHandler":
36
+ return GeminiLiveAudioHandler(expected_layout="mono", output_sample_rate=self.output_sample_rate)
 
 
 
 
37
 
38
  async def start_up(self) -> None:
 
 
 
 
 
 
 
39
  await self.wait_for_args()
40
+
 
41
  voice_name = DEFAULT_VOICE
42
  try:
43
  if len(self.latest_args) >= 2 and isinstance(self.latest_args[1], str) and self.latest_args[1].strip():
 
45
  except Exception:
46
  pass
47
 
48
+ if not GEMINI_API_KEY:
49
+ await self.output_queue.put(AdditionalOutputs({"type": "error", "message": "Missing GEMINI_API_KEY on server."}))
 
 
 
 
50
  return
51
 
52
+ client = genai.Client(api_key=GEMINI_API_KEY, http_options={"api_version": "v1alpha"})
 
 
 
53
 
54
  config = LiveConnectConfig(
55
+ response_modalities=["AUDIO"],
56
  speech_config=SpeechConfig(
57
  voice_config=VoiceConfig(
58
  prebuilt_voice_config=PrebuiltVoiceConfig(voice_name=voice_name)
 
61
  )
62
 
63
  async with client.aio.live.connect(model=GEMINI_LIVE_MODEL, config=config) as session:
 
64
  async for audio in session.start_stream(stream=self._stream_pcm(), mime_type="audio/pcm"):
65
+ if getattr(audio, "data", None):
 
66
  arr = np.frombuffer(audio.data, dtype=np.int16)
 
67
  self.output_queue.put_nowait((self.output_sample_rate, arr.reshape(1, -1)))
68
 
69
  async def _stream_pcm(self) -> AsyncGenerator[bytes, None]:
 
 
 
70
  while not self.quit.is_set():
71
  try:
72
  chunk = await asyncio.wait_for(self.input_queue.get(), timeout=0.1)
73
  yield chunk
74
+ except asyncio.TimeoutError:
75
  pass
76
 
77
  async def receive(self, frame: tuple[int, np.ndarray]) -> None:
 
 
 
78
  _, audio = frame
 
79
  audio = np.asarray(audio)
80
  if audio.ndim == 2:
81
  audio = audio.squeeze()
82
  if audio.dtype != np.int16:
83
  audio = audio.astype(np.int16)
 
 
84
  self.input_queue.put_nowait(audio.tobytes())
85
 
86
+ async def emit(self):
 
 
 
87
  return await wait_for_item(self.output_queue)
88
 
89
  async def shutdown(self) -> None:
90
  self.quit.set()
91
 
92
 
 
 
 
 
 
 
93
  stream = Stream(
94
  handler=GeminiLiveAudioHandler(),
95
  modality="audio",
96
  mode="send-receive",
97
+ additional_inputs=["voice_name"],
 
 
 
 
 
98
  )
99
 
100
  app = FastAPI()
 
 
101
  stream.mount(app)
102
 
103
 
104
+ @app.get("/health")
105
+ async def health():
106
+ return {"ok": True, "model": GEMINI_LIVE_MODEL}
107
+
108
+
109
+ @app.get("/versions")
110
+ async def versions():
111
+ import fastrtc
112
+ return {
113
+ "gradio": getattr(gr, "__version__", "unknown"),
114
+ "fastrtc": getattr(fastrtc, "__version__", "unknown"),
115
+ "python": f"{os.sys.version_info.major}.{os.sys.version_info.minor}.{os.sys.version_info.micro}",
116
+ }
117
+
118
+
119
  @app.get("/outputs")
120
  async def outputs(webrtc_id: str):
121
  async def event_stream():
122
  async for out in stream.output_stream(webrtc_id):
 
 
123
  payload = json.dumps(out.args[0] if out.args else None)
124
  yield f"event: output\ndata: {payload}\n\n"
 
125
  return StreamingResponse(event_stream(), media_type="text/event-stream")