File size: 6,772 Bytes
4f58e42 9e6a926 6be6d8e 4f58e42 6be6d8e 4f58e42 9e6a926 4f58e42 9e6a926 d8eeec6 9e6a926 d8eeec6 0b9b77b 9e6a926 4f58e42 9e6a926 0b9b77b 9e6a926 4f58e42 0b9b77b 4f58e42 9e6a926 4f58e42 9e6a926 4f58e42 0b9b77b 4f58e42 8435256 4f58e42 8435256 4f58e42 9e6a926 4f58e42 0b9b77b 4f58e42 0b9b77b 4f58e42 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 | """Integration tests for HTTP endpoints.
Covers: /health, /tasks, /grader, /baseline, /dashboard.
Also tests the internal _run_heuristic_episode and _run_baseline_sync.
"""
from __future__ import annotations
import pytest
from fastapi.testclient import TestClient
from server.app import ALL_TASKS, app
from server._heuristic import (
_get_score,
_run_heuristic_episode,
run_baseline_all_tasks as _run_baseline_sync,
)
from server.environment import MLTrainingEnvironment
@pytest.fixture
def client():
return TestClient(app)
# ---------- /health ----------
class TestHealthEndpoint:
def test_returns_healthy(self, client):
resp = client.get("/health")
assert resp.status_code == 200
data = resp.json()
assert data["status"] == "healthy"
assert data["tasks"] == 7
def test_task_count_matches_all_tasks(self, client):
resp = client.get("/health")
assert resp.json()["tasks"] == len(ALL_TASKS)
# ---------- /tasks ----------
class TestTasksEndpoint:
def test_returns_six_tasks(self, client):
resp = client.get("/tasks")
assert resp.status_code == 200
tasks = resp.json()
assert len(tasks) == 7
ids = [t["id"] for t in tasks]
assert "task_001" in ids
assert "task_006" in ids
def test_tasks_have_action_schema(self, client):
resp = client.get("/tasks")
tasks = resp.json()
for task in tasks:
assert "action_schema" in task
assert "properties" in task["action_schema"]
def test_tasks_have_difficulty_and_max_steps(self, client):
resp = client.get("/tasks")
for task in resp.json():
assert "difficulty" in task
assert task["difficulty"] in ("easy", "medium", "hard", "medium-hard")
assert "max_steps" in task
assert task["max_steps"] > 0
# ---------- /grader ----------
class TestGraderEndpoint:
def test_no_completed_episode(self, client):
import server._baseline_results as br
br._last_results.clear()
resp = client.post("/grader")
assert resp.status_code == 200
data = resp.json()
assert data["score"] is None
assert data["error"] == "no_completed_episode"
def test_grader_after_completed_episode(self, client):
"""Run a quick episode then verify /grader returns a score."""
import server._baseline_results as br
br._last_results.clear()
# Run a minimal episode via the internal function
env = MLTrainingEnvironment()
env.reset(seed=42, episode_id="grader_test", task_id="task_001")
score = _run_heuristic_episode(env, "task_001")
assert 0.0 <= score <= 1.0
# Now the grader endpoint should return the stored result
resp = client.post("/grader")
data = resp.json()
assert data["score"] is not None
assert 0.0 <= data["score"] <= 1.0
def test_grader_with_session_id(self, client):
"""Grader can filter by session_id."""
import server._baseline_results as br
br._last_results.clear()
resp = client.post("/grader?session_id=nonexistent_session")
data = resp.json()
assert data["score"] is None
# ---------- /baseline ----------
class TestBaselineEndpoint:
def test_baseline_returns_scores(self, client):
resp = client.post("/baseline")
assert resp.status_code == 200
data = resp.json()
assert "scores" in data
scores = data["scores"]
assert len(scores) == 7
for task_id, score in scores.items():
assert 0.0 <= score <= 1.0, f"{task_id}: {score}"
def test_baseline_scores_in_valid_range(self, client):
resp = client.post("/baseline")
scores = resp.json()["scores"]
values = list(scores.values())
assert all(0.0 <= v <= 1.0 for v in values), "Scores must be in [0.0, 1.0]"
assert len(values) >= 3, "Need at least 3 tasks"
# ---------- /dashboard ----------
class TestDashboardEndpoint:
def test_returns_html(self, client):
resp = client.get("/dashboard")
assert resp.status_code == 200
assert "Plotly" in resp.text
assert "WebSocket" in resp.text
# ---------- Internal heuristic functions ----------
class TestRunHeuristicEpisode:
"""Test the internal baseline heuristic logic in app.py."""
def test_task_001_exploding(self):
env = MLTrainingEnvironment()
env.reset(seed=42, episode_id="h_001", task_id="task_001")
score = _run_heuristic_episode(env, "task_001")
assert score == 1.0
def test_task_002_vanishing(self):
env = MLTrainingEnvironment()
env.reset(seed=42, episode_id="h_002", task_id="task_002")
score = _run_heuristic_episode(env, "task_002")
assert score == 1.0
def test_task_003_leakage(self):
env = MLTrainingEnvironment()
env.reset(seed=42, episode_id="h_003", task_id="task_003")
score = _run_heuristic_episode(env, "task_003")
assert score >= 0.9
def test_task_004_overfitting(self):
env = MLTrainingEnvironment()
env.reset(seed=42, episode_id="h_004", task_id="task_004")
score = _run_heuristic_episode(env, "task_004")
assert 0.0 < score <= 1.0
def test_task_005_batchnorm(self):
env = MLTrainingEnvironment()
env.reset(seed=42, episode_id="h_005", task_id="task_005")
score = _run_heuristic_episode(env, "task_005")
assert 0.0 < score <= 1.0
def test_task_006_code_bug(self):
env = MLTrainingEnvironment()
env.reset(seed=42, episode_id="h_006", task_id="task_006")
score = _run_heuristic_episode(env, "task_006")
assert score >= 0.4
class TestGetScore:
def test_no_session(self):
env = MLTrainingEnvironment()
assert _get_score(env) == 0.0
def test_with_session(self):
env = MLTrainingEnvironment()
env.reset(seed=42, episode_id="gs_test", task_id="task_001")
_run_heuristic_episode(env, "task_001")
assert _get_score(env) >= 0.0
class TestRunBaselineSync:
def test_returns_all_tasks(self):
scores = _run_baseline_sync()
assert len(scores) == 7
for task_id in [
"task_001",
"task_002",
"task_003",
"task_004",
"task_005",
"task_006",
"task_007",
]:
assert task_id in scores
assert 0.0 <= scores[task_id] <= 1.0
def test_reproducible(self):
scores1 = _run_baseline_sync()
scores2 = _run_baseline_sync()
assert scores1 == scores2
|