File size: 7,124 Bytes
76fb0a8
cc7c705
 
 
 
 
 
 
4fac3a3
cc7c705
4ee8f0b
4fac3a3
 
cc7c705
 
 
 
 
 
 
 
4fac3a3
 
76fb0a8
 
cc7c705
 
 
4fac3a3
cc7c705
4fac3a3
 
 
 
 
 
 
 
76fb0a8
4fac3a3
 
76fb0a8
4fac3a3
 
cc7c705
 
 
 
 
 
 
 
 
4fac3a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc7c705
 
76fb0a8
4fac3a3
 
 
 
 
 
76fb0a8
 
4fac3a3
76fb0a8
4fac3a3
 
76fb0a8
 
 
 
4fac3a3
 
 
 
 
 
76fb0a8
 
4fac3a3
76fb0a8
4fac3a3
76fb0a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4fac3a3
 
 
 
 
 
 
 
 
 
 
 
 
cc7c705
 
 
 
4fac3a3
76fb0a8
4fac3a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b731f8
4fac3a3
 
 
 
cc7c705
4fac3a3
cc7c705
4fac3a3
cc7c705
4fac3a3
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
# basic_handler.py
import asyncio
import base64
import json
import os
import traceback
from websockets.asyncio.client import connect

# Configuration for Gemini API
host = "generativelanguage.googleapis.com"
model = "gemini-2.0-flash-live-001"  # You can change this to a different model if needed
api_key_env = os.environ.get("GOOGLE_API_KEY", "")
uri_template = f"wss://{host}/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key={{api_key}}"

class AudioLoop:
    def __init__(self):
        self.ws = None
        # Queue for messages to be sent *to* Gemini
        self.out_queue = asyncio.Queue()
        # Queue for PCM audio received *from* Gemini
        self.audio_in_queue = asyncio.Queue()
        # Flag to signal shutdown
        self.shutdown_event = asyncio.Event()
        
    async def startup(self, api_key=None):
        """Send the model setup message to Gemini.
        
        Args:
            api_key: API key to use (overrides environment variable)
        """
        # Use provided API key or fallback to environment variable
        key = api_key or api_key_env
        if not key:
            raise ValueError("No API key provided and GOOGLE_API_KEY environment variable not set")
        
        uri = uri_template.format(api_key=key)
        self.ws = await connect(uri, additional_headers={"Content-Type": "application/json"})

        # Absolutely minimal setup message
        setup_msg = {
            "setup": {
                "model": f"models/{model}"
            }
        }
            
        await self.ws.send(json.dumps(setup_msg))

        raw_response = await self.ws.recv()
        setup_response = json.loads(raw_response)
        print("[AudioLoop] Setup response from Gemini:", setup_response)

    async def send_realtime(self):
        """Read from out_queue and forward those messages to Gemini in real time."""
        try:
            while not self.shutdown_event.is_set():
                # Get next message from queue with timeout
                try:
                    msg = await asyncio.wait_for(self.out_queue.get(), 0.5)
                    await self.ws.send(json.dumps(msg))
                except asyncio.TimeoutError:
                    # No message in queue, continue checking
                    continue
        except asyncio.CancelledError:
            print("[AudioLoop] send_realtime task cancelled")
        except Exception as e:
            print(f"[AudioLoop] Error in send_realtime: {e}")
            traceback.print_exc()
        finally:
            print("[AudioLoop] send_realtime task ended")

    async def receive_audio(self):
        """Read from Gemini websocket and process responses."""
        try:
            while not self.shutdown_event.is_set():
                try:
                    raw_response = await asyncio.wait_for(self.ws.recv(), 0.5)
                    response = json.loads(raw_response)
                    
                    # Print for debugging
                    print(f"[AudioLoop] Received response: {json.dumps(response)[:500]}...")
                    
                    # Process audio data if present
                    try:
                        # Check for inline PCM data
                        if ("serverContent" in response and 
                            "modelTurn" in response["serverContent"] and 
                            "parts" in response["serverContent"]["modelTurn"]):
                            
                            parts = response["serverContent"]["modelTurn"]["parts"]
                            for part in parts:
                                if "inlineData" in part and "data" in part["inlineData"]:
                                    b64data = part["inlineData"]["data"]
                                    pcm_data = base64.b64decode(b64data)
                                    await self.audio_in_queue.put(pcm_data)
                    except Exception as e:
                        print(f"[AudioLoop] Error extracting audio: {e}")
                    
                    # Handle tool calls if present
                    tool_call = response.pop('toolCall', None)
                    if tool_call:
                        print(f"[AudioLoop] Tool call received: {tool_call}")
                        # Send simple OK response for now
                        for fc in tool_call.get('functionCalls', []):
                            resp_msg = {
                                'tool_response': {
                                    'function_responses': [{
                                        'id': fc.get('id', ''),
                                        'name': fc.get('name', ''),
                                        'response': {'result': {'string_value': 'ok'}}
                                    }]
                                }
                            }
                            await self.ws.send(json.dumps(resp_msg))
                
                except asyncio.TimeoutError:
                    # No message received, continue checking
                    continue
                except Exception as e:
                    print(f"[AudioLoop] Error processing message: {e}")
                    traceback.print_exc()
        except asyncio.CancelledError:
            print("[AudioLoop] receive_audio task cancelled")
        except Exception as e:
            print(f"[AudioLoop] Error in receive_audio: {e}")
            traceback.print_exc()
        finally:
            print("[AudioLoop] receive_audio task ended")

    async def run(self):
        """Main entry point: connects to Gemini, starts send/receive tasks."""
        try:
            # Initialize the connection with Gemini
            await self.startup()
            
            # Start processing tasks
            try:
                # Create tasks for sending and receiving data
                send_task = asyncio.create_task(self.send_realtime())
                receive_task = asyncio.create_task(self.receive_audio())
                
                # Wait for shutdown event
                await self.shutdown_event.wait()
                
                # Cancel tasks
                send_task.cancel()
                receive_task.cancel()
                
                # Wait for tasks to complete
                await asyncio.gather(send_task, receive_task, return_exceptions=True)
            finally:
                # Clean up connection
                try:
                    await self.ws.close()
                    print("[AudioLoop] Closed WebSocket connection")
                except Exception as e:
                    print(f"[AudioLoop] Error closing Gemini connection: {e}")
        except asyncio.CancelledError:
            print("[AudioLoop] run task cancelled")
        except Exception as e:
            print(f"[AudioLoop] Error in run: {e}")
            traceback.print_exc()
        finally:
            print("[AudioLoop] run task ended")
    
    def stop(self):
        """Signal tasks to stop and clean up resources."""
        self.shutdown_event.set()