Ajaxwin commited on
Commit
8493010
Β·
unverified Β·
1 Parent(s): cfae7a7

Switching to old inference.py

Browse files
Files changed (1) hide show
  1. inference.py +175 -270
inference.py CHANGED
@@ -35,43 +35,34 @@ from openai import OpenAI
35
  from server import Task1Environment, Task2Environment, Task3Environment
36
  from env.schemas import Action, ActionType
37
  from utils import T1_SYSTEM, T2_SYSTEM, T3_SYSTEM
38
- from dotenv import load_dotenv
39
 
40
  # ─────────────────────────────────────────────────────────────────────────────
41
  # Configuration
42
  # ─────────────────────────────────────────────────────────────────────────────
43
 
44
- load_dotenv() # Load from .env if available; otherwise rely on actual env vars
45
- API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
46
- MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o")
47
- HF_TOKEN = os.getenv("HF_TOKEN")
48
 
49
  if not HF_TOKEN:
50
  print("[WARN] HF_TOKEN not set β€” API calls may fail.", file=sys.stderr)
51
  exit(1)
52
 
53
- ENV_BENCHMARK = "smart-contract-audit"
54
- NUM_EPISODES = 2 # keep low on free tier; raise for full eval
55
- SEED_BASE = 42
56
 
57
- # Max LLM calls per episode (including the mandatory submit on last step).
58
- # Budget: free tier handles ~5-6 calls per episode before rate-limiting.
59
- MAX_STEPS_T1 = 5
60
- MAX_STEPS_T2 = 4
61
- MAX_STEPS_T3 = 4
62
 
63
- # How many steps before the end we start injecting "submit now" pressure.
64
- # E.g. PRESSURE_AT=2 means last 2 steps show a warning.
65
- PRESSURE_AT = 2
 
66
 
67
- # Sliding-window size: how many recent (user, assistant) pairs to keep.
68
- # system prompt + 2 exchanges = ~800 tokens max β€” safe for free tier.
69
- HISTORY_WINDOW = 2
70
-
71
- # Truncate action results to this many chars before inserting into the prompt.
72
- MAX_RESULT_CHARS = 400
73
-
74
- # A grader_score >= this threshold β†’ success=true in [END] line
75
  SUCCESS_SCORE_THRESHOLD = 0.5
76
 
77
  client = OpenAI(api_key=HF_TOKEN, base_url=API_BASE_URL)
@@ -117,115 +108,30 @@ def log_end(
117
  )
118
 
119
 
120
- # ─────────────────────────────────────────────────────────────────────────────
121
- # Shared utilities
122
- # ─────────────────────────────────────────────────────────────────────────────
123
-
124
- def _truncate(text: str, limit: int = MAX_RESULT_CHARS) -> str:
125
- """Truncate long action results to keep prompts small."""
126
- if len(text) <= limit:
127
- return text
128
- return text[:limit] + f"... [truncated, {len(text) - limit} chars omitted]"
129
-
130
-
131
- def _sliding_messages(system: str, history: List[Dict[str, str]]) -> List[Dict[str, str]]:
132
- """
133
- Return system prompt + the last HISTORY_WINDOW (user, assistant) pairs.
134
- This keeps total tokens bounded regardless of episode length.
135
- """
136
- # history = [..., user, assistant, user, assistant, ...]
137
- # We want the last HISTORY_WINDOW complete pairs (2 messages each).
138
- keep = HISTORY_WINDOW * 2
139
- windowed = history[-keep:] if len(history) > keep else history
140
- return [{"role": "system", "content": system}] + windowed
141
-
142
-
143
- def _call_llm(messages: List[Dict[str, str]], max_tokens: int = 150) -> tuple[str, Optional[str]]:
144
- """Call the LLM; return (raw_response, error_string_or_None)."""
145
- try:
146
- resp = client.chat.completions.create(
147
- model=MODEL_NAME, # type: ignore
148
- messages=messages, # type: ignore
149
- max_tokens=max_tokens,
150
- temperature=0.0,
151
- )
152
- return resp.choices[0].message.content.strip(), None # type: ignore
153
- except Exception as e:
154
- return "", str(e)[:80]
155
-
156
-
157
- def _parse_action(raw: str, fallback_at: ActionType,
158
- fallback_params: Dict[str, Any]) -> tuple[ActionType, Dict[str, Any]]:
159
- """Parse LLM JSON response into (ActionType, params). Use fallback on failure."""
160
- try:
161
- parsed = json.loads(raw)
162
- return ActionType(parsed["action"]), parsed.get("params", {})
163
- except Exception:
164
- return fallback_at, fallback_params
165
-
166
-
167
- def _pressure_suffix(steps_left: int) -> str:
168
- """Return an urgent suffix when the step budget is nearly exhausted."""
169
- if steps_left <= 0:
170
- return (
171
- "\n\n⚠️ FINAL STEP β€” you MUST submit your best answer RIGHT NOW.\n"
172
- "Do not browse further. Emit a submit action immediately."
173
- )
174
- if steps_left <= PRESSURE_AT:
175
- return (
176
- f"\n\n⚠️ Only {steps_left} step(s) remaining. "
177
- "You should submit your answer in the next step or two."
178
- )
179
- return ""
180
-
181
-
182
  # ─────────────────────────────────────────────────────────────────────────────
183
  # Task 1 β€” Targeted Vulnerability Detection
184
  # ─────────────────────────────────────────────────────────────────────────────
185
 
186
- def _t1_user(obs: Dict[str, Any], steps_left: int) -> str:
187
- result = _truncate(obs.get("last_action_result") or "Episode just started.")
188
  return (
189
- f"Contract: {obs['contract_name']} | {obs['contract_description'][:80]}\n"
190
- f"Step {obs['step_count']} | Reward: {obs['cumulative_reward']:.2f}\n"
191
- f"Last action: {obs['last_action'] or 'None'}\n"
192
- f"Result: {result}"
193
- + _pressure_suffix(steps_left)
194
- )
195
-
196
-
197
- def _t1_force_submit(obs: Dict[str, Any], history: List[Dict[str, str]]) -> tuple[ActionType, Dict[str, Any]]:
198
- """
199
- Build a forced submission from what we already know.
200
- Strategy: ask the LLM one more time with an explicit 'submit NOW' mandate.
201
- If that fails, fall back to a heuristic.
202
- """
203
- mandate = (
204
- "Based on everything you have seen, submit your best answer NOW.\n"
205
- "Respond ONLY with this JSON (fill in the values):\n"
206
- '{"action":"submit","params":{"function_name":"<best_guess>","vulnerability_type":"<best_guess>"}}'
207
  )
208
- messages = _sliding_messages(T1_SYSTEM, history) + [{"role": "user", "content": mandate}]
209
- raw, _ = _call_llm(messages, max_tokens=80)
210
- at, params = _parse_action(raw, ActionType.SUBMIT,
211
- {"function_name": "withdraw",
212
- "vulnerability_type": "reentrancy"})
213
- # Guarantee it's always a submit
214
- if at != ActionType.SUBMIT:
215
- at = ActionType.SUBMIT
216
- if "function_name" not in params:
217
- params["function_name"] = "withdraw"
218
- if "vulnerability_type" not in params:
219
- params["vulnerability_type"] = "reentrancy"
220
- return at, params
221
 
222
 
223
  def _run_t1_episode(env: Task1Environment, seed: int, ep_num: int) -> Dict[str, Any]:
 
224
  r = env.reset(seed=seed)
225
  obs = r.observation.model_dump()
 
226
  log_start(task="task1_vuln_detection", env=ENV_BENCHMARK, model=MODEL_NAME) # type: ignore
227
 
228
- history: List[Dict[str, str]] = []
 
 
229
  step_rewards: List[float] = []
230
  grader_score = 0.0
231
  steps_taken = 0
@@ -233,24 +139,31 @@ def _run_t1_episode(env: Task1Environment, seed: int, ep_num: int) -> Dict[str,
233
 
234
  try:
235
  for step in range(1, MAX_STEPS_T1 + 1):
236
- steps_left = MAX_STEPS_T1 - step
237
- is_last = (step == MAX_STEPS_T1)
238
-
239
- if is_last:
240
- # Never waste the last step on browsing β€” force a submission
241
- at, params = _t1_force_submit(obs, history)
242
- else:
243
- user_msg = _t1_user(obs, steps_left)
244
- history.append({"role": "user", "content": user_msg})
245
- messages = _sliding_messages(T1_SYSTEM, history)
246
- raw, error_msg = _call_llm(messages)
247
- history.append({"role": "assistant", "content": raw})
248
- at, params = _parse_action(raw, ActionType.LIST_FUNCTIONS, {})
249
-
250
- result = env.step(Action(action_type=at, params=params))
251
- obs = result.observation.model_dump()
252
- r_val = result.reward.value
253
- done = result.done
 
 
 
 
 
 
 
254
 
255
  step_rewards.append(r_val)
256
  steps_taken = step
@@ -258,18 +171,22 @@ def _run_t1_episode(env: Task1Environment, seed: int, ep_num: int) -> Dict[str,
258
 
259
  if done:
260
  v = r_val
261
- grader_score = 0.999 if v >= 4.9 else (0.5 if v >= 0.9 else 0.0)
262
  break
263
 
264
- if not is_last:
265
- time.sleep(0.5)
266
 
267
  finally:
268
  success = grader_score >= SUCCESS_SCORE_THRESHOLD
269
  log_end(success=success, steps=steps_taken, score=grader_score, rewards=step_rewards)
270
 
271
- return {"episode": ep_num, "seed": seed, "contract": obs["contract_name"],
272
- "grader_score": grader_score, "cumulative_reward": obs["cumulative_reward"]}
 
 
 
 
 
273
 
274
 
275
  # ─────────────────────────────────────────────────────────────────────────────
@@ -277,48 +194,29 @@ def _run_t1_episode(env: Task1Environment, seed: int, ep_num: int) -> Dict[str,
277
  # ─────────────────────────────────────────────────────────────────────────────
278
 
279
 
280
- def _t2_user(obs: Dict[str, Any], steps_left: int) -> str:
281
- extra = obs.get("extra", {})
282
- result = _truncate(obs.get("last_action_result") or "Episode just started.")
283
- return (
284
- f"Contract: {obs['contract_name']} | "
285
- f"Function: {extra.get('target_function','?')} ({extra.get('target_signature','')})\n"
286
- f"Step {obs['step_count']} | Reward: {obs['cumulative_reward']:.2f}\n"
287
- f"Last action: {obs['last_action'] or 'None'}\n"
288
- f"Result: {result}"
289
- + _pressure_suffix(steps_left)
290
- )
291
-
292
-
293
- def _t2_force_submit(obs: Dict[str, Any], history: List[Dict[str, str]]) -> tuple[ActionType, Dict[str, Any]]:
294
- """Force a submit_property based on everything seen so far."""
295
  extra = obs.get("extra", {})
296
- fn = extra.get("target_function", "this function")
297
- mandate = (
298
- f"You must now submit your best property for '{fn}'.\n"
299
- "Write 2-3 sentences covering: what state changes, what is transferred, revert conditions.\n"
300
- "Respond ONLY with:\n"
301
- '{"action":"submit_property","params":{"property":"<your property here>"}}'
 
302
  )
303
- messages = _sliding_messages(T2_SYSTEM, history) + [{"role": "user", "content": mandate}]
304
- raw, _ = _call_llm(messages, max_tokens=200)
305
- at, params = _parse_action(raw, ActionType.SUBMIT_PROPERTY, {})
306
- if at != ActionType.SUBMIT_PROPERTY or not params.get("property", "").strip():
307
- at = ActionType.SUBMIT_PROPERTY
308
- params = {"property": (
309
- f"After a successful call to {fn}, the contract updates its internal state "
310
- f"according to the function's logic. Reverts if input conditions are not met."
311
- )}
312
- return at, params
313
 
314
 
315
  def _run_t2_episode(env: Task2Environment, seed: int, ep_num: int) -> Dict[str, Any]:
 
316
  r = env.reset(seed=seed)
317
  obs = r.observation.model_dump()
318
  fn = obs["extra"].get("target_function", "?")
 
319
  log_start(task="task2_property_discovery", env=ENV_BENCHMARK, model=MODEL_NAME) # type: ignore
320
 
321
- history: List[Dict[str, str]] = []
 
 
322
  step_rewards: List[float] = []
323
  grader_score = 0.0
324
  steps_taken = 0
@@ -326,23 +224,31 @@ def _run_t2_episode(env: Task2Environment, seed: int, ep_num: int) -> Dict[str,
326
 
327
  try:
328
  for step in range(1, MAX_STEPS_T2 + 1):
329
- steps_left = MAX_STEPS_T2 - step
330
- is_last = (step == MAX_STEPS_T2)
331
-
332
- if is_last:
333
- at, params = _t2_force_submit(obs, history)
334
- else:
335
- user_msg = _t2_user(obs, steps_left)
336
- history.append({"role": "user", "content": user_msg})
337
- messages = _sliding_messages(T2_SYSTEM, history)
338
- raw, error_msg = _call_llm(messages, max_tokens=250)
339
- history.append({"role": "assistant", "content": raw})
340
- at, params = _parse_action(raw, ActionType.GET_FUNCTION_NATSPEC, {})
341
-
342
- result = env.step(Action(action_type=at, params=params))
343
- obs = result.observation.model_dump()
344
- r_val = result.reward.value
345
- done = result.done
 
 
 
 
 
 
 
 
346
 
347
  step_rewards.append(r_val)
348
  steps_taken = step
@@ -352,64 +258,48 @@ def _run_t2_episode(env: Task2Environment, seed: int, ep_num: int) -> Dict[str,
352
  grader_score = round(r_val / 5.0, 3) if r_val > 0 else 0.0
353
  break
354
 
355
- if not is_last:
356
- time.sleep(0.5)
357
 
358
  finally:
359
  success = grader_score >= SUCCESS_SCORE_THRESHOLD
360
  log_end(success=success, steps=steps_taken, score=grader_score, rewards=step_rewards)
361
 
362
- return {"episode": ep_num, "seed": seed,
363
- "contract": obs["contract_name"], "function": fn,
364
- "grader_score": grader_score, "cumulative_reward": obs["cumulative_reward"]}
 
 
 
 
 
365
 
366
 
367
  # ─────────────────────────────────────────────────────────────────────────────
368
  # Task 3 β€” Rule Checker
369
  # ─────────────────────────────────────────────────────────────────────────────
370
 
371
- def _t3_user(obs: Dict[str, Any], steps_left: int) -> str:
372
- extra = obs.get("extra", {})
373
- result = _truncate(obs.get("last_action_result") or "Episode just started.")
374
- return (
375
- f"Contract: {obs['contract_name']}\n"
376
- f"Property: {extra.get('property_english', '(none)')[:200]}\n"
377
- f"Step {obs['step_count']} | Reward: {obs['cumulative_reward']:.2f}\n"
378
- f"Last action: {obs['last_action'] or 'None'}\n"
379
- f"Result: {result}"
380
- + _pressure_suffix(steps_left)
381
- )
382
-
383
 
384
- def _t3_force_submit(obs: Dict[str, Any], history: List[Dict[str, str]]) -> tuple[ActionType, Dict[str, Any]]:
385
- """Force a submit_function based on everything seen so far."""
386
- prop = obs.get("extra", {}).get("property_english", "")
387
- mandate = (
388
- f"Property: {prop[:200]}\n"
389
- "Based on everything you have seen, which function violates this property?\n"
390
- "Respond ONLY with:\n"
391
- '{"action":"submit_function","params":{"function_name":"<your_best_guess>"}}'
392
  )
393
- messages = _sliding_messages(T3_SYSTEM, history) + [{"role": "user", "content": mandate}]
394
- raw, _ = _call_llm(messages, max_tokens=80)
395
- at, params = _parse_action(raw, ActionType.SUBMIT_FUNCTION, {})
396
- if at != ActionType.SUBMIT_FUNCTION or not params.get("function_name", "").strip():
397
- # Heuristic fallback: scan property text for a function name mention
398
- fn_candidates = ["withdraw", "emergencyDrain", "buyTokens", "setPrice",
399
- "bid", "finalize", "stake", "claimRewards"]
400
- prop_lower = prop.lower()
401
- chosen = next((fn for fn in fn_candidates if fn.lower() in prop_lower), "withdraw")
402
- at = ActionType.SUBMIT_FUNCTION
403
- params = {"function_name": chosen}
404
- return at, params
405
 
406
 
407
  def _run_t3_episode(env: Task3Environment, seed: int, ep_num: int) -> Dict[str, Any]:
 
408
  r = env.reset(seed=seed)
409
  obs = r.observation.model_dump()
 
410
  log_start(task="task3_rule_checker", env=ENV_BENCHMARK, model=MODEL_NAME) # type: ignore
411
 
412
- history: List[Dict[str, str]] = []
 
 
413
  step_rewards: List[float] = []
414
  grader_score = 0.0
415
  steps_taken = 0
@@ -417,23 +307,31 @@ def _run_t3_episode(env: Task3Environment, seed: int, ep_num: int) -> Dict[str,
417
 
418
  try:
419
  for step in range(1, MAX_STEPS_T3 + 1):
420
- steps_left = MAX_STEPS_T3 - step
421
- is_last = (step == MAX_STEPS_T3)
422
-
423
- if is_last:
424
- at, params = _t3_force_submit(obs, history)
425
- else:
426
- user_msg = _t3_user(obs, steps_left)
427
- history.append({"role": "user", "content": user_msg})
428
- messages = _sliding_messages(T3_SYSTEM, history)
429
- raw, error_msg = _call_llm(messages)
430
- history.append({"role": "assistant", "content": raw})
431
- at, params = _parse_action(raw, ActionType.GET_PROPERTY_SPECIFICATION, {})
432
-
433
- result = env.step(Action(action_type=at, params=params))
434
- obs = result.observation.model_dump()
435
- r_val = result.reward.value
436
- done = result.done
 
 
 
 
 
 
 
 
437
 
438
  step_rewards.append(r_val)
439
  steps_taken = step
@@ -441,18 +339,22 @@ def _run_t3_episode(env: Task3Environment, seed: int, ep_num: int) -> Dict[str,
441
 
442
  if done:
443
  v = r_val
444
- grader_score = 0.999 if v >= 4.9 else (0.3 if v >= 0.999 else 0.0)
445
  break
446
 
447
- if not is_last:
448
- time.sleep(0.5)
449
 
450
  finally:
451
  success = grader_score >= SUCCESS_SCORE_THRESHOLD
452
  log_end(success=success, steps=steps_taken, score=grader_score, rewards=step_rewards)
453
 
454
- return {"episode": ep_num, "seed": seed, "contract": obs["contract_name"],
455
- "grader_score": grader_score, "cumulative_reward": obs["cumulative_reward"]}
 
 
 
 
 
456
 
457
 
458
  # ─────────────────────────────────────────────────────────────────────────────
@@ -469,9 +371,11 @@ def run_task1(n: int = NUM_EPISODES) -> Dict[str, Any]:
469
  avg_r = sum(e["cumulative_reward"] for e in episodes) / n
470
  print(f"\n Avg grader score : {avg_s:.3f}", flush=True)
471
  print(f" Avg cum reward : {avg_r:.2f}", flush=True)
472
- return {"task_id": "task1_vuln_detection", "name": "Targeted Vulnerability Detection",
473
- "status": "active", "num_episodes": n, "episodes": episodes,
474
- "avg_grader_score": avg_s, "avg_cumulative_reward": avg_r}
 
 
475
 
476
 
477
  def run_task2(n: int = NUM_EPISODES) -> Dict[str, Any]:
@@ -484,9 +388,11 @@ def run_task2(n: int = NUM_EPISODES) -> Dict[str, Any]:
484
  avg_r = sum(e["cumulative_reward"] for e in episodes) / n
485
  print(f"\n Avg grader score : {avg_s:.3f}", flush=True)
486
  print(f" Avg cum reward : {avg_r:.2f}", flush=True)
487
- return {"task_id": "task2_property_discovery", "name": "Property Discovery",
488
- "status": "active", "num_episodes": n, "episodes": episodes,
489
- "avg_grader_score": avg_s, "avg_cumulative_reward": avg_r}
 
 
490
 
491
 
492
  def run_task3(n: int = NUM_EPISODES) -> Dict[str, Any]:
@@ -499,9 +405,11 @@ def run_task3(n: int = NUM_EPISODES) -> Dict[str, Any]:
499
  avg_r = sum(e["cumulative_reward"] for e in episodes) / n
500
  print(f"\n Avg grader score : {avg_s:.3f}", flush=True)
501
  print(f" Avg cum reward : {avg_r:.2f}", flush=True)
502
- return {"task_id": "task3_rule_checker", "name": "Rule Checker",
503
- "status": "active", "num_episodes": n, "episodes": episodes,
504
- "avg_grader_score": avg_s, "avg_cumulative_reward": avg_r}
 
 
505
 
506
 
507
  # ─────────────────────────────────────────────────────────────────────────────
@@ -509,21 +417,18 @@ def run_task3(n: int = NUM_EPISODES) -> Dict[str, Any]:
509
  # ─────────────────────────────────────────────────────────────────────────────
510
 
511
  async def main() -> None:
 
512
  print("Smart Contract Audit RL Environment β€” Baseline Inference", flush=True)
513
- print(f"Model : {MODEL_NAME}", flush=True)
514
- print(f"Base URL : {API_BASE_URL}", flush=True)
515
- print(f"Episodes : {NUM_EPISODES} per task | "
516
- f"Max steps: T1={MAX_STEPS_T1} T2={MAX_STEPS_T2} T3={MAX_STEPS_T3}", flush=True)
517
- print(f"Hist window: last {HISTORY_WINDOW} exchanges | "
518
- f"Result truncation: {MAX_RESULT_CHARS} chars", flush=True)
519
 
520
  t1 = run_task1(NUM_EPISODES)
521
  t2 = run_task2(NUM_EPISODES)
522
  t3 = run_task3(NUM_EPISODES)
523
 
524
  results = {
525
- "model": MODEL_NAME, "base_url": API_BASE_URL,
526
- "tasks": [t1, t2, t3],
 
527
  }
528
  overall = sum(t["avg_grader_score"] for t in results["tasks"]) / 3
529
  results["overall_avg_score"] = overall
@@ -541,4 +446,4 @@ async def main() -> None:
541
 
542
 
543
  if __name__ == "__main__":
544
- asyncio.run(main())
 
35
  from server import Task1Environment, Task2Environment, Task3Environment
36
  from env.schemas import Action, ActionType
37
  from utils import T1_SYSTEM, T2_SYSTEM, T3_SYSTEM
38
+ from dotenv import dotenv_values
39
 
40
  # ─────────────────────────────────────────────────────────────────────────────
41
  # Configuration
42
  # ─────────────────────────────────────────────────────────────────────────────
43
 
44
+ config = dotenv_values(".env")
45
+ API_BASE_URL = config.get("API_BASE_URL", "https://api.openai.com/v1")
46
+ MODEL_NAME = config.get("MODEL_NAME", "gpt-4o")
47
+ HF_TOKEN = config.get("HF_TOKEN", "")
48
 
49
  if not HF_TOKEN:
50
  print("[WARN] HF_TOKEN not set β€” API calls may fail.", file=sys.stderr)
51
  exit(1)
52
 
53
+ # Benchmark / environment identifier (constant for this env)
54
+ ENV_BENCHMARK = "smart-contract-audit"
 
55
 
56
+ # Episodes per task
57
+ NUM_EPISODES = 3
58
+ SEED_BASE = 42
 
 
59
 
60
+ # Max steps per task
61
+ MAX_STEPS_T1 = 15
62
+ MAX_STEPS_T2 = 10
63
+ MAX_STEPS_T3 = 12
64
 
65
+ # A grader_score >= this is considered a "success" for the [END] line
 
 
 
 
 
 
 
66
  SUCCESS_SCORE_THRESHOLD = 0.5
67
 
68
  client = OpenAI(api_key=HF_TOKEN, base_url=API_BASE_URL)
 
108
  )
109
 
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  # ─────────────────────────────────────────────────────────────────────────────
112
  # Task 1 β€” Targeted Vulnerability Detection
113
  # ─────────────────────────────────────────────────────────────────────────────
114
 
115
+ def _t1_user_msg(obs: Dict[str, Any]) -> str:
 
116
  return (
117
+ f"Contract: {obs['contract_name']}\n"
118
+ f"Description: {obs['contract_description']}\n"
119
+ f"Step: {obs['step_count']} | Reward so far: {obs['cumulative_reward']:.2f}\n\n"
120
+ f"Last action : {obs['last_action'] or 'None'}\n"
121
+ f"Last result : {obs['last_action_result'] or 'Episode just started.'}"
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
 
125
  def _run_t1_episode(env: Task1Environment, seed: int, ep_num: int) -> Dict[str, Any]:
126
+ """Run one Task 1 episode; emit [START]/[STEP]/[END]."""
127
  r = env.reset(seed=seed)
128
  obs = r.observation.model_dump()
129
+
130
  log_start(task="task1_vuln_detection", env=ENV_BENCHMARK, model=MODEL_NAME) # type: ignore
131
 
132
+ messages: List[ChatCompletionMessageParam] = [ # type: ignore
133
+ {"role": "system", "content": T1_SYSTEM}
134
+ ]
135
  step_rewards: List[float] = []
136
  grader_score = 0.0
137
  steps_taken = 0
 
139
 
140
  try:
141
  for step in range(1, MAX_STEPS_T1 + 1):
142
+ messages.append({"role": "user", "content": _t1_user_msg(obs)})
143
+ try:
144
+ resp = client.chat.completions.create(
145
+ model=MODEL_NAME, messages=messages, # type: ignore
146
+ max_tokens=200, temperature=0.0,
147
+ )
148
+ raw = resp.choices[0].message.content.strip() # type: ignore
149
+ error_msg = None
150
+ except Exception as e:
151
+ raw = ""
152
+ error_msg = str(e)[:80]
153
+ print(f"[DEBUG] T1 LLM error ep={ep_num} step={step}: {e}", file=sys.stderr)
154
+
155
+ try:
156
+ parsed = json.loads(raw)
157
+ at = ActionType(parsed["action"])
158
+ params = parsed.get("params", {})
159
+ except Exception:
160
+ at, params = ActionType.LIST_FUNCTIONS, {}
161
+
162
+ messages.append({"role": "assistant", "content": raw})
163
+ result = env.step(Action(action_type=at, params=params))
164
+ obs = result.observation.model_dump()
165
+ r_val = result.reward.value
166
+ done = result.done
167
 
168
  step_rewards.append(r_val)
169
  steps_taken = step
 
171
 
172
  if done:
173
  v = r_val
174
+ grader_score = 1.0 if v >= 4.9 else (0.5 if v >= 0.9 else 0.0)
175
  break
176
 
177
+ time.sleep(0.3)
 
178
 
179
  finally:
180
  success = grader_score >= SUCCESS_SCORE_THRESHOLD
181
  log_end(success=success, steps=steps_taken, score=grader_score, rewards=step_rewards)
182
 
183
+ return {
184
+ "episode": ep_num,
185
+ "seed": seed,
186
+ "contract": obs["contract_name"],
187
+ "grader_score": grader_score,
188
+ "cumulative_reward": obs["cumulative_reward"],
189
+ }
190
 
191
 
192
  # ─────────────────────────────────────────────────────────────────────────────
 
194
  # ─────────────────────────────────────────────────────────────────────────────
195
 
196
 
197
+ def _t2_user_msg(obs: Dict[str, Any]) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  extra = obs.get("extra", {})
199
+ return (
200
+ f"Contract : {obs['contract_name']}\n"
201
+ f"Function : {extra.get('target_function', '?')} "
202
+ f"({extra.get('target_signature', '')})\n"
203
+ f"Step: {obs['step_count']} | Reward so far: {obs['cumulative_reward']:.2f}\n\n"
204
+ f"Last action : {obs['last_action'] or 'None'}\n"
205
+ f"Last result :\n{obs['last_action_result'] or 'Episode just started.'}"
206
  )
 
 
 
 
 
 
 
 
 
 
207
 
208
 
209
  def _run_t2_episode(env: Task2Environment, seed: int, ep_num: int) -> Dict[str, Any]:
210
+ """Run one Task 2 episode; emit [START]/[STEP]/[END]."""
211
  r = env.reset(seed=seed)
212
  obs = r.observation.model_dump()
213
  fn = obs["extra"].get("target_function", "?")
214
+
215
  log_start(task="task2_property_discovery", env=ENV_BENCHMARK, model=MODEL_NAME) # type: ignore
216
 
217
+ messages: List[ChatCompletionMessageParam] = [ # type: ignore
218
+ {"role": "system", "content": T2_SYSTEM}
219
+ ]
220
  step_rewards: List[float] = []
221
  grader_score = 0.0
222
  steps_taken = 0
 
224
 
225
  try:
226
  for step in range(1, MAX_STEPS_T2 + 1):
227
+ messages.append({"role": "user", "content": _t2_user_msg(obs)})
228
+ try:
229
+ resp = client.chat.completions.create(
230
+ model=MODEL_NAME, messages=messages, # type: ignore
231
+ max_tokens=400, temperature=0.0,
232
+ )
233
+ raw = resp.choices[0].message.content.strip() # type: ignore
234
+ error_msg = None
235
+ except Exception as e:
236
+ raw = ""
237
+ error_msg = str(e)[:80]
238
+ print(f"[DEBUG] T2 LLM error ep={ep_num} step={step}: {e}", file=sys.stderr)
239
+
240
+ try:
241
+ parsed = json.loads(raw)
242
+ at = ActionType(parsed["action"])
243
+ params = parsed.get("params", {})
244
+ except Exception:
245
+ at, params = ActionType.GET_FUNCTION_CODE, {}
246
+
247
+ messages.append({"role": "assistant", "content": raw})
248
+ result = env.step(Action(action_type=at, params=params))
249
+ obs = result.observation.model_dump()
250
+ r_val = result.reward.value
251
+ done = result.done
252
 
253
  step_rewards.append(r_val)
254
  steps_taken = step
 
258
  grader_score = round(r_val / 5.0, 3) if r_val > 0 else 0.0
259
  break
260
 
261
+ time.sleep(0.3)
 
262
 
263
  finally:
264
  success = grader_score >= SUCCESS_SCORE_THRESHOLD
265
  log_end(success=success, steps=steps_taken, score=grader_score, rewards=step_rewards)
266
 
267
+ return {
268
+ "episode": ep_num,
269
+ "seed": seed,
270
+ "contract": obs["contract_name"],
271
+ "function": fn,
272
+ "grader_score": grader_score,
273
+ "cumulative_reward": obs["cumulative_reward"],
274
+ }
275
 
276
 
277
  # ─────────────────────────────────────────────────────────────────────────────
278
  # Task 3 β€” Rule Checker
279
  # ─────────────────────────────────────────────────────────────────────────────
280
 
 
 
 
 
 
 
 
 
 
 
 
 
281
 
282
+ def _t3_user_msg(obs: Dict[str, Any]) -> str:
283
+ extra = obs.get("extra", {})
284
+ return (
285
+ f"Contract : {obs['contract_name']}\n"
286
+ f"Property : {extra.get('property_english', '(none)')}\n"
287
+ f"Step: {obs['step_count']} | Reward so far: {obs['cumulative_reward']:.2f}\n\n"
288
+ f"Last action : {obs['last_action'] or 'None'}\n"
289
+ f"Last result :\n{obs['last_action_result'] or 'Episode just started.'}"
290
  )
 
 
 
 
 
 
 
 
 
 
 
 
291
 
292
 
293
  def _run_t3_episode(env: Task3Environment, seed: int, ep_num: int) -> Dict[str, Any]:
294
+ """Run one Task 3 episode; emit [START]/[STEP]/[END]."""
295
  r = env.reset(seed=seed)
296
  obs = r.observation.model_dump()
297
+
298
  log_start(task="task3_rule_checker", env=ENV_BENCHMARK, model=MODEL_NAME) # type: ignore
299
 
300
+ messages: List[ChatCompletionMessageParam] = [ # type: ignore
301
+ {"role": "system", "content": T3_SYSTEM}
302
+ ]
303
  step_rewards: List[float] = []
304
  grader_score = 0.0
305
  steps_taken = 0
 
307
 
308
  try:
309
  for step in range(1, MAX_STEPS_T3 + 1):
310
+ messages.append({"role": "user", "content": _t3_user_msg(obs)})
311
+ try:
312
+ resp = client.chat.completions.create(
313
+ model=MODEL_NAME, messages=messages, # type: ignore
314
+ max_tokens=200, temperature=0.0,
315
+ )
316
+ raw = resp.choices[0].message.content.strip() # type: ignore
317
+ error_msg = None
318
+ except Exception as e:
319
+ raw = ""
320
+ error_msg = str(e)[:80]
321
+ print(f"[DEBUG] T3 LLM error ep={ep_num} step={step}: {e}", file=sys.stderr)
322
+
323
+ try:
324
+ parsed = json.loads(raw)
325
+ at = ActionType(parsed["action"])
326
+ params = parsed.get("params", {})
327
+ except Exception:
328
+ at, params = ActionType.LIST_FUNCTIONS, {}
329
+
330
+ messages.append({"role": "assistant", "content": raw})
331
+ result = env.step(Action(action_type=at, params=params))
332
+ obs = result.observation.model_dump()
333
+ r_val = result.reward.value
334
+ done = result.done
335
 
336
  step_rewards.append(r_val)
337
  steps_taken = step
 
339
 
340
  if done:
341
  v = r_val
342
+ grader_score = 1.0 if v >= 4.9 else (0.3 if v >= 1.0 else 0.0)
343
  break
344
 
345
+ time.sleep(0.3)
 
346
 
347
  finally:
348
  success = grader_score >= SUCCESS_SCORE_THRESHOLD
349
  log_end(success=success, steps=steps_taken, score=grader_score, rewards=step_rewards)
350
 
351
+ return {
352
+ "episode": ep_num,
353
+ "seed": seed,
354
+ "contract": obs["contract_name"],
355
+ "grader_score": grader_score,
356
+ "cumulative_reward": obs["cumulative_reward"],
357
+ }
358
 
359
 
360
  # ─────────────────────────────────────────────────────────────────────────────
 
371
  avg_r = sum(e["cumulative_reward"] for e in episodes) / n
372
  print(f"\n Avg grader score : {avg_s:.3f}", flush=True)
373
  print(f" Avg cum reward : {avg_r:.2f}", flush=True)
374
+ return {
375
+ "task_id": "task1_vuln_detection", "name": "Targeted Vulnerability Detection",
376
+ "status": "active", "num_episodes": n, "episodes": episodes,
377
+ "avg_grader_score": avg_s, "avg_cumulative_reward": avg_r,
378
+ }
379
 
380
 
381
  def run_task2(n: int = NUM_EPISODES) -> Dict[str, Any]:
 
388
  avg_r = sum(e["cumulative_reward"] for e in episodes) / n
389
  print(f"\n Avg grader score : {avg_s:.3f}", flush=True)
390
  print(f" Avg cum reward : {avg_r:.2f}", flush=True)
391
+ return {
392
+ "task_id": "task2_property_discovery", "name": "Property Discovery",
393
+ "status": "active", "num_episodes": n, "episodes": episodes,
394
+ "avg_grader_score": avg_s, "avg_cumulative_reward": avg_r,
395
+ }
396
 
397
 
398
  def run_task3(n: int = NUM_EPISODES) -> Dict[str, Any]:
 
405
  avg_r = sum(e["cumulative_reward"] for e in episodes) / n
406
  print(f"\n Avg grader score : {avg_s:.3f}", flush=True)
407
  print(f" Avg cum reward : {avg_r:.2f}", flush=True)
408
+ return {
409
+ "task_id": "task3_rule_checker", "name": "Rule Checker",
410
+ "status": "active", "num_episodes": n, "episodes": episodes,
411
+ "avg_grader_score": avg_s, "avg_cumulative_reward": avg_r,
412
+ }
413
 
414
 
415
  # ─────────────────────────────────────────────────────────────────────────────
 
417
  # ─────────────────────────────────────────────────────────────────────────────
418
 
419
  async def main() -> None:
420
+ """Async entry point (wraps sync env calls; asyncio.run() expected by caller)."""
421
  print("Smart Contract Audit RL Environment β€” Baseline Inference", flush=True)
422
+ print(f"Model: {MODEL_NAME} | Base URL: {API_BASE_URL}", flush=True)
 
 
 
 
 
423
 
424
  t1 = run_task1(NUM_EPISODES)
425
  t2 = run_task2(NUM_EPISODES)
426
  t3 = run_task3(NUM_EPISODES)
427
 
428
  results = {
429
+ "model": MODEL_NAME,
430
+ "base_url": API_BASE_URL,
431
+ "tasks": [t1, t2, t3],
432
  }
433
  overall = sum(t["avg_grader_score"] for t in results["tasks"]) / 3
434
  results["overall_avg_score"] = overall
 
446
 
447
 
448
  if __name__ == "__main__":
449
+ asyncio.run(main())