Spaces:
Sleeping
Sleeping
| """ | |
| Unit Tests for Reward Calculation | |
| Tests reward shaping logic and component breakdown. | |
| """ | |
| import pytest | |
| from adaptive_alert_triage.models import Action, Alert, Reward | |
| from rewards.reward import ( | |
| calculate_reward, | |
| calculate_system_failure_penalty, | |
| calculate_episode_bonus, | |
| get_reward_range, | |
| create_reward_summary, | |
| REWARD_CRITICAL_HANDLED, | |
| REWARD_FAILURE_PREVENTED, | |
| REWARD_FALSE_POSITIVE_IGNORED, | |
| PENALTY_MISSED_CRITICAL, | |
| ) | |
| class TestRewardCalculation: | |
| """Test core reward calculation logic.""" | |
| def test_critical_investigated_reward(self): | |
| """Test reward for correctly investigating critical alert.""" | |
| alert = Alert( | |
| id="alert_001", | |
| visible_severity=0.85, | |
| confidence=0.9, | |
| alert_type="CPU", | |
| age=1, | |
| true_severity=0.90, # Critical | |
| is_correlated=False, | |
| ) | |
| action = Action(alert_id="alert_001", action_type="INVESTIGATE") | |
| reward = calculate_reward(action, alert) | |
| assert reward.value == REWARD_CRITICAL_HANDLED, \ | |
| f"Expected {REWARD_CRITICAL_HANDLED}, got {reward.value}" | |
| assert reward.components["critical_handled"] == REWARD_CRITICAL_HANDLED | |
| assert reward.info["action_correct"] is True | |
| def test_critical_escalated_reward(self): | |
| """Test reward for escalating critical alert.""" | |
| alert = Alert( | |
| id="alert_002", | |
| visible_severity=0.9, | |
| confidence=0.95, | |
| alert_type="SECURITY", | |
| age=2, | |
| true_severity=0.95, | |
| is_correlated=False, | |
| ) | |
| action = Action(alert_id="alert_002", action_type="ESCALATE") | |
| reward = calculate_reward(action, alert) | |
| # Escalate gets 90% of investigate reward | |
| expected = REWARD_CRITICAL_HANDLED * 0.9 | |
| assert abs(reward.value - expected) < 0.01, \ | |
| f"Expected ~{expected}, got {reward.value}" | |
| assert reward.info["action_correct"] is True | |
| def test_false_positive_ignored_reward(self): | |
| """Test reward for correctly ignoring false positive.""" | |
| alert = Alert( | |
| id="alert_003", | |
| visible_severity=0.3, | |
| confidence=0.4, | |
| alert_type="DISK", | |
| age=0, | |
| true_severity=0.15, # False positive | |
| is_correlated=False, | |
| ) | |
| action = Action(alert_id="alert_003", action_type="IGNORE") | |
| reward = calculate_reward(action, alert) | |
| assert reward.value == REWARD_FALSE_POSITIVE_IGNORED, \ | |
| f"Expected {REWARD_FALSE_POSITIVE_IGNORED}, got {reward.value}" | |
| assert reward.components["false_positive_ignored"] == REWARD_FALSE_POSITIVE_IGNORED | |
| assert reward.info["action_correct"] is True | |
| def test_critical_ignored_penalty(self): | |
| """Test penalty for ignoring critical alert.""" | |
| alert = Alert( | |
| id="alert_004", | |
| visible_severity=0.7, | |
| confidence=0.8, | |
| alert_type="SECURITY", | |
| age=2, | |
| true_severity=0.95, # Critical | |
| is_correlated=False, | |
| ) | |
| action = Action(alert_id="alert_004", action_type="IGNORE") | |
| reward = calculate_reward(action, alert) | |
| assert reward.value == PENALTY_MISSED_CRITICAL, \ | |
| f"Expected {PENALTY_MISSED_CRITICAL}, got {reward.value}" | |
| assert reward.components["missed_critical"] == PENALTY_MISSED_CRITICAL | |
| assert reward.info["action_correct"] is False | |
| def test_unnecessary_investigation_penalty(self): | |
| """Test penalty for investigating false positive.""" | |
| alert = Alert( | |
| id="alert_005", | |
| visible_severity=0.35, | |
| confidence=0.45, | |
| alert_type="NETWORK", | |
| age=0, | |
| true_severity=0.20, # False positive | |
| is_correlated=False, | |
| ) | |
| action = Action(alert_id="alert_005", action_type="INVESTIGATE") | |
| reward = calculate_reward(action, alert) | |
| assert reward.value < 0.0, "Should be negative for wasted resources" | |
| assert reward.components["unnecessary_invest"] < 0.0 | |
| def test_correlated_alert_bonus(self): | |
| """Test bonus for handling correlated alerts.""" | |
| alert = Alert( | |
| id="alert_006", | |
| visible_severity=0.8, | |
| confidence=0.85, | |
| alert_type="CPU", | |
| age=1, | |
| true_severity=0.85, | |
| is_correlated=True, # Correlated | |
| ) | |
| action = Action(alert_id="alert_006", action_type="INVESTIGATE") | |
| reward = calculate_reward(action, alert) | |
| # Should get critical + failure prevention bonus | |
| expected = REWARD_CRITICAL_HANDLED + REWARD_FAILURE_PREVENTED | |
| assert reward.value == expected, \ | |
| f"Expected {expected}, got {reward.value}" | |
| assert reward.components["failure_prevented"] == REWARD_FAILURE_PREVENTED | |
| def test_medium_alert_handling(self): | |
| """Test reward for medium severity alerts.""" | |
| alert = Alert( | |
| id="alert_007", | |
| visible_severity=0.6, | |
| confidence=0.7, | |
| alert_type="MEMORY", | |
| age=1, | |
| true_severity=0.55, # Medium | |
| is_correlated=False, | |
| ) | |
| action = Action(alert_id="alert_007", action_type="INVESTIGATE") | |
| reward = calculate_reward(action, alert) | |
| # Medium alerts get scaled reward based on severity | |
| assert 0.0 < reward.value < REWARD_CRITICAL_HANDLED, \ | |
| "Medium alert should get moderate positive reward" | |
| assert "medium_handled" in reward.components | |
| def test_delay_action_rewards(self): | |
| """Test rewards for DELAY action.""" | |
| # Delaying medium alert (acceptable) | |
| alert_medium = Alert( | |
| id="alert_008", | |
| visible_severity=0.5, | |
| confidence=0.6, | |
| alert_type="DISK", | |
| age=0, | |
| true_severity=0.50, | |
| is_correlated=False, | |
| ) | |
| action_delay = Action(alert_id="alert_008", action_type="DELAY") | |
| reward_medium = calculate_reward(action_delay, alert_medium, {"max_investigations": 3}) | |
| assert reward_medium.value >= 0.0, "Delaying medium alert should be acceptable" | |
| # Delaying critical alert (risky) | |
| alert_critical = Alert( | |
| id="alert_009", | |
| visible_severity=0.85, | |
| confidence=0.9, | |
| alert_type="CPU", | |
| age=2, | |
| true_severity=0.90, | |
| is_correlated=False, | |
| ) | |
| action_delay_crit = Action(alert_id="alert_009", action_type="DELAY") | |
| reward_critical = calculate_reward(action_delay_crit, alert_critical) | |
| assert reward_critical.value < 0.0, "Delaying critical alert should be penalized" | |
| class TestRewardComponents: | |
| """Test reward component breakdown.""" | |
| def test_reward_has_components(self): | |
| """Test that all rewards include component breakdown.""" | |
| alert = Alert( | |
| id="a1", visible_severity=0.9, confidence=0.9, alert_type="CPU", | |
| age=1, true_severity=0.9 | |
| ) | |
| action = Action(alert_id="a1", action_type="INVESTIGATE") | |
| reward = calculate_reward(action, alert) | |
| assert isinstance(reward.components, dict) | |
| assert len(reward.components) > 0 | |
| assert sum(reward.components.values()) == reward.value | |
| def test_reward_info_fields(self): | |
| """Test that reward info contains useful debugging information.""" | |
| alert = Alert( | |
| id="a1", visible_severity=0.9, confidence=0.9, alert_type="CPU", | |
| age=1, true_severity=0.9 | |
| ) | |
| action = Action(alert_id="a1", action_type="INVESTIGATE") | |
| reward = calculate_reward(action, alert) | |
| assert "alert_id" in reward.info | |
| assert "true_severity" in reward.info | |
| assert "is_critical" in reward.info | |
| assert "is_false_positive" in reward.info | |
| assert "action_correct" in reward.info | |
| class TestAuxiliaryFunctions: | |
| """Test auxiliary reward functions.""" | |
| def test_system_failure_penalty(self): | |
| """Test system failure penalty calculation.""" | |
| penalty_1 = calculate_system_failure_penalty(1) | |
| penalty_3 = calculate_system_failure_penalty(3) | |
| assert penalty_1 < 0.0 | |
| assert penalty_3 < penalty_1 | |
| def test_episode_bonus_high_accuracy(self): | |
| """Test episode bonus for high accuracy.""" | |
| bonus = calculate_episode_bonus(correct_actions=85, total_actions=100, failures_count=0) | |
| assert bonus > 0.0, "High accuracy should give bonus" | |
| def test_episode_bonus_perfect(self): | |
| """Test episode bonus for perfect performance.""" | |
| bonus_perfect = calculate_episode_bonus( | |
| correct_actions=100, total_actions=100, failures_count=0 | |
| ) | |
| bonus_high = calculate_episode_bonus( | |
| correct_actions=85, total_actions=100, failures_count=0 | |
| ) | |
| assert bonus_perfect > bonus_high, "Perfect should get higher bonus" | |
| def test_episode_bonus_with_failures(self): | |
| """Test that failures reduce episode bonus.""" | |
| bonus_no_fail = calculate_episode_bonus( | |
| correct_actions=80, total_actions=100, failures_count=0 | |
| ) | |
| bonus_with_fail = calculate_episode_bonus( | |
| correct_actions=80, total_actions=100, failures_count=2 | |
| ) | |
| assert bonus_no_fail > bonus_with_fail, "Failures should reduce bonus" | |
| def test_reward_range(self): | |
| """Test reward range calculation.""" | |
| min_r, max_r = get_reward_range() | |
| assert min_r < 0.0, "Min reward should be negative (penalty)" | |
| assert max_r > 0.0, "Max reward should be positive" | |
| assert max_r >= abs(min_r) - 0.01, "Max reward magnitude should be similar or exceed penalty" | |
| def test_reward_summary_empty(self): | |
| """Test reward summary with empty list.""" | |
| summary = create_reward_summary([]) | |
| assert summary["total_reward"] == 0.0 | |
| assert summary["num_steps"] == 0 | |
| def test_reward_summary_aggregation(self): | |
| """Test reward summary aggregates correctly.""" | |
| rewards = [ | |
| Reward(value=10.0, components={"critical_handled": 10.0}, | |
| info={"action_correct": True}), | |
| Reward(value=3.0, components={"false_positive_ignored": 3.0}, | |
| info={"action_correct": True}), | |
| Reward(value=-2.0, components={"unnecessary_investigation": -2.0}, | |
| info={"action_correct": False}), | |
| ] | |
| summary = create_reward_summary(rewards) | |
| assert summary["total_reward"] == 11.0 | |
| assert summary["mean_reward"] == 11.0 / 3 | |
| assert summary["num_steps"] == 3 | |
| assert summary["correct_actions"] == 2 | |
| assert summary["accuracy"] == 2/3 | |
| assert "critical_handled" in summary["components"] | |
| class TestRewardConsistency: | |
| """Test consistency and edge cases.""" | |
| def test_same_input_same_reward(self): | |
| """Test deterministic reward calculation.""" | |
| alert = Alert( | |
| id="a1", visible_severity=0.9, confidence=0.9, alert_type="CPU", | |
| age=1, true_severity=0.9 | |
| ) | |
| action = Action(alert_id="a1", action_type="INVESTIGATE") | |
| reward1 = calculate_reward(action, alert) | |
| reward2 = calculate_reward(action, alert) | |
| assert reward1.value == reward2.value | |
| def test_all_action_types_covered(self): | |
| """Test that all action types produce rewards.""" | |
| alert = Alert( | |
| id="a1", visible_severity=0.6, confidence=0.7, alert_type="CPU", | |
| age=1, true_severity=0.6 | |
| ) | |
| action_types = ["INVESTIGATE", "IGNORE", "ESCALATE", "DELAY"] | |
| for action_type in action_types: | |
| action = Action(alert_id="a1", action_type=action_type) | |
| reward = calculate_reward(action, alert) | |
| assert isinstance(reward.value, float), \ | |
| f"Action {action_type} should return numeric reward" | |
| if __name__ == "__main__": | |
| pytest.main([__file__, "-v"]) | |