File size: 4,985 Bytes
dcee3d3
13507d6
dcee3d3
 
 
13507d6
dcee3d3
13507d6
 
dcee3d3
 
 
 
 
13507d6
 
 
 
 
0c8a432
 
 
 
 
13507d6
 
 
 
 
 
dcee3d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c8a432
 
 
 
dcee3d3
0c8a432
 
 
13507d6
 
dcee3d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13507d6
0c8a432
 
 
 
 
 
 
 
13507d6
 
dcee3d3
 
 
13507d6
 
 
dcee3d3
13507d6
 
dcee3d3
 
 
 
 
 
13507d6
 
 
 
 
 
 
b358ab7
aa10350
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b358ab7
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import json
from typing import Optional

from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel
from models import PRAction, PRObservation, PRState
from server.action_normalizer import normalize_action_payload
from server.environment import PRReviewEnvironment

app = FastAPI(
    title="PR Review Negotiation Environment",
    version="1.0.0",
    description="A multi-turn pull request review negotiation benchmark for engineering judgment.",
)
env = PRReviewEnvironment()

class ResetRequest(BaseModel):
    task_name: Optional[str] = "single-pass-review"

class CustomTaskConfig(BaseModel):
    diff: str
    pr_title: Optional[str] = "Custom Review Session"
    pr_description: Optional[str] = "User-provided code snippet for review."

class StepResponse(BaseModel):
    observation: PRObservation
    reward: float
    done: bool
    info: dict


def _task_metadata(task_name: str, task: dict) -> dict:
    return {
        "name": task_name,
        "pr_title": task.get("pr_title", task_name),
        "pr_description": task.get("pr_description", ""),
        "max_turns": task.get("max_turns", 1),
    }


def _model_schema(model: type[BaseModel]) -> dict:
    if hasattr(model, "model_json_schema"):
        return model.model_json_schema()
    return model.schema()


async def _read_payload(request: Request):
    body = await request.body()
    if not body:
        raise HTTPException(status_code=400, detail="Request body is required.")
    try:
        return json.loads(body)
    except json.JSONDecodeError:
        return body.decode("utf-8", errors="replace")

@app.get("/")
def index():
    return {
        "message": "Backend API is running!",
        "action": "Visit the dashboard at http://localhost:3000 locally or http://localhost:7860 in Docker.",
        "api_docs": "/docs"
    }

@app.get("/health")
def health():
    return {"status": "healthy"}

@app.get("/metadata")
def metadata():
    return {
        "name": "pr-review-env",
        "description": "A multi-turn pull request review negotiation benchmark for root-cause depth, false-fix resistance, and escalation judgment.",
        "version": "1.0.0",
        "author": "Levi710",
    }

@app.get("/schema")
def schema():
    return {
        "action": _model_schema(PRAction),
        "observation": _model_schema(PRObservation),
        "state": _model_schema(PRState),
    }

@app.post("/mcp")
async def mcp(request: Request):
    try:
        payload = await request.json()
    except Exception:
        payload = {}
    return {
        "jsonrpc": "2.0",
        "id": payload.get("id") if isinstance(payload, dict) else None,
        "error": {
            "code": -32601,
            "message": "MCP tools are not implemented for this environment.",
        },
    }

@app.get("/tasks")
def tasks():
    from server.environment import TASKS
    return {"tasks": [_task_metadata(name, task) for name, task in TASKS.items()]}

@app.post("/config/custom")
def set_custom_task(config: CustomTaskConfig):
    from server.tasks import custom
    custom.TASK["diff"] = config.diff
    custom.TASK["pr_title"] = config.pr_title
    custom.TASK["pr_description"] = config.pr_description
    return {"status": "success"}

@app.post("/reset", response_model=PRObservation)
def reset(req: ResetRequest = ResetRequest()):
    from server.environment import TASKS
    if req.task_name not in TASKS:
        raise HTTPException(status_code=404, detail=f"Unknown task: {req.task_name}")
    return env.reset(task_name=req.task_name)

@app.post("/step", response_model=StepResponse)
async def step(request: Request):
    if env._state is None or env._state.done:
        raise HTTPException(status_code=400, detail="Call /reset first or episode is done.")
    payload = await _read_payload(request)
    try:
        action = normalize_action_payload(payload)
    except Exception as exc:
        raise HTTPException(status_code=400, detail=f"Invalid action payload: {exc}") from exc
    obs, reward, done, info = env.step(action)
    return StepResponse(observation=obs, reward=reward, done=done, info=info)

@app.get("/state", response_model=PRState)
def state():
    if env._state is None:
        raise HTTPException(status_code=400, detail="No active episode. Call /reset first.")
    return env.state()

@app.post("/diff")
async def generate_diff(payload: dict):
    old_code = payload.get("old_code", "")
    new_code = payload.get("new_code", "")
    filename = payload.get("filename", "file.py")
    
    import difflib
    old_lines = old_code.splitlines(keepends=True)
    new_lines = new_code.splitlines(keepends=True)
    diff = difflib.unified_diff(
        old_lines, new_lines,
        fromfile=f"a/{filename}",
        tofile=f"b/{filename}"
    )
    return {"diff": "".join(diff)}

def main():
    import uvicorn
    import os
    port = int(os.environ.get("PORT", 8000))
    uvicorn.run(app, host="0.0.0.0", port=port)

if __name__ == "__main__":
    main()