File size: 6,795 Bytes
1175c0b
 
 
19f7f7b
 
 
 
 
 
 
 
 
 
1175c0b
 
 
 
19f7f7b
 
1175c0b
19f7f7b
1175c0b
 
 
 
 
19f7f7b
1175c0b
 
 
 
 
19f7f7b
 
 
1175c0b
 
 
 
19f7f7b
 
 
1175c0b
 
 
 
 
 
19f7f7b
1175c0b
 
19f7f7b
 
 
 
1175c0b
 
19f7f7b
 
 
 
1175c0b
 
 
 
 
 
 
 
 
 
 
 
9ef85f9
1175c0b
 
19f7f7b
 
 
 
 
 
1175c0b
9ef85f9
 
 
 
 
 
19f7f7b
 
 
 
 
1175c0b
19f7f7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1175c0b
 
 
 
19f7f7b
 
 
1175c0b
19f7f7b
 
1175c0b
 
 
 
19f7f7b
 
 
 
 
 
 
 
 
 
 
 
 
 
1175c0b
19f7f7b
 
 
 
 
1175c0b
19f7f7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1175c0b
 
 
 
 
19f7f7b
1175c0b
 
19f7f7b
 
 
 
1175c0b
19f7f7b
 
1175c0b
19f7f7b
bbbb404
19f7f7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bbbb404
 
 
19f7f7b
bbbb404
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
"""
Thin FastAPI server — marshals JSON in/out.
No simulation logic lives here.

Endpoints:
  GET  /health                      health check
  GET  /tasks                       list available scenarios
  POST /reset    {task_name, seed}  start a new episode
  POST /step     {action_type, ...} execute one action (phase-aware)
  GET  /state                       per-episode metadata
  GET  /trajectory                  full P1+P2 step records
  POST /score    {declared_patch, declared_no_change, belief_history}
                                    unified grader breakdown
"""

from __future__ import annotations

from dataclasses import asdict
from typing import Any, Dict, List, Optional

from fastapi import FastAPI, Request, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel

from .incident_environment import IncidentEnvironment


# ------------------------------------------------------------------
# App
# ------------------------------------------------------------------

app = FastAPI(
    title       = "SRE Incident Response Environment",
    description = "Two-phase OpenEnv environment (P1 ops + P2 code attribution).",
    version     = "0.2.0",
)

app.add_middleware(
    CORSMiddleware,
    allow_origins = ["*"],
    allow_methods = ["*"],
    allow_headers = ["*"],
)

env = IncidentEnvironment()


# ------------------------------------------------------------------
# Request models
# ------------------------------------------------------------------

class StepRequest(BaseModel):
    action_type:    str
    target_service: Optional[str]      = None
    parameters:     Dict[str, Any]     = {}


class ScoreRequest(BaseModel):
    declared_patch:     Optional[str]   = None
    declared_no_change: bool            = False
    belief_history:     List[Dict[str, Any]] = []


# ------------------------------------------------------------------
# Endpoints
# ------------------------------------------------------------------

@app.get("/health")
def health() -> Dict[str, str]:
    return {"status": "healthy"}


@app.post("/reset")
async def reset(request: Request) -> Dict[str, Any]:
    """
    Initialize a new incident episode.

    Accepts (all optional):
        task_name : str       specific scenario, otherwise sampled from pool
        seed      : int       RNG seed for deterministic replay
        pool      : "A"|"B"|"C"|"D"   selects training pool (sets default mode)
        mode      : "p1_only"|"p2_only"|"joint"   force episode mode
    """
    try:
        body = await request.json()
    except Exception:
        body = {}
    if not isinstance(body, dict):
        body = {}
    return env.reset(
        task_name = body.get("task_name"),
        seed      = body.get("seed"),
        pool      = body.get("pool"),
        mode      = body.get("mode"),
    )


@app.get("/pools")
def list_pools() -> Dict[str, Any]:
    """Pool registry — used by training runners to discover task names."""
    from ..pools import POOLS
    return {
        name: {
            "name":        p.name,
            "description": p.description,
            "task_names":  list(p.task_names),
            "mode":        p.mode,
            "inject_oracle_belief": p.inject_oracle_belief,
        }
        for name, p in POOLS.items()
    }


@app.post("/step")
def step(request: StepRequest) -> Dict[str, Any]:
    """Execute one agent action — phase-aware dispatch."""
    return env.step({
        "action_type":    request.action_type,
        "target_service": request.target_service,
        "parameters":     request.parameters or {},
    })


@app.get("/state")
def state() -> Dict[str, Any]:
    return env.get_state()


@app.get("/trajectory")
def trajectory() -> Dict[str, Any]:
    """Return the current episode's full P1 + P2 trajectory."""
    return {
        "p1": [_serialize_step(r) for r in env.get_p1_trajectory()],
        "p2": [_serialize_step(r) for r in env.get_p2_trajectory()],
    }


@app.post("/score")
def score(req: ScoreRequest) -> Dict[str, Any]:
    """
    Unified grader breakdown + counterfactual r_cross.

    Returns:
        final, p1_rca, p1_efficiency, patch_quality, no_change_detection,
        p2_efficiency, r_cross, null_context_p2_score
    """
    from ..tasks import compute_r_cross
    breakdown = env.score_unified(belief_history=req.belief_history)
    state = env.get_state()
    task = state.get("task_name")
    r_cross = 0.0
    null_baseline = 0.0
    if task:
        try:
            r_cross = compute_r_cross(
                task_name          = task,
                declared_patch     = state.get("declared_patch"),
                declared_no_change = bool(state.get("declared_no_change")),
                p2_trajectory      = env.get_p2_trajectory(),
            )
            from ..tasks import get_scenario
            ctx = get_scenario(task).code_context
            if ctx is not None:
                null_baseline = float(ctx.null_context_p2_score)
        except Exception:
            pass
    return {
        **breakdown,
        "r_cross":               round(r_cross, 4),
        "null_context_p2_score": round(null_baseline, 4),
    }


@app.get("/tasks")
def list_tasks() -> Dict[str, Any]:
    from ..tasks import TASK_REGISTRY
    out: Dict[str, Any] = {}
    for name, cls in TASK_REGISTRY.items():
        scenario = cls()
        out[name] = {
            "display_name":        scenario.display_name,
            "severity":            scenario.severity,
            "max_steps":           scenario.max_steps,
            "time_budget_minutes": scenario.time_budget_minutes,
            "has_phase2":          scenario.code_context is not None,
            "fault_class":         scenario.fault_class,
        }
    return {"tasks": out}


# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------

def _serialize_step(r) -> Dict[str, Any]:
    """Convert a StepRecord into a JSON-safe dict."""
    return {
        "step_number":           r.step_number,
        "phase":                 r.phase,
        "action": {
            "action_type":    r.action.action_type,
            "target_service": r.action.target_service,
            "parameters":     r.action.parameters,
        },
        "reward":                r.reward,
        "observation_summary":   r.observation_summary,
        "service_statuses_after": r.service_statuses_after,
        "timestamp_minutes":     r.timestamp_minutes,
        "belief_state_snapshot": r.belief_state_snapshot,
    }


def main() -> None:
    import uvicorn
    uvicorn.run("incident_env.server.app:app", host="0.0.0.0", port=8000, reload=False)


if __name__ == "__main__":
    main()