3v324v23's picture
feat: refresh dashboard UX and backend integration
dcee3d3
import json
from typing import Optional
from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel
from models import PRAction, PRObservation, PRState
from server.action_normalizer import normalize_action_payload
from server.environment import PRReviewEnvironment
app = FastAPI(
title="PR Review Negotiation Environment",
version="1.0.0",
description="A multi-turn pull request review negotiation benchmark for engineering judgment.",
)
env = PRReviewEnvironment()
class ResetRequest(BaseModel):
task_name: Optional[str] = "single-pass-review"
class CustomTaskConfig(BaseModel):
diff: str
pr_title: Optional[str] = "Custom Review Session"
pr_description: Optional[str] = "User-provided code snippet for review."
class StepResponse(BaseModel):
observation: PRObservation
reward: float
done: bool
info: dict
def _task_metadata(task_name: str, task: dict) -> dict:
return {
"name": task_name,
"pr_title": task.get("pr_title", task_name),
"pr_description": task.get("pr_description", ""),
"max_turns": task.get("max_turns", 1),
}
def _model_schema(model: type[BaseModel]) -> dict:
if hasattr(model, "model_json_schema"):
return model.model_json_schema()
return model.schema()
async def _read_payload(request: Request):
body = await request.body()
if not body:
raise HTTPException(status_code=400, detail="Request body is required.")
try:
return json.loads(body)
except json.JSONDecodeError:
return body.decode("utf-8", errors="replace")
@app.get("/")
def index():
return {
"message": "Backend API is running!",
"action": "Visit the dashboard at http://localhost:3000 locally or http://localhost:7860 in Docker.",
"api_docs": "/docs"
}
@app.get("/health")
def health():
return {"status": "healthy"}
@app.get("/metadata")
def metadata():
return {
"name": "pr-review-env",
"description": "A multi-turn pull request review negotiation benchmark for root-cause depth, false-fix resistance, and escalation judgment.",
"version": "1.0.0",
"author": "Levi710",
}
@app.get("/schema")
def schema():
return {
"action": _model_schema(PRAction),
"observation": _model_schema(PRObservation),
"state": _model_schema(PRState),
}
@app.post("/mcp")
async def mcp(request: Request):
try:
payload = await request.json()
except Exception:
payload = {}
return {
"jsonrpc": "2.0",
"id": payload.get("id") if isinstance(payload, dict) else None,
"error": {
"code": -32601,
"message": "MCP tools are not implemented for this environment.",
},
}
@app.get("/tasks")
def tasks():
from server.environment import TASKS
return {"tasks": [_task_metadata(name, task) for name, task in TASKS.items()]}
@app.post("/config/custom")
def set_custom_task(config: CustomTaskConfig):
from server.tasks import custom
custom.TASK["diff"] = config.diff
custom.TASK["pr_title"] = config.pr_title
custom.TASK["pr_description"] = config.pr_description
return {"status": "success"}
@app.post("/reset", response_model=PRObservation)
def reset(req: ResetRequest = ResetRequest()):
from server.environment import TASKS
if req.task_name not in TASKS:
raise HTTPException(status_code=404, detail=f"Unknown task: {req.task_name}")
return env.reset(task_name=req.task_name)
@app.post("/step", response_model=StepResponse)
async def step(request: Request):
if env._state is None or env._state.done:
raise HTTPException(status_code=400, detail="Call /reset first or episode is done.")
payload = await _read_payload(request)
try:
action = normalize_action_payload(payload)
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Invalid action payload: {exc}") from exc
obs, reward, done, info = env.step(action)
return StepResponse(observation=obs, reward=reward, done=done, info=info)
@app.get("/state", response_model=PRState)
def state():
if env._state is None:
raise HTTPException(status_code=400, detail="No active episode. Call /reset first.")
return env.state()
@app.post("/diff")
async def generate_diff(payload: dict):
old_code = payload.get("old_code", "")
new_code = payload.get("new_code", "")
filename = payload.get("filename", "file.py")
import difflib
old_lines = old_code.splitlines(keepends=True)
new_lines = new_code.splitlines(keepends=True)
diff = difflib.unified_diff(
old_lines, new_lines,
fromfile=f"a/{filename}",
tofile=f"b/{filename}"
)
return {"diff": "".join(diff)}
def main():
import uvicorn
import os
port = int(os.environ.get("PORT", 8000))
uvicorn.run(app, host="0.0.0.0", port=port)
if __name__ == "__main__":
main()