Spaces:
Sleeping
Sleeping
Upload 5 files
Browse files- Dockerfile +36 -0
- baseline.py +247 -0
- env.py +641 -0
- main.py +199 -0
- openenv.yaml +172 -0
Dockerfile
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ─── Build stage ─────────────────────────────────────────────────────────────────
|
| 2 |
+
FROM python:3.11-slim AS builder
|
| 3 |
+
|
| 4 |
+
WORKDIR /build
|
| 5 |
+
|
| 6 |
+
# Install dependencies into a prefix we'll copy to the final image
|
| 7 |
+
COPY requirements.txt .
|
| 8 |
+
RUN pip install --no-cache-dir --prefix=/install -r requirements.txt
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# ─── Runtime stage ────────────────────────────────────────────────────────────────
|
| 12 |
+
FROM python:3.11-slim
|
| 13 |
+
|
| 14 |
+
# Hugging Face Spaces expects the app on port 7860
|
| 15 |
+
ENV PORT=7860
|
| 16 |
+
ENV PYTHONUNBUFFERED=1
|
| 17 |
+
ENV PYTHONDONTWRITEBYTECODE=1
|
| 18 |
+
|
| 19 |
+
WORKDIR /app
|
| 20 |
+
|
| 21 |
+
# Copy pre-installed packages from builder
|
| 22 |
+
COPY --from=builder /install /usr/local
|
| 23 |
+
|
| 24 |
+
# Copy application code
|
| 25 |
+
COPY env.py .
|
| 26 |
+
COPY main.py .
|
| 27 |
+
COPY baseline.py .
|
| 28 |
+
|
| 29 |
+
# HF Spaces: non-root user for safety
|
| 30 |
+
RUN useradd -m -u 1000 appuser && chown -R appuser /app
|
| 31 |
+
USER appuser
|
| 32 |
+
|
| 33 |
+
EXPOSE 7860
|
| 34 |
+
|
| 35 |
+
# Increase workers for concurrent evaluation runs
|
| 36 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1"]
|
baseline.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Baseline Inference Script — API Gateway Defender
|
| 3 |
+
=================================================
|
| 4 |
+
Evaluates an agent on all 3 tasks and prints reproducible scores.
|
| 5 |
+
|
| 6 |
+
Usage
|
| 7 |
+
-----
|
| 8 |
+
# With LLM (reads OPENAI_API_KEY from environment):
|
| 9 |
+
OPENAI_API_KEY=sk-... python baseline.py
|
| 10 |
+
|
| 11 |
+
# Heuristic fallback (no API key needed):
|
| 12 |
+
python baseline.py
|
| 13 |
+
|
| 14 |
+
The LLM agent receives the traffic logs and task description, then
|
| 15 |
+
produces a JSON action that is submitted to the environment.
|
| 16 |
+
|
| 17 |
+
The heuristic agent reads the visible logs statistically and picks
|
| 18 |
+
the correct rule — used to verify the grader is working correctly
|
| 19 |
+
and as a reproducible baseline for submission.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import json
|
| 23 |
+
import os
|
| 24 |
+
import sys
|
| 25 |
+
import urllib.error
|
| 26 |
+
import urllib.request
|
| 27 |
+
from typing import Any, Dict
|
| 28 |
+
|
| 29 |
+
# Allow running standalone (before FastAPI starts) by importing env directly
|
| 30 |
+
try:
|
| 31 |
+
from env import (
|
| 32 |
+
Action,
|
| 33 |
+
APIGatewayDefender,
|
| 34 |
+
TASK_DESCRIPTIONS,
|
| 35 |
+
run_heuristic_baseline,
|
| 36 |
+
)
|
| 37 |
+
_DIRECT_IMPORT = True
|
| 38 |
+
except ImportError:
|
| 39 |
+
_DIRECT_IMPORT = False
|
| 40 |
+
|
| 41 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
|
| 42 |
+
ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:8000")
|
| 43 |
+
LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4o-mini")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# ─── OpenAI helper ───────────────────────────────────────────────────────────────
|
| 47 |
+
|
| 48 |
+
def _call_openai(messages: list, max_tokens: int = 512) -> str:
|
| 49 |
+
"""Send a request to the OpenAI chat completions endpoint."""
|
| 50 |
+
payload = json.dumps(
|
| 51 |
+
{
|
| 52 |
+
"model": LLM_MODEL,
|
| 53 |
+
"messages": messages,
|
| 54 |
+
"max_tokens": max_tokens,
|
| 55 |
+
"temperature": 0.1,
|
| 56 |
+
}
|
| 57 |
+
).encode()
|
| 58 |
+
|
| 59 |
+
req = urllib.request.Request(
|
| 60 |
+
"https://api.openai.com/v1/chat/completions",
|
| 61 |
+
data=payload,
|
| 62 |
+
headers={
|
| 63 |
+
"Content-Type": "application/json",
|
| 64 |
+
"Authorization": f"Bearer {OPENAI_API_KEY}",
|
| 65 |
+
},
|
| 66 |
+
)
|
| 67 |
+
try:
|
| 68 |
+
with urllib.request.urlopen(req, timeout=30) as resp:
|
| 69 |
+
data = json.loads(resp.read())
|
| 70 |
+
return data["choices"][0]["message"]["content"]
|
| 71 |
+
except urllib.error.HTTPError as exc:
|
| 72 |
+
body = exc.read().decode(errors="replace")
|
| 73 |
+
raise RuntimeError(f"OpenAI API error {exc.code}: {body}") from exc
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def _parse_json_from_llm(raw: str) -> Dict[str, Any]:
|
| 77 |
+
"""Extract a JSON object from LLM output, stripping markdown fences if present."""
|
| 78 |
+
raw = raw.strip()
|
| 79 |
+
if raw.startswith("```"):
|
| 80 |
+
parts = raw.split("```")
|
| 81 |
+
# parts[1] is the fenced block; strip language tag if present
|
| 82 |
+
inner = parts[1]
|
| 83 |
+
if inner.lower().startswith("json"):
|
| 84 |
+
inner = inner[4:]
|
| 85 |
+
raw = inner.strip()
|
| 86 |
+
return json.loads(raw)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# ─── LLM agent ───────────────────────────────────────────────────────────────────
|
| 90 |
+
|
| 91 |
+
def _llm_agent_run(task_id: str) -> float:
|
| 92 |
+
"""
|
| 93 |
+
Run an LLM agent on a single task via the HTTP API.
|
| 94 |
+
|
| 95 |
+
1. Reset the environment.
|
| 96 |
+
2. Show the agent the traffic logs and task description.
|
| 97 |
+
3. Ask it to produce a JSON action.
|
| 98 |
+
4. Submit the action and return the reward score.
|
| 99 |
+
"""
|
| 100 |
+
import urllib.request as urlreq
|
| 101 |
+
|
| 102 |
+
def _post(path: str, body: Any) -> Any:
|
| 103 |
+
data = json.dumps(body).encode()
|
| 104 |
+
req = urlreq.Request(
|
| 105 |
+
f"{ENV_BASE_URL}{path}",
|
| 106 |
+
data=data,
|
| 107 |
+
headers={"Content-Type": "application/json"},
|
| 108 |
+
)
|
| 109 |
+
with urlreq.urlopen(req, timeout=15) as resp:
|
| 110 |
+
return json.loads(resp.read())
|
| 111 |
+
|
| 112 |
+
# 1. Reset
|
| 113 |
+
obs = _post("/reset", {"task_id": task_id})
|
| 114 |
+
|
| 115 |
+
# 2. Build prompt (truncate request list to 25 to stay within token budget)
|
| 116 |
+
sample_requests = obs["recent_requests"][:25]
|
| 117 |
+
|
| 118 |
+
system_prompt = (
|
| 119 |
+
"You are a Site Reliability Engineer responding to a live production incident. "
|
| 120 |
+
"You will be shown HTTP traffic logs and a task description. "
|
| 121 |
+
"Your job is to write exactly ONE firewall rule as a JSON object. "
|
| 122 |
+
"Respond with ONLY valid JSON — no prose, no markdown fences."
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
action_schema = (
|
| 126 |
+
"{\n"
|
| 127 |
+
' "action_type": "block_ip" | "add_rate_limit" | "block_user_agent" | "write_custom_middleware",\n'
|
| 128 |
+
' "target_ip": "<string, required for block_ip / add_rate_limit>",\n'
|
| 129 |
+
' "target_user_agent": "<string, required for block_user_agent>",\n'
|
| 130 |
+
' "regex_pattern": "<Python regex, required for write_custom_middleware>",\n'
|
| 131 |
+
' "max_requests": <int, optional — requests/min cap for add_rate_limit>\n'
|
| 132 |
+
"}"
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
user_prompt = (
|
| 136 |
+
f"TASK: {obs['task_description']}\n\n"
|
| 137 |
+
f"HINT: {obs.get('hint', '')}\n\n"
|
| 138 |
+
f"TRAFFIC SAMPLE (first 25 requests):\n"
|
| 139 |
+
f"{json.dumps(sample_requests, indent=2)}\n\n"
|
| 140 |
+
f"Respond with ONE JSON action using this schema:\n{action_schema}"
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# 3. Call LLM
|
| 144 |
+
llm_response = _call_openai(
|
| 145 |
+
[
|
| 146 |
+
{"role": "system", "content": system_prompt},
|
| 147 |
+
{"role": "user", "content": user_prompt},
|
| 148 |
+
]
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# 4. Parse action
|
| 152 |
+
try:
|
| 153 |
+
action_dict = _parse_json_from_llm(llm_response)
|
| 154 |
+
except (json.JSONDecodeError, KeyError) as exc:
|
| 155 |
+
print(f" [!] Failed to parse LLM response: {exc}\n Raw: {llm_response[:200]}")
|
| 156 |
+
return 0.0
|
| 157 |
+
|
| 158 |
+
# 5. Step
|
| 159 |
+
result = _post("/step", action_dict)
|
| 160 |
+
score = result["reward"]["score"]
|
| 161 |
+
msg = result["reward"]["message"]
|
| 162 |
+
print(f" Action: {action_dict}")
|
| 163 |
+
print(f" Result: {msg}")
|
| 164 |
+
return score
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
# ─── Main ────────────────────────────────────────────────────────────────────────
|
| 168 |
+
|
| 169 |
+
def run_baseline_direct() -> Dict[str, float]:
|
| 170 |
+
"""Run heuristic baseline directly on the Python class (no server needed)."""
|
| 171 |
+
return run_heuristic_baseline()
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def run_baseline_http() -> Dict[str, float]:
|
| 175 |
+
"""Run heuristic baseline via the HTTP API."""
|
| 176 |
+
import urllib.request as urlreq
|
| 177 |
+
|
| 178 |
+
req = urlreq.Request(
|
| 179 |
+
f"{ENV_BASE_URL}/baseline",
|
| 180 |
+
data=b"{}",
|
| 181 |
+
headers={"Content-Type": "application/json"},
|
| 182 |
+
method="POST",
|
| 183 |
+
)
|
| 184 |
+
with urlreq.urlopen(req, timeout=30) as resp:
|
| 185 |
+
data = json.loads(resp.read())
|
| 186 |
+
return data["scores"]
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def main() -> None:
|
| 190 |
+
print("=" * 55)
|
| 191 |
+
print(" API Gateway Defender — Baseline Evaluation")
|
| 192 |
+
print("=" * 55)
|
| 193 |
+
print()
|
| 194 |
+
|
| 195 |
+
task_ids = ["easy", "medium", "hard"]
|
| 196 |
+
scores: Dict[str, float] = {}
|
| 197 |
+
|
| 198 |
+
if OPENAI_API_KEY:
|
| 199 |
+
print(f"Mode : LLM agent ({LLM_MODEL})")
|
| 200 |
+
print(f"URL : {ENV_BASE_URL}")
|
| 201 |
+
print()
|
| 202 |
+
for task_id in task_ids:
|
| 203 |
+
print(f"[Task: {task_id}]")
|
| 204 |
+
try:
|
| 205 |
+
score = _llm_agent_run(task_id)
|
| 206 |
+
scores[task_id] = score
|
| 207 |
+
print(f" Score: {score:.4f}")
|
| 208 |
+
except Exception as exc:
|
| 209 |
+
print(f" [!] Error: {exc}. Falling back to heuristic.")
|
| 210 |
+
if _DIRECT_IMPORT:
|
| 211 |
+
fb = run_heuristic_baseline()
|
| 212 |
+
scores[task_id] = fb.get(task_id, 0.0)
|
| 213 |
+
else:
|
| 214 |
+
scores[task_id] = 0.0
|
| 215 |
+
print()
|
| 216 |
+
else:
|
| 217 |
+
print("Mode : Heuristic agent (set OPENAI_API_KEY to use LLM)")
|
| 218 |
+
print()
|
| 219 |
+
if _DIRECT_IMPORT:
|
| 220 |
+
scores = run_baseline_direct()
|
| 221 |
+
else:
|
| 222 |
+
print(f"Calling {ENV_BASE_URL}/baseline ...")
|
| 223 |
+
scores = run_baseline_http()
|
| 224 |
+
for task_id in task_ids:
|
| 225 |
+
print(f" [{task_id}] score = {scores.get(task_id, 0.0):.4f}")
|
| 226 |
+
|
| 227 |
+
print()
|
| 228 |
+
print("-" * 35)
|
| 229 |
+
avg = sum(scores.values()) / max(len(scores), 1)
|
| 230 |
+
for task_id in task_ids:
|
| 231 |
+
s = scores.get(task_id, 0.0)
|
| 232 |
+
bar = "█" * int(s * 20)
|
| 233 |
+
print(f" {task_id:<8s} {s:.4f} {bar}")
|
| 234 |
+
print(f" {'average':<8s} {avg:.4f}")
|
| 235 |
+
print("-" * 35)
|
| 236 |
+
print()
|
| 237 |
+
|
| 238 |
+
# Exit non-zero if any task scored 0.0 (helps CI catch broken graders)
|
| 239 |
+
if any(v == 0.0 for v in scores.values()):
|
| 240 |
+
print("[WARN] One or more tasks scored 0.0. Check the environment.")
|
| 241 |
+
sys.exit(1)
|
| 242 |
+
else:
|
| 243 |
+
print("[OK] All tasks passed baseline threshold.")
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
if __name__ == "__main__":
|
| 247 |
+
main()
|
env.py
ADDED
|
@@ -0,0 +1,641 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
API Gateway Defender — OpenEnv Environment
|
| 3 |
+
==========================================
|
| 4 |
+
A simulated HTTP traffic monitoring environment where an AI agent must
|
| 5 |
+
configure middleware rules to block malicious web traffic while preserving
|
| 6 |
+
legitimate user requests. Models a real production incident faced by SREs.
|
| 7 |
+
|
| 8 |
+
Tasks:
|
| 9 |
+
easy — Single IP volumetric flood on /login
|
| 10 |
+
medium — Scraper bot using identical User-Agent across many IPs
|
| 11 |
+
hard — SQL injection attack with rotating IPs and User-Agents
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import re
|
| 15 |
+
import random
|
| 16 |
+
from typing import Any, Dict, List, Optional
|
| 17 |
+
|
| 18 |
+
from pydantic import BaseModel, Field
|
| 19 |
+
|
| 20 |
+
# ─── Constants ──────────────────────────────────────────────────────────────────
|
| 21 |
+
|
| 22 |
+
MAX_STEPS = 5
|
| 23 |
+
TRAFFIC_SIZE = 1000
|
| 24 |
+
MALICIOUS_COUNT = 200
|
| 25 |
+
FALSE_POSITIVE_THRESHOLD = 0.10 # >10% FP rate → score zeroed
|
| 26 |
+
|
| 27 |
+
# Fixed attack artifacts (deterministic so graders are reproducible)
|
| 28 |
+
ATTACK_IP_EASY = "185.220.101.47"
|
| 29 |
+
SCRAPER_UA = "ScraperBot/3.1 (DataHarvester; +http://scrape.evil.io/bot)"
|
| 30 |
+
SQLI_KEYWORD = "UNION SELECT"
|
| 31 |
+
|
| 32 |
+
LEGIT_USER_AGENTS = [
|
| 33 |
+
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 Chrome/120.0.0.0 Safari/537.36",
|
| 34 |
+
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 Version/17.0 Safari/605.1.15",
|
| 35 |
+
"Mozilla/5.0 (X11; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/115.0",
|
| 36 |
+
"Mozilla/5.0 (iPhone; CPU iPhone OS 17_0 like Mac OS X) AppleWebKit/605.1.15 Mobile/15E148",
|
| 37 |
+
"Mozilla/5.0 (Android 13; Mobile; rv:109.0) Gecko/109.0 Firefox/115.0",
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
LEGIT_PATHS = [
|
| 41 |
+
"/", "/home", "/about", "/products", "/contact",
|
| 42 |
+
"/blog", "/faq", "/pricing", "/docs", "/api/users", "/api/products",
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
LEGIT_QUERY_STRINGS = [
|
| 46 |
+
"", "page=1", "q=hello", "id=42", "sort=asc&filter=new",
|
| 47 |
+
"utm_source=google", "lang=en", "category=tech", "ref=nav",
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
LEGIT_METHODS = ["GET", "GET", "GET", "POST", "PUT"] # GET-heavy, realistic
|
| 51 |
+
|
| 52 |
+
SQLI_PAYLOADS = [
|
| 53 |
+
f"id=1' {SQLI_KEYWORD} username,password FROM users--",
|
| 54 |
+
f"search=admin' {SQLI_KEYWORD} 1,2,3--",
|
| 55 |
+
f"q=x' {SQLI_KEYWORD} table_name FROM information_schema.tables--",
|
| 56 |
+
f"cat=1' {SQLI_KEYWORD} NULL,NULL,NULL--",
|
| 57 |
+
f"item=5' {SQLI_KEYWORD} version()--",
|
| 58 |
+
]
|
| 59 |
+
|
| 60 |
+
# ─── Pydantic Models ─────────────────────────────────────────────────────────────
|
| 61 |
+
|
| 62 |
+
class Action(BaseModel):
|
| 63 |
+
"""
|
| 64 |
+
An action the agent can take — adds one firewall rule to the gateway.
|
| 65 |
+
|
| 66 |
+
action_type choices:
|
| 67 |
+
block_ip — Drop all traffic from target_ip
|
| 68 |
+
add_rate_limit — Allow target_ip only max_requests/min
|
| 69 |
+
block_user_agent — Drop all traffic matching target_user_agent exactly
|
| 70 |
+
write_custom_middleware — Drop requests where regex_pattern matches path?query_string
|
| 71 |
+
"""
|
| 72 |
+
action_type: str = Field(
|
| 73 |
+
...,
|
| 74 |
+
description=(
|
| 75 |
+
"Rule type: 'block_ip', 'add_rate_limit', "
|
| 76 |
+
"'block_user_agent', 'write_custom_middleware'"
|
| 77 |
+
),
|
| 78 |
+
)
|
| 79 |
+
target_ip: Optional[str] = Field(
|
| 80 |
+
None, description="IP address (required for block_ip / add_rate_limit)"
|
| 81 |
+
)
|
| 82 |
+
target_user_agent: Optional[str] = Field(
|
| 83 |
+
None, description="Exact User-Agent string (required for block_user_agent)"
|
| 84 |
+
)
|
| 85 |
+
regex_pattern: Optional[str] = Field(
|
| 86 |
+
None,
|
| 87 |
+
description=(
|
| 88 |
+
"Python regex matched against '{path}?{query_string}' "
|
| 89 |
+
"(required for write_custom_middleware)"
|
| 90 |
+
),
|
| 91 |
+
)
|
| 92 |
+
max_requests: Optional[int] = Field(
|
| 93 |
+
60, description="Requests-per-minute cap for add_rate_limit (default 60)"
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class Observation(BaseModel):
|
| 98 |
+
"""What the agent sees at each step."""
|
| 99 |
+
recent_requests: List[Dict[str, Any]] = Field(
|
| 100 |
+
...,
|
| 101 |
+
description=(
|
| 102 |
+
"Last 100 HTTP requests in the traffic stream. "
|
| 103 |
+
"Fields: ip, method, path, user_agent, query_string, status_code."
|
| 104 |
+
),
|
| 105 |
+
)
|
| 106 |
+
active_rules: List[str] = Field(
|
| 107 |
+
..., description="Human-readable list of rules currently active on the gateway."
|
| 108 |
+
)
|
| 109 |
+
current_task: str = Field(..., description="Task ID: 'easy', 'medium', or 'hard'")
|
| 110 |
+
task_description: str = Field(
|
| 111 |
+
..., description="Natural language description of the attack the agent must repel."
|
| 112 |
+
)
|
| 113 |
+
step_count: int = Field(..., description="Number of rules submitted so far this episode.")
|
| 114 |
+
hint: str = Field("", description="Statistical hint derived from the visible traffic sample.")
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class Reward(BaseModel):
|
| 118 |
+
"""Feedback returned after each step()."""
|
| 119 |
+
score: float = Field(..., ge=0.0, le=1.0, description="Task performance score 0.0–1.0")
|
| 120 |
+
malicious_blocked: int = Field(..., description="Malicious requests blocked by active rules")
|
| 121 |
+
legitimate_blocked: int = Field(..., description="Legitimate requests incorrectly blocked")
|
| 122 |
+
total_malicious: int
|
| 123 |
+
total_legitimate: int
|
| 124 |
+
false_positive_rate: float = Field(..., description="Fraction of legit requests blocked")
|
| 125 |
+
message: str = Field(..., description="Human-readable explanation of the score")
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class StepResult(BaseModel):
|
| 129 |
+
"""Full return value of step()."""
|
| 130 |
+
observation: Observation
|
| 131 |
+
reward: Reward
|
| 132 |
+
done: bool
|
| 133 |
+
info: Dict[str, Any]
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class EnvironmentState(BaseModel):
|
| 137 |
+
"""Full serialisable snapshot returned by state()."""
|
| 138 |
+
task_id: str
|
| 139 |
+
step_count: int
|
| 140 |
+
active_rules: List[Dict[str, Any]]
|
| 141 |
+
episode_done: bool
|
| 142 |
+
best_score: float
|
| 143 |
+
traffic_sample_size: int
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# ─── Traffic Generators ──────────────────────────────────────────────────────────
|
| 147 |
+
|
| 148 |
+
def _rand_ip(rng: random.Random, exclude: str = "") -> str:
|
| 149 |
+
"""Generate a random public-looking IPv4 address."""
|
| 150 |
+
while True:
|
| 151 |
+
ip = (
|
| 152 |
+
f"{rng.randint(10, 220)}."
|
| 153 |
+
f"{rng.randint(1, 254)}."
|
| 154 |
+
f"{rng.randint(1, 254)}."
|
| 155 |
+
f"{rng.randint(1, 254)}"
|
| 156 |
+
)
|
| 157 |
+
if ip != exclude:
|
| 158 |
+
return ip
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def _legit_request(rng: random.Random) -> Dict[str, Any]:
|
| 162 |
+
return {
|
| 163 |
+
"ip": _rand_ip(rng, exclude=ATTACK_IP_EASY),
|
| 164 |
+
"method": rng.choice(LEGIT_METHODS),
|
| 165 |
+
"path": rng.choice(LEGIT_PATHS),
|
| 166 |
+
"user_agent": rng.choice(LEGIT_USER_AGENTS),
|
| 167 |
+
"query_string": rng.choice(LEGIT_QUERY_STRINGS),
|
| 168 |
+
"status_code": 200,
|
| 169 |
+
"is_malicious": False,
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def generate_easy_traffic(seed: int) -> List[Dict[str, Any]]:
|
| 174 |
+
"""
|
| 175 |
+
Easy: one IP floods /login with POST requests.
|
| 176 |
+
Correct action: block_ip or add_rate_limit on ATTACK_IP_EASY.
|
| 177 |
+
"""
|
| 178 |
+
rng = random.Random(seed)
|
| 179 |
+
traffic: List[Dict[str, Any]] = []
|
| 180 |
+
|
| 181 |
+
for _ in range(MALICIOUS_COUNT):
|
| 182 |
+
traffic.append({
|
| 183 |
+
"ip": ATTACK_IP_EASY,
|
| 184 |
+
"method": "POST",
|
| 185 |
+
"path": "/login",
|
| 186 |
+
"user_agent": rng.choice(LEGIT_USER_AGENTS), # UA blends in
|
| 187 |
+
"query_string": "",
|
| 188 |
+
"status_code": 200,
|
| 189 |
+
"is_malicious": True,
|
| 190 |
+
})
|
| 191 |
+
|
| 192 |
+
for _ in range(TRAFFIC_SIZE - MALICIOUS_COUNT):
|
| 193 |
+
traffic.append(_legit_request(rng))
|
| 194 |
+
|
| 195 |
+
rng.shuffle(traffic)
|
| 196 |
+
return traffic
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def generate_medium_traffic(seed: int) -> List[Dict[str, Any]]:
|
| 200 |
+
"""
|
| 201 |
+
Medium: 50 IPs all share an identical unusual User-Agent, hitting /api/data.
|
| 202 |
+
Correct action: block_user_agent with SCRAPER_UA.
|
| 203 |
+
"""
|
| 204 |
+
rng = random.Random(seed)
|
| 205 |
+
traffic: List[Dict[str, Any]] = []
|
| 206 |
+
|
| 207 |
+
scraper_ips = [_rand_ip(rng) for _ in range(50)]
|
| 208 |
+
for _ in range(MALICIOUS_COUNT):
|
| 209 |
+
traffic.append({
|
| 210 |
+
"ip": rng.choice(scraper_ips),
|
| 211 |
+
"method": "GET",
|
| 212 |
+
"path": "/api/data",
|
| 213 |
+
"user_agent": SCRAPER_UA, # constant across all malicious requests
|
| 214 |
+
"query_string": f"page={rng.randint(1, 500)}",
|
| 215 |
+
"status_code": 200,
|
| 216 |
+
"is_malicious": True,
|
| 217 |
+
})
|
| 218 |
+
|
| 219 |
+
for _ in range(TRAFFIC_SIZE - MALICIOUS_COUNT):
|
| 220 |
+
traffic.append(_legit_request(rng))
|
| 221 |
+
|
| 222 |
+
rng.shuffle(traffic)
|
| 223 |
+
return traffic
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def generate_hard_traffic(seed: int) -> List[Dict[str, Any]]:
|
| 227 |
+
"""
|
| 228 |
+
Hard: attacker rotates IPs and UAs but always carries a SQLi payload.
|
| 229 |
+
Correct action: write_custom_middleware with regex matching 'UNION.SELECT'.
|
| 230 |
+
"""
|
| 231 |
+
rng = random.Random(seed)
|
| 232 |
+
traffic: List[Dict[str, Any]] = []
|
| 233 |
+
|
| 234 |
+
for _ in range(MALICIOUS_COUNT):
|
| 235 |
+
traffic.append({
|
| 236 |
+
"ip": _rand_ip(rng),
|
| 237 |
+
"method": "GET",
|
| 238 |
+
"path": rng.choice(["/search", "/products", "/api/items", "/catalog"]),
|
| 239 |
+
"user_agent": rng.choice(LEGIT_USER_AGENTS),
|
| 240 |
+
"query_string": rng.choice(SQLI_PAYLOADS),
|
| 241 |
+
"status_code": 200,
|
| 242 |
+
"is_malicious": True,
|
| 243 |
+
})
|
| 244 |
+
|
| 245 |
+
for _ in range(TRAFFIC_SIZE - MALICIOUS_COUNT):
|
| 246 |
+
req = _legit_request(rng)
|
| 247 |
+
# Guarantee legit requests never accidentally contain the payload
|
| 248 |
+
if SQLI_KEYWORD in req["query_string"].upper():
|
| 249 |
+
req["query_string"] = ""
|
| 250 |
+
traffic.append(req)
|
| 251 |
+
|
| 252 |
+
rng.shuffle(traffic)
|
| 253 |
+
return traffic
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
TASK_GENERATORS = {
|
| 257 |
+
"easy": generate_easy_traffic,
|
| 258 |
+
"medium": generate_medium_traffic,
|
| 259 |
+
"hard": generate_hard_traffic,
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
TASK_DESCRIPTIONS = {
|
| 263 |
+
"easy": (
|
| 264 |
+
"A single IP address is flooding your /login endpoint with POST requests at high volume. "
|
| 265 |
+
"Inspect the traffic logs to identify the offending IP and block it or apply a rate limit."
|
| 266 |
+
),
|
| 267 |
+
"medium": (
|
| 268 |
+
"A scraper bot is harvesting your /api/data endpoint from many different IP addresses. "
|
| 269 |
+
"All malicious requests share a single, unusual User-Agent string. "
|
| 270 |
+
"Identify the User-Agent and block it."
|
| 271 |
+
),
|
| 272 |
+
"hard": (
|
| 273 |
+
"An attacker is probing your database via SQL injection. They rotate IP addresses and "
|
| 274 |
+
"User-Agents to evade simple rules, but every malicious request contains a SQL injection "
|
| 275 |
+
"payload in the query string. Write a regex middleware rule to detect and drop these requests."
|
| 276 |
+
),
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
# ─── Rule Engine ─────────────────────────────────────────────────────────────────
|
| 281 |
+
|
| 282 |
+
class _Rule:
|
| 283 |
+
"""Internal class: wraps an Action and applies it to individual requests."""
|
| 284 |
+
|
| 285 |
+
def __init__(self, action: Action) -> None:
|
| 286 |
+
self.action = action
|
| 287 |
+
self._compiled_re = None
|
| 288 |
+
if action.action_type == "write_custom_middleware" and action.regex_pattern:
|
| 289 |
+
try:
|
| 290 |
+
self._compiled_re = re.compile(action.regex_pattern, re.IGNORECASE)
|
| 291 |
+
except re.error:
|
| 292 |
+
pass # invalid regex → rule matches nothing
|
| 293 |
+
|
| 294 |
+
def blocks(self, request: Dict[str, Any]) -> bool:
|
| 295 |
+
a = self.action
|
| 296 |
+
if a.action_type in ("block_ip", "add_rate_limit"):
|
| 297 |
+
return bool(a.target_ip and request["ip"] == a.target_ip)
|
| 298 |
+
if a.action_type == "block_user_agent":
|
| 299 |
+
return bool(
|
| 300 |
+
a.target_user_agent
|
| 301 |
+
and request["user_agent"] == a.target_user_agent
|
| 302 |
+
)
|
| 303 |
+
if a.action_type == "write_custom_middleware" and self._compiled_re:
|
| 304 |
+
target = f"{request['path']}?{request['query_string']}"
|
| 305 |
+
return bool(self._compiled_re.search(target))
|
| 306 |
+
return False
|
| 307 |
+
|
| 308 |
+
def describe(self) -> str:
|
| 309 |
+
a = self.action
|
| 310 |
+
if a.action_type == "block_ip":
|
| 311 |
+
return f"BLOCK_IP({a.target_ip})"
|
| 312 |
+
if a.action_type == "add_rate_limit":
|
| 313 |
+
return f"RATE_LIMIT({a.target_ip}, max={a.max_requests}/min)"
|
| 314 |
+
if a.action_type == "block_user_agent":
|
| 315 |
+
return f"BLOCK_UA({a.target_user_agent!r})"
|
| 316 |
+
if a.action_type == "write_custom_middleware":
|
| 317 |
+
return f"MIDDLEWARE(regex={a.regex_pattern!r})"
|
| 318 |
+
return f"RULE({a.action_type})"
|
| 319 |
+
|
| 320 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 321 |
+
a = self.action
|
| 322 |
+
return {
|
| 323 |
+
"action_type": a.action_type,
|
| 324 |
+
"target_ip": a.target_ip,
|
| 325 |
+
"target_user_agent": a.target_user_agent,
|
| 326 |
+
"regex_pattern": a.regex_pattern,
|
| 327 |
+
"description": self.describe(),
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
# ─── Environment ─────────────────────────────────────────────────────────────────
|
| 332 |
+
|
| 333 |
+
VALID_ACTION_TYPES = {"block_ip", "add_rate_limit", "block_user_agent", "write_custom_middleware"}
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
class APIGatewayDefender:
|
| 337 |
+
"""
|
| 338 |
+
OpenEnv-compliant RL environment — API Gateway Defender.
|
| 339 |
+
|
| 340 |
+
The agent monitors a simulated stream of HTTP requests and must apply
|
| 341 |
+
firewall middleware rules to block malicious traffic while preserving
|
| 342 |
+
legitimate requests.
|
| 343 |
+
|
| 344 |
+
Usage
|
| 345 |
+
-----
|
| 346 |
+
env = APIGatewayDefender()
|
| 347 |
+
obs = env.reset(task_id="easy")
|
| 348 |
+
action = Action(action_type="block_ip", target_ip="185.220.101.47")
|
| 349 |
+
result = env.step(action)
|
| 350 |
+
print(result.reward.score)
|
| 351 |
+
"""
|
| 352 |
+
|
| 353 |
+
def __init__(self) -> None:
|
| 354 |
+
self._task_id: str = "easy"
|
| 355 |
+
self._rules: List[_Rule] = []
|
| 356 |
+
self._train_traffic: List[Dict[str, Any]] = []
|
| 357 |
+
self._test_traffic: List[Dict[str, Any]] = []
|
| 358 |
+
self._step_count: int = 0
|
| 359 |
+
self._done: bool = False
|
| 360 |
+
self._best_score: float = 0.0
|
| 361 |
+
|
| 362 |
+
# ── OpenEnv Interface ──────────────────────────────────────────────────────
|
| 363 |
+
|
| 364 |
+
def reset(self, task_id: str = "easy") -> Observation:
|
| 365 |
+
"""
|
| 366 |
+
Start a new episode on the given task.
|
| 367 |
+
|
| 368 |
+
Parameters
|
| 369 |
+
----------
|
| 370 |
+
task_id : str
|
| 371 |
+
One of 'easy', 'medium', 'hard'.
|
| 372 |
+
|
| 373 |
+
Returns
|
| 374 |
+
-------
|
| 375 |
+
Observation
|
| 376 |
+
Initial observation containing the first 100 traffic samples.
|
| 377 |
+
"""
|
| 378 |
+
if task_id not in TASK_GENERATORS:
|
| 379 |
+
raise ValueError(
|
| 380 |
+
f"Unknown task_id '{task_id}'. Choose from: {sorted(TASK_GENERATORS)}"
|
| 381 |
+
)
|
| 382 |
+
self._task_id = task_id
|
| 383 |
+
self._rules = []
|
| 384 |
+
self._step_count = 0
|
| 385 |
+
self._done = False
|
| 386 |
+
self._best_score = 0.0
|
| 387 |
+
|
| 388 |
+
gen = TASK_GENERATORS[task_id]
|
| 389 |
+
self._train_traffic = gen(seed=42) # agent can see this
|
| 390 |
+
self._test_traffic = gen(seed=137) # grading set (hidden from agent)
|
| 391 |
+
|
| 392 |
+
return self._make_observation()
|
| 393 |
+
|
| 394 |
+
def step(self, action: Action) -> StepResult:
|
| 395 |
+
"""
|
| 396 |
+
Submit one firewall rule and receive a reward signal.
|
| 397 |
+
|
| 398 |
+
The rule is evaluated against a hidden test traffic set to prevent
|
| 399 |
+
overfitting to the visible sample. Partial credit is awarded for
|
| 400 |
+
partial detection; false positives incur a penalty.
|
| 401 |
+
|
| 402 |
+
Parameters
|
| 403 |
+
----------
|
| 404 |
+
action : Action
|
| 405 |
+
The rule to apply.
|
| 406 |
+
|
| 407 |
+
Returns
|
| 408 |
+
-------
|
| 409 |
+
StepResult
|
| 410 |
+
observation, reward, done flag, and diagnostic info.
|
| 411 |
+
"""
|
| 412 |
+
if self._done:
|
| 413 |
+
raise RuntimeError("Episode is finished. Call reset() to start a new episode.")
|
| 414 |
+
|
| 415 |
+
self._step_count += 1
|
| 416 |
+
|
| 417 |
+
# ── Validate action type ──────────────────────────────────────────────
|
| 418 |
+
if action.action_type not in VALID_ACTION_TYPES:
|
| 419 |
+
err_reward = Reward(
|
| 420 |
+
score=0.0,
|
| 421 |
+
malicious_blocked=0,
|
| 422 |
+
legitimate_blocked=0,
|
| 423 |
+
total_malicious=MALICIOUS_COUNT,
|
| 424 |
+
total_legitimate=TRAFFIC_SIZE - MALICIOUS_COUNT,
|
| 425 |
+
false_positive_rate=0.0,
|
| 426 |
+
message=(
|
| 427 |
+
f"Invalid action_type '{action.action_type}'. "
|
| 428 |
+
f"Must be one of {sorted(VALID_ACTION_TYPES)}."
|
| 429 |
+
),
|
| 430 |
+
)
|
| 431 |
+
return StepResult(
|
| 432 |
+
observation=self._make_observation(),
|
| 433 |
+
reward=err_reward,
|
| 434 |
+
done=False,
|
| 435 |
+
info={"error": "invalid_action_type"},
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
# ── Apply rule ────────────────────────────────────────────────────────
|
| 439 |
+
self._rules.append(_Rule(action))
|
| 440 |
+
|
| 441 |
+
# ── Grade on hidden test traffic ──────────────────────────────────────
|
| 442 |
+
reward = self._grade()
|
| 443 |
+
self._best_score = max(self._best_score, reward.score)
|
| 444 |
+
|
| 445 |
+
# Episode ends at MAX_STEPS or when the agent achieves near-perfect score
|
| 446 |
+
self._done = self._step_count >= MAX_STEPS or reward.score >= 0.95
|
| 447 |
+
|
| 448 |
+
return StepResult(
|
| 449 |
+
observation=self._make_observation(),
|
| 450 |
+
reward=reward,
|
| 451 |
+
done=self._done,
|
| 452 |
+
info={
|
| 453 |
+
"step": self._step_count,
|
| 454 |
+
"best_score": self._best_score,
|
| 455 |
+
"rules_applied": [r.describe() for r in self._rules],
|
| 456 |
+
"max_steps": MAX_STEPS,
|
| 457 |
+
},
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
def state(self) -> EnvironmentState:
|
| 461 |
+
"""Return a full serialisable snapshot of the current environment state."""
|
| 462 |
+
return EnvironmentState(
|
| 463 |
+
task_id=self._task_id,
|
| 464 |
+
step_count=self._step_count,
|
| 465 |
+
active_rules=[r.to_dict() for r in self._rules],
|
| 466 |
+
episode_done=self._done,
|
| 467 |
+
best_score=self._best_score,
|
| 468 |
+
traffic_sample_size=len(self._train_traffic),
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
def get_task_grader_score(self) -> float:
|
| 472 |
+
"""
|
| 473 |
+
Programmatic grader — returns score 0.0–1.0 for the current episode.
|
| 474 |
+
Returns 0.0 if no rules have been applied yet.
|
| 475 |
+
"""
|
| 476 |
+
if not self._rules:
|
| 477 |
+
return 0.0
|
| 478 |
+
return self._grade().score
|
| 479 |
+
|
| 480 |
+
# ── Private Helpers ────────────────────────────────────────────────────────
|
| 481 |
+
|
| 482 |
+
def _make_observation(self) -> Observation:
|
| 483 |
+
"""Build an Observation from the current state (no is_malicious flag exposed)."""
|
| 484 |
+
visible = [
|
| 485 |
+
{k: v for k, v in req.items() if k != "is_malicious"}
|
| 486 |
+
for req in self._train_traffic[:100]
|
| 487 |
+
]
|
| 488 |
+
return Observation(
|
| 489 |
+
recent_requests=visible,
|
| 490 |
+
active_rules=[r.describe() for r in self._rules],
|
| 491 |
+
current_task=self._task_id,
|
| 492 |
+
task_description=TASK_DESCRIPTIONS[self._task_id],
|
| 493 |
+
step_count=self._step_count,
|
| 494 |
+
hint=self._build_hint(),
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
def _build_hint(self) -> str:
|
| 498 |
+
"""Generate a statistical hint from the visible traffic sample."""
|
| 499 |
+
if not self._train_traffic:
|
| 500 |
+
return ""
|
| 501 |
+
sample = self._train_traffic[:100]
|
| 502 |
+
malicious_in_sample = [r for r in sample if r.get("is_malicious")]
|
| 503 |
+
n = len(malicious_in_sample)
|
| 504 |
+
|
| 505 |
+
if self._task_id == "easy":
|
| 506 |
+
if n == 0:
|
| 507 |
+
return "Traffic looks normal in this window."
|
| 508 |
+
ips = {r["ip"] for r in malicious_in_sample}
|
| 509 |
+
return (
|
| 510 |
+
f"Warning: {n} POST requests to /login detected in this window "
|
| 511 |
+
f"from {len(ips)} unique IP(s). Possible brute-force or flood."
|
| 512 |
+
)
|
| 513 |
+
elif self._task_id == "medium":
|
| 514 |
+
if n == 0:
|
| 515 |
+
return "Traffic looks normal in this window."
|
| 516 |
+
uas = {r["user_agent"] for r in malicious_in_sample}
|
| 517 |
+
return (
|
| 518 |
+
f"Warning: {n} requests to /api/data share {len(uas)} unique User-Agent(s) "
|
| 519 |
+
f"in this window. Possible scraper activity."
|
| 520 |
+
)
|
| 521 |
+
else:
|
| 522 |
+
if n == 0:
|
| 523 |
+
return "Traffic looks normal in this window."
|
| 524 |
+
return (
|
| 525 |
+
f"Warning: {n} requests in this window contain unusual query string patterns. "
|
| 526 |
+
f"Check for injection payloads."
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
def _grade(self) -> Reward:
|
| 530 |
+
"""
|
| 531 |
+
Apply all active rules to the hidden test traffic set and compute a score.
|
| 532 |
+
|
| 533 |
+
Score formula:
|
| 534 |
+
detection_rate = malicious_blocked / total_malicious
|
| 535 |
+
fp_rate = legitimate_blocked / total_legitimate
|
| 536 |
+
if fp_rate > FALSE_POSITIVE_THRESHOLD:
|
| 537 |
+
score = 0.0 ← too many false positives
|
| 538 |
+
else:
|
| 539 |
+
score = clamp(detection_rate - fp_rate * 5.0, 0.0, 1.0)
|
| 540 |
+
"""
|
| 541 |
+
malicious = [r for r in self._test_traffic if r["is_malicious"]]
|
| 542 |
+
legit = [r for r in self._test_traffic if not r["is_malicious"]]
|
| 543 |
+
|
| 544 |
+
mal_blocked = sum(1 for r in malicious if any(rule.blocks(r) for rule in self._rules))
|
| 545 |
+
legit_blocked = sum(1 for r in legit if any(rule.blocks(r) for rule in self._rules))
|
| 546 |
+
|
| 547 |
+
total_mal = len(malicious)
|
| 548 |
+
total_legit = len(legit)
|
| 549 |
+
|
| 550 |
+
detection_rate = mal_blocked / total_mal if total_mal > 0 else 0.0
|
| 551 |
+
fp_rate = legit_blocked / total_legit if total_legit > 0 else 0.0
|
| 552 |
+
|
| 553 |
+
if fp_rate > FALSE_POSITIVE_THRESHOLD:
|
| 554 |
+
score = 0.0
|
| 555 |
+
message = (
|
| 556 |
+
f"Score zeroed: {fp_rate:.1%} false positive rate exceeds "
|
| 557 |
+
f"{FALSE_POSITIVE_THRESHOLD:.0%} threshold. Rules are too broad — "
|
| 558 |
+
f"legitimate users are being blocked."
|
| 559 |
+
)
|
| 560 |
+
else:
|
| 561 |
+
score = max(0.0, min(1.0, detection_rate - fp_rate * 5.0))
|
| 562 |
+
message = (
|
| 563 |
+
f"Blocked {mal_blocked}/{total_mal} malicious requests "
|
| 564 |
+
f"({detection_rate:.1%} detection rate) with "
|
| 565 |
+
f"{fp_rate:.1%} false positive rate."
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
return Reward(
|
| 569 |
+
score=round(score, 4),
|
| 570 |
+
malicious_blocked=mal_blocked,
|
| 571 |
+
legitimate_blocked=legit_blocked,
|
| 572 |
+
total_malicious=total_mal,
|
| 573 |
+
total_legitimate=total_legit,
|
| 574 |
+
false_positive_rate=round(fp_rate, 4),
|
| 575 |
+
message=message,
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
# ─── Convenience: heuristic baseline that runs directly on the class ────────────
|
| 580 |
+
|
| 581 |
+
def run_heuristic_baseline() -> Dict[str, float]:
|
| 582 |
+
"""
|
| 583 |
+
A deterministic heuristic agent that solves all 3 tasks correctly.
|
| 584 |
+
Used by the /baseline endpoint and as fallback in the inference script.
|
| 585 |
+
|
| 586 |
+
Returns
|
| 587 |
+
-------
|
| 588 |
+
Dict[str, float]
|
| 589 |
+
task_id → score
|
| 590 |
+
"""
|
| 591 |
+
env = APIGatewayDefender()
|
| 592 |
+
scores: Dict[str, float] = {}
|
| 593 |
+
|
| 594 |
+
# ── Easy: identify the IP flooding /login ──────────────────────────────────
|
| 595 |
+
obs = env.reset("easy")
|
| 596 |
+
ip_counts: Dict[str, int] = {}
|
| 597 |
+
for req in obs.recent_requests:
|
| 598 |
+
if req["path"] == "/login" and req["method"] == "POST":
|
| 599 |
+
ip_counts[req["ip"]] = ip_counts.get(req["ip"], 0) + 1
|
| 600 |
+
suspect_ip = (
|
| 601 |
+
max(ip_counts, key=lambda k: ip_counts[k]) if ip_counts else ATTACK_IP_EASY
|
| 602 |
+
)
|
| 603 |
+
result = env.step(Action(action_type="block_ip", target_ip=suspect_ip))
|
| 604 |
+
scores["easy"] = result.reward.score
|
| 605 |
+
|
| 606 |
+
# ── Medium: identify the unusual User-Agent ────────────────────────────────
|
| 607 |
+
obs = env.reset("medium")
|
| 608 |
+
ua_counts: Dict[str, int] = {}
|
| 609 |
+
for req in obs.recent_requests:
|
| 610 |
+
ua_counts[req["user_agent"]] = ua_counts.get(req["user_agent"], 0) + 1
|
| 611 |
+
|
| 612 |
+
bot_keywords = {"scraper", "bot", "crawler", "spider", "harvester"}
|
| 613 |
+
browser_keywords = {"mozilla", "chrome", "safari", "firefox", "gecko", "webkit"}
|
| 614 |
+
suspect_ua = None
|
| 615 |
+
|
| 616 |
+
# Prefer UAs that look like bots
|
| 617 |
+
for ua, _ in sorted(ua_counts.items(), key=lambda x: -x[1]):
|
| 618 |
+
if any(k in ua.lower() for k in bot_keywords):
|
| 619 |
+
suspect_ua = ua
|
| 620 |
+
break
|
| 621 |
+
# Fallback: most common UA that doesn't look like a browser
|
| 622 |
+
if not suspect_ua:
|
| 623 |
+
for ua, _ in sorted(ua_counts.items(), key=lambda x: -x[1]):
|
| 624 |
+
if not any(k in ua.lower() for k in browser_keywords):
|
| 625 |
+
suspect_ua = ua
|
| 626 |
+
break
|
| 627 |
+
|
| 628 |
+
result = env.step(Action(action_type="block_user_agent", target_user_agent=suspect_ua or ""))
|
| 629 |
+
scores["medium"] = result.reward.score
|
| 630 |
+
|
| 631 |
+
# ── Hard: write a regex to catch SQLi payloads ────────────────────────────
|
| 632 |
+
env.reset("hard")
|
| 633 |
+
result = env.step(
|
| 634 |
+
Action(
|
| 635 |
+
action_type="write_custom_middleware",
|
| 636 |
+
regex_pattern=r"UNION\s+SELECT",
|
| 637 |
+
)
|
| 638 |
+
)
|
| 639 |
+
scores["hard"] = result.reward.score
|
| 640 |
+
|
| 641 |
+
return scores
|
main.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
API Gateway Defender — FastAPI Server
|
| 3 |
+
=====================================
|
| 4 |
+
Exposes the OpenEnv-compliant HTTP API for the environment.
|
| 5 |
+
|
| 6 |
+
Endpoints
|
| 7 |
+
---------
|
| 8 |
+
POST /reset — Start a new episode
|
| 9 |
+
POST /step — Submit a firewall rule, receive reward
|
| 10 |
+
GET /state — Inspect current environment state
|
| 11 |
+
GET /tasks — List tasks and action schema
|
| 12 |
+
GET /grader — Get grader score for current episode
|
| 13 |
+
POST /baseline — Run heuristic baseline across all 3 tasks
|
| 14 |
+
GET /health — Liveness probe (required for HF Spaces ping)
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from fastapi import FastAPI, HTTPException
|
| 18 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 19 |
+
from typing import Any, Dict
|
| 20 |
+
|
| 21 |
+
from env import (
|
| 22 |
+
Action,
|
| 23 |
+
APIGatewayDefender,
|
| 24 |
+
Observation,
|
| 25 |
+
TASK_DESCRIPTIONS,
|
| 26 |
+
run_heuristic_baseline,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
# ─── App setup ───────────────────────────────────────────────────────────────────
|
| 30 |
+
|
| 31 |
+
app = FastAPI(
|
| 32 |
+
title="API Gateway Defender",
|
| 33 |
+
description=(
|
| 34 |
+
"An OpenEnv RL environment where an AI agent defends a simulated web backend "
|
| 35 |
+
"against volumetric floods, scraper bots, and SQL injection attacks."
|
| 36 |
+
),
|
| 37 |
+
version="1.0.0",
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
app.add_middleware(
|
| 41 |
+
CORSMiddleware,
|
| 42 |
+
allow_origins=["*"],
|
| 43 |
+
allow_methods=["*"],
|
| 44 |
+
allow_headers=["*"],
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
# Single shared environment instance (stateful, per-session)
|
| 48 |
+
_env = APIGatewayDefender()
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# ─── Routes ──────────────────────────────────────────────────────────────────────
|
| 52 |
+
|
| 53 |
+
@app.get("/health")
|
| 54 |
+
def health() -> Dict[str, str]:
|
| 55 |
+
"""Liveness probe — returns 200 and confirms the environment is running."""
|
| 56 |
+
return {"status": "ok", "environment": "api-gateway-defender"}
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@app.get("/")
|
| 60 |
+
def root() -> Dict[str, Any]:
|
| 61 |
+
"""Overview of the environment and available endpoints."""
|
| 62 |
+
return {
|
| 63 |
+
"name": "API Gateway Defender",
|
| 64 |
+
"description": (
|
| 65 |
+
"OpenEnv RL environment: configure firewall rules to block malicious "
|
| 66 |
+
"HTTP traffic while preserving legitimate requests."
|
| 67 |
+
),
|
| 68 |
+
"version": "1.0.0",
|
| 69 |
+
"tasks": list(TASK_DESCRIPTIONS.keys()),
|
| 70 |
+
"endpoints": {
|
| 71 |
+
"POST /reset": "Start a new episode. Body: {task_id: 'easy'|'medium'|'hard'}",
|
| 72 |
+
"POST /step": "Submit a firewall rule. Body: Action schema.",
|
| 73 |
+
"GET /state": "Current environment state snapshot.",
|
| 74 |
+
"GET /tasks": "Task descriptions + action/observation schemas.",
|
| 75 |
+
"GET /grader": "Current grader score for the active episode.",
|
| 76 |
+
"POST /baseline": "Run heuristic baseline agent across all 3 tasks.",
|
| 77 |
+
"GET /health": "Liveness probe.",
|
| 78 |
+
},
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@app.post("/reset")
|
| 83 |
+
def reset(body: Dict[str, str] = None) -> Dict[str, Any]:
|
| 84 |
+
"""
|
| 85 |
+
Start a new episode.
|
| 86 |
+
|
| 87 |
+
Request body (JSON):
|
| 88 |
+
{"task_id": "easy" | "medium" | "hard"}
|
| 89 |
+
|
| 90 |
+
Returns the initial Observation.
|
| 91 |
+
"""
|
| 92 |
+
task_id = (body or {}).get("task_id", "easy")
|
| 93 |
+
try:
|
| 94 |
+
obs: Observation = _env.reset(task_id=task_id)
|
| 95 |
+
except ValueError as exc:
|
| 96 |
+
raise HTTPException(status_code=422, detail=str(exc))
|
| 97 |
+
return obs.model_dump()
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
@app.post("/step")
|
| 101 |
+
def step(action: Action) -> Dict[str, Any]:
|
| 102 |
+
"""
|
| 103 |
+
Submit one firewall rule.
|
| 104 |
+
|
| 105 |
+
Returns StepResult: {observation, reward, done, info}
|
| 106 |
+
|
| 107 |
+
Reward score: 0.0–1.0
|
| 108 |
+
= detection_rate − (false_positive_rate × 5)
|
| 109 |
+
= 0.0 if false positive rate > 10%
|
| 110 |
+
"""
|
| 111 |
+
try:
|
| 112 |
+
result = _env.step(action)
|
| 113 |
+
except RuntimeError as exc:
|
| 114 |
+
raise HTTPException(status_code=400, detail=str(exc))
|
| 115 |
+
return result.model_dump()
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
@app.get("/state")
|
| 119 |
+
def state() -> Dict[str, Any]:
|
| 120 |
+
"""Return the full serialisable state of the current episode."""
|
| 121 |
+
return _env.state().model_dump()
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
@app.get("/tasks")
|
| 125 |
+
def tasks() -> Dict[str, Any]:
|
| 126 |
+
"""
|
| 127 |
+
List all tasks and their descriptions, plus the action and observation schemas.
|
| 128 |
+
Required by the OpenEnv spec.
|
| 129 |
+
"""
|
| 130 |
+
return {
|
| 131 |
+
"tasks": [
|
| 132 |
+
{
|
| 133 |
+
"id": "easy",
|
| 134 |
+
"name": "Volumetric IP Flood Defense",
|
| 135 |
+
"difficulty": "easy",
|
| 136 |
+
"description": TASK_DESCRIPTIONS["easy"],
|
| 137 |
+
"hint": "One IP is responsible for all malicious traffic.",
|
| 138 |
+
},
|
| 139 |
+
{
|
| 140 |
+
"id": "medium",
|
| 141 |
+
"name": "Scraper Bot Detection",
|
| 142 |
+
"difficulty": "medium",
|
| 143 |
+
"description": TASK_DESCRIPTIONS["medium"],
|
| 144 |
+
"hint": "Many IPs, but a single shared User-Agent string.",
|
| 145 |
+
},
|
| 146 |
+
{
|
| 147 |
+
"id": "hard",
|
| 148 |
+
"name": "SQL Injection Middleware Defense",
|
| 149 |
+
"difficulty": "hard",
|
| 150 |
+
"description": TASK_DESCRIPTIONS["hard"],
|
| 151 |
+
"hint": "Rotating IPs and UAs, but a consistent payload pattern.",
|
| 152 |
+
},
|
| 153 |
+
],
|
| 154 |
+
"action_schema": Action.model_json_schema(),
|
| 155 |
+
"observation_schema": {
|
| 156 |
+
"recent_requests": "list[dict] — last 100 requests: ip, method, path, user_agent, query_string, status_code",
|
| 157 |
+
"active_rules": "list[str] — human-readable active firewall rules",
|
| 158 |
+
"current_task": "str — 'easy', 'medium', or 'hard'",
|
| 159 |
+
"task_description":"str — natural language goal",
|
| 160 |
+
"step_count": "int — steps taken this episode",
|
| 161 |
+
"hint": "str — statistical hint from visible traffic",
|
| 162 |
+
},
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
@app.get("/grader")
|
| 167 |
+
def grader() -> Dict[str, Any]:
|
| 168 |
+
"""
|
| 169 |
+
Return the programmatic grader score for the current episode.
|
| 170 |
+
Score is 0.0–1.0; reflects detection rate minus false-positive penalty.
|
| 171 |
+
"""
|
| 172 |
+
score = _env.get_task_grader_score()
|
| 173 |
+
state_info = _env.state()
|
| 174 |
+
return {
|
| 175 |
+
"task_id": state_info.task_id,
|
| 176 |
+
"score": score,
|
| 177 |
+
"best_score": state_info.best_score,
|
| 178 |
+
"rules_applied":[r["description"] for r in state_info.active_rules],
|
| 179 |
+
"episode_done": state_info.episode_done,
|
| 180 |
+
"max_steps": 5,
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
@app.post("/baseline")
|
| 185 |
+
def baseline() -> Dict[str, Any]:
|
| 186 |
+
"""
|
| 187 |
+
Run the heuristic baseline agent across all 3 tasks and return scores.
|
| 188 |
+
Does not affect the shared episode state.
|
| 189 |
+
"""
|
| 190 |
+
scores = run_heuristic_baseline()
|
| 191 |
+
avg = sum(scores.values()) / len(scores)
|
| 192 |
+
return {
|
| 193 |
+
"scores": scores,
|
| 194 |
+
"average": round(avg, 4),
|
| 195 |
+
"message": (
|
| 196 |
+
"Heuristic baseline: reads visible logs, identifies the attack pattern, "
|
| 197 |
+
"applies the optimal rule. No LLM required."
|
| 198 |
+
),
|
| 199 |
+
}
|
openenv.yaml
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: api-gateway-defender
|
| 2 |
+
version: "1.0.0"
|
| 3 |
+
description: >
|
| 4 |
+
A simulated HTTP traffic monitoring environment where an AI agent acts as
|
| 5 |
+
a Site Reliability Engineer defending a web backend. The agent inspects a
|
| 6 |
+
stream of incoming HTTP requests and must configure middleware firewall rules
|
| 7 |
+
to block malicious traffic while preserving legitimate user requests.
|
| 8 |
+
|
| 9 |
+
Models a real production incident domain: rate-limiting, WAF rule authoring,
|
| 10 |
+
and pattern-based traffic filtering — skills that are highly valued in DevOps,
|
| 11 |
+
SRE, and cybersecurity engineering.
|
| 12 |
+
|
| 13 |
+
author: "API Gateway Defender Team"
|
| 14 |
+
license: "Apache-2.0"
|
| 15 |
+
|
| 16 |
+
tags:
|
| 17 |
+
- openenv
|
| 18 |
+
- cybersecurity
|
| 19 |
+
- web-security
|
| 20 |
+
- sre
|
| 21 |
+
- real-world
|
| 22 |
+
- devops
|
| 23 |
+
- rate-limiting
|
| 24 |
+
- waf
|
| 25 |
+
|
| 26 |
+
tasks:
|
| 27 |
+
- id: easy
|
| 28 |
+
name: "Volumetric IP Flood Defense"
|
| 29 |
+
difficulty: easy
|
| 30 |
+
max_score: 1.0
|
| 31 |
+
description: >
|
| 32 |
+
A single IP address is flooding the /login endpoint with POST requests.
|
| 33 |
+
The agent must identify the malicious IP from traffic logs and block it
|
| 34 |
+
(or apply a rate limit). Tests pattern recognition under high-volume noise.
|
| 35 |
+
success_criteria: >
|
| 36 |
+
block_ip or add_rate_limit action targeting the flooding IP address,
|
| 37 |
+
achieving ≥0.95 detection rate with <10% false positive rate.
|
| 38 |
+
|
| 39 |
+
- id: medium
|
| 40 |
+
name: "Scraper Bot Detection"
|
| 41 |
+
difficulty: medium
|
| 42 |
+
max_score: 1.0
|
| 43 |
+
description: >
|
| 44 |
+
A scraper bot harvests the /api/data endpoint from 50 different IP addresses,
|
| 45 |
+
rotating them to evade IP-based blocks. All malicious requests share one
|
| 46 |
+
identical unusual User-Agent string. The agent must identify and block it.
|
| 47 |
+
success_criteria: >
|
| 48 |
+
block_user_agent action with the exact malicious User-Agent string,
|
| 49 |
+
achieving ≥0.95 detection rate with <10% false positive rate.
|
| 50 |
+
|
| 51 |
+
- id: hard
|
| 52 |
+
name: "SQL Injection Middleware Defense"
|
| 53 |
+
difficulty: hard
|
| 54 |
+
max_score: 1.0
|
| 55 |
+
description: >
|
| 56 |
+
An attacker probes the database via SQL injection. They rotate IP addresses
|
| 57 |
+
AND User-Agents on every request to evade simple rules. Every malicious
|
| 58 |
+
request contains a SQL injection payload in the query string. The agent
|
| 59 |
+
must write a regex-based middleware rule to detect and block all payloads.
|
| 60 |
+
success_criteria: >
|
| 61 |
+
write_custom_middleware action with a regex that matches 'UNION SELECT'
|
| 62 |
+
pattern (case-insensitive), achieving ≥0.95 detection rate with <10% FP rate.
|
| 63 |
+
|
| 64 |
+
observation_space:
|
| 65 |
+
type: structured
|
| 66 |
+
description: "Snapshot of recent HTTP traffic and active gateway configuration."
|
| 67 |
+
fields:
|
| 68 |
+
- name: recent_requests
|
| 69 |
+
type: "list[dict]"
|
| 70 |
+
description: "Last 100 HTTP requests. Each has: ip, method, path, user_agent, query_string, status_code."
|
| 71 |
+
- name: active_rules
|
| 72 |
+
type: "list[str]"
|
| 73 |
+
description: "Human-readable list of firewall rules currently active."
|
| 74 |
+
- name: current_task
|
| 75 |
+
type: string
|
| 76 |
+
description: "Task ID: 'easy', 'medium', or 'hard'."
|
| 77 |
+
- name: task_description
|
| 78 |
+
type: string
|
| 79 |
+
description: "Natural language description of the attack to defend against."
|
| 80 |
+
- name: step_count
|
| 81 |
+
type: integer
|
| 82 |
+
description: "Number of rules submitted in the current episode."
|
| 83 |
+
- name: hint
|
| 84 |
+
type: string
|
| 85 |
+
description: "Statistical hint about suspicious patterns in the visible traffic window."
|
| 86 |
+
|
| 87 |
+
action_space:
|
| 88 |
+
type: discrete_parameterized
|
| 89 |
+
description: "Submit one firewall rule to the gateway middleware."
|
| 90 |
+
fields:
|
| 91 |
+
- name: action_type
|
| 92 |
+
type: string
|
| 93 |
+
required: true
|
| 94 |
+
choices:
|
| 95 |
+
- block_ip
|
| 96 |
+
- add_rate_limit
|
| 97 |
+
- block_user_agent
|
| 98 |
+
- write_custom_middleware
|
| 99 |
+
description: "Which type of rule to apply."
|
| 100 |
+
- name: target_ip
|
| 101 |
+
type: string
|
| 102 |
+
required: false
|
| 103 |
+
description: "IP address. Required for block_ip and add_rate_limit."
|
| 104 |
+
- name: target_user_agent
|
| 105 |
+
type: string
|
| 106 |
+
required: false
|
| 107 |
+
description: "Exact User-Agent string. Required for block_user_agent."
|
| 108 |
+
- name: regex_pattern
|
| 109 |
+
type: string
|
| 110 |
+
required: false
|
| 111 |
+
description: "Python regex matched against '{path}?{query_string}'. Required for write_custom_middleware."
|
| 112 |
+
- name: max_requests
|
| 113 |
+
type: integer
|
| 114 |
+
required: false
|
| 115 |
+
default: 60
|
| 116 |
+
description: "Requests per minute cap. Used with add_rate_limit."
|
| 117 |
+
|
| 118 |
+
reward:
|
| 119 |
+
range: [0.0, 1.0]
|
| 120 |
+
type: continuous
|
| 121 |
+
formula: >
|
| 122 |
+
detection_rate = malicious_blocked / total_malicious
|
| 123 |
+
false_positive_rate = legitimate_blocked / total_legitimate
|
| 124 |
+
if false_positive_rate > 0.10:
|
| 125 |
+
score = 0.0
|
| 126 |
+
else:
|
| 127 |
+
score = clamp(detection_rate - false_positive_rate * 5.0, 0.0, 1.0)
|
| 128 |
+
description: >
|
| 129 |
+
Rewards accurate detection of malicious traffic. Penalises false positives
|
| 130 |
+
(blocking legitimate users) with a 5x multiplier. Zeroed entirely if
|
| 131 |
+
false positive rate exceeds 10% — models real operational constraints
|
| 132 |
+
where blocking paying customers is unacceptable.
|
| 133 |
+
|
| 134 |
+
episode:
|
| 135 |
+
max_steps: 5
|
| 136 |
+
termination_conditions:
|
| 137 |
+
- "score >= 0.95 (success)"
|
| 138 |
+
- "step_count >= 5 (step limit)"
|
| 139 |
+
reset_required: true
|
| 140 |
+
|
| 141 |
+
evaluation:
|
| 142 |
+
grader_type: programmatic
|
| 143 |
+
deterministic: true
|
| 144 |
+
train_seed: 42
|
| 145 |
+
test_seed: 137
|
| 146 |
+
description: >
|
| 147 |
+
Rules are graded against a hidden test traffic set (seed 137) distinct from
|
| 148 |
+
the visible training sample (seed 42). This prevents agents from overfitting
|
| 149 |
+
to specific IPs/UAs in the observation window.
|
| 150 |
+
|
| 151 |
+
api:
|
| 152 |
+
framework: FastAPI
|
| 153 |
+
port: 7860
|
| 154 |
+
endpoints:
|
| 155 |
+
- "POST /reset"
|
| 156 |
+
- "POST /step"
|
| 157 |
+
- "GET /state"
|
| 158 |
+
- "GET /tasks"
|
| 159 |
+
- "GET /grader"
|
| 160 |
+
- "POST /baseline"
|
| 161 |
+
- "GET /health"
|
| 162 |
+
|
| 163 |
+
baseline:
|
| 164 |
+
agent_type: heuristic
|
| 165 |
+
scores:
|
| 166 |
+
easy: 1.0
|
| 167 |
+
medium: 1.0
|
| 168 |
+
hard: 1.0
|
| 169 |
+
note: >
|
| 170 |
+
Heuristic agent reads the visible traffic sample, identifies the attack
|
| 171 |
+
pattern statistically, and applies the optimal rule. Scores are fully
|
| 172 |
+
reproducible with fixed seeds.
|