| """Multi-objective reward engine for the Adaptive AI Firewall environment. |
| |
| Computes R = α·security + β·availability + γ·efficiency + δ·timeliness |
| with careful balance to prevent degenerate policies (block-all / allow-all). |
| """ |
| from __future__ import annotations |
|
|
| from typing import Dict, Tuple |
| import math |
|
|
|
|
| ACTIONS = { |
| 0: "ALLOW", |
| 1: "BLOCK", |
| 2: "INSPECT", |
| 3: "SANDBOX", |
| 4: "RATE_LIMIT", |
| 5: "QUARANTINE", |
| } |
|
|
| |
| ACTION_COSTS = { |
| 0: {"latency": 0.0, "compute": 0.0}, |
| 1: {"latency": 0.0, "compute": 0.005}, |
| 2: {"latency": 0.08, "compute": 0.05}, |
| 3: {"latency": 0.20, "compute": 0.12}, |
| 4: {"latency": 0.02, "compute": 0.015}, |
| 5: {"latency": 0.05, "compute": 0.025}, |
| } |
|
|
| |
| BLOCKING_ACTIONS = frozenset({1, 3, 5}) |
| |
| INSPECTION_ACTIONS = frozenset({2}) |
|
|
|
|
| class RewardEngine: |
| """Weighted multi-objective reward with anti-degeneracy safeguards. |
| |
| Key design choices (from RL perspective): |
| - Miss penalty (0.8) vs false-positive penalty (0.35) ratio is ~2.3:1 |
| This creates genuine tension — block-all loses ~0.35 per benign session, |
| and with 60–80% benign traffic, the total FP penalty dominates. |
| - Correct ALLOW on benign traffic gets +0.08 (dense positive signal). |
| - INSPECT keeps the session alive (handled by environment) and gets a |
| small bonus when it correctly identifies malicious traffic. |
| - Timeliness bonus decays exponentially with kill-chain phase. |
| """ |
|
|
| def __init__( |
| self, |
| alpha: float = 0.35, |
| beta: float = 0.30, |
| gamma: float = 0.20, |
| delta: float = 0.15, |
| ) -> None: |
| self.alpha = alpha |
| self.beta = beta |
| self.gamma = gamma |
| self.delta = delta |
|
|
| def action_cost(self, action: int) -> float: |
| costs = ACTION_COSTS.get(action, ACTION_COSTS[0]) |
| return float(costs["latency"] + costs["compute"]) |
|
|
| def reward( |
| self, |
| action: int, |
| is_malicious: bool, |
| budget_remaining: float, |
| attack_phase: int, |
| inspect_correct: bool = False, |
| ) -> Tuple[float, Dict[str, float]]: |
| blocked = action in BLOCKING_ACTIONS |
| inspected = action in INSPECTION_ACTIONS |
|
|
| |
| r_security = 0.0 |
| if is_malicious and blocked: |
| r_security += 1.0 |
| elif is_malicious and not blocked and not inspected: |
| r_security -= 2.0 |
| elif is_malicious and inspected: |
| |
| r_security += 0.15 if inspect_correct else -0.5 |
|
|
| |
| r_availability = 0.0 |
| if not is_malicious and blocked: |
| r_availability -= 1.2 |
| elif not is_malicious and inspected: |
| r_availability -= 0.15 |
| elif not is_malicious and action == 0: |
| r_availability += 0.25 |
|
|
| |
| if not is_malicious and action == 4: |
| r_availability -= 0.4 |
|
|
| |
| cost = self.action_cost(action) |
| |
| r_efficiency = -cost / max(budget_remaining, 0.1) |
|
|
| |
| |
| early_bonus = math.exp(-max(attack_phase, 0)) |
| r_timeliness = early_bonus if (is_malicious and blocked) else 0.0 |
|
|
| total = ( |
| self.alpha * r_security |
| + self.beta * r_availability |
| + self.gamma * r_efficiency |
| + self.delta * r_timeliness |
| ) |
| return total, { |
| "security": r_security, |
| "availability": r_availability, |
| "efficiency": r_efficiency, |
| "timeliness": r_timeliness, |
| "cost": cost, |
| } |
|
|