File size: 4,985 Bytes
dcee3d3 13507d6 dcee3d3 13507d6 dcee3d3 13507d6 dcee3d3 13507d6 0c8a432 13507d6 dcee3d3 0c8a432 dcee3d3 0c8a432 13507d6 dcee3d3 13507d6 0c8a432 13507d6 dcee3d3 13507d6 dcee3d3 13507d6 dcee3d3 13507d6 b358ab7 aa10350 b358ab7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 | import json
from typing import Optional
from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel
from models import PRAction, PRObservation, PRState
from server.action_normalizer import normalize_action_payload
from server.environment import PRReviewEnvironment
app = FastAPI(
title="PR Review Negotiation Environment",
version="1.0.0",
description="A multi-turn pull request review negotiation benchmark for engineering judgment.",
)
env = PRReviewEnvironment()
class ResetRequest(BaseModel):
task_name: Optional[str] = "single-pass-review"
class CustomTaskConfig(BaseModel):
diff: str
pr_title: Optional[str] = "Custom Review Session"
pr_description: Optional[str] = "User-provided code snippet for review."
class StepResponse(BaseModel):
observation: PRObservation
reward: float
done: bool
info: dict
def _task_metadata(task_name: str, task: dict) -> dict:
return {
"name": task_name,
"pr_title": task.get("pr_title", task_name),
"pr_description": task.get("pr_description", ""),
"max_turns": task.get("max_turns", 1),
}
def _model_schema(model: type[BaseModel]) -> dict:
if hasattr(model, "model_json_schema"):
return model.model_json_schema()
return model.schema()
async def _read_payload(request: Request):
body = await request.body()
if not body:
raise HTTPException(status_code=400, detail="Request body is required.")
try:
return json.loads(body)
except json.JSONDecodeError:
return body.decode("utf-8", errors="replace")
@app.get("/")
def index():
return {
"message": "Backend API is running!",
"action": "Visit the dashboard at http://localhost:3000 locally or http://localhost:7860 in Docker.",
"api_docs": "/docs"
}
@app.get("/health")
def health():
return {"status": "healthy"}
@app.get("/metadata")
def metadata():
return {
"name": "pr-review-env",
"description": "A multi-turn pull request review negotiation benchmark for root-cause depth, false-fix resistance, and escalation judgment.",
"version": "1.0.0",
"author": "Levi710",
}
@app.get("/schema")
def schema():
return {
"action": _model_schema(PRAction),
"observation": _model_schema(PRObservation),
"state": _model_schema(PRState),
}
@app.post("/mcp")
async def mcp(request: Request):
try:
payload = await request.json()
except Exception:
payload = {}
return {
"jsonrpc": "2.0",
"id": payload.get("id") if isinstance(payload, dict) else None,
"error": {
"code": -32601,
"message": "MCP tools are not implemented for this environment.",
},
}
@app.get("/tasks")
def tasks():
from server.environment import TASKS
return {"tasks": [_task_metadata(name, task) for name, task in TASKS.items()]}
@app.post("/config/custom")
def set_custom_task(config: CustomTaskConfig):
from server.tasks import custom
custom.TASK["diff"] = config.diff
custom.TASK["pr_title"] = config.pr_title
custom.TASK["pr_description"] = config.pr_description
return {"status": "success"}
@app.post("/reset", response_model=PRObservation)
def reset(req: ResetRequest = ResetRequest()):
from server.environment import TASKS
if req.task_name not in TASKS:
raise HTTPException(status_code=404, detail=f"Unknown task: {req.task_name}")
return env.reset(task_name=req.task_name)
@app.post("/step", response_model=StepResponse)
async def step(request: Request):
if env._state is None or env._state.done:
raise HTTPException(status_code=400, detail="Call /reset first or episode is done.")
payload = await _read_payload(request)
try:
action = normalize_action_payload(payload)
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Invalid action payload: {exc}") from exc
obs, reward, done, info = env.step(action)
return StepResponse(observation=obs, reward=reward, done=done, info=info)
@app.get("/state", response_model=PRState)
def state():
if env._state is None:
raise HTTPException(status_code=400, detail="No active episode. Call /reset first.")
return env.state()
@app.post("/diff")
async def generate_diff(payload: dict):
old_code = payload.get("old_code", "")
new_code = payload.get("new_code", "")
filename = payload.get("filename", "file.py")
import difflib
old_lines = old_code.splitlines(keepends=True)
new_lines = new_code.splitlines(keepends=True)
diff = difflib.unified_diff(
old_lines, new_lines,
fromfile=f"a/{filename}",
tofile=f"b/{filename}"
)
return {"diff": "".join(diff)}
def main():
import uvicorn
import os
port = int(os.environ.get("PORT", 8000))
uvicorn.run(app, host="0.0.0.0", port=port)
if __name__ == "__main__":
main()
|