incident-triage-env / tests /test_env.py
XcodeAddy's picture
Keep reset and schema rewards inside unit interval
af2ccc5
import unittest
from fastapi.testclient import TestClient
from app import app, completed_states, sessions
from environment import IncidentEnv, validate_ticket_dataset
from incidents import TICKETS
from models import IncidentAction, IncidentState, TaskType
class IncidentEnvApiTests(unittest.TestCase):
def setUp(self) -> None:
sessions.clear()
completed_states.clear()
self.client = TestClient(app)
def tearDown(self) -> None:
sessions.clear()
completed_states.clear()
def test_health_schema_and_mcp_helper_endpoints(self) -> None:
health_response = self.client.get("/health")
self.assertEqual(health_response.status_code, 200)
self.assertEqual(health_response.json()["status"], "healthy")
schema_response = self.client.get("/schema")
self.assertEqual(schema_response.status_code, 200)
schema_body = schema_response.json()
self.assertIn("action", schema_body)
self.assertIn("observation", schema_body)
self.assertIn("state", schema_body)
mcp_response = self.client.post("/mcp", json={"jsonrpc": "2.0", "id": 1, "method": "ping"})
self.assertEqual(mcp_response.status_code, 200)
mcp_body = mcp_response.json()
self.assertEqual(mcp_body["jsonrpc"], "2.0")
self.assertEqual(mcp_body["id"], 1)
grader_response = self.client.get("/grader")
self.assertEqual(grader_response.status_code, 200)
grader_body = grader_response.json()
self.assertIn("notes", grader_body)
self.assertIn("task2", grader_body["notes"])
def test_tickets_endpoint_returns_safe_ticket_inventory(self) -> None:
response = self.client.get("/tickets")
self.assertEqual(response.status_code, 200)
body = response.json()
self.assertEqual(body["count"], len(TICKETS))
self.assertEqual(body["tickets"][0]["incident_id"], "INC-001")
self.assertIn("expected_field", body["tickets"][0])
self.assertNotIn("ground_truth", body["tickets"][0])
def test_ui_routes_and_assets_are_served(self) -> None:
home_response = self.client.get("/")
self.assertEqual(home_response.status_code, 200)
self.assertIn("Incident Triage Environment", home_response.text)
status_response = self.client.get("/status")
self.assertEqual(status_response.status_code, 200)
self.assertIn("Environment readiness dashboard", status_response.text)
playground_response = self.client.get("/playground")
self.assertEqual(playground_response.status_code, 200)
self.assertIn("Interactive playground", playground_response.text)
api_response = self.client.get("/api")
self.assertEqual(api_response.status_code, 200)
self.assertIn("API Explorer", api_response.text)
asset_response = self.client.get("/assets/app.js")
self.assertEqual(asset_response.status_code, 200)
self.assertIn("bootstrap", asset_response.text)
def test_reset_returns_requested_ticket_and_session_state(self) -> None:
response = self.client.post(
"/reset",
json={"task_type": "task3", "ticket_id": "INC-014"},
)
self.assertEqual(response.status_code, 200)
body = response.json()
self.assertEqual(body["observation"]["incident_id"], "INC-014")
self.assertEqual(body["observation"]["task_type"], "task3")
self.assertEqual(body["reward"]["value"], 0.01)
self.assertFalse(body["done"])
self.assertIn("session_id", body["info"])
self.assertEqual(body["info"]["state"]["status"], "awaiting_action")
def test_reset_without_seed_is_deterministic_for_same_task(self) -> None:
first_response = self.client.post("/reset", json={"task_type": "task2"})
second_response = self.client.post("/reset", json={"task_type": "task2"})
self.assertEqual(first_response.status_code, 200)
self.assertEqual(second_response.status_code, 200)
self.assertEqual(
first_response.json()["observation"]["incident_id"],
second_response.json()["observation"]["incident_id"],
)
def test_step_completes_episode_and_state_endpoint_reflects_completion(self) -> None:
reset_response = self.client.post(
"/reset",
json={"task_type": "task3", "ticket_id": "INC-014"},
)
session_id = reset_response.json()["info"]["session_id"]
step_response = self.client.post(
f"/step?session_id={session_id}",
json={
"incident_id": "INC-014",
"task_type": "task3",
"action": "FAILOVER",
},
)
self.assertEqual(step_response.status_code, 200)
step_body = step_response.json()
self.assertTrue(step_body["done"])
self.assertEqual(step_body["reward"]["value"], 0.99)
self.assertTrue(step_body["info"]["correct"])
self.assertEqual(step_body["info"]["ground_truth"], "FAILOVER")
state_response = self.client.get(f"/state?session_id={session_id}")
self.assertEqual(state_response.status_code, 200)
state_body = state_response.json()
self.assertTrue(state_body["done"])
self.assertEqual(state_body["status"], "completed")
self.assertEqual(state_body["last_reward"], 0.99)
self.assertNotIn(session_id, sessions)
self.assertIn(session_id, completed_states)
repeated_step_response = self.client.post(
f"/step?session_id={session_id}",
json={
"incident_id": "INC-014",
"task_type": "task3",
"action": "FAILOVER",
},
)
self.assertEqual(repeated_step_response.status_code, 400)
self.assertIn("already completed", repeated_step_response.json()["detail"])
def test_step_rejects_action_for_wrong_task_type(self) -> None:
reset_response = self.client.post(
"/reset",
json={"task_type": "task3", "ticket_id": "INC-014"},
)
session_id = reset_response.json()["info"]["session_id"]
step_response = self.client.post(
f"/step?session_id={session_id}",
json={
"incident_id": "INC-014",
"task_type": "task2",
"root_cause": "NETWORK",
},
)
self.assertEqual(step_response.status_code, 400)
self.assertIn("does not match", step_response.json()["detail"])
def test_dataset_validation_rejects_empty_ground_truth(self) -> None:
with self.assertRaisesRegex(RuntimeError, "empty ground_truth"):
validate_ticket_dataset(
[
{
"incident_id": "INC-BAD",
"task_type": "task1",
"alert_text": "Broken test ticket",
"context": {},
"ground_truth": {},
}
]
)
def test_step_raises_clear_dataset_error_for_invalid_ground_truth(self) -> None:
env = IncidentEnv()
env.current_ticket = {
"incident_id": "INC-BAD",
"task_type": "task1",
"alert_text": "Broken test ticket",
"context": {},
"ground_truth": {},
}
env.episode_id = "episode-bad"
with self.assertRaisesRegex(RuntimeError, "dataset integrity error"):
env.step(
IncidentAction(
incident_id="INC-BAD",
task_type="task1",
severity="SEV1",
)
)
def test_lifespan_shutdown_clears_session_stores(self) -> None:
sessions["active-session"] = IncidentEnv()
completed_states["done-session"] = IncidentState(
episode_id="episode-1",
session_id="done-session",
step_count=1,
max_steps=1,
total_reward=0.99,
done=True,
incident_id="INC-001",
task_type=TaskType.TASK1,
difficulty="easy",
status="completed",
last_reward=0.99,
)
with TestClient(app) as client:
response = client.get("/health")
self.assertEqual(response.status_code, 200)
self.assertEqual(sessions, {})
self.assertEqual(completed_states, {})
if __name__ == "__main__":
unittest.main()