Spaces:
Running
Running
Commit ·
670f19f
1
Parent(s): 7be72df
MLOps Firefighter - OpenEnv environment
Browse files- Dockerfile +18 -0
- README.md +287 -7
- baseline_inference.py +245 -0
- models.py +117 -0
- openenv.yaml +60 -0
- pyproject.toml +24 -0
- requirements.txt +5 -0
- server/__init__.py +1 -0
- server/app.py +259 -0
- server/environment.py +334 -0
- tasks.py +526 -0
- tests/test_environment.py +251 -0
- validate.py +219 -0
Dockerfile
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
RUN useradd -m -u 1000 user
|
| 4 |
+
USER user
|
| 5 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
| 6 |
+
ENV PYTHONUNBUFFERED=1
|
| 7 |
+
ENV PORT=7860
|
| 8 |
+
|
| 9 |
+
WORKDIR /app
|
| 10 |
+
|
| 11 |
+
COPY --chown=user ./requirements.txt requirements.txt
|
| 12 |
+
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
| 13 |
+
|
| 14 |
+
COPY --chown=user . /app
|
| 15 |
+
|
| 16 |
+
EXPOSE 7860
|
| 17 |
+
|
| 18 |
+
CMD ["python", "server/app.py"]
|
README.md
CHANGED
|
@@ -1,12 +1,292 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
colorTo: red
|
| 6 |
sdk: docker
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
pinned: false
|
| 8 |
-
license:
|
| 9 |
-
short_description: AI agents act as on-call MLOps engineers
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: MLOps Firefighter
|
| 3 |
+
colorFrom: red
|
| 4 |
+
colorTo: orange
|
|
|
|
| 5 |
sdk: docker
|
| 6 |
+
app_port: 7860
|
| 7 |
+
tags:
|
| 8 |
+
- openenv
|
| 9 |
+
- rl
|
| 10 |
+
- mlops
|
| 11 |
+
- production-ml
|
| 12 |
pinned: false
|
| 13 |
+
license: bsd-3-clause
|
|
|
|
| 14 |
---
|
| 15 |
|
| 16 |
+
# MLOps Firefighter — OpenEnv Environment
|
| 17 |
+
|
| 18 |
+
**Debug and fix ML models failing in production.**
|
| 19 |
+
|
| 20 |
+
An OpenEnv-compliant reinforcement learning environment where AI agents act as on-call MLOps engineers. When a production ML model starts misbehaving, the agent must diagnose the root cause and apply the correct fix, just like a real engineer would at 3 AM.
|
| 21 |
+
|
| 22 |
+
---
|
| 23 |
+
|
| 24 |
+
## Why This Environment?
|
| 25 |
+
|
| 26 |
+
Every ML team eventually faces the "model broke in prod" moment. Data drifts. Thresholds get misconfigured during deployments. Training pipelines get poisoned. These are real incidents that cost companies millions of dollars and require skilled human reasoning to resolve.
|
| 27 |
+
|
| 28 |
+
This environment captures that challenge:
|
| 29 |
+
- **Real-world task**: MLOps incident response is performed daily by thousands of engineers
|
| 30 |
+
- **Rich reasoning required**: agents must investigate before acting, weigh evidence, and avoid destructive actions
|
| 31 |
+
- **Novel domain**: no existing OpenEnv covers production ML debugging
|
| 32 |
+
- **Meaningful difficulty progression**: from a simple config error to a subtle adversarial data poisoning attack
|
| 33 |
+
|
| 34 |
+
---
|
| 35 |
+
|
| 36 |
+
## Environment Overview
|
| 37 |
+
|
| 38 |
+
### How It Works
|
| 39 |
+
|
| 40 |
+
1. **An incident fires**: the agent receives alerts about a failing ML model (e.g., precision dropped, complaints spiked)
|
| 41 |
+
2. **The agent investigates**: run diagnostics like checking metrics, querying logs, inspecting data distributions
|
| 42 |
+
3. **The agent identifies the root cause**: data drift? threshold misconfiguration? poisoned training data?
|
| 43 |
+
4. **The agent applies a fix**: rollback the model, adjust thresholds, fix the data pipeline, add guardrails
|
| 44 |
+
5. **The agent submits a diagnosis**: declare the root cause and close the incident
|
| 45 |
+
|
| 46 |
+
The grader evaluates: diagnostic thoroughness (30%), diagnosis accuracy (30%), remediation correctness (25%), and efficiency (15%).
|
| 47 |
+
|
| 48 |
+
### Reward Shaping
|
| 49 |
+
|
| 50 |
+
Rewards provide signal throughout the episode, not just at the end:
|
| 51 |
+
|
| 52 |
+
| Action | Reward |
|
| 53 |
+
|--------|--------|
|
| 54 |
+
| Required diagnostic (first time) | +0.3 |
|
| 55 |
+
| Useful but non-critical diagnostic | +0.1 |
|
| 56 |
+
| Redundant diagnostic (already done) | +0.05 |
|
| 57 |
+
| Correct remediation (after investigation) | +1.0 |
|
| 58 |
+
| Hasty remediation (without diagnosis) | −0.5 |
|
| 59 |
+
| Wrong remediation | −0.3 |
|
| 60 |
+
| Final score bonus (on submit_diagnosis) | 0–5.0 scaled by grader |
|
| 61 |
+
| Timeout penalty | Reduced final multiplier |
|
| 62 |
+
|
| 63 |
+
---
|
| 64 |
+
|
| 65 |
+
## Action Space
|
| 66 |
+
|
| 67 |
+
Actions the agent can take, specified via `action_type` string:
|
| 68 |
+
|
| 69 |
+
### Diagnostic Actions (investigation)
|
| 70 |
+
| Action | Description |
|
| 71 |
+
|--------|-------------|
|
| 72 |
+
| `inspect_metrics` | View model performance dashboard (accuracy, precision, recall, latency) |
|
| 73 |
+
| `query_logs` | Search production logs for deployment events, errors, warnings |
|
| 74 |
+
| `check_data_dist` | Compare training vs. serving data distributions, detect drift |
|
| 75 |
+
| `check_feature_importance` | Examine which features the model relies on |
|
| 76 |
+
| `run_prediction_sample` | Test the model on samples with known ground-truth labels |
|
| 77 |
+
| `check_infrastructure` | Check CPU, memory, GPU, latency — rule out infra issues |
|
| 78 |
+
| `check_upstream_pipeline` | Inspect data pipeline health, connectors, schema changes |
|
| 79 |
+
|
| 80 |
+
### Remediation Actions (fixes)
|
| 81 |
+
| Action | Parameters | Description |
|
| 82 |
+
|--------|-----------|-------------|
|
| 83 |
+
| `rollback_model` | `target_version: str` | Revert to a previous model version |
|
| 84 |
+
| `adjust_threshold` | `new_threshold: float` | Tune the decision threshold |
|
| 85 |
+
| `retrain_model` | — | Trigger model retraining on corrected data |
|
| 86 |
+
| `fix_data_pipeline` | — | Repair data ingestion / feature pipeline |
|
| 87 |
+
| `scale_infrastructure` | — | Add compute resources |
|
| 88 |
+
| `add_feature_guard` | — | Add input validation and monitoring guardrails |
|
| 89 |
+
|
| 90 |
+
### Episode Control
|
| 91 |
+
| Action | Parameters | Description |
|
| 92 |
+
|--------|-----------|-------------|
|
| 93 |
+
| `submit_diagnosis` | `root_cause: str, summary: str` | Declare root cause and end the episode |
|
| 94 |
+
|
| 95 |
+
---
|
| 96 |
+
|
| 97 |
+
## Observation Space
|
| 98 |
+
|
| 99 |
+
After each action, the agent receives an `MLOpsObservation` with:
|
| 100 |
+
|
| 101 |
+
| Field | Type | Description |
|
| 102 |
+
|-------|------|-------------|
|
| 103 |
+
| `done` | bool | Whether the episode has ended |
|
| 104 |
+
| `reward` | float | Reward for the last action |
|
| 105 |
+
| `step_number` | int | Current step (starts at 0) |
|
| 106 |
+
| `max_steps` | int | Steps before timeout |
|
| 107 |
+
| `task_id` | str | Task identifier |
|
| 108 |
+
| `task_description` | str | Natural language incident description |
|
| 109 |
+
| `alerts` | list[Alert] | Active production alerts with severity, metric, value |
|
| 110 |
+
| `model_info` | ModelInfo | Deployed model name, version, framework, endpoint |
|
| 111 |
+
| `action_result` | str | Detailed textual result of the last action |
|
| 112 |
+
| `action_success` | bool | Whether the action executed successfully |
|
| 113 |
+
| `diagnostics_gathered` | list[str] | Summary of all diagnostics collected so far |
|
| 114 |
+
| `available_actions` | list[str] | Valid action types |
|
| 115 |
+
|
| 116 |
+
---
|
| 117 |
+
|
| 118 |
+
## Tasks
|
| 119 |
+
|
| 120 |
+
### Task 1: Threshold Misconfiguration (Easy)
|
| 121 |
+
**Scenario**: A fraud detection model's precision dropped from 0.94 to 0.61 after redeployment. The decision threshold was accidentally changed from 0.55 to 0.30.
|
| 122 |
+
|
| 123 |
+
- **Root cause**: Threshold misconfiguration during deployment
|
| 124 |
+
- **Expected fix**: Adjust threshold back to ~0.55
|
| 125 |
+
- **Key signal**: Metrics show threshold changed; prediction samples confirm threshold is the lever
|
| 126 |
+
- **Max steps**: 15
|
| 127 |
+
|
| 128 |
+
### Task 2: Data Drift with Stale Feature Pipeline (Medium)
|
| 129 |
+
**Scenario**: A loan default model's AUC degraded from 0.91 to 0.74. A new credit bureau connector was onboarded 10 days ago, introducing null values and unit changes (dollars → cents) in key features.
|
| 130 |
+
|
| 131 |
+
- **Root causes**: Data drift + broken feature pipeline
|
| 132 |
+
- **Expected fixes**: Fix data pipeline + retrain model
|
| 133 |
+
- **Key signals**: Distribution comparison shows 100x scale change in income; upstream pipeline shows connector issues
|
| 134 |
+
- **Max steps**: 20
|
| 135 |
+
|
| 136 |
+
### Task 3: Silent Model Regression with Data Poisoning (Hard)
|
| 137 |
+
**Scenario**: A content moderation model has 0.96 overall accuracy (looks fine!), but hate speech and violence recall collapsed to ~40%. An automated retraining pipeline ingested a poisoned crowd-source batch that systematically mislabeled harmful content as "safe." Aggregate metrics didn't catch it because the safe class dominates traffic (92%).
|
| 138 |
+
|
| 139 |
+
- **Root causes**: Training data poisoning / label corruption
|
| 140 |
+
- **Expected fixes**: Roll back to safe version + add per-class recall guardrails
|
| 141 |
+
- **Key signals**: Per-class metrics reveal the hidden regression; data audit shows contaminated batch
|
| 142 |
+
- **Difficulty**: Requires the agent to look beyond aggregate metrics and think about class imbalance
|
| 143 |
+
- **Max steps**: 25
|
| 144 |
+
|
| 145 |
+
### Expected Difficulty
|
| 146 |
+
| Task | Difficulty | Baseline Score | Notes |
|
| 147 |
+
|------|-----------|---------------|-------|
|
| 148 |
+
| Threshold Misconfiguration | Easy | 1.000 | Single clear root cause |
|
| 149 |
+
| Data Drift | Medium | 0.970 | Multiple issues to identify |
|
| 150 |
+
| Silent Regression | Hard | 0.970 | Requires looking past surface metrics |
|
| 151 |
+
|
| 152 |
+
---
|
| 153 |
+
|
| 154 |
+
## Setup & Usage
|
| 155 |
+
|
| 156 |
+
### Quick Start (Local Python)
|
| 157 |
+
|
| 158 |
+
```bash
|
| 159 |
+
# Clone and install
|
| 160 |
+
git clone <repo-url>
|
| 161 |
+
cd mlops_firefighter
|
| 162 |
+
pip install -r requirements.txt
|
| 163 |
+
|
| 164 |
+
# Run the server
|
| 165 |
+
python server/app.py
|
| 166 |
+
# → Server running at http://localhost:7860
|
| 167 |
+
|
| 168 |
+
# In another terminal, test it
|
| 169 |
+
curl http://localhost:7860/health
|
| 170 |
+
curl http://localhost:7860/tasks
|
| 171 |
+
curl -X POST http://localhost:7860/reset -H "Content-Type: application/json" -d '{"task_id": "task_threshold_drift"}'
|
| 172 |
+
curl -X POST http://localhost:7860/step -H "Content-Type: application/json" -d '{"action_type": "inspect_metrics"}'
|
| 173 |
+
```
|
| 174 |
+
|
| 175 |
+
### Docker
|
| 176 |
+
|
| 177 |
+
```bash
|
| 178 |
+
# Build
|
| 179 |
+
docker build -t mlops-firefighter .
|
| 180 |
+
|
| 181 |
+
# Run
|
| 182 |
+
docker run -p 7860:7860 mlops-firefighter
|
| 183 |
+
|
| 184 |
+
# Test
|
| 185 |
+
curl http://localhost:7860/health
|
| 186 |
+
```
|
| 187 |
+
|
| 188 |
+
### Run Baseline Inference (requires OpenAI API key)
|
| 189 |
+
|
| 190 |
+
```bash
|
| 191 |
+
export OPENAI_API_KEY=sk-...
|
| 192 |
+
python baseline_inference.py --base-url http://localhost:7860
|
| 193 |
+
```
|
| 194 |
+
|
| 195 |
+
### Run the Built-in Baseline (no API key needed)
|
| 196 |
+
|
| 197 |
+
```bash
|
| 198 |
+
curl -X POST http://localhost:7860/baseline
|
| 199 |
+
```
|
| 200 |
+
|
| 201 |
+
### Run Tests
|
| 202 |
+
|
| 203 |
+
```bash
|
| 204 |
+
python tests/test_environment.py
|
| 205 |
+
```
|
| 206 |
+
|
| 207 |
+
---
|
| 208 |
+
|
| 209 |
+
## API Endpoints
|
| 210 |
+
|
| 211 |
+
| Endpoint | Method | Description |
|
| 212 |
+
|----------|--------|-------------|
|
| 213 |
+
| `/health` | GET | Health check — returns `{"status": "healthy"}` |
|
| 214 |
+
| `/reset` | POST | Start new episode. Body: `{"task_id": "..."}` |
|
| 215 |
+
| `/step` | POST | Take action. Body: `{"action_type": "...", "parameters": {...}}` |
|
| 216 |
+
| `/state` | GET | Current environment state |
|
| 217 |
+
| `/tasks` | GET | List all tasks with action schema |
|
| 218 |
+
| `/grader` | POST | Get grader score for completed episode |
|
| 219 |
+
| `/baseline` | POST | Run scripted baseline on all tasks, return scores |
|
| 220 |
+
|
| 221 |
+
---
|
| 222 |
+
|
| 223 |
+
## Grader Scoring Breakdown
|
| 224 |
+
|
| 225 |
+
Each episode is scored 0.0–1.0 based on four components:
|
| 226 |
+
|
| 227 |
+
| Component | Weight | Description |
|
| 228 |
+
|-----------|--------|-------------|
|
| 229 |
+
| Diagnostic thoroughness | 30% | Did the agent run the key diagnostics? |
|
| 230 |
+
| Diagnosis accuracy | 30% | Did the agent identify the correct root cause? |
|
| 231 |
+
| Remediation accuracy | 25% | Did the agent apply the right fix(es)? Penalizes wrong fixes. |
|
| 232 |
+
| Efficiency | 15% | How many steps did the agent use? Fewer = better. |
|
| 233 |
+
|
| 234 |
+
---
|
| 235 |
+
|
| 236 |
+
## Baseline Scores
|
| 237 |
+
|
| 238 |
+
Using the built-in scripted baseline (perfect knowledge):
|
| 239 |
+
|
| 240 |
+
| Task | Score | Steps |
|
| 241 |
+
|------|-------|-------|
|
| 242 |
+
| task_threshold_drift (Easy) | 1.000 | 4 |
|
| 243 |
+
| task_data_drift (Medium) | 0.970 | 6 |
|
| 244 |
+
| task_silent_regression (Hard) | 0.970 | 7 |
|
| 245 |
+
| **Average** | **0.980** | — |
|
| 246 |
+
|
| 247 |
+
These represent near-optimal play. An LLM agent without perfect knowledge will score lower, especially on the hard task where aggregate metrics are misleading.
|
| 248 |
+
|
| 249 |
+
---
|
| 250 |
+
|
| 251 |
+
## Project Structure
|
| 252 |
+
|
| 253 |
+
```
|
| 254 |
+
mlops_firefighter/
|
| 255 |
+
├── models.py # Pydantic Action/Observation models
|
| 256 |
+
├── tasks.py # Task definitions + grader functions
|
| 257 |
+
├── openenv.yaml # OpenEnv manifest
|
| 258 |
+
├── requirements.txt # Python dependencies
|
| 259 |
+
├── pyproject.toml # Package config
|
| 260 |
+
├── Dockerfile # Container image
|
| 261 |
+
├── baseline_inference.py # LLM baseline script (OpenAI API)
|
| 262 |
+
├── README.md # This file
|
| 263 |
+
├── server/
|
| 264 |
+
│ ├── __init__.py
|
| 265 |
+
│ ├── app.py # FastAPI server with all endpoints
|
| 266 |
+
│ └── environment.py # Core environment logic
|
| 267 |
+
└── tests/
|
| 268 |
+
└── test_environment.py # Comprehensive test suite (13 tests)
|
| 269 |
+
```
|
| 270 |
+
|
| 271 |
+
---
|
| 272 |
+
|
| 273 |
+
## OpenEnv Spec Compliance
|
| 274 |
+
|
| 275 |
+
- ✅ Typed `Action` and `Observation` Pydantic models
|
| 276 |
+
- ✅ `step(action)` → returns observation, reward, done, info
|
| 277 |
+
- ✅ `reset()` → returns initial observation
|
| 278 |
+
- ✅ `state()` → returns current state
|
| 279 |
+
- ✅ `openenv.yaml` with metadata
|
| 280 |
+
- ✅ 3 tasks with difficulty progression (easy/medium/hard)
|
| 281 |
+
- ✅ Programmatic graders scoring 0.0–1.0
|
| 282 |
+
- ✅ Meaningful reward shaping (not just sparse end-of-episode)
|
| 283 |
+
- ✅ Baseline inference script using OpenAI API
|
| 284 |
+
- ✅ Dockerfile that builds and runs
|
| 285 |
+
- ✅ `/health`, `/tasks`, `/grader`, `/baseline` endpoints
|
| 286 |
+
|
| 287 |
+
---
|
| 288 |
+
|
| 289 |
+
## License
|
| 290 |
+
|
| 291 |
+
BSD-3-Clause
|
| 292 |
+
|
baseline_inference.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright (c) 2026 MLOps Firefighter Contributors
|
| 3 |
+
# Licensed under the BSD-3-Clause License
|
| 4 |
+
|
| 5 |
+
"""
|
| 6 |
+
Baseline inference script for the MLOps Firefighter environment.
|
| 7 |
+
|
| 8 |
+
Uses the OpenAI API client to run a model (e.g. GPT-4o) against all 3 tasks,
|
| 9 |
+
producing a reproducible baseline score.
|
| 10 |
+
|
| 11 |
+
Usage:
|
| 12 |
+
export OPENAI_API_KEY=sk-...
|
| 13 |
+
python baseline_inference.py [--base-url http://localhost:7860]
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import json
|
| 20 |
+
import os
|
| 21 |
+
import sys
|
| 22 |
+
|
| 23 |
+
import requests
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
from openai import OpenAI
|
| 27 |
+
except ImportError:
|
| 28 |
+
print("Install openai: pip install openai")
|
| 29 |
+
sys.exit(1)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# ── Configuration ────────────────────────────────────────────────────────────
|
| 33 |
+
|
| 34 |
+
DEFAULT_ENV_URL = "http://localhost:7860"
|
| 35 |
+
MODEL = os.environ.get("OPENAI_MODEL", "gpt-4o")
|
| 36 |
+
|
| 37 |
+
SYSTEM_PROMPT = """\
|
| 38 |
+
You are an expert MLOps engineer on-call. A production ML model is failing and
|
| 39 |
+
you must diagnose the root cause and fix it.
|
| 40 |
+
|
| 41 |
+
You interact with the environment by choosing actions. Each action returns
|
| 42 |
+
diagnostic information or applies a remediation.
|
| 43 |
+
|
| 44 |
+
STRATEGY:
|
| 45 |
+
1. First, investigate: run diagnostics to understand what's wrong
|
| 46 |
+
2. Identify the root cause from the evidence
|
| 47 |
+
3. Apply the correct remediation
|
| 48 |
+
4. Submit your diagnosis with a clear root cause label and summary
|
| 49 |
+
|
| 50 |
+
Available actions (use exact strings):
|
| 51 |
+
- inspect_metrics: View model performance metrics dashboard
|
| 52 |
+
- query_logs: Search production logs for anomalies
|
| 53 |
+
- check_data_dist: Compare training vs serving data distributions
|
| 54 |
+
- check_feature_importance: Examine feature weights and importance
|
| 55 |
+
- run_prediction_sample: Test model on sample inputs with known labels
|
| 56 |
+
- check_infrastructure: Check latency, memory, GPU, compute
|
| 57 |
+
- check_upstream_pipeline: Inspect data pipeline health
|
| 58 |
+
- rollback_model: Revert to a previous model version (params: target_version)
|
| 59 |
+
- adjust_threshold: Tune decision threshold (params: new_threshold)
|
| 60 |
+
- retrain_model: Trigger model retraining
|
| 61 |
+
- fix_data_pipeline: Repair data ingestion issues
|
| 62 |
+
- scale_infrastructure: Add compute resources
|
| 63 |
+
- add_feature_guard: Add input validation / guardrails
|
| 64 |
+
- submit_diagnosis: Declare root cause and close (params: root_cause, summary)
|
| 65 |
+
|
| 66 |
+
Respond with ONLY valid JSON: {"action_type": "...", "parameters": {...}}
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def call_env(base_url: str, endpoint: str, method: str = "POST", data: dict | None = None) -> dict:
|
| 71 |
+
"""Call an environment endpoint."""
|
| 72 |
+
url = f"{base_url}{endpoint}"
|
| 73 |
+
if method == "GET":
|
| 74 |
+
resp = requests.get(url, timeout=30)
|
| 75 |
+
else:
|
| 76 |
+
resp = requests.post(url, json=data or {}, timeout=30)
|
| 77 |
+
resp.raise_for_status()
|
| 78 |
+
return resp.json()
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def run_agent_on_task(client: OpenAI, base_url: str, task_id: str, task_name: str) -> dict:
|
| 82 |
+
"""Run the LLM agent on a single task."""
|
| 83 |
+
print(f"\n{'='*60}")
|
| 84 |
+
print(f" Task: {task_name} ({task_id})")
|
| 85 |
+
print(f"{'='*60}")
|
| 86 |
+
|
| 87 |
+
# Reset environment
|
| 88 |
+
reset_result = call_env(base_url, "/reset", data={"task_id": task_id})
|
| 89 |
+
obs = reset_result["observation"]
|
| 90 |
+
print(f" Incident: {obs['task_description'][:100]}...")
|
| 91 |
+
|
| 92 |
+
messages = [
|
| 93 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 94 |
+
{
|
| 95 |
+
"role": "user",
|
| 96 |
+
"content": (
|
| 97 |
+
f"INCIDENT ALERT:\n{obs['task_description']}\n\n"
|
| 98 |
+
f"ALERTS:\n{json.dumps(obs['alerts'], indent=2)}\n\n"
|
| 99 |
+
f"MODEL INFO:\n{json.dumps(obs['model_info'], indent=2)}\n\n"
|
| 100 |
+
f"Available actions: {obs['available_actions']}\n\n"
|
| 101 |
+
"What is your first action? Respond with JSON only."
|
| 102 |
+
),
|
| 103 |
+
},
|
| 104 |
+
]
|
| 105 |
+
|
| 106 |
+
total_reward = 0.0
|
| 107 |
+
steps = 0
|
| 108 |
+
max_steps = obs.get("max_steps", 20)
|
| 109 |
+
|
| 110 |
+
while steps < max_steps:
|
| 111 |
+
# Get LLM decision
|
| 112 |
+
try:
|
| 113 |
+
response = client.chat.completions.create(
|
| 114 |
+
model=MODEL,
|
| 115 |
+
messages=messages,
|
| 116 |
+
temperature=0.2,
|
| 117 |
+
max_tokens=500,
|
| 118 |
+
)
|
| 119 |
+
llm_output = response.choices[0].message.content.strip()
|
| 120 |
+
except Exception as e:
|
| 121 |
+
print(f" LLM error: {e}")
|
| 122 |
+
break
|
| 123 |
+
|
| 124 |
+
# Parse action from LLM
|
| 125 |
+
try:
|
| 126 |
+
# Handle markdown code blocks
|
| 127 |
+
if "```" in llm_output:
|
| 128 |
+
llm_output = llm_output.split("```")[1]
|
| 129 |
+
if llm_output.startswith("json"):
|
| 130 |
+
llm_output = llm_output[4:]
|
| 131 |
+
llm_output = llm_output.strip()
|
| 132 |
+
action_data = json.loads(llm_output)
|
| 133 |
+
except json.JSONDecodeError:
|
| 134 |
+
# Try to extract JSON from the response
|
| 135 |
+
import re
|
| 136 |
+
match = re.search(r'\{[^}]+\}', llm_output)
|
| 137 |
+
if match:
|
| 138 |
+
try:
|
| 139 |
+
action_data = json.loads(match.group())
|
| 140 |
+
except json.JSONDecodeError:
|
| 141 |
+
print(f" Failed to parse LLM output: {llm_output[:100]}")
|
| 142 |
+
break
|
| 143 |
+
else:
|
| 144 |
+
print(f" Failed to parse LLM output: {llm_output[:100]}")
|
| 145 |
+
break
|
| 146 |
+
|
| 147 |
+
action_type = action_data.get("action_type", "")
|
| 148 |
+
parameters = action_data.get("parameters", {})
|
| 149 |
+
|
| 150 |
+
# Step environment
|
| 151 |
+
step_result = call_env(
|
| 152 |
+
base_url, "/step",
|
| 153 |
+
data={"action_type": action_type, "parameters": parameters},
|
| 154 |
+
)
|
| 155 |
+
obs = step_result["observation"]
|
| 156 |
+
reward = step_result["reward"]
|
| 157 |
+
done = step_result["done"]
|
| 158 |
+
total_reward += reward
|
| 159 |
+
steps += 1
|
| 160 |
+
|
| 161 |
+
print(f" Step {steps}: {action_type} → reward={reward:.2f}")
|
| 162 |
+
|
| 163 |
+
if done:
|
| 164 |
+
print(f" Episode done. Total reward: {total_reward:.2f}")
|
| 165 |
+
break
|
| 166 |
+
|
| 167 |
+
# Feed result back to LLM
|
| 168 |
+
messages.append({"role": "assistant", "content": llm_output})
|
| 169 |
+
messages.append({
|
| 170 |
+
"role": "user",
|
| 171 |
+
"content": (
|
| 172 |
+
f"ACTION RESULT:\n{obs['action_result']}\n\n"
|
| 173 |
+
f"Step {obs['step_number']}/{obs['max_steps']}\n"
|
| 174 |
+
f"Diagnostics gathered so far: {len(obs['diagnostics_gathered'])}\n\n"
|
| 175 |
+
"What is your next action? Respond with JSON only."
|
| 176 |
+
),
|
| 177 |
+
})
|
| 178 |
+
|
| 179 |
+
# Get grader score
|
| 180 |
+
grader_result = call_env(base_url, "/grader", data={})
|
| 181 |
+
score = grader_result.get("score", 0.0)
|
| 182 |
+
breakdown = grader_result.get("breakdown", {})
|
| 183 |
+
|
| 184 |
+
print(f" Grader score: {score:.3f}")
|
| 185 |
+
print(f" Breakdown: {json.dumps(breakdown, indent=2)}")
|
| 186 |
+
|
| 187 |
+
return {
|
| 188 |
+
"task_id": task_id,
|
| 189 |
+
"task_name": task_name,
|
| 190 |
+
"score": score,
|
| 191 |
+
"breakdown": breakdown,
|
| 192 |
+
"steps": steps,
|
| 193 |
+
"total_reward": round(total_reward, 3),
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def main():
|
| 198 |
+
parser = argparse.ArgumentParser(description="MLOps Firefighter Baseline")
|
| 199 |
+
parser.add_argument("--base-url", default=DEFAULT_ENV_URL, help="Environment URL")
|
| 200 |
+
args = parser.parse_args()
|
| 201 |
+
|
| 202 |
+
api_key = os.environ.get("OPENAI_API_KEY")
|
| 203 |
+
if not api_key:
|
| 204 |
+
print("ERROR: Set OPENAI_API_KEY environment variable")
|
| 205 |
+
sys.exit(1)
|
| 206 |
+
|
| 207 |
+
client = OpenAI(api_key=api_key)
|
| 208 |
+
|
| 209 |
+
# Get tasks
|
| 210 |
+
tasks_info = call_env(args.base_url, "/tasks", method="GET")
|
| 211 |
+
tasks = tasks_info["tasks"]
|
| 212 |
+
|
| 213 |
+
print("\n" + "=" * 60)
|
| 214 |
+
print(" MLOps Firefighter — Baseline Inference")
|
| 215 |
+
print(f" Model: {MODEL}")
|
| 216 |
+
print(f" Environment: {args.base_url}")
|
| 217 |
+
print(f" Tasks: {len(tasks)}")
|
| 218 |
+
print("=" * 60)
|
| 219 |
+
|
| 220 |
+
results = []
|
| 221 |
+
for task in tasks:
|
| 222 |
+
result = run_agent_on_task(
|
| 223 |
+
client, args.base_url, task["task_id"], task["name"]
|
| 224 |
+
)
|
| 225 |
+
results.append(result)
|
| 226 |
+
|
| 227 |
+
# Summary
|
| 228 |
+
print("\n" + "=" * 60)
|
| 229 |
+
print(" BASELINE RESULTS SUMMARY")
|
| 230 |
+
print("=" * 60)
|
| 231 |
+
for r in results:
|
| 232 |
+
print(f" [{r['task_id']}] {r['task_name']}")
|
| 233 |
+
print(f" Score: {r['score']:.3f} | Steps: {r['steps']} | Reward: {r['total_reward']}")
|
| 234 |
+
avg = sum(r["score"] for r in results) / len(results) if results else 0
|
| 235 |
+
print(f"\n Average Score: {avg:.3f}")
|
| 236 |
+
print("=" * 60)
|
| 237 |
+
|
| 238 |
+
# Write results to file
|
| 239 |
+
with open("baseline_results.json", "w") as f:
|
| 240 |
+
json.dump({"results": results, "average_score": round(avg, 3)}, f, indent=2)
|
| 241 |
+
print(" Results saved to baseline_results.json")
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
if __name__ == "__main__":
|
| 245 |
+
main()
|
models.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 MLOps Firefighter Contributors
|
| 2 |
+
# Licensed under the BSD-3-Clause License
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Typed Pydantic models for the MLOps Firefighter environment.
|
| 6 |
+
Defines Action, Observation, and related types for the OpenEnv spec.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
from enum import Enum
|
| 12 |
+
from typing import Any, Optional
|
| 13 |
+
|
| 14 |
+
from pydantic import BaseModel, Field
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# ── Action Space ─────────────────────────────────────────────────────────────
|
| 18 |
+
|
| 19 |
+
class ActionType(str, Enum):
|
| 20 |
+
"""All actions an agent can take to diagnose and fix a production ML model."""
|
| 21 |
+
|
| 22 |
+
# Diagnostic actions
|
| 23 |
+
INSPECT_METRICS = "inspect_metrics" # View model performance metrics
|
| 24 |
+
QUERY_LOGS = "query_logs" # Search production logs
|
| 25 |
+
CHECK_DATA_DISTRIBUTION = "check_data_dist" # Compare train vs. serving data
|
| 26 |
+
CHECK_FEATURE_IMPORTANCE = "check_feature_importance" # Examine feature weights
|
| 27 |
+
RUN_PREDICTION_SAMPLE = "run_prediction_sample" # Test model on sample inputs
|
| 28 |
+
CHECK_INFRASTRUCTURE = "check_infrastructure" # Check latency, memory, GPU
|
| 29 |
+
CHECK_UPSTREAM_PIPELINE = "check_upstream_pipeline" # Inspect data pipeline health
|
| 30 |
+
|
| 31 |
+
# Remediation actions
|
| 32 |
+
ROLLBACK_MODEL = "rollback_model" # Revert to previous model version
|
| 33 |
+
ADJUST_THRESHOLD = "adjust_threshold" # Tune decision threshold
|
| 34 |
+
RETRAIN_MODEL = "retrain_model" # Trigger retraining
|
| 35 |
+
FIX_DATA_PIPELINE = "fix_data_pipeline" # Repair data ingestion issue
|
| 36 |
+
SCALE_INFRASTRUCTURE = "scale_infrastructure" # Add compute resources
|
| 37 |
+
ADD_FEATURE_GUARD = "add_feature_guard" # Add input validation / guardrail
|
| 38 |
+
SUBMIT_DIAGNOSIS = "submit_diagnosis" # Declare root cause & close episode
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class MLOpsAction(BaseModel):
|
| 42 |
+
"""An action taken by the agent in the MLOps Firefighter environment."""
|
| 43 |
+
|
| 44 |
+
action_type: ActionType = Field(
|
| 45 |
+
..., description="The type of action to perform"
|
| 46 |
+
)
|
| 47 |
+
parameters: dict[str, Any] = Field(
|
| 48 |
+
default_factory=dict,
|
| 49 |
+
description=(
|
| 50 |
+
"Action-specific parameters. E.g. for adjust_threshold: "
|
| 51 |
+
"{'new_threshold': 0.6}. For submit_diagnosis: "
|
| 52 |
+
"{'root_cause': 'data_drift', 'summary': '...'}."
|
| 53 |
+
),
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# ── Observation Space ────────────────────────────────────────────────────────
|
| 58 |
+
|
| 59 |
+
class AlertSeverity(str, Enum):
|
| 60 |
+
CRITICAL = "critical"
|
| 61 |
+
HIGH = "high"
|
| 62 |
+
MEDIUM = "medium"
|
| 63 |
+
LOW = "low"
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class Alert(BaseModel):
|
| 67 |
+
"""A production alert that triggered the incident."""
|
| 68 |
+
severity: AlertSeverity
|
| 69 |
+
message: str
|
| 70 |
+
metric_name: str
|
| 71 |
+
current_value: float
|
| 72 |
+
threshold: float
|
| 73 |
+
timestamp: str
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class ModelInfo(BaseModel):
|
| 77 |
+
"""Information about the deployed model."""
|
| 78 |
+
model_name: str
|
| 79 |
+
model_version: str
|
| 80 |
+
deployed_at: str
|
| 81 |
+
framework: str
|
| 82 |
+
endpoint: str
|
| 83 |
+
previous_versions: list[str] = Field(default_factory=list)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class MLOpsObservation(BaseModel):
|
| 87 |
+
"""What the agent sees after each action."""
|
| 88 |
+
|
| 89 |
+
# Episode metadata
|
| 90 |
+
done: bool = Field(False, description="Whether the episode has ended")
|
| 91 |
+
reward: float = Field(0.0, description="Reward for the last action")
|
| 92 |
+
step_number: int = Field(0, description="Current step in the episode")
|
| 93 |
+
max_steps: int = Field(20, description="Maximum steps before timeout")
|
| 94 |
+
|
| 95 |
+
# Incident context (always visible)
|
| 96 |
+
task_id: str = Field("", description="Task identifier")
|
| 97 |
+
task_description: str = Field("", description="Natural language task description")
|
| 98 |
+
alerts: list[Alert] = Field(default_factory=list, description="Active alerts")
|
| 99 |
+
model_info: ModelInfo | None = Field(None, description="Deployed model details")
|
| 100 |
+
|
| 101 |
+
# Action result (populated after each step)
|
| 102 |
+
action_result: str = Field(
|
| 103 |
+
"", description="Textual result of the last action taken"
|
| 104 |
+
)
|
| 105 |
+
action_success: bool = Field(True, description="Whether the action executed OK")
|
| 106 |
+
|
| 107 |
+
# Accumulated diagnostic context
|
| 108 |
+
diagnostics_gathered: list[str] = Field(
|
| 109 |
+
default_factory=list,
|
| 110 |
+
description="Summary of diagnostics the agent has collected so far",
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
# Hints / guidance
|
| 114 |
+
available_actions: list[str] = Field(
|
| 115 |
+
default_factory=list,
|
| 116 |
+
description="List of valid action types the agent can take",
|
| 117 |
+
)
|
openenv.yaml
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: mlops_firefighter
|
| 2 |
+
version: "1.0.0"
|
| 3 |
+
description: >
|
| 4 |
+
MLOps Firefighter: An OpenEnv environment where AI agents debug and fix
|
| 5 |
+
ML models failing in production. Agents diagnose root causes (data drift,
|
| 6 |
+
threshold misconfiguration, training data poisoning) and apply the correct
|
| 7 |
+
remediation. 3 tasks from easy to hard with programmatic graders.
|
| 8 |
+
|
| 9 |
+
author: "MLOps Firefighter Contributors"
|
| 10 |
+
license: "BSD-3-Clause"
|
| 11 |
+
tags:
|
| 12 |
+
- openenv
|
| 13 |
+
- rl
|
| 14 |
+
- mlops
|
| 15 |
+
- production-ml
|
| 16 |
+
- debugging
|
| 17 |
+
- real-world
|
| 18 |
+
|
| 19 |
+
environment:
|
| 20 |
+
entrypoint: server/app.py
|
| 21 |
+
port: 7860
|
| 22 |
+
python_version: "3.11"
|
| 23 |
+
|
| 24 |
+
action_model: models.MLOpsAction
|
| 25 |
+
observation_model: models.MLOpsObservation
|
| 26 |
+
|
| 27 |
+
endpoints:
|
| 28 |
+
reset: /reset
|
| 29 |
+
step: /step
|
| 30 |
+
state: /state
|
| 31 |
+
health: /health
|
| 32 |
+
tasks: /tasks
|
| 33 |
+
grader: /grader
|
| 34 |
+
baseline: /baseline
|
| 35 |
+
|
| 36 |
+
tasks:
|
| 37 |
+
- id: task_threshold_drift
|
| 38 |
+
name: "Threshold Misconfiguration After Redeployment"
|
| 39 |
+
difficulty: easy
|
| 40 |
+
description: >
|
| 41 |
+
Fraud detection model precision dropped after redeployment.
|
| 42 |
+
Agent must diagnose threshold misconfiguration and adjust it.
|
| 43 |
+
max_steps: 15
|
| 44 |
+
|
| 45 |
+
- id: task_data_drift
|
| 46 |
+
name: "Data Drift with Stale Feature Pipeline"
|
| 47 |
+
difficulty: medium
|
| 48 |
+
description: >
|
| 49 |
+
Loan default model degraded due to upstream data pipeline change.
|
| 50 |
+
Agent must identify feature drift/corruption and fix the pipeline + retrain.
|
| 51 |
+
max_steps: 20
|
| 52 |
+
|
| 53 |
+
- id: task_silent_regression
|
| 54 |
+
name: "Silent Model Regression with Adversarial Inputs"
|
| 55 |
+
difficulty: hard
|
| 56 |
+
description: >
|
| 57 |
+
Content moderation model looks fine on aggregate metrics but is missing
|
| 58 |
+
hate/violence due to poisoned training data. Agent must find the silent
|
| 59 |
+
regression, roll back safely, and add guards.
|
| 60 |
+
max_steps: 25
|
pyproject.toml
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=68.0", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "mlops-firefighter"
|
| 7 |
+
version = "1.0.0"
|
| 8 |
+
description = "OpenEnv environment: Debug and fix ML models failing in production"
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
license = {text = "BSD-3-Clause"}
|
| 11 |
+
requires-python = ">=3.10"
|
| 12 |
+
dependencies = [
|
| 13 |
+
"fastapi>=0.104.0",
|
| 14 |
+
"uvicorn[standard]>=0.24.0",
|
| 15 |
+
"pydantic>=2.5.0",
|
| 16 |
+
"requests>=2.31.0",
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
[project.optional-dependencies]
|
| 20 |
+
baseline = ["openai>=1.0.0"]
|
| 21 |
+
dev = ["pytest>=7.0", "httpx>=0.25.0"]
|
| 22 |
+
|
| 23 |
+
[tool.setuptools.packages.find]
|
| 24 |
+
include = ["mlops_firefighter*", "server*"]
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi>=0.104.0
|
| 2 |
+
uvicorn[standard]>=0.24.0
|
| 3 |
+
pydantic>=2.5.0
|
| 4 |
+
requests>=2.31.0
|
| 5 |
+
openai>=1.0.0
|
server/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Server package
|
server/app.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 MLOps Firefighter Contributors
|
| 2 |
+
# Licensed under the BSD-3-Clause License
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
FastAPI server for the MLOps Firefighter OpenEnv environment.
|
| 6 |
+
|
| 7 |
+
Endpoints:
|
| 8 |
+
POST /reset — Start a new episode
|
| 9 |
+
POST /step — Take an action
|
| 10 |
+
GET /state — Current environment state
|
| 11 |
+
GET /health — Health check
|
| 12 |
+
GET /tasks — List tasks and action schema
|
| 13 |
+
POST /grader — Score a completed episode
|
| 14 |
+
POST /baseline — Run baseline inference on all tasks
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import json
|
| 20 |
+
import os
|
| 21 |
+
import sys
|
| 22 |
+
import traceback
|
| 23 |
+
from typing import Any
|
| 24 |
+
|
| 25 |
+
from fastapi import FastAPI, HTTPException
|
| 26 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 27 |
+
from pydantic import BaseModel, Field
|
| 28 |
+
|
| 29 |
+
# Add parent dir to path for imports
|
| 30 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 31 |
+
|
| 32 |
+
from models import ActionType, MLOpsAction, MLOpsObservation
|
| 33 |
+
from tasks import ALL_TASKS, grade_episode
|
| 34 |
+
|
| 35 |
+
from environment import MLOpsFirefighterEnvironment
|
| 36 |
+
|
| 37 |
+
# ── App Setup ────────────────────────────────────────────────────────────────
|
| 38 |
+
|
| 39 |
+
app = FastAPI(
|
| 40 |
+
title="MLOps Firefighter — OpenEnv Environment",
|
| 41 |
+
description=(
|
| 42 |
+
"An AI agent environment for debugging and fixing ML models in production. "
|
| 43 |
+
"The agent diagnoses root causes of model failures (data drift, threshold "
|
| 44 |
+
"misconfiguration, training data poisoning) and applies the correct fix."
|
| 45 |
+
),
|
| 46 |
+
version="1.0.0",
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
app.add_middleware(
|
| 50 |
+
CORSMiddleware,
|
| 51 |
+
allow_origins=["*"],
|
| 52 |
+
allow_credentials=True,
|
| 53 |
+
allow_methods=["*"],
|
| 54 |
+
allow_headers=["*"],
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# Per-session environments (simple in-memory for single container)
|
| 58 |
+
_environments: dict[str, MLOpsFirefighterEnvironment] = {}
|
| 59 |
+
_default_env = MLOpsFirefighterEnvironment()
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _get_env(session_id: str | None = None) -> MLOpsFirefighterEnvironment:
|
| 63 |
+
if session_id and session_id in _environments:
|
| 64 |
+
return _environments[session_id]
|
| 65 |
+
return _default_env
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# ── Request / Response Models ────────────────────────────────────────────────
|
| 69 |
+
|
| 70 |
+
class ResetRequest(BaseModel):
|
| 71 |
+
task_id: str | None = Field(None, description="Task to load (optional)")
|
| 72 |
+
session_id: str | None = Field(None, description="Session ID (optional)")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class StepRequest(BaseModel):
|
| 76 |
+
action_type: str = Field(..., description="Action type to perform")
|
| 77 |
+
parameters: dict[str, Any] = Field(default_factory=dict)
|
| 78 |
+
session_id: str | None = Field(None, description="Session ID (optional)")
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class GraderRequest(BaseModel):
|
| 82 |
+
session_id: str | None = Field(None)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# ── Endpoints ────────────────────────────────────────────────────────────────
|
| 86 |
+
|
| 87 |
+
@app.get("/health")
|
| 88 |
+
async def health():
|
| 89 |
+
return {"status": "healthy"}
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@app.post("/reset")
|
| 93 |
+
async def reset(req: ResetRequest) -> dict:
|
| 94 |
+
env = _get_env(req.session_id)
|
| 95 |
+
if req.session_id:
|
| 96 |
+
_environments[req.session_id] = env
|
| 97 |
+
|
| 98 |
+
obs = env.reset(task_id=req.task_id)
|
| 99 |
+
return {
|
| 100 |
+
"observation": obs.model_dump(),
|
| 101 |
+
"reward": obs.reward,
|
| 102 |
+
"done": obs.done,
|
| 103 |
+
"info": {"episode_id": env.state()["episode_id"]},
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
@app.post("/step")
|
| 108 |
+
async def step(req: StepRequest) -> dict:
|
| 109 |
+
env = _get_env(req.session_id)
|
| 110 |
+
|
| 111 |
+
try:
|
| 112 |
+
action_type = ActionType(req.action_type)
|
| 113 |
+
except ValueError:
|
| 114 |
+
raise HTTPException(
|
| 115 |
+
status_code=400,
|
| 116 |
+
detail=f"Invalid action_type: '{req.action_type}'. "
|
| 117 |
+
f"Valid actions: {[a.value for a in ActionType]}",
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
action = MLOpsAction(action_type=action_type, parameters=req.parameters)
|
| 121 |
+
obs = env.step(action)
|
| 122 |
+
|
| 123 |
+
return {
|
| 124 |
+
"observation": obs.model_dump(),
|
| 125 |
+
"reward": obs.reward,
|
| 126 |
+
"done": obs.done,
|
| 127 |
+
"info": env.state(),
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@app.get("/state")
|
| 132 |
+
async def state(session_id: str | None = None) -> dict:
|
| 133 |
+
env = _get_env(session_id)
|
| 134 |
+
return env.state()
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
@app.get("/tasks")
|
| 138 |
+
async def list_tasks() -> dict:
|
| 139 |
+
"""List all tasks with their action schema."""
|
| 140 |
+
tasks = []
|
| 141 |
+
for tid, t in ALL_TASKS.items():
|
| 142 |
+
tasks.append({
|
| 143 |
+
"task_id": t.task_id,
|
| 144 |
+
"name": t.name,
|
| 145 |
+
"difficulty": t.difficulty,
|
| 146 |
+
"description": t.description,
|
| 147 |
+
"max_steps": t.max_steps,
|
| 148 |
+
})
|
| 149 |
+
|
| 150 |
+
action_schema = {
|
| 151 |
+
"action_type": {
|
| 152 |
+
"type": "string",
|
| 153 |
+
"required": True,
|
| 154 |
+
"enum": [a.value for a in ActionType],
|
| 155 |
+
"description": "The action to perform",
|
| 156 |
+
},
|
| 157 |
+
"parameters": {
|
| 158 |
+
"type": "object",
|
| 159 |
+
"required": False,
|
| 160 |
+
"description": (
|
| 161 |
+
"Action-specific parameters. For adjust_threshold: "
|
| 162 |
+
"{'new_threshold': float}. For submit_diagnosis: "
|
| 163 |
+
"{'root_cause': str, 'summary': str}. For rollback_model: "
|
| 164 |
+
"{'target_version': str}."
|
| 165 |
+
),
|
| 166 |
+
},
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
return {"tasks": tasks, "action_schema": action_schema}
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
@app.post("/grader")
|
| 173 |
+
async def grader(req: GraderRequest) -> dict:
|
| 174 |
+
"""Return grader score for a completed episode."""
|
| 175 |
+
env = _get_env(req.session_id)
|
| 176 |
+
st = env.state()
|
| 177 |
+
|
| 178 |
+
if not st["done"]:
|
| 179 |
+
raise HTTPException(
|
| 180 |
+
status_code=400,
|
| 181 |
+
detail="Episode not complete. Finish the episode first.",
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
if st["grader_result"]:
|
| 185 |
+
return {
|
| 186 |
+
"score": st["grader_result"]["total"],
|
| 187 |
+
"breakdown": st["grader_result"],
|
| 188 |
+
"task_id": st["task_id"],
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
return {
|
| 192 |
+
"score": 0.0,
|
| 193 |
+
"breakdown": {},
|
| 194 |
+
"task_id": st["task_id"],
|
| 195 |
+
"message": "No grader result available.",
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
@app.post("/baseline")
|
| 200 |
+
async def run_baseline() -> dict:
|
| 201 |
+
"""Run a scripted baseline agent on all 3 tasks and return scores."""
|
| 202 |
+
results = {}
|
| 203 |
+
|
| 204 |
+
for task_id, task_def in ALL_TASKS.items():
|
| 205 |
+
env = MLOpsFirefighterEnvironment()
|
| 206 |
+
env.reset(task_id=task_id)
|
| 207 |
+
|
| 208 |
+
# Baseline strategy: run all required diagnostics, then apply
|
| 209 |
+
# correct remediation, then submit diagnosis
|
| 210 |
+
for diag in task_def.required_diagnostics:
|
| 211 |
+
action = MLOpsAction(action_type=diag, parameters={})
|
| 212 |
+
env.step(action)
|
| 213 |
+
|
| 214 |
+
# Apply correct remediations
|
| 215 |
+
for rem in task_def.correct_remediations:
|
| 216 |
+
params = {}
|
| 217 |
+
if rem == ActionType.ADJUST_THRESHOLD:
|
| 218 |
+
params = {"new_threshold": task_def.extra_state.get("optimal_threshold", 0.5)}
|
| 219 |
+
elif rem == ActionType.ROLLBACK_MODEL:
|
| 220 |
+
params = {"target_version": task_def.model_info.previous_versions[0]}
|
| 221 |
+
action = MLOpsAction(action_type=rem, parameters=params)
|
| 222 |
+
env.step(action)
|
| 223 |
+
|
| 224 |
+
# Submit diagnosis
|
| 225 |
+
action = MLOpsAction(
|
| 226 |
+
action_type=ActionType.SUBMIT_DIAGNOSIS,
|
| 227 |
+
parameters={
|
| 228 |
+
"root_cause": task_def.root_causes[0],
|
| 229 |
+
"summary": f"Baseline diagnosis for {task_def.name}",
|
| 230 |
+
},
|
| 231 |
+
)
|
| 232 |
+
env.step(action)
|
| 233 |
+
|
| 234 |
+
st = env.state()
|
| 235 |
+
results[task_id] = {
|
| 236 |
+
"task_name": task_def.name,
|
| 237 |
+
"difficulty": task_def.difficulty,
|
| 238 |
+
"score": st["grader_result"]["total"] if st["grader_result"] else 0.0,
|
| 239 |
+
"breakdown": st["grader_result"],
|
| 240 |
+
"steps_taken": st["step_count"],
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
# Average score
|
| 244 |
+
scores = [r["score"] for r in results.values()]
|
| 245 |
+
avg = sum(scores) / len(scores) if scores else 0.0
|
| 246 |
+
|
| 247 |
+
return {
|
| 248 |
+
"baseline_results": results,
|
| 249 |
+
"average_score": round(avg, 3),
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
# ── Main ─────────────────────────────────────────────────────────────────────
|
| 254 |
+
|
| 255 |
+
if __name__ == "__main__":
|
| 256 |
+
import uvicorn
|
| 257 |
+
|
| 258 |
+
port = int(os.environ.get("PORT", 7860))
|
| 259 |
+
uvicorn.run(app, host="0.0.0.0", port=port)
|
server/environment.py
ADDED
|
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 MLOps Firefighter Contributors
|
| 2 |
+
# Licensed under the BSD-3-Clause License
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Core environment logic for the MLOps Firefighter.
|
| 6 |
+
|
| 7 |
+
Implements the OpenEnv Environment interface: reset(), step(), state().
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import uuid
|
| 13 |
+
from typing import Any
|
| 14 |
+
|
| 15 |
+
from models import (
|
| 16 |
+
ActionType,
|
| 17 |
+
AlertSeverity,
|
| 18 |
+
MLOpsAction,
|
| 19 |
+
MLOpsObservation,
|
| 20 |
+
)
|
| 21 |
+
from tasks import ALL_TASKS, TaskDefinition, grade_episode
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class MLOpsFirefighterEnvironment:
|
| 25 |
+
"""
|
| 26 |
+
Simulates an ML model failing in production.
|
| 27 |
+
|
| 28 |
+
The agent must diagnose the root cause and apply the correct fix
|
| 29 |
+
through a sequence of diagnostic and remediation actions.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self) -> None:
|
| 33 |
+
self._episode_id: str = ""
|
| 34 |
+
self._step_count: int = 0
|
| 35 |
+
self._task: TaskDefinition | None = None
|
| 36 |
+
self._done: bool = False
|
| 37 |
+
self._cumulative_reward: float = 0.0
|
| 38 |
+
|
| 39 |
+
# Tracking for grading
|
| 40 |
+
self._actions_taken: list[dict[str, Any]] = []
|
| 41 |
+
self._diagnostics_gathered: list[str] = []
|
| 42 |
+
self._diagnosis_submitted: dict[str, Any] | None = None
|
| 43 |
+
self._remediations_applied: list[str] = []
|
| 44 |
+
self._last_grader_result: dict | None = None
|
| 45 |
+
|
| 46 |
+
# ── OpenEnv Interface ────────────────────────────────────────────────
|
| 47 |
+
|
| 48 |
+
def reset(self, task_id: str | None = None) -> MLOpsObservation:
|
| 49 |
+
"""Initialize a new incident episode."""
|
| 50 |
+
self._episode_id = str(uuid.uuid4())
|
| 51 |
+
self._step_count = 0
|
| 52 |
+
self._done = False
|
| 53 |
+
self._cumulative_reward = 0.0
|
| 54 |
+
self._actions_taken = []
|
| 55 |
+
self._diagnostics_gathered = []
|
| 56 |
+
self._diagnosis_submitted = None
|
| 57 |
+
self._remediations_applied = []
|
| 58 |
+
self._last_grader_result = None
|
| 59 |
+
|
| 60 |
+
# Pick task
|
| 61 |
+
if task_id and task_id in ALL_TASKS:
|
| 62 |
+
self._task = ALL_TASKS[task_id]
|
| 63 |
+
else:
|
| 64 |
+
# Default to easy task
|
| 65 |
+
self._task = list(ALL_TASKS.values())[0]
|
| 66 |
+
|
| 67 |
+
return MLOpsObservation(
|
| 68 |
+
done=False,
|
| 69 |
+
reward=0.0,
|
| 70 |
+
step_number=0,
|
| 71 |
+
max_steps=self._task.max_steps,
|
| 72 |
+
task_id=self._task.task_id,
|
| 73 |
+
task_description=self._task.description,
|
| 74 |
+
alerts=self._task.alerts,
|
| 75 |
+
model_info=self._task.model_info,
|
| 76 |
+
action_result="Incident assigned to you. Begin investigation.",
|
| 77 |
+
action_success=True,
|
| 78 |
+
diagnostics_gathered=[],
|
| 79 |
+
available_actions=[a.value for a in ActionType],
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
def step(self, action: MLOpsAction) -> MLOpsObservation:
|
| 83 |
+
"""Execute an action and return the resulting observation."""
|
| 84 |
+
if self._task is None:
|
| 85 |
+
return MLOpsObservation(
|
| 86 |
+
done=True,
|
| 87 |
+
reward=-1.0,
|
| 88 |
+
action_result="Error: No task loaded. Call reset() first.",
|
| 89 |
+
action_success=False,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
if self._done:
|
| 93 |
+
return MLOpsObservation(
|
| 94 |
+
done=True,
|
| 95 |
+
reward=0.0,
|
| 96 |
+
step_number=self._step_count,
|
| 97 |
+
max_steps=self._task.max_steps,
|
| 98 |
+
task_id=self._task.task_id,
|
| 99 |
+
task_description=self._task.description,
|
| 100 |
+
action_result="Episode already ended.",
|
| 101 |
+
action_success=False,
|
| 102 |
+
diagnostics_gathered=self._diagnostics_gathered,
|
| 103 |
+
available_actions=[],
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
self._step_count += 1
|
| 107 |
+
self._actions_taken.append({
|
| 108 |
+
"action_type": action.action_type.value,
|
| 109 |
+
"parameters": action.parameters,
|
| 110 |
+
"step": self._step_count,
|
| 111 |
+
})
|
| 112 |
+
|
| 113 |
+
# Process the action
|
| 114 |
+
result_text, reward, action_success = self._process_action(action)
|
| 115 |
+
|
| 116 |
+
# Check for episode end
|
| 117 |
+
if action.action_type == ActionType.SUBMIT_DIAGNOSIS:
|
| 118 |
+
self._done = True
|
| 119 |
+
# Run grader for final reward
|
| 120 |
+
final_score, breakdown = grade_episode(
|
| 121 |
+
task=self._task,
|
| 122 |
+
actions_taken=self._actions_taken,
|
| 123 |
+
diagnosis_submitted=self._diagnosis_submitted,
|
| 124 |
+
remediation_applied=self._remediations_applied,
|
| 125 |
+
total_steps=self._step_count,
|
| 126 |
+
)
|
| 127 |
+
self._last_grader_result = breakdown
|
| 128 |
+
# Scale final reward: bonus for good diagnosis
|
| 129 |
+
reward = final_score * 5.0 # 0–5 range for final step
|
| 130 |
+
result_text += f"\n\n── EPISODE COMPLETE ──\nFinal Score: {final_score:.3f}\nBreakdown: {breakdown}"
|
| 131 |
+
|
| 132 |
+
# Check timeout
|
| 133 |
+
if self._step_count >= self._task.max_steps and not self._done:
|
| 134 |
+
self._done = True
|
| 135 |
+
final_score, breakdown = grade_episode(
|
| 136 |
+
task=self._task,
|
| 137 |
+
actions_taken=self._actions_taken,
|
| 138 |
+
diagnosis_submitted=self._diagnosis_submitted,
|
| 139 |
+
remediation_applied=self._remediations_applied,
|
| 140 |
+
total_steps=self._step_count,
|
| 141 |
+
)
|
| 142 |
+
self._last_grader_result = breakdown
|
| 143 |
+
reward = final_score * 3.0 # Lower multiplier for timeout
|
| 144 |
+
result_text += (
|
| 145 |
+
f"\n\n── TIMEOUT — Episode ended (max {self._task.max_steps} steps) ──\n"
|
| 146 |
+
f"Score: {final_score:.3f}\nBreakdown: {breakdown}"
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
self._cumulative_reward += reward
|
| 150 |
+
|
| 151 |
+
return MLOpsObservation(
|
| 152 |
+
done=self._done,
|
| 153 |
+
reward=round(reward, 4),
|
| 154 |
+
step_number=self._step_count,
|
| 155 |
+
max_steps=self._task.max_steps,
|
| 156 |
+
task_id=self._task.task_id,
|
| 157 |
+
task_description=self._task.description,
|
| 158 |
+
alerts=self._task.alerts,
|
| 159 |
+
model_info=self._task.model_info,
|
| 160 |
+
action_result=result_text,
|
| 161 |
+
action_success=action_success,
|
| 162 |
+
diagnostics_gathered=self._diagnostics_gathered,
|
| 163 |
+
available_actions=(
|
| 164 |
+
[a.value for a in ActionType] if not self._done else []
|
| 165 |
+
),
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
def state(self) -> dict[str, Any]:
|
| 169 |
+
"""Return current environment state."""
|
| 170 |
+
return {
|
| 171 |
+
"episode_id": self._episode_id,
|
| 172 |
+
"step_count": self._step_count,
|
| 173 |
+
"task_id": self._task.task_id if self._task else None,
|
| 174 |
+
"done": self._done,
|
| 175 |
+
"cumulative_reward": round(self._cumulative_reward, 4),
|
| 176 |
+
"actions_taken_count": len(self._actions_taken),
|
| 177 |
+
"diagnostics_gathered": self._diagnostics_gathered,
|
| 178 |
+
"remediations_applied": self._remediations_applied,
|
| 179 |
+
"diagnosis_submitted": self._diagnosis_submitted is not None,
|
| 180 |
+
"grader_result": self._last_grader_result,
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
# ── Action Processing ────────────────────────────────────────────────
|
| 184 |
+
|
| 185 |
+
def _process_action(self, action: MLOpsAction) -> tuple[str, float, bool]:
|
| 186 |
+
"""
|
| 187 |
+
Process an action and return (result_text, step_reward, success).
|
| 188 |
+
|
| 189 |
+
Reward shaping:
|
| 190 |
+
- Useful diagnostic: +0.3
|
| 191 |
+
- Redundant diagnostic: +0.05
|
| 192 |
+
- Correct remediation: +1.0
|
| 193 |
+
- Wrong remediation: -0.5
|
| 194 |
+
- Destructive action without diagnosis: -1.0
|
| 195 |
+
- Submit diagnosis: handled in step() via grader
|
| 196 |
+
"""
|
| 197 |
+
task = self._task
|
| 198 |
+
assert task is not None
|
| 199 |
+
|
| 200 |
+
at = action.action_type
|
| 201 |
+
|
| 202 |
+
# ── Diagnostic actions ───────────────────────────────────────────
|
| 203 |
+
if at in (
|
| 204 |
+
ActionType.INSPECT_METRICS,
|
| 205 |
+
ActionType.QUERY_LOGS,
|
| 206 |
+
ActionType.CHECK_DATA_DISTRIBUTION,
|
| 207 |
+
ActionType.CHECK_FEATURE_IMPORTANCE,
|
| 208 |
+
ActionType.RUN_PREDICTION_SAMPLE,
|
| 209 |
+
ActionType.CHECK_INFRASTRUCTURE,
|
| 210 |
+
ActionType.CHECK_UPSTREAM_PIPELINE,
|
| 211 |
+
):
|
| 212 |
+
result_text = task.diagnostic_results.get(
|
| 213 |
+
at, "No additional information available for this diagnostic."
|
| 214 |
+
)
|
| 215 |
+
# Reward: useful if it's a required diagnostic we haven't done
|
| 216 |
+
diag_label = f"[{at.value}] {result_text[:80]}..."
|
| 217 |
+
if diag_label in self._diagnostics_gathered:
|
| 218 |
+
# Redundant — already gathered
|
| 219 |
+
reward = 0.05
|
| 220 |
+
result_text = "(You already ran this diagnostic.)\n\n" + result_text
|
| 221 |
+
else:
|
| 222 |
+
self._diagnostics_gathered.append(diag_label)
|
| 223 |
+
if at in task.required_diagnostics:
|
| 224 |
+
reward = 0.3 # Useful diagnostic
|
| 225 |
+
else:
|
| 226 |
+
reward = 0.1 # Valid but not critical
|
| 227 |
+
return result_text, reward, True
|
| 228 |
+
|
| 229 |
+
# ── Remediation actions ──────────────────────────────────────────
|
| 230 |
+
if at in (
|
| 231 |
+
ActionType.ROLLBACK_MODEL,
|
| 232 |
+
ActionType.ADJUST_THRESHOLD,
|
| 233 |
+
ActionType.RETRAIN_MODEL,
|
| 234 |
+
ActionType.FIX_DATA_PIPELINE,
|
| 235 |
+
ActionType.SCALE_INFRASTRUCTURE,
|
| 236 |
+
ActionType.ADD_FEATURE_GUARD,
|
| 237 |
+
):
|
| 238 |
+
return self._process_remediation(at, action.parameters)
|
| 239 |
+
|
| 240 |
+
# ── Submit diagnosis ─────────────────────────────────────────────
|
| 241 |
+
if at == ActionType.SUBMIT_DIAGNOSIS:
|
| 242 |
+
self._diagnosis_submitted = action.parameters
|
| 243 |
+
root_cause = action.parameters.get("root_cause", "not specified")
|
| 244 |
+
summary = action.parameters.get("summary", "no summary")
|
| 245 |
+
return (
|
| 246 |
+
f"Diagnosis submitted.\n"
|
| 247 |
+
f" Root cause: {root_cause}\n"
|
| 248 |
+
f" Summary: {summary}",
|
| 249 |
+
0.0, # Actual reward computed in step() via grader
|
| 250 |
+
True,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
return "Unknown action type.", -0.2, False
|
| 254 |
+
|
| 255 |
+
def _process_remediation(
|
| 256 |
+
self, action_type: ActionType, params: dict[str, Any]
|
| 257 |
+
) -> tuple[str, float, bool]:
|
| 258 |
+
"""Process a remediation action."""
|
| 259 |
+
task = self._task
|
| 260 |
+
assert task is not None
|
| 261 |
+
|
| 262 |
+
self._remediations_applied.append(action_type.value)
|
| 263 |
+
|
| 264 |
+
is_correct = action_type in task.correct_remediations
|
| 265 |
+
has_diagnosed = len(self._diagnostics_gathered) >= 2
|
| 266 |
+
|
| 267 |
+
# Penalize hasty remediation without diagnosis
|
| 268 |
+
if not has_diagnosed:
|
| 269 |
+
return (
|
| 270 |
+
f"⚠ WARNING: Applying {action_type.value} without sufficient "
|
| 271 |
+
f"diagnostic investigation. This is risky in production.\n"
|
| 272 |
+
f"Action applied, but confidence is low.",
|
| 273 |
+
-0.5,
|
| 274 |
+
True,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
if is_correct:
|
| 278 |
+
# Specific feedback per remediation
|
| 279 |
+
feedback = self._get_remediation_feedback(action_type, params)
|
| 280 |
+
return feedback, 1.0, True
|
| 281 |
+
else:
|
| 282 |
+
return (
|
| 283 |
+
f"Applied {action_type.value}, but this doesn't address the "
|
| 284 |
+
f"root cause. The issue persists. Consider more investigation.",
|
| 285 |
+
-0.3,
|
| 286 |
+
True,
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
def _get_remediation_feedback(
|
| 290 |
+
self, action_type: ActionType, params: dict[str, Any]
|
| 291 |
+
) -> str:
|
| 292 |
+
"""Generate specific feedback for correct remediations."""
|
| 293 |
+
task = self._task
|
| 294 |
+
assert task is not None
|
| 295 |
+
|
| 296 |
+
if action_type == ActionType.ADJUST_THRESHOLD:
|
| 297 |
+
new_t = params.get("new_threshold", "unspecified")
|
| 298 |
+
optimal = task.extra_state.get("optimal_threshold", 0.5)
|
| 299 |
+
return (
|
| 300 |
+
f"Threshold adjusted to {new_t}.\n"
|
| 301 |
+
f"Optimal threshold was {optimal}.\n"
|
| 302 |
+
f"Precision recovering. False positive rate decreasing."
|
| 303 |
+
)
|
| 304 |
+
elif action_type == ActionType.ROLLBACK_MODEL:
|
| 305 |
+
target = params.get("target_version", "previous")
|
| 306 |
+
return (
|
| 307 |
+
f"Model rolled back to {target}.\n"
|
| 308 |
+
f"Previous model restored. Monitoring metrics.\n"
|
| 309 |
+
f"Harmful content detection recovering."
|
| 310 |
+
)
|
| 311 |
+
elif action_type == ActionType.FIX_DATA_PIPELINE:
|
| 312 |
+
return (
|
| 313 |
+
"Data pipeline fix initiated.\n"
|
| 314 |
+
" - Credit bureau connector patched\n"
|
| 315 |
+
" - Unit conversion (cents→dollars) applied to annual_income\n"
|
| 316 |
+
" - Null handling added for credit_utilization_ratio\n"
|
| 317 |
+
" Pipeline revalidation in progress."
|
| 318 |
+
)
|
| 319 |
+
elif action_type == ActionType.RETRAIN_MODEL:
|
| 320 |
+
return (
|
| 321 |
+
"Model retraining triggered with corrected data.\n"
|
| 322 |
+
" - Stale features refreshed\n"
|
| 323 |
+
" - Estimated completion: 2 hours\n"
|
| 324 |
+
" - Will auto-deploy after quality gate passes."
|
| 325 |
+
)
|
| 326 |
+
elif action_type == ActionType.ADD_FEATURE_GUARD:
|
| 327 |
+
return (
|
| 328 |
+
"Feature guard / input validation added.\n"
|
| 329 |
+
" - Per-class recall monitoring enabled\n"
|
| 330 |
+
" - Crowd-source batch validation gate added\n"
|
| 331 |
+
" - Anomalous label distributions will trigger review."
|
| 332 |
+
)
|
| 333 |
+
else:
|
| 334 |
+
return f"{action_type.value} applied successfully."
|
tasks.py
ADDED
|
@@ -0,0 +1,526 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 MLOps Firefighter Contributors
|
| 2 |
+
# Licensed under the BSD-3-Clause License
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Task definitions for the MLOps Firefighter environment.
|
| 6 |
+
|
| 7 |
+
Each task defines:
|
| 8 |
+
- A production ML incident scenario
|
| 9 |
+
- The ground-truth root cause(s)
|
| 10 |
+
- Required diagnostic steps (for partial credit)
|
| 11 |
+
- Correct remediation action(s)
|
| 12 |
+
- A grader function that scores agent performance 0.0–1.0
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from typing import Any
|
| 19 |
+
|
| 20 |
+
from models import (
|
| 21 |
+
ActionType,
|
| 22 |
+
Alert,
|
| 23 |
+
AlertSeverity,
|
| 24 |
+
ModelInfo,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ── Task Data Structure ─────────────────────────────────────────────────────
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class TaskDefinition:
|
| 32 |
+
task_id: str
|
| 33 |
+
name: str
|
| 34 |
+
difficulty: str # "easy", "medium", "hard"
|
| 35 |
+
description: str
|
| 36 |
+
alerts: list[Alert]
|
| 37 |
+
model_info: ModelInfo
|
| 38 |
+
root_causes: list[str] # acceptable root cause labels
|
| 39 |
+
required_diagnostics: list[ActionType] # expected investigation steps
|
| 40 |
+
correct_remediations: list[ActionType] # accepted fix actions
|
| 41 |
+
diagnostic_results: dict[ActionType, str] # what each diagnostic reveals
|
| 42 |
+
max_steps: int = 20
|
| 43 |
+
extra_state: dict[str, Any] = field(default_factory=dict)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# ── Task 1: Easy — Threshold Misconfiguration ───────────────────────────────
|
| 47 |
+
|
| 48 |
+
TASK_EASY = TaskDefinition(
|
| 49 |
+
task_id="task_threshold_drift",
|
| 50 |
+
name="Threshold Misconfiguration After Redeployment",
|
| 51 |
+
difficulty="easy",
|
| 52 |
+
description=(
|
| 53 |
+
"INCIDENT: The fraud detection model's precision has dropped from 0.94 to 0.61 "
|
| 54 |
+
"over the last 2 hours. False positive rate spiked 3x. The model was redeployed "
|
| 55 |
+
"yesterday with a new version. No changes to training data or features were made. "
|
| 56 |
+
"Investigate and fix the issue."
|
| 57 |
+
),
|
| 58 |
+
alerts=[
|
| 59 |
+
Alert(
|
| 60 |
+
severity=AlertSeverity.HIGH,
|
| 61 |
+
message="Precision drop detected on fraud-detection-v3",
|
| 62 |
+
metric_name="precision",
|
| 63 |
+
current_value=0.61,
|
| 64 |
+
threshold=0.85,
|
| 65 |
+
timestamp="2026-03-27T10:15:00Z",
|
| 66 |
+
),
|
| 67 |
+
Alert(
|
| 68 |
+
severity=AlertSeverity.MEDIUM,
|
| 69 |
+
message="False positive rate above acceptable range",
|
| 70 |
+
metric_name="false_positive_rate",
|
| 71 |
+
current_value=0.18,
|
| 72 |
+
threshold=0.05,
|
| 73 |
+
timestamp="2026-03-27T10:17:00Z",
|
| 74 |
+
),
|
| 75 |
+
],
|
| 76 |
+
model_info=ModelInfo(
|
| 77 |
+
model_name="fraud-detection",
|
| 78 |
+
model_version="v3.1.0",
|
| 79 |
+
deployed_at="2026-03-26T14:00:00Z",
|
| 80 |
+
framework="XGBoost",
|
| 81 |
+
endpoint="/api/v1/predict/fraud",
|
| 82 |
+
previous_versions=["v3.0.2", "v2.9.1"],
|
| 83 |
+
),
|
| 84 |
+
root_causes=["threshold_misconfiguration", "threshold_too_low", "bad_threshold"],
|
| 85 |
+
required_diagnostics=[
|
| 86 |
+
ActionType.INSPECT_METRICS,
|
| 87 |
+
ActionType.RUN_PREDICTION_SAMPLE,
|
| 88 |
+
],
|
| 89 |
+
correct_remediations=[ActionType.ADJUST_THRESHOLD],
|
| 90 |
+
diagnostic_results={
|
| 91 |
+
ActionType.INSPECT_METRICS: (
|
| 92 |
+
"METRICS DASHBOARD:\n"
|
| 93 |
+
" Model: fraud-detection v3.1.0\n"
|
| 94 |
+
" Precision: 0.61 (was 0.94 on v3.0.2)\n"
|
| 95 |
+
" Recall: 0.99 (was 0.87 on v3.0.2)\n"
|
| 96 |
+
" F1: 0.76 (was 0.90 on v3.0.2)\n"
|
| 97 |
+
" Decision threshold: 0.30 (was 0.55 on v3.0.2)\n"
|
| 98 |
+
" NOTE: Threshold was changed during redeployment.\n"
|
| 99 |
+
" Requests/sec: 1,240 (normal)\n"
|
| 100 |
+
" P99 latency: 45ms (normal)"
|
| 101 |
+
),
|
| 102 |
+
ActionType.QUERY_LOGS: (
|
| 103 |
+
"LOG SEARCH RESULTS (last 4 hours):\n"
|
| 104 |
+
" [10:00] Deployment v3.1.0 started\n"
|
| 105 |
+
" [10:01] Config applied: threshold=0.30 (previous: 0.55)\n"
|
| 106 |
+
" [10:02] Health check passed\n"
|
| 107 |
+
" [10:15] Alert: precision below 0.85\n"
|
| 108 |
+
" No error logs. No OOM events. Pipeline healthy."
|
| 109 |
+
),
|
| 110 |
+
ActionType.CHECK_DATA_DISTRIBUTION: (
|
| 111 |
+
"DATA DISTRIBUTION COMPARISON:\n"
|
| 112 |
+
" Training data: 2.1M samples, fraud_rate=3.2%\n"
|
| 113 |
+
" Serving data (last 24h): 148K requests, fraud_rate=3.1%\n"
|
| 114 |
+
" Feature distributions: no significant drift detected\n"
|
| 115 |
+
" KL divergence: 0.003 (threshold: 0.05)\n"
|
| 116 |
+
" All features within expected ranges."
|
| 117 |
+
),
|
| 118 |
+
ActionType.CHECK_FEATURE_IMPORTANCE: (
|
| 119 |
+
"FEATURE IMPORTANCE (top 5):\n"
|
| 120 |
+
" 1. transaction_amount: 0.31\n"
|
| 121 |
+
" 2. merchant_risk_score: 0.22\n"
|
| 122 |
+
" 3. velocity_1h: 0.18\n"
|
| 123 |
+
" 4. distance_from_home: 0.14\n"
|
| 124 |
+
" 5. device_fingerprint_match: 0.09\n"
|
| 125 |
+
" No anomalies in feature importance vs. training."
|
| 126 |
+
),
|
| 127 |
+
ActionType.RUN_PREDICTION_SAMPLE: (
|
| 128 |
+
"PREDICTION SAMPLE TEST (100 known-labeled samples):\n"
|
| 129 |
+
" At threshold=0.30: precision=0.62, recall=0.98\n"
|
| 130 |
+
" At threshold=0.55: precision=0.93, recall=0.88\n"
|
| 131 |
+
" At threshold=0.50: precision=0.91, recall=0.90\n"
|
| 132 |
+
" Conclusion: threshold is the primary lever for precision."
|
| 133 |
+
),
|
| 134 |
+
ActionType.CHECK_INFRASTRUCTURE: (
|
| 135 |
+
"INFRASTRUCTURE STATUS:\n"
|
| 136 |
+
" CPU: 34% (normal)\n Memory: 4.2GB/8GB (normal)\n"
|
| 137 |
+
" GPU: N/A (CPU-only model)\n Latency P99: 45ms (normal)\n"
|
| 138 |
+
" No infrastructure issues detected."
|
| 139 |
+
),
|
| 140 |
+
ActionType.CHECK_UPSTREAM_PIPELINE: (
|
| 141 |
+
"UPSTREAM PIPELINE:\n"
|
| 142 |
+
" Feature store: healthy, last refresh 12 min ago\n"
|
| 143 |
+
" Data ingestion: nominal, 0 failures in 24h\n"
|
| 144 |
+
" Schema validation: passing\n"
|
| 145 |
+
" No upstream issues."
|
| 146 |
+
),
|
| 147 |
+
},
|
| 148 |
+
max_steps=15,
|
| 149 |
+
extra_state={"optimal_threshold": 0.55},
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# ── Task 2: Medium — Data Drift + Stale Feature ─────────────────────────────
|
| 154 |
+
|
| 155 |
+
TASK_MEDIUM = TaskDefinition(
|
| 156 |
+
task_id="task_data_drift",
|
| 157 |
+
name="Data Drift with Stale Feature Pipeline",
|
| 158 |
+
difficulty="medium",
|
| 159 |
+
description=(
|
| 160 |
+
"INCIDENT: The loan default prediction model's AUC has degraded from 0.91 to "
|
| 161 |
+
"0.74 over the past week. Customer complaints about incorrect rejections have "
|
| 162 |
+
"tripled. The model was last retrained 4 months ago. A new credit bureau data "
|
| 163 |
+
"provider was onboarded 10 days ago. Diagnose the issue and fix it."
|
| 164 |
+
),
|
| 165 |
+
alerts=[
|
| 166 |
+
Alert(
|
| 167 |
+
severity=AlertSeverity.CRITICAL,
|
| 168 |
+
message="AUC degradation on loan-default-predictor",
|
| 169 |
+
metric_name="auc_roc",
|
| 170 |
+
current_value=0.74,
|
| 171 |
+
threshold=0.85,
|
| 172 |
+
timestamp="2026-03-27T08:00:00Z",
|
| 173 |
+
),
|
| 174 |
+
Alert(
|
| 175 |
+
severity=AlertSeverity.HIGH,
|
| 176 |
+
message="Customer complaint rate above threshold",
|
| 177 |
+
metric_name="complaint_rate",
|
| 178 |
+
current_value=0.12,
|
| 179 |
+
threshold=0.03,
|
| 180 |
+
timestamp="2026-03-27T09:30:00Z",
|
| 181 |
+
),
|
| 182 |
+
],
|
| 183 |
+
model_info=ModelInfo(
|
| 184 |
+
model_name="loan-default-predictor",
|
| 185 |
+
model_version="v2.4.0",
|
| 186 |
+
deployed_at="2025-11-15T10:00:00Z",
|
| 187 |
+
framework="LightGBM",
|
| 188 |
+
endpoint="/api/v1/predict/loan-default",
|
| 189 |
+
previous_versions=["v2.3.1", "v2.2.0"],
|
| 190 |
+
),
|
| 191 |
+
root_causes=[
|
| 192 |
+
"data_drift",
|
| 193 |
+
"stale_feature",
|
| 194 |
+
"feature_pipeline_broken",
|
| 195 |
+
"data_distribution_shift",
|
| 196 |
+
],
|
| 197 |
+
required_diagnostics=[
|
| 198 |
+
ActionType.CHECK_DATA_DISTRIBUTION,
|
| 199 |
+
ActionType.CHECK_UPSTREAM_PIPELINE,
|
| 200 |
+
ActionType.INSPECT_METRICS,
|
| 201 |
+
],
|
| 202 |
+
correct_remediations=[
|
| 203 |
+
ActionType.FIX_DATA_PIPELINE,
|
| 204 |
+
ActionType.RETRAIN_MODEL,
|
| 205 |
+
],
|
| 206 |
+
diagnostic_results={
|
| 207 |
+
ActionType.INSPECT_METRICS: (
|
| 208 |
+
"METRICS DASHBOARD:\n"
|
| 209 |
+
" Model: loan-default-predictor v2.4.0\n"
|
| 210 |
+
" AUC: 0.74 (was 0.91 at deploy, 0.88 last week)\n"
|
| 211 |
+
" Accuracy: 0.69 (was 0.84)\n"
|
| 212 |
+
" False rejection rate: 14% (was 4%)\n"
|
| 213 |
+
" Threshold: 0.50 (unchanged)\n"
|
| 214 |
+
" Requests/sec: 320 (normal)\n"
|
| 215 |
+
" Degradation trend: gradual decline starting ~10 days ago"
|
| 216 |
+
),
|
| 217 |
+
ActionType.QUERY_LOGS: (
|
| 218 |
+
"LOG SEARCH RESULTS (last 14 days):\n"
|
| 219 |
+
" [Mar 17] New credit bureau connector deployed (v1.2)\n"
|
| 220 |
+
" [Mar 17] Feature 'credit_utilization_ratio' source changed\n"
|
| 221 |
+
" [Mar 18] Warning: 2,340 null values in 'credit_utilization_ratio'\n"
|
| 222 |
+
" [Mar 19-27] Recurring nulls in credit_utilization_ratio (avg 8%/day)\n"
|
| 223 |
+
" [Mar 22] Warning: feature 'annual_income' format changed (cents→dollars)\n"
|
| 224 |
+
" No deployment events. No model changes."
|
| 225 |
+
),
|
| 226 |
+
ActionType.CHECK_DATA_DISTRIBUTION: (
|
| 227 |
+
"DATA DISTRIBUTION COMPARISON:\n"
|
| 228 |
+
" Training data vs. Serving data (last 7 days):\n"
|
| 229 |
+
" ─────────────────────────────────────────────\n"
|
| 230 |
+
" credit_utilization_ratio:\n"
|
| 231 |
+
" Train: mean=0.34, std=0.21, null_rate=0.1%\n"
|
| 232 |
+
" Serve: mean=0.08, std=0.42, null_rate=8.3% ⚠ DRIFT DETECTED\n"
|
| 233 |
+
" KL divergence: 0.87 (threshold: 0.05)\n"
|
| 234 |
+
" annual_income:\n"
|
| 235 |
+
" Train: mean=65,400, std=28,000\n"
|
| 236 |
+
" Serve: mean=6,540,000, std=2,800,000 ⚠ DRIFT DETECTED\n"
|
| 237 |
+
" KL divergence: 12.4 (threshold: 0.05)\n"
|
| 238 |
+
" NOTE: Values appear 100x larger — possible unit change\n"
|
| 239 |
+
" Other features: within normal ranges"
|
| 240 |
+
),
|
| 241 |
+
ActionType.CHECK_FEATURE_IMPORTANCE: (
|
| 242 |
+
"FEATURE IMPORTANCE (top 5):\n"
|
| 243 |
+
" 1. credit_utilization_ratio: 0.28 ← AFFECTED\n"
|
| 244 |
+
" 2. annual_income: 0.22 ← AFFECTED\n"
|
| 245 |
+
" 3. debt_to_income: 0.19\n"
|
| 246 |
+
" 4. payment_history_score: 0.15\n"
|
| 247 |
+
" 5. employment_length: 0.08\n"
|
| 248 |
+
" The top 2 features (50% of importance) are drifted."
|
| 249 |
+
),
|
| 250 |
+
ActionType.RUN_PREDICTION_SAMPLE: (
|
| 251 |
+
"PREDICTION SAMPLE (50 recent applications with known outcomes):\n"
|
| 252 |
+
" Model predictions vs. actual outcomes:\n"
|
| 253 |
+
" - Correct: 31/50 (62%)\n"
|
| 254 |
+
" - False rejections: 14/50 (28%) — most involve high-income applicants\n"
|
| 255 |
+
" - Missed defaults: 5/50 (10%)\n"
|
| 256 |
+
" Annual income feature appears corrupted (values in cents not dollars)"
|
| 257 |
+
),
|
| 258 |
+
ActionType.CHECK_INFRASTRUCTURE: (
|
| 259 |
+
"INFRASTRUCTURE STATUS:\n"
|
| 260 |
+
" CPU: 22% (normal)\n Memory: 3.1GB/8GB (normal)\n"
|
| 261 |
+
" Latency P99: 38ms (normal)\n"
|
| 262 |
+
" No infrastructure issues."
|
| 263 |
+
),
|
| 264 |
+
ActionType.CHECK_UPSTREAM_PIPELINE: (
|
| 265 |
+
"UPSTREAM PIPELINE STATUS:\n"
|
| 266 |
+
" Feature store: HEALTHY (but stale for 2 features)\n"
|
| 267 |
+
" Data ingestion: DEGRADED\n"
|
| 268 |
+
" - credit_bureau_connector_v1.2: returning nulls 8% of time\n"
|
| 269 |
+
" - Schema mismatch: 'annual_income' was USD→now returned in cents\n"
|
| 270 |
+
" - Last successful full refresh: 11 days ago\n"
|
| 271 |
+
" Other connectors: nominal\n"
|
| 272 |
+
" RECOMMENDATION: Fix credit bureau connector, add unit conversion"
|
| 273 |
+
),
|
| 274 |
+
},
|
| 275 |
+
max_steps=20,
|
| 276 |
+
extra_state={
|
| 277 |
+
"requires_pipeline_fix": True,
|
| 278 |
+
"requires_retrain": True,
|
| 279 |
+
},
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
# ── Task 3: Hard — Adversarial + Silent Model Regression ────────────────────
|
| 284 |
+
|
| 285 |
+
TASK_HARD = TaskDefinition(
|
| 286 |
+
task_id="task_silent_regression",
|
| 287 |
+
name="Silent Model Regression with Adversarial Inputs",
|
| 288 |
+
difficulty="hard",
|
| 289 |
+
description=(
|
| 290 |
+
"INCIDENT: The content moderation model appears to be working normally — overall "
|
| 291 |
+
"accuracy is 0.96. However, trust & safety reports show a 5x increase in "
|
| 292 |
+
"user-reported harmful content slipping through. The model was retrained 3 days ago "
|
| 293 |
+
"with an automated pipeline. No alerts have fired because aggregate metrics look "
|
| 294 |
+
"fine. A community moderator escalated this manually. Investigate what's happening "
|
| 295 |
+
"and fix it. Be careful — hasty actions could take down moderation for millions of "
|
| 296 |
+
"users."
|
| 297 |
+
),
|
| 298 |
+
alerts=[
|
| 299 |
+
Alert(
|
| 300 |
+
severity=AlertSeverity.LOW,
|
| 301 |
+
message="Manual escalation: T&S reports 5x harmful content increase",
|
| 302 |
+
metric_name="user_reports_harmful",
|
| 303 |
+
current_value=847.0,
|
| 304 |
+
threshold=200.0,
|
| 305 |
+
timestamp="2026-03-27T06:00:00Z",
|
| 306 |
+
),
|
| 307 |
+
],
|
| 308 |
+
model_info=ModelInfo(
|
| 309 |
+
model_name="content-moderator",
|
| 310 |
+
model_version="v5.2.0",
|
| 311 |
+
deployed_at="2026-03-24T03:00:00Z",
|
| 312 |
+
framework="PyTorch (transformer)",
|
| 313 |
+
endpoint="/api/v1/moderate/content",
|
| 314 |
+
previous_versions=["v5.1.3", "v5.0.0", "v4.8.2"],
|
| 315 |
+
),
|
| 316 |
+
root_causes=[
|
| 317 |
+
"training_data_poisoning",
|
| 318 |
+
"adversarial_inputs",
|
| 319 |
+
"label_corruption",
|
| 320 |
+
"data_poisoning",
|
| 321 |
+
"silent_regression",
|
| 322 |
+
"class_specific_regression",
|
| 323 |
+
],
|
| 324 |
+
required_diagnostics=[
|
| 325 |
+
ActionType.INSPECT_METRICS,
|
| 326 |
+
ActionType.RUN_PREDICTION_SAMPLE,
|
| 327 |
+
ActionType.CHECK_DATA_DISTRIBUTION,
|
| 328 |
+
ActionType.CHECK_UPSTREAM_PIPELINE,
|
| 329 |
+
],
|
| 330 |
+
correct_remediations=[
|
| 331 |
+
ActionType.ROLLBACK_MODEL,
|
| 332 |
+
ActionType.ADD_FEATURE_GUARD,
|
| 333 |
+
],
|
| 334 |
+
diagnostic_results={
|
| 335 |
+
ActionType.INSPECT_METRICS: (
|
| 336 |
+
"METRICS DASHBOARD:\n"
|
| 337 |
+
" Model: content-moderator v5.2.0\n"
|
| 338 |
+
" Overall accuracy: 0.96 (looks normal!)\n"
|
| 339 |
+
" Overall precision: 0.94, recall: 0.93\n"
|
| 340 |
+
" Latency P99: 120ms (normal)\n"
|
| 341 |
+
" Requests/sec: 45,000 (normal)\n"
|
| 342 |
+
" ── Per-class breakdown ──\n"
|
| 343 |
+
" spam: precision=0.97, recall=0.98 (normal)\n"
|
| 344 |
+
" hate: precision=0.95, recall=0.42 ⚠ (was 0.91)\n"
|
| 345 |
+
" violence: precision=0.93, recall=0.38 ⚠ (was 0.89)\n"
|
| 346 |
+
" safe: precision=0.96, recall=0.99 (normal)\n"
|
| 347 |
+
" NOTE: Hate and violence recall collapsed but aggregate looks OK\n"
|
| 348 |
+
" because 'safe' class dominates (92% of traffic)."
|
| 349 |
+
),
|
| 350 |
+
ActionType.QUERY_LOGS: (
|
| 351 |
+
"LOG SEARCH RESULTS:\n"
|
| 352 |
+
" [Mar 24 02:30] Auto-retrain pipeline triggered\n"
|
| 353 |
+
" [Mar 24 02:45] Training dataset assembled: 2.4M samples\n"
|
| 354 |
+
" [Mar 24 02:46] Warning: 34,000 new labels from crowd-source batch #847\n"
|
| 355 |
+
" [Mar 24 03:00] Model v5.2.0 deployed (passed aggregate quality gate)\n"
|
| 356 |
+
" [Mar 24 03:01] A/B test: v5.2.0 vs v5.1.3 — aggregate acc comparable\n"
|
| 357 |
+
" [Mar 25] T&S team flags increase in harmful content reports\n"
|
| 358 |
+
" [Mar 26] Reports continue to climb\n"
|
| 359 |
+
" [Mar 27] Manual escalation from community moderator"
|
| 360 |
+
),
|
| 361 |
+
ActionType.CHECK_DATA_DISTRIBUTION: (
|
| 362 |
+
"DATA DISTRIBUTION COMPARISON:\n"
|
| 363 |
+
" Training data v5.2.0 vs v5.1.3:\n"
|
| 364 |
+
" ─────────────────────────────────────────────\n"
|
| 365 |
+
" Total samples: 2.4M (v5.2.0) vs 2.1M (v5.1.3)\n"
|
| 366 |
+
" New batch #847: 34,000 samples added\n"
|
| 367 |
+
" Label distribution in batch #847:\n"
|
| 368 |
+
" safe: 89% | hate: 7% | violence: 4% | spam: 0%\n"
|
| 369 |
+
" Cross-check with manual audit (100 random from batch #847):\n"
|
| 370 |
+
" Labeled 'safe' but actually hateful: 23/100 ⚠ CONTAMINATION\n"
|
| 371 |
+
" Labeled 'safe' but actually violent: 11/100 ⚠ CONTAMINATION\n"
|
| 372 |
+
" Correctly labeled: 66/100\n"
|
| 373 |
+
" CONCLUSION: Batch #847 contains systematically mislabeled data.\n"
|
| 374 |
+
" Hate/violence content was labeled as 'safe', teaching the model\n"
|
| 375 |
+
" to pass through harmful content."
|
| 376 |
+
),
|
| 377 |
+
ActionType.CHECK_FEATURE_IMPORTANCE: (
|
| 378 |
+
"ATTENTION ANALYSIS (transformer model):\n"
|
| 379 |
+
" Attention patterns on hate/violence content:\n"
|
| 380 |
+
" v5.1.3: High attention on slurs, threats, graphic terms\n"
|
| 381 |
+
" v5.2.0: Attention dispersed, less focus on key harmful tokens\n"
|
| 382 |
+
" The model appears to have 'unlearned' key harmful patterns\n"
|
| 383 |
+
" from the poisoned training batch."
|
| 384 |
+
),
|
| 385 |
+
ActionType.RUN_PREDICTION_SAMPLE: (
|
| 386 |
+
"PREDICTION SAMPLE (200 curated test samples, ground truth labels):\n"
|
| 387 |
+
" Category-level results:\n"
|
| 388 |
+
" ── Safe content (100 samples) ──\n"
|
| 389 |
+
" Correct: 98/100 (98%) — model correctly passes safe content\n"
|
| 390 |
+
" ── Hate speech (50 samples) ──\n"
|
| 391 |
+
" Detected: 21/50 (42%) ⚠ — model misses most hate speech\n"
|
| 392 |
+
" Missed examples include: coded language, dog-whistles, slurs\n"
|
| 393 |
+
" ── Violent content (50 samples) ──\n"
|
| 394 |
+
" Detected: 19/50 (38%) ⚠ — model misses most violence\n"
|
| 395 |
+
" Missed examples: graphic threats, incitement\n"
|
| 396 |
+
" Overall accuracy inflated by safe-class dominance."
|
| 397 |
+
),
|
| 398 |
+
ActionType.CHECK_INFRASTRUCTURE: (
|
| 399 |
+
"INFRASTRUCTURE STATUS:\n"
|
| 400 |
+
" GPU: 4x A100, utilization 67% (normal)\n"
|
| 401 |
+
" Memory: 28GB/80GB (normal)\n"
|
| 402 |
+
" Latency P99: 120ms (normal)\n"
|
| 403 |
+
" Throughput: 45K req/s (normal)\n"
|
| 404 |
+
" No infrastructure issues."
|
| 405 |
+
),
|
| 406 |
+
ActionType.CHECK_UPSTREAM_PIPELINE: (
|
| 407 |
+
"UPSTREAM PIPELINE STATUS:\n"
|
| 408 |
+
" Auto-retrain pipeline: HEALTHY (but concern flagged)\n"
|
| 409 |
+
" Training data sources:\n"
|
| 410 |
+
" - Internal labeled data: 2.1M samples (validated)\n"
|
| 411 |
+
" - Crowd-source batch #847: 34K samples (UNVALIDATED) ⚠\n"
|
| 412 |
+
" Quality gate: aggregate accuracy only — no per-class checks ⚠\n"
|
| 413 |
+
" The pipeline accepted v5.2.0 because overall accuracy was 0.96,\n"
|
| 414 |
+
" but it did not check per-class recall.\n"
|
| 415 |
+
" RECOMMENDATION: Add per-class recall gates. Remove batch #847.\n"
|
| 416 |
+
" Consider rollback to v5.1.3 while investigating."
|
| 417 |
+
),
|
| 418 |
+
},
|
| 419 |
+
max_steps=25,
|
| 420 |
+
extra_state={
|
| 421 |
+
"poisoned_batch": "#847",
|
| 422 |
+
"safe_rollback_version": "v5.1.3",
|
| 423 |
+
"requires_guard": True,
|
| 424 |
+
},
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
# ── Task Registry ────────────────────────────────────────────────────────────
|
| 429 |
+
|
| 430 |
+
ALL_TASKS: dict[str, TaskDefinition] = {
|
| 431 |
+
TASK_EASY.task_id: TASK_EASY,
|
| 432 |
+
TASK_MEDIUM.task_id: TASK_MEDIUM,
|
| 433 |
+
TASK_HARD.task_id: TASK_HARD,
|
| 434 |
+
}
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
# ── Grader Functions ─────────────────────────────────────────────────────────
|
| 438 |
+
|
| 439 |
+
def grade_episode(
|
| 440 |
+
task: TaskDefinition,
|
| 441 |
+
actions_taken: list[dict],
|
| 442 |
+
diagnosis_submitted: dict | None,
|
| 443 |
+
remediation_applied: list[str],
|
| 444 |
+
total_steps: int,
|
| 445 |
+
) -> tuple[float, dict]:
|
| 446 |
+
"""
|
| 447 |
+
Grade an agent's performance on a task.
|
| 448 |
+
|
| 449 |
+
Returns:
|
| 450 |
+
(score, breakdown) where score is 0.0–1.0 and breakdown is a dict
|
| 451 |
+
of component scores for interpretability.
|
| 452 |
+
"""
|
| 453 |
+
breakdown: dict[str, float] = {}
|
| 454 |
+
|
| 455 |
+
# ── 1. Diagnostic thoroughness (30%) ─────────────────────────────────
|
| 456 |
+
required = set(task.required_diagnostics)
|
| 457 |
+
performed = set()
|
| 458 |
+
for a in actions_taken:
|
| 459 |
+
try:
|
| 460 |
+
at = ActionType(a["action_type"])
|
| 461 |
+
if at in required:
|
| 462 |
+
performed.add(at)
|
| 463 |
+
except (ValueError, KeyError):
|
| 464 |
+
pass
|
| 465 |
+
|
| 466 |
+
if required:
|
| 467 |
+
diag_score = len(performed) / len(required)
|
| 468 |
+
else:
|
| 469 |
+
diag_score = 1.0
|
| 470 |
+
breakdown["diagnostic_thoroughness"] = round(diag_score, 3)
|
| 471 |
+
|
| 472 |
+
# ── 2. Correct diagnosis (30%) ───────────────────────────────────────
|
| 473 |
+
diag_submitted_score = 0.0
|
| 474 |
+
if diagnosis_submitted:
|
| 475 |
+
root = diagnosis_submitted.get("root_cause", "").lower().strip()
|
| 476 |
+
for valid in task.root_causes:
|
| 477 |
+
if valid.lower() in root or root in valid.lower():
|
| 478 |
+
diag_submitted_score = 1.0
|
| 479 |
+
break
|
| 480 |
+
# Partial credit: if they mention a related keyword
|
| 481 |
+
if diag_submitted_score == 0.0:
|
| 482 |
+
keywords = set()
|
| 483 |
+
for rc in task.root_causes:
|
| 484 |
+
keywords.update(rc.lower().replace("_", " ").split())
|
| 485 |
+
matches = sum(1 for kw in keywords if kw in root)
|
| 486 |
+
if matches > 0:
|
| 487 |
+
diag_submitted_score = min(0.5, matches * 0.2)
|
| 488 |
+
breakdown["diagnosis_accuracy"] = round(diag_submitted_score, 3)
|
| 489 |
+
|
| 490 |
+
# ── 3. Correct remediation (25%) ─────────────────────────────────────
|
| 491 |
+
correct_rems = set(r.value for r in task.correct_remediations)
|
| 492 |
+
applied = set(remediation_applied)
|
| 493 |
+
if correct_rems:
|
| 494 |
+
matched = len(applied & correct_rems)
|
| 495 |
+
rem_score = matched / len(correct_rems)
|
| 496 |
+
# Penalize wrong remediations
|
| 497 |
+
wrong = applied - correct_rems
|
| 498 |
+
penalty = len(wrong) * 0.15
|
| 499 |
+
rem_score = max(0.0, rem_score - penalty)
|
| 500 |
+
else:
|
| 501 |
+
rem_score = 1.0 if not applied else 0.5
|
| 502 |
+
breakdown["remediation_accuracy"] = round(rem_score, 3)
|
| 503 |
+
|
| 504 |
+
# ── 4. Efficiency (15%) ──────────────────────────────────────────────
|
| 505 |
+
if total_steps <= len(task.required_diagnostics) + 2:
|
| 506 |
+
eff_score = 1.0 # Very efficient
|
| 507 |
+
elif total_steps <= task.max_steps * 0.5:
|
| 508 |
+
eff_score = 0.8
|
| 509 |
+
elif total_steps <= task.max_steps * 0.75:
|
| 510 |
+
eff_score = 0.5
|
| 511 |
+
elif total_steps < task.max_steps:
|
| 512 |
+
eff_score = 0.3
|
| 513 |
+
else:
|
| 514 |
+
eff_score = 0.1 # Timed out
|
| 515 |
+
breakdown["efficiency"] = round(eff_score, 3)
|
| 516 |
+
|
| 517 |
+
# ── Weighted total ───────────────────────────────────────────────────
|
| 518 |
+
total = (
|
| 519 |
+
0.30 * breakdown["diagnostic_thoroughness"]
|
| 520 |
+
+ 0.30 * breakdown["diagnosis_accuracy"]
|
| 521 |
+
+ 0.25 * breakdown["remediation_accuracy"]
|
| 522 |
+
+ 0.15 * breakdown["efficiency"]
|
| 523 |
+
)
|
| 524 |
+
breakdown["total"] = round(total, 3)
|
| 525 |
+
|
| 526 |
+
return round(total, 3), breakdown
|
tests/test_environment.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Tests for the MLOps Firefighter environments."""
|
| 3 |
+
|
| 4 |
+
import sys
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 8 |
+
|
| 9 |
+
from models import ActionType, MLOpsAction
|
| 10 |
+
from server.environment import MLOpsFirefighterEnvironment
|
| 11 |
+
from tasks import ALL_TASKS, grade_episode
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def test_reset_returns_valid_observation():
|
| 15 |
+
env = MLOpsFirefighterEnvironment()
|
| 16 |
+
obs = env.reset(task_id="task_threshold_drift")
|
| 17 |
+
assert obs.done is False
|
| 18 |
+
assert obs.step_number == 0
|
| 19 |
+
assert obs.task_id == "task_threshold_drift"
|
| 20 |
+
assert len(obs.alerts) > 0
|
| 21 |
+
assert obs.model_info is not None
|
| 22 |
+
assert len(obs.available_actions) > 0
|
| 23 |
+
print("✓ test_reset_returns_valid_observation")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def test_step_diagnostic_action():
|
| 27 |
+
env = MLOpsFirefighterEnvironment()
|
| 28 |
+
env.reset(task_id="task_threshold_drift")
|
| 29 |
+
action = MLOpsAction(action_type=ActionType.INSPECT_METRICS)
|
| 30 |
+
obs = env.step(action)
|
| 31 |
+
assert obs.done is False
|
| 32 |
+
assert obs.reward > 0 # Useful diagnostic
|
| 33 |
+
assert obs.step_number == 1
|
| 34 |
+
assert "METRICS DASHBOARD" in obs.action_result
|
| 35 |
+
assert len(obs.diagnostics_gathered) == 1
|
| 36 |
+
print("✓ test_step_diagnostic_action")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def test_redundant_diagnostic():
|
| 40 |
+
env = MLOpsFirefighterEnvironment()
|
| 41 |
+
env.reset(task_id="task_threshold_drift")
|
| 42 |
+
action = MLOpsAction(action_type=ActionType.INSPECT_METRICS)
|
| 43 |
+
env.step(action)
|
| 44 |
+
obs2 = env.step(action)
|
| 45 |
+
assert obs2.reward == 0.05 # Redundant
|
| 46 |
+
assert "already ran" in obs2.action_result.lower()
|
| 47 |
+
print("✓ test_redundant_diagnostic")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def test_hasty_remediation_penalized():
|
| 51 |
+
env = MLOpsFirefighterEnvironment()
|
| 52 |
+
env.reset(task_id="task_threshold_drift")
|
| 53 |
+
# Apply fix without any diagnosis
|
| 54 |
+
action = MLOpsAction(
|
| 55 |
+
action_type=ActionType.ADJUST_THRESHOLD,
|
| 56 |
+
parameters={"new_threshold": 0.55},
|
| 57 |
+
)
|
| 58 |
+
obs = env.step(action)
|
| 59 |
+
assert obs.reward < 0 # Penalized
|
| 60 |
+
assert "WARNING" in obs.action_result
|
| 61 |
+
print("✓ test_hasty_remediation_penalized")
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def test_correct_remediation_after_diagnosis():
|
| 65 |
+
env = MLOpsFirefighterEnvironment()
|
| 66 |
+
env.reset(task_id="task_threshold_drift")
|
| 67 |
+
# Run 2 diagnostics
|
| 68 |
+
env.step(MLOpsAction(action_type=ActionType.INSPECT_METRICS))
|
| 69 |
+
env.step(MLOpsAction(action_type=ActionType.RUN_PREDICTION_SAMPLE))
|
| 70 |
+
# Apply correct fix
|
| 71 |
+
obs = env.step(MLOpsAction(
|
| 72 |
+
action_type=ActionType.ADJUST_THRESHOLD,
|
| 73 |
+
parameters={"new_threshold": 0.55},
|
| 74 |
+
))
|
| 75 |
+
assert obs.reward == 1.0 # Correct remediation
|
| 76 |
+
print("✓ test_correct_remediation_after_diagnosis")
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def test_submit_diagnosis_ends_episode():
|
| 80 |
+
env = MLOpsFirefighterEnvironment()
|
| 81 |
+
env.reset(task_id="task_threshold_drift")
|
| 82 |
+
env.step(MLOpsAction(action_type=ActionType.INSPECT_METRICS))
|
| 83 |
+
env.step(MLOpsAction(action_type=ActionType.RUN_PREDICTION_SAMPLE))
|
| 84 |
+
env.step(MLOpsAction(
|
| 85 |
+
action_type=ActionType.ADJUST_THRESHOLD,
|
| 86 |
+
parameters={"new_threshold": 0.55},
|
| 87 |
+
))
|
| 88 |
+
obs = env.step(MLOpsAction(
|
| 89 |
+
action_type=ActionType.SUBMIT_DIAGNOSIS,
|
| 90 |
+
parameters={"root_cause": "threshold_misconfiguration", "summary": "test"},
|
| 91 |
+
))
|
| 92 |
+
assert obs.done is True
|
| 93 |
+
assert obs.reward > 0
|
| 94 |
+
print("✓ test_submit_diagnosis_ends_episode")
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def test_state_tracking():
|
| 98 |
+
env = MLOpsFirefighterEnvironment()
|
| 99 |
+
env.reset(task_id="task_threshold_drift")
|
| 100 |
+
env.step(MLOpsAction(action_type=ActionType.INSPECT_METRICS))
|
| 101 |
+
st = env.state()
|
| 102 |
+
assert st["step_count"] == 1
|
| 103 |
+
assert st["task_id"] == "task_threshold_drift"
|
| 104 |
+
assert st["done"] is False
|
| 105 |
+
assert st["actions_taken_count"] == 1
|
| 106 |
+
print("✓ test_state_tracking")
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def test_timeout():
|
| 110 |
+
env = MLOpsFirefighterEnvironment()
|
| 111 |
+
env.reset(task_id="task_threshold_drift")
|
| 112 |
+
# Exhaust all steps with useless actions
|
| 113 |
+
for _ in range(15):
|
| 114 |
+
obs = env.step(MLOpsAction(action_type=ActionType.CHECK_INFRASTRUCTURE))
|
| 115 |
+
if obs.done:
|
| 116 |
+
break
|
| 117 |
+
assert obs.done is True
|
| 118 |
+
print("✓ test_timeout")
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def test_grader_scores_in_range():
|
| 122 |
+
"""Verify all grader scores are between 0.0 and 1.0."""
|
| 123 |
+
for task_id, task in ALL_TASKS.items():
|
| 124 |
+
# Perfect run
|
| 125 |
+
score_perfect, _ = grade_episode(
|
| 126 |
+
task=task,
|
| 127 |
+
actions_taken=[{"action_type": d.value} for d in task.required_diagnostics],
|
| 128 |
+
diagnosis_submitted={"root_cause": task.root_causes[0], "summary": "test"},
|
| 129 |
+
remediation_applied=[r.value for r in task.correct_remediations],
|
| 130 |
+
total_steps=len(task.required_diagnostics) + len(task.correct_remediations) + 1,
|
| 131 |
+
)
|
| 132 |
+
assert 0.0 <= score_perfect <= 1.0, f"Perfect score out of range: {score_perfect}"
|
| 133 |
+
|
| 134 |
+
# Zero run
|
| 135 |
+
score_zero, _ = grade_episode(
|
| 136 |
+
task=task,
|
| 137 |
+
actions_taken=[],
|
| 138 |
+
diagnosis_submitted=None,
|
| 139 |
+
remediation_applied=[],
|
| 140 |
+
total_steps=task.max_steps,
|
| 141 |
+
)
|
| 142 |
+
assert 0.0 <= score_zero <= 1.0, f"Zero score out of range: {score_zero}"
|
| 143 |
+
|
| 144 |
+
print("✓ test_grader_scores_in_range")
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def test_all_three_tasks_exist():
|
| 148 |
+
assert len(ALL_TASKS) >= 3
|
| 149 |
+
difficulties = {t.difficulty for t in ALL_TASKS.values()}
|
| 150 |
+
assert "easy" in difficulties
|
| 151 |
+
assert "medium" in difficulties
|
| 152 |
+
assert "hard" in difficulties
|
| 153 |
+
print("✓ test_all_three_tasks_exist")
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def test_full_episode_easy():
|
| 157 |
+
"""Full integration test: perfect run on easy task."""
|
| 158 |
+
env = MLOpsFirefighterEnvironment()
|
| 159 |
+
env.reset(task_id="task_threshold_drift")
|
| 160 |
+
|
| 161 |
+
# Diagnostics
|
| 162 |
+
env.step(MLOpsAction(action_type=ActionType.INSPECT_METRICS))
|
| 163 |
+
env.step(MLOpsAction(action_type=ActionType.RUN_PREDICTION_SAMPLE))
|
| 164 |
+
|
| 165 |
+
# Fix
|
| 166 |
+
env.step(MLOpsAction(
|
| 167 |
+
action_type=ActionType.ADJUST_THRESHOLD,
|
| 168 |
+
parameters={"new_threshold": 0.55},
|
| 169 |
+
))
|
| 170 |
+
|
| 171 |
+
# Submit
|
| 172 |
+
obs = env.step(MLOpsAction(
|
| 173 |
+
action_type=ActionType.SUBMIT_DIAGNOSIS,
|
| 174 |
+
parameters={"root_cause": "threshold_misconfiguration", "summary": "test"},
|
| 175 |
+
))
|
| 176 |
+
|
| 177 |
+
assert obs.done is True
|
| 178 |
+
st = env.state()
|
| 179 |
+
score = st["grader_result"]["total"]
|
| 180 |
+
assert score >= 0.85, f"Expected high score for perfect run, got {score}"
|
| 181 |
+
print(f"✓ test_full_episode_easy (score={score:.3f})")
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def test_full_episode_medium():
|
| 185 |
+
"""Full integration test: good run on medium task."""
|
| 186 |
+
env = MLOpsFirefighterEnvironment()
|
| 187 |
+
env.reset(task_id="task_data_drift")
|
| 188 |
+
|
| 189 |
+
env.step(MLOpsAction(action_type=ActionType.INSPECT_METRICS))
|
| 190 |
+
env.step(MLOpsAction(action_type=ActionType.CHECK_DATA_DISTRIBUTION))
|
| 191 |
+
env.step(MLOpsAction(action_type=ActionType.CHECK_UPSTREAM_PIPELINE))
|
| 192 |
+
|
| 193 |
+
env.step(MLOpsAction(action_type=ActionType.FIX_DATA_PIPELINE))
|
| 194 |
+
env.step(MLOpsAction(action_type=ActionType.RETRAIN_MODEL))
|
| 195 |
+
|
| 196 |
+
obs = env.step(MLOpsAction(
|
| 197 |
+
action_type=ActionType.SUBMIT_DIAGNOSIS,
|
| 198 |
+
parameters={"root_cause": "data_drift", "summary": "test"},
|
| 199 |
+
))
|
| 200 |
+
|
| 201 |
+
assert obs.done is True
|
| 202 |
+
st = env.state()
|
| 203 |
+
score = st["grader_result"]["total"]
|
| 204 |
+
assert score >= 0.80, f"Expected high score, got {score}"
|
| 205 |
+
print(f"✓ test_full_episode_medium (score={score:.3f})")
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def test_full_episode_hard():
|
| 209 |
+
"""Full integration test: good run on hard task."""
|
| 210 |
+
env = MLOpsFirefighterEnvironment()
|
| 211 |
+
env.reset(task_id="task_silent_regression")
|
| 212 |
+
|
| 213 |
+
env.step(MLOpsAction(action_type=ActionType.INSPECT_METRICS))
|
| 214 |
+
env.step(MLOpsAction(action_type=ActionType.RUN_PREDICTION_SAMPLE))
|
| 215 |
+
env.step(MLOpsAction(action_type=ActionType.CHECK_DATA_DISTRIBUTION))
|
| 216 |
+
env.step(MLOpsAction(action_type=ActionType.CHECK_UPSTREAM_PIPELINE))
|
| 217 |
+
|
| 218 |
+
env.step(MLOpsAction(
|
| 219 |
+
action_type=ActionType.ROLLBACK_MODEL,
|
| 220 |
+
parameters={"target_version": "v5.1.3"},
|
| 221 |
+
))
|
| 222 |
+
env.step(MLOpsAction(action_type=ActionType.ADD_FEATURE_GUARD))
|
| 223 |
+
|
| 224 |
+
obs = env.step(MLOpsAction(
|
| 225 |
+
action_type=ActionType.SUBMIT_DIAGNOSIS,
|
| 226 |
+
parameters={"root_cause": "training_data_poisoning", "summary": "test"},
|
| 227 |
+
))
|
| 228 |
+
|
| 229 |
+
assert obs.done is True
|
| 230 |
+
st = env.state()
|
| 231 |
+
score = st["grader_result"]["total"]
|
| 232 |
+
assert score >= 0.80, f"Expected high score, got {score}"
|
| 233 |
+
print(f"✓ test_full_episode_hard (score={score:.3f})")
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
if __name__ == "__main__":
|
| 237 |
+
print("\n=== MLOps Firefighter Environment Tests ===\n")
|
| 238 |
+
test_reset_returns_valid_observation()
|
| 239 |
+
test_step_diagnostic_action()
|
| 240 |
+
test_redundant_diagnostic()
|
| 241 |
+
test_hasty_remediation_penalized()
|
| 242 |
+
test_correct_remediation_after_diagnosis()
|
| 243 |
+
test_submit_diagnosis_ends_episode()
|
| 244 |
+
test_state_tracking()
|
| 245 |
+
test_timeout()
|
| 246 |
+
test_grader_scores_in_range()
|
| 247 |
+
test_all_three_tasks_exist()
|
| 248 |
+
test_full_episode_easy()
|
| 249 |
+
test_full_episode_medium()
|
| 250 |
+
test_full_episode_hard()
|
| 251 |
+
print("\n=== All tests passed! ===\n")
|
validate.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Pre-submission validation script for the MLOps Firefighter environment.
|
| 4 |
+
|
| 5 |
+
Checks all requirements from the hackathon rubric:
|
| 6 |
+
1. openenv.yaml exists and is valid
|
| 7 |
+
2. Typed Pydantic models exist
|
| 8 |
+
3. step()/reset()/state() work correctly
|
| 9 |
+
4. 3+ tasks with graders
|
| 10 |
+
5. Grader scores in 0.0–1.0 range
|
| 11 |
+
6. All required endpoints respond
|
| 12 |
+
7. Baseline produces scores
|
| 13 |
+
8. Dockerfile exists
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import json
|
| 17 |
+
import sys
|
| 18 |
+
import os
|
| 19 |
+
import yaml
|
| 20 |
+
|
| 21 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 22 |
+
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "server"))
|
| 23 |
+
|
| 24 |
+
PASS = "✅"
|
| 25 |
+
FAIL = "❌"
|
| 26 |
+
results = []
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def check(name: str, condition: bool, detail: str = ""):
|
| 30 |
+
status = PASS if condition else FAIL
|
| 31 |
+
results.append((name, condition))
|
| 32 |
+
msg = f" {status} {name}"
|
| 33 |
+
if detail:
|
| 34 |
+
msg += f" — {detail}"
|
| 35 |
+
print(msg)
|
| 36 |
+
return condition
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def main():
|
| 40 |
+
print("\n" + "=" * 60)
|
| 41 |
+
print(" MLOps Firefighter — Pre-Submission Validator")
|
| 42 |
+
print("=" * 60 + "\n")
|
| 43 |
+
|
| 44 |
+
# 1. openenv.yaml
|
| 45 |
+
print("[1/8] OpenEnv manifest (openenv.yaml)")
|
| 46 |
+
yaml_path = os.path.join(os.path.dirname(__file__), "openenv.yaml")
|
| 47 |
+
has_yaml = os.path.exists(yaml_path)
|
| 48 |
+
check("openenv.yaml exists", has_yaml)
|
| 49 |
+
if has_yaml:
|
| 50 |
+
with open(yaml_path) as f:
|
| 51 |
+
manifest = yaml.safe_load(f)
|
| 52 |
+
check("Has name", "name" in manifest)
|
| 53 |
+
check("Has version", "version" in manifest)
|
| 54 |
+
check("Has description", "description" in manifest)
|
| 55 |
+
check("Has tasks", "tasks" in manifest and len(manifest["tasks"]) >= 3)
|
| 56 |
+
check("Has 'openenv' tag", "openenv" in manifest.get("tags", []))
|
| 57 |
+
|
| 58 |
+
# 2. Typed Pydantic models
|
| 59 |
+
print("\n[2/8] Typed Pydantic models")
|
| 60 |
+
try:
|
| 61 |
+
from models import MLOpsAction, MLOpsObservation, ActionType
|
| 62 |
+
check("MLOpsAction importable", True)
|
| 63 |
+
check("MLOpsObservation importable", True)
|
| 64 |
+
check("ActionType enum exists", len(ActionType) >= 10)
|
| 65 |
+
# Verify they're Pydantic
|
| 66 |
+
a = MLOpsAction(action_type=ActionType.INSPECT_METRICS)
|
| 67 |
+
check("MLOpsAction is Pydantic", hasattr(a, "model_dump"))
|
| 68 |
+
except Exception as e:
|
| 69 |
+
check("Models import", False, str(e))
|
| 70 |
+
|
| 71 |
+
# 3. step()/reset()/state()
|
| 72 |
+
print("\n[3/8] Environment interface (reset/step/state)")
|
| 73 |
+
try:
|
| 74 |
+
from server.environment import MLOpsFirefighterEnvironment
|
| 75 |
+
env = MLOpsFirefighterEnvironment()
|
| 76 |
+
|
| 77 |
+
obs = env.reset(task_id="task_threshold_drift")
|
| 78 |
+
check("reset() returns observation", obs is not None)
|
| 79 |
+
check("reset() obs has done=False", obs.done is False)
|
| 80 |
+
check("reset() obs has step_number=0", obs.step_number == 0)
|
| 81 |
+
|
| 82 |
+
from models import MLOpsAction, ActionType
|
| 83 |
+
obs2 = env.step(MLOpsAction(action_type=ActionType.INSPECT_METRICS))
|
| 84 |
+
check("step() returns observation", obs2 is not None)
|
| 85 |
+
check("step() increments step_number", obs2.step_number == 1)
|
| 86 |
+
check("step() returns reward", isinstance(obs2.reward, float))
|
| 87 |
+
|
| 88 |
+
st = env.state()
|
| 89 |
+
check("state() returns dict", isinstance(st, dict))
|
| 90 |
+
check("state() has episode_id", "episode_id" in st)
|
| 91 |
+
check("state() has step_count", "step_count" in st)
|
| 92 |
+
except Exception as e:
|
| 93 |
+
check("Environment interface", False, str(e))
|
| 94 |
+
|
| 95 |
+
# 4. 3+ tasks
|
| 96 |
+
print("\n[4/8] Task definitions")
|
| 97 |
+
try:
|
| 98 |
+
from tasks import ALL_TASKS
|
| 99 |
+
check("3+ tasks defined", len(ALL_TASKS) >= 3)
|
| 100 |
+
difficulties = {t.difficulty for t in ALL_TASKS.values()}
|
| 101 |
+
check("Has easy task", "easy" in difficulties)
|
| 102 |
+
check("Has medium task", "medium" in difficulties)
|
| 103 |
+
check("Has hard task", "hard" in difficulties)
|
| 104 |
+
for tid, task in ALL_TASKS.items():
|
| 105 |
+
check(f"Task '{tid}' has root_causes", len(task.root_causes) > 0)
|
| 106 |
+
check(f"Task '{tid}' has diagnostics", len(task.required_diagnostics) > 0)
|
| 107 |
+
check(f"Task '{tid}' has remediations", len(task.correct_remediations) > 0)
|
| 108 |
+
except Exception as e:
|
| 109 |
+
check("Tasks", False, str(e))
|
| 110 |
+
|
| 111 |
+
# 5. Grader scores in range
|
| 112 |
+
print("\n[5/8] Grader scoring (0.0–1.0)")
|
| 113 |
+
try:
|
| 114 |
+
from tasks import grade_episode, ALL_TASKS
|
| 115 |
+
from models import ActionType
|
| 116 |
+
for tid, task in ALL_TASKS.items():
|
| 117 |
+
# Perfect
|
| 118 |
+
score, bd = grade_episode(
|
| 119 |
+
task=task,
|
| 120 |
+
actions_taken=[{"action_type": d.value} for d in task.required_diagnostics],
|
| 121 |
+
diagnosis_submitted={"root_cause": task.root_causes[0]},
|
| 122 |
+
remediation_applied=[r.value for r in task.correct_remediations],
|
| 123 |
+
total_steps=len(task.required_diagnostics) + 2,
|
| 124 |
+
)
|
| 125 |
+
check(f"'{tid}' perfect score in [0,1]", 0.0 <= score <= 1.0, f"{score:.3f}")
|
| 126 |
+
|
| 127 |
+
# Empty
|
| 128 |
+
score_z, _ = grade_episode(
|
| 129 |
+
task=task, actions_taken=[], diagnosis_submitted=None,
|
| 130 |
+
remediation_applied=[], total_steps=task.max_steps,
|
| 131 |
+
)
|
| 132 |
+
check(f"'{tid}' empty score in [0,1]", 0.0 <= score_z <= 1.0, f"{score_z:.3f}")
|
| 133 |
+
|
| 134 |
+
# Partial credit varies
|
| 135 |
+
check(f"'{tid}' grader differentiates", score > score_z, f"perfect={score:.3f} > empty={score_z:.3f}")
|
| 136 |
+
except Exception as e:
|
| 137 |
+
check("Grader", False, str(e))
|
| 138 |
+
|
| 139 |
+
# 6. All endpoints
|
| 140 |
+
print("\n[6/8] HTTP endpoints")
|
| 141 |
+
try:
|
| 142 |
+
from fastapi.testclient import TestClient
|
| 143 |
+
from server.app import app
|
| 144 |
+
client = TestClient(app)
|
| 145 |
+
|
| 146 |
+
r = client.get("/health")
|
| 147 |
+
check("/health returns 200", r.status_code == 200)
|
| 148 |
+
|
| 149 |
+
r = client.get("/tasks")
|
| 150 |
+
check("/tasks returns 200", r.status_code == 200)
|
| 151 |
+
check("/tasks has action_schema", "action_schema" in r.json())
|
| 152 |
+
|
| 153 |
+
r = client.post("/reset", json={"task_id": "task_threshold_drift"})
|
| 154 |
+
check("/reset returns 200", r.status_code == 200)
|
| 155 |
+
|
| 156 |
+
r = client.post("/step", json={"action_type": "inspect_metrics"})
|
| 157 |
+
check("/step returns 200", r.status_code == 200)
|
| 158 |
+
|
| 159 |
+
r = client.get("/state")
|
| 160 |
+
check("/state returns 200", r.status_code == 200)
|
| 161 |
+
|
| 162 |
+
# Complete an episode for grader test
|
| 163 |
+
client.post("/reset", json={"task_id": "task_threshold_drift"})
|
| 164 |
+
client.post("/step", json={"action_type": "inspect_metrics"})
|
| 165 |
+
client.post("/step", json={"action_type": "submit_diagnosis",
|
| 166 |
+
"parameters": {"root_cause": "test", "summary": "t"}})
|
| 167 |
+
r = client.post("/grader", json={})
|
| 168 |
+
check("/grader returns 200", r.status_code == 200)
|
| 169 |
+
|
| 170 |
+
r = client.post("/baseline")
|
| 171 |
+
check("/baseline returns 200", r.status_code == 200)
|
| 172 |
+
check("/baseline has scores", "average_score" in r.json())
|
| 173 |
+
except Exception as e:
|
| 174 |
+
check("Endpoints", False, str(e))
|
| 175 |
+
|
| 176 |
+
# 7. Baseline produces scores
|
| 177 |
+
print("\n[7/8] Baseline scoring")
|
| 178 |
+
try:
|
| 179 |
+
r = client.post("/baseline")
|
| 180 |
+
data = r.json()
|
| 181 |
+
avg = data["average_score"]
|
| 182 |
+
check("Baseline avg score > 0", avg > 0, f"avg={avg}")
|
| 183 |
+
for tid, result in data["baseline_results"].items():
|
| 184 |
+
s = result["score"]
|
| 185 |
+
check(f"Baseline '{tid}' in [0,1]", 0.0 <= s <= 1.0, f"{s:.3f}")
|
| 186 |
+
except Exception as e:
|
| 187 |
+
check("Baseline", False, str(e))
|
| 188 |
+
|
| 189 |
+
# 8. Dockerfile exists
|
| 190 |
+
print("\n[8/8] Dockerfile")
|
| 191 |
+
df_path = os.path.join(os.path.dirname(__file__), "Dockerfile")
|
| 192 |
+
check("Dockerfile exists", os.path.exists(df_path))
|
| 193 |
+
if os.path.exists(df_path):
|
| 194 |
+
with open(df_path) as f:
|
| 195 |
+
content = f.read()
|
| 196 |
+
check("Dockerfile has FROM", "FROM" in content)
|
| 197 |
+
check("Dockerfile has EXPOSE", "EXPOSE" in content)
|
| 198 |
+
check("Dockerfile has CMD", "CMD" in content)
|
| 199 |
+
|
| 200 |
+
# Summary
|
| 201 |
+
total = len(results)
|
| 202 |
+
passed = sum(1 for _, ok in results if ok)
|
| 203 |
+
failed = total - passed
|
| 204 |
+
|
| 205 |
+
print("\n" + "=" * 60)
|
| 206 |
+
if failed == 0:
|
| 207 |
+
print(f" {PASS} ALL {total} CHECKS PASSED — Ready to submit!")
|
| 208 |
+
else:
|
| 209 |
+
print(f" {FAIL} {failed}/{total} checks failed")
|
| 210 |
+
for name, ok in results:
|
| 211 |
+
if not ok:
|
| 212 |
+
print(f" - {name}")
|
| 213 |
+
print("=" * 60 + "\n")
|
| 214 |
+
|
| 215 |
+
return 0 if failed == 0 else 1
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
if __name__ == "__main__":
|
| 219 |
+
sys.exit(main())
|