Spaces:
Sleeping
Sleeping
File size: 7,220 Bytes
e220bac 71ebbfc e220bac 71ebbfc e220bac 8a168fe e220bac 71ebbfc e220bac 8a168fe e220bac 71ebbfc e220bac 71ebbfc e220bac 71ebbfc e220bac 71ebbfc 8a168fe 71ebbfc 8a168fe e220bac 8a168fe e220bac 71ebbfc 8a168fe e220bac 71ebbfc e220bac 71ebbfc e220bac 71ebbfc e220bac 71ebbfc e220bac 71ebbfc 8a168fe 71ebbfc e220bac 71ebbfc e220bac 71ebbfc e220bac 71ebbfc e220bac 71ebbfc e220bac 71ebbfc e220bac 8a168fe 0f8f6ca | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 | """
Baseline Inference Script β API Gateway Defender
=================================================
Runs the heuristic agent on all 3 tasks and prints structured output
in the required [START]/[STEP]/[END] format for the OpenEnv validator.
Usage
-----
python inference.py
# With LLM proxy (injected by validator):
API_BASE_URL=https://... API_KEY=... python inference.py
# Against a different server:
ENV_BASE_URL=https://... python inference.py
"""
import json
import os
import sys
import urllib.request
from typing import Any, Dict
# Use the LiteLLM proxy credentials injected by the validator.
# API_BASE_URL must end WITHOUT a trailing slash for /chat/completions appending.
API_KEY = os.getenv("API_KEY", os.getenv("OPENAI_API_KEY", ""))
_raw_base = os.getenv("API_BASE_URL", "").rstrip("/")
LLM_BASE_URL = _raw_base if _raw_base else "https://api.openai.com/v1"
ENV_BASE_URL = os.getenv("ENV_BASE_URL", "https://cystroncode-api-gateway-defender.hf.space")
LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4o-mini")
TASK_IDS = ["easy", "medium", "hard"]
# βββ HTTP helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _post(path: str, body: Any) -> Any:
data = json.dumps(body).encode()
req = urllib.request.Request(
f"{ENV_BASE_URL}{path}",
data=data,
headers={"Content-Type": "application/json"},
)
with urllib.request.urlopen(req, timeout=30) as resp:
return json.loads(resp.read())
# βββ Heuristic agent ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _heuristic_action(task_id: str, obs: Dict[str, Any]) -> Dict[str, Any]:
requests_list = obs.get("observation", obs).get("recent_requests", [])
if task_id == "easy":
ip_counts: Dict[str, int] = {}
for req in requests_list:
if req.get("path") == "/login" and req.get("method") == "POST":
ip = req.get("ip", "")
ip_counts[ip] = ip_counts.get(ip, 0) + 1
suspect_ip = max(ip_counts, key=lambda k: ip_counts[k]) if ip_counts else "185.220.101.47"
return {"action_type": "block_ip", "target_ip": suspect_ip}
elif task_id == "medium":
ua_counts: Dict[str, int] = {}
for req in requests_list:
ua = req.get("user_agent", "")
ua_counts[ua] = ua_counts.get(ua, 0) + 1
bot_kw = {"scraper", "bot", "crawler", "spider", "harvester"}
browser_kw = {"mozilla", "chrome", "safari", "firefox", "gecko", "webkit"}
suspect_ua = None
for ua, _ in sorted(ua_counts.items(), key=lambda x: -x[1]):
if any(k in ua.lower() for k in bot_kw):
suspect_ua = ua
break
if not suspect_ua:
for ua, _ in sorted(ua_counts.items(), key=lambda x: -x[1]):
if not any(k in ua.lower() for k in browser_kw):
suspect_ua = ua
break
return {"action_type": "block_user_agent",
"target_user_agent": suspect_ua or "ScraperBot/3.1"}
else:
return {"action_type": "write_custom_middleware",
"regex_pattern": r"UNION\s+SELECT"}
# βββ LLM agent ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _llm_action(task_id: str, obs: Dict[str, Any]) -> Dict[str, Any]:
"""Call the LiteLLM proxy supplied by the validator via API_BASE_URL / API_KEY."""
inner_obs = obs.get("observation", obs)
sample = inner_obs.get("recent_requests", [])[:25]
payload = json.dumps({
"model": LLM_MODEL,
"messages": [
{"role": "system", "content": "You are an SRE. Return ONE firewall rule as JSON only. No prose."},
{"role": "user", "content": (
f"TASK: {inner_obs.get('task_description','')}\n"
f"HINT: {inner_obs.get('hint','')}\n"
f"TRAFFIC: {json.dumps(sample)}\n"
'JSON schema: {"action_type":"block_ip"|"block_user_agent"|"write_custom_middleware"|"add_rate_limit",'
'"target_ip":"...","target_user_agent":"...","regex_pattern":"..."}'
)},
],
"max_tokens": 256,
"temperature": 0.1,
}).encode()
# Always route through the validator-injected LiteLLM proxy endpoint
llm_url = f"{LLM_BASE_URL}/chat/completions"
req = urllib.request.Request(
llm_url,
data=payload,
headers={"Content-Type": "application/json",
"Authorization": f"Bearer {API_KEY}"},
)
with urllib.request.urlopen(req, timeout=30) as resp:
raw = json.loads(resp.read())["choices"][0]["message"]["content"].strip()
if raw.startswith("```"):
raw = raw.split("```")[1]
if raw.lower().startswith("json"):
raw = raw[4:]
return json.loads(raw.strip())
# βββ Run one task episode βββββββββββββββββββββββββββββββββββββββββββββββββββββ
def run_task(task_id: str) -> Dict[str, Any]:
obs = _post("/reset", {"task_id": task_id})
score = 0.0
steps_taken = 0
step_results = []
for step_num in range(1, 6):
try:
# Use LLM if a key is available (prefers validator-injected API_KEY)
action = _llm_action(task_id, obs) if API_KEY else _heuristic_action(task_id, obs)
except Exception:
action = _heuristic_action(task_id, obs)
result = _post("/step", action)
reward = result.get("reward", {}).get("score", 0.0)
done = result.get("done", False)
obs = result
score = reward
steps_taken = step_num
step_results.append((step_num, reward))
if done:
break
return {"task_id": task_id, "score": score,
"steps": steps_taken, "step_results": step_results}
# βββ Main βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def main():
for task_id in TASK_IDS:
print(f"[START] task={task_id}", flush=True)
try:
result = run_task(task_id)
for step_num, reward in result["step_results"]:
print(f"[STEP] step={step_num} reward={reward}", flush=True)
print(f"[END] task={task_id} score={result['score']} steps={result['steps']}", flush=True)
except Exception as exc:
print(f"[STEP] step=1 reward=0.0", flush=True)
print(f"[END] task={task_id} score=0.0 steps=1", flush=True)
print(f"# ERROR: {exc}", file=sys.stderr, flush=True)
if __name__ == "__main__":
main()
|