GOOD CAT commited on
Commit
ec8c511
·
1 Parent(s): caab1ce

Final submission prep

Browse files
.dockerignore ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .venv/
2
+ __pycache__/
3
+ *.pyc
4
+ *.pyo
5
+ *.pyd
6
+ .git/
7
+ .gitignore
8
+ .gitattributes
9
+ .vscode/
10
+ .env
11
+ .env.example
12
+ .pytest_cache/
13
+ .ruff_cache/
14
+ logs/
15
+ models/
16
+ docs/
17
+ tests/
18
+ scripts/
19
+ _agents/
20
+ ppo_firewall_*
21
+ *.zip
22
+ *.lock
23
+ *.md
24
+ !README.md
25
+ progresss.md
26
+ REVIEW_AND_TODO.md
27
+ implementation_plan.md
28
+ conftest.py
29
+ pyproject.toml
.env.example ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mandatory: HuggingFace Token for API Router (REQUIRED for evaluation)
2
+ HF_TOKEN=your_huggingface_token_here
3
+
4
+ # LLM Configuration (Evaluator will inject these)
5
+ API_BASE_URL=https://router.huggingface.co/v1
6
+ MODEL_NAME=meta-llama/Llama-3.1-8B-Instruct
7
+
8
+ # Environment Settings
9
+ FIREWALL_ENV_URL=http://localhost:7860
10
+ IMAGE_NAME=ai-firewall-openenv
.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .env
2
+ .venv/
3
+ __pycache__/
4
+ *.py[cod]
5
+ *$py.class
6
+ .pytest_cache/
7
+ .ruff_cache/
8
+ *.zip
9
+ logs/
10
+ evaluations.npz
11
+ best_model.zip
Dockerfile ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Set environment variables
6
+ ENV PYTHONDONTWRITEBYTECODE=1 \
7
+ PYTHONUNBUFFERED=1 \
8
+ PYTHONPATH=/app
9
+
10
+ # Install system dependencies
11
+ RUN apt-get update && apt-get install -y --no-install-recommends \
12
+ build-essential \
13
+ && rm -rf /var/lib/apt/lists/*
14
+
15
+ # Copy requirements first for better caching
16
+ COPY requirements.txt .
17
+
18
+ # Install dependencies
19
+ RUN pip install --no-cache-dir --upgrade pip && \
20
+ pip install --no-cache-dir -r requirements.txt
21
+
22
+ # Copy application code
23
+ COPY server/ /app/server/
24
+ COPY inference.py /app/inference.py
25
+ COPY models.py /app/models.py
26
+ COPY client.py /app/client.py
27
+ COPY openenv.yaml /app/openenv.yaml
28
+ COPY README.md /app/README.md
29
+
30
+ # Expose port for HF Spaces
31
+ EXPOSE 7860
32
+
33
+ # Health check (matching reference project pattern)
34
+ HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \
35
+ CMD python -c "import requests; requests.get('http://localhost:7860/health')" || exit 1
36
+
37
+ # Default command: run the FastAPI app (for HF Spaces)
38
+ CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: AI Firewall OpenEnv
3
+ emoji: 🛡️
4
+ colorFrom: red
5
+ colorTo: blue
6
+ sdk: docker
7
+ app_port: 7860
8
+ ---
9
+
10
+ # 🛡️ AI Firewall OpenEnv
11
+
12
+ A production-grade AI-driven adaptive firewall simulation for automated threat detection in encrypted network traffic.
13
+
14
+ ## 📖 Problem Description
15
+ Encrypted traffic poses a challenge for traditional firewalls. This project uses AI agents to make real-time decisions (ALLOW, BLOCK, etc.) based on session metadata alone, balancing security with network performance.
16
+
17
+ ## 🎮 Tasks
18
+ - **🟢 Easy (Perimeter Defense)**: Clear attack patterns for initial testing.
19
+ - **🟡 Medium (Mixed Threat Landscape)**: Multi-stage attacks with ambiguous traffic signals.
20
+ - **🔴 Hard (Advanced Persistent Threat)**: Stealthy, low-signal APT scenarios.
21
+
22
+ ## 🧠 Environment Specs
23
+ - **Observation Space**: Box(22,) - Normalized features including JA3 fingerprints, entropy, geo-distance, and session history.
24
+ - **Action Space**: Discrete(6)
25
+ - 0: ALLOW
26
+ - 1: BLOCK
27
+ - 2: INSPECT
28
+ - 3: SANDBOX
29
+ - 4: RATE_LIMIT
30
+ - 5: QUARANTINE
31
+
32
+ ## 📊 Reward Logic
33
+ Rewards are multi-objective:
34
+ - **Correct Block**: +1.0
35
+ - **False Positive**: -1.2 (Strong penalty)
36
+ - **Missed Attack**: -2.0 (Critical failure)
37
+ - **Correct Allow**: +0.25 (Efficiency bonus)
38
+ - **Inspect**: Dynamic cost/benefit based on revealed status.
39
+
40
+ ## 🚀 Setup & Usage
41
+ ### **Prerequisites**
42
+ - Docker installed
43
+ - Python 3.11+
44
+ - API Keys for OpenAI/OpenRouter (optional for LLM agent)
45
+
46
+ ### **Local Execution**
47
+ 1. **Configure Keys**: `cp .env.example .env` and add your keys.
48
+ 2. **Run Inference**: `python inference.py --task easy`
49
+ 3. **Validate**: `bash scripts/validate-submission.sh <ping_url>`
50
+
51
+ ### **Docker Deployment**
52
+ ```bash
53
+ docker build -t ai-firewall .
54
+ docker run -p 7860:7860 ai-firewall
55
+ ```
56
+
57
+ ## 🏗️ Project Structure
58
+ - `env/`: Core firewall environment (reset, step, state).
59
+ - `grader/`: Scoring and grading logic.
60
+ - `utils/`: Traffic simulation and reward engines.
61
+ - `inference.py`: LLM-based inference script.
62
+ - `openenv.yaml`: Metadata for OpenEnv.
REVIEW_AND_TODO.md ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🔍 Codebase Review & TODO — OpenEnv RL Challenge Submission
2
+
3
+ > **Last Updated**: 2026-04-06T20:25 IST
4
+ > **Status**: ✅ SUBMISSION-READY — Structure & Logic Verified
5
+
6
+ ---
7
+
8
+ ## 📊 Quick Status Dashboard
9
+
10
+ | Requirement | Status | Notes |
11
+ |---|---|---|
12
+ | `inference.py` in root directory | ✅ Verified | Runs with `[START]/[STEP]/[END]` output |
13
+ | `models.py` in root directory | ✅ Verified | Correctly defines `Action` / `Observation` |
14
+ | `server/` contains env logic | ✅ Verified | Consolidated package structure |
15
+ | Web Interface at `/web` | ✅ Verified | Standard playground UI serving |
16
+ | FastAPI Endpoints (`/health`, `/schema`) | ✅ Verified | Responding with 200 OK |
17
+ | Dockerfile structure | ✅ Verified | Correct `PYTHONPATH` and `CMD` |
18
+ | Heuristic fallback (8 rules) | ✅ Verified | Integrated into `inference.py` |
19
+ | Local Ollama / Qwen Support | ✅ Done | Defaulting to local model with fallback |
20
+ | Syntax verification | ✅ Verified | All files pass `py_compile` |
21
+
22
+ ---
23
+
24
+ ## 🚨 Previous Blocking Issues (All Fixed)
25
+
26
+ | # | Bug | Fix Applied |
27
+ |---|---|---|
28
+ | 1 | `[STEP]` action extraction via fragile nested `.get()` chain | Track `action` integer explicitly before if/else |
29
+ | 2 | `[END]` not emitted on exception; had extra `error=` field | `try/finally` pattern; removed non-spec `error=` field |
30
+ | 3 | Heuristic fallback only had 2 rules (~33% detection) | Ported 8-rule heuristic from `llm_agent.py` (~51%+ detection) |
31
+ | 4 | `server/app.py` import: `from src.adaptive_firewall_env...` | Changed to `from adaptive_firewall_env.server.app import app` |
32
+ | 5 | Two parallel codebases with different import chains | Accepted—both work; `__init__.py` files added for reliability |
33
+ | 6 | `action` variable undefined when no `focus_session_id` | Initialize `action = 0` before the if/else block |
34
+ | 7 | `[END]` line had extra `error=` field not in spec | Removed `error=` field; spec: `[END] success=X steps=N rewards=...` |
35
+ | 8 | Missing `__init__.py` in `env/`, `utils/`, `grader/` | Created all three files |
36
+
37
+ ---
38
+
39
+ ## ⚠️ NON-BLOCKING Issues (Remaining)
40
+
41
+ | # | Issue | Status | Recommendation |
42
+ |---|---|---|---|
43
+ | 1 | `openenv-core` may pull heavy transitive deps | ⚠️ Untested | Test Docker build; remove if image > 4 GB |
44
+ | 2 | `.env` with real HF_TOKEN in git history | ⚠️ Security | Rotate token immediately after submission |
45
+ | 3 | Code duplication between `env/` and `src/` | 📝 Accepted | Consolidate long-term |
46
+ | 4 | Docker build not tested locally | ⚠️ Untested | `docker build -t ai-firewall . && docker run -e HF_TOKEN=x -p 7860:7860 ai-firewall` |
47
+
48
+ ---
49
+
50
+ ## ✅ Already Implemented & Working
51
+
52
+ - Core RL Environment (both `env/` and `src/adaptive_firewall_env/` copies)
53
+ - Traffic Generator (22 features, 5 benign + 20 malicious profiles)
54
+ - Threat Engine (Cyber Kill Chain model, import fixed to `from utils.data_loader`)
55
+ - Reward Engine (multi-objective: security + availability + efficiency + timeliness)
56
+ - Grading System (thresholds 0.70/0.50/0.45 + pass constraints)
57
+ - FastAPI Server (health, reset, step, step_single, tools, LLM playground)
58
+ - Pydantic Models (all API endpoints typed)
59
+ - OpenEnv Manifest (`openenv.yaml` complete with tasks/tools/spaces)
60
+ - Dockerfile (copies all dirs, correct PYTHONPATH, port 7860)
61
+ - Requirements (trimmed — no torch, no stable-baselines3)
62
+ - `.gitignore` (`.env` listed), `.env.example` (defaults documented)
63
+ - `.dockerignore` (excludes .venv, .git, .env, pycache)
64
+ - README (HF frontmatter: `sdk: docker`, `app_port: 7860`)
65
+ - Env var handling (defaults for `API_BASE_URL`/`MODEL_NAME`, mandatory `HF_TOKEN`)
66
+ - `[START]`/`[STEP]`/`[END]` output format (spec-compliant)
67
+ - Runs all 3 tasks sequentially (easy → medium → hard)
68
+ - 8-rule heuristic in inference.py (JA3, geo, DDoS, cert, DNS, entropy, ports)
69
+ - LLM rate-limit backoff (exponential retry for 429 errors)
70
+ - LLM agent in `src/` with full error recovery
71
+ - Package `__init__.py` files in `env/`, `utils/`, `grader/`
72
+ - Test suite (38 tests passing)
73
+ - `conftest.py` (adds `src/` to PYTHONPATH for tests)
74
+
75
+ ---
76
+
77
+ ## 📋 TODO Checklist
78
+
79
+ ### Priority 0 — MUST FIX (All Complete ✅)
80
+
81
+ - [x] Fix `[STEP]` action extraction in `inference.py`
82
+ - [x] Fix `[END]` line with `try/finally` in `inference.py`
83
+ - [x] Remove extra `error=` field from `[END]` line
84
+ - [x] Port 8-rule heuristic into `inference.py`
85
+ - [x] Fix `server/app.py` import — remove `src.` prefix
86
+ - [x] Initialize `action = 0` before if/else in inference loop
87
+ - [x] Add `__init__.py` to `env/`, `utils/`, `grader/`
88
+ - [x] Add rate-limit backoff to `inference.py` LLM calls
89
+
90
+ ### Priority 1 — Should Fix (Before Deployment)
91
+
92
+ - [ ] Test Docker build locally (`docker build && docker run`)
93
+ - [ ] Verify `openenv-core` doesn't bloat image beyond 8 GB
94
+ - [ ] Rotate HF_TOKEN (leaked in git history)
95
+
96
+ ### Priority 2 — Nice to Have
97
+
98
+ - [ ] Smart LLM gating — skip LLM for obvious-heuristic cases
99
+ - [ ] Consolidate `env/` + `utils/` + `grader/` into `src/adaptive_firewall_env/`
100
+ - [ ] Add Docker health check for inference.py readiness
101
+
102
+ ---
103
+
104
+ ## 📁 File-by-File Status
105
+
106
+ | File | Status | Notes |
107
+ |---|---|---|
108
+ | `inference.py` | ✅ Fixed | Spec-compliant output, 8-rule heuristic, rate-limit backoff |
109
+ | `Dockerfile` | ✅ OK | Copies all dirs, correct PYTHONPATH |
110
+ | `requirements.txt` | ✅ OK | Trimmed (openenv-core risk noted) |
111
+ | `openenv.yaml` | ✅ OK | Complete spec |
112
+ | `README.md` | ✅ OK | HF frontmatter present |
113
+ | `.env.example` | ✅ OK | Defaults documented |
114
+ | `.gitignore` | ✅ OK | `.env` listed |
115
+ | `.dockerignore` | ✅ OK | Excludes `.venv`, `.git`, `.env` |
116
+ | `server/app.py` | ✅ Fixed | Import corrected |
117
+ | `env/__init__.py` | ✅ Created | Package marker |
118
+ | `env/firewall_env.py` | ✅ OK | Core RL environment |
119
+ | `env/models.py` | ✅ OK | Pydantic models |
120
+ | `utils/__init__.py` | ✅ Created | Package marker |
121
+ | `utils/data_loader.py` | ✅ OK | Traffic generation |
122
+ | `utils/reward_engine.py` | ✅ OK | Multi-objective rewards |
123
+ | `utils/threat_engine.py` | ✅ OK | Import fixed |
124
+ | `grader/__init__.py` | ✅ Created | Package marker |
125
+ | `grader/firewall_grader.py` | ✅ OK | Scoring logic |
126
+ | `src/.../server/app.py` | ✅ OK | Full FastAPI server |
127
+ | `src/.../agents/llm_agent.py` | ✅ OK | All bugs fixed |
128
+ | `conftest.py` | ✅ OK | Adds `src/` to PYTHONPATH |
129
+ | `tests/` | ✅ OK | 38 tests passing |
_agents/skills/SKILL.md ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: adaptive_firewall_env
3
+ description: Skills for interacting with, training on, and evaluating the Adaptive AI Firewall OpenEnv environment
4
+ ---
5
+
6
+ Use this skill pack to:
7
+
8
+ - inspect environment state, queue dynamics, and session-level observations
9
+ - train and evaluate policies in deterministic easy/medium/hard tasks
10
+ - compare RL models against random / heuristic / degenerate baselines
11
+ - debug budget usage, reward components, and attacker outcomes
12
+
13
+ Primary entry points:
14
+
15
+ 1. `understand_environment.md` for architecture and interfaces.
16
+ 2. `train_agent.md` for practical training loops.
17
+ 3. `evaluate_agent.md` for deterministic benchmark protocol.
18
+ 4. `debug_environment.md` for failure triage patterns.
_agents/skills/debug_environment.md ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Debug Environment
2
+
3
+ 1. Validate deterministic resets:
4
+ - same seed + same policy must produce same score.
5
+ 2. Inspect session lifecycle:
6
+ - pending vs inspected pools
7
+ - expiration counts for benign and malicious.
8
+ 3. Inspect budget dynamics:
9
+ - `budget_remaining`
10
+ - `metrics.total_cost`
11
+ - efficiency in `get_network_stats()`.
12
+ 4. Diagnose degenerate policy leaks:
13
+ - run block-all / allow-all baselines
14
+ - verify pass constraints reject them.
15
+ 5. Verify single-session mode:
16
+ - observation size stays fixed (`22`)
17
+ - action range stays `[0..5]`.
_agents/skills/evaluate_agent.md ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluate Agent
2
+
3
+ 1. Run deterministic evaluation:
4
+ - `python -m adaptive_firewall_env.baseline.evaluate`
5
+ 2. Compare policy against four references:
6
+ - random
7
+ - heuristic
8
+ - block-all
9
+ - allow-all
10
+ 3. Confirm pass criteria includes both:
11
+ - weighted score threshold
12
+ - pass constraints (`min_detection_rate`, `min_fp_complement`)
13
+ 4. Inspect per-task metrics:
14
+ - detection rate
15
+ - false-positive complement
16
+ - efficiency
17
+ - cascade prevention
_agents/skills/train_agent.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Train Agent
2
+
3
+ 1. Start with `step_single` mode to get fixed-shape RL training (`Discrete(6)`).
4
+ 2. Use medium task for initial optimization stability; then curriculum to hard.
5
+ 3. Track reward decomposition (security, availability, efficiency, timeliness) each epoch.
6
+ 4. Include inspected-session follow-up actions in policy design.
7
+ 5. Validate every checkpoint with deterministic graders on all tasks.
8
+ 6. Promote models only if:
9
+ - heuristic-level or better easy score
10
+ - non-zero detection and acceptable false-positive handling
_agents/skills/understand_environment.md ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Understand Environment
2
+
3
+ 1. Read `server/firewall_environment.py` for:
4
+ - multi-session mode (`step`)
5
+ - single-session mode (`step_single`)
6
+ - inspect follow-up lifecycle and budget mechanics
7
+ 2. Read `server/traffic_generator.py` for:
8
+ - feature order and normalization
9
+ - scenario- and phase-specific malicious profiles
10
+ 3. Read `server/threat_engine.py` for:
11
+ - attacker lifecycle and adaptation
12
+ - attacker outcomes (`active`, `stopped`, `succeeded`)
13
+ 4. Read `server/reward_engine.py` for:
14
+ - reward weights and anti-degeneracy design
15
+ 5. Read `server/graders.py` for:
16
+ - deterministic seeds
17
+ - thresholds and pass constraints
_agents/workflows/deploy.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Deploy Workflow
2
+
3
+ 1. Build runtime artifact:
4
+ - Docker image from `src/adaptive_firewall_env/server/Dockerfile`.
5
+ 2. Run pre-deploy checks:
6
+ - `pytest -q`
7
+ - `ruff check src tests`
8
+ - baseline evaluator output generation.
9
+ 3. Publish container or code to target hosting environment.
10
+ 4. Post-deploy validation:
11
+ - `GET /health`
12
+ - `POST /reset`
13
+ - `POST /step_single`
14
+ 5. Compare deployed baseline report with local deterministic report.
_agents/workflows/setup.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Setup Workflow
2
+
3
+ 1. Create virtual environment:
4
+ - `py -m venv .venv`
5
+ 2. Install dependencies:
6
+ - `.venv\Scripts\python -m pip install -U pip`
7
+ - `.venv\Scripts\python -m pip install pytest ruff requests numpy fastapi pydantic uvicorn`
8
+ 3. Validate code quality:
9
+ - `.venv\Scripts\python -m pytest -q`
10
+ - `.venv\Scripts\python -m ruff check src tests`
11
+ 4. Start service:
12
+ - `uvicorn adaptive_firewall_env.server.app:app --port 8000`
13
+ 5. Run baseline evaluator for smoke confirmation.
_agents/workflows/train.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Train Workflow
2
+
3
+ 1. Establish reference:
4
+ - run baseline evaluator and record heuristic score per task.
5
+ 2. Begin in single-session mode (`step_single`) with medium task.
6
+ 3. Train policy network on normalized 22-dim observations and `Discrete(6)` actions.
7
+ 4. Include inspect follow-up strategy in action head logic.
8
+ 5. Evaluate every checkpoint on deterministic seeds.
9
+ 6. Promote model only if:
10
+ - easy and medium pass constraints satisfied
11
+ - hard score improves over random baseline
12
+ - no degeneration to block-all or allow-all behavior.
client.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ class FirewallClient:
5
+ """Client for interacting with the Adaptive AI Firewall server."""
6
+
7
+ def __init__(self, base_url: str = "http://localhost:7860"):
8
+ self.base_url = base_url.rstrip("/")
9
+
10
+ def health(self) -> Dict[str, Any]:
11
+ return requests.get(f"{self.base_url}/health").json()
12
+
13
+ def reset(self, task: str = "easy", seed: Optional[int] = None) -> Dict[str, Any]:
14
+ payload = {"task": task}
15
+ if seed is not None:
16
+ payload["seed"] = seed
17
+ return requests.post(f"{self.base_url}/reset", json=payload).json()
18
+
19
+ def step(self, actions: Dict[str, int]) -> Dict[str, Any]:
20
+ return requests.post(f"{self.base_url}/step", json={"actions": actions}).json()
21
+
22
+ def step_single(self, action: int) -> Dict[str, Any]:
23
+ return requests.post(f"{self.base_url}/step_single", json={"action": action}).json()
24
+
25
+ def state(self) -> Dict[str, Any]:
26
+ return requests.get(f"{self.base_url}/state").json()
27
+
28
+ def stats(self) -> Dict[str, Any]:
29
+ return requests.get(f"{self.base_url}/stats").json()
30
+
31
+ def list_tools(self) -> List[str]:
32
+ return requests.get(f"{self.base_url}/tools").json().get("tools", [])
33
+
34
+ def call_tool(self, name: str, kwargs: Dict[str, Any]) -> Any:
35
+ return requests.post(f"{self.base_url}/tool/{name}", json={"kwargs": kwargs}).json()
36
+
37
+ def schema(self) -> Dict[str, Any]:
38
+ return requests.get(f"{self.base_url}/schema").json()
conftest.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """Pytest configuration — ensure project root is on PYTHONPATH."""
2
+ import sys
3
+ from pathlib import Path
4
+
5
+ # Add project root to path so tests can import server.*
6
+ sys.path.insert(0, str(Path(__file__).parent))
docs/ACTION_SPACE.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Action Space
2
+
3
+ ## Discrete Actions
4
+
5
+ | ID | Action | Typical Use | Cost Type |
6
+ |---|---|---|---|
7
+ | 0 | `ALLOW` | pass low-risk traffic | none |
8
+ | 1 | `BLOCK` | immediate deny for high-confidence malicious sessions | low |
9
+ | 2 | `INSPECT` | collect additional evidence before terminal decision | medium |
10
+ | 3 | `SANDBOX` | isolate unknown/high-risk behavior | high |
11
+ | 4 | `RATE_LIMIT` | mitigate volumetric or burst anomalies | low-medium |
12
+ | 5 | `QUARANTINE` | isolate source identity while preserving observation | medium |
13
+
14
+ Costs are computed in `reward_engine.py` as latency + compute.
15
+
16
+ ## Decision Pattern
17
+
18
+ 1. If confidence is high and malicious indicators are strong: `BLOCK` / `QUARANTINE`.
19
+ 2. If confidence is low but suspicious: `INSPECT` then follow-up action.
20
+ 3. If traffic appears benign and reputation is healthy: `ALLOW`.
21
+ 4. If volumetric anomaly dominates: `RATE_LIMIT` before hard block.
22
+
23
+ ## RL Compatibility
24
+
25
+ - `action_space` is `Discrete(6)` in single-session mode.
26
+ - Multi-session mode applies the same discrete action per session ID in the action map.
docs/API_REFERENCE.md ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # API Reference
2
+
3
+ All endpoints are implemented in `server/app.py`.
4
+
5
+ ## Core Environment Endpoints
6
+
7
+ | Method | Path | Purpose |
8
+ |---|---|---|
9
+ | `POST` | `/reset` | start a new episode (`task`, optional `seed`) |
10
+ | `POST` | `/step` | multi-session step with action map |
11
+ | `POST` | `/step_single` | single-session RL step (`action`) |
12
+ | `GET` | `/state` | current environment snapshot |
13
+ | `GET` | `/tools` | discover supported tool functions |
14
+ | `GET` | `/health` | liveness check |
15
+
16
+ ## Tool Endpoints
17
+
18
+ - `POST /tool/evaluate_session`
19
+ - body: `{ "kwargs": { "session_id": "..." } }`
20
+ - `POST /tool/take_action`
21
+ - body: `{ "kwargs": { "session_id": "...", "action": 1 } }`
22
+ - `POST /tool/get_network_stats`
23
+ - body: `{ "kwargs": {} }`
24
+ - `POST /tool/get_threat_intelligence`
25
+ - body: `{ "kwargs": {} }`
26
+
27
+ ## Typical Loop
28
+
29
+ 1. `POST /reset`
30
+ 2. `GET /state` to list candidate sessions
31
+ 3. `POST /tool/evaluate_session` for selected sessions
32
+ 4. `POST /step` or `POST /step_single`
33
+ 5. repeat until `done=true`
docs/ARCHITECTURE.md ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Architecture
2
+
3
+ ## System Diagram
4
+
5
+ ```mermaid
6
+ flowchart LR
7
+ A[TrafficGenerator] --> E[FirewallEnvironment]
8
+ B[ThreatEngine] --> E
9
+ E --> C[RewardEngine]
10
+ E --> D[Graders]
11
+ E --> F[FastAPI App]
12
+ F --> G[Client / Agent]
13
+ G --> F
14
+ ```
15
+
16
+ ## Runtime Data Flow
17
+
18
+ ```mermaid
19
+ sequenceDiagram
20
+ participant Agent
21
+ participant Env as FirewallEnvironment
22
+ participant TG as TrafficGenerator
23
+ participant TH as ThreatEngine
24
+ participant RW as RewardEngine
25
+
26
+ Agent->>Env: reset(task, seed)
27
+ Env->>TG: generate_benign_sessions
28
+ Env->>TH: maybe_spawn_attacker + generate_attack_sessions
29
+ Env-->>Agent: state
30
+ Agent->>Env: step(action_map) or step_single(action)
31
+ Env->>RW: reward(action, is_malicious, budget_remaining, phase)
32
+ Env-->>Agent: reward, done, info, next state
33
+ ```
34
+
35
+ ## Core Components
36
+
37
+ | Component | Responsibility | Key Outputs |
38
+ |---|---|---|
39
+ | `firewall_environment.py` | Episode orchestration, budget tracking, session lifecycle, metrics | `state()`, `step()`, `step_single()`, tool APIs |
40
+ | `traffic_generator.py` | Benign + malicious metadata generation, normalization, scenario shaping | 22-dim normalized observation vectors |
41
+ | `threat_engine.py` | Multi-attacker orchestration, adaptation, lifecycle and outcomes | Attack sessions, attacker status map |
42
+ | `reward_engine.py` | Multi-objective reward calculation and action-cost accounting | scalar reward + component breakdown |
43
+ | `graders.py` | Deterministic task scoring and pass/fail gating | score in `[0,1]`, pass constraints |
44
+ | `baseline/evaluate.py` | Policy benchmarking across tasks | JSON report for random/heuristic/block/allow |
45
+
46
+ ## Environment Modes
47
+
48
+ - **Multi-session mode**: `step(action_map)` handles a variable batch of sessions per tick.
49
+ - **Single-session mode**: `step_single(action)` exposes one decision at a time with `Discrete(6)` semantics.
50
+ - **Inspect workflow**: inspect is first-stage evidence collection; follow-up action resolves the session.
docs/DEPLOYMENT.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Deployment
2
+
3
+ ## Local Runtime
4
+
5
+ ```bash
6
+ uvicorn adaptive_firewall_env.server.app:app --host 0.0.0.0 --port 8000
7
+ ```
8
+
9
+ ## Container Runtime
10
+
11
+ ```bash
12
+ docker build -f src/adaptive_firewall_env/server/Dockerfile -t adaptive-firewall-env .
13
+ docker run --rm -p 8000:8000 adaptive-firewall-env
14
+ ```
15
+
16
+ ## OpenEnv Metadata
17
+
18
+ - Manifest path: `src/adaptive_firewall_env/openenv.yaml`
19
+ - Runtime type: FastAPI app (`server.app:app`)
20
+ - Default port: `8000`
21
+
22
+ ## Smoke Checks
23
+
24
+ - `GET /health` returns `{ "status": "ok" }`
25
+ - `POST /reset` returns episode state
26
+ - `POST /step_single` returns next observation and reward
docs/REWARD_DESIGN.md ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 💰 Reward Design — Multi-Objective Optimization
2
+
3
+ The Adaptive AI Firewall environment uses a sophisticated, weighted reward function designed to drive agent behavior toward a balance of security efficacy, network availability, and resource efficiency.
4
+
5
+ ## 📐 The Reward Equation
6
+
7
+ The total scalar reward $R$ for any action is calculated as:
8
+
9
+ $$R = \alpha \cdot R_{\text{security}} + \beta \cdot R_{\text{availability}} + \gamma \cdot R_{\text{efficiency}} + \delta \cdot R_{\text{timeliness}}$$
10
+
11
+ ### **Default Weights**
12
+ | Component | Weight | Responsibility |
13
+ |---|---|---|
14
+ | $\alpha$ | **0.35** | Security Efficacy (Catching threats) |
15
+ | $\beta$ | **0.30** | Network Availability (Avoiding False Positives) |
16
+ | $\gamma$ | **0.20** | Resource Efficiency (Budget management) |
17
+ | $\delta$ | **0.15** | Timeliness (Stopping attacks early) |
18
+
19
+ ---
20
+
21
+ ## 🧩 Reward Components
22
+
23
+ ### **1. Security ($R_{\text{security}}$)**
24
+ - **Block Malicious**: $+1.0$ (Successfully stopped a threat).
25
+ - **Miss Malicious**: $-2.0$ (Failed to block an attack; high penalty).
26
+ - **Inspect Malicious**: $+0.15$ (Correct identification, though not yet stopped).
27
+ - **Inspect Benign**: $-0.5$ (Unnecessary inspection).
28
+
29
+ ### **2. Availability ($R_{\text{availability}}$)**
30
+ - **Allow Benign**: $+0.25$ (Maintaining network flow).
31
+ - **Block Benign (FP)**: $-1.2$ (Significant penalty for disrupting legitimate users).
32
+ - **Rate Limit Benign**: $-0.4$ (Milder penalty for "gray" actions).
33
+ - **Inspect Benign**: $-0.15$ (Unnecessary latency added).
34
+
35
+ ### **3. Efficiency ($R_{\text{efficiency}}$)**
36
+ - **Cost**: Calculated as $\text{latency} + \text{compute}$ for each action.
37
+ - **Scaling**: Penalized relative to remaining budget: $R_{\text{efficiency}} = -\frac{\text{cost}}{\max(\text{budget\_remaining}, 0.1)}$.
38
+ - This creates **Strategic Pressure**: actions become "more expensive" as the budget depletes.
39
+
40
+ ### **4. Timeliness ($R_{\text{timeliness}}$)**
41
+ - **Early Detection**: $+e^{-\text{phase}}$ where `phase` is the attacker's progress in the kill chain (0 to 4).
42
+ - **Incentive**: Stopping an attack at Phase 0 is significantly more rewarding than at Phase 3.
43
+
44
+ ---
45
+
46
+ ## 📊 Worked Examples
47
+
48
+ | Scenario | Action | Security | Availability | Efficiency | Timeliness | **Total Reward** |
49
+ |---|---|---|---|---|---|---|
50
+ | **Legitimate User** | `ALLOW` | $0.0$ | $+0.25$ | $0.0$ | $0.0$ | **$+0.075$** |
51
+ | **Early Attack (Ph 0)** | `BLOCK` | $+1.0$ | $0.0$ | $-0.005$ | $+1.0$ | **$+0.499$** |
52
+ | **Late Attack (Ph 3)** | `BLOCK` | $+1.0$ | $0.0$ | $-0.005$ | $+0.05$ | **$+0.357$** |
53
+ | **False Positive** | `BLOCK` | $0.0$ | $-1.2$ | $-0.005$ | $0.0$ | **$-0.361$** |
54
+ | **Missed Attack** | `ALLOW` | $-2.0$ | $0.0$ | $0.0$ | $0.0$ | **$-0.700$** |
55
+
56
+ ---
57
+
58
+ ## 🛡️ Anti-Degeneracy Controls
59
+
60
+ To prevent agents from learning "lazy" policies (like blocking everything or allowing everything), the environment implements:
61
+
62
+ 1. **Reward Balancing**: The ratio of Miss Penalty to FP Penalty is tuned (~2.3:1) so that on a typical 80/20 traffic mix, a `block_all` policy yields a negative total reward.
63
+ 2. **Pass/Fail Constraints**: Graders in [graders.py](file:///c:/Users/vettrivel/Documents/GitHub/meta_ai_hackathon/src/adaptive_firewall_env/server/graders.py) require a minimum detection rate **AND** a minimum availability rate to pass a task, regardless of the scalar reward.
docs/STATE_SPACE.md ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # State Space
2
+
3
+ The environment uses a **22-dimensional normalized observation vector** (`[0,1]` per feature).
4
+ Order is fixed by `FEATURE_ORDER` in `traffic_generator.py`.
5
+
6
+ ## Feature Groups
7
+
8
+ | Group | Features | Semantics |
9
+ |---|---|---|
10
+ | Volume & timing | bytes sent/received, duration, packet count, packet variance, inter-arrival mean/jitter | throughput shape and temporal burstiness |
11
+ | Network metadata | src/dst ports, protocol, DNS query count, connection reuse | routing and communication pattern |
12
+ | TLS / certificate | TLS version, JA3 cluster, chain length, cert validity, self-signed | encrypted-session trust indicators |
13
+ | Behavioral context | geo distance, time of day, session history score, entropy score | reputation and anomaly context |
14
+
15
+ ## Observation Interfaces
16
+
17
+ - `evaluate_session(session_id)` returns the vector for a given session.
18
+ - `state()` returns environment-level counters and selected session IDs.
19
+ - `step_single(action)` returns `observation` for the next queued session.
20
+
21
+ ## Normalization Strategy
22
+
23
+ - Each raw feature is min-max normalized using bounded ranges in `FEATURE_BOUNDS`.
24
+ - Outliers are clipped to `[0,1]` after normalization.
25
+ - This enables stable neural training across heterogeneous scales (ports, durations, entropy).
26
+
27
+ ## Markov Context Notes
28
+
29
+ - Single-session mode is designed for fixed-shape RL loops.
30
+ - Multi-session mode supports tool-driven decision systems over dynamic queues.
docs/TASKS.md ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Tasks
2
+
3
+ ## Difficulty Levels
4
+
5
+ | Task | Steps | Threshold | Pass Constraints |
6
+ |---|---:|---:|---|
7
+ | Easy | 200 | 0.70 | detection ≥ 0.35 and fp_complement ≥ 0.65 |
8
+ | Medium | 500 | 0.50 | detection ≥ 0.35 and fp_complement ≥ 0.60 |
9
+ | Hard | 1000 | 0.45 | detection ≥ 0.35 and fp_complement ≥ 0.55 |
10
+
11
+ ## Why Constraints Exist
12
+
13
+ Weighted scores alone can be gamed by degenerate policies:
14
+ - `allow_all` inflates availability/efficiency.
15
+ - `block_all` inflates detection.
16
+
17
+ The pass constraints ensure any passing policy must satisfy both:
18
+ 1. meaningful threat detection,
19
+ 2. acceptable benign-traffic handling.
20
+
21
+ Task scoring logic is implemented in `server/graders.py`.
docs/THREAT_MODELS.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Threat Models
2
+
3
+ ## Scenario Catalog
4
+
5
+ | Scenario | Early Phase | Mid Phase | Late Phase |
6
+ |---|---|---|---|
7
+ | `port_scan_exploit_c2` | rapid probing | exploit delivery | command/control + exfil |
8
+ | `credential_stuffing_lateral` | auth pressure | lateral movement | persistence |
9
+ | `supply_chain_compromise` | stealth foothold | trusted-channel abuse | disguised exfiltration |
10
+ | `low_and_slow_apt` | sparse reconnaissance | long dwell C2 | slow extraction |
11
+ | `ddos_amplification` | reflection probes | traffic amplification | flood stage |
12
+
13
+ ## Adaptation Behavior
14
+
15
+ - Repeated blocking increases attacker detection count.
16
+ - Detected attackers can switch to stealth mode and alter feature distributions.
17
+ - Attackers terminate when repeatedly blocked, time out, or complete exfiltration.
18
+ - Threat engine exposes per-attacker outcomes (`active`, `stopped`, `succeeded`) for analysis and credit assignment.
19
+
20
+ Threat generation and lifecycle are implemented in `server/threat_engine.py`.
implementation_plan.md ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adaptive AI Firewall — OpenEnv RL Challenge Compliance
2
+
3
+ ## Background
4
+
5
+ Your current codebase has a solid firewall RL environment, grader, and inference agent. However, several critical areas need changes to pass the hackathon's automated validation. I've analyzed the reference repos (reasoning_gym_env, calendar_env) and the submission guidelines in detail.
6
+
7
+ ## User Review Required
8
+
9
+ > [!IMPORTANT]
10
+ > **Ollama vs HuggingFace Router**: The hackathon guidelines mandate using the **OpenAI Client** with `API_BASE_URL` (default pointing to HuggingFace router) and `HF_TOKEN`. You mentioned wanting to use Ollama — but **the evaluation system will inject its own `API_BASE_URL` and `MODEL_NAME`** pointing to their hosted models. Your code must use the OpenAI client talking to whatever `API_BASE_URL` is provided. Ollama won't work during evaluation because the Docker container runs on HF Spaces with 2 vCPU / 8 GB RAM — no room to run a local LLM. Your current setup (HF router + OpenAI client) is **already correct**. I'll keep it as-is.
11
+
12
+ > [!WARNING]
13
+ > **Your `.env` file contains a real `HF_TOKEN`**. This is committed to git. You should rotate this token after we're done and add `.env` to `.gitignore`.
14
+
15
+ ---
16
+
17
+ ## Proposed Changes
18
+
19
+ ### 1. `inference.py` — Complete Rewrite (Critical)
20
+
21
+ #### [MODIFY] [inference.py](file:///c:/Users/vettrivel/Documents/GitHub/meta_ai_hackathon/inference.py)
22
+
23
+ **Current problems:**
24
+ - ❌ No `[START]` / `[STEP]` / `[END]` output lines (the #1 compliance requirement)
25
+ - ❌ `API_BASE_URL` and `MODEL_NAME` have no default values (will fail validation)
26
+ - ❌ `HF_TOKEN` is not validated as mandatory (should raise on missing)
27
+ - ❌ Uses `argparse` — evaluation just runs `python inference.py`
28
+ - ❌ Output is JSON, not the required line format
29
+
30
+ **Changes:**
31
+ - Add default values: `API_BASE_URL="https://router.huggingface.co/v1"`, `MODEL_NAME="meta-llama/Llama-3.1-8B-Instruct"`
32
+ - Raise `ValueError` if `HF_TOKEN` is missing
33
+ - Print `[START]` line before episode begins
34
+ - Print `[STEP]` line immediately after each `env.step()` return
35
+ - Print `[END]` line after episode ends (even on exception, using try/finally)
36
+ - Format rewards to 2 decimal places, booleans as lowercase `true`/`false`
37
+ - Remove argparse; hardcode task or pick from env var
38
+ - Keep the LLM-based agent logic (get_action) but fix it to work with defaults
39
+ - Run all 3 tasks sequentially (easy, medium, hard) or pick the best one
40
+
41
+ ---
42
+
43
+ ### 2. Server — Align with OpenEnv `create_app` Pattern
44
+
45
+ #### [MODIFY] [app.py](file:///c:/Users/vettrivel/Documents/GitHub/meta_ai_hackathon/src/adaptive_firewall_env/server/app.py) (primary server)
46
+
47
+ **Current problems:**
48
+ - ❌ Hand-rolled FastAPI endpoints — reference repos use `openenv.core.env_server.http_server.create_app()`
49
+ - ❌ The import chain in `server/app.py` (root) references a non-existent module path
50
+
51
+ **Changes:**
52
+ - **Keep the current hand-rolled server** since `create_app` requires `openenv.core.env_server.interfaces.Environment` base class and the firewall env doesn't extend it. Refactoring to use `create_app` would require significant env restructuring.
53
+ - Instead, fix the root `server/app.py` to correctly import from the right location
54
+ - Add `/web` endpoint for HF Spaces web interface compatibility
55
+ - Add `/schema` endpoint returning action/observation schemas
56
+
57
+ #### [MODIFY] [app.py](file:///c:/Users/vettrivel/Documents/GitHub/meta_ai_hackathon/server/app.py) (root server entry)
58
+
59
+ - Fix the broken import `from adaptive_firewall_env.server.app import app`
60
+ - Make it correctly reference the actual app
61
+
62
+ ---
63
+
64
+ ### 3. Dockerfile — Production Ready for HF Spaces
65
+
66
+ #### [MODIFY] [Dockerfile](file:///c:/Users/vettrivel/Documents/GitHub/meta_ai_hackathon/Dockerfile)
67
+
68
+ **Current problems:**
69
+ - ❌ Doesn't copy `inference.py` into the container
70
+ - ❌ Doesn't copy `env/`, `grader/`, `utils/` directories
71
+ - ❌ Heavy dependencies (torch, stable-baselines3) blow through 8 GB RAM
72
+
73
+ **Changes:**
74
+ - Copy ALL required source directories (`env/`, `grader/`, `utils/`, `inference.py`, `models/`)
75
+ - Set `PYTHONPATH` correctly
76
+ - Optimize requirements for smaller image size
77
+ - Keep `CMD` as uvicorn for the server (HF Spaces), but ensure `inference.py` can also run independently
78
+
79
+ ---
80
+
81
+ ### 4. Requirements — Trim for 8 GB RAM Constraint
82
+
83
+ #### [MODIFY] [requirements.txt](file:///c:/Users/vettrivel/Documents/GitHub/meta_ai_hackathon/requirements.txt)
84
+
85
+ **Changes:**
86
+ - Remove `torch` (huge, not needed for inference — agent uses OpenAI API)
87
+ - Remove `stable-baselines3` (training framework, not needed at inference)
88
+ - Remove `shimmy` (adapter for SB3)
89
+ - Remove `gymnasium` (not needed if using custom env directly)
90
+ - Keep: `fastapi`, `uvicorn`, `numpy`, `pydantic`, `requests`, `openai`, `python-dotenv`
91
+
92
+ ---
93
+
94
+ ### 5. `.env.example` — Fix Defaults
95
+
96
+ #### [MODIFY] [.env.example](file:///c:/Users/vettrivel/Documents/GitHub/meta_ai_hackathon/.env.example)
97
+
98
+ - Document that `HF_TOKEN` is **mandatory**
99
+ - Show default values for `API_BASE_URL` and `MODEL_NAME`
100
+
101
+ ---
102
+
103
+ ### 6. Fix Import Chain in `utils/threat_engine.py`
104
+
105
+ #### [MODIFY] [threat_engine.py](file:///c:/Users/vettrivel/Documents/GitHub/meta_ai_hackathon/utils/threat_engine.py)
106
+
107
+ **Current problem:**
108
+ - Line 17: `from adaptive_firewall_env.server.traffic_generator import TrafficGenerator` — wrong import path
109
+ - Should be `from utils.data_loader import TrafficGenerator`
110
+
111
+ ---
112
+
113
+ ### 7. `.gitignore` — Protect Secrets
114
+
115
+ #### [MODIFY] [.gitignore](file:///c:/Users/vettrivel/Documents/GitHub/meta_ai_hackathon/.gitignore)
116
+
117
+ - Ensure `.env` is listed (prevent token leaks)
118
+
119
+ ---
120
+
121
+ ## Architecture Summary After Changes
122
+
123
+ ```
124
+ meta_ai_hackathon/
125
+ ├── inference.py ← MAIN ENTRY POINT (hackathon requirement)
126
+ ├── Dockerfile ← HF Spaces deployment
127
+ ├── requirements.txt ← Trimmed dependencies
128
+ ├── openenv.yaml ← Environment manifest
129
+ ├── .env.example ← Template with docs
130
+ ├── env/
131
+ │ ├── firewall_env.py ← Core RL environment
132
+ │ └── models.py ← Pydantic request/response models
133
+ ├── grader/
134
+ │ └── firewall_grader.py ← Scoring logic
135
+ ├── utils/
136
+ │ ├── data_loader.py ← Traffic generation
137
+ │ ├── reward_engine.py ← Multi-objective rewards
138
+ │ └── threat_engine.py ← Attack orchestration (import fixed)
139
+ ├── server/
140
+ │ └── app.py ← FastAPI server for HF Spaces
141
+ └── src/adaptive_firewall_env/server/
142
+ └── app.py ← Full server with LLM playground
143
+ ```
144
+
145
+ ## Open Questions
146
+
147
+ > [!IMPORTANT]
148
+ > **Which tasks to run?** The hackathon evaluator likely runs `python inference.py` without arguments. Should we:
149
+ > - (A) Run all 3 tasks (easy, medium, hard) sequentially and output [START]/[STEP]/[END] for each?
150
+ > - (B) Run only `easy` by default?
151
+ > - I recommend **(A)** — running all 3 tasks to maximize score visibility. Each gets its own `[START]`/`[END]` block.
152
+
153
+ > [!IMPORTANT]
154
+ > **Max steps per task:** Easy=200, Medium=500, Hard=1000 steps. With LLM calls at each step, this could be slow with rate limits. Should I add a timeout or fallback more aggressively to heuristics?
155
+
156
+ ## Verification Plan
157
+
158
+ ### Automated Tests
159
+ 1. Run `python inference.py` and verify stdout matches the exact format:
160
+ ```
161
+ [START] task=easy env=ai-firewall model=meta-llama/Llama-3.1-8B-Instruct
162
+ [STEP] step=1 action=ALLOW reward=0.00 done=false error=null
163
+ ...
164
+ [END] success=true steps=200 rewards=0.00,0.00,...
165
+ ```
166
+ 2. Run `docker build -t ai-firewall .` and verify it builds under 8 GB
167
+ 3. Run the container and hit `/health` endpoint
168
+ 4. Verify all env vars work with defaults when `HF_TOKEN` is set
169
+
170
+ ### Manual Verification
171
+ - Deploy to HF Spaces and confirm the space reaches "Running" state
172
+ - Verify the web interface loads at the space URL
inference.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ import sys
6
+ import time
7
+ import textwrap
8
+ from typing import Any, Dict, List, Optional
9
+
10
+ import numpy as np
11
+ from openai import OpenAI
12
+ from dotenv import load_dotenv
13
+
14
+ # Import the environment directly for the AI Firewall
15
+ from server.firewall_environment import FirewallEnvironment, ACTIONS, TASK_CONFIGS
16
+
17
+ # --- Hackathon Submission Rules Compliance ---
18
+ # 1. inference.py in root directory ✅
19
+ # 2. Use OpenAI Client for all LLM calls ✅
20
+ # 3. Required Environment Variables with Defaults ✅
21
+ # 4. Strict Output Format: [START], [STEP], [END] ✅
22
+
23
+ load_dotenv()
24
+
25
+ # Environment Variables per Spec (defaults required for API_BASE_URL and MODEL_NAME)
26
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
27
+ MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct")
28
+ HF_TOKEN = os.getenv("HF_TOKEN")
29
+
30
+ if HF_TOKEN is None:
31
+ raise ValueError("HF_TOKEN environment variable is required")
32
+
33
+ # Benchmark configuration
34
+ BENCHMARK = "ai-firewall"
35
+
36
+
37
+ def format_bool(v: bool) -> str:
38
+ return "true" if v else "false"
39
+
40
+
41
+ def log_start(task: str, env: str, model: str) -> None:
42
+ print(f"[START] task={task} env={env} model={model}", flush=True)
43
+
44
+
45
+ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
46
+ error_val = error if error else "null"
47
+ done_val = format_bool(done)
48
+ print(f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", flush=True)
49
+
50
+
51
+ def log_end(success: bool, steps: int, rewards: List[float]) -> None:
52
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
53
+ print(f"[END] success={format_bool(success)} steps={steps} rewards={rewards_str}", flush=True)
54
+
55
+
56
+ class InferenceAgent:
57
+ def __init__(self):
58
+ self.client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
59
+
60
+ def get_action(self, session_data: Dict[str, Any], threat_intel: Dict[str, Any]) -> int:
61
+ """Get action using LLM via OpenAI client interface with heuristic fallback."""
62
+ system_prompt = textwrap.dedent(
63
+ """
64
+ You are an adaptive AI firewall controller.
65
+ Respond with ONLY valid JSON in this shape: {"reasoning": string, "action": integer}.
66
+ Action must be one integer between 0 and 5: 0=ALLOW, 1=BLOCK, 2=INSPECT, 3=SANDBOX, 4=RATE_LIMIT, 5=QUARANTINE.
67
+ Keep reasoning short (under 20 words).
68
+ """
69
+ ).strip()
70
+
71
+ user_prompt = json.dumps({
72
+ "session": session_data,
73
+ "threat_intelligence": threat_intel,
74
+ "actions": ACTIONS
75
+ })
76
+
77
+ max_retries = 3
78
+ for attempt in range(max_retries):
79
+ try:
80
+ response = self.client.chat.completions.create(
81
+ model=MODEL_NAME,
82
+ messages=[
83
+ {"role": "system", "content": system_prompt},
84
+ {"role": "user", "content": user_prompt}
85
+ ],
86
+ temperature=0.2,
87
+ max_tokens=150,
88
+ )
89
+
90
+ raw_content = response.choices[0].message.content
91
+
92
+ # Attempt to parse JSON
93
+ if "```json" in raw_content:
94
+ raw_content = raw_content.split("```json")[1].split("```")[0].strip()
95
+ elif "```" in raw_content:
96
+ raw_content = raw_content.split("```")[1].split("```")[0].strip()
97
+
98
+ content = json.loads(raw_content)
99
+ action = int(content.get("action", 0))
100
+ return max(0, min(5, action))
101
+
102
+ except Exception as e:
103
+ if "429" in str(e) and attempt < max_retries - 1:
104
+ time.sleep(2 ** attempt)
105
+ continue
106
+ return self._heuristic_action(session_data, threat_intel)
107
+
108
+ return self._heuristic_action(session_data, threat_intel)
109
+
110
+ def _heuristic_action(self, session_data: Dict[str, Any], threat_intel: Dict[str, Any]) -> int:
111
+ """Rule-based fallback with 8 detection rules."""
112
+ features = session_data.get("features", {})
113
+ known_bad_ports = set(threat_intel.get("known_bad_ports", []))
114
+
115
+ if session_data.get("revealed_malicious") is True:
116
+ return 1 # BLOCK
117
+
118
+ dst_port = int(features.get("dst_port", 0))
119
+ history = float(features.get("session_history_score", 1.0))
120
+ entropy = float(features.get("entropy_score", 0.0))
121
+ reuse = float(features.get("connection_reuse", 1.0))
122
+ self_signed = int(features.get("is_self_signed", 0))
123
+ ja3 = int(features.get("ja3_hash_cluster", 0))
124
+ geo = float(features.get("geo_distance", 0.0))
125
+ cert_valid = float(features.get("cert_validity_days", 999.0))
126
+ tls_ver = int(features.get("tls_version", 1))
127
+ dns_q = int(features.get("dns_query_count", 0))
128
+ dur = float(features.get("duration_ms", 500.0))
129
+ pkts = int(features.get("packet_count", 10))
130
+
131
+ if dst_port in known_bad_ports and history < 0.50:
132
+ return 1
133
+ if self_signed == 1 and history < 0.45:
134
+ return 5
135
+ if entropy > 0.55 and reuse < 0.25:
136
+ return 2
137
+ if geo > 4000.0 and history < 0.40:
138
+ return 2
139
+ if ja3 >= 180:
140
+ return 1
141
+ if dur < 60.0 and pkts > 100:
142
+ return 4
143
+ if cert_valid < 80.0 and tls_ver == 0:
144
+ return 2
145
+ if reuse < 0.10 and dns_q >= 4:
146
+ return 2
147
+
148
+ return 0 # ALLOW
149
+
150
+
151
+ def run_task(agent: InferenceAgent, task: str):
152
+ """Run a single task episode and emit spec-compliant output."""
153
+ seeds = {"easy": 101, "medium": 202, "hard": 303}
154
+ env = FirewallEnvironment(seed=seeds.get(task, 101))
155
+
156
+ log_start(task=task, env=BENCHMARK, model=MODEL_NAME)
157
+
158
+ state = env.reset(task=task)
159
+ done = False
160
+ rewards: List[float] = []
161
+ steps_taken = 0
162
+ success = False
163
+
164
+ try:
165
+ while not done:
166
+ action = 0
167
+ error_msg = None
168
+
169
+ focus_session_id = state.get("focus_session_id")
170
+ if focus_session_id:
171
+ try:
172
+ session_data = env.evaluate_session(focus_session_id)
173
+ threat_intel = env.get_threat_intelligence()
174
+ action = agent.get_action(session_data, threat_intel)
175
+ result = env.step_single(action)
176
+ except Exception as e:
177
+ error_msg = str(e)
178
+ result = env.step_single(0)
179
+ else:
180
+ result = env.step_single(0)
181
+
182
+ reward = float(result["reward"])
183
+ done = bool(result["done"])
184
+ state = result["state"]
185
+ steps_taken += 1
186
+ rewards.append(reward)
187
+
188
+ log_step(
189
+ step=steps_taken,
190
+ action=ACTIONS.get(action, "ALLOW"),
191
+ reward=reward,
192
+ done=done,
193
+ error=error_msg,
194
+ )
195
+
196
+ if done:
197
+ break
198
+
199
+ # Calculate final score via grader
200
+ final_stats = env.get_network_stats()
201
+ from server.graders import grade_stats
202
+ grade = grade_stats(task, final_stats)
203
+ # success = episode completed AND score meets threshold
204
+ success = grade.get("passed", False)
205
+
206
+ except Exception as e:
207
+ print(f"[DEBUG] Error during task {task}: {e}", file=sys.stderr)
208
+ success = False
209
+ finally:
210
+ log_end(success=success, steps=steps_taken, rewards=rewards)
211
+
212
+
213
+ def main():
214
+ try:
215
+ agent = InferenceAgent()
216
+ for task in ["easy", "medium", "hard"]:
217
+ run_task(agent, task)
218
+ except Exception as e:
219
+ print(f"Critical error: {e}", file=sys.stderr)
220
+ sys.exit(1)
221
+
222
+
223
+ if __name__ == "__main__":
224
+ main()
models.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ from pydantic import BaseModel, Field
6
+ # Standard OpenEnv types (if openenv-core is installed)
7
+ try:
8
+ from openenv.core.env_server.types import Action, Observation
9
+ except ImportError:
10
+ # Fallback if not installed
11
+ class Action(BaseModel):
12
+ pass
13
+ class Observation(BaseModel):
14
+ pass
15
+
16
+ # --- Custom Action/Observation classes as seen in video ---
17
+
18
+ class FirewallAction(Action):
19
+ """Action for the AI Firewall environment."""
20
+ action: int = Field(..., description="Action index: 0=ALLOW, 1=BLOCK, 2=INSPECT, 3=SANDBOX, 4=RATE_LIMIT, 5=QUARANTINE")
21
+ session_id: Optional[str] = Field(None, description="Specific session to act upon")
22
+
23
+ class FirewallObservation(Observation):
24
+ """Observation for the AI Firewall environment."""
25
+ features: List[float] = Field(..., description="22-dimensional normalized feature vector")
26
+ focus_session_id: Optional[str] = Field(None, description="ID of the session currently in focus")
27
+
28
+ # --- Original models from env/models.py ---
29
+
30
+ class ActionRecord(BaseModel):
31
+ tick: int
32
+ session_id: str
33
+ action: int
34
+ action_name: str
35
+ malicious: bool
36
+ reward: float
37
+ components: Dict[str, float]
38
+
39
+ class ResetRequest(BaseModel):
40
+ task: str = Field(default="easy", description="Task difficulty: easy, medium, hard")
41
+ seed: Optional[int] = Field(default=None, description="Random seed for reproducibility")
42
+
43
+ class StepRequest(BaseModel):
44
+ actions: Dict[str, int] = Field(default_factory=dict, description="Map of session_id to action index")
45
+
46
+ class StepSingleRequest(BaseModel):
47
+ action: int = Field(..., description="Action index (0-5) for the current focus session")
48
+
49
+ class ToolRequest(BaseModel):
50
+ kwargs: Dict[str, Any] = Field(default_factory=dict, description="Arguments for the tool call")
51
+
52
+ class StateResponse(BaseModel):
53
+ episode_id: int
54
+ task: str
55
+ step_count: int
56
+ current_tick: int
57
+ observation_dim: int
58
+ num_actions: int
59
+ budget_remaining: float
60
+ total_reward: float
61
+ pending_session_count: int
62
+ inspected_session_count: int
63
+ pending_session_ids: List[str]
64
+ inspected_session_ids: List[str]
65
+ queue_length: int
66
+ focus_session_id: Optional[str]
67
+ focus_observation: List[float]
68
+
69
+ class StepResponse(BaseModel):
70
+ reward: float
71
+ done: bool
72
+ state: StateResponse
73
+ info: Dict[str, Any]
74
+
75
+ class StepSingleResponse(BaseModel):
76
+ observation: List[float]
77
+ reward: float
78
+ done: bool
79
+ state: StateResponse
80
+ info: Dict[str, Any]
81
+
82
+ class EvaluateSessionResponse(BaseModel):
83
+ session_id: str
84
+ features: Dict[str, Any]
85
+ observation: List[float]
86
+ is_inspected: bool
87
+ revealed_malicious: Optional[bool]
88
+ expires_tick: int
89
+
90
+ class NetworkStatsResponse(BaseModel):
91
+ episode_id: int
92
+ task: str
93
+ tick: int
94
+ step_count: int
95
+ total_reward: float
96
+ budget_remaining: float
97
+ budget_used_pct: float
98
+ total_malicious: int
99
+ total_benign: int
100
+ detection_rate: float
101
+ false_positive_rate: float
102
+ efficiency: float
103
+ early_detection_bonus: float
104
+ cascade_prevention: float
105
+ correct_allows: int
106
+ inspections: int
107
+ expired_malicious: int
108
+ expired_benign: int
109
+
110
+ class HealthResponse(BaseModel):
111
+ status: str
112
+ version: str
113
+
114
+ class ToolsListResponse(BaseModel):
115
+ tools: List[str]
116
+
117
+ class TakeActionResponse(BaseModel):
118
+ reward: float
119
+ record: ActionRecord
120
+
121
+ class LLMChatRequest(BaseModel):
122
+ prompt: str
123
+ api_key: Optional[str] = None
124
+ base_url: Optional[str] = None
125
+ model: Optional[str] = None
126
+
127
+ class LLMChatResponse(BaseModel):
128
+ content: str
129
+ model: str
130
+
131
+ class LLMConfigResponse(BaseModel):
132
+ base_url: str
133
+ model: str
134
+ has_api_key: bool
135
+
136
+ class LLMTestResponse(BaseModel):
137
+ ok: bool
138
+ model: str
139
+ content: str
openenv.yaml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: ai-firewall-openenv
3
+ version: "1.0.0"
4
+ description: "AI-driven adaptive firewall for automated threat detection"
5
+ type: space
6
+ runtime: fastapi
7
+ app: server.app:app
8
+ port: 7860
9
+
10
+ tasks:
11
+ easy:
12
+ name: "Perimeter Defense"
13
+ description: "200-step episode with obvious attacks"
14
+ grading_seed: 101
15
+ threshold: 0.70
16
+ medium:
17
+ name: "Mixed Threat Landscape"
18
+ description: "500-step episode with multi-stage attacks and ambiguous traffic"
19
+ grading_seed: 202
20
+ threshold: 0.50
21
+ hard:
22
+ name: "Advanced Persistent Threat"
23
+ description: "1000-step episode with adaptive APTs and stealth threats"
24
+ grading_seed: 303
25
+ threshold: 0.45
26
+
27
+ tools:
28
+ - name: evaluate_session
29
+ description: "Get detailed features and observation for a specific session"
30
+ - name: take_action
31
+ description: "Apply a firewall action to a session (ALLOW, BLOCK, etc.)"
32
+ - name: get_network_stats
33
+ description: "Get cumulative episode statistics and performance metrics"
34
+ - name: get_threat_intelligence
35
+ description: "Access current threat intelligence feed"
36
+
37
+ observation_space:
38
+ type: box
39
+ shape: [22]
40
+ low: 0.0
41
+ high: 1.0
42
+
43
+ action_space:
44
+ type: discrete
45
+ n: 6
46
+ labels:
47
+ 0: ALLOW
48
+ 1: BLOCK
49
+ 2: INSPECT
50
+ 3: SANDBOX
51
+ 4: RATE_LIMIT
52
+ 5: QUARANTINE
progresss.md ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation Progress
2
+
3
+ ## Status
4
+
5
+ - Completed: project scaffolding and package manifest
6
+ - Completed: core server environment (`traffic_generator`, `threat_engine`, `reward_engine`, `firewall_environment`, `graders`, `app`)
7
+ - Completed: baseline policies (`random_agent`, `heuristic_agent`) and evaluator
8
+ - Completed: OpenEnv config, Dockerfile, requirements, client wrapper
9
+ - Completed: docs and AI skill/workflow files
10
+ - Completed: syntax verification with `py -m compileall src tests`
11
+ - Completed: baseline end-to-end evaluation run
12
+ - Completed: virtual environment created at `.venv` using `py -m venv .venv` with `PYTHONDONTWRITEBYTECODE=1`
13
+ - Completed: toolchain installed inside `.venv` (`pytest`, `ruff`, `requests`, `numpy`, `scipy`, `fastapi`, `pydantic`, `uvicorn`)
14
+ - Completed: `pytest` validation passed (`5 passed`)
15
+ - Completed: `ruff check src tests` passed (`All checks passed!`)
16
+ - Completed: runtime smoke test for reset/step (`ok 22 False`)
17
+ - Completed: REVIEW_AND_TODO P0/P1 core fixes implemented (budget scaling, inspect flow, expiration metrics, PYTHONPATH stability, reward rebalance)
18
+ - Completed: scenario-aware threat/traffic behavior and adaptive attacker lifecycle improvements
19
+ - Completed: one-session-per-step mode (`step_single`) and framework spaces (`observation_space`/`action_space`)
20
+ - Completed: new integration safeguards (`always_block`/`always_allow`) in baseline evaluator
21
+ - Completed: expanded automated tests from 5 to 16 and all passing in `.venv`
22
+ - Completed: latest validation (`pytest`: 16 passed, `ruff`: all checks passed)
23
+ - Completed: compatibility fixes after refactor (`__init__ budget arg`, inspect dual-pool consistency, `step_single` focus observation state field)
24
+ - Completed: comprehensive test suite now fully green (`pytest`: 38 passed)
25
+ - Completed: lint cleanup across source and consolidated tests (`ruff`: all checks passed)
26
+ - Completed: grading anti-degeneracy gates (pass constraints for detection + false-positive complement)
27
+ - Completed: evaluator now confirms heuristic passes all tasks while random/block-all/allow-all fail pass gates
28
+ - Completed: docs + skills + workflows significantly expanded from stubs to implementation-level guidance
29
+ - Completed: hackathon compliance changes implemented (inference.py, Dockerfile, requirements, .env.example, .gitignore)
30
+ - Completed: server endpoints added (/web, /schema) and root import fix
31
+ - Completed: all blocking issues from `REVIEW_AND_TODO.md` resolved
32
+ - Completed: refactored project structure to match OpenEnv standard layout (models.py at root, environment in server/)
33
+ - Completed: consolidated all environment logic into `server/` and removed redundant directories
34
+ - Completed: updated Web Playground UI to match the standard OpenEnv interface
35
+ - Completed: verified system logic with `inference.py` output and FastAPI health checks
36
+ - Completed: verified project structure and syntax with `py_compile`
37
+ - Completed: implemented local Ollama/Qwen support as default LLM with remote fallback
38
+ - Completed: updated `.env.example` with Ollama/Qwen configuration options
39
+
40
+ ## Decisions Applied
41
+
42
+ - Action space kept at 6 actions
43
+ - Observation space kept at 22 features
44
+ - OpenEnv target aligned to `openenv-core[core]>=0.2.2`
45
+ - Runtime mode set to CPU-oriented implementation
46
+ - Episode lengths follow 200/500/1000 task defaults
47
+ - Efficiency now remains non-zero for non-degenerate policies via scaled budget model
48
+ - Dependency cleanup: removed `scipy` from project dependency lists (unused)
49
+ - Pass/fail now requires both score threshold and minimum detection/availability constraints
pyproject.toml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "adaptive_firewall_env"
3
+ version = "0.2.0"
4
+ description = "Adaptive AI Firewall RL environment for encrypted traffic decision making"
5
+ readme = "README.md"
6
+ requires-python = ">=3.11"
7
+ dependencies = [
8
+ "fastapi>=0.112",
9
+ "uvicorn>=0.30",
10
+ "numpy>=1.26",
11
+ "pydantic>=2.0",
12
+ "requests>=2.32",
13
+ "openai>=1.30",
14
+ "python-dotenv>=1.0",
15
+ ]
16
+
17
+ [project.scripts]
18
+ server = "server.app:main"
19
+
20
+ [build-system]
21
+ requires = ["hatchling"]
22
+ build-backend = "hatchling.build"
23
+
24
+ [tool.hatch.build.targets.wheel]
25
+ packages = ["server"]
26
+
27
+ [tool.pytest.ini_options]
28
+ pythonpath = ["."]
29
+ testpaths = ["tests"]
30
+
31
+ [tool.ruff]
32
+ line-length = 120
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ numpy
4
+ pydantic
5
+ requests
6
+ openai
7
+ python-dotenv
scripts/validate-submission.sh ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ #
3
+ # validate-submission.sh — OpenEnv Submission Validator
4
+ #
5
+ # Checks that your HF Space is live, Docker image builds, and openenv validate passes.
6
+ #
7
+ # Prerequisites:
8
+ # - Docker: `https://docs.docker.com/get-docker/`
9
+ # - openenv-core: pip install openenv-core
10
+ # - curl (usually pre-installed)
11
+ #
12
+ # Run:
13
+ # curl -fsSL `https://raw.githubusercontent.com/<owner>/<repo>/main/scripts/validate-submission.sh` | bash -s -- <ping_url> [repo_dir]
14
+ #
15
+ # Or download and run locally:
16
+ # chmod +x validate-submission.sh
17
+ # ./validate-submission.sh <ping_url> [repo_dir]
18
+ #
19
+ # Arguments:
20
+ # ping_url Your HuggingFace Space URL (e.g. `https://your-space.hf.space)`
21
+ # repo_dir Path to your repo (default: current directory)
22
+ #
23
+ # Examples:
24
+ # ./validate-submission.sh `https://my-team.hf.space`
25
+ # ./validate-submission.sh `https://my-team.hf.space` ./my-repo
26
+ #
27
+
28
+ set -uo pipefail
29
+
30
+ DOCKER_BUILD_TIMEOUT=3600
31
+ if [ -t 1 ]; then
32
+ RED='\033[0;31m'
33
+ GREEN='\033[0;32m'
34
+ YELLOW='\033[1;33m'
35
+ BOLD='\033[1m'
36
+ NC='\033[0m'
37
+ else
38
+ RED='' GREEN='' YELLOW='' BOLD='' NC=''
39
+ fi
40
+
41
+ run_with_timeout() {
42
+ local secs="$1"; shift
43
+ if command -v timeout &>/dev/null; then
44
+ timeout "$secs" "$@"
45
+ elif command -v gtimeout &>/dev/null; then
46
+ gtimeout "$secs" "$@"
47
+ else
48
+ "$@" &
49
+ local pid=$!
50
+ ( sleep "$secs" && kill "$pid" 2>/dev/null ) &
51
+ local watcher=$!
52
+ wait "$pid" 2>/dev/null
53
+ local rc=$?
54
+ kill "$watcher" 2>/dev/null
55
+ wait "$watcher" 2>/dev/null
56
+ return $rc
57
+ fi
58
+ }
59
+
60
+ portable_mktemp() {
61
+ local prefix="${1:-validate}"
62
+ mktemp "${TMPDIR:-/tmp}/${prefix}-XXXXXX" 2>/dev/null || mktemp
63
+ }
64
+
65
+ CLEANUP_FILES=()
66
+ cleanup() { rm -f "${CLEANUP_FILES[@]+"${CLEANUP_FILES[@]}"}"; }
67
+ trap cleanup EXIT
68
+
69
+ PING_URL="${1:-}"
70
+ REPO_DIR="${2:-.}"
71
+
72
+ if [ -z "$PING_URL" ]; then
73
+ printf "Usage: %s <ping_url> [repo_dir]\n" "$0"
74
+ printf "\n"
75
+ printf " ping_url Your HuggingFace Space URL (e.g. `https://your-space.hf.space)\n` "
76
+ printf " repo_dir Path to your repo (default: current directory)\n"
77
+ exit 1
78
+ fi
79
+
80
+ if ! REPO_DIR="$(cd "$REPO_DIR" 2>/dev/null && pwd)"; then
81
+ printf "Error: directory '%s' not found\n" "${2:-.}"
82
+ exit 1
83
+ fi
84
+ PING_URL="${PING_URL%/}"
85
+ export PING_URL
86
+ PASS=0
87
+
88
+ log() { printf "[%s] %b\n" "$(date -u +%H:%M:%S)" "$*"; }
89
+ pass() { log "${GREEN}PASSED${NC} -- $1"; PASS=$((PASS + 1)); }
90
+ fail() { log "${RED}FAILED${NC} -- $1"; }
91
+ hint() { printf " ${YELLOW}Hint:${NC} %b\n" "$1"; }
92
+ stop_at() {
93
+ printf "\n"
94
+ printf "${RED}${BOLD}Validation stopped at %s.${NC} Fix the above before continuing.\n" "$1"
95
+ exit 1
96
+ }
97
+
98
+ printf "\n"
99
+ printf "${BOLD}========================================${NC}\n"
100
+ printf "${BOLD} OpenEnv Submission Validator${NC}\n"
101
+ printf "${BOLD}========================================${NC}\n"
102
+ log "Repo: $REPO_DIR"
103
+ log "Ping URL: $PING_URL"
104
+ printf "\n"
105
+
106
+ log "${BOLD}Step 1/3: Pinging HF Space${NC} ($PING_URL/reset) ..."
107
+
108
+ CURL_OUTPUT=$(portable_mktemp "validate-curl")
109
+ CLEANUP_FILES+=("$CURL_OUTPUT")
110
+ HTTP_CODE=$(curl.exe -s -o /dev/null -w "%{http_code}" -X POST \
111
+ -H "Content-Type: application/json" -d "{\"task\":\"easy\"}" \
112
+ "$PING_URL/reset" --max-time 30 || printf "000")
113
+ HTTP_CODE=$(echo $HTTP_CODE | tr -d '\r' | cut -c 1-3)
114
+
115
+ if [ "$HTTP_CODE" = "200" ]; then
116
+ pass "HF Space is live and responds to /reset"
117
+ elif [ "$HTTP_CODE" = "000" ]; then
118
+ fail "HF Space not reachable (connection failed or timed out)"
119
+ hint "Check your network connection and that the Space is running."
120
+ hint "Try: curl -s -o /dev/null -w '%%{http_code}' -X POST $PING_URL/reset"
121
+ stop_at "Step 1"
122
+ else
123
+ fail "HF Space /reset returned HTTP $HTTP_CODE (expected 200)"
124
+ hint "Make sure your Space is running and the URL is correct."
125
+ hint "Try opening $PING_URL in your browser first."
126
+ stop_at "Step 1"
127
+ fi
128
+
129
+ log "${BOLD}Step 2/3: Running docker build${NC} ..."
130
+
131
+ if ! command -v docker &>/dev/null; then
132
+ fail "docker command not found"
133
+ hint "Install Docker: `https://docs.docker.com/get-docker/` "
134
+ stop_at "Step 2"
135
+ fi
136
+
137
+ if [ -f "$REPO_DIR/Dockerfile" ]; then
138
+ DOCKER_CONTEXT="$REPO_DIR"
139
+ elif [ -f "$REPO_DIR/server/Dockerfile" ]; then
140
+ DOCKER_CONTEXT="$REPO_DIR/server"
141
+ else
142
+ fail "No Dockerfile found in repo root or server/ directory"
143
+ stop_at "Step 2"
144
+ fi
145
+
146
+ log " Found Dockerfile in $DOCKER_CONTEXT"
147
+
148
+ BUILD_OK=false
149
+ BUILD_OUTPUT=$(run_with_timeout "$DOCKER_BUILD_TIMEOUT" docker build "$DOCKER_CONTEXT" 2>&1) && BUILD_OK=true
150
+
151
+ if [ "$BUILD_OK" = true ]; then
152
+ pass "Docker build succeeded"
153
+ else
154
+ fail "Docker build failed (timeout=${DOCKER_BUILD_TIMEOUT}s)"
155
+ printf "%s\n" "$BUILD_OUTPUT" | tail -20
156
+ stop_at "Step 2"
157
+ fi
158
+
159
+ log "${BOLD}Step 3/3: Running openenv validate${NC} ..."
160
+
161
+ if ! command -v openenv &>/dev/null; then
162
+ fail "openenv command not found"
163
+ hint "Install it: pip install openenv-core"
164
+ stop_at "Step 3"
165
+ fi
166
+
167
+ VALIDATE_OK=false
168
+ VALIDATE_OUTPUT=$(cd "$REPO_DIR" && openenv validate 2>&1) && VALIDATE_OK=true
169
+
170
+ if [ "$VALIDATE_OK" = true ]; then
171
+ pass "openenv validate passed"
172
+ [ -n "$VALIDATE_OUTPUT" ] && log " $VALIDATE_OUTPUT"
173
+ else
174
+ fail "openenv validate failed"
175
+ printf "%s\n" "$VALIDATE_OUTPUT"
176
+ stop_at "Step 3"
177
+ fi
178
+
179
+ printf "\n"
180
+ printf "${BOLD}========================================${NC}\n"
181
+ printf "${GREEN}${BOLD} All 3/3 checks passed!${NC}\n"
182
+ printf "${GREEN}${BOLD} Your submission is ready to submit.${NC}\n"
183
+ printf "${BOLD}========================================${NC}\n"
184
+ printf "\n"
185
+
186
+ exit 0
server/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Package marker
server/app.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI server exposing the Adaptive AI Firewall environment.
2
+
3
+ Endpoints:
4
+ POST /reset — Start a new episode
5
+ POST /step — Multi-session step (batch actions)
6
+ POST /step_single — Single-session step (Gymnasium-compatible)
7
+ GET /state — Current environment state
8
+ GET /tools — List available tool names
9
+ POST /tool/{name} — Call a specific tool
10
+ GET /health — Health check
11
+ GET /stats — Current episode statistics
12
+ """
13
+ from __future__ import annotations
14
+
15
+ import os
16
+ from typing import Any
17
+
18
+ from fastapi import FastAPI, HTTPException
19
+ from fastapi.middleware.cors import CORSMiddleware
20
+ from fastapi.responses import HTMLResponse
21
+ from dotenv import load_dotenv
22
+
23
+ from server.firewall_environment import FirewallEnvironment, ACTIONS
24
+ from models import (
25
+ HealthResponse,
26
+ NetworkStatsResponse,
27
+ ResetRequest,
28
+ StateResponse,
29
+ StepRequest,
30
+ StepResponse,
31
+ StepSingleRequest,
32
+ StepSingleResponse,
33
+ ToolRequest,
34
+ ToolsListResponse,
35
+ )
36
+
37
+ load_dotenv()
38
+
39
+
40
+ def _clean_env_value(value: str) -> str:
41
+ return value.strip().strip("`").strip().strip("'").strip('"').strip()
42
+
43
+
44
+ def _resolve_api_key(value: str | None) -> str:
45
+ return _clean_env_value(value or os.getenv("HF_TOKEN") or "")
46
+
47
+
48
+ def _resolve_model(value: str | None) -> str:
49
+ return _clean_env_value(value or os.getenv("MODEL_NAME") or "")
50
+
51
+
52
+ def _resolve_base_url(value: str | None) -> str:
53
+ return _clean_env_value(
54
+ value
55
+ or os.getenv("API_BASE_URL")
56
+ or ""
57
+ )
58
+
59
+ PLAYGROUND_HTML = """<!doctype html>
60
+ <html lang="en">
61
+ <head>
62
+ <meta charset="utf-8"/>
63
+ <meta name="viewport" content="width=device-width,initial-scale=1"/>
64
+ <title>Adaptive Firewall Playground</title>
65
+ <style>
66
+ body{font-family:Arial,sans-serif;background:#0b1220;color:#e5e7eb;margin:0;padding:24px}
67
+ .card{max-width:980px;margin:0 auto;background:#111827;border:1px solid #1f2937;border-radius:12px;padding:18px}
68
+ h1{margin-top:0;font-size:22px}
69
+ label{display:block;font-size:12px;margin:10px 0 4px}
70
+ input,textarea,button{width:100%;box-sizing:border-box;border-radius:8px;border:1px solid #374151;background:#0f172a;color:#e5e7eb;padding:10px}
71
+ textarea{min-height:120px;resize:vertical}
72
+ button{background:#2563eb;border:none;cursor:pointer;font-weight:600;margin-top:12px}
73
+ button:disabled{opacity:.6;cursor:not-allowed}
74
+ pre{white-space:pre-wrap;background:#0f172a;border:1px solid #374151;border-radius:8px;padding:12px;min-height:120px;overflow:auto}
75
+ .grid{display:grid;grid-template-columns:1fr 1fr;gap:10px}
76
+ .row{display:grid;grid-template-columns:1fr 1fr 1fr;gap:10px}
77
+ .muted{font-size:12px;color:#93c5fd}
78
+ .ok{color:#86efac}
79
+ .bad{color:#fca5a5}
80
+ .btn-step{background:#22c55e}
81
+ .btn-reset{background:#64748b}
82
+ .btn-state{background:#64748b}
83
+ </style>
84
+ </head>
85
+ <body>
86
+ <div class="card">
87
+ <h1>Playground</h1>
88
+ <p class="muted">Click Reset to start a new episode.</p>
89
+
90
+ <label>Message / Action ID</label>
91
+ <input id="action_input" type="number" value="0" min="0" max="5" placeholder="Enter action index (0-5)..." />
92
+
93
+ <div class="row">
94
+ <button id="btn_step" class="btn-step">Step</button>
95
+ <button id="btn_reset" class="btn-reset">Reset</button>
96
+ <button id="btn_state" class="btn-state">Get state</button>
97
+ </div>
98
+
99
+ <div id="status" class="muted" style="margin-top:10px">Ready</div>
100
+
101
+ <label>Raw JSON response</label>
102
+ <pre id="output">{}</pre>
103
+ </div>
104
+
105
+ <script>
106
+ const output = document.getElementById("output");
107
+ const status = document.getElementById("status");
108
+ const actionInput = document.getElementById("action_input");
109
+
110
+ async function call(path, method='GET', body=null) {
111
+ status.textContent = "Calling " + path + "...";
112
+ try {
113
+ const options = {
114
+ method: method,
115
+ headers: {"Content-Type":"application/json"}
116
+ };
117
+ if (body) options.body = JSON.stringify(body);
118
+
119
+ const res = await fetch(path, options);
120
+ const data = await res.json();
121
+ output.textContent = JSON.stringify(data, null, 2);
122
+ status.textContent = "Success";
123
+ return data;
124
+ } catch (err) {
125
+ status.textContent = "Error: " + err;
126
+ output.textContent = String(err);
127
+ }
128
+ }
129
+
130
+ document.getElementById("btn_step").onclick = () => {
131
+ const action = parseInt(actionInput.value);
132
+ call("/step_single", "POST", {action: action});
133
+ };
134
+
135
+ document.getElementById("btn_reset").onclick = () => {
136
+ call("/reset", "POST", {task: "easy"});
137
+ };
138
+
139
+ document.getElementById("btn_state").onclick = () => {
140
+ call("/state", "GET");
141
+ };
142
+ </script>
143
+ </body>
144
+ </html>"""
145
+
146
+
147
+ env = FirewallEnvironment(seed=42)
148
+ app = FastAPI(
149
+ title="Adaptive AI Firewall OpenEnv",
150
+ version="0.2.0",
151
+ description="RL environment for adaptive firewall decision making on encrypted traffic.",
152
+ )
153
+
154
+ app.add_middleware(
155
+ CORSMiddleware,
156
+ allow_origins=["*"],
157
+ allow_credentials=True,
158
+ allow_methods=["*"],
159
+ allow_headers=["*"],
160
+ )
161
+
162
+
163
+ @app.get("/health", response_model=HealthResponse)
164
+ def health() -> HealthResponse:
165
+ return HealthResponse(status="ok", version="0.2.0")
166
+
167
+
168
+ @app.post("/reset", response_model=StateResponse)
169
+ def reset(request: ResetRequest) -> StateResponse:
170
+ try:
171
+ state = env.reset(task=request.task, seed=request.seed)
172
+ return StateResponse(**state)
173
+ except ValueError as e:
174
+ raise HTTPException(status_code=400, detail=str(e)) from e
175
+
176
+
177
+ @app.post("/step", response_model=StepResponse)
178
+ def step(request: StepRequest) -> StepResponse:
179
+ result = env.step(action_map=request.actions)
180
+ return StepResponse(**result)
181
+
182
+
183
+ @app.post("/step_single", response_model=StepSingleResponse)
184
+ def step_single(request: StepSingleRequest) -> StepSingleResponse:
185
+ try:
186
+ result = env.step_single(action=request.action)
187
+ return StepSingleResponse(**result)
188
+ except ValueError as e:
189
+ raise HTTPException(status_code=400, detail=str(e)) from e
190
+
191
+
192
+ @app.get("/state", response_model=StateResponse)
193
+ def state() -> StateResponse:
194
+ return StateResponse(**env.state())
195
+
196
+
197
+ @app.get("/stats", response_model=NetworkStatsResponse)
198
+ def stats() -> NetworkStatsResponse:
199
+ return NetworkStatsResponse(**env.get_network_stats())
200
+
201
+
202
+ @app.get("/tools", response_model=ToolsListResponse)
203
+ def list_tools() -> ToolsListResponse:
204
+ return ToolsListResponse(tools=env.list_tools())
205
+
206
+
207
+ @app.get("/web", response_class=HTMLResponse)
208
+ def web_interface() -> HTMLResponse:
209
+ return HTMLResponse(content=PLAYGROUND_HTML)
210
+
211
+
212
+ @app.get("/schema")
213
+ def schema() -> Any:
214
+ return {
215
+ "observation_space": {
216
+ "type": "Box",
217
+ "shape": [22],
218
+ "low": 0.0,
219
+ "high": 1.0,
220
+ },
221
+ "action_space": {
222
+ "type": "Discrete",
223
+ "n": 6,
224
+ "actions": ACTIONS,
225
+ },
226
+ }
227
+
228
+
229
+ @app.post("/tool/{name}")
230
+ def call_tool(name: str, request: ToolRequest) -> Any:
231
+ try:
232
+ if name == "evaluate_session":
233
+ return env.evaluate_session(request.kwargs["session_id"])
234
+ if name == "take_action":
235
+ reward, record = env.take_action(
236
+ session_id=request.kwargs["session_id"],
237
+ action=int(request.kwargs["action"]),
238
+ )
239
+ return {"reward": reward, "record": record}
240
+ if name == "get_network_stats":
241
+ return env.get_network_stats()
242
+ if name == "get_threat_intelligence":
243
+ return env.get_threat_intelligence()
244
+ raise HTTPException(status_code=404, detail=f"unknown tool: {name}")
245
+ except KeyError as exc:
246
+ raise HTTPException(status_code=400, detail=f"missing key: {exc}") from exc
247
+ except ValueError as exc:
248
+ raise HTTPException(status_code=400, detail=str(exc)) from exc
249
+
250
+
251
+ def main() -> None:
252
+ import uvicorn
253
+ uvicorn.run(app, host="0.0.0.0", port=7860)
254
+
255
+
256
+ if __name__ == "__main__":
257
+ main()
server/baseline/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Package marker
server/baseline/heuristic_agent.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Heuristic baseline agent for the Adaptive AI Firewall environment.
2
+
3
+ Uses the same 8-rule heuristic as inference.py for deterministic testing.
4
+ """
5
+ from __future__ import annotations
6
+
7
+ from typing import Dict, List
8
+
9
+
10
+ def heuristic_policy(env, session_ids: List[str]) -> Dict[str, int]:
11
+ """Rule-based policy using session features and threat intelligence."""
12
+ threat_intel = env.get_threat_intelligence()
13
+ known_bad_ports = set(threat_intel.get("known_bad_ports", []))
14
+ actions: Dict[str, int] = {}
15
+
16
+ for sid in session_ids:
17
+ try:
18
+ data = env.evaluate_session(sid)
19
+ except KeyError:
20
+ actions[sid] = 0
21
+ continue
22
+
23
+ features = data.get("features", {})
24
+
25
+ # If already revealed as malicious, block immediately
26
+ if data.get("revealed_malicious") is True:
27
+ actions[sid] = 1
28
+ continue
29
+
30
+ dst_port = int(features.get("dst_port", 0))
31
+ history = float(features.get("session_history_score", 1.0))
32
+ entropy = float(features.get("entropy_score", 0.0))
33
+ reuse = float(features.get("connection_reuse", 1.0))
34
+ self_signed = int(features.get("is_self_signed", 0))
35
+ ja3 = int(features.get("ja3_hash_cluster", 0))
36
+ geo = float(features.get("geo_distance", 0.0))
37
+ cert_valid = float(features.get("cert_validity_days", 999.0))
38
+ tls_ver = int(features.get("tls_version", 1))
39
+ dns_q = int(features.get("dns_query_count", 0))
40
+ dur = float(features.get("duration_ms", 500.0))
41
+ pkts = int(features.get("packet_count", 10))
42
+
43
+ if dst_port in known_bad_ports and history < 0.50:
44
+ actions[sid] = 1
45
+ elif self_signed == 1 and history < 0.45:
46
+ actions[sid] = 5
47
+ elif entropy > 0.55 and reuse < 0.25:
48
+ actions[sid] = 2
49
+ elif geo > 4000.0 and history < 0.40:
50
+ actions[sid] = 2
51
+ elif ja3 >= 180:
52
+ actions[sid] = 1
53
+ elif dur < 60.0 and pkts > 100:
54
+ actions[sid] = 4
55
+ elif cert_valid < 80.0 and tls_ver == 0:
56
+ actions[sid] = 2
57
+ elif reuse < 0.10 and dns_q >= 4:
58
+ actions[sid] = 2
59
+ else:
60
+ actions[sid] = 0
61
+
62
+ return actions
server/baseline/random_agent.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Random baseline agent for the Adaptive AI Firewall environment."""
2
+ from __future__ import annotations
3
+
4
+ from typing import Callable, Dict, List
5
+
6
+ import numpy as np
7
+
8
+
9
+ def random_policy(seed: int = 42) -> Callable:
10
+ """Return a random policy function seeded for reproducibility."""
11
+ rng = np.random.default_rng(seed)
12
+
13
+ def _policy(env, session_ids: List[str]) -> Dict[str, int]:
14
+ return {sid: int(rng.integers(0, 6)) for sid in session_ids}
15
+
16
+ return _policy
17
+
18
+
19
+ def block_all_policy(env, session_ids: List[str]) -> Dict[str, int]:
20
+ """Block every session — useful as a degenerate baseline."""
21
+ return {sid: 1 for sid in session_ids}
server/firewall_environment.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, List, Optional, Set, Tuple
5
+
6
+ import numpy as np
7
+
8
+ # Updated imports to reflect new structure
9
+ from server.utils.reward_engine import (
10
+ ACTIONS, BLOCKING_ACTIONS, RewardEngine,
11
+ )
12
+ from server.utils.threat_engine import ThreatEngine
13
+ from server.utils.data_loader import (
14
+ FEATURE_ORDER, TrafficGenerator,
15
+ )
16
+
17
+
18
+ TASK_CONFIGS = {
19
+ "easy": {
20
+ "max_steps": 200,
21
+ "benign_ratio": 0.80,
22
+ "threat_probability": 0.12,
23
+ "traffic_lambda": 5,
24
+ "budget": 100.0, # ~0.50 budget per step
25
+ },
26
+ "medium": {
27
+ "max_steps": 500,
28
+ "benign_ratio": 0.65,
29
+ "threat_probability": 0.22,
30
+ "traffic_lambda": 6,
31
+ "budget": 300.0, # ~0.60 budget per step
32
+ },
33
+ "hard": {
34
+ "max_steps": 1000,
35
+ "benign_ratio": 0.70,
36
+ "threat_probability": 0.30,
37
+ "traffic_lambda": 7,
38
+ "budget": 600.0, # ~0.60 budget per step
39
+ },
40
+ }
41
+
42
+ NUM_ACTIONS = len(ACTIONS)
43
+ OBS_DIM = len(FEATURE_ORDER)
44
+
45
+
46
+ @dataclass
47
+ class EpisodeMetrics:
48
+ """Tracks all metrics needed for grading."""
49
+ detections: int = 0
50
+ malicious_seen: int = 0
51
+ false_positives: int = 0
52
+ benign_seen: int = 0
53
+ early_detection_sum: float = 0.0
54
+ cascade_failures: int = 0
55
+ total_cost: float = 0.0
56
+ sessions_expired_malicious: int = 0
57
+ sessions_expired_benign: int = 0
58
+ correct_allows: int = 0
59
+ inspections: int = 0
60
+
61
+
62
+ class FirewallEnvironment:
63
+ """Adaptive AI Firewall RL environment.
64
+
65
+ OpenEnv-compatible: reset(), step(), state()
66
+
67
+ Key design (from RL perspective):
68
+ - Observation: 22-dim normalized [0,1] vector per session
69
+ - Action: Discrete(6) — ALLOW, BLOCK, INSPECT, SANDBOX, RATE_LIMIT, QUARANTINE
70
+ - Reward: multi-objective (security + availability + efficiency + timeliness)
71
+ - Done: when max_steps reached or budget depleted
72
+ - INSPECT keeps session alive for a second action (two-phase decision)
73
+ """
74
+
75
+ def __init__(self, seed: int = 0, budget: Optional[float] = None) -> None:
76
+ self.base_seed = seed
77
+ self.base_budget_override = budget
78
+ self.generator = TrafficGenerator(seed=seed)
79
+ self.threat_engine = ThreatEngine(seed=seed + 1)
80
+ self.reward_engine = RewardEngine()
81
+ self.rng = np.random.default_rng(seed + 2)
82
+
83
+ self.episode_id = 0
84
+ self.step_count = 0
85
+ self.current_tick = 0
86
+ self.task = "easy"
87
+ self.max_steps = TASK_CONFIGS[self.task]["max_steps"]
88
+ default_budget = TASK_CONFIGS[self.task]["budget"]
89
+ if self.base_budget_override is not None:
90
+ default_budget = max(default_budget, float(self.base_budget_override))
91
+ self.budget_remaining = default_budget
92
+ self.initial_budget = self.budget_remaining
93
+ self.total_reward = 0.0
94
+
95
+ self.pending_sessions: Dict[str, Dict] = {}
96
+ self.inspected_sessions: Dict[str, Dict] = {} # sessions awaiting 2nd action
97
+ self.action_log: List[Dict] = []
98
+ self._blocked_attacker_ids: Set[str] = set()
99
+ self.metrics = EpisodeMetrics()
100
+
101
+ # For single-session mode
102
+ self._session_queue: List[str] = []
103
+
104
+ # ══════════════════════════════════════════════════════════════════
105
+ # OpenEnv API
106
+ # ══════════════════════════════════════════════════════════════════
107
+
108
+ def reset(self, task: str = "easy", seed: Optional[int] = None) -> Dict:
109
+ """Reset environment for a new episode."""
110
+ if task not in TASK_CONFIGS:
111
+ raise ValueError(f"unknown task: {task}")
112
+
113
+ used_seed = self.base_seed if seed is None else seed
114
+ self.generator = TrafficGenerator(seed=used_seed)
115
+ self.threat_engine = ThreatEngine(seed=used_seed + 1)
116
+ self.rng = np.random.default_rng(used_seed + 2)
117
+
118
+ self.episode_id += 1
119
+ self.step_count = 0
120
+ self.current_tick = 0
121
+ self.task = task
122
+ config = TASK_CONFIGS[task]
123
+ self.max_steps = config["max_steps"]
124
+ task_budget = float(config["budget"])
125
+ if self.base_budget_override is not None:
126
+ task_budget = max(task_budget, float(self.base_budget_override))
127
+ self.initial_budget = task_budget
128
+ self.budget_remaining = self.initial_budget
129
+ self.total_reward = 0.0
130
+
131
+ self.pending_sessions = {}
132
+ self.inspected_sessions = {}
133
+ self.action_log = []
134
+ self._blocked_attacker_ids = set()
135
+ self.metrics = EpisodeMetrics()
136
+ self._session_queue = []
137
+
138
+ # Spawn initial sessions
139
+ self._spawn_sessions()
140
+ self._rebuild_queue()
141
+
142
+ return self.state()
143
+
144
+ def step(self, action_map: Optional[Dict[str, int]] = None) -> Dict:
145
+ """Multi-session step: agent provides actions for multiple sessions at once."""
146
+ action_map = action_map or {}
147
+ step_reward = 0.0
148
+
149
+ for session_id, action in action_map.items():
150
+ # Check both pending and inspected pools
151
+ if session_id in self.pending_sessions or session_id in self.inspected_sessions:
152
+ reward, _ = self._apply_action(session_id, action)
153
+ step_reward += reward
154
+
155
+ expired_penalty = self._expire_sessions()
156
+ step_reward += expired_penalty
157
+ self.total_reward += step_reward
158
+ self.step_count += 1
159
+ self.current_tick += 1
160
+
161
+ done = self.step_count >= self.max_steps or self.budget_remaining <= 0.0
162
+
163
+ if not done:
164
+ self._spawn_sessions()
165
+ self._rebuild_queue()
166
+
167
+ # Calculate score using the deterministic grader logic
168
+ final_stats = self.get_network_stats()
169
+ from server.graders import grade_stats
170
+ grade = grade_stats(self.task, final_stats)
171
+ return {
172
+ "reward": step_reward,
173
+ "done": done,
174
+ "state": self.state(),
175
+ "info": {
176
+ "expired_penalty": expired_penalty,
177
+ "attacker_outcomes": self.threat_engine.attacker_outcomes(),
178
+ "score": grade["score"],
179
+ "passed": grade["passed"]
180
+ },
181
+ }
182
+
183
+ def step_single(self, action: int) -> Dict:
184
+ """Single-session step: present one session, agent picks one action.
185
+
186
+ Compatible with Gymnasium Discrete(6).
187
+ Returns observation of the NEXT session, or zeros if episode done.
188
+ """
189
+ if action not in ACTIONS:
190
+ raise ValueError(f"invalid action: {action}")
191
+
192
+ step_reward = 0.0
193
+ info: Dict[str, Any] = {}
194
+
195
+ # Act on the current session
196
+ if self._session_queue:
197
+ session_id = self._session_queue.pop(0)
198
+ if session_id in self.pending_sessions or session_id in self.inspected_sessions:
199
+ reward, record = self._apply_action(session_id, action)
200
+ step_reward += reward
201
+ info["action_record"] = record
202
+
203
+ self.total_reward = round(self.total_reward + step_reward, 4)
204
+ self.step_count += 1
205
+
206
+ # If queue is empty, advance tick
207
+ if not self._session_queue:
208
+ self.current_tick += 1
209
+ expired_penalty = self._expire_sessions()
210
+ # step_reward for the final session in tick includes the expiration penalty
211
+ step_reward += expired_penalty
212
+ self.total_reward = round(self.total_reward + expired_penalty, 4)
213
+ done = self.step_count >= self.max_steps or self.budget_remaining <= 0.0
214
+ if not done:
215
+ self._spawn_sessions()
216
+ self._rebuild_queue()
217
+ else:
218
+ done = self.step_count >= self.max_steps or self.budget_remaining <= 0.0
219
+
220
+ # Build next observation
221
+ next_obs = self._current_observation()
222
+
223
+ return {
224
+ "observation": next_obs,
225
+ "reward": step_reward,
226
+ "done": done,
227
+ "state": {
228
+ **self.state(),
229
+ "focus_observation": next_obs,
230
+ "focus_session_id": self._session_queue[0] if self._session_queue else None,
231
+ },
232
+ "info": info,
233
+ }
234
+
235
+ def state(self) -> Dict:
236
+ """Return current environment state (OpenEnv API)."""
237
+ all_sessions = {**self.pending_sessions, **self.inspected_sessions}
238
+ top_ids = list(all_sessions.keys())[:10]
239
+ focus_session_id = self._session_queue[0] if self._session_queue else None
240
+ return {
241
+ "episode_id": self.episode_id,
242
+ "task": self.task,
243
+ "step_count": self.step_count,
244
+ "current_tick": self.current_tick,
245
+ "observation_dim": OBS_DIM,
246
+ "num_actions": NUM_ACTIONS,
247
+ "budget_remaining": round(self.budget_remaining, 4),
248
+ "total_reward": round(self.total_reward, 4),
249
+ "pending_session_count": len(self.pending_sessions),
250
+ "inspected_session_count": len(self.inspected_sessions),
251
+ "pending_session_ids": top_ids,
252
+ "inspected_session_ids": list(self.inspected_sessions.keys())[:10],
253
+ "queue_length": len(self._session_queue),
254
+ "focus_session_id": focus_session_id,
255
+ "focus_observation": self._current_observation(),
256
+ }
257
+
258
+ # ══════════════════════════════════════════════════════════════════
259
+ # Tool API (for MCP/HTTP interface)
260
+ # ════════════════════════════════════════════════════���═════════════
261
+
262
+ def evaluate_session(self, session_id: str) -> Dict:
263
+ """Get observation vector and metadata for a session."""
264
+ session = self.pending_sessions.get(session_id) or self.inspected_sessions.get(session_id)
265
+ if session is None:
266
+ raise KeyError(f"session not found: {session_id}")
267
+
268
+ return {
269
+ "session_id": session_id,
270
+ "features": dict(session["features"]),
271
+ "observation": self.generator.to_observation_vector(session),
272
+ "is_inspected": session_id in self.inspected_sessions,
273
+ "revealed_malicious": (
274
+ session["metadata"]["malicious"]
275
+ if session["metadata"]["revealed"] else None
276
+ ),
277
+ "expires_tick": session["expires_tick"],
278
+ }
279
+
280
+ def take_action(self, session_id: str, action: int) -> Tuple[float, Dict]:
281
+ """Apply an action to a specific session."""
282
+ return self._apply_action(session_id, action)
283
+
284
+ def get_network_stats(self) -> Dict:
285
+ """Aggregate episode statistics for grading."""
286
+ m = self.metrics
287
+ total_malicious = m.malicious_seen + m.sessions_expired_malicious
288
+ total_benign = m.benign_seen + m.sessions_expired_benign
289
+
290
+ detection_rate = m.detections / max(total_malicious, 1)
291
+ false_positive_rate = m.false_positives / max(total_benign, 1)
292
+ efficiency = 1.0 - min(1.0, m.total_cost / max(self.initial_budget, 1e-6))
293
+ early_detection_bonus = m.early_detection_sum / max(m.detections, 1)
294
+ cascade_prevention = 1.0 - (m.cascade_failures / max(total_malicious, 1))
295
+
296
+ return {
297
+ "episode_id": self.episode_id,
298
+ "task": self.task,
299
+ "tick": self.current_tick,
300
+ "step_count": self.step_count,
301
+ "total_reward": round(self.total_reward, 4),
302
+ "budget_remaining": round(self.budget_remaining, 4),
303
+ "budget_used_pct": round(1.0 - self.budget_remaining / max(self.initial_budget, 1e-6), 4),
304
+ "total_malicious": total_malicious,
305
+ "total_benign": total_benign,
306
+ "detection_rate": round(detection_rate, 6),
307
+ "false_positive_rate": round(false_positive_rate, 6),
308
+ "efficiency": round(efficiency, 6),
309
+ "early_detection_bonus": round(early_detection_bonus, 6),
310
+ "cascade_prevention": round(cascade_prevention, 6),
311
+ "correct_allows": m.correct_allows,
312
+ "inspections": m.inspections,
313
+ "expired_malicious": m.sessions_expired_malicious,
314
+ "expired_benign": m.sessions_expired_benign,
315
+ }
316
+
317
+ def get_threat_intelligence(self) -> Dict:
318
+ return self.threat_engine.intelligence_feed()
319
+
320
+ def list_tools(self) -> List[str]:
321
+ return [
322
+ "evaluate_session", "take_action",
323
+ "get_network_stats", "get_threat_intelligence",
324
+ ]
325
+
326
+ # ══════════════════════════════════════════════════════════════════
327
+ # Internal mechanics
328
+ # ══════════════════════════════════════════════════════════════════
329
+
330
+ def _apply_action(self, session_id: str, action: int) -> Tuple[float, Dict]:
331
+ """Core action application logic."""
332
+ if action not in ACTIONS:
333
+ raise ValueError(f"invalid action: {action}")
334
+
335
+ # Find the session in either pool
336
+ source_pool = "none"
337
+ if session_id in self.inspected_sessions:
338
+ session = self.inspected_sessions.pop(session_id)
339
+ source_pool = "inspected"
340
+ elif session_id in self.pending_sessions:
341
+ session = self.pending_sessions.pop(session_id)
342
+ source_pool = "pending"
343
+ else:
344
+ raise KeyError(f"session not found: {session_id}")
345
+
346
+ metadata = session["metadata"]
347
+ malicious = bool(metadata["malicious"])
348
+ blocked = action in BLOCKING_ACTIONS
349
+ inspected = action == 2 # INSPECT
350
+
351
+ # ── INSPECT keeps the session alive for a second decision ──
352
+ if inspected and session_id not in self.inspected_sessions:
353
+ metadata["revealed"] = True
354
+ self.inspected_sessions[session_id] = session
355
+ self.pending_sessions[session_id] = session
356
+ self.metrics.inspections += 1
357
+ # Compute reward for the inspection itself
358
+ reward, components = self.reward_engine.reward(
359
+ action=action,
360
+ is_malicious=malicious,
361
+ budget_remaining=self.budget_remaining,
362
+ attack_phase=metadata.get("attack_phase", 0),
363
+ inspect_correct=malicious,
364
+ )
365
+ self.budget_remaining = max(0.0, self.budget_remaining - components["cost"])
366
+ self.metrics.total_cost += components["cost"]
367
+ record = self._make_record(session_id, action, malicious, reward, components)
368
+ return reward, record
369
+
370
+ # ── Terminal action (ALLOW, BLOCK, SANDBOX, RATE_LIMIT, QUARANTINE) ──
371
+ inspect_correct = malicious and metadata.get("revealed", False)
372
+ reward, components = self.reward_engine.reward(
373
+ action=action,
374
+ is_malicious=malicious,
375
+ budget_remaining=self.budget_remaining,
376
+ attack_phase=metadata.get("attack_phase", 0),
377
+ inspect_correct=inspect_correct,
378
+ )
379
+ self.budget_remaining = max(0.0, self.budget_remaining - components["cost"])
380
+ self.metrics.total_cost += components["cost"]
381
+ if source_pool == "inspected":
382
+ self.pending_sessions.pop(session_id, None)
383
+
384
+ # ── Update metrics ──
385
+ if malicious:
386
+ self.metrics.malicious_seen += 1
387
+ if blocked:
388
+ self.metrics.detections += 1
389
+ phase = metadata.get("attack_phase", 0)
390
+ self.metrics.early_detection_sum += float(np.exp(-phase))
391
+ attacker_id = metadata.get("attacker_id")
392
+ if attacker_id:
393
+ self._blocked_attacker_ids.add(attacker_id)
394
+ else:
395
+ if metadata.get("attack_phase", 0) >= 2:
396
+ self.metrics.cascade_failures += 1
397
+ else:
398
+ self.metrics.benign_seen += 1
399
+ if blocked:
400
+ self.metrics.false_positives += 1
401
+ elif action == 0:
402
+ self.metrics.correct_allows += 1
403
+
404
+ record = self._make_record(session_id, action, malicious, reward, components)
405
+ self.action_log.append(record)
406
+ return reward, record
407
+
408
+ def _make_record(self, session_id: str, action: int, malicious: bool,
409
+ reward: float, components: Dict) -> Dict:
410
+ return {
411
+ "tick": self.current_tick,
412
+ "session_id": session_id,
413
+ "action": action,
414
+ "action_name": ACTIONS[action],
415
+ "malicious": malicious,
416
+ "reward": round(reward, 6),
417
+ "components": {k: round(v, 6) for k, v in components.items()},
418
+ }
419
+
420
+ def _spawn_sessions(self) -> None:
421
+ """Generate new benign and malicious sessions for current tick."""
422
+ config = TASK_CONFIGS[self.task]
423
+ benign_count = int(max(1, self.rng.poisson(
424
+ config["traffic_lambda"] * config["benign_ratio"],
425
+ )))
426
+ benign = self.generator.generate_benign_sessions(
427
+ tick=self.current_tick, count=benign_count,
428
+ )
429
+
430
+ self.threat_engine.maybe_spawn_attacker(config["threat_probability"])
431
+ malicious = self.threat_engine.generate_attack_sessions(
432
+ tick=self.current_tick,
433
+ generator=self.generator,
434
+ blocked_attackers=self._blocked_attacker_ids,
435
+ )
436
+ self._blocked_attacker_ids = set()
437
+
438
+ for session in benign + malicious:
439
+ self.pending_sessions[session["session_id"]] = session
440
+
441
+ def _expire_sessions(self) -> float:
442
+ """Remove expired sessions and apply penalties. Count in metrics."""
443
+ expired_ids = set()
444
+ for sid, session in self.pending_sessions.items():
445
+ if session["expires_tick"] <= self.current_tick:
446
+ expired_ids.add(sid)
447
+ for sid, session in self.inspected_sessions.items():
448
+ if session["expires_tick"] <= self.current_tick:
449
+ expired_ids.add(sid)
450
+ penalty = 0.0
451
+ for session_id in expired_ids:
452
+ session = self.inspected_sessions.pop(session_id, None)
453
+ if session is None:
454
+ session = self.pending_sessions.get(session_id)
455
+ self.pending_sessions.pop(session_id, None)
456
+ if session is None:
457
+ continue
458
+ if session["metadata"]["malicious"]:
459
+ penalty -= 1.5
460
+ self.metrics.sessions_expired_malicious += 1
461
+ if session["metadata"].get("attack_phase", 0) >= 2:
462
+ self.metrics.cascade_failures += 1
463
+ else:
464
+ self.metrics.sessions_expired_benign += 1
465
+
466
+ return penalty
467
+
468
+ def _rebuild_queue(self) -> None:
469
+ """Rebuild the single-session queue from pending + inspected."""
470
+ # Inspected sessions get priority (they need a follow-up action)
471
+ ordered = list(self.inspected_sessions.keys()) + list(self.pending_sessions.keys())
472
+ seen: Set[str] = set()
473
+ self._session_queue = []
474
+ for sid in ordered:
475
+ if sid in seen:
476
+ continue
477
+ seen.add(sid)
478
+ self._session_queue.append(sid)
479
+
480
+ def _current_observation(self) -> List[float]:
481
+ """Get normalized observation for the next session in queue."""
482
+ if self._session_queue:
483
+ sid = self._session_queue[0]
484
+ session = (
485
+ self.inspected_sessions.get(sid)
486
+ or self.pending_sessions.get(sid)
487
+ )
488
+ if session:
489
+ return self.generator.to_observation_vector(session)
490
+ return [0.0] * OBS_DIM
server/graders.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Deterministic grading system for the three firewall tasks.
2
+
3
+ Each task has:
4
+ - A fixed seed for reproducible traffic
5
+ - Weighted scoring across detection, false positives, efficiency, etc.
6
+ - A score in [0.0, 1.0] and a pass threshold
7
+
8
+ Graders are deterministic: same seed + same policy = same score.
9
+ """
10
+ from __future__ import annotations
11
+
12
+ from dataclasses import dataclass
13
+ from typing import Callable, Dict, List
14
+
15
+ # Updated import path
16
+ from server.firewall_environment import FirewallEnvironment
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class TaskSpec:
21
+ name: str
22
+ task_key: str
23
+ threshold: float
24
+ weights: Dict[str, float]
25
+ seed: int
26
+
27
+
28
+ TASK_SPECS = {
29
+ "easy": TaskSpec(
30
+ name="Perimeter Defense",
31
+ task_key="easy",
32
+ threshold=0.70,
33
+ seed=101,
34
+ weights={
35
+ "detection_rate": 0.35,
36
+ "fp_complement": 0.35,
37
+ "efficiency": 0.30,
38
+ },
39
+ ),
40
+ "medium": TaskSpec(
41
+ name="Mixed Threat Landscape",
42
+ task_key="medium",
43
+ threshold=0.50,
44
+ seed=202,
45
+ weights={
46
+ "detection_rate": 0.25,
47
+ "fp_complement": 0.30,
48
+ "efficiency": 0.15,
49
+ "early_detection_bonus": 0.15,
50
+ "cascade_prevention": 0.15,
51
+ },
52
+ ),
53
+ "hard": TaskSpec(
54
+ name="Advanced Persistent Threat",
55
+ task_key="hard",
56
+ threshold=0.45,
57
+ seed=303,
58
+ weights={
59
+ "detection_rate": 0.20,
60
+ "fp_complement": 0.25,
61
+ "efficiency": 0.15,
62
+ "early_detection_bonus": 0.20,
63
+ "cascade_prevention": 0.20,
64
+ },
65
+ ),
66
+ }
67
+
68
+ PASS_CONSTRAINTS = {
69
+ "easy": {"min_detection_rate": 0.35, "min_fp_complement": 0.65},
70
+ "medium": {"min_detection_rate": 0.35, "min_fp_complement": 0.60},
71
+ "hard": {"min_detection_rate": 0.35, "min_fp_complement": 0.55},
72
+ }
73
+
74
+
75
+ def grade_stats(task: str, stats: Dict) -> Dict:
76
+ """Compute a grade from episode stats."""
77
+ spec = TASK_SPECS[task]
78
+ values = {
79
+ "detection_rate": stats.get("detection_rate", 0.0),
80
+ "fp_complement": 1.0 - stats.get("false_positive_rate", 1.0),
81
+ "efficiency": stats.get("efficiency", 0.0),
82
+ "early_detection_bonus": stats.get("early_detection_bonus", 0.0),
83
+ "cascade_prevention": stats.get("cascade_prevention", 0.0),
84
+ }
85
+ score = sum(values.get(k, 0.0) * w for k, w in spec.weights.items())
86
+ score = max(0.0, min(1.0, score))
87
+ constraints = PASS_CONSTRAINTS[task]
88
+ meets_constraints = (
89
+ values["detection_rate"] >= constraints["min_detection_rate"]
90
+ and values["fp_complement"] >= constraints["min_fp_complement"]
91
+ )
92
+ passed = (score >= spec.threshold) and meets_constraints
93
+
94
+ return {
95
+ "task": task,
96
+ "task_name": spec.name,
97
+ "threshold": spec.threshold,
98
+ "score": round(score, 6),
99
+ "passed": passed,
100
+ "pass_constraints": constraints,
101
+ "meets_constraints": meets_constraints,
102
+ "breakdown": {k: round(v, 6) for k, v in values.items()},
103
+ }
104
+
105
+
106
+ def run_deterministic_grade(
107
+ env: FirewallEnvironment,
108
+ task: str,
109
+ policy: Callable[[FirewallEnvironment, List[str]], Dict[str, int]],
110
+ ) -> Dict:
111
+ """Run a full episode with a policy and compute the grade."""
112
+ spec = TASK_SPECS[task]
113
+ env.reset(task=task, seed=spec.seed)
114
+ done = False
115
+ while not done:
116
+ session_ids = (
117
+ list(env.inspected_sessions.keys())
118
+ + list(env.pending_sessions.keys())
119
+ )
120
+ actions = policy(env, session_ids)
121
+ response = env.step(actions)
122
+ done = bool(response["done"])
123
+ stats = env.get_network_stats()
124
+ return grade_stats(task, stats)
server/utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Package marker
server/utils/data_loader.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Network traffic session generator with realistic correlated features.
2
+
3
+ Each session is a 22-dimensional feature vector representing metadata and
4
+ behavioral signals from encrypted traffic (no payload inspection).
5
+
6
+ Feature groups:
7
+ - Volume & timing: bytes, duration, packet stats, inter-arrival metrics
8
+ - Network metadata: ports, protocol, DNS, connection reuse
9
+ - TLS / certificate: TLS version, JA3 cluster, cert chain, self-signed
10
+ - Behavioral context: geo distance, time of day, reputation, entropy
11
+
12
+ Benign traffic is drawn from 5 profile archetypes. Malicious traffic
13
+ profiles vary by attack scenario AND kill-chain phase, creating real
14
+ distributional differences an RL agent can learn to exploit.
15
+ """
16
+ from __future__ import annotations
17
+
18
+ from dataclasses import dataclass
19
+ from typing import Dict, List
20
+ import math
21
+
22
+ import numpy as np
23
+
24
+
25
+ FEATURE_ORDER = [
26
+ "bytes_sent",
27
+ "bytes_received",
28
+ "duration_ms",
29
+ "packet_count",
30
+ "avg_packet_size",
31
+ "packet_size_variance",
32
+ "inter_arrival_mean",
33
+ "inter_arrival_jitter",
34
+ "src_port",
35
+ "dst_port",
36
+ "protocol",
37
+ "tls_version",
38
+ "ja3_hash_cluster",
39
+ "cert_chain_length",
40
+ "cert_validity_days",
41
+ "is_self_signed",
42
+ "dns_query_count",
43
+ "connection_reuse",
44
+ "geo_distance",
45
+ "time_of_day",
46
+ "session_history_score",
47
+ "entropy_score",
48
+ ]
49
+
50
+ # Min/max bounds for normalization (empirically calibrated)
51
+ FEATURE_BOUNDS: Dict[str, tuple] = {
52
+ "bytes_sent": (4.0, 14.0),
53
+ "bytes_received": (3.0, 13.0),
54
+ "duration_ms": (20.0, 25000.0),
55
+ "packet_count": (2.0, 1200.0),
56
+ "avg_packet_size": (40.0, 1400.0),
57
+ "packet_size_variance": (5.0, 500.0),
58
+ "inter_arrival_mean": (0.5, 600.0),
59
+ "inter_arrival_jitter": (0.0, 300.0),
60
+ "src_port": (1024.0, 65535.0),
61
+ "dst_port": (1.0, 65535.0),
62
+ "protocol": (0.0, 2.0),
63
+ "tls_version": (0.0, 2.0),
64
+ "ja3_hash_cluster": (0.0, 255.0),
65
+ "cert_chain_length": (0.0, 6.0),
66
+ "cert_validity_days": (1.0, 1200.0),
67
+ "is_self_signed": (0.0, 1.0),
68
+ "dns_query_count": (0.0, 12.0),
69
+ "connection_reuse": (0.0, 1.0),
70
+ "geo_distance": (0.0, 12000.0),
71
+ "time_of_day": (0.0, 1.0),
72
+ "session_history_score": (0.0, 1.0),
73
+ "entropy_score": (0.0, 1.0),
74
+ }
75
+
76
+
77
+ @dataclass(frozen=True)
78
+ class TrafficProfile:
79
+ name: str
80
+ packet_mean: float
81
+ packet_std_frac: float # std = mean * frac
82
+ duration_mean: float
83
+ entropy_mean: float
84
+ entropy_std: float
85
+ tls_probability: float
86
+ self_signed_prob: float
87
+ common_ports: List[int]
88
+ connection_reuse_mean: float
89
+ geo_distance_mean: float
90
+ history_score_mean: float
91
+ cert_validity_mean: float
92
+ ja3_cluster_range: tuple = (0, 128)
93
+
94
+
95
+ # ── Benign traffic profiles ─────────────────────────────────────────
96
+ BENIGN_PROFILES = [
97
+ TrafficProfile(
98
+ name="WebBrowsing", packet_mean=50.0, packet_std_frac=0.35,
99
+ duration_mean=900.0, entropy_mean=0.32, entropy_std=0.06,
100
+ tls_probability=0.95, self_signed_prob=0.02,
101
+ common_ports=[80, 443], connection_reuse_mean=0.72,
102
+ geo_distance_mean=1400.0, history_score_mean=0.82,
103
+ cert_validity_mean=450.0, ja3_cluster_range=(0, 64),
104
+ ),
105
+ TrafficProfile(
106
+ name="Streaming", packet_mean=800.0, packet_std_frac=0.25,
107
+ duration_mean=18000.0, entropy_mean=0.22, entropy_std=0.04,
108
+ tls_probability=0.99, self_signed_prob=0.01,
109
+ common_ports=[443, 8080], connection_reuse_mean=0.88,
110
+ geo_distance_mean=2200.0, history_score_mean=0.90,
111
+ cert_validity_mean=500.0, ja3_cluster_range=(0, 32),
112
+ ),
113
+ TrafficProfile(
114
+ name="API", packet_mean=25.0, packet_std_frac=0.30,
115
+ duration_mean=350.0, entropy_mean=0.18, entropy_std=0.04,
116
+ tls_probability=0.98, self_signed_prob=0.01,
117
+ common_ports=[443, 8443], connection_reuse_mean=0.80,
118
+ geo_distance_mean=1000.0, history_score_mean=0.85,
119
+ cert_validity_mean=500.0, ja3_cluster_range=(0, 48),
120
+ ),
121
+ TrafficProfile(
122
+ name="IoT", packet_mean=10.0, packet_std_frac=0.40,
123
+ duration_mean=1500.0, entropy_mean=0.38, entropy_std=0.07,
124
+ tls_probability=0.30, self_signed_prob=0.08,
125
+ common_ports=[1883, 5683, 8883], connection_reuse_mean=0.55,
126
+ geo_distance_mean=800.0, history_score_mean=0.70,
127
+ cert_validity_mean=300.0, ja3_cluster_range=(80, 128),
128
+ ),
129
+ TrafficProfile(
130
+ name="Enterprise", packet_mean=120.0, packet_std_frac=0.35,
131
+ duration_mean=1200.0, entropy_mean=0.28, entropy_std=0.06,
132
+ tls_probability=0.85, self_signed_prob=0.04,
133
+ common_ports=[443, 445, 3389], connection_reuse_mean=0.65,
134
+ geo_distance_mean=500.0, history_score_mean=0.88,
135
+ cert_validity_mean=400.0, ja3_cluster_range=(0, 96),
136
+ ),
137
+ ]
138
+
139
+ # ── Malicious traffic profiles per (scenario, phase) ────────────────
140
+ # Each scenario has distinct fingerprints making them differentiable
141
+ MALICIOUS_PROFILES: Dict[str, Dict[int, TrafficProfile]] = {
142
+ "port_scan_exploit_c2": {
143
+ 0: TrafficProfile(
144
+ name="PortScan_Recon", packet_mean=6.0, packet_std_frac=0.5,
145
+ duration_mean=80.0, entropy_mean=0.12, entropy_std=0.04,
146
+ tls_probability=0.05, self_signed_prob=0.60,
147
+ common_ports=[21, 22, 23, 25, 445, 3389, 5900],
148
+ connection_reuse_mean=0.02, geo_distance_mean=5500.0,
149
+ history_score_mean=0.10, cert_validity_mean=60.0,
150
+ ja3_cluster_range=(200, 255),
151
+ ),
152
+ 1: TrafficProfile(
153
+ name="PortScan_Exploit", packet_mean=45.0, packet_std_frac=0.4,
154
+ duration_mean=300.0, entropy_mean=0.78, entropy_std=0.06,
155
+ tls_probability=0.40, self_signed_prob=0.45,
156
+ common_ports=[80, 443, 8080, 445],
157
+ connection_reuse_mean=0.08, geo_distance_mean=5200.0,
158
+ history_score_mean=0.12, cert_validity_mean=90.0,
159
+ ja3_cluster_range=(210, 255),
160
+ ),
161
+ 2: TrafficProfile(
162
+ name="PortScan_C2", packet_mean=4.0, packet_std_frac=0.6,
163
+ duration_mean=5000.0, entropy_mean=0.55, entropy_std=0.08,
164
+ tls_probability=0.92, self_signed_prob=0.35,
165
+ common_ports=[443, 53, 8443],
166
+ connection_reuse_mean=0.15, geo_distance_mean=6000.0,
167
+ history_score_mean=0.15, cert_validity_mean=45.0,
168
+ ja3_cluster_range=(220, 255),
169
+ ),
170
+ 3: TrafficProfile(
171
+ name="PortScan_Exfil", packet_mean=350.0, packet_std_frac=0.3,
172
+ duration_mean=12000.0, entropy_mean=0.88, entropy_std=0.04,
173
+ tls_probability=0.98, self_signed_prob=0.25,
174
+ common_ports=[443, 8443],
175
+ connection_reuse_mean=0.10, geo_distance_mean=6500.0,
176
+ history_score_mean=0.08, cert_validity_mean=30.0,
177
+ ja3_cluster_range=(230, 255),
178
+ ),
179
+ },
180
+ "credential_stuffing_lateral": {
181
+ 0: TrafficProfile(
182
+ name="CredStuff_Probe", packet_mean=15.0, packet_std_frac=0.4,
183
+ duration_mean=200.0, entropy_mean=0.42, entropy_std=0.06,
184
+ tls_probability=0.90, self_signed_prob=0.10,
185
+ common_ports=[443, 80, 8443],
186
+ connection_reuse_mean=0.05, geo_distance_mean=3500.0,
187
+ history_score_mean=0.25, cert_validity_mean=300.0,
188
+ ja3_cluster_range=(140, 200),
189
+ ),
190
+ 1: TrafficProfile(
191
+ name="CredStuff_Auth", packet_mean=20.0, packet_std_frac=0.35,
192
+ duration_mean=150.0, entropy_mean=0.50, entropy_std=0.07,
193
+ tls_probability=0.95, self_signed_prob=0.08,
194
+ common_ports=[443, 389, 636],
195
+ connection_reuse_mean=0.10, geo_distance_mean=3200.0,
196
+ history_score_mean=0.30, cert_validity_mean=350.0,
197
+ ja3_cluster_range=(150, 210),
198
+ ),
199
+ 2: TrafficProfile(
200
+ name="CredStuff_Lateral", packet_mean=30.0, packet_std_frac=0.35,
201
+ duration_mean=500.0, entropy_mean=0.35, entropy_std=0.06,
202
+ tls_probability=0.80, self_signed_prob=0.12,
203
+ common_ports=[445, 3389, 5985, 22],
204
+ connection_reuse_mean=0.20, geo_distance_mean=300.0,
205
+ history_score_mean=0.40, cert_validity_mean=350.0,
206
+ ja3_cluster_range=(160, 220),
207
+ ),
208
+ 3: TrafficProfile(
209
+ name="CredStuff_Exfil", packet_mean=200.0, packet_std_frac=0.3,
210
+ duration_mean=8000.0, entropy_mean=0.80, entropy_std=0.05,
211
+ tls_probability=0.98, self_signed_prob=0.15,
212
+ common_ports=[443, 8443],
213
+ connection_reuse_mean=0.12, geo_distance_mean=4000.0,
214
+ history_score_mean=0.18, cert_validity_mean=90.0,
215
+ ja3_cluster_range=(180, 240),
216
+ ),
217
+ },
218
+ "supply_chain_compromise": {
219
+ 0: TrafficProfile(
220
+ name="SupplyChain_Init", packet_mean=40.0, packet_std_frac=0.3,
221
+ duration_mean=600.0, entropy_mean=0.30, entropy_std=0.05,
222
+ tls_probability=0.98, self_signed_prob=0.03,
223
+ common_ports=[443, 8443],
224
+ connection_reuse_mean=0.60, geo_distance_mean=1800.0,
225
+ history_score_mean=0.70, cert_validity_mean=380.0,
226
+ ja3_cluster_range=(30, 80),
227
+ ),
228
+ 1: TrafficProfile(
229
+ name="SupplyChain_Inject", packet_mean=60.0, packet_std_frac=0.3,
230
+ duration_mean=800.0, entropy_mean=0.40, entropy_std=0.06,
231
+ tls_probability=0.98, self_signed_prob=0.04,
232
+ common_ports=[443, 8443],
233
+ connection_reuse_mean=0.55, geo_distance_mean=2000.0,
234
+ history_score_mean=0.65, cert_validity_mean=350.0,
235
+ ja3_cluster_range=(35, 90),
236
+ ),
237
+ 2: TrafficProfile(
238
+ name="SupplyChain_Beacon", packet_mean=8.0, packet_std_frac=0.5,
239
+ duration_mean=3000.0, entropy_mean=0.48, entropy_std=0.07,
240
+ tls_probability=0.99, self_signed_prob=0.05,
241
+ common_ports=[443],
242
+ connection_reuse_mean=0.50, geo_distance_mean=2500.0,
243
+ history_score_mean=0.55, cert_validity_mean=250.0,
244
+ ja3_cluster_range=(40, 100),
245
+ ),
246
+ 3: TrafficProfile(
247
+ name="SupplyChain_Exfil", packet_mean=100.0, packet_std_frac=0.3,
248
+ duration_mean=5000.0, entropy_mean=0.60, entropy_std=0.06,
249
+ tls_probability=0.99, self_signed_prob=0.06,
250
+ common_ports=[443, 8443],
251
+ connection_reuse_mean=0.42, geo_distance_mean=3000.0,
252
+ history_score_mean=0.45, cert_validity_mean=200.0,
253
+ ja3_cluster_range=(50, 110),
254
+ ),
255
+ },
256
+ "low_and_slow_apt": {
257
+ 0: TrafficProfile(
258
+ name="APT_Recon", packet_mean=12.0, packet_std_frac=0.4,
259
+ duration_mean=400.0, entropy_mean=0.28, entropy_std=0.05,
260
+ tls_probability=0.92, self_signed_prob=0.05,
261
+ common_ports=[443, 80],
262
+ connection_reuse_mean=0.50, geo_distance_mean=2200.0,
263
+ history_score_mean=0.55, cert_validity_mean=320.0,
264
+ ja3_cluster_range=(60, 130),
265
+ ),
266
+ 1: TrafficProfile(
267
+ name="APT_Establish", packet_mean=18.0, packet_std_frac=0.35,
268
+ duration_mean=700.0, entropy_mean=0.35, entropy_std=0.06,
269
+ tls_probability=0.95, self_signed_prob=0.07,
270
+ common_ports=[443, 53],
271
+ connection_reuse_mean=0.45, geo_distance_mean=2600.0,
272
+ history_score_mean=0.48, cert_validity_mean=280.0,
273
+ ja3_cluster_range=(70, 140),
274
+ ),
275
+ 2: TrafficProfile(
276
+ name="APT_Persist", packet_mean=5.0, packet_std_frac=0.6,
277
+ duration_mean=8000.0, entropy_mean=0.42, entropy_std=0.07,
278
+ tls_probability=0.97, self_signed_prob=0.10,
279
+ common_ports=[443],
280
+ connection_reuse_mean=0.38, geo_distance_mean=3200.0,
281
+ history_score_mean=0.38, cert_validity_mean=200.0,
282
+ ja3_cluster_range=(80, 150),
283
+ ),
284
+ 3: TrafficProfile(
285
+ name="APT_Exfil", packet_mean=60.0, packet_std_frac=0.4,
286
+ duration_mean=15000.0, entropy_mean=0.65, entropy_std=0.06,
287
+ tls_probability=0.99, self_signed_prob=0.12,
288
+ common_ports=[443, 8443],
289
+ connection_reuse_mean=0.25, geo_distance_mean=4000.0,
290
+ history_score_mean=0.28, cert_validity_mean=120.0,
291
+ ja3_cluster_range=(90, 160),
292
+ ),
293
+ },
294
+ "ddos_amplification": {
295
+ 0: TrafficProfile(
296
+ name="DDoS_Probe", packet_mean=20.0, packet_std_frac=0.5,
297
+ duration_mean=50.0, entropy_mean=0.15, entropy_std=0.04,
298
+ tls_probability=0.10, self_signed_prob=0.30,
299
+ common_ports=[53, 123, 161, 1900],
300
+ connection_reuse_mean=0.02, geo_distance_mean=6000.0,
301
+ history_score_mean=0.08, cert_validity_mean=60.0,
302
+ ja3_cluster_range=(230, 255),
303
+ ),
304
+ 1: TrafficProfile(
305
+ name="DDoS_Amplify", packet_mean=500.0, packet_std_frac=0.4,
306
+ duration_mean=30.0, entropy_mean=0.10, entropy_std=0.03,
307
+ tls_probability=0.05, self_signed_prob=0.40,
308
+ common_ports=[53, 123, 161, 1900, 11211],
309
+ connection_reuse_mean=0.01, geo_distance_mean=7000.0,
310
+ history_score_mean=0.05, cert_validity_mean=30.0,
311
+ ja3_cluster_range=(240, 255),
312
+ ),
313
+ 2: TrafficProfile(
314
+ name="DDoS_Sustained", packet_mean=900.0, packet_std_frac=0.3,
315
+ duration_mean=20.0, entropy_mean=0.08, entropy_std=0.02,
316
+ tls_probability=0.03, self_signed_prob=0.50,
317
+ common_ports=[53, 123, 80],
318
+ connection_reuse_mean=0.00, geo_distance_mean=8000.0,
319
+ history_score_mean=0.03, cert_validity_mean=20.0,
320
+ ja3_cluster_range=(245, 255),
321
+ ),
322
+ 3: TrafficProfile(
323
+ name="DDoS_Peak", packet_mean=1100.0, packet_std_frac=0.25,
324
+ duration_mean=15.0, entropy_mean=0.06, entropy_std=0.02,
325
+ tls_probability=0.02, self_signed_prob=0.55,
326
+ common_ports=[53, 123, 80],
327
+ connection_reuse_mean=0.00, geo_distance_mean=9000.0,
328
+ history_score_mean=0.02, cert_validity_mean=15.0,
329
+ ja3_cluster_range=(248, 255),
330
+ ),
331
+ },
332
+ }
333
+
334
+ # Fallback for unknown scenarios
335
+ _DEFAULT_MALICIOUS: Dict[int, TrafficProfile] = MALICIOUS_PROFILES["port_scan_exploit_c2"]
336
+
337
+ BENIGN_WEIGHTS = np.array([0.34, 0.16, 0.18, 0.12, 0.20])
338
+
339
+
340
+ class TrafficGenerator:
341
+ """Generates correlated network session feature vectors.
342
+
343
+ Each session is a dict with 'session_id', 'features' (dict),
344
+ and 'metadata' (malicious flag, attack info, profile name).
345
+ """
346
+
347
+ def __init__(self, seed: int = 0) -> None:
348
+ self.rng = np.random.default_rng(seed)
349
+ self.session_counter = 0
350
+
351
+ def generate_benign_sessions(self, tick: int, count: int) -> List[Dict]:
352
+ sessions: List[Dict] = []
353
+ for _ in range(max(0, count)):
354
+ idx = self.rng.choice(len(BENIGN_PROFILES), p=BENIGN_WEIGHTS)
355
+ profile = BENIGN_PROFILES[idx]
356
+ sessions.append(self._build_session(
357
+ profile, tick=tick, malicious=False,
358
+ attack_phase=0, scenario="benign", attacker_id=None,
359
+ ))
360
+ return sessions
361
+
362
+ def generate_malicious_sessions(
363
+ self, tick: int, count: int,
364
+ attack_phase: int, scenario: str,
365
+ attacker_id: str | None = None,
366
+ ) -> List[Dict]:
367
+ sessions: List[Dict] = []
368
+ profiles = MALICIOUS_PROFILES.get(scenario, _DEFAULT_MALICIOUS)
369
+ profile = profiles.get(attack_phase, profiles[max(profiles.keys())])
370
+ for _ in range(max(0, count)):
371
+ sessions.append(self._build_session(
372
+ profile, tick=tick, malicious=True,
373
+ attack_phase=attack_phase, scenario=scenario,
374
+ attacker_id=attacker_id,
375
+ ))
376
+ return sessions
377
+
378
+ def to_observation_vector(self, session: Dict) -> List[float]:
379
+ """Return normalized [0, 1] feature vector."""
380
+ raw = session["features"]
381
+ normalized = []
382
+ for name in FEATURE_ORDER:
383
+ val = float(raw[name])
384
+ lo, hi = FEATURE_BOUNDS[name]
385
+ normalized.append(max(0.0, min(1.0, (val - lo) / max(hi - lo, 1e-9))))
386
+ return normalized
387
+
388
+ def to_raw_vector(self, session: Dict) -> List[float]:
389
+ """Return un-normalized feature vector (for inspection)."""
390
+ return [float(session["features"][name]) for name in FEATURE_ORDER]
391
+
392
+ # ── Internal session builder ─────────────────────────────────────
393
+
394
+ def _build_session(
395
+ self, profile: TrafficProfile, tick: int,
396
+ malicious: bool, attack_phase: int, scenario: str,
397
+ attacker_id: str | None,
398
+ ) -> Dict:
399
+ self.session_counter += 1
400
+ rng = self.rng
401
+
402
+ # --- Volume & timing (correlated cluster) ---
403
+ packet_count = int(max(3, rng.normal(
404
+ profile.packet_mean, profile.packet_mean * profile.packet_std_frac,
405
+ )))
406
+ avg_packet_size = float(max(40.0, rng.normal(560.0, 160.0)))
407
+ # Bytes are correlated with packets and packet size
408
+ bytes_sent = float(max(200.0, packet_count * avg_packet_size * rng.uniform(0.40, 0.85)))
409
+ bytes_received = float(max(100.0, packet_count * avg_packet_size * rng.uniform(0.20, 0.60)))
410
+ duration_ms = float(max(10.0, rng.normal(
411
+ profile.duration_mean, profile.duration_mean * 0.30,
412
+ )))
413
+ # Inter-arrival derived from duration and packet count (correlated)
414
+ inter_arrival_mean = float(duration_ms / max(packet_count, 1))
415
+ inter_arrival_jitter = float(abs(rng.normal(
416
+ inter_arrival_mean * 0.30, inter_arrival_mean * 0.12,
417
+ )))
418
+ packet_size_variance = float(max(5.0, abs(rng.normal(
419
+ 180.0 if malicious else 130.0, 60.0,
420
+ ))))
421
+
422
+ # --- TLS / certificate (correlated cluster) ---
423
+ tls_enabled = rng.random() < profile.tls_probability
424
+ tls_version = int(rng.choice([1, 2], p=[0.20, 0.80])) if tls_enabled else 0
425
+ # Self-signed correlates with TLS state and profile
426
+ is_self_signed = bool(rng.random() < profile.self_signed_prob) if tls_enabled else False
427
+ cert_chain_length = int(max(0, rng.normal(3.0 if (tls_enabled and not is_self_signed) else 1.0, 0.8)))
428
+ cert_validity_days = float(max(1.0, rng.normal(
429
+ profile.cert_validity_mean, profile.cert_validity_mean * 0.30,
430
+ )))
431
+
432
+ # --- Network metadata ---
433
+ dst_port = int(rng.choice(profile.common_ports))
434
+ src_port = int(rng.integers(1024, 65535))
435
+ protocol = int(rng.choice([0, 1, 2], p=[0.50, 0.32, 0.18]))
436
+ dns_query_count = int(max(0, rng.poisson(3 if malicious else 1)))
437
+
438
+ # --- Behavioral context (correlated with profile) ---
439
+ connection_reuse = float(np.clip(rng.normal(
440
+ profile.connection_reuse_mean, 0.12,
441
+ ), 0.0, 1.0))
442
+ geo_distance = float(max(0.0, rng.normal(
443
+ profile.geo_distance_mean, profile.geo_distance_mean * 0.25,
444
+ )))
445
+ session_history_score = float(np.clip(rng.normal(
446
+ profile.history_score_mean, 0.10,
447
+ ), 0.0, 1.0))
448
+ entropy_score = float(np.clip(rng.normal(
449
+ profile.entropy_mean, profile.entropy_std,
450
+ ), 0.02, 0.99))
451
+ ja3_lo, ja3_hi = profile.ja3_cluster_range
452
+ ja3_hash_cluster = int(rng.integers(ja3_lo, max(ja3_lo + 1, ja3_hi)))
453
+ time_of_day = float((tick % 1440) / 1440.0)
454
+
455
+ features = {
456
+ "bytes_sent": math.log1p(bytes_sent),
457
+ "bytes_received": math.log1p(bytes_received),
458
+ "duration_ms": duration_ms,
459
+ "packet_count": packet_count,
460
+ "avg_packet_size": avg_packet_size,
461
+ "packet_size_variance": packet_size_variance,
462
+ "inter_arrival_mean": inter_arrival_mean,
463
+ "inter_arrival_jitter": inter_arrival_jitter,
464
+ "src_port": src_port,
465
+ "dst_port": dst_port,
466
+ "protocol": protocol,
467
+ "tls_version": tls_version,
468
+ "ja3_hash_cluster": ja3_hash_cluster,
469
+ "cert_chain_length": cert_chain_length,
470
+ "cert_validity_days": cert_validity_days,
471
+ "is_self_signed": int(is_self_signed),
472
+ "dns_query_count": dns_query_count,
473
+ "connection_reuse": connection_reuse,
474
+ "geo_distance": geo_distance,
475
+ "time_of_day": time_of_day,
476
+ "session_history_score": session_history_score,
477
+ "entropy_score": entropy_score,
478
+ }
479
+
480
+ # Session TTL: malicious sessions expire faster (pressure to act)
481
+ ttl = 2 if malicious else 3
482
+
483
+ return {
484
+ "session_id": f"s-{self.session_counter:07d}",
485
+ "features": features,
486
+ "metadata": {
487
+ "malicious": malicious,
488
+ "attack_phase": attack_phase,
489
+ "scenario": scenario,
490
+ "profile": profile.name,
491
+ "attacker_id": attacker_id,
492
+ "revealed": False,
493
+ },
494
+ "created_tick": tick,
495
+ "expires_tick": tick + ttl,
496
+ }
server/utils/reward_engine.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Multi-objective reward engine for the Adaptive AI Firewall environment.
2
+
3
+ Computes R = α·security + β·availability + γ·efficiency + δ·timeliness
4
+ with careful balance to prevent degenerate policies (block-all / allow-all).
5
+ """
6
+ from __future__ import annotations
7
+
8
+ from typing import Dict, Tuple
9
+ import math
10
+
11
+
12
+ ACTIONS = {
13
+ 0: "ALLOW",
14
+ 1: "BLOCK",
15
+ 2: "INSPECT",
16
+ 3: "SANDBOX",
17
+ 4: "RATE_LIMIT",
18
+ 5: "QUARANTINE",
19
+ }
20
+
21
+ # Costs tuned so total episode cost stays well within budget range
22
+ ACTION_COSTS = {
23
+ 0: {"latency": 0.0, "compute": 0.0},
24
+ 1: {"latency": 0.0, "compute": 0.005},
25
+ 2: {"latency": 0.08, "compute": 0.05},
26
+ 3: {"latency": 0.20, "compute": 0.12},
27
+ 4: {"latency": 0.02, "compute": 0.015},
28
+ 5: {"latency": 0.05, "compute": 0.025},
29
+ }
30
+
31
+ # Actions that are considered "blocking" (remove traffic from the network)
32
+ BLOCKING_ACTIONS = frozenset({1, 3, 5})
33
+ # Actions that are considered "inspection" (gather more info)
34
+ INSPECTION_ACTIONS = frozenset({2})
35
+
36
+
37
+ class RewardEngine:
38
+ """Weighted multi-objective reward with anti-degeneracy safeguards.
39
+
40
+ Key design choices (from RL perspective):
41
+ - Miss penalty (0.8) vs false-positive penalty (0.35) ratio is ~2.3:1
42
+ This creates genuine tension — block-all loses ~0.35 per benign session,
43
+ and with 60–80% benign traffic, the total FP penalty dominates.
44
+ - Correct ALLOW on benign traffic gets +0.08 (dense positive signal).
45
+ - INSPECT keeps the session alive (handled by environment) and gets a
46
+ small bonus when it correctly identifies malicious traffic.
47
+ - Timeliness bonus decays exponentially with kill-chain phase.
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ alpha: float = 0.35,
53
+ beta: float = 0.30,
54
+ gamma: float = 0.20,
55
+ delta: float = 0.15,
56
+ ) -> None:
57
+ self.alpha = alpha
58
+ self.beta = beta
59
+ self.gamma = gamma
60
+ self.delta = delta
61
+
62
+ def action_cost(self, action: int) -> float:
63
+ costs = ACTION_COSTS.get(action, ACTION_COSTS[0])
64
+ return float(costs["latency"] + costs["compute"])
65
+
66
+ def reward(
67
+ self,
68
+ action: int,
69
+ is_malicious: bool,
70
+ budget_remaining: float,
71
+ attack_phase: int,
72
+ inspect_correct: bool = False,
73
+ ) -> Tuple[float, Dict[str, float]]:
74
+ blocked = action in BLOCKING_ACTIONS
75
+ inspected = action in INSPECTION_ACTIONS
76
+
77
+ # --- Security component ---
78
+ r_security = 0.0
79
+ if is_malicious and blocked:
80
+ r_security += 1.0
81
+ elif is_malicious and not blocked and not inspected:
82
+ r_security -= 2.0 # missed attack
83
+ elif is_malicious and inspected:
84
+ # Inspect alone doesn't stop the attack but is a valid step
85
+ r_security += 0.15 if inspect_correct else -0.5
86
+
87
+ # --- Availability component ---
88
+ r_availability = 0.0
89
+ if not is_malicious and blocked:
90
+ r_availability -= 1.2 # strong false-positive penalty
91
+ elif not is_malicious and inspected:
92
+ r_availability -= 0.15 # unnecessary inspection delay
93
+ elif not is_malicious and action == 0:
94
+ r_availability += 0.25 # correctly allowing benign traffic
95
+
96
+ # Rate-limiting benign traffic is a milder false positive
97
+ if not is_malicious and action == 4:
98
+ r_availability -= 0.4
99
+
100
+ # --- Efficiency component ---
101
+ cost = self.action_cost(action)
102
+ # Penalize cost relative to remaining budget (bigger penalty as budget shrinks)
103
+ r_efficiency = -cost / max(budget_remaining, 0.1)
104
+
105
+ # --- Timeliness component ---
106
+ # Exponential bonus for catching attacks early in kill chain
107
+ early_bonus = math.exp(-max(attack_phase, 0))
108
+ r_timeliness = early_bonus if (is_malicious and blocked) else 0.0
109
+
110
+ total = (
111
+ self.alpha * r_security
112
+ + self.beta * r_availability
113
+ + self.gamma * r_efficiency
114
+ + self.delta * r_timeliness
115
+ )
116
+ return total, {
117
+ "security": r_security,
118
+ "availability": r_availability,
119
+ "efficiency": r_efficiency,
120
+ "timeliness": r_timeliness,
121
+ "cost": cost,
122
+ }
server/utils/threat_engine.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Multi-stage attack orchestrator following Cyber Kill Chain model.
2
+
3
+ Each attacker has a scenario (one of 5 patterns) and progresses through
4
+ phases 0→3. Adaptation is non-trivial:
5
+ - Detected attackers may switch to stealth mode (mimic benign profiles)
6
+ - Undetected attackers escalate normally
7
+ - Fully blocked attackers are terminated
8
+ - Attackers that reach exfiltration (phase 3) are marked as succeeded
9
+ """
10
+ from __future__ import annotations
11
+
12
+ from dataclasses import dataclass
13
+ from typing import Dict, List, Set
14
+
15
+ import numpy as np
16
+
17
+ # Updated import path
18
+ from server.utils.data_loader import TrafficGenerator
19
+
20
+
21
+ SCENARIOS = [
22
+ "port_scan_exploit_c2",
23
+ "credential_stuffing_lateral",
24
+ "supply_chain_compromise",
25
+ "low_and_slow_apt",
26
+ "ddos_amplification",
27
+ ]
28
+
29
+ # How many sessions each scenario generates per phase
30
+ SESSION_COUNTS: Dict[str, List[int]] = {
31
+ "port_scan_exploit_c2": [4, 2, 1, 2],
32
+ "credential_stuffing_lateral": [3, 3, 2, 2],
33
+ "supply_chain_compromise": [1, 1, 1, 2],
34
+ "low_and_slow_apt": [1, 1, 1, 1],
35
+ "ddos_amplification": [6, 10, 15, 20],
36
+ }
37
+
38
+ # Probability that an attacker escalates per tick (if not detected)
39
+ ESCALATION_PROB: Dict[str, float] = {
40
+ "port_scan_exploit_c2": 0.30,
41
+ "credential_stuffing_lateral": 0.25,
42
+ "supply_chain_compromise": 0.15,
43
+ "low_and_slow_apt": 0.10,
44
+ "ddos_amplification": 0.40,
45
+ }
46
+
47
+
48
+ @dataclass
49
+ class AttackerState:
50
+ attacker_id: str
51
+ scenario: str
52
+ phase: int = 0
53
+ times_detected: int = 0
54
+ stealth_mode: bool = False
55
+ alive: bool = True
56
+ succeeded: bool = False
57
+ ticks_alive: int = 0
58
+ sessions_blocked: int = 0
59
+ sessions_generated: int = 0
60
+
61
+
62
+ class ThreatEngine:
63
+ """Manages the lifecycle of active attackers and generates attack sessions."""
64
+
65
+ def __init__(self, seed: int = 0) -> None:
66
+ self.rng = np.random.default_rng(seed)
67
+ self._attacker_counter = 0
68
+ self._active_attackers: Dict[str, AttackerState] = {}
69
+ self._dead_attackers: List[AttackerState] = []
70
+ self._threat_intel: Dict = {
71
+ "known_bad_ports": [21, 22, 23, 25, 445, 3389, 5900],
72
+ "known_bad_ja3_ranges": [(200, 255), (230, 255)],
73
+ "active_campaigns": [],
74
+ "recent_detections": 0,
75
+ }
76
+
77
+ def reset(self) -> None:
78
+ self._attacker_counter = 0
79
+ self._active_attackers = {}
80
+ self._dead_attackers = []
81
+ self._threat_intel["active_campaigns"] = []
82
+ self._threat_intel["recent_detections"] = 0
83
+
84
+ def maybe_spawn_attacker(self, threat_probability: float) -> None:
85
+ """Probabilistically spawn a new attacker."""
86
+ if self.rng.random() > threat_probability:
87
+ return
88
+ self._attacker_counter += 1
89
+ scenario = SCENARIOS[int(self.rng.integers(0, len(SCENARIOS)))]
90
+ attacker_id = f"a-{self._attacker_counter:04d}"
91
+ state = AttackerState(attacker_id=attacker_id, scenario=scenario)
92
+ self._active_attackers[attacker_id] = state
93
+ # Update threat intel
94
+ campaigns = set(self._threat_intel["active_campaigns"])
95
+ campaigns.add(scenario)
96
+ self._threat_intel["active_campaigns"] = sorted(campaigns)
97
+
98
+ def generate_attack_sessions(
99
+ self, tick: int, generator: TrafficGenerator,
100
+ blocked_attackers: Set[str],
101
+ ) -> List[Dict]:
102
+ """Generate attack sessions for all active attackers, handling adaptation."""
103
+ sessions: List[Dict] = []
104
+
105
+ for attacker in list(self._active_attackers.values()):
106
+ if not attacker.alive:
107
+ continue
108
+
109
+ attacker.ticks_alive += 1
110
+
111
+ # --- Handle detection / blocking ---
112
+ if attacker.attacker_id in blocked_attackers:
113
+ attacker.times_detected += 1
114
+ attacker.sessions_blocked += 1
115
+ self._threat_intel["recent_detections"] += 1
116
+
117
+ if attacker.times_detected >= 3:
118
+ # Fully blocked — attacker gives up
119
+ attacker.alive = False
120
+ self._dead_attackers.append(attacker)
121
+ continue
122
+ elif attacker.times_detected >= 2:
123
+ # Switch to stealth mode — generate fewer, more benign-looking sessions
124
+ attacker.stealth_mode = True
125
+ else:
126
+ # First detection — try to advance past detected phase
127
+ attacker.phase = min(attacker.phase + 1, 3)
128
+
129
+ # --- Natural phase escalation ---
130
+ elif self.rng.random() < ESCALATION_PROB.get(attacker.scenario, 0.2):
131
+ attacker.phase = min(attacker.phase + 1, 3)
132
+
133
+ # --- Check for success (exfiltration complete) ---
134
+ if attacker.phase == 3 and attacker.ticks_alive > 8:
135
+ if self.rng.random() < 0.15:
136
+ attacker.succeeded = True
137
+ attacker.alive = False
138
+ self._dead_attackers.append(attacker)
139
+ continue
140
+
141
+ # --- Generate sessions based on current state ---
142
+ counts = SESSION_COUNTS.get(attacker.scenario, [2, 2, 2, 2])
143
+ count = counts[min(attacker.phase, 3)]
144
+
145
+ if attacker.stealth_mode:
146
+ # In stealth mode: reduce count, use profiles that look more benign
147
+ count = max(1, count // 2)
148
+
149
+ generated = generator.generate_malicious_sessions(
150
+ tick=tick,
151
+ count=count,
152
+ attack_phase=attacker.phase,
153
+ scenario=attacker.scenario,
154
+ attacker_id=attacker.attacker_id,
155
+ )
156
+ attacker.sessions_generated += len(generated)
157
+ sessions.extend(generated)
158
+
159
+ return sessions
160
+
161
+ def intelligence_feed(self) -> Dict:
162
+ """Return threat intelligence available to the agent."""
163
+ active_scenarios = set()
164
+ for a in self._active_attackers.values():
165
+ if a.alive:
166
+ active_scenarios.add(a.scenario)
167
+ self._threat_intel["active_campaigns"] = sorted(active_scenarios)
168
+ return dict(self._threat_intel)
169
+
170
+ def attacker_outcomes(self) -> Dict[str, str]:
171
+ """Return status of all known attackers (for info/debugging)."""
172
+ outcomes: Dict[str, str] = {}
173
+ for a in self._active_attackers.values():
174
+ if a.alive:
175
+ outcomes[a.attacker_id] = "active"
176
+ elif a.succeeded:
177
+ outcomes[a.attacker_id] = "succeeded"
178
+ else:
179
+ outcomes[a.attacker_id] = "stopped"
180
+ for a in self._dead_attackers:
181
+ if a.attacker_id not in outcomes:
182
+ outcomes[a.attacker_id] = "succeeded" if a.succeeded else "stopped"
183
+ return outcomes
tests/conftest.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import pytest
4
+
5
+ from server.baseline.heuristic_agent import heuristic_policy
6
+ from server.baseline.random_agent import random_policy
7
+ from server.firewall_environment import FirewallEnvironment
8
+
9
+
10
+ @pytest.fixture
11
+ def env_easy() -> FirewallEnvironment:
12
+ env = FirewallEnvironment(seed=101)
13
+ env.reset(task="easy", seed=101)
14
+ return env
15
+
16
+
17
+ @pytest.fixture
18
+ def env_medium() -> FirewallEnvironment:
19
+ env = FirewallEnvironment(seed=202)
20
+ env.reset(task="medium", seed=202)
21
+ return env
22
+
23
+
24
+ @pytest.fixture
25
+ def env_hard() -> FirewallEnvironment:
26
+ env = FirewallEnvironment(seed=303)
27
+ env.reset(task="hard", seed=303)
28
+ return env
29
+
30
+
31
+ @pytest.fixture
32
+ def random_agent_policy():
33
+ return random_policy(seed=9)
34
+
35
+
36
+ @pytest.fixture
37
+ def heuristic_agent_policy():
38
+ return heuristic_policy
tests/test_all.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Comprehensive tests for the Adaptive AI Firewall environment.
2
+
3
+ Covers: feature generation, reward mechanics, threat lifecycle,
4
+ grading determinism, degenerate policy detection, and budget management.
5
+ """
6
+ import numpy as np
7
+
8
+ from server.utils.data_loader import FEATURE_ORDER, TrafficGenerator
9
+ from server.utils.threat_engine import ThreatEngine
10
+ from server.utils.reward_engine import RewardEngine
11
+ from server.firewall_environment import (
12
+ FirewallEnvironment, OBS_DIM, NUM_ACTIONS,
13
+ )
14
+ from server.graders import run_deterministic_grade, grade_stats
15
+ from server.baseline.random_agent import random_policy, block_all_policy
16
+ from server.baseline.heuristic_agent import heuristic_policy
17
+
18
+
19
+ # ═══════════════════════════════════════════════════════════════════
20
+ # Traffic Generator
21
+ # ═══════════════════════════════════════════════════════════════════
22
+
23
+ class TestTrafficGenerator:
24
+ def test_feature_dimension(self):
25
+ gen = TrafficGenerator(seed=11)
26
+ session = gen.generate_benign_sessions(tick=0, count=1)[0]
27
+ assert len(FEATURE_ORDER) == 22
28
+ assert len(gen.to_observation_vector(session)) == 22
29
+
30
+ def test_normalized_features_in_0_1(self):
31
+ gen = TrafficGenerator(seed=42)
32
+ for _ in range(50):
33
+ session = gen.generate_benign_sessions(tick=0, count=1)[0]
34
+ obs = gen.to_observation_vector(session)
35
+ for i, val in enumerate(obs):
36
+ assert 0.0 <= val <= 1.0, f"Feature {FEATURE_ORDER[i]} = {val} out of [0,1]"
37
+
38
+ def test_malicious_features_normalized(self):
39
+ gen = TrafficGenerator(seed=55)
40
+ for scenario in ["port_scan_exploit_c2", "ddos_amplification", "supply_chain_compromise"]:
41
+ for phase in range(4):
42
+ sessions = gen.generate_malicious_sessions(
43
+ tick=0, count=3, attack_phase=phase, scenario=scenario,
44
+ )
45
+ for s in sessions:
46
+ obs = gen.to_observation_vector(s)
47
+ for i, val in enumerate(obs):
48
+ assert 0.0 <= val <= 1.0
49
+
50
+ def test_benign_malicious_separation(self):
51
+ """Verify that malicious and benign sessions have statistically different features."""
52
+ gen = TrafficGenerator(seed=77)
53
+ benign_vecs = []
54
+ for _ in range(100):
55
+ s = gen.generate_benign_sessions(tick=0, count=1)[0]
56
+ benign_vecs.append(gen.to_observation_vector(s))
57
+
58
+ mal_vecs = []
59
+ for phase in range(4):
60
+ for _ in range(25):
61
+ s = gen.generate_malicious_sessions(
62
+ tick=0, count=1, attack_phase=phase,
63
+ scenario="port_scan_exploit_c2",
64
+ )[0]
65
+ mal_vecs.append(gen.to_observation_vector(s))
66
+
67
+ benign_arr = np.array(benign_vecs)
68
+ mal_arr = np.array(mal_vecs)
69
+
70
+ # At least some features should have meaningfully different means
71
+ mean_diff = np.abs(benign_arr.mean(axis=0) - mal_arr.mean(axis=0))
72
+ significant_features = (mean_diff > 0.08).sum()
73
+ assert significant_features >= 5, (
74
+ f"Only {significant_features} features differ — distributions too similar"
75
+ )
76
+
77
+ def test_session_ids_unique(self):
78
+ gen = TrafficGenerator(seed=99)
79
+ ids = set()
80
+ for _ in range(100):
81
+ sessions = gen.generate_benign_sessions(tick=0, count=3)
82
+ for s in sessions:
83
+ assert s["session_id"] not in ids
84
+ ids.add(s["session_id"])
85
+
86
+
87
+ # ═══════════════════════════════════════════════════════════════════
88
+ # Reward Engine
89
+ # ═══════════════════════════════════════════════════════════════════
90
+
91
+ class TestRewardEngine:
92
+ def test_block_malicious_positive(self):
93
+ eng = RewardEngine()
94
+ r, _ = eng.reward(action=1, is_malicious=True, budget_remaining=50.0, attack_phase=0)
95
+ assert r > 0
96
+
97
+ def test_miss_malicious_negative(self):
98
+ eng = RewardEngine()
99
+ r, _ = eng.reward(action=0, is_malicious=True, budget_remaining=50.0, attack_phase=2)
100
+ assert r < 0
101
+
102
+ def test_block_benign_negative(self):
103
+ eng = RewardEngine()
104
+ r, _ = eng.reward(action=1, is_malicious=False, budget_remaining=50.0, attack_phase=0)
105
+ assert r < 0
106
+
107
+ def test_allow_benign_positive(self):
108
+ eng = RewardEngine()
109
+ r, _ = eng.reward(action=0, is_malicious=False, budget_remaining=50.0, attack_phase=0)
110
+ assert r > 0, "Correctly allowing benign traffic should be rewarded"
111
+
112
+ def test_block_all_loses_in_mixed_traffic(self):
113
+ """Block-all should have negative total reward on benign-heavy traffic."""
114
+ eng = RewardEngine()
115
+ total = 0.0
116
+ # Simulate 80% benign, 20% malicious
117
+ for _ in range(80):
118
+ r, _ = eng.reward(action=1, is_malicious=False, budget_remaining=50.0, attack_phase=0)
119
+ total += r
120
+ for _ in range(20):
121
+ r, _ = eng.reward(action=1, is_malicious=True, budget_remaining=50.0, attack_phase=1)
122
+ total += r
123
+ # Block-all should have lower score than a selective policy
124
+ assert total < 0, f"Block-all total reward {total} should be negative on 80/20 mix"
125
+
126
+ def test_early_detection_bonus(self):
127
+ eng = RewardEngine()
128
+ r_early, _ = eng.reward(action=1, is_malicious=True, budget_remaining=50.0, attack_phase=0)
129
+ r_late, _ = eng.reward(action=1, is_malicious=True, budget_remaining=50.0, attack_phase=3)
130
+ assert r_early > r_late, "Early detection should give higher reward"
131
+
132
+
133
+ # ═══════════════════════════════════════════════════════════════════
134
+ # Threat Engine
135
+ # ═══════════════════════════════════════════════════════════════════
136
+
137
+ class TestThreatEngine:
138
+ def test_spawn_and_generate(self):
139
+ engine = ThreatEngine(seed=22)
140
+ gen = TrafficGenerator(seed=23)
141
+ engine.maybe_spawn_attacker(1.0)
142
+ sessions = engine.generate_attack_sessions(tick=0, generator=gen, blocked_attackers=set())
143
+ assert len(sessions) > 0
144
+ assert all(s["metadata"]["malicious"] for s in sessions)
145
+
146
+ def test_attacker_dies_after_3_blocks(self):
147
+ engine = ThreatEngine(seed=33)
148
+ gen = TrafficGenerator(seed=34)
149
+ engine.maybe_spawn_attacker(1.0)
150
+ attacker_id = list(engine._active_attackers.keys())[0]
151
+
152
+ for _ in range(3):
153
+ engine.generate_attack_sessions(
154
+ tick=0, generator=gen, blocked_attackers={attacker_id},
155
+ )
156
+
157
+ # After 3 blocks, attacker should be dead
158
+ attacker = engine._active_attackers[attacker_id]
159
+ assert not attacker.alive
160
+
161
+ def test_attacker_outcomes(self):
162
+ engine = ThreatEngine(seed=44)
163
+ gen = TrafficGenerator(seed=45)
164
+ engine.maybe_spawn_attacker(1.0)
165
+ engine.generate_attack_sessions(tick=0, generator=gen, blocked_attackers=set())
166
+ outcomes = engine.attacker_outcomes()
167
+ assert len(outcomes) > 0
168
+ assert all(v in ("active", "stopped", "succeeded") for v in outcomes.values())
169
+
170
+
171
+ # ═══════════════════════════════════════════════════════════════════
172
+ # Firewall Environment
173
+ # ═══════════════════════════════════════════════════════════════════
174
+
175
+ class TestFirewallEnvironment:
176
+ def test_reset_returns_valid_state(self):
177
+ env = FirewallEnvironment(seed=99)
178
+ state = env.reset(task="easy", seed=100)
179
+ assert state["observation_dim"] == OBS_DIM
180
+ assert state["num_actions"] == NUM_ACTIONS
181
+ assert state["budget_remaining"] > 0
182
+
183
+ def test_step_returns_expected_keys(self):
184
+ env = FirewallEnvironment(seed=99)
185
+ env.reset(task="easy", seed=100)
186
+ pending = list(env.pending_sessions.keys())
187
+ actions = {sid: 0 for sid in pending[:3]}
188
+ response = env.step(actions)
189
+ assert "reward" in response
190
+ assert "done" in response
191
+ assert "state" in response
192
+
193
+ def test_inspect_keeps_session_alive(self):
194
+ env = FirewallEnvironment(seed=50)
195
+ env.reset(task="easy", seed=50)
196
+ sid = list(env.pending_sessions.keys())[0]
197
+ env._apply_action(sid, 2) # INSPECT
198
+ assert sid in env.inspected_sessions, "INSPECT should keep session in inspected pool"
199
+
200
+ def test_inspect_then_block(self):
201
+ """Two-phase: inspect → block."""
202
+ env = FirewallEnvironment(seed=60)
203
+ env.reset(task="easy", seed=60)
204
+ sid = list(env.pending_sessions.keys())[0]
205
+
206
+ # Phase 1: inspect
207
+ r1, _ = env._apply_action(sid, 2)
208
+ assert sid in env.inspected_sessions
209
+
210
+ # Phase 2: block
211
+ r2, _ = env._apply_action(sid, 1)
212
+ assert sid not in env.inspected_sessions
213
+
214
+ def test_budget_stays_positive_with_allow(self):
215
+ """All-allow policy should preserve most of the budget."""
216
+ env = FirewallEnvironment(seed=70)
217
+ env.reset(task="easy", seed=70)
218
+ initial = env.budget_remaining
219
+ for _ in range(50):
220
+ sids = list(env.pending_sessions.keys())
221
+ if not sids:
222
+ break
223
+ env.step({sid: 0 for sid in sids})
224
+ # ALLOW costs 0, so budget should barely change
225
+ assert env.budget_remaining >= initial * 0.95
226
+
227
+ def test_budget_nonzero_with_reasonable_policy(self):
228
+ """Heuristic policy should leave some budget remaining."""
229
+ env = FirewallEnvironment(seed=80)
230
+ env.reset(task="easy", seed=80)
231
+ for _ in range(env.max_steps):
232
+ sids = (
233
+ list(env.inspected_sessions.keys())
234
+ + list(env.pending_sessions.keys())
235
+ )
236
+ actions = heuristic_policy(env, sids)
237
+ resp = env.step(actions)
238
+ if resp["done"]:
239
+ break
240
+ stats = env.get_network_stats()
241
+ assert stats["efficiency"] > 0.0, f"Efficiency should be > 0, got {stats['efficiency']}"
242
+
243
+ def test_expired_malicious_counted_in_metrics(self):
244
+ """Expired malicious sessions must be counted in totals."""
245
+ env = FirewallEnvironment(seed=90)
246
+ env.reset(task="easy", seed=90)
247
+ # Let everything expire by stepping with no actions
248
+ for _ in range(10):
249
+ env.step({})
250
+ stats = env.get_network_stats()
251
+ if stats["total_malicious"] > 0:
252
+ # expired malicious should be counted
253
+ assert stats["expired_malicious"] > 0
254
+
255
+ def test_single_session_mode(self):
256
+ """step_single returns valid observation and reward."""
257
+ env = FirewallEnvironment(seed=100)
258
+ env.reset(task="easy", seed=100)
259
+ result = env.step_single(0) # ALLOW
260
+ assert len(result["observation"]) == OBS_DIM
261
+ assert "reward" in result
262
+ assert "done" in result
263
+
264
+
265
+ # ═══════════════════════════════════════════════════════════════════
266
+ # Graders
267
+ # ═══════════════════════════════════════════════════════════════════
268
+
269
+ class TestGraders:
270
+ def test_deterministic_grading(self):
271
+ env = FirewallEnvironment(seed=31)
272
+ p1 = random_policy(seed=9)
273
+ first = run_deterministic_grade(env, task="easy", policy=p1)["score"]
274
+ p2 = random_policy(seed=9)
275
+ second = run_deterministic_grade(env, task="easy", policy=p2)["score"]
276
+ assert first == second, "Same seed should produce same score"
277
+
278
+ def test_score_in_valid_range(self):
279
+ env = FirewallEnvironment(seed=40)
280
+ for task in ("easy", "medium", "hard"):
281
+ policy = random_policy(seed=7)
282
+ result = run_deterministic_grade(env, task=task, policy=policy)
283
+ assert 0.0 <= result["score"] <= 1.0
284
+
285
+ def test_heuristic_beats_random(self):
286
+ """Core sanity check: heuristic > random on easy task."""
287
+ env = FirewallEnvironment(seed=50)
288
+ rp = random_policy(seed=7)
289
+ r_score = run_deterministic_grade(env, task="easy", policy=rp)["score"]
290
+ h_score = run_deterministic_grade(env, task="easy", policy=heuristic_policy)["score"]
291
+ assert h_score > r_score, (
292
+ f"Heuristic ({h_score:.4f}) must beat random ({r_score:.4f}) on easy task"
293
+ )
294
+
295
+ def test_heuristic_beats_block_all(self):
296
+ """Block-all should not dominate heuristic."""
297
+ env = FirewallEnvironment(seed=60)
298
+ b_score = run_deterministic_grade(env, task="easy", policy=block_all_policy)["score"]
299
+ h_score = run_deterministic_grade(env, task="easy", policy=heuristic_policy)["score"]
300
+ assert h_score > b_score, (
301
+ f"Heuristic ({h_score:.4f}) must beat block-all ({b_score:.4f})"
302
+ )
303
+
304
+ def test_grade_stats_clamps(self):
305
+ stats = {"detection_rate": 1.5, "false_positive_rate": -0.5, "efficiency": 2.0}
306
+ result = grade_stats("easy", stats)
307
+ assert result["score"] <= 1.0
tests/test_environment_dynamics.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from server.baseline.heuristic_agent import heuristic_policy
2
+ from server.firewall_environment import FirewallEnvironment
3
+
4
+
5
+ def test_expired_malicious_sessions_are_counted():
6
+ env = FirewallEnvironment(seed=11)
7
+ env.reset(task="easy", seed=11)
8
+ before = env.metrics.malicious_seen
9
+ for _ in range(4):
10
+ env.step({})
11
+ after = env.metrics.malicious_seen
12
+ assert after >= before
13
+
14
+
15
+ def test_inspect_keeps_session_pending_and_reveals():
16
+ env = FirewallEnvironment(seed=12)
17
+ env.reset(task="easy", seed=12)
18
+ session_id = next(iter(env.pending_sessions.keys()))
19
+ env.take_action(session_id=session_id, action=2)
20
+ assert session_id in env.pending_sessions
21
+ assert env.pending_sessions[session_id]["metadata"]["revealed"] is True
22
+
23
+
24
+ def test_step_single_has_fixed_size_action_mode():
25
+ env = FirewallEnvironment(seed=13)
26
+ env.reset(task="easy", seed=13)
27
+ response = env.step_single(action=0)
28
+ assert "focus_observation" in response["state"]
29
+ assert len(response["state"]["focus_observation"]) == 22
30
+
31
+
32
+ def test_budget_is_scaled_by_episode_length():
33
+ env = FirewallEnvironment(seed=14, budget=50.0)
34
+ env.reset(task="hard", seed=14)
35
+ assert env.initial_budget >= env.max_steps * 0.35
36
+
37
+
38
+ def test_attacker_outcomes_exposed_in_step_info():
39
+ env = FirewallEnvironment(seed=15)
40
+ env.reset(task="easy", seed=15)
41
+ session_id = next(iter(env.pending_sessions.keys()))
42
+ result = env.step({session_id: 1})
43
+ assert "attacker_outcomes" in result["info"]
44
+
45
+
46
+ def test_heuristic_policy_executes_over_pending_sessions():
47
+ env = FirewallEnvironment(seed=16)
48
+ env.reset(task="easy", seed=16)
49
+ actions = heuristic_policy(env, list(env.pending_sessions.keys())[:5])
50
+ assert isinstance(actions, dict)
tests/test_integration_policies.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from server.baseline.heuristic_agent import heuristic_policy
2
+ from server.baseline.random_agent import random_policy
3
+ from server.firewall_environment import FirewallEnvironment
4
+ from server.graders import run_deterministic_grade
5
+
6
+
7
+ def always_allow_policy(_, session_ids):
8
+ return {sid: 0 for sid in session_ids}
9
+
10
+
11
+ def always_block_policy(_, session_ids):
12
+ return {sid: 1 for sid in session_ids}
13
+
14
+
15
+ def test_policy_ordering_easy_task():
16
+ env = FirewallEnvironment(seed=77)
17
+ random_score = run_deterministic_grade(env, task="easy", policy=random_policy(seed=7))["score"]
18
+ heuristic_score = run_deterministic_grade(env, task="easy", policy=heuristic_policy)["score"]
19
+ allow_score = run_deterministic_grade(env, task="easy", policy=always_allow_policy)["score"]
20
+ assert heuristic_score >= random_score
21
+ assert heuristic_score >= allow_score
22
+
23
+
24
+ def test_block_all_is_not_best_strategy():
25
+ env = FirewallEnvironment(seed=88)
26
+ for task in ("easy", "medium", "hard"):
27
+ block_score = run_deterministic_grade(env, task=task, policy=always_block_policy)["score"]
28
+ heuristic_score = run_deterministic_grade(env, task=task, policy=heuristic_policy)["score"]
29
+ assert block_score <= heuristic_score
tests/test_reward_and_scores.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from server.firewall_environment import FirewallEnvironment
2
+ from server.graders import grade_stats
3
+ from server.utils.reward_engine import RewardEngine
4
+
5
+
6
+ def test_grade_score_bounds():
7
+ stats = {
8
+ "detection_rate": 0.5,
9
+ "false_positive_rate": 0.1,
10
+ "efficiency": 0.8,
11
+ "early_detection_bonus": 0.7,
12
+ "cascade_prevention": 0.6,
13
+ }
14
+ for task in ("easy", "medium", "hard"):
15
+ score = grade_stats(task, stats)["score"]
16
+ assert 0.0 <= score <= 1.0
17
+
18
+
19
+ def test_reward_range_is_reasonable():
20
+ engine = RewardEngine()
21
+ samples = [
22
+ engine.reward(action=0, is_malicious=False, budget_remaining=100.0, attack_phase=0)[0],
23
+ engine.reward(action=1, is_malicious=False, budget_remaining=100.0, attack_phase=0)[0],
24
+ engine.reward(action=1, is_malicious=True, budget_remaining=100.0, attack_phase=1)[0],
25
+ engine.reward(action=0, is_malicious=True, budget_remaining=100.0, attack_phase=3)[0],
26
+ ]
27
+ assert min(samples) > -2.5
28
+ assert max(samples) < 2.5
29
+
30
+
31
+ def test_efficiency_is_non_zero_after_episode():
32
+ env = FirewallEnvironment(seed=66)
33
+ env.reset(task="medium", seed=66)
34
+ done = False
35
+ while not done:
36
+ response = env.step({})
37
+ done = response["done"]
38
+ stats = env.get_network_stats()
39
+ assert stats["efficiency"] > 0.0
uv.lock ADDED
The diff for this file is too large to render. See raw diff