Spaces:
Sleeping
Sleeping
ajaxwin commited on
Commit Β·
1248d28
1
Parent(s): 2171069
refactor: Update inference.py to use AsyncOpenAI and make episode functions asynchronous
Browse files- inference.py +24 -34
inference.py
CHANGED
|
@@ -29,7 +29,7 @@ import sys
|
|
| 29 |
import time
|
| 30 |
from typing import Any, Dict, List, Optional
|
| 31 |
|
| 32 |
-
from openai import
|
| 33 |
|
| 34 |
from server import Task1Environment, Task2Environment, Task3Environment
|
| 35 |
from env.schemas import Action, ActionType
|
|
@@ -48,7 +48,7 @@ HF_TOKEN = os.getenv("HF_TOKEN", "")
|
|
| 48 |
if not HF_TOKEN:
|
| 49 |
raise RuntimeError("HF_TOKEN environment variable not set")
|
| 50 |
|
| 51 |
-
client =
|
| 52 |
|
| 53 |
# Benchmark / environment identifier (constant for this env)
|
| 54 |
ENV_BENCHMARK = "smart-contract-audit"
|
|
@@ -69,17 +69,17 @@ SUCCESS_SCORE_THRESHOLD = 0.5
|
|
| 69 |
# Unified LLM call function
|
| 70 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 71 |
|
| 72 |
-
def get_llm_response(
|
| 73 |
messages: List[Dict[str, str]],
|
| 74 |
max_tokens: int = 200,
|
| 75 |
temperature: float = 0.0,
|
| 76 |
) -> str:
|
| 77 |
"""
|
| 78 |
-
Call the
|
| 79 |
Returns the response content as a string.
|
| 80 |
Raises an exception on failure (to be caught by the caller).
|
| 81 |
"""
|
| 82 |
-
completion = client.chat.completions.create(
|
| 83 |
model=MODEL_NAME,
|
| 84 |
messages=messages, # type: ignore
|
| 85 |
)
|
|
@@ -137,7 +137,7 @@ def _t1_user_msg(obs: Dict[str, Any]) -> str:
|
|
| 137 |
)
|
| 138 |
|
| 139 |
|
| 140 |
-
def _run_t1_episode(env: Task1Environment, seed: int, ep_num: int) -> Dict[str, Any]:
|
| 141 |
"""Run one Task 1 episode; emit [START]/[STEP]/[END]."""
|
| 142 |
r = env.reset(seed=seed)
|
| 143 |
obs = r.observation.model_dump()
|
|
@@ -156,7 +156,7 @@ def _run_t1_episode(env: Task1Environment, seed: int, ep_num: int) -> Dict[str,
|
|
| 156 |
for step in range(1, MAX_STEPS_T1 + 1):
|
| 157 |
messages.append({"role": "user", "content": _t1_user_msg(obs)})
|
| 158 |
try:
|
| 159 |
-
raw = get_llm_response(messages, max_tokens=200, temperature=0.0)
|
| 160 |
error_msg = None
|
| 161 |
except Exception as e:
|
| 162 |
raw = ""
|
|
@@ -194,8 +194,7 @@ def _run_t1_episode(env: Task1Environment, seed: int, ep_num: int) -> Dict[str,
|
|
| 194 |
"episode": ep_num,
|
| 195 |
"seed": seed,
|
| 196 |
"contract": obs["contract_name"],
|
| 197 |
-
"grader_score": grader_score
|
| 198 |
-
"cumulative_reward": obs["cumulative_reward"],
|
| 199 |
}
|
| 200 |
|
| 201 |
|
|
@@ -214,7 +213,7 @@ def _t2_user_msg(obs: Dict[str, Any]) -> str:
|
|
| 214 |
)
|
| 215 |
|
| 216 |
|
| 217 |
-
def _run_t2_episode(env: Task2Environment, seed: int, ep_num: int) -> Dict[str, Any]:
|
| 218 |
"""Run one Task 2 episode; emit [START]/[STEP]/[END]."""
|
| 219 |
r = env.reset(seed=seed)
|
| 220 |
obs = r.observation.model_dump()
|
|
@@ -234,7 +233,7 @@ def _run_t2_episode(env: Task2Environment, seed: int, ep_num: int) -> Dict[str,
|
|
| 234 |
for step in range(1, MAX_STEPS_T2 + 1):
|
| 235 |
messages.append({"role": "user", "content": _t2_user_msg(obs)})
|
| 236 |
try:
|
| 237 |
-
raw = get_llm_response(messages, max_tokens=400, temperature=0.0)
|
| 238 |
error_msg = None
|
| 239 |
except Exception as e:
|
| 240 |
raw = ""
|
|
@@ -273,8 +272,7 @@ def _run_t2_episode(env: Task2Environment, seed: int, ep_num: int) -> Dict[str,
|
|
| 273 |
"seed": seed,
|
| 274 |
"contract": obs["contract_name"],
|
| 275 |
"function": fn,
|
| 276 |
-
"grader_score": grader_score
|
| 277 |
-
"cumulative_reward": obs["cumulative_reward"],
|
| 278 |
}
|
| 279 |
|
| 280 |
|
|
@@ -292,7 +290,7 @@ def _t3_user_msg(obs: Dict[str, Any]) -> str:
|
|
| 292 |
)
|
| 293 |
|
| 294 |
|
| 295 |
-
def _run_t3_episode(env: Task3Environment, seed: int, ep_num: int) -> Dict[str, Any]:
|
| 296 |
"""Run one Task 3 episode; emit [START]/[STEP]/[END]."""
|
| 297 |
r = env.reset(seed=seed)
|
| 298 |
obs = r.observation.model_dump()
|
|
@@ -311,7 +309,7 @@ def _run_t3_episode(env: Task3Environment, seed: int, ep_num: int) -> Dict[str,
|
|
| 311 |
for step in range(1, MAX_STEPS_T3 + 1):
|
| 312 |
messages.append({"role": "user", "content": _t3_user_msg(obs)})
|
| 313 |
try:
|
| 314 |
-
raw = get_llm_response(messages, max_tokens=200, temperature=0.0)
|
| 315 |
error_msg = None
|
| 316 |
except Exception as e:
|
| 317 |
raw = ""
|
|
@@ -349,8 +347,7 @@ def _run_t3_episode(env: Task3Environment, seed: int, ep_num: int) -> Dict[str,
|
|
| 349 |
"episode": ep_num,
|
| 350 |
"seed": seed,
|
| 351 |
"contract": obs["contract_name"],
|
| 352 |
-
"grader_score": grader_score
|
| 353 |
-
"cumulative_reward": obs["cumulative_reward"],
|
| 354 |
}
|
| 355 |
|
| 356 |
|
|
@@ -358,12 +355,12 @@ def _run_t3_episode(env: Task3Environment, seed: int, ep_num: int) -> Dict[str,
|
|
| 358 |
# Task runners
|
| 359 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 360 |
|
| 361 |
-
def run_task1(n: int = NUM_EPISODES) -> Dict[str, Any]:
|
| 362 |
print("\n" + "="*60, flush=True)
|
| 363 |
print("TASK 1: Targeted Vulnerability Detection", flush=True)
|
| 364 |
print("="*60, flush=True)
|
| 365 |
env = Task1Environment()
|
| 366 |
-
episodes = [_run_t1_episode(env, SEED_BASE + i, i + 1) for i in range(n)]
|
| 367 |
avg_s = sum(e["grader_score"] for e in episodes) / n
|
| 368 |
print(f"\n Avg grader score : {avg_s:.3f}", flush=True)
|
| 369 |
return {
|
|
@@ -373,14 +370,13 @@ def run_task1(n: int = NUM_EPISODES) -> Dict[str, Any]:
|
|
| 373 |
}
|
| 374 |
|
| 375 |
|
| 376 |
-
def run_task2(n: int = NUM_EPISODES) -> Dict[str, Any]:
|
| 377 |
print("\n" + "="*60, flush=True)
|
| 378 |
print("TASK 2: Property Discovery", flush=True)
|
| 379 |
print("="*60, flush=True)
|
| 380 |
env = Task2Environment()
|
| 381 |
-
episodes = [_run_t2_episode(env, SEED_BASE + i, i + 1) for i in range(n)]
|
| 382 |
avg_s = sum(e["grader_score"] for e in episodes) / n
|
| 383 |
-
avg_r = sum(e["cumulative_reward"] for e in episodes) / n
|
| 384 |
print(f"\n Avg grader score : {avg_s:.3f}", flush=True)
|
| 385 |
return {
|
| 386 |
"task_id": "task2_property_discovery", "name": "Property Discovery",
|
|
@@ -389,14 +385,13 @@ def run_task2(n: int = NUM_EPISODES) -> Dict[str, Any]:
|
|
| 389 |
}
|
| 390 |
|
| 391 |
|
| 392 |
-
def run_task3(n: int = NUM_EPISODES) -> Dict[str, Any]:
|
| 393 |
print("\n" + "="*60, flush=True)
|
| 394 |
print("TASK 3: Rule Checker", flush=True)
|
| 395 |
print("="*60, flush=True)
|
| 396 |
env = Task3Environment()
|
| 397 |
-
episodes = [_run_t3_episode(env, SEED_BASE + i, i + 1) for i in range(n)]
|
| 398 |
avg_s = sum(e["grader_score"] for e in episodes) / n
|
| 399 |
-
avg_r = sum(e["cumulative_reward"] for e in episodes) / n
|
| 400 |
print(f"\n Avg grader score : {avg_s:.3f}", flush=True)
|
| 401 |
return {
|
| 402 |
"task_id": "task3_rule_checker", "name": "Rule Checker",
|
|
@@ -412,17 +407,12 @@ def run_task3(n: int = NUM_EPISODES) -> Dict[str, Any]:
|
|
| 412 |
async def main() -> None:
|
| 413 |
"""Async entry point (wraps sync env calls; asyncio.run() expected by caller)."""
|
| 414 |
print("Smart Contract Audit RL Environment β Baseline Inference", flush=True)
|
| 415 |
-
print(f"Model: {MODEL_NAME} | Groq API", flush=True)
|
| 416 |
|
| 417 |
-
t1 = run_task1(NUM_EPISODES)
|
| 418 |
-
t2 = run_task2(NUM_EPISODES)
|
| 419 |
-
t3 = run_task3(NUM_EPISODES)
|
| 420 |
|
| 421 |
-
results = {
|
| 422 |
-
"model": MODEL_NAME,
|
| 423 |
-
"backend": "groq",
|
| 424 |
-
"tasks": [t1, t2, t3],
|
| 425 |
-
}
|
| 426 |
overall = sum(t["avg_grader_score"] for t in results["tasks"]) / 3
|
| 427 |
results["overall_avg_score"] = overall
|
| 428 |
|
|
|
|
| 29 |
import time
|
| 30 |
from typing import Any, Dict, List, Optional
|
| 31 |
|
| 32 |
+
from openai import AsyncOpenAI
|
| 33 |
|
| 34 |
from server import Task1Environment, Task2Environment, Task3Environment
|
| 35 |
from env.schemas import Action, ActionType
|
|
|
|
| 48 |
if not HF_TOKEN:
|
| 49 |
raise RuntimeError("HF_TOKEN environment variable not set")
|
| 50 |
|
| 51 |
+
client = AsyncOpenAI(api_key=HF_TOKEN, base_url=API_BASE_URL)
|
| 52 |
|
| 53 |
# Benchmark / environment identifier (constant for this env)
|
| 54 |
ENV_BENCHMARK = "smart-contract-audit"
|
|
|
|
| 69 |
# Unified LLM call function
|
| 70 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 71 |
|
| 72 |
+
async def get_llm_response(
|
| 73 |
messages: List[Dict[str, str]],
|
| 74 |
max_tokens: int = 200,
|
| 75 |
temperature: float = 0.0,
|
| 76 |
) -> str:
|
| 77 |
"""
|
| 78 |
+
Call the LLM with the given messages and parameters.
|
| 79 |
Returns the response content as a string.
|
| 80 |
Raises an exception on failure (to be caught by the caller).
|
| 81 |
"""
|
| 82 |
+
completion = await client.chat.completions.create(
|
| 83 |
model=MODEL_NAME,
|
| 84 |
messages=messages, # type: ignore
|
| 85 |
)
|
|
|
|
| 137 |
)
|
| 138 |
|
| 139 |
|
| 140 |
+
async def _run_t1_episode(env: Task1Environment, seed: int, ep_num: int) -> Dict[str, Any]:
|
| 141 |
"""Run one Task 1 episode; emit [START]/[STEP]/[END]."""
|
| 142 |
r = env.reset(seed=seed)
|
| 143 |
obs = r.observation.model_dump()
|
|
|
|
| 156 |
for step in range(1, MAX_STEPS_T1 + 1):
|
| 157 |
messages.append({"role": "user", "content": _t1_user_msg(obs)})
|
| 158 |
try:
|
| 159 |
+
raw = await get_llm_response(messages, max_tokens=200, temperature=0.0)
|
| 160 |
error_msg = None
|
| 161 |
except Exception as e:
|
| 162 |
raw = ""
|
|
|
|
| 194 |
"episode": ep_num,
|
| 195 |
"seed": seed,
|
| 196 |
"contract": obs["contract_name"],
|
| 197 |
+
"grader_score": grader_score
|
|
|
|
| 198 |
}
|
| 199 |
|
| 200 |
|
|
|
|
| 213 |
)
|
| 214 |
|
| 215 |
|
| 216 |
+
async def _run_t2_episode(env: Task2Environment, seed: int, ep_num: int) -> Dict[str, Any]:
|
| 217 |
"""Run one Task 2 episode; emit [START]/[STEP]/[END]."""
|
| 218 |
r = env.reset(seed=seed)
|
| 219 |
obs = r.observation.model_dump()
|
|
|
|
| 233 |
for step in range(1, MAX_STEPS_T2 + 1):
|
| 234 |
messages.append({"role": "user", "content": _t2_user_msg(obs)})
|
| 235 |
try:
|
| 236 |
+
raw = await get_llm_response(messages, max_tokens=400, temperature=0.0)
|
| 237 |
error_msg = None
|
| 238 |
except Exception as e:
|
| 239 |
raw = ""
|
|
|
|
| 272 |
"seed": seed,
|
| 273 |
"contract": obs["contract_name"],
|
| 274 |
"function": fn,
|
| 275 |
+
"grader_score": grader_score
|
|
|
|
| 276 |
}
|
| 277 |
|
| 278 |
|
|
|
|
| 290 |
)
|
| 291 |
|
| 292 |
|
| 293 |
+
async def _run_t3_episode(env: Task3Environment, seed: int, ep_num: int) -> Dict[str, Any]:
|
| 294 |
"""Run one Task 3 episode; emit [START]/[STEP]/[END]."""
|
| 295 |
r = env.reset(seed=seed)
|
| 296 |
obs = r.observation.model_dump()
|
|
|
|
| 309 |
for step in range(1, MAX_STEPS_T3 + 1):
|
| 310 |
messages.append({"role": "user", "content": _t3_user_msg(obs)})
|
| 311 |
try:
|
| 312 |
+
raw = await get_llm_response(messages, max_tokens=200, temperature=0.0)
|
| 313 |
error_msg = None
|
| 314 |
except Exception as e:
|
| 315 |
raw = ""
|
|
|
|
| 347 |
"episode": ep_num,
|
| 348 |
"seed": seed,
|
| 349 |
"contract": obs["contract_name"],
|
| 350 |
+
"grader_score": grader_score
|
|
|
|
| 351 |
}
|
| 352 |
|
| 353 |
|
|
|
|
| 355 |
# Task runners
|
| 356 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 357 |
|
| 358 |
+
async def run_task1(n: int = NUM_EPISODES) -> Dict[str, Any]:
|
| 359 |
print("\n" + "="*60, flush=True)
|
| 360 |
print("TASK 1: Targeted Vulnerability Detection", flush=True)
|
| 361 |
print("="*60, flush=True)
|
| 362 |
env = Task1Environment()
|
| 363 |
+
episodes = [await _run_t1_episode(env, SEED_BASE + i, i + 1) for i in range(n)]
|
| 364 |
avg_s = sum(e["grader_score"] for e in episodes) / n
|
| 365 |
print(f"\n Avg grader score : {avg_s:.3f}", flush=True)
|
| 366 |
return {
|
|
|
|
| 370 |
}
|
| 371 |
|
| 372 |
|
| 373 |
+
async def run_task2(n: int = NUM_EPISODES) -> Dict[str, Any]:
|
| 374 |
print("\n" + "="*60, flush=True)
|
| 375 |
print("TASK 2: Property Discovery", flush=True)
|
| 376 |
print("="*60, flush=True)
|
| 377 |
env = Task2Environment()
|
| 378 |
+
episodes = [await _run_t2_episode(env, SEED_BASE + i, i + 1) for i in range(n)]
|
| 379 |
avg_s = sum(e["grader_score"] for e in episodes) / n
|
|
|
|
| 380 |
print(f"\n Avg grader score : {avg_s:.3f}", flush=True)
|
| 381 |
return {
|
| 382 |
"task_id": "task2_property_discovery", "name": "Property Discovery",
|
|
|
|
| 385 |
}
|
| 386 |
|
| 387 |
|
| 388 |
+
async def run_task3(n: int = NUM_EPISODES) -> Dict[str, Any]:
|
| 389 |
print("\n" + "="*60, flush=True)
|
| 390 |
print("TASK 3: Rule Checker", flush=True)
|
| 391 |
print("="*60, flush=True)
|
| 392 |
env = Task3Environment()
|
| 393 |
+
episodes = [await _run_t3_episode(env, SEED_BASE + i, i + 1) for i in range(n)]
|
| 394 |
avg_s = sum(e["grader_score"] for e in episodes) / n
|
|
|
|
| 395 |
print(f"\n Avg grader score : {avg_s:.3f}", flush=True)
|
| 396 |
return {
|
| 397 |
"task_id": "task3_rule_checker", "name": "Rule Checker",
|
|
|
|
| 407 |
async def main() -> None:
|
| 408 |
"""Async entry point (wraps sync env calls; asyncio.run() expected by caller)."""
|
| 409 |
print("Smart Contract Audit RL Environment β Baseline Inference", flush=True)
|
|
|
|
| 410 |
|
| 411 |
+
t1 = await run_task1(NUM_EPISODES)
|
| 412 |
+
t2 = await run_task2(NUM_EPISODES)
|
| 413 |
+
t3 = await run_task3(NUM_EPISODES)
|
| 414 |
|
| 415 |
+
results: Dict[str, Any] = { "tasks": [t1, t2, t3] }
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
overall = sum(t["avg_grader_score"] for t in results["tasks"]) / 3
|
| 417 |
results["overall_avg_score"] = overall
|
| 418 |
|