Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| """ | |
| OpenEnv Validation CLI Tool | |
| Usage: | |
| python -m src.adaptive_alert_triage.validate | |
| openenv validate (if registered as entry point) | |
| Validates that the Adaptive Alert Triage environment meets the full OpenEnv | |
| interface specification: | |
| 1. Typed Observation, Action, and Reward Pydantic models | |
| 2. step(action) → returns (observation, reward, done, info) | |
| 3. reset() → returns initial observation | |
| 4. state() → returns current state | |
| 5. openenv.yaml with metadata | |
| """ | |
| import sys | |
| import json | |
| from pathlib import Path | |
| from typing import Dict, List, Tuple | |
| import yaml | |
| from adaptive_alert_triage.env import AdaptiveAlertTriageEnv | |
| from adaptive_alert_triage.models import ( | |
| Action, | |
| Observation, | |
| Reward, | |
| Alert, | |
| EpisodeState, | |
| ) | |
| class OpenEnvValidator: | |
| """Validates OpenEnv compliance of the environment.""" | |
| def __init__(self, verbose: bool = True): | |
| self.verbose = verbose | |
| self.checks_passed = [] | |
| self.checks_failed = [] | |
| def log(self, message: str, level: str = "INFO"): | |
| """Log a message with level.""" | |
| if self.verbose: | |
| print(f"[{level}] {message}") | |
| def check(self, name: str, condition: bool, details: str = "") -> bool: | |
| """Record a check result.""" | |
| if condition: | |
| self.checks_passed.append(name) | |
| self.log(f"✓ {name}", "PASS") | |
| if details: | |
| self.log(f" {details}", "INFO") | |
| return True | |
| else: | |
| self.checks_failed.append((name, details)) | |
| self.log(f"✗ {name}", "FAIL") | |
| if details: | |
| self.log(f" {details}", "ERROR") | |
| return False | |
| def validate_pydantic_models(self) -> bool: | |
| """1. Check that models are Pydantic BaseModels.""" | |
| self.log("\n=== Validating Pydantic Models ===", "INFO") | |
| from pydantic import BaseModel | |
| checks = [ | |
| ("Observation is Pydantic BaseModel", issubclass(Observation, BaseModel)), | |
| ("Action is Pydantic BaseModel", issubclass(Action, BaseModel)), | |
| ("Reward is Pydantic BaseModel", issubclass(Reward, BaseModel)), | |
| ("EpisodeState is Pydantic BaseModel", issubclass(EpisodeState, BaseModel)), | |
| ("Alert is Pydantic BaseModel", issubclass(Alert, BaseModel)), | |
| ] | |
| return all(self.check(name, cond) for name, cond in checks) | |
| def validate_required_fields(self) -> bool: | |
| """Check that models have required fields.""" | |
| self.log("\n=== Validating Model Fields ===", "INFO") | |
| checks = [ | |
| ( | |
| "Observation has required fields", | |
| {"alerts", "system_load", "queue_length", "time_remaining", "episode_step"}.issubset( | |
| set(Observation.model_fields.keys()) | |
| ), | |
| f"Fields: {', '.join(sorted(Observation.model_fields.keys()))}" | |
| ), | |
| ( | |
| "Action has required fields", | |
| {"alert_id", "action_type"}.issubset(set(Action.model_fields.keys())), | |
| f"Fields: {', '.join(sorted(Action.model_fields.keys()))}" | |
| ), | |
| ( | |
| "Reward has required fields", | |
| {"value", "components"}.issubset(set(Reward.model_fields.keys())), | |
| f"Fields: {', '.join(sorted(Reward.model_fields.keys()))}" | |
| ), | |
| ] | |
| return all(self.check(name, cond, details) for name, cond, details in checks) | |
| def validate_serialization(self) -> bool: | |
| """Check that models can be serialized/deserialized.""" | |
| self.log("\n=== Validating Serialization ===", "INFO") | |
| try: | |
| # Test Action serialization | |
| action = Action(alert_id="test", action_type="INVESTIGATE") | |
| json_str = action.model_dump_json() | |
| restored = Action.model_validate_json(json_str) | |
| action_ok = restored.alert_id == action.alert_id | |
| self.check("Action serialization round-trip", action_ok) | |
| # Test Reward serialization | |
| reward = Reward(value=10.0, components={"test": 10.0}) | |
| json_str = reward.model_dump_json() | |
| restored = Reward.model_validate_json(json_str) | |
| reward_ok = restored.value == reward.value | |
| self.check("Reward serialization round-trip", reward_ok) | |
| return action_ok and reward_ok | |
| except Exception as e: | |
| self.check("Serialization", False, str(e)) | |
| return False | |
| def validate_reset_method(self) -> bool: | |
| """2. Check reset() method.""" | |
| self.log("\n=== Validating reset() Method ===", "INFO") | |
| try: | |
| env = AdaptiveAlertTriageEnv(task_id="easy", seed=42) | |
| # Check method exists | |
| has_method = hasattr(env, "reset") | |
| self.check("reset() method exists", has_method) | |
| if not has_method: | |
| return False | |
| # Check return type | |
| obs = env.reset() | |
| returns_observation = isinstance(obs, Observation) | |
| self.check("reset() returns Observation", returns_observation) | |
| # Check reproducibility | |
| env2 = AdaptiveAlertTriageEnv(task_id="easy") | |
| obs2 = env2.reset(seed=42) | |
| is_reproducible = len(env.alerts) == len(env2.alerts) | |
| self.check("reset() is reproducible with seed", is_reproducible) | |
| return has_method and returns_observation and is_reproducible | |
| except Exception as e: | |
| self.check("reset() validation", False, str(e)) | |
| return False | |
| def validate_step_method(self) -> bool: | |
| """3. Check step(action) method.""" | |
| self.log("\n=== Validating step() Method ===", "INFO") | |
| try: | |
| env = AdaptiveAlertTriageEnv(task_id="easy", seed=42) | |
| obs = env.reset() | |
| # Check method exists | |
| has_method = hasattr(env, "step") | |
| self.check("step() method exists", has_method) | |
| if not has_method or not obs.alerts: | |
| return False | |
| # Take a step | |
| action = Action(alert_id=obs.alerts[0].id, action_type="INVESTIGATE") | |
| result = env.step(action) | |
| # Check return type is tuple | |
| is_tuple = isinstance(result, tuple) | |
| self.check("step() returns tuple", is_tuple) | |
| if not is_tuple: | |
| return False | |
| # Check tuple length | |
| correct_length = len(result) == 4 | |
| self.check("step() returns 4-tuple", correct_length, f"Got {len(result)} elements") | |
| if not correct_length: | |
| return False | |
| next_obs, reward, done, info = result | |
| # Check return types | |
| obs_ok = isinstance(next_obs, Observation) | |
| self.check("step() returns Observation", obs_ok) | |
| reward_ok = isinstance(reward, Reward) | |
| self.check("step() returns Reward", reward_ok) | |
| done_ok = isinstance(done, bool) | |
| self.check("step() returns bool (done)", done_ok) | |
| info_ok = isinstance(info, dict) | |
| self.check("step() returns dict (info)", info_ok) | |
| # Check info contents | |
| if info_ok: | |
| has_processed_alerts = "processed_alerts" in info | |
| self.check( | |
| "info contains 'processed_alerts'", | |
| has_processed_alerts, | |
| f"Keys: {', '.join(sorted(info.keys()))}" | |
| ) | |
| has_correlation_groups = "correlation_groups" in info | |
| self.check("info contains 'correlation_groups'", has_correlation_groups) | |
| return obs_ok and reward_ok and done_ok and info_ok | |
| except Exception as e: | |
| self.check("step() validation", False, str(e)) | |
| return False | |
| def validate_state_method(self) -> bool: | |
| """4. Check state() method.""" | |
| self.log("\n=== Validating state() Method ===", "INFO") | |
| try: | |
| env = AdaptiveAlertTriageEnv(task_id="easy", seed=42) | |
| env.reset() | |
| # Check method exists | |
| has_method = hasattr(env, "state") | |
| self.check("state() method exists", has_method) | |
| if not has_method: | |
| return False | |
| # Get state | |
| state = env.state() | |
| # Check return type | |
| is_episode_state = isinstance(state, EpisodeState) | |
| self.check("state() returns EpisodeState", is_episode_state) | |
| if not is_episode_state: | |
| return False | |
| # Check required attributes | |
| has_observation = hasattr(state, "observation") and isinstance(state.observation, Observation) | |
| self.check("EpisodeState has observation (Observation)", has_observation) | |
| has_hidden_state = hasattr(state, "hidden_state") and isinstance(state.hidden_state, dict) | |
| self.check("EpisodeState has hidden_state (dict)", has_hidden_state) | |
| if has_hidden_state: | |
| has_true_severities = "true_severities" in state.hidden_state | |
| self.check("hidden_state contains true_severities", has_true_severities) | |
| has_correlation_groups = "correlation_groups" in state.hidden_state | |
| self.check("hidden_state contains correlation_groups", has_correlation_groups) | |
| has_cumulative_reward = hasattr(state, "cumulative_reward") | |
| self.check("EpisodeState has cumulative_reward", has_cumulative_reward) | |
| return is_episode_state and has_observation and has_hidden_state | |
| except Exception as e: | |
| self.check("state() validation", False, str(e)) | |
| return False | |
| def validate_openenv_yaml(self) -> bool: | |
| """5. Check openenv.yaml metadata.""" | |
| self.log("\n=== Validating openenv.yaml ===", "INFO") | |
| try: | |
| yaml_path = Path("openenv.yaml") | |
| # Check file exists | |
| exists = yaml_path.exists() | |
| self.check("openenv.yaml exists", exists, str(yaml_path.absolute())) | |
| if not exists: | |
| return False | |
| # Check valid YAML | |
| with open(yaml_path) as f: | |
| data = yaml.safe_load(f) | |
| is_dict = isinstance(data, dict) | |
| self.check("openenv.yaml is valid YAML dict", is_dict) | |
| if not is_dict: | |
| return False | |
| # Check required fields | |
| required_fields = { | |
| ("name", "Environment name"), | |
| ("version", "Version string"), | |
| ("description", "Description"), | |
| ("tasks", "Task definitions"), | |
| } | |
| all_present = True | |
| for field, description in required_fields: | |
| present = field in data | |
| self.check(f"'{field}' present ({description})", present) | |
| all_present = all_present and present | |
| # Check tasks structure | |
| if "tasks" in data: | |
| tasks = data["tasks"] | |
| is_list = isinstance(tasks, list) | |
| self.check("tasks is a list", is_list, f"Got {type(tasks)}") | |
| if is_list: | |
| has_tasks = len(tasks) > 0 | |
| self.check("tasks list is not empty", has_tasks, f"{len(tasks)} tasks defined") | |
| # Check each task has ID | |
| all_have_ids = all("id" in task for task in tasks) | |
| task_ids = [task.get("id", "?") for task in tasks] | |
| self.check("all tasks have 'id'", all_have_ids, f"IDs: {', '.join(task_ids)}") | |
| # Check config section | |
| has_config = "config" in data | |
| self.check("'config' section present", has_config) | |
| if has_config and "actions" in data["config"]: | |
| expected_actions = {"INVESTIGATE", "IGNORE", "ESCALATE", "DELAY"} | |
| yaml_actions = set(data["config"]["actions"]) | |
| has_all_actions = expected_actions.issubset(yaml_actions) | |
| self.check("config.actions includes all required actions", has_all_actions, | |
| f"Found: {', '.join(sorted(yaml_actions))}") | |
| return all_present | |
| except Exception as e: | |
| self.check("openenv.yaml validation", False, str(e)) | |
| return False | |
| def validate_all_tasks(self) -> bool: | |
| """Verify all tasks work correctly.""" | |
| self.log("\n=== Validating All Tasks ===", "INFO") | |
| try: | |
| all_ok = True | |
| for task_id in ["easy", "medium", "hard"]: | |
| try: | |
| env = AdaptiveAlertTriageEnv(task_id=task_id, seed=42) | |
| obs = env.reset() | |
| # Verify structure | |
| obs_ok = isinstance(obs, Observation) | |
| # Take one step | |
| if obs.alerts: | |
| action = Action(alert_id=obs.alerts[0].id, action_type="INVESTIGATE") | |
| next_obs, reward, done, info = env.step(action) | |
| step_ok = ( | |
| isinstance(next_obs, Observation) and | |
| isinstance(reward, Reward) and | |
| isinstance(done, bool) and | |
| isinstance(info, dict) | |
| ) | |
| else: | |
| step_ok = True | |
| # Get state | |
| state_ok = isinstance(env.state(), EpisodeState) | |
| task_ok = obs_ok and step_ok and state_ok | |
| self.check(f"Task '{task_id}' is OpenEnv compliant", task_ok) | |
| all_ok = all_ok and task_ok | |
| except Exception as e: | |
| self.check(f"Task '{task_id}' is OpenEnv compliant", False, str(e)) | |
| all_ok = False | |
| return all_ok | |
| except Exception as e: | |
| self.check("Task validation", False, str(e)) | |
| return False | |
| def run_all_checks(self) -> bool: | |
| """Run all validation checks.""" | |
| self.log("=" * 60) | |
| self.log("OpenEnv Compliance Validator", "INFO") | |
| self.log("=" * 60) | |
| results = [ | |
| self.validate_pydantic_models(), | |
| self.validate_required_fields(), | |
| self.validate_serialization(), | |
| self.validate_reset_method(), | |
| self.validate_step_method(), | |
| self.validate_state_method(), | |
| self.validate_openenv_yaml(), | |
| self.validate_all_tasks(), | |
| ] | |
| # Print summary | |
| self.log("\n" + "=" * 60, "INFO") | |
| self.log("VALIDATION SUMMARY", "INFO") | |
| self.log("=" * 60, "INFO") | |
| total_passed = len(self.checks_passed) | |
| total_failed = len(self.checks_failed) | |
| total_checks = total_passed + total_failed | |
| self.log(f"Passed: {total_passed}/{total_checks}", "INFO") | |
| if self.checks_failed: | |
| self.log(f"Failed: {total_failed}/{total_checks}", "ERROR") | |
| for name, details in self.checks_failed: | |
| self.log(f" - {name}", "ERROR") | |
| if details: | |
| self.log(f" {details}", "ERROR") | |
| else: | |
| self.log("All checks passed! ✓", "PASS") | |
| self.log("=" * 60 + "\n", "INFO") | |
| return len(self.checks_failed) == 0 | |
| def main(): | |
| """Entry point for CLI.""" | |
| validator = OpenEnvValidator(verbose=True) | |
| success = validator.run_all_checks() | |
| # Return appropriate exit code | |
| sys.exit(0 if success else 1) | |
| if __name__ == "__main__": | |
| main() |