File size: 5,688 Bytes
3b074e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
from __future__ import annotations
import json
from typing import Any, Dict, Optional
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Request
from openenv_state import OPENENV_STATE, OpenEnvState
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel

from models import MLOpsAction, MLOpsObservation, MLOpsState
from mlops_environment import MLOpsEnvironment

app = FastAPI(
    title="MLOps Pipeline Debugger",
    description="OpenEnv environment: AI agent diagnoses broken ML training runs.",
    version="1.0.0",
)
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])

_http_env: Optional[MLOpsEnvironment] = None


class ResetRequest(BaseModel):
    task_id: Optional[str] = "easy"
    seed: Optional[int] = None
    task: Optional[str] = None  # Support both task_id and task


class StepResponse(BaseModel):
    observation: MLOpsObservation
    reward: float
    done: bool
    info: Dict[str, Any]


@app.post("/reset", response_model=MLOpsObservation)
async def reset(request: Request):
    try:
        body = await request.json()
    except Exception:
        body = {}
    task_id = body.get("task_id") or body.get("task") or "easy"
    seed = body.get("seed")
    global _http_env
    _http_env = MLOpsEnvironment(task_id=task_id)
    return _http_env.reset(seed=seed)


@app.get("/")
async def root():
    return {
        "message": "MLOps Pipeline Debugger API",
        "version": "1.0.0",
        "docs": "This is an OpenEnv-compatible RL environment",
        "endpoints": {
            "GET /": "This message",
            "GET /health": "Health check",
            "GET /tasks": "List available tasks",
            "GET /openenv/state": "OpenEnv state",
            "POST /reset": "Start a new episode",
            "POST /step": "Take an action",
            "GET /state": "Get current state",
        },
    }


@app.get("/health")
async def health():
    return {"status": "ok", "environment": "mlops_debug_env", "version": "1.0.0"}


@app.get("/openenv/state", response_model=OpenEnvState)
def openenv_state():
    return OPENENV_STATE


@app.get("/tasks")
async def list_tasks():
    return {
        "tasks": [
            {
                "task_id": "easy",
                "name": "Config Error Diagnosis",
                "difficulty": "easy",
                "max_steps": 20,
            },
            {
                "task_id": "medium",
                "name": "Data Leakage Detection",
                "difficulty": "medium",
                "max_steps": 30,
            },
            {
                "task_id": "hard",
                "name": "Silent Evaluation Bug",
                "difficulty": "hard",
                "max_steps": 40,
            },
        ]
    }


@app.post("/step", response_model=StepResponse)
async def step(request: Request):
    if _http_env is None:
        raise HTTPException(400, "Call /reset first.")

    # Get raw body as dict
    try:
        body = await request.json()
    except Exception:
        body = {}

    # Handle various input formats
    action = None
    try:
        if "action_type" in body:
            action = MLOpsAction(**body)
        elif "action" in body:
            action = MLOpsAction(**body["action"])
        elif "message" in body:
            action = MLOpsAction(action_type=body["message"])
    except Exception as e:
        raise HTTPException(422, f"Invalid action: {str(e)}")

    if action is None or action.action_type is None:
        raise HTTPException(422, "Field required: action_type")

    try:
        obs, reward, done, info = _http_env.step(action)
        return StepResponse(observation=obs, reward=reward, done=done, info=info)
    except Exception as e:
        raise HTTPException(500, f"Step error: {str(e)}")


@app.get("/state", response_model=MLOpsState)
async def state():
    if _http_env is None:
        raise HTTPException(400, "Call /reset first.")
    return _http_env.state


@app.websocket("/ws")
async def ws_endpoint(websocket: WebSocket):
    await websocket.accept()
    env: Optional[MLOpsEnvironment] = None
    try:
        while True:
            msg = json.loads(await websocket.receive_text())
            method = msg.get("method")
            if method == "reset":
                env = MLOpsEnvironment(task_id=msg.get("task_id", "easy"))
                obs = env.reset(seed=msg.get("seed"))
                await websocket.send_text(
                    json.dumps({"method": "reset", "observation": obs.model_dump()})
                )
            elif method == "step":
                if env is None:
                    await websocket.send_text(json.dumps({"error": "Call reset first"}))
                    continue
                action = MLOpsAction(**msg.get("action", {}))
                obs, reward, done, info = env.step(action)
                await websocket.send_text(
                    json.dumps(
                        {
                            "method": "step",
                            "observation": obs.model_dump(),
                            "reward": reward,
                            "done": done,
                            "info": info,
                        }
                    )
                )
            elif method == "state":
                if env is None:
                    await websocket.send_text(json.dumps({"error": "Call reset first"}))
                    continue
                await websocket.send_text(
                    json.dumps({"method": "state", "state": env.state.model_dump()})
                )
    except WebSocketDisconnect:
        pass