File size: 4,242 Bytes
ec8c511
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""Multi-objective reward engine for the Adaptive AI Firewall environment.

Computes R = α·security + β·availability + γ·efficiency + δ·timeliness
with careful balance to prevent degenerate policies (block-all / allow-all).
"""
from __future__ import annotations

from typing import Dict, Tuple
import math


ACTIONS = {
    0: "ALLOW",
    1: "BLOCK",
    2: "INSPECT",
    3: "SANDBOX",
    4: "RATE_LIMIT",
    5: "QUARANTINE",
}

# Costs tuned so total episode cost stays well within budget range
ACTION_COSTS = {
    0: {"latency": 0.0, "compute": 0.0},
    1: {"latency": 0.0, "compute": 0.005},
    2: {"latency": 0.08, "compute": 0.05},
    3: {"latency": 0.20, "compute": 0.12},
    4: {"latency": 0.02, "compute": 0.015},
    5: {"latency": 0.05, "compute": 0.025},
}

# Actions that are considered "blocking" (remove traffic from the network)
BLOCKING_ACTIONS = frozenset({1, 3, 5})
# Actions that are considered "inspection" (gather more info)
INSPECTION_ACTIONS = frozenset({2})


class RewardEngine:
    """Weighted multi-objective reward with anti-degeneracy safeguards.

    Key design choices (from RL perspective):
    - Miss penalty (0.8) vs false-positive penalty (0.35) ratio is ~2.3:1
      This creates genuine tension — block-all loses ~0.35 per benign session,
      and with 60–80% benign traffic, the total FP penalty dominates.
    - Correct ALLOW on benign traffic gets +0.08 (dense positive signal).
    - INSPECT keeps the session alive (handled by environment) and gets a
      small bonus when it correctly identifies malicious traffic.
    - Timeliness bonus decays exponentially with kill-chain phase.
    """

    def __init__(
        self,
        alpha: float = 0.35,
        beta: float = 0.30,
        gamma: float = 0.20,
        delta: float = 0.15,
    ) -> None:
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.delta = delta

    def action_cost(self, action: int) -> float:
        costs = ACTION_COSTS.get(action, ACTION_COSTS[0])
        return float(costs["latency"] + costs["compute"])

    def reward(
        self,
        action: int,
        is_malicious: bool,
        budget_remaining: float,
        attack_phase: int,
        inspect_correct: bool = False,
    ) -> Tuple[float, Dict[str, float]]:
        blocked = action in BLOCKING_ACTIONS
        inspected = action in INSPECTION_ACTIONS

        # --- Security component ---
        r_security = 0.0
        if is_malicious and blocked:
            r_security += 1.0
        elif is_malicious and not blocked and not inspected:
            r_security -= 2.0  # missed attack
        elif is_malicious and inspected:
            # Inspect alone doesn't stop the attack but is a valid step
            r_security += 0.15 if inspect_correct else -0.5

        # --- Availability component ---
        r_availability = 0.0
        if not is_malicious and blocked:
            r_availability -= 1.2  # strong false-positive penalty
        elif not is_malicious and inspected:
            r_availability -= 0.15  # unnecessary inspection delay
        elif not is_malicious and action == 0:
            r_availability += 0.25  # correctly allowing benign traffic

        # Rate-limiting benign traffic is a milder false positive
        if not is_malicious and action == 4:
            r_availability -= 0.4

        # --- Efficiency component ---
        cost = self.action_cost(action)
        # Penalize cost relative to remaining budget (bigger penalty as budget shrinks)
        r_efficiency = -cost / max(budget_remaining, 0.1)

        # --- Timeliness component ---
        # Exponential bonus for catching attacks early in kill chain
        early_bonus = math.exp(-max(attack_phase, 0))
        r_timeliness = early_bonus if (is_malicious and blocked) else 0.0

        total = (
            self.alpha * r_security
            + self.beta * r_availability
            + self.gamma * r_efficiency
            + self.delta * r_timeliness
        )
        return total, {
            "security": r_security,
            "availability": r_availability,
            "efficiency": r_efficiency,
            "timeliness": r_timeliness,
            "cost": cost,
        }