meta_ai_hackathon / server /firewall_environment.py
GOOD CAT
Final submission prep
ec8c511
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple
import numpy as np
# Updated imports to reflect new structure
from server.utils.reward_engine import (
ACTIONS, BLOCKING_ACTIONS, RewardEngine,
)
from server.utils.threat_engine import ThreatEngine
from server.utils.data_loader import (
FEATURE_ORDER, TrafficGenerator,
)
TASK_CONFIGS = {
"easy": {
"max_steps": 200,
"benign_ratio": 0.80,
"threat_probability": 0.12,
"traffic_lambda": 5,
"budget": 100.0, # ~0.50 budget per step
},
"medium": {
"max_steps": 500,
"benign_ratio": 0.65,
"threat_probability": 0.22,
"traffic_lambda": 6,
"budget": 300.0, # ~0.60 budget per step
},
"hard": {
"max_steps": 1000,
"benign_ratio": 0.70,
"threat_probability": 0.30,
"traffic_lambda": 7,
"budget": 600.0, # ~0.60 budget per step
},
}
NUM_ACTIONS = len(ACTIONS)
OBS_DIM = len(FEATURE_ORDER)
@dataclass
class EpisodeMetrics:
"""Tracks all metrics needed for grading."""
detections: int = 0
malicious_seen: int = 0
false_positives: int = 0
benign_seen: int = 0
early_detection_sum: float = 0.0
cascade_failures: int = 0
total_cost: float = 0.0
sessions_expired_malicious: int = 0
sessions_expired_benign: int = 0
correct_allows: int = 0
inspections: int = 0
class FirewallEnvironment:
"""Adaptive AI Firewall RL environment.
OpenEnv-compatible: reset(), step(), state()
Key design (from RL perspective):
- Observation: 22-dim normalized [0,1] vector per session
- Action: Discrete(6) — ALLOW, BLOCK, INSPECT, SANDBOX, RATE_LIMIT, QUARANTINE
- Reward: multi-objective (security + availability + efficiency + timeliness)
- Done: when max_steps reached or budget depleted
- INSPECT keeps session alive for a second action (two-phase decision)
"""
def __init__(self, seed: int = 0, budget: Optional[float] = None) -> None:
self.base_seed = seed
self.base_budget_override = budget
self.generator = TrafficGenerator(seed=seed)
self.threat_engine = ThreatEngine(seed=seed + 1)
self.reward_engine = RewardEngine()
self.rng = np.random.default_rng(seed + 2)
self.episode_id = 0
self.step_count = 0
self.current_tick = 0
self.task = "easy"
self.max_steps = TASK_CONFIGS[self.task]["max_steps"]
default_budget = TASK_CONFIGS[self.task]["budget"]
if self.base_budget_override is not None:
default_budget = max(default_budget, float(self.base_budget_override))
self.budget_remaining = default_budget
self.initial_budget = self.budget_remaining
self.total_reward = 0.0
self.pending_sessions: Dict[str, Dict] = {}
self.inspected_sessions: Dict[str, Dict] = {} # sessions awaiting 2nd action
self.action_log: List[Dict] = []
self._blocked_attacker_ids: Set[str] = set()
self.metrics = EpisodeMetrics()
# For single-session mode
self._session_queue: List[str] = []
# ══════════════════════════════════════════════════════════════════
# OpenEnv API
# ══════════════════════════════════════════════════════════════════
def reset(self, task: str = "easy", seed: Optional[int] = None) -> Dict:
"""Reset environment for a new episode."""
if task not in TASK_CONFIGS:
raise ValueError(f"unknown task: {task}")
used_seed = self.base_seed if seed is None else seed
self.generator = TrafficGenerator(seed=used_seed)
self.threat_engine = ThreatEngine(seed=used_seed + 1)
self.rng = np.random.default_rng(used_seed + 2)
self.episode_id += 1
self.step_count = 0
self.current_tick = 0
self.task = task
config = TASK_CONFIGS[task]
self.max_steps = config["max_steps"]
task_budget = float(config["budget"])
if self.base_budget_override is not None:
task_budget = max(task_budget, float(self.base_budget_override))
self.initial_budget = task_budget
self.budget_remaining = self.initial_budget
self.total_reward = 0.0
self.pending_sessions = {}
self.inspected_sessions = {}
self.action_log = []
self._blocked_attacker_ids = set()
self.metrics = EpisodeMetrics()
self._session_queue = []
# Spawn initial sessions
self._spawn_sessions()
self._rebuild_queue()
return self.state()
def step(self, action_map: Optional[Dict[str, int]] = None) -> Dict:
"""Multi-session step: agent provides actions for multiple sessions at once."""
action_map = action_map or {}
step_reward = 0.0
for session_id, action in action_map.items():
# Check both pending and inspected pools
if session_id in self.pending_sessions or session_id in self.inspected_sessions:
reward, _ = self._apply_action(session_id, action)
step_reward += reward
expired_penalty = self._expire_sessions()
step_reward += expired_penalty
self.total_reward += step_reward
self.step_count += 1
self.current_tick += 1
done = self.step_count >= self.max_steps or self.budget_remaining <= 0.0
if not done:
self._spawn_sessions()
self._rebuild_queue()
# Calculate score using the deterministic grader logic
final_stats = self.get_network_stats()
from server.graders import grade_stats
grade = grade_stats(self.task, final_stats)
return {
"reward": step_reward,
"done": done,
"state": self.state(),
"info": {
"expired_penalty": expired_penalty,
"attacker_outcomes": self.threat_engine.attacker_outcomes(),
"score": grade["score"],
"passed": grade["passed"]
},
}
def step_single(self, action: int) -> Dict:
"""Single-session step: present one session, agent picks one action.
Compatible with Gymnasium Discrete(6).
Returns observation of the NEXT session, or zeros if episode done.
"""
if action not in ACTIONS:
raise ValueError(f"invalid action: {action}")
step_reward = 0.0
info: Dict[str, Any] = {}
# Act on the current session
if self._session_queue:
session_id = self._session_queue.pop(0)
if session_id in self.pending_sessions or session_id in self.inspected_sessions:
reward, record = self._apply_action(session_id, action)
step_reward += reward
info["action_record"] = record
self.total_reward = round(self.total_reward + step_reward, 4)
self.step_count += 1
# If queue is empty, advance tick
if not self._session_queue:
self.current_tick += 1
expired_penalty = self._expire_sessions()
# step_reward for the final session in tick includes the expiration penalty
step_reward += expired_penalty
self.total_reward = round(self.total_reward + expired_penalty, 4)
done = self.step_count >= self.max_steps or self.budget_remaining <= 0.0
if not done:
self._spawn_sessions()
self._rebuild_queue()
else:
done = self.step_count >= self.max_steps or self.budget_remaining <= 0.0
# Build next observation
next_obs = self._current_observation()
return {
"observation": next_obs,
"reward": step_reward,
"done": done,
"state": {
**self.state(),
"focus_observation": next_obs,
"focus_session_id": self._session_queue[0] if self._session_queue else None,
},
"info": info,
}
def state(self) -> Dict:
"""Return current environment state (OpenEnv API)."""
all_sessions = {**self.pending_sessions, **self.inspected_sessions}
top_ids = list(all_sessions.keys())[:10]
focus_session_id = self._session_queue[0] if self._session_queue else None
return {
"episode_id": self.episode_id,
"task": self.task,
"step_count": self.step_count,
"current_tick": self.current_tick,
"observation_dim": OBS_DIM,
"num_actions": NUM_ACTIONS,
"budget_remaining": round(self.budget_remaining, 4),
"total_reward": round(self.total_reward, 4),
"pending_session_count": len(self.pending_sessions),
"inspected_session_count": len(self.inspected_sessions),
"pending_session_ids": top_ids,
"inspected_session_ids": list(self.inspected_sessions.keys())[:10],
"queue_length": len(self._session_queue),
"focus_session_id": focus_session_id,
"focus_observation": self._current_observation(),
}
# ══════════════════════════════════════════════════════════════════
# Tool API (for MCP/HTTP interface)
# ══════════════════════════════════════════════════════════════════
def evaluate_session(self, session_id: str) -> Dict:
"""Get observation vector and metadata for a session."""
session = self.pending_sessions.get(session_id) or self.inspected_sessions.get(session_id)
if session is None:
raise KeyError(f"session not found: {session_id}")
return {
"session_id": session_id,
"features": dict(session["features"]),
"observation": self.generator.to_observation_vector(session),
"is_inspected": session_id in self.inspected_sessions,
"revealed_malicious": (
session["metadata"]["malicious"]
if session["metadata"]["revealed"] else None
),
"expires_tick": session["expires_tick"],
}
def take_action(self, session_id: str, action: int) -> Tuple[float, Dict]:
"""Apply an action to a specific session."""
return self._apply_action(session_id, action)
def get_network_stats(self) -> Dict:
"""Aggregate episode statistics for grading."""
m = self.metrics
total_malicious = m.malicious_seen + m.sessions_expired_malicious
total_benign = m.benign_seen + m.sessions_expired_benign
detection_rate = m.detections / max(total_malicious, 1)
false_positive_rate = m.false_positives / max(total_benign, 1)
efficiency = 1.0 - min(1.0, m.total_cost / max(self.initial_budget, 1e-6))
early_detection_bonus = m.early_detection_sum / max(m.detections, 1)
cascade_prevention = 1.0 - (m.cascade_failures / max(total_malicious, 1))
return {
"episode_id": self.episode_id,
"task": self.task,
"tick": self.current_tick,
"step_count": self.step_count,
"total_reward": round(self.total_reward, 4),
"budget_remaining": round(self.budget_remaining, 4),
"budget_used_pct": round(1.0 - self.budget_remaining / max(self.initial_budget, 1e-6), 4),
"total_malicious": total_malicious,
"total_benign": total_benign,
"detection_rate": round(detection_rate, 6),
"false_positive_rate": round(false_positive_rate, 6),
"efficiency": round(efficiency, 6),
"early_detection_bonus": round(early_detection_bonus, 6),
"cascade_prevention": round(cascade_prevention, 6),
"correct_allows": m.correct_allows,
"inspections": m.inspections,
"expired_malicious": m.sessions_expired_malicious,
"expired_benign": m.sessions_expired_benign,
}
def get_threat_intelligence(self) -> Dict:
return self.threat_engine.intelligence_feed()
def list_tools(self) -> List[str]:
return [
"evaluate_session", "take_action",
"get_network_stats", "get_threat_intelligence",
]
# ══════════════════════════════════════════════════════════════════
# Internal mechanics
# ══════════════════════════════════════════════════════════════════
def _apply_action(self, session_id: str, action: int) -> Tuple[float, Dict]:
"""Core action application logic."""
if action not in ACTIONS:
raise ValueError(f"invalid action: {action}")
# Find the session in either pool
source_pool = "none"
if session_id in self.inspected_sessions:
session = self.inspected_sessions.pop(session_id)
source_pool = "inspected"
elif session_id in self.pending_sessions:
session = self.pending_sessions.pop(session_id)
source_pool = "pending"
else:
raise KeyError(f"session not found: {session_id}")
metadata = session["metadata"]
malicious = bool(metadata["malicious"])
blocked = action in BLOCKING_ACTIONS
inspected = action == 2 # INSPECT
# ── INSPECT keeps the session alive for a second decision ──
if inspected and session_id not in self.inspected_sessions:
metadata["revealed"] = True
self.inspected_sessions[session_id] = session
self.pending_sessions[session_id] = session
self.metrics.inspections += 1
# Compute reward for the inspection itself
reward, components = self.reward_engine.reward(
action=action,
is_malicious=malicious,
budget_remaining=self.budget_remaining,
attack_phase=metadata.get("attack_phase", 0),
inspect_correct=malicious,
)
self.budget_remaining = max(0.0, self.budget_remaining - components["cost"])
self.metrics.total_cost += components["cost"]
record = self._make_record(session_id, action, malicious, reward, components)
return reward, record
# ── Terminal action (ALLOW, BLOCK, SANDBOX, RATE_LIMIT, QUARANTINE) ──
inspect_correct = malicious and metadata.get("revealed", False)
reward, components = self.reward_engine.reward(
action=action,
is_malicious=malicious,
budget_remaining=self.budget_remaining,
attack_phase=metadata.get("attack_phase", 0),
inspect_correct=inspect_correct,
)
self.budget_remaining = max(0.0, self.budget_remaining - components["cost"])
self.metrics.total_cost += components["cost"]
if source_pool == "inspected":
self.pending_sessions.pop(session_id, None)
# ── Update metrics ──
if malicious:
self.metrics.malicious_seen += 1
if blocked:
self.metrics.detections += 1
phase = metadata.get("attack_phase", 0)
self.metrics.early_detection_sum += float(np.exp(-phase))
attacker_id = metadata.get("attacker_id")
if attacker_id:
self._blocked_attacker_ids.add(attacker_id)
else:
if metadata.get("attack_phase", 0) >= 2:
self.metrics.cascade_failures += 1
else:
self.metrics.benign_seen += 1
if blocked:
self.metrics.false_positives += 1
elif action == 0:
self.metrics.correct_allows += 1
record = self._make_record(session_id, action, malicious, reward, components)
self.action_log.append(record)
return reward, record
def _make_record(self, session_id: str, action: int, malicious: bool,
reward: float, components: Dict) -> Dict:
return {
"tick": self.current_tick,
"session_id": session_id,
"action": action,
"action_name": ACTIONS[action],
"malicious": malicious,
"reward": round(reward, 6),
"components": {k: round(v, 6) for k, v in components.items()},
}
def _spawn_sessions(self) -> None:
"""Generate new benign and malicious sessions for current tick."""
config = TASK_CONFIGS[self.task]
benign_count = int(max(1, self.rng.poisson(
config["traffic_lambda"] * config["benign_ratio"],
)))
benign = self.generator.generate_benign_sessions(
tick=self.current_tick, count=benign_count,
)
self.threat_engine.maybe_spawn_attacker(config["threat_probability"])
malicious = self.threat_engine.generate_attack_sessions(
tick=self.current_tick,
generator=self.generator,
blocked_attackers=self._blocked_attacker_ids,
)
self._blocked_attacker_ids = set()
for session in benign + malicious:
self.pending_sessions[session["session_id"]] = session
def _expire_sessions(self) -> float:
"""Remove expired sessions and apply penalties. Count in metrics."""
expired_ids = set()
for sid, session in self.pending_sessions.items():
if session["expires_tick"] <= self.current_tick:
expired_ids.add(sid)
for sid, session in self.inspected_sessions.items():
if session["expires_tick"] <= self.current_tick:
expired_ids.add(sid)
penalty = 0.0
for session_id in expired_ids:
session = self.inspected_sessions.pop(session_id, None)
if session is None:
session = self.pending_sessions.get(session_id)
self.pending_sessions.pop(session_id, None)
if session is None:
continue
if session["metadata"]["malicious"]:
penalty -= 1.5
self.metrics.sessions_expired_malicious += 1
if session["metadata"].get("attack_phase", 0) >= 2:
self.metrics.cascade_failures += 1
else:
self.metrics.sessions_expired_benign += 1
return penalty
def _rebuild_queue(self) -> None:
"""Rebuild the single-session queue from pending + inspected."""
# Inspected sessions get priority (they need a follow-up action)
ordered = list(self.inspected_sessions.keys()) + list(self.pending_sessions.keys())
seen: Set[str] = set()
self._session_queue = []
for sid in ordered:
if sid in seen:
continue
seen.add(sid)
self._session_queue.append(sid)
def _current_observation(self) -> List[float]:
"""Get normalized observation for the next session in queue."""
if self._session_queue:
sid = self._session_queue[0]
session = (
self.inspected_sessions.get(sid)
or self.pending_sessions.get(sid)
)
if session:
return self.generator.to_observation_vector(session)
return [0.0] * OBS_DIM