Spaces:
Running
Running
File size: 6,571 Bytes
4b1a7ca | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 | import os
import json
import asyncio
import redis.asyncio as redis
from fastapi import FastAPI, Request, BackgroundTasks
from fastapi.responses import HTMLResponse, StreamingResponse
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from review_orchestrator import CodeReviewOrchestrator
from pydantic import BaseModel
from load_dotenv import load_dotenv
load_dotenv()
app = FastAPI()
templates = Jinja2Templates(directory="templates")
# Initialize Orchestrator
orchestrator = CodeReviewOrchestrator()
class ReviewRequest(BaseModel):
repo_url: str
pr_number: int
openai_api_key: str | None = None
mcp_server_url: str | None = None
class MCPRequest(BaseModel):
mcp_server_url: str
@app.get("/", response_class=HTMLResponse)
async def read_root(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/list-tools")
async def list_tools(request: MCPRequest):
from fastmcp import Client
from nmagents.command import ToolList
try:
# Ensure URL ends with /
url = request.mcp_server_url
if not url.endswith("/"):
url = url + "/"
async with Client(url) as client:
# We can't easily use ToolList command here as it returns a formatted string
# We'll use the client directly to list tools if possible, or parse the output
# fastmcp client doesn't expose list_tools directly in a simple way without calling the server
# But nmagents ToolList does exactly that.
tool_list_command = ToolList(client, "List tools")
tools_description = await tool_list_command.execute(None)
return {"status": "success", "tools": tools_description}
except Exception as e:
return {"status": "error", "message": str(e)}
@app.post("/review")
async def trigger_review(request: ReviewRequest, background_tasks: BackgroundTasks):
# Trigger the review in the background
# We need to wrap the async generator to consume it, otherwise it won't run
# We need to get the time_hash to return it, but the orchestrator generates it.
# For now, we will generate it here and pass it, or just return a "latest" indicator.
# Better: Orchestrator's review_pr_stream generates it. We can't easily get it back from a background task.
# Solution: We will generate time_hash here and pass it to orchestrator (need to update orchestrator signature).
from datetime import datetime
time_hash = datetime.now().strftime("%Y%m%d%H%M%S")
# Add run to history immediately
redis_host = os.getenv("REDIS_HOST", "localhost")
redis_port = int(os.getenv("REDIS_PORT", 6380))
r = redis.Redis(host=redis_host, port=redis_port, db=0, decode_responses=True)
repo_name = request.repo_url.rstrip('/').split('/')[-1]
runs_key = f"review:runs:{repo_name}:{request.pr_number}"
await r.sadd(runs_key, time_hash)
await r.close()
background_tasks.add_task(run_review, request.repo_url, request.pr_number, time_hash, request.openai_api_key, request.mcp_server_url)
return {"status": "Review started", "time_hash": time_hash, "stream_url": f"/stream/{repo_name}/{request.pr_number}/{time_hash}"}
async def run_review(repo_url: str, pr_number: int, time_hash: str, api_key: str | None = None, mcp_server_url: str | None = None):
# Consume the generator to ensure it runs
# Note: We need to update orchestrator.review_pr_stream to accept time_hash
async for _ in orchestrator.review_pr_stream(repo_url, pr_number, time_hash, api_key, mcp_server_url):
pass
@app.get("/runs/{repo_name}/{pr_number}")
async def list_runs(repo_name: str, pr_number: int):
redis_host = os.getenv("REDIS_HOST", "localhost")
redis_port = int(os.getenv("REDIS_PORT", 6380))
r = redis.Redis(host=redis_host, port=redis_port, db=0, decode_responses=True)
runs_key = f"review:runs:{repo_name}:{pr_number}"
try:
runs = await r.smembers(runs_key)
return {"runs": sorted(list(runs), reverse=True)}
finally:
await r.close()
@app.get("/runs")
async def list_all_runs():
redis_host = os.getenv("REDIS_HOST", "localhost")
redis_port = int(os.getenv("REDIS_PORT", 6380))
r = redis.Redis(host=redis_host, port=redis_port, db=0, decode_responses=True)
try:
keys = await r.keys("review:runs:*:*")
all_runs = []
for key in keys:
# key format: review:runs:repo_name:pr_number
parts = key.split(":")
if len(parts) >= 4:
repo_name = parts[2]
pr_number = parts[3]
runs = await r.smembers(key)
for run in runs:
all_runs.append({
"repo_name": repo_name,
"pr_number": pr_number,
"time_hash": run
})
# Sort by time_hash descending
all_runs.sort(key=lambda x: x["time_hash"], reverse=True)
return {"runs": all_runs}
finally:
await r.close()
@app.get("/stream/{repo_name}/{pr_number}/{time_hash}")
async def stream_events(repo_name: str, pr_number: int, time_hash: str):
redis_host = os.getenv("REDIS_HOST", "localhost")
redis_port = int(os.getenv("REDIS_PORT", 6380))
r = redis.Redis(host=redis_host, port=redis_port, db=0, decode_responses=True)
stream_key = f"review:stream:{repo_name}:{pr_number}:{time_hash}"
async def event_generator():
last_id = "0-0" # Start from beginning
try:
while True:
# Read new messages
streams = await r.xread({stream_key: last_id}, count=1, block=1000)
if not streams:
# Send a keep-alive comment to prevent timeout
yield ": keep-alive\n\n"
continue
for stream_name, messages in streams:
for message_id, data in messages:
last_id = message_id
# Format as SSE
yield f"data: {json.dumps(data)}\n\n"
except asyncio.CancelledError:
print("Stream cancelled")
finally:
await r.close()
return StreamingResponse(event_generator(), media_type="text/event-stream")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
|