ajaxwin
New matching logic for grader
8fccda7
raw
history blame
13.2 kB
"""
eval.py
-------
Evaluation harness for the Smart Contract Audit RL Environment.
Runs a configurable number of episodes per task, collecting grader scores
and reward trajectories. Produces a detailed JSON report.
Unlike inference.py (which uses an external LLM), this evaluates the
*environment itself* using a built-in oracle agent β€” useful for:
- Verifying grader correctness
- Benchmarking reward shaping
- Checking score distribution across vulnerability types
Usage:
python eval.py # all 8 vuln episodes
python eval.py --episodes 16 # more episodes
python eval.py --seed 0 --verbose # detailed per-step output
python eval.py --out results.json # custom output file
"""
import argparse
import json
import sys
import time
from typing import Any, Dict, List
from tasks.task1.environment import Task1Environment
from env.schemas import Action, ActionType
from data.data_loader import load_contracts, get_all_vulnerable_entries
# ─────────────────────────────────────────────────────────────────────────────
# Oracle agent (always submits the ground-truth answer)
# ─────────────────────────────────────────────────────────────────────────────
def oracle_agent(env: Task1Environment, seed: int, verbose: bool = False) -> Dict[str, Any]:
"""
Runs one episode using the oracle strategy:
1. list_functions
2. get_function_code (for the target function β€” peeked from state)
3. submit correct answer
This gives an upper-bound score trajectory for the environment.
Always ends with grader_score = 1.0.
"""
reset_result = env.reset(seed=seed)
obs = reset_result.observation
steps_taken: List[Dict[str, Any]] = []
def _step(at: ActionType, params: dict = None) -> Any:
params = params or {}
action = Action(action_type=at, params=params)
result = env.step(action)
entry = {
"step": result.observation.step_count,
"action": at.value,
"params": params,
"reward": result.reward.value,
"reason": result.reward.reason,
"cumulative": result.observation.cumulative_reward,
"done": result.done,
}
steps_taken.append(entry)
if verbose:
done_flag = " [DONE]" if result.done else ""
print(
f" step {entry['step']:2d}: {at.value:25s} "
f"r={result.reward.value:+.2f} cum={entry['cumulative']:+.2f}"
f"{done_flag}"
)
return result
# Peek at ground truth (oracle only)
state = env.state()
target_fn = state.target_function
# Get ground-truth vulnerability from data
contracts = load_contracts()
vuln_issue = None
for contract in contracts:
for fn in contract.get("functions", []):
if fn["name"].lower() == target_fn.lower() and fn.get("vulnerable"):
# ! SINCE OUR MATCHER IS BASED ON FACT THAT EXPECTED STRING IS 2-3 WORDS, THIS DOESN'T MATCH WELL
vuln_issue = fn["vulnerability_details"]["issue"]
break
if vuln_issue:
break
if verbose:
print(f" Contract : {obs.contract_name}")
print(f" Target : {target_fn} ({vuln_issue})")
# Step 1: list functions (small cost, realistic)
_step(ActionType.LIST_FUNCTIONS)
# Step 2: read target function code (gets +0.05 shaping reward)
_step(ActionType.GET_FUNCTION_CODE, {"function_name": target_fn})
# Step 3: submit perfect answer
result = _step(ActionType.SUBMIT, {
"function_name": target_fn,
"vulnerability_type": vuln_issue,
})
final_reward = result.reward.value
if final_reward >= 4.9:
grader_score = 1.0
elif final_reward >= 0.9:
grader_score = 0.5
else:
grader_score = 0.0
return {
"seed": seed,
"contract": obs.contract_name,
"target_function": target_fn,
"vulnerability": vuln_issue,
"grader_score": grader_score,
"cumulative_reward": result.observation.cumulative_reward,
"steps": steps_taken,
"num_steps": len(steps_taken),
}
# ─────────────────────────────────────────────────────────────────────────────
# Partial agent (submits correct function, wrong vuln type)
# ─────────────────────────────────────────────────────────────────────────────
def partial_agent(env: Task1Environment, seed: int) -> Dict[str, Any]:
"""Submits right function, always uses 'unknown' as vulnerability type β†’ score 0.5."""
reset_result = env.reset(seed=seed)
obs = reset_result.observation
state = env.state()
target_fn = state.target_function
action = Action(action_type=ActionType.SUBMIT, params={
"function_name": target_fn,
"vulnerability_type": "unknown vulnerability",
})
result = env.step(action)
return {
"seed": seed,
"grader_score": 0.5,
"cumulative_reward": result.observation.cumulative_reward,
}
# ─────────────────────────────────────────────────────────────────────────────
# Random agent (submits a random wrong function)
# ─────────────────────────────────────────────────────────────────────────────
def random_agent(env: Task1Environment, seed: int) -> Dict[str, Any]:
"""Always submits 'constructor' β€” always wrong β†’ score 0.0."""
env.reset(seed=seed)
action = Action(action_type=ActionType.SUBMIT, params={
"function_name": "constructor",
"vulnerability_type": "reentrancy",
})
result = env.step(action)
return {
"seed": seed,
"grader_score": 0.0,
"cumulative_reward": result.observation.cumulative_reward,
}
# ─────────────────────────────────────────────────────────────────────────────
# Evaluation runner
# ─────────────────────────────────────────────────────────────────────────────
def run_evaluation(
num_episodes: int = 8,
seed_offset: int = 0,
verbose: bool = False,
output_file: str = "eval_results.json",
) -> None:
env = Task1Environment()
contracts = load_contracts()
entries = get_all_vulnerable_entries(contracts)
vuln_types = list({fn["vulnerability_details"]["issue"] for _, fn in entries})
print("=" * 64)
print("Smart Contract Audit RL Environment β€” Evaluation")
print("=" * 64)
print(f" Episodes : {num_episodes}")
print(f" Seed range: {seed_offset} – {seed_offset + num_episodes - 1}")
print(f" Vulns in dataset: {len(entries)}")
print()
# ── Oracle agent ─────────────────────────────────────────────────────────
print("β–Ά Oracle agent (upper bound β€” always submits correct answer):")
oracle_episodes = []
for i in range(num_episodes):
seed = seed_offset + i
ep = oracle_agent(env, seed=seed, verbose=verbose)
oracle_episodes.append(ep)
icon = "βœ…" if ep["grader_score"] == 1.0 else "⚠️ "
print(
f" {icon} seed={seed:3d} {ep['contract']:12s} "
f"{ep['target_function']:15s} score={ep['grader_score']:.1f} "
f"reward={ep['cumulative_reward']:+.2f}"
)
oracle_avg = sum(e["grader_score"] for e in oracle_episodes) / num_episodes
oracle_avg_r = sum(e["cumulative_reward"] for e in oracle_episodes) / num_episodes
print(f"\n Oracle avg grader score : {oracle_avg:.3f}")
print(f" Oracle avg reward : {oracle_avg_r:+.2f}")
# ── Partial agent ─────────────────────────────────────────────────────────
print("\nβ–Ά Partial agent (right function, wrong vuln type β†’ 0.5 each):")
partial_episodes = []
for i in range(num_episodes):
ep = partial_agent(env, seed=seed_offset + i)
partial_episodes.append(ep)
partial_avg = sum(e["grader_score"] for e in partial_episodes) / num_episodes
print(f" Partial avg grader score: {partial_avg:.3f}")
# ── Random agent ──────────────────────────────────────────────────────────
print("\nβ–Ά Random agent (always wrong β†’ 0.0 each):")
random_episodes = []
for i in range(num_episodes):
ep = random_agent(env, seed=seed_offset + i)
random_episodes.append(ep)
random_avg = sum(e["grader_score"] for e in random_episodes) / num_episodes
print(f" Random avg grader score : {random_avg:.3f}")
# ── Score distribution ────────────────────────────────────────────────────
print("\nβ–Ά Coverage across vulnerability types:")
seen = {}
for ep in oracle_episodes:
v = ep.get("vulnerability", "unknown")
seen[v] = seen.get(v, 0) + 1
for v in sorted(seen):
print(f" {seen[v]:2d}x {v}")
# ── Summary ───────────────────────────────────────────────────────────────
print("\n" + "=" * 64)
print("SUMMARY")
print("=" * 64)
print(f" Oracle (ceiling): {oracle_avg:.3f} {'βœ…' if oracle_avg == 1.0 else '⚠️ '}")
print(f" Partial (partial): {partial_avg:.3f} βœ…")
print(f" Random (floor) : {random_avg:.3f} βœ…")
assert oracle_avg == 1.0, "Oracle should always score 1.0"
assert partial_avg == 0.5, "Partial should always score 0.5"
assert random_avg == 0.0, "Random should always score 0.0"
print("\n βœ… All score sanity checks passed.")
# ── Write results ─────────────────────────────────────────────────────────
report = {
"num_episodes": num_episodes,
"seed_offset": seed_offset,
"agents": {
"oracle": {"avg_score": oracle_avg, "avg_reward": oracle_avg_r, "episodes": oracle_episodes},
"partial": {"avg_score": partial_avg, "episodes": partial_episodes},
"random": {"avg_score": random_avg, "episodes": random_episodes},
},
"vulnerability_coverage": seen,
}
with open(output_file, "w") as f:
json.dump(report, f, indent=2)
print(f"\n Results written to {output_file}")
# ─────────────────────────────────────────────────────────────────────────────
# Entry point
# ─────────────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="Evaluate the SC Audit RL Environment")
parser.add_argument("--episodes", type=int, default=8,
help="Number of episodes per agent (default: 8)")
parser.add_argument("--seed", type=int, default=42,
help="Starting seed (default: 42)")
parser.add_argument("--verbose", action="store_true",
help="Print per-step details for oracle agent")
parser.add_argument("--out", default="eval_results.json",
help="Output JSON file (default: eval_results.json)")
args = parser.parse_args()
run_evaluation(
num_episodes=args.episodes,
seed_offset=args.seed,
verbose=args.verbose,
output_file=args.out,
)
if __name__ == "__main__":
main()