mlops-firefighter / validate.py
ShubhanKamat's picture
MLOps Firefighter - OpenEnv environment
670f19f
#!/usr/bin/env python3
"""
Pre-submission validation script for the MLOps Firefighter environment.
Checks all requirements from the hackathon rubric:
1. openenv.yaml exists and is valid
2. Typed Pydantic models exist
3. step()/reset()/state() work correctly
4. 3+ tasks with graders
5. Grader scores in 0.0–1.0 range
6. All required endpoints respond
7. Baseline produces scores
8. Dockerfile exists
"""
import json
import sys
import os
import yaml
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "server"))
PASS = "✅"
FAIL = "❌"
results = []
def check(name: str, condition: bool, detail: str = ""):
status = PASS if condition else FAIL
results.append((name, condition))
msg = f" {status} {name}"
if detail:
msg += f" — {detail}"
print(msg)
return condition
def main():
print("\n" + "=" * 60)
print(" MLOps Firefighter — Pre-Submission Validator")
print("=" * 60 + "\n")
# 1. openenv.yaml
print("[1/8] OpenEnv manifest (openenv.yaml)")
yaml_path = os.path.join(os.path.dirname(__file__), "openenv.yaml")
has_yaml = os.path.exists(yaml_path)
check("openenv.yaml exists", has_yaml)
if has_yaml:
with open(yaml_path) as f:
manifest = yaml.safe_load(f)
check("Has name", "name" in manifest)
check("Has version", "version" in manifest)
check("Has description", "description" in manifest)
check("Has tasks", "tasks" in manifest and len(manifest["tasks"]) >= 3)
check("Has 'openenv' tag", "openenv" in manifest.get("tags", []))
# 2. Typed Pydantic models
print("\n[2/8] Typed Pydantic models")
try:
from models import MLOpsAction, MLOpsObservation, ActionType
check("MLOpsAction importable", True)
check("MLOpsObservation importable", True)
check("ActionType enum exists", len(ActionType) >= 10)
# Verify they're Pydantic
a = MLOpsAction(action_type=ActionType.INSPECT_METRICS)
check("MLOpsAction is Pydantic", hasattr(a, "model_dump"))
except Exception as e:
check("Models import", False, str(e))
# 3. step()/reset()/state()
print("\n[3/8] Environment interface (reset/step/state)")
try:
from server.environment import MLOpsFirefighterEnvironment
env = MLOpsFirefighterEnvironment()
obs = env.reset(task_id="task_threshold_drift")
check("reset() returns observation", obs is not None)
check("reset() obs has done=False", obs.done is False)
check("reset() obs has step_number=0", obs.step_number == 0)
from models import MLOpsAction, ActionType
obs2 = env.step(MLOpsAction(action_type=ActionType.INSPECT_METRICS))
check("step() returns observation", obs2 is not None)
check("step() increments step_number", obs2.step_number == 1)
check("step() returns reward", isinstance(obs2.reward, float))
st = env.state()
check("state() returns dict", isinstance(st, dict))
check("state() has episode_id", "episode_id" in st)
check("state() has step_count", "step_count" in st)
except Exception as e:
check("Environment interface", False, str(e))
# 4. 3+ tasks
print("\n[4/8] Task definitions")
try:
from tasks import ALL_TASKS
check("3+ tasks defined", len(ALL_TASKS) >= 3)
difficulties = {t.difficulty for t in ALL_TASKS.values()}
check("Has easy task", "easy" in difficulties)
check("Has medium task", "medium" in difficulties)
check("Has hard task", "hard" in difficulties)
for tid, task in ALL_TASKS.items():
check(f"Task '{tid}' has root_causes", len(task.root_causes) > 0)
check(f"Task '{tid}' has diagnostics", len(task.required_diagnostics) > 0)
check(f"Task '{tid}' has remediations", len(task.correct_remediations) > 0)
except Exception as e:
check("Tasks", False, str(e))
# 5. Grader scores in range
print("\n[5/8] Grader scoring (0.0–1.0)")
try:
from tasks import grade_episode, ALL_TASKS
from models import ActionType
for tid, task in ALL_TASKS.items():
# Perfect
score, bd = grade_episode(
task=task,
actions_taken=[{"action_type": d.value} for d in task.required_diagnostics],
diagnosis_submitted={"root_cause": task.root_causes[0]},
remediation_applied=[r.value for r in task.correct_remediations],
total_steps=len(task.required_diagnostics) + 2,
)
check(f"'{tid}' perfect score in [0,1]", 0.0 <= score <= 1.0, f"{score:.3f}")
# Empty
score_z, _ = grade_episode(
task=task, actions_taken=[], diagnosis_submitted=None,
remediation_applied=[], total_steps=task.max_steps,
)
check(f"'{tid}' empty score in [0,1]", 0.0 <= score_z <= 1.0, f"{score_z:.3f}")
# Partial credit varies
check(f"'{tid}' grader differentiates", score > score_z, f"perfect={score:.3f} > empty={score_z:.3f}")
except Exception as e:
check("Grader", False, str(e))
# 6. All endpoints
print("\n[6/8] HTTP endpoints")
try:
from fastapi.testclient import TestClient
from server.app import app
client = TestClient(app)
r = client.get("/health")
check("/health returns 200", r.status_code == 200)
r = client.get("/tasks")
check("/tasks returns 200", r.status_code == 200)
check("/tasks has action_schema", "action_schema" in r.json())
r = client.post("/reset", json={"task_id": "task_threshold_drift"})
check("/reset returns 200", r.status_code == 200)
r = client.post("/step", json={"action_type": "inspect_metrics"})
check("/step returns 200", r.status_code == 200)
r = client.get("/state")
check("/state returns 200", r.status_code == 200)
# Complete an episode for grader test
client.post("/reset", json={"task_id": "task_threshold_drift"})
client.post("/step", json={"action_type": "inspect_metrics"})
client.post("/step", json={"action_type": "submit_diagnosis",
"parameters": {"root_cause": "test", "summary": "t"}})
r = client.post("/grader", json={})
check("/grader returns 200", r.status_code == 200)
r = client.post("/baseline")
check("/baseline returns 200", r.status_code == 200)
check("/baseline has scores", "average_score" in r.json())
except Exception as e:
check("Endpoints", False, str(e))
# 7. Baseline produces scores
print("\n[7/8] Baseline scoring")
try:
r = client.post("/baseline")
data = r.json()
avg = data["average_score"]
check("Baseline avg score > 0", avg > 0, f"avg={avg}")
for tid, result in data["baseline_results"].items():
s = result["score"]
check(f"Baseline '{tid}' in [0,1]", 0.0 <= s <= 1.0, f"{s:.3f}")
except Exception as e:
check("Baseline", False, str(e))
# 8. Dockerfile exists
print("\n[8/8] Dockerfile")
df_path = os.path.join(os.path.dirname(__file__), "Dockerfile")
check("Dockerfile exists", os.path.exists(df_path))
if os.path.exists(df_path):
with open(df_path) as f:
content = f.read()
check("Dockerfile has FROM", "FROM" in content)
check("Dockerfile has EXPOSE", "EXPOSE" in content)
check("Dockerfile has CMD", "CMD" in content)
# Summary
total = len(results)
passed = sum(1 for _, ok in results if ok)
failed = total - passed
print("\n" + "=" * 60)
if failed == 0:
print(f" {PASS} ALL {total} CHECKS PASSED — Ready to submit!")
else:
print(f" {FAIL} {failed}/{total} checks failed")
for name, ok in results:
if not ok:
print(f" - {name}")
print("=" * 60 + "\n")
return 0 if failed == 0 else 1
if __name__ == "__main__":
sys.exit(main())