SmartContractAudit / validate.py
ajaxwin
Inital Commit
08c19c7
raw
history blame
11 kB
"""
validate.py
-----------
Pre-submission validation script.
Checks all OpenEnv spec requirements locally before submitting.
Usage:
python validate.py
Exit code 0 = all checks pass.
Exit code 1 = one or more checks failed.
"""
import json
import sys
import traceback
from typing import Callable, List, Tuple
# ─────────────────────────────────────────────────────────────────────────────
# Helpers
# ─────────────────────────────────────────────────────────────────────────────
PASS = "βœ…"
FAIL = "❌"
SKIP = "⏭ "
results: List[Tuple[str, bool, str]] = []
def check(name: str, fn: Callable[[], None]) -> None:
try:
fn()
results.append((name, True, ""))
print(f" {PASS} {name}")
except Exception as e:
tb = traceback.format_exc(limit=3)
results.append((name, False, str(e)))
print(f" {FAIL} {name}")
print(f" {e}")
# ─────────────────────────────────────────────────────────────────────────────
# Checks
# ─────────────────────────────────────────────────────────────────────────────
def check_imports():
from env.schemas import Observation, Action, Reward, StepResult, ResetResult, StateResult
from tasks.task1.environment import Task1Environment
from tasks.task1.grader import Task1Grader
from data.data_loader import load_contracts
def check_openenv_yaml():
import yaml
with open("openenv.yaml") as f:
spec = yaml.safe_load(f)
assert "name" in spec
assert "tasks" in spec
assert len(spec["tasks"]) >= 3, "Need at least 3 tasks defined"
assert "observation_space" in spec
assert "action_space" in spec
assert "reward" in spec
def check_pydantic_models():
from env.schemas import Observation, Action, ActionType, Reward, StepResult, ResetResult, StateResult
# Instantiate each model
obs = Observation(
task_id="t1", contract_name="C", contract_description="D",
available_actions=["submit"]
)
assert obs.task_id == "t1"
action = Action(action_type=ActionType.LIST_FUNCTIONS)
assert action.action_type == ActionType.LIST_FUNCTIONS
reward = Reward(value=1.0, reason="test")
assert reward.value == 1.0
step = StepResult(observation=obs, reward=reward, done=False)
assert not step.done
reset = ResetResult(observation=obs)
assert reset.observation.task_id == "t1"
state = StateResult(task_id="t1", contract_name="C", step_count=0,
cumulative_reward=0.0, done=False)
assert state.step_count == 0
def check_data_loading():
from data.data_loader import load_contracts, get_all_vulnerable_entries
contracts = load_contracts()
assert len(contracts) >= 1, "No contracts loaded"
entries = get_all_vulnerable_entries(contracts)
assert len(entries) >= 3, f"Need >= 3 vulnerable functions, got {len(entries)}"
for contract, fn in entries:
assert fn.get("vulnerable") is True
assert fn.get("vulnerability_details") is not None
assert "issue" in fn["vulnerability_details"]
def check_env_reset():
from tasks.task1.environment import Task1Environment
env = Task1Environment()
result = env.reset(seed=42)
assert result.observation is not None
assert result.observation.task_id == "task1_vuln_detection"
assert result.observation.contract_name != ""
assert not result.observation.done
assert result.observation.step_count == 0
def check_env_step():
from tasks.task1.environment import Task1Environment
from env.schemas import Action, ActionType
env = Task1Environment()
env.reset(seed=42)
result = env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
assert result.observation is not None
assert isinstance(result.reward.value, float)
assert isinstance(result.done, bool)
assert "info" in result.model_dump()
def check_env_state():
from tasks.task1.environment import Task1Environment
env = Task1Environment()
env.reset(seed=42)
state = env.state()
assert state.task_id == "task1_vuln_detection"
assert state.contract_name != ""
assert state.target_function is not None # exposed for debugging
def check_grader_scores_in_range():
from tasks.task1.grader import Task1Grader
cases = [
("withdraw", "Reentrancy vulnerability", "withdraw", "reentrancy", 1.0),
("withdraw", "Reentrancy vulnerability", "withdraw", "something else", 0.5),
("withdraw", "Reentrancy vulnerability", "deposit", "reentrancy", 0.0),
]
for tf, issue, sf, sv, expected in cases:
g = Task1Grader(tf, issue)
score = g.grade_submission(sf, sv)
assert 0.0 <= score <= 1.0, f"Score {score} out of range"
assert abs(score - expected) < 0.01, f"Expected {expected}, got {score}"
def check_grader_deterministic():
from tasks.task1.grader import Task1Grader
g = Task1Grader("withdraw", "Reentrancy vulnerability")
s1 = g.grade_submission("withdraw", "reentrancy")
s2 = g.grade_submission("withdraw", "reentrancy")
assert s1 == s2 == 1.0, "Grader must be deterministic"
def check_reward_shaping():
"""Verify reward is non-binary (multiple distinct values across steps)."""
from tasks.task1.environment import Task1Environment
from env.schemas import Action, ActionType
env = Task1Environment()
env.reset(seed=1)
rewards = set()
for at in [ActionType.LIST_FUNCTIONS, ActionType.GET_FILE_METADATA, ActionType.GET_CALL_GRAPH]:
r = env.step(Action(action_type=at))
rewards.add(round(r.reward.value, 4))
# Should have at least 2 distinct shaping reward values
assert len(rewards) >= 2, f"Expected multiple reward values, got {rewards}"
def check_episode_boundary():
"""Episode must end after submit and raise on subsequent step."""
from tasks.task1.environment import Task1Environment
from env.schemas import Action, ActionType
env = Task1Environment()
env.reset(seed=2)
env.step(Action(action_type=ActionType.SUBMIT, params={
"function_name": "withdraw", "vulnerability_type": "test"
}))
try:
env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
raise AssertionError("Should have raised RuntimeError after episode end")
except RuntimeError:
pass # Expected
def check_repeated_query_penalty():
from tasks.task1.environment import Task1Environment
from env.schemas import Action, ActionType
env = Task1Environment()
env.reset(seed=3)
env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
r = env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
assert r.reward.value == -0.40, f"Expected -0.40 for repeated query, got {r.reward.value}"
def check_tasks_list():
"""All three tasks must be listed (even if placeholders)."""
from tasks.task2 import __all__ as t2 # noqa
from tasks.task3 import __all__ as t3 # noqa
def check_dockerfile_exists():
import os
assert os.path.exists("Dockerfile"), "Dockerfile is missing"
with open("Dockerfile") as f:
content = f.read()
assert "7860" in content, "Dockerfile must EXPOSE 7860 (HF Spaces)"
assert "uvicorn" in content or "CMD" in content
def check_inference_script():
import os
assert os.path.exists("inference.py"), "inference.py is missing"
with open("inference.py") as f:
content = f.read()
assert "OPENAI_API_KEY" in content or "HF_TOKEN" in content, \
"inference.py must read API credentials from env vars"
assert "API_BASE_URL" in content
assert "MODEL_NAME" in content
def check_baseline_json_schema():
"""baseline_scores.json must have valid schema if it exists."""
import os
if not os.path.exists("baseline_scores.json"):
return # OK β€” file is generated at runtime
with open("baseline_scores.json") as f:
data = json.load(f)
assert "tasks" in data
for task in data["tasks"]:
score = task["avg_grader_score"]
assert 0.0 <= score <= 1.0, f"Score {score} out of range"
# ─────────────────────────────────────────────────────────────────────────────
# Runner
# ─────────────────────────────────────────────────────────────────────────────
def main():
print("=" * 60)
print("OpenEnv Pre-Submission Validation")
print("=" * 60)
all_checks = [
("Python imports", check_imports),
("openenv.yaml format", check_openenv_yaml),
("Pydantic model types", check_pydantic_models),
("Dataset loading (3+ vulns)", check_data_loading),
("env.reset() β†’ ResetResult", check_env_reset),
("env.step() β†’ StepResult", check_env_step),
("env.state() β†’ StateResult", check_env_state),
("Grader scores in [0.0, 1.0]", check_grader_scores_in_range),
("Grader is deterministic", check_grader_deterministic),
("Reward shaping (non-binary)", check_reward_shaping),
("Episode boundary (done=True)",check_episode_boundary),
("Repeated query penalty", check_repeated_query_penalty),
("Task 2 & 3 placeholders", check_tasks_list),
("Dockerfile exists + port", check_dockerfile_exists),
("inference.py exists + vars", check_inference_script),
("baseline_scores.json schema", check_baseline_json_schema),
]
print()
for name, fn in all_checks:
check(name, fn)
print()
passed = sum(1 for _, ok, _ in results if ok)
total = len(results)
failed = [(n, msg) for n, ok, msg in results if not ok]
print("=" * 60)
print(f"Results: {passed}/{total} checks passed")
if failed:
print("\nFailed checks:")
for name, msg in failed:
print(f" {FAIL} {name}: {msg}")
print()
print("❌ VALIDATION FAILED β€” fix the issues above before submitting.")
sys.exit(1)
else:
print()
print("βœ… ALL CHECKS PASSED β€” ready to submit!")
sys.exit(0)
if __name__ == "__main__":
main()