File size: 5,370 Bytes
acd245a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4860283
acd245a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4860283
 
 
 
acd245a
 
 
 
 
 
 
 
 
 
 
4860283
acd245a
 
 
 
 
 
 
 
 
 
 
4860283
acd245a
 
 
 
4860283
acd245a
4860283
 
acd245a
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""Main entry point for the FastAPI application."""

import asyncio
import os

from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware

from app.routers import admin, aider
from app.services.session_manager import session_manager

RESTRICTED_COMMANDS = {
    "/model",  # Prevents model switching
    "/tokens",
}

app = FastAPI(title="Aider WebSocket Server")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Frontend URL
    allow_credentials=True,  # Required for WebSocket with credentials
    allow_methods=["*"],  # Allow all HTTP methods (GET, POST, etc.)
    allow_headers=["*"],  # Allow all headers
)

app.include_router(admin.router)
app.include_router(aider.router)


@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket) -> None:
    """Websocket endpoint for handling commands and sessions."""
    await websocket.accept()

    try:
        data = await websocket.receive_json()
        repo_url = data.get("repo_url")

        try:
            session_id = await session_manager.create_session(repo_url=repo_url)
            await websocket.send_json(
                {"type": "status", "content": f"Session created with ID: {session_id}"}
            )
        except Exception as e:
            await websocket.send_json(
                {"type": "error", "content": f"Failed to create session: {str(e)}"}
            )
            return

        aider_service = session_manager.get_session(session_id)

        output_task = asyncio.create_task(aider_service.process_output(websocket))

        while True:
            try:
                data = await websocket.receive_json()
                command = data.get("command")

                if command == "send_to_aider":
                    message = data.get("message", "")
                    if message.startswith("/"):
                        cmd = message.split()[0].lower()
                        if cmd in RESTRICTED_COMMANDS:
                            await websocket.send_json(
                                {
                                    "type": "error",
                                    "content": f"Command '{cmd}' is not allowed in this environment.",
                                }
                            )
                            continue
                    await aider_service.send_command(message)
                elif command == "list_files":
                    files = await aider_service.list_files()
                    await websocket.send_json({"type": "files", "content": files})
                elif command == "read_file":
                    filename = data.get("filename")
                    if filename:
                        content = await aider_service.read_file(filename)
                        await websocket.send_json(
                            {"type": "file_content", "content": content}
                        )
                    else:
                        await websocket.send_json(
                            {"type": "error", "content": "Filename not provided"}
                        )
                elif command == "download":
                    zip_path = await aider_service.create_download_zip()
                    if zip_path:
                        try:
                            with open(zip_path, "rb") as f:
                                zip_content = f.read()
                            await websocket.send_bytes(zip_content)
                        finally:
                            if os.path.exists(zip_path):
                                os.remove(zip_path)
                    else:
                        await websocket.send_json(
                            {
                                "type": "error",
                                "content": "Failed to create download zip",
                            }
                        )
                elif command.startswith("/"):
                    cmd = command.split()[0].lower()
                    if cmd in RESTRICTED_COMMANDS:
                        await websocket.send_json(
                            {
                                "type": "error",
                                "content": f"Command '{cmd}' is not allowed in this environment.",
                            }
                        )
                        continue

                    command_str = command
                    if "args" in data:
                        args = data.get("args", [])
                        command_str += " " + " ".join(args)
                    await aider_service.send_command(command_str)
                else:
                    await websocket.send_json(
                        {"type": "error", "content": "Unknown command"}
                    )

            except WebSocketDisconnect:
                break
            except Exception as e:
                await websocket.send_json({"type": "error", "content": str(e)})

    except Exception as e:
        await websocket.send_json({"type": "error", "content": str(e)})
    finally:
        output_task.cancel()
        await session_manager.cleanup_session(session_id)


@app.on_event("shutdown")
async def shutdown_event() -> None:
    """Clean up all sessions when the server shuts down."""
    await session_manager.cleanup_all_sessions()