import asyncio import json import os import websockets from google import genai import base64 # Load API key from environment (do NOT overwrite it) # Set GOOGLE_API_KEY in Hugging Face Space Secrets if not os.getenv("GOOGLE_API_KEY"): raise RuntimeError("GOOGLE_API_KEY is not set. Add it in Hugging Face Space Secrets.") MODEL = "gemini-2.0-flash-exp" # use your model ID client = genai.Client( http_options={ 'api_version': 'v1alpha', } ) async def gemini_session_handler(client_websocket: websockets.WebSocketServerProtocol): try: config_message = await client_websocket.recv() config_data = json.loads(config_message) config = config_data.get("setup", {}) config["system_instruction"] = """You are a helpful assistant for screen sharing sessions. Your role is to: 1) Analyze and describe the content being shared on screen 2) Answer questions about the shared content 3) Provide relevant information and context about what's being shown 4) Assist with technical issues related to screen sharing 5) Maintain a professional and helpful tone. Focus on being concise and clear in your responses.""" async with client.aio.live.connect(model=MODEL, config=config) as session: print("Connected to Gemini API") async def send_to_gemini(): try: async for message in client_websocket: try: data = json.loads(message) if "realtime_input" in data: for chunk in data["realtime_input"]["media_chunks"]: if chunk["mime_type"] == "audio/pcm": await session.send({"mime_type": "audio/pcm", "data": chunk["data"]}) elif chunk["mime_type"] == "image/jpeg": await session.send({"mime_type": "image/jpeg", "data": chunk["data"]}) except Exception as e: print(f"Error sending to Gemini: {e}") print("Client connection closed (send)") except Exception as e: print(f"Error sending to Gemini: {e}") finally: print("send_to_gemini closed") async def receive_from_gemini(): try: while True: try: print("receiving from gemini") async for response in session.receive(): if response.server_content is None: print(f'Unhandled server message! - {response}') continue model_turn = response.server_content.model_turn if model_turn: for part in model_turn.parts: if hasattr(part, 'text') and part.text is not None: await client_websocket.send(json.dumps({"text": part.text})) elif hasattr(part, 'inline_data') and part.inline_data is not None: print("audio mime_type:", part.inline_data.mime_type) base64_audio = base64.b64encode(part.inline_data.data).decode('utf-8') await client_websocket.send(json.dumps({"audio": base64_audio})) print("audio received") if response.server_content.turn_complete: print('\n') except websockets.exceptions.ConnectionClosedOK: print("Client connection closed normally (receive)") break except Exception as e: print(f"Error receiving from Gemini: {e}") break except Exception as e: print(f"Error receiving from Gemini: {e}") finally: print("Gemini connection closed (receive)") send_task = asyncio.create_task(send_to_gemini()) receive_task = asyncio.create_task(receive_from_gemini()) await asyncio.gather(send_task, receive_task) except Exception as e: print(f"Error in Gemini session: {e}") finally: print("Gemini session closed.") async def main() -> None: async with websockets.serve(gemini_session_handler, "localhost", 9083): print("Running websocket server localhost:9083...") await asyncio.Future() if __name__ == "__main__": asyncio.run(main())