Spaces:
Sleeping
Sleeping
| """ | |
| Unit Tests for Task Graders | |
| Tests grading logic for easy, medium, and hard tasks. | |
| """ | |
| import pytest | |
| from adaptive_alert_triage.models import Action, Alert, Reward | |
| from tasks.easy import EasyTaskGrader | |
| from tasks.medium import MediumTaskGrader | |
| from tasks.hard import HardTaskGrader | |
| class TestEasyTaskGrader: | |
| """Test easy task grading logic.""" | |
| def test_critical_alert_correct(self): | |
| """Test correct handling of critical alert.""" | |
| grader = EasyTaskGrader() | |
| alert = Alert( | |
| id="alert_001", | |
| visible_severity=0.85, | |
| confidence=0.9, | |
| alert_type="CPU", | |
| age=1, | |
| true_severity=0.90, # Critical | |
| ) | |
| action = Action(alert_id="alert_001", action_type="INVESTIGATE") | |
| reward = Reward(value=10.0) | |
| score = grader.grade_action(action, alert, reward) | |
| assert score == 1.0, "Should get full score for correct action" | |
| assert grader.correct_actions == 1 | |
| assert grader.total_actions == 1 | |
| def test_critical_alert_incorrect(self): | |
| """Test incorrect handling of critical alert (ignored).""" | |
| grader = EasyTaskGrader() | |
| alert = Alert( | |
| id="alert_002", | |
| visible_severity=0.7, | |
| confidence=0.8, | |
| alert_type="SECURITY", | |
| age=2, | |
| true_severity=0.95, # Critical | |
| ) | |
| action = Action(alert_id="alert_002", action_type="IGNORE") | |
| reward = Reward(value=-8.0) | |
| score = grader.grade_action(action, alert, reward) | |
| assert score == 0.0, "Should get zero score for missed critical" | |
| assert grader.correct_actions == 0 | |
| assert grader.total_actions == 1 | |
| def test_false_positive_correct(self): | |
| """Test correct handling of false positive (ignored).""" | |
| grader = EasyTaskGrader() | |
| alert = Alert( | |
| id="alert_003", | |
| visible_severity=0.3, | |
| confidence=0.4, | |
| alert_type="DISK", | |
| age=0, | |
| true_severity=0.15, # False positive | |
| ) | |
| action = Action(alert_id="alert_003", action_type="IGNORE") | |
| reward = Reward(value=3.0) | |
| score = grader.grade_action(action, alert, reward) | |
| assert score == 1.0, "Should get full score for ignoring FP" | |
| assert grader.correct_actions == 1 | |
| def test_episode_score_calculation(self): | |
| """Test episode score aggregation.""" | |
| grader = EasyTaskGrader() | |
| # 3 actions: 2 correct, 1 incorrect | |
| alerts_actions = [ | |
| (Alert(id="a1", visible_severity=0.9, confidence=0.9, alert_type="CPU", | |
| age=1, true_severity=0.9), "INVESTIGATE", True), | |
| (Alert(id="a2", visible_severity=0.3, confidence=0.4, alert_type="DISK", | |
| age=0, true_severity=0.2), "IGNORE", True), | |
| (Alert(id="a3", visible_severity=0.8, confidence=0.8, alert_type="SECURITY", | |
| age=1, true_severity=0.95), "IGNORE", False), | |
| ] | |
| for alert, action_type, _ in alerts_actions: | |
| action = Action(alert_id=alert.id, action_type=action_type) | |
| reward = Reward(value=0.0) | |
| grader.grade_action(action, alert, reward) | |
| score = grader.get_episode_score() | |
| assert abs(score - 2/3) < 0.01, f"Expected 0.667, got {score}" | |
| def test_metrics_breakdown(self): | |
| """Test detailed metrics generation.""" | |
| grader = EasyTaskGrader() | |
| alert = Alert( | |
| id="a1", visible_severity=0.9, confidence=0.9, alert_type="CPU", | |
| age=1, true_severity=0.9 | |
| ) | |
| action = Action(alert_id="a1", action_type="INVESTIGATE") | |
| reward = Reward(value=10.0) | |
| grader.grade_action(action, alert, reward) | |
| metrics = grader.get_metrics() | |
| assert "overall_score" in metrics | |
| assert "correct_actions" in metrics | |
| assert "critical_accuracy" in metrics | |
| assert "action_breakdown" in metrics | |
| class TestMediumTaskGrader: | |
| """Test medium task grading logic with resource constraints.""" | |
| def test_productive_investigation(self): | |
| """Test high-value investigation scores well.""" | |
| grader = MediumTaskGrader(max_investigations_per_step=3) | |
| alert = Alert( | |
| id="alert_001", | |
| visible_severity=0.85, | |
| confidence=0.9, | |
| alert_type="CPU", | |
| age=1, | |
| true_severity=0.90, | |
| ) | |
| action = Action(alert_id="alert_001", action_type="INVESTIGATE") | |
| reward = Reward(value=10.0) | |
| contribution = grader.grade_action(action, alert, reward) | |
| assert contribution > 0.0, "High-value investigation should contribute positively" | |
| assert grader._total_investigations == 1 | |
| def test_wasteful_investigation(self): | |
| """Test investigation on false positive is penalized.""" | |
| grader = MediumTaskGrader(max_investigations_per_step=3) | |
| alert = Alert( | |
| id="alert_002", | |
| visible_severity=0.3, | |
| confidence=0.4, | |
| alert_type="DISK", | |
| age=0, | |
| true_severity=0.15, # False positive | |
| ) | |
| action = Action(alert_id="alert_002", action_type="INVESTIGATE") | |
| reward = Reward(value=-2.0) | |
| contribution = grader.grade_action(action, alert, reward) | |
| assert contribution == 0.0, "Wasteful investigation should give zero contribution" | |
| assert grader._unnecessary_invest == 1 | |
| def test_resource_efficiency_calculation(self): | |
| """Test resource efficiency metric.""" | |
| grader = MediumTaskGrader(max_investigations_per_step=3) | |
| # 2 productive investigations, 1 wasteful | |
| alerts_actions = [ | |
| (0.9, "INVESTIGATE", True), # Productive | |
| (0.8, "INVESTIGATE", True), # Productive | |
| (0.15, "INVESTIGATE", False), # Wasteful | |
| ] | |
| for true_sev, action_type, _ in alerts_actions: | |
| alert = Alert( | |
| id=f"a_{true_sev}", visible_severity=true_sev, confidence=0.8, | |
| alert_type="CPU", age=1, true_severity=true_sev | |
| ) | |
| action = Action(alert_id=alert.id, action_type=action_type) | |
| reward = Reward(value=0.0) | |
| grader.grade_action(action, alert, reward) | |
| efficiency = grader.calculate_resource_efficiency() | |
| assert abs(efficiency - 2/3) < 0.01, f"Expected 0.667, got {efficiency}" | |
| def test_episode_score_with_efficiency(self): | |
| """Test that episode score considers efficiency factor.""" | |
| grader = MediumTaskGrader(max_investigations_per_step=3) | |
| # Add some actions | |
| alert = Alert( | |
| id="a1", visible_severity=0.9, confidence=0.9, alert_type="CPU", | |
| age=1, true_severity=0.9 | |
| ) | |
| action = Action(alert_id="a1", action_type="INVESTIGATE") | |
| reward = Reward(value=10.0) | |
| grader.grade_action(action, alert, reward) | |
| score = grader.get_episode_score() | |
| assert 0.0 <= score <= 1.0, "Score should be normalized" | |
| def test_critical_missed_penalty(self): | |
| """Test missing critical alerts incurs penalty.""" | |
| grader = MediumTaskGrader(max_investigations_per_step=3) | |
| alert = Alert( | |
| id="a1", visible_severity=0.8, confidence=0.8, alert_type="SECURITY", | |
| age=1, true_severity=0.95 # Critical | |
| ) | |
| action = Action(alert_id="a1", action_type="IGNORE") | |
| reward = Reward(value=-8.0) | |
| grader.grade_action(action, alert, reward) | |
| assert grader._critical_missed == 1 | |
| # Score should be penalized | |
| score = grader.get_episode_score() | |
| assert score < 0.5, "Missing critical should heavily impact score" | |
| class TestHardTaskGrader: | |
| """Test hard task grading with correlation detection.""" | |
| def test_correlation_detection(self): | |
| """Test bonus for handling correlated alerts.""" | |
| correlation_chains = [["alert_001", "alert_002", "alert_003"]] | |
| grader = HardTaskGrader() | |
| grader.update_correlation_state(correlation_chains) | |
| alert = Alert( | |
| id="alert_001", | |
| visible_severity=0.8, | |
| confidence=0.85, | |
| alert_type="CPU", | |
| age=1, | |
| true_severity=0.85, | |
| is_correlated=True, | |
| ) | |
| action = Action(alert_id="alert_001", action_type="INVESTIGATE") | |
| reward = Reward(value=10.0) | |
| contribution = grader.grade_action(action, alert, reward) | |
| assert contribution >= alert.true_severity, "Should be rewarded proportionally for chain trigger" | |
| def test_failure_prevention_bonus(self): | |
| """Test bonus for preventing cascading failures.""" | |
| correlation_chains = [["alert_001", "alert_002", "alert_003"]] | |
| grader = HardTaskGrader() | |
| grader.update_correlation_state(correlation_chains) | |
| # Handle first alert in chain (early detection) | |
| alert = Alert( | |
| id="alert_001", | |
| visible_severity=0.75, | |
| confidence=0.85, | |
| alert_type="CPU", | |
| age=1, | |
| true_severity=0.80, | |
| is_correlated=True, | |
| ) | |
| action = Action(alert_id="alert_001", action_type="INVESTIGATE") | |
| reward = Reward(value=10.0) | |
| grader.grade_action(action, alert, reward) | |
| m = grader.get_metrics() | |
| assert m["chains_stopped"] >= 1, "Should register failure prevention" | |
| def test_system_failure_penalty(self): | |
| """Test heavy penalty for system failures.""" | |
| grader = HardTaskGrader() | |
| # Record a failure | |
| grader.record_failures(1) | |
| assert grader._system_failures == 1 | |
| # Stability score should be reduced | |
| stability = grader.calculate_stability_score() | |
| assert stability < 1.0 | |
| def test_missed_correlated_alert_penalty(self): | |
| """Test extra penalty for missing correlated alerts.""" | |
| correlation_chains = [["alert_001", "alert_002"]] | |
| grader = HardTaskGrader() | |
| grader.update_correlation_state(correlation_chains) | |
| alert = Alert( | |
| id="alert_001", | |
| visible_severity=0.7, | |
| confidence=0.8, | |
| alert_type="CPU", | |
| age=1, | |
| true_severity=0.85, | |
| is_correlated=True, | |
| ) | |
| action = Action(alert_id="alert_001", action_type="IGNORE") | |
| reward = Reward(value=-8.0) | |
| contribution = grader.grade_action(action, alert, reward) | |
| # Should have negative contribution for missing correlated critical | |
| assert contribution < -0.2, f"Should have extra penalty for correlated miss, got {contribution}" | |
| def test_correlation_detection_rate(self): | |
| """Test calculation of correlation detection rate.""" | |
| correlation_chains = [ | |
| ["alert_001", "alert_002"], | |
| ["alert_003", "alert_004"], | |
| ] | |
| grader = HardTaskGrader() | |
| grader.update_correlation_state(correlation_chains) | |
| # Handle one chain | |
| alert = Alert(id="alert_001", visible_severity=0.8, confidence=0.85, alert_type="CPU", age=1, true_severity=0.85, is_correlated=True) | |
| grader.grade_action(Action(alert_id="alert_001", action_type="INVESTIGATE"), alert, Reward(value=0)) | |
| rate = grader.calculate_correlation_detection_rate() | |
| assert abs(rate - 0.5) < 0.01, "Should detect 50% of chains" | |
| def test_stability_score_perfect(self): | |
| """Test perfect stability (zero failures).""" | |
| grader = HardTaskGrader() | |
| stability = grader.calculate_stability_score() | |
| assert stability == 1.0, "Zero failures should give perfect stability" | |
| def test_stability_score_degraded(self): | |
| """Test degraded stability with failures.""" | |
| grader = HardTaskGrader() | |
| # Multiple failures | |
| for _ in range(3): | |
| grader.record_failures(1) | |
| stability = grader.calculate_stability_score() | |
| assert stability < 1.0, "Failures should reduce stability" | |
| def test_grader_reset(): | |
| """Test that graders can be reset between episodes.""" | |
| grader = EasyTaskGrader() | |
| # Do some actions | |
| alert = Alert( | |
| id="a1", visible_severity=0.9, confidence=0.9, alert_type="CPU", | |
| age=1, true_severity=0.9 | |
| ) | |
| action = Action(alert_id="a1", action_type="INVESTIGATE") | |
| reward = Reward(value=10.0) | |
| grader.grade_action(action, alert, reward) | |
| assert grader.total_actions == 1 | |
| # Reset | |
| grader.reset() | |
| assert grader.total_actions == 0 | |
| assert grader.correct_actions == 0 | |
| assert len(grader.action_history) == 0 | |
| if __name__ == "__main__": | |
| pytest.main([__file__, "-v"]) | |