File size: 3,619 Bytes
57c1397
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""Tests for the deterministic WebSec Repair environment."""

from __future__ import annotations

from fastapi.testclient import TestClient

from websec_repair_env.models import WebSecRepairAction
from websec_repair_env.server.app import app
from websec_repair_env.server.websec_repair_environment import WebSecRepairEnvironment


TASK_CASES = [
    ("sqli_login", "sql_injection", "parameterized_query", "strip_quotes"),
    ("xss_comments", "xss", "html_escape", "remove_script_substring"),
    ("broken_auth_admin", "broken_auth", "require_admin_role", "hide_admin_link"),
]


def _solve_task(env: WebSecRepairEnvironment, task_id: str, vulnerability: str, patch_id: str):
    env.reset(task_id=task_id)
    env.step(WebSecRepairAction(action_type="inspect"))
    env.step(
        WebSecRepairAction(
            action_type="classify",
            vulnerability_type=vulnerability,
        )
    )
    env.step(
        WebSecRepairAction(
            action_type="apply_patch",
            patch_id=patch_id,
        )
    )
    env.step(WebSecRepairAction(action_type="verify"))
    return env.step(WebSecRepairAction(action_type="submit"))


def test_reset_and_state_smoke() -> None:
    env = WebSecRepairEnvironment()
    obs = env.reset(task_id="xss_comments")
    assert obs.task_id == "xss_comments"
    assert obs.code_snippet == ""
    assert env.state.inspected is False

    env.step(WebSecRepairAction(action_type="inspect"))
    assert env.state.inspected is True
    assert env.state.task_id == "xss_comments"

    obs = env.reset(task_id="broken_auth_admin")
    assert obs.task_id == "broken_auth_admin"
    assert env.state.inspected is False
    assert env.state.selected_vulnerability == ""
    assert env.state.applied_patch_id == ""
    assert env.state.exploit_test_passed is False


def test_happy_path_for_each_task() -> None:
    for task_id, vulnerability, correct_patch, _ in TASK_CASES:
        env = WebSecRepairEnvironment()
        result = _solve_task(env, task_id, vulnerability, correct_patch)
        assert result.done is True
        assert result.grader_passed is True
        assert result.exploit_test_passed is True
        assert result.functionality_test_passed is True
        assert env.state.score == 1.0


def test_wrong_patch_failure_for_each_task() -> None:
    for task_id, vulnerability, _, wrong_patch in TASK_CASES:
        env = WebSecRepairEnvironment()
        result = _solve_task(env, task_id, vulnerability, wrong_patch)
        assert result.done is True
        assert result.grader_passed is False
        assert env.state.score < 1.0
        assert (
            result.exploit_test_passed is False
            or result.functionality_test_passed is False
        )


def test_http_routes_return_expected_shapes() -> None:
    client = TestClient(app)

    tasks_response = client.get("/tasks")
    assert tasks_response.status_code == 200
    tasks_payload = tasks_response.json()
    assert tasks_payload["environment"] == "websec_repair_env"
    assert len(tasks_payload["tasks"]) == 3

    baseline_response = client.get("/baseline", params={"task_id": "sqli_login"})
    assert baseline_response.status_code == 200
    baseline_payload = baseline_response.json()
    assert len(baseline_payload["baselines"]) == 1
    assert baseline_payload["baselines"][0]["task_id"] == "sqli_login"

    grader_response = client.get("/grader", params={"task_id": "sqli_login"})
    assert grader_response.status_code == 200
    grader_payload = grader_response.json()
    assert grader_payload["task_id"] == "sqli_login"
    assert "checks" in grader_payload