ethanrom's picture
Update app/main.py
4860283 verified
"""Main entry point for the FastAPI application."""
import asyncio
import os
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from app.routers import admin, aider
from app.services.session_manager import session_manager
RESTRICTED_COMMANDS = {
"/model", # Prevents model switching
"/tokens",
}
app = FastAPI(title="Aider WebSocket Server")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Frontend URL
allow_credentials=True, # Required for WebSocket with credentials
allow_methods=["*"], # Allow all HTTP methods (GET, POST, etc.)
allow_headers=["*"], # Allow all headers
)
app.include_router(admin.router)
app.include_router(aider.router)
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket) -> None:
"""Websocket endpoint for handling commands and sessions."""
await websocket.accept()
try:
data = await websocket.receive_json()
repo_url = data.get("repo_url")
try:
session_id = await session_manager.create_session(repo_url=repo_url)
await websocket.send_json(
{"type": "status", "content": f"Session created with ID: {session_id}"}
)
except Exception as e:
await websocket.send_json(
{"type": "error", "content": f"Failed to create session: {str(e)}"}
)
return
aider_service = session_manager.get_session(session_id)
output_task = asyncio.create_task(aider_service.process_output(websocket))
while True:
try:
data = await websocket.receive_json()
command = data.get("command")
if command == "send_to_aider":
message = data.get("message", "")
if message.startswith("/"):
cmd = message.split()[0].lower()
if cmd in RESTRICTED_COMMANDS:
await websocket.send_json(
{
"type": "error",
"content": f"Command '{cmd}' is not allowed in this environment.",
}
)
continue
await aider_service.send_command(message)
elif command == "list_files":
files = await aider_service.list_files()
await websocket.send_json({"type": "files", "content": files})
elif command == "read_file":
filename = data.get("filename")
if filename:
content = await aider_service.read_file(filename)
await websocket.send_json(
{"type": "file_content", "content": content}
)
else:
await websocket.send_json(
{"type": "error", "content": "Filename not provided"}
)
elif command == "download":
zip_path = await aider_service.create_download_zip()
if zip_path:
try:
with open(zip_path, "rb") as f:
zip_content = f.read()
await websocket.send_bytes(zip_content)
finally:
if os.path.exists(zip_path):
os.remove(zip_path)
else:
await websocket.send_json(
{
"type": "error",
"content": "Failed to create download zip",
}
)
elif command.startswith("/"):
cmd = command.split()[0].lower()
if cmd in RESTRICTED_COMMANDS:
await websocket.send_json(
{
"type": "error",
"content": f"Command '{cmd}' is not allowed in this environment.",
}
)
continue
command_str = command
if "args" in data:
args = data.get("args", [])
command_str += " " + " ".join(args)
await aider_service.send_command(command_str)
else:
await websocket.send_json(
{"type": "error", "content": "Unknown command"}
)
except WebSocketDisconnect:
break
except Exception as e:
await websocket.send_json({"type": "error", "content": str(e)})
except Exception as e:
await websocket.send_json({"type": "error", "content": str(e)})
finally:
output_task.cancel()
await session_manager.cleanup_session(session_id)
@app.on_event("shutdown")
async def shutdown_event() -> None:
"""Clean up all sessions when the server shuts down."""
await session_manager.cleanup_all_sessions()