File size: 12,529 Bytes
b14c6e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c18a9d1
b14c6e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c18a9d1
 
b14c6e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c18a9d1
b14c6e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c18a9d1
b14c6e3
 
 
 
 
 
c18a9d1
b14c6e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c18a9d1
b14c6e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
"""
Unit Tests for Reward Calculation

Tests reward shaping logic and component breakdown.
"""

import pytest
from adaptive_alert_triage.models import Action, Alert, Reward
from rewards.reward import (
    calculate_reward,
    calculate_system_failure_penalty,
    calculate_episode_bonus,
    get_reward_range,
    create_reward_summary,
    REWARD_CRITICAL_HANDLED,
    REWARD_FAILURE_PREVENTED,
    REWARD_FALSE_POSITIVE_IGNORED,
    PENALTY_MISSED_CRITICAL,
)


class TestRewardCalculation:
    """Test core reward calculation logic."""
    
    def test_critical_investigated_reward(self):
        """Test reward for correctly investigating critical alert."""
        alert = Alert(
            id="alert_001",
            visible_severity=0.85,
            confidence=0.9,
            alert_type="CPU",
            age=1,
            true_severity=0.90,  # Critical
            is_correlated=False,
        )
        action = Action(alert_id="alert_001", action_type="INVESTIGATE")
        
        reward = calculate_reward(action, alert)
        
        assert reward.value == REWARD_CRITICAL_HANDLED, \
            f"Expected {REWARD_CRITICAL_HANDLED}, got {reward.value}"
        assert reward.components["critical_handled"] == REWARD_CRITICAL_HANDLED
        assert reward.info["action_correct"] is True
    
    def test_critical_escalated_reward(self):
        """Test reward for escalating critical alert."""
        alert = Alert(
            id="alert_002",
            visible_severity=0.9,
            confidence=0.95,
            alert_type="SECURITY",
            age=2,
            true_severity=0.95,
            is_correlated=False,
        )
        action = Action(alert_id="alert_002", action_type="ESCALATE")
        
        reward = calculate_reward(action, alert)
        
        # Escalate gets 90% of investigate reward
        expected = REWARD_CRITICAL_HANDLED * 0.9
        assert abs(reward.value - expected) < 0.01, \
            f"Expected ~{expected}, got {reward.value}"
        assert reward.info["action_correct"] is True
    
    def test_false_positive_ignored_reward(self):
        """Test reward for correctly ignoring false positive."""
        alert = Alert(
            id="alert_003",
            visible_severity=0.3,
            confidence=0.4,
            alert_type="DISK",
            age=0,
            true_severity=0.15,  # False positive
            is_correlated=False,
        )
        action = Action(alert_id="alert_003", action_type="IGNORE")
        
        reward = calculate_reward(action, alert)
        
        assert reward.value == REWARD_FALSE_POSITIVE_IGNORED, \
            f"Expected {REWARD_FALSE_POSITIVE_IGNORED}, got {reward.value}"
        assert reward.components["false_positive_ignored"] == REWARD_FALSE_POSITIVE_IGNORED
        assert reward.info["action_correct"] is True
    
    def test_critical_ignored_penalty(self):
        """Test penalty for ignoring critical alert."""
        alert = Alert(
            id="alert_004",
            visible_severity=0.7,
            confidence=0.8,
            alert_type="SECURITY",
            age=2,
            true_severity=0.95,  # Critical
            is_correlated=False,
        )
        action = Action(alert_id="alert_004", action_type="IGNORE")
        
        reward = calculate_reward(action, alert)
        
        assert reward.value == PENALTY_MISSED_CRITICAL, \
            f"Expected {PENALTY_MISSED_CRITICAL}, got {reward.value}"
        assert reward.components["missed_critical"] == PENALTY_MISSED_CRITICAL
        assert reward.info["action_correct"] is False
    
    def test_unnecessary_investigation_penalty(self):
        """Test penalty for investigating false positive."""
        alert = Alert(
            id="alert_005",
            visible_severity=0.35,
            confidence=0.45,
            alert_type="NETWORK",
            age=0,
            true_severity=0.20,  # False positive
            is_correlated=False,
        )
        action = Action(alert_id="alert_005", action_type="INVESTIGATE")
        
        reward = calculate_reward(action, alert)
        
        assert reward.value < 0.0, "Should be negative for wasted resources"
        assert reward.components["unnecessary_invest"] < 0.0
    
    def test_correlated_alert_bonus(self):
        """Test bonus for handling correlated alerts."""
        alert = Alert(
            id="alert_006",
            visible_severity=0.8,
            confidence=0.85,
            alert_type="CPU",
            age=1,
            true_severity=0.85,
            is_correlated=True,  # Correlated
        )
        action = Action(alert_id="alert_006", action_type="INVESTIGATE")
        
        reward = calculate_reward(action, alert)
        
        # Should get critical + failure prevention bonus
        expected = REWARD_CRITICAL_HANDLED + REWARD_FAILURE_PREVENTED
        assert reward.value == expected, \
            f"Expected {expected}, got {reward.value}"
        assert reward.components["failure_prevented"] == REWARD_FAILURE_PREVENTED
    
    def test_medium_alert_handling(self):
        """Test reward for medium severity alerts."""
        alert = Alert(
            id="alert_007",
            visible_severity=0.6,
            confidence=0.7,
            alert_type="MEMORY",
            age=1,
            true_severity=0.55,  # Medium
            is_correlated=False,
        )
        action = Action(alert_id="alert_007", action_type="INVESTIGATE")
        
        reward = calculate_reward(action, alert)
        
        # Medium alerts get scaled reward based on severity
        assert 0.0 < reward.value < REWARD_CRITICAL_HANDLED, \
            "Medium alert should get moderate positive reward"
        assert "medium_handled" in reward.components
    
    def test_delay_action_rewards(self):
        """Test rewards for DELAY action."""
        # Delaying medium alert (acceptable)
        alert_medium = Alert(
            id="alert_008",
            visible_severity=0.5,
            confidence=0.6,
            alert_type="DISK",
            age=0,
            true_severity=0.50,
            is_correlated=False,
        )
        action_delay = Action(alert_id="alert_008", action_type="DELAY")
    
        reward_medium = calculate_reward(action_delay, alert_medium, {"max_investigations": 3})
        assert reward_medium.value >= 0.0, "Delaying medium alert should be acceptable"
        
        # Delaying critical alert (risky)
        alert_critical = Alert(
            id="alert_009",
            visible_severity=0.85,
            confidence=0.9,
            alert_type="CPU",
            age=2,
            true_severity=0.90,
            is_correlated=False,
        )
        action_delay_crit = Action(alert_id="alert_009", action_type="DELAY")
        
        reward_critical = calculate_reward(action_delay_crit, alert_critical)
        assert reward_critical.value < 0.0, "Delaying critical alert should be penalized"


class TestRewardComponents:
    """Test reward component breakdown."""
    
    def test_reward_has_components(self):
        """Test that all rewards include component breakdown."""
        alert = Alert(
            id="a1", visible_severity=0.9, confidence=0.9, alert_type="CPU",
            age=1, true_severity=0.9
        )
        action = Action(alert_id="a1", action_type="INVESTIGATE")
        
        reward = calculate_reward(action, alert)
        
        assert isinstance(reward.components, dict)
        assert len(reward.components) > 0
        assert sum(reward.components.values()) == reward.value
    
    def test_reward_info_fields(self):
        """Test that reward info contains useful debugging information."""
        alert = Alert(
            id="a1", visible_severity=0.9, confidence=0.9, alert_type="CPU",
            age=1, true_severity=0.9
        )
        action = Action(alert_id="a1", action_type="INVESTIGATE")
        
        reward = calculate_reward(action, alert)
        
        assert "alert_id" in reward.info
        assert "true_severity" in reward.info
        assert "is_critical" in reward.info
        assert "is_false_positive" in reward.info
        assert "action_correct" in reward.info


class TestAuxiliaryFunctions:
    """Test auxiliary reward functions."""
    
    def test_system_failure_penalty(self):
        """Test system failure penalty calculation."""
        penalty_1 = calculate_system_failure_penalty(1)
        penalty_3 = calculate_system_failure_penalty(3)
        
        assert penalty_1 < 0.0
        assert penalty_3 < penalty_1
    
    def test_episode_bonus_high_accuracy(self):
        """Test episode bonus for high accuracy."""
        bonus = calculate_episode_bonus(correct_actions=85, total_actions=100, failures_count=0)
        
        assert bonus > 0.0, "High accuracy should give bonus"
    
    def test_episode_bonus_perfect(self):
        """Test episode bonus for perfect performance."""
        bonus_perfect = calculate_episode_bonus(
            correct_actions=100, total_actions=100, failures_count=0
        )
        
        bonus_high = calculate_episode_bonus(
            correct_actions=85, total_actions=100, failures_count=0
        )
        
        assert bonus_perfect > bonus_high, "Perfect should get higher bonus"
    
    def test_episode_bonus_with_failures(self):
        """Test that failures reduce episode bonus."""
        bonus_no_fail = calculate_episode_bonus(
            correct_actions=80, total_actions=100, failures_count=0
        )
        bonus_with_fail = calculate_episode_bonus(
            correct_actions=80, total_actions=100, failures_count=2
        )
        
        assert bonus_no_fail > bonus_with_fail, "Failures should reduce bonus"
    
    def test_reward_range(self):
        """Test reward range calculation."""
        min_r, max_r = get_reward_range()
        
        assert min_r < 0.0, "Min reward should be negative (penalty)"
        assert max_r > 0.0, "Max reward should be positive"
        assert max_r >= abs(min_r) - 0.01, "Max reward magnitude should be similar or exceed penalty"
    
    def test_reward_summary_empty(self):
        """Test reward summary with empty list."""
        summary = create_reward_summary([])
        
        assert summary["total_reward"] == 0.0
        assert summary["num_steps"] == 0
    
    def test_reward_summary_aggregation(self):
        """Test reward summary aggregates correctly."""
        rewards = [
            Reward(value=10.0, components={"critical_handled": 10.0}, 
                   info={"action_correct": True}),
            Reward(value=3.0, components={"false_positive_ignored": 3.0}, 
                   info={"action_correct": True}),
            Reward(value=-2.0, components={"unnecessary_investigation": -2.0}, 
                   info={"action_correct": False}),
        ]
        
        summary = create_reward_summary(rewards)
        
        assert summary["total_reward"] == 11.0
        assert summary["mean_reward"] == 11.0 / 3
        assert summary["num_steps"] == 3
        assert summary["correct_actions"] == 2
        assert summary["accuracy"] == 2/3
        assert "critical_handled" in summary["components"]


class TestRewardConsistency:
    """Test consistency and edge cases."""
    
    def test_same_input_same_reward(self):
        """Test deterministic reward calculation."""
        alert = Alert(
            id="a1", visible_severity=0.9, confidence=0.9, alert_type="CPU",
            age=1, true_severity=0.9
        )
        action = Action(alert_id="a1", action_type="INVESTIGATE")
        
        reward1 = calculate_reward(action, alert)
        reward2 = calculate_reward(action, alert)
        
        assert reward1.value == reward2.value
    
    def test_all_action_types_covered(self):
        """Test that all action types produce rewards."""
        alert = Alert(
            id="a1", visible_severity=0.6, confidence=0.7, alert_type="CPU",
            age=1, true_severity=0.6
        )
        
        action_types = ["INVESTIGATE", "IGNORE", "ESCALATE", "DELAY"]
        
        for action_type in action_types:
            action = Action(alert_id="a1", action_type=action_type)
            reward = calculate_reward(action, alert)
            
            assert isinstance(reward.value, float), \
                f"Action {action_type} should return numeric reward"


if __name__ == "__main__":
    pytest.main([__file__, "-v"])