Tusharp2006 commited on
Commit
d6fcbb0
·
1 Parent(s): 7cd2458
Files changed (4) hide show
  1. tasks/easy.py +5 -11
  2. tasks/hard.py +4 -14
  3. tasks/medium.py +5 -15
  4. validate.py +0 -431
tasks/easy.py CHANGED
@@ -161,21 +161,15 @@ class EasyTaskGrader:
161
 
162
  def get_episode_score(self) -> float:
163
  """
164
- Return final normalised score in (0, 1).
165
-
166
- Formula: 0.01 + 0.98 * (correct_actions / total_actions)
167
- This ensures the score is always strictly between 0 and 1 as
168
- required by the grading system.
169
  """
170
  if self.total_actions == 0:
171
- return 0.01
172
 
173
  raw = self.correct_actions / self.total_actions
174
- # Enforce strict (0, 1) range
175
- clamped = 0.01 + 0.98 * raw
176
- rounded = round(float(clamped), 2)
177
- # Ensure no rounding to boundaries (0.0 or 1.0)
178
- return max(0.01, min(rounded, 0.99))
179
 
180
 
181
  def passed(self) -> bool:
 
161
 
162
  def get_episode_score(self) -> float:
163
  """
164
+ Return final normalised score strictly in (0, 1) — never 0.0 or 1.0.
 
 
 
 
165
  """
166
  if self.total_actions == 0:
167
+ return 0.5
168
 
169
  raw = self.correct_actions / self.total_actions
170
+ # Map [0,1] -> (0,1) with a small epsilon margin, no rounding
171
+ score = 0.001 + 0.998 * float(raw)
172
+ return max(0.001, min(0.999, score))
 
 
173
 
174
 
175
  def passed(self) -> bool:
tasks/hard.py CHANGED
@@ -372,19 +372,11 @@ class HardTaskGrader:
372
 
373
  def get_episode_score(self) -> float:
374
  """
375
- Return final normalised score in (0, 1).
376
-
377
- Formula:
378
- chain_score = Σ chain.outcome_score()
379
- stability = _stability_score(system_failures)
380
- base = (raw * stability)
381
- clamped = 0.01 + 0.98 * base
382
  """
383
- # Chain component
384
  chain_score = sum(c.outcome_score() for c in self._chains.values())
385
  max_chain = sum(c.max_possible() for c in self._chains.values())
386
 
387
- # Isolation bonus (capped)
388
  isolation = min(
389
  self._isolation_correct * _ISOLATION_BONUS_PER_ALERT,
390
  _ISOLATION_BONUS_CAP,
@@ -396,11 +388,9 @@ class HardTaskGrader:
396
  stability = self._stability_score(self._system_failures)
397
  final_base = max(0.0, min(raw * stability, 1.0))
398
 
399
- # Enforce strict (0, 1) range
400
- clamped = 0.01 + 0.98 * final_base
401
- rounded = round(float(clamped), 2)
402
- # Ensure no rounding to boundaries (0.0 or 1.0)
403
- return max(0.01, min(rounded, 0.99))
404
 
405
 
406
  def passed(self) -> bool:
 
372
 
373
  def get_episode_score(self) -> float:
374
  """
375
+ Return final normalised score strictly in (0, 1) — never 0.0 or 1.0.
 
 
 
 
 
 
376
  """
 
377
  chain_score = sum(c.outcome_score() for c in self._chains.values())
378
  max_chain = sum(c.max_possible() for c in self._chains.values())
379
 
 
380
  isolation = min(
381
  self._isolation_correct * _ISOLATION_BONUS_PER_ALERT,
382
  _ISOLATION_BONUS_CAP,
 
388
  stability = self._stability_score(self._system_failures)
389
  final_base = max(0.0, min(raw * stability, 1.0))
390
 
391
+ # Map [0,1] -> (0,1) with a small epsilon margin, no rounding
392
+ score = 0.001 + 0.998 * float(final_base)
393
+ return max(0.001, min(0.999, score))
 
 
394
 
395
 
396
  def passed(self) -> bool:
tasks/medium.py CHANGED
@@ -192,27 +192,19 @@ class MediumTaskGrader:
192
 
193
  def get_episode_score(self) -> float:
194
  """
195
- Return final normalised score in (0, 1).
196
-
197
- Formula:
198
- raw = resolved_score / max_possible_score
199
- base = max(0.0, raw − fp_penalty − miss_penalty)
200
- clamped = 0.01 + 0.98 * base
201
  """
202
  if self._max_possible_score <= 0.0:
203
- return 0.01
204
 
205
- # Normalised resolved quality
206
  raw = min(self._resolved_score / self._max_possible_score, 1.0)
207
 
208
- # Penalty 1: wasted investigation budget on false positives
209
  if self._total_investigations > 0:
210
  fp_rate = self._unnecessary_invest / self._total_investigations
211
  else:
212
  fp_rate = 0.0
213
  fp_penalty = _FP_PENALTY_WEIGHT * fp_rate
214
 
215
- # Penalty 2: missed critical alerts
216
  if self._critical_total > 0:
217
  miss_rate = min(self._critical_missed / self._critical_total, 1.0)
218
  else:
@@ -220,11 +212,9 @@ class MediumTaskGrader:
220
  miss_penalty = _CRITICAL_MISS_PENALTY_WEIGHT * miss_rate
221
 
222
  base_score = max(0.0, raw - fp_penalty - miss_penalty)
223
- # Enforce strict (0, 1) range
224
- clamped = 0.01 + 0.98 * base_score
225
- rounded = round(float(clamped), 2)
226
- # Ensure no rounding to boundaries (0.0 or 1.0)
227
- return max(0.01, min(rounded, 0.99))
228
 
229
 
230
  def passed(self) -> bool:
 
192
 
193
  def get_episode_score(self) -> float:
194
  """
195
+ Return final normalised score strictly in (0, 1) — never 0.0 or 1.0.
 
 
 
 
 
196
  """
197
  if self._max_possible_score <= 0.0:
198
+ return 0.5
199
 
 
200
  raw = min(self._resolved_score / self._max_possible_score, 1.0)
201
 
 
202
  if self._total_investigations > 0:
203
  fp_rate = self._unnecessary_invest / self._total_investigations
204
  else:
205
  fp_rate = 0.0
206
  fp_penalty = _FP_PENALTY_WEIGHT * fp_rate
207
 
 
208
  if self._critical_total > 0:
209
  miss_rate = min(self._critical_missed / self._critical_total, 1.0)
210
  else:
 
212
  miss_penalty = _CRITICAL_MISS_PENALTY_WEIGHT * miss_rate
213
 
214
  base_score = max(0.0, raw - fp_penalty - miss_penalty)
215
+ # Map [0,1] -> (0,1) with a small epsilon margin, no rounding
216
+ score = 0.001 + 0.998 * float(base_score)
217
+ return max(0.001, min(0.999, score))
 
 
218
 
219
 
220
  def passed(self) -> bool:
validate.py DELETED
@@ -1,431 +0,0 @@
1
- #!/usr/bin/env python
2
- """
3
- OpenEnv Validation CLI Tool
4
-
5
- Usage:
6
- python -m src.adaptive_alert_triage.validate
7
- openenv validate (if registered as entry point)
8
-
9
- Validates that the Adaptive Alert Triage environment meets the full OpenEnv
10
- interface specification:
11
- 1. Typed Observation, Action, and Reward Pydantic models
12
- 2. step(action) → returns (observation, reward, done, info)
13
- 3. reset() → returns initial observation
14
- 4. state() → returns current state
15
- 5. openenv.yaml with metadata
16
- """
17
-
18
- import sys
19
- import json
20
- from pathlib import Path
21
- from typing import Dict, List, Tuple
22
- import yaml
23
-
24
- from adaptive_alert_triage.env import AdaptiveAlertTriageEnv
25
- from adaptive_alert_triage.models import (
26
- Action,
27
- Observation,
28
- Reward,
29
- Alert,
30
- EpisodeState,
31
- )
32
-
33
-
34
- class OpenEnvValidator:
35
- """Validates OpenEnv compliance of the environment."""
36
-
37
- def __init__(self, verbose: bool = True):
38
- self.verbose = verbose
39
- self.checks_passed = []
40
- self.checks_failed = []
41
-
42
- def log(self, message: str, level: str = "INFO"):
43
- """Log a message with level."""
44
- if self.verbose:
45
- print(f"[{level}] {message}")
46
-
47
- def check(self, name: str, condition: bool, details: str = "") -> bool:
48
- """Record a check result."""
49
- if condition:
50
- self.checks_passed.append(name)
51
- self.log(f"✓ {name}", "PASS")
52
- if details:
53
- self.log(f" {details}", "INFO")
54
- return True
55
- else:
56
- self.checks_failed.append((name, details))
57
- self.log(f"✗ {name}", "FAIL")
58
- if details:
59
- self.log(f" {details}", "ERROR")
60
- return False
61
-
62
- def validate_pydantic_models(self) -> bool:
63
- """1. Check that models are Pydantic BaseModels."""
64
- self.log("\n=== Validating Pydantic Models ===", "INFO")
65
-
66
- from pydantic import BaseModel
67
-
68
- checks = [
69
- ("Observation is Pydantic BaseModel", issubclass(Observation, BaseModel)),
70
- ("Action is Pydantic BaseModel", issubclass(Action, BaseModel)),
71
- ("Reward is Pydantic BaseModel", issubclass(Reward, BaseModel)),
72
- ("EpisodeState is Pydantic BaseModel", issubclass(EpisodeState, BaseModel)),
73
- ("Alert is Pydantic BaseModel", issubclass(Alert, BaseModel)),
74
- ]
75
-
76
- return all(self.check(name, cond) for name, cond in checks)
77
-
78
- def validate_required_fields(self) -> bool:
79
- """Check that models have required fields."""
80
- self.log("\n=== Validating Model Fields ===", "INFO")
81
-
82
- checks = [
83
- (
84
- "Observation has required fields",
85
- {"alerts", "system_load", "queue_length", "time_remaining", "episode_step"}.issubset(
86
- set(Observation.model_fields.keys())
87
- ),
88
- f"Fields: {', '.join(sorted(Observation.model_fields.keys()))}"
89
- ),
90
- (
91
- "Action has required fields",
92
- {"alert_id", "action_type"}.issubset(set(Action.model_fields.keys())),
93
- f"Fields: {', '.join(sorted(Action.model_fields.keys()))}"
94
- ),
95
- (
96
- "Reward has required fields",
97
- {"value", "components"}.issubset(set(Reward.model_fields.keys())),
98
- f"Fields: {', '.join(sorted(Reward.model_fields.keys()))}"
99
- ),
100
- ]
101
-
102
- return all(self.check(name, cond, details) for name, cond, details in checks)
103
-
104
- def validate_serialization(self) -> bool:
105
- """Check that models can be serialized/deserialized."""
106
- self.log("\n=== Validating Serialization ===", "INFO")
107
-
108
- try:
109
- # Test Action serialization
110
- action = Action(alert_id="test", action_type="INVESTIGATE")
111
- json_str = action.model_dump_json()
112
- restored = Action.model_validate_json(json_str)
113
- action_ok = restored.alert_id == action.alert_id
114
- self.check("Action serialization round-trip", action_ok)
115
-
116
- # Test Reward serialization
117
- reward = Reward(value=10.0, components={"test": 10.0})
118
- json_str = reward.model_dump_json()
119
- restored = Reward.model_validate_json(json_str)
120
- reward_ok = restored.value == reward.value
121
- self.check("Reward serialization round-trip", reward_ok)
122
-
123
- return action_ok and reward_ok
124
- except Exception as e:
125
- self.check("Serialization", False, str(e))
126
- return False
127
-
128
- def validate_reset_method(self) -> bool:
129
- """2. Check reset() method."""
130
- self.log("\n=== Validating reset() Method ===", "INFO")
131
-
132
- try:
133
- env = AdaptiveAlertTriageEnv(task_id="easy", seed=42)
134
-
135
- # Check method exists
136
- has_method = hasattr(env, "reset")
137
- self.check("reset() method exists", has_method)
138
- if not has_method:
139
- return False
140
-
141
- # Check return type
142
- obs = env.reset()
143
- returns_observation = isinstance(obs, Observation)
144
- self.check("reset() returns Observation", returns_observation)
145
-
146
- # Check reproducibility
147
- env2 = AdaptiveAlertTriageEnv(task_id="easy")
148
- obs2 = env2.reset(seed=42)
149
- is_reproducible = len(env.alerts) == len(env2.alerts)
150
- self.check("reset() is reproducible with seed", is_reproducible)
151
-
152
- return has_method and returns_observation and is_reproducible
153
- except Exception as e:
154
- self.check("reset() validation", False, str(e))
155
- return False
156
-
157
- def validate_step_method(self) -> bool:
158
- """3. Check step(action) method."""
159
- self.log("\n=== Validating step() Method ===", "INFO")
160
-
161
- try:
162
- env = AdaptiveAlertTriageEnv(task_id="easy", seed=42)
163
- obs = env.reset()
164
-
165
- # Check method exists
166
- has_method = hasattr(env, "step")
167
- self.check("step() method exists", has_method)
168
- if not has_method or not obs.alerts:
169
- return False
170
-
171
- # Take a step
172
- action = Action(alert_id=obs.alerts[0].id, action_type="INVESTIGATE")
173
- result = env.step(action)
174
-
175
- # Check return type is tuple
176
- is_tuple = isinstance(result, tuple)
177
- self.check("step() returns tuple", is_tuple)
178
-
179
- if not is_tuple:
180
- return False
181
-
182
- # Check tuple length
183
- correct_length = len(result) == 4
184
- self.check("step() returns 4-tuple", correct_length, f"Got {len(result)} elements")
185
-
186
- if not correct_length:
187
- return False
188
-
189
- next_obs, reward, done, info = result
190
-
191
- # Check return types
192
- obs_ok = isinstance(next_obs, Observation)
193
- self.check("step() returns Observation", obs_ok)
194
-
195
- reward_ok = isinstance(reward, Reward)
196
- self.check("step() returns Reward", reward_ok)
197
-
198
- done_ok = isinstance(done, bool)
199
- self.check("step() returns bool (done)", done_ok)
200
-
201
- info_ok = isinstance(info, dict)
202
- self.check("step() returns dict (info)", info_ok)
203
-
204
- # Check info contents
205
- if info_ok:
206
- has_processed_alerts = "processed_alerts" in info
207
- self.check(
208
- "info contains 'processed_alerts'",
209
- has_processed_alerts,
210
- f"Keys: {', '.join(sorted(info.keys()))}"
211
- )
212
-
213
- has_correlation_groups = "correlation_groups" in info
214
- self.check("info contains 'correlation_groups'", has_correlation_groups)
215
-
216
- return obs_ok and reward_ok and done_ok and info_ok
217
- except Exception as e:
218
- self.check("step() validation", False, str(e))
219
- return False
220
-
221
- def validate_state_method(self) -> bool:
222
- """4. Check state() method."""
223
- self.log("\n=== Validating state() Method ===", "INFO")
224
-
225
- try:
226
- env = AdaptiveAlertTriageEnv(task_id="easy", seed=42)
227
- env.reset()
228
-
229
- # Check method exists
230
- has_method = hasattr(env, "state")
231
- self.check("state() method exists", has_method)
232
- if not has_method:
233
- return False
234
-
235
- # Get state
236
- state = env.state()
237
-
238
- # Check return type
239
- is_episode_state = isinstance(state, EpisodeState)
240
- self.check("state() returns EpisodeState", is_episode_state)
241
-
242
- if not is_episode_state:
243
- return False
244
-
245
- # Check required attributes
246
- has_observation = hasattr(state, "observation") and isinstance(state.observation, Observation)
247
- self.check("EpisodeState has observation (Observation)", has_observation)
248
-
249
- has_hidden_state = hasattr(state, "hidden_state") and isinstance(state.hidden_state, dict)
250
- self.check("EpisodeState has hidden_state (dict)", has_hidden_state)
251
-
252
- if has_hidden_state:
253
- has_true_severities = "true_severities" in state.hidden_state
254
- self.check("hidden_state contains true_severities", has_true_severities)
255
-
256
- has_correlation_groups = "correlation_groups" in state.hidden_state
257
- self.check("hidden_state contains correlation_groups", has_correlation_groups)
258
-
259
- has_cumulative_reward = hasattr(state, "cumulative_reward")
260
- self.check("EpisodeState has cumulative_reward", has_cumulative_reward)
261
-
262
- return is_episode_state and has_observation and has_hidden_state
263
- except Exception as e:
264
- self.check("state() validation", False, str(e))
265
- return False
266
-
267
- def validate_openenv_yaml(self) -> bool:
268
- """5. Check openenv.yaml metadata."""
269
- self.log("\n=== Validating openenv.yaml ===", "INFO")
270
-
271
- try:
272
- yaml_path = Path("openenv.yaml")
273
-
274
- # Check file exists
275
- exists = yaml_path.exists()
276
- self.check("openenv.yaml exists", exists, str(yaml_path.absolute()))
277
-
278
- if not exists:
279
- return False
280
-
281
- # Check valid YAML
282
- with open(yaml_path) as f:
283
- data = yaml.safe_load(f)
284
-
285
- is_dict = isinstance(data, dict)
286
- self.check("openenv.yaml is valid YAML dict", is_dict)
287
-
288
- if not is_dict:
289
- return False
290
-
291
- # Check required fields
292
- required_fields = {
293
- ("name", "Environment name"),
294
- ("version", "Version string"),
295
- ("description", "Description"),
296
- ("tasks", "Task definitions"),
297
- }
298
-
299
- all_present = True
300
- for field, description in required_fields:
301
- present = field in data
302
- self.check(f"'{field}' present ({description})", present)
303
- all_present = all_present and present
304
-
305
- # Check tasks structure
306
- if "tasks" in data:
307
- tasks = data["tasks"]
308
- is_list = isinstance(tasks, list)
309
- self.check("tasks is a list", is_list, f"Got {type(tasks)}")
310
-
311
- if is_list:
312
- has_tasks = len(tasks) > 0
313
- self.check("tasks list is not empty", has_tasks, f"{len(tasks)} tasks defined")
314
-
315
- # Check each task has ID
316
- all_have_ids = all("id" in task for task in tasks)
317
- task_ids = [task.get("id", "?") for task in tasks]
318
- self.check("all tasks have 'id'", all_have_ids, f"IDs: {', '.join(task_ids)}")
319
-
320
- # Check config section
321
- has_config = "config" in data
322
- self.check("'config' section present", has_config)
323
-
324
- if has_config and "actions" in data["config"]:
325
- expected_actions = {"INVESTIGATE", "IGNORE", "ESCALATE", "DELAY"}
326
- yaml_actions = set(data["config"]["actions"])
327
- has_all_actions = expected_actions.issubset(yaml_actions)
328
- self.check("config.actions includes all required actions", has_all_actions,
329
- f"Found: {', '.join(sorted(yaml_actions))}")
330
-
331
- return all_present
332
- except Exception as e:
333
- self.check("openenv.yaml validation", False, str(e))
334
- return False
335
-
336
- def validate_all_tasks(self) -> bool:
337
- """Verify all tasks work correctly."""
338
- self.log("\n=== Validating All Tasks ===", "INFO")
339
-
340
- try:
341
- all_ok = True
342
- for task_id in ["easy", "medium", "hard"]:
343
- try:
344
- env = AdaptiveAlertTriageEnv(task_id=task_id, seed=42)
345
- obs = env.reset()
346
-
347
- # Verify structure
348
- obs_ok = isinstance(obs, Observation)
349
-
350
- # Take one step
351
- if obs.alerts:
352
- action = Action(alert_id=obs.alerts[0].id, action_type="INVESTIGATE")
353
- next_obs, reward, done, info = env.step(action)
354
-
355
- step_ok = (
356
- isinstance(next_obs, Observation) and
357
- isinstance(reward, Reward) and
358
- isinstance(done, bool) and
359
- isinstance(info, dict)
360
- )
361
- else:
362
- step_ok = True
363
-
364
- # Get state
365
- state_ok = isinstance(env.state(), EpisodeState)
366
-
367
- task_ok = obs_ok and step_ok and state_ok
368
- self.check(f"Task '{task_id}' is OpenEnv compliant", task_ok)
369
- all_ok = all_ok and task_ok
370
- except Exception as e:
371
- self.check(f"Task '{task_id}' is OpenEnv compliant", False, str(e))
372
- all_ok = False
373
-
374
- return all_ok
375
- except Exception as e:
376
- self.check("Task validation", False, str(e))
377
- return False
378
-
379
- def run_all_checks(self) -> bool:
380
- """Run all validation checks."""
381
- self.log("=" * 60)
382
- self.log("OpenEnv Compliance Validator", "INFO")
383
- self.log("=" * 60)
384
-
385
- results = [
386
- self.validate_pydantic_models(),
387
- self.validate_required_fields(),
388
- self.validate_serialization(),
389
- self.validate_reset_method(),
390
- self.validate_step_method(),
391
- self.validate_state_method(),
392
- self.validate_openenv_yaml(),
393
- self.validate_all_tasks(),
394
- ]
395
-
396
- # Print summary
397
- self.log("\n" + "=" * 60, "INFO")
398
- self.log("VALIDATION SUMMARY", "INFO")
399
- self.log("=" * 60, "INFO")
400
-
401
- total_passed = len(self.checks_passed)
402
- total_failed = len(self.checks_failed)
403
- total_checks = total_passed + total_failed
404
-
405
- self.log(f"Passed: {total_passed}/{total_checks}", "INFO")
406
-
407
- if self.checks_failed:
408
- self.log(f"Failed: {total_failed}/{total_checks}", "ERROR")
409
- for name, details in self.checks_failed:
410
- self.log(f" - {name}", "ERROR")
411
- if details:
412
- self.log(f" {details}", "ERROR")
413
- else:
414
- self.log("All checks passed! ✓", "PASS")
415
-
416
- self.log("=" * 60 + "\n", "INFO")
417
-
418
- return len(self.checks_failed) == 0
419
-
420
-
421
- def main():
422
- """Entry point for CLI."""
423
- validator = OpenEnvValidator(verbose=True)
424
- success = validator.run_all_checks()
425
-
426
- # Return appropriate exit code
427
- sys.exit(0 if success else 1)
428
-
429
-
430
- if __name__ == "__main__":
431
- main()