Daksh Verma commited on
Commit
57c1397
·
verified ·
1 Parent(s): 3282be6

Upload folder using huggingface_hub

Browse files
Dockerfile ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
2
+ FROM ${BASE_IMAGE}
3
+
4
+ WORKDIR /app/env
5
+
6
+ COPY . /app/env
7
+
8
+ RUN python -m pip install --no-cache-dir "openenv-core[core]>=0.2.1" && \
9
+ python -m pip install --no-cache-dir --no-deps -e /app/env
10
+
11
+ ENV PYTHONPATH="/app/env:${PYTHONPATH}"
12
+
13
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
14
+ CMD curl -f http://localhost:8000/health || exit 1
15
+
16
+ ENV ENABLE_WEB_INTERFACE=true
17
+ CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
README.md CHANGED
@@ -1,10 +1,121 @@
1
  ---
2
- title: Websec Repair Env
3
- emoji: 😻
4
- colorFrom: indigo
5
- colorTo: red
6
  sdk: docker
7
  pinned: false
 
 
 
 
 
 
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: WebSec Repair Env
3
+ emoji: 🛡️
4
+ colorFrom: red
5
+ colorTo: gray
6
  sdk: docker
7
  pinned: false
8
+ app_port: 8000
9
+ base_path: /web
10
+ tags:
11
+ - openenv
12
+ - security
13
+ - web
14
+ - training
15
  ---
16
 
17
+ # WebSec Repair Env
18
+
19
+ `websec_repair_env` is a lean OpenEnv environment for AI vulnerability review and remediation.
20
+ Each episode presents one deterministic vulnerability scenario and asks the agent to:
21
+
22
+ 1. inspect the task
23
+ 2. classify the vulnerability
24
+ 3. apply one discrete patch template
25
+ 4. verify exploit blocking and functionality preservation
26
+ 5. submit
27
+
28
+ The environment ships with exactly three tasks:
29
+
30
+ - `sqli_login`
31
+ - `xss_comments`
32
+ - `broken_auth_admin`
33
+
34
+ ## Actions
35
+
36
+ - `inspect`
37
+ - `classify`
38
+ - `apply_patch`
39
+ - `verify`
40
+ - `submit`
41
+
42
+ ## Reward
43
+
44
+ The score is absolute in `[0.0, 1.0]`:
45
+
46
+ - `0.25` correct classification
47
+ - `0.35` correct patch
48
+ - `0.20` exploit blocked
49
+ - `0.15` functionality preserved
50
+ - `0.05` successful submit
51
+
52
+ `step()` returns reward as score delta from the previous state.
53
+
54
+ ## Extra Routes
55
+
56
+ - `GET /tasks`
57
+ - `GET /grader`
58
+ - `GET /baseline`
59
+
60
+ `/grader` accepts optional `task_id`.
61
+ `/baseline` accepts optional `task_id` and returns a filtered catalog.
62
+
63
+ ## Local Usage
64
+
65
+ Install and lock dependencies:
66
+
67
+ ```bash
68
+ uv sync
69
+ ```
70
+
71
+ Run the server:
72
+
73
+ ```bash
74
+ uv run server
75
+ ```
76
+
77
+ Run the baseline agent against a running server:
78
+
79
+ ```bash
80
+ uv run python inference.py --task sqli_login
81
+ ```
82
+
83
+ Run tests:
84
+
85
+ ```bash
86
+ uv run pytest
87
+ ```
88
+
89
+ Validate structure:
90
+
91
+ ```bash
92
+ /home/dux/.openclaw/workspace/OpenEnv/venv/bin/openenv validate . --verbose
93
+ ```
94
+
95
+ Validate a running server:
96
+
97
+ ```bash
98
+ /home/dux/.openclaw/workspace/OpenEnv/venv/bin/openenv validate --url http://127.0.0.1:8000
99
+ ```
100
+
101
+ ## Docker
102
+
103
+ Build:
104
+
105
+ ```bash
106
+ docker build -t websec-repair-env:latest -f server/Dockerfile .
107
+ ```
108
+
109
+ Run:
110
+
111
+ ```bash
112
+ docker run --rm -p 8000:8000 websec-repair-env:latest
113
+ ```
114
+
115
+ ## Hugging Face Spaces
116
+
117
+ From this environment directory:
118
+
119
+ ```bash
120
+ /home/dux/.openclaw/workspace/OpenEnv/venv/bin/openenv push
121
+ ```
__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """OpenEnv environment for deterministic web security repair tasks."""
2
+
3
+ from .client import WebSecRepairEnv
4
+ from .models import WebSecRepairAction, WebSecRepairObservation, WebSecRepairState
5
+
6
+ __all__ = [
7
+ "WebSecRepairAction",
8
+ "WebSecRepairObservation",
9
+ "WebSecRepairState",
10
+ "WebSecRepairEnv",
11
+ ]
client.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Typed client for the WebSec Repair environment."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Dict
6
+
7
+ from openenv.core import EnvClient
8
+ from openenv.core.client_types import StepResult
9
+
10
+ from .models import PatchOption, WebSecRepairAction, WebSecRepairObservation, WebSecRepairState
11
+
12
+
13
+ class WebSecRepairEnv(
14
+ EnvClient[WebSecRepairAction, WebSecRepairObservation, WebSecRepairState]
15
+ ):
16
+ """WebSocket client for the deterministic web security repair environment."""
17
+
18
+ def _step_payload(self, action: WebSecRepairAction) -> Dict[str, Any]:
19
+ return {
20
+ "action_type": action.action_type,
21
+ "vulnerability_type": action.vulnerability_type,
22
+ "patch_id": action.patch_id,
23
+ "metadata": action.metadata,
24
+ }
25
+
26
+ def _parse_result(
27
+ self,
28
+ payload: Dict[str, Any],
29
+ ) -> StepResult[WebSecRepairObservation]:
30
+ obs_data = payload.get("observation", {})
31
+ observation = WebSecRepairObservation(
32
+ task_id=obs_data.get("task_id", ""),
33
+ instruction=obs_data.get("instruction", ""),
34
+ code_snippet=obs_data.get("code_snippet", ""),
35
+ scanner_hint=obs_data.get("scanner_hint", ""),
36
+ status_message=obs_data.get("status_message", ""),
37
+ selected_vulnerability=obs_data.get("selected_vulnerability", ""),
38
+ applied_patch_id=obs_data.get("applied_patch_id", ""),
39
+ patch_options=[PatchOption(**item) for item in obs_data.get("patch_options", [])],
40
+ exploit_test_passed=obs_data.get("exploit_test_passed", False),
41
+ functionality_test_passed=obs_data.get("functionality_test_passed", False),
42
+ grader_passed=obs_data.get("grader_passed", False),
43
+ done=payload.get("done", False),
44
+ reward=payload.get("reward"),
45
+ metadata=obs_data.get("metadata", {}),
46
+ )
47
+ return StepResult(
48
+ observation=observation,
49
+ reward=payload.get("reward"),
50
+ done=payload.get("done", False),
51
+ )
52
+
53
+ def _parse_state(self, payload: Dict[str, Any]) -> WebSecRepairState:
54
+ return WebSecRepairState(
55
+ episode_id=payload.get("episode_id"),
56
+ step_count=payload.get("step_count", 0),
57
+ task_id=payload.get("task_id", "sqli_login"),
58
+ difficulty=payload.get("difficulty", "easy"),
59
+ inspected=payload.get("inspected", False),
60
+ selected_vulnerability=payload.get("selected_vulnerability", ""),
61
+ applied_patch_id=payload.get("applied_patch_id", ""),
62
+ exploit_test_passed=payload.get("exploit_test_passed", False),
63
+ functionality_test_passed=payload.get("functionality_test_passed", False),
64
+ submitted=payload.get("submitted", False),
65
+ score=payload.get("score", 0.0),
66
+ )
inference.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Deterministic baseline agent for WebSec Repair Env."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import sys
8
+ from pathlib import Path
9
+
10
+ ROOT = Path(__file__).resolve().parent
11
+ PARENT = ROOT.parent
12
+ if str(PARENT) not in sys.path:
13
+ sys.path.insert(0, str(PARENT))
14
+
15
+ from websec_repair_env import WebSecRepairAction, WebSecRepairEnv
16
+
17
+
18
+ TASK_TO_VULNERABILITY = {
19
+ "sqli_login": "sql_injection",
20
+ "xss_comments": "xss",
21
+ "broken_auth_admin": "broken_auth",
22
+ }
23
+
24
+ TASK_TO_PATCH = {
25
+ "sqli_login": "parameterized_query",
26
+ "xss_comments": "html_escape",
27
+ "broken_auth_admin": "require_admin_role",
28
+ }
29
+
30
+ HINT_KEYWORDS = {
31
+ "sql injection": "sql_injection",
32
+ "cross-site scripting": "xss",
33
+ "xss": "xss",
34
+ "access control": "broken_auth",
35
+ "authorization": "broken_auth",
36
+ "admin route": "broken_auth",
37
+ }
38
+
39
+
40
+ def choose_vulnerability(task_id: str, scanner_hint: str) -> str:
41
+ """Pick the deterministic vulnerability label for the baseline."""
42
+ lowered = scanner_hint.lower()
43
+ for keyword, label in HINT_KEYWORDS.items():
44
+ if keyword in lowered:
45
+ return label
46
+ return TASK_TO_VULNERABILITY[task_id]
47
+
48
+
49
+ def run_baseline(base_url: str, task_id: str) -> int:
50
+ """Run the deterministic baseline policy against a running environment."""
51
+ with WebSecRepairEnv(base_url=base_url).sync() as env:
52
+ result = env.reset(task_id=task_id)
53
+ print(f"reset: task={result.observation.task_id}")
54
+
55
+ result = env.step(WebSecRepairAction(action_type="inspect"))
56
+ print(f"inspect: reward={result.reward} status={result.observation.status_message}")
57
+
58
+ vulnerability = choose_vulnerability(
59
+ result.observation.task_id,
60
+ result.observation.scanner_hint,
61
+ )
62
+ result = env.step(
63
+ WebSecRepairAction(
64
+ action_type="classify",
65
+ vulnerability_type=vulnerability,
66
+ )
67
+ )
68
+ print(f"classify: reward={result.reward} selected={result.observation.selected_vulnerability}")
69
+
70
+ patch_id = TASK_TO_PATCH[result.observation.task_id]
71
+ result = env.step(
72
+ WebSecRepairAction(
73
+ action_type="apply_patch",
74
+ patch_id=patch_id,
75
+ )
76
+ )
77
+ print(f"apply_patch: reward={result.reward} patch={result.observation.applied_patch_id}")
78
+
79
+ result = env.step(WebSecRepairAction(action_type="verify"))
80
+ print(
81
+ "verify: "
82
+ f"reward={result.reward} exploit={result.observation.exploit_test_passed} "
83
+ f"functionality={result.observation.functionality_test_passed}"
84
+ )
85
+
86
+ result = env.step(WebSecRepairAction(action_type="submit"))
87
+ print(
88
+ "submit: "
89
+ f"reward={result.reward} done={result.done} passed={result.observation.grader_passed} "
90
+ f"score={result.observation.reward}"
91
+ )
92
+ return 0 if result.observation.grader_passed else 1
93
+
94
+
95
+ def main() -> None:
96
+ parser = argparse.ArgumentParser()
97
+ parser.add_argument("--base-url", default="http://127.0.0.1:8000")
98
+ parser.add_argument(
99
+ "--task",
100
+ default="sqli_login",
101
+ choices=sorted(TASK_TO_VULNERABILITY),
102
+ )
103
+ args = parser.parse_args()
104
+ raise SystemExit(run_baseline(args.base_url, args.task))
105
+
106
+
107
+ if __name__ == "__main__":
108
+ main()
models.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Typed models for the WebSec Repair environment."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Literal
6
+
7
+ from openenv.core.env_server.types import Action, Observation, State
8
+ from pydantic import BaseModel, Field
9
+
10
+
11
+ class PatchOption(BaseModel):
12
+ """One available patch template for the active vulnerability task."""
13
+
14
+ id: str
15
+ label: str
16
+
17
+
18
+ class WebSecRepairAction(Action):
19
+ """Actions available to an agent inside the environment."""
20
+
21
+ action_type: Literal["inspect", "classify", "apply_patch", "verify", "submit"]
22
+ vulnerability_type: str | None = Field(
23
+ default=None,
24
+ description="Chosen vulnerability label for classify actions.",
25
+ )
26
+ patch_id: str | None = Field(
27
+ default=None,
28
+ description="Chosen patch option for apply_patch actions.",
29
+ )
30
+
31
+
32
+ class WebSecRepairObservation(Observation):
33
+ """Observation returned after each environment step."""
34
+
35
+ task_id: str = Field(default="", description="Active task id.")
36
+ instruction: str = Field(default="", description="Task instruction.")
37
+ code_snippet: str = Field(
38
+ default="",
39
+ description="Visible vulnerable code snippet after inspect.",
40
+ )
41
+ scanner_hint: str = Field(
42
+ default="",
43
+ description="Visible scanner hint after inspect.",
44
+ )
45
+ status_message: str = Field(
46
+ default="",
47
+ description="Human-readable result of the last action.",
48
+ )
49
+ selected_vulnerability: str = Field(
50
+ default="",
51
+ description="Currently selected vulnerability label.",
52
+ )
53
+ applied_patch_id: str = Field(
54
+ default="",
55
+ description="Currently applied patch option id.",
56
+ )
57
+ patch_options: list[PatchOption] = Field(
58
+ default_factory=list,
59
+ description="Patch options visible after inspect.",
60
+ )
61
+ exploit_test_passed: bool = Field(
62
+ default=False,
63
+ description="Whether the exploit is blocked.",
64
+ )
65
+ functionality_test_passed: bool = Field(
66
+ default=False,
67
+ description="Whether the legitimate behavior is preserved.",
68
+ )
69
+ grader_passed: bool = Field(
70
+ default=False,
71
+ description="Whether the task fully passes the grader.",
72
+ )
73
+
74
+
75
+ class WebSecRepairState(State):
76
+ """Persistent episode state tracked by the environment."""
77
+
78
+ task_id: str = Field(default="sqli_login", description="Active task id.")
79
+ difficulty: str = Field(default="easy", description="Difficulty bucket.")
80
+ inspected: bool = Field(
81
+ default=False,
82
+ description="Whether inspect has been called in this episode.",
83
+ )
84
+ selected_vulnerability: str = Field(
85
+ default="",
86
+ description="Chosen vulnerability classification.",
87
+ )
88
+ applied_patch_id: str = Field(default="", description="Chosen patch option id.")
89
+ exploit_test_passed: bool = Field(
90
+ default=False,
91
+ description="Whether verify blocked the exploit.",
92
+ )
93
+ functionality_test_passed: bool = Field(
94
+ default=False,
95
+ description="Whether verify preserved functionality.",
96
+ )
97
+ submitted: bool = Field(default=False, description="Whether submit was called.")
98
+ score: float = Field(default=0.0, description="Current absolute grader score.")
99
+
100
+
101
+ class TaskDefinition(BaseModel):
102
+ """Public task metadata."""
103
+
104
+ id: str
105
+ difficulty: str
106
+ title: str
107
+ instruction: str
108
+ code_snippet: str
109
+ scanner_hint: str
110
+ patch_options: list[PatchOption] = Field(default_factory=list)
111
+
112
+
113
+ class TaskCatalog(BaseModel):
114
+ """Response model for /tasks."""
115
+
116
+ environment: str
117
+ default_task_id: str
118
+ tasks: list[TaskDefinition]
119
+
120
+
121
+ class GraderCheck(BaseModel):
122
+ """One individual grader check."""
123
+
124
+ name: str
125
+ passed: bool
126
+ detail: str
127
+
128
+
129
+ class GraderReport(BaseModel):
130
+ """Current grader state for a task."""
131
+
132
+ task_id: str
133
+ passed: bool
134
+ score: float
135
+ message: str
136
+ checks: list[GraderCheck] = Field(default_factory=list)
137
+
138
+
139
+ class BaselineActionStep(BaseModel):
140
+ """One baseline action in the reference trajectory."""
141
+
142
+ action_type: Literal["inspect", "classify", "apply_patch", "verify", "submit"]
143
+ vulnerability_type: str | None = None
144
+ patch_id: str | None = None
145
+
146
+
147
+ class BaselineDefinition(BaseModel):
148
+ """Public baseline trajectory for one task."""
149
+
150
+ task_id: str
151
+ title: str
152
+ description: str
153
+ actions: list[BaselineActionStep] = Field(default_factory=list)
154
+
155
+
156
+ class BaselineCatalog(BaseModel):
157
+ """Response model for /baseline."""
158
+
159
+ environment: str
160
+ baselines: list[BaselineDefinition] = Field(default_factory=list)
openenv.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: websec_repair_env
3
+ type: space
4
+ runtime: fastapi
5
+ app: server.app:app
6
+ port: 8000
openenv_websec_repair_env.egg-info/PKG-INFO ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: openenv-websec-repair-env
3
+ Version: 0.1.0
4
+ Summary: Deterministic OpenEnv environment for web vulnerability repair tasks
5
+ Requires-Python: >=3.10
6
+ Description-Content-Type: text/markdown
7
+ Requires-Dist: openenv-core[core]>=0.2.1
8
+ Provides-Extra: dev
9
+ Requires-Dist: pytest>=8.0.0; extra == "dev"
10
+
11
+ ---
12
+ title: WebSec Repair Env
13
+ emoji: 🛡️
14
+ colorFrom: red
15
+ colorTo: gray
16
+ sdk: docker
17
+ pinned: false
18
+ app_port: 8000
19
+ base_path: /web
20
+ tags:
21
+ - openenv
22
+ - security
23
+ - web
24
+ - training
25
+ ---
26
+
27
+ # WebSec Repair Env
28
+
29
+ `websec_repair_env` is a lean OpenEnv environment for AI vulnerability review and remediation.
30
+ Each episode presents one deterministic vulnerability scenario and asks the agent to:
31
+
32
+ 1. inspect the task
33
+ 2. classify the vulnerability
34
+ 3. apply one discrete patch template
35
+ 4. verify exploit blocking and functionality preservation
36
+ 5. submit
37
+
38
+ The environment ships with exactly three tasks:
39
+
40
+ - `sqli_login`
41
+ - `xss_comments`
42
+ - `broken_auth_admin`
43
+
44
+ ## Actions
45
+
46
+ - `inspect`
47
+ - `classify`
48
+ - `apply_patch`
49
+ - `verify`
50
+ - `submit`
51
+
52
+ ## Reward
53
+
54
+ The score is absolute in `[0.0, 1.0]`:
55
+
56
+ - `0.25` correct classification
57
+ - `0.35` correct patch
58
+ - `0.20` exploit blocked
59
+ - `0.15` functionality preserved
60
+ - `0.05` successful submit
61
+
62
+ `step()` returns reward as score delta from the previous state.
63
+
64
+ ## Extra Routes
65
+
66
+ - `GET /tasks`
67
+ - `GET /grader`
68
+ - `GET /baseline`
69
+
70
+ `/grader` accepts optional `task_id`.
71
+ `/baseline` accepts optional `task_id` and returns a filtered catalog.
72
+
73
+ ## Local Usage
74
+
75
+ Install and lock dependencies:
76
+
77
+ ```bash
78
+ uv sync
79
+ ```
80
+
81
+ Run the server:
82
+
83
+ ```bash
84
+ uv run server
85
+ ```
86
+
87
+ Run the baseline agent against a running server:
88
+
89
+ ```bash
90
+ uv run python inference.py --task sqli_login
91
+ ```
92
+
93
+ Run tests:
94
+
95
+ ```bash
96
+ uv run pytest
97
+ ```
98
+
99
+ Validate structure:
100
+
101
+ ```bash
102
+ /home/dux/.openclaw/workspace/OpenEnv/venv/bin/openenv validate . --verbose
103
+ ```
104
+
105
+ Validate a running server:
106
+
107
+ ```bash
108
+ /home/dux/.openclaw/workspace/OpenEnv/venv/bin/openenv validate --url http://127.0.0.1:8000
109
+ ```
110
+
111
+ ## Docker
112
+
113
+ Build:
114
+
115
+ ```bash
116
+ docker build -t websec-repair-env:latest -f server/Dockerfile .
117
+ ```
118
+
119
+ Run:
120
+
121
+ ```bash
122
+ docker run --rm -p 8000:8000 websec-repair-env:latest
123
+ ```
124
+
125
+ ## Hugging Face Spaces
126
+
127
+ From this environment directory:
128
+
129
+ ```bash
130
+ /home/dux/.openclaw/workspace/OpenEnv/venv/bin/openenv push
131
+ ```
openenv_websec_repair_env.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ README.md
2
+ pyproject.toml
3
+ ./__init__.py
4
+ ./client.py
5
+ ./inference.py
6
+ ./models.py
7
+ openenv_websec_repair_env.egg-info/PKG-INFO
8
+ openenv_websec_repair_env.egg-info/SOURCES.txt
9
+ openenv_websec_repair_env.egg-info/dependency_links.txt
10
+ openenv_websec_repair_env.egg-info/entry_points.txt
11
+ openenv_websec_repair_env.egg-info/requires.txt
12
+ openenv_websec_repair_env.egg-info/top_level.txt
13
+ server/__init__.py
14
+ server/app.py
15
+ server/challenge.py
16
+ server/websec_repair_environment.py
17
+ tests/test_websec_repair_env.py
openenv_websec_repair_env.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
openenv_websec_repair_env.egg-info/entry_points.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [console_scripts]
2
+ server = websec_repair_env.server.app:main
openenv_websec_repair_env.egg-info/requires.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ openenv-core[core]>=0.2.1
2
+
3
+ [dev]
4
+ pytest>=8.0.0
openenv_websec_repair_env.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ websec_repair_env
pyproject.toml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=45", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "openenv-websec-repair-env"
7
+ version = "0.1.0"
8
+ description = "Deterministic OpenEnv environment for web vulnerability repair tasks"
9
+ readme = "README.md"
10
+ requires-python = ">=3.10"
11
+ dependencies = [
12
+ "openenv-core[core]>=0.2.1",
13
+ ]
14
+
15
+ [project.optional-dependencies]
16
+ dev = [
17
+ "pytest>=8.0.0",
18
+ ]
19
+
20
+ [project.scripts]
21
+ server = "websec_repair_env.server.app:main"
22
+
23
+ [tool.setuptools]
24
+ include-package-data = true
25
+ packages = ["websec_repair_env", "websec_repair_env.server"]
26
+ package-dir = { "websec_repair_env" = ".", "websec_repair_env.server" = "server" }
server/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Server package for the WebSec Repair environment."""
server/app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI application for the WebSec Repair environment."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+
7
+ from fastapi import HTTPException
8
+
9
+ try:
10
+ from openenv.core.env_server.http_server import create_app
11
+
12
+ from ..models import BaselineCatalog, GraderReport, TaskCatalog, WebSecRepairAction, WebSecRepairObservation
13
+ from .challenge import grade_task, list_baselines, list_tasks
14
+ from .websec_repair_environment import WebSecRepairEnvironment
15
+ except ImportError:
16
+ from openenv.core.env_server.http_server import create_app
17
+
18
+ from models import ( # type: ignore
19
+ BaselineCatalog,
20
+ GraderReport,
21
+ TaskCatalog,
22
+ WebSecRepairAction,
23
+ WebSecRepairObservation,
24
+ )
25
+ from server.challenge import grade_task, list_baselines, list_tasks # type: ignore
26
+ from server.websec_repair_environment import WebSecRepairEnvironment # type: ignore
27
+
28
+
29
+ app = create_app(
30
+ WebSecRepairEnvironment,
31
+ WebSecRepairAction,
32
+ WebSecRepairObservation,
33
+ env_name="websec_repair_env",
34
+ max_concurrent_envs=1,
35
+ )
36
+
37
+
38
+ @app.get("/tasks", response_model=TaskCatalog, tags=["challenge"])
39
+ def tasks() -> TaskCatalog:
40
+ """List all deterministic vulnerability repair tasks."""
41
+ return list_tasks()
42
+
43
+
44
+ @app.get("/grader", response_model=GraderReport, tags=["challenge"])
45
+ def grader(task_id: str | None = None) -> GraderReport:
46
+ """Return the current grader state, optionally filtered to a task id."""
47
+ try:
48
+ return grade_task(task_id=task_id)
49
+ except ValueError as exc:
50
+ raise HTTPException(status_code=404, detail=str(exc)) from exc
51
+
52
+
53
+ @app.get("/baseline", response_model=BaselineCatalog, tags=["challenge"])
54
+ def baseline(task_id: str | None = None) -> BaselineCatalog:
55
+ """Return all baseline trajectories or one filtered task baseline."""
56
+ try:
57
+ return list_baselines(task_id=task_id)
58
+ except ValueError as exc:
59
+ raise HTTPException(status_code=404, detail=str(exc)) from exc
60
+
61
+
62
+ def serve(host: str = "0.0.0.0", port: int = 8000) -> None:
63
+ """Run the FastAPI app with uvicorn."""
64
+ import uvicorn
65
+
66
+ uvicorn.run(app, host=host, port=port)
67
+
68
+
69
+ def main() -> None:
70
+ """CLI entry point exposed via the project script."""
71
+ parser = argparse.ArgumentParser()
72
+ parser.add_argument("--host", default="0.0.0.0")
73
+ parser.add_argument("--port", type=int, default=8000)
74
+ args = parser.parse_args()
75
+ serve(host=args.host, port=args.port)
76
+
77
+
78
+ if __name__ == "__main__":
79
+ main()
server/challenge.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Task definitions, baseline trajectories, and grader logic."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from copy import deepcopy
6
+
7
+ try:
8
+ from ..models import (
9
+ BaselineActionStep,
10
+ BaselineCatalog,
11
+ BaselineDefinition,
12
+ GraderCheck,
13
+ GraderReport,
14
+ PatchOption,
15
+ TaskCatalog,
16
+ TaskDefinition,
17
+ )
18
+ except ImportError:
19
+ from models import ( # type: ignore
20
+ BaselineActionStep,
21
+ BaselineCatalog,
22
+ BaselineDefinition,
23
+ GraderCheck,
24
+ GraderReport,
25
+ PatchOption,
26
+ TaskCatalog,
27
+ TaskDefinition,
28
+ )
29
+
30
+ ENVIRONMENT_NAME = "websec_repair_env"
31
+ DEFAULT_TASK_ID = "sqli_login"
32
+
33
+ _PATCH_WEIGHTS = {
34
+ "classification": 0.25,
35
+ "patch": 0.35,
36
+ "exploit": 0.20,
37
+ "functionality": 0.15,
38
+ "submit": 0.05,
39
+ }
40
+
41
+ TASKS = {
42
+ "sqli_login": {
43
+ "definition": TaskDefinition(
44
+ id="sqli_login",
45
+ difficulty="easy",
46
+ title="Patch the vulnerable login query",
47
+ instruction=(
48
+ "Inspect the login handler, classify the vulnerability, choose one patch, "
49
+ "verify exploit blocking and functionality preservation, then submit."
50
+ ),
51
+ code_snippet=(
52
+ "def login(username, password):\n"
53
+ " query = (\n"
54
+ " \"SELECT id FROM users WHERE username = '\" + username + \"' \"\n"
55
+ " \"AND password = '\" + password + \"'\"\n"
56
+ " )\n"
57
+ " return db.execute(query).fetchone()\n"
58
+ ),
59
+ scanner_hint=(
60
+ "Scanner warning: SQL injection risk detected in login query due to "
61
+ "string concatenation with user-controlled input."
62
+ ),
63
+ patch_options=[
64
+ PatchOption(id="parameterized_query", label="Use parameterized query"),
65
+ PatchOption(id="strip_quotes", label="Strip quotes from input"),
66
+ PatchOption(id="disable_login", label="Disable login route"),
67
+ ],
68
+ ),
69
+ "correct_vulnerability": "sql_injection",
70
+ "correct_patch": "parameterized_query",
71
+ "verify": {
72
+ "parameterized_query": (
73
+ True,
74
+ True,
75
+ "Exploit is blocked and valid logins still work with bound parameters.",
76
+ ),
77
+ "strip_quotes": (
78
+ False,
79
+ True,
80
+ "Input filtering is incomplete; the injection path is still reachable.",
81
+ ),
82
+ "disable_login": (
83
+ True,
84
+ False,
85
+ "Attack path is gone because login is disabled, but legitimate users are locked out.",
86
+ ),
87
+ },
88
+ },
89
+ "xss_comments": {
90
+ "definition": TaskDefinition(
91
+ id="xss_comments",
92
+ difficulty="medium",
93
+ title="Stop script execution in comment rendering",
94
+ instruction=(
95
+ "Inspect the comment renderer, classify the vulnerability, choose one patch, "
96
+ "verify exploit blocking and functionality preservation, then submit."
97
+ ),
98
+ code_snippet=(
99
+ "def render_comment(comment):\n"
100
+ ' return f"<div class=\\"comment\\">{comment}</div>"\n'
101
+ ),
102
+ scanner_hint=(
103
+ "Scanner warning: Cross-site scripting risk detected because comment text is "
104
+ "rendered into HTML without escaping."
105
+ ),
106
+ patch_options=[
107
+ PatchOption(id="html_escape", label="Escape comment before rendering"),
108
+ PatchOption(id="remove_script_substring", label="Remove 'script' substring"),
109
+ PatchOption(id="disable_comment_rendering", label="Hide comments completely"),
110
+ ],
111
+ ),
112
+ "correct_vulnerability": "xss",
113
+ "correct_patch": "html_escape",
114
+ "verify": {
115
+ "html_escape": (
116
+ True,
117
+ True,
118
+ "Malicious payloads render as text and normal comments still display.",
119
+ ),
120
+ "remove_script_substring": (
121
+ False,
122
+ True,
123
+ "Substring filtering misses alternate payloads; comments still render.",
124
+ ),
125
+ "disable_comment_rendering": (
126
+ True,
127
+ False,
128
+ "No script runs because comments are removed, but legitimate content is gone too.",
129
+ ),
130
+ },
131
+ },
132
+ "broken_auth_admin": {
133
+ "definition": TaskDefinition(
134
+ id="broken_auth_admin",
135
+ difficulty="hard",
136
+ title="Restore admin-only access control",
137
+ instruction=(
138
+ "Inspect the admin route, classify the vulnerability, choose one patch, "
139
+ "verify exploit blocking and functionality preservation, then submit."
140
+ ),
141
+ code_snippet=(
142
+ "@app.get('/admin')\n"
143
+ "def admin_panel(current_user):\n"
144
+ " if not current_user:\n"
145
+ " raise UnauthorizedError()\n"
146
+ " return render_admin_dashboard()\n"
147
+ ),
148
+ scanner_hint=(
149
+ "Scanner warning: Broken access control. Authenticated non-admin users can reach "
150
+ "the admin route because role enforcement is missing."
151
+ ),
152
+ patch_options=[
153
+ PatchOption(id="require_admin_role", label="Enforce admin role server-side"),
154
+ PatchOption(id="hide_admin_link", label="Hide the admin link in UI"),
155
+ PatchOption(id="deny_all_admin_access", label="Block all admin route traffic"),
156
+ ],
157
+ ),
158
+ "correct_vulnerability": "broken_auth",
159
+ "correct_patch": "require_admin_role",
160
+ "verify": {
161
+ "require_admin_role": (
162
+ True,
163
+ True,
164
+ "Non-admin users are denied while legitimate admins still reach the dashboard.",
165
+ ),
166
+ "hide_admin_link": (
167
+ False,
168
+ True,
169
+ "The UI hides the link, but direct requests still bypass authorization.",
170
+ ),
171
+ "deny_all_admin_access": (
172
+ True,
173
+ False,
174
+ "The bypass is gone, but valid admins are blocked too.",
175
+ ),
176
+ },
177
+ },
178
+ }
179
+
180
+ BASELINES = {
181
+ task_id: BaselineDefinition(
182
+ task_id=task_id,
183
+ title=f"{task['definition'].title} baseline",
184
+ description="Inspect, classify correctly, apply the safe patch, verify, then submit.",
185
+ actions=[
186
+ BaselineActionStep(action_type="inspect"),
187
+ BaselineActionStep(
188
+ action_type="classify",
189
+ vulnerability_type=task["correct_vulnerability"],
190
+ ),
191
+ BaselineActionStep(
192
+ action_type="apply_patch",
193
+ patch_id=task["correct_patch"],
194
+ ),
195
+ BaselineActionStep(action_type="verify"),
196
+ BaselineActionStep(action_type="submit"),
197
+ ],
198
+ )
199
+ for task_id, task in TASKS.items()
200
+ }
201
+
202
+
203
+ def _blank_progress(task_id: str) -> dict[str, object]:
204
+ try:
205
+ task = TASKS[task_id]["definition"]
206
+ except KeyError as exc:
207
+ valid = ", ".join(sorted(TASKS))
208
+ raise ValueError(f"Unknown task_id {task_id!r}. Expected one of: {valid}") from exc
209
+ return {
210
+ "task_id": task.id,
211
+ "difficulty": task.difficulty,
212
+ "inspected": False,
213
+ "selected_vulnerability": "",
214
+ "applied_patch_id": "",
215
+ "exploit_test_passed": False,
216
+ "functionality_test_passed": False,
217
+ "submitted": False,
218
+ "score": 0.0,
219
+ }
220
+
221
+
222
+ _CURRENT_PROGRESS = _blank_progress(DEFAULT_TASK_ID)
223
+
224
+
225
+ def get_task(task_id: str) -> TaskDefinition:
226
+ """Resolve one public task definition."""
227
+ try:
228
+ return TASKS[task_id]["definition"]
229
+ except KeyError as exc:
230
+ valid = ", ".join(sorted(TASKS))
231
+ raise ValueError(f"Unknown task_id {task_id!r}. Expected one of: {valid}") from exc
232
+
233
+
234
+ def list_tasks() -> TaskCatalog:
235
+ """Return the task catalog."""
236
+ return TaskCatalog(
237
+ environment=ENVIRONMENT_NAME,
238
+ default_task_id=DEFAULT_TASK_ID,
239
+ tasks=[TASKS["sqli_login"]["definition"], TASKS["xss_comments"]["definition"], TASKS["broken_auth_admin"]["definition"]],
240
+ )
241
+
242
+
243
+ def list_baselines(task_id: str | None = None) -> BaselineCatalog:
244
+ """Return all baseline trajectories or one filtered trajectory."""
245
+ if task_id is None:
246
+ baselines = [BASELINES["sqli_login"], BASELINES["xss_comments"], BASELINES["broken_auth_admin"]]
247
+ else:
248
+ get_task(task_id)
249
+ baselines = [BASELINES[task_id]]
250
+ return BaselineCatalog(environment=ENVIRONMENT_NAME, baselines=baselines)
251
+
252
+
253
+ def get_baseline(task_id: str) -> BaselineDefinition:
254
+ """Return one baseline trajectory."""
255
+ get_task(task_id)
256
+ return BASELINES[task_id]
257
+
258
+
259
+ def reset_runtime_progress(task_id: str) -> dict[str, object]:
260
+ """Reset the shared runtime snapshot for the requested task."""
261
+ global _CURRENT_PROGRESS
262
+ _CURRENT_PROGRESS = _blank_progress(task_id)
263
+ return deepcopy(_CURRENT_PROGRESS)
264
+
265
+
266
+ def set_runtime_progress(progress: dict[str, object]) -> dict[str, object]:
267
+ """Replace the shared runtime snapshot with the provided progress."""
268
+ global _CURRENT_PROGRESS
269
+ task_id = str(progress.get("task_id", DEFAULT_TASK_ID))
270
+ get_task(task_id)
271
+ merged = _blank_progress(task_id)
272
+ merged.update(progress)
273
+ _CURRENT_PROGRESS = merged
274
+ return deepcopy(_CURRENT_PROGRESS)
275
+
276
+
277
+ def current_runtime_progress() -> dict[str, object]:
278
+ """Return the current shared runtime snapshot."""
279
+ return deepcopy(_CURRENT_PROGRESS)
280
+
281
+
282
+ def verification_outcome(task_id: str, patch_id: str | None) -> tuple[bool, bool, str]:
283
+ """Evaluate exploit and functionality checks for the selected patch."""
284
+ get_task(task_id)
285
+ if not patch_id:
286
+ return False, False, "No patch has been applied yet, so both checks fail."
287
+ try:
288
+ return TASKS[task_id]["verify"][patch_id]
289
+ except KeyError as exc:
290
+ valid = ", ".join(option.id for option in TASKS[task_id]["definition"].patch_options)
291
+ raise ValueError(f"Unknown patch_id {patch_id!r}. Expected one of: {valid}") from exc
292
+
293
+
294
+ def grade_task(
295
+ task_id: str | None = None,
296
+ progress: dict[str, object] | None = None,
297
+ ) -> GraderReport:
298
+ """Grade either the provided progress snapshot or the shared runtime state."""
299
+ if progress is None:
300
+ if task_id is None:
301
+ working = current_runtime_progress()
302
+ else:
303
+ get_task(task_id)
304
+ working = current_runtime_progress()
305
+ if working["task_id"] != task_id:
306
+ working = _blank_progress(task_id)
307
+ else:
308
+ working = deepcopy(progress)
309
+ task_id = str(working.get("task_id", task_id or DEFAULT_TASK_ID))
310
+ get_task(task_id)
311
+
312
+ resolved_task_id = str(task_id or working["task_id"])
313
+ task = TASKS[resolved_task_id]
314
+
315
+ classification_correct = working.get("selected_vulnerability") == task["correct_vulnerability"]
316
+ patch_correct = working.get("applied_patch_id") == task["correct_patch"]
317
+ exploit_blocked = bool(working.get("exploit_test_passed", False))
318
+ functionality_preserved = bool(working.get("functionality_test_passed", False))
319
+ successful_submit = (
320
+ bool(working.get("submitted", False))
321
+ and classification_correct
322
+ and patch_correct
323
+ and exploit_blocked
324
+ and functionality_preserved
325
+ )
326
+
327
+ score = 0.0
328
+ if classification_correct:
329
+ score += _PATCH_WEIGHTS["classification"]
330
+ if patch_correct:
331
+ score += _PATCH_WEIGHTS["patch"]
332
+ if exploit_blocked:
333
+ score += _PATCH_WEIGHTS["exploit"]
334
+ if functionality_preserved:
335
+ score += _PATCH_WEIGHTS["functionality"]
336
+ if successful_submit:
337
+ score += _PATCH_WEIGHTS["submit"]
338
+
339
+ checks = [
340
+ GraderCheck(
341
+ name="classification_correct",
342
+ passed=classification_correct,
343
+ detail=(
344
+ f"Expected {task['correct_vulnerability']!r}; "
345
+ f"received {working.get('selected_vulnerability', '')!r}."
346
+ ),
347
+ ),
348
+ GraderCheck(
349
+ name="patch_correct",
350
+ passed=patch_correct,
351
+ detail=(
352
+ f"Expected {task['correct_patch']!r}; "
353
+ f"received {working.get('applied_patch_id', '')!r}."
354
+ ),
355
+ ),
356
+ GraderCheck(
357
+ name="exploit_blocked",
358
+ passed=exploit_blocked,
359
+ detail=(
360
+ "Verify step confirms the exploit is blocked."
361
+ if exploit_blocked
362
+ else "Exploit is still possible or verify has not been run."
363
+ ),
364
+ ),
365
+ GraderCheck(
366
+ name="functionality_preserved",
367
+ passed=functionality_preserved,
368
+ detail=(
369
+ "Verify step confirms legitimate behavior still works."
370
+ if functionality_preserved
371
+ else "Legitimate behavior is broken or verify has not been run."
372
+ ),
373
+ ),
374
+ GraderCheck(
375
+ name="submitted",
376
+ passed=successful_submit,
377
+ detail=(
378
+ "Task was submitted after all required checks passed."
379
+ if successful_submit
380
+ else "Submit bonus only applies after the correct verified repair is submitted."
381
+ ),
382
+ ),
383
+ ]
384
+
385
+ passed = all(check.passed for check in checks)
386
+ message = "Task solved." if passed else "Task is not solved yet."
387
+ return GraderReport(
388
+ task_id=resolved_task_id,
389
+ passed=passed,
390
+ score=round(score, 2),
391
+ message=message,
392
+ checks=checks,
393
+ )
server/websec_repair_environment.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Environment implementation for deterministic web security repair tasks."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+ from uuid import uuid4
7
+
8
+ from openenv.core.env_server.interfaces import Environment
9
+ from openenv.core.env_server.types import EnvironmentMetadata
10
+
11
+ try:
12
+ from ..models import WebSecRepairAction, WebSecRepairObservation, WebSecRepairState
13
+ from .challenge import (
14
+ DEFAULT_TASK_ID,
15
+ TASKS,
16
+ get_task,
17
+ grade_task,
18
+ reset_runtime_progress,
19
+ set_runtime_progress,
20
+ verification_outcome,
21
+ )
22
+ except ImportError:
23
+ from models import WebSecRepairAction, WebSecRepairObservation, WebSecRepairState # type: ignore
24
+ from server.challenge import ( # type: ignore
25
+ DEFAULT_TASK_ID,
26
+ TASKS,
27
+ get_task,
28
+ grade_task,
29
+ reset_runtime_progress,
30
+ set_runtime_progress,
31
+ verification_outcome,
32
+ )
33
+
34
+ MAX_STEPS = 6
35
+
36
+
37
+ class WebSecRepairEnvironment(
38
+ Environment[WebSecRepairAction, WebSecRepairObservation, WebSecRepairState]
39
+ ):
40
+ """Lean deterministic environment for vulnerability classification and repair."""
41
+
42
+ SUPPORTS_CONCURRENT_SESSIONS: bool = False
43
+
44
+ def __init__(self) -> None:
45
+ super().__init__()
46
+ default_task = get_task(DEFAULT_TASK_ID)
47
+ self._state = WebSecRepairState(
48
+ episode_id=str(uuid4()),
49
+ step_count=0,
50
+ task_id=default_task.id,
51
+ difficulty=default_task.difficulty,
52
+ inspected=False,
53
+ selected_vulnerability="",
54
+ applied_patch_id="",
55
+ exploit_test_passed=False,
56
+ functionality_test_passed=False,
57
+ submitted=False,
58
+ score=0.0,
59
+ )
60
+ set_runtime_progress(self._state.model_dump())
61
+
62
+ def reset(
63
+ self,
64
+ seed: int | None = None,
65
+ episode_id: str | None = None,
66
+ **kwargs: Any,
67
+ ) -> WebSecRepairObservation:
68
+ del seed
69
+
70
+ task_id = kwargs.get("task_id", DEFAULT_TASK_ID)
71
+ task = get_task(task_id)
72
+ reset_runtime_progress(task_id)
73
+
74
+ self._state = WebSecRepairState(
75
+ episode_id=episode_id or str(uuid4()),
76
+ step_count=0,
77
+ task_id=task.id,
78
+ difficulty=task.difficulty,
79
+ inspected=False,
80
+ selected_vulnerability="",
81
+ applied_patch_id="",
82
+ exploit_test_passed=False,
83
+ functionality_test_passed=False,
84
+ submitted=False,
85
+ score=0.0,
86
+ )
87
+ set_runtime_progress(self._state.model_dump())
88
+
89
+ return self._build_observation(
90
+ status_message="Task loaded. Use inspect to reveal the snippet, hint, and patch options.",
91
+ reward=0.0,
92
+ done=False,
93
+ )
94
+
95
+ def step(
96
+ self,
97
+ action: WebSecRepairAction,
98
+ timeout_s: float | None = None,
99
+ **kwargs: Any,
100
+ ) -> WebSecRepairObservation:
101
+ del timeout_s, kwargs
102
+
103
+ previous_score = self._state.score
104
+ self._state.step_count += 1
105
+ status_message = ""
106
+
107
+ if self._state.submitted:
108
+ status_message = "Episode already submitted. Reset before taking more actions."
109
+ elif action.action_type == "inspect":
110
+ self._state.inspected = True
111
+ status_message = "Inspection complete. The vulnerable snippet, scanner hint, and patch options are now visible."
112
+ elif action.action_type == "classify":
113
+ if not action.vulnerability_type:
114
+ status_message = "Classification failed: vulnerability_type is required."
115
+ else:
116
+ self._state.selected_vulnerability = action.vulnerability_type
117
+ status_message = f"Stored vulnerability classification {action.vulnerability_type!r}."
118
+ elif action.action_type == "apply_patch":
119
+ if not action.patch_id:
120
+ status_message = "Patch application failed: patch_id is required."
121
+ else:
122
+ valid_patch_ids = {option.id for option in TASKS[self._state.task_id]["definition"].patch_options}
123
+ if action.patch_id not in valid_patch_ids:
124
+ status_message = f"Patch application failed: unknown patch_id {action.patch_id!r}."
125
+ else:
126
+ self._state.applied_patch_id = action.patch_id
127
+ self._state.exploit_test_passed = False
128
+ self._state.functionality_test_passed = False
129
+ status_message = f"Applied patch template {action.patch_id!r}."
130
+ elif action.action_type == "verify":
131
+ exploit, functionality, verify_message = verification_outcome(
132
+ self._state.task_id,
133
+ self._state.applied_patch_id or None,
134
+ )
135
+ self._state.exploit_test_passed = exploit
136
+ self._state.functionality_test_passed = functionality
137
+ status_message = verify_message
138
+ elif action.action_type == "submit":
139
+ self._state.submitted = True
140
+ status_message = "Submission recorded."
141
+ else:
142
+ status_message = f"Unsupported action_type {action.action_type!r}."
143
+
144
+ report = grade_task(progress=self._state.model_dump())
145
+ self._state.score = report.score
146
+ set_runtime_progress(self._state.model_dump())
147
+
148
+ done = self._state.submitted or self._state.step_count >= MAX_STEPS
149
+ if self._state.submitted:
150
+ if report.passed:
151
+ status_message = "Submission recorded. Task solved."
152
+ else:
153
+ status_message = "Submission recorded, but the grader still reports an incomplete repair."
154
+ elif self._state.step_count >= MAX_STEPS:
155
+ status_message = f"{status_message} Max steps reached."
156
+
157
+ reward = round(self._state.score - previous_score, 2)
158
+ return self._build_observation(
159
+ status_message=status_message.strip(),
160
+ reward=reward,
161
+ done=done,
162
+ grader_report=report,
163
+ )
164
+
165
+ @property
166
+ def state(self) -> WebSecRepairState:
167
+ """Return the current environment state."""
168
+ return self._state
169
+
170
+ def get_metadata(self) -> EnvironmentMetadata:
171
+ """Return environment metadata for the OpenEnv UI."""
172
+ return EnvironmentMetadata(
173
+ name="WebSecRepairEnvironment",
174
+ description=(
175
+ "Deterministic OpenEnv environment with three web vulnerability repair tasks "
176
+ "covering SQL injection, XSS, and broken access control."
177
+ ),
178
+ version="0.1.0",
179
+ author="Codex",
180
+ )
181
+
182
+ def _build_observation(
183
+ self,
184
+ status_message: str,
185
+ reward: float,
186
+ done: bool,
187
+ grader_report: Any | None = None,
188
+ ) -> WebSecRepairObservation:
189
+ task = get_task(self._state.task_id)
190
+ if grader_report is None:
191
+ grader_report = grade_task(progress=self._state.model_dump())
192
+ visible_patch_options = task.patch_options if self._state.inspected else []
193
+ visible_code = task.code_snippet if self._state.inspected else ""
194
+ visible_hint = task.scanner_hint if self._state.inspected else ""
195
+ return WebSecRepairObservation(
196
+ task_id=task.id,
197
+ instruction=task.instruction,
198
+ code_snippet=visible_code,
199
+ scanner_hint=visible_hint,
200
+ status_message=status_message,
201
+ selected_vulnerability=self._state.selected_vulnerability,
202
+ applied_patch_id=self._state.applied_patch_id,
203
+ patch_options=visible_patch_options,
204
+ exploit_test_passed=self._state.exploit_test_passed,
205
+ functionality_test_passed=self._state.functionality_test_passed,
206
+ grader_passed=grader_report.passed,
207
+ reward=reward,
208
+ done=done,
209
+ metadata={
210
+ "step_count": self._state.step_count,
211
+ "score": self._state.score,
212
+ "max_steps": MAX_STEPS,
213
+ "grader": grader_report.model_dump(),
214
+ },
215
+ )
tests/test_websec_repair_env.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for the deterministic WebSec Repair environment."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from fastapi.testclient import TestClient
6
+
7
+ from websec_repair_env.models import WebSecRepairAction
8
+ from websec_repair_env.server.app import app
9
+ from websec_repair_env.server.websec_repair_environment import WebSecRepairEnvironment
10
+
11
+
12
+ TASK_CASES = [
13
+ ("sqli_login", "sql_injection", "parameterized_query", "strip_quotes"),
14
+ ("xss_comments", "xss", "html_escape", "remove_script_substring"),
15
+ ("broken_auth_admin", "broken_auth", "require_admin_role", "hide_admin_link"),
16
+ ]
17
+
18
+
19
+ def _solve_task(env: WebSecRepairEnvironment, task_id: str, vulnerability: str, patch_id: str):
20
+ env.reset(task_id=task_id)
21
+ env.step(WebSecRepairAction(action_type="inspect"))
22
+ env.step(
23
+ WebSecRepairAction(
24
+ action_type="classify",
25
+ vulnerability_type=vulnerability,
26
+ )
27
+ )
28
+ env.step(
29
+ WebSecRepairAction(
30
+ action_type="apply_patch",
31
+ patch_id=patch_id,
32
+ )
33
+ )
34
+ env.step(WebSecRepairAction(action_type="verify"))
35
+ return env.step(WebSecRepairAction(action_type="submit"))
36
+
37
+
38
+ def test_reset_and_state_smoke() -> None:
39
+ env = WebSecRepairEnvironment()
40
+ obs = env.reset(task_id="xss_comments")
41
+ assert obs.task_id == "xss_comments"
42
+ assert obs.code_snippet == ""
43
+ assert env.state.inspected is False
44
+
45
+ env.step(WebSecRepairAction(action_type="inspect"))
46
+ assert env.state.inspected is True
47
+ assert env.state.task_id == "xss_comments"
48
+
49
+ obs = env.reset(task_id="broken_auth_admin")
50
+ assert obs.task_id == "broken_auth_admin"
51
+ assert env.state.inspected is False
52
+ assert env.state.selected_vulnerability == ""
53
+ assert env.state.applied_patch_id == ""
54
+ assert env.state.exploit_test_passed is False
55
+
56
+
57
+ def test_happy_path_for_each_task() -> None:
58
+ for task_id, vulnerability, correct_patch, _ in TASK_CASES:
59
+ env = WebSecRepairEnvironment()
60
+ result = _solve_task(env, task_id, vulnerability, correct_patch)
61
+ assert result.done is True
62
+ assert result.grader_passed is True
63
+ assert result.exploit_test_passed is True
64
+ assert result.functionality_test_passed is True
65
+ assert env.state.score == 1.0
66
+
67
+
68
+ def test_wrong_patch_failure_for_each_task() -> None:
69
+ for task_id, vulnerability, _, wrong_patch in TASK_CASES:
70
+ env = WebSecRepairEnvironment()
71
+ result = _solve_task(env, task_id, vulnerability, wrong_patch)
72
+ assert result.done is True
73
+ assert result.grader_passed is False
74
+ assert env.state.score < 1.0
75
+ assert (
76
+ result.exploit_test_passed is False
77
+ or result.functionality_test_passed is False
78
+ )
79
+
80
+
81
+ def test_http_routes_return_expected_shapes() -> None:
82
+ client = TestClient(app)
83
+
84
+ tasks_response = client.get("/tasks")
85
+ assert tasks_response.status_code == 200
86
+ tasks_payload = tasks_response.json()
87
+ assert tasks_payload["environment"] == "websec_repair_env"
88
+ assert len(tasks_payload["tasks"]) == 3
89
+
90
+ baseline_response = client.get("/baseline", params={"task_id": "sqli_login"})
91
+ assert baseline_response.status_code == 200
92
+ baseline_payload = baseline_response.json()
93
+ assert len(baseline_payload["baselines"]) == 1
94
+ assert baseline_payload["baselines"][0]["task_id"] == "sqli_login"
95
+
96
+ grader_response = client.get("/grader", params={"task_id": "sqli_login"})
97
+ assert grader_response.status_code == 200
98
+ grader_payload = grader_response.json()
99
+ assert grader_payload["task_id"] == "sqli_login"
100
+ assert "checks" in grader_payload
uv.lock ADDED
The diff for this file is too large to render. See raw diff