CystronCode commited on
Commit
c3fbc01
·
verified ·
1 Parent(s): 022430a

Upload 5 files

Browse files
Files changed (5) hide show
  1. Dockerfile +36 -0
  2. baseline.py +247 -0
  3. env.py +641 -0
  4. main.py +199 -0
  5. openenv.yaml +172 -0
Dockerfile ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ─── Build stage ─────────────────────────────────────────────────────────────────
2
+ FROM python:3.11-slim AS builder
3
+
4
+ WORKDIR /build
5
+
6
+ # Install dependencies into a prefix we'll copy to the final image
7
+ COPY requirements.txt .
8
+ RUN pip install --no-cache-dir --prefix=/install -r requirements.txt
9
+
10
+
11
+ # ─── Runtime stage ────────────────────────────────────────────────────────────────
12
+ FROM python:3.11-slim
13
+
14
+ # Hugging Face Spaces expects the app on port 7860
15
+ ENV PORT=7860
16
+ ENV PYTHONUNBUFFERED=1
17
+ ENV PYTHONDONTWRITEBYTECODE=1
18
+
19
+ WORKDIR /app
20
+
21
+ # Copy pre-installed packages from builder
22
+ COPY --from=builder /install /usr/local
23
+
24
+ # Copy application code
25
+ COPY env.py .
26
+ COPY main.py .
27
+ COPY baseline.py .
28
+
29
+ # HF Spaces: non-root user for safety
30
+ RUN useradd -m -u 1000 appuser && chown -R appuser /app
31
+ USER appuser
32
+
33
+ EXPOSE 7860
34
+
35
+ # Increase workers for concurrent evaluation runs
36
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1"]
baseline.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Baseline Inference Script — API Gateway Defender
3
+ =================================================
4
+ Evaluates an agent on all 3 tasks and prints reproducible scores.
5
+
6
+ Usage
7
+ -----
8
+ # With LLM (reads OPENAI_API_KEY from environment):
9
+ OPENAI_API_KEY=sk-... python baseline.py
10
+
11
+ # Heuristic fallback (no API key needed):
12
+ python baseline.py
13
+
14
+ The LLM agent receives the traffic logs and task description, then
15
+ produces a JSON action that is submitted to the environment.
16
+
17
+ The heuristic agent reads the visible logs statistically and picks
18
+ the correct rule — used to verify the grader is working correctly
19
+ and as a reproducible baseline for submission.
20
+ """
21
+
22
+ import json
23
+ import os
24
+ import sys
25
+ import urllib.error
26
+ import urllib.request
27
+ from typing import Any, Dict
28
+
29
+ # Allow running standalone (before FastAPI starts) by importing env directly
30
+ try:
31
+ from env import (
32
+ Action,
33
+ APIGatewayDefender,
34
+ TASK_DESCRIPTIONS,
35
+ run_heuristic_baseline,
36
+ )
37
+ _DIRECT_IMPORT = True
38
+ except ImportError:
39
+ _DIRECT_IMPORT = False
40
+
41
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
42
+ ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:8000")
43
+ LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4o-mini")
44
+
45
+
46
+ # ─── OpenAI helper ───────────────────────────────────────────────────────────────
47
+
48
+ def _call_openai(messages: list, max_tokens: int = 512) -> str:
49
+ """Send a request to the OpenAI chat completions endpoint."""
50
+ payload = json.dumps(
51
+ {
52
+ "model": LLM_MODEL,
53
+ "messages": messages,
54
+ "max_tokens": max_tokens,
55
+ "temperature": 0.1,
56
+ }
57
+ ).encode()
58
+
59
+ req = urllib.request.Request(
60
+ "https://api.openai.com/v1/chat/completions",
61
+ data=payload,
62
+ headers={
63
+ "Content-Type": "application/json",
64
+ "Authorization": f"Bearer {OPENAI_API_KEY}",
65
+ },
66
+ )
67
+ try:
68
+ with urllib.request.urlopen(req, timeout=30) as resp:
69
+ data = json.loads(resp.read())
70
+ return data["choices"][0]["message"]["content"]
71
+ except urllib.error.HTTPError as exc:
72
+ body = exc.read().decode(errors="replace")
73
+ raise RuntimeError(f"OpenAI API error {exc.code}: {body}") from exc
74
+
75
+
76
+ def _parse_json_from_llm(raw: str) -> Dict[str, Any]:
77
+ """Extract a JSON object from LLM output, stripping markdown fences if present."""
78
+ raw = raw.strip()
79
+ if raw.startswith("```"):
80
+ parts = raw.split("```")
81
+ # parts[1] is the fenced block; strip language tag if present
82
+ inner = parts[1]
83
+ if inner.lower().startswith("json"):
84
+ inner = inner[4:]
85
+ raw = inner.strip()
86
+ return json.loads(raw)
87
+
88
+
89
+ # ─── LLM agent ───────────────────────────────────────────────────────────────────
90
+
91
+ def _llm_agent_run(task_id: str) -> float:
92
+ """
93
+ Run an LLM agent on a single task via the HTTP API.
94
+
95
+ 1. Reset the environment.
96
+ 2. Show the agent the traffic logs and task description.
97
+ 3. Ask it to produce a JSON action.
98
+ 4. Submit the action and return the reward score.
99
+ """
100
+ import urllib.request as urlreq
101
+
102
+ def _post(path: str, body: Any) -> Any:
103
+ data = json.dumps(body).encode()
104
+ req = urlreq.Request(
105
+ f"{ENV_BASE_URL}{path}",
106
+ data=data,
107
+ headers={"Content-Type": "application/json"},
108
+ )
109
+ with urlreq.urlopen(req, timeout=15) as resp:
110
+ return json.loads(resp.read())
111
+
112
+ # 1. Reset
113
+ obs = _post("/reset", {"task_id": task_id})
114
+
115
+ # 2. Build prompt (truncate request list to 25 to stay within token budget)
116
+ sample_requests = obs["recent_requests"][:25]
117
+
118
+ system_prompt = (
119
+ "You are a Site Reliability Engineer responding to a live production incident. "
120
+ "You will be shown HTTP traffic logs and a task description. "
121
+ "Your job is to write exactly ONE firewall rule as a JSON object. "
122
+ "Respond with ONLY valid JSON — no prose, no markdown fences."
123
+ )
124
+
125
+ action_schema = (
126
+ "{\n"
127
+ ' "action_type": "block_ip" | "add_rate_limit" | "block_user_agent" | "write_custom_middleware",\n'
128
+ ' "target_ip": "<string, required for block_ip / add_rate_limit>",\n'
129
+ ' "target_user_agent": "<string, required for block_user_agent>",\n'
130
+ ' "regex_pattern": "<Python regex, required for write_custom_middleware>",\n'
131
+ ' "max_requests": <int, optional — requests/min cap for add_rate_limit>\n'
132
+ "}"
133
+ )
134
+
135
+ user_prompt = (
136
+ f"TASK: {obs['task_description']}\n\n"
137
+ f"HINT: {obs.get('hint', '')}\n\n"
138
+ f"TRAFFIC SAMPLE (first 25 requests):\n"
139
+ f"{json.dumps(sample_requests, indent=2)}\n\n"
140
+ f"Respond with ONE JSON action using this schema:\n{action_schema}"
141
+ )
142
+
143
+ # 3. Call LLM
144
+ llm_response = _call_openai(
145
+ [
146
+ {"role": "system", "content": system_prompt},
147
+ {"role": "user", "content": user_prompt},
148
+ ]
149
+ )
150
+
151
+ # 4. Parse action
152
+ try:
153
+ action_dict = _parse_json_from_llm(llm_response)
154
+ except (json.JSONDecodeError, KeyError) as exc:
155
+ print(f" [!] Failed to parse LLM response: {exc}\n Raw: {llm_response[:200]}")
156
+ return 0.0
157
+
158
+ # 5. Step
159
+ result = _post("/step", action_dict)
160
+ score = result["reward"]["score"]
161
+ msg = result["reward"]["message"]
162
+ print(f" Action: {action_dict}")
163
+ print(f" Result: {msg}")
164
+ return score
165
+
166
+
167
+ # ─── Main ────────────────────────────────────────────────────────────────────────
168
+
169
+ def run_baseline_direct() -> Dict[str, float]:
170
+ """Run heuristic baseline directly on the Python class (no server needed)."""
171
+ return run_heuristic_baseline()
172
+
173
+
174
+ def run_baseline_http() -> Dict[str, float]:
175
+ """Run heuristic baseline via the HTTP API."""
176
+ import urllib.request as urlreq
177
+
178
+ req = urlreq.Request(
179
+ f"{ENV_BASE_URL}/baseline",
180
+ data=b"{}",
181
+ headers={"Content-Type": "application/json"},
182
+ method="POST",
183
+ )
184
+ with urlreq.urlopen(req, timeout=30) as resp:
185
+ data = json.loads(resp.read())
186
+ return data["scores"]
187
+
188
+
189
+ def main() -> None:
190
+ print("=" * 55)
191
+ print(" API Gateway Defender — Baseline Evaluation")
192
+ print("=" * 55)
193
+ print()
194
+
195
+ task_ids = ["easy", "medium", "hard"]
196
+ scores: Dict[str, float] = {}
197
+
198
+ if OPENAI_API_KEY:
199
+ print(f"Mode : LLM agent ({LLM_MODEL})")
200
+ print(f"URL : {ENV_BASE_URL}")
201
+ print()
202
+ for task_id in task_ids:
203
+ print(f"[Task: {task_id}]")
204
+ try:
205
+ score = _llm_agent_run(task_id)
206
+ scores[task_id] = score
207
+ print(f" Score: {score:.4f}")
208
+ except Exception as exc:
209
+ print(f" [!] Error: {exc}. Falling back to heuristic.")
210
+ if _DIRECT_IMPORT:
211
+ fb = run_heuristic_baseline()
212
+ scores[task_id] = fb.get(task_id, 0.0)
213
+ else:
214
+ scores[task_id] = 0.0
215
+ print()
216
+ else:
217
+ print("Mode : Heuristic agent (set OPENAI_API_KEY to use LLM)")
218
+ print()
219
+ if _DIRECT_IMPORT:
220
+ scores = run_baseline_direct()
221
+ else:
222
+ print(f"Calling {ENV_BASE_URL}/baseline ...")
223
+ scores = run_baseline_http()
224
+ for task_id in task_ids:
225
+ print(f" [{task_id}] score = {scores.get(task_id, 0.0):.4f}")
226
+
227
+ print()
228
+ print("-" * 35)
229
+ avg = sum(scores.values()) / max(len(scores), 1)
230
+ for task_id in task_ids:
231
+ s = scores.get(task_id, 0.0)
232
+ bar = "█" * int(s * 20)
233
+ print(f" {task_id:<8s} {s:.4f} {bar}")
234
+ print(f" {'average':<8s} {avg:.4f}")
235
+ print("-" * 35)
236
+ print()
237
+
238
+ # Exit non-zero if any task scored 0.0 (helps CI catch broken graders)
239
+ if any(v == 0.0 for v in scores.values()):
240
+ print("[WARN] One or more tasks scored 0.0. Check the environment.")
241
+ sys.exit(1)
242
+ else:
243
+ print("[OK] All tasks passed baseline threshold.")
244
+
245
+
246
+ if __name__ == "__main__":
247
+ main()
env.py ADDED
@@ -0,0 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ API Gateway Defender — OpenEnv Environment
3
+ ==========================================
4
+ A simulated HTTP traffic monitoring environment where an AI agent must
5
+ configure middleware rules to block malicious web traffic while preserving
6
+ legitimate user requests. Models a real production incident faced by SREs.
7
+
8
+ Tasks:
9
+ easy — Single IP volumetric flood on /login
10
+ medium — Scraper bot using identical User-Agent across many IPs
11
+ hard — SQL injection attack with rotating IPs and User-Agents
12
+ """
13
+
14
+ import re
15
+ import random
16
+ from typing import Any, Dict, List, Optional
17
+
18
+ from pydantic import BaseModel, Field
19
+
20
+ # ─── Constants ──────────────────────────────────────────────────────────────────
21
+
22
+ MAX_STEPS = 5
23
+ TRAFFIC_SIZE = 1000
24
+ MALICIOUS_COUNT = 200
25
+ FALSE_POSITIVE_THRESHOLD = 0.10 # >10% FP rate → score zeroed
26
+
27
+ # Fixed attack artifacts (deterministic so graders are reproducible)
28
+ ATTACK_IP_EASY = "185.220.101.47"
29
+ SCRAPER_UA = "ScraperBot/3.1 (DataHarvester; +http://scrape.evil.io/bot)"
30
+ SQLI_KEYWORD = "UNION SELECT"
31
+
32
+ LEGIT_USER_AGENTS = [
33
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 Chrome/120.0.0.0 Safari/537.36",
34
+ "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 Version/17.0 Safari/605.1.15",
35
+ "Mozilla/5.0 (X11; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/115.0",
36
+ "Mozilla/5.0 (iPhone; CPU iPhone OS 17_0 like Mac OS X) AppleWebKit/605.1.15 Mobile/15E148",
37
+ "Mozilla/5.0 (Android 13; Mobile; rv:109.0) Gecko/109.0 Firefox/115.0",
38
+ ]
39
+
40
+ LEGIT_PATHS = [
41
+ "/", "/home", "/about", "/products", "/contact",
42
+ "/blog", "/faq", "/pricing", "/docs", "/api/users", "/api/products",
43
+ ]
44
+
45
+ LEGIT_QUERY_STRINGS = [
46
+ "", "page=1", "q=hello", "id=42", "sort=asc&filter=new",
47
+ "utm_source=google", "lang=en", "category=tech", "ref=nav",
48
+ ]
49
+
50
+ LEGIT_METHODS = ["GET", "GET", "GET", "POST", "PUT"] # GET-heavy, realistic
51
+
52
+ SQLI_PAYLOADS = [
53
+ f"id=1' {SQLI_KEYWORD} username,password FROM users--",
54
+ f"search=admin' {SQLI_KEYWORD} 1,2,3--",
55
+ f"q=x' {SQLI_KEYWORD} table_name FROM information_schema.tables--",
56
+ f"cat=1' {SQLI_KEYWORD} NULL,NULL,NULL--",
57
+ f"item=5' {SQLI_KEYWORD} version()--",
58
+ ]
59
+
60
+ # ─── Pydantic Models ─────────────────────────────────────────────────────────────
61
+
62
+ class Action(BaseModel):
63
+ """
64
+ An action the agent can take — adds one firewall rule to the gateway.
65
+
66
+ action_type choices:
67
+ block_ip — Drop all traffic from target_ip
68
+ add_rate_limit — Allow target_ip only max_requests/min
69
+ block_user_agent — Drop all traffic matching target_user_agent exactly
70
+ write_custom_middleware — Drop requests where regex_pattern matches path?query_string
71
+ """
72
+ action_type: str = Field(
73
+ ...,
74
+ description=(
75
+ "Rule type: 'block_ip', 'add_rate_limit', "
76
+ "'block_user_agent', 'write_custom_middleware'"
77
+ ),
78
+ )
79
+ target_ip: Optional[str] = Field(
80
+ None, description="IP address (required for block_ip / add_rate_limit)"
81
+ )
82
+ target_user_agent: Optional[str] = Field(
83
+ None, description="Exact User-Agent string (required for block_user_agent)"
84
+ )
85
+ regex_pattern: Optional[str] = Field(
86
+ None,
87
+ description=(
88
+ "Python regex matched against '{path}?{query_string}' "
89
+ "(required for write_custom_middleware)"
90
+ ),
91
+ )
92
+ max_requests: Optional[int] = Field(
93
+ 60, description="Requests-per-minute cap for add_rate_limit (default 60)"
94
+ )
95
+
96
+
97
+ class Observation(BaseModel):
98
+ """What the agent sees at each step."""
99
+ recent_requests: List[Dict[str, Any]] = Field(
100
+ ...,
101
+ description=(
102
+ "Last 100 HTTP requests in the traffic stream. "
103
+ "Fields: ip, method, path, user_agent, query_string, status_code."
104
+ ),
105
+ )
106
+ active_rules: List[str] = Field(
107
+ ..., description="Human-readable list of rules currently active on the gateway."
108
+ )
109
+ current_task: str = Field(..., description="Task ID: 'easy', 'medium', or 'hard'")
110
+ task_description: str = Field(
111
+ ..., description="Natural language description of the attack the agent must repel."
112
+ )
113
+ step_count: int = Field(..., description="Number of rules submitted so far this episode.")
114
+ hint: str = Field("", description="Statistical hint derived from the visible traffic sample.")
115
+
116
+
117
+ class Reward(BaseModel):
118
+ """Feedback returned after each step()."""
119
+ score: float = Field(..., ge=0.0, le=1.0, description="Task performance score 0.0–1.0")
120
+ malicious_blocked: int = Field(..., description="Malicious requests blocked by active rules")
121
+ legitimate_blocked: int = Field(..., description="Legitimate requests incorrectly blocked")
122
+ total_malicious: int
123
+ total_legitimate: int
124
+ false_positive_rate: float = Field(..., description="Fraction of legit requests blocked")
125
+ message: str = Field(..., description="Human-readable explanation of the score")
126
+
127
+
128
+ class StepResult(BaseModel):
129
+ """Full return value of step()."""
130
+ observation: Observation
131
+ reward: Reward
132
+ done: bool
133
+ info: Dict[str, Any]
134
+
135
+
136
+ class EnvironmentState(BaseModel):
137
+ """Full serialisable snapshot returned by state()."""
138
+ task_id: str
139
+ step_count: int
140
+ active_rules: List[Dict[str, Any]]
141
+ episode_done: bool
142
+ best_score: float
143
+ traffic_sample_size: int
144
+
145
+
146
+ # ─── Traffic Generators ──────────────────────────────────────────────────────────
147
+
148
+ def _rand_ip(rng: random.Random, exclude: str = "") -> str:
149
+ """Generate a random public-looking IPv4 address."""
150
+ while True:
151
+ ip = (
152
+ f"{rng.randint(10, 220)}."
153
+ f"{rng.randint(1, 254)}."
154
+ f"{rng.randint(1, 254)}."
155
+ f"{rng.randint(1, 254)}"
156
+ )
157
+ if ip != exclude:
158
+ return ip
159
+
160
+
161
+ def _legit_request(rng: random.Random) -> Dict[str, Any]:
162
+ return {
163
+ "ip": _rand_ip(rng, exclude=ATTACK_IP_EASY),
164
+ "method": rng.choice(LEGIT_METHODS),
165
+ "path": rng.choice(LEGIT_PATHS),
166
+ "user_agent": rng.choice(LEGIT_USER_AGENTS),
167
+ "query_string": rng.choice(LEGIT_QUERY_STRINGS),
168
+ "status_code": 200,
169
+ "is_malicious": False,
170
+ }
171
+
172
+
173
+ def generate_easy_traffic(seed: int) -> List[Dict[str, Any]]:
174
+ """
175
+ Easy: one IP floods /login with POST requests.
176
+ Correct action: block_ip or add_rate_limit on ATTACK_IP_EASY.
177
+ """
178
+ rng = random.Random(seed)
179
+ traffic: List[Dict[str, Any]] = []
180
+
181
+ for _ in range(MALICIOUS_COUNT):
182
+ traffic.append({
183
+ "ip": ATTACK_IP_EASY,
184
+ "method": "POST",
185
+ "path": "/login",
186
+ "user_agent": rng.choice(LEGIT_USER_AGENTS), # UA blends in
187
+ "query_string": "",
188
+ "status_code": 200,
189
+ "is_malicious": True,
190
+ })
191
+
192
+ for _ in range(TRAFFIC_SIZE - MALICIOUS_COUNT):
193
+ traffic.append(_legit_request(rng))
194
+
195
+ rng.shuffle(traffic)
196
+ return traffic
197
+
198
+
199
+ def generate_medium_traffic(seed: int) -> List[Dict[str, Any]]:
200
+ """
201
+ Medium: 50 IPs all share an identical unusual User-Agent, hitting /api/data.
202
+ Correct action: block_user_agent with SCRAPER_UA.
203
+ """
204
+ rng = random.Random(seed)
205
+ traffic: List[Dict[str, Any]] = []
206
+
207
+ scraper_ips = [_rand_ip(rng) for _ in range(50)]
208
+ for _ in range(MALICIOUS_COUNT):
209
+ traffic.append({
210
+ "ip": rng.choice(scraper_ips),
211
+ "method": "GET",
212
+ "path": "/api/data",
213
+ "user_agent": SCRAPER_UA, # constant across all malicious requests
214
+ "query_string": f"page={rng.randint(1, 500)}",
215
+ "status_code": 200,
216
+ "is_malicious": True,
217
+ })
218
+
219
+ for _ in range(TRAFFIC_SIZE - MALICIOUS_COUNT):
220
+ traffic.append(_legit_request(rng))
221
+
222
+ rng.shuffle(traffic)
223
+ return traffic
224
+
225
+
226
+ def generate_hard_traffic(seed: int) -> List[Dict[str, Any]]:
227
+ """
228
+ Hard: attacker rotates IPs and UAs but always carries a SQLi payload.
229
+ Correct action: write_custom_middleware with regex matching 'UNION.SELECT'.
230
+ """
231
+ rng = random.Random(seed)
232
+ traffic: List[Dict[str, Any]] = []
233
+
234
+ for _ in range(MALICIOUS_COUNT):
235
+ traffic.append({
236
+ "ip": _rand_ip(rng),
237
+ "method": "GET",
238
+ "path": rng.choice(["/search", "/products", "/api/items", "/catalog"]),
239
+ "user_agent": rng.choice(LEGIT_USER_AGENTS),
240
+ "query_string": rng.choice(SQLI_PAYLOADS),
241
+ "status_code": 200,
242
+ "is_malicious": True,
243
+ })
244
+
245
+ for _ in range(TRAFFIC_SIZE - MALICIOUS_COUNT):
246
+ req = _legit_request(rng)
247
+ # Guarantee legit requests never accidentally contain the payload
248
+ if SQLI_KEYWORD in req["query_string"].upper():
249
+ req["query_string"] = ""
250
+ traffic.append(req)
251
+
252
+ rng.shuffle(traffic)
253
+ return traffic
254
+
255
+
256
+ TASK_GENERATORS = {
257
+ "easy": generate_easy_traffic,
258
+ "medium": generate_medium_traffic,
259
+ "hard": generate_hard_traffic,
260
+ }
261
+
262
+ TASK_DESCRIPTIONS = {
263
+ "easy": (
264
+ "A single IP address is flooding your /login endpoint with POST requests at high volume. "
265
+ "Inspect the traffic logs to identify the offending IP and block it or apply a rate limit."
266
+ ),
267
+ "medium": (
268
+ "A scraper bot is harvesting your /api/data endpoint from many different IP addresses. "
269
+ "All malicious requests share a single, unusual User-Agent string. "
270
+ "Identify the User-Agent and block it."
271
+ ),
272
+ "hard": (
273
+ "An attacker is probing your database via SQL injection. They rotate IP addresses and "
274
+ "User-Agents to evade simple rules, but every malicious request contains a SQL injection "
275
+ "payload in the query string. Write a regex middleware rule to detect and drop these requests."
276
+ ),
277
+ }
278
+
279
+
280
+ # ─── Rule Engine ─────────────────────────────────────────────────────────────────
281
+
282
+ class _Rule:
283
+ """Internal class: wraps an Action and applies it to individual requests."""
284
+
285
+ def __init__(self, action: Action) -> None:
286
+ self.action = action
287
+ self._compiled_re = None
288
+ if action.action_type == "write_custom_middleware" and action.regex_pattern:
289
+ try:
290
+ self._compiled_re = re.compile(action.regex_pattern, re.IGNORECASE)
291
+ except re.error:
292
+ pass # invalid regex → rule matches nothing
293
+
294
+ def blocks(self, request: Dict[str, Any]) -> bool:
295
+ a = self.action
296
+ if a.action_type in ("block_ip", "add_rate_limit"):
297
+ return bool(a.target_ip and request["ip"] == a.target_ip)
298
+ if a.action_type == "block_user_agent":
299
+ return bool(
300
+ a.target_user_agent
301
+ and request["user_agent"] == a.target_user_agent
302
+ )
303
+ if a.action_type == "write_custom_middleware" and self._compiled_re:
304
+ target = f"{request['path']}?{request['query_string']}"
305
+ return bool(self._compiled_re.search(target))
306
+ return False
307
+
308
+ def describe(self) -> str:
309
+ a = self.action
310
+ if a.action_type == "block_ip":
311
+ return f"BLOCK_IP({a.target_ip})"
312
+ if a.action_type == "add_rate_limit":
313
+ return f"RATE_LIMIT({a.target_ip}, max={a.max_requests}/min)"
314
+ if a.action_type == "block_user_agent":
315
+ return f"BLOCK_UA({a.target_user_agent!r})"
316
+ if a.action_type == "write_custom_middleware":
317
+ return f"MIDDLEWARE(regex={a.regex_pattern!r})"
318
+ return f"RULE({a.action_type})"
319
+
320
+ def to_dict(self) -> Dict[str, Any]:
321
+ a = self.action
322
+ return {
323
+ "action_type": a.action_type,
324
+ "target_ip": a.target_ip,
325
+ "target_user_agent": a.target_user_agent,
326
+ "regex_pattern": a.regex_pattern,
327
+ "description": self.describe(),
328
+ }
329
+
330
+
331
+ # ─── Environment ─────────────────────────────────────────────────────────────────
332
+
333
+ VALID_ACTION_TYPES = {"block_ip", "add_rate_limit", "block_user_agent", "write_custom_middleware"}
334
+
335
+
336
+ class APIGatewayDefender:
337
+ """
338
+ OpenEnv-compliant RL environment — API Gateway Defender.
339
+
340
+ The agent monitors a simulated stream of HTTP requests and must apply
341
+ firewall middleware rules to block malicious traffic while preserving
342
+ legitimate requests.
343
+
344
+ Usage
345
+ -----
346
+ env = APIGatewayDefender()
347
+ obs = env.reset(task_id="easy")
348
+ action = Action(action_type="block_ip", target_ip="185.220.101.47")
349
+ result = env.step(action)
350
+ print(result.reward.score)
351
+ """
352
+
353
+ def __init__(self) -> None:
354
+ self._task_id: str = "easy"
355
+ self._rules: List[_Rule] = []
356
+ self._train_traffic: List[Dict[str, Any]] = []
357
+ self._test_traffic: List[Dict[str, Any]] = []
358
+ self._step_count: int = 0
359
+ self._done: bool = False
360
+ self._best_score: float = 0.0
361
+
362
+ # ── OpenEnv Interface ──────────────────────────────────────────────────────
363
+
364
+ def reset(self, task_id: str = "easy") -> Observation:
365
+ """
366
+ Start a new episode on the given task.
367
+
368
+ Parameters
369
+ ----------
370
+ task_id : str
371
+ One of 'easy', 'medium', 'hard'.
372
+
373
+ Returns
374
+ -------
375
+ Observation
376
+ Initial observation containing the first 100 traffic samples.
377
+ """
378
+ if task_id not in TASK_GENERATORS:
379
+ raise ValueError(
380
+ f"Unknown task_id '{task_id}'. Choose from: {sorted(TASK_GENERATORS)}"
381
+ )
382
+ self._task_id = task_id
383
+ self._rules = []
384
+ self._step_count = 0
385
+ self._done = False
386
+ self._best_score = 0.0
387
+
388
+ gen = TASK_GENERATORS[task_id]
389
+ self._train_traffic = gen(seed=42) # agent can see this
390
+ self._test_traffic = gen(seed=137) # grading set (hidden from agent)
391
+
392
+ return self._make_observation()
393
+
394
+ def step(self, action: Action) -> StepResult:
395
+ """
396
+ Submit one firewall rule and receive a reward signal.
397
+
398
+ The rule is evaluated against a hidden test traffic set to prevent
399
+ overfitting to the visible sample. Partial credit is awarded for
400
+ partial detection; false positives incur a penalty.
401
+
402
+ Parameters
403
+ ----------
404
+ action : Action
405
+ The rule to apply.
406
+
407
+ Returns
408
+ -------
409
+ StepResult
410
+ observation, reward, done flag, and diagnostic info.
411
+ """
412
+ if self._done:
413
+ raise RuntimeError("Episode is finished. Call reset() to start a new episode.")
414
+
415
+ self._step_count += 1
416
+
417
+ # ── Validate action type ──────────────────────────────────────────────
418
+ if action.action_type not in VALID_ACTION_TYPES:
419
+ err_reward = Reward(
420
+ score=0.0,
421
+ malicious_blocked=0,
422
+ legitimate_blocked=0,
423
+ total_malicious=MALICIOUS_COUNT,
424
+ total_legitimate=TRAFFIC_SIZE - MALICIOUS_COUNT,
425
+ false_positive_rate=0.0,
426
+ message=(
427
+ f"Invalid action_type '{action.action_type}'. "
428
+ f"Must be one of {sorted(VALID_ACTION_TYPES)}."
429
+ ),
430
+ )
431
+ return StepResult(
432
+ observation=self._make_observation(),
433
+ reward=err_reward,
434
+ done=False,
435
+ info={"error": "invalid_action_type"},
436
+ )
437
+
438
+ # ── Apply rule ────────────────────────────────────────────────────────
439
+ self._rules.append(_Rule(action))
440
+
441
+ # ── Grade on hidden test traffic ──────────────────────────────────────
442
+ reward = self._grade()
443
+ self._best_score = max(self._best_score, reward.score)
444
+
445
+ # Episode ends at MAX_STEPS or when the agent achieves near-perfect score
446
+ self._done = self._step_count >= MAX_STEPS or reward.score >= 0.95
447
+
448
+ return StepResult(
449
+ observation=self._make_observation(),
450
+ reward=reward,
451
+ done=self._done,
452
+ info={
453
+ "step": self._step_count,
454
+ "best_score": self._best_score,
455
+ "rules_applied": [r.describe() for r in self._rules],
456
+ "max_steps": MAX_STEPS,
457
+ },
458
+ )
459
+
460
+ def state(self) -> EnvironmentState:
461
+ """Return a full serialisable snapshot of the current environment state."""
462
+ return EnvironmentState(
463
+ task_id=self._task_id,
464
+ step_count=self._step_count,
465
+ active_rules=[r.to_dict() for r in self._rules],
466
+ episode_done=self._done,
467
+ best_score=self._best_score,
468
+ traffic_sample_size=len(self._train_traffic),
469
+ )
470
+
471
+ def get_task_grader_score(self) -> float:
472
+ """
473
+ Programmatic grader — returns score 0.0–1.0 for the current episode.
474
+ Returns 0.0 if no rules have been applied yet.
475
+ """
476
+ if not self._rules:
477
+ return 0.0
478
+ return self._grade().score
479
+
480
+ # ── Private Helpers ────────────────────────────────────────────────────────
481
+
482
+ def _make_observation(self) -> Observation:
483
+ """Build an Observation from the current state (no is_malicious flag exposed)."""
484
+ visible = [
485
+ {k: v for k, v in req.items() if k != "is_malicious"}
486
+ for req in self._train_traffic[:100]
487
+ ]
488
+ return Observation(
489
+ recent_requests=visible,
490
+ active_rules=[r.describe() for r in self._rules],
491
+ current_task=self._task_id,
492
+ task_description=TASK_DESCRIPTIONS[self._task_id],
493
+ step_count=self._step_count,
494
+ hint=self._build_hint(),
495
+ )
496
+
497
+ def _build_hint(self) -> str:
498
+ """Generate a statistical hint from the visible traffic sample."""
499
+ if not self._train_traffic:
500
+ return ""
501
+ sample = self._train_traffic[:100]
502
+ malicious_in_sample = [r for r in sample if r.get("is_malicious")]
503
+ n = len(malicious_in_sample)
504
+
505
+ if self._task_id == "easy":
506
+ if n == 0:
507
+ return "Traffic looks normal in this window."
508
+ ips = {r["ip"] for r in malicious_in_sample}
509
+ return (
510
+ f"Warning: {n} POST requests to /login detected in this window "
511
+ f"from {len(ips)} unique IP(s). Possible brute-force or flood."
512
+ )
513
+ elif self._task_id == "medium":
514
+ if n == 0:
515
+ return "Traffic looks normal in this window."
516
+ uas = {r["user_agent"] for r in malicious_in_sample}
517
+ return (
518
+ f"Warning: {n} requests to /api/data share {len(uas)} unique User-Agent(s) "
519
+ f"in this window. Possible scraper activity."
520
+ )
521
+ else:
522
+ if n == 0:
523
+ return "Traffic looks normal in this window."
524
+ return (
525
+ f"Warning: {n} requests in this window contain unusual query string patterns. "
526
+ f"Check for injection payloads."
527
+ )
528
+
529
+ def _grade(self) -> Reward:
530
+ """
531
+ Apply all active rules to the hidden test traffic set and compute a score.
532
+
533
+ Score formula:
534
+ detection_rate = malicious_blocked / total_malicious
535
+ fp_rate = legitimate_blocked / total_legitimate
536
+ if fp_rate > FALSE_POSITIVE_THRESHOLD:
537
+ score = 0.0 ← too many false positives
538
+ else:
539
+ score = clamp(detection_rate - fp_rate * 5.0, 0.0, 1.0)
540
+ """
541
+ malicious = [r for r in self._test_traffic if r["is_malicious"]]
542
+ legit = [r for r in self._test_traffic if not r["is_malicious"]]
543
+
544
+ mal_blocked = sum(1 for r in malicious if any(rule.blocks(r) for rule in self._rules))
545
+ legit_blocked = sum(1 for r in legit if any(rule.blocks(r) for rule in self._rules))
546
+
547
+ total_mal = len(malicious)
548
+ total_legit = len(legit)
549
+
550
+ detection_rate = mal_blocked / total_mal if total_mal > 0 else 0.0
551
+ fp_rate = legit_blocked / total_legit if total_legit > 0 else 0.0
552
+
553
+ if fp_rate > FALSE_POSITIVE_THRESHOLD:
554
+ score = 0.0
555
+ message = (
556
+ f"Score zeroed: {fp_rate:.1%} false positive rate exceeds "
557
+ f"{FALSE_POSITIVE_THRESHOLD:.0%} threshold. Rules are too broad — "
558
+ f"legitimate users are being blocked."
559
+ )
560
+ else:
561
+ score = max(0.0, min(1.0, detection_rate - fp_rate * 5.0))
562
+ message = (
563
+ f"Blocked {mal_blocked}/{total_mal} malicious requests "
564
+ f"({detection_rate:.1%} detection rate) with "
565
+ f"{fp_rate:.1%} false positive rate."
566
+ )
567
+
568
+ return Reward(
569
+ score=round(score, 4),
570
+ malicious_blocked=mal_blocked,
571
+ legitimate_blocked=legit_blocked,
572
+ total_malicious=total_mal,
573
+ total_legitimate=total_legit,
574
+ false_positive_rate=round(fp_rate, 4),
575
+ message=message,
576
+ )
577
+
578
+
579
+ # ─── Convenience: heuristic baseline that runs directly on the class ────────────
580
+
581
+ def run_heuristic_baseline() -> Dict[str, float]:
582
+ """
583
+ A deterministic heuristic agent that solves all 3 tasks correctly.
584
+ Used by the /baseline endpoint and as fallback in the inference script.
585
+
586
+ Returns
587
+ -------
588
+ Dict[str, float]
589
+ task_id → score
590
+ """
591
+ env = APIGatewayDefender()
592
+ scores: Dict[str, float] = {}
593
+
594
+ # ── Easy: identify the IP flooding /login ──────────────────────────────────
595
+ obs = env.reset("easy")
596
+ ip_counts: Dict[str, int] = {}
597
+ for req in obs.recent_requests:
598
+ if req["path"] == "/login" and req["method"] == "POST":
599
+ ip_counts[req["ip"]] = ip_counts.get(req["ip"], 0) + 1
600
+ suspect_ip = (
601
+ max(ip_counts, key=lambda k: ip_counts[k]) if ip_counts else ATTACK_IP_EASY
602
+ )
603
+ result = env.step(Action(action_type="block_ip", target_ip=suspect_ip))
604
+ scores["easy"] = result.reward.score
605
+
606
+ # ── Medium: identify the unusual User-Agent ────────────────────────────────
607
+ obs = env.reset("medium")
608
+ ua_counts: Dict[str, int] = {}
609
+ for req in obs.recent_requests:
610
+ ua_counts[req["user_agent"]] = ua_counts.get(req["user_agent"], 0) + 1
611
+
612
+ bot_keywords = {"scraper", "bot", "crawler", "spider", "harvester"}
613
+ browser_keywords = {"mozilla", "chrome", "safari", "firefox", "gecko", "webkit"}
614
+ suspect_ua = None
615
+
616
+ # Prefer UAs that look like bots
617
+ for ua, _ in sorted(ua_counts.items(), key=lambda x: -x[1]):
618
+ if any(k in ua.lower() for k in bot_keywords):
619
+ suspect_ua = ua
620
+ break
621
+ # Fallback: most common UA that doesn't look like a browser
622
+ if not suspect_ua:
623
+ for ua, _ in sorted(ua_counts.items(), key=lambda x: -x[1]):
624
+ if not any(k in ua.lower() for k in browser_keywords):
625
+ suspect_ua = ua
626
+ break
627
+
628
+ result = env.step(Action(action_type="block_user_agent", target_user_agent=suspect_ua or ""))
629
+ scores["medium"] = result.reward.score
630
+
631
+ # ── Hard: write a regex to catch SQLi payloads ────────────────────────────
632
+ env.reset("hard")
633
+ result = env.step(
634
+ Action(
635
+ action_type="write_custom_middleware",
636
+ regex_pattern=r"UNION\s+SELECT",
637
+ )
638
+ )
639
+ scores["hard"] = result.reward.score
640
+
641
+ return scores
main.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ API Gateway Defender — FastAPI Server
3
+ =====================================
4
+ Exposes the OpenEnv-compliant HTTP API for the environment.
5
+
6
+ Endpoints
7
+ ---------
8
+ POST /reset — Start a new episode
9
+ POST /step — Submit a firewall rule, receive reward
10
+ GET /state — Inspect current environment state
11
+ GET /tasks — List tasks and action schema
12
+ GET /grader — Get grader score for current episode
13
+ POST /baseline — Run heuristic baseline across all 3 tasks
14
+ GET /health — Liveness probe (required for HF Spaces ping)
15
+ """
16
+
17
+ from fastapi import FastAPI, HTTPException
18
+ from fastapi.middleware.cors import CORSMiddleware
19
+ from typing import Any, Dict
20
+
21
+ from env import (
22
+ Action,
23
+ APIGatewayDefender,
24
+ Observation,
25
+ TASK_DESCRIPTIONS,
26
+ run_heuristic_baseline,
27
+ )
28
+
29
+ # ─── App setup ───────────────────────────────────────────────────────────────────
30
+
31
+ app = FastAPI(
32
+ title="API Gateway Defender",
33
+ description=(
34
+ "An OpenEnv RL environment where an AI agent defends a simulated web backend "
35
+ "against volumetric floods, scraper bots, and SQL injection attacks."
36
+ ),
37
+ version="1.0.0",
38
+ )
39
+
40
+ app.add_middleware(
41
+ CORSMiddleware,
42
+ allow_origins=["*"],
43
+ allow_methods=["*"],
44
+ allow_headers=["*"],
45
+ )
46
+
47
+ # Single shared environment instance (stateful, per-session)
48
+ _env = APIGatewayDefender()
49
+
50
+
51
+ # ─── Routes ──────────────────────────────────────────────────────────────────────
52
+
53
+ @app.get("/health")
54
+ def health() -> Dict[str, str]:
55
+ """Liveness probe — returns 200 and confirms the environment is running."""
56
+ return {"status": "ok", "environment": "api-gateway-defender"}
57
+
58
+
59
+ @app.get("/")
60
+ def root() -> Dict[str, Any]:
61
+ """Overview of the environment and available endpoints."""
62
+ return {
63
+ "name": "API Gateway Defender",
64
+ "description": (
65
+ "OpenEnv RL environment: configure firewall rules to block malicious "
66
+ "HTTP traffic while preserving legitimate requests."
67
+ ),
68
+ "version": "1.0.0",
69
+ "tasks": list(TASK_DESCRIPTIONS.keys()),
70
+ "endpoints": {
71
+ "POST /reset": "Start a new episode. Body: {task_id: 'easy'|'medium'|'hard'}",
72
+ "POST /step": "Submit a firewall rule. Body: Action schema.",
73
+ "GET /state": "Current environment state snapshot.",
74
+ "GET /tasks": "Task descriptions + action/observation schemas.",
75
+ "GET /grader": "Current grader score for the active episode.",
76
+ "POST /baseline": "Run heuristic baseline agent across all 3 tasks.",
77
+ "GET /health": "Liveness probe.",
78
+ },
79
+ }
80
+
81
+
82
+ @app.post("/reset")
83
+ def reset(body: Dict[str, str] = None) -> Dict[str, Any]:
84
+ """
85
+ Start a new episode.
86
+
87
+ Request body (JSON):
88
+ {"task_id": "easy" | "medium" | "hard"}
89
+
90
+ Returns the initial Observation.
91
+ """
92
+ task_id = (body or {}).get("task_id", "easy")
93
+ try:
94
+ obs: Observation = _env.reset(task_id=task_id)
95
+ except ValueError as exc:
96
+ raise HTTPException(status_code=422, detail=str(exc))
97
+ return obs.model_dump()
98
+
99
+
100
+ @app.post("/step")
101
+ def step(action: Action) -> Dict[str, Any]:
102
+ """
103
+ Submit one firewall rule.
104
+
105
+ Returns StepResult: {observation, reward, done, info}
106
+
107
+ Reward score: 0.0–1.0
108
+ = detection_rate − (false_positive_rate × 5)
109
+ = 0.0 if false positive rate > 10%
110
+ """
111
+ try:
112
+ result = _env.step(action)
113
+ except RuntimeError as exc:
114
+ raise HTTPException(status_code=400, detail=str(exc))
115
+ return result.model_dump()
116
+
117
+
118
+ @app.get("/state")
119
+ def state() -> Dict[str, Any]:
120
+ """Return the full serialisable state of the current episode."""
121
+ return _env.state().model_dump()
122
+
123
+
124
+ @app.get("/tasks")
125
+ def tasks() -> Dict[str, Any]:
126
+ """
127
+ List all tasks and their descriptions, plus the action and observation schemas.
128
+ Required by the OpenEnv spec.
129
+ """
130
+ return {
131
+ "tasks": [
132
+ {
133
+ "id": "easy",
134
+ "name": "Volumetric IP Flood Defense",
135
+ "difficulty": "easy",
136
+ "description": TASK_DESCRIPTIONS["easy"],
137
+ "hint": "One IP is responsible for all malicious traffic.",
138
+ },
139
+ {
140
+ "id": "medium",
141
+ "name": "Scraper Bot Detection",
142
+ "difficulty": "medium",
143
+ "description": TASK_DESCRIPTIONS["medium"],
144
+ "hint": "Many IPs, but a single shared User-Agent string.",
145
+ },
146
+ {
147
+ "id": "hard",
148
+ "name": "SQL Injection Middleware Defense",
149
+ "difficulty": "hard",
150
+ "description": TASK_DESCRIPTIONS["hard"],
151
+ "hint": "Rotating IPs and UAs, but a consistent payload pattern.",
152
+ },
153
+ ],
154
+ "action_schema": Action.model_json_schema(),
155
+ "observation_schema": {
156
+ "recent_requests": "list[dict] — last 100 requests: ip, method, path, user_agent, query_string, status_code",
157
+ "active_rules": "list[str] — human-readable active firewall rules",
158
+ "current_task": "str — 'easy', 'medium', or 'hard'",
159
+ "task_description":"str — natural language goal",
160
+ "step_count": "int — steps taken this episode",
161
+ "hint": "str — statistical hint from visible traffic",
162
+ },
163
+ }
164
+
165
+
166
+ @app.get("/grader")
167
+ def grader() -> Dict[str, Any]:
168
+ """
169
+ Return the programmatic grader score for the current episode.
170
+ Score is 0.0–1.0; reflects detection rate minus false-positive penalty.
171
+ """
172
+ score = _env.get_task_grader_score()
173
+ state_info = _env.state()
174
+ return {
175
+ "task_id": state_info.task_id,
176
+ "score": score,
177
+ "best_score": state_info.best_score,
178
+ "rules_applied":[r["description"] for r in state_info.active_rules],
179
+ "episode_done": state_info.episode_done,
180
+ "max_steps": 5,
181
+ }
182
+
183
+
184
+ @app.post("/baseline")
185
+ def baseline() -> Dict[str, Any]:
186
+ """
187
+ Run the heuristic baseline agent across all 3 tasks and return scores.
188
+ Does not affect the shared episode state.
189
+ """
190
+ scores = run_heuristic_baseline()
191
+ avg = sum(scores.values()) / len(scores)
192
+ return {
193
+ "scores": scores,
194
+ "average": round(avg, 4),
195
+ "message": (
196
+ "Heuristic baseline: reads visible logs, identifies the attack pattern, "
197
+ "applies the optimal rule. No LLM required."
198
+ ),
199
+ }
openenv.yaml ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: api-gateway-defender
2
+ version: "1.0.0"
3
+ description: >
4
+ A simulated HTTP traffic monitoring environment where an AI agent acts as
5
+ a Site Reliability Engineer defending a web backend. The agent inspects a
6
+ stream of incoming HTTP requests and must configure middleware firewall rules
7
+ to block malicious traffic while preserving legitimate user requests.
8
+
9
+ Models a real production incident domain: rate-limiting, WAF rule authoring,
10
+ and pattern-based traffic filtering — skills that are highly valued in DevOps,
11
+ SRE, and cybersecurity engineering.
12
+
13
+ author: "API Gateway Defender Team"
14
+ license: "Apache-2.0"
15
+
16
+ tags:
17
+ - openenv
18
+ - cybersecurity
19
+ - web-security
20
+ - sre
21
+ - real-world
22
+ - devops
23
+ - rate-limiting
24
+ - waf
25
+
26
+ tasks:
27
+ - id: easy
28
+ name: "Volumetric IP Flood Defense"
29
+ difficulty: easy
30
+ max_score: 1.0
31
+ description: >
32
+ A single IP address is flooding the /login endpoint with POST requests.
33
+ The agent must identify the malicious IP from traffic logs and block it
34
+ (or apply a rate limit). Tests pattern recognition under high-volume noise.
35
+ success_criteria: >
36
+ block_ip or add_rate_limit action targeting the flooding IP address,
37
+ achieving ≥0.95 detection rate with <10% false positive rate.
38
+
39
+ - id: medium
40
+ name: "Scraper Bot Detection"
41
+ difficulty: medium
42
+ max_score: 1.0
43
+ description: >
44
+ A scraper bot harvests the /api/data endpoint from 50 different IP addresses,
45
+ rotating them to evade IP-based blocks. All malicious requests share one
46
+ identical unusual User-Agent string. The agent must identify and block it.
47
+ success_criteria: >
48
+ block_user_agent action with the exact malicious User-Agent string,
49
+ achieving ≥0.95 detection rate with <10% false positive rate.
50
+
51
+ - id: hard
52
+ name: "SQL Injection Middleware Defense"
53
+ difficulty: hard
54
+ max_score: 1.0
55
+ description: >
56
+ An attacker probes the database via SQL injection. They rotate IP addresses
57
+ AND User-Agents on every request to evade simple rules. Every malicious
58
+ request contains a SQL injection payload in the query string. The agent
59
+ must write a regex-based middleware rule to detect and block all payloads.
60
+ success_criteria: >
61
+ write_custom_middleware action with a regex that matches 'UNION SELECT'
62
+ pattern (case-insensitive), achieving ≥0.95 detection rate with <10% FP rate.
63
+
64
+ observation_space:
65
+ type: structured
66
+ description: "Snapshot of recent HTTP traffic and active gateway configuration."
67
+ fields:
68
+ - name: recent_requests
69
+ type: "list[dict]"
70
+ description: "Last 100 HTTP requests. Each has: ip, method, path, user_agent, query_string, status_code."
71
+ - name: active_rules
72
+ type: "list[str]"
73
+ description: "Human-readable list of firewall rules currently active."
74
+ - name: current_task
75
+ type: string
76
+ description: "Task ID: 'easy', 'medium', or 'hard'."
77
+ - name: task_description
78
+ type: string
79
+ description: "Natural language description of the attack to defend against."
80
+ - name: step_count
81
+ type: integer
82
+ description: "Number of rules submitted in the current episode."
83
+ - name: hint
84
+ type: string
85
+ description: "Statistical hint about suspicious patterns in the visible traffic window."
86
+
87
+ action_space:
88
+ type: discrete_parameterized
89
+ description: "Submit one firewall rule to the gateway middleware."
90
+ fields:
91
+ - name: action_type
92
+ type: string
93
+ required: true
94
+ choices:
95
+ - block_ip
96
+ - add_rate_limit
97
+ - block_user_agent
98
+ - write_custom_middleware
99
+ description: "Which type of rule to apply."
100
+ - name: target_ip
101
+ type: string
102
+ required: false
103
+ description: "IP address. Required for block_ip and add_rate_limit."
104
+ - name: target_user_agent
105
+ type: string
106
+ required: false
107
+ description: "Exact User-Agent string. Required for block_user_agent."
108
+ - name: regex_pattern
109
+ type: string
110
+ required: false
111
+ description: "Python regex matched against '{path}?{query_string}'. Required for write_custom_middleware."
112
+ - name: max_requests
113
+ type: integer
114
+ required: false
115
+ default: 60
116
+ description: "Requests per minute cap. Used with add_rate_limit."
117
+
118
+ reward:
119
+ range: [0.0, 1.0]
120
+ type: continuous
121
+ formula: >
122
+ detection_rate = malicious_blocked / total_malicious
123
+ false_positive_rate = legitimate_blocked / total_legitimate
124
+ if false_positive_rate > 0.10:
125
+ score = 0.0
126
+ else:
127
+ score = clamp(detection_rate - false_positive_rate * 5.0, 0.0, 1.0)
128
+ description: >
129
+ Rewards accurate detection of malicious traffic. Penalises false positives
130
+ (blocking legitimate users) with a 5x multiplier. Zeroed entirely if
131
+ false positive rate exceeds 10% — models real operational constraints
132
+ where blocking paying customers is unacceptable.
133
+
134
+ episode:
135
+ max_steps: 5
136
+ termination_conditions:
137
+ - "score >= 0.95 (success)"
138
+ - "step_count >= 5 (step limit)"
139
+ reset_required: true
140
+
141
+ evaluation:
142
+ grader_type: programmatic
143
+ deterministic: true
144
+ train_seed: 42
145
+ test_seed: 137
146
+ description: >
147
+ Rules are graded against a hidden test traffic set (seed 137) distinct from
148
+ the visible training sample (seed 42). This prevents agents from overfitting
149
+ to specific IPs/UAs in the observation window.
150
+
151
+ api:
152
+ framework: FastAPI
153
+ port: 7860
154
+ endpoints:
155
+ - "POST /reset"
156
+ - "POST /step"
157
+ - "GET /state"
158
+ - "GET /tasks"
159
+ - "GET /grader"
160
+ - "POST /baseline"
161
+ - "GET /health"
162
+
163
+ baseline:
164
+ agent_type: heuristic
165
+ scores:
166
+ easy: 1.0
167
+ medium: 1.0
168
+ hard: 1.0
169
+ note: >
170
+ Heuristic agent reads the visible traffic sample, identifies the attack
171
+ pattern statistically, and applies the optimal rule. Scores are fully
172
+ reproducible with fixed seeds.