SmartContractAudit / validate.py
ajaxwin
task1, task2 evaluated
671787b
raw
history blame
13.3 kB
"""
validate.py
-----------
Pre-submission validation β€” 24 checks across all three tasks.
Usage: python validate.py
Exit 0 = all pass. Exit 1 = failures.
"""
import json, sys
from typing import Callable, List, Tuple
PASS = "βœ…"; FAIL = "❌"
results: List[Tuple[str, bool, str]] = []
def check(name: str, fn: Callable) -> None:
try:
fn(); results.append((name, True, ""))
print(f" {PASS} {name}")
except Exception as e:
results.append((name, False, str(e)))
print(f" {FAIL} {name}\n {e}")
# ── Checks ────────────────────────────────────────────────────────────────────
def check_imports():
from env.schemas import Observation, Action, Reward, StepResult, ResetResult, StateResult, ActionType
from tasks.task1.environment import Task1Environment; from tasks.task1.grader import Task1Grader
from tasks.task2.environment import Task2Environment; from tasks.task2.grader import Task2Grader
from tasks.task3.environment import Task3Environment; from tasks.task3.grader import Task3Grader
from data.data_loader import load_contracts
def check_openenv_yaml():
import yaml
with open("openenv.yaml") as f: spec = yaml.safe_load(f)
assert "name" in spec and len(spec.get("tasks", [])) >= 3
assert "observation_space" in spec and "action_space" in spec and "reward" in spec
tasks = spec["tasks"]
active = [t for t in tasks if t.get("status") == "active"]
assert len(active) >= 2, f"Expected >=2 active tasks, got {len(active)}"
def check_pydantic_models():
from env.schemas import Observation, Action, ActionType, Reward, StepResult, ResetResult
obs = Observation(task_id="t", contract_name="C", contract_description="D", available_actions=[])
for at in [ActionType.LIST_FUNCTIONS, ActionType.SUBMIT_PROPERTY,
ActionType.GET_PROPERTY_SPECIFICATION, ActionType.SUBMIT_FUNCTION]:
Action(action_type=at)
Reward(value=-1.5, reason="test")
StepResult(observation=obs, reward=Reward(value=0, reason=""), done=False)
def check_data_loading():
from data.data_loader import (load_contracts, get_all_vulnerable_entries,
get_all_property_entries, get_all_task3_entries)
c = load_contracts()
assert len(get_all_vulnerable_entries(c)) >= 3
assert len(get_all_property_entries(c)) >= 3
entries = get_all_task3_entries(c)
assert len(entries) >= 3, f"Need >=3 task3 entries, got {len(entries)}"
for _, fn in entries:
t3 = fn.get("task3", {})
assert t3.get("property_english"), f"{fn['name']} missing property_english"
assert t3.get("property_formal"), f"{fn['name']} missing property_formal"
def check_t1_env():
from tasks.task1.environment import Task1Environment
from env.schemas import Action, ActionType
env = Task1Environment()
r = env.reset(seed=42); assert r.observation.task_id == "task1_vuln_detection"
s = env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
assert s.reward.value == -0.05 and s.observation.step_count == 1
assert env.state().target_function is not None
def check_t2_env():
from tasks.task2.environment import Task2Environment
from env.schemas import Action, ActionType
env = Task2Environment()
r = env.reset(seed=42)
assert r.observation.task_id == "task2_property_discovery"
assert "target_function" in r.observation.extra
for at in [ActionType.GET_FUNCTION_CODE, ActionType.GET_FUNCTION_NATSPEC,
ActionType.GET_FILE_NATSPEC, ActionType.GET_SIGNATURE,
ActionType.GET_RELATED_FUNCTIONS, ActionType.GET_SIMILAR_RULE]:
env.step(Action(action_type=at))
def check_t3_env():
from tasks.task3.environment import Task3Environment
from env.schemas import Action, ActionType
env = Task3Environment()
r = env.reset(seed=42)
assert r.observation.task_id == "task3_rule_checker"
assert "property_english" in r.observation.extra
prop = r.observation.extra["property_english"]
assert len(prop) > 10, "property_english too short"
for at in [ActionType.LIST_FUNCTIONS, ActionType.GET_PROPERTY_SPECIFICATION,
ActionType.GET_CALL_GRAPH, ActionType.GET_STATE_VARIABLE]:
s = env.step(Action(action_type=at))
assert s.reward.value < 0, f"{at.value} should have negative shaping reward"
def check_t3_action_costs():
from tasks.task3.environment import Task3Environment
from env.schemas import Action, ActionType
env = Task3Environment(); env.reset(seed=42)
costs = {
ActionType.GET_PROPERTY_SPECIFICATION: -0.03,
ActionType.LIST_FUNCTIONS: -0.05,
ActionType.GET_CALL_GRAPH: -0.08,
}
for at, expected in costs.items():
e2 = Task3Environment(); e2.reset(seed=42)
s = e2.step(Action(action_type=at))
assert abs(s.reward.value - expected) < 0.001, \
f"{at.value}: expected {expected}, got {s.reward.value}"
def check_t3_function_metadata():
from tasks.task3.environment import Task3Environment
from env.schemas import Action, ActionType
env = Task3Environment(); env.reset(seed=43)
s = env.step(Action(action_type=ActionType.GET_FUNCTION_METADATA,
params={"function_name": "withdraw"}))
assert "Visibility" in s.observation.last_action_result
assert s.reward.value == -0.05
def check_t3_submit_correct():
from tasks.task3.environment import Task3Environment
from env.schemas import Action, ActionType
env = Task3Environment(); env.reset(seed=42)
target = env.state().target_function
s = env.step(Action(action_type=ActionType.SUBMIT_FUNCTION,
params={"function_name": target}))
assert s.done and s.reward.value == 5.0, \
f"Expected reward=5.0, got {s.reward.value}"
def check_t3_submit_subfunction():
from tasks.task3.environment import Task3Environment
from env.schemas import Action, ActionType
# seed 45 β†’ bid with subfunction getPrice
env = Task3Environment(); env.reset(seed=45)
assert env.state().target_function == "bid"
s = env.step(Action(action_type=ActionType.SUBMIT_FUNCTION,
params={"function_name": "getPrice"}))
assert s.done and s.reward.value == 1.5, \
f"Expected partial reward=1.5, got {s.reward.value}"
def check_t3_submit_wrong():
from tasks.task3.environment import Task3Environment
from env.schemas import Action, ActionType
env = Task3Environment(); env.reset(seed=42)
s = env.step(Action(action_type=ActionType.SUBMIT_FUNCTION,
params={"function_name": "constructor"}))
assert s.done and s.reward.value == -1.5
def check_t3_one_submit_only():
from tasks.task3.environment import Task3Environment
from env.schemas import Action, ActionType
env = Task3Environment(); env.reset(seed=42)
env.step(Action(action_type=ActionType.SUBMIT_FUNCTION,
params={"function_name": "deposit"}))
try:
env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
raise AssertionError("Should raise RuntimeError after done")
except RuntimeError:
pass
def check_t3_repeated_penalty():
from tasks.task3.environment import Task3Environment
from env.schemas import Action, ActionType
env = Task3Environment(); env.reset(seed=42)
env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
s = env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
assert s.reward.value == -0.40
def check_t1_grader():
from tasks.task1.grader import Task1Grader
g = Task1Grader("withdraw", "Reentrancy vulnerability")
assert g.grade_submission("withdraw", "reentrancy") == 1.0
assert g.grade_submission("withdraw", "vague") == 0.5
assert g.grade_submission("deposit", "reentrancy") == 0.0
def check_t2_grader():
from tasks.task2.grader import Task2Grader
from data.data_loader import load_contracts, get_all_property_entries
for c, fn in get_all_property_entries(load_contracts()):
g = Task2Grader(fn["name"], fn["property"])
assert g.grade(fn["property"])[0] >= 0.65
assert g.grade("") == 0.0
s = g.grade("test"); assert s == g.grade("test") # deterministic
def check_t3_grader():
from tasks.task3.grader import Task3Grader
g = Task3Grader("withdraw", ["deposit"], "some rule")
assert g.grade("withdraw") == 1.0
assert g.grade("WITHDRAW") == 1.0 # case-insensitive
assert g.grade("deposit") == 0.3
assert g.grade("constructor") == 0.0
s, r = g.grade_and_reward("withdraw"); assert s == 1.0 and r == 5.0
s, r = g.grade_and_reward("deposit"); assert s == 0.3 and r == 1.5
s, r = g.grade_and_reward("other"); assert s == 0.0 and r == -1.5
def check_reward_shaping():
from tasks.task3.environment import Task3Environment
from env.schemas import Action, ActionType
env = Task3Environment(); env.reset(seed=1)
rewards = {env.step(Action(action_type=at)).reward.value
for at in [ActionType.LIST_FUNCTIONS,
ActionType.GET_PROPERTY_SPECIFICATION,
ActionType.GET_CALL_GRAPH]}
assert len(rewards) >= 2
def check_app_imports():
from app import app
from fastapi.testclient import TestClient
client = TestClient(app)
r = client.get("/health"); assert r.status_code == 200
tasks = client.get("/tasks").json()["tasks"]
active = [t for t in tasks if t["status"] == "active"]
assert len(active) == 3, f"Expected 3 active tasks, got {len(active)}: {active}"
def check_t3_http_reset():
from app import app
from fastapi.testclient import TestClient
client = TestClient(app)
r = client.post("/reset", json={"task_id": "task3_rule_checker", "seed": 42})
assert r.status_code == 200
obs = r.json()["observation"]
assert obs["task_id"] == "task3_rule_checker"
assert "property_english" in obs["extra"]
def check_dockerfile():
import os
assert os.path.exists("Dockerfile")
c = open("Dockerfile").read()
assert "7860" in c and ("uvicorn" in c or "CMD" in c)
def check_inference_script():
import os
assert os.path.exists("inference.py")
c = open("inference.py").read()
assert "HF_TOKEN" in c and "API_BASE_URL" in c and "MODEL_NAME" in c
assert "Task3Environment" in c or "run_task3" in c
assert "submit_function" in c
def check_baseline_json():
import os
if not os.path.exists("baseline_scores.json"): return
data = json.load(open("baseline_scores.json"))
for t in data.get("tasks", []):
assert 0.0 <= t["avg_grader_score"] <= 1.0
# ── Runner ────────────────────────────────────────────────────────────────────
ALL_CHECKS = [
("Python imports (T1+T2+T3)", check_imports),
("openenv.yaml: 3 tasks, β‰₯2 active", check_openenv_yaml),
("Pydantic models (all ActionTypes)", check_pydantic_models),
("Dataset: vuln+property+task3 entries",check_data_loading),
("T1 env: reset/step/state", check_t1_env),
("T2 env: reset + 6 browse actions", check_t2_env),
("T3 env: reset + browse actions", check_t3_env),
("T3 action costs (formalized -0.03)", check_t3_action_costs),
("T3 get_function_metadata", check_t3_function_metadata),
("T3 submit correct β†’ +5.0", check_t3_submit_correct),
("T3 submit subfunction β†’ +1.5", check_t3_submit_subfunction),
("T3 submit wrong β†’ -1.5", check_t3_submit_wrong),
("T3 one submit per episode", check_t3_one_submit_only),
("T3 repeated query β†’ -0.40", check_t3_repeated_penalty),
("T1 grader: 0/0.5/1.0 rubric", check_t1_grader),
("T2 grader: all 11 properties", check_t2_grader),
("T3 grader: 1.0/0.3/0.0 + case-ins.", check_t3_grader),
("Reward shaping non-binary (T3)", check_reward_shaping),
("FastAPI: 3 active tasks", check_app_imports),
("FastAPI: T3 reset endpoint", check_t3_http_reset),
("Dockerfile + port 7860", check_dockerfile),
("inference.py: T3 code present", check_inference_script),
("baseline_scores.json schema", check_baseline_json),
]
def main():
print("=" * 64)
print("OpenEnv Pre-Submission Validation (Task 1 + 2 + 3)")
print("=" * 64)
print()
for name, fn in ALL_CHECKS:
check(name, fn)
passed = sum(1 for _, ok, _ in results if ok)
total = len(results)
failed = [(n, m) for n, ok, m in results if not ok]
print()
print("=" * 64)
print(f"Results: {passed}/{total} checks passed")
if failed:
print("\nFailed checks:")
for n, m in failed:
print(f" {FAIL} {n}: {m}")
print("\n❌ VALIDATION FAILED")
sys.exit(1)
else:
print("\nβœ… ALL CHECKS PASSED β€” ready to submit!")
sys.exit(0)
if __name__ == "__main__":
main()