websec-repair-env / tests /test_websec_repair_env.py
Daksh Verma
Upload folder using huggingface_hub
57c1397 verified
"""Tests for the deterministic WebSec Repair environment."""
from __future__ import annotations
from fastapi.testclient import TestClient
from websec_repair_env.models import WebSecRepairAction
from websec_repair_env.server.app import app
from websec_repair_env.server.websec_repair_environment import WebSecRepairEnvironment
TASK_CASES = [
("sqli_login", "sql_injection", "parameterized_query", "strip_quotes"),
("xss_comments", "xss", "html_escape", "remove_script_substring"),
("broken_auth_admin", "broken_auth", "require_admin_role", "hide_admin_link"),
]
def _solve_task(env: WebSecRepairEnvironment, task_id: str, vulnerability: str, patch_id: str):
env.reset(task_id=task_id)
env.step(WebSecRepairAction(action_type="inspect"))
env.step(
WebSecRepairAction(
action_type="classify",
vulnerability_type=vulnerability,
)
)
env.step(
WebSecRepairAction(
action_type="apply_patch",
patch_id=patch_id,
)
)
env.step(WebSecRepairAction(action_type="verify"))
return env.step(WebSecRepairAction(action_type="submit"))
def test_reset_and_state_smoke() -> None:
env = WebSecRepairEnvironment()
obs = env.reset(task_id="xss_comments")
assert obs.task_id == "xss_comments"
assert obs.code_snippet == ""
assert env.state.inspected is False
env.step(WebSecRepairAction(action_type="inspect"))
assert env.state.inspected is True
assert env.state.task_id == "xss_comments"
obs = env.reset(task_id="broken_auth_admin")
assert obs.task_id == "broken_auth_admin"
assert env.state.inspected is False
assert env.state.selected_vulnerability == ""
assert env.state.applied_patch_id == ""
assert env.state.exploit_test_passed is False
def test_happy_path_for_each_task() -> None:
for task_id, vulnerability, correct_patch, _ in TASK_CASES:
env = WebSecRepairEnvironment()
result = _solve_task(env, task_id, vulnerability, correct_patch)
assert result.done is True
assert result.grader_passed is True
assert result.exploit_test_passed is True
assert result.functionality_test_passed is True
assert env.state.score == 1.0
def test_wrong_patch_failure_for_each_task() -> None:
for task_id, vulnerability, _, wrong_patch in TASK_CASES:
env = WebSecRepairEnvironment()
result = _solve_task(env, task_id, vulnerability, wrong_patch)
assert result.done is True
assert result.grader_passed is False
assert env.state.score < 1.0
assert (
result.exploit_test_passed is False
or result.functionality_test_passed is False
)
def test_http_routes_return_expected_shapes() -> None:
client = TestClient(app)
tasks_response = client.get("/tasks")
assert tasks_response.status_code == 200
tasks_payload = tasks_response.json()
assert tasks_payload["environment"] == "websec_repair_env"
assert len(tasks_payload["tasks"]) == 3
baseline_response = client.get("/baseline", params={"task_id": "sqli_login"})
assert baseline_response.status_code == 200
baseline_payload = baseline_response.json()
assert len(baseline_payload["baselines"]) == 1
assert baseline_payload["baselines"][0]["task_id"] == "sqli_login"
grader_response = client.get("/grader", params={"task_id": "sqli_login"})
assert grader_response.status_code == 200
grader_payload = grader_response.json()
assert grader_payload["task_id"] == "sqli_login"
assert "checks" in grader_payload