Spaces:
Sleeping
Sleeping
File size: 7,124 Bytes
76fb0a8 cc7c705 4fac3a3 cc7c705 4ee8f0b 4fac3a3 cc7c705 4fac3a3 76fb0a8 cc7c705 4fac3a3 cc7c705 4fac3a3 76fb0a8 4fac3a3 76fb0a8 4fac3a3 cc7c705 4fac3a3 cc7c705 76fb0a8 4fac3a3 76fb0a8 4fac3a3 76fb0a8 4fac3a3 76fb0a8 4fac3a3 76fb0a8 4fac3a3 76fb0a8 4fac3a3 76fb0a8 4fac3a3 cc7c705 4fac3a3 76fb0a8 4fac3a3 9b731f8 4fac3a3 cc7c705 4fac3a3 cc7c705 4fac3a3 cc7c705 4fac3a3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 | # basic_handler.py
import asyncio
import base64
import json
import os
import traceback
from websockets.asyncio.client import connect
# Configuration for Gemini API
host = "generativelanguage.googleapis.com"
model = "gemini-2.0-flash-live-001" # You can change this to a different model if needed
api_key_env = os.environ.get("GOOGLE_API_KEY", "")
uri_template = f"wss://{host}/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key={{api_key}}"
class AudioLoop:
def __init__(self):
self.ws = None
# Queue for messages to be sent *to* Gemini
self.out_queue = asyncio.Queue()
# Queue for PCM audio received *from* Gemini
self.audio_in_queue = asyncio.Queue()
# Flag to signal shutdown
self.shutdown_event = asyncio.Event()
async def startup(self, api_key=None):
"""Send the model setup message to Gemini.
Args:
api_key: API key to use (overrides environment variable)
"""
# Use provided API key or fallback to environment variable
key = api_key or api_key_env
if not key:
raise ValueError("No API key provided and GOOGLE_API_KEY environment variable not set")
uri = uri_template.format(api_key=key)
self.ws = await connect(uri, additional_headers={"Content-Type": "application/json"})
# Absolutely minimal setup message
setup_msg = {
"setup": {
"model": f"models/{model}"
}
}
await self.ws.send(json.dumps(setup_msg))
raw_response = await self.ws.recv()
setup_response = json.loads(raw_response)
print("[AudioLoop] Setup response from Gemini:", setup_response)
async def send_realtime(self):
"""Read from out_queue and forward those messages to Gemini in real time."""
try:
while not self.shutdown_event.is_set():
# Get next message from queue with timeout
try:
msg = await asyncio.wait_for(self.out_queue.get(), 0.5)
await self.ws.send(json.dumps(msg))
except asyncio.TimeoutError:
# No message in queue, continue checking
continue
except asyncio.CancelledError:
print("[AudioLoop] send_realtime task cancelled")
except Exception as e:
print(f"[AudioLoop] Error in send_realtime: {e}")
traceback.print_exc()
finally:
print("[AudioLoop] send_realtime task ended")
async def receive_audio(self):
"""Read from Gemini websocket and process responses."""
try:
while not self.shutdown_event.is_set():
try:
raw_response = await asyncio.wait_for(self.ws.recv(), 0.5)
response = json.loads(raw_response)
# Print for debugging
print(f"[AudioLoop] Received response: {json.dumps(response)[:500]}...")
# Process audio data if present
try:
# Check for inline PCM data
if ("serverContent" in response and
"modelTurn" in response["serverContent"] and
"parts" in response["serverContent"]["modelTurn"]):
parts = response["serverContent"]["modelTurn"]["parts"]
for part in parts:
if "inlineData" in part and "data" in part["inlineData"]:
b64data = part["inlineData"]["data"]
pcm_data = base64.b64decode(b64data)
await self.audio_in_queue.put(pcm_data)
except Exception as e:
print(f"[AudioLoop] Error extracting audio: {e}")
# Handle tool calls if present
tool_call = response.pop('toolCall', None)
if tool_call:
print(f"[AudioLoop] Tool call received: {tool_call}")
# Send simple OK response for now
for fc in tool_call.get('functionCalls', []):
resp_msg = {
'tool_response': {
'function_responses': [{
'id': fc.get('id', ''),
'name': fc.get('name', ''),
'response': {'result': {'string_value': 'ok'}}
}]
}
}
await self.ws.send(json.dumps(resp_msg))
except asyncio.TimeoutError:
# No message received, continue checking
continue
except Exception as e:
print(f"[AudioLoop] Error processing message: {e}")
traceback.print_exc()
except asyncio.CancelledError:
print("[AudioLoop] receive_audio task cancelled")
except Exception as e:
print(f"[AudioLoop] Error in receive_audio: {e}")
traceback.print_exc()
finally:
print("[AudioLoop] receive_audio task ended")
async def run(self):
"""Main entry point: connects to Gemini, starts send/receive tasks."""
try:
# Initialize the connection with Gemini
await self.startup()
# Start processing tasks
try:
# Create tasks for sending and receiving data
send_task = asyncio.create_task(self.send_realtime())
receive_task = asyncio.create_task(self.receive_audio())
# Wait for shutdown event
await self.shutdown_event.wait()
# Cancel tasks
send_task.cancel()
receive_task.cancel()
# Wait for tasks to complete
await asyncio.gather(send_task, receive_task, return_exceptions=True)
finally:
# Clean up connection
try:
await self.ws.close()
print("[AudioLoop] Closed WebSocket connection")
except Exception as e:
print(f"[AudioLoop] Error closing Gemini connection: {e}")
except asyncio.CancelledError:
print("[AudioLoop] run task cancelled")
except Exception as e:
print(f"[AudioLoop] Error in run: {e}")
traceback.print_exc()
finally:
print("[AudioLoop] run task ended")
def stop(self):
"""Signal tasks to stop and clean up resources."""
self.shutdown_event.set() |