Spaces:
Runtime error
Runtime error
| # webapp.py | |
| import asyncio | |
| import base64 | |
| import json | |
| import os | |
| import tempfile | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect, UploadFile, File | |
| from fastapi.responses import HTMLResponse, JSONResponse | |
| from fastapi.staticfiles import StaticFiles | |
| import uvicorn | |
| from handler import AudioLoop # Import your AudioLoop from above | |
| app = FastAPI() | |
| # Mount the web_ui directory to serve static files | |
| current_dir = os.path.dirname(os.path.realpath(__file__)) | |
| app.mount("/web_ui", StaticFiles(directory=current_dir), name="web_ui") | |
| async def get_index(): | |
| # Read and return the index.html file | |
| index_path = os.path.join(current_dir, "index.html") | |
| with open(index_path, "r", encoding="utf-8") as f: | |
| html_content = f.read() | |
| return HTMLResponse(content=html_content) | |
| async def upload_json_file(file: UploadFile = File(...)): | |
| try: | |
| # Create a temporary file to store the uploaded content | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".json") as temp_file: | |
| # Write the content to the temp file | |
| content = await file.read() | |
| temp_file.write(content) | |
| file_path = temp_file.name | |
| # Parse the JSON to validate it | |
| try: | |
| json_content = json.loads(content) | |
| except json.JSONDecodeError: | |
| return JSONResponse(status_code=400, content={"message": "Invalid JSON file"}) | |
| # Store the file path or content for later retrieval | |
| # You could use a database or in-memory store for a production app | |
| # For simplicity, we'll just return the content | |
| return {"message": "JSON file uploaded successfully", "file_path": file_path, "content": json_content} | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content={"message": f"Error uploading file: {str(e)}"}) | |
| async def websocket_endpoint(websocket: WebSocket): | |
| await websocket.accept() | |
| print("[websocket_endpoint] Client connected.") | |
| # Create a new AudioLoop instance for this client | |
| audio_loop = AudioLoop() | |
| audio_ordering_buffer = {} | |
| expected_audio_seq = 0 | |
| # Start the AudioLoop for this client | |
| loop_task = asyncio.create_task(audio_loop.run()) | |
| print("[websocket_endpoint] Started new AudioLoop for client") | |
| async def from_client_to_gemini(): | |
| """Handles incoming messages from the client and forwards them to Gemini.""" | |
| nonlocal audio_ordering_buffer, expected_audio_seq | |
| try: | |
| while True: | |
| data = await websocket.receive_text() | |
| msg = json.loads(data) | |
| msg_type = msg.get("type") | |
| #print("[from_client_to_gemini] Received message from client:", msg) | |
| # Handle audio data from client | |
| if msg_type == "audio": | |
| raw_pcm = base64.b64decode(msg["payload"]) | |
| forward_msg = { | |
| "realtime_input": { | |
| "media_chunks": [ | |
| { | |
| "data": base64.b64encode(raw_pcm).decode(), | |
| "mime_type": "audio/pcm" | |
| } | |
| ] | |
| } | |
| } | |
| # Retrieve the sequence number from the message | |
| seq = msg.get("seq") | |
| if seq is not None: | |
| # Store the message in the buffer | |
| audio_ordering_buffer[seq] = forward_msg | |
| # Forward any messages in order | |
| while expected_audio_seq in audio_ordering_buffer: | |
| msg_to_forward = audio_ordering_buffer.pop(expected_audio_seq) | |
| await audio_loop.out_queue.put(msg_to_forward) | |
| expected_audio_seq += 1 | |
| else: | |
| # If no sequence number is provided, forward immediately | |
| await audio_loop.out_queue.put(forward_msg) | |
| # Handle text data from client | |
| elif msg_type == "text": | |
| user_text = msg.get("content", "") | |
| print("[from_client_to_gemini] Forwarding user text to Gemini:", user_text) | |
| forward_msg = { | |
| "client_content": { | |
| "turn_complete": True, | |
| "turns": [ | |
| { | |
| "role": "user", | |
| "parts": [ | |
| {"text": user_text} | |
| ] | |
| } | |
| ] | |
| } | |
| } | |
| await audio_loop.out_queue.put(forward_msg) | |
| # Handle JSON data from client | |
| elif msg_type == "json": | |
| json_data = msg.get("content", {}) | |
| print("[from_client_to_gemini] Forwarding JSON data to Gemini:", json_data) | |
| # Format the message to include both the JSON data and a prompt | |
| json_prompt = f"The user has shared the following JSON data with you. Please analyze it and respond appropriately:\n\n{json.dumps(json_data, indent=2)}" | |
| forward_msg = { | |
| "client_content": { | |
| "turn_complete": True, | |
| "turns": [ | |
| { | |
| "role": "user", | |
| "parts": [ | |
| {"text": json_prompt} | |
| ] | |
| } | |
| ] | |
| } | |
| } | |
| await audio_loop.out_queue.put(forward_msg) | |
| else: | |
| print("[from_client_to_gemini] Unknown message type:", msg_type) | |
| except WebSocketDisconnect: | |
| print("[from_client_to_gemini] Client disconnected.") | |
| #del audio_loop | |
| loop_task.cancel() | |
| except Exception as e: | |
| print("[from_client_to_gemini] Error:", e) | |
| async def from_gemini_to_client(): | |
| """Reads PCM audio from Gemini and sends it back to the client.""" | |
| try: | |
| while True: | |
| pcm_data = await audio_loop.audio_in_queue.get() | |
| b64_pcm = base64.b64encode(pcm_data).decode() | |
| out_msg = { | |
| "type": "audio", | |
| "payload": b64_pcm | |
| } | |
| print("[from_gemini_to_client] Sending audio chunk to client. Size:", len(pcm_data)) | |
| await websocket.send_text(json.dumps(out_msg)) | |
| except WebSocketDisconnect: | |
| print("[from_gemini_to_client] Client disconnected.") | |
| audio_loop.stop() | |
| except Exception as e: | |
| print("[from_gemini_to_client] Error:", e) | |
| # Launch both tasks concurrently. If either fails or disconnects, we exit. | |
| try: | |
| await asyncio.gather( | |
| from_client_to_gemini(), | |
| from_gemini_to_client(), | |
| ) | |
| finally: | |
| print("[websocket_endpoint] WebSocket handler finished.") | |
| # Clean up the AudioLoop when the client disconnects | |
| loop_task.cancel() | |
| try: | |
| await loop_task | |
| except asyncio.CancelledError: | |
| pass | |
| print("[websocket_endpoint] Cleaned up AudioLoop for client") | |
| if __name__ == "__main__": | |
| uvicorn.run("webapp:app", host="0.0.0.0", port=7860, reload=True) |