"""Main entry point for the FastAPI application.""" import asyncio import os from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware from app.routers import admin, aider from app.services.session_manager import session_manager RESTRICTED_COMMANDS = { "/model", # Prevents model switching "/tokens", } app = FastAPI(title="Aider WebSocket Server") app.add_middleware( CORSMiddleware, allow_origins=["*"], # Frontend URL allow_credentials=True, # Required for WebSocket with credentials allow_methods=["*"], # Allow all HTTP methods (GET, POST, etc.) allow_headers=["*"], # Allow all headers ) app.include_router(admin.router) app.include_router(aider.router) @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket) -> None: """Websocket endpoint for handling commands and sessions.""" await websocket.accept() try: data = await websocket.receive_json() repo_url = data.get("repo_url") try: session_id = await session_manager.create_session(repo_url=repo_url) await websocket.send_json( {"type": "status", "content": f"Session created with ID: {session_id}"} ) except Exception as e: await websocket.send_json( {"type": "error", "content": f"Failed to create session: {str(e)}"} ) return aider_service = session_manager.get_session(session_id) output_task = asyncio.create_task(aider_service.process_output(websocket)) while True: try: data = await websocket.receive_json() command = data.get("command") if command == "send_to_aider": message = data.get("message", "") if message.startswith("/"): cmd = message.split()[0].lower() if cmd in RESTRICTED_COMMANDS: await websocket.send_json( { "type": "error", "content": f"Command '{cmd}' is not allowed in this environment.", } ) continue await aider_service.send_command(message) elif command == "list_files": files = await aider_service.list_files() await websocket.send_json({"type": "files", "content": files}) elif command == "read_file": filename = data.get("filename") if filename: content = await aider_service.read_file(filename) await websocket.send_json( {"type": "file_content", "content": content} ) else: await websocket.send_json( {"type": "error", "content": "Filename not provided"} ) elif command == "download": zip_path = await aider_service.create_download_zip() if zip_path: try: with open(zip_path, "rb") as f: zip_content = f.read() await websocket.send_bytes(zip_content) finally: if os.path.exists(zip_path): os.remove(zip_path) else: await websocket.send_json( { "type": "error", "content": "Failed to create download zip", } ) elif command.startswith("/"): cmd = command.split()[0].lower() if cmd in RESTRICTED_COMMANDS: await websocket.send_json( { "type": "error", "content": f"Command '{cmd}' is not allowed in this environment.", } ) continue command_str = command if "args" in data: args = data.get("args", []) command_str += " " + " ".join(args) await aider_service.send_command(command_str) else: await websocket.send_json( {"type": "error", "content": "Unknown command"} ) except WebSocketDisconnect: break except Exception as e: await websocket.send_json({"type": "error", "content": str(e)}) except Exception as e: await websocket.send_json({"type": "error", "content": str(e)}) finally: output_task.cancel() await session_manager.cleanup_session(session_id) @app.on_event("shutdown") async def shutdown_event() -> None: """Clean up all sessions when the server shuts down.""" await session_manager.cleanup_all_sessions()