File size: 3,289 Bytes
96939ad
 
d416acc
 
 
 
 
96939ad
d416acc
 
 
96939ad
d416acc
 
96939ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d416acc
96939ad
d416acc
96939ad
 
 
 
 
 
 
 
 
 
 
 
 
 
d416acc
 
 
 
96939ad
 
 
 
 
 
 
 
 
 
d416acc
 
 
 
96939ad
 
9fecec8
 
96939ad
 
 
 
 
 
 
 
 
 
9fecec8
96939ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d416acc
 
 
 
 
96939ad
d416acc
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import importlib
import yaml

from fastapi import FastAPI
from pydantic import BaseModel
from environment.api_triage_env import APITriageEnv

app = FastAPI(title="API Triage Agent", version="1.0.0")
env = APITriageEnv()

class ActionRequest(BaseModel):
    action: str


# load task definitions from openenv.yaml
def _load_tasks():
    with open("openenv.yaml", "r") as f:
        cfg = yaml.safe_load(f)
    return cfg.get("tasks", [])


@app.get("/")
def root():
    return {"status": "ok", "environment": "api-triage-agent"}


@app.get("/health")
def health():
    return {"status": "healthy"}


@app.post("/reset")
def reset():
    state = env.reset()
    return {
        "observation": {
            "step": state.step,
            "max_steps": state.max_steps,
            "incident_summary": state.incident_summary,
            "logs": state.logs,
            "response_code": state.response_code,
            "fix_applied": state.fix_applied,
            "is_resolved": state.is_resolved,
        },
        "reward": None,
        "done": False,
    }


@app.get("/state")
def state():
    current = env.state()
    return {
        "step": current.step,
        "max_steps": current.max_steps,
        "incident_summary": current.incident_summary,
        "logs": current.logs,
        "response_code": current.response_code,
        "fix_applied": current.fix_applied,
        "is_resolved": current.is_resolved,
    }


@app.post("/step")
def step(request: ActionRequest):
    action = request.action
    observation, reward, done, info = env.step(action)
    # Clamp reward to strictly (0, 1) for OpenEnv compliance
    clamped_reward = min(max(reward / 20.5, 0.001), 0.999)
    return {
        "observation": {
            "step": observation.step,
            "max_steps": observation.max_steps,
            "incident_summary": observation.incident_summary,
            "logs": observation.logs,
            "response_code": observation.response_code,
            "fix_applied": observation.fix_applied,
            "is_resolved": observation.is_resolved,
        },
        "reward": clamped_reward,
        "done": done,
        "info": info,
    }


@app.get("/tasks")
def list_tasks():
    """Return all tasks defined in openenv.yaml with their graders."""
    tasks = _load_tasks()
    return {
        "tasks": [
            {
                "id": t["id"],
                "name": t["name"],
                "description": t["description"],
                "difficulty": t["difficulty"],
                "grader": t["grader"],
            }
            for t in tasks
        ]
    }


@app.post("/grade/{task_id}")
def grade_task(task_id: str):
    """Run the grader for a specific task and return the score."""
    tasks = _load_tasks()
    task = next((t for t in tasks if t["id"] == task_id), None)
    if task is None:
        return {"error": f"Task '{task_id}' not found", "score": 0.0}

    grader_ref = task["grader"]
    module_path, func_name = grader_ref.rsplit(":", 1)
    mod = importlib.import_module(module_path)
    grade_fn = getattr(mod, func_name)
    score = grade_fn()
    return {"task_id": task_id, "score": score}


def main():
    import uvicorn
    uvicorn.run("app:app", host="0.0.0.0", port=7860)


if __name__ == "__main__":
    main()