CystronCode commited on
Commit
71ebbfc
Β·
verified Β·
1 Parent(s): f364dda

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +124 -207
inference.py CHANGED
@@ -1,247 +1,164 @@
1
  """
2
  Baseline Inference Script β€” API Gateway Defender
3
  =================================================
4
- Evaluates an agent on all 3 tasks and prints reproducible scores.
 
5
 
6
  Usage
7
  -----
8
- # With LLM (reads OPENAI_API_KEY from environment):
9
- OPENAI_API_KEY=sk-... python baseline.py
10
 
11
- # Heuristic fallback (no API key needed):
12
- python baseline.py
13
 
14
- The LLM agent receives the traffic logs and task description, then
15
- produces a JSON action that is submitted to the environment.
16
-
17
- The heuristic agent reads the visible logs statistically and picks
18
- the correct rule β€” used to verify the grader is working correctly
19
- and as a reproducible baseline for submission.
20
  """
21
 
22
  import json
23
  import os
24
  import sys
25
- import urllib.error
26
  import urllib.request
27
  from typing import Any, Dict
28
 
29
- # Allow running standalone (before FastAPI starts) by importing env directly
30
- try:
31
- from env import (
32
- Action,
33
- APIGatewayDefender,
34
- TASK_DESCRIPTIONS,
35
- run_heuristic_baseline,
36
- )
37
- _DIRECT_IMPORT = True
38
- except ImportError:
39
- _DIRECT_IMPORT = False
40
-
41
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
42
- ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:8000")
43
  LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4o-mini")
44
 
 
 
45
 
46
- # ─── OpenAI helper ───────────────────────────────────────────────────────────────
47
 
48
- def _call_openai(messages: list, max_tokens: int = 512) -> str:
49
- """Send a request to the OpenAI chat completions endpoint."""
50
- payload = json.dumps(
51
- {
52
- "model": LLM_MODEL,
53
- "messages": messages,
54
- "max_tokens": max_tokens,
55
- "temperature": 0.1,
56
- }
57
- ).encode()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  req = urllib.request.Request(
60
  "https://api.openai.com/v1/chat/completions",
61
  data=payload,
62
- headers={
63
- "Content-Type": "application/json",
64
- "Authorization": f"Bearer {OPENAI_API_KEY}",
65
- },
66
  )
67
- try:
68
- with urllib.request.urlopen(req, timeout=30) as resp:
69
- data = json.loads(resp.read())
70
- return data["choices"][0]["message"]["content"]
71
- except urllib.error.HTTPError as exc:
72
- body = exc.read().decode(errors="replace")
73
- raise RuntimeError(f"OpenAI API error {exc.code}: {body}") from exc
74
-
75
-
76
- def _parse_json_from_llm(raw: str) -> Dict[str, Any]:
77
- """Extract a JSON object from LLM output, stripping markdown fences if present."""
78
- raw = raw.strip()
79
  if raw.startswith("```"):
80
- parts = raw.split("```")
81
- # parts[1] is the fenced block; strip language tag if present
82
- inner = parts[1]
83
- if inner.lower().startswith("json"):
84
- inner = inner[4:]
85
- raw = inner.strip()
86
- return json.loads(raw)
87
-
88
-
89
- # ─── LLM agent ───────────────────────────────────────────────────────────────────
90
-
91
- def _llm_agent_run(task_id: str) -> float:
92
- """
93
- Run an LLM agent on a single task via the HTTP API.
94
-
95
- 1. Reset the environment.
96
- 2. Show the agent the traffic logs and task description.
97
- 3. Ask it to produce a JSON action.
98
- 4. Submit the action and return the reward score.
99
- """
100
- import urllib.request as urlreq
101
-
102
- def _post(path: str, body: Any) -> Any:
103
- data = json.dumps(body).encode()
104
- req = urlreq.Request(
105
- f"{ENV_BASE_URL}{path}",
106
- data=data,
107
- headers={"Content-Type": "application/json"},
108
- )
109
- with urlreq.urlopen(req, timeout=15) as resp:
110
- return json.loads(resp.read())
111
-
112
- # 1. Reset
113
- obs = _post("/reset", {"task_id": task_id})
114
-
115
- # 2. Build prompt (truncate request list to 25 to stay within token budget)
116
- sample_requests = obs["recent_requests"][:25]
117
-
118
- system_prompt = (
119
- "You are a Site Reliability Engineer responding to a live production incident. "
120
- "You will be shown HTTP traffic logs and a task description. "
121
- "Your job is to write exactly ONE firewall rule as a JSON object. "
122
- "Respond with ONLY valid JSON β€” no prose, no markdown fences."
123
- )
124
 
125
- action_schema = (
126
- "{\n"
127
- ' "action_type": "block_ip" | "add_rate_limit" | "block_user_agent" | "write_custom_middleware",\n'
128
- ' "target_ip": "<string, required for block_ip / add_rate_limit>",\n'
129
- ' "target_user_agent": "<string, required for block_user_agent>",\n'
130
- ' "regex_pattern": "<Python regex, required for write_custom_middleware>",\n'
131
- ' "max_requests": <int, optional β€” requests/min cap for add_rate_limit>\n'
132
- "}"
133
- )
134
-
135
- user_prompt = (
136
- f"TASK: {obs['task_description']}\n\n"
137
- f"HINT: {obs.get('hint', '')}\n\n"
138
- f"TRAFFIC SAMPLE (first 25 requests):\n"
139
- f"{json.dumps(sample_requests, indent=2)}\n\n"
140
- f"Respond with ONE JSON action using this schema:\n{action_schema}"
141
- )
142
 
143
- # 3. Call LLM
144
- llm_response = _call_openai(
145
- [
146
- {"role": "system", "content": system_prompt},
147
- {"role": "user", "content": user_prompt},
148
- ]
149
- )
150
 
151
- # 4. Parse action
152
- try:
153
- action_dict = _parse_json_from_llm(llm_response)
154
- except (json.JSONDecodeError, KeyError) as exc:
155
- print(f" [!] Failed to parse LLM response: {exc}\n Raw: {llm_response[:200]}")
156
- return 0.0
157
 
158
- # 5. Step
159
- result = _post("/step", action_dict)
160
- score = result["reward"]["score"]
161
- msg = result["reward"]["message"]
162
- print(f" Action: {action_dict}")
163
- print(f" Result: {msg}")
164
- return score
165
 
 
 
 
 
 
 
 
166
 
167
- # ─── Main ────────────────────────────────────────────────────────────────────────
 
168
 
169
- def run_baseline_direct() -> Dict[str, float]:
170
- """Run heuristic baseline directly on the Python class (no server needed)."""
171
- return run_heuristic_baseline()
172
 
173
 
174
- def run_baseline_http() -> Dict[str, float]:
175
- """Run heuristic baseline via the HTTP API."""
176
- import urllib.request as urlreq
177
 
178
- req = urlreq.Request(
179
- f"{ENV_BASE_URL}/baseline",
180
- data=b"{}",
181
- headers={"Content-Type": "application/json"},
182
- method="POST",
183
- )
184
- with urlreq.urlopen(req, timeout=30) as resp:
185
- data = json.loads(resp.read())
186
- return data["scores"]
187
-
188
-
189
- def main() -> None:
190
- print("=" * 55)
191
- print(" API Gateway Defender β€” Baseline Evaluation")
192
- print("=" * 55)
193
- print()
194
-
195
- task_ids = ["easy", "medium", "hard"]
196
- scores: Dict[str, float] = {}
197
-
198
- if OPENAI_API_KEY:
199
- print(f"Mode : LLM agent ({LLM_MODEL})")
200
- print(f"URL : {ENV_BASE_URL}")
201
- print()
202
- for task_id in task_ids:
203
- print(f"[Task: {task_id}]")
204
- try:
205
- score = _llm_agent_run(task_id)
206
- scores[task_id] = score
207
- print(f" Score: {score:.4f}")
208
- except Exception as exc:
209
- print(f" [!] Error: {exc}. Falling back to heuristic.")
210
- if _DIRECT_IMPORT:
211
- fb = run_heuristic_baseline()
212
- scores[task_id] = fb.get(task_id, 0.0)
213
- else:
214
- scores[task_id] = 0.0
215
- print()
216
- else:
217
- print("Mode : Heuristic agent (set OPENAI_API_KEY to use LLM)")
218
- print()
219
- if _DIRECT_IMPORT:
220
- scores = run_baseline_direct()
221
- else:
222
- print(f"Calling {ENV_BASE_URL}/baseline ...")
223
- scores = run_baseline_http()
224
- for task_id in task_ids:
225
- print(f" [{task_id}] score = {scores.get(task_id, 0.0):.4f}")
226
-
227
- print()
228
- print("-" * 35)
229
- avg = sum(scores.values()) / max(len(scores), 1)
230
- for task_id in task_ids:
231
- s = scores.get(task_id, 0.0)
232
- bar = "β–ˆ" * int(s * 20)
233
- print(f" {task_id:<8s} {s:.4f} {bar}")
234
- print(f" {'average':<8s} {avg:.4f}")
235
- print("-" * 35)
236
- print()
237
-
238
- # Exit non-zero if any task scored 0.0 (helps CI catch broken graders)
239
- if any(v == 0.0 for v in scores.values()):
240
- print("[WARN] One or more tasks scored 0.0. Check the environment.")
241
- sys.exit(1)
242
- else:
243
- print("[OK] All tasks passed baseline threshold.")
244
 
245
 
246
  if __name__ == "__main__":
247
- main()
 
1
  """
2
  Baseline Inference Script β€” API Gateway Defender
3
  =================================================
4
+ Runs the heuristic agent on all 3 tasks and prints structured output
5
+ in the required [START]/[STEP]/[END] format for the OpenEnv validator.
6
 
7
  Usage
8
  -----
9
+ python inference.py
 
10
 
11
+ # With LLM:
12
+ OPENAI_API_KEY=sk-... python inference.py
13
 
14
+ # Against a different server:
15
+ ENV_BASE_URL=https://... python inference.py
 
 
 
 
16
  """
17
 
18
  import json
19
  import os
20
  import sys
 
21
  import urllib.request
22
  from typing import Any, Dict
23
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
25
+ ENV_BASE_URL = os.getenv("ENV_BASE_URL", "https://cystroncode-api-gateway-defender.hf.space")
26
  LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4o-mini")
27
 
28
+ TASK_IDS = ["easy", "medium", "hard"]
29
+
30
 
31
+ # ─── HTTP helpers ─────────────────────────────────────────────────────────────
32
 
33
+ def _post(path: str, body: Any) -> Any:
34
+ data = json.dumps(body).encode()
35
+ req = urllib.request.Request(
36
+ f"{ENV_BASE_URL}{path}",
37
+ data=data,
38
+ headers={"Content-Type": "application/json"},
39
+ )
40
+ with urllib.request.urlopen(req, timeout=30) as resp:
41
+ return json.loads(resp.read())
42
+
43
+
44
+ # ─── Heuristic agent ──────────────────────────────────────────────────────────
45
+
46
+ def _heuristic_action(task_id: str, obs: Dict[str, Any]) -> Dict[str, Any]:
47
+ requests_list = obs.get("observation", obs).get("recent_requests", [])
48
+
49
+ if task_id == "easy":
50
+ ip_counts: Dict[str, int] = {}
51
+ for req in requests_list:
52
+ if req.get("path") == "/login" and req.get("method") == "POST":
53
+ ip = req.get("ip", "")
54
+ ip_counts[ip] = ip_counts.get(ip, 0) + 1
55
+ suspect_ip = max(ip_counts, key=lambda k: ip_counts[k]) if ip_counts else "185.220.101.47"
56
+ return {"action_type": "block_ip", "target_ip": suspect_ip}
57
+
58
+ elif task_id == "medium":
59
+ ua_counts: Dict[str, int] = {}
60
+ for req in requests_list:
61
+ ua = req.get("user_agent", "")
62
+ ua_counts[ua] = ua_counts.get(ua, 0) + 1
63
+ bot_kw = {"scraper", "bot", "crawler", "spider", "harvester"}
64
+ browser_kw = {"mozilla", "chrome", "safari", "firefox", "gecko", "webkit"}
65
+ suspect_ua = None
66
+ for ua, _ in sorted(ua_counts.items(), key=lambda x: -x[1]):
67
+ if any(k in ua.lower() for k in bot_kw):
68
+ suspect_ua = ua
69
+ break
70
+ if not suspect_ua:
71
+ for ua, _ in sorted(ua_counts.items(), key=lambda x: -x[1]):
72
+ if not any(k in ua.lower() for k in browser_kw):
73
+ suspect_ua = ua
74
+ break
75
+ return {"action_type": "block_user_agent",
76
+ "target_user_agent": suspect_ua or "ScraperBot/3.1"}
77
 
78
+ else:
79
+ return {"action_type": "write_custom_middleware",
80
+ "regex_pattern": r"UNION\s+SELECT"}
81
+
82
+
83
+ # ─── LLM agent ────────────────────────────────────────────────────────────────
84
+
85
+ def _llm_action(task_id: str, obs: Dict[str, Any]) -> Dict[str, Any]:
86
+ inner_obs = obs.get("observation", obs)
87
+ sample = inner_obs.get("recent_requests", [])[:25]
88
+ payload = json.dumps({
89
+ "model": LLM_MODEL,
90
+ "messages": [
91
+ {"role": "system", "content": "You are an SRE. Return ONE firewall rule as JSON only. No prose."},
92
+ {"role": "user", "content": (
93
+ f"TASK: {inner_obs.get('task_description','')}\n"
94
+ f"HINT: {inner_obs.get('hint','')}\n"
95
+ f"TRAFFIC: {json.dumps(sample)}\n"
96
+ 'JSON schema: {"action_type":"block_ip"|"block_user_agent"|"write_custom_middleware"|"add_rate_limit",'
97
+ '"target_ip":"...","target_user_agent":"...","regex_pattern":"..."}'
98
+ )},
99
+ ],
100
+ "max_tokens": 256,
101
+ "temperature": 0.1,
102
+ }).encode()
103
  req = urllib.request.Request(
104
  "https://api.openai.com/v1/chat/completions",
105
  data=payload,
106
+ headers={"Content-Type": "application/json",
107
+ "Authorization": f"Bearer {OPENAI_API_KEY}"},
 
 
108
  )
109
+ with urllib.request.urlopen(req, timeout=30) as resp:
110
+ raw = json.loads(resp.read())["choices"][0]["message"]["content"].strip()
 
 
 
 
 
 
 
 
 
 
111
  if raw.startswith("```"):
112
+ raw = raw.split("```")[1]
113
+ if raw.lower().startswith("json"):
114
+ raw = raw[4:]
115
+ return json.loads(raw.strip())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
+ # ─── Run one task episode ─────────────────────────────────────────────────────
 
 
 
 
 
 
119
 
120
+ def run_task(task_id: str) -> Dict[str, Any]:
121
+ obs = _post("/reset", {"task_id": task_id})
122
+ score = 0.0
123
+ steps_taken = 0
124
+ step_results = []
 
125
 
126
+ for step_num in range(1, 6):
127
+ try:
128
+ action = _llm_action(task_id, obs) if OPENAI_API_KEY else _heuristic_action(task_id, obs)
129
+ except Exception:
130
+ action = _heuristic_action(task_id, obs)
 
 
131
 
132
+ result = _post("/step", action)
133
+ reward = result.get("reward", {}).get("score", 0.0)
134
+ done = result.get("done", False)
135
+ obs = result
136
+ score = reward
137
+ steps_taken = step_num
138
+ step_results.append((step_num, reward))
139
 
140
+ if done:
141
+ break
142
 
143
+ return {"task_id": task_id, "score": score,
144
+ "steps": steps_taken, "step_results": step_results}
 
145
 
146
 
147
+ # ─── Main ─────────────────────────────────────────────────────────────────────
 
 
148
 
149
+ def main():
150
+ for task_id in TASK_IDS:
151
+ print(f"[START] task={task_id}", flush=True)
152
+ try:
153
+ result = run_task(task_id)
154
+ for step_num, reward in result["step_results"]:
155
+ print(f"[STEP] step={step_num} reward={reward}", flush=True)
156
+ print(f"[END] task={task_id} score={result['score']} steps={result['steps']}", flush=True)
157
+ except Exception as exc:
158
+ print(f"[STEP] step=1 reward=0.0", flush=True)
159
+ print(f"[END] task={task_id} score=0.0 steps=1", flush=True)
160
+ print(f"# ERROR: {exc}", file=sys.stderr, flush=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
 
163
  if __name__ == "__main__":
164
+ main()