CystronCode's picture
Update env.py
5574a9a verified
Raw
History Blame Contribute Delete
25.3 kB
"""
API Gateway Defender β€” OpenEnv Environment
==========================================
A simulated HTTP traffic monitoring environment where an AI agent must
configure middleware rules to block malicious web traffic while preserving
legitimate user requests. Models a real production incident faced by SREs.
Tasks:
easy β€” Single IP volumetric flood on /login
medium β€” Scraper bot using identical User-Agent across many IPs
hard β€” SQL injection attack with rotating IPs and User-Agents
"""
import re
import random
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
# ─── Constants ──────────────────────────────────────────────────────────────────
MAX_STEPS = 5
TRAFFIC_SIZE = 1000
MALICIOUS_COUNT = 200
FALSE_POSITIVE_THRESHOLD = 0.10 # >10% FP rate β†’ score zeroed
# Fixed attack artifacts (deterministic so graders are reproducible)
ATTACK_IP_EASY = "185.220.101.47"
SCRAPER_UA = "ScraperBot/3.1 (DataHarvester; +http://scrape.evil.io/bot)"
SQLI_KEYWORD = "UNION SELECT"
LEGIT_USER_AGENTS = [
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 Chrome/120.0.0.0 Safari/537.36",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 Version/17.0 Safari/605.1.15",
"Mozilla/5.0 (X11; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/115.0",
"Mozilla/5.0 (iPhone; CPU iPhone OS 17_0 like Mac OS X) AppleWebKit/605.1.15 Mobile/15E148",
"Mozilla/5.0 (Android 13; Mobile; rv:109.0) Gecko/109.0 Firefox/115.0",
]
LEGIT_PATHS = [
"/", "/home", "/about", "/products", "/contact",
"/blog", "/faq", "/pricing", "/docs", "/api/users", "/api/products",
]
LEGIT_QUERY_STRINGS = [
"", "page=1", "q=hello", "id=42", "sort=asc&filter=new",
"utm_source=google", "lang=en", "category=tech", "ref=nav",
]
LEGIT_METHODS = ["GET", "GET", "GET", "POST", "PUT"] # GET-heavy, realistic
SQLI_PAYLOADS = [
f"id=1' {SQLI_KEYWORD} username,password FROM users--",
f"search=admin' {SQLI_KEYWORD} 1,2,3--",
f"q=x' {SQLI_KEYWORD} table_name FROM information_schema.tables--",
f"cat=1' {SQLI_KEYWORD} NULL,NULL,NULL--",
f"item=5' {SQLI_KEYWORD} version()--",
]
# ─── Pydantic Models ─────────────────────────────────────────────────────────────
class Action(BaseModel):
"""
An action the agent can take β€” adds one firewall rule to the gateway.
action_type choices:
block_ip β€” Drop all traffic from target_ip
add_rate_limit β€” Allow target_ip only max_requests/min
block_user_agent β€” Drop all traffic matching target_user_agent exactly
write_custom_middleware β€” Drop requests where regex_pattern matches path?query_string
"""
action_type: str = Field(
...,
description=(
"Rule type: 'block_ip', 'add_rate_limit', "
"'block_user_agent', 'write_custom_middleware'"
),
)
target_ip: Optional[str] = Field(
None, description="IP address (required for block_ip / add_rate_limit)"
)
target_user_agent: Optional[str] = Field(
None, description="Exact User-Agent string (required for block_user_agent)"
)
regex_pattern: Optional[str] = Field(
None,
description=(
"Python regex matched against '{path}?{query_string}' "
"(required for write_custom_middleware)"
),
)
max_requests: Optional[int] = Field(
60, description="Requests-per-minute cap for add_rate_limit (default 60)"
)
class Observation(BaseModel):
"""What the agent sees at each step."""
recent_requests: List[Dict[str, Any]] = Field(
...,
description=(
"Last 100 HTTP requests in the traffic stream. "
"Fields: ip, method, path, user_agent, query_string, status_code."
),
)
active_rules: List[str] = Field(
..., description="Human-readable list of rules currently active on the gateway."
)
current_task: str = Field(..., description="Task ID: 'easy', 'medium', or 'hard'")
task_description: str = Field(
..., description="Natural language description of the attack the agent must repel."
)
step_count: int = Field(..., description="Number of rules submitted so far this episode.")
hint: str = Field("", description="Statistical hint derived from the visible traffic sample.")
class Reward(BaseModel):
"""Feedback returned after each step()."""
score: float = Field(..., ge=0.0, le=1.0, description="Task performance score 0.0–1.0")
malicious_blocked: int = Field(..., description="Malicious requests blocked by active rules")
legitimate_blocked: int = Field(..., description="Legitimate requests incorrectly blocked")
total_malicious: int
total_legitimate: int
false_positive_rate: float = Field(..., description="Fraction of legit requests blocked")
message: str = Field(..., description="Human-readable explanation of the score")
class StepResult(BaseModel):
"""Full return value of step()."""
observation: Observation
reward: Reward
done: bool
info: Dict[str, Any]
class EnvironmentState(BaseModel):
"""Full serialisable snapshot returned by state()."""
task_id: str
step_count: int
active_rules: List[Dict[str, Any]]
episode_done: bool
best_score: float
traffic_sample_size: int
# ─── Traffic Generators ──────────────────────────────────────────────────────────
def _rand_ip(rng: random.Random, exclude: str = "") -> str:
"""Generate a random public-looking IPv4 address."""
while True:
ip = (
f"{rng.randint(10, 220)}."
f"{rng.randint(1, 254)}."
f"{rng.randint(1, 254)}."
f"{rng.randint(1, 254)}"
)
if ip != exclude:
return ip
def _legit_request(rng: random.Random) -> Dict[str, Any]:
return {
"ip": _rand_ip(rng, exclude=ATTACK_IP_EASY),
"method": rng.choice(LEGIT_METHODS),
"path": rng.choice(LEGIT_PATHS),
"user_agent": rng.choice(LEGIT_USER_AGENTS),
"query_string": rng.choice(LEGIT_QUERY_STRINGS),
"status_code": 200,
"is_malicious": False,
}
def generate_easy_traffic(seed: int) -> List[Dict[str, Any]]:
"""
Easy: one IP floods /login with POST requests.
Correct action: block_ip or add_rate_limit on ATTACK_IP_EASY.
"""
rng = random.Random(seed)
traffic: List[Dict[str, Any]] = []
for _ in range(MALICIOUS_COUNT):
traffic.append({
"ip": ATTACK_IP_EASY,
"method": "POST",
"path": "/login",
"user_agent": rng.choice(LEGIT_USER_AGENTS), # UA blends in
"query_string": "",
"status_code": 200,
"is_malicious": True,
})
for _ in range(TRAFFIC_SIZE - MALICIOUS_COUNT):
traffic.append(_legit_request(rng))
rng.shuffle(traffic)
return traffic
def generate_medium_traffic(seed: int) -> List[Dict[str, Any]]:
"""
Medium: 50 IPs all share an identical unusual User-Agent, hitting /api/data.
Correct action: block_user_agent with SCRAPER_UA.
"""
rng = random.Random(seed)
traffic: List[Dict[str, Any]] = []
scraper_ips = [_rand_ip(rng) for _ in range(50)]
for _ in range(MALICIOUS_COUNT):
traffic.append({
"ip": rng.choice(scraper_ips),
"method": "GET",
"path": "/api/data",
"user_agent": SCRAPER_UA, # constant across all malicious requests
"query_string": f"page={rng.randint(1, 500)}",
"status_code": 200,
"is_malicious": True,
})
for _ in range(TRAFFIC_SIZE - MALICIOUS_COUNT):
traffic.append(_legit_request(rng))
rng.shuffle(traffic)
return traffic
def generate_hard_traffic(seed: int) -> List[Dict[str, Any]]:
"""
Hard: attacker rotates IPs and UAs but always carries a SQLi payload.
Correct action: write_custom_middleware with regex matching 'UNION.SELECT'.
"""
rng = random.Random(seed)
traffic: List[Dict[str, Any]] = []
for _ in range(MALICIOUS_COUNT):
traffic.append({
"ip": _rand_ip(rng),
"method": "GET",
"path": rng.choice(["/search", "/products", "/api/items", "/catalog"]),
"user_agent": rng.choice(LEGIT_USER_AGENTS),
"query_string": rng.choice(SQLI_PAYLOADS),
"status_code": 200,
"is_malicious": True,
})
for _ in range(TRAFFIC_SIZE - MALICIOUS_COUNT):
req = _legit_request(rng)
# Guarantee legit requests never accidentally contain the payload
if SQLI_KEYWORD in req["query_string"].upper():
req["query_string"] = ""
traffic.append(req)
rng.shuffle(traffic)
return traffic
TASK_GENERATORS = {
"easy": generate_easy_traffic,
"medium": generate_medium_traffic,
"hard": generate_hard_traffic,
}
TASK_DESCRIPTIONS = {
"easy": (
"A single IP address is flooding your /login endpoint with POST requests at high volume. "
"Inspect the traffic logs to identify the offending IP and block it or apply a rate limit."
),
"medium": (
"A scraper bot is harvesting your /api/data endpoint from many different IP addresses. "
"All malicious requests share a single, unusual User-Agent string. "
"Identify the User-Agent and block it."
),
"hard": (
"An attacker is probing your database via SQL injection. They rotate IP addresses and "
"User-Agents to evade simple rules, but every malicious request contains a SQL injection "
"payload in the query string. Write a regex middleware rule to detect and drop these requests."
),
}
# ─── Rule Engine ─────────────────────────────────────────────────────────────────
class _Rule:
"""Internal class: wraps an Action and applies it to individual requests."""
def __init__(self, action: Action) -> None:
self.action = action
self._compiled_re = None
if action.action_type == "write_custom_middleware" and action.regex_pattern:
try:
self._compiled_re = re.compile(action.regex_pattern, re.IGNORECASE)
except re.error:
pass # invalid regex β†’ rule matches nothing
def blocks(self, request: Dict[str, Any]) -> bool:
a = self.action
if a.action_type in ("block_ip", "add_rate_limit"):
return bool(a.target_ip and request["ip"] == a.target_ip)
if a.action_type == "block_user_agent":
return bool(
a.target_user_agent
and request["user_agent"] == a.target_user_agent
)
if a.action_type == "write_custom_middleware" and self._compiled_re:
target = f"{request['path']}?{request['query_string']}"
return bool(self._compiled_re.search(target))
return False
def describe(self) -> str:
a = self.action
if a.action_type == "block_ip":
return f"BLOCK_IP({a.target_ip})"
if a.action_type == "add_rate_limit":
return f"RATE_LIMIT({a.target_ip}, max={a.max_requests}/min)"
if a.action_type == "block_user_agent":
return f"BLOCK_UA({a.target_user_agent!r})"
if a.action_type == "write_custom_middleware":
return f"MIDDLEWARE(regex={a.regex_pattern!r})"
return f"RULE({a.action_type})"
def to_dict(self) -> Dict[str, Any]:
a = self.action
return {
"action_type": a.action_type,
"target_ip": a.target_ip,
"target_user_agent": a.target_user_agent,
"regex_pattern": a.regex_pattern,
"description": self.describe(),
}
# ─── Environment ─────────────────────────────────────────────────────────────────
VALID_ACTION_TYPES = {"block_ip", "add_rate_limit", "block_user_agent", "write_custom_middleware"}
class APIGatewayDefender:
"""
OpenEnv-compliant RL environment β€” API Gateway Defender.
The agent monitors a simulated stream of HTTP requests and must apply
firewall middleware rules to block malicious traffic while preserving
legitimate requests.
Usage
-----
env = APIGatewayDefender()
obs = env.reset(task_id="easy")
action = Action(action_type="block_ip", target_ip="185.220.101.47")
result = env.step(action)
print(result.reward.score)
"""
def __init__(self) -> None:
self._task_id: str = "easy"
self._rules: List[_Rule] = []
self._train_traffic: List[Dict[str, Any]] = []
self._test_traffic: List[Dict[str, Any]] = []
self._step_count: int = 0
self._done: bool = False
self._best_score: float = 0.0
# ── OpenEnv Interface ──────────────────────────────────────────────────────
def reset(self, task_id: str = "easy") -> Observation:
"""
Start a new episode on the given task.
Parameters
----------
task_id : str
One of 'easy', 'medium', 'hard'.
Returns
-------
Observation
Initial observation containing the first 100 traffic samples.
"""
if task_id not in TASK_GENERATORS:
raise ValueError(
f"Unknown task_id '{task_id}'. Choose from: {sorted(TASK_GENERATORS)}"
)
self._task_id = task_id
self._rules = []
self._step_count = 0
self._done = False
self._best_score = 0.0
gen = TASK_GENERATORS[task_id]
self._train_traffic = gen(seed=42) # agent can see this
self._test_traffic = gen(seed=137) # grading set (hidden from agent)
return self._make_observation()
def step(self, action: Action) -> StepResult:
"""
Submit one firewall rule and receive a reward signal.
The rule is evaluated against a hidden test traffic set to prevent
overfitting to the visible sample. Partial credit is awarded for
partial detection; false positives incur a penalty.
Parameters
----------
action : Action
The rule to apply.
Returns
-------
StepResult
observation, reward, done flag, and diagnostic info.
"""
if self._done:
raise RuntimeError("Episode is finished. Call reset() to start a new episode.")
self._step_count += 1
# ── Validate action type ──────────────────────────────────────────────
if action.action_type not in VALID_ACTION_TYPES:
err_reward = Reward(
score=0.0,
malicious_blocked=0,
legitimate_blocked=0,
total_malicious=MALICIOUS_COUNT,
total_legitimate=TRAFFIC_SIZE - MALICIOUS_COUNT,
false_positive_rate=0.0,
message=(
f"Invalid action_type '{action.action_type}'. "
f"Must be one of {sorted(VALID_ACTION_TYPES)}."
),
)
return StepResult(
observation=self._make_observation(),
reward=err_reward,
done=False,
info={"error": "invalid_action_type"},
)
# ── Apply rule ────────────────────────────────────────────────────────
self._rules.append(_Rule(action))
# ── Grade on hidden test traffic ──────────────────────────────────────
reward = self._grade()
self._best_score = max(self._best_score, reward.score)
# Episode ends at MAX_STEPS or when the agent achieves near-perfect score
self._done = self._step_count >= MAX_STEPS or reward.score >= 0.95
return StepResult(
observation=self._make_observation(),
reward=reward,
done=self._done,
info={
"step": self._step_count,
"best_score": self._best_score,
"rules_applied": [r.describe() for r in self._rules],
"max_steps": MAX_STEPS,
},
)
def state(self) -> EnvironmentState:
"""Return a full serialisable snapshot of the current environment state."""
return EnvironmentState(
task_id=self._task_id,
step_count=self._step_count,
active_rules=[r.to_dict() for r in self._rules],
episode_done=self._done,
best_score=self._best_score,
traffic_sample_size=len(self._train_traffic),
)
def get_task_grader_score(self) -> float:
"""
Programmatic grader β€” returns score strictly in (0, 1) for the current episode.
Returns the minimum non-zero score if no rules have been applied yet.
"""
if not self._rules:
return 0.001
return self._grade().score
# ── Private Helpers ────────────────────────────────────────────────────────
def _make_observation(self) -> Observation:
"""Build an Observation from the current state (no is_malicious flag exposed)."""
visible = [
{k: v for k, v in req.items() if k != "is_malicious"}
for req in self._train_traffic[:100]
]
return Observation(
recent_requests=visible,
active_rules=[r.describe() for r in self._rules],
current_task=self._task_id,
task_description=TASK_DESCRIPTIONS[self._task_id],
step_count=self._step_count,
hint=self._build_hint(),
)
def _build_hint(self) -> str:
"""Generate a statistical hint from the visible traffic sample."""
if not self._train_traffic:
return ""
sample = self._train_traffic[:100]
malicious_in_sample = [r for r in sample if r.get("is_malicious")]
n = len(malicious_in_sample)
if self._task_id == "easy":
if n == 0:
return "Traffic looks normal in this window."
ips = {r["ip"] for r in malicious_in_sample}
return (
f"Warning: {n} POST requests to /login detected in this window "
f"from {len(ips)} unique IP(s). Possible brute-force or flood."
)
elif self._task_id == "medium":
if n == 0:
return "Traffic looks normal in this window."
uas = {r["user_agent"] for r in malicious_in_sample}
return (
f"Warning: {n} requests to /api/data share {len(uas)} unique User-Agent(s) "
f"in this window. Possible scraper activity."
)
else:
if n == 0:
return "Traffic looks normal in this window."
return (
f"Warning: {n} requests in this window contain unusual query string patterns. "
f"Check for injection payloads."
)
# Validator requires scores strictly between 0 and 1 (exclusive)
_SCORE_MIN = 0.001
_SCORE_MAX = 0.999
def _grade(self) -> Reward:
"""
Apply all active rules to the hidden test traffic set and compute a score.
Score formula:
detection_rate = malicious_blocked / total_malicious
fp_rate = legitimate_blocked / total_legitimate
if fp_rate > FALSE_POSITIVE_THRESHOLD:
score = _SCORE_MIN ← too many false positives
else:
score = clamp(detection_rate - fp_rate * 5.0, _SCORE_MIN, _SCORE_MAX)
The final score is always strictly in (0, 1) as required by the validator.
"""
malicious = [r for r in self._test_traffic if r["is_malicious"]]
legit = [r for r in self._test_traffic if not r["is_malicious"]]
mal_blocked = sum(1 for r in malicious if any(rule.blocks(r) for rule in self._rules))
legit_blocked = sum(1 for r in legit if any(rule.blocks(r) for rule in self._rules))
total_mal = len(malicious)
total_legit = len(legit)
detection_rate = mal_blocked / total_mal if total_mal > 0 else 0.0
fp_rate = legit_blocked / total_legit if total_legit > 0 else 0.0
if fp_rate > FALSE_POSITIVE_THRESHOLD:
score = self._SCORE_MIN
message = (
f"Score floored: {fp_rate:.1%} false positive rate exceeds "
f"{FALSE_POSITIVE_THRESHOLD:.0%} threshold. Rules are too broad β€” "
f"legitimate users are being blocked."
)
else:
raw = detection_rate - fp_rate * 5.0
score = max(self._SCORE_MIN, min(self._SCORE_MAX, raw))
message = (
f"Blocked {mal_blocked}/{total_mal} malicious requests "
f"({detection_rate:.1%} detection rate) with "
f"{fp_rate:.1%} false positive rate."
)
return Reward(
score=round(score, 4),
malicious_blocked=mal_blocked,
legitimate_blocked=legit_blocked,
total_malicious=total_mal,
total_legitimate=total_legit,
false_positive_rate=round(fp_rate, 4),
message=message,
)
# ─── Convenience: heuristic baseline that runs directly on the class ────────────
def run_heuristic_baseline() -> Dict[str, float]:
"""
A deterministic heuristic agent that solves all 3 tasks correctly.
Used by the /baseline endpoint and as fallback in the inference script.
Returns
-------
Dict[str, float]
task_id β†’ score
"""
env = APIGatewayDefender()
scores: Dict[str, float] = {}
# ── Easy: identify the IP flooding /login ──────────────────────────────────
obs = env.reset("easy")
ip_counts: Dict[str, int] = {}
for req in obs.recent_requests:
if req["path"] == "/login" and req["method"] == "POST":
ip_counts[req["ip"]] = ip_counts.get(req["ip"], 0) + 1
suspect_ip = (
max(ip_counts, key=lambda k: ip_counts[k]) if ip_counts else ATTACK_IP_EASY
)
result = env.step(Action(action_type="block_ip", target_ip=suspect_ip))
scores["easy"] = result.reward.score
# ── Medium: identify the unusual User-Agent ────────────────────────────────
obs = env.reset("medium")
ua_counts: Dict[str, int] = {}
for req in obs.recent_requests:
ua_counts[req["user_agent"]] = ua_counts.get(req["user_agent"], 0) + 1
bot_keywords = {"scraper", "bot", "crawler", "spider", "harvester"}
browser_keywords = {"mozilla", "chrome", "safari", "firefox", "gecko", "webkit"}
suspect_ua = None
# Prefer UAs that look like bots
for ua, _ in sorted(ua_counts.items(), key=lambda x: -x[1]):
if any(k in ua.lower() for k in bot_keywords):
suspect_ua = ua
break
# Fallback: most common UA that doesn't look like a browser
if not suspect_ua:
for ua, _ in sorted(ua_counts.items(), key=lambda x: -x[1]):
if not any(k in ua.lower() for k in browser_keywords):
suspect_ua = ua
break
result = env.step(Action(action_type="block_user_agent", target_user_agent=suspect_ua or ""))
scores["medium"] = result.reward.score
# ── Hard: write a regex to catch SQLi payloads ────────────────────────────
env.reset("hard")
result = env.step(
Action(
action_type="write_custom_middleware",
regex_pattern=r"UNION\s+SELECT",
)
)
scores["hard"] = result.reward.score
return scores