Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """Deterministic baseline agent for WebSec Repair Env.""" | |
| from __future__ import annotations | |
| import argparse | |
| import sys | |
| from pathlib import Path | |
| ROOT = Path(__file__).resolve().parent | |
| PARENT = ROOT.parent | |
| if str(PARENT) not in sys.path: | |
| sys.path.insert(0, str(PARENT)) | |
| from websec_repair_env import WebSecRepairAction, WebSecRepairEnv | |
| TASK_TO_VULNERABILITY = { | |
| "sqli_login": "sql_injection", | |
| "xss_comments": "xss", | |
| "broken_auth_admin": "broken_auth", | |
| } | |
| TASK_TO_PATCH = { | |
| "sqli_login": "parameterized_query", | |
| "xss_comments": "html_escape", | |
| "broken_auth_admin": "require_admin_role", | |
| } | |
| HINT_KEYWORDS = { | |
| "sql injection": "sql_injection", | |
| "cross-site scripting": "xss", | |
| "xss": "xss", | |
| "access control": "broken_auth", | |
| "authorization": "broken_auth", | |
| "admin route": "broken_auth", | |
| } | |
| def choose_vulnerability(task_id: str, scanner_hint: str) -> str: | |
| """Pick the deterministic vulnerability label for the baseline.""" | |
| lowered = scanner_hint.lower() | |
| for keyword, label in HINT_KEYWORDS.items(): | |
| if keyword in lowered: | |
| return label | |
| return TASK_TO_VULNERABILITY[task_id] | |
| def run_baseline(base_url: str, task_id: str) -> int: | |
| """Run the deterministic baseline policy against a running environment.""" | |
| with WebSecRepairEnv(base_url=base_url).sync() as env: | |
| result = env.reset(task_id=task_id) | |
| print(f"reset: task={result.observation.task_id}") | |
| result = env.step(WebSecRepairAction(action_type="inspect")) | |
| print(f"inspect: reward={result.reward} status={result.observation.status_message}") | |
| vulnerability = choose_vulnerability( | |
| result.observation.task_id, | |
| result.observation.scanner_hint, | |
| ) | |
| result = env.step( | |
| WebSecRepairAction( | |
| action_type="classify", | |
| vulnerability_type=vulnerability, | |
| ) | |
| ) | |
| print(f"classify: reward={result.reward} selected={result.observation.selected_vulnerability}") | |
| patch_id = TASK_TO_PATCH[result.observation.task_id] | |
| result = env.step( | |
| WebSecRepairAction( | |
| action_type="apply_patch", | |
| patch_id=patch_id, | |
| ) | |
| ) | |
| print(f"apply_patch: reward={result.reward} patch={result.observation.applied_patch_id}") | |
| result = env.step(WebSecRepairAction(action_type="verify")) | |
| print( | |
| "verify: " | |
| f"reward={result.reward} exploit={result.observation.exploit_test_passed} " | |
| f"functionality={result.observation.functionality_test_passed}" | |
| ) | |
| result = env.step(WebSecRepairAction(action_type="submit")) | |
| print( | |
| "submit: " | |
| f"reward={result.reward} done={result.done} passed={result.observation.grader_passed} " | |
| f"score={result.observation.reward}" | |
| ) | |
| return 0 if result.observation.grader_passed else 1 | |
| def main() -> None: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--base-url", default="http://127.0.0.1:8000") | |
| parser.add_argument( | |
| "--task", | |
| default="sqli_login", | |
| choices=sorted(TASK_TO_VULNERABILITY), | |
| ) | |
| args = parser.parse_args() | |
| raise SystemExit(run_baseline(args.base_url, args.task)) | |
| if __name__ == "__main__": | |
| main() | |