ajaxwin commited on
Commit
1248d28
Β·
1 Parent(s): 2171069

refactor: Update inference.py to use AsyncOpenAI and make episode functions asynchronous

Browse files
Files changed (1) hide show
  1. inference.py +24 -34
inference.py CHANGED
@@ -29,7 +29,7 @@ import sys
29
  import time
30
  from typing import Any, Dict, List, Optional
31
 
32
- from openai import OpenAI
33
 
34
  from server import Task1Environment, Task2Environment, Task3Environment
35
  from env.schemas import Action, ActionType
@@ -48,7 +48,7 @@ HF_TOKEN = os.getenv("HF_TOKEN", "")
48
  if not HF_TOKEN:
49
  raise RuntimeError("HF_TOKEN environment variable not set")
50
 
51
- client = OpenAI(api_key=HF_TOKEN, base_url=API_BASE_URL)
52
 
53
  # Benchmark / environment identifier (constant for this env)
54
  ENV_BENCHMARK = "smart-contract-audit"
@@ -69,17 +69,17 @@ SUCCESS_SCORE_THRESHOLD = 0.5
69
  # Unified LLM call function
70
  # ─────────────────────────────────────────────────────────────────────────────
71
 
72
- def get_llm_response(
73
  messages: List[Dict[str, str]],
74
  max_tokens: int = 200,
75
  temperature: float = 0.0,
76
  ) -> str:
77
  """
78
- Call the Groq LLM with the given messages and parameters.
79
  Returns the response content as a string.
80
  Raises an exception on failure (to be caught by the caller).
81
  """
82
- completion = client.chat.completions.create(
83
  model=MODEL_NAME,
84
  messages=messages, # type: ignore
85
  )
@@ -137,7 +137,7 @@ def _t1_user_msg(obs: Dict[str, Any]) -> str:
137
  )
138
 
139
 
140
- def _run_t1_episode(env: Task1Environment, seed: int, ep_num: int) -> Dict[str, Any]:
141
  """Run one Task 1 episode; emit [START]/[STEP]/[END]."""
142
  r = env.reset(seed=seed)
143
  obs = r.observation.model_dump()
@@ -156,7 +156,7 @@ def _run_t1_episode(env: Task1Environment, seed: int, ep_num: int) -> Dict[str,
156
  for step in range(1, MAX_STEPS_T1 + 1):
157
  messages.append({"role": "user", "content": _t1_user_msg(obs)})
158
  try:
159
- raw = get_llm_response(messages, max_tokens=200, temperature=0.0)
160
  error_msg = None
161
  except Exception as e:
162
  raw = ""
@@ -194,8 +194,7 @@ def _run_t1_episode(env: Task1Environment, seed: int, ep_num: int) -> Dict[str,
194
  "episode": ep_num,
195
  "seed": seed,
196
  "contract": obs["contract_name"],
197
- "grader_score": grader_score,
198
- "cumulative_reward": obs["cumulative_reward"],
199
  }
200
 
201
 
@@ -214,7 +213,7 @@ def _t2_user_msg(obs: Dict[str, Any]) -> str:
214
  )
215
 
216
 
217
- def _run_t2_episode(env: Task2Environment, seed: int, ep_num: int) -> Dict[str, Any]:
218
  """Run one Task 2 episode; emit [START]/[STEP]/[END]."""
219
  r = env.reset(seed=seed)
220
  obs = r.observation.model_dump()
@@ -234,7 +233,7 @@ def _run_t2_episode(env: Task2Environment, seed: int, ep_num: int) -> Dict[str,
234
  for step in range(1, MAX_STEPS_T2 + 1):
235
  messages.append({"role": "user", "content": _t2_user_msg(obs)})
236
  try:
237
- raw = get_llm_response(messages, max_tokens=400, temperature=0.0)
238
  error_msg = None
239
  except Exception as e:
240
  raw = ""
@@ -273,8 +272,7 @@ def _run_t2_episode(env: Task2Environment, seed: int, ep_num: int) -> Dict[str,
273
  "seed": seed,
274
  "contract": obs["contract_name"],
275
  "function": fn,
276
- "grader_score": grader_score,
277
- "cumulative_reward": obs["cumulative_reward"],
278
  }
279
 
280
 
@@ -292,7 +290,7 @@ def _t3_user_msg(obs: Dict[str, Any]) -> str:
292
  )
293
 
294
 
295
- def _run_t3_episode(env: Task3Environment, seed: int, ep_num: int) -> Dict[str, Any]:
296
  """Run one Task 3 episode; emit [START]/[STEP]/[END]."""
297
  r = env.reset(seed=seed)
298
  obs = r.observation.model_dump()
@@ -311,7 +309,7 @@ def _run_t3_episode(env: Task3Environment, seed: int, ep_num: int) -> Dict[str,
311
  for step in range(1, MAX_STEPS_T3 + 1):
312
  messages.append({"role": "user", "content": _t3_user_msg(obs)})
313
  try:
314
- raw = get_llm_response(messages, max_tokens=200, temperature=0.0)
315
  error_msg = None
316
  except Exception as e:
317
  raw = ""
@@ -349,8 +347,7 @@ def _run_t3_episode(env: Task3Environment, seed: int, ep_num: int) -> Dict[str,
349
  "episode": ep_num,
350
  "seed": seed,
351
  "contract": obs["contract_name"],
352
- "grader_score": grader_score,
353
- "cumulative_reward": obs["cumulative_reward"],
354
  }
355
 
356
 
@@ -358,12 +355,12 @@ def _run_t3_episode(env: Task3Environment, seed: int, ep_num: int) -> Dict[str,
358
  # Task runners
359
  # ─────────────────────────────────────────────────────────────────────────────
360
 
361
- def run_task1(n: int = NUM_EPISODES) -> Dict[str, Any]:
362
  print("\n" + "="*60, flush=True)
363
  print("TASK 1: Targeted Vulnerability Detection", flush=True)
364
  print("="*60, flush=True)
365
  env = Task1Environment()
366
- episodes = [_run_t1_episode(env, SEED_BASE + i, i + 1) for i in range(n)]
367
  avg_s = sum(e["grader_score"] for e in episodes) / n
368
  print(f"\n Avg grader score : {avg_s:.3f}", flush=True)
369
  return {
@@ -373,14 +370,13 @@ def run_task1(n: int = NUM_EPISODES) -> Dict[str, Any]:
373
  }
374
 
375
 
376
- def run_task2(n: int = NUM_EPISODES) -> Dict[str, Any]:
377
  print("\n" + "="*60, flush=True)
378
  print("TASK 2: Property Discovery", flush=True)
379
  print("="*60, flush=True)
380
  env = Task2Environment()
381
- episodes = [_run_t2_episode(env, SEED_BASE + i, i + 1) for i in range(n)]
382
  avg_s = sum(e["grader_score"] for e in episodes) / n
383
- avg_r = sum(e["cumulative_reward"] for e in episodes) / n
384
  print(f"\n Avg grader score : {avg_s:.3f}", flush=True)
385
  return {
386
  "task_id": "task2_property_discovery", "name": "Property Discovery",
@@ -389,14 +385,13 @@ def run_task2(n: int = NUM_EPISODES) -> Dict[str, Any]:
389
  }
390
 
391
 
392
- def run_task3(n: int = NUM_EPISODES) -> Dict[str, Any]:
393
  print("\n" + "="*60, flush=True)
394
  print("TASK 3: Rule Checker", flush=True)
395
  print("="*60, flush=True)
396
  env = Task3Environment()
397
- episodes = [_run_t3_episode(env, SEED_BASE + i, i + 1) for i in range(n)]
398
  avg_s = sum(e["grader_score"] for e in episodes) / n
399
- avg_r = sum(e["cumulative_reward"] for e in episodes) / n
400
  print(f"\n Avg grader score : {avg_s:.3f}", flush=True)
401
  return {
402
  "task_id": "task3_rule_checker", "name": "Rule Checker",
@@ -412,17 +407,12 @@ def run_task3(n: int = NUM_EPISODES) -> Dict[str, Any]:
412
  async def main() -> None:
413
  """Async entry point (wraps sync env calls; asyncio.run() expected by caller)."""
414
  print("Smart Contract Audit RL Environment β€” Baseline Inference", flush=True)
415
- print(f"Model: {MODEL_NAME} | Groq API", flush=True)
416
 
417
- t1 = run_task1(NUM_EPISODES)
418
- t2 = run_task2(NUM_EPISODES)
419
- t3 = run_task3(NUM_EPISODES)
420
 
421
- results = {
422
- "model": MODEL_NAME,
423
- "backend": "groq",
424
- "tasks": [t1, t2, t3],
425
- }
426
  overall = sum(t["avg_grader_score"] for t in results["tasks"]) / 3
427
  results["overall_avg_score"] = overall
428
 
 
29
  import time
30
  from typing import Any, Dict, List, Optional
31
 
32
+ from openai import AsyncOpenAI
33
 
34
  from server import Task1Environment, Task2Environment, Task3Environment
35
  from env.schemas import Action, ActionType
 
48
  if not HF_TOKEN:
49
  raise RuntimeError("HF_TOKEN environment variable not set")
50
 
51
+ client = AsyncOpenAI(api_key=HF_TOKEN, base_url=API_BASE_URL)
52
 
53
  # Benchmark / environment identifier (constant for this env)
54
  ENV_BENCHMARK = "smart-contract-audit"
 
69
  # Unified LLM call function
70
  # ─────────────────────────────────────────────────────────────────────────────
71
 
72
+ async def get_llm_response(
73
  messages: List[Dict[str, str]],
74
  max_tokens: int = 200,
75
  temperature: float = 0.0,
76
  ) -> str:
77
  """
78
+ Call the LLM with the given messages and parameters.
79
  Returns the response content as a string.
80
  Raises an exception on failure (to be caught by the caller).
81
  """
82
+ completion = await client.chat.completions.create(
83
  model=MODEL_NAME,
84
  messages=messages, # type: ignore
85
  )
 
137
  )
138
 
139
 
140
+ async def _run_t1_episode(env: Task1Environment, seed: int, ep_num: int) -> Dict[str, Any]:
141
  """Run one Task 1 episode; emit [START]/[STEP]/[END]."""
142
  r = env.reset(seed=seed)
143
  obs = r.observation.model_dump()
 
156
  for step in range(1, MAX_STEPS_T1 + 1):
157
  messages.append({"role": "user", "content": _t1_user_msg(obs)})
158
  try:
159
+ raw = await get_llm_response(messages, max_tokens=200, temperature=0.0)
160
  error_msg = None
161
  except Exception as e:
162
  raw = ""
 
194
  "episode": ep_num,
195
  "seed": seed,
196
  "contract": obs["contract_name"],
197
+ "grader_score": grader_score
 
198
  }
199
 
200
 
 
213
  )
214
 
215
 
216
+ async def _run_t2_episode(env: Task2Environment, seed: int, ep_num: int) -> Dict[str, Any]:
217
  """Run one Task 2 episode; emit [START]/[STEP]/[END]."""
218
  r = env.reset(seed=seed)
219
  obs = r.observation.model_dump()
 
233
  for step in range(1, MAX_STEPS_T2 + 1):
234
  messages.append({"role": "user", "content": _t2_user_msg(obs)})
235
  try:
236
+ raw = await get_llm_response(messages, max_tokens=400, temperature=0.0)
237
  error_msg = None
238
  except Exception as e:
239
  raw = ""
 
272
  "seed": seed,
273
  "contract": obs["contract_name"],
274
  "function": fn,
275
+ "grader_score": grader_score
 
276
  }
277
 
278
 
 
290
  )
291
 
292
 
293
+ async def _run_t3_episode(env: Task3Environment, seed: int, ep_num: int) -> Dict[str, Any]:
294
  """Run one Task 3 episode; emit [START]/[STEP]/[END]."""
295
  r = env.reset(seed=seed)
296
  obs = r.observation.model_dump()
 
309
  for step in range(1, MAX_STEPS_T3 + 1):
310
  messages.append({"role": "user", "content": _t3_user_msg(obs)})
311
  try:
312
+ raw = await get_llm_response(messages, max_tokens=200, temperature=0.0)
313
  error_msg = None
314
  except Exception as e:
315
  raw = ""
 
347
  "episode": ep_num,
348
  "seed": seed,
349
  "contract": obs["contract_name"],
350
+ "grader_score": grader_score
 
351
  }
352
 
353
 
 
355
  # Task runners
356
  # ─────────────────────────────────────────────────────────────────────────────
357
 
358
+ async def run_task1(n: int = NUM_EPISODES) -> Dict[str, Any]:
359
  print("\n" + "="*60, flush=True)
360
  print("TASK 1: Targeted Vulnerability Detection", flush=True)
361
  print("="*60, flush=True)
362
  env = Task1Environment()
363
+ episodes = [await _run_t1_episode(env, SEED_BASE + i, i + 1) for i in range(n)]
364
  avg_s = sum(e["grader_score"] for e in episodes) / n
365
  print(f"\n Avg grader score : {avg_s:.3f}", flush=True)
366
  return {
 
370
  }
371
 
372
 
373
+ async def run_task2(n: int = NUM_EPISODES) -> Dict[str, Any]:
374
  print("\n" + "="*60, flush=True)
375
  print("TASK 2: Property Discovery", flush=True)
376
  print("="*60, flush=True)
377
  env = Task2Environment()
378
+ episodes = [await _run_t2_episode(env, SEED_BASE + i, i + 1) for i in range(n)]
379
  avg_s = sum(e["grader_score"] for e in episodes) / n
 
380
  print(f"\n Avg grader score : {avg_s:.3f}", flush=True)
381
  return {
382
  "task_id": "task2_property_discovery", "name": "Property Discovery",
 
385
  }
386
 
387
 
388
+ async def run_task3(n: int = NUM_EPISODES) -> Dict[str, Any]:
389
  print("\n" + "="*60, flush=True)
390
  print("TASK 3: Rule Checker", flush=True)
391
  print("="*60, flush=True)
392
  env = Task3Environment()
393
+ episodes = [await _run_t3_episode(env, SEED_BASE + i, i + 1) for i in range(n)]
394
  avg_s = sum(e["grader_score"] for e in episodes) / n
 
395
  print(f"\n Avg grader score : {avg_s:.3f}", flush=True)
396
  return {
397
  "task_id": "task3_rule_checker", "name": "Rule Checker",
 
407
  async def main() -> None:
408
  """Async entry point (wraps sync env calls; asyncio.run() expected by caller)."""
409
  print("Smart Contract Audit RL Environment β€” Baseline Inference", flush=True)
 
410
 
411
+ t1 = await run_task1(NUM_EPISODES)
412
+ t2 = await run_task2(NUM_EPISODES)
413
+ t3 = await run_task3(NUM_EPISODES)
414
 
415
+ results: Dict[str, Any] = { "tasks": [t1, t2, t3] }
 
 
 
 
416
  overall = sum(t["avg_grader_score"] for t in results["tasks"]) / 3
417
  results["overall_avg_score"] = overall
418