File size: 5,677 Bytes
5fe9036
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c1d75c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b562fdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5fe9036
 
 
 
 
 
 
1c1d75c
5fe9036
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
"""
FastAPI server for the SRE Incident Response OpenEnv environment.
"""

import sys
import os

# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Dict, List, Optional

from models import Action, Observation, State
from env.environment import IncidentResponseEnv
from tasks import SCENARIOS

app = FastAPI(
    title="SRE Incident Response Environment",
    description="An OpenEnv environment for training AI agents on production incident response.",
    version="1.0.0",
)

env = IncidentResponseEnv()


# ── Request/Response models ────────────────────────────────────────────

class ResetRequest(BaseModel):
    task_id: str = "easy"
    seed: int = 0


class ResetResponse(BaseModel):
    observation: Observation
    session_id: str


class StepRequest(BaseModel):
    session_id: str
    action: Action


class StepResponse(BaseModel):
    observation: Observation
    reward: float
    done: bool
    info: Dict


class TaskInfo(BaseModel):
    task_id: str
    name: str
    difficulty: str
    max_steps: int
    description: str


# ── OpenEnv spec endpoints ─────────────────────────────────────────────

@app.get("/health")
def health():
    return {"status": "healthy"}


@app.get("/metadata")
def metadata():
    return {
        "name": "sre-incident-response",
        "description": "SRE Incident Response environment β€” train AI agents to diagnose and fix production incidents",
        "version": "1.0.0",
    }


@app.get("/schema")
def schema():
    return {
        "action": Action.model_json_schema(),
        "observation": Observation.model_json_schema(),
        "state": State.model_json_schema(),
    }


@app.get("/state")
def state_no_session():
    """Return state for the most recent session, or empty state if none."""
    if env.sessions:
        last_sid = list(env.sessions.keys())[-1]
        return env.state(last_sid)
    return State()


@app.post("/mcp")
def mcp_endpoint(body: dict = {}):
    """Minimal MCP JSON-RPC endpoint for OpenEnv spec compliance."""
    method = body.get("method", "")
    req_id = body.get("id", 1)
    if method == "initialize":
        return {
            "jsonrpc": "2.0",
            "id": req_id,
            "result": {
                "protocolVersion": "2024-11-05",
                "serverInfo": {"name": "sre-incident-response", "version": "1.0.0"},
                "capabilities": {},
            },
        }
    return {
        "jsonrpc": "2.0",
        "id": req_id,
        "result": {},
    }


# ── Endpoints ──────────────────────────────────────────────────────────

@app.get("/")
def root():
    return {
        "name": "SRE Incident Response Environment",
        "version": "1.0.0",
        "endpoints": ["/reset", "/step", "/state/{session_id}", "/tasks", "/health", "/metadata", "/schema"],
    }


@app.post("/reset", response_model=ResetResponse)
def reset(request: ResetRequest):
    try:
        obs, session_id = env.reset(task_id=request.task_id, seed=request.seed)
        return ResetResponse(observation=obs, session_id=session_id)
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))


@app.post("/step", response_model=StepResponse)
def step(request: StepRequest):
    try:
        obs, reward, done, info = env.step(request.session_id, request.action)
        # Ensure info is JSON-serializable
        clean_info = {}
        for k, v in info.items():
            clean_info[k] = v
        return StepResponse(observation=obs, reward=reward, done=done, info=clean_info)
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))


@app.get("/state/{session_id}", response_model=State)
def state(session_id: str):
    try:
        return env.state(session_id)
    except ValueError as e:
        raise HTTPException(status_code=404, detail=str(e))


@app.get("/tasks", response_model=List[TaskInfo])
def tasks():
    result = []
    for tid, scenario in SCENARIOS.items():
        result.append(TaskInfo(
            task_id=tid,
            name=scenario.name,
            difficulty=scenario.difficulty,
            max_steps=scenario.max_steps,
            description=scenario.incident_summary,
        ))
    return result


# ── OpenEnv-prefixed aliases ───────────────────────────────────────────

@app.post("/openenv/reset", response_model=ResetResponse)
def openenv_reset(request: ResetRequest):
    return reset(request)


@app.post("/openenv/step", response_model=StepResponse)
def openenv_step(request: StepRequest):
    return step(request)


@app.get("/openenv/state/{session_id}", response_model=State)
def openenv_state(session_id: str):
    return state(session_id)


@app.get("/openenv/tasks", response_model=List[TaskInfo])
def openenv_tasks():
    return tasks()


# ── Main ───────────────────────────────────────────────────────────────

def main():
    import uvicorn
    port = int(os.environ.get("PORT", "8000"))
    uvicorn.run(app, host="0.0.0.0", port=port)


if __name__ == "__main__":
    main()