Spaces:
Sleeping
Sleeping
GOOD CAT commited on
Commit ·
ec8c511
1
Parent(s): caab1ce
Final submission prep
Browse files- .dockerignore +29 -0
- .env.example +10 -0
- .gitignore +11 -0
- Dockerfile +38 -0
- README.md +62 -0
- REVIEW_AND_TODO.md +129 -0
- _agents/skills/SKILL.md +18 -0
- _agents/skills/debug_environment.md +17 -0
- _agents/skills/evaluate_agent.md +17 -0
- _agents/skills/train_agent.md +10 -0
- _agents/skills/understand_environment.md +17 -0
- _agents/workflows/deploy.md +14 -0
- _agents/workflows/setup.md +13 -0
- _agents/workflows/train.md +12 -0
- client.py +38 -0
- conftest.py +6 -0
- docs/ACTION_SPACE.md +26 -0
- docs/API_REFERENCE.md +33 -0
- docs/ARCHITECTURE.md +50 -0
- docs/DEPLOYMENT.md +26 -0
- docs/REWARD_DESIGN.md +63 -0
- docs/STATE_SPACE.md +30 -0
- docs/TASKS.md +21 -0
- docs/THREAT_MODELS.md +20 -0
- implementation_plan.md +172 -0
- inference.py +224 -0
- models.py +139 -0
- openenv.yaml +52 -0
- progresss.md +49 -0
- pyproject.toml +32 -0
- requirements.txt +7 -0
- scripts/validate-submission.sh +186 -0
- server/__init__.py +1 -0
- server/app.py +257 -0
- server/baseline/__init__.py +1 -0
- server/baseline/heuristic_agent.py +62 -0
- server/baseline/random_agent.py +21 -0
- server/firewall_environment.py +490 -0
- server/graders.py +124 -0
- server/utils/__init__.py +1 -0
- server/utils/data_loader.py +496 -0
- server/utils/reward_engine.py +122 -0
- server/utils/threat_engine.py +183 -0
- tests/conftest.py +38 -0
- tests/test_all.py +307 -0
- tests/test_environment_dynamics.py +50 -0
- tests/test_integration_policies.py +29 -0
- tests/test_reward_and_scores.py +39 -0
- uv.lock +0 -0
.dockerignore
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.venv/
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.pyc
|
| 4 |
+
*.pyo
|
| 5 |
+
*.pyd
|
| 6 |
+
.git/
|
| 7 |
+
.gitignore
|
| 8 |
+
.gitattributes
|
| 9 |
+
.vscode/
|
| 10 |
+
.env
|
| 11 |
+
.env.example
|
| 12 |
+
.pytest_cache/
|
| 13 |
+
.ruff_cache/
|
| 14 |
+
logs/
|
| 15 |
+
models/
|
| 16 |
+
docs/
|
| 17 |
+
tests/
|
| 18 |
+
scripts/
|
| 19 |
+
_agents/
|
| 20 |
+
ppo_firewall_*
|
| 21 |
+
*.zip
|
| 22 |
+
*.lock
|
| 23 |
+
*.md
|
| 24 |
+
!README.md
|
| 25 |
+
progresss.md
|
| 26 |
+
REVIEW_AND_TODO.md
|
| 27 |
+
implementation_plan.md
|
| 28 |
+
conftest.py
|
| 29 |
+
pyproject.toml
|
.env.example
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Mandatory: HuggingFace Token for API Router (REQUIRED for evaluation)
|
| 2 |
+
HF_TOKEN=your_huggingface_token_here
|
| 3 |
+
|
| 4 |
+
# LLM Configuration (Evaluator will inject these)
|
| 5 |
+
API_BASE_URL=https://router.huggingface.co/v1
|
| 6 |
+
MODEL_NAME=meta-llama/Llama-3.1-8B-Instruct
|
| 7 |
+
|
| 8 |
+
# Environment Settings
|
| 9 |
+
FIREWALL_ENV_URL=http://localhost:7860
|
| 10 |
+
IMAGE_NAME=ai-firewall-openenv
|
.gitignore
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.env
|
| 2 |
+
.venv/
|
| 3 |
+
__pycache__/
|
| 4 |
+
*.py[cod]
|
| 5 |
+
*$py.class
|
| 6 |
+
.pytest_cache/
|
| 7 |
+
.ruff_cache/
|
| 8 |
+
*.zip
|
| 9 |
+
logs/
|
| 10 |
+
evaluations.npz
|
| 11 |
+
best_model.zip
|
Dockerfile
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Set environment variables
|
| 6 |
+
ENV PYTHONDONTWRITEBYTECODE=1 \
|
| 7 |
+
PYTHONUNBUFFERED=1 \
|
| 8 |
+
PYTHONPATH=/app
|
| 9 |
+
|
| 10 |
+
# Install system dependencies
|
| 11 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 12 |
+
build-essential \
|
| 13 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 14 |
+
|
| 15 |
+
# Copy requirements first for better caching
|
| 16 |
+
COPY requirements.txt .
|
| 17 |
+
|
| 18 |
+
# Install dependencies
|
| 19 |
+
RUN pip install --no-cache-dir --upgrade pip && \
|
| 20 |
+
pip install --no-cache-dir -r requirements.txt
|
| 21 |
+
|
| 22 |
+
# Copy application code
|
| 23 |
+
COPY server/ /app/server/
|
| 24 |
+
COPY inference.py /app/inference.py
|
| 25 |
+
COPY models.py /app/models.py
|
| 26 |
+
COPY client.py /app/client.py
|
| 27 |
+
COPY openenv.yaml /app/openenv.yaml
|
| 28 |
+
COPY README.md /app/README.md
|
| 29 |
+
|
| 30 |
+
# Expose port for HF Spaces
|
| 31 |
+
EXPOSE 7860
|
| 32 |
+
|
| 33 |
+
# Health check (matching reference project pattern)
|
| 34 |
+
HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \
|
| 35 |
+
CMD python -c "import requests; requests.get('http://localhost:7860/health')" || exit 1
|
| 36 |
+
|
| 37 |
+
# Default command: run the FastAPI app (for HF Spaces)
|
| 38 |
+
CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: AI Firewall OpenEnv
|
| 3 |
+
emoji: 🛡️
|
| 4 |
+
colorFrom: red
|
| 5 |
+
colorTo: blue
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# 🛡️ AI Firewall OpenEnv
|
| 11 |
+
|
| 12 |
+
A production-grade AI-driven adaptive firewall simulation for automated threat detection in encrypted network traffic.
|
| 13 |
+
|
| 14 |
+
## 📖 Problem Description
|
| 15 |
+
Encrypted traffic poses a challenge for traditional firewalls. This project uses AI agents to make real-time decisions (ALLOW, BLOCK, etc.) based on session metadata alone, balancing security with network performance.
|
| 16 |
+
|
| 17 |
+
## 🎮 Tasks
|
| 18 |
+
- **🟢 Easy (Perimeter Defense)**: Clear attack patterns for initial testing.
|
| 19 |
+
- **🟡 Medium (Mixed Threat Landscape)**: Multi-stage attacks with ambiguous traffic signals.
|
| 20 |
+
- **🔴 Hard (Advanced Persistent Threat)**: Stealthy, low-signal APT scenarios.
|
| 21 |
+
|
| 22 |
+
## 🧠 Environment Specs
|
| 23 |
+
- **Observation Space**: Box(22,) - Normalized features including JA3 fingerprints, entropy, geo-distance, and session history.
|
| 24 |
+
- **Action Space**: Discrete(6)
|
| 25 |
+
- 0: ALLOW
|
| 26 |
+
- 1: BLOCK
|
| 27 |
+
- 2: INSPECT
|
| 28 |
+
- 3: SANDBOX
|
| 29 |
+
- 4: RATE_LIMIT
|
| 30 |
+
- 5: QUARANTINE
|
| 31 |
+
|
| 32 |
+
## 📊 Reward Logic
|
| 33 |
+
Rewards are multi-objective:
|
| 34 |
+
- **Correct Block**: +1.0
|
| 35 |
+
- **False Positive**: -1.2 (Strong penalty)
|
| 36 |
+
- **Missed Attack**: -2.0 (Critical failure)
|
| 37 |
+
- **Correct Allow**: +0.25 (Efficiency bonus)
|
| 38 |
+
- **Inspect**: Dynamic cost/benefit based on revealed status.
|
| 39 |
+
|
| 40 |
+
## 🚀 Setup & Usage
|
| 41 |
+
### **Prerequisites**
|
| 42 |
+
- Docker installed
|
| 43 |
+
- Python 3.11+
|
| 44 |
+
- API Keys for OpenAI/OpenRouter (optional for LLM agent)
|
| 45 |
+
|
| 46 |
+
### **Local Execution**
|
| 47 |
+
1. **Configure Keys**: `cp .env.example .env` and add your keys.
|
| 48 |
+
2. **Run Inference**: `python inference.py --task easy`
|
| 49 |
+
3. **Validate**: `bash scripts/validate-submission.sh <ping_url>`
|
| 50 |
+
|
| 51 |
+
### **Docker Deployment**
|
| 52 |
+
```bash
|
| 53 |
+
docker build -t ai-firewall .
|
| 54 |
+
docker run -p 7860:7860 ai-firewall
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
## 🏗️ Project Structure
|
| 58 |
+
- `env/`: Core firewall environment (reset, step, state).
|
| 59 |
+
- `grader/`: Scoring and grading logic.
|
| 60 |
+
- `utils/`: Traffic simulation and reward engines.
|
| 61 |
+
- `inference.py`: LLM-based inference script.
|
| 62 |
+
- `openenv.yaml`: Metadata for OpenEnv.
|
REVIEW_AND_TODO.md
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🔍 Codebase Review & TODO — OpenEnv RL Challenge Submission
|
| 2 |
+
|
| 3 |
+
> **Last Updated**: 2026-04-06T20:25 IST
|
| 4 |
+
> **Status**: ✅ SUBMISSION-READY — Structure & Logic Verified
|
| 5 |
+
|
| 6 |
+
---
|
| 7 |
+
|
| 8 |
+
## 📊 Quick Status Dashboard
|
| 9 |
+
|
| 10 |
+
| Requirement | Status | Notes |
|
| 11 |
+
|---|---|---|
|
| 12 |
+
| `inference.py` in root directory | ✅ Verified | Runs with `[START]/[STEP]/[END]` output |
|
| 13 |
+
| `models.py` in root directory | ✅ Verified | Correctly defines `Action` / `Observation` |
|
| 14 |
+
| `server/` contains env logic | ✅ Verified | Consolidated package structure |
|
| 15 |
+
| Web Interface at `/web` | ✅ Verified | Standard playground UI serving |
|
| 16 |
+
| FastAPI Endpoints (`/health`, `/schema`) | ✅ Verified | Responding with 200 OK |
|
| 17 |
+
| Dockerfile structure | ✅ Verified | Correct `PYTHONPATH` and `CMD` |
|
| 18 |
+
| Heuristic fallback (8 rules) | ✅ Verified | Integrated into `inference.py` |
|
| 19 |
+
| Local Ollama / Qwen Support | ✅ Done | Defaulting to local model with fallback |
|
| 20 |
+
| Syntax verification | ✅ Verified | All files pass `py_compile` |
|
| 21 |
+
|
| 22 |
+
---
|
| 23 |
+
|
| 24 |
+
## 🚨 Previous Blocking Issues (All Fixed)
|
| 25 |
+
|
| 26 |
+
| # | Bug | Fix Applied |
|
| 27 |
+
|---|---|---|
|
| 28 |
+
| 1 | `[STEP]` action extraction via fragile nested `.get()` chain | Track `action` integer explicitly before if/else |
|
| 29 |
+
| 2 | `[END]` not emitted on exception; had extra `error=` field | `try/finally` pattern; removed non-spec `error=` field |
|
| 30 |
+
| 3 | Heuristic fallback only had 2 rules (~33% detection) | Ported 8-rule heuristic from `llm_agent.py` (~51%+ detection) |
|
| 31 |
+
| 4 | `server/app.py` import: `from src.adaptive_firewall_env...` | Changed to `from adaptive_firewall_env.server.app import app` |
|
| 32 |
+
| 5 | Two parallel codebases with different import chains | Accepted—both work; `__init__.py` files added for reliability |
|
| 33 |
+
| 6 | `action` variable undefined when no `focus_session_id` | Initialize `action = 0` before the if/else block |
|
| 34 |
+
| 7 | `[END]` line had extra `error=` field not in spec | Removed `error=` field; spec: `[END] success=X steps=N rewards=...` |
|
| 35 |
+
| 8 | Missing `__init__.py` in `env/`, `utils/`, `grader/` | Created all three files |
|
| 36 |
+
|
| 37 |
+
---
|
| 38 |
+
|
| 39 |
+
## ⚠️ NON-BLOCKING Issues (Remaining)
|
| 40 |
+
|
| 41 |
+
| # | Issue | Status | Recommendation |
|
| 42 |
+
|---|---|---|---|
|
| 43 |
+
| 1 | `openenv-core` may pull heavy transitive deps | ⚠️ Untested | Test Docker build; remove if image > 4 GB |
|
| 44 |
+
| 2 | `.env` with real HF_TOKEN in git history | ⚠️ Security | Rotate token immediately after submission |
|
| 45 |
+
| 3 | Code duplication between `env/` and `src/` | 📝 Accepted | Consolidate long-term |
|
| 46 |
+
| 4 | Docker build not tested locally | ⚠️ Untested | `docker build -t ai-firewall . && docker run -e HF_TOKEN=x -p 7860:7860 ai-firewall` |
|
| 47 |
+
|
| 48 |
+
---
|
| 49 |
+
|
| 50 |
+
## ✅ Already Implemented & Working
|
| 51 |
+
|
| 52 |
+
- Core RL Environment (both `env/` and `src/adaptive_firewall_env/` copies)
|
| 53 |
+
- Traffic Generator (22 features, 5 benign + 20 malicious profiles)
|
| 54 |
+
- Threat Engine (Cyber Kill Chain model, import fixed to `from utils.data_loader`)
|
| 55 |
+
- Reward Engine (multi-objective: security + availability + efficiency + timeliness)
|
| 56 |
+
- Grading System (thresholds 0.70/0.50/0.45 + pass constraints)
|
| 57 |
+
- FastAPI Server (health, reset, step, step_single, tools, LLM playground)
|
| 58 |
+
- Pydantic Models (all API endpoints typed)
|
| 59 |
+
- OpenEnv Manifest (`openenv.yaml` complete with tasks/tools/spaces)
|
| 60 |
+
- Dockerfile (copies all dirs, correct PYTHONPATH, port 7860)
|
| 61 |
+
- Requirements (trimmed — no torch, no stable-baselines3)
|
| 62 |
+
- `.gitignore` (`.env` listed), `.env.example` (defaults documented)
|
| 63 |
+
- `.dockerignore` (excludes .venv, .git, .env, pycache)
|
| 64 |
+
- README (HF frontmatter: `sdk: docker`, `app_port: 7860`)
|
| 65 |
+
- Env var handling (defaults for `API_BASE_URL`/`MODEL_NAME`, mandatory `HF_TOKEN`)
|
| 66 |
+
- `[START]`/`[STEP]`/`[END]` output format (spec-compliant)
|
| 67 |
+
- Runs all 3 tasks sequentially (easy → medium → hard)
|
| 68 |
+
- 8-rule heuristic in inference.py (JA3, geo, DDoS, cert, DNS, entropy, ports)
|
| 69 |
+
- LLM rate-limit backoff (exponential retry for 429 errors)
|
| 70 |
+
- LLM agent in `src/` with full error recovery
|
| 71 |
+
- Package `__init__.py` files in `env/`, `utils/`, `grader/`
|
| 72 |
+
- Test suite (38 tests passing)
|
| 73 |
+
- `conftest.py` (adds `src/` to PYTHONPATH for tests)
|
| 74 |
+
|
| 75 |
+
---
|
| 76 |
+
|
| 77 |
+
## 📋 TODO Checklist
|
| 78 |
+
|
| 79 |
+
### Priority 0 — MUST FIX (All Complete ✅)
|
| 80 |
+
|
| 81 |
+
- [x] Fix `[STEP]` action extraction in `inference.py`
|
| 82 |
+
- [x] Fix `[END]` line with `try/finally` in `inference.py`
|
| 83 |
+
- [x] Remove extra `error=` field from `[END]` line
|
| 84 |
+
- [x] Port 8-rule heuristic into `inference.py`
|
| 85 |
+
- [x] Fix `server/app.py` import — remove `src.` prefix
|
| 86 |
+
- [x] Initialize `action = 0` before if/else in inference loop
|
| 87 |
+
- [x] Add `__init__.py` to `env/`, `utils/`, `grader/`
|
| 88 |
+
- [x] Add rate-limit backoff to `inference.py` LLM calls
|
| 89 |
+
|
| 90 |
+
### Priority 1 — Should Fix (Before Deployment)
|
| 91 |
+
|
| 92 |
+
- [ ] Test Docker build locally (`docker build && docker run`)
|
| 93 |
+
- [ ] Verify `openenv-core` doesn't bloat image beyond 8 GB
|
| 94 |
+
- [ ] Rotate HF_TOKEN (leaked in git history)
|
| 95 |
+
|
| 96 |
+
### Priority 2 — Nice to Have
|
| 97 |
+
|
| 98 |
+
- [ ] Smart LLM gating — skip LLM for obvious-heuristic cases
|
| 99 |
+
- [ ] Consolidate `env/` + `utils/` + `grader/` into `src/adaptive_firewall_env/`
|
| 100 |
+
- [ ] Add Docker health check for inference.py readiness
|
| 101 |
+
|
| 102 |
+
---
|
| 103 |
+
|
| 104 |
+
## 📁 File-by-File Status
|
| 105 |
+
|
| 106 |
+
| File | Status | Notes |
|
| 107 |
+
|---|---|---|
|
| 108 |
+
| `inference.py` | ✅ Fixed | Spec-compliant output, 8-rule heuristic, rate-limit backoff |
|
| 109 |
+
| `Dockerfile` | ✅ OK | Copies all dirs, correct PYTHONPATH |
|
| 110 |
+
| `requirements.txt` | ✅ OK | Trimmed (openenv-core risk noted) |
|
| 111 |
+
| `openenv.yaml` | ✅ OK | Complete spec |
|
| 112 |
+
| `README.md` | ✅ OK | HF frontmatter present |
|
| 113 |
+
| `.env.example` | ✅ OK | Defaults documented |
|
| 114 |
+
| `.gitignore` | ✅ OK | `.env` listed |
|
| 115 |
+
| `.dockerignore` | ✅ OK | Excludes `.venv`, `.git`, `.env` |
|
| 116 |
+
| `server/app.py` | ✅ Fixed | Import corrected |
|
| 117 |
+
| `env/__init__.py` | ✅ Created | Package marker |
|
| 118 |
+
| `env/firewall_env.py` | ✅ OK | Core RL environment |
|
| 119 |
+
| `env/models.py` | ✅ OK | Pydantic models |
|
| 120 |
+
| `utils/__init__.py` | ✅ Created | Package marker |
|
| 121 |
+
| `utils/data_loader.py` | ✅ OK | Traffic generation |
|
| 122 |
+
| `utils/reward_engine.py` | ✅ OK | Multi-objective rewards |
|
| 123 |
+
| `utils/threat_engine.py` | ✅ OK | Import fixed |
|
| 124 |
+
| `grader/__init__.py` | ✅ Created | Package marker |
|
| 125 |
+
| `grader/firewall_grader.py` | ✅ OK | Scoring logic |
|
| 126 |
+
| `src/.../server/app.py` | ✅ OK | Full FastAPI server |
|
| 127 |
+
| `src/.../agents/llm_agent.py` | ✅ OK | All bugs fixed |
|
| 128 |
+
| `conftest.py` | ✅ OK | Adds `src/` to PYTHONPATH |
|
| 129 |
+
| `tests/` | ✅ OK | 38 tests passing |
|
_agents/skills/SKILL.md
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
name: adaptive_firewall_env
|
| 3 |
+
description: Skills for interacting with, training on, and evaluating the Adaptive AI Firewall OpenEnv environment
|
| 4 |
+
---
|
| 5 |
+
|
| 6 |
+
Use this skill pack to:
|
| 7 |
+
|
| 8 |
+
- inspect environment state, queue dynamics, and session-level observations
|
| 9 |
+
- train and evaluate policies in deterministic easy/medium/hard tasks
|
| 10 |
+
- compare RL models against random / heuristic / degenerate baselines
|
| 11 |
+
- debug budget usage, reward components, and attacker outcomes
|
| 12 |
+
|
| 13 |
+
Primary entry points:
|
| 14 |
+
|
| 15 |
+
1. `understand_environment.md` for architecture and interfaces.
|
| 16 |
+
2. `train_agent.md` for practical training loops.
|
| 17 |
+
3. `evaluate_agent.md` for deterministic benchmark protocol.
|
| 18 |
+
4. `debug_environment.md` for failure triage patterns.
|
_agents/skills/debug_environment.md
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Debug Environment
|
| 2 |
+
|
| 3 |
+
1. Validate deterministic resets:
|
| 4 |
+
- same seed + same policy must produce same score.
|
| 5 |
+
2. Inspect session lifecycle:
|
| 6 |
+
- pending vs inspected pools
|
| 7 |
+
- expiration counts for benign and malicious.
|
| 8 |
+
3. Inspect budget dynamics:
|
| 9 |
+
- `budget_remaining`
|
| 10 |
+
- `metrics.total_cost`
|
| 11 |
+
- efficiency in `get_network_stats()`.
|
| 12 |
+
4. Diagnose degenerate policy leaks:
|
| 13 |
+
- run block-all / allow-all baselines
|
| 14 |
+
- verify pass constraints reject them.
|
| 15 |
+
5. Verify single-session mode:
|
| 16 |
+
- observation size stays fixed (`22`)
|
| 17 |
+
- action range stays `[0..5]`.
|
_agents/skills/evaluate_agent.md
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Evaluate Agent
|
| 2 |
+
|
| 3 |
+
1. Run deterministic evaluation:
|
| 4 |
+
- `python -m adaptive_firewall_env.baseline.evaluate`
|
| 5 |
+
2. Compare policy against four references:
|
| 6 |
+
- random
|
| 7 |
+
- heuristic
|
| 8 |
+
- block-all
|
| 9 |
+
- allow-all
|
| 10 |
+
3. Confirm pass criteria includes both:
|
| 11 |
+
- weighted score threshold
|
| 12 |
+
- pass constraints (`min_detection_rate`, `min_fp_complement`)
|
| 13 |
+
4. Inspect per-task metrics:
|
| 14 |
+
- detection rate
|
| 15 |
+
- false-positive complement
|
| 16 |
+
- efficiency
|
| 17 |
+
- cascade prevention
|
_agents/skills/train_agent.md
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Train Agent
|
| 2 |
+
|
| 3 |
+
1. Start with `step_single` mode to get fixed-shape RL training (`Discrete(6)`).
|
| 4 |
+
2. Use medium task for initial optimization stability; then curriculum to hard.
|
| 5 |
+
3. Track reward decomposition (security, availability, efficiency, timeliness) each epoch.
|
| 6 |
+
4. Include inspected-session follow-up actions in policy design.
|
| 7 |
+
5. Validate every checkpoint with deterministic graders on all tasks.
|
| 8 |
+
6. Promote models only if:
|
| 9 |
+
- heuristic-level or better easy score
|
| 10 |
+
- non-zero detection and acceptable false-positive handling
|
_agents/skills/understand_environment.md
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Understand Environment
|
| 2 |
+
|
| 3 |
+
1. Read `server/firewall_environment.py` for:
|
| 4 |
+
- multi-session mode (`step`)
|
| 5 |
+
- single-session mode (`step_single`)
|
| 6 |
+
- inspect follow-up lifecycle and budget mechanics
|
| 7 |
+
2. Read `server/traffic_generator.py` for:
|
| 8 |
+
- feature order and normalization
|
| 9 |
+
- scenario- and phase-specific malicious profiles
|
| 10 |
+
3. Read `server/threat_engine.py` for:
|
| 11 |
+
- attacker lifecycle and adaptation
|
| 12 |
+
- attacker outcomes (`active`, `stopped`, `succeeded`)
|
| 13 |
+
4. Read `server/reward_engine.py` for:
|
| 14 |
+
- reward weights and anti-degeneracy design
|
| 15 |
+
5. Read `server/graders.py` for:
|
| 16 |
+
- deterministic seeds
|
| 17 |
+
- thresholds and pass constraints
|
_agents/workflows/deploy.md
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Deploy Workflow
|
| 2 |
+
|
| 3 |
+
1. Build runtime artifact:
|
| 4 |
+
- Docker image from `src/adaptive_firewall_env/server/Dockerfile`.
|
| 5 |
+
2. Run pre-deploy checks:
|
| 6 |
+
- `pytest -q`
|
| 7 |
+
- `ruff check src tests`
|
| 8 |
+
- baseline evaluator output generation.
|
| 9 |
+
3. Publish container or code to target hosting environment.
|
| 10 |
+
4. Post-deploy validation:
|
| 11 |
+
- `GET /health`
|
| 12 |
+
- `POST /reset`
|
| 13 |
+
- `POST /step_single`
|
| 14 |
+
5. Compare deployed baseline report with local deterministic report.
|
_agents/workflows/setup.md
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Setup Workflow
|
| 2 |
+
|
| 3 |
+
1. Create virtual environment:
|
| 4 |
+
- `py -m venv .venv`
|
| 5 |
+
2. Install dependencies:
|
| 6 |
+
- `.venv\Scripts\python -m pip install -U pip`
|
| 7 |
+
- `.venv\Scripts\python -m pip install pytest ruff requests numpy fastapi pydantic uvicorn`
|
| 8 |
+
3. Validate code quality:
|
| 9 |
+
- `.venv\Scripts\python -m pytest -q`
|
| 10 |
+
- `.venv\Scripts\python -m ruff check src tests`
|
| 11 |
+
4. Start service:
|
| 12 |
+
- `uvicorn adaptive_firewall_env.server.app:app --port 8000`
|
| 13 |
+
5. Run baseline evaluator for smoke confirmation.
|
_agents/workflows/train.md
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Train Workflow
|
| 2 |
+
|
| 3 |
+
1. Establish reference:
|
| 4 |
+
- run baseline evaluator and record heuristic score per task.
|
| 5 |
+
2. Begin in single-session mode (`step_single`) with medium task.
|
| 6 |
+
3. Train policy network on normalized 22-dim observations and `Discrete(6)` actions.
|
| 7 |
+
4. Include inspect follow-up strategy in action head logic.
|
| 8 |
+
5. Evaluate every checkpoint on deterministic seeds.
|
| 9 |
+
6. Promote model only if:
|
| 10 |
+
- easy and medium pass constraints satisfied
|
| 11 |
+
- hard score improves over random baseline
|
| 12 |
+
- no degeneration to block-all or allow-all behavior.
|
client.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
from typing import Any, Dict, List, Optional
|
| 3 |
+
|
| 4 |
+
class FirewallClient:
|
| 5 |
+
"""Client for interacting with the Adaptive AI Firewall server."""
|
| 6 |
+
|
| 7 |
+
def __init__(self, base_url: str = "http://localhost:7860"):
|
| 8 |
+
self.base_url = base_url.rstrip("/")
|
| 9 |
+
|
| 10 |
+
def health(self) -> Dict[str, Any]:
|
| 11 |
+
return requests.get(f"{self.base_url}/health").json()
|
| 12 |
+
|
| 13 |
+
def reset(self, task: str = "easy", seed: Optional[int] = None) -> Dict[str, Any]:
|
| 14 |
+
payload = {"task": task}
|
| 15 |
+
if seed is not None:
|
| 16 |
+
payload["seed"] = seed
|
| 17 |
+
return requests.post(f"{self.base_url}/reset", json=payload).json()
|
| 18 |
+
|
| 19 |
+
def step(self, actions: Dict[str, int]) -> Dict[str, Any]:
|
| 20 |
+
return requests.post(f"{self.base_url}/step", json={"actions": actions}).json()
|
| 21 |
+
|
| 22 |
+
def step_single(self, action: int) -> Dict[str, Any]:
|
| 23 |
+
return requests.post(f"{self.base_url}/step_single", json={"action": action}).json()
|
| 24 |
+
|
| 25 |
+
def state(self) -> Dict[str, Any]:
|
| 26 |
+
return requests.get(f"{self.base_url}/state").json()
|
| 27 |
+
|
| 28 |
+
def stats(self) -> Dict[str, Any]:
|
| 29 |
+
return requests.get(f"{self.base_url}/stats").json()
|
| 30 |
+
|
| 31 |
+
def list_tools(self) -> List[str]:
|
| 32 |
+
return requests.get(f"{self.base_url}/tools").json().get("tools", [])
|
| 33 |
+
|
| 34 |
+
def call_tool(self, name: str, kwargs: Dict[str, Any]) -> Any:
|
| 35 |
+
return requests.post(f"{self.base_url}/tool/{name}", json={"kwargs": kwargs}).json()
|
| 36 |
+
|
| 37 |
+
def schema(self) -> Dict[str, Any]:
|
| 38 |
+
return requests.get(f"{self.base_url}/schema").json()
|
conftest.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pytest configuration — ensure project root is on PYTHONPATH."""
|
| 2 |
+
import sys
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
# Add project root to path so tests can import server.*
|
| 6 |
+
sys.path.insert(0, str(Path(__file__).parent))
|
docs/ACTION_SPACE.md
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Action Space
|
| 2 |
+
|
| 3 |
+
## Discrete Actions
|
| 4 |
+
|
| 5 |
+
| ID | Action | Typical Use | Cost Type |
|
| 6 |
+
|---|---|---|---|
|
| 7 |
+
| 0 | `ALLOW` | pass low-risk traffic | none |
|
| 8 |
+
| 1 | `BLOCK` | immediate deny for high-confidence malicious sessions | low |
|
| 9 |
+
| 2 | `INSPECT` | collect additional evidence before terminal decision | medium |
|
| 10 |
+
| 3 | `SANDBOX` | isolate unknown/high-risk behavior | high |
|
| 11 |
+
| 4 | `RATE_LIMIT` | mitigate volumetric or burst anomalies | low-medium |
|
| 12 |
+
| 5 | `QUARANTINE` | isolate source identity while preserving observation | medium |
|
| 13 |
+
|
| 14 |
+
Costs are computed in `reward_engine.py` as latency + compute.
|
| 15 |
+
|
| 16 |
+
## Decision Pattern
|
| 17 |
+
|
| 18 |
+
1. If confidence is high and malicious indicators are strong: `BLOCK` / `QUARANTINE`.
|
| 19 |
+
2. If confidence is low but suspicious: `INSPECT` then follow-up action.
|
| 20 |
+
3. If traffic appears benign and reputation is healthy: `ALLOW`.
|
| 21 |
+
4. If volumetric anomaly dominates: `RATE_LIMIT` before hard block.
|
| 22 |
+
|
| 23 |
+
## RL Compatibility
|
| 24 |
+
|
| 25 |
+
- `action_space` is `Discrete(6)` in single-session mode.
|
| 26 |
+
- Multi-session mode applies the same discrete action per session ID in the action map.
|
docs/API_REFERENCE.md
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# API Reference
|
| 2 |
+
|
| 3 |
+
All endpoints are implemented in `server/app.py`.
|
| 4 |
+
|
| 5 |
+
## Core Environment Endpoints
|
| 6 |
+
|
| 7 |
+
| Method | Path | Purpose |
|
| 8 |
+
|---|---|---|
|
| 9 |
+
| `POST` | `/reset` | start a new episode (`task`, optional `seed`) |
|
| 10 |
+
| `POST` | `/step` | multi-session step with action map |
|
| 11 |
+
| `POST` | `/step_single` | single-session RL step (`action`) |
|
| 12 |
+
| `GET` | `/state` | current environment snapshot |
|
| 13 |
+
| `GET` | `/tools` | discover supported tool functions |
|
| 14 |
+
| `GET` | `/health` | liveness check |
|
| 15 |
+
|
| 16 |
+
## Tool Endpoints
|
| 17 |
+
|
| 18 |
+
- `POST /tool/evaluate_session`
|
| 19 |
+
- body: `{ "kwargs": { "session_id": "..." } }`
|
| 20 |
+
- `POST /tool/take_action`
|
| 21 |
+
- body: `{ "kwargs": { "session_id": "...", "action": 1 } }`
|
| 22 |
+
- `POST /tool/get_network_stats`
|
| 23 |
+
- body: `{ "kwargs": {} }`
|
| 24 |
+
- `POST /tool/get_threat_intelligence`
|
| 25 |
+
- body: `{ "kwargs": {} }`
|
| 26 |
+
|
| 27 |
+
## Typical Loop
|
| 28 |
+
|
| 29 |
+
1. `POST /reset`
|
| 30 |
+
2. `GET /state` to list candidate sessions
|
| 31 |
+
3. `POST /tool/evaluate_session` for selected sessions
|
| 32 |
+
4. `POST /step` or `POST /step_single`
|
| 33 |
+
5. repeat until `done=true`
|
docs/ARCHITECTURE.md
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Architecture
|
| 2 |
+
|
| 3 |
+
## System Diagram
|
| 4 |
+
|
| 5 |
+
```mermaid
|
| 6 |
+
flowchart LR
|
| 7 |
+
A[TrafficGenerator] --> E[FirewallEnvironment]
|
| 8 |
+
B[ThreatEngine] --> E
|
| 9 |
+
E --> C[RewardEngine]
|
| 10 |
+
E --> D[Graders]
|
| 11 |
+
E --> F[FastAPI App]
|
| 12 |
+
F --> G[Client / Agent]
|
| 13 |
+
G --> F
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
## Runtime Data Flow
|
| 17 |
+
|
| 18 |
+
```mermaid
|
| 19 |
+
sequenceDiagram
|
| 20 |
+
participant Agent
|
| 21 |
+
participant Env as FirewallEnvironment
|
| 22 |
+
participant TG as TrafficGenerator
|
| 23 |
+
participant TH as ThreatEngine
|
| 24 |
+
participant RW as RewardEngine
|
| 25 |
+
|
| 26 |
+
Agent->>Env: reset(task, seed)
|
| 27 |
+
Env->>TG: generate_benign_sessions
|
| 28 |
+
Env->>TH: maybe_spawn_attacker + generate_attack_sessions
|
| 29 |
+
Env-->>Agent: state
|
| 30 |
+
Agent->>Env: step(action_map) or step_single(action)
|
| 31 |
+
Env->>RW: reward(action, is_malicious, budget_remaining, phase)
|
| 32 |
+
Env-->>Agent: reward, done, info, next state
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
## Core Components
|
| 36 |
+
|
| 37 |
+
| Component | Responsibility | Key Outputs |
|
| 38 |
+
|---|---|---|
|
| 39 |
+
| `firewall_environment.py` | Episode orchestration, budget tracking, session lifecycle, metrics | `state()`, `step()`, `step_single()`, tool APIs |
|
| 40 |
+
| `traffic_generator.py` | Benign + malicious metadata generation, normalization, scenario shaping | 22-dim normalized observation vectors |
|
| 41 |
+
| `threat_engine.py` | Multi-attacker orchestration, adaptation, lifecycle and outcomes | Attack sessions, attacker status map |
|
| 42 |
+
| `reward_engine.py` | Multi-objective reward calculation and action-cost accounting | scalar reward + component breakdown |
|
| 43 |
+
| `graders.py` | Deterministic task scoring and pass/fail gating | score in `[0,1]`, pass constraints |
|
| 44 |
+
| `baseline/evaluate.py` | Policy benchmarking across tasks | JSON report for random/heuristic/block/allow |
|
| 45 |
+
|
| 46 |
+
## Environment Modes
|
| 47 |
+
|
| 48 |
+
- **Multi-session mode**: `step(action_map)` handles a variable batch of sessions per tick.
|
| 49 |
+
- **Single-session mode**: `step_single(action)` exposes one decision at a time with `Discrete(6)` semantics.
|
| 50 |
+
- **Inspect workflow**: inspect is first-stage evidence collection; follow-up action resolves the session.
|
docs/DEPLOYMENT.md
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Deployment
|
| 2 |
+
|
| 3 |
+
## Local Runtime
|
| 4 |
+
|
| 5 |
+
```bash
|
| 6 |
+
uvicorn adaptive_firewall_env.server.app:app --host 0.0.0.0 --port 8000
|
| 7 |
+
```
|
| 8 |
+
|
| 9 |
+
## Container Runtime
|
| 10 |
+
|
| 11 |
+
```bash
|
| 12 |
+
docker build -f src/adaptive_firewall_env/server/Dockerfile -t adaptive-firewall-env .
|
| 13 |
+
docker run --rm -p 8000:8000 adaptive-firewall-env
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
## OpenEnv Metadata
|
| 17 |
+
|
| 18 |
+
- Manifest path: `src/adaptive_firewall_env/openenv.yaml`
|
| 19 |
+
- Runtime type: FastAPI app (`server.app:app`)
|
| 20 |
+
- Default port: `8000`
|
| 21 |
+
|
| 22 |
+
## Smoke Checks
|
| 23 |
+
|
| 24 |
+
- `GET /health` returns `{ "status": "ok" }`
|
| 25 |
+
- `POST /reset` returns episode state
|
| 26 |
+
- `POST /step_single` returns next observation and reward
|
docs/REWARD_DESIGN.md
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 💰 Reward Design — Multi-Objective Optimization
|
| 2 |
+
|
| 3 |
+
The Adaptive AI Firewall environment uses a sophisticated, weighted reward function designed to drive agent behavior toward a balance of security efficacy, network availability, and resource efficiency.
|
| 4 |
+
|
| 5 |
+
## 📐 The Reward Equation
|
| 6 |
+
|
| 7 |
+
The total scalar reward $R$ for any action is calculated as:
|
| 8 |
+
|
| 9 |
+
$$R = \alpha \cdot R_{\text{security}} + \beta \cdot R_{\text{availability}} + \gamma \cdot R_{\text{efficiency}} + \delta \cdot R_{\text{timeliness}}$$
|
| 10 |
+
|
| 11 |
+
### **Default Weights**
|
| 12 |
+
| Component | Weight | Responsibility |
|
| 13 |
+
|---|---|---|
|
| 14 |
+
| $\alpha$ | **0.35** | Security Efficacy (Catching threats) |
|
| 15 |
+
| $\beta$ | **0.30** | Network Availability (Avoiding False Positives) |
|
| 16 |
+
| $\gamma$ | **0.20** | Resource Efficiency (Budget management) |
|
| 17 |
+
| $\delta$ | **0.15** | Timeliness (Stopping attacks early) |
|
| 18 |
+
|
| 19 |
+
---
|
| 20 |
+
|
| 21 |
+
## 🧩 Reward Components
|
| 22 |
+
|
| 23 |
+
### **1. Security ($R_{\text{security}}$)**
|
| 24 |
+
- **Block Malicious**: $+1.0$ (Successfully stopped a threat).
|
| 25 |
+
- **Miss Malicious**: $-2.0$ (Failed to block an attack; high penalty).
|
| 26 |
+
- **Inspect Malicious**: $+0.15$ (Correct identification, though not yet stopped).
|
| 27 |
+
- **Inspect Benign**: $-0.5$ (Unnecessary inspection).
|
| 28 |
+
|
| 29 |
+
### **2. Availability ($R_{\text{availability}}$)**
|
| 30 |
+
- **Allow Benign**: $+0.25$ (Maintaining network flow).
|
| 31 |
+
- **Block Benign (FP)**: $-1.2$ (Significant penalty for disrupting legitimate users).
|
| 32 |
+
- **Rate Limit Benign**: $-0.4$ (Milder penalty for "gray" actions).
|
| 33 |
+
- **Inspect Benign**: $-0.15$ (Unnecessary latency added).
|
| 34 |
+
|
| 35 |
+
### **3. Efficiency ($R_{\text{efficiency}}$)**
|
| 36 |
+
- **Cost**: Calculated as $\text{latency} + \text{compute}$ for each action.
|
| 37 |
+
- **Scaling**: Penalized relative to remaining budget: $R_{\text{efficiency}} = -\frac{\text{cost}}{\max(\text{budget\_remaining}, 0.1)}$.
|
| 38 |
+
- This creates **Strategic Pressure**: actions become "more expensive" as the budget depletes.
|
| 39 |
+
|
| 40 |
+
### **4. Timeliness ($R_{\text{timeliness}}$)**
|
| 41 |
+
- **Early Detection**: $+e^{-\text{phase}}$ where `phase` is the attacker's progress in the kill chain (0 to 4).
|
| 42 |
+
- **Incentive**: Stopping an attack at Phase 0 is significantly more rewarding than at Phase 3.
|
| 43 |
+
|
| 44 |
+
---
|
| 45 |
+
|
| 46 |
+
## 📊 Worked Examples
|
| 47 |
+
|
| 48 |
+
| Scenario | Action | Security | Availability | Efficiency | Timeliness | **Total Reward** |
|
| 49 |
+
|---|---|---|---|---|---|---|
|
| 50 |
+
| **Legitimate User** | `ALLOW` | $0.0$ | $+0.25$ | $0.0$ | $0.0$ | **$+0.075$** |
|
| 51 |
+
| **Early Attack (Ph 0)** | `BLOCK` | $+1.0$ | $0.0$ | $-0.005$ | $+1.0$ | **$+0.499$** |
|
| 52 |
+
| **Late Attack (Ph 3)** | `BLOCK` | $+1.0$ | $0.0$ | $-0.005$ | $+0.05$ | **$+0.357$** |
|
| 53 |
+
| **False Positive** | `BLOCK` | $0.0$ | $-1.2$ | $-0.005$ | $0.0$ | **$-0.361$** |
|
| 54 |
+
| **Missed Attack** | `ALLOW` | $-2.0$ | $0.0$ | $0.0$ | $0.0$ | **$-0.700$** |
|
| 55 |
+
|
| 56 |
+
---
|
| 57 |
+
|
| 58 |
+
## 🛡️ Anti-Degeneracy Controls
|
| 59 |
+
|
| 60 |
+
To prevent agents from learning "lazy" policies (like blocking everything or allowing everything), the environment implements:
|
| 61 |
+
|
| 62 |
+
1. **Reward Balancing**: The ratio of Miss Penalty to FP Penalty is tuned (~2.3:1) so that on a typical 80/20 traffic mix, a `block_all` policy yields a negative total reward.
|
| 63 |
+
2. **Pass/Fail Constraints**: Graders in [graders.py](file:///c:/Users/vettrivel/Documents/GitHub/meta_ai_hackathon/src/adaptive_firewall_env/server/graders.py) require a minimum detection rate **AND** a minimum availability rate to pass a task, regardless of the scalar reward.
|
docs/STATE_SPACE.md
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# State Space
|
| 2 |
+
|
| 3 |
+
The environment uses a **22-dimensional normalized observation vector** (`[0,1]` per feature).
|
| 4 |
+
Order is fixed by `FEATURE_ORDER` in `traffic_generator.py`.
|
| 5 |
+
|
| 6 |
+
## Feature Groups
|
| 7 |
+
|
| 8 |
+
| Group | Features | Semantics |
|
| 9 |
+
|---|---|---|
|
| 10 |
+
| Volume & timing | bytes sent/received, duration, packet count, packet variance, inter-arrival mean/jitter | throughput shape and temporal burstiness |
|
| 11 |
+
| Network metadata | src/dst ports, protocol, DNS query count, connection reuse | routing and communication pattern |
|
| 12 |
+
| TLS / certificate | TLS version, JA3 cluster, chain length, cert validity, self-signed | encrypted-session trust indicators |
|
| 13 |
+
| Behavioral context | geo distance, time of day, session history score, entropy score | reputation and anomaly context |
|
| 14 |
+
|
| 15 |
+
## Observation Interfaces
|
| 16 |
+
|
| 17 |
+
- `evaluate_session(session_id)` returns the vector for a given session.
|
| 18 |
+
- `state()` returns environment-level counters and selected session IDs.
|
| 19 |
+
- `step_single(action)` returns `observation` for the next queued session.
|
| 20 |
+
|
| 21 |
+
## Normalization Strategy
|
| 22 |
+
|
| 23 |
+
- Each raw feature is min-max normalized using bounded ranges in `FEATURE_BOUNDS`.
|
| 24 |
+
- Outliers are clipped to `[0,1]` after normalization.
|
| 25 |
+
- This enables stable neural training across heterogeneous scales (ports, durations, entropy).
|
| 26 |
+
|
| 27 |
+
## Markov Context Notes
|
| 28 |
+
|
| 29 |
+
- Single-session mode is designed for fixed-shape RL loops.
|
| 30 |
+
- Multi-session mode supports tool-driven decision systems over dynamic queues.
|
docs/TASKS.md
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Tasks
|
| 2 |
+
|
| 3 |
+
## Difficulty Levels
|
| 4 |
+
|
| 5 |
+
| Task | Steps | Threshold | Pass Constraints |
|
| 6 |
+
|---|---:|---:|---|
|
| 7 |
+
| Easy | 200 | 0.70 | detection ≥ 0.35 and fp_complement ≥ 0.65 |
|
| 8 |
+
| Medium | 500 | 0.50 | detection ≥ 0.35 and fp_complement ≥ 0.60 |
|
| 9 |
+
| Hard | 1000 | 0.45 | detection ≥ 0.35 and fp_complement ≥ 0.55 |
|
| 10 |
+
|
| 11 |
+
## Why Constraints Exist
|
| 12 |
+
|
| 13 |
+
Weighted scores alone can be gamed by degenerate policies:
|
| 14 |
+
- `allow_all` inflates availability/efficiency.
|
| 15 |
+
- `block_all` inflates detection.
|
| 16 |
+
|
| 17 |
+
The pass constraints ensure any passing policy must satisfy both:
|
| 18 |
+
1. meaningful threat detection,
|
| 19 |
+
2. acceptable benign-traffic handling.
|
| 20 |
+
|
| 21 |
+
Task scoring logic is implemented in `server/graders.py`.
|
docs/THREAT_MODELS.md
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Threat Models
|
| 2 |
+
|
| 3 |
+
## Scenario Catalog
|
| 4 |
+
|
| 5 |
+
| Scenario | Early Phase | Mid Phase | Late Phase |
|
| 6 |
+
|---|---|---|---|
|
| 7 |
+
| `port_scan_exploit_c2` | rapid probing | exploit delivery | command/control + exfil |
|
| 8 |
+
| `credential_stuffing_lateral` | auth pressure | lateral movement | persistence |
|
| 9 |
+
| `supply_chain_compromise` | stealth foothold | trusted-channel abuse | disguised exfiltration |
|
| 10 |
+
| `low_and_slow_apt` | sparse reconnaissance | long dwell C2 | slow extraction |
|
| 11 |
+
| `ddos_amplification` | reflection probes | traffic amplification | flood stage |
|
| 12 |
+
|
| 13 |
+
## Adaptation Behavior
|
| 14 |
+
|
| 15 |
+
- Repeated blocking increases attacker detection count.
|
| 16 |
+
- Detected attackers can switch to stealth mode and alter feature distributions.
|
| 17 |
+
- Attackers terminate when repeatedly blocked, time out, or complete exfiltration.
|
| 18 |
+
- Threat engine exposes per-attacker outcomes (`active`, `stopped`, `succeeded`) for analysis and credit assignment.
|
| 19 |
+
|
| 20 |
+
Threat generation and lifecycle are implemented in `server/threat_engine.py`.
|
implementation_plan.md
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adaptive AI Firewall — OpenEnv RL Challenge Compliance
|
| 2 |
+
|
| 3 |
+
## Background
|
| 4 |
+
|
| 5 |
+
Your current codebase has a solid firewall RL environment, grader, and inference agent. However, several critical areas need changes to pass the hackathon's automated validation. I've analyzed the reference repos (reasoning_gym_env, calendar_env) and the submission guidelines in detail.
|
| 6 |
+
|
| 7 |
+
## User Review Required
|
| 8 |
+
|
| 9 |
+
> [!IMPORTANT]
|
| 10 |
+
> **Ollama vs HuggingFace Router**: The hackathon guidelines mandate using the **OpenAI Client** with `API_BASE_URL` (default pointing to HuggingFace router) and `HF_TOKEN`. You mentioned wanting to use Ollama — but **the evaluation system will inject its own `API_BASE_URL` and `MODEL_NAME`** pointing to their hosted models. Your code must use the OpenAI client talking to whatever `API_BASE_URL` is provided. Ollama won't work during evaluation because the Docker container runs on HF Spaces with 2 vCPU / 8 GB RAM — no room to run a local LLM. Your current setup (HF router + OpenAI client) is **already correct**. I'll keep it as-is.
|
| 11 |
+
|
| 12 |
+
> [!WARNING]
|
| 13 |
+
> **Your `.env` file contains a real `HF_TOKEN`**. This is committed to git. You should rotate this token after we're done and add `.env` to `.gitignore`.
|
| 14 |
+
|
| 15 |
+
---
|
| 16 |
+
|
| 17 |
+
## Proposed Changes
|
| 18 |
+
|
| 19 |
+
### 1. `inference.py` — Complete Rewrite (Critical)
|
| 20 |
+
|
| 21 |
+
#### [MODIFY] [inference.py](file:///c:/Users/vettrivel/Documents/GitHub/meta_ai_hackathon/inference.py)
|
| 22 |
+
|
| 23 |
+
**Current problems:**
|
| 24 |
+
- ❌ No `[START]` / `[STEP]` / `[END]` output lines (the #1 compliance requirement)
|
| 25 |
+
- ❌ `API_BASE_URL` and `MODEL_NAME` have no default values (will fail validation)
|
| 26 |
+
- ❌ `HF_TOKEN` is not validated as mandatory (should raise on missing)
|
| 27 |
+
- ❌ Uses `argparse` — evaluation just runs `python inference.py`
|
| 28 |
+
- ❌ Output is JSON, not the required line format
|
| 29 |
+
|
| 30 |
+
**Changes:**
|
| 31 |
+
- Add default values: `API_BASE_URL="https://router.huggingface.co/v1"`, `MODEL_NAME="meta-llama/Llama-3.1-8B-Instruct"`
|
| 32 |
+
- Raise `ValueError` if `HF_TOKEN` is missing
|
| 33 |
+
- Print `[START]` line before episode begins
|
| 34 |
+
- Print `[STEP]` line immediately after each `env.step()` return
|
| 35 |
+
- Print `[END]` line after episode ends (even on exception, using try/finally)
|
| 36 |
+
- Format rewards to 2 decimal places, booleans as lowercase `true`/`false`
|
| 37 |
+
- Remove argparse; hardcode task or pick from env var
|
| 38 |
+
- Keep the LLM-based agent logic (get_action) but fix it to work with defaults
|
| 39 |
+
- Run all 3 tasks sequentially (easy, medium, hard) or pick the best one
|
| 40 |
+
|
| 41 |
+
---
|
| 42 |
+
|
| 43 |
+
### 2. Server — Align with OpenEnv `create_app` Pattern
|
| 44 |
+
|
| 45 |
+
#### [MODIFY] [app.py](file:///c:/Users/vettrivel/Documents/GitHub/meta_ai_hackathon/src/adaptive_firewall_env/server/app.py) (primary server)
|
| 46 |
+
|
| 47 |
+
**Current problems:**
|
| 48 |
+
- ❌ Hand-rolled FastAPI endpoints — reference repos use `openenv.core.env_server.http_server.create_app()`
|
| 49 |
+
- ❌ The import chain in `server/app.py` (root) references a non-existent module path
|
| 50 |
+
|
| 51 |
+
**Changes:**
|
| 52 |
+
- **Keep the current hand-rolled server** since `create_app` requires `openenv.core.env_server.interfaces.Environment` base class and the firewall env doesn't extend it. Refactoring to use `create_app` would require significant env restructuring.
|
| 53 |
+
- Instead, fix the root `server/app.py` to correctly import from the right location
|
| 54 |
+
- Add `/web` endpoint for HF Spaces web interface compatibility
|
| 55 |
+
- Add `/schema` endpoint returning action/observation schemas
|
| 56 |
+
|
| 57 |
+
#### [MODIFY] [app.py](file:///c:/Users/vettrivel/Documents/GitHub/meta_ai_hackathon/server/app.py) (root server entry)
|
| 58 |
+
|
| 59 |
+
- Fix the broken import `from adaptive_firewall_env.server.app import app`
|
| 60 |
+
- Make it correctly reference the actual app
|
| 61 |
+
|
| 62 |
+
---
|
| 63 |
+
|
| 64 |
+
### 3. Dockerfile — Production Ready for HF Spaces
|
| 65 |
+
|
| 66 |
+
#### [MODIFY] [Dockerfile](file:///c:/Users/vettrivel/Documents/GitHub/meta_ai_hackathon/Dockerfile)
|
| 67 |
+
|
| 68 |
+
**Current problems:**
|
| 69 |
+
- ❌ Doesn't copy `inference.py` into the container
|
| 70 |
+
- ❌ Doesn't copy `env/`, `grader/`, `utils/` directories
|
| 71 |
+
- ❌ Heavy dependencies (torch, stable-baselines3) blow through 8 GB RAM
|
| 72 |
+
|
| 73 |
+
**Changes:**
|
| 74 |
+
- Copy ALL required source directories (`env/`, `grader/`, `utils/`, `inference.py`, `models/`)
|
| 75 |
+
- Set `PYTHONPATH` correctly
|
| 76 |
+
- Optimize requirements for smaller image size
|
| 77 |
+
- Keep `CMD` as uvicorn for the server (HF Spaces), but ensure `inference.py` can also run independently
|
| 78 |
+
|
| 79 |
+
---
|
| 80 |
+
|
| 81 |
+
### 4. Requirements — Trim for 8 GB RAM Constraint
|
| 82 |
+
|
| 83 |
+
#### [MODIFY] [requirements.txt](file:///c:/Users/vettrivel/Documents/GitHub/meta_ai_hackathon/requirements.txt)
|
| 84 |
+
|
| 85 |
+
**Changes:**
|
| 86 |
+
- Remove `torch` (huge, not needed for inference — agent uses OpenAI API)
|
| 87 |
+
- Remove `stable-baselines3` (training framework, not needed at inference)
|
| 88 |
+
- Remove `shimmy` (adapter for SB3)
|
| 89 |
+
- Remove `gymnasium` (not needed if using custom env directly)
|
| 90 |
+
- Keep: `fastapi`, `uvicorn`, `numpy`, `pydantic`, `requests`, `openai`, `python-dotenv`
|
| 91 |
+
|
| 92 |
+
---
|
| 93 |
+
|
| 94 |
+
### 5. `.env.example` — Fix Defaults
|
| 95 |
+
|
| 96 |
+
#### [MODIFY] [.env.example](file:///c:/Users/vettrivel/Documents/GitHub/meta_ai_hackathon/.env.example)
|
| 97 |
+
|
| 98 |
+
- Document that `HF_TOKEN` is **mandatory**
|
| 99 |
+
- Show default values for `API_BASE_URL` and `MODEL_NAME`
|
| 100 |
+
|
| 101 |
+
---
|
| 102 |
+
|
| 103 |
+
### 6. Fix Import Chain in `utils/threat_engine.py`
|
| 104 |
+
|
| 105 |
+
#### [MODIFY] [threat_engine.py](file:///c:/Users/vettrivel/Documents/GitHub/meta_ai_hackathon/utils/threat_engine.py)
|
| 106 |
+
|
| 107 |
+
**Current problem:**
|
| 108 |
+
- Line 17: `from adaptive_firewall_env.server.traffic_generator import TrafficGenerator` — wrong import path
|
| 109 |
+
- Should be `from utils.data_loader import TrafficGenerator`
|
| 110 |
+
|
| 111 |
+
---
|
| 112 |
+
|
| 113 |
+
### 7. `.gitignore` — Protect Secrets
|
| 114 |
+
|
| 115 |
+
#### [MODIFY] [.gitignore](file:///c:/Users/vettrivel/Documents/GitHub/meta_ai_hackathon/.gitignore)
|
| 116 |
+
|
| 117 |
+
- Ensure `.env` is listed (prevent token leaks)
|
| 118 |
+
|
| 119 |
+
---
|
| 120 |
+
|
| 121 |
+
## Architecture Summary After Changes
|
| 122 |
+
|
| 123 |
+
```
|
| 124 |
+
meta_ai_hackathon/
|
| 125 |
+
├── inference.py ← MAIN ENTRY POINT (hackathon requirement)
|
| 126 |
+
├── Dockerfile ← HF Spaces deployment
|
| 127 |
+
├── requirements.txt ← Trimmed dependencies
|
| 128 |
+
├── openenv.yaml ← Environment manifest
|
| 129 |
+
├── .env.example ← Template with docs
|
| 130 |
+
├── env/
|
| 131 |
+
│ ├── firewall_env.py ← Core RL environment
|
| 132 |
+
│ └── models.py ← Pydantic request/response models
|
| 133 |
+
├── grader/
|
| 134 |
+
│ └── firewall_grader.py ← Scoring logic
|
| 135 |
+
├── utils/
|
| 136 |
+
│ ├── data_loader.py ← Traffic generation
|
| 137 |
+
│ ├── reward_engine.py ← Multi-objective rewards
|
| 138 |
+
│ └── threat_engine.py ← Attack orchestration (import fixed)
|
| 139 |
+
├── server/
|
| 140 |
+
│ └── app.py ← FastAPI server for HF Spaces
|
| 141 |
+
└── src/adaptive_firewall_env/server/
|
| 142 |
+
└── app.py ← Full server with LLM playground
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
+
## Open Questions
|
| 146 |
+
|
| 147 |
+
> [!IMPORTANT]
|
| 148 |
+
> **Which tasks to run?** The hackathon evaluator likely runs `python inference.py` without arguments. Should we:
|
| 149 |
+
> - (A) Run all 3 tasks (easy, medium, hard) sequentially and output [START]/[STEP]/[END] for each?
|
| 150 |
+
> - (B) Run only `easy` by default?
|
| 151 |
+
> - I recommend **(A)** — running all 3 tasks to maximize score visibility. Each gets its own `[START]`/`[END]` block.
|
| 152 |
+
|
| 153 |
+
> [!IMPORTANT]
|
| 154 |
+
> **Max steps per task:** Easy=200, Medium=500, Hard=1000 steps. With LLM calls at each step, this could be slow with rate limits. Should I add a timeout or fallback more aggressively to heuristics?
|
| 155 |
+
|
| 156 |
+
## Verification Plan
|
| 157 |
+
|
| 158 |
+
### Automated Tests
|
| 159 |
+
1. Run `python inference.py` and verify stdout matches the exact format:
|
| 160 |
+
```
|
| 161 |
+
[START] task=easy env=ai-firewall model=meta-llama/Llama-3.1-8B-Instruct
|
| 162 |
+
[STEP] step=1 action=ALLOW reward=0.00 done=false error=null
|
| 163 |
+
...
|
| 164 |
+
[END] success=true steps=200 rewards=0.00,0.00,...
|
| 165 |
+
```
|
| 166 |
+
2. Run `docker build -t ai-firewall .` and verify it builds under 8 GB
|
| 167 |
+
3. Run the container and hit `/health` endpoint
|
| 168 |
+
4. Verify all env vars work with defaults when `HF_TOKEN` is set
|
| 169 |
+
|
| 170 |
+
### Manual Verification
|
| 171 |
+
- Deploy to HF Spaces and confirm the space reaches "Running" state
|
| 172 |
+
- Verify the web interface loads at the space URL
|
inference.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
import time
|
| 7 |
+
import textwrap
|
| 8 |
+
from typing import Any, Dict, List, Optional
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
from openai import OpenAI
|
| 12 |
+
from dotenv import load_dotenv
|
| 13 |
+
|
| 14 |
+
# Import the environment directly for the AI Firewall
|
| 15 |
+
from server.firewall_environment import FirewallEnvironment, ACTIONS, TASK_CONFIGS
|
| 16 |
+
|
| 17 |
+
# --- Hackathon Submission Rules Compliance ---
|
| 18 |
+
# 1. inference.py in root directory ✅
|
| 19 |
+
# 2. Use OpenAI Client for all LLM calls ✅
|
| 20 |
+
# 3. Required Environment Variables with Defaults ✅
|
| 21 |
+
# 4. Strict Output Format: [START], [STEP], [END] ✅
|
| 22 |
+
|
| 23 |
+
load_dotenv()
|
| 24 |
+
|
| 25 |
+
# Environment Variables per Spec (defaults required for API_BASE_URL and MODEL_NAME)
|
| 26 |
+
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 27 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct")
|
| 28 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 29 |
+
|
| 30 |
+
if HF_TOKEN is None:
|
| 31 |
+
raise ValueError("HF_TOKEN environment variable is required")
|
| 32 |
+
|
| 33 |
+
# Benchmark configuration
|
| 34 |
+
BENCHMARK = "ai-firewall"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def format_bool(v: bool) -> str:
|
| 38 |
+
return "true" if v else "false"
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def log_start(task: str, env: str, model: str) -> None:
|
| 42 |
+
print(f"[START] task={task} env={env} model={model}", flush=True)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
|
| 46 |
+
error_val = error if error else "null"
|
| 47 |
+
done_val = format_bool(done)
|
| 48 |
+
print(f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", flush=True)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def log_end(success: bool, steps: int, rewards: List[float]) -> None:
|
| 52 |
+
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
|
| 53 |
+
print(f"[END] success={format_bool(success)} steps={steps} rewards={rewards_str}", flush=True)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class InferenceAgent:
|
| 57 |
+
def __init__(self):
|
| 58 |
+
self.client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
|
| 59 |
+
|
| 60 |
+
def get_action(self, session_data: Dict[str, Any], threat_intel: Dict[str, Any]) -> int:
|
| 61 |
+
"""Get action using LLM via OpenAI client interface with heuristic fallback."""
|
| 62 |
+
system_prompt = textwrap.dedent(
|
| 63 |
+
"""
|
| 64 |
+
You are an adaptive AI firewall controller.
|
| 65 |
+
Respond with ONLY valid JSON in this shape: {"reasoning": string, "action": integer}.
|
| 66 |
+
Action must be one integer between 0 and 5: 0=ALLOW, 1=BLOCK, 2=INSPECT, 3=SANDBOX, 4=RATE_LIMIT, 5=QUARANTINE.
|
| 67 |
+
Keep reasoning short (under 20 words).
|
| 68 |
+
"""
|
| 69 |
+
).strip()
|
| 70 |
+
|
| 71 |
+
user_prompt = json.dumps({
|
| 72 |
+
"session": session_data,
|
| 73 |
+
"threat_intelligence": threat_intel,
|
| 74 |
+
"actions": ACTIONS
|
| 75 |
+
})
|
| 76 |
+
|
| 77 |
+
max_retries = 3
|
| 78 |
+
for attempt in range(max_retries):
|
| 79 |
+
try:
|
| 80 |
+
response = self.client.chat.completions.create(
|
| 81 |
+
model=MODEL_NAME,
|
| 82 |
+
messages=[
|
| 83 |
+
{"role": "system", "content": system_prompt},
|
| 84 |
+
{"role": "user", "content": user_prompt}
|
| 85 |
+
],
|
| 86 |
+
temperature=0.2,
|
| 87 |
+
max_tokens=150,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
raw_content = response.choices[0].message.content
|
| 91 |
+
|
| 92 |
+
# Attempt to parse JSON
|
| 93 |
+
if "```json" in raw_content:
|
| 94 |
+
raw_content = raw_content.split("```json")[1].split("```")[0].strip()
|
| 95 |
+
elif "```" in raw_content:
|
| 96 |
+
raw_content = raw_content.split("```")[1].split("```")[0].strip()
|
| 97 |
+
|
| 98 |
+
content = json.loads(raw_content)
|
| 99 |
+
action = int(content.get("action", 0))
|
| 100 |
+
return max(0, min(5, action))
|
| 101 |
+
|
| 102 |
+
except Exception as e:
|
| 103 |
+
if "429" in str(e) and attempt < max_retries - 1:
|
| 104 |
+
time.sleep(2 ** attempt)
|
| 105 |
+
continue
|
| 106 |
+
return self._heuristic_action(session_data, threat_intel)
|
| 107 |
+
|
| 108 |
+
return self._heuristic_action(session_data, threat_intel)
|
| 109 |
+
|
| 110 |
+
def _heuristic_action(self, session_data: Dict[str, Any], threat_intel: Dict[str, Any]) -> int:
|
| 111 |
+
"""Rule-based fallback with 8 detection rules."""
|
| 112 |
+
features = session_data.get("features", {})
|
| 113 |
+
known_bad_ports = set(threat_intel.get("known_bad_ports", []))
|
| 114 |
+
|
| 115 |
+
if session_data.get("revealed_malicious") is True:
|
| 116 |
+
return 1 # BLOCK
|
| 117 |
+
|
| 118 |
+
dst_port = int(features.get("dst_port", 0))
|
| 119 |
+
history = float(features.get("session_history_score", 1.0))
|
| 120 |
+
entropy = float(features.get("entropy_score", 0.0))
|
| 121 |
+
reuse = float(features.get("connection_reuse", 1.0))
|
| 122 |
+
self_signed = int(features.get("is_self_signed", 0))
|
| 123 |
+
ja3 = int(features.get("ja3_hash_cluster", 0))
|
| 124 |
+
geo = float(features.get("geo_distance", 0.0))
|
| 125 |
+
cert_valid = float(features.get("cert_validity_days", 999.0))
|
| 126 |
+
tls_ver = int(features.get("tls_version", 1))
|
| 127 |
+
dns_q = int(features.get("dns_query_count", 0))
|
| 128 |
+
dur = float(features.get("duration_ms", 500.0))
|
| 129 |
+
pkts = int(features.get("packet_count", 10))
|
| 130 |
+
|
| 131 |
+
if dst_port in known_bad_ports and history < 0.50:
|
| 132 |
+
return 1
|
| 133 |
+
if self_signed == 1 and history < 0.45:
|
| 134 |
+
return 5
|
| 135 |
+
if entropy > 0.55 and reuse < 0.25:
|
| 136 |
+
return 2
|
| 137 |
+
if geo > 4000.0 and history < 0.40:
|
| 138 |
+
return 2
|
| 139 |
+
if ja3 >= 180:
|
| 140 |
+
return 1
|
| 141 |
+
if dur < 60.0 and pkts > 100:
|
| 142 |
+
return 4
|
| 143 |
+
if cert_valid < 80.0 and tls_ver == 0:
|
| 144 |
+
return 2
|
| 145 |
+
if reuse < 0.10 and dns_q >= 4:
|
| 146 |
+
return 2
|
| 147 |
+
|
| 148 |
+
return 0 # ALLOW
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def run_task(agent: InferenceAgent, task: str):
|
| 152 |
+
"""Run a single task episode and emit spec-compliant output."""
|
| 153 |
+
seeds = {"easy": 101, "medium": 202, "hard": 303}
|
| 154 |
+
env = FirewallEnvironment(seed=seeds.get(task, 101))
|
| 155 |
+
|
| 156 |
+
log_start(task=task, env=BENCHMARK, model=MODEL_NAME)
|
| 157 |
+
|
| 158 |
+
state = env.reset(task=task)
|
| 159 |
+
done = False
|
| 160 |
+
rewards: List[float] = []
|
| 161 |
+
steps_taken = 0
|
| 162 |
+
success = False
|
| 163 |
+
|
| 164 |
+
try:
|
| 165 |
+
while not done:
|
| 166 |
+
action = 0
|
| 167 |
+
error_msg = None
|
| 168 |
+
|
| 169 |
+
focus_session_id = state.get("focus_session_id")
|
| 170 |
+
if focus_session_id:
|
| 171 |
+
try:
|
| 172 |
+
session_data = env.evaluate_session(focus_session_id)
|
| 173 |
+
threat_intel = env.get_threat_intelligence()
|
| 174 |
+
action = agent.get_action(session_data, threat_intel)
|
| 175 |
+
result = env.step_single(action)
|
| 176 |
+
except Exception as e:
|
| 177 |
+
error_msg = str(e)
|
| 178 |
+
result = env.step_single(0)
|
| 179 |
+
else:
|
| 180 |
+
result = env.step_single(0)
|
| 181 |
+
|
| 182 |
+
reward = float(result["reward"])
|
| 183 |
+
done = bool(result["done"])
|
| 184 |
+
state = result["state"]
|
| 185 |
+
steps_taken += 1
|
| 186 |
+
rewards.append(reward)
|
| 187 |
+
|
| 188 |
+
log_step(
|
| 189 |
+
step=steps_taken,
|
| 190 |
+
action=ACTIONS.get(action, "ALLOW"),
|
| 191 |
+
reward=reward,
|
| 192 |
+
done=done,
|
| 193 |
+
error=error_msg,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
if done:
|
| 197 |
+
break
|
| 198 |
+
|
| 199 |
+
# Calculate final score via grader
|
| 200 |
+
final_stats = env.get_network_stats()
|
| 201 |
+
from server.graders import grade_stats
|
| 202 |
+
grade = grade_stats(task, final_stats)
|
| 203 |
+
# success = episode completed AND score meets threshold
|
| 204 |
+
success = grade.get("passed", False)
|
| 205 |
+
|
| 206 |
+
except Exception as e:
|
| 207 |
+
print(f"[DEBUG] Error during task {task}: {e}", file=sys.stderr)
|
| 208 |
+
success = False
|
| 209 |
+
finally:
|
| 210 |
+
log_end(success=success, steps=steps_taken, rewards=rewards)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def main():
|
| 214 |
+
try:
|
| 215 |
+
agent = InferenceAgent()
|
| 216 |
+
for task in ["easy", "medium", "hard"]:
|
| 217 |
+
run_task(agent, task)
|
| 218 |
+
except Exception as e:
|
| 219 |
+
print(f"Critical error: {e}", file=sys.stderr)
|
| 220 |
+
sys.exit(1)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
if __name__ == "__main__":
|
| 224 |
+
main()
|
models.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any, Dict, List, Optional
|
| 4 |
+
|
| 5 |
+
from pydantic import BaseModel, Field
|
| 6 |
+
# Standard OpenEnv types (if openenv-core is installed)
|
| 7 |
+
try:
|
| 8 |
+
from openenv.core.env_server.types import Action, Observation
|
| 9 |
+
except ImportError:
|
| 10 |
+
# Fallback if not installed
|
| 11 |
+
class Action(BaseModel):
|
| 12 |
+
pass
|
| 13 |
+
class Observation(BaseModel):
|
| 14 |
+
pass
|
| 15 |
+
|
| 16 |
+
# --- Custom Action/Observation classes as seen in video ---
|
| 17 |
+
|
| 18 |
+
class FirewallAction(Action):
|
| 19 |
+
"""Action for the AI Firewall environment."""
|
| 20 |
+
action: int = Field(..., description="Action index: 0=ALLOW, 1=BLOCK, 2=INSPECT, 3=SANDBOX, 4=RATE_LIMIT, 5=QUARANTINE")
|
| 21 |
+
session_id: Optional[str] = Field(None, description="Specific session to act upon")
|
| 22 |
+
|
| 23 |
+
class FirewallObservation(Observation):
|
| 24 |
+
"""Observation for the AI Firewall environment."""
|
| 25 |
+
features: List[float] = Field(..., description="22-dimensional normalized feature vector")
|
| 26 |
+
focus_session_id: Optional[str] = Field(None, description="ID of the session currently in focus")
|
| 27 |
+
|
| 28 |
+
# --- Original models from env/models.py ---
|
| 29 |
+
|
| 30 |
+
class ActionRecord(BaseModel):
|
| 31 |
+
tick: int
|
| 32 |
+
session_id: str
|
| 33 |
+
action: int
|
| 34 |
+
action_name: str
|
| 35 |
+
malicious: bool
|
| 36 |
+
reward: float
|
| 37 |
+
components: Dict[str, float]
|
| 38 |
+
|
| 39 |
+
class ResetRequest(BaseModel):
|
| 40 |
+
task: str = Field(default="easy", description="Task difficulty: easy, medium, hard")
|
| 41 |
+
seed: Optional[int] = Field(default=None, description="Random seed for reproducibility")
|
| 42 |
+
|
| 43 |
+
class StepRequest(BaseModel):
|
| 44 |
+
actions: Dict[str, int] = Field(default_factory=dict, description="Map of session_id to action index")
|
| 45 |
+
|
| 46 |
+
class StepSingleRequest(BaseModel):
|
| 47 |
+
action: int = Field(..., description="Action index (0-5) for the current focus session")
|
| 48 |
+
|
| 49 |
+
class ToolRequest(BaseModel):
|
| 50 |
+
kwargs: Dict[str, Any] = Field(default_factory=dict, description="Arguments for the tool call")
|
| 51 |
+
|
| 52 |
+
class StateResponse(BaseModel):
|
| 53 |
+
episode_id: int
|
| 54 |
+
task: str
|
| 55 |
+
step_count: int
|
| 56 |
+
current_tick: int
|
| 57 |
+
observation_dim: int
|
| 58 |
+
num_actions: int
|
| 59 |
+
budget_remaining: float
|
| 60 |
+
total_reward: float
|
| 61 |
+
pending_session_count: int
|
| 62 |
+
inspected_session_count: int
|
| 63 |
+
pending_session_ids: List[str]
|
| 64 |
+
inspected_session_ids: List[str]
|
| 65 |
+
queue_length: int
|
| 66 |
+
focus_session_id: Optional[str]
|
| 67 |
+
focus_observation: List[float]
|
| 68 |
+
|
| 69 |
+
class StepResponse(BaseModel):
|
| 70 |
+
reward: float
|
| 71 |
+
done: bool
|
| 72 |
+
state: StateResponse
|
| 73 |
+
info: Dict[str, Any]
|
| 74 |
+
|
| 75 |
+
class StepSingleResponse(BaseModel):
|
| 76 |
+
observation: List[float]
|
| 77 |
+
reward: float
|
| 78 |
+
done: bool
|
| 79 |
+
state: StateResponse
|
| 80 |
+
info: Dict[str, Any]
|
| 81 |
+
|
| 82 |
+
class EvaluateSessionResponse(BaseModel):
|
| 83 |
+
session_id: str
|
| 84 |
+
features: Dict[str, Any]
|
| 85 |
+
observation: List[float]
|
| 86 |
+
is_inspected: bool
|
| 87 |
+
revealed_malicious: Optional[bool]
|
| 88 |
+
expires_tick: int
|
| 89 |
+
|
| 90 |
+
class NetworkStatsResponse(BaseModel):
|
| 91 |
+
episode_id: int
|
| 92 |
+
task: str
|
| 93 |
+
tick: int
|
| 94 |
+
step_count: int
|
| 95 |
+
total_reward: float
|
| 96 |
+
budget_remaining: float
|
| 97 |
+
budget_used_pct: float
|
| 98 |
+
total_malicious: int
|
| 99 |
+
total_benign: int
|
| 100 |
+
detection_rate: float
|
| 101 |
+
false_positive_rate: float
|
| 102 |
+
efficiency: float
|
| 103 |
+
early_detection_bonus: float
|
| 104 |
+
cascade_prevention: float
|
| 105 |
+
correct_allows: int
|
| 106 |
+
inspections: int
|
| 107 |
+
expired_malicious: int
|
| 108 |
+
expired_benign: int
|
| 109 |
+
|
| 110 |
+
class HealthResponse(BaseModel):
|
| 111 |
+
status: str
|
| 112 |
+
version: str
|
| 113 |
+
|
| 114 |
+
class ToolsListResponse(BaseModel):
|
| 115 |
+
tools: List[str]
|
| 116 |
+
|
| 117 |
+
class TakeActionResponse(BaseModel):
|
| 118 |
+
reward: float
|
| 119 |
+
record: ActionRecord
|
| 120 |
+
|
| 121 |
+
class LLMChatRequest(BaseModel):
|
| 122 |
+
prompt: str
|
| 123 |
+
api_key: Optional[str] = None
|
| 124 |
+
base_url: Optional[str] = None
|
| 125 |
+
model: Optional[str] = None
|
| 126 |
+
|
| 127 |
+
class LLMChatResponse(BaseModel):
|
| 128 |
+
content: str
|
| 129 |
+
model: str
|
| 130 |
+
|
| 131 |
+
class LLMConfigResponse(BaseModel):
|
| 132 |
+
base_url: str
|
| 133 |
+
model: str
|
| 134 |
+
has_api_key: bool
|
| 135 |
+
|
| 136 |
+
class LLMTestResponse(BaseModel):
|
| 137 |
+
ok: bool
|
| 138 |
+
model: str
|
| 139 |
+
content: str
|
openenv.yaml
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: ai-firewall-openenv
|
| 3 |
+
version: "1.0.0"
|
| 4 |
+
description: "AI-driven adaptive firewall for automated threat detection"
|
| 5 |
+
type: space
|
| 6 |
+
runtime: fastapi
|
| 7 |
+
app: server.app:app
|
| 8 |
+
port: 7860
|
| 9 |
+
|
| 10 |
+
tasks:
|
| 11 |
+
easy:
|
| 12 |
+
name: "Perimeter Defense"
|
| 13 |
+
description: "200-step episode with obvious attacks"
|
| 14 |
+
grading_seed: 101
|
| 15 |
+
threshold: 0.70
|
| 16 |
+
medium:
|
| 17 |
+
name: "Mixed Threat Landscape"
|
| 18 |
+
description: "500-step episode with multi-stage attacks and ambiguous traffic"
|
| 19 |
+
grading_seed: 202
|
| 20 |
+
threshold: 0.50
|
| 21 |
+
hard:
|
| 22 |
+
name: "Advanced Persistent Threat"
|
| 23 |
+
description: "1000-step episode with adaptive APTs and stealth threats"
|
| 24 |
+
grading_seed: 303
|
| 25 |
+
threshold: 0.45
|
| 26 |
+
|
| 27 |
+
tools:
|
| 28 |
+
- name: evaluate_session
|
| 29 |
+
description: "Get detailed features and observation for a specific session"
|
| 30 |
+
- name: take_action
|
| 31 |
+
description: "Apply a firewall action to a session (ALLOW, BLOCK, etc.)"
|
| 32 |
+
- name: get_network_stats
|
| 33 |
+
description: "Get cumulative episode statistics and performance metrics"
|
| 34 |
+
- name: get_threat_intelligence
|
| 35 |
+
description: "Access current threat intelligence feed"
|
| 36 |
+
|
| 37 |
+
observation_space:
|
| 38 |
+
type: box
|
| 39 |
+
shape: [22]
|
| 40 |
+
low: 0.0
|
| 41 |
+
high: 1.0
|
| 42 |
+
|
| 43 |
+
action_space:
|
| 44 |
+
type: discrete
|
| 45 |
+
n: 6
|
| 46 |
+
labels:
|
| 47 |
+
0: ALLOW
|
| 48 |
+
1: BLOCK
|
| 49 |
+
2: INSPECT
|
| 50 |
+
3: SANDBOX
|
| 51 |
+
4: RATE_LIMIT
|
| 52 |
+
5: QUARANTINE
|
progresss.md
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Implementation Progress
|
| 2 |
+
|
| 3 |
+
## Status
|
| 4 |
+
|
| 5 |
+
- Completed: project scaffolding and package manifest
|
| 6 |
+
- Completed: core server environment (`traffic_generator`, `threat_engine`, `reward_engine`, `firewall_environment`, `graders`, `app`)
|
| 7 |
+
- Completed: baseline policies (`random_agent`, `heuristic_agent`) and evaluator
|
| 8 |
+
- Completed: OpenEnv config, Dockerfile, requirements, client wrapper
|
| 9 |
+
- Completed: docs and AI skill/workflow files
|
| 10 |
+
- Completed: syntax verification with `py -m compileall src tests`
|
| 11 |
+
- Completed: baseline end-to-end evaluation run
|
| 12 |
+
- Completed: virtual environment created at `.venv` using `py -m venv .venv` with `PYTHONDONTWRITEBYTECODE=1`
|
| 13 |
+
- Completed: toolchain installed inside `.venv` (`pytest`, `ruff`, `requests`, `numpy`, `scipy`, `fastapi`, `pydantic`, `uvicorn`)
|
| 14 |
+
- Completed: `pytest` validation passed (`5 passed`)
|
| 15 |
+
- Completed: `ruff check src tests` passed (`All checks passed!`)
|
| 16 |
+
- Completed: runtime smoke test for reset/step (`ok 22 False`)
|
| 17 |
+
- Completed: REVIEW_AND_TODO P0/P1 core fixes implemented (budget scaling, inspect flow, expiration metrics, PYTHONPATH stability, reward rebalance)
|
| 18 |
+
- Completed: scenario-aware threat/traffic behavior and adaptive attacker lifecycle improvements
|
| 19 |
+
- Completed: one-session-per-step mode (`step_single`) and framework spaces (`observation_space`/`action_space`)
|
| 20 |
+
- Completed: new integration safeguards (`always_block`/`always_allow`) in baseline evaluator
|
| 21 |
+
- Completed: expanded automated tests from 5 to 16 and all passing in `.venv`
|
| 22 |
+
- Completed: latest validation (`pytest`: 16 passed, `ruff`: all checks passed)
|
| 23 |
+
- Completed: compatibility fixes after refactor (`__init__ budget arg`, inspect dual-pool consistency, `step_single` focus observation state field)
|
| 24 |
+
- Completed: comprehensive test suite now fully green (`pytest`: 38 passed)
|
| 25 |
+
- Completed: lint cleanup across source and consolidated tests (`ruff`: all checks passed)
|
| 26 |
+
- Completed: grading anti-degeneracy gates (pass constraints for detection + false-positive complement)
|
| 27 |
+
- Completed: evaluator now confirms heuristic passes all tasks while random/block-all/allow-all fail pass gates
|
| 28 |
+
- Completed: docs + skills + workflows significantly expanded from stubs to implementation-level guidance
|
| 29 |
+
- Completed: hackathon compliance changes implemented (inference.py, Dockerfile, requirements, .env.example, .gitignore)
|
| 30 |
+
- Completed: server endpoints added (/web, /schema) and root import fix
|
| 31 |
+
- Completed: all blocking issues from `REVIEW_AND_TODO.md` resolved
|
| 32 |
+
- Completed: refactored project structure to match OpenEnv standard layout (models.py at root, environment in server/)
|
| 33 |
+
- Completed: consolidated all environment logic into `server/` and removed redundant directories
|
| 34 |
+
- Completed: updated Web Playground UI to match the standard OpenEnv interface
|
| 35 |
+
- Completed: verified system logic with `inference.py` output and FastAPI health checks
|
| 36 |
+
- Completed: verified project structure and syntax with `py_compile`
|
| 37 |
+
- Completed: implemented local Ollama/Qwen support as default LLM with remote fallback
|
| 38 |
+
- Completed: updated `.env.example` with Ollama/Qwen configuration options
|
| 39 |
+
|
| 40 |
+
## Decisions Applied
|
| 41 |
+
|
| 42 |
+
- Action space kept at 6 actions
|
| 43 |
+
- Observation space kept at 22 features
|
| 44 |
+
- OpenEnv target aligned to `openenv-core[core]>=0.2.2`
|
| 45 |
+
- Runtime mode set to CPU-oriented implementation
|
| 46 |
+
- Episode lengths follow 200/500/1000 task defaults
|
| 47 |
+
- Efficiency now remains non-zero for non-degenerate policies via scaled budget model
|
| 48 |
+
- Dependency cleanup: removed `scipy` from project dependency lists (unused)
|
| 49 |
+
- Pass/fail now requires both score threshold and minimum detection/availability constraints
|
pyproject.toml
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "adaptive_firewall_env"
|
| 3 |
+
version = "0.2.0"
|
| 4 |
+
description = "Adaptive AI Firewall RL environment for encrypted traffic decision making"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.11"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"fastapi>=0.112",
|
| 9 |
+
"uvicorn>=0.30",
|
| 10 |
+
"numpy>=1.26",
|
| 11 |
+
"pydantic>=2.0",
|
| 12 |
+
"requests>=2.32",
|
| 13 |
+
"openai>=1.30",
|
| 14 |
+
"python-dotenv>=1.0",
|
| 15 |
+
]
|
| 16 |
+
|
| 17 |
+
[project.scripts]
|
| 18 |
+
server = "server.app:main"
|
| 19 |
+
|
| 20 |
+
[build-system]
|
| 21 |
+
requires = ["hatchling"]
|
| 22 |
+
build-backend = "hatchling.build"
|
| 23 |
+
|
| 24 |
+
[tool.hatch.build.targets.wheel]
|
| 25 |
+
packages = ["server"]
|
| 26 |
+
|
| 27 |
+
[tool.pytest.ini_options]
|
| 28 |
+
pythonpath = ["."]
|
| 29 |
+
testpaths = ["tests"]
|
| 30 |
+
|
| 31 |
+
[tool.ruff]
|
| 32 |
+
line-length = 120
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi
|
| 2 |
+
uvicorn
|
| 3 |
+
numpy
|
| 4 |
+
pydantic
|
| 5 |
+
requests
|
| 6 |
+
openai
|
| 7 |
+
python-dotenv
|
scripts/validate-submission.sh
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
#
|
| 3 |
+
# validate-submission.sh — OpenEnv Submission Validator
|
| 4 |
+
#
|
| 5 |
+
# Checks that your HF Space is live, Docker image builds, and openenv validate passes.
|
| 6 |
+
#
|
| 7 |
+
# Prerequisites:
|
| 8 |
+
# - Docker: `https://docs.docker.com/get-docker/`
|
| 9 |
+
# - openenv-core: pip install openenv-core
|
| 10 |
+
# - curl (usually pre-installed)
|
| 11 |
+
#
|
| 12 |
+
# Run:
|
| 13 |
+
# curl -fsSL `https://raw.githubusercontent.com/<owner>/<repo>/main/scripts/validate-submission.sh` | bash -s -- <ping_url> [repo_dir]
|
| 14 |
+
#
|
| 15 |
+
# Or download and run locally:
|
| 16 |
+
# chmod +x validate-submission.sh
|
| 17 |
+
# ./validate-submission.sh <ping_url> [repo_dir]
|
| 18 |
+
#
|
| 19 |
+
# Arguments:
|
| 20 |
+
# ping_url Your HuggingFace Space URL (e.g. `https://your-space.hf.space)`
|
| 21 |
+
# repo_dir Path to your repo (default: current directory)
|
| 22 |
+
#
|
| 23 |
+
# Examples:
|
| 24 |
+
# ./validate-submission.sh `https://my-team.hf.space`
|
| 25 |
+
# ./validate-submission.sh `https://my-team.hf.space` ./my-repo
|
| 26 |
+
#
|
| 27 |
+
|
| 28 |
+
set -uo pipefail
|
| 29 |
+
|
| 30 |
+
DOCKER_BUILD_TIMEOUT=3600
|
| 31 |
+
if [ -t 1 ]; then
|
| 32 |
+
RED='\033[0;31m'
|
| 33 |
+
GREEN='\033[0;32m'
|
| 34 |
+
YELLOW='\033[1;33m'
|
| 35 |
+
BOLD='\033[1m'
|
| 36 |
+
NC='\033[0m'
|
| 37 |
+
else
|
| 38 |
+
RED='' GREEN='' YELLOW='' BOLD='' NC=''
|
| 39 |
+
fi
|
| 40 |
+
|
| 41 |
+
run_with_timeout() {
|
| 42 |
+
local secs="$1"; shift
|
| 43 |
+
if command -v timeout &>/dev/null; then
|
| 44 |
+
timeout "$secs" "$@"
|
| 45 |
+
elif command -v gtimeout &>/dev/null; then
|
| 46 |
+
gtimeout "$secs" "$@"
|
| 47 |
+
else
|
| 48 |
+
"$@" &
|
| 49 |
+
local pid=$!
|
| 50 |
+
( sleep "$secs" && kill "$pid" 2>/dev/null ) &
|
| 51 |
+
local watcher=$!
|
| 52 |
+
wait "$pid" 2>/dev/null
|
| 53 |
+
local rc=$?
|
| 54 |
+
kill "$watcher" 2>/dev/null
|
| 55 |
+
wait "$watcher" 2>/dev/null
|
| 56 |
+
return $rc
|
| 57 |
+
fi
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
portable_mktemp() {
|
| 61 |
+
local prefix="${1:-validate}"
|
| 62 |
+
mktemp "${TMPDIR:-/tmp}/${prefix}-XXXXXX" 2>/dev/null || mktemp
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
CLEANUP_FILES=()
|
| 66 |
+
cleanup() { rm -f "${CLEANUP_FILES[@]+"${CLEANUP_FILES[@]}"}"; }
|
| 67 |
+
trap cleanup EXIT
|
| 68 |
+
|
| 69 |
+
PING_URL="${1:-}"
|
| 70 |
+
REPO_DIR="${2:-.}"
|
| 71 |
+
|
| 72 |
+
if [ -z "$PING_URL" ]; then
|
| 73 |
+
printf "Usage: %s <ping_url> [repo_dir]\n" "$0"
|
| 74 |
+
printf "\n"
|
| 75 |
+
printf " ping_url Your HuggingFace Space URL (e.g. `https://your-space.hf.space)\n` "
|
| 76 |
+
printf " repo_dir Path to your repo (default: current directory)\n"
|
| 77 |
+
exit 1
|
| 78 |
+
fi
|
| 79 |
+
|
| 80 |
+
if ! REPO_DIR="$(cd "$REPO_DIR" 2>/dev/null && pwd)"; then
|
| 81 |
+
printf "Error: directory '%s' not found\n" "${2:-.}"
|
| 82 |
+
exit 1
|
| 83 |
+
fi
|
| 84 |
+
PING_URL="${PING_URL%/}"
|
| 85 |
+
export PING_URL
|
| 86 |
+
PASS=0
|
| 87 |
+
|
| 88 |
+
log() { printf "[%s] %b\n" "$(date -u +%H:%M:%S)" "$*"; }
|
| 89 |
+
pass() { log "${GREEN}PASSED${NC} -- $1"; PASS=$((PASS + 1)); }
|
| 90 |
+
fail() { log "${RED}FAILED${NC} -- $1"; }
|
| 91 |
+
hint() { printf " ${YELLOW}Hint:${NC} %b\n" "$1"; }
|
| 92 |
+
stop_at() {
|
| 93 |
+
printf "\n"
|
| 94 |
+
printf "${RED}${BOLD}Validation stopped at %s.${NC} Fix the above before continuing.\n" "$1"
|
| 95 |
+
exit 1
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
printf "\n"
|
| 99 |
+
printf "${BOLD}========================================${NC}\n"
|
| 100 |
+
printf "${BOLD} OpenEnv Submission Validator${NC}\n"
|
| 101 |
+
printf "${BOLD}========================================${NC}\n"
|
| 102 |
+
log "Repo: $REPO_DIR"
|
| 103 |
+
log "Ping URL: $PING_URL"
|
| 104 |
+
printf "\n"
|
| 105 |
+
|
| 106 |
+
log "${BOLD}Step 1/3: Pinging HF Space${NC} ($PING_URL/reset) ..."
|
| 107 |
+
|
| 108 |
+
CURL_OUTPUT=$(portable_mktemp "validate-curl")
|
| 109 |
+
CLEANUP_FILES+=("$CURL_OUTPUT")
|
| 110 |
+
HTTP_CODE=$(curl.exe -s -o /dev/null -w "%{http_code}" -X POST \
|
| 111 |
+
-H "Content-Type: application/json" -d "{\"task\":\"easy\"}" \
|
| 112 |
+
"$PING_URL/reset" --max-time 30 || printf "000")
|
| 113 |
+
HTTP_CODE=$(echo $HTTP_CODE | tr -d '\r' | cut -c 1-3)
|
| 114 |
+
|
| 115 |
+
if [ "$HTTP_CODE" = "200" ]; then
|
| 116 |
+
pass "HF Space is live and responds to /reset"
|
| 117 |
+
elif [ "$HTTP_CODE" = "000" ]; then
|
| 118 |
+
fail "HF Space not reachable (connection failed or timed out)"
|
| 119 |
+
hint "Check your network connection and that the Space is running."
|
| 120 |
+
hint "Try: curl -s -o /dev/null -w '%%{http_code}' -X POST $PING_URL/reset"
|
| 121 |
+
stop_at "Step 1"
|
| 122 |
+
else
|
| 123 |
+
fail "HF Space /reset returned HTTP $HTTP_CODE (expected 200)"
|
| 124 |
+
hint "Make sure your Space is running and the URL is correct."
|
| 125 |
+
hint "Try opening $PING_URL in your browser first."
|
| 126 |
+
stop_at "Step 1"
|
| 127 |
+
fi
|
| 128 |
+
|
| 129 |
+
log "${BOLD}Step 2/3: Running docker build${NC} ..."
|
| 130 |
+
|
| 131 |
+
if ! command -v docker &>/dev/null; then
|
| 132 |
+
fail "docker command not found"
|
| 133 |
+
hint "Install Docker: `https://docs.docker.com/get-docker/` "
|
| 134 |
+
stop_at "Step 2"
|
| 135 |
+
fi
|
| 136 |
+
|
| 137 |
+
if [ -f "$REPO_DIR/Dockerfile" ]; then
|
| 138 |
+
DOCKER_CONTEXT="$REPO_DIR"
|
| 139 |
+
elif [ -f "$REPO_DIR/server/Dockerfile" ]; then
|
| 140 |
+
DOCKER_CONTEXT="$REPO_DIR/server"
|
| 141 |
+
else
|
| 142 |
+
fail "No Dockerfile found in repo root or server/ directory"
|
| 143 |
+
stop_at "Step 2"
|
| 144 |
+
fi
|
| 145 |
+
|
| 146 |
+
log " Found Dockerfile in $DOCKER_CONTEXT"
|
| 147 |
+
|
| 148 |
+
BUILD_OK=false
|
| 149 |
+
BUILD_OUTPUT=$(run_with_timeout "$DOCKER_BUILD_TIMEOUT" docker build "$DOCKER_CONTEXT" 2>&1) && BUILD_OK=true
|
| 150 |
+
|
| 151 |
+
if [ "$BUILD_OK" = true ]; then
|
| 152 |
+
pass "Docker build succeeded"
|
| 153 |
+
else
|
| 154 |
+
fail "Docker build failed (timeout=${DOCKER_BUILD_TIMEOUT}s)"
|
| 155 |
+
printf "%s\n" "$BUILD_OUTPUT" | tail -20
|
| 156 |
+
stop_at "Step 2"
|
| 157 |
+
fi
|
| 158 |
+
|
| 159 |
+
log "${BOLD}Step 3/3: Running openenv validate${NC} ..."
|
| 160 |
+
|
| 161 |
+
if ! command -v openenv &>/dev/null; then
|
| 162 |
+
fail "openenv command not found"
|
| 163 |
+
hint "Install it: pip install openenv-core"
|
| 164 |
+
stop_at "Step 3"
|
| 165 |
+
fi
|
| 166 |
+
|
| 167 |
+
VALIDATE_OK=false
|
| 168 |
+
VALIDATE_OUTPUT=$(cd "$REPO_DIR" && openenv validate 2>&1) && VALIDATE_OK=true
|
| 169 |
+
|
| 170 |
+
if [ "$VALIDATE_OK" = true ]; then
|
| 171 |
+
pass "openenv validate passed"
|
| 172 |
+
[ -n "$VALIDATE_OUTPUT" ] && log " $VALIDATE_OUTPUT"
|
| 173 |
+
else
|
| 174 |
+
fail "openenv validate failed"
|
| 175 |
+
printf "%s\n" "$VALIDATE_OUTPUT"
|
| 176 |
+
stop_at "Step 3"
|
| 177 |
+
fi
|
| 178 |
+
|
| 179 |
+
printf "\n"
|
| 180 |
+
printf "${BOLD}========================================${NC}\n"
|
| 181 |
+
printf "${GREEN}${BOLD} All 3/3 checks passed!${NC}\n"
|
| 182 |
+
printf "${GREEN}${BOLD} Your submission is ready to submit.${NC}\n"
|
| 183 |
+
printf "${BOLD}========================================${NC}\n"
|
| 184 |
+
printf "\n"
|
| 185 |
+
|
| 186 |
+
exit 0
|
server/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Package marker
|
server/app.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI server exposing the Adaptive AI Firewall environment.
|
| 2 |
+
|
| 3 |
+
Endpoints:
|
| 4 |
+
POST /reset — Start a new episode
|
| 5 |
+
POST /step — Multi-session step (batch actions)
|
| 6 |
+
POST /step_single — Single-session step (Gymnasium-compatible)
|
| 7 |
+
GET /state — Current environment state
|
| 8 |
+
GET /tools — List available tool names
|
| 9 |
+
POST /tool/{name} — Call a specific tool
|
| 10 |
+
GET /health — Health check
|
| 11 |
+
GET /stats — Current episode statistics
|
| 12 |
+
"""
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
from typing import Any
|
| 17 |
+
|
| 18 |
+
from fastapi import FastAPI, HTTPException
|
| 19 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 20 |
+
from fastapi.responses import HTMLResponse
|
| 21 |
+
from dotenv import load_dotenv
|
| 22 |
+
|
| 23 |
+
from server.firewall_environment import FirewallEnvironment, ACTIONS
|
| 24 |
+
from models import (
|
| 25 |
+
HealthResponse,
|
| 26 |
+
NetworkStatsResponse,
|
| 27 |
+
ResetRequest,
|
| 28 |
+
StateResponse,
|
| 29 |
+
StepRequest,
|
| 30 |
+
StepResponse,
|
| 31 |
+
StepSingleRequest,
|
| 32 |
+
StepSingleResponse,
|
| 33 |
+
ToolRequest,
|
| 34 |
+
ToolsListResponse,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
load_dotenv()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _clean_env_value(value: str) -> str:
|
| 41 |
+
return value.strip().strip("`").strip().strip("'").strip('"').strip()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _resolve_api_key(value: str | None) -> str:
|
| 45 |
+
return _clean_env_value(value or os.getenv("HF_TOKEN") or "")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _resolve_model(value: str | None) -> str:
|
| 49 |
+
return _clean_env_value(value or os.getenv("MODEL_NAME") or "")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _resolve_base_url(value: str | None) -> str:
|
| 53 |
+
return _clean_env_value(
|
| 54 |
+
value
|
| 55 |
+
or os.getenv("API_BASE_URL")
|
| 56 |
+
or ""
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
PLAYGROUND_HTML = """<!doctype html>
|
| 60 |
+
<html lang="en">
|
| 61 |
+
<head>
|
| 62 |
+
<meta charset="utf-8"/>
|
| 63 |
+
<meta name="viewport" content="width=device-width,initial-scale=1"/>
|
| 64 |
+
<title>Adaptive Firewall Playground</title>
|
| 65 |
+
<style>
|
| 66 |
+
body{font-family:Arial,sans-serif;background:#0b1220;color:#e5e7eb;margin:0;padding:24px}
|
| 67 |
+
.card{max-width:980px;margin:0 auto;background:#111827;border:1px solid #1f2937;border-radius:12px;padding:18px}
|
| 68 |
+
h1{margin-top:0;font-size:22px}
|
| 69 |
+
label{display:block;font-size:12px;margin:10px 0 4px}
|
| 70 |
+
input,textarea,button{width:100%;box-sizing:border-box;border-radius:8px;border:1px solid #374151;background:#0f172a;color:#e5e7eb;padding:10px}
|
| 71 |
+
textarea{min-height:120px;resize:vertical}
|
| 72 |
+
button{background:#2563eb;border:none;cursor:pointer;font-weight:600;margin-top:12px}
|
| 73 |
+
button:disabled{opacity:.6;cursor:not-allowed}
|
| 74 |
+
pre{white-space:pre-wrap;background:#0f172a;border:1px solid #374151;border-radius:8px;padding:12px;min-height:120px;overflow:auto}
|
| 75 |
+
.grid{display:grid;grid-template-columns:1fr 1fr;gap:10px}
|
| 76 |
+
.row{display:grid;grid-template-columns:1fr 1fr 1fr;gap:10px}
|
| 77 |
+
.muted{font-size:12px;color:#93c5fd}
|
| 78 |
+
.ok{color:#86efac}
|
| 79 |
+
.bad{color:#fca5a5}
|
| 80 |
+
.btn-step{background:#22c55e}
|
| 81 |
+
.btn-reset{background:#64748b}
|
| 82 |
+
.btn-state{background:#64748b}
|
| 83 |
+
</style>
|
| 84 |
+
</head>
|
| 85 |
+
<body>
|
| 86 |
+
<div class="card">
|
| 87 |
+
<h1>Playground</h1>
|
| 88 |
+
<p class="muted">Click Reset to start a new episode.</p>
|
| 89 |
+
|
| 90 |
+
<label>Message / Action ID</label>
|
| 91 |
+
<input id="action_input" type="number" value="0" min="0" max="5" placeholder="Enter action index (0-5)..." />
|
| 92 |
+
|
| 93 |
+
<div class="row">
|
| 94 |
+
<button id="btn_step" class="btn-step">Step</button>
|
| 95 |
+
<button id="btn_reset" class="btn-reset">Reset</button>
|
| 96 |
+
<button id="btn_state" class="btn-state">Get state</button>
|
| 97 |
+
</div>
|
| 98 |
+
|
| 99 |
+
<div id="status" class="muted" style="margin-top:10px">Ready</div>
|
| 100 |
+
|
| 101 |
+
<label>Raw JSON response</label>
|
| 102 |
+
<pre id="output">{}</pre>
|
| 103 |
+
</div>
|
| 104 |
+
|
| 105 |
+
<script>
|
| 106 |
+
const output = document.getElementById("output");
|
| 107 |
+
const status = document.getElementById("status");
|
| 108 |
+
const actionInput = document.getElementById("action_input");
|
| 109 |
+
|
| 110 |
+
async function call(path, method='GET', body=null) {
|
| 111 |
+
status.textContent = "Calling " + path + "...";
|
| 112 |
+
try {
|
| 113 |
+
const options = {
|
| 114 |
+
method: method,
|
| 115 |
+
headers: {"Content-Type":"application/json"}
|
| 116 |
+
};
|
| 117 |
+
if (body) options.body = JSON.stringify(body);
|
| 118 |
+
|
| 119 |
+
const res = await fetch(path, options);
|
| 120 |
+
const data = await res.json();
|
| 121 |
+
output.textContent = JSON.stringify(data, null, 2);
|
| 122 |
+
status.textContent = "Success";
|
| 123 |
+
return data;
|
| 124 |
+
} catch (err) {
|
| 125 |
+
status.textContent = "Error: " + err;
|
| 126 |
+
output.textContent = String(err);
|
| 127 |
+
}
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
document.getElementById("btn_step").onclick = () => {
|
| 131 |
+
const action = parseInt(actionInput.value);
|
| 132 |
+
call("/step_single", "POST", {action: action});
|
| 133 |
+
};
|
| 134 |
+
|
| 135 |
+
document.getElementById("btn_reset").onclick = () => {
|
| 136 |
+
call("/reset", "POST", {task: "easy"});
|
| 137 |
+
};
|
| 138 |
+
|
| 139 |
+
document.getElementById("btn_state").onclick = () => {
|
| 140 |
+
call("/state", "GET");
|
| 141 |
+
};
|
| 142 |
+
</script>
|
| 143 |
+
</body>
|
| 144 |
+
</html>"""
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
env = FirewallEnvironment(seed=42)
|
| 148 |
+
app = FastAPI(
|
| 149 |
+
title="Adaptive AI Firewall OpenEnv",
|
| 150 |
+
version="0.2.0",
|
| 151 |
+
description="RL environment for adaptive firewall decision making on encrypted traffic.",
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
app.add_middleware(
|
| 155 |
+
CORSMiddleware,
|
| 156 |
+
allow_origins=["*"],
|
| 157 |
+
allow_credentials=True,
|
| 158 |
+
allow_methods=["*"],
|
| 159 |
+
allow_headers=["*"],
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
@app.get("/health", response_model=HealthResponse)
|
| 164 |
+
def health() -> HealthResponse:
|
| 165 |
+
return HealthResponse(status="ok", version="0.2.0")
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
@app.post("/reset", response_model=StateResponse)
|
| 169 |
+
def reset(request: ResetRequest) -> StateResponse:
|
| 170 |
+
try:
|
| 171 |
+
state = env.reset(task=request.task, seed=request.seed)
|
| 172 |
+
return StateResponse(**state)
|
| 173 |
+
except ValueError as e:
|
| 174 |
+
raise HTTPException(status_code=400, detail=str(e)) from e
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
@app.post("/step", response_model=StepResponse)
|
| 178 |
+
def step(request: StepRequest) -> StepResponse:
|
| 179 |
+
result = env.step(action_map=request.actions)
|
| 180 |
+
return StepResponse(**result)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
@app.post("/step_single", response_model=StepSingleResponse)
|
| 184 |
+
def step_single(request: StepSingleRequest) -> StepSingleResponse:
|
| 185 |
+
try:
|
| 186 |
+
result = env.step_single(action=request.action)
|
| 187 |
+
return StepSingleResponse(**result)
|
| 188 |
+
except ValueError as e:
|
| 189 |
+
raise HTTPException(status_code=400, detail=str(e)) from e
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
@app.get("/state", response_model=StateResponse)
|
| 193 |
+
def state() -> StateResponse:
|
| 194 |
+
return StateResponse(**env.state())
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
@app.get("/stats", response_model=NetworkStatsResponse)
|
| 198 |
+
def stats() -> NetworkStatsResponse:
|
| 199 |
+
return NetworkStatsResponse(**env.get_network_stats())
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
@app.get("/tools", response_model=ToolsListResponse)
|
| 203 |
+
def list_tools() -> ToolsListResponse:
|
| 204 |
+
return ToolsListResponse(tools=env.list_tools())
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
@app.get("/web", response_class=HTMLResponse)
|
| 208 |
+
def web_interface() -> HTMLResponse:
|
| 209 |
+
return HTMLResponse(content=PLAYGROUND_HTML)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
@app.get("/schema")
|
| 213 |
+
def schema() -> Any:
|
| 214 |
+
return {
|
| 215 |
+
"observation_space": {
|
| 216 |
+
"type": "Box",
|
| 217 |
+
"shape": [22],
|
| 218 |
+
"low": 0.0,
|
| 219 |
+
"high": 1.0,
|
| 220 |
+
},
|
| 221 |
+
"action_space": {
|
| 222 |
+
"type": "Discrete",
|
| 223 |
+
"n": 6,
|
| 224 |
+
"actions": ACTIONS,
|
| 225 |
+
},
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
@app.post("/tool/{name}")
|
| 230 |
+
def call_tool(name: str, request: ToolRequest) -> Any:
|
| 231 |
+
try:
|
| 232 |
+
if name == "evaluate_session":
|
| 233 |
+
return env.evaluate_session(request.kwargs["session_id"])
|
| 234 |
+
if name == "take_action":
|
| 235 |
+
reward, record = env.take_action(
|
| 236 |
+
session_id=request.kwargs["session_id"],
|
| 237 |
+
action=int(request.kwargs["action"]),
|
| 238 |
+
)
|
| 239 |
+
return {"reward": reward, "record": record}
|
| 240 |
+
if name == "get_network_stats":
|
| 241 |
+
return env.get_network_stats()
|
| 242 |
+
if name == "get_threat_intelligence":
|
| 243 |
+
return env.get_threat_intelligence()
|
| 244 |
+
raise HTTPException(status_code=404, detail=f"unknown tool: {name}")
|
| 245 |
+
except KeyError as exc:
|
| 246 |
+
raise HTTPException(status_code=400, detail=f"missing key: {exc}") from exc
|
| 247 |
+
except ValueError as exc:
|
| 248 |
+
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def main() -> None:
|
| 252 |
+
import uvicorn
|
| 253 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
if __name__ == "__main__":
|
| 257 |
+
main()
|
server/baseline/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Package marker
|
server/baseline/heuristic_agent.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Heuristic baseline agent for the Adaptive AI Firewall environment.
|
| 2 |
+
|
| 3 |
+
Uses the same 8-rule heuristic as inference.py for deterministic testing.
|
| 4 |
+
"""
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
from typing import Dict, List
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def heuristic_policy(env, session_ids: List[str]) -> Dict[str, int]:
|
| 11 |
+
"""Rule-based policy using session features and threat intelligence."""
|
| 12 |
+
threat_intel = env.get_threat_intelligence()
|
| 13 |
+
known_bad_ports = set(threat_intel.get("known_bad_ports", []))
|
| 14 |
+
actions: Dict[str, int] = {}
|
| 15 |
+
|
| 16 |
+
for sid in session_ids:
|
| 17 |
+
try:
|
| 18 |
+
data = env.evaluate_session(sid)
|
| 19 |
+
except KeyError:
|
| 20 |
+
actions[sid] = 0
|
| 21 |
+
continue
|
| 22 |
+
|
| 23 |
+
features = data.get("features", {})
|
| 24 |
+
|
| 25 |
+
# If already revealed as malicious, block immediately
|
| 26 |
+
if data.get("revealed_malicious") is True:
|
| 27 |
+
actions[sid] = 1
|
| 28 |
+
continue
|
| 29 |
+
|
| 30 |
+
dst_port = int(features.get("dst_port", 0))
|
| 31 |
+
history = float(features.get("session_history_score", 1.0))
|
| 32 |
+
entropy = float(features.get("entropy_score", 0.0))
|
| 33 |
+
reuse = float(features.get("connection_reuse", 1.0))
|
| 34 |
+
self_signed = int(features.get("is_self_signed", 0))
|
| 35 |
+
ja3 = int(features.get("ja3_hash_cluster", 0))
|
| 36 |
+
geo = float(features.get("geo_distance", 0.0))
|
| 37 |
+
cert_valid = float(features.get("cert_validity_days", 999.0))
|
| 38 |
+
tls_ver = int(features.get("tls_version", 1))
|
| 39 |
+
dns_q = int(features.get("dns_query_count", 0))
|
| 40 |
+
dur = float(features.get("duration_ms", 500.0))
|
| 41 |
+
pkts = int(features.get("packet_count", 10))
|
| 42 |
+
|
| 43 |
+
if dst_port in known_bad_ports and history < 0.50:
|
| 44 |
+
actions[sid] = 1
|
| 45 |
+
elif self_signed == 1 and history < 0.45:
|
| 46 |
+
actions[sid] = 5
|
| 47 |
+
elif entropy > 0.55 and reuse < 0.25:
|
| 48 |
+
actions[sid] = 2
|
| 49 |
+
elif geo > 4000.0 and history < 0.40:
|
| 50 |
+
actions[sid] = 2
|
| 51 |
+
elif ja3 >= 180:
|
| 52 |
+
actions[sid] = 1
|
| 53 |
+
elif dur < 60.0 and pkts > 100:
|
| 54 |
+
actions[sid] = 4
|
| 55 |
+
elif cert_valid < 80.0 and tls_ver == 0:
|
| 56 |
+
actions[sid] = 2
|
| 57 |
+
elif reuse < 0.10 and dns_q >= 4:
|
| 58 |
+
actions[sid] = 2
|
| 59 |
+
else:
|
| 60 |
+
actions[sid] = 0
|
| 61 |
+
|
| 62 |
+
return actions
|
server/baseline/random_agent.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Random baseline agent for the Adaptive AI Firewall environment."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from typing import Callable, Dict, List
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def random_policy(seed: int = 42) -> Callable:
|
| 10 |
+
"""Return a random policy function seeded for reproducibility."""
|
| 11 |
+
rng = np.random.default_rng(seed)
|
| 12 |
+
|
| 13 |
+
def _policy(env, session_ids: List[str]) -> Dict[str, int]:
|
| 14 |
+
return {sid: int(rng.integers(0, 6)) for sid in session_ids}
|
| 15 |
+
|
| 16 |
+
return _policy
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def block_all_policy(env, session_ids: List[str]) -> Dict[str, int]:
|
| 20 |
+
"""Block every session — useful as a degenerate baseline."""
|
| 21 |
+
return {sid: 1 for sid in session_ids}
|
server/firewall_environment.py
ADDED
|
@@ -0,0 +1,490 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, Dict, List, Optional, Set, Tuple
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
# Updated imports to reflect new structure
|
| 9 |
+
from server.utils.reward_engine import (
|
| 10 |
+
ACTIONS, BLOCKING_ACTIONS, RewardEngine,
|
| 11 |
+
)
|
| 12 |
+
from server.utils.threat_engine import ThreatEngine
|
| 13 |
+
from server.utils.data_loader import (
|
| 14 |
+
FEATURE_ORDER, TrafficGenerator,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
TASK_CONFIGS = {
|
| 19 |
+
"easy": {
|
| 20 |
+
"max_steps": 200,
|
| 21 |
+
"benign_ratio": 0.80,
|
| 22 |
+
"threat_probability": 0.12,
|
| 23 |
+
"traffic_lambda": 5,
|
| 24 |
+
"budget": 100.0, # ~0.50 budget per step
|
| 25 |
+
},
|
| 26 |
+
"medium": {
|
| 27 |
+
"max_steps": 500,
|
| 28 |
+
"benign_ratio": 0.65,
|
| 29 |
+
"threat_probability": 0.22,
|
| 30 |
+
"traffic_lambda": 6,
|
| 31 |
+
"budget": 300.0, # ~0.60 budget per step
|
| 32 |
+
},
|
| 33 |
+
"hard": {
|
| 34 |
+
"max_steps": 1000,
|
| 35 |
+
"benign_ratio": 0.70,
|
| 36 |
+
"threat_probability": 0.30,
|
| 37 |
+
"traffic_lambda": 7,
|
| 38 |
+
"budget": 600.0, # ~0.60 budget per step
|
| 39 |
+
},
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
NUM_ACTIONS = len(ACTIONS)
|
| 43 |
+
OBS_DIM = len(FEATURE_ORDER)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@dataclass
|
| 47 |
+
class EpisodeMetrics:
|
| 48 |
+
"""Tracks all metrics needed for grading."""
|
| 49 |
+
detections: int = 0
|
| 50 |
+
malicious_seen: int = 0
|
| 51 |
+
false_positives: int = 0
|
| 52 |
+
benign_seen: int = 0
|
| 53 |
+
early_detection_sum: float = 0.0
|
| 54 |
+
cascade_failures: int = 0
|
| 55 |
+
total_cost: float = 0.0
|
| 56 |
+
sessions_expired_malicious: int = 0
|
| 57 |
+
sessions_expired_benign: int = 0
|
| 58 |
+
correct_allows: int = 0
|
| 59 |
+
inspections: int = 0
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class FirewallEnvironment:
|
| 63 |
+
"""Adaptive AI Firewall RL environment.
|
| 64 |
+
|
| 65 |
+
OpenEnv-compatible: reset(), step(), state()
|
| 66 |
+
|
| 67 |
+
Key design (from RL perspective):
|
| 68 |
+
- Observation: 22-dim normalized [0,1] vector per session
|
| 69 |
+
- Action: Discrete(6) — ALLOW, BLOCK, INSPECT, SANDBOX, RATE_LIMIT, QUARANTINE
|
| 70 |
+
- Reward: multi-objective (security + availability + efficiency + timeliness)
|
| 71 |
+
- Done: when max_steps reached or budget depleted
|
| 72 |
+
- INSPECT keeps session alive for a second action (two-phase decision)
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
def __init__(self, seed: int = 0, budget: Optional[float] = None) -> None:
|
| 76 |
+
self.base_seed = seed
|
| 77 |
+
self.base_budget_override = budget
|
| 78 |
+
self.generator = TrafficGenerator(seed=seed)
|
| 79 |
+
self.threat_engine = ThreatEngine(seed=seed + 1)
|
| 80 |
+
self.reward_engine = RewardEngine()
|
| 81 |
+
self.rng = np.random.default_rng(seed + 2)
|
| 82 |
+
|
| 83 |
+
self.episode_id = 0
|
| 84 |
+
self.step_count = 0
|
| 85 |
+
self.current_tick = 0
|
| 86 |
+
self.task = "easy"
|
| 87 |
+
self.max_steps = TASK_CONFIGS[self.task]["max_steps"]
|
| 88 |
+
default_budget = TASK_CONFIGS[self.task]["budget"]
|
| 89 |
+
if self.base_budget_override is not None:
|
| 90 |
+
default_budget = max(default_budget, float(self.base_budget_override))
|
| 91 |
+
self.budget_remaining = default_budget
|
| 92 |
+
self.initial_budget = self.budget_remaining
|
| 93 |
+
self.total_reward = 0.0
|
| 94 |
+
|
| 95 |
+
self.pending_sessions: Dict[str, Dict] = {}
|
| 96 |
+
self.inspected_sessions: Dict[str, Dict] = {} # sessions awaiting 2nd action
|
| 97 |
+
self.action_log: List[Dict] = []
|
| 98 |
+
self._blocked_attacker_ids: Set[str] = set()
|
| 99 |
+
self.metrics = EpisodeMetrics()
|
| 100 |
+
|
| 101 |
+
# For single-session mode
|
| 102 |
+
self._session_queue: List[str] = []
|
| 103 |
+
|
| 104 |
+
# ══════════════════════════════════════════════════════════════════
|
| 105 |
+
# OpenEnv API
|
| 106 |
+
# ══════════════════════════════════════════════════════════════════
|
| 107 |
+
|
| 108 |
+
def reset(self, task: str = "easy", seed: Optional[int] = None) -> Dict:
|
| 109 |
+
"""Reset environment for a new episode."""
|
| 110 |
+
if task not in TASK_CONFIGS:
|
| 111 |
+
raise ValueError(f"unknown task: {task}")
|
| 112 |
+
|
| 113 |
+
used_seed = self.base_seed if seed is None else seed
|
| 114 |
+
self.generator = TrafficGenerator(seed=used_seed)
|
| 115 |
+
self.threat_engine = ThreatEngine(seed=used_seed + 1)
|
| 116 |
+
self.rng = np.random.default_rng(used_seed + 2)
|
| 117 |
+
|
| 118 |
+
self.episode_id += 1
|
| 119 |
+
self.step_count = 0
|
| 120 |
+
self.current_tick = 0
|
| 121 |
+
self.task = task
|
| 122 |
+
config = TASK_CONFIGS[task]
|
| 123 |
+
self.max_steps = config["max_steps"]
|
| 124 |
+
task_budget = float(config["budget"])
|
| 125 |
+
if self.base_budget_override is not None:
|
| 126 |
+
task_budget = max(task_budget, float(self.base_budget_override))
|
| 127 |
+
self.initial_budget = task_budget
|
| 128 |
+
self.budget_remaining = self.initial_budget
|
| 129 |
+
self.total_reward = 0.0
|
| 130 |
+
|
| 131 |
+
self.pending_sessions = {}
|
| 132 |
+
self.inspected_sessions = {}
|
| 133 |
+
self.action_log = []
|
| 134 |
+
self._blocked_attacker_ids = set()
|
| 135 |
+
self.metrics = EpisodeMetrics()
|
| 136 |
+
self._session_queue = []
|
| 137 |
+
|
| 138 |
+
# Spawn initial sessions
|
| 139 |
+
self._spawn_sessions()
|
| 140 |
+
self._rebuild_queue()
|
| 141 |
+
|
| 142 |
+
return self.state()
|
| 143 |
+
|
| 144 |
+
def step(self, action_map: Optional[Dict[str, int]] = None) -> Dict:
|
| 145 |
+
"""Multi-session step: agent provides actions for multiple sessions at once."""
|
| 146 |
+
action_map = action_map or {}
|
| 147 |
+
step_reward = 0.0
|
| 148 |
+
|
| 149 |
+
for session_id, action in action_map.items():
|
| 150 |
+
# Check both pending and inspected pools
|
| 151 |
+
if session_id in self.pending_sessions or session_id in self.inspected_sessions:
|
| 152 |
+
reward, _ = self._apply_action(session_id, action)
|
| 153 |
+
step_reward += reward
|
| 154 |
+
|
| 155 |
+
expired_penalty = self._expire_sessions()
|
| 156 |
+
step_reward += expired_penalty
|
| 157 |
+
self.total_reward += step_reward
|
| 158 |
+
self.step_count += 1
|
| 159 |
+
self.current_tick += 1
|
| 160 |
+
|
| 161 |
+
done = self.step_count >= self.max_steps or self.budget_remaining <= 0.0
|
| 162 |
+
|
| 163 |
+
if not done:
|
| 164 |
+
self._spawn_sessions()
|
| 165 |
+
self._rebuild_queue()
|
| 166 |
+
|
| 167 |
+
# Calculate score using the deterministic grader logic
|
| 168 |
+
final_stats = self.get_network_stats()
|
| 169 |
+
from server.graders import grade_stats
|
| 170 |
+
grade = grade_stats(self.task, final_stats)
|
| 171 |
+
return {
|
| 172 |
+
"reward": step_reward,
|
| 173 |
+
"done": done,
|
| 174 |
+
"state": self.state(),
|
| 175 |
+
"info": {
|
| 176 |
+
"expired_penalty": expired_penalty,
|
| 177 |
+
"attacker_outcomes": self.threat_engine.attacker_outcomes(),
|
| 178 |
+
"score": grade["score"],
|
| 179 |
+
"passed": grade["passed"]
|
| 180 |
+
},
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
def step_single(self, action: int) -> Dict:
|
| 184 |
+
"""Single-session step: present one session, agent picks one action.
|
| 185 |
+
|
| 186 |
+
Compatible with Gymnasium Discrete(6).
|
| 187 |
+
Returns observation of the NEXT session, or zeros if episode done.
|
| 188 |
+
"""
|
| 189 |
+
if action not in ACTIONS:
|
| 190 |
+
raise ValueError(f"invalid action: {action}")
|
| 191 |
+
|
| 192 |
+
step_reward = 0.0
|
| 193 |
+
info: Dict[str, Any] = {}
|
| 194 |
+
|
| 195 |
+
# Act on the current session
|
| 196 |
+
if self._session_queue:
|
| 197 |
+
session_id = self._session_queue.pop(0)
|
| 198 |
+
if session_id in self.pending_sessions or session_id in self.inspected_sessions:
|
| 199 |
+
reward, record = self._apply_action(session_id, action)
|
| 200 |
+
step_reward += reward
|
| 201 |
+
info["action_record"] = record
|
| 202 |
+
|
| 203 |
+
self.total_reward = round(self.total_reward + step_reward, 4)
|
| 204 |
+
self.step_count += 1
|
| 205 |
+
|
| 206 |
+
# If queue is empty, advance tick
|
| 207 |
+
if not self._session_queue:
|
| 208 |
+
self.current_tick += 1
|
| 209 |
+
expired_penalty = self._expire_sessions()
|
| 210 |
+
# step_reward for the final session in tick includes the expiration penalty
|
| 211 |
+
step_reward += expired_penalty
|
| 212 |
+
self.total_reward = round(self.total_reward + expired_penalty, 4)
|
| 213 |
+
done = self.step_count >= self.max_steps or self.budget_remaining <= 0.0
|
| 214 |
+
if not done:
|
| 215 |
+
self._spawn_sessions()
|
| 216 |
+
self._rebuild_queue()
|
| 217 |
+
else:
|
| 218 |
+
done = self.step_count >= self.max_steps or self.budget_remaining <= 0.0
|
| 219 |
+
|
| 220 |
+
# Build next observation
|
| 221 |
+
next_obs = self._current_observation()
|
| 222 |
+
|
| 223 |
+
return {
|
| 224 |
+
"observation": next_obs,
|
| 225 |
+
"reward": step_reward,
|
| 226 |
+
"done": done,
|
| 227 |
+
"state": {
|
| 228 |
+
**self.state(),
|
| 229 |
+
"focus_observation": next_obs,
|
| 230 |
+
"focus_session_id": self._session_queue[0] if self._session_queue else None,
|
| 231 |
+
},
|
| 232 |
+
"info": info,
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
def state(self) -> Dict:
|
| 236 |
+
"""Return current environment state (OpenEnv API)."""
|
| 237 |
+
all_sessions = {**self.pending_sessions, **self.inspected_sessions}
|
| 238 |
+
top_ids = list(all_sessions.keys())[:10]
|
| 239 |
+
focus_session_id = self._session_queue[0] if self._session_queue else None
|
| 240 |
+
return {
|
| 241 |
+
"episode_id": self.episode_id,
|
| 242 |
+
"task": self.task,
|
| 243 |
+
"step_count": self.step_count,
|
| 244 |
+
"current_tick": self.current_tick,
|
| 245 |
+
"observation_dim": OBS_DIM,
|
| 246 |
+
"num_actions": NUM_ACTIONS,
|
| 247 |
+
"budget_remaining": round(self.budget_remaining, 4),
|
| 248 |
+
"total_reward": round(self.total_reward, 4),
|
| 249 |
+
"pending_session_count": len(self.pending_sessions),
|
| 250 |
+
"inspected_session_count": len(self.inspected_sessions),
|
| 251 |
+
"pending_session_ids": top_ids,
|
| 252 |
+
"inspected_session_ids": list(self.inspected_sessions.keys())[:10],
|
| 253 |
+
"queue_length": len(self._session_queue),
|
| 254 |
+
"focus_session_id": focus_session_id,
|
| 255 |
+
"focus_observation": self._current_observation(),
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
# ══════════════════════════════════════════════════════════════════
|
| 259 |
+
# Tool API (for MCP/HTTP interface)
|
| 260 |
+
# ════════════════════════════════════════════════════���═════════════
|
| 261 |
+
|
| 262 |
+
def evaluate_session(self, session_id: str) -> Dict:
|
| 263 |
+
"""Get observation vector and metadata for a session."""
|
| 264 |
+
session = self.pending_sessions.get(session_id) or self.inspected_sessions.get(session_id)
|
| 265 |
+
if session is None:
|
| 266 |
+
raise KeyError(f"session not found: {session_id}")
|
| 267 |
+
|
| 268 |
+
return {
|
| 269 |
+
"session_id": session_id,
|
| 270 |
+
"features": dict(session["features"]),
|
| 271 |
+
"observation": self.generator.to_observation_vector(session),
|
| 272 |
+
"is_inspected": session_id in self.inspected_sessions,
|
| 273 |
+
"revealed_malicious": (
|
| 274 |
+
session["metadata"]["malicious"]
|
| 275 |
+
if session["metadata"]["revealed"] else None
|
| 276 |
+
),
|
| 277 |
+
"expires_tick": session["expires_tick"],
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
def take_action(self, session_id: str, action: int) -> Tuple[float, Dict]:
|
| 281 |
+
"""Apply an action to a specific session."""
|
| 282 |
+
return self._apply_action(session_id, action)
|
| 283 |
+
|
| 284 |
+
def get_network_stats(self) -> Dict:
|
| 285 |
+
"""Aggregate episode statistics for grading."""
|
| 286 |
+
m = self.metrics
|
| 287 |
+
total_malicious = m.malicious_seen + m.sessions_expired_malicious
|
| 288 |
+
total_benign = m.benign_seen + m.sessions_expired_benign
|
| 289 |
+
|
| 290 |
+
detection_rate = m.detections / max(total_malicious, 1)
|
| 291 |
+
false_positive_rate = m.false_positives / max(total_benign, 1)
|
| 292 |
+
efficiency = 1.0 - min(1.0, m.total_cost / max(self.initial_budget, 1e-6))
|
| 293 |
+
early_detection_bonus = m.early_detection_sum / max(m.detections, 1)
|
| 294 |
+
cascade_prevention = 1.0 - (m.cascade_failures / max(total_malicious, 1))
|
| 295 |
+
|
| 296 |
+
return {
|
| 297 |
+
"episode_id": self.episode_id,
|
| 298 |
+
"task": self.task,
|
| 299 |
+
"tick": self.current_tick,
|
| 300 |
+
"step_count": self.step_count,
|
| 301 |
+
"total_reward": round(self.total_reward, 4),
|
| 302 |
+
"budget_remaining": round(self.budget_remaining, 4),
|
| 303 |
+
"budget_used_pct": round(1.0 - self.budget_remaining / max(self.initial_budget, 1e-6), 4),
|
| 304 |
+
"total_malicious": total_malicious,
|
| 305 |
+
"total_benign": total_benign,
|
| 306 |
+
"detection_rate": round(detection_rate, 6),
|
| 307 |
+
"false_positive_rate": round(false_positive_rate, 6),
|
| 308 |
+
"efficiency": round(efficiency, 6),
|
| 309 |
+
"early_detection_bonus": round(early_detection_bonus, 6),
|
| 310 |
+
"cascade_prevention": round(cascade_prevention, 6),
|
| 311 |
+
"correct_allows": m.correct_allows,
|
| 312 |
+
"inspections": m.inspections,
|
| 313 |
+
"expired_malicious": m.sessions_expired_malicious,
|
| 314 |
+
"expired_benign": m.sessions_expired_benign,
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
def get_threat_intelligence(self) -> Dict:
|
| 318 |
+
return self.threat_engine.intelligence_feed()
|
| 319 |
+
|
| 320 |
+
def list_tools(self) -> List[str]:
|
| 321 |
+
return [
|
| 322 |
+
"evaluate_session", "take_action",
|
| 323 |
+
"get_network_stats", "get_threat_intelligence",
|
| 324 |
+
]
|
| 325 |
+
|
| 326 |
+
# ══════════════════════════════════════════════════════════════════
|
| 327 |
+
# Internal mechanics
|
| 328 |
+
# ══════════════════════════════════════════════════════════════════
|
| 329 |
+
|
| 330 |
+
def _apply_action(self, session_id: str, action: int) -> Tuple[float, Dict]:
|
| 331 |
+
"""Core action application logic."""
|
| 332 |
+
if action not in ACTIONS:
|
| 333 |
+
raise ValueError(f"invalid action: {action}")
|
| 334 |
+
|
| 335 |
+
# Find the session in either pool
|
| 336 |
+
source_pool = "none"
|
| 337 |
+
if session_id in self.inspected_sessions:
|
| 338 |
+
session = self.inspected_sessions.pop(session_id)
|
| 339 |
+
source_pool = "inspected"
|
| 340 |
+
elif session_id in self.pending_sessions:
|
| 341 |
+
session = self.pending_sessions.pop(session_id)
|
| 342 |
+
source_pool = "pending"
|
| 343 |
+
else:
|
| 344 |
+
raise KeyError(f"session not found: {session_id}")
|
| 345 |
+
|
| 346 |
+
metadata = session["metadata"]
|
| 347 |
+
malicious = bool(metadata["malicious"])
|
| 348 |
+
blocked = action in BLOCKING_ACTIONS
|
| 349 |
+
inspected = action == 2 # INSPECT
|
| 350 |
+
|
| 351 |
+
# ── INSPECT keeps the session alive for a second decision ──
|
| 352 |
+
if inspected and session_id not in self.inspected_sessions:
|
| 353 |
+
metadata["revealed"] = True
|
| 354 |
+
self.inspected_sessions[session_id] = session
|
| 355 |
+
self.pending_sessions[session_id] = session
|
| 356 |
+
self.metrics.inspections += 1
|
| 357 |
+
# Compute reward for the inspection itself
|
| 358 |
+
reward, components = self.reward_engine.reward(
|
| 359 |
+
action=action,
|
| 360 |
+
is_malicious=malicious,
|
| 361 |
+
budget_remaining=self.budget_remaining,
|
| 362 |
+
attack_phase=metadata.get("attack_phase", 0),
|
| 363 |
+
inspect_correct=malicious,
|
| 364 |
+
)
|
| 365 |
+
self.budget_remaining = max(0.0, self.budget_remaining - components["cost"])
|
| 366 |
+
self.metrics.total_cost += components["cost"]
|
| 367 |
+
record = self._make_record(session_id, action, malicious, reward, components)
|
| 368 |
+
return reward, record
|
| 369 |
+
|
| 370 |
+
# ── Terminal action (ALLOW, BLOCK, SANDBOX, RATE_LIMIT, QUARANTINE) ──
|
| 371 |
+
inspect_correct = malicious and metadata.get("revealed", False)
|
| 372 |
+
reward, components = self.reward_engine.reward(
|
| 373 |
+
action=action,
|
| 374 |
+
is_malicious=malicious,
|
| 375 |
+
budget_remaining=self.budget_remaining,
|
| 376 |
+
attack_phase=metadata.get("attack_phase", 0),
|
| 377 |
+
inspect_correct=inspect_correct,
|
| 378 |
+
)
|
| 379 |
+
self.budget_remaining = max(0.0, self.budget_remaining - components["cost"])
|
| 380 |
+
self.metrics.total_cost += components["cost"]
|
| 381 |
+
if source_pool == "inspected":
|
| 382 |
+
self.pending_sessions.pop(session_id, None)
|
| 383 |
+
|
| 384 |
+
# ── Update metrics ──
|
| 385 |
+
if malicious:
|
| 386 |
+
self.metrics.malicious_seen += 1
|
| 387 |
+
if blocked:
|
| 388 |
+
self.metrics.detections += 1
|
| 389 |
+
phase = metadata.get("attack_phase", 0)
|
| 390 |
+
self.metrics.early_detection_sum += float(np.exp(-phase))
|
| 391 |
+
attacker_id = metadata.get("attacker_id")
|
| 392 |
+
if attacker_id:
|
| 393 |
+
self._blocked_attacker_ids.add(attacker_id)
|
| 394 |
+
else:
|
| 395 |
+
if metadata.get("attack_phase", 0) >= 2:
|
| 396 |
+
self.metrics.cascade_failures += 1
|
| 397 |
+
else:
|
| 398 |
+
self.metrics.benign_seen += 1
|
| 399 |
+
if blocked:
|
| 400 |
+
self.metrics.false_positives += 1
|
| 401 |
+
elif action == 0:
|
| 402 |
+
self.metrics.correct_allows += 1
|
| 403 |
+
|
| 404 |
+
record = self._make_record(session_id, action, malicious, reward, components)
|
| 405 |
+
self.action_log.append(record)
|
| 406 |
+
return reward, record
|
| 407 |
+
|
| 408 |
+
def _make_record(self, session_id: str, action: int, malicious: bool,
|
| 409 |
+
reward: float, components: Dict) -> Dict:
|
| 410 |
+
return {
|
| 411 |
+
"tick": self.current_tick,
|
| 412 |
+
"session_id": session_id,
|
| 413 |
+
"action": action,
|
| 414 |
+
"action_name": ACTIONS[action],
|
| 415 |
+
"malicious": malicious,
|
| 416 |
+
"reward": round(reward, 6),
|
| 417 |
+
"components": {k: round(v, 6) for k, v in components.items()},
|
| 418 |
+
}
|
| 419 |
+
|
| 420 |
+
def _spawn_sessions(self) -> None:
|
| 421 |
+
"""Generate new benign and malicious sessions for current tick."""
|
| 422 |
+
config = TASK_CONFIGS[self.task]
|
| 423 |
+
benign_count = int(max(1, self.rng.poisson(
|
| 424 |
+
config["traffic_lambda"] * config["benign_ratio"],
|
| 425 |
+
)))
|
| 426 |
+
benign = self.generator.generate_benign_sessions(
|
| 427 |
+
tick=self.current_tick, count=benign_count,
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
self.threat_engine.maybe_spawn_attacker(config["threat_probability"])
|
| 431 |
+
malicious = self.threat_engine.generate_attack_sessions(
|
| 432 |
+
tick=self.current_tick,
|
| 433 |
+
generator=self.generator,
|
| 434 |
+
blocked_attackers=self._blocked_attacker_ids,
|
| 435 |
+
)
|
| 436 |
+
self._blocked_attacker_ids = set()
|
| 437 |
+
|
| 438 |
+
for session in benign + malicious:
|
| 439 |
+
self.pending_sessions[session["session_id"]] = session
|
| 440 |
+
|
| 441 |
+
def _expire_sessions(self) -> float:
|
| 442 |
+
"""Remove expired sessions and apply penalties. Count in metrics."""
|
| 443 |
+
expired_ids = set()
|
| 444 |
+
for sid, session in self.pending_sessions.items():
|
| 445 |
+
if session["expires_tick"] <= self.current_tick:
|
| 446 |
+
expired_ids.add(sid)
|
| 447 |
+
for sid, session in self.inspected_sessions.items():
|
| 448 |
+
if session["expires_tick"] <= self.current_tick:
|
| 449 |
+
expired_ids.add(sid)
|
| 450 |
+
penalty = 0.0
|
| 451 |
+
for session_id in expired_ids:
|
| 452 |
+
session = self.inspected_sessions.pop(session_id, None)
|
| 453 |
+
if session is None:
|
| 454 |
+
session = self.pending_sessions.get(session_id)
|
| 455 |
+
self.pending_sessions.pop(session_id, None)
|
| 456 |
+
if session is None:
|
| 457 |
+
continue
|
| 458 |
+
if session["metadata"]["malicious"]:
|
| 459 |
+
penalty -= 1.5
|
| 460 |
+
self.metrics.sessions_expired_malicious += 1
|
| 461 |
+
if session["metadata"].get("attack_phase", 0) >= 2:
|
| 462 |
+
self.metrics.cascade_failures += 1
|
| 463 |
+
else:
|
| 464 |
+
self.metrics.sessions_expired_benign += 1
|
| 465 |
+
|
| 466 |
+
return penalty
|
| 467 |
+
|
| 468 |
+
def _rebuild_queue(self) -> None:
|
| 469 |
+
"""Rebuild the single-session queue from pending + inspected."""
|
| 470 |
+
# Inspected sessions get priority (they need a follow-up action)
|
| 471 |
+
ordered = list(self.inspected_sessions.keys()) + list(self.pending_sessions.keys())
|
| 472 |
+
seen: Set[str] = set()
|
| 473 |
+
self._session_queue = []
|
| 474 |
+
for sid in ordered:
|
| 475 |
+
if sid in seen:
|
| 476 |
+
continue
|
| 477 |
+
seen.add(sid)
|
| 478 |
+
self._session_queue.append(sid)
|
| 479 |
+
|
| 480 |
+
def _current_observation(self) -> List[float]:
|
| 481 |
+
"""Get normalized observation for the next session in queue."""
|
| 482 |
+
if self._session_queue:
|
| 483 |
+
sid = self._session_queue[0]
|
| 484 |
+
session = (
|
| 485 |
+
self.inspected_sessions.get(sid)
|
| 486 |
+
or self.pending_sessions.get(sid)
|
| 487 |
+
)
|
| 488 |
+
if session:
|
| 489 |
+
return self.generator.to_observation_vector(session)
|
| 490 |
+
return [0.0] * OBS_DIM
|
server/graders.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Deterministic grading system for the three firewall tasks.
|
| 2 |
+
|
| 3 |
+
Each task has:
|
| 4 |
+
- A fixed seed for reproducible traffic
|
| 5 |
+
- Weighted scoring across detection, false positives, efficiency, etc.
|
| 6 |
+
- A score in [0.0, 1.0] and a pass threshold
|
| 7 |
+
|
| 8 |
+
Graders are deterministic: same seed + same policy = same score.
|
| 9 |
+
"""
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
from typing import Callable, Dict, List
|
| 14 |
+
|
| 15 |
+
# Updated import path
|
| 16 |
+
from server.firewall_environment import FirewallEnvironment
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass(frozen=True)
|
| 20 |
+
class TaskSpec:
|
| 21 |
+
name: str
|
| 22 |
+
task_key: str
|
| 23 |
+
threshold: float
|
| 24 |
+
weights: Dict[str, float]
|
| 25 |
+
seed: int
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
TASK_SPECS = {
|
| 29 |
+
"easy": TaskSpec(
|
| 30 |
+
name="Perimeter Defense",
|
| 31 |
+
task_key="easy",
|
| 32 |
+
threshold=0.70,
|
| 33 |
+
seed=101,
|
| 34 |
+
weights={
|
| 35 |
+
"detection_rate": 0.35,
|
| 36 |
+
"fp_complement": 0.35,
|
| 37 |
+
"efficiency": 0.30,
|
| 38 |
+
},
|
| 39 |
+
),
|
| 40 |
+
"medium": TaskSpec(
|
| 41 |
+
name="Mixed Threat Landscape",
|
| 42 |
+
task_key="medium",
|
| 43 |
+
threshold=0.50,
|
| 44 |
+
seed=202,
|
| 45 |
+
weights={
|
| 46 |
+
"detection_rate": 0.25,
|
| 47 |
+
"fp_complement": 0.30,
|
| 48 |
+
"efficiency": 0.15,
|
| 49 |
+
"early_detection_bonus": 0.15,
|
| 50 |
+
"cascade_prevention": 0.15,
|
| 51 |
+
},
|
| 52 |
+
),
|
| 53 |
+
"hard": TaskSpec(
|
| 54 |
+
name="Advanced Persistent Threat",
|
| 55 |
+
task_key="hard",
|
| 56 |
+
threshold=0.45,
|
| 57 |
+
seed=303,
|
| 58 |
+
weights={
|
| 59 |
+
"detection_rate": 0.20,
|
| 60 |
+
"fp_complement": 0.25,
|
| 61 |
+
"efficiency": 0.15,
|
| 62 |
+
"early_detection_bonus": 0.20,
|
| 63 |
+
"cascade_prevention": 0.20,
|
| 64 |
+
},
|
| 65 |
+
),
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
PASS_CONSTRAINTS = {
|
| 69 |
+
"easy": {"min_detection_rate": 0.35, "min_fp_complement": 0.65},
|
| 70 |
+
"medium": {"min_detection_rate": 0.35, "min_fp_complement": 0.60},
|
| 71 |
+
"hard": {"min_detection_rate": 0.35, "min_fp_complement": 0.55},
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def grade_stats(task: str, stats: Dict) -> Dict:
|
| 76 |
+
"""Compute a grade from episode stats."""
|
| 77 |
+
spec = TASK_SPECS[task]
|
| 78 |
+
values = {
|
| 79 |
+
"detection_rate": stats.get("detection_rate", 0.0),
|
| 80 |
+
"fp_complement": 1.0 - stats.get("false_positive_rate", 1.0),
|
| 81 |
+
"efficiency": stats.get("efficiency", 0.0),
|
| 82 |
+
"early_detection_bonus": stats.get("early_detection_bonus", 0.0),
|
| 83 |
+
"cascade_prevention": stats.get("cascade_prevention", 0.0),
|
| 84 |
+
}
|
| 85 |
+
score = sum(values.get(k, 0.0) * w for k, w in spec.weights.items())
|
| 86 |
+
score = max(0.0, min(1.0, score))
|
| 87 |
+
constraints = PASS_CONSTRAINTS[task]
|
| 88 |
+
meets_constraints = (
|
| 89 |
+
values["detection_rate"] >= constraints["min_detection_rate"]
|
| 90 |
+
and values["fp_complement"] >= constraints["min_fp_complement"]
|
| 91 |
+
)
|
| 92 |
+
passed = (score >= spec.threshold) and meets_constraints
|
| 93 |
+
|
| 94 |
+
return {
|
| 95 |
+
"task": task,
|
| 96 |
+
"task_name": spec.name,
|
| 97 |
+
"threshold": spec.threshold,
|
| 98 |
+
"score": round(score, 6),
|
| 99 |
+
"passed": passed,
|
| 100 |
+
"pass_constraints": constraints,
|
| 101 |
+
"meets_constraints": meets_constraints,
|
| 102 |
+
"breakdown": {k: round(v, 6) for k, v in values.items()},
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def run_deterministic_grade(
|
| 107 |
+
env: FirewallEnvironment,
|
| 108 |
+
task: str,
|
| 109 |
+
policy: Callable[[FirewallEnvironment, List[str]], Dict[str, int]],
|
| 110 |
+
) -> Dict:
|
| 111 |
+
"""Run a full episode with a policy and compute the grade."""
|
| 112 |
+
spec = TASK_SPECS[task]
|
| 113 |
+
env.reset(task=task, seed=spec.seed)
|
| 114 |
+
done = False
|
| 115 |
+
while not done:
|
| 116 |
+
session_ids = (
|
| 117 |
+
list(env.inspected_sessions.keys())
|
| 118 |
+
+ list(env.pending_sessions.keys())
|
| 119 |
+
)
|
| 120 |
+
actions = policy(env, session_ids)
|
| 121 |
+
response = env.step(actions)
|
| 122 |
+
done = bool(response["done"])
|
| 123 |
+
stats = env.get_network_stats()
|
| 124 |
+
return grade_stats(task, stats)
|
server/utils/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Package marker
|
server/utils/data_loader.py
ADDED
|
@@ -0,0 +1,496 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Network traffic session generator with realistic correlated features.
|
| 2 |
+
|
| 3 |
+
Each session is a 22-dimensional feature vector representing metadata and
|
| 4 |
+
behavioral signals from encrypted traffic (no payload inspection).
|
| 5 |
+
|
| 6 |
+
Feature groups:
|
| 7 |
+
- Volume & timing: bytes, duration, packet stats, inter-arrival metrics
|
| 8 |
+
- Network metadata: ports, protocol, DNS, connection reuse
|
| 9 |
+
- TLS / certificate: TLS version, JA3 cluster, cert chain, self-signed
|
| 10 |
+
- Behavioral context: geo distance, time of day, reputation, entropy
|
| 11 |
+
|
| 12 |
+
Benign traffic is drawn from 5 profile archetypes. Malicious traffic
|
| 13 |
+
profiles vary by attack scenario AND kill-chain phase, creating real
|
| 14 |
+
distributional differences an RL agent can learn to exploit.
|
| 15 |
+
"""
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Dict, List
|
| 20 |
+
import math
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
FEATURE_ORDER = [
|
| 26 |
+
"bytes_sent",
|
| 27 |
+
"bytes_received",
|
| 28 |
+
"duration_ms",
|
| 29 |
+
"packet_count",
|
| 30 |
+
"avg_packet_size",
|
| 31 |
+
"packet_size_variance",
|
| 32 |
+
"inter_arrival_mean",
|
| 33 |
+
"inter_arrival_jitter",
|
| 34 |
+
"src_port",
|
| 35 |
+
"dst_port",
|
| 36 |
+
"protocol",
|
| 37 |
+
"tls_version",
|
| 38 |
+
"ja3_hash_cluster",
|
| 39 |
+
"cert_chain_length",
|
| 40 |
+
"cert_validity_days",
|
| 41 |
+
"is_self_signed",
|
| 42 |
+
"dns_query_count",
|
| 43 |
+
"connection_reuse",
|
| 44 |
+
"geo_distance",
|
| 45 |
+
"time_of_day",
|
| 46 |
+
"session_history_score",
|
| 47 |
+
"entropy_score",
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
# Min/max bounds for normalization (empirically calibrated)
|
| 51 |
+
FEATURE_BOUNDS: Dict[str, tuple] = {
|
| 52 |
+
"bytes_sent": (4.0, 14.0),
|
| 53 |
+
"bytes_received": (3.0, 13.0),
|
| 54 |
+
"duration_ms": (20.0, 25000.0),
|
| 55 |
+
"packet_count": (2.0, 1200.0),
|
| 56 |
+
"avg_packet_size": (40.0, 1400.0),
|
| 57 |
+
"packet_size_variance": (5.0, 500.0),
|
| 58 |
+
"inter_arrival_mean": (0.5, 600.0),
|
| 59 |
+
"inter_arrival_jitter": (0.0, 300.0),
|
| 60 |
+
"src_port": (1024.0, 65535.0),
|
| 61 |
+
"dst_port": (1.0, 65535.0),
|
| 62 |
+
"protocol": (0.0, 2.0),
|
| 63 |
+
"tls_version": (0.0, 2.0),
|
| 64 |
+
"ja3_hash_cluster": (0.0, 255.0),
|
| 65 |
+
"cert_chain_length": (0.0, 6.0),
|
| 66 |
+
"cert_validity_days": (1.0, 1200.0),
|
| 67 |
+
"is_self_signed": (0.0, 1.0),
|
| 68 |
+
"dns_query_count": (0.0, 12.0),
|
| 69 |
+
"connection_reuse": (0.0, 1.0),
|
| 70 |
+
"geo_distance": (0.0, 12000.0),
|
| 71 |
+
"time_of_day": (0.0, 1.0),
|
| 72 |
+
"session_history_score": (0.0, 1.0),
|
| 73 |
+
"entropy_score": (0.0, 1.0),
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@dataclass(frozen=True)
|
| 78 |
+
class TrafficProfile:
|
| 79 |
+
name: str
|
| 80 |
+
packet_mean: float
|
| 81 |
+
packet_std_frac: float # std = mean * frac
|
| 82 |
+
duration_mean: float
|
| 83 |
+
entropy_mean: float
|
| 84 |
+
entropy_std: float
|
| 85 |
+
tls_probability: float
|
| 86 |
+
self_signed_prob: float
|
| 87 |
+
common_ports: List[int]
|
| 88 |
+
connection_reuse_mean: float
|
| 89 |
+
geo_distance_mean: float
|
| 90 |
+
history_score_mean: float
|
| 91 |
+
cert_validity_mean: float
|
| 92 |
+
ja3_cluster_range: tuple = (0, 128)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# ── Benign traffic profiles ─────────────────────────────────────────
|
| 96 |
+
BENIGN_PROFILES = [
|
| 97 |
+
TrafficProfile(
|
| 98 |
+
name="WebBrowsing", packet_mean=50.0, packet_std_frac=0.35,
|
| 99 |
+
duration_mean=900.0, entropy_mean=0.32, entropy_std=0.06,
|
| 100 |
+
tls_probability=0.95, self_signed_prob=0.02,
|
| 101 |
+
common_ports=[80, 443], connection_reuse_mean=0.72,
|
| 102 |
+
geo_distance_mean=1400.0, history_score_mean=0.82,
|
| 103 |
+
cert_validity_mean=450.0, ja3_cluster_range=(0, 64),
|
| 104 |
+
),
|
| 105 |
+
TrafficProfile(
|
| 106 |
+
name="Streaming", packet_mean=800.0, packet_std_frac=0.25,
|
| 107 |
+
duration_mean=18000.0, entropy_mean=0.22, entropy_std=0.04,
|
| 108 |
+
tls_probability=0.99, self_signed_prob=0.01,
|
| 109 |
+
common_ports=[443, 8080], connection_reuse_mean=0.88,
|
| 110 |
+
geo_distance_mean=2200.0, history_score_mean=0.90,
|
| 111 |
+
cert_validity_mean=500.0, ja3_cluster_range=(0, 32),
|
| 112 |
+
),
|
| 113 |
+
TrafficProfile(
|
| 114 |
+
name="API", packet_mean=25.0, packet_std_frac=0.30,
|
| 115 |
+
duration_mean=350.0, entropy_mean=0.18, entropy_std=0.04,
|
| 116 |
+
tls_probability=0.98, self_signed_prob=0.01,
|
| 117 |
+
common_ports=[443, 8443], connection_reuse_mean=0.80,
|
| 118 |
+
geo_distance_mean=1000.0, history_score_mean=0.85,
|
| 119 |
+
cert_validity_mean=500.0, ja3_cluster_range=(0, 48),
|
| 120 |
+
),
|
| 121 |
+
TrafficProfile(
|
| 122 |
+
name="IoT", packet_mean=10.0, packet_std_frac=0.40,
|
| 123 |
+
duration_mean=1500.0, entropy_mean=0.38, entropy_std=0.07,
|
| 124 |
+
tls_probability=0.30, self_signed_prob=0.08,
|
| 125 |
+
common_ports=[1883, 5683, 8883], connection_reuse_mean=0.55,
|
| 126 |
+
geo_distance_mean=800.0, history_score_mean=0.70,
|
| 127 |
+
cert_validity_mean=300.0, ja3_cluster_range=(80, 128),
|
| 128 |
+
),
|
| 129 |
+
TrafficProfile(
|
| 130 |
+
name="Enterprise", packet_mean=120.0, packet_std_frac=0.35,
|
| 131 |
+
duration_mean=1200.0, entropy_mean=0.28, entropy_std=0.06,
|
| 132 |
+
tls_probability=0.85, self_signed_prob=0.04,
|
| 133 |
+
common_ports=[443, 445, 3389], connection_reuse_mean=0.65,
|
| 134 |
+
geo_distance_mean=500.0, history_score_mean=0.88,
|
| 135 |
+
cert_validity_mean=400.0, ja3_cluster_range=(0, 96),
|
| 136 |
+
),
|
| 137 |
+
]
|
| 138 |
+
|
| 139 |
+
# ── Malicious traffic profiles per (scenario, phase) ────────────────
|
| 140 |
+
# Each scenario has distinct fingerprints making them differentiable
|
| 141 |
+
MALICIOUS_PROFILES: Dict[str, Dict[int, TrafficProfile]] = {
|
| 142 |
+
"port_scan_exploit_c2": {
|
| 143 |
+
0: TrafficProfile(
|
| 144 |
+
name="PortScan_Recon", packet_mean=6.0, packet_std_frac=0.5,
|
| 145 |
+
duration_mean=80.0, entropy_mean=0.12, entropy_std=0.04,
|
| 146 |
+
tls_probability=0.05, self_signed_prob=0.60,
|
| 147 |
+
common_ports=[21, 22, 23, 25, 445, 3389, 5900],
|
| 148 |
+
connection_reuse_mean=0.02, geo_distance_mean=5500.0,
|
| 149 |
+
history_score_mean=0.10, cert_validity_mean=60.0,
|
| 150 |
+
ja3_cluster_range=(200, 255),
|
| 151 |
+
),
|
| 152 |
+
1: TrafficProfile(
|
| 153 |
+
name="PortScan_Exploit", packet_mean=45.0, packet_std_frac=0.4,
|
| 154 |
+
duration_mean=300.0, entropy_mean=0.78, entropy_std=0.06,
|
| 155 |
+
tls_probability=0.40, self_signed_prob=0.45,
|
| 156 |
+
common_ports=[80, 443, 8080, 445],
|
| 157 |
+
connection_reuse_mean=0.08, geo_distance_mean=5200.0,
|
| 158 |
+
history_score_mean=0.12, cert_validity_mean=90.0,
|
| 159 |
+
ja3_cluster_range=(210, 255),
|
| 160 |
+
),
|
| 161 |
+
2: TrafficProfile(
|
| 162 |
+
name="PortScan_C2", packet_mean=4.0, packet_std_frac=0.6,
|
| 163 |
+
duration_mean=5000.0, entropy_mean=0.55, entropy_std=0.08,
|
| 164 |
+
tls_probability=0.92, self_signed_prob=0.35,
|
| 165 |
+
common_ports=[443, 53, 8443],
|
| 166 |
+
connection_reuse_mean=0.15, geo_distance_mean=6000.0,
|
| 167 |
+
history_score_mean=0.15, cert_validity_mean=45.0,
|
| 168 |
+
ja3_cluster_range=(220, 255),
|
| 169 |
+
),
|
| 170 |
+
3: TrafficProfile(
|
| 171 |
+
name="PortScan_Exfil", packet_mean=350.0, packet_std_frac=0.3,
|
| 172 |
+
duration_mean=12000.0, entropy_mean=0.88, entropy_std=0.04,
|
| 173 |
+
tls_probability=0.98, self_signed_prob=0.25,
|
| 174 |
+
common_ports=[443, 8443],
|
| 175 |
+
connection_reuse_mean=0.10, geo_distance_mean=6500.0,
|
| 176 |
+
history_score_mean=0.08, cert_validity_mean=30.0,
|
| 177 |
+
ja3_cluster_range=(230, 255),
|
| 178 |
+
),
|
| 179 |
+
},
|
| 180 |
+
"credential_stuffing_lateral": {
|
| 181 |
+
0: TrafficProfile(
|
| 182 |
+
name="CredStuff_Probe", packet_mean=15.0, packet_std_frac=0.4,
|
| 183 |
+
duration_mean=200.0, entropy_mean=0.42, entropy_std=0.06,
|
| 184 |
+
tls_probability=0.90, self_signed_prob=0.10,
|
| 185 |
+
common_ports=[443, 80, 8443],
|
| 186 |
+
connection_reuse_mean=0.05, geo_distance_mean=3500.0,
|
| 187 |
+
history_score_mean=0.25, cert_validity_mean=300.0,
|
| 188 |
+
ja3_cluster_range=(140, 200),
|
| 189 |
+
),
|
| 190 |
+
1: TrafficProfile(
|
| 191 |
+
name="CredStuff_Auth", packet_mean=20.0, packet_std_frac=0.35,
|
| 192 |
+
duration_mean=150.0, entropy_mean=0.50, entropy_std=0.07,
|
| 193 |
+
tls_probability=0.95, self_signed_prob=0.08,
|
| 194 |
+
common_ports=[443, 389, 636],
|
| 195 |
+
connection_reuse_mean=0.10, geo_distance_mean=3200.0,
|
| 196 |
+
history_score_mean=0.30, cert_validity_mean=350.0,
|
| 197 |
+
ja3_cluster_range=(150, 210),
|
| 198 |
+
),
|
| 199 |
+
2: TrafficProfile(
|
| 200 |
+
name="CredStuff_Lateral", packet_mean=30.0, packet_std_frac=0.35,
|
| 201 |
+
duration_mean=500.0, entropy_mean=0.35, entropy_std=0.06,
|
| 202 |
+
tls_probability=0.80, self_signed_prob=0.12,
|
| 203 |
+
common_ports=[445, 3389, 5985, 22],
|
| 204 |
+
connection_reuse_mean=0.20, geo_distance_mean=300.0,
|
| 205 |
+
history_score_mean=0.40, cert_validity_mean=350.0,
|
| 206 |
+
ja3_cluster_range=(160, 220),
|
| 207 |
+
),
|
| 208 |
+
3: TrafficProfile(
|
| 209 |
+
name="CredStuff_Exfil", packet_mean=200.0, packet_std_frac=0.3,
|
| 210 |
+
duration_mean=8000.0, entropy_mean=0.80, entropy_std=0.05,
|
| 211 |
+
tls_probability=0.98, self_signed_prob=0.15,
|
| 212 |
+
common_ports=[443, 8443],
|
| 213 |
+
connection_reuse_mean=0.12, geo_distance_mean=4000.0,
|
| 214 |
+
history_score_mean=0.18, cert_validity_mean=90.0,
|
| 215 |
+
ja3_cluster_range=(180, 240),
|
| 216 |
+
),
|
| 217 |
+
},
|
| 218 |
+
"supply_chain_compromise": {
|
| 219 |
+
0: TrafficProfile(
|
| 220 |
+
name="SupplyChain_Init", packet_mean=40.0, packet_std_frac=0.3,
|
| 221 |
+
duration_mean=600.0, entropy_mean=0.30, entropy_std=0.05,
|
| 222 |
+
tls_probability=0.98, self_signed_prob=0.03,
|
| 223 |
+
common_ports=[443, 8443],
|
| 224 |
+
connection_reuse_mean=0.60, geo_distance_mean=1800.0,
|
| 225 |
+
history_score_mean=0.70, cert_validity_mean=380.0,
|
| 226 |
+
ja3_cluster_range=(30, 80),
|
| 227 |
+
),
|
| 228 |
+
1: TrafficProfile(
|
| 229 |
+
name="SupplyChain_Inject", packet_mean=60.0, packet_std_frac=0.3,
|
| 230 |
+
duration_mean=800.0, entropy_mean=0.40, entropy_std=0.06,
|
| 231 |
+
tls_probability=0.98, self_signed_prob=0.04,
|
| 232 |
+
common_ports=[443, 8443],
|
| 233 |
+
connection_reuse_mean=0.55, geo_distance_mean=2000.0,
|
| 234 |
+
history_score_mean=0.65, cert_validity_mean=350.0,
|
| 235 |
+
ja3_cluster_range=(35, 90),
|
| 236 |
+
),
|
| 237 |
+
2: TrafficProfile(
|
| 238 |
+
name="SupplyChain_Beacon", packet_mean=8.0, packet_std_frac=0.5,
|
| 239 |
+
duration_mean=3000.0, entropy_mean=0.48, entropy_std=0.07,
|
| 240 |
+
tls_probability=0.99, self_signed_prob=0.05,
|
| 241 |
+
common_ports=[443],
|
| 242 |
+
connection_reuse_mean=0.50, geo_distance_mean=2500.0,
|
| 243 |
+
history_score_mean=0.55, cert_validity_mean=250.0,
|
| 244 |
+
ja3_cluster_range=(40, 100),
|
| 245 |
+
),
|
| 246 |
+
3: TrafficProfile(
|
| 247 |
+
name="SupplyChain_Exfil", packet_mean=100.0, packet_std_frac=0.3,
|
| 248 |
+
duration_mean=5000.0, entropy_mean=0.60, entropy_std=0.06,
|
| 249 |
+
tls_probability=0.99, self_signed_prob=0.06,
|
| 250 |
+
common_ports=[443, 8443],
|
| 251 |
+
connection_reuse_mean=0.42, geo_distance_mean=3000.0,
|
| 252 |
+
history_score_mean=0.45, cert_validity_mean=200.0,
|
| 253 |
+
ja3_cluster_range=(50, 110),
|
| 254 |
+
),
|
| 255 |
+
},
|
| 256 |
+
"low_and_slow_apt": {
|
| 257 |
+
0: TrafficProfile(
|
| 258 |
+
name="APT_Recon", packet_mean=12.0, packet_std_frac=0.4,
|
| 259 |
+
duration_mean=400.0, entropy_mean=0.28, entropy_std=0.05,
|
| 260 |
+
tls_probability=0.92, self_signed_prob=0.05,
|
| 261 |
+
common_ports=[443, 80],
|
| 262 |
+
connection_reuse_mean=0.50, geo_distance_mean=2200.0,
|
| 263 |
+
history_score_mean=0.55, cert_validity_mean=320.0,
|
| 264 |
+
ja3_cluster_range=(60, 130),
|
| 265 |
+
),
|
| 266 |
+
1: TrafficProfile(
|
| 267 |
+
name="APT_Establish", packet_mean=18.0, packet_std_frac=0.35,
|
| 268 |
+
duration_mean=700.0, entropy_mean=0.35, entropy_std=0.06,
|
| 269 |
+
tls_probability=0.95, self_signed_prob=0.07,
|
| 270 |
+
common_ports=[443, 53],
|
| 271 |
+
connection_reuse_mean=0.45, geo_distance_mean=2600.0,
|
| 272 |
+
history_score_mean=0.48, cert_validity_mean=280.0,
|
| 273 |
+
ja3_cluster_range=(70, 140),
|
| 274 |
+
),
|
| 275 |
+
2: TrafficProfile(
|
| 276 |
+
name="APT_Persist", packet_mean=5.0, packet_std_frac=0.6,
|
| 277 |
+
duration_mean=8000.0, entropy_mean=0.42, entropy_std=0.07,
|
| 278 |
+
tls_probability=0.97, self_signed_prob=0.10,
|
| 279 |
+
common_ports=[443],
|
| 280 |
+
connection_reuse_mean=0.38, geo_distance_mean=3200.0,
|
| 281 |
+
history_score_mean=0.38, cert_validity_mean=200.0,
|
| 282 |
+
ja3_cluster_range=(80, 150),
|
| 283 |
+
),
|
| 284 |
+
3: TrafficProfile(
|
| 285 |
+
name="APT_Exfil", packet_mean=60.0, packet_std_frac=0.4,
|
| 286 |
+
duration_mean=15000.0, entropy_mean=0.65, entropy_std=0.06,
|
| 287 |
+
tls_probability=0.99, self_signed_prob=0.12,
|
| 288 |
+
common_ports=[443, 8443],
|
| 289 |
+
connection_reuse_mean=0.25, geo_distance_mean=4000.0,
|
| 290 |
+
history_score_mean=0.28, cert_validity_mean=120.0,
|
| 291 |
+
ja3_cluster_range=(90, 160),
|
| 292 |
+
),
|
| 293 |
+
},
|
| 294 |
+
"ddos_amplification": {
|
| 295 |
+
0: TrafficProfile(
|
| 296 |
+
name="DDoS_Probe", packet_mean=20.0, packet_std_frac=0.5,
|
| 297 |
+
duration_mean=50.0, entropy_mean=0.15, entropy_std=0.04,
|
| 298 |
+
tls_probability=0.10, self_signed_prob=0.30,
|
| 299 |
+
common_ports=[53, 123, 161, 1900],
|
| 300 |
+
connection_reuse_mean=0.02, geo_distance_mean=6000.0,
|
| 301 |
+
history_score_mean=0.08, cert_validity_mean=60.0,
|
| 302 |
+
ja3_cluster_range=(230, 255),
|
| 303 |
+
),
|
| 304 |
+
1: TrafficProfile(
|
| 305 |
+
name="DDoS_Amplify", packet_mean=500.0, packet_std_frac=0.4,
|
| 306 |
+
duration_mean=30.0, entropy_mean=0.10, entropy_std=0.03,
|
| 307 |
+
tls_probability=0.05, self_signed_prob=0.40,
|
| 308 |
+
common_ports=[53, 123, 161, 1900, 11211],
|
| 309 |
+
connection_reuse_mean=0.01, geo_distance_mean=7000.0,
|
| 310 |
+
history_score_mean=0.05, cert_validity_mean=30.0,
|
| 311 |
+
ja3_cluster_range=(240, 255),
|
| 312 |
+
),
|
| 313 |
+
2: TrafficProfile(
|
| 314 |
+
name="DDoS_Sustained", packet_mean=900.0, packet_std_frac=0.3,
|
| 315 |
+
duration_mean=20.0, entropy_mean=0.08, entropy_std=0.02,
|
| 316 |
+
tls_probability=0.03, self_signed_prob=0.50,
|
| 317 |
+
common_ports=[53, 123, 80],
|
| 318 |
+
connection_reuse_mean=0.00, geo_distance_mean=8000.0,
|
| 319 |
+
history_score_mean=0.03, cert_validity_mean=20.0,
|
| 320 |
+
ja3_cluster_range=(245, 255),
|
| 321 |
+
),
|
| 322 |
+
3: TrafficProfile(
|
| 323 |
+
name="DDoS_Peak", packet_mean=1100.0, packet_std_frac=0.25,
|
| 324 |
+
duration_mean=15.0, entropy_mean=0.06, entropy_std=0.02,
|
| 325 |
+
tls_probability=0.02, self_signed_prob=0.55,
|
| 326 |
+
common_ports=[53, 123, 80],
|
| 327 |
+
connection_reuse_mean=0.00, geo_distance_mean=9000.0,
|
| 328 |
+
history_score_mean=0.02, cert_validity_mean=15.0,
|
| 329 |
+
ja3_cluster_range=(248, 255),
|
| 330 |
+
),
|
| 331 |
+
},
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
# Fallback for unknown scenarios
|
| 335 |
+
_DEFAULT_MALICIOUS: Dict[int, TrafficProfile] = MALICIOUS_PROFILES["port_scan_exploit_c2"]
|
| 336 |
+
|
| 337 |
+
BENIGN_WEIGHTS = np.array([0.34, 0.16, 0.18, 0.12, 0.20])
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
class TrafficGenerator:
|
| 341 |
+
"""Generates correlated network session feature vectors.
|
| 342 |
+
|
| 343 |
+
Each session is a dict with 'session_id', 'features' (dict),
|
| 344 |
+
and 'metadata' (malicious flag, attack info, profile name).
|
| 345 |
+
"""
|
| 346 |
+
|
| 347 |
+
def __init__(self, seed: int = 0) -> None:
|
| 348 |
+
self.rng = np.random.default_rng(seed)
|
| 349 |
+
self.session_counter = 0
|
| 350 |
+
|
| 351 |
+
def generate_benign_sessions(self, tick: int, count: int) -> List[Dict]:
|
| 352 |
+
sessions: List[Dict] = []
|
| 353 |
+
for _ in range(max(0, count)):
|
| 354 |
+
idx = self.rng.choice(len(BENIGN_PROFILES), p=BENIGN_WEIGHTS)
|
| 355 |
+
profile = BENIGN_PROFILES[idx]
|
| 356 |
+
sessions.append(self._build_session(
|
| 357 |
+
profile, tick=tick, malicious=False,
|
| 358 |
+
attack_phase=0, scenario="benign", attacker_id=None,
|
| 359 |
+
))
|
| 360 |
+
return sessions
|
| 361 |
+
|
| 362 |
+
def generate_malicious_sessions(
|
| 363 |
+
self, tick: int, count: int,
|
| 364 |
+
attack_phase: int, scenario: str,
|
| 365 |
+
attacker_id: str | None = None,
|
| 366 |
+
) -> List[Dict]:
|
| 367 |
+
sessions: List[Dict] = []
|
| 368 |
+
profiles = MALICIOUS_PROFILES.get(scenario, _DEFAULT_MALICIOUS)
|
| 369 |
+
profile = profiles.get(attack_phase, profiles[max(profiles.keys())])
|
| 370 |
+
for _ in range(max(0, count)):
|
| 371 |
+
sessions.append(self._build_session(
|
| 372 |
+
profile, tick=tick, malicious=True,
|
| 373 |
+
attack_phase=attack_phase, scenario=scenario,
|
| 374 |
+
attacker_id=attacker_id,
|
| 375 |
+
))
|
| 376 |
+
return sessions
|
| 377 |
+
|
| 378 |
+
def to_observation_vector(self, session: Dict) -> List[float]:
|
| 379 |
+
"""Return normalized [0, 1] feature vector."""
|
| 380 |
+
raw = session["features"]
|
| 381 |
+
normalized = []
|
| 382 |
+
for name in FEATURE_ORDER:
|
| 383 |
+
val = float(raw[name])
|
| 384 |
+
lo, hi = FEATURE_BOUNDS[name]
|
| 385 |
+
normalized.append(max(0.0, min(1.0, (val - lo) / max(hi - lo, 1e-9))))
|
| 386 |
+
return normalized
|
| 387 |
+
|
| 388 |
+
def to_raw_vector(self, session: Dict) -> List[float]:
|
| 389 |
+
"""Return un-normalized feature vector (for inspection)."""
|
| 390 |
+
return [float(session["features"][name]) for name in FEATURE_ORDER]
|
| 391 |
+
|
| 392 |
+
# ── Internal session builder ─────────────────────────────────────
|
| 393 |
+
|
| 394 |
+
def _build_session(
|
| 395 |
+
self, profile: TrafficProfile, tick: int,
|
| 396 |
+
malicious: bool, attack_phase: int, scenario: str,
|
| 397 |
+
attacker_id: str | None,
|
| 398 |
+
) -> Dict:
|
| 399 |
+
self.session_counter += 1
|
| 400 |
+
rng = self.rng
|
| 401 |
+
|
| 402 |
+
# --- Volume & timing (correlated cluster) ---
|
| 403 |
+
packet_count = int(max(3, rng.normal(
|
| 404 |
+
profile.packet_mean, profile.packet_mean * profile.packet_std_frac,
|
| 405 |
+
)))
|
| 406 |
+
avg_packet_size = float(max(40.0, rng.normal(560.0, 160.0)))
|
| 407 |
+
# Bytes are correlated with packets and packet size
|
| 408 |
+
bytes_sent = float(max(200.0, packet_count * avg_packet_size * rng.uniform(0.40, 0.85)))
|
| 409 |
+
bytes_received = float(max(100.0, packet_count * avg_packet_size * rng.uniform(0.20, 0.60)))
|
| 410 |
+
duration_ms = float(max(10.0, rng.normal(
|
| 411 |
+
profile.duration_mean, profile.duration_mean * 0.30,
|
| 412 |
+
)))
|
| 413 |
+
# Inter-arrival derived from duration and packet count (correlated)
|
| 414 |
+
inter_arrival_mean = float(duration_ms / max(packet_count, 1))
|
| 415 |
+
inter_arrival_jitter = float(abs(rng.normal(
|
| 416 |
+
inter_arrival_mean * 0.30, inter_arrival_mean * 0.12,
|
| 417 |
+
)))
|
| 418 |
+
packet_size_variance = float(max(5.0, abs(rng.normal(
|
| 419 |
+
180.0 if malicious else 130.0, 60.0,
|
| 420 |
+
))))
|
| 421 |
+
|
| 422 |
+
# --- TLS / certificate (correlated cluster) ---
|
| 423 |
+
tls_enabled = rng.random() < profile.tls_probability
|
| 424 |
+
tls_version = int(rng.choice([1, 2], p=[0.20, 0.80])) if tls_enabled else 0
|
| 425 |
+
# Self-signed correlates with TLS state and profile
|
| 426 |
+
is_self_signed = bool(rng.random() < profile.self_signed_prob) if tls_enabled else False
|
| 427 |
+
cert_chain_length = int(max(0, rng.normal(3.0 if (tls_enabled and not is_self_signed) else 1.0, 0.8)))
|
| 428 |
+
cert_validity_days = float(max(1.0, rng.normal(
|
| 429 |
+
profile.cert_validity_mean, profile.cert_validity_mean * 0.30,
|
| 430 |
+
)))
|
| 431 |
+
|
| 432 |
+
# --- Network metadata ---
|
| 433 |
+
dst_port = int(rng.choice(profile.common_ports))
|
| 434 |
+
src_port = int(rng.integers(1024, 65535))
|
| 435 |
+
protocol = int(rng.choice([0, 1, 2], p=[0.50, 0.32, 0.18]))
|
| 436 |
+
dns_query_count = int(max(0, rng.poisson(3 if malicious else 1)))
|
| 437 |
+
|
| 438 |
+
# --- Behavioral context (correlated with profile) ---
|
| 439 |
+
connection_reuse = float(np.clip(rng.normal(
|
| 440 |
+
profile.connection_reuse_mean, 0.12,
|
| 441 |
+
), 0.0, 1.0))
|
| 442 |
+
geo_distance = float(max(0.0, rng.normal(
|
| 443 |
+
profile.geo_distance_mean, profile.geo_distance_mean * 0.25,
|
| 444 |
+
)))
|
| 445 |
+
session_history_score = float(np.clip(rng.normal(
|
| 446 |
+
profile.history_score_mean, 0.10,
|
| 447 |
+
), 0.0, 1.0))
|
| 448 |
+
entropy_score = float(np.clip(rng.normal(
|
| 449 |
+
profile.entropy_mean, profile.entropy_std,
|
| 450 |
+
), 0.02, 0.99))
|
| 451 |
+
ja3_lo, ja3_hi = profile.ja3_cluster_range
|
| 452 |
+
ja3_hash_cluster = int(rng.integers(ja3_lo, max(ja3_lo + 1, ja3_hi)))
|
| 453 |
+
time_of_day = float((tick % 1440) / 1440.0)
|
| 454 |
+
|
| 455 |
+
features = {
|
| 456 |
+
"bytes_sent": math.log1p(bytes_sent),
|
| 457 |
+
"bytes_received": math.log1p(bytes_received),
|
| 458 |
+
"duration_ms": duration_ms,
|
| 459 |
+
"packet_count": packet_count,
|
| 460 |
+
"avg_packet_size": avg_packet_size,
|
| 461 |
+
"packet_size_variance": packet_size_variance,
|
| 462 |
+
"inter_arrival_mean": inter_arrival_mean,
|
| 463 |
+
"inter_arrival_jitter": inter_arrival_jitter,
|
| 464 |
+
"src_port": src_port,
|
| 465 |
+
"dst_port": dst_port,
|
| 466 |
+
"protocol": protocol,
|
| 467 |
+
"tls_version": tls_version,
|
| 468 |
+
"ja3_hash_cluster": ja3_hash_cluster,
|
| 469 |
+
"cert_chain_length": cert_chain_length,
|
| 470 |
+
"cert_validity_days": cert_validity_days,
|
| 471 |
+
"is_self_signed": int(is_self_signed),
|
| 472 |
+
"dns_query_count": dns_query_count,
|
| 473 |
+
"connection_reuse": connection_reuse,
|
| 474 |
+
"geo_distance": geo_distance,
|
| 475 |
+
"time_of_day": time_of_day,
|
| 476 |
+
"session_history_score": session_history_score,
|
| 477 |
+
"entropy_score": entropy_score,
|
| 478 |
+
}
|
| 479 |
+
|
| 480 |
+
# Session TTL: malicious sessions expire faster (pressure to act)
|
| 481 |
+
ttl = 2 if malicious else 3
|
| 482 |
+
|
| 483 |
+
return {
|
| 484 |
+
"session_id": f"s-{self.session_counter:07d}",
|
| 485 |
+
"features": features,
|
| 486 |
+
"metadata": {
|
| 487 |
+
"malicious": malicious,
|
| 488 |
+
"attack_phase": attack_phase,
|
| 489 |
+
"scenario": scenario,
|
| 490 |
+
"profile": profile.name,
|
| 491 |
+
"attacker_id": attacker_id,
|
| 492 |
+
"revealed": False,
|
| 493 |
+
},
|
| 494 |
+
"created_tick": tick,
|
| 495 |
+
"expires_tick": tick + ttl,
|
| 496 |
+
}
|
server/utils/reward_engine.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Multi-objective reward engine for the Adaptive AI Firewall environment.
|
| 2 |
+
|
| 3 |
+
Computes R = α·security + β·availability + γ·efficiency + δ·timeliness
|
| 4 |
+
with careful balance to prevent degenerate policies (block-all / allow-all).
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
from typing import Dict, Tuple
|
| 9 |
+
import math
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
ACTIONS = {
|
| 13 |
+
0: "ALLOW",
|
| 14 |
+
1: "BLOCK",
|
| 15 |
+
2: "INSPECT",
|
| 16 |
+
3: "SANDBOX",
|
| 17 |
+
4: "RATE_LIMIT",
|
| 18 |
+
5: "QUARANTINE",
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
# Costs tuned so total episode cost stays well within budget range
|
| 22 |
+
ACTION_COSTS = {
|
| 23 |
+
0: {"latency": 0.0, "compute": 0.0},
|
| 24 |
+
1: {"latency": 0.0, "compute": 0.005},
|
| 25 |
+
2: {"latency": 0.08, "compute": 0.05},
|
| 26 |
+
3: {"latency": 0.20, "compute": 0.12},
|
| 27 |
+
4: {"latency": 0.02, "compute": 0.015},
|
| 28 |
+
5: {"latency": 0.05, "compute": 0.025},
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
# Actions that are considered "blocking" (remove traffic from the network)
|
| 32 |
+
BLOCKING_ACTIONS = frozenset({1, 3, 5})
|
| 33 |
+
# Actions that are considered "inspection" (gather more info)
|
| 34 |
+
INSPECTION_ACTIONS = frozenset({2})
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class RewardEngine:
|
| 38 |
+
"""Weighted multi-objective reward with anti-degeneracy safeguards.
|
| 39 |
+
|
| 40 |
+
Key design choices (from RL perspective):
|
| 41 |
+
- Miss penalty (0.8) vs false-positive penalty (0.35) ratio is ~2.3:1
|
| 42 |
+
This creates genuine tension — block-all loses ~0.35 per benign session,
|
| 43 |
+
and with 60–80% benign traffic, the total FP penalty dominates.
|
| 44 |
+
- Correct ALLOW on benign traffic gets +0.08 (dense positive signal).
|
| 45 |
+
- INSPECT keeps the session alive (handled by environment) and gets a
|
| 46 |
+
small bonus when it correctly identifies malicious traffic.
|
| 47 |
+
- Timeliness bonus decays exponentially with kill-chain phase.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
alpha: float = 0.35,
|
| 53 |
+
beta: float = 0.30,
|
| 54 |
+
gamma: float = 0.20,
|
| 55 |
+
delta: float = 0.15,
|
| 56 |
+
) -> None:
|
| 57 |
+
self.alpha = alpha
|
| 58 |
+
self.beta = beta
|
| 59 |
+
self.gamma = gamma
|
| 60 |
+
self.delta = delta
|
| 61 |
+
|
| 62 |
+
def action_cost(self, action: int) -> float:
|
| 63 |
+
costs = ACTION_COSTS.get(action, ACTION_COSTS[0])
|
| 64 |
+
return float(costs["latency"] + costs["compute"])
|
| 65 |
+
|
| 66 |
+
def reward(
|
| 67 |
+
self,
|
| 68 |
+
action: int,
|
| 69 |
+
is_malicious: bool,
|
| 70 |
+
budget_remaining: float,
|
| 71 |
+
attack_phase: int,
|
| 72 |
+
inspect_correct: bool = False,
|
| 73 |
+
) -> Tuple[float, Dict[str, float]]:
|
| 74 |
+
blocked = action in BLOCKING_ACTIONS
|
| 75 |
+
inspected = action in INSPECTION_ACTIONS
|
| 76 |
+
|
| 77 |
+
# --- Security component ---
|
| 78 |
+
r_security = 0.0
|
| 79 |
+
if is_malicious and blocked:
|
| 80 |
+
r_security += 1.0
|
| 81 |
+
elif is_malicious and not blocked and not inspected:
|
| 82 |
+
r_security -= 2.0 # missed attack
|
| 83 |
+
elif is_malicious and inspected:
|
| 84 |
+
# Inspect alone doesn't stop the attack but is a valid step
|
| 85 |
+
r_security += 0.15 if inspect_correct else -0.5
|
| 86 |
+
|
| 87 |
+
# --- Availability component ---
|
| 88 |
+
r_availability = 0.0
|
| 89 |
+
if not is_malicious and blocked:
|
| 90 |
+
r_availability -= 1.2 # strong false-positive penalty
|
| 91 |
+
elif not is_malicious and inspected:
|
| 92 |
+
r_availability -= 0.15 # unnecessary inspection delay
|
| 93 |
+
elif not is_malicious and action == 0:
|
| 94 |
+
r_availability += 0.25 # correctly allowing benign traffic
|
| 95 |
+
|
| 96 |
+
# Rate-limiting benign traffic is a milder false positive
|
| 97 |
+
if not is_malicious and action == 4:
|
| 98 |
+
r_availability -= 0.4
|
| 99 |
+
|
| 100 |
+
# --- Efficiency component ---
|
| 101 |
+
cost = self.action_cost(action)
|
| 102 |
+
# Penalize cost relative to remaining budget (bigger penalty as budget shrinks)
|
| 103 |
+
r_efficiency = -cost / max(budget_remaining, 0.1)
|
| 104 |
+
|
| 105 |
+
# --- Timeliness component ---
|
| 106 |
+
# Exponential bonus for catching attacks early in kill chain
|
| 107 |
+
early_bonus = math.exp(-max(attack_phase, 0))
|
| 108 |
+
r_timeliness = early_bonus if (is_malicious and blocked) else 0.0
|
| 109 |
+
|
| 110 |
+
total = (
|
| 111 |
+
self.alpha * r_security
|
| 112 |
+
+ self.beta * r_availability
|
| 113 |
+
+ self.gamma * r_efficiency
|
| 114 |
+
+ self.delta * r_timeliness
|
| 115 |
+
)
|
| 116 |
+
return total, {
|
| 117 |
+
"security": r_security,
|
| 118 |
+
"availability": r_availability,
|
| 119 |
+
"efficiency": r_efficiency,
|
| 120 |
+
"timeliness": r_timeliness,
|
| 121 |
+
"cost": cost,
|
| 122 |
+
}
|
server/utils/threat_engine.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Multi-stage attack orchestrator following Cyber Kill Chain model.
|
| 2 |
+
|
| 3 |
+
Each attacker has a scenario (one of 5 patterns) and progresses through
|
| 4 |
+
phases 0→3. Adaptation is non-trivial:
|
| 5 |
+
- Detected attackers may switch to stealth mode (mimic benign profiles)
|
| 6 |
+
- Undetected attackers escalate normally
|
| 7 |
+
- Fully blocked attackers are terminated
|
| 8 |
+
- Attackers that reach exfiltration (phase 3) are marked as succeeded
|
| 9 |
+
"""
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
from typing import Dict, List, Set
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
|
| 17 |
+
# Updated import path
|
| 18 |
+
from server.utils.data_loader import TrafficGenerator
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
SCENARIOS = [
|
| 22 |
+
"port_scan_exploit_c2",
|
| 23 |
+
"credential_stuffing_lateral",
|
| 24 |
+
"supply_chain_compromise",
|
| 25 |
+
"low_and_slow_apt",
|
| 26 |
+
"ddos_amplification",
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
# How many sessions each scenario generates per phase
|
| 30 |
+
SESSION_COUNTS: Dict[str, List[int]] = {
|
| 31 |
+
"port_scan_exploit_c2": [4, 2, 1, 2],
|
| 32 |
+
"credential_stuffing_lateral": [3, 3, 2, 2],
|
| 33 |
+
"supply_chain_compromise": [1, 1, 1, 2],
|
| 34 |
+
"low_and_slow_apt": [1, 1, 1, 1],
|
| 35 |
+
"ddos_amplification": [6, 10, 15, 20],
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
# Probability that an attacker escalates per tick (if not detected)
|
| 39 |
+
ESCALATION_PROB: Dict[str, float] = {
|
| 40 |
+
"port_scan_exploit_c2": 0.30,
|
| 41 |
+
"credential_stuffing_lateral": 0.25,
|
| 42 |
+
"supply_chain_compromise": 0.15,
|
| 43 |
+
"low_and_slow_apt": 0.10,
|
| 44 |
+
"ddos_amplification": 0.40,
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@dataclass
|
| 49 |
+
class AttackerState:
|
| 50 |
+
attacker_id: str
|
| 51 |
+
scenario: str
|
| 52 |
+
phase: int = 0
|
| 53 |
+
times_detected: int = 0
|
| 54 |
+
stealth_mode: bool = False
|
| 55 |
+
alive: bool = True
|
| 56 |
+
succeeded: bool = False
|
| 57 |
+
ticks_alive: int = 0
|
| 58 |
+
sessions_blocked: int = 0
|
| 59 |
+
sessions_generated: int = 0
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class ThreatEngine:
|
| 63 |
+
"""Manages the lifecycle of active attackers and generates attack sessions."""
|
| 64 |
+
|
| 65 |
+
def __init__(self, seed: int = 0) -> None:
|
| 66 |
+
self.rng = np.random.default_rng(seed)
|
| 67 |
+
self._attacker_counter = 0
|
| 68 |
+
self._active_attackers: Dict[str, AttackerState] = {}
|
| 69 |
+
self._dead_attackers: List[AttackerState] = []
|
| 70 |
+
self._threat_intel: Dict = {
|
| 71 |
+
"known_bad_ports": [21, 22, 23, 25, 445, 3389, 5900],
|
| 72 |
+
"known_bad_ja3_ranges": [(200, 255), (230, 255)],
|
| 73 |
+
"active_campaigns": [],
|
| 74 |
+
"recent_detections": 0,
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
def reset(self) -> None:
|
| 78 |
+
self._attacker_counter = 0
|
| 79 |
+
self._active_attackers = {}
|
| 80 |
+
self._dead_attackers = []
|
| 81 |
+
self._threat_intel["active_campaigns"] = []
|
| 82 |
+
self._threat_intel["recent_detections"] = 0
|
| 83 |
+
|
| 84 |
+
def maybe_spawn_attacker(self, threat_probability: float) -> None:
|
| 85 |
+
"""Probabilistically spawn a new attacker."""
|
| 86 |
+
if self.rng.random() > threat_probability:
|
| 87 |
+
return
|
| 88 |
+
self._attacker_counter += 1
|
| 89 |
+
scenario = SCENARIOS[int(self.rng.integers(0, len(SCENARIOS)))]
|
| 90 |
+
attacker_id = f"a-{self._attacker_counter:04d}"
|
| 91 |
+
state = AttackerState(attacker_id=attacker_id, scenario=scenario)
|
| 92 |
+
self._active_attackers[attacker_id] = state
|
| 93 |
+
# Update threat intel
|
| 94 |
+
campaigns = set(self._threat_intel["active_campaigns"])
|
| 95 |
+
campaigns.add(scenario)
|
| 96 |
+
self._threat_intel["active_campaigns"] = sorted(campaigns)
|
| 97 |
+
|
| 98 |
+
def generate_attack_sessions(
|
| 99 |
+
self, tick: int, generator: TrafficGenerator,
|
| 100 |
+
blocked_attackers: Set[str],
|
| 101 |
+
) -> List[Dict]:
|
| 102 |
+
"""Generate attack sessions for all active attackers, handling adaptation."""
|
| 103 |
+
sessions: List[Dict] = []
|
| 104 |
+
|
| 105 |
+
for attacker in list(self._active_attackers.values()):
|
| 106 |
+
if not attacker.alive:
|
| 107 |
+
continue
|
| 108 |
+
|
| 109 |
+
attacker.ticks_alive += 1
|
| 110 |
+
|
| 111 |
+
# --- Handle detection / blocking ---
|
| 112 |
+
if attacker.attacker_id in blocked_attackers:
|
| 113 |
+
attacker.times_detected += 1
|
| 114 |
+
attacker.sessions_blocked += 1
|
| 115 |
+
self._threat_intel["recent_detections"] += 1
|
| 116 |
+
|
| 117 |
+
if attacker.times_detected >= 3:
|
| 118 |
+
# Fully blocked — attacker gives up
|
| 119 |
+
attacker.alive = False
|
| 120 |
+
self._dead_attackers.append(attacker)
|
| 121 |
+
continue
|
| 122 |
+
elif attacker.times_detected >= 2:
|
| 123 |
+
# Switch to stealth mode — generate fewer, more benign-looking sessions
|
| 124 |
+
attacker.stealth_mode = True
|
| 125 |
+
else:
|
| 126 |
+
# First detection — try to advance past detected phase
|
| 127 |
+
attacker.phase = min(attacker.phase + 1, 3)
|
| 128 |
+
|
| 129 |
+
# --- Natural phase escalation ---
|
| 130 |
+
elif self.rng.random() < ESCALATION_PROB.get(attacker.scenario, 0.2):
|
| 131 |
+
attacker.phase = min(attacker.phase + 1, 3)
|
| 132 |
+
|
| 133 |
+
# --- Check for success (exfiltration complete) ---
|
| 134 |
+
if attacker.phase == 3 and attacker.ticks_alive > 8:
|
| 135 |
+
if self.rng.random() < 0.15:
|
| 136 |
+
attacker.succeeded = True
|
| 137 |
+
attacker.alive = False
|
| 138 |
+
self._dead_attackers.append(attacker)
|
| 139 |
+
continue
|
| 140 |
+
|
| 141 |
+
# --- Generate sessions based on current state ---
|
| 142 |
+
counts = SESSION_COUNTS.get(attacker.scenario, [2, 2, 2, 2])
|
| 143 |
+
count = counts[min(attacker.phase, 3)]
|
| 144 |
+
|
| 145 |
+
if attacker.stealth_mode:
|
| 146 |
+
# In stealth mode: reduce count, use profiles that look more benign
|
| 147 |
+
count = max(1, count // 2)
|
| 148 |
+
|
| 149 |
+
generated = generator.generate_malicious_sessions(
|
| 150 |
+
tick=tick,
|
| 151 |
+
count=count,
|
| 152 |
+
attack_phase=attacker.phase,
|
| 153 |
+
scenario=attacker.scenario,
|
| 154 |
+
attacker_id=attacker.attacker_id,
|
| 155 |
+
)
|
| 156 |
+
attacker.sessions_generated += len(generated)
|
| 157 |
+
sessions.extend(generated)
|
| 158 |
+
|
| 159 |
+
return sessions
|
| 160 |
+
|
| 161 |
+
def intelligence_feed(self) -> Dict:
|
| 162 |
+
"""Return threat intelligence available to the agent."""
|
| 163 |
+
active_scenarios = set()
|
| 164 |
+
for a in self._active_attackers.values():
|
| 165 |
+
if a.alive:
|
| 166 |
+
active_scenarios.add(a.scenario)
|
| 167 |
+
self._threat_intel["active_campaigns"] = sorted(active_scenarios)
|
| 168 |
+
return dict(self._threat_intel)
|
| 169 |
+
|
| 170 |
+
def attacker_outcomes(self) -> Dict[str, str]:
|
| 171 |
+
"""Return status of all known attackers (for info/debugging)."""
|
| 172 |
+
outcomes: Dict[str, str] = {}
|
| 173 |
+
for a in self._active_attackers.values():
|
| 174 |
+
if a.alive:
|
| 175 |
+
outcomes[a.attacker_id] = "active"
|
| 176 |
+
elif a.succeeded:
|
| 177 |
+
outcomes[a.attacker_id] = "succeeded"
|
| 178 |
+
else:
|
| 179 |
+
outcomes[a.attacker_id] = "stopped"
|
| 180 |
+
for a in self._dead_attackers:
|
| 181 |
+
if a.attacker_id not in outcomes:
|
| 182 |
+
outcomes[a.attacker_id] = "succeeded" if a.succeeded else "stopped"
|
| 183 |
+
return outcomes
|
tests/conftest.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
|
| 5 |
+
from server.baseline.heuristic_agent import heuristic_policy
|
| 6 |
+
from server.baseline.random_agent import random_policy
|
| 7 |
+
from server.firewall_environment import FirewallEnvironment
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@pytest.fixture
|
| 11 |
+
def env_easy() -> FirewallEnvironment:
|
| 12 |
+
env = FirewallEnvironment(seed=101)
|
| 13 |
+
env.reset(task="easy", seed=101)
|
| 14 |
+
return env
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@pytest.fixture
|
| 18 |
+
def env_medium() -> FirewallEnvironment:
|
| 19 |
+
env = FirewallEnvironment(seed=202)
|
| 20 |
+
env.reset(task="medium", seed=202)
|
| 21 |
+
return env
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@pytest.fixture
|
| 25 |
+
def env_hard() -> FirewallEnvironment:
|
| 26 |
+
env = FirewallEnvironment(seed=303)
|
| 27 |
+
env.reset(task="hard", seed=303)
|
| 28 |
+
return env
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@pytest.fixture
|
| 32 |
+
def random_agent_policy():
|
| 33 |
+
return random_policy(seed=9)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@pytest.fixture
|
| 37 |
+
def heuristic_agent_policy():
|
| 38 |
+
return heuristic_policy
|
tests/test_all.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Comprehensive tests for the Adaptive AI Firewall environment.
|
| 2 |
+
|
| 3 |
+
Covers: feature generation, reward mechanics, threat lifecycle,
|
| 4 |
+
grading determinism, degenerate policy detection, and budget management.
|
| 5 |
+
"""
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from server.utils.data_loader import FEATURE_ORDER, TrafficGenerator
|
| 9 |
+
from server.utils.threat_engine import ThreatEngine
|
| 10 |
+
from server.utils.reward_engine import RewardEngine
|
| 11 |
+
from server.firewall_environment import (
|
| 12 |
+
FirewallEnvironment, OBS_DIM, NUM_ACTIONS,
|
| 13 |
+
)
|
| 14 |
+
from server.graders import run_deterministic_grade, grade_stats
|
| 15 |
+
from server.baseline.random_agent import random_policy, block_all_policy
|
| 16 |
+
from server.baseline.heuristic_agent import heuristic_policy
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# ═══════════════════════════════════════════════════════════════════
|
| 20 |
+
# Traffic Generator
|
| 21 |
+
# ═══════════════════════════════════════════════════════════════════
|
| 22 |
+
|
| 23 |
+
class TestTrafficGenerator:
|
| 24 |
+
def test_feature_dimension(self):
|
| 25 |
+
gen = TrafficGenerator(seed=11)
|
| 26 |
+
session = gen.generate_benign_sessions(tick=0, count=1)[0]
|
| 27 |
+
assert len(FEATURE_ORDER) == 22
|
| 28 |
+
assert len(gen.to_observation_vector(session)) == 22
|
| 29 |
+
|
| 30 |
+
def test_normalized_features_in_0_1(self):
|
| 31 |
+
gen = TrafficGenerator(seed=42)
|
| 32 |
+
for _ in range(50):
|
| 33 |
+
session = gen.generate_benign_sessions(tick=0, count=1)[0]
|
| 34 |
+
obs = gen.to_observation_vector(session)
|
| 35 |
+
for i, val in enumerate(obs):
|
| 36 |
+
assert 0.0 <= val <= 1.0, f"Feature {FEATURE_ORDER[i]} = {val} out of [0,1]"
|
| 37 |
+
|
| 38 |
+
def test_malicious_features_normalized(self):
|
| 39 |
+
gen = TrafficGenerator(seed=55)
|
| 40 |
+
for scenario in ["port_scan_exploit_c2", "ddos_amplification", "supply_chain_compromise"]:
|
| 41 |
+
for phase in range(4):
|
| 42 |
+
sessions = gen.generate_malicious_sessions(
|
| 43 |
+
tick=0, count=3, attack_phase=phase, scenario=scenario,
|
| 44 |
+
)
|
| 45 |
+
for s in sessions:
|
| 46 |
+
obs = gen.to_observation_vector(s)
|
| 47 |
+
for i, val in enumerate(obs):
|
| 48 |
+
assert 0.0 <= val <= 1.0
|
| 49 |
+
|
| 50 |
+
def test_benign_malicious_separation(self):
|
| 51 |
+
"""Verify that malicious and benign sessions have statistically different features."""
|
| 52 |
+
gen = TrafficGenerator(seed=77)
|
| 53 |
+
benign_vecs = []
|
| 54 |
+
for _ in range(100):
|
| 55 |
+
s = gen.generate_benign_sessions(tick=0, count=1)[0]
|
| 56 |
+
benign_vecs.append(gen.to_observation_vector(s))
|
| 57 |
+
|
| 58 |
+
mal_vecs = []
|
| 59 |
+
for phase in range(4):
|
| 60 |
+
for _ in range(25):
|
| 61 |
+
s = gen.generate_malicious_sessions(
|
| 62 |
+
tick=0, count=1, attack_phase=phase,
|
| 63 |
+
scenario="port_scan_exploit_c2",
|
| 64 |
+
)[0]
|
| 65 |
+
mal_vecs.append(gen.to_observation_vector(s))
|
| 66 |
+
|
| 67 |
+
benign_arr = np.array(benign_vecs)
|
| 68 |
+
mal_arr = np.array(mal_vecs)
|
| 69 |
+
|
| 70 |
+
# At least some features should have meaningfully different means
|
| 71 |
+
mean_diff = np.abs(benign_arr.mean(axis=0) - mal_arr.mean(axis=0))
|
| 72 |
+
significant_features = (mean_diff > 0.08).sum()
|
| 73 |
+
assert significant_features >= 5, (
|
| 74 |
+
f"Only {significant_features} features differ — distributions too similar"
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
def test_session_ids_unique(self):
|
| 78 |
+
gen = TrafficGenerator(seed=99)
|
| 79 |
+
ids = set()
|
| 80 |
+
for _ in range(100):
|
| 81 |
+
sessions = gen.generate_benign_sessions(tick=0, count=3)
|
| 82 |
+
for s in sessions:
|
| 83 |
+
assert s["session_id"] not in ids
|
| 84 |
+
ids.add(s["session_id"])
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# ═══════════════════════════════════════════════════════════════════
|
| 88 |
+
# Reward Engine
|
| 89 |
+
# ═══════════════════════════════════════════════════════════════════
|
| 90 |
+
|
| 91 |
+
class TestRewardEngine:
|
| 92 |
+
def test_block_malicious_positive(self):
|
| 93 |
+
eng = RewardEngine()
|
| 94 |
+
r, _ = eng.reward(action=1, is_malicious=True, budget_remaining=50.0, attack_phase=0)
|
| 95 |
+
assert r > 0
|
| 96 |
+
|
| 97 |
+
def test_miss_malicious_negative(self):
|
| 98 |
+
eng = RewardEngine()
|
| 99 |
+
r, _ = eng.reward(action=0, is_malicious=True, budget_remaining=50.0, attack_phase=2)
|
| 100 |
+
assert r < 0
|
| 101 |
+
|
| 102 |
+
def test_block_benign_negative(self):
|
| 103 |
+
eng = RewardEngine()
|
| 104 |
+
r, _ = eng.reward(action=1, is_malicious=False, budget_remaining=50.0, attack_phase=0)
|
| 105 |
+
assert r < 0
|
| 106 |
+
|
| 107 |
+
def test_allow_benign_positive(self):
|
| 108 |
+
eng = RewardEngine()
|
| 109 |
+
r, _ = eng.reward(action=0, is_malicious=False, budget_remaining=50.0, attack_phase=0)
|
| 110 |
+
assert r > 0, "Correctly allowing benign traffic should be rewarded"
|
| 111 |
+
|
| 112 |
+
def test_block_all_loses_in_mixed_traffic(self):
|
| 113 |
+
"""Block-all should have negative total reward on benign-heavy traffic."""
|
| 114 |
+
eng = RewardEngine()
|
| 115 |
+
total = 0.0
|
| 116 |
+
# Simulate 80% benign, 20% malicious
|
| 117 |
+
for _ in range(80):
|
| 118 |
+
r, _ = eng.reward(action=1, is_malicious=False, budget_remaining=50.0, attack_phase=0)
|
| 119 |
+
total += r
|
| 120 |
+
for _ in range(20):
|
| 121 |
+
r, _ = eng.reward(action=1, is_malicious=True, budget_remaining=50.0, attack_phase=1)
|
| 122 |
+
total += r
|
| 123 |
+
# Block-all should have lower score than a selective policy
|
| 124 |
+
assert total < 0, f"Block-all total reward {total} should be negative on 80/20 mix"
|
| 125 |
+
|
| 126 |
+
def test_early_detection_bonus(self):
|
| 127 |
+
eng = RewardEngine()
|
| 128 |
+
r_early, _ = eng.reward(action=1, is_malicious=True, budget_remaining=50.0, attack_phase=0)
|
| 129 |
+
r_late, _ = eng.reward(action=1, is_malicious=True, budget_remaining=50.0, attack_phase=3)
|
| 130 |
+
assert r_early > r_late, "Early detection should give higher reward"
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
# ═══════════════════════════════════════════════════════════════════
|
| 134 |
+
# Threat Engine
|
| 135 |
+
# ═══════════════════════════════════════════════════════════════════
|
| 136 |
+
|
| 137 |
+
class TestThreatEngine:
|
| 138 |
+
def test_spawn_and_generate(self):
|
| 139 |
+
engine = ThreatEngine(seed=22)
|
| 140 |
+
gen = TrafficGenerator(seed=23)
|
| 141 |
+
engine.maybe_spawn_attacker(1.0)
|
| 142 |
+
sessions = engine.generate_attack_sessions(tick=0, generator=gen, blocked_attackers=set())
|
| 143 |
+
assert len(sessions) > 0
|
| 144 |
+
assert all(s["metadata"]["malicious"] for s in sessions)
|
| 145 |
+
|
| 146 |
+
def test_attacker_dies_after_3_blocks(self):
|
| 147 |
+
engine = ThreatEngine(seed=33)
|
| 148 |
+
gen = TrafficGenerator(seed=34)
|
| 149 |
+
engine.maybe_spawn_attacker(1.0)
|
| 150 |
+
attacker_id = list(engine._active_attackers.keys())[0]
|
| 151 |
+
|
| 152 |
+
for _ in range(3):
|
| 153 |
+
engine.generate_attack_sessions(
|
| 154 |
+
tick=0, generator=gen, blocked_attackers={attacker_id},
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# After 3 blocks, attacker should be dead
|
| 158 |
+
attacker = engine._active_attackers[attacker_id]
|
| 159 |
+
assert not attacker.alive
|
| 160 |
+
|
| 161 |
+
def test_attacker_outcomes(self):
|
| 162 |
+
engine = ThreatEngine(seed=44)
|
| 163 |
+
gen = TrafficGenerator(seed=45)
|
| 164 |
+
engine.maybe_spawn_attacker(1.0)
|
| 165 |
+
engine.generate_attack_sessions(tick=0, generator=gen, blocked_attackers=set())
|
| 166 |
+
outcomes = engine.attacker_outcomes()
|
| 167 |
+
assert len(outcomes) > 0
|
| 168 |
+
assert all(v in ("active", "stopped", "succeeded") for v in outcomes.values())
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
# ═══════════════════════════════════════════════════════════════════
|
| 172 |
+
# Firewall Environment
|
| 173 |
+
# ═══════════════════════════════════════════════════════════════════
|
| 174 |
+
|
| 175 |
+
class TestFirewallEnvironment:
|
| 176 |
+
def test_reset_returns_valid_state(self):
|
| 177 |
+
env = FirewallEnvironment(seed=99)
|
| 178 |
+
state = env.reset(task="easy", seed=100)
|
| 179 |
+
assert state["observation_dim"] == OBS_DIM
|
| 180 |
+
assert state["num_actions"] == NUM_ACTIONS
|
| 181 |
+
assert state["budget_remaining"] > 0
|
| 182 |
+
|
| 183 |
+
def test_step_returns_expected_keys(self):
|
| 184 |
+
env = FirewallEnvironment(seed=99)
|
| 185 |
+
env.reset(task="easy", seed=100)
|
| 186 |
+
pending = list(env.pending_sessions.keys())
|
| 187 |
+
actions = {sid: 0 for sid in pending[:3]}
|
| 188 |
+
response = env.step(actions)
|
| 189 |
+
assert "reward" in response
|
| 190 |
+
assert "done" in response
|
| 191 |
+
assert "state" in response
|
| 192 |
+
|
| 193 |
+
def test_inspect_keeps_session_alive(self):
|
| 194 |
+
env = FirewallEnvironment(seed=50)
|
| 195 |
+
env.reset(task="easy", seed=50)
|
| 196 |
+
sid = list(env.pending_sessions.keys())[0]
|
| 197 |
+
env._apply_action(sid, 2) # INSPECT
|
| 198 |
+
assert sid in env.inspected_sessions, "INSPECT should keep session in inspected pool"
|
| 199 |
+
|
| 200 |
+
def test_inspect_then_block(self):
|
| 201 |
+
"""Two-phase: inspect → block."""
|
| 202 |
+
env = FirewallEnvironment(seed=60)
|
| 203 |
+
env.reset(task="easy", seed=60)
|
| 204 |
+
sid = list(env.pending_sessions.keys())[0]
|
| 205 |
+
|
| 206 |
+
# Phase 1: inspect
|
| 207 |
+
r1, _ = env._apply_action(sid, 2)
|
| 208 |
+
assert sid in env.inspected_sessions
|
| 209 |
+
|
| 210 |
+
# Phase 2: block
|
| 211 |
+
r2, _ = env._apply_action(sid, 1)
|
| 212 |
+
assert sid not in env.inspected_sessions
|
| 213 |
+
|
| 214 |
+
def test_budget_stays_positive_with_allow(self):
|
| 215 |
+
"""All-allow policy should preserve most of the budget."""
|
| 216 |
+
env = FirewallEnvironment(seed=70)
|
| 217 |
+
env.reset(task="easy", seed=70)
|
| 218 |
+
initial = env.budget_remaining
|
| 219 |
+
for _ in range(50):
|
| 220 |
+
sids = list(env.pending_sessions.keys())
|
| 221 |
+
if not sids:
|
| 222 |
+
break
|
| 223 |
+
env.step({sid: 0 for sid in sids})
|
| 224 |
+
# ALLOW costs 0, so budget should barely change
|
| 225 |
+
assert env.budget_remaining >= initial * 0.95
|
| 226 |
+
|
| 227 |
+
def test_budget_nonzero_with_reasonable_policy(self):
|
| 228 |
+
"""Heuristic policy should leave some budget remaining."""
|
| 229 |
+
env = FirewallEnvironment(seed=80)
|
| 230 |
+
env.reset(task="easy", seed=80)
|
| 231 |
+
for _ in range(env.max_steps):
|
| 232 |
+
sids = (
|
| 233 |
+
list(env.inspected_sessions.keys())
|
| 234 |
+
+ list(env.pending_sessions.keys())
|
| 235 |
+
)
|
| 236 |
+
actions = heuristic_policy(env, sids)
|
| 237 |
+
resp = env.step(actions)
|
| 238 |
+
if resp["done"]:
|
| 239 |
+
break
|
| 240 |
+
stats = env.get_network_stats()
|
| 241 |
+
assert stats["efficiency"] > 0.0, f"Efficiency should be > 0, got {stats['efficiency']}"
|
| 242 |
+
|
| 243 |
+
def test_expired_malicious_counted_in_metrics(self):
|
| 244 |
+
"""Expired malicious sessions must be counted in totals."""
|
| 245 |
+
env = FirewallEnvironment(seed=90)
|
| 246 |
+
env.reset(task="easy", seed=90)
|
| 247 |
+
# Let everything expire by stepping with no actions
|
| 248 |
+
for _ in range(10):
|
| 249 |
+
env.step({})
|
| 250 |
+
stats = env.get_network_stats()
|
| 251 |
+
if stats["total_malicious"] > 0:
|
| 252 |
+
# expired malicious should be counted
|
| 253 |
+
assert stats["expired_malicious"] > 0
|
| 254 |
+
|
| 255 |
+
def test_single_session_mode(self):
|
| 256 |
+
"""step_single returns valid observation and reward."""
|
| 257 |
+
env = FirewallEnvironment(seed=100)
|
| 258 |
+
env.reset(task="easy", seed=100)
|
| 259 |
+
result = env.step_single(0) # ALLOW
|
| 260 |
+
assert len(result["observation"]) == OBS_DIM
|
| 261 |
+
assert "reward" in result
|
| 262 |
+
assert "done" in result
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
# ═══════════════════════════════════════════════════════════════════
|
| 266 |
+
# Graders
|
| 267 |
+
# ═══════════════════════════════════════════════════════════════════
|
| 268 |
+
|
| 269 |
+
class TestGraders:
|
| 270 |
+
def test_deterministic_grading(self):
|
| 271 |
+
env = FirewallEnvironment(seed=31)
|
| 272 |
+
p1 = random_policy(seed=9)
|
| 273 |
+
first = run_deterministic_grade(env, task="easy", policy=p1)["score"]
|
| 274 |
+
p2 = random_policy(seed=9)
|
| 275 |
+
second = run_deterministic_grade(env, task="easy", policy=p2)["score"]
|
| 276 |
+
assert first == second, "Same seed should produce same score"
|
| 277 |
+
|
| 278 |
+
def test_score_in_valid_range(self):
|
| 279 |
+
env = FirewallEnvironment(seed=40)
|
| 280 |
+
for task in ("easy", "medium", "hard"):
|
| 281 |
+
policy = random_policy(seed=7)
|
| 282 |
+
result = run_deterministic_grade(env, task=task, policy=policy)
|
| 283 |
+
assert 0.0 <= result["score"] <= 1.0
|
| 284 |
+
|
| 285 |
+
def test_heuristic_beats_random(self):
|
| 286 |
+
"""Core sanity check: heuristic > random on easy task."""
|
| 287 |
+
env = FirewallEnvironment(seed=50)
|
| 288 |
+
rp = random_policy(seed=7)
|
| 289 |
+
r_score = run_deterministic_grade(env, task="easy", policy=rp)["score"]
|
| 290 |
+
h_score = run_deterministic_grade(env, task="easy", policy=heuristic_policy)["score"]
|
| 291 |
+
assert h_score > r_score, (
|
| 292 |
+
f"Heuristic ({h_score:.4f}) must beat random ({r_score:.4f}) on easy task"
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
def test_heuristic_beats_block_all(self):
|
| 296 |
+
"""Block-all should not dominate heuristic."""
|
| 297 |
+
env = FirewallEnvironment(seed=60)
|
| 298 |
+
b_score = run_deterministic_grade(env, task="easy", policy=block_all_policy)["score"]
|
| 299 |
+
h_score = run_deterministic_grade(env, task="easy", policy=heuristic_policy)["score"]
|
| 300 |
+
assert h_score > b_score, (
|
| 301 |
+
f"Heuristic ({h_score:.4f}) must beat block-all ({b_score:.4f})"
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
def test_grade_stats_clamps(self):
|
| 305 |
+
stats = {"detection_rate": 1.5, "false_positive_rate": -0.5, "efficiency": 2.0}
|
| 306 |
+
result = grade_stats("easy", stats)
|
| 307 |
+
assert result["score"] <= 1.0
|
tests/test_environment_dynamics.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from server.baseline.heuristic_agent import heuristic_policy
|
| 2 |
+
from server.firewall_environment import FirewallEnvironment
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def test_expired_malicious_sessions_are_counted():
|
| 6 |
+
env = FirewallEnvironment(seed=11)
|
| 7 |
+
env.reset(task="easy", seed=11)
|
| 8 |
+
before = env.metrics.malicious_seen
|
| 9 |
+
for _ in range(4):
|
| 10 |
+
env.step({})
|
| 11 |
+
after = env.metrics.malicious_seen
|
| 12 |
+
assert after >= before
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def test_inspect_keeps_session_pending_and_reveals():
|
| 16 |
+
env = FirewallEnvironment(seed=12)
|
| 17 |
+
env.reset(task="easy", seed=12)
|
| 18 |
+
session_id = next(iter(env.pending_sessions.keys()))
|
| 19 |
+
env.take_action(session_id=session_id, action=2)
|
| 20 |
+
assert session_id in env.pending_sessions
|
| 21 |
+
assert env.pending_sessions[session_id]["metadata"]["revealed"] is True
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def test_step_single_has_fixed_size_action_mode():
|
| 25 |
+
env = FirewallEnvironment(seed=13)
|
| 26 |
+
env.reset(task="easy", seed=13)
|
| 27 |
+
response = env.step_single(action=0)
|
| 28 |
+
assert "focus_observation" in response["state"]
|
| 29 |
+
assert len(response["state"]["focus_observation"]) == 22
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def test_budget_is_scaled_by_episode_length():
|
| 33 |
+
env = FirewallEnvironment(seed=14, budget=50.0)
|
| 34 |
+
env.reset(task="hard", seed=14)
|
| 35 |
+
assert env.initial_budget >= env.max_steps * 0.35
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def test_attacker_outcomes_exposed_in_step_info():
|
| 39 |
+
env = FirewallEnvironment(seed=15)
|
| 40 |
+
env.reset(task="easy", seed=15)
|
| 41 |
+
session_id = next(iter(env.pending_sessions.keys()))
|
| 42 |
+
result = env.step({session_id: 1})
|
| 43 |
+
assert "attacker_outcomes" in result["info"]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def test_heuristic_policy_executes_over_pending_sessions():
|
| 47 |
+
env = FirewallEnvironment(seed=16)
|
| 48 |
+
env.reset(task="easy", seed=16)
|
| 49 |
+
actions = heuristic_policy(env, list(env.pending_sessions.keys())[:5])
|
| 50 |
+
assert isinstance(actions, dict)
|
tests/test_integration_policies.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from server.baseline.heuristic_agent import heuristic_policy
|
| 2 |
+
from server.baseline.random_agent import random_policy
|
| 3 |
+
from server.firewall_environment import FirewallEnvironment
|
| 4 |
+
from server.graders import run_deterministic_grade
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def always_allow_policy(_, session_ids):
|
| 8 |
+
return {sid: 0 for sid in session_ids}
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def always_block_policy(_, session_ids):
|
| 12 |
+
return {sid: 1 for sid in session_ids}
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def test_policy_ordering_easy_task():
|
| 16 |
+
env = FirewallEnvironment(seed=77)
|
| 17 |
+
random_score = run_deterministic_grade(env, task="easy", policy=random_policy(seed=7))["score"]
|
| 18 |
+
heuristic_score = run_deterministic_grade(env, task="easy", policy=heuristic_policy)["score"]
|
| 19 |
+
allow_score = run_deterministic_grade(env, task="easy", policy=always_allow_policy)["score"]
|
| 20 |
+
assert heuristic_score >= random_score
|
| 21 |
+
assert heuristic_score >= allow_score
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def test_block_all_is_not_best_strategy():
|
| 25 |
+
env = FirewallEnvironment(seed=88)
|
| 26 |
+
for task in ("easy", "medium", "hard"):
|
| 27 |
+
block_score = run_deterministic_grade(env, task=task, policy=always_block_policy)["score"]
|
| 28 |
+
heuristic_score = run_deterministic_grade(env, task=task, policy=heuristic_policy)["score"]
|
| 29 |
+
assert block_score <= heuristic_score
|
tests/test_reward_and_scores.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from server.firewall_environment import FirewallEnvironment
|
| 2 |
+
from server.graders import grade_stats
|
| 3 |
+
from server.utils.reward_engine import RewardEngine
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def test_grade_score_bounds():
|
| 7 |
+
stats = {
|
| 8 |
+
"detection_rate": 0.5,
|
| 9 |
+
"false_positive_rate": 0.1,
|
| 10 |
+
"efficiency": 0.8,
|
| 11 |
+
"early_detection_bonus": 0.7,
|
| 12 |
+
"cascade_prevention": 0.6,
|
| 13 |
+
}
|
| 14 |
+
for task in ("easy", "medium", "hard"):
|
| 15 |
+
score = grade_stats(task, stats)["score"]
|
| 16 |
+
assert 0.0 <= score <= 1.0
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def test_reward_range_is_reasonable():
|
| 20 |
+
engine = RewardEngine()
|
| 21 |
+
samples = [
|
| 22 |
+
engine.reward(action=0, is_malicious=False, budget_remaining=100.0, attack_phase=0)[0],
|
| 23 |
+
engine.reward(action=1, is_malicious=False, budget_remaining=100.0, attack_phase=0)[0],
|
| 24 |
+
engine.reward(action=1, is_malicious=True, budget_remaining=100.0, attack_phase=1)[0],
|
| 25 |
+
engine.reward(action=0, is_malicious=True, budget_remaining=100.0, attack_phase=3)[0],
|
| 26 |
+
]
|
| 27 |
+
assert min(samples) > -2.5
|
| 28 |
+
assert max(samples) < 2.5
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def test_efficiency_is_non_zero_after_episode():
|
| 32 |
+
env = FirewallEnvironment(seed=66)
|
| 33 |
+
env.reset(task="medium", seed=66)
|
| 34 |
+
done = False
|
| 35 |
+
while not done:
|
| 36 |
+
response = env.step({})
|
| 37 |
+
done = response["done"]
|
| 38 |
+
stats = env.get_network_stats()
|
| 39 |
+
assert stats["efficiency"] > 0.0
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|