""" Comprehensive OpenEnv Compliance Test Suite Validates that all OpenEnv interface requirements are met: 1. Typed Observation, Action, and Reward Pydantic models 2. step(action) → returns (observation, reward, done, info) 3. reset() → returns initial observation 4. state() → returns current state 5. openenv.yaml with metadata 6. Tested via openenv validate Run with: pytest tests/test_openenv_compliance.py -v """ import json from pathlib import Path from typing import Any, Dict import pytest import yaml from pydantic import BaseModel from adaptive_alert_triage.env import AdaptiveAlertTriageEnv from adaptive_alert_triage.models import ( Action, Observation, Reward, Alert, EpisodeState, ActionType, AlertType, ) # ============================================================================ # REQUIREMENT 1: Typed Pydantic Models # ============================================================================ class TestPydanticModels: """Verify Observation, Action, and Reward are properly typed Pydantic models.""" def test_observation_is_pydantic_model(self): """Observation must be a Pydantic BaseModel.""" assert issubclass(Observation, BaseModel), "Observation must inherit from Pydantic BaseModel" def test_action_is_pydantic_model(self): """Action must be a Pydantic BaseModel.""" assert issubclass(Action, BaseModel), "Action must inherit from Pydantic BaseModel" def test_reward_is_pydantic_model(self): """Reward must be a Pydantic BaseModel.""" assert issubclass(Reward, BaseModel), "Reward must inherit from Pydantic BaseModel" def test_episode_state_is_pydantic_model(self): """EpisodeState must be a Pydantic BaseModel.""" assert issubclass(EpisodeState, BaseModel), "EpisodeState must inherit from Pydantic BaseModel" def test_alert_is_pydantic_model(self): """Alert must be a Pydantic BaseModel.""" assert issubclass(Alert, BaseModel), "Alert must inherit from Pydantic BaseModel" def test_observation_has_required_fields(self): """Observation must have all required fields.""" required_fields = {"alerts", "system_load", "queue_length", "time_remaining", "episode_step", "resource_budget"} model_fields = set(Observation.model_fields.keys()) assert required_fields.issubset(model_fields), f"Missing fields: {required_fields - model_fields}" def test_action_has_required_fields(self): """Action must have alert_id and action_type.""" required_fields = {"alert_id", "action_type"} model_fields = set(Action.model_fields.keys()) assert required_fields.issubset(model_fields), f"Missing fields: {required_fields - model_fields}" def test_reward_has_required_fields(self): """Reward must have value and components.""" required_fields = {"value", "components"} model_fields = set(Reward.model_fields.keys()) assert required_fields.issubset(model_fields), f"Missing fields: {required_fields - model_fields}" def test_action_type_is_literal(self): """Validate ActionType literal values.""" valid_actions = {"INVESTIGATE", "IGNORE", "ESCALATE", "DELAY"} # Create an action with each type to verify validation for action_type in valid_actions: action = Action(alert_id="test", action_type=action_type) assert action.action_type == action_type def test_alert_type_is_literal(self): """Validate AlertType literal values.""" valid_types = {"CPU", "MEMORY", "DISK", "NETWORK", "APPLICATION", "SECURITY"} # Create an alert with each type for alert_type in valid_types: alert = Alert( id="test", visible_severity=0.5, confidence=0.8, alert_type=alert_type, age=0, ) assert alert.alert_type == alert_type def test_observation_serialization(self): """Observation must be JSON serializable.""" env = AdaptiveAlertTriageEnv(task_id="easy", seed=42) obs = env.reset() # Should be able to serialize to model_dump_json json_str = obs.model_dump_json() assert isinstance(json_str, str) # Should be able to parse back parsed = json.loads(json_str) assert isinstance(parsed, dict) def test_action_serialization(self): """Action must be JSON serializable.""" action = Action(alert_id="alert_001", action_type="INVESTIGATE") json_str = action.model_dump_json() assert isinstance(json_str, str) parsed = json.loads(json_str) assert parsed["alert_id"] == "alert_001" assert parsed["action_type"] == "INVESTIGATE" def test_reward_serialization(self): """Reward must be JSON serializable.""" reward = Reward( value=10.0, components={"critical_handled": 10.0}, info={"alert_id": "alert_001"} ) json_str = reward.model_dump_json() assert isinstance(json_str, str) parsed = json.loads(json_str) assert parsed["value"] == 10.0 # ============================================================================ # REQUIREMENT 2: step(action) → (observation, reward, done, info) # ============================================================================ class TestStepInterface: """Verify step() method signature and return types.""" def test_step_exists(self): """Environment must have a step method.""" env = AdaptiveAlertTriageEnv(task_id="easy", seed=42) assert hasattr(env, "step"), "Environment must have step() method" def test_step_accepts_action(self): """step() must accept an Action parameter.""" env = AdaptiveAlertTriageEnv(task_id="easy", seed=42) obs = env.reset() action = Action(alert_id=obs.alerts[0].id, action_type="INVESTIGATE") result = env.step(action) assert result is not None, "step() should return a value" def test_step_returns_tuple(self): """step() must return a tuple of 4 elements.""" env = AdaptiveAlertTriageEnv(task_id="easy", seed=42) obs = env.reset() action = Action(alert_id=obs.alerts[0].id, action_type="INVESTIGATE") result = env.step(action) assert isinstance(result, tuple), "step() must return a tuple" assert len(result) == 4, "step() must return exactly 4 values" def test_step_returns_observation(self): """First return value must be Observation.""" env = AdaptiveAlertTriageEnv(task_id="easy", seed=42) obs = env.reset() action = Action(alert_id=obs.alerts[0].id, action_type="INVESTIGATE") next_obs, _, _, _ = env.step(action) assert isinstance(next_obs, Observation), "First return must be Observation" def test_step_returns_reward(self): """Second return value must be Reward.""" env = AdaptiveAlertTriageEnv(task_id="easy", seed=42) obs = env.reset() action = Action(alert_id=obs.alerts[0].id, action_type="INVESTIGATE") _, reward, _, _ = env.step(action) assert isinstance(reward, Reward), "Second return must be Reward" def test_step_returns_done(self): """Third return value must be bool (done flag).""" env = AdaptiveAlertTriageEnv(task_id="easy", seed=42) obs = env.reset() action = Action(alert_id=obs.alerts[0].id, action_type="INVESTIGATE") _, _, done, _ = env.step(action) assert isinstance(done, bool), "Third return must be boolean (done flag)" def test_step_returns_info(self): """Fourth return value must be dict (info).""" env = AdaptiveAlertTriageEnv(task_id="easy", seed=42) obs = env.reset() action = Action(alert_id=obs.alerts[0].id, action_type="INVESTIGATE") _, _, _, info = env.step(action) assert isinstance(info, dict), "Fourth return must be a dictionary (info)" def test_info_contains_processed_alerts(self): """info dict must contain processed_alerts.""" env = AdaptiveAlertTriageEnv(task_id="easy", seed=42) obs = env.reset() action = Action(alert_id=obs.alerts[0].id, action_type="INVESTIGATE") _, _, _, info = env.step(action) assert "processed_alerts" in info, "info must contain 'processed_alerts'" assert isinstance(info["processed_alerts"], list), "processed_alerts must be a list" def test_info_contains_correlation_groups(self): """info dict must contain correlation_groups.""" env = AdaptiveAlertTriageEnv(task_id="hard", seed=42) obs = env.reset() action = Action(alert_id=obs.alerts[0].id, action_type="INVESTIGATE") _, _, _, info = env.step(action) assert "correlation_groups" in info, "info must contain 'correlation_groups'" assert isinstance(info["correlation_groups"], list), "correlation_groups must be a list" def test_info_contains_system_failure(self): """info dict should indicate system failure state.""" env = AdaptiveAlertTriageEnv(task_id="easy", seed=42) obs = env.reset() action = Action(alert_id=obs.alerts[0].id, action_type="INVESTIGATE") _, _, _, info = env.step(action) assert "system_failure" in info, "info should contain 'system_failure'" def test_reward_has_value(self): """Reward must have a numeric value.""" env = AdaptiveAlertTriageEnv(task_id="easy", seed=42) obs = env.reset() action = Action(alert_id=obs.alerts[0].id, action_type="INVESTIGATE") _, reward, _, _ = env.step(action) assert isinstance(reward.value, (int, float)), "Reward.value must be numeric" def test_observation_updated_after_step(self): """Observation should normally change after step().""" env = AdaptiveAlertTriageEnv(task_id="easy", seed=42) obs_before = env.reset() action = Action(alert_id=obs_before.alerts[0].id, action_type="INVESTIGATE") obs_after, _, _, _ = env.step(action) # Episode step should have incremented assert obs_after.episode_step == obs_before.episode_step + 1 # ============================================================================ # REQUIREMENT 3: reset() → Observation # ============================================================================ class TestResetInterface: """Verify reset() method signature and return type.""" def test_reset_exists(self): """Environment must have a reset method.""" env = AdaptiveAlertTriageEnv(task_id="easy", seed=42) assert hasattr(env, "reset"), "Environment must have reset() method" def test_reset_returns_observation(self): """reset() must return an Observation.""" env = AdaptiveAlertTriageEnv(task_id="easy", seed=42) obs = env.reset() assert isinstance(obs, Observation), "reset() must return an Observation" def test_reset_accepts_seed(self): """reset() should accept optional seed parameter.""" env = AdaptiveAlertTriageEnv(task_id="easy") obs = env.reset(seed=42) assert isinstance(obs, Observation), "reset(seed=...) should return Observation" def test_reset_accepts_options(self): """reset() should accept optional options parameter.""" env = AdaptiveAlertTriageEnv(task_id="easy", seed=42) obs = env.reset(options={}) assert isinstance(obs, Observation), "reset(options=...) should return Observation" def test_reset_reproducibility(self): """Same seed should produce same initial observation.""" env1 = AdaptiveAlertTriageEnv(task_id="easy") obs1 = env1.reset(seed=42) env2 = AdaptiveAlertTriageEnv(task_id="easy") obs2 = env2.reset(seed=42) assert len(obs1.alerts) == len(obs2.alerts), "Same seed should produce same number of alerts" def test_reset_clears_episode_state(self): """reset() should clear episode state between calls.""" env = AdaptiveAlertTriageEnv(task_id="easy", seed=42) obs1 = env.reset() assert obs1.episode_step == 0, "Initial episode_step should be 0" # Take a step if obs1.alerts: action = Action(alert_id=obs1.alerts[0].id, action_type="INVESTIGATE") _, _, _, _ = env.step(action) # Reset again obs2 = env.reset(seed=99) assert obs2.episode_step == 0, "After reset, episode_step should be 0 again" # ============================================================================ # REQUIREMENT 4: state() → EpisodeState # ============================================================================ class TestStateInterface: """Verify state() method and return type.""" def test_state_exists(self): """Environment must have a state method.""" env = AdaptiveAlertTriageEnv(task_id="easy", seed=42) assert hasattr(env, "state"), "Environment must have state() method" def test_state_returns_episode_state(self): """state() must return an EpisodeState.""" env = AdaptiveAlertTriageEnv(task_id="easy", seed=42) env.reset() state = env.state() assert isinstance(state, EpisodeState), "state() must return an EpisodeState" def test_episode_state_contains_observation(self): """EpisodeState must contain current observation.""" env = AdaptiveAlertTriageEnv(task_id="easy", seed=42) env.reset() state = env.state() assert hasattr(state, "observation"), "EpisodeState must have observation" assert isinstance(state.observation, Observation), "observation must be an Observation" def test_episode_state_contains_hidden_state(self): """EpisodeState must contain hidden_state dict.""" env = AdaptiveAlertTriageEnv(task_id="easy", seed=42) env.reset() state = env.state() assert hasattr(state, "hidden_state"), "EpisodeState must have hidden_state" assert isinstance(state.hidden_state, dict), "hidden_state must be a dict" def test_hidden_state_contains_true_severities(self): """hidden_state must contain true_severities mapping.""" env = AdaptiveAlertTriageEnv(task_id="easy", seed=42) env.reset() state = env.state() assert "true_severities" in state.hidden_state, "hidden_state must contain true_severities" def test_hidden_state_contains_correlation_groups(self): """hidden_state must contain correlation_groups.""" env = AdaptiveAlertTriageEnv(task_id="hard", seed=42) env.reset() state = env.state() assert "correlation_groups" in state.hidden_state, "hidden_state must contain correlation_groups" def test_episode_state_contains_cumulative_reward(self): """EpisodeState must track cumulative_reward.""" env = AdaptiveAlertTriageEnv(task_id="easy", seed=42) env.reset() state = env.state() assert hasattr(state, "cumulative_reward"), "EpisodeState must have cumulative_reward" assert isinstance(state.cumulative_reward, (int, float)), "cumulative_reward must be numeric" def test_episode_state_contains_failures_count(self): """EpisodeState must track failures count.""" env = AdaptiveAlertTriageEnv(task_id="easy", seed=42) env.reset() state = env.state() assert hasattr(state, "failures_count"), "EpisodeState must have failures_count" assert isinstance(state.failures_count, int), "failures_count must be an integer" def test_episode_state_tracks_actions_taken(self): """EpisodeState should track actions taken.""" env = AdaptiveAlertTriageEnv(task_id="easy", seed=42) obs = env.reset() state_before = env.state() initial_action_count = len(state_before.actions_taken) # Take an action if obs.alerts: action = Action(alert_id=obs.alerts[0].id, action_type="INVESTIGATE") _, _, _, _ = env.step(action) state_after = env.state() assert len(state_after.actions_taken) >= initial_action_count, "actions_taken should accumulate" # ============================================================================ # REQUIREMENT 5: openenv.yaml with metadata # ============================================================================ class TestOpenEnvYAML: """Verify openenv.yaml provides required metadata.""" def test_openenv_yaml_exists(self): """openenv.yaml must exist in project root.""" yaml_path = Path("openenv.yaml") assert yaml_path.exists(), f"openenv.yaml must exist at {yaml_path}" def test_openenv_yaml_is_valid_yaml(self): """openenv.yaml must be valid YAML.""" with open("openenv.yaml") as f: data = yaml.safe_load(f) assert isinstance(data, dict), "openenv.yaml must parse to a dictionary" def test_openenv_yaml_has_name(self): """openenv.yaml must have a 'name' field.""" with open("openenv.yaml") as f: data = yaml.safe_load(f) assert "name" in data, "openenv.yaml must have 'name' field" def test_openenv_yaml_has_version(self): """openenv.yaml must have a 'version' field.""" with open("openenv.yaml") as f: data = yaml.safe_load(f) assert "version" in data, "openenv.yaml must have 'version' field" def test_openenv_yaml_has_description(self): """openenv.yaml must have a 'description' field.""" with open("openenv.yaml") as f: data = yaml.safe_load(f) assert "description" in data, "openenv.yaml must have 'description' field" def test_openenv_yaml_has_tasks(self): """openenv.yaml must define tasks.""" with open("openenv.yaml") as f: data = yaml.safe_load(f) assert "tasks" in data, "openenv.yaml must have 'tasks' section" assert isinstance(data["tasks"], list), "tasks must be a list" assert len(data["tasks"]) > 0, "tasks list must not be empty" def test_openenv_yaml_tasks_have_ids(self): """Each task must have an 'id' field.""" with open("openenv.yaml") as f: data = yaml.safe_load(f) for task in data["tasks"]: assert "id" in task, f"Task missing 'id' field: {task}" def test_openenv_yaml_has_config(self): """openenv.yaml should have a 'config' section.""" with open("openenv.yaml") as f: data = yaml.safe_load(f) assert "config" in data, "openenv.yaml should have 'config' section" def test_openenv_yaml_config_has_actions(self): """config should define available actions.""" with open("openenv.yaml") as f: data = yaml.safe_load(f) assert "actions" in data["config"], "config must define 'actions'" expected_actions = {"INVESTIGATE", "IGNORE", "ESCALATE", "DELAY"} yaml_actions = set(data["config"]["actions"]) assert expected_actions.issubset(yaml_actions), f"config must include all standard actions" # ============================================================================ # REQUIREMENT 6: Validation Testing # ============================================================================ class TestOpenEnvValidation: """End-to-end OpenEnv compliance validation.""" def test_full_episode_workflow(self): """Complete episode following OpenEnv spec should work.""" env = AdaptiveAlertTriageEnv(task_id="easy", seed=42) # 1. Reset to get initial observation obs = env.reset() assert isinstance(obs, Observation) # 2. Run episode steps done = False episode_steps = 0 max_allowed_steps = env.max_steps + 5 # Allow some buffer while not done and episode_steps < max_allowed_steps: if not obs.alerts: break # Take an action action = Action(alert_id=obs.alerts[0].id, action_type="INVESTIGATE") obs, reward, done, info = env.step(action) # Validate return types assert isinstance(obs, Observation) assert isinstance(reward, Reward) assert isinstance(done, bool) assert isinstance(info, dict) episode_steps += 1 # 3. Get final state final_state = env.state() assert isinstance(final_state, EpisodeState) def test_all_task_difficulties(self): """All task difficulties should be OpenEnv compliant.""" for task_id in ["easy", "medium", "hard"]: env = AdaptiveAlertTriageEnv(task_id=task_id, seed=42) # Reset obs = env.reset() assert isinstance(obs, Observation), f"reset() failed for {task_id}" # Step if obs.alerts: action = Action(alert_id=obs.alerts[0].id, action_type="INVESTIGATE") obs, reward, done, info = env.step(action) assert isinstance(obs, Observation) assert isinstance(reward, Reward) assert isinstance(done, bool) assert isinstance(info, dict) # State state = env.state() assert isinstance(state, EpisodeState), f"state() failed for {task_id}" def test_pydantic_validation(self): """Pydantic models should validate their fields.""" # Invalid action type should fail validation with pytest.raises(Exception): Action(alert_id="test", action_type="INVALID_ACTION") # Invalid alert type should fail validation with pytest.raises(Exception): Alert( id="test", visible_severity=0.5, confidence=0.8, alert_type="INVALID_TYPE", age=0, ) def test_serialization_round_trip(self): """Models should serialize/deserialize without data loss.""" action = Action( alert_id="alert_123", action_type="INVESTIGATE", metadata={"reason": "high severity"} ) # Serialize json_str = action.model_dump_json() # Deserialize restored = Action.model_validate_json(json_str) assert restored.alert_id == action.alert_id assert restored.action_type == action.action_type assert restored.metadata == action.metadata if __name__ == "__main__": pytest.main([__file__, "-v"])