Spaces:
Sleeping
Sleeping
| """ | |
| API Gateway Defender β OpenEnv Environment | |
| ========================================== | |
| A simulated HTTP traffic monitoring environment where an AI agent must | |
| configure middleware rules to block malicious web traffic while preserving | |
| legitimate user requests. Models a real production incident faced by SREs. | |
| Tasks: | |
| easy β Single IP volumetric flood on /login | |
| medium β Scraper bot using identical User-Agent across many IPs | |
| hard β SQL injection attack with rotating IPs and User-Agents | |
| """ | |
| import re | |
| import random | |
| from typing import Any, Dict, List, Optional | |
| from pydantic import BaseModel, Field | |
| # βββ Constants ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MAX_STEPS = 5 | |
| TRAFFIC_SIZE = 1000 | |
| MALICIOUS_COUNT = 200 | |
| FALSE_POSITIVE_THRESHOLD = 0.10 # >10% FP rate β score zeroed | |
| # Fixed attack artifacts (deterministic so graders are reproducible) | |
| ATTACK_IP_EASY = "185.220.101.47" | |
| SCRAPER_UA = "ScraperBot/3.1 (DataHarvester; +http://scrape.evil.io/bot)" | |
| SQLI_KEYWORD = "UNION SELECT" | |
| LEGIT_USER_AGENTS = [ | |
| "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 Chrome/120.0.0.0 Safari/537.36", | |
| "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 Version/17.0 Safari/605.1.15", | |
| "Mozilla/5.0 (X11; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/115.0", | |
| "Mozilla/5.0 (iPhone; CPU iPhone OS 17_0 like Mac OS X) AppleWebKit/605.1.15 Mobile/15E148", | |
| "Mozilla/5.0 (Android 13; Mobile; rv:109.0) Gecko/109.0 Firefox/115.0", | |
| ] | |
| LEGIT_PATHS = [ | |
| "/", "/home", "/about", "/products", "/contact", | |
| "/blog", "/faq", "/pricing", "/docs", "/api/users", "/api/products", | |
| ] | |
| LEGIT_QUERY_STRINGS = [ | |
| "", "page=1", "q=hello", "id=42", "sort=asc&filter=new", | |
| "utm_source=google", "lang=en", "category=tech", "ref=nav", | |
| ] | |
| LEGIT_METHODS = ["GET", "GET", "GET", "POST", "PUT"] # GET-heavy, realistic | |
| SQLI_PAYLOADS = [ | |
| f"id=1' {SQLI_KEYWORD} username,password FROM users--", | |
| f"search=admin' {SQLI_KEYWORD} 1,2,3--", | |
| f"q=x' {SQLI_KEYWORD} table_name FROM information_schema.tables--", | |
| f"cat=1' {SQLI_KEYWORD} NULL,NULL,NULL--", | |
| f"item=5' {SQLI_KEYWORD} version()--", | |
| ] | |
| # βββ Pydantic Models βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class Action(BaseModel): | |
| """ | |
| An action the agent can take β adds one firewall rule to the gateway. | |
| action_type choices: | |
| block_ip β Drop all traffic from target_ip | |
| add_rate_limit β Allow target_ip only max_requests/min | |
| block_user_agent β Drop all traffic matching target_user_agent exactly | |
| write_custom_middleware β Drop requests where regex_pattern matches path?query_string | |
| """ | |
| action_type: str = Field( | |
| ..., | |
| description=( | |
| "Rule type: 'block_ip', 'add_rate_limit', " | |
| "'block_user_agent', 'write_custom_middleware'" | |
| ), | |
| ) | |
| target_ip: Optional[str] = Field( | |
| None, description="IP address (required for block_ip / add_rate_limit)" | |
| ) | |
| target_user_agent: Optional[str] = Field( | |
| None, description="Exact User-Agent string (required for block_user_agent)" | |
| ) | |
| regex_pattern: Optional[str] = Field( | |
| None, | |
| description=( | |
| "Python regex matched against '{path}?{query_string}' " | |
| "(required for write_custom_middleware)" | |
| ), | |
| ) | |
| max_requests: Optional[int] = Field( | |
| 60, description="Requests-per-minute cap for add_rate_limit (default 60)" | |
| ) | |
| class Observation(BaseModel): | |
| """What the agent sees at each step.""" | |
| recent_requests: List[Dict[str, Any]] = Field( | |
| ..., | |
| description=( | |
| "Last 100 HTTP requests in the traffic stream. " | |
| "Fields: ip, method, path, user_agent, query_string, status_code." | |
| ), | |
| ) | |
| active_rules: List[str] = Field( | |
| ..., description="Human-readable list of rules currently active on the gateway." | |
| ) | |
| current_task: str = Field(..., description="Task ID: 'easy', 'medium', or 'hard'") | |
| task_description: str = Field( | |
| ..., description="Natural language description of the attack the agent must repel." | |
| ) | |
| step_count: int = Field(..., description="Number of rules submitted so far this episode.") | |
| hint: str = Field("", description="Statistical hint derived from the visible traffic sample.") | |
| class Reward(BaseModel): | |
| """Feedback returned after each step().""" | |
| score: float = Field(..., ge=0.0, le=1.0, description="Task performance score 0.0β1.0") | |
| malicious_blocked: int = Field(..., description="Malicious requests blocked by active rules") | |
| legitimate_blocked: int = Field(..., description="Legitimate requests incorrectly blocked") | |
| total_malicious: int | |
| total_legitimate: int | |
| false_positive_rate: float = Field(..., description="Fraction of legit requests blocked") | |
| message: str = Field(..., description="Human-readable explanation of the score") | |
| class StepResult(BaseModel): | |
| """Full return value of step().""" | |
| observation: Observation | |
| reward: Reward | |
| done: bool | |
| info: Dict[str, Any] | |
| class EnvironmentState(BaseModel): | |
| """Full serialisable snapshot returned by state().""" | |
| task_id: str | |
| step_count: int | |
| active_rules: List[Dict[str, Any]] | |
| episode_done: bool | |
| best_score: float | |
| traffic_sample_size: int | |
| # βββ Traffic Generators ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _rand_ip(rng: random.Random, exclude: str = "") -> str: | |
| """Generate a random public-looking IPv4 address.""" | |
| while True: | |
| ip = ( | |
| f"{rng.randint(10, 220)}." | |
| f"{rng.randint(1, 254)}." | |
| f"{rng.randint(1, 254)}." | |
| f"{rng.randint(1, 254)}" | |
| ) | |
| if ip != exclude: | |
| return ip | |
| def _legit_request(rng: random.Random) -> Dict[str, Any]: | |
| return { | |
| "ip": _rand_ip(rng, exclude=ATTACK_IP_EASY), | |
| "method": rng.choice(LEGIT_METHODS), | |
| "path": rng.choice(LEGIT_PATHS), | |
| "user_agent": rng.choice(LEGIT_USER_AGENTS), | |
| "query_string": rng.choice(LEGIT_QUERY_STRINGS), | |
| "status_code": 200, | |
| "is_malicious": False, | |
| } | |
| def generate_easy_traffic(seed: int) -> List[Dict[str, Any]]: | |
| """ | |
| Easy: one IP floods /login with POST requests. | |
| Correct action: block_ip or add_rate_limit on ATTACK_IP_EASY. | |
| """ | |
| rng = random.Random(seed) | |
| traffic: List[Dict[str, Any]] = [] | |
| for _ in range(MALICIOUS_COUNT): | |
| traffic.append({ | |
| "ip": ATTACK_IP_EASY, | |
| "method": "POST", | |
| "path": "/login", | |
| "user_agent": rng.choice(LEGIT_USER_AGENTS), # UA blends in | |
| "query_string": "", | |
| "status_code": 200, | |
| "is_malicious": True, | |
| }) | |
| for _ in range(TRAFFIC_SIZE - MALICIOUS_COUNT): | |
| traffic.append(_legit_request(rng)) | |
| rng.shuffle(traffic) | |
| return traffic | |
| def generate_medium_traffic(seed: int) -> List[Dict[str, Any]]: | |
| """ | |
| Medium: 50 IPs all share an identical unusual User-Agent, hitting /api/data. | |
| Correct action: block_user_agent with SCRAPER_UA. | |
| """ | |
| rng = random.Random(seed) | |
| traffic: List[Dict[str, Any]] = [] | |
| scraper_ips = [_rand_ip(rng) for _ in range(50)] | |
| for _ in range(MALICIOUS_COUNT): | |
| traffic.append({ | |
| "ip": rng.choice(scraper_ips), | |
| "method": "GET", | |
| "path": "/api/data", | |
| "user_agent": SCRAPER_UA, # constant across all malicious requests | |
| "query_string": f"page={rng.randint(1, 500)}", | |
| "status_code": 200, | |
| "is_malicious": True, | |
| }) | |
| for _ in range(TRAFFIC_SIZE - MALICIOUS_COUNT): | |
| traffic.append(_legit_request(rng)) | |
| rng.shuffle(traffic) | |
| return traffic | |
| def generate_hard_traffic(seed: int) -> List[Dict[str, Any]]: | |
| """ | |
| Hard: attacker rotates IPs and UAs but always carries a SQLi payload. | |
| Correct action: write_custom_middleware with regex matching 'UNION.SELECT'. | |
| """ | |
| rng = random.Random(seed) | |
| traffic: List[Dict[str, Any]] = [] | |
| for _ in range(MALICIOUS_COUNT): | |
| traffic.append({ | |
| "ip": _rand_ip(rng), | |
| "method": "GET", | |
| "path": rng.choice(["/search", "/products", "/api/items", "/catalog"]), | |
| "user_agent": rng.choice(LEGIT_USER_AGENTS), | |
| "query_string": rng.choice(SQLI_PAYLOADS), | |
| "status_code": 200, | |
| "is_malicious": True, | |
| }) | |
| for _ in range(TRAFFIC_SIZE - MALICIOUS_COUNT): | |
| req = _legit_request(rng) | |
| # Guarantee legit requests never accidentally contain the payload | |
| if SQLI_KEYWORD in req["query_string"].upper(): | |
| req["query_string"] = "" | |
| traffic.append(req) | |
| rng.shuffle(traffic) | |
| return traffic | |
| TASK_GENERATORS = { | |
| "easy": generate_easy_traffic, | |
| "medium": generate_medium_traffic, | |
| "hard": generate_hard_traffic, | |
| } | |
| TASK_DESCRIPTIONS = { | |
| "easy": ( | |
| "A single IP address is flooding your /login endpoint with POST requests at high volume. " | |
| "Inspect the traffic logs to identify the offending IP and block it or apply a rate limit." | |
| ), | |
| "medium": ( | |
| "A scraper bot is harvesting your /api/data endpoint from many different IP addresses. " | |
| "All malicious requests share a single, unusual User-Agent string. " | |
| "Identify the User-Agent and block it." | |
| ), | |
| "hard": ( | |
| "An attacker is probing your database via SQL injection. They rotate IP addresses and " | |
| "User-Agents to evade simple rules, but every malicious request contains a SQL injection " | |
| "payload in the query string. Write a regex middleware rule to detect and drop these requests." | |
| ), | |
| } | |
| # βββ Rule Engine βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class _Rule: | |
| """Internal class: wraps an Action and applies it to individual requests.""" | |
| def __init__(self, action: Action) -> None: | |
| self.action = action | |
| self._compiled_re = None | |
| if action.action_type == "write_custom_middleware" and action.regex_pattern: | |
| try: | |
| self._compiled_re = re.compile(action.regex_pattern, re.IGNORECASE) | |
| except re.error: | |
| pass # invalid regex β rule matches nothing | |
| def blocks(self, request: Dict[str, Any]) -> bool: | |
| a = self.action | |
| if a.action_type in ("block_ip", "add_rate_limit"): | |
| return bool(a.target_ip and request["ip"] == a.target_ip) | |
| if a.action_type == "block_user_agent": | |
| return bool( | |
| a.target_user_agent | |
| and request["user_agent"] == a.target_user_agent | |
| ) | |
| if a.action_type == "write_custom_middleware" and self._compiled_re: | |
| target = f"{request['path']}?{request['query_string']}" | |
| return bool(self._compiled_re.search(target)) | |
| return False | |
| def describe(self) -> str: | |
| a = self.action | |
| if a.action_type == "block_ip": | |
| return f"BLOCK_IP({a.target_ip})" | |
| if a.action_type == "add_rate_limit": | |
| return f"RATE_LIMIT({a.target_ip}, max={a.max_requests}/min)" | |
| if a.action_type == "block_user_agent": | |
| return f"BLOCK_UA({a.target_user_agent!r})" | |
| if a.action_type == "write_custom_middleware": | |
| return f"MIDDLEWARE(regex={a.regex_pattern!r})" | |
| return f"RULE({a.action_type})" | |
| def to_dict(self) -> Dict[str, Any]: | |
| a = self.action | |
| return { | |
| "action_type": a.action_type, | |
| "target_ip": a.target_ip, | |
| "target_user_agent": a.target_user_agent, | |
| "regex_pattern": a.regex_pattern, | |
| "description": self.describe(), | |
| } | |
| # βββ Environment βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| VALID_ACTION_TYPES = {"block_ip", "add_rate_limit", "block_user_agent", "write_custom_middleware"} | |
| class APIGatewayDefender: | |
| """ | |
| OpenEnv-compliant RL environment β API Gateway Defender. | |
| The agent monitors a simulated stream of HTTP requests and must apply | |
| firewall middleware rules to block malicious traffic while preserving | |
| legitimate requests. | |
| Usage | |
| ----- | |
| env = APIGatewayDefender() | |
| obs = env.reset(task_id="easy") | |
| action = Action(action_type="block_ip", target_ip="185.220.101.47") | |
| result = env.step(action) | |
| print(result.reward.score) | |
| """ | |
| def __init__(self) -> None: | |
| self._task_id: str = "easy" | |
| self._rules: List[_Rule] = [] | |
| self._train_traffic: List[Dict[str, Any]] = [] | |
| self._test_traffic: List[Dict[str, Any]] = [] | |
| self._step_count: int = 0 | |
| self._done: bool = False | |
| self._best_score: float = 0.0 | |
| # ββ OpenEnv Interface ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def reset(self, task_id: str = "easy") -> Observation: | |
| """ | |
| Start a new episode on the given task. | |
| Parameters | |
| ---------- | |
| task_id : str | |
| One of 'easy', 'medium', 'hard'. | |
| Returns | |
| ------- | |
| Observation | |
| Initial observation containing the first 100 traffic samples. | |
| """ | |
| if task_id not in TASK_GENERATORS: | |
| raise ValueError( | |
| f"Unknown task_id '{task_id}'. Choose from: {sorted(TASK_GENERATORS)}" | |
| ) | |
| self._task_id = task_id | |
| self._rules = [] | |
| self._step_count = 0 | |
| self._done = False | |
| self._best_score = 0.0 | |
| gen = TASK_GENERATORS[task_id] | |
| self._train_traffic = gen(seed=42) # agent can see this | |
| self._test_traffic = gen(seed=137) # grading set (hidden from agent) | |
| return self._make_observation() | |
| def step(self, action: Action) -> StepResult: | |
| """ | |
| Submit one firewall rule and receive a reward signal. | |
| The rule is evaluated against a hidden test traffic set to prevent | |
| overfitting to the visible sample. Partial credit is awarded for | |
| partial detection; false positives incur a penalty. | |
| Parameters | |
| ---------- | |
| action : Action | |
| The rule to apply. | |
| Returns | |
| ------- | |
| StepResult | |
| observation, reward, done flag, and diagnostic info. | |
| """ | |
| if self._done: | |
| raise RuntimeError("Episode is finished. Call reset() to start a new episode.") | |
| self._step_count += 1 | |
| # ββ Validate action type ββββββββββββββββββββββββββββββββββββββββββββββ | |
| if action.action_type not in VALID_ACTION_TYPES: | |
| err_reward = Reward( | |
| score=0.0, | |
| malicious_blocked=0, | |
| legitimate_blocked=0, | |
| total_malicious=MALICIOUS_COUNT, | |
| total_legitimate=TRAFFIC_SIZE - MALICIOUS_COUNT, | |
| false_positive_rate=0.0, | |
| message=( | |
| f"Invalid action_type '{action.action_type}'. " | |
| f"Must be one of {sorted(VALID_ACTION_TYPES)}." | |
| ), | |
| ) | |
| return StepResult( | |
| observation=self._make_observation(), | |
| reward=err_reward, | |
| done=False, | |
| info={"error": "invalid_action_type"}, | |
| ) | |
| # ββ Apply rule ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| self._rules.append(_Rule(action)) | |
| # ββ Grade on hidden test traffic ββββββββββββββββββββββββββββββββββββββ | |
| reward = self._grade() | |
| self._best_score = max(self._best_score, reward.score) | |
| # Episode ends at MAX_STEPS or when the agent achieves near-perfect score | |
| self._done = self._step_count >= MAX_STEPS or reward.score >= 0.95 | |
| return StepResult( | |
| observation=self._make_observation(), | |
| reward=reward, | |
| done=self._done, | |
| info={ | |
| "step": self._step_count, | |
| "best_score": self._best_score, | |
| "rules_applied": [r.describe() for r in self._rules], | |
| "max_steps": MAX_STEPS, | |
| }, | |
| ) | |
| def state(self) -> EnvironmentState: | |
| """Return a full serialisable snapshot of the current environment state.""" | |
| return EnvironmentState( | |
| task_id=self._task_id, | |
| step_count=self._step_count, | |
| active_rules=[r.to_dict() for r in self._rules], | |
| episode_done=self._done, | |
| best_score=self._best_score, | |
| traffic_sample_size=len(self._train_traffic), | |
| ) | |
| def get_task_grader_score(self) -> float: | |
| """ | |
| Programmatic grader β returns score strictly in (0, 1) for the current episode. | |
| Returns the minimum non-zero score if no rules have been applied yet. | |
| """ | |
| if not self._rules: | |
| return 0.001 | |
| return self._grade().score | |
| # ββ Private Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _make_observation(self) -> Observation: | |
| """Build an Observation from the current state (no is_malicious flag exposed).""" | |
| visible = [ | |
| {k: v for k, v in req.items() if k != "is_malicious"} | |
| for req in self._train_traffic[:100] | |
| ] | |
| return Observation( | |
| recent_requests=visible, | |
| active_rules=[r.describe() for r in self._rules], | |
| current_task=self._task_id, | |
| task_description=TASK_DESCRIPTIONS[self._task_id], | |
| step_count=self._step_count, | |
| hint=self._build_hint(), | |
| ) | |
| def _build_hint(self) -> str: | |
| """Generate a statistical hint from the visible traffic sample.""" | |
| if not self._train_traffic: | |
| return "" | |
| sample = self._train_traffic[:100] | |
| malicious_in_sample = [r for r in sample if r.get("is_malicious")] | |
| n = len(malicious_in_sample) | |
| if self._task_id == "easy": | |
| if n == 0: | |
| return "Traffic looks normal in this window." | |
| ips = {r["ip"] for r in malicious_in_sample} | |
| return ( | |
| f"Warning: {n} POST requests to /login detected in this window " | |
| f"from {len(ips)} unique IP(s). Possible brute-force or flood." | |
| ) | |
| elif self._task_id == "medium": | |
| if n == 0: | |
| return "Traffic looks normal in this window." | |
| uas = {r["user_agent"] for r in malicious_in_sample} | |
| return ( | |
| f"Warning: {n} requests to /api/data share {len(uas)} unique User-Agent(s) " | |
| f"in this window. Possible scraper activity." | |
| ) | |
| else: | |
| if n == 0: | |
| return "Traffic looks normal in this window." | |
| return ( | |
| f"Warning: {n} requests in this window contain unusual query string patterns. " | |
| f"Check for injection payloads." | |
| ) | |
| # Validator requires scores strictly between 0 and 1 (exclusive) | |
| _SCORE_MIN = 0.001 | |
| _SCORE_MAX = 0.999 | |
| def _grade(self) -> Reward: | |
| """ | |
| Apply all active rules to the hidden test traffic set and compute a score. | |
| Score formula: | |
| detection_rate = malicious_blocked / total_malicious | |
| fp_rate = legitimate_blocked / total_legitimate | |
| if fp_rate > FALSE_POSITIVE_THRESHOLD: | |
| score = _SCORE_MIN β too many false positives | |
| else: | |
| score = clamp(detection_rate - fp_rate * 5.0, _SCORE_MIN, _SCORE_MAX) | |
| The final score is always strictly in (0, 1) as required by the validator. | |
| """ | |
| malicious = [r for r in self._test_traffic if r["is_malicious"]] | |
| legit = [r for r in self._test_traffic if not r["is_malicious"]] | |
| mal_blocked = sum(1 for r in malicious if any(rule.blocks(r) for rule in self._rules)) | |
| legit_blocked = sum(1 for r in legit if any(rule.blocks(r) for rule in self._rules)) | |
| total_mal = len(malicious) | |
| total_legit = len(legit) | |
| detection_rate = mal_blocked / total_mal if total_mal > 0 else 0.0 | |
| fp_rate = legit_blocked / total_legit if total_legit > 0 else 0.0 | |
| if fp_rate > FALSE_POSITIVE_THRESHOLD: | |
| score = self._SCORE_MIN | |
| message = ( | |
| f"Score floored: {fp_rate:.1%} false positive rate exceeds " | |
| f"{FALSE_POSITIVE_THRESHOLD:.0%} threshold. Rules are too broad β " | |
| f"legitimate users are being blocked." | |
| ) | |
| else: | |
| raw = detection_rate - fp_rate * 5.0 | |
| score = max(self._SCORE_MIN, min(self._SCORE_MAX, raw)) | |
| message = ( | |
| f"Blocked {mal_blocked}/{total_mal} malicious requests " | |
| f"({detection_rate:.1%} detection rate) with " | |
| f"{fp_rate:.1%} false positive rate." | |
| ) | |
| return Reward( | |
| score=round(score, 4), | |
| malicious_blocked=mal_blocked, | |
| legitimate_blocked=legit_blocked, | |
| total_malicious=total_mal, | |
| total_legitimate=total_legit, | |
| false_positive_rate=round(fp_rate, 4), | |
| message=message, | |
| ) | |
| # βββ Convenience: heuristic baseline that runs directly on the class ββββββββββββ | |
| def run_heuristic_baseline() -> Dict[str, float]: | |
| """ | |
| A deterministic heuristic agent that solves all 3 tasks correctly. | |
| Used by the /baseline endpoint and as fallback in the inference script. | |
| Returns | |
| ------- | |
| Dict[str, float] | |
| task_id β score | |
| """ | |
| env = APIGatewayDefender() | |
| scores: Dict[str, float] = {} | |
| # ββ Easy: identify the IP flooding /login ββββββββββββββββββββββββββββββββββ | |
| obs = env.reset("easy") | |
| ip_counts: Dict[str, int] = {} | |
| for req in obs.recent_requests: | |
| if req["path"] == "/login" and req["method"] == "POST": | |
| ip_counts[req["ip"]] = ip_counts.get(req["ip"], 0) + 1 | |
| suspect_ip = ( | |
| max(ip_counts, key=lambda k: ip_counts[k]) if ip_counts else ATTACK_IP_EASY | |
| ) | |
| result = env.step(Action(action_type="block_ip", target_ip=suspect_ip)) | |
| scores["easy"] = result.reward.score | |
| # ββ Medium: identify the unusual User-Agent ββββββββββββββββββββββββββββββββ | |
| obs = env.reset("medium") | |
| ua_counts: Dict[str, int] = {} | |
| for req in obs.recent_requests: | |
| ua_counts[req["user_agent"]] = ua_counts.get(req["user_agent"], 0) + 1 | |
| bot_keywords = {"scraper", "bot", "crawler", "spider", "harvester"} | |
| browser_keywords = {"mozilla", "chrome", "safari", "firefox", "gecko", "webkit"} | |
| suspect_ua = None | |
| # Prefer UAs that look like bots | |
| for ua, _ in sorted(ua_counts.items(), key=lambda x: -x[1]): | |
| if any(k in ua.lower() for k in bot_keywords): | |
| suspect_ua = ua | |
| break | |
| # Fallback: most common UA that doesn't look like a browser | |
| if not suspect_ua: | |
| for ua, _ in sorted(ua_counts.items(), key=lambda x: -x[1]): | |
| if not any(k in ua.lower() for k in browser_keywords): | |
| suspect_ua = ua | |
| break | |
| result = env.step(Action(action_type="block_user_agent", target_user_agent=suspect_ua or "")) | |
| scores["medium"] = result.reward.score | |
| # ββ Hard: write a regex to catch SQLi payloads ββββββββββββββββββββββββββββ | |
| env.reset("hard") | |
| result = env.step( | |
| Action( | |
| action_type="write_custom_middleware", | |
| regex_pattern=r"UNION\s+SELECT", | |
| ) | |
| ) | |
| scores["hard"] = result.reward.score | |
| return scores | |