saadpie commited on
Commit
06aae43
·
verified ·
1 Parent(s): 07104ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -41
app.py CHANGED
@@ -2,15 +2,13 @@ import os
2
  import asyncio
3
  from quart import Quart, websocket
4
  from google import genai
5
- from google.genai import types
6
 
7
  app = Quart(__name__)
8
 
9
  # Ensure your HF Space has GEMINI_API_KEY set in its secrets/environment variables
10
  client = genai.Client()
11
 
12
- # Note: The official live model name is currently gemini-2.0-flash-exp.
13
- # Update this if you have specific access to a 3.1 live preview endpoint.
14
  MODEL = "gemini-2.0-flash-exp"
15
 
16
  VOICE_MODES = {
@@ -34,36 +32,31 @@ async def index():
34
  async def ws_stream():
35
  """
36
  WebSocket endpoint for the Termux client.
37
- Connect via: ws://<hf-space-url>/stream?voice=Zephyr
38
  """
39
- # Grab the requested voice from the URL parameter, default to Zephyr
40
  requested_voice = websocket.args.get("voice", "Zephyr")
41
  voice_name = VOICE_MODES.get(requested_voice, "Zephyr")
42
 
43
- # Mirroring your TS configuration
44
- config = types.LiveConnectConfig(
45
- response_modalities=[types.LiveModality.AUDIO],
46
- speech_config=types.SpeechConfig(
47
- voice_config=types.VoiceConfig(
48
- prebuilt_voice_config=types.PrebuiltVoiceConfig(
49
- voice_name=voice_name
50
- )
51
- )
52
- ),
53
- tools=[{"google_search": {}}],
54
- system_instruction=types.Content(
55
- parts=[types.Part.from_text(
56
- "You are ASH-BAND, a high-fidelity AI wearable companion. "
57
- "Speak in a professional, concise, and helpful tone. "
58
- "You have access to Google Search. Keep responses brief to minimize latency. "
59
- "Your responses are spoken aloud."
60
- )]
61
  )
62
- )
63
 
64
  print(f"Connecting to Gemini Live API with voice: {voice_name}...")
65
 
66
  try:
 
67
  async with client.aio.live.connect(model=MODEL, config=config) as session:
68
  print("Live session established.")
69
 
@@ -71,10 +64,9 @@ async def ws_stream():
71
  async def client_to_gemini():
72
  try:
73
  while True:
74
- # Receive audio chunks from the client
75
  data = await websocket.receive()
76
  if isinstance(data, bytes):
77
- # The TS file was downsampling to 16000Hz PCM
78
  await session.send(
79
  input={"data": data, "mime_type": "audio/pcm;rate=16000"}
80
  )
@@ -88,18 +80,15 @@ async def ws_stream():
88
  try:
89
  async for message in session.receive():
90
  server_content = message.server_content
91
- if server_content is not None:
92
- # Handle Interruption
93
  if server_content.interrupted:
94
- print("AI Interrupted by user.")
95
- # In a more complex setup, send a control message to client to clear audio queue
96
 
97
  model_turn = server_content.model_turn
98
- if model_turn is not None:
99
  for part in model_turn.parts:
100
- # Output raw audio back to the client
101
  if part.inline_data and part.inline_data.data:
102
- # Gemini returns 24kHz PCM audio
103
  await websocket.send(part.inline_data.data)
104
  except asyncio.CancelledError:
105
  pass
@@ -110,20 +99,17 @@ async def ws_stream():
110
  task1 = asyncio.create_task(client_to_gemini())
111
  task2 = asyncio.create_task(gemini_to_client())
112
 
113
- # Wait until one of the connections drops
114
- done, pending = await asyncio.wait(
115
  [task1, task2],
116
  return_when=asyncio.FIRST_COMPLETED,
117
  )
118
 
119
- # Clean up the remaining task
120
- for p in pending:
121
- p.cancel()
122
 
123
  except Exception as e:
124
  print(f"Connection failed: {e}")
125
 
126
- # Hugging Face Spaces standard port is 7860
127
  if __name__ == "__main__":
128
- app.run(host="0.0.0.0", port=7860)
129
-
 
2
  import asyncio
3
  from quart import Quart, websocket
4
  from google import genai
 
5
 
6
  app = Quart(__name__)
7
 
8
  # Ensure your HF Space has GEMINI_API_KEY set in its secrets/environment variables
9
  client = genai.Client()
10
 
11
+ # Note: Using gemini-2.0-flash-exp as it is the most stable for the Live SDK currently
 
12
  MODEL = "gemini-2.0-flash-exp"
13
 
14
  VOICE_MODES = {
 
32
  async def ws_stream():
33
  """
34
  WebSocket endpoint for the Termux client.
35
+ Connect via: wss://<hf-space-url>/stream?voice=Zephyr
36
  """
 
37
  requested_voice = websocket.args.get("voice", "Zephyr")
38
  voice_name = VOICE_MODES.get(requested_voice, "Zephyr")
39
 
40
+ # Using a dictionary for config prevents AttributeError on specific SDK versions
41
+ config = {
42
+ "response_modalities": ["AUDIO"],
43
+ "speech_config": {
44
+ "voice_config": {
45
+ "prebuilt_voice_config": {"voice_name": voice_name}
46
+ }
47
+ },
48
+ "tools": [{"google_search": {}}],
49
+ "system_instruction": (
50
+ "You are ASH-BAND, a high-fidelity AI wearable companion. "
51
+ "Speak in a professional, concise, and helpful tone. "
52
+ "Keep responses brief to minimize latency. Your responses are spoken aloud."
 
 
 
 
 
53
  )
54
+ }
55
 
56
  print(f"Connecting to Gemini Live API with voice: {voice_name}...")
57
 
58
  try:
59
+ # Pass the dictionary directly to the config parameter
60
  async with client.aio.live.connect(model=MODEL, config=config) as session:
61
  print("Live session established.")
62
 
 
64
  async def client_to_gemini():
65
  try:
66
  while True:
 
67
  data = await websocket.receive()
68
  if isinstance(data, bytes):
69
+ # Sending 16kHz PCM data from client to Gemini
70
  await session.send(
71
  input={"data": data, "mime_type": "audio/pcm;rate=16000"}
72
  )
 
80
  try:
81
  async for message in session.receive():
82
  server_content = message.server_content
83
+ if server_content:
 
84
  if server_content.interrupted:
85
+ print("AI Interrupted.")
 
86
 
87
  model_turn = server_content.model_turn
88
+ if model_turn:
89
  for part in model_turn.parts:
 
90
  if part.inline_data and part.inline_data.data:
91
+ # Sending 24kHz PCM data back to client
92
  await websocket.send(part.inline_data.data)
93
  except asyncio.CancelledError:
94
  pass
 
99
  task1 = asyncio.create_task(client_to_gemini())
100
  task2 = asyncio.create_task(gemini_to_client())
101
 
102
+ await asyncio.wait(
 
103
  [task1, task2],
104
  return_when=asyncio.FIRST_COMPLETED,
105
  )
106
 
107
+ task1.cancel()
108
+ task2.cancel()
 
109
 
110
  except Exception as e:
111
  print(f"Connection failed: {e}")
112
 
 
113
  if __name__ == "__main__":
114
+ # HF Spaces standard port is 7860
115
+ app.run(host="0.0.0.0", port=7860)