SalexAI commited on
Commit
176ee90
·
verified ·
1 Parent(s): 8ce42f6

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +68 -95
app/main.py CHANGED
@@ -11,9 +11,8 @@ import websockets
11
 
12
  load_dotenv()
13
 
14
- app = FastAPI(title="Gemini Live Native-Audio WS Proxy", version="2.0.0")
15
 
16
- # Gemini Live API WebSocket endpoint (v1beta, BidiGenerateContent)
17
  GEMINI_LIVE_WS_URL = (
18
  "wss://generativelanguage.googleapis.com/ws/"
19
  "google.ai.generativelanguage.v1beta.GenerativeService.BidiGenerateContent"
@@ -21,31 +20,54 @@ GEMINI_LIVE_WS_URL = (
21
 
22
  API_KEY = os.getenv("GEMINI_API_KEY", "").strip()
23
 
24
- # Defaults (override via HF Space variables)
25
- DEFAULT_MODEL = os.getenv("GEMINI_MODEL", "models/gemini-2.0-flash-live-001")
26
- DEFAULT_SYSTEM = os.getenv("GEMINI_SYSTEM_INSTRUCTION", "You are a helpful assistant for a school coding club.")
 
 
 
 
 
 
27
  DEFAULT_TEMPERATURE = float(os.getenv("GEMINI_TEMPERATURE", "0.7"))
28
  DEFAULT_MAX_TOKENS = int(os.getenv("GEMINI_MAX_OUTPUT_TOKENS", "1024"))
29
 
30
- # Native-audio config defaults
31
  DEFAULT_VOICE = os.getenv("GEMINI_VOICE_NAME", "Kore")
32
- # input audio: most common is 16k PCM16 mono
33
  DEFAULT_INPUT_RATE = int(os.getenv("GEMINI_INPUT_AUDIO_RATE", "16000"))
34
- # output audio: docs commonly mention 24k PCM16
35
  DEFAULT_OUTPUT_RATE = int(os.getenv("GEMINI_OUTPUT_AUDIO_RATE", "24000"))
36
 
37
- # Debug passthrough (set to "1" to enable)
38
  DEBUG_GEMINI_RAW = os.getenv("DEBUG_GEMINI_RAW", "0").strip() == "1"
39
 
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  @app.get("/health")
42
  async def health():
 
43
  ok = bool(API_KEY)
44
  return JSONResponse(
45
  {
46
  "ok": ok,
47
  "has_api_key": ok,
48
- "model": DEFAULT_MODEL,
49
  "voice": DEFAULT_VOICE,
50
  "input_rate": DEFAULT_INPUT_RATE,
51
  "output_rate": DEFAULT_OUTPUT_RATE,
@@ -64,9 +86,6 @@ def _extract_text_parts(content: Dict[str, Any]) -> str:
64
 
65
 
66
  def _extract_inline_audio_parts(content: Dict[str, Any]) -> List[Dict[str, str]]:
67
- """
68
- Returns list of {"mime": "...", "data": "base64..."} for any inlineData parts.
69
- """
70
  parts = content.get("parts") or []
71
  out: List[Dict[str, str]] = []
72
  for p in parts:
@@ -86,14 +105,13 @@ async def _gemini_ws_connect(setup_payload: Dict[str, Any]):
86
  ws = await websockets.connect(
87
  GEMINI_LIVE_WS_URL,
88
  extra_headers=headers,
89
- max_size=16 * 1024 * 1024,
90
  ping_interval=20,
91
  ping_timeout=20,
92
  )
93
 
94
  await ws.send(json.dumps(setup_payload))
95
 
96
- # wait for setupComplete
97
  while True:
98
  raw = await ws.recv()
99
  msg = json.loads(raw)
@@ -105,28 +123,6 @@ async def _gemini_ws_connect(setup_payload: Dict[str, Any]):
105
 
106
  @app.websocket("/ws")
107
  async def ws_proxy(client_ws: WebSocket):
108
- """
109
- Client protocol (native-audio + VAD friendly):
110
- -> {"type":"configure", "model": "...", "system_instruction": "...", "temperature": 0.7,
111
- "max_output_tokens": 1024, "voice": "Kore", "input_rate": 16000}
112
- (optional, must be first; else defaults are used)
113
-
114
- -> {"type":"audio","data":"<base64 pcm16 mono>","rate":16000}
115
- (send repeatedly while user is speaking)
116
-
117
- -> {"type":"audio_end"}
118
- (send when VAD decides user stopped speaking; triggers assistant response)
119
-
120
- -> {"type":"text","text":"..."} (optional helper; NOT the main mode for native audio)
121
-
122
- Server -> client:
123
- <- {"type":"ready"}
124
- <- {"type":"text_delta","text":"..."} (assistant text parts, if any)
125
- <- {"type":"audio_delta","mime":"...","data":"..."} (assistant audio chunks)
126
- <- {"type":"turn_complete"}
127
- <- {"type":"error","message":"..."}
128
- <- {"type":"gemini_raw","message":{...}} (only if DEBUG_GEMINI_RAW=1)
129
- """
130
  await client_ws.accept()
131
 
132
  if not API_KEY:
@@ -134,60 +130,57 @@ async def ws_proxy(client_ws: WebSocket):
134
  await client_ws.close(code=1011)
135
  return
136
 
137
- # --- Phase 1: accept optional configure before connecting to Gemini ---
138
  cfg = {
139
- "model": DEFAULT_MODEL,
140
- "system_instruction": DEFAULT_SYSTEM,
141
  "temperature": DEFAULT_TEMPERATURE,
142
  "max_output_tokens": DEFAULT_MAX_TOKENS,
143
- "voice": DEFAULT_VOICE,
144
  "input_rate": DEFAULT_INPUT_RATE,
145
  }
146
 
147
- async def _wait_for_optional_config(timeout_s: float = 1.2):
148
- try:
149
- raw = await asyncio.wait_for(client_ws.receive_text(), timeout=timeout_s)
150
- except asyncio.TimeoutError:
151
- return
152
- except Exception:
153
- return
154
-
155
- data = json.loads(raw)
156
- if data.get("type") != "configure":
157
- # if first message is not configure, we treat it as "not configure"
158
- # and stash it for later by putting it into a queue (simple: handle inline)
159
- return data
160
-
161
- # apply config
162
- if isinstance(data.get("model"), str) and data["model"].strip():
163
- cfg["model"] = data["model"].strip()
164
- if isinstance(data.get("system_instruction"), str) and data["system_instruction"].strip():
165
- cfg["system_instruction"] = data["system_instruction"].strip()
166
- if data.get("temperature") is not None:
167
  try:
168
- cfg["temperature"] = float(data["temperature"])
 
169
  except Exception:
170
  pass
171
- if data.get("max_output_tokens") is not None:
172
  try:
173
- cfg["max_output_tokens"] = int(data["max_output_tokens"])
 
174
  except Exception:
175
  pass
176
- if isinstance(data.get("voice"), str) and data["voice"].strip():
177
- cfg["voice"] = data["voice"].strip()
178
- if data.get("input_rate") is not None:
179
  try:
180
- cfg["input_rate"] = int(data["input_rate"])
 
181
  except Exception:
182
  pass
183
 
184
- await client_ws.send_text(json.dumps({"type": "configured"}))
185
- return None
 
 
 
 
 
186
 
187
- first_non_config = await _wait_for_optional_config()
 
188
 
189
- # --- Phase 2: connect to Gemini with native-audio setup ---
190
- # NOTE: For native-audio models, AUDIO modality is required.
191
  setup_payload = {
192
  "setup": {
193
  "model": cfg["model"],
@@ -203,7 +196,6 @@ async def ws_proxy(client_ws: WebSocket):
203
  }
204
  },
205
  },
206
- # Enable transcripts so Scratch can display text while audio plays
207
  "inputAudioTranscription": {},
208
  "outputAudioTranscription": {},
209
  "systemInstruction": {
@@ -218,15 +210,12 @@ async def ws_proxy(client_ws: WebSocket):
218
 
219
  try:
220
  gemini_ws = await _gemini_ws_connect(setup_payload)
221
- await client_ws.send_text(json.dumps({"type": "ready"}))
222
  except Exception as e:
223
  await client_ws.send_text(json.dumps({"type": "error", "message": f"Gemini setup failed: {e}"}))
224
  await client_ws.close(code=1011)
225
  return
226
 
227
- # If we consumed a non-config first message, we need to handle it.
228
- pending_first = first_non_config
229
-
230
  async def forward_client_to_gemini():
231
  nonlocal pending_first
232
  try:
@@ -245,7 +234,6 @@ async def ws_proxy(client_ws: WebSocket):
245
  return
246
 
247
  if t == "audio":
248
- # expects base64 PCM16 mono
249
  b64 = data.get("data")
250
  rate = data.get("rate", cfg["input_rate"])
251
  if not isinstance(b64, str) or not b64:
@@ -267,14 +255,10 @@ async def ws_proxy(client_ws: WebSocket):
267
  continue
268
 
269
  if t == "audio_end":
270
- # tell Gemini the input stream ended for this turn
271
- payload = {"realtimeInput": {"audioStreamEnd": True}}
272
- await gemini_ws.send(json.dumps(payload))
273
  continue
274
 
275
  if t == "text":
276
- # Optional helper: send text as a turn (some native-audio sessions still accept it),
277
- # but for voice-first you should mainly use audio.
278
  text = data.get("text", "")
279
  if isinstance(text, str) and text.strip():
280
  payload = {
@@ -286,11 +270,6 @@ async def ws_proxy(client_ws: WebSocket):
286
  await gemini_ws.send(json.dumps(payload))
287
  continue
288
 
289
- # Advanced passthrough
290
- if t == "live_raw" and isinstance(data.get("payload"), dict):
291
- await gemini_ws.send(json.dumps(data["payload"]))
292
- continue
293
-
294
  await client_ws.send_text(json.dumps({"type": "error", "message": f"Unknown message type: {t}"}))
295
 
296
  except WebSocketDisconnect:
@@ -315,19 +294,16 @@ async def ws_proxy(client_ws: WebSocket):
315
  if isinstance(server_content, dict):
316
  model_turn = server_content.get("modelTurn")
317
  if isinstance(model_turn, dict):
318
- # text parts
319
  txt = _extract_text_parts(model_turn)
320
  if txt:
321
  await client_ws.send_text(json.dumps({"type": "text_delta", "text": txt}))
322
 
323
- # audio parts (inlineData)
324
  audios = _extract_inline_audio_parts(model_turn)
325
  for a in audios:
326
  await client_ws.send_text(
327
  json.dumps({"type": "audio_delta", "mime": a["mime"], "data": a["data"]})
328
  )
329
 
330
- # Some implementations also include transcription fields; pass through if present
331
  out_tx = server_content.get("outputTranscription")
332
  if isinstance(out_tx, dict) and isinstance(out_tx.get("text"), str):
333
  await client_ws.send_text(
@@ -337,9 +313,6 @@ async def ws_proxy(client_ws: WebSocket):
337
  if server_content.get("generationComplete") is True:
338
  await client_ws.send_text(json.dumps({"type": "turn_complete"}))
339
 
340
- if "goAway" in msg:
341
- await client_ws.send_text(json.dumps({"type": "go_away", "goAway": msg["goAway"]}))
342
-
343
  except Exception as e:
344
  stop_event.set()
345
  try:
 
11
 
12
  load_dotenv()
13
 
14
+ app = FastAPI(title="Gemini Live Native-Audio WS Proxy", version="2.1.0")
15
 
 
16
  GEMINI_LIVE_WS_URL = (
17
  "wss://generativelanguage.googleapis.com/ws/"
18
  "google.ai.generativelanguage.v1beta.GenerativeService.BidiGenerateContent"
 
20
 
21
  API_KEY = os.getenv("GEMINI_API_KEY", "").strip()
22
 
23
+ # IMPORTANT: pick a REAL default model here (must support Live + native audio)
24
+ # Put your known-working native audio model id below:
25
+ FALLBACK_NATIVE_AUDIO_MODEL = "models/gemini-2.5-flash-native-audio-preview-12-2025"
26
+
27
+ DEFAULT_MODEL = os.getenv("GEMINI_MODEL", FALLBACK_NATIVE_AUDIO_MODEL)
28
+ DEFAULT_SYSTEM = os.getenv(
29
+ "GEMINI_SYSTEM_INSTRUCTION",
30
+ "You are a helpful assistant for a school coding club."
31
+ )
32
  DEFAULT_TEMPERATURE = float(os.getenv("GEMINI_TEMPERATURE", "0.7"))
33
  DEFAULT_MAX_TOKENS = int(os.getenv("GEMINI_MAX_OUTPUT_TOKENS", "1024"))
34
 
 
35
  DEFAULT_VOICE = os.getenv("GEMINI_VOICE_NAME", "Kore")
 
36
  DEFAULT_INPUT_RATE = int(os.getenv("GEMINI_INPUT_AUDIO_RATE", "16000"))
 
37
  DEFAULT_OUTPUT_RATE = int(os.getenv("GEMINI_OUTPUT_AUDIO_RATE", "24000"))
38
 
 
39
  DEBUG_GEMINI_RAW = os.getenv("DEBUG_GEMINI_RAW", "0").strip() == "1"
40
 
41
 
42
+ def _clean_str(x: Any) -> str:
43
+ if not isinstance(x, str):
44
+ return ""
45
+ return x.strip()
46
+
47
+
48
+ def _is_bad_model(s: str) -> bool:
49
+ s2 = (s or "").strip().lower()
50
+ return (not s2) or (s2 in {"undefined", "null", "none"})
51
+
52
+
53
+ def _safe_model(model: Any) -> str:
54
+ m = _clean_str(model)
55
+ if _is_bad_model(m):
56
+ m = _clean_str(DEFAULT_MODEL)
57
+ if _is_bad_model(m):
58
+ m = FALLBACK_NATIVE_AUDIO_MODEL
59
+ return m
60
+
61
+
62
  @app.get("/health")
63
  async def health():
64
+ model = _safe_model(DEFAULT_MODEL)
65
  ok = bool(API_KEY)
66
  return JSONResponse(
67
  {
68
  "ok": ok,
69
  "has_api_key": ok,
70
+ "model": model,
71
  "voice": DEFAULT_VOICE,
72
  "input_rate": DEFAULT_INPUT_RATE,
73
  "output_rate": DEFAULT_OUTPUT_RATE,
 
86
 
87
 
88
  def _extract_inline_audio_parts(content: Dict[str, Any]) -> List[Dict[str, str]]:
 
 
 
89
  parts = content.get("parts") or []
90
  out: List[Dict[str, str]] = []
91
  for p in parts:
 
105
  ws = await websockets.connect(
106
  GEMINI_LIVE_WS_URL,
107
  extra_headers=headers,
108
+ max_size=32 * 1024 * 1024,
109
  ping_interval=20,
110
  ping_timeout=20,
111
  )
112
 
113
  await ws.send(json.dumps(setup_payload))
114
 
 
115
  while True:
116
  raw = await ws.recv()
117
  msg = json.loads(raw)
 
123
 
124
  @app.websocket("/ws")
125
  async def ws_proxy(client_ws: WebSocket):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  await client_ws.accept()
127
 
128
  if not API_KEY:
 
130
  await client_ws.close(code=1011)
131
  return
132
 
133
+ # Defaults per connection
134
  cfg = {
135
+ "model": _safe_model(DEFAULT_MODEL),
136
+ "system_instruction": _clean_str(DEFAULT_SYSTEM) or "You are helpful.",
137
  "temperature": DEFAULT_TEMPERATURE,
138
  "max_output_tokens": DEFAULT_MAX_TOKENS,
139
+ "voice": _clean_str(DEFAULT_VOICE) or "Kore",
140
  "input_rate": DEFAULT_INPUT_RATE,
141
  }
142
 
143
+ # Wait briefly for optional configure (FIRST message)
144
+ pending_first: Optional[Dict[str, Any]] = None
145
+ try:
146
+ raw = await asyncio.wait_for(client_ws.receive_text(), timeout=1.2)
147
+ first = json.loads(raw)
148
+ if isinstance(first, dict) and first.get("type") == "configure":
149
+ cfg["model"] = _safe_model(first.get("model"))
150
+ si = _clean_str(first.get("system_instruction"))
151
+ if si:
152
+ cfg["system_instruction"] = si
 
 
 
 
 
 
 
 
 
 
153
  try:
154
+ if first.get("temperature") is not None:
155
+ cfg["temperature"] = float(first["temperature"])
156
  except Exception:
157
  pass
 
158
  try:
159
+ if first.get("max_output_tokens") is not None:
160
+ cfg["max_output_tokens"] = int(first["max_output_tokens"])
161
  except Exception:
162
  pass
163
+ v = _clean_str(first.get("voice"))
164
+ if v:
165
+ cfg["voice"] = v
166
  try:
167
+ if first.get("input_rate") is not None:
168
+ cfg["input_rate"] = int(first["input_rate"])
169
  except Exception:
170
  pass
171
 
172
+ await client_ws.send_text(json.dumps({"type": "configured"}))
173
+ else:
174
+ pending_first = first if isinstance(first, dict) else None
175
+ except asyncio.TimeoutError:
176
+ pass
177
+ except Exception:
178
+ pass
179
 
180
+ # FINAL guard (this prevents “undefined” ever reaching Gemini)
181
+ cfg["model"] = _safe_model(cfg["model"])
182
 
183
+ # Build native-audio session setup
 
184
  setup_payload = {
185
  "setup": {
186
  "model": cfg["model"],
 
196
  }
197
  },
198
  },
 
199
  "inputAudioTranscription": {},
200
  "outputAudioTranscription": {},
201
  "systemInstruction": {
 
210
 
211
  try:
212
  gemini_ws = await _gemini_ws_connect(setup_payload)
213
+ await client_ws.send_text(json.dumps({"type": "ready", "model": cfg["model"]}))
214
  except Exception as e:
215
  await client_ws.send_text(json.dumps({"type": "error", "message": f"Gemini setup failed: {e}"}))
216
  await client_ws.close(code=1011)
217
  return
218
 
 
 
 
219
  async def forward_client_to_gemini():
220
  nonlocal pending_first
221
  try:
 
234
  return
235
 
236
  if t == "audio":
 
237
  b64 = data.get("data")
238
  rate = data.get("rate", cfg["input_rate"])
239
  if not isinstance(b64, str) or not b64:
 
255
  continue
256
 
257
  if t == "audio_end":
258
+ await gemini_ws.send(json.dumps({"realtimeInput": {"audioStreamEnd": True}}))
 
 
259
  continue
260
 
261
  if t == "text":
 
 
262
  text = data.get("text", "")
263
  if isinstance(text, str) and text.strip():
264
  payload = {
 
270
  await gemini_ws.send(json.dumps(payload))
271
  continue
272
 
 
 
 
 
 
273
  await client_ws.send_text(json.dumps({"type": "error", "message": f"Unknown message type: {t}"}))
274
 
275
  except WebSocketDisconnect:
 
294
  if isinstance(server_content, dict):
295
  model_turn = server_content.get("modelTurn")
296
  if isinstance(model_turn, dict):
 
297
  txt = _extract_text_parts(model_turn)
298
  if txt:
299
  await client_ws.send_text(json.dumps({"type": "text_delta", "text": txt}))
300
 
 
301
  audios = _extract_inline_audio_parts(model_turn)
302
  for a in audios:
303
  await client_ws.send_text(
304
  json.dumps({"type": "audio_delta", "mime": a["mime"], "data": a["data"]})
305
  )
306
 
 
307
  out_tx = server_content.get("outputTranscription")
308
  if isinstance(out_tx, dict) and isinstance(out_tx.get("text"), str):
309
  await client_ws.send_text(
 
313
  if server_content.get("generationComplete") is True:
314
  await client_ws.send_text(json.dumps({"type": "turn_complete"}))
315
 
 
 
 
316
  except Exception as e:
317
  stop_event.set()
318
  try: