ShubhanKamat commited on
Commit
670f19f
·
1 Parent(s): 7be72df

MLOps Firefighter - OpenEnv environment

Browse files
Files changed (13) hide show
  1. Dockerfile +18 -0
  2. README.md +287 -7
  3. baseline_inference.py +245 -0
  4. models.py +117 -0
  5. openenv.yaml +60 -0
  6. pyproject.toml +24 -0
  7. requirements.txt +5 -0
  8. server/__init__.py +1 -0
  9. server/app.py +259 -0
  10. server/environment.py +334 -0
  11. tasks.py +526 -0
  12. tests/test_environment.py +251 -0
  13. validate.py +219 -0
Dockerfile ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ RUN useradd -m -u 1000 user
4
+ USER user
5
+ ENV PATH="/home/user/.local/bin:$PATH"
6
+ ENV PYTHONUNBUFFERED=1
7
+ ENV PORT=7860
8
+
9
+ WORKDIR /app
10
+
11
+ COPY --chown=user ./requirements.txt requirements.txt
12
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
13
+
14
+ COPY --chown=user . /app
15
+
16
+ EXPOSE 7860
17
+
18
+ CMD ["python", "server/app.py"]
README.md CHANGED
@@ -1,12 +1,292 @@
1
  ---
2
- title: Mlops Firefighter
3
- emoji: 👀
4
- colorFrom: blue
5
- colorTo: red
6
  sdk: docker
 
 
 
 
 
 
7
  pinned: false
8
- license: mit
9
- short_description: AI agents act as on-call MLOps engineers
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: MLOps Firefighter
3
+ colorFrom: red
4
+ colorTo: orange
 
5
  sdk: docker
6
+ app_port: 7860
7
+ tags:
8
+ - openenv
9
+ - rl
10
+ - mlops
11
+ - production-ml
12
  pinned: false
13
+ license: bsd-3-clause
 
14
  ---
15
 
16
+ # MLOps Firefighter OpenEnv Environment
17
+
18
+ **Debug and fix ML models failing in production.**
19
+
20
+ An OpenEnv-compliant reinforcement learning environment where AI agents act as on-call MLOps engineers. When a production ML model starts misbehaving, the agent must diagnose the root cause and apply the correct fix, just like a real engineer would at 3 AM.
21
+
22
+ ---
23
+
24
+ ## Why This Environment?
25
+
26
+ Every ML team eventually faces the "model broke in prod" moment. Data drifts. Thresholds get misconfigured during deployments. Training pipelines get poisoned. These are real incidents that cost companies millions of dollars and require skilled human reasoning to resolve.
27
+
28
+ This environment captures that challenge:
29
+ - **Real-world task**: MLOps incident response is performed daily by thousands of engineers
30
+ - **Rich reasoning required**: agents must investigate before acting, weigh evidence, and avoid destructive actions
31
+ - **Novel domain**: no existing OpenEnv covers production ML debugging
32
+ - **Meaningful difficulty progression**: from a simple config error to a subtle adversarial data poisoning attack
33
+
34
+ ---
35
+
36
+ ## Environment Overview
37
+
38
+ ### How It Works
39
+
40
+ 1. **An incident fires**: the agent receives alerts about a failing ML model (e.g., precision dropped, complaints spiked)
41
+ 2. **The agent investigates**: run diagnostics like checking metrics, querying logs, inspecting data distributions
42
+ 3. **The agent identifies the root cause**: data drift? threshold misconfiguration? poisoned training data?
43
+ 4. **The agent applies a fix**: rollback the model, adjust thresholds, fix the data pipeline, add guardrails
44
+ 5. **The agent submits a diagnosis**: declare the root cause and close the incident
45
+
46
+ The grader evaluates: diagnostic thoroughness (30%), diagnosis accuracy (30%), remediation correctness (25%), and efficiency (15%).
47
+
48
+ ### Reward Shaping
49
+
50
+ Rewards provide signal throughout the episode, not just at the end:
51
+
52
+ | Action | Reward |
53
+ |--------|--------|
54
+ | Required diagnostic (first time) | +0.3 |
55
+ | Useful but non-critical diagnostic | +0.1 |
56
+ | Redundant diagnostic (already done) | +0.05 |
57
+ | Correct remediation (after investigation) | +1.0 |
58
+ | Hasty remediation (without diagnosis) | −0.5 |
59
+ | Wrong remediation | −0.3 |
60
+ | Final score bonus (on submit_diagnosis) | 0–5.0 scaled by grader |
61
+ | Timeout penalty | Reduced final multiplier |
62
+
63
+ ---
64
+
65
+ ## Action Space
66
+
67
+ Actions the agent can take, specified via `action_type` string:
68
+
69
+ ### Diagnostic Actions (investigation)
70
+ | Action | Description |
71
+ |--------|-------------|
72
+ | `inspect_metrics` | View model performance dashboard (accuracy, precision, recall, latency) |
73
+ | `query_logs` | Search production logs for deployment events, errors, warnings |
74
+ | `check_data_dist` | Compare training vs. serving data distributions, detect drift |
75
+ | `check_feature_importance` | Examine which features the model relies on |
76
+ | `run_prediction_sample` | Test the model on samples with known ground-truth labels |
77
+ | `check_infrastructure` | Check CPU, memory, GPU, latency — rule out infra issues |
78
+ | `check_upstream_pipeline` | Inspect data pipeline health, connectors, schema changes |
79
+
80
+ ### Remediation Actions (fixes)
81
+ | Action | Parameters | Description |
82
+ |--------|-----------|-------------|
83
+ | `rollback_model` | `target_version: str` | Revert to a previous model version |
84
+ | `adjust_threshold` | `new_threshold: float` | Tune the decision threshold |
85
+ | `retrain_model` | — | Trigger model retraining on corrected data |
86
+ | `fix_data_pipeline` | — | Repair data ingestion / feature pipeline |
87
+ | `scale_infrastructure` | — | Add compute resources |
88
+ | `add_feature_guard` | — | Add input validation and monitoring guardrails |
89
+
90
+ ### Episode Control
91
+ | Action | Parameters | Description |
92
+ |--------|-----------|-------------|
93
+ | `submit_diagnosis` | `root_cause: str, summary: str` | Declare root cause and end the episode |
94
+
95
+ ---
96
+
97
+ ## Observation Space
98
+
99
+ After each action, the agent receives an `MLOpsObservation` with:
100
+
101
+ | Field | Type | Description |
102
+ |-------|------|-------------|
103
+ | `done` | bool | Whether the episode has ended |
104
+ | `reward` | float | Reward for the last action |
105
+ | `step_number` | int | Current step (starts at 0) |
106
+ | `max_steps` | int | Steps before timeout |
107
+ | `task_id` | str | Task identifier |
108
+ | `task_description` | str | Natural language incident description |
109
+ | `alerts` | list[Alert] | Active production alerts with severity, metric, value |
110
+ | `model_info` | ModelInfo | Deployed model name, version, framework, endpoint |
111
+ | `action_result` | str | Detailed textual result of the last action |
112
+ | `action_success` | bool | Whether the action executed successfully |
113
+ | `diagnostics_gathered` | list[str] | Summary of all diagnostics collected so far |
114
+ | `available_actions` | list[str] | Valid action types |
115
+
116
+ ---
117
+
118
+ ## Tasks
119
+
120
+ ### Task 1: Threshold Misconfiguration (Easy)
121
+ **Scenario**: A fraud detection model's precision dropped from 0.94 to 0.61 after redeployment. The decision threshold was accidentally changed from 0.55 to 0.30.
122
+
123
+ - **Root cause**: Threshold misconfiguration during deployment
124
+ - **Expected fix**: Adjust threshold back to ~0.55
125
+ - **Key signal**: Metrics show threshold changed; prediction samples confirm threshold is the lever
126
+ - **Max steps**: 15
127
+
128
+ ### Task 2: Data Drift with Stale Feature Pipeline (Medium)
129
+ **Scenario**: A loan default model's AUC degraded from 0.91 to 0.74. A new credit bureau connector was onboarded 10 days ago, introducing null values and unit changes (dollars → cents) in key features.
130
+
131
+ - **Root causes**: Data drift + broken feature pipeline
132
+ - **Expected fixes**: Fix data pipeline + retrain model
133
+ - **Key signals**: Distribution comparison shows 100x scale change in income; upstream pipeline shows connector issues
134
+ - **Max steps**: 20
135
+
136
+ ### Task 3: Silent Model Regression with Data Poisoning (Hard)
137
+ **Scenario**: A content moderation model has 0.96 overall accuracy (looks fine!), but hate speech and violence recall collapsed to ~40%. An automated retraining pipeline ingested a poisoned crowd-source batch that systematically mislabeled harmful content as "safe." Aggregate metrics didn't catch it because the safe class dominates traffic (92%).
138
+
139
+ - **Root causes**: Training data poisoning / label corruption
140
+ - **Expected fixes**: Roll back to safe version + add per-class recall guardrails
141
+ - **Key signals**: Per-class metrics reveal the hidden regression; data audit shows contaminated batch
142
+ - **Difficulty**: Requires the agent to look beyond aggregate metrics and think about class imbalance
143
+ - **Max steps**: 25
144
+
145
+ ### Expected Difficulty
146
+ | Task | Difficulty | Baseline Score | Notes |
147
+ |------|-----------|---------------|-------|
148
+ | Threshold Misconfiguration | Easy | 1.000 | Single clear root cause |
149
+ | Data Drift | Medium | 0.970 | Multiple issues to identify |
150
+ | Silent Regression | Hard | 0.970 | Requires looking past surface metrics |
151
+
152
+ ---
153
+
154
+ ## Setup & Usage
155
+
156
+ ### Quick Start (Local Python)
157
+
158
+ ```bash
159
+ # Clone and install
160
+ git clone <repo-url>
161
+ cd mlops_firefighter
162
+ pip install -r requirements.txt
163
+
164
+ # Run the server
165
+ python server/app.py
166
+ # → Server running at http://localhost:7860
167
+
168
+ # In another terminal, test it
169
+ curl http://localhost:7860/health
170
+ curl http://localhost:7860/tasks
171
+ curl -X POST http://localhost:7860/reset -H "Content-Type: application/json" -d '{"task_id": "task_threshold_drift"}'
172
+ curl -X POST http://localhost:7860/step -H "Content-Type: application/json" -d '{"action_type": "inspect_metrics"}'
173
+ ```
174
+
175
+ ### Docker
176
+
177
+ ```bash
178
+ # Build
179
+ docker build -t mlops-firefighter .
180
+
181
+ # Run
182
+ docker run -p 7860:7860 mlops-firefighter
183
+
184
+ # Test
185
+ curl http://localhost:7860/health
186
+ ```
187
+
188
+ ### Run Baseline Inference (requires OpenAI API key)
189
+
190
+ ```bash
191
+ export OPENAI_API_KEY=sk-...
192
+ python baseline_inference.py --base-url http://localhost:7860
193
+ ```
194
+
195
+ ### Run the Built-in Baseline (no API key needed)
196
+
197
+ ```bash
198
+ curl -X POST http://localhost:7860/baseline
199
+ ```
200
+
201
+ ### Run Tests
202
+
203
+ ```bash
204
+ python tests/test_environment.py
205
+ ```
206
+
207
+ ---
208
+
209
+ ## API Endpoints
210
+
211
+ | Endpoint | Method | Description |
212
+ |----------|--------|-------------|
213
+ | `/health` | GET | Health check — returns `{"status": "healthy"}` |
214
+ | `/reset` | POST | Start new episode. Body: `{"task_id": "..."}` |
215
+ | `/step` | POST | Take action. Body: `{"action_type": "...", "parameters": {...}}` |
216
+ | `/state` | GET | Current environment state |
217
+ | `/tasks` | GET | List all tasks with action schema |
218
+ | `/grader` | POST | Get grader score for completed episode |
219
+ | `/baseline` | POST | Run scripted baseline on all tasks, return scores |
220
+
221
+ ---
222
+
223
+ ## Grader Scoring Breakdown
224
+
225
+ Each episode is scored 0.0–1.0 based on four components:
226
+
227
+ | Component | Weight | Description |
228
+ |-----------|--------|-------------|
229
+ | Diagnostic thoroughness | 30% | Did the agent run the key diagnostics? |
230
+ | Diagnosis accuracy | 30% | Did the agent identify the correct root cause? |
231
+ | Remediation accuracy | 25% | Did the agent apply the right fix(es)? Penalizes wrong fixes. |
232
+ | Efficiency | 15% | How many steps did the agent use? Fewer = better. |
233
+
234
+ ---
235
+
236
+ ## Baseline Scores
237
+
238
+ Using the built-in scripted baseline (perfect knowledge):
239
+
240
+ | Task | Score | Steps |
241
+ |------|-------|-------|
242
+ | task_threshold_drift (Easy) | 1.000 | 4 |
243
+ | task_data_drift (Medium) | 0.970 | 6 |
244
+ | task_silent_regression (Hard) | 0.970 | 7 |
245
+ | **Average** | **0.980** | — |
246
+
247
+ These represent near-optimal play. An LLM agent without perfect knowledge will score lower, especially on the hard task where aggregate metrics are misleading.
248
+
249
+ ---
250
+
251
+ ## Project Structure
252
+
253
+ ```
254
+ mlops_firefighter/
255
+ ├── models.py # Pydantic Action/Observation models
256
+ ├── tasks.py # Task definitions + grader functions
257
+ ├── openenv.yaml # OpenEnv manifest
258
+ ├── requirements.txt # Python dependencies
259
+ ├── pyproject.toml # Package config
260
+ ├── Dockerfile # Container image
261
+ ├── baseline_inference.py # LLM baseline script (OpenAI API)
262
+ ├── README.md # This file
263
+ ├── server/
264
+ │ ├── __init__.py
265
+ │ ├── app.py # FastAPI server with all endpoints
266
+ │ └── environment.py # Core environment logic
267
+ └── tests/
268
+ └── test_environment.py # Comprehensive test suite (13 tests)
269
+ ```
270
+
271
+ ---
272
+
273
+ ## OpenEnv Spec Compliance
274
+
275
+ - ✅ Typed `Action` and `Observation` Pydantic models
276
+ - ✅ `step(action)` → returns observation, reward, done, info
277
+ - ✅ `reset()` → returns initial observation
278
+ - ✅ `state()` → returns current state
279
+ - ✅ `openenv.yaml` with metadata
280
+ - ✅ 3 tasks with difficulty progression (easy/medium/hard)
281
+ - ✅ Programmatic graders scoring 0.0–1.0
282
+ - ✅ Meaningful reward shaping (not just sparse end-of-episode)
283
+ - ✅ Baseline inference script using OpenAI API
284
+ - ✅ Dockerfile that builds and runs
285
+ - ✅ `/health`, `/tasks`, `/grader`, `/baseline` endpoints
286
+
287
+ ---
288
+
289
+ ## License
290
+
291
+ BSD-3-Clause
292
+
baseline_inference.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) 2026 MLOps Firefighter Contributors
3
+ # Licensed under the BSD-3-Clause License
4
+
5
+ """
6
+ Baseline inference script for the MLOps Firefighter environment.
7
+
8
+ Uses the OpenAI API client to run a model (e.g. GPT-4o) against all 3 tasks,
9
+ producing a reproducible baseline score.
10
+
11
+ Usage:
12
+ export OPENAI_API_KEY=sk-...
13
+ python baseline_inference.py [--base-url http://localhost:7860]
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import argparse
19
+ import json
20
+ import os
21
+ import sys
22
+
23
+ import requests
24
+
25
+ try:
26
+ from openai import OpenAI
27
+ except ImportError:
28
+ print("Install openai: pip install openai")
29
+ sys.exit(1)
30
+
31
+
32
+ # ── Configuration ────────────────────────────────────────────────────────────
33
+
34
+ DEFAULT_ENV_URL = "http://localhost:7860"
35
+ MODEL = os.environ.get("OPENAI_MODEL", "gpt-4o")
36
+
37
+ SYSTEM_PROMPT = """\
38
+ You are an expert MLOps engineer on-call. A production ML model is failing and
39
+ you must diagnose the root cause and fix it.
40
+
41
+ You interact with the environment by choosing actions. Each action returns
42
+ diagnostic information or applies a remediation.
43
+
44
+ STRATEGY:
45
+ 1. First, investigate: run diagnostics to understand what's wrong
46
+ 2. Identify the root cause from the evidence
47
+ 3. Apply the correct remediation
48
+ 4. Submit your diagnosis with a clear root cause label and summary
49
+
50
+ Available actions (use exact strings):
51
+ - inspect_metrics: View model performance metrics dashboard
52
+ - query_logs: Search production logs for anomalies
53
+ - check_data_dist: Compare training vs serving data distributions
54
+ - check_feature_importance: Examine feature weights and importance
55
+ - run_prediction_sample: Test model on sample inputs with known labels
56
+ - check_infrastructure: Check latency, memory, GPU, compute
57
+ - check_upstream_pipeline: Inspect data pipeline health
58
+ - rollback_model: Revert to a previous model version (params: target_version)
59
+ - adjust_threshold: Tune decision threshold (params: new_threshold)
60
+ - retrain_model: Trigger model retraining
61
+ - fix_data_pipeline: Repair data ingestion issues
62
+ - scale_infrastructure: Add compute resources
63
+ - add_feature_guard: Add input validation / guardrails
64
+ - submit_diagnosis: Declare root cause and close (params: root_cause, summary)
65
+
66
+ Respond with ONLY valid JSON: {"action_type": "...", "parameters": {...}}
67
+ """
68
+
69
+
70
+ def call_env(base_url: str, endpoint: str, method: str = "POST", data: dict | None = None) -> dict:
71
+ """Call an environment endpoint."""
72
+ url = f"{base_url}{endpoint}"
73
+ if method == "GET":
74
+ resp = requests.get(url, timeout=30)
75
+ else:
76
+ resp = requests.post(url, json=data or {}, timeout=30)
77
+ resp.raise_for_status()
78
+ return resp.json()
79
+
80
+
81
+ def run_agent_on_task(client: OpenAI, base_url: str, task_id: str, task_name: str) -> dict:
82
+ """Run the LLM agent on a single task."""
83
+ print(f"\n{'='*60}")
84
+ print(f" Task: {task_name} ({task_id})")
85
+ print(f"{'='*60}")
86
+
87
+ # Reset environment
88
+ reset_result = call_env(base_url, "/reset", data={"task_id": task_id})
89
+ obs = reset_result["observation"]
90
+ print(f" Incident: {obs['task_description'][:100]}...")
91
+
92
+ messages = [
93
+ {"role": "system", "content": SYSTEM_PROMPT},
94
+ {
95
+ "role": "user",
96
+ "content": (
97
+ f"INCIDENT ALERT:\n{obs['task_description']}\n\n"
98
+ f"ALERTS:\n{json.dumps(obs['alerts'], indent=2)}\n\n"
99
+ f"MODEL INFO:\n{json.dumps(obs['model_info'], indent=2)}\n\n"
100
+ f"Available actions: {obs['available_actions']}\n\n"
101
+ "What is your first action? Respond with JSON only."
102
+ ),
103
+ },
104
+ ]
105
+
106
+ total_reward = 0.0
107
+ steps = 0
108
+ max_steps = obs.get("max_steps", 20)
109
+
110
+ while steps < max_steps:
111
+ # Get LLM decision
112
+ try:
113
+ response = client.chat.completions.create(
114
+ model=MODEL,
115
+ messages=messages,
116
+ temperature=0.2,
117
+ max_tokens=500,
118
+ )
119
+ llm_output = response.choices[0].message.content.strip()
120
+ except Exception as e:
121
+ print(f" LLM error: {e}")
122
+ break
123
+
124
+ # Parse action from LLM
125
+ try:
126
+ # Handle markdown code blocks
127
+ if "```" in llm_output:
128
+ llm_output = llm_output.split("```")[1]
129
+ if llm_output.startswith("json"):
130
+ llm_output = llm_output[4:]
131
+ llm_output = llm_output.strip()
132
+ action_data = json.loads(llm_output)
133
+ except json.JSONDecodeError:
134
+ # Try to extract JSON from the response
135
+ import re
136
+ match = re.search(r'\{[^}]+\}', llm_output)
137
+ if match:
138
+ try:
139
+ action_data = json.loads(match.group())
140
+ except json.JSONDecodeError:
141
+ print(f" Failed to parse LLM output: {llm_output[:100]}")
142
+ break
143
+ else:
144
+ print(f" Failed to parse LLM output: {llm_output[:100]}")
145
+ break
146
+
147
+ action_type = action_data.get("action_type", "")
148
+ parameters = action_data.get("parameters", {})
149
+
150
+ # Step environment
151
+ step_result = call_env(
152
+ base_url, "/step",
153
+ data={"action_type": action_type, "parameters": parameters},
154
+ )
155
+ obs = step_result["observation"]
156
+ reward = step_result["reward"]
157
+ done = step_result["done"]
158
+ total_reward += reward
159
+ steps += 1
160
+
161
+ print(f" Step {steps}: {action_type} → reward={reward:.2f}")
162
+
163
+ if done:
164
+ print(f" Episode done. Total reward: {total_reward:.2f}")
165
+ break
166
+
167
+ # Feed result back to LLM
168
+ messages.append({"role": "assistant", "content": llm_output})
169
+ messages.append({
170
+ "role": "user",
171
+ "content": (
172
+ f"ACTION RESULT:\n{obs['action_result']}\n\n"
173
+ f"Step {obs['step_number']}/{obs['max_steps']}\n"
174
+ f"Diagnostics gathered so far: {len(obs['diagnostics_gathered'])}\n\n"
175
+ "What is your next action? Respond with JSON only."
176
+ ),
177
+ })
178
+
179
+ # Get grader score
180
+ grader_result = call_env(base_url, "/grader", data={})
181
+ score = grader_result.get("score", 0.0)
182
+ breakdown = grader_result.get("breakdown", {})
183
+
184
+ print(f" Grader score: {score:.3f}")
185
+ print(f" Breakdown: {json.dumps(breakdown, indent=2)}")
186
+
187
+ return {
188
+ "task_id": task_id,
189
+ "task_name": task_name,
190
+ "score": score,
191
+ "breakdown": breakdown,
192
+ "steps": steps,
193
+ "total_reward": round(total_reward, 3),
194
+ }
195
+
196
+
197
+ def main():
198
+ parser = argparse.ArgumentParser(description="MLOps Firefighter Baseline")
199
+ parser.add_argument("--base-url", default=DEFAULT_ENV_URL, help="Environment URL")
200
+ args = parser.parse_args()
201
+
202
+ api_key = os.environ.get("OPENAI_API_KEY")
203
+ if not api_key:
204
+ print("ERROR: Set OPENAI_API_KEY environment variable")
205
+ sys.exit(1)
206
+
207
+ client = OpenAI(api_key=api_key)
208
+
209
+ # Get tasks
210
+ tasks_info = call_env(args.base_url, "/tasks", method="GET")
211
+ tasks = tasks_info["tasks"]
212
+
213
+ print("\n" + "=" * 60)
214
+ print(" MLOps Firefighter — Baseline Inference")
215
+ print(f" Model: {MODEL}")
216
+ print(f" Environment: {args.base_url}")
217
+ print(f" Tasks: {len(tasks)}")
218
+ print("=" * 60)
219
+
220
+ results = []
221
+ for task in tasks:
222
+ result = run_agent_on_task(
223
+ client, args.base_url, task["task_id"], task["name"]
224
+ )
225
+ results.append(result)
226
+
227
+ # Summary
228
+ print("\n" + "=" * 60)
229
+ print(" BASELINE RESULTS SUMMARY")
230
+ print("=" * 60)
231
+ for r in results:
232
+ print(f" [{r['task_id']}] {r['task_name']}")
233
+ print(f" Score: {r['score']:.3f} | Steps: {r['steps']} | Reward: {r['total_reward']}")
234
+ avg = sum(r["score"] for r in results) / len(results) if results else 0
235
+ print(f"\n Average Score: {avg:.3f}")
236
+ print("=" * 60)
237
+
238
+ # Write results to file
239
+ with open("baseline_results.json", "w") as f:
240
+ json.dump({"results": results, "average_score": round(avg, 3)}, f, indent=2)
241
+ print(" Results saved to baseline_results.json")
242
+
243
+
244
+ if __name__ == "__main__":
245
+ main()
models.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 MLOps Firefighter Contributors
2
+ # Licensed under the BSD-3-Clause License
3
+
4
+ """
5
+ Typed Pydantic models for the MLOps Firefighter environment.
6
+ Defines Action, Observation, and related types for the OpenEnv spec.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from enum import Enum
12
+ from typing import Any, Optional
13
+
14
+ from pydantic import BaseModel, Field
15
+
16
+
17
+ # ── Action Space ─────────────────────────────────────────────────────────────
18
+
19
+ class ActionType(str, Enum):
20
+ """All actions an agent can take to diagnose and fix a production ML model."""
21
+
22
+ # Diagnostic actions
23
+ INSPECT_METRICS = "inspect_metrics" # View model performance metrics
24
+ QUERY_LOGS = "query_logs" # Search production logs
25
+ CHECK_DATA_DISTRIBUTION = "check_data_dist" # Compare train vs. serving data
26
+ CHECK_FEATURE_IMPORTANCE = "check_feature_importance" # Examine feature weights
27
+ RUN_PREDICTION_SAMPLE = "run_prediction_sample" # Test model on sample inputs
28
+ CHECK_INFRASTRUCTURE = "check_infrastructure" # Check latency, memory, GPU
29
+ CHECK_UPSTREAM_PIPELINE = "check_upstream_pipeline" # Inspect data pipeline health
30
+
31
+ # Remediation actions
32
+ ROLLBACK_MODEL = "rollback_model" # Revert to previous model version
33
+ ADJUST_THRESHOLD = "adjust_threshold" # Tune decision threshold
34
+ RETRAIN_MODEL = "retrain_model" # Trigger retraining
35
+ FIX_DATA_PIPELINE = "fix_data_pipeline" # Repair data ingestion issue
36
+ SCALE_INFRASTRUCTURE = "scale_infrastructure" # Add compute resources
37
+ ADD_FEATURE_GUARD = "add_feature_guard" # Add input validation / guardrail
38
+ SUBMIT_DIAGNOSIS = "submit_diagnosis" # Declare root cause & close episode
39
+
40
+
41
+ class MLOpsAction(BaseModel):
42
+ """An action taken by the agent in the MLOps Firefighter environment."""
43
+
44
+ action_type: ActionType = Field(
45
+ ..., description="The type of action to perform"
46
+ )
47
+ parameters: dict[str, Any] = Field(
48
+ default_factory=dict,
49
+ description=(
50
+ "Action-specific parameters. E.g. for adjust_threshold: "
51
+ "{'new_threshold': 0.6}. For submit_diagnosis: "
52
+ "{'root_cause': 'data_drift', 'summary': '...'}."
53
+ ),
54
+ )
55
+
56
+
57
+ # ── Observation Space ────────────────────────────────────────────────────────
58
+
59
+ class AlertSeverity(str, Enum):
60
+ CRITICAL = "critical"
61
+ HIGH = "high"
62
+ MEDIUM = "medium"
63
+ LOW = "low"
64
+
65
+
66
+ class Alert(BaseModel):
67
+ """A production alert that triggered the incident."""
68
+ severity: AlertSeverity
69
+ message: str
70
+ metric_name: str
71
+ current_value: float
72
+ threshold: float
73
+ timestamp: str
74
+
75
+
76
+ class ModelInfo(BaseModel):
77
+ """Information about the deployed model."""
78
+ model_name: str
79
+ model_version: str
80
+ deployed_at: str
81
+ framework: str
82
+ endpoint: str
83
+ previous_versions: list[str] = Field(default_factory=list)
84
+
85
+
86
+ class MLOpsObservation(BaseModel):
87
+ """What the agent sees after each action."""
88
+
89
+ # Episode metadata
90
+ done: bool = Field(False, description="Whether the episode has ended")
91
+ reward: float = Field(0.0, description="Reward for the last action")
92
+ step_number: int = Field(0, description="Current step in the episode")
93
+ max_steps: int = Field(20, description="Maximum steps before timeout")
94
+
95
+ # Incident context (always visible)
96
+ task_id: str = Field("", description="Task identifier")
97
+ task_description: str = Field("", description="Natural language task description")
98
+ alerts: list[Alert] = Field(default_factory=list, description="Active alerts")
99
+ model_info: ModelInfo | None = Field(None, description="Deployed model details")
100
+
101
+ # Action result (populated after each step)
102
+ action_result: str = Field(
103
+ "", description="Textual result of the last action taken"
104
+ )
105
+ action_success: bool = Field(True, description="Whether the action executed OK")
106
+
107
+ # Accumulated diagnostic context
108
+ diagnostics_gathered: list[str] = Field(
109
+ default_factory=list,
110
+ description="Summary of diagnostics the agent has collected so far",
111
+ )
112
+
113
+ # Hints / guidance
114
+ available_actions: list[str] = Field(
115
+ default_factory=list,
116
+ description="List of valid action types the agent can take",
117
+ )
openenv.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: mlops_firefighter
2
+ version: "1.0.0"
3
+ description: >
4
+ MLOps Firefighter: An OpenEnv environment where AI agents debug and fix
5
+ ML models failing in production. Agents diagnose root causes (data drift,
6
+ threshold misconfiguration, training data poisoning) and apply the correct
7
+ remediation. 3 tasks from easy to hard with programmatic graders.
8
+
9
+ author: "MLOps Firefighter Contributors"
10
+ license: "BSD-3-Clause"
11
+ tags:
12
+ - openenv
13
+ - rl
14
+ - mlops
15
+ - production-ml
16
+ - debugging
17
+ - real-world
18
+
19
+ environment:
20
+ entrypoint: server/app.py
21
+ port: 7860
22
+ python_version: "3.11"
23
+
24
+ action_model: models.MLOpsAction
25
+ observation_model: models.MLOpsObservation
26
+
27
+ endpoints:
28
+ reset: /reset
29
+ step: /step
30
+ state: /state
31
+ health: /health
32
+ tasks: /tasks
33
+ grader: /grader
34
+ baseline: /baseline
35
+
36
+ tasks:
37
+ - id: task_threshold_drift
38
+ name: "Threshold Misconfiguration After Redeployment"
39
+ difficulty: easy
40
+ description: >
41
+ Fraud detection model precision dropped after redeployment.
42
+ Agent must diagnose threshold misconfiguration and adjust it.
43
+ max_steps: 15
44
+
45
+ - id: task_data_drift
46
+ name: "Data Drift with Stale Feature Pipeline"
47
+ difficulty: medium
48
+ description: >
49
+ Loan default model degraded due to upstream data pipeline change.
50
+ Agent must identify feature drift/corruption and fix the pipeline + retrain.
51
+ max_steps: 20
52
+
53
+ - id: task_silent_regression
54
+ name: "Silent Model Regression with Adversarial Inputs"
55
+ difficulty: hard
56
+ description: >
57
+ Content moderation model looks fine on aggregate metrics but is missing
58
+ hate/violence due to poisoned training data. Agent must find the silent
59
+ regression, roll back safely, and add guards.
60
+ max_steps: 25
pyproject.toml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=68.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "mlops-firefighter"
7
+ version = "1.0.0"
8
+ description = "OpenEnv environment: Debug and fix ML models failing in production"
9
+ readme = "README.md"
10
+ license = {text = "BSD-3-Clause"}
11
+ requires-python = ">=3.10"
12
+ dependencies = [
13
+ "fastapi>=0.104.0",
14
+ "uvicorn[standard]>=0.24.0",
15
+ "pydantic>=2.5.0",
16
+ "requests>=2.31.0",
17
+ ]
18
+
19
+ [project.optional-dependencies]
20
+ baseline = ["openai>=1.0.0"]
21
+ dev = ["pytest>=7.0", "httpx>=0.25.0"]
22
+
23
+ [tool.setuptools.packages.find]
24
+ include = ["mlops_firefighter*", "server*"]
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ fastapi>=0.104.0
2
+ uvicorn[standard]>=0.24.0
3
+ pydantic>=2.5.0
4
+ requests>=2.31.0
5
+ openai>=1.0.0
server/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Server package
server/app.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 MLOps Firefighter Contributors
2
+ # Licensed under the BSD-3-Clause License
3
+
4
+ """
5
+ FastAPI server for the MLOps Firefighter OpenEnv environment.
6
+
7
+ Endpoints:
8
+ POST /reset — Start a new episode
9
+ POST /step — Take an action
10
+ GET /state — Current environment state
11
+ GET /health — Health check
12
+ GET /tasks — List tasks and action schema
13
+ POST /grader — Score a completed episode
14
+ POST /baseline — Run baseline inference on all tasks
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import json
20
+ import os
21
+ import sys
22
+ import traceback
23
+ from typing import Any
24
+
25
+ from fastapi import FastAPI, HTTPException
26
+ from fastapi.middleware.cors import CORSMiddleware
27
+ from pydantic import BaseModel, Field
28
+
29
+ # Add parent dir to path for imports
30
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
31
+
32
+ from models import ActionType, MLOpsAction, MLOpsObservation
33
+ from tasks import ALL_TASKS, grade_episode
34
+
35
+ from environment import MLOpsFirefighterEnvironment
36
+
37
+ # ── App Setup ────────────────────────────────────────────────────────────────
38
+
39
+ app = FastAPI(
40
+ title="MLOps Firefighter — OpenEnv Environment",
41
+ description=(
42
+ "An AI agent environment for debugging and fixing ML models in production. "
43
+ "The agent diagnoses root causes of model failures (data drift, threshold "
44
+ "misconfiguration, training data poisoning) and applies the correct fix."
45
+ ),
46
+ version="1.0.0",
47
+ )
48
+
49
+ app.add_middleware(
50
+ CORSMiddleware,
51
+ allow_origins=["*"],
52
+ allow_credentials=True,
53
+ allow_methods=["*"],
54
+ allow_headers=["*"],
55
+ )
56
+
57
+ # Per-session environments (simple in-memory for single container)
58
+ _environments: dict[str, MLOpsFirefighterEnvironment] = {}
59
+ _default_env = MLOpsFirefighterEnvironment()
60
+
61
+
62
+ def _get_env(session_id: str | None = None) -> MLOpsFirefighterEnvironment:
63
+ if session_id and session_id in _environments:
64
+ return _environments[session_id]
65
+ return _default_env
66
+
67
+
68
+ # ── Request / Response Models ────────────────────────────────────────────────
69
+
70
+ class ResetRequest(BaseModel):
71
+ task_id: str | None = Field(None, description="Task to load (optional)")
72
+ session_id: str | None = Field(None, description="Session ID (optional)")
73
+
74
+
75
+ class StepRequest(BaseModel):
76
+ action_type: str = Field(..., description="Action type to perform")
77
+ parameters: dict[str, Any] = Field(default_factory=dict)
78
+ session_id: str | None = Field(None, description="Session ID (optional)")
79
+
80
+
81
+ class GraderRequest(BaseModel):
82
+ session_id: str | None = Field(None)
83
+
84
+
85
+ # ── Endpoints ────────────────────────────────────────────────────────────────
86
+
87
+ @app.get("/health")
88
+ async def health():
89
+ return {"status": "healthy"}
90
+
91
+
92
+ @app.post("/reset")
93
+ async def reset(req: ResetRequest) -> dict:
94
+ env = _get_env(req.session_id)
95
+ if req.session_id:
96
+ _environments[req.session_id] = env
97
+
98
+ obs = env.reset(task_id=req.task_id)
99
+ return {
100
+ "observation": obs.model_dump(),
101
+ "reward": obs.reward,
102
+ "done": obs.done,
103
+ "info": {"episode_id": env.state()["episode_id"]},
104
+ }
105
+
106
+
107
+ @app.post("/step")
108
+ async def step(req: StepRequest) -> dict:
109
+ env = _get_env(req.session_id)
110
+
111
+ try:
112
+ action_type = ActionType(req.action_type)
113
+ except ValueError:
114
+ raise HTTPException(
115
+ status_code=400,
116
+ detail=f"Invalid action_type: '{req.action_type}'. "
117
+ f"Valid actions: {[a.value for a in ActionType]}",
118
+ )
119
+
120
+ action = MLOpsAction(action_type=action_type, parameters=req.parameters)
121
+ obs = env.step(action)
122
+
123
+ return {
124
+ "observation": obs.model_dump(),
125
+ "reward": obs.reward,
126
+ "done": obs.done,
127
+ "info": env.state(),
128
+ }
129
+
130
+
131
+ @app.get("/state")
132
+ async def state(session_id: str | None = None) -> dict:
133
+ env = _get_env(session_id)
134
+ return env.state()
135
+
136
+
137
+ @app.get("/tasks")
138
+ async def list_tasks() -> dict:
139
+ """List all tasks with their action schema."""
140
+ tasks = []
141
+ for tid, t in ALL_TASKS.items():
142
+ tasks.append({
143
+ "task_id": t.task_id,
144
+ "name": t.name,
145
+ "difficulty": t.difficulty,
146
+ "description": t.description,
147
+ "max_steps": t.max_steps,
148
+ })
149
+
150
+ action_schema = {
151
+ "action_type": {
152
+ "type": "string",
153
+ "required": True,
154
+ "enum": [a.value for a in ActionType],
155
+ "description": "The action to perform",
156
+ },
157
+ "parameters": {
158
+ "type": "object",
159
+ "required": False,
160
+ "description": (
161
+ "Action-specific parameters. For adjust_threshold: "
162
+ "{'new_threshold': float}. For submit_diagnosis: "
163
+ "{'root_cause': str, 'summary': str}. For rollback_model: "
164
+ "{'target_version': str}."
165
+ ),
166
+ },
167
+ }
168
+
169
+ return {"tasks": tasks, "action_schema": action_schema}
170
+
171
+
172
+ @app.post("/grader")
173
+ async def grader(req: GraderRequest) -> dict:
174
+ """Return grader score for a completed episode."""
175
+ env = _get_env(req.session_id)
176
+ st = env.state()
177
+
178
+ if not st["done"]:
179
+ raise HTTPException(
180
+ status_code=400,
181
+ detail="Episode not complete. Finish the episode first.",
182
+ )
183
+
184
+ if st["grader_result"]:
185
+ return {
186
+ "score": st["grader_result"]["total"],
187
+ "breakdown": st["grader_result"],
188
+ "task_id": st["task_id"],
189
+ }
190
+
191
+ return {
192
+ "score": 0.0,
193
+ "breakdown": {},
194
+ "task_id": st["task_id"],
195
+ "message": "No grader result available.",
196
+ }
197
+
198
+
199
+ @app.post("/baseline")
200
+ async def run_baseline() -> dict:
201
+ """Run a scripted baseline agent on all 3 tasks and return scores."""
202
+ results = {}
203
+
204
+ for task_id, task_def in ALL_TASKS.items():
205
+ env = MLOpsFirefighterEnvironment()
206
+ env.reset(task_id=task_id)
207
+
208
+ # Baseline strategy: run all required diagnostics, then apply
209
+ # correct remediation, then submit diagnosis
210
+ for diag in task_def.required_diagnostics:
211
+ action = MLOpsAction(action_type=diag, parameters={})
212
+ env.step(action)
213
+
214
+ # Apply correct remediations
215
+ for rem in task_def.correct_remediations:
216
+ params = {}
217
+ if rem == ActionType.ADJUST_THRESHOLD:
218
+ params = {"new_threshold": task_def.extra_state.get("optimal_threshold", 0.5)}
219
+ elif rem == ActionType.ROLLBACK_MODEL:
220
+ params = {"target_version": task_def.model_info.previous_versions[0]}
221
+ action = MLOpsAction(action_type=rem, parameters=params)
222
+ env.step(action)
223
+
224
+ # Submit diagnosis
225
+ action = MLOpsAction(
226
+ action_type=ActionType.SUBMIT_DIAGNOSIS,
227
+ parameters={
228
+ "root_cause": task_def.root_causes[0],
229
+ "summary": f"Baseline diagnosis for {task_def.name}",
230
+ },
231
+ )
232
+ env.step(action)
233
+
234
+ st = env.state()
235
+ results[task_id] = {
236
+ "task_name": task_def.name,
237
+ "difficulty": task_def.difficulty,
238
+ "score": st["grader_result"]["total"] if st["grader_result"] else 0.0,
239
+ "breakdown": st["grader_result"],
240
+ "steps_taken": st["step_count"],
241
+ }
242
+
243
+ # Average score
244
+ scores = [r["score"] for r in results.values()]
245
+ avg = sum(scores) / len(scores) if scores else 0.0
246
+
247
+ return {
248
+ "baseline_results": results,
249
+ "average_score": round(avg, 3),
250
+ }
251
+
252
+
253
+ # ── Main ─────────────────────────────────────────────────────────────────────
254
+
255
+ if __name__ == "__main__":
256
+ import uvicorn
257
+
258
+ port = int(os.environ.get("PORT", 7860))
259
+ uvicorn.run(app, host="0.0.0.0", port=port)
server/environment.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 MLOps Firefighter Contributors
2
+ # Licensed under the BSD-3-Clause License
3
+
4
+ """
5
+ Core environment logic for the MLOps Firefighter.
6
+
7
+ Implements the OpenEnv Environment interface: reset(), step(), state().
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import uuid
13
+ from typing import Any
14
+
15
+ from models import (
16
+ ActionType,
17
+ AlertSeverity,
18
+ MLOpsAction,
19
+ MLOpsObservation,
20
+ )
21
+ from tasks import ALL_TASKS, TaskDefinition, grade_episode
22
+
23
+
24
+ class MLOpsFirefighterEnvironment:
25
+ """
26
+ Simulates an ML model failing in production.
27
+
28
+ The agent must diagnose the root cause and apply the correct fix
29
+ through a sequence of diagnostic and remediation actions.
30
+ """
31
+
32
+ def __init__(self) -> None:
33
+ self._episode_id: str = ""
34
+ self._step_count: int = 0
35
+ self._task: TaskDefinition | None = None
36
+ self._done: bool = False
37
+ self._cumulative_reward: float = 0.0
38
+
39
+ # Tracking for grading
40
+ self._actions_taken: list[dict[str, Any]] = []
41
+ self._diagnostics_gathered: list[str] = []
42
+ self._diagnosis_submitted: dict[str, Any] | None = None
43
+ self._remediations_applied: list[str] = []
44
+ self._last_grader_result: dict | None = None
45
+
46
+ # ── OpenEnv Interface ────────────────────────────────────────────────
47
+
48
+ def reset(self, task_id: str | None = None) -> MLOpsObservation:
49
+ """Initialize a new incident episode."""
50
+ self._episode_id = str(uuid.uuid4())
51
+ self._step_count = 0
52
+ self._done = False
53
+ self._cumulative_reward = 0.0
54
+ self._actions_taken = []
55
+ self._diagnostics_gathered = []
56
+ self._diagnosis_submitted = None
57
+ self._remediations_applied = []
58
+ self._last_grader_result = None
59
+
60
+ # Pick task
61
+ if task_id and task_id in ALL_TASKS:
62
+ self._task = ALL_TASKS[task_id]
63
+ else:
64
+ # Default to easy task
65
+ self._task = list(ALL_TASKS.values())[0]
66
+
67
+ return MLOpsObservation(
68
+ done=False,
69
+ reward=0.0,
70
+ step_number=0,
71
+ max_steps=self._task.max_steps,
72
+ task_id=self._task.task_id,
73
+ task_description=self._task.description,
74
+ alerts=self._task.alerts,
75
+ model_info=self._task.model_info,
76
+ action_result="Incident assigned to you. Begin investigation.",
77
+ action_success=True,
78
+ diagnostics_gathered=[],
79
+ available_actions=[a.value for a in ActionType],
80
+ )
81
+
82
+ def step(self, action: MLOpsAction) -> MLOpsObservation:
83
+ """Execute an action and return the resulting observation."""
84
+ if self._task is None:
85
+ return MLOpsObservation(
86
+ done=True,
87
+ reward=-1.0,
88
+ action_result="Error: No task loaded. Call reset() first.",
89
+ action_success=False,
90
+ )
91
+
92
+ if self._done:
93
+ return MLOpsObservation(
94
+ done=True,
95
+ reward=0.0,
96
+ step_number=self._step_count,
97
+ max_steps=self._task.max_steps,
98
+ task_id=self._task.task_id,
99
+ task_description=self._task.description,
100
+ action_result="Episode already ended.",
101
+ action_success=False,
102
+ diagnostics_gathered=self._diagnostics_gathered,
103
+ available_actions=[],
104
+ )
105
+
106
+ self._step_count += 1
107
+ self._actions_taken.append({
108
+ "action_type": action.action_type.value,
109
+ "parameters": action.parameters,
110
+ "step": self._step_count,
111
+ })
112
+
113
+ # Process the action
114
+ result_text, reward, action_success = self._process_action(action)
115
+
116
+ # Check for episode end
117
+ if action.action_type == ActionType.SUBMIT_DIAGNOSIS:
118
+ self._done = True
119
+ # Run grader for final reward
120
+ final_score, breakdown = grade_episode(
121
+ task=self._task,
122
+ actions_taken=self._actions_taken,
123
+ diagnosis_submitted=self._diagnosis_submitted,
124
+ remediation_applied=self._remediations_applied,
125
+ total_steps=self._step_count,
126
+ )
127
+ self._last_grader_result = breakdown
128
+ # Scale final reward: bonus for good diagnosis
129
+ reward = final_score * 5.0 # 0–5 range for final step
130
+ result_text += f"\n\n── EPISODE COMPLETE ──\nFinal Score: {final_score:.3f}\nBreakdown: {breakdown}"
131
+
132
+ # Check timeout
133
+ if self._step_count >= self._task.max_steps and not self._done:
134
+ self._done = True
135
+ final_score, breakdown = grade_episode(
136
+ task=self._task,
137
+ actions_taken=self._actions_taken,
138
+ diagnosis_submitted=self._diagnosis_submitted,
139
+ remediation_applied=self._remediations_applied,
140
+ total_steps=self._step_count,
141
+ )
142
+ self._last_grader_result = breakdown
143
+ reward = final_score * 3.0 # Lower multiplier for timeout
144
+ result_text += (
145
+ f"\n\n── TIMEOUT — Episode ended (max {self._task.max_steps} steps) ──\n"
146
+ f"Score: {final_score:.3f}\nBreakdown: {breakdown}"
147
+ )
148
+
149
+ self._cumulative_reward += reward
150
+
151
+ return MLOpsObservation(
152
+ done=self._done,
153
+ reward=round(reward, 4),
154
+ step_number=self._step_count,
155
+ max_steps=self._task.max_steps,
156
+ task_id=self._task.task_id,
157
+ task_description=self._task.description,
158
+ alerts=self._task.alerts,
159
+ model_info=self._task.model_info,
160
+ action_result=result_text,
161
+ action_success=action_success,
162
+ diagnostics_gathered=self._diagnostics_gathered,
163
+ available_actions=(
164
+ [a.value for a in ActionType] if not self._done else []
165
+ ),
166
+ )
167
+
168
+ def state(self) -> dict[str, Any]:
169
+ """Return current environment state."""
170
+ return {
171
+ "episode_id": self._episode_id,
172
+ "step_count": self._step_count,
173
+ "task_id": self._task.task_id if self._task else None,
174
+ "done": self._done,
175
+ "cumulative_reward": round(self._cumulative_reward, 4),
176
+ "actions_taken_count": len(self._actions_taken),
177
+ "diagnostics_gathered": self._diagnostics_gathered,
178
+ "remediations_applied": self._remediations_applied,
179
+ "diagnosis_submitted": self._diagnosis_submitted is not None,
180
+ "grader_result": self._last_grader_result,
181
+ }
182
+
183
+ # ── Action Processing ────────────────────────────────────────────────
184
+
185
+ def _process_action(self, action: MLOpsAction) -> tuple[str, float, bool]:
186
+ """
187
+ Process an action and return (result_text, step_reward, success).
188
+
189
+ Reward shaping:
190
+ - Useful diagnostic: +0.3
191
+ - Redundant diagnostic: +0.05
192
+ - Correct remediation: +1.0
193
+ - Wrong remediation: -0.5
194
+ - Destructive action without diagnosis: -1.0
195
+ - Submit diagnosis: handled in step() via grader
196
+ """
197
+ task = self._task
198
+ assert task is not None
199
+
200
+ at = action.action_type
201
+
202
+ # ── Diagnostic actions ───────────────────────────────────────────
203
+ if at in (
204
+ ActionType.INSPECT_METRICS,
205
+ ActionType.QUERY_LOGS,
206
+ ActionType.CHECK_DATA_DISTRIBUTION,
207
+ ActionType.CHECK_FEATURE_IMPORTANCE,
208
+ ActionType.RUN_PREDICTION_SAMPLE,
209
+ ActionType.CHECK_INFRASTRUCTURE,
210
+ ActionType.CHECK_UPSTREAM_PIPELINE,
211
+ ):
212
+ result_text = task.diagnostic_results.get(
213
+ at, "No additional information available for this diagnostic."
214
+ )
215
+ # Reward: useful if it's a required diagnostic we haven't done
216
+ diag_label = f"[{at.value}] {result_text[:80]}..."
217
+ if diag_label in self._diagnostics_gathered:
218
+ # Redundant — already gathered
219
+ reward = 0.05
220
+ result_text = "(You already ran this diagnostic.)\n\n" + result_text
221
+ else:
222
+ self._diagnostics_gathered.append(diag_label)
223
+ if at in task.required_diagnostics:
224
+ reward = 0.3 # Useful diagnostic
225
+ else:
226
+ reward = 0.1 # Valid but not critical
227
+ return result_text, reward, True
228
+
229
+ # ── Remediation actions ──────────────────────────────────────────
230
+ if at in (
231
+ ActionType.ROLLBACK_MODEL,
232
+ ActionType.ADJUST_THRESHOLD,
233
+ ActionType.RETRAIN_MODEL,
234
+ ActionType.FIX_DATA_PIPELINE,
235
+ ActionType.SCALE_INFRASTRUCTURE,
236
+ ActionType.ADD_FEATURE_GUARD,
237
+ ):
238
+ return self._process_remediation(at, action.parameters)
239
+
240
+ # ── Submit diagnosis ─────────────────────────────────────────────
241
+ if at == ActionType.SUBMIT_DIAGNOSIS:
242
+ self._diagnosis_submitted = action.parameters
243
+ root_cause = action.parameters.get("root_cause", "not specified")
244
+ summary = action.parameters.get("summary", "no summary")
245
+ return (
246
+ f"Diagnosis submitted.\n"
247
+ f" Root cause: {root_cause}\n"
248
+ f" Summary: {summary}",
249
+ 0.0, # Actual reward computed in step() via grader
250
+ True,
251
+ )
252
+
253
+ return "Unknown action type.", -0.2, False
254
+
255
+ def _process_remediation(
256
+ self, action_type: ActionType, params: dict[str, Any]
257
+ ) -> tuple[str, float, bool]:
258
+ """Process a remediation action."""
259
+ task = self._task
260
+ assert task is not None
261
+
262
+ self._remediations_applied.append(action_type.value)
263
+
264
+ is_correct = action_type in task.correct_remediations
265
+ has_diagnosed = len(self._diagnostics_gathered) >= 2
266
+
267
+ # Penalize hasty remediation without diagnosis
268
+ if not has_diagnosed:
269
+ return (
270
+ f"⚠ WARNING: Applying {action_type.value} without sufficient "
271
+ f"diagnostic investigation. This is risky in production.\n"
272
+ f"Action applied, but confidence is low.",
273
+ -0.5,
274
+ True,
275
+ )
276
+
277
+ if is_correct:
278
+ # Specific feedback per remediation
279
+ feedback = self._get_remediation_feedback(action_type, params)
280
+ return feedback, 1.0, True
281
+ else:
282
+ return (
283
+ f"Applied {action_type.value}, but this doesn't address the "
284
+ f"root cause. The issue persists. Consider more investigation.",
285
+ -0.3,
286
+ True,
287
+ )
288
+
289
+ def _get_remediation_feedback(
290
+ self, action_type: ActionType, params: dict[str, Any]
291
+ ) -> str:
292
+ """Generate specific feedback for correct remediations."""
293
+ task = self._task
294
+ assert task is not None
295
+
296
+ if action_type == ActionType.ADJUST_THRESHOLD:
297
+ new_t = params.get("new_threshold", "unspecified")
298
+ optimal = task.extra_state.get("optimal_threshold", 0.5)
299
+ return (
300
+ f"Threshold adjusted to {new_t}.\n"
301
+ f"Optimal threshold was {optimal}.\n"
302
+ f"Precision recovering. False positive rate decreasing."
303
+ )
304
+ elif action_type == ActionType.ROLLBACK_MODEL:
305
+ target = params.get("target_version", "previous")
306
+ return (
307
+ f"Model rolled back to {target}.\n"
308
+ f"Previous model restored. Monitoring metrics.\n"
309
+ f"Harmful content detection recovering."
310
+ )
311
+ elif action_type == ActionType.FIX_DATA_PIPELINE:
312
+ return (
313
+ "Data pipeline fix initiated.\n"
314
+ " - Credit bureau connector patched\n"
315
+ " - Unit conversion (cents→dollars) applied to annual_income\n"
316
+ " - Null handling added for credit_utilization_ratio\n"
317
+ " Pipeline revalidation in progress."
318
+ )
319
+ elif action_type == ActionType.RETRAIN_MODEL:
320
+ return (
321
+ "Model retraining triggered with corrected data.\n"
322
+ " - Stale features refreshed\n"
323
+ " - Estimated completion: 2 hours\n"
324
+ " - Will auto-deploy after quality gate passes."
325
+ )
326
+ elif action_type == ActionType.ADD_FEATURE_GUARD:
327
+ return (
328
+ "Feature guard / input validation added.\n"
329
+ " - Per-class recall monitoring enabled\n"
330
+ " - Crowd-source batch validation gate added\n"
331
+ " - Anomalous label distributions will trigger review."
332
+ )
333
+ else:
334
+ return f"{action_type.value} applied successfully."
tasks.py ADDED
@@ -0,0 +1,526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 MLOps Firefighter Contributors
2
+ # Licensed under the BSD-3-Clause License
3
+
4
+ """
5
+ Task definitions for the MLOps Firefighter environment.
6
+
7
+ Each task defines:
8
+ - A production ML incident scenario
9
+ - The ground-truth root cause(s)
10
+ - Required diagnostic steps (for partial credit)
11
+ - Correct remediation action(s)
12
+ - A grader function that scores agent performance 0.0–1.0
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ from dataclasses import dataclass, field
18
+ from typing import Any
19
+
20
+ from models import (
21
+ ActionType,
22
+ Alert,
23
+ AlertSeverity,
24
+ ModelInfo,
25
+ )
26
+
27
+
28
+ # ── Task Data Structure ─────────────────────────────────────────────────────
29
+
30
+ @dataclass
31
+ class TaskDefinition:
32
+ task_id: str
33
+ name: str
34
+ difficulty: str # "easy", "medium", "hard"
35
+ description: str
36
+ alerts: list[Alert]
37
+ model_info: ModelInfo
38
+ root_causes: list[str] # acceptable root cause labels
39
+ required_diagnostics: list[ActionType] # expected investigation steps
40
+ correct_remediations: list[ActionType] # accepted fix actions
41
+ diagnostic_results: dict[ActionType, str] # what each diagnostic reveals
42
+ max_steps: int = 20
43
+ extra_state: dict[str, Any] = field(default_factory=dict)
44
+
45
+
46
+ # ── Task 1: Easy — Threshold Misconfiguration ───────────────────────────────
47
+
48
+ TASK_EASY = TaskDefinition(
49
+ task_id="task_threshold_drift",
50
+ name="Threshold Misconfiguration After Redeployment",
51
+ difficulty="easy",
52
+ description=(
53
+ "INCIDENT: The fraud detection model's precision has dropped from 0.94 to 0.61 "
54
+ "over the last 2 hours. False positive rate spiked 3x. The model was redeployed "
55
+ "yesterday with a new version. No changes to training data or features were made. "
56
+ "Investigate and fix the issue."
57
+ ),
58
+ alerts=[
59
+ Alert(
60
+ severity=AlertSeverity.HIGH,
61
+ message="Precision drop detected on fraud-detection-v3",
62
+ metric_name="precision",
63
+ current_value=0.61,
64
+ threshold=0.85,
65
+ timestamp="2026-03-27T10:15:00Z",
66
+ ),
67
+ Alert(
68
+ severity=AlertSeverity.MEDIUM,
69
+ message="False positive rate above acceptable range",
70
+ metric_name="false_positive_rate",
71
+ current_value=0.18,
72
+ threshold=0.05,
73
+ timestamp="2026-03-27T10:17:00Z",
74
+ ),
75
+ ],
76
+ model_info=ModelInfo(
77
+ model_name="fraud-detection",
78
+ model_version="v3.1.0",
79
+ deployed_at="2026-03-26T14:00:00Z",
80
+ framework="XGBoost",
81
+ endpoint="/api/v1/predict/fraud",
82
+ previous_versions=["v3.0.2", "v2.9.1"],
83
+ ),
84
+ root_causes=["threshold_misconfiguration", "threshold_too_low", "bad_threshold"],
85
+ required_diagnostics=[
86
+ ActionType.INSPECT_METRICS,
87
+ ActionType.RUN_PREDICTION_SAMPLE,
88
+ ],
89
+ correct_remediations=[ActionType.ADJUST_THRESHOLD],
90
+ diagnostic_results={
91
+ ActionType.INSPECT_METRICS: (
92
+ "METRICS DASHBOARD:\n"
93
+ " Model: fraud-detection v3.1.0\n"
94
+ " Precision: 0.61 (was 0.94 on v3.0.2)\n"
95
+ " Recall: 0.99 (was 0.87 on v3.0.2)\n"
96
+ " F1: 0.76 (was 0.90 on v3.0.2)\n"
97
+ " Decision threshold: 0.30 (was 0.55 on v3.0.2)\n"
98
+ " NOTE: Threshold was changed during redeployment.\n"
99
+ " Requests/sec: 1,240 (normal)\n"
100
+ " P99 latency: 45ms (normal)"
101
+ ),
102
+ ActionType.QUERY_LOGS: (
103
+ "LOG SEARCH RESULTS (last 4 hours):\n"
104
+ " [10:00] Deployment v3.1.0 started\n"
105
+ " [10:01] Config applied: threshold=0.30 (previous: 0.55)\n"
106
+ " [10:02] Health check passed\n"
107
+ " [10:15] Alert: precision below 0.85\n"
108
+ " No error logs. No OOM events. Pipeline healthy."
109
+ ),
110
+ ActionType.CHECK_DATA_DISTRIBUTION: (
111
+ "DATA DISTRIBUTION COMPARISON:\n"
112
+ " Training data: 2.1M samples, fraud_rate=3.2%\n"
113
+ " Serving data (last 24h): 148K requests, fraud_rate=3.1%\n"
114
+ " Feature distributions: no significant drift detected\n"
115
+ " KL divergence: 0.003 (threshold: 0.05)\n"
116
+ " All features within expected ranges."
117
+ ),
118
+ ActionType.CHECK_FEATURE_IMPORTANCE: (
119
+ "FEATURE IMPORTANCE (top 5):\n"
120
+ " 1. transaction_amount: 0.31\n"
121
+ " 2. merchant_risk_score: 0.22\n"
122
+ " 3. velocity_1h: 0.18\n"
123
+ " 4. distance_from_home: 0.14\n"
124
+ " 5. device_fingerprint_match: 0.09\n"
125
+ " No anomalies in feature importance vs. training."
126
+ ),
127
+ ActionType.RUN_PREDICTION_SAMPLE: (
128
+ "PREDICTION SAMPLE TEST (100 known-labeled samples):\n"
129
+ " At threshold=0.30: precision=0.62, recall=0.98\n"
130
+ " At threshold=0.55: precision=0.93, recall=0.88\n"
131
+ " At threshold=0.50: precision=0.91, recall=0.90\n"
132
+ " Conclusion: threshold is the primary lever for precision."
133
+ ),
134
+ ActionType.CHECK_INFRASTRUCTURE: (
135
+ "INFRASTRUCTURE STATUS:\n"
136
+ " CPU: 34% (normal)\n Memory: 4.2GB/8GB (normal)\n"
137
+ " GPU: N/A (CPU-only model)\n Latency P99: 45ms (normal)\n"
138
+ " No infrastructure issues detected."
139
+ ),
140
+ ActionType.CHECK_UPSTREAM_PIPELINE: (
141
+ "UPSTREAM PIPELINE:\n"
142
+ " Feature store: healthy, last refresh 12 min ago\n"
143
+ " Data ingestion: nominal, 0 failures in 24h\n"
144
+ " Schema validation: passing\n"
145
+ " No upstream issues."
146
+ ),
147
+ },
148
+ max_steps=15,
149
+ extra_state={"optimal_threshold": 0.55},
150
+ )
151
+
152
+
153
+ # ── Task 2: Medium — Data Drift + Stale Feature ─────────────────────────────
154
+
155
+ TASK_MEDIUM = TaskDefinition(
156
+ task_id="task_data_drift",
157
+ name="Data Drift with Stale Feature Pipeline",
158
+ difficulty="medium",
159
+ description=(
160
+ "INCIDENT: The loan default prediction model's AUC has degraded from 0.91 to "
161
+ "0.74 over the past week. Customer complaints about incorrect rejections have "
162
+ "tripled. The model was last retrained 4 months ago. A new credit bureau data "
163
+ "provider was onboarded 10 days ago. Diagnose the issue and fix it."
164
+ ),
165
+ alerts=[
166
+ Alert(
167
+ severity=AlertSeverity.CRITICAL,
168
+ message="AUC degradation on loan-default-predictor",
169
+ metric_name="auc_roc",
170
+ current_value=0.74,
171
+ threshold=0.85,
172
+ timestamp="2026-03-27T08:00:00Z",
173
+ ),
174
+ Alert(
175
+ severity=AlertSeverity.HIGH,
176
+ message="Customer complaint rate above threshold",
177
+ metric_name="complaint_rate",
178
+ current_value=0.12,
179
+ threshold=0.03,
180
+ timestamp="2026-03-27T09:30:00Z",
181
+ ),
182
+ ],
183
+ model_info=ModelInfo(
184
+ model_name="loan-default-predictor",
185
+ model_version="v2.4.0",
186
+ deployed_at="2025-11-15T10:00:00Z",
187
+ framework="LightGBM",
188
+ endpoint="/api/v1/predict/loan-default",
189
+ previous_versions=["v2.3.1", "v2.2.0"],
190
+ ),
191
+ root_causes=[
192
+ "data_drift",
193
+ "stale_feature",
194
+ "feature_pipeline_broken",
195
+ "data_distribution_shift",
196
+ ],
197
+ required_diagnostics=[
198
+ ActionType.CHECK_DATA_DISTRIBUTION,
199
+ ActionType.CHECK_UPSTREAM_PIPELINE,
200
+ ActionType.INSPECT_METRICS,
201
+ ],
202
+ correct_remediations=[
203
+ ActionType.FIX_DATA_PIPELINE,
204
+ ActionType.RETRAIN_MODEL,
205
+ ],
206
+ diagnostic_results={
207
+ ActionType.INSPECT_METRICS: (
208
+ "METRICS DASHBOARD:\n"
209
+ " Model: loan-default-predictor v2.4.0\n"
210
+ " AUC: 0.74 (was 0.91 at deploy, 0.88 last week)\n"
211
+ " Accuracy: 0.69 (was 0.84)\n"
212
+ " False rejection rate: 14% (was 4%)\n"
213
+ " Threshold: 0.50 (unchanged)\n"
214
+ " Requests/sec: 320 (normal)\n"
215
+ " Degradation trend: gradual decline starting ~10 days ago"
216
+ ),
217
+ ActionType.QUERY_LOGS: (
218
+ "LOG SEARCH RESULTS (last 14 days):\n"
219
+ " [Mar 17] New credit bureau connector deployed (v1.2)\n"
220
+ " [Mar 17] Feature 'credit_utilization_ratio' source changed\n"
221
+ " [Mar 18] Warning: 2,340 null values in 'credit_utilization_ratio'\n"
222
+ " [Mar 19-27] Recurring nulls in credit_utilization_ratio (avg 8%/day)\n"
223
+ " [Mar 22] Warning: feature 'annual_income' format changed (cents→dollars)\n"
224
+ " No deployment events. No model changes."
225
+ ),
226
+ ActionType.CHECK_DATA_DISTRIBUTION: (
227
+ "DATA DISTRIBUTION COMPARISON:\n"
228
+ " Training data vs. Serving data (last 7 days):\n"
229
+ " ─────────────────────────────────────────────\n"
230
+ " credit_utilization_ratio:\n"
231
+ " Train: mean=0.34, std=0.21, null_rate=0.1%\n"
232
+ " Serve: mean=0.08, std=0.42, null_rate=8.3% ⚠ DRIFT DETECTED\n"
233
+ " KL divergence: 0.87 (threshold: 0.05)\n"
234
+ " annual_income:\n"
235
+ " Train: mean=65,400, std=28,000\n"
236
+ " Serve: mean=6,540,000, std=2,800,000 ⚠ DRIFT DETECTED\n"
237
+ " KL divergence: 12.4 (threshold: 0.05)\n"
238
+ " NOTE: Values appear 100x larger — possible unit change\n"
239
+ " Other features: within normal ranges"
240
+ ),
241
+ ActionType.CHECK_FEATURE_IMPORTANCE: (
242
+ "FEATURE IMPORTANCE (top 5):\n"
243
+ " 1. credit_utilization_ratio: 0.28 ← AFFECTED\n"
244
+ " 2. annual_income: 0.22 ← AFFECTED\n"
245
+ " 3. debt_to_income: 0.19\n"
246
+ " 4. payment_history_score: 0.15\n"
247
+ " 5. employment_length: 0.08\n"
248
+ " The top 2 features (50% of importance) are drifted."
249
+ ),
250
+ ActionType.RUN_PREDICTION_SAMPLE: (
251
+ "PREDICTION SAMPLE (50 recent applications with known outcomes):\n"
252
+ " Model predictions vs. actual outcomes:\n"
253
+ " - Correct: 31/50 (62%)\n"
254
+ " - False rejections: 14/50 (28%) — most involve high-income applicants\n"
255
+ " - Missed defaults: 5/50 (10%)\n"
256
+ " Annual income feature appears corrupted (values in cents not dollars)"
257
+ ),
258
+ ActionType.CHECK_INFRASTRUCTURE: (
259
+ "INFRASTRUCTURE STATUS:\n"
260
+ " CPU: 22% (normal)\n Memory: 3.1GB/8GB (normal)\n"
261
+ " Latency P99: 38ms (normal)\n"
262
+ " No infrastructure issues."
263
+ ),
264
+ ActionType.CHECK_UPSTREAM_PIPELINE: (
265
+ "UPSTREAM PIPELINE STATUS:\n"
266
+ " Feature store: HEALTHY (but stale for 2 features)\n"
267
+ " Data ingestion: DEGRADED\n"
268
+ " - credit_bureau_connector_v1.2: returning nulls 8% of time\n"
269
+ " - Schema mismatch: 'annual_income' was USD→now returned in cents\n"
270
+ " - Last successful full refresh: 11 days ago\n"
271
+ " Other connectors: nominal\n"
272
+ " RECOMMENDATION: Fix credit bureau connector, add unit conversion"
273
+ ),
274
+ },
275
+ max_steps=20,
276
+ extra_state={
277
+ "requires_pipeline_fix": True,
278
+ "requires_retrain": True,
279
+ },
280
+ )
281
+
282
+
283
+ # ── Task 3: Hard — Adversarial + Silent Model Regression ────────────────────
284
+
285
+ TASK_HARD = TaskDefinition(
286
+ task_id="task_silent_regression",
287
+ name="Silent Model Regression with Adversarial Inputs",
288
+ difficulty="hard",
289
+ description=(
290
+ "INCIDENT: The content moderation model appears to be working normally — overall "
291
+ "accuracy is 0.96. However, trust & safety reports show a 5x increase in "
292
+ "user-reported harmful content slipping through. The model was retrained 3 days ago "
293
+ "with an automated pipeline. No alerts have fired because aggregate metrics look "
294
+ "fine. A community moderator escalated this manually. Investigate what's happening "
295
+ "and fix it. Be careful — hasty actions could take down moderation for millions of "
296
+ "users."
297
+ ),
298
+ alerts=[
299
+ Alert(
300
+ severity=AlertSeverity.LOW,
301
+ message="Manual escalation: T&S reports 5x harmful content increase",
302
+ metric_name="user_reports_harmful",
303
+ current_value=847.0,
304
+ threshold=200.0,
305
+ timestamp="2026-03-27T06:00:00Z",
306
+ ),
307
+ ],
308
+ model_info=ModelInfo(
309
+ model_name="content-moderator",
310
+ model_version="v5.2.0",
311
+ deployed_at="2026-03-24T03:00:00Z",
312
+ framework="PyTorch (transformer)",
313
+ endpoint="/api/v1/moderate/content",
314
+ previous_versions=["v5.1.3", "v5.0.0", "v4.8.2"],
315
+ ),
316
+ root_causes=[
317
+ "training_data_poisoning",
318
+ "adversarial_inputs",
319
+ "label_corruption",
320
+ "data_poisoning",
321
+ "silent_regression",
322
+ "class_specific_regression",
323
+ ],
324
+ required_diagnostics=[
325
+ ActionType.INSPECT_METRICS,
326
+ ActionType.RUN_PREDICTION_SAMPLE,
327
+ ActionType.CHECK_DATA_DISTRIBUTION,
328
+ ActionType.CHECK_UPSTREAM_PIPELINE,
329
+ ],
330
+ correct_remediations=[
331
+ ActionType.ROLLBACK_MODEL,
332
+ ActionType.ADD_FEATURE_GUARD,
333
+ ],
334
+ diagnostic_results={
335
+ ActionType.INSPECT_METRICS: (
336
+ "METRICS DASHBOARD:\n"
337
+ " Model: content-moderator v5.2.0\n"
338
+ " Overall accuracy: 0.96 (looks normal!)\n"
339
+ " Overall precision: 0.94, recall: 0.93\n"
340
+ " Latency P99: 120ms (normal)\n"
341
+ " Requests/sec: 45,000 (normal)\n"
342
+ " ── Per-class breakdown ──\n"
343
+ " spam: precision=0.97, recall=0.98 (normal)\n"
344
+ " hate: precision=0.95, recall=0.42 ⚠ (was 0.91)\n"
345
+ " violence: precision=0.93, recall=0.38 ⚠ (was 0.89)\n"
346
+ " safe: precision=0.96, recall=0.99 (normal)\n"
347
+ " NOTE: Hate and violence recall collapsed but aggregate looks OK\n"
348
+ " because 'safe' class dominates (92% of traffic)."
349
+ ),
350
+ ActionType.QUERY_LOGS: (
351
+ "LOG SEARCH RESULTS:\n"
352
+ " [Mar 24 02:30] Auto-retrain pipeline triggered\n"
353
+ " [Mar 24 02:45] Training dataset assembled: 2.4M samples\n"
354
+ " [Mar 24 02:46] Warning: 34,000 new labels from crowd-source batch #847\n"
355
+ " [Mar 24 03:00] Model v5.2.0 deployed (passed aggregate quality gate)\n"
356
+ " [Mar 24 03:01] A/B test: v5.2.0 vs v5.1.3 — aggregate acc comparable\n"
357
+ " [Mar 25] T&S team flags increase in harmful content reports\n"
358
+ " [Mar 26] Reports continue to climb\n"
359
+ " [Mar 27] Manual escalation from community moderator"
360
+ ),
361
+ ActionType.CHECK_DATA_DISTRIBUTION: (
362
+ "DATA DISTRIBUTION COMPARISON:\n"
363
+ " Training data v5.2.0 vs v5.1.3:\n"
364
+ " ─────────────────────────────────────────────\n"
365
+ " Total samples: 2.4M (v5.2.0) vs 2.1M (v5.1.3)\n"
366
+ " New batch #847: 34,000 samples added\n"
367
+ " Label distribution in batch #847:\n"
368
+ " safe: 89% | hate: 7% | violence: 4% | spam: 0%\n"
369
+ " Cross-check with manual audit (100 random from batch #847):\n"
370
+ " Labeled 'safe' but actually hateful: 23/100 ⚠ CONTAMINATION\n"
371
+ " Labeled 'safe' but actually violent: 11/100 ⚠ CONTAMINATION\n"
372
+ " Correctly labeled: 66/100\n"
373
+ " CONCLUSION: Batch #847 contains systematically mislabeled data.\n"
374
+ " Hate/violence content was labeled as 'safe', teaching the model\n"
375
+ " to pass through harmful content."
376
+ ),
377
+ ActionType.CHECK_FEATURE_IMPORTANCE: (
378
+ "ATTENTION ANALYSIS (transformer model):\n"
379
+ " Attention patterns on hate/violence content:\n"
380
+ " v5.1.3: High attention on slurs, threats, graphic terms\n"
381
+ " v5.2.0: Attention dispersed, less focus on key harmful tokens\n"
382
+ " The model appears to have 'unlearned' key harmful patterns\n"
383
+ " from the poisoned training batch."
384
+ ),
385
+ ActionType.RUN_PREDICTION_SAMPLE: (
386
+ "PREDICTION SAMPLE (200 curated test samples, ground truth labels):\n"
387
+ " Category-level results:\n"
388
+ " ── Safe content (100 samples) ──\n"
389
+ " Correct: 98/100 (98%) — model correctly passes safe content\n"
390
+ " ── Hate speech (50 samples) ──\n"
391
+ " Detected: 21/50 (42%) ⚠ — model misses most hate speech\n"
392
+ " Missed examples include: coded language, dog-whistles, slurs\n"
393
+ " ── Violent content (50 samples) ──\n"
394
+ " Detected: 19/50 (38%) ⚠ — model misses most violence\n"
395
+ " Missed examples: graphic threats, incitement\n"
396
+ " Overall accuracy inflated by safe-class dominance."
397
+ ),
398
+ ActionType.CHECK_INFRASTRUCTURE: (
399
+ "INFRASTRUCTURE STATUS:\n"
400
+ " GPU: 4x A100, utilization 67% (normal)\n"
401
+ " Memory: 28GB/80GB (normal)\n"
402
+ " Latency P99: 120ms (normal)\n"
403
+ " Throughput: 45K req/s (normal)\n"
404
+ " No infrastructure issues."
405
+ ),
406
+ ActionType.CHECK_UPSTREAM_PIPELINE: (
407
+ "UPSTREAM PIPELINE STATUS:\n"
408
+ " Auto-retrain pipeline: HEALTHY (but concern flagged)\n"
409
+ " Training data sources:\n"
410
+ " - Internal labeled data: 2.1M samples (validated)\n"
411
+ " - Crowd-source batch #847: 34K samples (UNVALIDATED) ⚠\n"
412
+ " Quality gate: aggregate accuracy only — no per-class checks ⚠\n"
413
+ " The pipeline accepted v5.2.0 because overall accuracy was 0.96,\n"
414
+ " but it did not check per-class recall.\n"
415
+ " RECOMMENDATION: Add per-class recall gates. Remove batch #847.\n"
416
+ " Consider rollback to v5.1.3 while investigating."
417
+ ),
418
+ },
419
+ max_steps=25,
420
+ extra_state={
421
+ "poisoned_batch": "#847",
422
+ "safe_rollback_version": "v5.1.3",
423
+ "requires_guard": True,
424
+ },
425
+ )
426
+
427
+
428
+ # ── Task Registry ────────────────────────────────────────────────────────────
429
+
430
+ ALL_TASKS: dict[str, TaskDefinition] = {
431
+ TASK_EASY.task_id: TASK_EASY,
432
+ TASK_MEDIUM.task_id: TASK_MEDIUM,
433
+ TASK_HARD.task_id: TASK_HARD,
434
+ }
435
+
436
+
437
+ # ── Grader Functions ─────────────────────────────────────────────────────────
438
+
439
+ def grade_episode(
440
+ task: TaskDefinition,
441
+ actions_taken: list[dict],
442
+ diagnosis_submitted: dict | None,
443
+ remediation_applied: list[str],
444
+ total_steps: int,
445
+ ) -> tuple[float, dict]:
446
+ """
447
+ Grade an agent's performance on a task.
448
+
449
+ Returns:
450
+ (score, breakdown) where score is 0.0–1.0 and breakdown is a dict
451
+ of component scores for interpretability.
452
+ """
453
+ breakdown: dict[str, float] = {}
454
+
455
+ # ── 1. Diagnostic thoroughness (30%) ─────────────────────────────────
456
+ required = set(task.required_diagnostics)
457
+ performed = set()
458
+ for a in actions_taken:
459
+ try:
460
+ at = ActionType(a["action_type"])
461
+ if at in required:
462
+ performed.add(at)
463
+ except (ValueError, KeyError):
464
+ pass
465
+
466
+ if required:
467
+ diag_score = len(performed) / len(required)
468
+ else:
469
+ diag_score = 1.0
470
+ breakdown["diagnostic_thoroughness"] = round(diag_score, 3)
471
+
472
+ # ── 2. Correct diagnosis (30%) ───────────────────────────────────────
473
+ diag_submitted_score = 0.0
474
+ if diagnosis_submitted:
475
+ root = diagnosis_submitted.get("root_cause", "").lower().strip()
476
+ for valid in task.root_causes:
477
+ if valid.lower() in root or root in valid.lower():
478
+ diag_submitted_score = 1.0
479
+ break
480
+ # Partial credit: if they mention a related keyword
481
+ if diag_submitted_score == 0.0:
482
+ keywords = set()
483
+ for rc in task.root_causes:
484
+ keywords.update(rc.lower().replace("_", " ").split())
485
+ matches = sum(1 for kw in keywords if kw in root)
486
+ if matches > 0:
487
+ diag_submitted_score = min(0.5, matches * 0.2)
488
+ breakdown["diagnosis_accuracy"] = round(diag_submitted_score, 3)
489
+
490
+ # ── 3. Correct remediation (25%) ─────────────────────────────────────
491
+ correct_rems = set(r.value for r in task.correct_remediations)
492
+ applied = set(remediation_applied)
493
+ if correct_rems:
494
+ matched = len(applied & correct_rems)
495
+ rem_score = matched / len(correct_rems)
496
+ # Penalize wrong remediations
497
+ wrong = applied - correct_rems
498
+ penalty = len(wrong) * 0.15
499
+ rem_score = max(0.0, rem_score - penalty)
500
+ else:
501
+ rem_score = 1.0 if not applied else 0.5
502
+ breakdown["remediation_accuracy"] = round(rem_score, 3)
503
+
504
+ # ── 4. Efficiency (15%) ──────────────────────────────────────────────
505
+ if total_steps <= len(task.required_diagnostics) + 2:
506
+ eff_score = 1.0 # Very efficient
507
+ elif total_steps <= task.max_steps * 0.5:
508
+ eff_score = 0.8
509
+ elif total_steps <= task.max_steps * 0.75:
510
+ eff_score = 0.5
511
+ elif total_steps < task.max_steps:
512
+ eff_score = 0.3
513
+ else:
514
+ eff_score = 0.1 # Timed out
515
+ breakdown["efficiency"] = round(eff_score, 3)
516
+
517
+ # ── Weighted total ───────────────────────────────────────────────────
518
+ total = (
519
+ 0.30 * breakdown["diagnostic_thoroughness"]
520
+ + 0.30 * breakdown["diagnosis_accuracy"]
521
+ + 0.25 * breakdown["remediation_accuracy"]
522
+ + 0.15 * breakdown["efficiency"]
523
+ )
524
+ breakdown["total"] = round(total, 3)
525
+
526
+ return round(total, 3), breakdown
tests/test_environment.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Tests for the MLOps Firefighter environments."""
3
+
4
+ import sys
5
+ import os
6
+
7
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
8
+
9
+ from models import ActionType, MLOpsAction
10
+ from server.environment import MLOpsFirefighterEnvironment
11
+ from tasks import ALL_TASKS, grade_episode
12
+
13
+
14
+ def test_reset_returns_valid_observation():
15
+ env = MLOpsFirefighterEnvironment()
16
+ obs = env.reset(task_id="task_threshold_drift")
17
+ assert obs.done is False
18
+ assert obs.step_number == 0
19
+ assert obs.task_id == "task_threshold_drift"
20
+ assert len(obs.alerts) > 0
21
+ assert obs.model_info is not None
22
+ assert len(obs.available_actions) > 0
23
+ print("✓ test_reset_returns_valid_observation")
24
+
25
+
26
+ def test_step_diagnostic_action():
27
+ env = MLOpsFirefighterEnvironment()
28
+ env.reset(task_id="task_threshold_drift")
29
+ action = MLOpsAction(action_type=ActionType.INSPECT_METRICS)
30
+ obs = env.step(action)
31
+ assert obs.done is False
32
+ assert obs.reward > 0 # Useful diagnostic
33
+ assert obs.step_number == 1
34
+ assert "METRICS DASHBOARD" in obs.action_result
35
+ assert len(obs.diagnostics_gathered) == 1
36
+ print("✓ test_step_diagnostic_action")
37
+
38
+
39
+ def test_redundant_diagnostic():
40
+ env = MLOpsFirefighterEnvironment()
41
+ env.reset(task_id="task_threshold_drift")
42
+ action = MLOpsAction(action_type=ActionType.INSPECT_METRICS)
43
+ env.step(action)
44
+ obs2 = env.step(action)
45
+ assert obs2.reward == 0.05 # Redundant
46
+ assert "already ran" in obs2.action_result.lower()
47
+ print("✓ test_redundant_diagnostic")
48
+
49
+
50
+ def test_hasty_remediation_penalized():
51
+ env = MLOpsFirefighterEnvironment()
52
+ env.reset(task_id="task_threshold_drift")
53
+ # Apply fix without any diagnosis
54
+ action = MLOpsAction(
55
+ action_type=ActionType.ADJUST_THRESHOLD,
56
+ parameters={"new_threshold": 0.55},
57
+ )
58
+ obs = env.step(action)
59
+ assert obs.reward < 0 # Penalized
60
+ assert "WARNING" in obs.action_result
61
+ print("✓ test_hasty_remediation_penalized")
62
+
63
+
64
+ def test_correct_remediation_after_diagnosis():
65
+ env = MLOpsFirefighterEnvironment()
66
+ env.reset(task_id="task_threshold_drift")
67
+ # Run 2 diagnostics
68
+ env.step(MLOpsAction(action_type=ActionType.INSPECT_METRICS))
69
+ env.step(MLOpsAction(action_type=ActionType.RUN_PREDICTION_SAMPLE))
70
+ # Apply correct fix
71
+ obs = env.step(MLOpsAction(
72
+ action_type=ActionType.ADJUST_THRESHOLD,
73
+ parameters={"new_threshold": 0.55},
74
+ ))
75
+ assert obs.reward == 1.0 # Correct remediation
76
+ print("✓ test_correct_remediation_after_diagnosis")
77
+
78
+
79
+ def test_submit_diagnosis_ends_episode():
80
+ env = MLOpsFirefighterEnvironment()
81
+ env.reset(task_id="task_threshold_drift")
82
+ env.step(MLOpsAction(action_type=ActionType.INSPECT_METRICS))
83
+ env.step(MLOpsAction(action_type=ActionType.RUN_PREDICTION_SAMPLE))
84
+ env.step(MLOpsAction(
85
+ action_type=ActionType.ADJUST_THRESHOLD,
86
+ parameters={"new_threshold": 0.55},
87
+ ))
88
+ obs = env.step(MLOpsAction(
89
+ action_type=ActionType.SUBMIT_DIAGNOSIS,
90
+ parameters={"root_cause": "threshold_misconfiguration", "summary": "test"},
91
+ ))
92
+ assert obs.done is True
93
+ assert obs.reward > 0
94
+ print("✓ test_submit_diagnosis_ends_episode")
95
+
96
+
97
+ def test_state_tracking():
98
+ env = MLOpsFirefighterEnvironment()
99
+ env.reset(task_id="task_threshold_drift")
100
+ env.step(MLOpsAction(action_type=ActionType.INSPECT_METRICS))
101
+ st = env.state()
102
+ assert st["step_count"] == 1
103
+ assert st["task_id"] == "task_threshold_drift"
104
+ assert st["done"] is False
105
+ assert st["actions_taken_count"] == 1
106
+ print("✓ test_state_tracking")
107
+
108
+
109
+ def test_timeout():
110
+ env = MLOpsFirefighterEnvironment()
111
+ env.reset(task_id="task_threshold_drift")
112
+ # Exhaust all steps with useless actions
113
+ for _ in range(15):
114
+ obs = env.step(MLOpsAction(action_type=ActionType.CHECK_INFRASTRUCTURE))
115
+ if obs.done:
116
+ break
117
+ assert obs.done is True
118
+ print("✓ test_timeout")
119
+
120
+
121
+ def test_grader_scores_in_range():
122
+ """Verify all grader scores are between 0.0 and 1.0."""
123
+ for task_id, task in ALL_TASKS.items():
124
+ # Perfect run
125
+ score_perfect, _ = grade_episode(
126
+ task=task,
127
+ actions_taken=[{"action_type": d.value} for d in task.required_diagnostics],
128
+ diagnosis_submitted={"root_cause": task.root_causes[0], "summary": "test"},
129
+ remediation_applied=[r.value for r in task.correct_remediations],
130
+ total_steps=len(task.required_diagnostics) + len(task.correct_remediations) + 1,
131
+ )
132
+ assert 0.0 <= score_perfect <= 1.0, f"Perfect score out of range: {score_perfect}"
133
+
134
+ # Zero run
135
+ score_zero, _ = grade_episode(
136
+ task=task,
137
+ actions_taken=[],
138
+ diagnosis_submitted=None,
139
+ remediation_applied=[],
140
+ total_steps=task.max_steps,
141
+ )
142
+ assert 0.0 <= score_zero <= 1.0, f"Zero score out of range: {score_zero}"
143
+
144
+ print("✓ test_grader_scores_in_range")
145
+
146
+
147
+ def test_all_three_tasks_exist():
148
+ assert len(ALL_TASKS) >= 3
149
+ difficulties = {t.difficulty for t in ALL_TASKS.values()}
150
+ assert "easy" in difficulties
151
+ assert "medium" in difficulties
152
+ assert "hard" in difficulties
153
+ print("✓ test_all_three_tasks_exist")
154
+
155
+
156
+ def test_full_episode_easy():
157
+ """Full integration test: perfect run on easy task."""
158
+ env = MLOpsFirefighterEnvironment()
159
+ env.reset(task_id="task_threshold_drift")
160
+
161
+ # Diagnostics
162
+ env.step(MLOpsAction(action_type=ActionType.INSPECT_METRICS))
163
+ env.step(MLOpsAction(action_type=ActionType.RUN_PREDICTION_SAMPLE))
164
+
165
+ # Fix
166
+ env.step(MLOpsAction(
167
+ action_type=ActionType.ADJUST_THRESHOLD,
168
+ parameters={"new_threshold": 0.55},
169
+ ))
170
+
171
+ # Submit
172
+ obs = env.step(MLOpsAction(
173
+ action_type=ActionType.SUBMIT_DIAGNOSIS,
174
+ parameters={"root_cause": "threshold_misconfiguration", "summary": "test"},
175
+ ))
176
+
177
+ assert obs.done is True
178
+ st = env.state()
179
+ score = st["grader_result"]["total"]
180
+ assert score >= 0.85, f"Expected high score for perfect run, got {score}"
181
+ print(f"✓ test_full_episode_easy (score={score:.3f})")
182
+
183
+
184
+ def test_full_episode_medium():
185
+ """Full integration test: good run on medium task."""
186
+ env = MLOpsFirefighterEnvironment()
187
+ env.reset(task_id="task_data_drift")
188
+
189
+ env.step(MLOpsAction(action_type=ActionType.INSPECT_METRICS))
190
+ env.step(MLOpsAction(action_type=ActionType.CHECK_DATA_DISTRIBUTION))
191
+ env.step(MLOpsAction(action_type=ActionType.CHECK_UPSTREAM_PIPELINE))
192
+
193
+ env.step(MLOpsAction(action_type=ActionType.FIX_DATA_PIPELINE))
194
+ env.step(MLOpsAction(action_type=ActionType.RETRAIN_MODEL))
195
+
196
+ obs = env.step(MLOpsAction(
197
+ action_type=ActionType.SUBMIT_DIAGNOSIS,
198
+ parameters={"root_cause": "data_drift", "summary": "test"},
199
+ ))
200
+
201
+ assert obs.done is True
202
+ st = env.state()
203
+ score = st["grader_result"]["total"]
204
+ assert score >= 0.80, f"Expected high score, got {score}"
205
+ print(f"✓ test_full_episode_medium (score={score:.3f})")
206
+
207
+
208
+ def test_full_episode_hard():
209
+ """Full integration test: good run on hard task."""
210
+ env = MLOpsFirefighterEnvironment()
211
+ env.reset(task_id="task_silent_regression")
212
+
213
+ env.step(MLOpsAction(action_type=ActionType.INSPECT_METRICS))
214
+ env.step(MLOpsAction(action_type=ActionType.RUN_PREDICTION_SAMPLE))
215
+ env.step(MLOpsAction(action_type=ActionType.CHECK_DATA_DISTRIBUTION))
216
+ env.step(MLOpsAction(action_type=ActionType.CHECK_UPSTREAM_PIPELINE))
217
+
218
+ env.step(MLOpsAction(
219
+ action_type=ActionType.ROLLBACK_MODEL,
220
+ parameters={"target_version": "v5.1.3"},
221
+ ))
222
+ env.step(MLOpsAction(action_type=ActionType.ADD_FEATURE_GUARD))
223
+
224
+ obs = env.step(MLOpsAction(
225
+ action_type=ActionType.SUBMIT_DIAGNOSIS,
226
+ parameters={"root_cause": "training_data_poisoning", "summary": "test"},
227
+ ))
228
+
229
+ assert obs.done is True
230
+ st = env.state()
231
+ score = st["grader_result"]["total"]
232
+ assert score >= 0.80, f"Expected high score, got {score}"
233
+ print(f"✓ test_full_episode_hard (score={score:.3f})")
234
+
235
+
236
+ if __name__ == "__main__":
237
+ print("\n=== MLOps Firefighter Environment Tests ===\n")
238
+ test_reset_returns_valid_observation()
239
+ test_step_diagnostic_action()
240
+ test_redundant_diagnostic()
241
+ test_hasty_remediation_penalized()
242
+ test_correct_remediation_after_diagnosis()
243
+ test_submit_diagnosis_ends_episode()
244
+ test_state_tracking()
245
+ test_timeout()
246
+ test_grader_scores_in_range()
247
+ test_all_three_tasks_exist()
248
+ test_full_episode_easy()
249
+ test_full_episode_medium()
250
+ test_full_episode_hard()
251
+ print("\n=== All tests passed! ===\n")
validate.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Pre-submission validation script for the MLOps Firefighter environment.
4
+
5
+ Checks all requirements from the hackathon rubric:
6
+ 1. openenv.yaml exists and is valid
7
+ 2. Typed Pydantic models exist
8
+ 3. step()/reset()/state() work correctly
9
+ 4. 3+ tasks with graders
10
+ 5. Grader scores in 0.0–1.0 range
11
+ 6. All required endpoints respond
12
+ 7. Baseline produces scores
13
+ 8. Dockerfile exists
14
+ """
15
+
16
+ import json
17
+ import sys
18
+ import os
19
+ import yaml
20
+
21
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
22
+ sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "server"))
23
+
24
+ PASS = "✅"
25
+ FAIL = "❌"
26
+ results = []
27
+
28
+
29
+ def check(name: str, condition: bool, detail: str = ""):
30
+ status = PASS if condition else FAIL
31
+ results.append((name, condition))
32
+ msg = f" {status} {name}"
33
+ if detail:
34
+ msg += f" — {detail}"
35
+ print(msg)
36
+ return condition
37
+
38
+
39
+ def main():
40
+ print("\n" + "=" * 60)
41
+ print(" MLOps Firefighter — Pre-Submission Validator")
42
+ print("=" * 60 + "\n")
43
+
44
+ # 1. openenv.yaml
45
+ print("[1/8] OpenEnv manifest (openenv.yaml)")
46
+ yaml_path = os.path.join(os.path.dirname(__file__), "openenv.yaml")
47
+ has_yaml = os.path.exists(yaml_path)
48
+ check("openenv.yaml exists", has_yaml)
49
+ if has_yaml:
50
+ with open(yaml_path) as f:
51
+ manifest = yaml.safe_load(f)
52
+ check("Has name", "name" in manifest)
53
+ check("Has version", "version" in manifest)
54
+ check("Has description", "description" in manifest)
55
+ check("Has tasks", "tasks" in manifest and len(manifest["tasks"]) >= 3)
56
+ check("Has 'openenv' tag", "openenv" in manifest.get("tags", []))
57
+
58
+ # 2. Typed Pydantic models
59
+ print("\n[2/8] Typed Pydantic models")
60
+ try:
61
+ from models import MLOpsAction, MLOpsObservation, ActionType
62
+ check("MLOpsAction importable", True)
63
+ check("MLOpsObservation importable", True)
64
+ check("ActionType enum exists", len(ActionType) >= 10)
65
+ # Verify they're Pydantic
66
+ a = MLOpsAction(action_type=ActionType.INSPECT_METRICS)
67
+ check("MLOpsAction is Pydantic", hasattr(a, "model_dump"))
68
+ except Exception as e:
69
+ check("Models import", False, str(e))
70
+
71
+ # 3. step()/reset()/state()
72
+ print("\n[3/8] Environment interface (reset/step/state)")
73
+ try:
74
+ from server.environment import MLOpsFirefighterEnvironment
75
+ env = MLOpsFirefighterEnvironment()
76
+
77
+ obs = env.reset(task_id="task_threshold_drift")
78
+ check("reset() returns observation", obs is not None)
79
+ check("reset() obs has done=False", obs.done is False)
80
+ check("reset() obs has step_number=0", obs.step_number == 0)
81
+
82
+ from models import MLOpsAction, ActionType
83
+ obs2 = env.step(MLOpsAction(action_type=ActionType.INSPECT_METRICS))
84
+ check("step() returns observation", obs2 is not None)
85
+ check("step() increments step_number", obs2.step_number == 1)
86
+ check("step() returns reward", isinstance(obs2.reward, float))
87
+
88
+ st = env.state()
89
+ check("state() returns dict", isinstance(st, dict))
90
+ check("state() has episode_id", "episode_id" in st)
91
+ check("state() has step_count", "step_count" in st)
92
+ except Exception as e:
93
+ check("Environment interface", False, str(e))
94
+
95
+ # 4. 3+ tasks
96
+ print("\n[4/8] Task definitions")
97
+ try:
98
+ from tasks import ALL_TASKS
99
+ check("3+ tasks defined", len(ALL_TASKS) >= 3)
100
+ difficulties = {t.difficulty for t in ALL_TASKS.values()}
101
+ check("Has easy task", "easy" in difficulties)
102
+ check("Has medium task", "medium" in difficulties)
103
+ check("Has hard task", "hard" in difficulties)
104
+ for tid, task in ALL_TASKS.items():
105
+ check(f"Task '{tid}' has root_causes", len(task.root_causes) > 0)
106
+ check(f"Task '{tid}' has diagnostics", len(task.required_diagnostics) > 0)
107
+ check(f"Task '{tid}' has remediations", len(task.correct_remediations) > 0)
108
+ except Exception as e:
109
+ check("Tasks", False, str(e))
110
+
111
+ # 5. Grader scores in range
112
+ print("\n[5/8] Grader scoring (0.0–1.0)")
113
+ try:
114
+ from tasks import grade_episode, ALL_TASKS
115
+ from models import ActionType
116
+ for tid, task in ALL_TASKS.items():
117
+ # Perfect
118
+ score, bd = grade_episode(
119
+ task=task,
120
+ actions_taken=[{"action_type": d.value} for d in task.required_diagnostics],
121
+ diagnosis_submitted={"root_cause": task.root_causes[0]},
122
+ remediation_applied=[r.value for r in task.correct_remediations],
123
+ total_steps=len(task.required_diagnostics) + 2,
124
+ )
125
+ check(f"'{tid}' perfect score in [0,1]", 0.0 <= score <= 1.0, f"{score:.3f}")
126
+
127
+ # Empty
128
+ score_z, _ = grade_episode(
129
+ task=task, actions_taken=[], diagnosis_submitted=None,
130
+ remediation_applied=[], total_steps=task.max_steps,
131
+ )
132
+ check(f"'{tid}' empty score in [0,1]", 0.0 <= score_z <= 1.0, f"{score_z:.3f}")
133
+
134
+ # Partial credit varies
135
+ check(f"'{tid}' grader differentiates", score > score_z, f"perfect={score:.3f} > empty={score_z:.3f}")
136
+ except Exception as e:
137
+ check("Grader", False, str(e))
138
+
139
+ # 6. All endpoints
140
+ print("\n[6/8] HTTP endpoints")
141
+ try:
142
+ from fastapi.testclient import TestClient
143
+ from server.app import app
144
+ client = TestClient(app)
145
+
146
+ r = client.get("/health")
147
+ check("/health returns 200", r.status_code == 200)
148
+
149
+ r = client.get("/tasks")
150
+ check("/tasks returns 200", r.status_code == 200)
151
+ check("/tasks has action_schema", "action_schema" in r.json())
152
+
153
+ r = client.post("/reset", json={"task_id": "task_threshold_drift"})
154
+ check("/reset returns 200", r.status_code == 200)
155
+
156
+ r = client.post("/step", json={"action_type": "inspect_metrics"})
157
+ check("/step returns 200", r.status_code == 200)
158
+
159
+ r = client.get("/state")
160
+ check("/state returns 200", r.status_code == 200)
161
+
162
+ # Complete an episode for grader test
163
+ client.post("/reset", json={"task_id": "task_threshold_drift"})
164
+ client.post("/step", json={"action_type": "inspect_metrics"})
165
+ client.post("/step", json={"action_type": "submit_diagnosis",
166
+ "parameters": {"root_cause": "test", "summary": "t"}})
167
+ r = client.post("/grader", json={})
168
+ check("/grader returns 200", r.status_code == 200)
169
+
170
+ r = client.post("/baseline")
171
+ check("/baseline returns 200", r.status_code == 200)
172
+ check("/baseline has scores", "average_score" in r.json())
173
+ except Exception as e:
174
+ check("Endpoints", False, str(e))
175
+
176
+ # 7. Baseline produces scores
177
+ print("\n[7/8] Baseline scoring")
178
+ try:
179
+ r = client.post("/baseline")
180
+ data = r.json()
181
+ avg = data["average_score"]
182
+ check("Baseline avg score > 0", avg > 0, f"avg={avg}")
183
+ for tid, result in data["baseline_results"].items():
184
+ s = result["score"]
185
+ check(f"Baseline '{tid}' in [0,1]", 0.0 <= s <= 1.0, f"{s:.3f}")
186
+ except Exception as e:
187
+ check("Baseline", False, str(e))
188
+
189
+ # 8. Dockerfile exists
190
+ print("\n[8/8] Dockerfile")
191
+ df_path = os.path.join(os.path.dirname(__file__), "Dockerfile")
192
+ check("Dockerfile exists", os.path.exists(df_path))
193
+ if os.path.exists(df_path):
194
+ with open(df_path) as f:
195
+ content = f.read()
196
+ check("Dockerfile has FROM", "FROM" in content)
197
+ check("Dockerfile has EXPOSE", "EXPOSE" in content)
198
+ check("Dockerfile has CMD", "CMD" in content)
199
+
200
+ # Summary
201
+ total = len(results)
202
+ passed = sum(1 for _, ok in results if ok)
203
+ failed = total - passed
204
+
205
+ print("\n" + "=" * 60)
206
+ if failed == 0:
207
+ print(f" {PASS} ALL {total} CHECKS PASSED — Ready to submit!")
208
+ else:
209
+ print(f" {FAIL} {failed}/{total} checks failed")
210
+ for name, ok in results:
211
+ if not ok:
212
+ print(f" - {name}")
213
+ print("=" * 60 + "\n")
214
+
215
+ return 0 if failed == 0 else 1
216
+
217
+
218
+ if __name__ == "__main__":
219
+ sys.exit(main())