internationalscholarsprogram commited on
Commit
b2d838a
·
1 Parent(s): 8171b65

Fix HF Space: FastAPI health check, Gemini WS bridge, deps & Docker CMD

Browse files
Files changed (3) hide show
  1. Dockerfile +1 -1
  2. app.py +35 -28
  3. requirements.txt +2 -0
Dockerfile CHANGED
@@ -10,4 +10,4 @@ COPY . .
10
  ENV PORT=7860
11
  EXPOSE 7860
12
 
13
- CMD ["python", "app.py"]
 
10
  ENV PORT=7860
11
  EXPOSE 7860
12
 
13
+ CMD ["sh", "-c", "uvicorn app:app --host 0.0.0.0 --port ${PORT}"]
app.py CHANGED
@@ -2,20 +2,30 @@
2
  import json
3
  import os
4
  import base64
5
- import websockets
 
6
  from google import genai
7
 
8
  MODEL = os.environ.get("MODEL", "gemini-2.0-flash-exp")
9
 
10
  # IMPORTANT: don't hardcode keys in code.
11
  # Set GOOGLE_API_KEY as a Hugging Face Secret.
12
- client = genai.Client(
13
- http_options={"api_version": "v1alpha"}
14
- )
 
 
 
 
 
 
 
 
 
15
 
16
- async def gemini_session_handler(client_websocket: websockets.WebSocketServerProtocol):
17
  try:
18
- config_message = await client_websocket.recv()
 
19
  config_data = json.loads(config_message)
20
 
21
  config = config_data.get("setup", {})
@@ -29,9 +39,11 @@ async def gemini_session_handler(client_websocket: websockets.WebSocketServerPro
29
  )
30
 
31
  async with client.aio.live.connect(model=MODEL, config=config) as session:
 
32
  async def send_to_gemini():
33
  try:
34
- async for message in client_websocket:
 
35
  data = json.loads(message)
36
 
37
  if "realtime_input" in data:
@@ -40,6 +52,8 @@ async def gemini_session_handler(client_websocket: websockets.WebSocketServerPro
40
  payload = chunk.get("data")
41
  if mt in ("audio/pcm", "image/jpeg") and payload:
42
  await session.send({"mime_type": mt, "data": payload})
 
 
43
  except Exception as e:
44
  print(f"send_to_gemini error: {e}")
45
 
@@ -53,36 +67,29 @@ async def gemini_session_handler(client_websocket: websockets.WebSocketServerPro
53
  if model_turn:
54
  for part in model_turn.parts:
55
  if getattr(part, "text", None):
56
- await client_websocket.send(json.dumps({"text": part.text}))
57
  elif getattr(part, "inline_data", None):
58
  b64_audio = base64.b64encode(part.inline_data.data).decode("utf-8")
59
- await client_websocket.send(json.dumps({"audio": b64_audio}))
60
 
61
  if response.server_content.turn_complete:
62
- # optional: notify turn complete
63
- await client_websocket.send(json.dumps({"turn_complete": True}))
 
64
  except Exception as e:
65
  print(f"receive_from_gemini error: {e}")
66
 
67
- await asyncio.gather(
68
- asyncio.create_task(send_to_gemini()),
69
- asyncio.create_task(receive_from_gemini()),
70
- )
71
 
 
 
72
  except Exception as e:
73
- print(f"gemini_session_handler error: {e}")
74
  try:
75
- await client_websocket.send(json.dumps({"error": str(e)}))
 
 
 
 
76
  except Exception:
77
  pass
78
-
79
- async def main():
80
- port = int(os.environ.get("PORT", "7860"))
81
- host = "0.0.0.0"
82
- print(f"Starting WebSocket server on {host}:{port}")
83
-
84
- async with websockets.serve(gemini_session_handler, host, port):
85
- await asyncio.Future() # run forever
86
-
87
- if __name__ == "__main__":
88
- asyncio.run(main())
 
2
  import json
3
  import os
4
  import base64
5
+
6
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
7
  from google import genai
8
 
9
  MODEL = os.environ.get("MODEL", "gemini-2.0-flash-exp")
10
 
11
  # IMPORTANT: don't hardcode keys in code.
12
  # Set GOOGLE_API_KEY as a Hugging Face Secret.
13
+ client = genai.Client(http_options={"api_version": "v1alpha"})
14
+
15
+ app = FastAPI()
16
+
17
+ @app.get("/")
18
+ def health():
19
+ # Hugging Face health check hits this as plain HTTP.
20
+ return {"status": "ok"}
21
+
22
+ @app.websocket("/ws")
23
+ async def gemini_ws_bridge(ws: WebSocket):
24
+ await ws.accept()
25
 
 
26
  try:
27
+ # First message must be config (same as your old code)
28
+ config_message = await ws.receive_text()
29
  config_data = json.loads(config_message)
30
 
31
  config = config_data.get("setup", {})
 
39
  )
40
 
41
  async with client.aio.live.connect(model=MODEL, config=config) as session:
42
+
43
  async def send_to_gemini():
44
  try:
45
+ while True:
46
+ message = await ws.receive_text()
47
  data = json.loads(message)
48
 
49
  if "realtime_input" in data:
 
52
  payload = chunk.get("data")
53
  if mt in ("audio/pcm", "image/jpeg") and payload:
54
  await session.send({"mime_type": mt, "data": payload})
55
+ except WebSocketDisconnect:
56
+ pass
57
  except Exception as e:
58
  print(f"send_to_gemini error: {e}")
59
 
 
67
  if model_turn:
68
  for part in model_turn.parts:
69
  if getattr(part, "text", None):
70
+ await ws.send_text(json.dumps({"text": part.text}))
71
  elif getattr(part, "inline_data", None):
72
  b64_audio = base64.b64encode(part.inline_data.data).decode("utf-8")
73
+ await ws.send_text(json.dumps({"audio": b64_audio}))
74
 
75
  if response.server_content.turn_complete:
76
+ await ws.send_text(json.dumps({"turn_complete": True}))
77
+ except WebSocketDisconnect:
78
+ pass
79
  except Exception as e:
80
  print(f"receive_from_gemini error: {e}")
81
 
82
+ await asyncio.gather(send_to_gemini(), receive_from_gemini())
 
 
 
83
 
84
+ except WebSocketDisconnect:
85
+ pass
86
  except Exception as e:
87
+ print(f"gemini_ws_bridge error: {e}")
88
  try:
89
+ await ws.send_text(json.dumps({"error": str(e)}))
90
+ except Exception:
91
+ pass
92
+ try:
93
+ await ws.close()
94
  except Exception:
95
  pass
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,2 +1,4 @@
1
  google-genai==0.3.0
2
  websockets==14.1
 
 
 
1
  google-genai==0.3.0
2
  websockets==14.1
3
+ fastapi==0.115.6
4
+ uvicorn[standard]==0.32.1