Spaces:
Sleeping
Sleeping
| """ | |
| eval.py | |
| ------- | |
| Evaluation harness for the Smart Contract Audit RL Environment. | |
| Runs a configurable number of episodes per task, collecting grader scores | |
| and reward trajectories. Produces a detailed JSON report. | |
| Unlike inference.py (which uses an external LLM), this evaluates the | |
| *environment itself* using a built-in oracle agent β useful for: | |
| - Verifying grader correctness | |
| - Benchmarking reward shaping | |
| - Checking score distribution across vulnerability types | |
| Usage: | |
| python eval.py # all 8 vuln episodes | |
| python eval.py --episodes 16 # more episodes | |
| python eval.py --seed 0 --verbose # detailed per-step output | |
| python eval.py --out results.json # custom output file | |
| """ | |
| import argparse | |
| import json | |
| import sys | |
| import time | |
| from typing import Any, Dict, List | |
| from tasks.task1.environment import Task1Environment | |
| from env.schemas import Action, ActionType | |
| from data.data_loader import load_contracts, get_all_vulnerable_entries | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Oracle agent (always submits the ground-truth answer) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def oracle_agent(env: Task1Environment, seed: int, verbose: bool = False) -> Dict[str, Any]: | |
| """ | |
| Runs one episode using the oracle strategy: | |
| 1. list_functions | |
| 2. get_function_code (for the target function β peeked from state) | |
| 3. submit correct answer | |
| This gives an upper-bound score trajectory for the environment. | |
| Always ends with grader_score = 1.0. | |
| """ | |
| reset_result = env.reset(seed=seed) | |
| obs = reset_result.observation | |
| steps_taken: List[Dict[str, Any]] = [] | |
| def _step(at: ActionType, params: dict = None) -> Any: | |
| params = params or {} | |
| action = Action(action_type=at, params=params) | |
| result = env.step(action) | |
| entry = { | |
| "step": result.observation.step_count, | |
| "action": at.value, | |
| "params": params, | |
| "reward": result.reward.value, | |
| "reason": result.reward.reason, | |
| "cumulative": result.observation.cumulative_reward, | |
| "done": result.done, | |
| } | |
| steps_taken.append(entry) | |
| if verbose: | |
| done_flag = " [DONE]" if result.done else "" | |
| print( | |
| f" step {entry['step']:2d}: {at.value:25s} " | |
| f"r={result.reward.value:+.2f} cum={entry['cumulative']:+.2f}" | |
| f"{done_flag}" | |
| ) | |
| return result | |
| # Peek at ground truth (oracle only) | |
| state = env.state() | |
| target_fn = state.target_function | |
| # Get ground-truth vulnerability from data | |
| contracts = load_contracts() | |
| vuln_issue = None | |
| for contract in contracts: | |
| for fn in contract.get("functions", []): | |
| if fn["name"].lower() == target_fn.lower() and fn.get("vulnerable"): | |
| # ! SINCE OUR MATCHER IS BASED ON FACT THAT EXPECTED STRING IS 2-3 WORDS, THIS DOESN'T MATCH WELL | |
| vuln_issue = fn["vulnerability_details"]["issue"] | |
| break | |
| if vuln_issue: | |
| break | |
| if verbose: | |
| print(f" Contract : {obs.contract_name}") | |
| print(f" Target : {target_fn} ({vuln_issue})") | |
| # Step 1: list functions (small cost, realistic) | |
| _step(ActionType.LIST_FUNCTIONS) | |
| # Step 2: read target function code (gets +0.05 shaping reward) | |
| _step(ActionType.GET_FUNCTION_CODE, {"function_name": target_fn}) | |
| # Step 3: submit perfect answer | |
| result = _step(ActionType.SUBMIT, { | |
| "function_name": target_fn, | |
| "vulnerability_type": vuln_issue, | |
| }) | |
| final_reward = result.reward.value | |
| if final_reward >= 4.9: | |
| grader_score = 1.0 | |
| elif final_reward >= 0.9: | |
| grader_score = 0.5 | |
| else: | |
| grader_score = 0.0 | |
| return { | |
| "seed": seed, | |
| "contract": obs.contract_name, | |
| "target_function": target_fn, | |
| "vulnerability": vuln_issue, | |
| "grader_score": grader_score, | |
| "cumulative_reward": result.observation.cumulative_reward, | |
| "steps": steps_taken, | |
| "num_steps": len(steps_taken), | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Partial agent (submits correct function, wrong vuln type) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def partial_agent(env: Task1Environment, seed: int) -> Dict[str, Any]: | |
| """Submits right function, always uses 'unknown' as vulnerability type β score 0.5.""" | |
| reset_result = env.reset(seed=seed) | |
| obs = reset_result.observation | |
| state = env.state() | |
| target_fn = state.target_function | |
| action = Action(action_type=ActionType.SUBMIT, params={ | |
| "function_name": target_fn, | |
| "vulnerability_type": "unknown vulnerability", | |
| }) | |
| result = env.step(action) | |
| return { | |
| "seed": seed, | |
| "grader_score": 0.5, | |
| "cumulative_reward": result.observation.cumulative_reward, | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Random agent (submits a random wrong function) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def random_agent(env: Task1Environment, seed: int) -> Dict[str, Any]: | |
| """Always submits 'constructor' β always wrong β score 0.0.""" | |
| env.reset(seed=seed) | |
| action = Action(action_type=ActionType.SUBMIT, params={ | |
| "function_name": "constructor", | |
| "vulnerability_type": "reentrancy", | |
| }) | |
| result = env.step(action) | |
| return { | |
| "seed": seed, | |
| "grader_score": 0.0, | |
| "cumulative_reward": result.observation.cumulative_reward, | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Evaluation runner | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_evaluation( | |
| num_episodes: int = 8, | |
| seed_offset: int = 0, | |
| verbose: bool = False, | |
| output_file: str = "eval_results.json", | |
| ) -> None: | |
| env = Task1Environment() | |
| contracts = load_contracts() | |
| entries = get_all_vulnerable_entries(contracts) | |
| vuln_types = list({fn["vulnerability_details"]["issue"] for _, fn in entries}) | |
| print("=" * 64) | |
| print("Smart Contract Audit RL Environment β Evaluation") | |
| print("=" * 64) | |
| print(f" Episodes : {num_episodes}") | |
| print(f" Seed range: {seed_offset} β {seed_offset + num_episodes - 1}") | |
| print(f" Vulns in dataset: {len(entries)}") | |
| print() | |
| # ββ Oracle agent βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("βΆ Oracle agent (upper bound β always submits correct answer):") | |
| oracle_episodes = [] | |
| for i in range(num_episodes): | |
| seed = seed_offset + i | |
| ep = oracle_agent(env, seed=seed, verbose=verbose) | |
| oracle_episodes.append(ep) | |
| icon = "β " if ep["grader_score"] == 1.0 else "β οΈ " | |
| print( | |
| f" {icon} seed={seed:3d} {ep['contract']:12s} " | |
| f"{ep['target_function']:15s} score={ep['grader_score']:.1f} " | |
| f"reward={ep['cumulative_reward']:+.2f}" | |
| ) | |
| oracle_avg = sum(e["grader_score"] for e in oracle_episodes) / num_episodes | |
| oracle_avg_r = sum(e["cumulative_reward"] for e in oracle_episodes) / num_episodes | |
| print(f"\n Oracle avg grader score : {oracle_avg:.3f}") | |
| print(f" Oracle avg reward : {oracle_avg_r:+.2f}") | |
| # ββ Partial agent βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("\nβΆ Partial agent (right function, wrong vuln type β 0.5 each):") | |
| partial_episodes = [] | |
| for i in range(num_episodes): | |
| ep = partial_agent(env, seed=seed_offset + i) | |
| partial_episodes.append(ep) | |
| partial_avg = sum(e["grader_score"] for e in partial_episodes) / num_episodes | |
| print(f" Partial avg grader score: {partial_avg:.3f}") | |
| # ββ Random agent ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("\nβΆ Random agent (always wrong β 0.0 each):") | |
| random_episodes = [] | |
| for i in range(num_episodes): | |
| ep = random_agent(env, seed=seed_offset + i) | |
| random_episodes.append(ep) | |
| random_avg = sum(e["grader_score"] for e in random_episodes) / num_episodes | |
| print(f" Random avg grader score : {random_avg:.3f}") | |
| # ββ Score distribution ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("\nβΆ Coverage across vulnerability types:") | |
| seen = {} | |
| for ep in oracle_episodes: | |
| v = ep.get("vulnerability", "unknown") | |
| seen[v] = seen.get(v, 0) + 1 | |
| for v in sorted(seen): | |
| print(f" {seen[v]:2d}x {v}") | |
| # ββ Summary βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("\n" + "=" * 64) | |
| print("SUMMARY") | |
| print("=" * 64) | |
| print(f" Oracle (ceiling): {oracle_avg:.3f} {'β ' if oracle_avg == 1.0 else 'β οΈ '}") | |
| print(f" Partial (partial): {partial_avg:.3f} β ") | |
| print(f" Random (floor) : {random_avg:.3f} β ") | |
| assert oracle_avg == 1.0, "Oracle should always score 1.0" | |
| assert partial_avg == 0.5, "Partial should always score 0.5" | |
| assert random_avg == 0.0, "Random should always score 0.0" | |
| print("\n β All score sanity checks passed.") | |
| # ββ Write results βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| report = { | |
| "num_episodes": num_episodes, | |
| "seed_offset": seed_offset, | |
| "agents": { | |
| "oracle": {"avg_score": oracle_avg, "avg_reward": oracle_avg_r, "episodes": oracle_episodes}, | |
| "partial": {"avg_score": partial_avg, "episodes": partial_episodes}, | |
| "random": {"avg_score": random_avg, "episodes": random_episodes}, | |
| }, | |
| "vulnerability_coverage": seen, | |
| } | |
| with open(output_file, "w") as f: | |
| json.dump(report, f, indent=2) | |
| print(f"\n Results written to {output_file}") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Entry point | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Evaluate the SC Audit RL Environment") | |
| parser.add_argument("--episodes", type=int, default=8, | |
| help="Number of episodes per agent (default: 8)") | |
| parser.add_argument("--seed", type=int, default=42, | |
| help="Starting seed (default: 42)") | |
| parser.add_argument("--verbose", action="store_true", | |
| help="Print per-step details for oracle agent") | |
| parser.add_argument("--out", default="eval_results.json", | |
| help="Output JSON file (default: eval_results.json)") | |
| args = parser.parse_args() | |
| run_evaluation( | |
| num_episodes=args.episodes, | |
| seed_offset=args.seed, | |
| verbose=args.verbose, | |
| output_file=args.out, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |