import asyncio import logging import os import sys import traceback from dotenv import load_dotenv from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse from google import genai from google.api_core import exceptions as google_exceptions # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) load_dotenv() if sys.version_info < (3, 11, 0): import taskgroup, exceptiongroup asyncio.TaskGroup = taskgroup.TaskGroup asyncio.ExceptionGroup = exceptiongroup.ExceptionGroup # Audio settings # FORMAT = pyaudio.paInt16 # Removed pyaudio dependency # CHANNELS = 1 # SEND_SAMPLE_RATE = 16000 # RECEIVE_SAMPLE_RATE = 24000 # CHUNK_SIZE = 1024 # Load configuration from environment variables GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY") MODEL = os.environ.get("GEMINI_MODEL", "models/gemini-2.0-flash-live-001") # Configure the client with the API key try: if not GOOGLE_API_KEY or GOOGLE_API_KEY == "YOUR_API_KEY_HERE": # In HF Spaces, we might set this via secrets, so we warn but don't exit immediately if it's missing during build logger.warning("GOOGLE_API_KEY environment variable not set or is a placeholder.") client = genai.Client(api_key=GOOGLE_API_KEY) except (KeyError, ValueError) as e: logger.critical(f"Error: {e}. Please set the GOOGLE_API_KEY environment variable.") # sys.exit(1) # Don't exit, let it fail at runtime if key is missing, to allow build to pass CONFIG = { "response_modalities": ["AUDIO"], "output_audio_transcription": {}, "generation_config": { "temperature": 1.0, }, } # pya = pyaudio.PyAudio() # Removed pyaudio dependency app = FastAPI() # Mount static files # We assume the frontend build will be copied to 'static' directory in the container if os.path.exists("static"): app.mount("/assets", StaticFiles(directory="static/assets"), name="assets") @app.get("/") async def get(): # Serve the index.html from the static directory if os.path.exists("static/index.html"): return FileResponse("static/index.html") return HTMLResponse("

Frontend not found. Please build the frontend.

") class AudioLoop: def __init__(self, websocket: WebSocket): self.websocket = websocket self.session = None async def run(self): try: async with client.aio.live.connect(model=MODEL, config=CONFIG) as session: self.session = session logger.info("Gemini Live API session started.") async with asyncio.TaskGroup() as tg: tg.create_task(self.receive_from_gemini()) tg.create_task(self.send_to_gemini()) except asyncio.CancelledError: logger.info("Audio loop cancelled.") except google_exceptions.GoogleAPICallError as e: logger.error(f"Google API call error in audio loop: {e}") await self.websocket.close(code=1011, reason=f"Google API Error: {e}") except Exception as e: logger.error(f"An error occurred in the audio loop: {e}") traceback.print_exc() await self.websocket.close(code=1011, reason="Internal Server Error") async def send_to_gemini(self): """Receives audio from the WebSocket and sends it to the Gemini API.""" while True: try: data = await self.websocket.receive_bytes() if self.session: await self.session.send( input={"data": data, "mime_type": "audio/pcm"} ) except WebSocketDisconnect: logger.info("Client disconnected from WebSocket.") break except Exception as e: logger.error(f"Error receiving from websocket or sending to Gemini: {e}") break async def receive_from_gemini(self): """Receives audio and text from the Gemini API and forwards it to the WebSocket.""" while True: try: if self.session: turn = self.session.receive() async for response in turn: # Handle audio data directly from response.data if data := response.data: await self.websocket.send_bytes(data) continue # Handle text/transcript and potentially nested audio data candidate_texts = [] server_content = ( response.server_content.model_turn.parts if response.server_content and response.server_content.model_turn and response.server_content.model_turn.parts else [] ) for part in server_content: # Check for nested audio data if inline_data := getattr(part, "inline_data", None): if data := getattr(inline_data, "data", None): await self.websocket.send_bytes(data) # Check for text if part_text := getattr(part, "text", None): candidate_texts.append(part_text) server_content_obj = getattr(response, "server_content", None) if server_content_obj: if output_transcription := getattr(server_content_obj, "output_transcription", None): if trans_text := getattr(output_transcription, "text", None): candidate_texts.append(trans_text) if input_transcription := getattr(server_content_obj, "input_transcription", None): if trans_text := getattr(input_transcription, "text", None): candidate_texts.append(trans_text) if output_text := getattr(response, "output_text", None): if isinstance(output_text, (list, tuple)): candidate_texts.extend(output_text) else: candidate_texts.append(output_text) if response_text := getattr(response, "text", None): candidate_texts.append(response_text) for text_chunk in candidate_texts: if not text_chunk: continue normalized = text_chunk.replace("\r", "").replace("\n", " ") if normalized and normalized.strip(): logger.info(f"Received text: {normalized.strip()}") await self.websocket.send_text(normalized) except WebSocketDisconnect: logger.info("Client disconnected. Stopping receive loop.") break except Exception as e: logger.error(f"Error receiving from Gemini or sending to websocket: {e}") break @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): await websocket.accept() logger.info("WebSocket connection accepted.") audio_loop = AudioLoop(websocket) try: await audio_loop.run() except WebSocketDisconnect: logger.info("Client disconnected.") except Exception as e: logger.error(f"Error in websocket endpoint: {e}") finally: logger.info("WebSocket connection closed.")