geminilivejson / webapp.py
nihalaninihal's picture
Update webapp.py
72c446d verified
# webapp.py
import asyncio
import base64
import json
import os
import tempfile
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, UploadFile, File
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
import uvicorn
from handler import AudioLoop # Import your AudioLoop from above
app = FastAPI()
# Mount the web_ui directory to serve static files
current_dir = os.path.dirname(os.path.realpath(__file__))
app.mount("/web_ui", StaticFiles(directory=current_dir), name="web_ui")
@app.get("/")
async def get_index():
# Read and return the index.html file
index_path = os.path.join(current_dir, "index.html")
with open(index_path, "r", encoding="utf-8") as f:
html_content = f.read()
return HTMLResponse(content=html_content)
@app.post("/upload_json")
async def upload_json_file(file: UploadFile = File(...)):
try:
# Create a temporary file to store the uploaded content
with tempfile.NamedTemporaryFile(delete=False, suffix=".json") as temp_file:
# Write the content to the temp file
content = await file.read()
temp_file.write(content)
file_path = temp_file.name
# Parse the JSON to validate it
try:
json_content = json.loads(content)
except json.JSONDecodeError:
return JSONResponse(status_code=400, content={"message": "Invalid JSON file"})
# Store the file path or content for later retrieval
# You could use a database or in-memory store for a production app
# For simplicity, we'll just return the content
return {"message": "JSON file uploaded successfully", "file_path": file_path, "content": json_content}
except Exception as e:
return JSONResponse(status_code=500, content={"message": f"Error uploading file: {str(e)}"})
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
print("[websocket_endpoint] Client connected.")
# Create a new AudioLoop instance for this client
audio_loop = AudioLoop()
audio_ordering_buffer = {}
expected_audio_seq = 0
# Start the AudioLoop for this client
loop_task = asyncio.create_task(audio_loop.run())
print("[websocket_endpoint] Started new AudioLoop for client")
async def from_client_to_gemini():
"""Handles incoming messages from the client and forwards them to Gemini."""
nonlocal audio_ordering_buffer, expected_audio_seq
try:
while True:
data = await websocket.receive_text()
msg = json.loads(data)
msg_type = msg.get("type")
#print("[from_client_to_gemini] Received message from client:", msg)
# Handle audio data from client
if msg_type == "audio":
raw_pcm = base64.b64decode(msg["payload"])
forward_msg = {
"realtime_input": {
"media_chunks": [
{
"data": base64.b64encode(raw_pcm).decode(),
"mime_type": "audio/pcm"
}
]
}
}
# Retrieve the sequence number from the message
seq = msg.get("seq")
if seq is not None:
# Store the message in the buffer
audio_ordering_buffer[seq] = forward_msg
# Forward any messages in order
while expected_audio_seq in audio_ordering_buffer:
msg_to_forward = audio_ordering_buffer.pop(expected_audio_seq)
await audio_loop.out_queue.put(msg_to_forward)
expected_audio_seq += 1
else:
# If no sequence number is provided, forward immediately
await audio_loop.out_queue.put(forward_msg)
# Handle text data from client
elif msg_type == "text":
user_text = msg.get("content", "")
print("[from_client_to_gemini] Forwarding user text to Gemini:", user_text)
forward_msg = {
"client_content": {
"turn_complete": True,
"turns": [
{
"role": "user",
"parts": [
{"text": user_text}
]
}
]
}
}
await audio_loop.out_queue.put(forward_msg)
# Handle JSON data from client
elif msg_type == "json":
json_data = msg.get("content", {})
print("[from_client_to_gemini] Forwarding JSON data to Gemini:", json_data)
# Format the message to include both the JSON data and a prompt
json_prompt = f"The user has shared the following JSON data with you. Please analyze it and respond appropriately:\n\n{json.dumps(json_data, indent=2)}"
forward_msg = {
"client_content": {
"turn_complete": True,
"turns": [
{
"role": "user",
"parts": [
{"text": json_prompt}
]
}
]
}
}
await audio_loop.out_queue.put(forward_msg)
else:
print("[from_client_to_gemini] Unknown message type:", msg_type)
except WebSocketDisconnect:
print("[from_client_to_gemini] Client disconnected.")
#del audio_loop
loop_task.cancel()
except Exception as e:
print("[from_client_to_gemini] Error:", e)
async def from_gemini_to_client():
"""Reads PCM audio from Gemini and sends it back to the client."""
try:
while True:
pcm_data = await audio_loop.audio_in_queue.get()
b64_pcm = base64.b64encode(pcm_data).decode()
out_msg = {
"type": "audio",
"payload": b64_pcm
}
print("[from_gemini_to_client] Sending audio chunk to client. Size:", len(pcm_data))
await websocket.send_text(json.dumps(out_msg))
except WebSocketDisconnect:
print("[from_gemini_to_client] Client disconnected.")
audio_loop.stop()
except Exception as e:
print("[from_gemini_to_client] Error:", e)
# Launch both tasks concurrently. If either fails or disconnects, we exit.
try:
await asyncio.gather(
from_client_to_gemini(),
from_gemini_to_client(),
)
finally:
print("[websocket_endpoint] WebSocket handler finished.")
# Clean up the AudioLoop when the client disconnects
loop_task.cancel()
try:
await loop_task
except asyncio.CancelledError:
pass
print("[websocket_endpoint] Cleaned up AudioLoop for client")
if __name__ == "__main__":
uvicorn.run("webapp:app", host="0.0.0.0", port=7860, reload=True)