Spaces:
Sleeping
Sleeping
Commit ·
d6fcbb0
1
Parent(s): 7cd2458
ee
Browse files- tasks/easy.py +5 -11
- tasks/hard.py +4 -14
- tasks/medium.py +5 -15
- validate.py +0 -431
tasks/easy.py
CHANGED
|
@@ -161,21 +161,15 @@ class EasyTaskGrader:
|
|
| 161 |
|
| 162 |
def get_episode_score(self) -> float:
|
| 163 |
"""
|
| 164 |
-
Return final normalised score in (0, 1).
|
| 165 |
-
|
| 166 |
-
Formula: 0.01 + 0.98 * (correct_actions / total_actions)
|
| 167 |
-
This ensures the score is always strictly between 0 and 1 as
|
| 168 |
-
required by the grading system.
|
| 169 |
"""
|
| 170 |
if self.total_actions == 0:
|
| 171 |
-
return 0.
|
| 172 |
|
| 173 |
raw = self.correct_actions / self.total_actions
|
| 174 |
-
#
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
# Ensure no rounding to boundaries (0.0 or 1.0)
|
| 178 |
-
return max(0.01, min(rounded, 0.99))
|
| 179 |
|
| 180 |
|
| 181 |
def passed(self) -> bool:
|
|
|
|
| 161 |
|
| 162 |
def get_episode_score(self) -> float:
|
| 163 |
"""
|
| 164 |
+
Return final normalised score strictly in (0, 1) — never 0.0 or 1.0.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
"""
|
| 166 |
if self.total_actions == 0:
|
| 167 |
+
return 0.5
|
| 168 |
|
| 169 |
raw = self.correct_actions / self.total_actions
|
| 170 |
+
# Map [0,1] -> (0,1) with a small epsilon margin, no rounding
|
| 171 |
+
score = 0.001 + 0.998 * float(raw)
|
| 172 |
+
return max(0.001, min(0.999, score))
|
|
|
|
|
|
|
| 173 |
|
| 174 |
|
| 175 |
def passed(self) -> bool:
|
tasks/hard.py
CHANGED
|
@@ -372,19 +372,11 @@ class HardTaskGrader:
|
|
| 372 |
|
| 373 |
def get_episode_score(self) -> float:
|
| 374 |
"""
|
| 375 |
-
Return final normalised score in (0, 1).
|
| 376 |
-
|
| 377 |
-
Formula:
|
| 378 |
-
chain_score = Σ chain.outcome_score()
|
| 379 |
-
stability = _stability_score(system_failures)
|
| 380 |
-
base = (raw * stability)
|
| 381 |
-
clamped = 0.01 + 0.98 * base
|
| 382 |
"""
|
| 383 |
-
# Chain component
|
| 384 |
chain_score = sum(c.outcome_score() for c in self._chains.values())
|
| 385 |
max_chain = sum(c.max_possible() for c in self._chains.values())
|
| 386 |
|
| 387 |
-
# Isolation bonus (capped)
|
| 388 |
isolation = min(
|
| 389 |
self._isolation_correct * _ISOLATION_BONUS_PER_ALERT,
|
| 390 |
_ISOLATION_BONUS_CAP,
|
|
@@ -396,11 +388,9 @@ class HardTaskGrader:
|
|
| 396 |
stability = self._stability_score(self._system_failures)
|
| 397 |
final_base = max(0.0, min(raw * stability, 1.0))
|
| 398 |
|
| 399 |
-
#
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
# Ensure no rounding to boundaries (0.0 or 1.0)
|
| 403 |
-
return max(0.01, min(rounded, 0.99))
|
| 404 |
|
| 405 |
|
| 406 |
def passed(self) -> bool:
|
|
|
|
| 372 |
|
| 373 |
def get_episode_score(self) -> float:
|
| 374 |
"""
|
| 375 |
+
Return final normalised score strictly in (0, 1) — never 0.0 or 1.0.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
"""
|
|
|
|
| 377 |
chain_score = sum(c.outcome_score() for c in self._chains.values())
|
| 378 |
max_chain = sum(c.max_possible() for c in self._chains.values())
|
| 379 |
|
|
|
|
| 380 |
isolation = min(
|
| 381 |
self._isolation_correct * _ISOLATION_BONUS_PER_ALERT,
|
| 382 |
_ISOLATION_BONUS_CAP,
|
|
|
|
| 388 |
stability = self._stability_score(self._system_failures)
|
| 389 |
final_base = max(0.0, min(raw * stability, 1.0))
|
| 390 |
|
| 391 |
+
# Map [0,1] -> (0,1) with a small epsilon margin, no rounding
|
| 392 |
+
score = 0.001 + 0.998 * float(final_base)
|
| 393 |
+
return max(0.001, min(0.999, score))
|
|
|
|
|
|
|
| 394 |
|
| 395 |
|
| 396 |
def passed(self) -> bool:
|
tasks/medium.py
CHANGED
|
@@ -192,27 +192,19 @@ class MediumTaskGrader:
|
|
| 192 |
|
| 193 |
def get_episode_score(self) -> float:
|
| 194 |
"""
|
| 195 |
-
Return final normalised score in (0, 1).
|
| 196 |
-
|
| 197 |
-
Formula:
|
| 198 |
-
raw = resolved_score / max_possible_score
|
| 199 |
-
base = max(0.0, raw − fp_penalty − miss_penalty)
|
| 200 |
-
clamped = 0.01 + 0.98 * base
|
| 201 |
"""
|
| 202 |
if self._max_possible_score <= 0.0:
|
| 203 |
-
return 0.
|
| 204 |
|
| 205 |
-
# Normalised resolved quality
|
| 206 |
raw = min(self._resolved_score / self._max_possible_score, 1.0)
|
| 207 |
|
| 208 |
-
# Penalty 1: wasted investigation budget on false positives
|
| 209 |
if self._total_investigations > 0:
|
| 210 |
fp_rate = self._unnecessary_invest / self._total_investigations
|
| 211 |
else:
|
| 212 |
fp_rate = 0.0
|
| 213 |
fp_penalty = _FP_PENALTY_WEIGHT * fp_rate
|
| 214 |
|
| 215 |
-
# Penalty 2: missed critical alerts
|
| 216 |
if self._critical_total > 0:
|
| 217 |
miss_rate = min(self._critical_missed / self._critical_total, 1.0)
|
| 218 |
else:
|
|
@@ -220,11 +212,9 @@ class MediumTaskGrader:
|
|
| 220 |
miss_penalty = _CRITICAL_MISS_PENALTY_WEIGHT * miss_rate
|
| 221 |
|
| 222 |
base_score = max(0.0, raw - fp_penalty - miss_penalty)
|
| 223 |
-
#
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
# Ensure no rounding to boundaries (0.0 or 1.0)
|
| 227 |
-
return max(0.01, min(rounded, 0.99))
|
| 228 |
|
| 229 |
|
| 230 |
def passed(self) -> bool:
|
|
|
|
| 192 |
|
| 193 |
def get_episode_score(self) -> float:
|
| 194 |
"""
|
| 195 |
+
Return final normalised score strictly in (0, 1) — never 0.0 or 1.0.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
"""
|
| 197 |
if self._max_possible_score <= 0.0:
|
| 198 |
+
return 0.5
|
| 199 |
|
|
|
|
| 200 |
raw = min(self._resolved_score / self._max_possible_score, 1.0)
|
| 201 |
|
|
|
|
| 202 |
if self._total_investigations > 0:
|
| 203 |
fp_rate = self._unnecessary_invest / self._total_investigations
|
| 204 |
else:
|
| 205 |
fp_rate = 0.0
|
| 206 |
fp_penalty = _FP_PENALTY_WEIGHT * fp_rate
|
| 207 |
|
|
|
|
| 208 |
if self._critical_total > 0:
|
| 209 |
miss_rate = min(self._critical_missed / self._critical_total, 1.0)
|
| 210 |
else:
|
|
|
|
| 212 |
miss_penalty = _CRITICAL_MISS_PENALTY_WEIGHT * miss_rate
|
| 213 |
|
| 214 |
base_score = max(0.0, raw - fp_penalty - miss_penalty)
|
| 215 |
+
# Map [0,1] -> (0,1) with a small epsilon margin, no rounding
|
| 216 |
+
score = 0.001 + 0.998 * float(base_score)
|
| 217 |
+
return max(0.001, min(0.999, score))
|
|
|
|
|
|
|
| 218 |
|
| 219 |
|
| 220 |
def passed(self) -> bool:
|
validate.py
DELETED
|
@@ -1,431 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python
|
| 2 |
-
"""
|
| 3 |
-
OpenEnv Validation CLI Tool
|
| 4 |
-
|
| 5 |
-
Usage:
|
| 6 |
-
python -m src.adaptive_alert_triage.validate
|
| 7 |
-
openenv validate (if registered as entry point)
|
| 8 |
-
|
| 9 |
-
Validates that the Adaptive Alert Triage environment meets the full OpenEnv
|
| 10 |
-
interface specification:
|
| 11 |
-
1. Typed Observation, Action, and Reward Pydantic models
|
| 12 |
-
2. step(action) → returns (observation, reward, done, info)
|
| 13 |
-
3. reset() → returns initial observation
|
| 14 |
-
4. state() → returns current state
|
| 15 |
-
5. openenv.yaml with metadata
|
| 16 |
-
"""
|
| 17 |
-
|
| 18 |
-
import sys
|
| 19 |
-
import json
|
| 20 |
-
from pathlib import Path
|
| 21 |
-
from typing import Dict, List, Tuple
|
| 22 |
-
import yaml
|
| 23 |
-
|
| 24 |
-
from adaptive_alert_triage.env import AdaptiveAlertTriageEnv
|
| 25 |
-
from adaptive_alert_triage.models import (
|
| 26 |
-
Action,
|
| 27 |
-
Observation,
|
| 28 |
-
Reward,
|
| 29 |
-
Alert,
|
| 30 |
-
EpisodeState,
|
| 31 |
-
)
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
class OpenEnvValidator:
|
| 35 |
-
"""Validates OpenEnv compliance of the environment."""
|
| 36 |
-
|
| 37 |
-
def __init__(self, verbose: bool = True):
|
| 38 |
-
self.verbose = verbose
|
| 39 |
-
self.checks_passed = []
|
| 40 |
-
self.checks_failed = []
|
| 41 |
-
|
| 42 |
-
def log(self, message: str, level: str = "INFO"):
|
| 43 |
-
"""Log a message with level."""
|
| 44 |
-
if self.verbose:
|
| 45 |
-
print(f"[{level}] {message}")
|
| 46 |
-
|
| 47 |
-
def check(self, name: str, condition: bool, details: str = "") -> bool:
|
| 48 |
-
"""Record a check result."""
|
| 49 |
-
if condition:
|
| 50 |
-
self.checks_passed.append(name)
|
| 51 |
-
self.log(f"✓ {name}", "PASS")
|
| 52 |
-
if details:
|
| 53 |
-
self.log(f" {details}", "INFO")
|
| 54 |
-
return True
|
| 55 |
-
else:
|
| 56 |
-
self.checks_failed.append((name, details))
|
| 57 |
-
self.log(f"✗ {name}", "FAIL")
|
| 58 |
-
if details:
|
| 59 |
-
self.log(f" {details}", "ERROR")
|
| 60 |
-
return False
|
| 61 |
-
|
| 62 |
-
def validate_pydantic_models(self) -> bool:
|
| 63 |
-
"""1. Check that models are Pydantic BaseModels."""
|
| 64 |
-
self.log("\n=== Validating Pydantic Models ===", "INFO")
|
| 65 |
-
|
| 66 |
-
from pydantic import BaseModel
|
| 67 |
-
|
| 68 |
-
checks = [
|
| 69 |
-
("Observation is Pydantic BaseModel", issubclass(Observation, BaseModel)),
|
| 70 |
-
("Action is Pydantic BaseModel", issubclass(Action, BaseModel)),
|
| 71 |
-
("Reward is Pydantic BaseModel", issubclass(Reward, BaseModel)),
|
| 72 |
-
("EpisodeState is Pydantic BaseModel", issubclass(EpisodeState, BaseModel)),
|
| 73 |
-
("Alert is Pydantic BaseModel", issubclass(Alert, BaseModel)),
|
| 74 |
-
]
|
| 75 |
-
|
| 76 |
-
return all(self.check(name, cond) for name, cond in checks)
|
| 77 |
-
|
| 78 |
-
def validate_required_fields(self) -> bool:
|
| 79 |
-
"""Check that models have required fields."""
|
| 80 |
-
self.log("\n=== Validating Model Fields ===", "INFO")
|
| 81 |
-
|
| 82 |
-
checks = [
|
| 83 |
-
(
|
| 84 |
-
"Observation has required fields",
|
| 85 |
-
{"alerts", "system_load", "queue_length", "time_remaining", "episode_step"}.issubset(
|
| 86 |
-
set(Observation.model_fields.keys())
|
| 87 |
-
),
|
| 88 |
-
f"Fields: {', '.join(sorted(Observation.model_fields.keys()))}"
|
| 89 |
-
),
|
| 90 |
-
(
|
| 91 |
-
"Action has required fields",
|
| 92 |
-
{"alert_id", "action_type"}.issubset(set(Action.model_fields.keys())),
|
| 93 |
-
f"Fields: {', '.join(sorted(Action.model_fields.keys()))}"
|
| 94 |
-
),
|
| 95 |
-
(
|
| 96 |
-
"Reward has required fields",
|
| 97 |
-
{"value", "components"}.issubset(set(Reward.model_fields.keys())),
|
| 98 |
-
f"Fields: {', '.join(sorted(Reward.model_fields.keys()))}"
|
| 99 |
-
),
|
| 100 |
-
]
|
| 101 |
-
|
| 102 |
-
return all(self.check(name, cond, details) for name, cond, details in checks)
|
| 103 |
-
|
| 104 |
-
def validate_serialization(self) -> bool:
|
| 105 |
-
"""Check that models can be serialized/deserialized."""
|
| 106 |
-
self.log("\n=== Validating Serialization ===", "INFO")
|
| 107 |
-
|
| 108 |
-
try:
|
| 109 |
-
# Test Action serialization
|
| 110 |
-
action = Action(alert_id="test", action_type="INVESTIGATE")
|
| 111 |
-
json_str = action.model_dump_json()
|
| 112 |
-
restored = Action.model_validate_json(json_str)
|
| 113 |
-
action_ok = restored.alert_id == action.alert_id
|
| 114 |
-
self.check("Action serialization round-trip", action_ok)
|
| 115 |
-
|
| 116 |
-
# Test Reward serialization
|
| 117 |
-
reward = Reward(value=10.0, components={"test": 10.0})
|
| 118 |
-
json_str = reward.model_dump_json()
|
| 119 |
-
restored = Reward.model_validate_json(json_str)
|
| 120 |
-
reward_ok = restored.value == reward.value
|
| 121 |
-
self.check("Reward serialization round-trip", reward_ok)
|
| 122 |
-
|
| 123 |
-
return action_ok and reward_ok
|
| 124 |
-
except Exception as e:
|
| 125 |
-
self.check("Serialization", False, str(e))
|
| 126 |
-
return False
|
| 127 |
-
|
| 128 |
-
def validate_reset_method(self) -> bool:
|
| 129 |
-
"""2. Check reset() method."""
|
| 130 |
-
self.log("\n=== Validating reset() Method ===", "INFO")
|
| 131 |
-
|
| 132 |
-
try:
|
| 133 |
-
env = AdaptiveAlertTriageEnv(task_id="easy", seed=42)
|
| 134 |
-
|
| 135 |
-
# Check method exists
|
| 136 |
-
has_method = hasattr(env, "reset")
|
| 137 |
-
self.check("reset() method exists", has_method)
|
| 138 |
-
if not has_method:
|
| 139 |
-
return False
|
| 140 |
-
|
| 141 |
-
# Check return type
|
| 142 |
-
obs = env.reset()
|
| 143 |
-
returns_observation = isinstance(obs, Observation)
|
| 144 |
-
self.check("reset() returns Observation", returns_observation)
|
| 145 |
-
|
| 146 |
-
# Check reproducibility
|
| 147 |
-
env2 = AdaptiveAlertTriageEnv(task_id="easy")
|
| 148 |
-
obs2 = env2.reset(seed=42)
|
| 149 |
-
is_reproducible = len(env.alerts) == len(env2.alerts)
|
| 150 |
-
self.check("reset() is reproducible with seed", is_reproducible)
|
| 151 |
-
|
| 152 |
-
return has_method and returns_observation and is_reproducible
|
| 153 |
-
except Exception as e:
|
| 154 |
-
self.check("reset() validation", False, str(e))
|
| 155 |
-
return False
|
| 156 |
-
|
| 157 |
-
def validate_step_method(self) -> bool:
|
| 158 |
-
"""3. Check step(action) method."""
|
| 159 |
-
self.log("\n=== Validating step() Method ===", "INFO")
|
| 160 |
-
|
| 161 |
-
try:
|
| 162 |
-
env = AdaptiveAlertTriageEnv(task_id="easy", seed=42)
|
| 163 |
-
obs = env.reset()
|
| 164 |
-
|
| 165 |
-
# Check method exists
|
| 166 |
-
has_method = hasattr(env, "step")
|
| 167 |
-
self.check("step() method exists", has_method)
|
| 168 |
-
if not has_method or not obs.alerts:
|
| 169 |
-
return False
|
| 170 |
-
|
| 171 |
-
# Take a step
|
| 172 |
-
action = Action(alert_id=obs.alerts[0].id, action_type="INVESTIGATE")
|
| 173 |
-
result = env.step(action)
|
| 174 |
-
|
| 175 |
-
# Check return type is tuple
|
| 176 |
-
is_tuple = isinstance(result, tuple)
|
| 177 |
-
self.check("step() returns tuple", is_tuple)
|
| 178 |
-
|
| 179 |
-
if not is_tuple:
|
| 180 |
-
return False
|
| 181 |
-
|
| 182 |
-
# Check tuple length
|
| 183 |
-
correct_length = len(result) == 4
|
| 184 |
-
self.check("step() returns 4-tuple", correct_length, f"Got {len(result)} elements")
|
| 185 |
-
|
| 186 |
-
if not correct_length:
|
| 187 |
-
return False
|
| 188 |
-
|
| 189 |
-
next_obs, reward, done, info = result
|
| 190 |
-
|
| 191 |
-
# Check return types
|
| 192 |
-
obs_ok = isinstance(next_obs, Observation)
|
| 193 |
-
self.check("step() returns Observation", obs_ok)
|
| 194 |
-
|
| 195 |
-
reward_ok = isinstance(reward, Reward)
|
| 196 |
-
self.check("step() returns Reward", reward_ok)
|
| 197 |
-
|
| 198 |
-
done_ok = isinstance(done, bool)
|
| 199 |
-
self.check("step() returns bool (done)", done_ok)
|
| 200 |
-
|
| 201 |
-
info_ok = isinstance(info, dict)
|
| 202 |
-
self.check("step() returns dict (info)", info_ok)
|
| 203 |
-
|
| 204 |
-
# Check info contents
|
| 205 |
-
if info_ok:
|
| 206 |
-
has_processed_alerts = "processed_alerts" in info
|
| 207 |
-
self.check(
|
| 208 |
-
"info contains 'processed_alerts'",
|
| 209 |
-
has_processed_alerts,
|
| 210 |
-
f"Keys: {', '.join(sorted(info.keys()))}"
|
| 211 |
-
)
|
| 212 |
-
|
| 213 |
-
has_correlation_groups = "correlation_groups" in info
|
| 214 |
-
self.check("info contains 'correlation_groups'", has_correlation_groups)
|
| 215 |
-
|
| 216 |
-
return obs_ok and reward_ok and done_ok and info_ok
|
| 217 |
-
except Exception as e:
|
| 218 |
-
self.check("step() validation", False, str(e))
|
| 219 |
-
return False
|
| 220 |
-
|
| 221 |
-
def validate_state_method(self) -> bool:
|
| 222 |
-
"""4. Check state() method."""
|
| 223 |
-
self.log("\n=== Validating state() Method ===", "INFO")
|
| 224 |
-
|
| 225 |
-
try:
|
| 226 |
-
env = AdaptiveAlertTriageEnv(task_id="easy", seed=42)
|
| 227 |
-
env.reset()
|
| 228 |
-
|
| 229 |
-
# Check method exists
|
| 230 |
-
has_method = hasattr(env, "state")
|
| 231 |
-
self.check("state() method exists", has_method)
|
| 232 |
-
if not has_method:
|
| 233 |
-
return False
|
| 234 |
-
|
| 235 |
-
# Get state
|
| 236 |
-
state = env.state()
|
| 237 |
-
|
| 238 |
-
# Check return type
|
| 239 |
-
is_episode_state = isinstance(state, EpisodeState)
|
| 240 |
-
self.check("state() returns EpisodeState", is_episode_state)
|
| 241 |
-
|
| 242 |
-
if not is_episode_state:
|
| 243 |
-
return False
|
| 244 |
-
|
| 245 |
-
# Check required attributes
|
| 246 |
-
has_observation = hasattr(state, "observation") and isinstance(state.observation, Observation)
|
| 247 |
-
self.check("EpisodeState has observation (Observation)", has_observation)
|
| 248 |
-
|
| 249 |
-
has_hidden_state = hasattr(state, "hidden_state") and isinstance(state.hidden_state, dict)
|
| 250 |
-
self.check("EpisodeState has hidden_state (dict)", has_hidden_state)
|
| 251 |
-
|
| 252 |
-
if has_hidden_state:
|
| 253 |
-
has_true_severities = "true_severities" in state.hidden_state
|
| 254 |
-
self.check("hidden_state contains true_severities", has_true_severities)
|
| 255 |
-
|
| 256 |
-
has_correlation_groups = "correlation_groups" in state.hidden_state
|
| 257 |
-
self.check("hidden_state contains correlation_groups", has_correlation_groups)
|
| 258 |
-
|
| 259 |
-
has_cumulative_reward = hasattr(state, "cumulative_reward")
|
| 260 |
-
self.check("EpisodeState has cumulative_reward", has_cumulative_reward)
|
| 261 |
-
|
| 262 |
-
return is_episode_state and has_observation and has_hidden_state
|
| 263 |
-
except Exception as e:
|
| 264 |
-
self.check("state() validation", False, str(e))
|
| 265 |
-
return False
|
| 266 |
-
|
| 267 |
-
def validate_openenv_yaml(self) -> bool:
|
| 268 |
-
"""5. Check openenv.yaml metadata."""
|
| 269 |
-
self.log("\n=== Validating openenv.yaml ===", "INFO")
|
| 270 |
-
|
| 271 |
-
try:
|
| 272 |
-
yaml_path = Path("openenv.yaml")
|
| 273 |
-
|
| 274 |
-
# Check file exists
|
| 275 |
-
exists = yaml_path.exists()
|
| 276 |
-
self.check("openenv.yaml exists", exists, str(yaml_path.absolute()))
|
| 277 |
-
|
| 278 |
-
if not exists:
|
| 279 |
-
return False
|
| 280 |
-
|
| 281 |
-
# Check valid YAML
|
| 282 |
-
with open(yaml_path) as f:
|
| 283 |
-
data = yaml.safe_load(f)
|
| 284 |
-
|
| 285 |
-
is_dict = isinstance(data, dict)
|
| 286 |
-
self.check("openenv.yaml is valid YAML dict", is_dict)
|
| 287 |
-
|
| 288 |
-
if not is_dict:
|
| 289 |
-
return False
|
| 290 |
-
|
| 291 |
-
# Check required fields
|
| 292 |
-
required_fields = {
|
| 293 |
-
("name", "Environment name"),
|
| 294 |
-
("version", "Version string"),
|
| 295 |
-
("description", "Description"),
|
| 296 |
-
("tasks", "Task definitions"),
|
| 297 |
-
}
|
| 298 |
-
|
| 299 |
-
all_present = True
|
| 300 |
-
for field, description in required_fields:
|
| 301 |
-
present = field in data
|
| 302 |
-
self.check(f"'{field}' present ({description})", present)
|
| 303 |
-
all_present = all_present and present
|
| 304 |
-
|
| 305 |
-
# Check tasks structure
|
| 306 |
-
if "tasks" in data:
|
| 307 |
-
tasks = data["tasks"]
|
| 308 |
-
is_list = isinstance(tasks, list)
|
| 309 |
-
self.check("tasks is a list", is_list, f"Got {type(tasks)}")
|
| 310 |
-
|
| 311 |
-
if is_list:
|
| 312 |
-
has_tasks = len(tasks) > 0
|
| 313 |
-
self.check("tasks list is not empty", has_tasks, f"{len(tasks)} tasks defined")
|
| 314 |
-
|
| 315 |
-
# Check each task has ID
|
| 316 |
-
all_have_ids = all("id" in task for task in tasks)
|
| 317 |
-
task_ids = [task.get("id", "?") for task in tasks]
|
| 318 |
-
self.check("all tasks have 'id'", all_have_ids, f"IDs: {', '.join(task_ids)}")
|
| 319 |
-
|
| 320 |
-
# Check config section
|
| 321 |
-
has_config = "config" in data
|
| 322 |
-
self.check("'config' section present", has_config)
|
| 323 |
-
|
| 324 |
-
if has_config and "actions" in data["config"]:
|
| 325 |
-
expected_actions = {"INVESTIGATE", "IGNORE", "ESCALATE", "DELAY"}
|
| 326 |
-
yaml_actions = set(data["config"]["actions"])
|
| 327 |
-
has_all_actions = expected_actions.issubset(yaml_actions)
|
| 328 |
-
self.check("config.actions includes all required actions", has_all_actions,
|
| 329 |
-
f"Found: {', '.join(sorted(yaml_actions))}")
|
| 330 |
-
|
| 331 |
-
return all_present
|
| 332 |
-
except Exception as e:
|
| 333 |
-
self.check("openenv.yaml validation", False, str(e))
|
| 334 |
-
return False
|
| 335 |
-
|
| 336 |
-
def validate_all_tasks(self) -> bool:
|
| 337 |
-
"""Verify all tasks work correctly."""
|
| 338 |
-
self.log("\n=== Validating All Tasks ===", "INFO")
|
| 339 |
-
|
| 340 |
-
try:
|
| 341 |
-
all_ok = True
|
| 342 |
-
for task_id in ["easy", "medium", "hard"]:
|
| 343 |
-
try:
|
| 344 |
-
env = AdaptiveAlertTriageEnv(task_id=task_id, seed=42)
|
| 345 |
-
obs = env.reset()
|
| 346 |
-
|
| 347 |
-
# Verify structure
|
| 348 |
-
obs_ok = isinstance(obs, Observation)
|
| 349 |
-
|
| 350 |
-
# Take one step
|
| 351 |
-
if obs.alerts:
|
| 352 |
-
action = Action(alert_id=obs.alerts[0].id, action_type="INVESTIGATE")
|
| 353 |
-
next_obs, reward, done, info = env.step(action)
|
| 354 |
-
|
| 355 |
-
step_ok = (
|
| 356 |
-
isinstance(next_obs, Observation) and
|
| 357 |
-
isinstance(reward, Reward) and
|
| 358 |
-
isinstance(done, bool) and
|
| 359 |
-
isinstance(info, dict)
|
| 360 |
-
)
|
| 361 |
-
else:
|
| 362 |
-
step_ok = True
|
| 363 |
-
|
| 364 |
-
# Get state
|
| 365 |
-
state_ok = isinstance(env.state(), EpisodeState)
|
| 366 |
-
|
| 367 |
-
task_ok = obs_ok and step_ok and state_ok
|
| 368 |
-
self.check(f"Task '{task_id}' is OpenEnv compliant", task_ok)
|
| 369 |
-
all_ok = all_ok and task_ok
|
| 370 |
-
except Exception as e:
|
| 371 |
-
self.check(f"Task '{task_id}' is OpenEnv compliant", False, str(e))
|
| 372 |
-
all_ok = False
|
| 373 |
-
|
| 374 |
-
return all_ok
|
| 375 |
-
except Exception as e:
|
| 376 |
-
self.check("Task validation", False, str(e))
|
| 377 |
-
return False
|
| 378 |
-
|
| 379 |
-
def run_all_checks(self) -> bool:
|
| 380 |
-
"""Run all validation checks."""
|
| 381 |
-
self.log("=" * 60)
|
| 382 |
-
self.log("OpenEnv Compliance Validator", "INFO")
|
| 383 |
-
self.log("=" * 60)
|
| 384 |
-
|
| 385 |
-
results = [
|
| 386 |
-
self.validate_pydantic_models(),
|
| 387 |
-
self.validate_required_fields(),
|
| 388 |
-
self.validate_serialization(),
|
| 389 |
-
self.validate_reset_method(),
|
| 390 |
-
self.validate_step_method(),
|
| 391 |
-
self.validate_state_method(),
|
| 392 |
-
self.validate_openenv_yaml(),
|
| 393 |
-
self.validate_all_tasks(),
|
| 394 |
-
]
|
| 395 |
-
|
| 396 |
-
# Print summary
|
| 397 |
-
self.log("\n" + "=" * 60, "INFO")
|
| 398 |
-
self.log("VALIDATION SUMMARY", "INFO")
|
| 399 |
-
self.log("=" * 60, "INFO")
|
| 400 |
-
|
| 401 |
-
total_passed = len(self.checks_passed)
|
| 402 |
-
total_failed = len(self.checks_failed)
|
| 403 |
-
total_checks = total_passed + total_failed
|
| 404 |
-
|
| 405 |
-
self.log(f"Passed: {total_passed}/{total_checks}", "INFO")
|
| 406 |
-
|
| 407 |
-
if self.checks_failed:
|
| 408 |
-
self.log(f"Failed: {total_failed}/{total_checks}", "ERROR")
|
| 409 |
-
for name, details in self.checks_failed:
|
| 410 |
-
self.log(f" - {name}", "ERROR")
|
| 411 |
-
if details:
|
| 412 |
-
self.log(f" {details}", "ERROR")
|
| 413 |
-
else:
|
| 414 |
-
self.log("All checks passed! ✓", "PASS")
|
| 415 |
-
|
| 416 |
-
self.log("=" * 60 + "\n", "INFO")
|
| 417 |
-
|
| 418 |
-
return len(self.checks_failed) == 0
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
def main():
|
| 422 |
-
"""Entry point for CLI."""
|
| 423 |
-
validator = OpenEnvValidator(verbose=True)
|
| 424 |
-
success = validator.run_all_checks()
|
| 425 |
-
|
| 426 |
-
# Return appropriate exit code
|
| 427 |
-
sys.exit(0 if success else 1)
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
if __name__ == "__main__":
|
| 431 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|