SalexAI commited on
Commit
3e9ae46
·
verified ·
1 Parent(s): bf898ab

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +299 -0
main.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import asyncio
4
+ from typing import Any, Dict, Optional
5
+
6
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
7
+ from fastapi.responses import JSONResponse
8
+ from dotenv import load_dotenv
9
+ import websockets
10
+
11
+ load_dotenv()
12
+
13
+ app = FastAPI(title="Gemini Live WS Proxy", version="1.0.0")
14
+
15
+ # Gemini Live API WebSocket endpoint for BidiGenerateContent (v1beta)
16
+ # (Official endpoint in the Live API WebSockets reference.)
17
+ GEMINI_LIVE_WS_URL = (
18
+ "wss://generativelanguage.googleapis.com/ws/"
19
+ "google.ai.generativelanguage.v1beta.GenerativeService.BidiGenerateContent"
20
+ )
21
+
22
+ DEFAULT_MODEL = os.getenv("GEMINI_MODEL", "models/gemini-2.5-flash")
23
+ DEFAULT_SYSTEM = os.getenv("GEMINI_SYSTEM_INSTRUCTION", "You are a helpful assistant.")
24
+ DEFAULT_TEMPERATURE = float(os.getenv("GEMINI_TEMPERATURE", "0.7"))
25
+ DEFAULT_MAX_TOKENS = int(os.getenv("GEMINI_MAX_OUTPUT_TOKENS", "1024"))
26
+
27
+ API_KEY = os.getenv("GEMINI_API_KEY", "").strip()
28
+ if not API_KEY:
29
+ # Don't crash import-time on HF if they haven't set secrets yet;
30
+ # we will return a clear runtime error at connection time.
31
+ pass
32
+
33
+
34
+ @app.get("/health")
35
+ async def health():
36
+ ok = bool(API_KEY)
37
+ return JSONResponse(
38
+ {
39
+ "ok": ok,
40
+ "has_api_key": ok,
41
+ "model": DEFAULT_MODEL,
42
+ }
43
+ )
44
+
45
+
46
+ def _safe_get_text_from_content(content: Dict[str, Any]) -> str:
47
+ """
48
+ Gemini Content format typically includes:
49
+ {"role": "...", "parts": [{"text": "..."} , ...]}
50
+ We concatenate any text parts we see.
51
+ """
52
+ parts = content.get("parts") or []
53
+ out = []
54
+ for p in parts:
55
+ if isinstance(p, dict) and "text" in p and isinstance(p["text"], str):
56
+ out.append(p["text"])
57
+ return "".join(out)
58
+
59
+
60
+ async def _gemini_connect_and_setup(
61
+ model: str,
62
+ system_instruction: str,
63
+ temperature: float,
64
+ max_output_tokens: int,
65
+ response_modalities: Optional[list] = None,
66
+ ):
67
+ """
68
+ Opens a websocket to Gemini Live API and sends the required initial setup message.
69
+ Clients should wait for setupComplete before sending further messages.
70
+ """
71
+ headers = {
72
+ # Gemini API auth: x-goog-api-key header is required for requests. :contentReference[oaicite:2]{index=2}
73
+ "x-goog-api-key": API_KEY,
74
+ }
75
+
76
+ ws = await websockets.connect(
77
+ GEMINI_LIVE_WS_URL,
78
+ extra_headers=headers,
79
+ max_size=8 * 1024 * 1024, # allow larger payloads if needed later
80
+ ping_interval=20,
81
+ ping_timeout=20,
82
+ )
83
+
84
+ setup_payload = {
85
+ "setup": {
86
+ "model": model,
87
+ "generationConfig": {
88
+ "temperature": temperature,
89
+ "maxOutputTokens": max_output_tokens,
90
+ "responseModalities": response_modalities or ["TEXT"],
91
+ },
92
+ # Live API reference shows systemInstruction is Content; we send text-only Content.
93
+ # (Docs note text parts in system instruction.) :contentReference[oaicite:3]{index=3}
94
+ "systemInstruction": {
95
+ "role": "system",
96
+ "parts": [{"text": system_instruction}],
97
+ },
98
+ }
99
+ }
100
+
101
+ await ws.send(json.dumps(setup_payload))
102
+
103
+ # Wait for setupComplete
104
+ while True:
105
+ raw = await ws.recv()
106
+ msg = json.loads(raw)
107
+ if "setupComplete" in msg:
108
+ return ws
109
+ # Forward other early messages if they appear, but don't block setup forever.
110
+ # If Gemini returns an error-like structure, surface it.
111
+ if "error" in msg:
112
+ raise RuntimeError(f"Gemini setup error: {msg['error']}")
113
+
114
+
115
+ @app.websocket("/ws")
116
+ async def ws_proxy(client_ws: WebSocket):
117
+ """
118
+ Client protocol (simple):
119
+ -> {"type":"text","text":"hello"}
120
+ -> {"type":"configure", "model": "...", "system_instruction": "...", "temperature": 0.7, "max_output_tokens": 1024}
121
+ -> {"type":"close"}
122
+
123
+ Server sends:
124
+ <- {"type":"ready"}
125
+ <- {"type":"text_delta","text":"..."} (streaming)
126
+ <- {"type":"turn_complete"}
127
+ <- {"type":"gemini_raw","message":{...}} (debug passthrough)
128
+ <- {"type":"error","message":"..."}
129
+ """
130
+ await client_ws.accept()
131
+
132
+ if not API_KEY:
133
+ await client_ws.send_text(
134
+ json.dumps(
135
+ {
136
+ "type": "error",
137
+ "message": "Server missing GEMINI_API_KEY env var. Set it in your Space secrets.",
138
+ }
139
+ )
140
+ )
141
+ await client_ws.close(code=1011)
142
+ return
143
+
144
+ # Per-connection defaults (can be overridden by configure message)
145
+ model = DEFAULT_MODEL
146
+ system_instruction = DEFAULT_SYSTEM
147
+ temperature = DEFAULT_TEMPERATURE
148
+ max_output_tokens = DEFAULT_MAX_TOKENS
149
+
150
+ gemini_ws = None
151
+ stop_event = asyncio.Event()
152
+
153
+ async def ensure_gemini():
154
+ nonlocal gemini_ws
155
+ if gemini_ws is None:
156
+ gemini_ws = await _gemini_connect_and_setup(
157
+ model=model,
158
+ system_instruction=system_instruction,
159
+ temperature=temperature,
160
+ max_output_tokens=max_output_tokens,
161
+ response_modalities=["TEXT"],
162
+ )
163
+
164
+ async def forward_client_to_gemini():
165
+ """
166
+ Reads from your client WebSocket and sends appropriate Live API messages to Gemini.
167
+ Uses clientContent + turnComplete for clean text turns. :contentReference[oaicite:4]{index=4}
168
+ """
169
+ try:
170
+ while not stop_event.is_set():
171
+ raw = await client_ws.receive_text()
172
+ data = json.loads(raw)
173
+
174
+ msg_type = data.get("type")
175
+ if msg_type == "configure":
176
+ # Allow config BEFORE Gemini connection is created.
177
+ if gemini_ws is not None:
178
+ await client_ws.send_text(
179
+ json.dumps(
180
+ {
181
+ "type": "error",
182
+ "message": "Cannot configure after session started. Open a new WS connection.",
183
+ }
184
+ )
185
+ )
186
+ continue
187
+
188
+ model = data.get("model", model)
189
+ system_instruction = data.get("system_instruction", system_instruction)
190
+ temperature = float(data.get("temperature", temperature))
191
+ max_output_tokens = int(data.get("max_output_tokens", max_output_tokens))
192
+ await client_ws.send_text(json.dumps({"type": "configured"}))
193
+ continue
194
+
195
+ if msg_type == "close":
196
+ stop_event.set()
197
+ return
198
+
199
+ if msg_type == "text":
200
+ text = data.get("text", "")
201
+ if not isinstance(text, str) or not text.strip():
202
+ continue
203
+
204
+ await ensure_gemini()
205
+
206
+ # Send a single "turn" using clientContent.turns and turnComplete=true. :contentReference[oaicite:5]{index=5}
207
+ payload = {
208
+ "clientContent": {
209
+ "turns": [
210
+ {
211
+ "role": "user",
212
+ "parts": [{"text": text}],
213
+ }
214
+ ],
215
+ "turnComplete": True,
216
+ }
217
+ }
218
+ await gemini_ws.send(json.dumps(payload))
219
+ continue
220
+
221
+ # Optional: raw passthrough (advanced users)
222
+ if msg_type == "live_raw":
223
+ await ensure_gemini()
224
+ payload = data.get("payload")
225
+ if isinstance(payload, dict):
226
+ await gemini_ws.send(json.dumps(payload))
227
+ continue
228
+
229
+ await client_ws.send_text(
230
+ json.dumps({"type": "error", "message": f"Unknown message type: {msg_type}"})
231
+ )
232
+
233
+ except WebSocketDisconnect:
234
+ stop_event.set()
235
+ except Exception as e:
236
+ stop_event.set()
237
+ try:
238
+ await client_ws.send_text(json.dumps({"type": "error", "message": str(e)}))
239
+ except Exception:
240
+ pass
241
+
242
+ async def forward_gemini_to_client():
243
+ """
244
+ Reads Gemini Live API server messages and forwards useful pieces to your client.
245
+ We extract text from serverContent.modelTurn.parts[].text when present. :contentReference[oaicite:6]{index=6}
246
+ """
247
+ try:
248
+ await ensure_gemini()
249
+ await client_ws.send_text(json.dumps({"type": "ready"}))
250
+
251
+ while not stop_event.is_set():
252
+ raw = await gemini_ws.recv()
253
+ msg = json.loads(raw)
254
+
255
+ # Optional debug passthrough:
256
+ await client_ws.send_text(json.dumps({"type": "gemini_raw", "message": msg}))
257
+
258
+ # The main streaming content arrives under "serverContent"
259
+ server_content = msg.get("serverContent")
260
+ if isinstance(server_content, dict):
261
+ # modelTurn is Content (role+parts)
262
+ model_turn = server_content.get("modelTurn")
263
+ if isinstance(model_turn, dict):
264
+ delta = _safe_get_text_from_content(model_turn)
265
+ if delta:
266
+ await client_ws.send_text(json.dumps({"type": "text_delta", "text": delta}))
267
+
268
+ # When generationComplete true, we end the turn
269
+ if server_content.get("generationComplete") is True:
270
+ await client_ws.send_text(json.dumps({"type": "turn_complete"}))
271
+
272
+ # Tool calls (if you later enable tools in setup)
273
+ if "toolCall" in msg:
274
+ await client_ws.send_text(json.dumps({"type": "tool_call", "toolCall": msg["toolCall"]}))
275
+
276
+ if "goAway" in msg:
277
+ await client_ws.send_text(json.dumps({"type": "go_away", "goAway": msg["goAway"]}))
278
+
279
+ except Exception as e:
280
+ stop_event.set()
281
+ try:
282
+ await client_ws.send_text(json.dumps({"type": "error", "message": f"Gemini link error: {e}"}))
283
+ except Exception:
284
+ pass
285
+
286
+ try:
287
+ # Run both directions
288
+ await asyncio.gather(forward_client_to_gemini(), forward_gemini_to_client())
289
+ finally:
290
+ stop_event.set()
291
+ try:
292
+ if gemini_ws is not None:
293
+ await gemini_ws.close()
294
+ except Exception:
295
+ pass
296
+ try:
297
+ await client_ws.close()
298
+ except Exception:
299
+ pass