File size: 20,219 Bytes
ec8c511
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple

import numpy as np

# Updated imports to reflect new structure
from server.utils.reward_engine import (
    ACTIONS, BLOCKING_ACTIONS, RewardEngine,
)
from server.utils.threat_engine import ThreatEngine
from server.utils.data_loader import (
    FEATURE_ORDER, TrafficGenerator,
)


TASK_CONFIGS = {
    "easy": {
        "max_steps": 200,
        "benign_ratio": 0.80,
        "threat_probability": 0.12,
        "traffic_lambda": 5,
        "budget": 100.0,        # ~0.50 budget per step
    },
    "medium": {
        "max_steps": 500,
        "benign_ratio": 0.65,
        "threat_probability": 0.22,
        "traffic_lambda": 6,
        "budget": 300.0,       # ~0.60 budget per step
    },
    "hard": {
        "max_steps": 1000,
        "benign_ratio": 0.70,
        "threat_probability": 0.30,
        "traffic_lambda": 7,
        "budget": 600.0,       # ~0.60 budget per step
    },
}

NUM_ACTIONS = len(ACTIONS)
OBS_DIM = len(FEATURE_ORDER)


@dataclass
class EpisodeMetrics:
    """Tracks all metrics needed for grading."""
    detections: int = 0
    malicious_seen: int = 0
    false_positives: int = 0
    benign_seen: int = 0
    early_detection_sum: float = 0.0
    cascade_failures: int = 0
    total_cost: float = 0.0
    sessions_expired_malicious: int = 0
    sessions_expired_benign: int = 0
    correct_allows: int = 0
    inspections: int = 0


class FirewallEnvironment:
    """Adaptive AI Firewall RL environment.

    OpenEnv-compatible: reset(), step(), state()

    Key design (from RL perspective):
      - Observation: 22-dim normalized [0,1] vector per session
      - Action: Discrete(6) β€” ALLOW, BLOCK, INSPECT, SANDBOX, RATE_LIMIT, QUARANTINE
      - Reward: multi-objective (security + availability + efficiency + timeliness)
      - Done: when max_steps reached or budget depleted
      - INSPECT keeps session alive for a second action (two-phase decision)
    """

    def __init__(self, seed: int = 0, budget: Optional[float] = None) -> None:
        self.base_seed = seed
        self.base_budget_override = budget
        self.generator = TrafficGenerator(seed=seed)
        self.threat_engine = ThreatEngine(seed=seed + 1)
        self.reward_engine = RewardEngine()
        self.rng = np.random.default_rng(seed + 2)

        self.episode_id = 0
        self.step_count = 0
        self.current_tick = 0
        self.task = "easy"
        self.max_steps = TASK_CONFIGS[self.task]["max_steps"]
        default_budget = TASK_CONFIGS[self.task]["budget"]
        if self.base_budget_override is not None:
            default_budget = max(default_budget, float(self.base_budget_override))
        self.budget_remaining = default_budget
        self.initial_budget = self.budget_remaining
        self.total_reward = 0.0

        self.pending_sessions: Dict[str, Dict] = {}
        self.inspected_sessions: Dict[str, Dict] = {}  # sessions awaiting 2nd action
        self.action_log: List[Dict] = []
        self._blocked_attacker_ids: Set[str] = set()
        self.metrics = EpisodeMetrics()

        # For single-session mode
        self._session_queue: List[str] = []

    # ══════════════════════════════════════════════════════════════════
    # OpenEnv API
    # ══════════════════════════════════════════════════════════════════

    def reset(self, task: str = "easy", seed: Optional[int] = None) -> Dict:
        """Reset environment for a new episode."""
        if task not in TASK_CONFIGS:
            raise ValueError(f"unknown task: {task}")

        used_seed = self.base_seed if seed is None else seed
        self.generator = TrafficGenerator(seed=used_seed)
        self.threat_engine = ThreatEngine(seed=used_seed + 1)
        self.rng = np.random.default_rng(used_seed + 2)

        self.episode_id += 1
        self.step_count = 0
        self.current_tick = 0
        self.task = task
        config = TASK_CONFIGS[task]
        self.max_steps = config["max_steps"]
        task_budget = float(config["budget"])
        if self.base_budget_override is not None:
            task_budget = max(task_budget, float(self.base_budget_override))
        self.initial_budget = task_budget
        self.budget_remaining = self.initial_budget
        self.total_reward = 0.0

        self.pending_sessions = {}
        self.inspected_sessions = {}
        self.action_log = []
        self._blocked_attacker_ids = set()
        self.metrics = EpisodeMetrics()
        self._session_queue = []

        # Spawn initial sessions
        self._spawn_sessions()
        self._rebuild_queue()

        return self.state()

    def step(self, action_map: Optional[Dict[str, int]] = None) -> Dict:
        """Multi-session step: agent provides actions for multiple sessions at once."""
        action_map = action_map or {}
        step_reward = 0.0

        for session_id, action in action_map.items():
            # Check both pending and inspected pools
            if session_id in self.pending_sessions or session_id in self.inspected_sessions:
                reward, _ = self._apply_action(session_id, action)
                step_reward += reward

        expired_penalty = self._expire_sessions()
        step_reward += expired_penalty
        self.total_reward += step_reward
        self.step_count += 1
        self.current_tick += 1

        done = self.step_count >= self.max_steps or self.budget_remaining <= 0.0

        if not done:
            self._spawn_sessions()
            self._rebuild_queue()

        # Calculate score using the deterministic grader logic
        final_stats = self.get_network_stats()
        from server.graders import grade_stats
        grade = grade_stats(self.task, final_stats)
        return {
            "reward": step_reward,
            "done": done,
            "state": self.state(),
            "info": {
                "expired_penalty": expired_penalty,
                "attacker_outcomes": self.threat_engine.attacker_outcomes(),
                "score": grade["score"],
                "passed": grade["passed"]
            },
        }

    def step_single(self, action: int) -> Dict:
        """Single-session step: present one session, agent picks one action.

        Compatible with Gymnasium Discrete(6).
        Returns observation of the NEXT session, or zeros if episode done.
        """
        if action not in ACTIONS:
            raise ValueError(f"invalid action: {action}")

        step_reward = 0.0
        info: Dict[str, Any] = {}

        # Act on the current session
        if self._session_queue:
            session_id = self._session_queue.pop(0)
            if session_id in self.pending_sessions or session_id in self.inspected_sessions:
                reward, record = self._apply_action(session_id, action)
                step_reward += reward
                info["action_record"] = record

        self.total_reward = round(self.total_reward + step_reward, 4)
        self.step_count += 1

        # If queue is empty, advance tick
        if not self._session_queue:
            self.current_tick += 1
            expired_penalty = self._expire_sessions()
            # step_reward for the final session in tick includes the expiration penalty
            step_reward += expired_penalty
            self.total_reward = round(self.total_reward + expired_penalty, 4)
            done = self.step_count >= self.max_steps or self.budget_remaining <= 0.0
            if not done:
                self._spawn_sessions()
                self._rebuild_queue()
        else:
            done = self.step_count >= self.max_steps or self.budget_remaining <= 0.0

        # Build next observation
        next_obs = self._current_observation()

        return {
            "observation": next_obs,
            "reward": step_reward,
            "done": done,
            "state": {
                **self.state(),
                "focus_observation": next_obs,
                "focus_session_id": self._session_queue[0] if self._session_queue else None,
            },
            "info": info,
        }

    def state(self) -> Dict:
        """Return current environment state (OpenEnv API)."""
        all_sessions = {**self.pending_sessions, **self.inspected_sessions}
        top_ids = list(all_sessions.keys())[:10]
        focus_session_id = self._session_queue[0] if self._session_queue else None
        return {
            "episode_id": self.episode_id,
            "task": self.task,
            "step_count": self.step_count,
            "current_tick": self.current_tick,
            "observation_dim": OBS_DIM,
            "num_actions": NUM_ACTIONS,
            "budget_remaining": round(self.budget_remaining, 4),
            "total_reward": round(self.total_reward, 4),
            "pending_session_count": len(self.pending_sessions),
            "inspected_session_count": len(self.inspected_sessions),
            "pending_session_ids": top_ids,
            "inspected_session_ids": list(self.inspected_sessions.keys())[:10],
            "queue_length": len(self._session_queue),
            "focus_session_id": focus_session_id,
            "focus_observation": self._current_observation(),
        }

    # ══════════════════════════════════════════════════════════════════
    # Tool API (for MCP/HTTP interface)
    # ══════════════════════════════════════════════════════════════════

    def evaluate_session(self, session_id: str) -> Dict:
        """Get observation vector and metadata for a session."""
        session = self.pending_sessions.get(session_id) or self.inspected_sessions.get(session_id)
        if session is None:
            raise KeyError(f"session not found: {session_id}")

        return {
            "session_id": session_id,
            "features": dict(session["features"]),
            "observation": self.generator.to_observation_vector(session),
            "is_inspected": session_id in self.inspected_sessions,
            "revealed_malicious": (
                session["metadata"]["malicious"]
                if session["metadata"]["revealed"] else None
            ),
            "expires_tick": session["expires_tick"],
        }

    def take_action(self, session_id: str, action: int) -> Tuple[float, Dict]:
        """Apply an action to a specific session."""
        return self._apply_action(session_id, action)

    def get_network_stats(self) -> Dict:
        """Aggregate episode statistics for grading."""
        m = self.metrics
        total_malicious = m.malicious_seen + m.sessions_expired_malicious
        total_benign = m.benign_seen + m.sessions_expired_benign

        detection_rate = m.detections / max(total_malicious, 1)
        false_positive_rate = m.false_positives / max(total_benign, 1)
        efficiency = 1.0 - min(1.0, m.total_cost / max(self.initial_budget, 1e-6))
        early_detection_bonus = m.early_detection_sum / max(m.detections, 1)
        cascade_prevention = 1.0 - (m.cascade_failures / max(total_malicious, 1))

        return {
            "episode_id": self.episode_id,
            "task": self.task,
            "tick": self.current_tick,
            "step_count": self.step_count,
            "total_reward": round(self.total_reward, 4),
            "budget_remaining": round(self.budget_remaining, 4),
            "budget_used_pct": round(1.0 - self.budget_remaining / max(self.initial_budget, 1e-6), 4),
            "total_malicious": total_malicious,
            "total_benign": total_benign,
            "detection_rate": round(detection_rate, 6),
            "false_positive_rate": round(false_positive_rate, 6),
            "efficiency": round(efficiency, 6),
            "early_detection_bonus": round(early_detection_bonus, 6),
            "cascade_prevention": round(cascade_prevention, 6),
            "correct_allows": m.correct_allows,
            "inspections": m.inspections,
            "expired_malicious": m.sessions_expired_malicious,
            "expired_benign": m.sessions_expired_benign,
        }

    def get_threat_intelligence(self) -> Dict:
        return self.threat_engine.intelligence_feed()

    def list_tools(self) -> List[str]:
        return [
            "evaluate_session", "take_action",
            "get_network_stats", "get_threat_intelligence",
        ]

    # ══════════════════════════════════════════════════════════════════
    # Internal mechanics
    # ══════════════════════════════════════════════════════════════════

    def _apply_action(self, session_id: str, action: int) -> Tuple[float, Dict]:
        """Core action application logic."""
        if action not in ACTIONS:
            raise ValueError(f"invalid action: {action}")

        # Find the session in either pool
        source_pool = "none"
        if session_id in self.inspected_sessions:
            session = self.inspected_sessions.pop(session_id)
            source_pool = "inspected"
        elif session_id in self.pending_sessions:
            session = self.pending_sessions.pop(session_id)
            source_pool = "pending"
        else:
            raise KeyError(f"session not found: {session_id}")

        metadata = session["metadata"]
        malicious = bool(metadata["malicious"])
        blocked = action in BLOCKING_ACTIONS
        inspected = action == 2  # INSPECT

        # ── INSPECT keeps the session alive for a second decision ──
        if inspected and session_id not in self.inspected_sessions:
            metadata["revealed"] = True
            self.inspected_sessions[session_id] = session
            self.pending_sessions[session_id] = session
            self.metrics.inspections += 1
            # Compute reward for the inspection itself
            reward, components = self.reward_engine.reward(
                action=action,
                is_malicious=malicious,
                budget_remaining=self.budget_remaining,
                attack_phase=metadata.get("attack_phase", 0),
                inspect_correct=malicious,
            )
            self.budget_remaining = max(0.0, self.budget_remaining - components["cost"])
            self.metrics.total_cost += components["cost"]
            record = self._make_record(session_id, action, malicious, reward, components)
            return reward, record

        # ── Terminal action (ALLOW, BLOCK, SANDBOX, RATE_LIMIT, QUARANTINE) ──
        inspect_correct = malicious and metadata.get("revealed", False)
        reward, components = self.reward_engine.reward(
            action=action,
            is_malicious=malicious,
            budget_remaining=self.budget_remaining,
            attack_phase=metadata.get("attack_phase", 0),
            inspect_correct=inspect_correct,
        )
        self.budget_remaining = max(0.0, self.budget_remaining - components["cost"])
        self.metrics.total_cost += components["cost"]
        if source_pool == "inspected":
            self.pending_sessions.pop(session_id, None)

        # ── Update metrics ──
        if malicious:
            self.metrics.malicious_seen += 1
            if blocked:
                self.metrics.detections += 1
                phase = metadata.get("attack_phase", 0)
                self.metrics.early_detection_sum += float(np.exp(-phase))
                attacker_id = metadata.get("attacker_id")
                if attacker_id:
                    self._blocked_attacker_ids.add(attacker_id)
            else:
                if metadata.get("attack_phase", 0) >= 2:
                    self.metrics.cascade_failures += 1
        else:
            self.metrics.benign_seen += 1
            if blocked:
                self.metrics.false_positives += 1
            elif action == 0:
                self.metrics.correct_allows += 1

        record = self._make_record(session_id, action, malicious, reward, components)
        self.action_log.append(record)
        return reward, record

    def _make_record(self, session_id: str, action: int, malicious: bool,
                     reward: float, components: Dict) -> Dict:
        return {
            "tick": self.current_tick,
            "session_id": session_id,
            "action": action,
            "action_name": ACTIONS[action],
            "malicious": malicious,
            "reward": round(reward, 6),
            "components": {k: round(v, 6) for k, v in components.items()},
        }

    def _spawn_sessions(self) -> None:
        """Generate new benign and malicious sessions for current tick."""
        config = TASK_CONFIGS[self.task]
        benign_count = int(max(1, self.rng.poisson(
            config["traffic_lambda"] * config["benign_ratio"],
        )))
        benign = self.generator.generate_benign_sessions(
            tick=self.current_tick, count=benign_count,
        )

        self.threat_engine.maybe_spawn_attacker(config["threat_probability"])
        malicious = self.threat_engine.generate_attack_sessions(
            tick=self.current_tick,
            generator=self.generator,
            blocked_attackers=self._blocked_attacker_ids,
        )
        self._blocked_attacker_ids = set()

        for session in benign + malicious:
            self.pending_sessions[session["session_id"]] = session

    def _expire_sessions(self) -> float:
        """Remove expired sessions and apply penalties. Count in metrics."""
        expired_ids = set()
        for sid, session in self.pending_sessions.items():
            if session["expires_tick"] <= self.current_tick:
                expired_ids.add(sid)
        for sid, session in self.inspected_sessions.items():
            if session["expires_tick"] <= self.current_tick:
                expired_ids.add(sid)
        penalty = 0.0
        for session_id in expired_ids:
            session = self.inspected_sessions.pop(session_id, None)
            if session is None:
                session = self.pending_sessions.get(session_id)
            self.pending_sessions.pop(session_id, None)
            if session is None:
                continue
            if session["metadata"]["malicious"]:
                penalty -= 1.5
                self.metrics.sessions_expired_malicious += 1
                if session["metadata"].get("attack_phase", 0) >= 2:
                    self.metrics.cascade_failures += 1
            else:
                self.metrics.sessions_expired_benign += 1

        return penalty

    def _rebuild_queue(self) -> None:
        """Rebuild the single-session queue from pending + inspected."""
        # Inspected sessions get priority (they need a follow-up action)
        ordered = list(self.inspected_sessions.keys()) + list(self.pending_sessions.keys())
        seen: Set[str] = set()
        self._session_queue = []
        for sid in ordered:
            if sid in seen:
                continue
            seen.add(sid)
            self._session_queue.append(sid)

    def _current_observation(self) -> List[float]:
        """Get normalized observation for the next session in queue."""
        if self._session_queue:
            sid = self._session_queue[0]
            session = (
                self.inspected_sessions.get(sid)
                or self.pending_sessions.get(sid)
            )
            if session:
                return self.generator.to_observation_vector(session)
        return [0.0] * OBS_DIM