Spaces:
Sleeping
Sleeping
Commit ·
c18a9d1
1
Parent(s): eea342f
grader fixes
Browse files- .agents/skills/openenv-cli/SKILL.md +18 -0
- openenv.yaml +7 -5
- pyproject.toml +3 -2
- rewards/reward.py +2 -4
- src/adaptive_alert_triage/env.py +2 -2
- src/adaptive_alert_triage/models.py +1 -7
- src/adaptive_alert_triage/server.py +2 -2
- tasks/easy.py +6 -6
- tasks/hard.py +7 -7
- tasks/medium.py +11 -9
- tests/test_env.py +9 -9
- tests/test_integration.py +9 -12
- tests/test_rewards.py +7 -7
- tests/test_tasks.py +5 -8
- uv.lock +0 -0
.agents/skills/openenv-cli/SKILL.md
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
name: openenv-cli
|
| 3 |
+
description: "OpenEnv CLI (`openenv`) for scaffolding, validating, building, and pushing OpenEnv environments."
|
| 4 |
+
---
|
| 5 |
+
|
| 6 |
+
Install: `pip install openenv-core`
|
| 7 |
+
|
| 8 |
+
The OpenEnv CLI command `openenv` is available.
|
| 9 |
+
Use `openenv --help` to view available commands.
|
| 10 |
+
|
| 11 |
+
Generated with `openenv-core v0.2.3`. Run `openenv skills add --force` to regenerate.
|
| 12 |
+
|
| 13 |
+
## Tips
|
| 14 |
+
|
| 15 |
+
- Start with `openenv init <env_name>` to scaffold a new environment
|
| 16 |
+
- Validate projects with `openenv validate`
|
| 17 |
+
- Build and deploy with `openenv build` and `openenv push`
|
| 18 |
+
- Use `openenv <command> --help` for command-specific options
|
openenv.yaml
CHANGED
|
@@ -140,7 +140,7 @@ tasks:
|
|
| 140 |
correlation_probability: 0.10
|
| 141 |
success_threshold: 0.70 # correct_actions / total_actions >= 0.70
|
| 142 |
grader: "tasks.easy.EasyTaskGrader"
|
| 143 |
-
grading_formula: "score =
|
| 144 |
|
| 145 |
- id: "medium"
|
| 146 |
name: "Resource-Constrained Triage"
|
|
@@ -158,8 +158,10 @@ tasks:
|
|
| 158 |
grader: "tasks.medium.MediumTaskGrader"
|
| 159 |
grading_formula: |
|
| 160 |
raw = resolved_score / max_possible_score
|
| 161 |
-
|
| 162 |
-
|
|
|
|
|
|
|
| 163 |
|
| 164 |
- id: "hard"
|
| 165 |
name: "Cascading Failure Prevention"
|
|
@@ -179,8 +181,8 @@ tasks:
|
|
| 179 |
grading_formula: |
|
| 180 |
chain_score = Σ stop_reward(position) × severity_weight
|
| 181 |
stability = {0 failures: 1.0, 1: 0.80, 2: 0.60, 3: 0.30, 4+: 0.00}
|
| 182 |
-
|
| 183 |
-
score =
|
| 184 |
|
| 185 |
# ── Evaluation metrics (produced by graders) ──────────────────────────────────
|
| 186 |
metrics:
|
|
|
|
| 140 |
correlation_probability: 0.10
|
| 141 |
success_threshold: 0.70 # correct_actions / total_actions >= 0.70
|
| 142 |
grader: "tasks.easy.EasyTaskGrader"
|
| 143 |
+
grading_formula: "score = (correct_actions / total_actions) * 0.98 + 0.01"
|
| 144 |
|
| 145 |
- id: "medium"
|
| 146 |
name: "Resource-Constrained Triage"
|
|
|
|
| 158 |
grader: "tasks.medium.MediumTaskGrader"
|
| 159 |
grading_formula: |
|
| 160 |
raw = resolved_score / max_possible_score
|
| 161 |
+
fp_penalty = 0.30 * (unnecessary_investigations / total_investigations)
|
| 162 |
+
miss_penalty = 0.20 * (critical_missed / max(critical_total, 1))
|
| 163 |
+
penalised = raw - fp_penalty - miss_penalty
|
| 164 |
+
score = (penalised * 0.6) + 0.35
|
| 165 |
|
| 166 |
- id: "hard"
|
| 167 |
name: "Cascading Failure Prevention"
|
|
|
|
| 181 |
grading_formula: |
|
| 182 |
chain_score = Σ stop_reward(position) × severity_weight
|
| 183 |
stability = {0 failures: 1.0, 1: 0.80, 2: 0.60, 3: 0.30, 4+: 0.00}
|
| 184 |
+
raw = (chain_score / max_possible) * stability
|
| 185 |
+
score = (raw * 0.98) + 0.01
|
| 186 |
|
| 187 |
# ── Evaluation metrics (produced by graders) ──────────────────────────────────
|
| 188 |
metrics:
|
pyproject.toml
CHANGED
|
@@ -27,7 +27,7 @@ classifiers = [
|
|
| 27 |
|
| 28 |
dependencies = [
|
| 29 |
"pydantic>=2.0.0",
|
| 30 |
-
"openenv>=0.1.0",
|
| 31 |
"numpy>=1.24.0",
|
| 32 |
"openai>=1.0.0",
|
| 33 |
"pyyaml>=6.0",
|
|
@@ -35,6 +35,7 @@ dependencies = [
|
|
| 35 |
"fastapi>=0.104.0",
|
| 36 |
"websockets>=12.0",
|
| 37 |
"requests>=2.31.0",
|
|
|
|
| 38 |
]
|
| 39 |
|
| 40 |
[project.optional-dependencies]
|
|
@@ -117,4 +118,4 @@ addopts = "-v --cov=src/adaptive_alert_triage --cov-report=term-missing"
|
|
| 117 |
dev = [
|
| 118 |
"pytest>=8.4.2",
|
| 119 |
"pytest-cov>=7.1.0",
|
| 120 |
-
]
|
|
|
|
| 27 |
|
| 28 |
dependencies = [
|
| 29 |
"pydantic>=2.0.0",
|
| 30 |
+
"openenv[cli]>=0.1.0",
|
| 31 |
"numpy>=1.24.0",
|
| 32 |
"openai>=1.0.0",
|
| 33 |
"pyyaml>=6.0",
|
|
|
|
| 35 |
"fastapi>=0.104.0",
|
| 36 |
"websockets>=12.0",
|
| 37 |
"requests>=2.31.0",
|
| 38 |
+
"openenv-core[cli]>=0.1.0",
|
| 39 |
]
|
| 40 |
|
| 41 |
[project.optional-dependencies]
|
|
|
|
| 118 |
dev = [
|
| 119 |
"pytest>=8.4.2",
|
| 120 |
"pytest-cov>=7.1.0",
|
| 121 |
+
]
|
rewards/reward.py
CHANGED
|
@@ -315,7 +315,6 @@ def calculate_reward(
|
|
| 315 |
components = {k: v * multiplier for k, v in components.items()}
|
| 316 |
|
| 317 |
total_reward: float = sum(components.values())
|
| 318 |
-
norm_reward: float = max(0.01, min(0.99, (total_reward + 40.0) / 80.0))
|
| 319 |
|
| 320 |
# -----------------------------------------------------------------------
|
| 321 |
# Info payload — consumed by graders and evaluation scripts
|
|
@@ -336,7 +335,7 @@ def calculate_reward(
|
|
| 336 |
}
|
| 337 |
|
| 338 |
return Reward(
|
| 339 |
-
value=
|
| 340 |
components=components,
|
| 341 |
info=info,
|
| 342 |
)
|
|
@@ -613,8 +612,7 @@ if __name__ == "__main__":
|
|
| 613 |
for desc, act, alert, cfg, expected in cases:
|
| 614 |
action = Action(alert_id=alert.id, action_type=act)
|
| 615 |
result = calculate_reward(action, alert, cfg)
|
| 616 |
-
|
| 617 |
-
ok = abs(result.value - normalized_expected) < 1e-4
|
| 618 |
status = "PASS" if ok else "FAIL"
|
| 619 |
if not ok:
|
| 620 |
all_pass = False
|
|
|
|
| 315 |
components = {k: v * multiplier for k, v in components.items()}
|
| 316 |
|
| 317 |
total_reward: float = sum(components.values())
|
|
|
|
| 318 |
|
| 319 |
# -----------------------------------------------------------------------
|
| 320 |
# Info payload — consumed by graders and evaluation scripts
|
|
|
|
| 335 |
}
|
| 336 |
|
| 337 |
return Reward(
|
| 338 |
+
value=total_reward,
|
| 339 |
components=components,
|
| 340 |
info=info,
|
| 341 |
)
|
|
|
|
| 612 |
for desc, act, alert, cfg, expected in cases:
|
| 613 |
action = Action(alert_id=alert.id, action_type=act)
|
| 614 |
result = calculate_reward(action, alert, cfg)
|
| 615 |
+
ok = abs(result.value - expected) < 1e-4
|
|
|
|
| 616 |
status = "PASS" if ok else "FAIL"
|
| 617 |
if not ok:
|
| 618 |
all_pass = False
|
src/adaptive_alert_triage/env.py
CHANGED
|
@@ -267,7 +267,7 @@ class AdaptiveAlertTriageEnv(gym.Env):
|
|
| 267 |
alert = self._get_alert_by_id(action.alert_id)
|
| 268 |
if alert is None:
|
| 269 |
reward = Reward(
|
| 270 |
-
value=
|
| 271 |
components={"invalid_action": -5.0},
|
| 272 |
info={"error": f"Alert ID '{action.alert_id}' not found in queue"},
|
| 273 |
)
|
|
@@ -284,7 +284,7 @@ class AdaptiveAlertTriageEnv(gym.Env):
|
|
| 284 |
):
|
| 285 |
if self.investigations_used >= self.max_investigations_per_step:
|
| 286 |
reward = Reward(
|
| 287 |
-
value=
|
| 288 |
components={"resource_budget_exceeded": -3.0},
|
| 289 |
info={
|
| 290 |
"error": "Investigation budget exhausted for this step",
|
|
|
|
| 267 |
alert = self._get_alert_by_id(action.alert_id)
|
| 268 |
if alert is None:
|
| 269 |
reward = Reward(
|
| 270 |
+
value=-5.0,
|
| 271 |
components={"invalid_action": -5.0},
|
| 272 |
info={"error": f"Alert ID '{action.alert_id}' not found in queue"},
|
| 273 |
)
|
|
|
|
| 284 |
):
|
| 285 |
if self.investigations_used >= self.max_investigations_per_step:
|
| 286 |
reward = Reward(
|
| 287 |
+
value=-3.0,
|
| 288 |
components={"resource_budget_exceeded": -3.0},
|
| 289 |
info={
|
| 290 |
"error": "Investigation budget exhausted for this step",
|
src/adaptive_alert_triage/models.py
CHANGED
|
@@ -222,13 +222,7 @@ class Reward(BaseModel):
|
|
| 222 |
info: Debugging / logging extras (ground-truth reveal, etc.).
|
| 223 |
"""
|
| 224 |
|
| 225 |
-
value: float = Field(...,
|
| 226 |
-
|
| 227 |
-
@field_validator("value", mode="before")
|
| 228 |
-
@classmethod
|
| 229 |
-
def clamp_reward_value(cls, v: float) -> float:
|
| 230 |
-
"""Silently clamp reward value to [0.01, 0.99] — strict (0, 1) bounds."""
|
| 231 |
-
return float(max(0.01, min(0.99, float(v))))
|
| 232 |
|
| 233 |
components: Dict[str, float] = Field(
|
| 234 |
default_factory=dict, description="Per-component reward breakdown"
|
|
|
|
| 222 |
info: Debugging / logging extras (ground-truth reveal, etc.).
|
| 223 |
"""
|
| 224 |
|
| 225 |
+
value: float = Field(..., description="Total scalar reward")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
|
| 227 |
components: Dict[str, float] = Field(
|
| 228 |
default_factory=dict, description="Per-component reward breakdown"
|
src/adaptive_alert_triage/server.py
CHANGED
|
@@ -131,7 +131,7 @@ def _tick(info: Dict) -> None:
|
|
| 131 |
|
| 132 |
def _score() -> float:
|
| 133 |
raw = _step_correct / _step_total if _step_total else 0.0
|
| 134 |
-
return
|
| 135 |
|
| 136 |
|
| 137 |
# ── PPO helpers ───────────────────────────────────────────────────────────────
|
|
@@ -604,7 +604,7 @@ async def ws_train(websocket: WebSocket):
|
|
| 604 |
lt += 1
|
| 605 |
if info.get("action_correct", False): lc += 1
|
| 606 |
raw_s = lc / lt if lt else 0.0
|
| 607 |
-
s =
|
| 608 |
if done: episode_scores.append(s)
|
| 609 |
info["task_score"] = s
|
| 610 |
await websocket.send_json({
|
|
|
|
| 131 |
|
| 132 |
def _score() -> float:
|
| 133 |
raw = _step_correct / _step_total if _step_total else 0.0
|
| 134 |
+
return round((raw * 0.98) + 0.01, 4)
|
| 135 |
|
| 136 |
|
| 137 |
# ── PPO helpers ───────────────────────────────────────────────────────────────
|
|
|
|
| 604 |
lt += 1
|
| 605 |
if info.get("action_correct", False): lc += 1
|
| 606 |
raw_s = lc / lt if lt else 0.0
|
| 607 |
+
s = round((raw_s * 0.98) + 0.01, 4)
|
| 608 |
if done: episode_scores.append(s)
|
| 609 |
info["task_score"] = s
|
| 610 |
await websocket.send_json({
|
tasks/easy.py
CHANGED
|
@@ -127,10 +127,10 @@ class EasyTaskGrader:
|
|
| 127 |
"alert_type": alert_data.get("alert_type", ""),
|
| 128 |
"is_false_positive":alert_data.get("is_false_positive", False),
|
| 129 |
"correct": is_correct,
|
| 130 |
-
"score":
|
| 131 |
})
|
| 132 |
|
| 133 |
-
return
|
| 134 |
|
| 135 |
# ------------------------------------------------------------------
|
| 136 |
# Legacy API (unit tests / backward compat)
|
|
@@ -167,10 +167,10 @@ class EasyTaskGrader:
|
|
| 167 |
return 0.5
|
| 168 |
|
| 169 |
raw = self.correct_actions / self.total_actions
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
return float(round(
|
| 174 |
|
| 175 |
|
| 176 |
def passed(self) -> bool:
|
|
|
|
| 127 |
"alert_type": alert_data.get("alert_type", ""),
|
| 128 |
"is_false_positive":alert_data.get("is_false_positive", False),
|
| 129 |
"correct": is_correct,
|
| 130 |
+
"score": 1.0 if is_correct else 0.0,
|
| 131 |
})
|
| 132 |
|
| 133 |
+
return 1.0 if is_correct else 0.0
|
| 134 |
|
| 135 |
# ------------------------------------------------------------------
|
| 136 |
# Legacy API (unit tests / backward compat)
|
|
|
|
| 167 |
return 0.5
|
| 168 |
|
| 169 |
raw = self.correct_actions / self.total_actions
|
| 170 |
+
|
| 171 |
+
# Linearly map exactly to [0.01, 0.99] without clipping
|
| 172 |
+
mapped = (raw * 0.98) + 0.01
|
| 173 |
+
return float(round(mapped, 4))
|
| 174 |
|
| 175 |
|
| 176 |
def passed(self) -> bool:
|
tasks/hard.py
CHANGED
|
@@ -383,14 +383,14 @@ class HardTaskGrader:
|
|
| 383 |
)
|
| 384 |
|
| 385 |
denominator = max(max_chain, 1.0)
|
| 386 |
-
raw =
|
| 387 |
-
|
| 388 |
stability = self._stability_score(self._system_failures)
|
| 389 |
-
|
| 390 |
-
#
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
return float(round(
|
| 394 |
|
| 395 |
|
| 396 |
def passed(self) -> bool:
|
|
|
|
| 383 |
)
|
| 384 |
|
| 385 |
denominator = max(max_chain, 1.0)
|
| 386 |
+
raw = (chain_score + isolation) / denominator
|
| 387 |
+
|
| 388 |
stability = self._stability_score(self._system_failures)
|
| 389 |
+
|
| 390 |
+
# Raw * stability is naturally in [0, 1].
|
| 391 |
+
# Map [0, 1] linearly to [0.01, 0.99] without clipping
|
| 392 |
+
mapped = (raw * stability * 0.98) + 0.01
|
| 393 |
+
return float(round(mapped, 4))
|
| 394 |
|
| 395 |
|
| 396 |
def passed(self) -> bool:
|
tasks/medium.py
CHANGED
|
@@ -197,25 +197,27 @@ class MediumTaskGrader:
|
|
| 197 |
if self._max_possible_score <= 0.0:
|
| 198 |
return 0.5
|
| 199 |
|
| 200 |
-
raw =
|
| 201 |
-
|
| 202 |
if self._total_investigations > 0:
|
| 203 |
fp_rate = self._unnecessary_invest / self._total_investigations
|
| 204 |
else:
|
| 205 |
fp_rate = 0.0
|
| 206 |
fp_penalty = _FP_PENALTY_WEIGHT * fp_rate
|
| 207 |
-
|
| 208 |
if self._critical_total > 0:
|
| 209 |
miss_rate = min(self._critical_missed / self._critical_total, 1.0)
|
| 210 |
else:
|
| 211 |
miss_rate = 0.0
|
| 212 |
miss_penalty = _CRITICAL_MISS_PENALTY_WEIGHT * miss_rate
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
#
|
| 218 |
-
|
|
|
|
|
|
|
| 219 |
|
| 220 |
|
| 221 |
def passed(self) -> bool:
|
|
|
|
| 197 |
if self._max_possible_score <= 0.0:
|
| 198 |
return 0.5
|
| 199 |
|
| 200 |
+
raw = self._resolved_score / self._max_possible_score
|
| 201 |
+
|
| 202 |
if self._total_investigations > 0:
|
| 203 |
fp_rate = self._unnecessary_invest / self._total_investigations
|
| 204 |
else:
|
| 205 |
fp_rate = 0.0
|
| 206 |
fp_penalty = _FP_PENALTY_WEIGHT * fp_rate
|
| 207 |
+
|
| 208 |
if self._critical_total > 0:
|
| 209 |
miss_rate = min(self._critical_missed / self._critical_total, 1.0)
|
| 210 |
else:
|
| 211 |
miss_rate = 0.0
|
| 212 |
miss_penalty = _CRITICAL_MISS_PENALTY_WEIGHT * miss_rate
|
| 213 |
+
|
| 214 |
+
# Penalised score is effectively between -0.50 and 1.00
|
| 215 |
+
penalised = raw - fp_penalty - miss_penalty
|
| 216 |
+
|
| 217 |
+
# Math map: penalised * 0.6 is [-0.3, 0.6]
|
| 218 |
+
# + 0.35 yields [0.05, 0.95] which guarantees (0, 1) bounds without clipping.
|
| 219 |
+
mapped = (penalised * 0.6) + 0.35
|
| 220 |
+
return float(round(mapped, 4))
|
| 221 |
|
| 222 |
|
| 223 |
def passed(self) -> bool:
|
tests/test_env.py
CHANGED
|
@@ -107,26 +107,26 @@ class TestTaskConfigurations:
|
|
| 107 |
def test_easy_task_config(self):
|
| 108 |
"""Test easy task has correct configuration."""
|
| 109 |
env = AdaptiveAlertTriageEnv(task_id="easy", seed=42)
|
| 110 |
-
|
| 111 |
-
assert env.max_steps ==
|
| 112 |
assert env.max_investigations_per_step is None # No resource constraint
|
| 113 |
-
assert env.failure_threshold ==
|
| 114 |
|
| 115 |
def test_medium_task_config(self):
|
| 116 |
"""Test medium task has resource constraints."""
|
| 117 |
env = AdaptiveAlertTriageEnv(task_id="medium", seed=42)
|
| 118 |
-
|
| 119 |
-
assert env.max_steps ==
|
| 120 |
assert env.max_investigations_per_step == 3 # Resource constrained
|
| 121 |
-
assert env.failure_threshold ==
|
| 122 |
|
| 123 |
def test_hard_task_config(self):
|
| 124 |
"""Test hard task has stricter failure tolerance."""
|
| 125 |
env = AdaptiveAlertTriageEnv(task_id="hard", seed=42)
|
| 126 |
-
|
| 127 |
-
assert env.max_steps ==
|
| 128 |
assert env.max_investigations_per_step == 3
|
| 129 |
-
assert env.failure_threshold ==
|
| 130 |
|
| 131 |
def test_resource_budget_tracking(self):
|
| 132 |
"""Test resource budget is tracked in medium/hard tasks."""
|
|
|
|
| 107 |
def test_easy_task_config(self):
|
| 108 |
"""Test easy task has correct configuration."""
|
| 109 |
env = AdaptiveAlertTriageEnv(task_id="easy", seed=42)
|
| 110 |
+
|
| 111 |
+
assert env.max_steps == 10
|
| 112 |
assert env.max_investigations_per_step is None # No resource constraint
|
| 113 |
+
assert env.failure_threshold == 2
|
| 114 |
|
| 115 |
def test_medium_task_config(self):
|
| 116 |
"""Test medium task has resource constraints."""
|
| 117 |
env = AdaptiveAlertTriageEnv(task_id="medium", seed=42)
|
| 118 |
+
|
| 119 |
+
assert env.max_steps == 15
|
| 120 |
assert env.max_investigations_per_step == 3 # Resource constrained
|
| 121 |
+
assert env.failure_threshold == 3
|
| 122 |
|
| 123 |
def test_hard_task_config(self):
|
| 124 |
"""Test hard task has stricter failure tolerance."""
|
| 125 |
env = AdaptiveAlertTriageEnv(task_id="hard", seed=42)
|
| 126 |
+
|
| 127 |
+
assert env.max_steps == 20
|
| 128 |
assert env.max_investigations_per_step == 3
|
| 129 |
+
assert env.failure_threshold == 2 # Stricter
|
| 130 |
|
| 131 |
def test_resource_budget_tracking(self):
|
| 132 |
"""Test resource budget is tracked in medium/hard tasks."""
|
tests/test_integration.py
CHANGED
|
@@ -202,7 +202,7 @@ class TestGraderWithProcessStep:
|
|
| 202 |
assert score == 1.0, "Should be correct for investigating critical"
|
| 203 |
|
| 204 |
final_score = grader.get_episode_score()
|
| 205 |
-
assert final_score ==
|
| 206 |
|
| 207 |
def test_medium_grader_process_step(self):
|
| 208 |
"""Test MediumTaskGrader.process_step() with alert data dict."""
|
|
@@ -237,9 +237,8 @@ class TestGraderWithProcessStep:
|
|
| 237 |
|
| 238 |
contribution = grader.process_step(alert_data, {})
|
| 239 |
|
| 240 |
-
# Should have correlation bonus
|
| 241 |
-
assert
|
| 242 |
-
assert contribution > 0.8, "Should get bonus for correlated alert"
|
| 243 |
|
| 244 |
|
| 245 |
class TestEvaluationIntegration:
|
|
@@ -331,15 +330,14 @@ class TestEvaluationIntegration:
|
|
| 331 |
metrics = grader.get_metrics()
|
| 332 |
|
| 333 |
# Verify grader tracked data
|
| 334 |
-
assert grader.
|
| 335 |
assert score >= 0.0, f"Score should be >= 0, got {score}"
|
| 336 |
|
| 337 |
# Log metrics for debugging
|
| 338 |
print(f"\nHard task metrics:")
|
| 339 |
print(f" Score: {score:.3f}")
|
| 340 |
print(f" Correlated alerts seen: {correlated_alerts_seen}")
|
| 341 |
-
print(f"
|
| 342 |
-
print(f" Total chains: {metrics['total_correlation_chains']}")
|
| 343 |
|
| 344 |
def test_full_evaluation_episode(self):
|
| 345 |
"""Full evaluation episode with all fixes."""
|
|
@@ -422,14 +420,13 @@ class TestCorrelationBonusFiring:
|
|
| 422 |
"correlation_group": 0,
|
| 423 |
}
|
| 424 |
|
| 425 |
-
initial_bonus = grader.correlation_bonus
|
| 426 |
grader.process_step(alert_data, {})
|
| 427 |
|
| 428 |
-
assert grader.
|
| 429 |
-
|
| 430 |
|
| 431 |
# Should also detect the correlation
|
| 432 |
-
assert grader.
|
| 433 |
|
| 434 |
def test_no_bonus_for_non_correlated(self):
|
| 435 |
"""Verify no correlation bonus for non-correlated alerts."""
|
|
@@ -445,7 +442,7 @@ class TestCorrelationBonusFiring:
|
|
| 445 |
|
| 446 |
grader.process_step(alert_data, {})
|
| 447 |
|
| 448 |
-
assert grader.
|
| 449 |
|
| 450 |
|
| 451 |
if __name__ == "__main__":
|
|
|
|
| 202 |
assert score == 1.0, "Should be correct for investigating critical"
|
| 203 |
|
| 204 |
final_score = grader.get_episode_score()
|
| 205 |
+
assert final_score == 0.99, "Episode score should be 0.99 mapped"
|
| 206 |
|
| 207 |
def test_medium_grader_process_step(self):
|
| 208 |
"""Test MediumTaskGrader.process_step() with alert data dict."""
|
|
|
|
| 237 |
|
| 238 |
contribution = grader.process_step(alert_data, {})
|
| 239 |
|
| 240 |
+
# Should have correlation bonus mapped to contribution
|
| 241 |
+
assert contribution >= 0.8, "Should get bonus for correlated alert"
|
|
|
|
| 242 |
|
| 243 |
|
| 244 |
class TestEvaluationIntegration:
|
|
|
|
| 330 |
metrics = grader.get_metrics()
|
| 331 |
|
| 332 |
# Verify grader tracked data
|
| 333 |
+
assert grader._total_actions > 0, "Should have processed actions"
|
| 334 |
assert score >= 0.0, f"Score should be >= 0, got {score}"
|
| 335 |
|
| 336 |
# Log metrics for debugging
|
| 337 |
print(f"\nHard task metrics:")
|
| 338 |
print(f" Score: {score:.3f}")
|
| 339 |
print(f" Correlated alerts seen: {correlated_alerts_seen}")
|
| 340 |
+
print(f" Total chains: {metrics['total_chains']}")
|
|
|
|
| 341 |
|
| 342 |
def test_full_evaluation_episode(self):
|
| 343 |
"""Full evaluation episode with all fixes."""
|
|
|
|
| 420 |
"correlation_group": 0,
|
| 421 |
}
|
| 422 |
|
|
|
|
| 423 |
grader.process_step(alert_data, {})
|
| 424 |
|
| 425 |
+
assert grader.get_metrics()["chain_score"] > 0, \
|
| 426 |
+
"Correlation bonus should increase!"
|
| 427 |
|
| 428 |
# Should also detect the correlation
|
| 429 |
+
assert grader.calculate_correlation_detection_rate() > 0.0, "Should detect correlation"
|
| 430 |
|
| 431 |
def test_no_bonus_for_non_correlated(self):
|
| 432 |
"""Verify no correlation bonus for non-correlated alerts."""
|
|
|
|
| 442 |
|
| 443 |
grader.process_step(alert_data, {})
|
| 444 |
|
| 445 |
+
assert grader.get_metrics()["chain_score"] == 0.0, "No bonus for non-correlated"
|
| 446 |
|
| 447 |
|
| 448 |
if __name__ == "__main__":
|
tests/test_rewards.py
CHANGED
|
@@ -119,7 +119,7 @@ class TestRewardCalculation:
|
|
| 119 |
reward = calculate_reward(action, alert)
|
| 120 |
|
| 121 |
assert reward.value < 0.0, "Should be negative for wasted resources"
|
| 122 |
-
assert reward.components["
|
| 123 |
|
| 124 |
def test_correlated_alert_bonus(self):
|
| 125 |
"""Test bonus for handling correlated alerts."""
|
|
@@ -175,8 +175,8 @@ class TestRewardCalculation:
|
|
| 175 |
is_correlated=False,
|
| 176 |
)
|
| 177 |
action_delay = Action(alert_id="alert_008", action_type="DELAY")
|
| 178 |
-
|
| 179 |
-
reward_medium = calculate_reward(action_delay, alert_medium)
|
| 180 |
assert reward_medium.value >= 0.0, "Delaying medium alert should be acceptable"
|
| 181 |
|
| 182 |
# Delaying critical alert (risky)
|
|
@@ -238,7 +238,7 @@ class TestAuxiliaryFunctions:
|
|
| 238 |
penalty_3 = calculate_system_failure_penalty(3)
|
| 239 |
|
| 240 |
assert penalty_1 < 0.0
|
| 241 |
-
assert penalty_3
|
| 242 |
|
| 243 |
def test_episode_bonus_high_accuracy(self):
|
| 244 |
"""Test episode bonus for high accuracy."""
|
|
@@ -275,14 +275,14 @@ class TestAuxiliaryFunctions:
|
|
| 275 |
|
| 276 |
assert min_r < 0.0, "Min reward should be negative (penalty)"
|
| 277 |
assert max_r > 0.0, "Max reward should be positive"
|
| 278 |
-
assert max_r > abs(min_r), "Max reward magnitude should exceed penalty"
|
| 279 |
|
| 280 |
def test_reward_summary_empty(self):
|
| 281 |
"""Test reward summary with empty list."""
|
| 282 |
summary = create_reward_summary([])
|
| 283 |
|
| 284 |
assert summary["total_reward"] == 0.0
|
| 285 |
-
assert summary["
|
| 286 |
|
| 287 |
def test_reward_summary_aggregation(self):
|
| 288 |
"""Test reward summary aggregates correctly."""
|
|
@@ -299,7 +299,7 @@ class TestAuxiliaryFunctions:
|
|
| 299 |
|
| 300 |
assert summary["total_reward"] == 11.0
|
| 301 |
assert summary["mean_reward"] == 11.0 / 3
|
| 302 |
-
assert summary["
|
| 303 |
assert summary["correct_actions"] == 2
|
| 304 |
assert summary["accuracy"] == 2/3
|
| 305 |
assert "critical_handled" in summary["components"]
|
|
|
|
| 119 |
reward = calculate_reward(action, alert)
|
| 120 |
|
| 121 |
assert reward.value < 0.0, "Should be negative for wasted resources"
|
| 122 |
+
assert reward.components["unnecessary_invest"] < 0.0
|
| 123 |
|
| 124 |
def test_correlated_alert_bonus(self):
|
| 125 |
"""Test bonus for handling correlated alerts."""
|
|
|
|
| 175 |
is_correlated=False,
|
| 176 |
)
|
| 177 |
action_delay = Action(alert_id="alert_008", action_type="DELAY")
|
| 178 |
+
|
| 179 |
+
reward_medium = calculate_reward(action_delay, alert_medium, {"max_investigations": 3})
|
| 180 |
assert reward_medium.value >= 0.0, "Delaying medium alert should be acceptable"
|
| 181 |
|
| 182 |
# Delaying critical alert (risky)
|
|
|
|
| 238 |
penalty_3 = calculate_system_failure_penalty(3)
|
| 239 |
|
| 240 |
assert penalty_1 < 0.0
|
| 241 |
+
assert penalty_3 < penalty_1
|
| 242 |
|
| 243 |
def test_episode_bonus_high_accuracy(self):
|
| 244 |
"""Test episode bonus for high accuracy."""
|
|
|
|
| 275 |
|
| 276 |
assert min_r < 0.0, "Min reward should be negative (penalty)"
|
| 277 |
assert max_r > 0.0, "Max reward should be positive"
|
| 278 |
+
assert max_r >= abs(min_r) - 0.01, "Max reward magnitude should be similar or exceed penalty"
|
| 279 |
|
| 280 |
def test_reward_summary_empty(self):
|
| 281 |
"""Test reward summary with empty list."""
|
| 282 |
summary = create_reward_summary([])
|
| 283 |
|
| 284 |
assert summary["total_reward"] == 0.0
|
| 285 |
+
assert summary["num_steps"] == 0
|
| 286 |
|
| 287 |
def test_reward_summary_aggregation(self):
|
| 288 |
"""Test reward summary aggregates correctly."""
|
|
|
|
| 299 |
|
| 300 |
assert summary["total_reward"] == 11.0
|
| 301 |
assert summary["mean_reward"] == 11.0 / 3
|
| 302 |
+
assert summary["num_steps"] == 3
|
| 303 |
assert summary["correct_actions"] == 2
|
| 304 |
assert summary["accuracy"] == 2/3
|
| 305 |
assert "critical_handled" in summary["components"]
|
tests/test_tasks.py
CHANGED
|
@@ -139,7 +139,7 @@ class TestMediumTaskGrader:
|
|
| 139 |
contribution = grader.grade_action(action, alert, reward)
|
| 140 |
|
| 141 |
assert contribution > 0.0, "High-value investigation should contribute positively"
|
| 142 |
-
assert grader.
|
| 143 |
|
| 144 |
def test_wasteful_investigation(self):
|
| 145 |
"""Test investigation on false positive is penalized."""
|
|
@@ -157,9 +157,8 @@ class TestMediumTaskGrader:
|
|
| 157 |
reward = Reward(value=-2.0)
|
| 158 |
|
| 159 |
contribution = grader.grade_action(action, alert, reward)
|
| 160 |
-
|
| 161 |
-
assert
|
| 162 |
-
assert grader.unnecessary_investigations == 1
|
| 163 |
|
| 164 |
def test_resource_efficiency_calculation(self):
|
| 165 |
"""Test resource efficiency metric."""
|
|
@@ -214,7 +213,7 @@ class TestMediumTaskGrader:
|
|
| 214 |
|
| 215 |
grader.grade_action(action, alert, reward)
|
| 216 |
|
| 217 |
-
assert grader.
|
| 218 |
# Score should be penalized
|
| 219 |
score = grader.get_episode_score()
|
| 220 |
assert score < 0.5, "Missing critical should heavily impact score"
|
|
@@ -243,9 +242,7 @@ class TestHardTaskGrader:
|
|
| 243 |
|
| 244 |
contribution = grader.grade_action(action, alert, reward)
|
| 245 |
|
| 246 |
-
|
| 247 |
-
assert contribution > alert.true_severity, "Should get correlation bonus"
|
| 248 |
-
assert grader.correlation_bonus > 0.0
|
| 249 |
|
| 250 |
def test_failure_prevention_bonus(self):
|
| 251 |
"""Test bonus for preventing cascading failures."""
|
|
|
|
| 139 |
contribution = grader.grade_action(action, alert, reward)
|
| 140 |
|
| 141 |
assert contribution > 0.0, "High-value investigation should contribute positively"
|
| 142 |
+
assert grader._total_investigations == 1
|
| 143 |
|
| 144 |
def test_wasteful_investigation(self):
|
| 145 |
"""Test investigation on false positive is penalized."""
|
|
|
|
| 157 |
reward = Reward(value=-2.0)
|
| 158 |
|
| 159 |
contribution = grader.grade_action(action, alert, reward)
|
| 160 |
+
assert contribution == 0.0, "Wasteful investigation should give zero contribution"
|
| 161 |
+
assert grader._unnecessary_invest == 1
|
|
|
|
| 162 |
|
| 163 |
def test_resource_efficiency_calculation(self):
|
| 164 |
"""Test resource efficiency metric."""
|
|
|
|
| 213 |
|
| 214 |
grader.grade_action(action, alert, reward)
|
| 215 |
|
| 216 |
+
assert grader._critical_missed == 1
|
| 217 |
# Score should be penalized
|
| 218 |
score = grader.get_episode_score()
|
| 219 |
assert score < 0.5, "Missing critical should heavily impact score"
|
|
|
|
| 242 |
|
| 243 |
contribution = grader.grade_action(action, alert, reward)
|
| 244 |
|
| 245 |
+
assert contribution >= alert.true_severity, "Should be rewarded proportionally for chain trigger"
|
|
|
|
|
|
|
| 246 |
|
| 247 |
def test_failure_prevention_bonus(self):
|
| 248 |
"""Test bonus for preventing cascading failures."""
|
uv.lock
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|