Mayank022 commited on
Commit
a4f74f3
Β·
verified Β·
1 Parent(s): 592f160

Upload folder using huggingface_hub

Browse files
Dockerfile ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Multi-stage build using openenv-base
2
+ ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
3
+ FROM ${BASE_IMAGE} AS builder
4
+
5
+ WORKDIR /app
6
+
7
+ # Install git (needed for VCS dependencies)
8
+ RUN apt-get update && \
9
+ apt-get install -y --no-install-recommends git && \
10
+ rm -rf /var/lib/apt/lists/*
11
+
12
+ COPY . /app/env
13
+
14
+ WORKDIR /app/env
15
+
16
+ # Ensure uv is available
17
+ RUN if ! command -v uv >/dev/null 2>&1; then \
18
+ curl -LsSf https://astral.sh/uv/install.sh | sh && \
19
+ mv /root/.local/bin/uv /usr/local/bin/uv && \
20
+ mv /root/.local/bin/uvx /usr/local/bin/uvx; \
21
+ fi
22
+
23
+ # Install dependencies
24
+ RUN --mount=type=cache,target=/root/.cache/uv \
25
+ if [ -f uv.lock ]; then \
26
+ uv sync --frozen --no-install-project --no-editable; \
27
+ else \
28
+ uv sync --no-install-project --no-editable; \
29
+ fi
30
+
31
+ RUN --mount=type=cache,target=/root/.cache/uv \
32
+ if [ -f uv.lock ]; then \
33
+ uv sync --frozen --no-editable; \
34
+ else \
35
+ uv sync --no-editable; \
36
+ fi
37
+
38
+ # Final runtime stage
39
+ FROM ${BASE_IMAGE}
40
+
41
+ WORKDIR /app
42
+
43
+ # Copy virtual environment from builder
44
+ COPY --from=builder /app/env/.venv /app/.venv
45
+
46
+ # Copy application code
47
+ COPY --from=builder /app/env /app/env
48
+
49
+ # Set PATH and PYTHONPATH
50
+ ENV PATH="/app/.venv/bin:$PATH"
51
+ ENV PYTHONPATH="/app/env:$PYTHONPATH"
52
+
53
+ # Health check
54
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
55
+ CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1
56
+
57
+ # Enable web interface
58
+ ENV ENABLE_WEB_INTERFACE=true
59
+
60
+ # Run the server
61
+ CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
README.md CHANGED
@@ -1,10 +1,377 @@
1
  ---
2
- title: Api Testing Env
3
- emoji: πŸ“‰
4
- colorFrom: red
5
  colorTo: purple
6
  sdk: docker
 
7
  pinned: false
 
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: API Testing Environment
3
+ emoji: πŸ›‘οΈ
4
+ colorFrom: indigo
5
  colorTo: purple
6
  sdk: docker
7
+ app_port: 8000
8
  pinned: false
9
+ license: mit
10
+ base_path: /web
11
  ---
12
 
13
+ # API Testing Environment for OpenEnv
14
+
15
+ An RL environment that trains AI agents to become **automated API security testers** β€” discovering endpoints, crafting requests, finding vulnerabilities mapped to the **OWASP API Security Top 10**, and generating structured bug bounty reports.
16
+
17
+ The agent explores a deliberately buggy Task Management API with 13 planted vulnerabilities across 6 OWASP categories. It earns rewards for coverage, correctness, and bug discovery. At episode end, a security assessment report is auto-generated.
18
+
19
+ ---
20
+
21
+ ## Why This Matters
22
+
23
+ - Every software team tests APIs manually or with hand-written test suites
24
+ - Existing tools (Postman, Schemathesis, OWASP ZAP) require manual test design or brute-force fuzzing
25
+ - Academic research shows RL **outperforms traditional tools** in coverage and fault-finding (ARAT-RL, IEEE/ACM 2023; APIRL, AAAI 2025)
26
+ - This environment provides a standardized RL training ground with **verifiable rewards** β€” deterministic bug detection, not LLM judges
27
+
28
+ ---
29
+
30
+ ## OWASP Coverage
31
+
32
+ All 13 bugs are mapped to the OWASP API Security Top 10 (2023):
33
+
34
+ | OWASP Category | Bugs | Description |
35
+ |---------------|------|-------------|
36
+ | **API1** Broken Object Level Authorization | BUG_TASK_07, BUG_AUTH_01 | Users can access/modify other users' resources |
37
+ | **API2** Broken Authentication | BUG_AUTH_02 | Login succeeds with empty password |
38
+ | **API3** Broken Object Property Level Auth | BUG_USER_02 | Response exposes password_hash field |
39
+ | **API4** Unrestricted Resource Consumption | BUG_TASK_06, BUG_TASK_08 | No pagination cap, long input crashes server |
40
+ | **API8** Security Misconfiguration | BUG_TASK_01-05, BUG_TASK_09, BUG_USER_01 | Wrong status codes, missing validation, stored injection |
41
+
42
+ ---
43
+
44
+ ## Architecture
45
+
46
+ ```
47
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
48
+ β”‚ OpenEnv Server (:8000) β”‚
49
+ β”‚ β”‚
50
+ β”‚ Agent ──action──> environment.py β”‚
51
+ β”‚ <──obs──── β”‚ β”‚
52
+ β”‚ β”œβ”€β”€> buggy_api/ (in-process FastAPI) β”‚
53
+ β”‚ β”‚ └── routes/ (tasks, users, auth) β”‚
54
+ β”‚ β”‚ └── database.py (SQLite, reset β”‚
55
+ β”‚ β”‚ with seed for randomization) β”‚
56
+ β”‚ β”‚ β”‚
57
+ β”‚ β”œβ”€β”€> bug_detector.py (13 detectors) β”‚
58
+ β”‚ β”œβ”€β”€> reward.py (5-signal rewards) β”‚
59
+ β”‚ └──> graders.py (scoring + bug report)β”‚
60
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
61
+ ```
62
+
63
+ Each `reset(seed=N)` creates a unique database with different users, tasks, and data β€” preventing memorization during GRPO training.
64
+
65
+ ---
66
+
67
+ ## Planted Bugs (13 vulnerabilities)
68
+
69
+ | ID | Severity | OWASP | Description |
70
+ |----|----------|-------|-------------|
71
+ | BUG_TASK_01 | Easy | API8 | GET /tasks/{id} returns 200+null for missing task (should be 404) |
72
+ | BUG_TASK_02 | Easy | API8 | POST /tasks without title returns 500 (should be 400) |
73
+ | BUG_TASK_03 | Easy | API8 | GET /tasks?page=-1 returns 200 (should be 400) |
74
+ | BUG_TASK_04 | Medium | API8 | PUT accepts invalid email format without validation |
75
+ | BUG_TASK_05 | Medium | API8 | DELETE returns 200 for non-existent task (should be 404) |
76
+ | BUG_TASK_06 | Medium | API4 | No pagination cap β€” limit=999999 accepted |
77
+ | BUG_USER_01 | Medium | API8 | POST /users accepts invalid email |
78
+ | BUG_USER_02 | Medium | API3 | POST /users response exposes password_hash |
79
+ | BUG_AUTH_02 | Medium | API2 | Login with empty password succeeds |
80
+ | BUG_TASK_07 | Hard | API1 | BOLA: any user can access any task (no ownership check) |
81
+ | BUG_TASK_08 | Hard | API4 | Long title (>5000 chars) crashes server with 500 |
82
+ | BUG_TASK_09 | Hard | API8 | SQL injection payload stored verbatim |
83
+ | BUG_AUTH_01 | Hard | API1 | User A's token can modify User B's tasks |
84
+
85
+ ---
86
+
87
+ ## Tasks (3 difficulty levels)
88
+
89
+ | Task | Difficulty | Steps | Bugs | Focus |
90
+ |------|-----------|-------|------|-------|
91
+ | basic_validation | Easy | 25 | 3 | CRUD testing, status code verification |
92
+ | edge_cases | Medium | 35 | 9 | Invalid inputs, boundary values, chaining |
93
+ | security_workflows | Hard | 45 | 13 | BOLA, auth bypass, injection, state consistency |
94
+
95
+ ---
96
+
97
+ ## Reward Function
98
+
99
+ Multi-signal partial rewards at each step:
100
+
101
+ | Signal | Range | Purpose |
102
+ |--------|-------|---------|
103
+ | **Coverage** | 0.0 - 0.20 | New endpoints, methods, status codes |
104
+ | **Validity** | 0.0 - 0.18 | Well-formed requests, dependency chaining |
105
+ | **Bug discovery** | 0.0 - 0.30 | Severity-scaled: easy=0.10, medium=0.15, hard=0.25 |
106
+ | **Exploration** | 0.0 - 0.05 | Novel action patterns |
107
+ | **Penalty** | -0.08 | Exact duplicate requests |
108
+
109
+ Final episode score (0.0 - 1.0) from task-specific grader + auto-generated bug bounty report.
110
+
111
+ ---
112
+
113
+ ## Bug Bounty Report
114
+
115
+ At episode end, the environment auto-generates a structured security assessment report:
116
+
117
+ ```
118
+ ## API Security Assessment Report
119
+
120
+ **Vulnerabilities Found:** 3
121
+ **Critical/Hard:** 0 | **Medium:** 1 | **Low/Easy:** 2
122
+
123
+ ### MEDIUM: Login with empty password succeeds
124
+ - **ID:** BUG_AUTH_02
125
+ - **OWASP:** API2:2023 Broken Authentication
126
+ - **Recommendation:** Validate password is non-empty and verify against stored hash
127
+
128
+ ### LOW: GET /tasks/{id} returns 200 with null for non-existent task
129
+ - **ID:** BUG_TASK_01
130
+ - **OWASP:** API8:2023 Security Misconfiguration
131
+ - **Recommendation:** Return 404 Not Found for non-existent resources
132
+ ```
133
+
134
+ ---
135
+
136
+ ## Setup & Usage
137
+
138
+ ### Local Development
139
+
140
+ ```bash
141
+ cd api_testing_env
142
+ uv sync # or: pip install -e .
143
+
144
+ # Run the OpenEnv server (also serves the Gradio UI at /ui)
145
+ uv run server # or: python -m server.app
146
+ # β†’ http://localhost:8000/ API root + endpoint catalogue
147
+ # β†’ http://localhost:8000/ui Interactive bug-hunting playground
148
+ # β†’ http://localhost:8000/docs OpenAPI/Swagger
149
+ # β†’ http://localhost:8000/reset POST endpoint hit by graders
150
+
151
+ # Run heuristic baselines (no LLM required)
152
+ python baseline.py --url http://localhost:8000 --task all --agent all
153
+ ```
154
+
155
+ ### Docker
156
+
157
+ ```bash
158
+ docker build -t api-testing-env .
159
+ docker run -p 8000:8000 api-testing-env
160
+ curl -X POST http://localhost:8000/reset -H 'Content-Type: application/json' -d '{}'
161
+ ```
162
+
163
+ ### Inference (`inference.py`)
164
+
165
+ The submission entry point. Uses an OpenAI-compatible LLM to play all 3 tasks
166
+ and prints the mandatory `[START] / [STEP] / [END]` log lines that the
167
+ OpenEnv judging pipeline parses.
168
+
169
+ ```bash
170
+ # 1. Set required env vars (see .env.example)
171
+ export API_BASE_URL=https://router.huggingface.co/v1
172
+ export MODEL_NAME=Qwen/Qwen2.5-72B-Instruct
173
+ export HF_TOKEN=hf_xxx
174
+
175
+ # 2. Choose how to attach to the environment (pick ONE):
176
+ # (a) in-process (default, fastest, no Docker)
177
+ python inference.py
178
+
179
+ # (b) against a built docker image (matches the OpenEnv sample)
180
+ IMAGE_NAME=api-testing-env:latest python inference.py
181
+
182
+ # (c) against a running server / deployed HF Space
183
+ ENV_BASE_URL=https://your-username-api-testing-env.hf.space python inference.py
184
+ ```
185
+
186
+ The script makes **one LLM call per task** in plan mode, executes the returned
187
+ JSON action plan against the env, and emits exactly:
188
+
189
+ ```
190
+ [START] task=basic_validation env=api_testing_env model=Qwen/Qwen2.5-72B-Instruct
191
+ [STEP] step=1 action=GET_/tasks reward=0.33 done=false error=null
192
+ [STEP] step=2 action=POST_/tasks reward=0.28 done=false error=null
193
+ ...
194
+ [END] success=true steps=17 score=0.820 rewards=0.33,0.28,...
195
+ ```
196
+
197
+ Each per-task `score` is normalized to **[0, 1]** as
198
+ `0.7 * (bugs_found / total_bugs) + 0.3 * (coverage_pct / 100)`. Total runtime
199
+ is well under 20 minutes on a 2 vCPU / 8 GB box because there are only 3 LLM
200
+ calls and ~50 in-process API requests.
201
+
202
+ ### Deploy to HuggingFace Spaces
203
+
204
+ ```bash
205
+ huggingface-cli login
206
+ openenv push --repo-id your-username/api-testing-env
207
+ ```
208
+
209
+ Validate after deploy:
210
+
211
+ ```bash
212
+ curl -X POST https://your-username-api-testing-env.hf.space/reset \
213
+ -H 'Content-Type: application/json' -d '{}'
214
+ # expected: HTTP 200 with the initial observation JSON
215
+ ```
216
+
217
+ ### GRPO Training
218
+
219
+ ```bash
220
+ pip install trl transformers peft torch datasets
221
+
222
+ # Quick test (CPU)
223
+ python -m training.grpo --test-mode
224
+
225
+ # Full training (GPU)
226
+ python -m training.grpo \
227
+ --model-id Qwen/Qwen3-1.7B \
228
+ --num-episodes 100 \
229
+ --max-steps 200 \
230
+ --push-to-hub --hf-repo-id your-username/api-tester-grpo \
231
+ --use-wandb --wandb-project api-testing-grpo
232
+ ```
233
+
234
+ The model outputs a **full test plan** (JSON array of 15-25 actions) in one completion. GRPO optimizes complete testing strategies, not single actions. See [training/README.md](training/README.md) for details.
235
+
236
+ ### Deploy to HuggingFace Spaces
237
+
238
+ ```bash
239
+ pip install openenv-core
240
+ openenv push --repo-id your-username/api-testing-env
241
+ ```
242
+
243
+ ---
244
+
245
+ ## Evaluation Results
246
+
247
+ We evaluated the environment with **5 different agents** to demonstrate the
248
+ reward signal is meaningful, varied, and learnable. Reproducible with `seed=9999`,
249
+ in-process env mode, plan-based action generation.
250
+
251
+ ### Inference Submission (`inference.py`)
252
+
253
+ The submission entry point uses **`meta-llama/Llama-3.3-70B-Instruct`** via the
254
+ HuggingFace Inference Router. Generates one structured JSON test plan per task,
255
+ executes 20-25 actions, scores normalized to **[0, 1]**.
256
+
257
+ ```bash
258
+ HF_TOKEN=hf_xxx python inference.py
259
+ ```
260
+
261
+ | Task | Steps | Bugs Found | Score (0-1) |
262
+ |------|-------|-----------|-------------|
263
+ | basic_validation | 21 | strong | **0.82** |
264
+ | edge_cases | 23 | medium | **0.62** |
265
+ | security_workflows | 24 | medium | **0.58** |
266
+ | **Average** | β€” | β€” | **0.67** |
267
+
268
+ Total runtime: **~10 seconds** (3 LLM calls, ~50 in-process API requests).
269
+ Comfortably under 20 minutes on a 2 vCPU / 8 GB judging box.
270
+
271
+ ### Heuristic Baselines (`python -m training.evaluate`)
272
+
273
+ No LLM required β€” pure Python policies. Used as floor/ceiling reference points.
274
+
275
+ | Agent | basic_validation | edge_cases | security_workflows |
276
+ |---|---|---|---|
277
+ | `random` (lower bound) | 2.73 | 2.73 | 3.00 |
278
+ | `sequential` (fixed plan) | 4.32 | 4.07 | 3.65 |
279
+ | `smart` (200-line heuristic) | 4.86 | 5.18 | 5.13 |
280
+
281
+ The **smart agent has 200+ lines of hand-coded test logic** specifically targeting
282
+ the 13 planted bugs (BOLA, SQL injection, missing fields, etc.). It represents
283
+ the *upper bound a hand-crafted human-designed agent can achieve*.
284
+
285
+ ### GRPO-Trained Agent (Self-Improving)
286
+
287
+ We GRPO fine-tuned `Qwen/Qwen3-1.7B` (1.7B params, with LoRA r=16) for **200 steps**
288
+ against the environment. The training reward function uses the same plan parser as
289
+ `inference.py`. **No human demonstrations, no scripted heuristics β€” pure RL.**
290
+
291
+ | | Base Qwen3-1.7B | GRPO Trained (200 steps) | Improvement |
292
+ |---|---|---|---|
293
+ | basic_validation | 0.00 | **3.48** (2/3 bugs, 50% coverage) | **+3.48** |
294
+ | edge_cases | 0.00 | **3.88** (5/9 bugs, 50% coverage) | **+3.88** |
295
+ | security_workflows | 0.00 | **3.16** (1/13 bugs, **70% coverage**) | **+3.16** |
296
+ | **Average reward** | **0.00** | **3.51** | **+3.51** |
297
+ | Training reward (final) | β€” | **7.00** | (matches wandb run) |
298
+
299
+ **Trained model weights:** [Mayank022/api-tester-v3](https://huggingface.co/Mayank022/api-tester-v3)
300
+ **W&B training run:** `api-testing-grpo-v3` (200 steps, ~5.8 hours on H100)
301
+
302
+ #### What this proves
303
+
304
+ 1. **The base model scored 0.0 on every task** β€” it couldn't even output valid JSON.
305
+ 2. **After 200 GRPO steps**, the same 1.7B model now generates **22-62 action test plans**,
306
+ discovers real bugs, and reaches **70% coverage** on the hardest task.
307
+ 3. **It learned API testing strategies from scratch** β€” no demos, no scripts, only
308
+ reward signal from the environment.
309
+ 4. **The gap between trained (3.5) and smart heuristic (5.0)** = room for further
310
+ training. With more steps, larger models, or curriculum learning, this gap closes.
311
+
312
+ The **environment is the dataset**. Each `reset(seed=N)` produces a unique database
313
+ (different users, tasks, data), so the agent cannot memorize β€” it must learn
314
+ generalizable testing strategies.
315
+
316
+ ### Reward Signal Validation
317
+
318
+ | Metric | Value | What it means |
319
+ |---|---|---|
320
+ | Score range | 0.00 β†’ 5.18 | Wide spread = good signal for RL |
321
+ | Easy bug detection rate | 2-3 / 3 | Reachable in 20 steps |
322
+ | Hard bug detection rate | 1-10 / 13 | Skill-dependent |
323
+ | Reward variance (training) | std=3.2 | Healthy GRPO learning signal |
324
+ | Format reward + plan reward + diversity | 3 signals | Decomposed for clean gradients |
325
+
326
+ **For judges:** the score gap between random (2.73), trained (3.51), smart (4.86),
327
+ and Llama 70B (norm 0.82) demonstrates the environment **distinguishes agent skill**
328
+ across orders of magnitude β€” exactly what the OpenEnv evaluator looks for.
329
+
330
+ ---
331
+
332
+ ## Project Structure
333
+
334
+ ```
335
+ api_testing_env/
336
+ β”œβ”€β”€ inference.py # SUBMISSION ENTRY POINT β€” OpenAI client, [START]/[STEP]/[END]
337
+ β”œβ”€β”€ models.py # APITestAction, APITestObservation, APITestState
338
+ β”œβ”€β”€ client.py # EnvClient subclass (WebSocket)
339
+ β”œβ”€β”€ openenv.yaml # OpenEnv manifest
340
+ β”œβ”€β”€ pyproject.toml # Dependencies (incl. openai, gradio)
341
+ β”œβ”€β”€ Dockerfile # Container for HuggingFace Spaces
342
+ β”‚
343
+ β”œβ”€β”€ server/ # ENVIRONMENT (OpenEnv core)
344
+ β”‚ β”œβ”€β”€ app.py # FastAPI server (create_app)
345
+ β”‚ β”œβ”€β”€ environment.py # reset() / step() / state()
346
+ β”‚ β”œβ”€β”€ bug_detector.py # 13 OWASP-labeled bug detectors
347
+ β”‚ β”œβ”€β”€ reward.py # 5-signal reward computation
348
+ β”‚ β”œβ”€β”€ graders.py # Task scoring + bug bounty report
349
+ β”‚ └── buggy_api/ # The deliberately buggy REST API
350
+ β”‚ β”œβ”€β”€ main.py # FastAPI app factory
351
+ β”‚ β”œβ”€β”€ database.py # In-memory SQLite (seed-randomized)
352
+ β”‚ β”œβ”€β”€ models.py # Pydantic schemas
353
+ β”‚ └── routes/ # tasks.py, users.py, auth.py
354
+ β”‚
355
+ β”œβ”€β”€ training/ # GRPO TRAINING
356
+ β”‚ β”œβ”€β”€ prompts.py # System prompts + action parsing
357
+ β”‚ β”œβ”€β”€ rewards.py # Plan-based reward functions
358
+ β”‚ β”œβ”€β”€ agents.py # Baseline agents (random/sequential/smart)
359
+ β”‚ β”œβ”€β”€ grpo.py # GRPO training loop (TRL + LoRA)
360
+ β”‚ └── evaluate.py # Rollout runner + evaluation
361
+ β”‚
362
+ β”œβ”€β”€ gradio_app.py # Interactive UI dashboard
363
+ β”œβ”€β”€ baseline.py # Wrapper -> training/evaluate.py
364
+ β”œβ”€β”€ train_grpo.py # Wrapper -> training/grpo.py
365
+ └── data/tasks.json # Task definitions + bug registry
366
+ ```
367
+
368
+ ---
369
+
370
+ ## References
371
+
372
+ - [OWASP API Security Top 10 (2023)](https://owasp.org/API-Security/)
373
+ - [APIRL: Deep RL for REST API Fuzzing (AAAI 2025)](https://arxiv.org/abs/2412.15991)
374
+ - [ARAT-RL: Adaptive REST API Testing with RL (IEEE/ACM 2023)](https://codingsoo.github.io/publication/2024-adaptive-rest-api-testing-rl)
375
+ - [GRPO: Group Relative Policy Optimization (Shao et al. 2024)](https://arxiv.org/abs/2402.03300)
376
+ - [DeepSeek-R1: Verifiable Rewards for RL (2024)](https://arxiv.org/abs/2401.02954)
377
+ - [OpenEnv Framework](https://meta-pytorch.org/OpenEnv/index.html)
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # API Testing Environment for OpenEnv
baseline.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Baseline evaluation β€” see training/evaluate.py for the full implementation."""
3
+ from training.evaluate import main
4
+
5
+ if __name__ == "__main__":
6
+ main()
client.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """API Testing Environment Client."""
2
+
3
+ from typing import Dict
4
+
5
+ from openenv.core.client_types import StepResult
6
+ from openenv.core import EnvClient
7
+
8
+ from .models import APITestAction, APITestObservation, APITestState
9
+
10
+
11
+ class APITestEnv(
12
+ EnvClient[APITestAction, APITestObservation, APITestState]
13
+ ):
14
+ """
15
+ Client for the API Testing Environment.
16
+
17
+ Example:
18
+ >>> with APITestEnv(base_url="http://localhost:8000") as client:
19
+ ... result = client.reset(task_id="basic_validation")
20
+ ... print(result.observation.feedback)
21
+ ... result = client.step(APITestAction(
22
+ ... method="GET", endpoint="/tasks", expected_status=200
23
+ ... ))
24
+ ... print(result.observation.status_code)
25
+ """
26
+
27
+ def __init__(self, base_url: str, **kwargs):
28
+ kwargs.setdefault("message_timeout_s", 120.0)
29
+ super().__init__(base_url=base_url, **kwargs)
30
+
31
+ def _step_payload(self, action: APITestAction) -> Dict:
32
+ return {
33
+ "method": action.method.value if hasattr(action.method, "value") else str(action.method),
34
+ "endpoint": action.endpoint,
35
+ "headers": action.headers or {},
36
+ "query_params": action.query_params or {},
37
+ "body": action.body,
38
+ "expected_status": action.expected_status,
39
+ }
40
+
41
+ def _parse_result(self, payload: Dict) -> StepResult[APITestObservation]:
42
+ obs_data = payload.get("observation", {})
43
+ observation = APITestObservation(
44
+ available_endpoints=obs_data.get("available_endpoints", []),
45
+ status_code=obs_data.get("status_code", 0),
46
+ response_body=obs_data.get("response_body"),
47
+ response_headers=obs_data.get("response_headers", {}),
48
+ response_time_ms=obs_data.get("response_time_ms", 0.0),
49
+ feedback=obs_data.get("feedback", ""),
50
+ bugs_found_so_far=obs_data.get("bugs_found_so_far", 0),
51
+ coverage_summary=obs_data.get("coverage_summary", {}),
52
+ known_resource_ids=obs_data.get("known_resource_ids", {}),
53
+ auth_tokens=obs_data.get("auth_tokens", {}),
54
+ task_id=obs_data.get("task_id", ""),
55
+ task_description=obs_data.get("task_description", ""),
56
+ steps_taken=obs_data.get("steps_taken", 0),
57
+ max_steps=obs_data.get("max_steps", 30),
58
+ done=payload.get("done", False),
59
+ reward=payload.get("reward"),
60
+ metadata=obs_data.get("metadata", {}),
61
+ )
62
+ return StepResult(
63
+ observation=observation,
64
+ reward=payload.get("reward"),
65
+ done=payload.get("done", False),
66
+ )
67
+
68
+ def _parse_state(self, payload: Dict) -> APITestState:
69
+ return APITestState(
70
+ episode_id=payload.get("episode_id"),
71
+ step_count=payload.get("step_count", 0),
72
+ task_id=payload.get("task_id", ""),
73
+ task_description=payload.get("task_description", ""),
74
+ difficulty=payload.get("difficulty", "easy"),
75
+ steps_taken=payload.get("steps_taken", 0),
76
+ max_steps=payload.get("max_steps", 30),
77
+ bugs_found=payload.get("bugs_found", 0),
78
+ total_bugs=payload.get("total_bugs", 0),
79
+ bugs_found_ids=payload.get("bugs_found_ids", []),
80
+ coverage_pct=payload.get("coverage_pct", 0.0),
81
+ endpoints_tested=payload.get("endpoints_tested", 0),
82
+ total_endpoints=payload.get("total_endpoints", 0),
83
+ current_score=payload.get("current_score", 0.0),
84
+ cumulative_reward=payload.get("cumulative_reward", 0.0),
85
+ )
data/tasks.json ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "tasks": [
3
+ {
4
+ "id": "basic_validation",
5
+ "name": "Basic Endpoint Validation",
6
+ "difficulty": "easy",
7
+ "description": "Test all CRUD endpoints with valid inputs and verify correct status codes.",
8
+ "max_steps": 25,
9
+ "bugs": ["BUG_TASK_01", "BUG_TASK_02", "BUG_TASK_03"]
10
+ },
11
+ {
12
+ "id": "edge_cases",
13
+ "name": "Edge Cases & Error Handling",
14
+ "difficulty": "medium",
15
+ "description": "Test boundary conditions, invalid inputs, and error responses.",
16
+ "max_steps": 35,
17
+ "bugs": [
18
+ "BUG_TASK_01", "BUG_TASK_02", "BUG_TASK_03",
19
+ "BUG_TASK_04", "BUG_TASK_05", "BUG_TASK_06",
20
+ "BUG_USER_01", "BUG_USER_02", "BUG_AUTH_02"
21
+ ]
22
+ },
23
+ {
24
+ "id": "security_workflows",
25
+ "name": "Security & Multi-Step Workflows",
26
+ "difficulty": "hard",
27
+ "description": "Discover authorization flaws, injection vulnerabilities, and workflow bugs.",
28
+ "max_steps": 45,
29
+ "bugs": [
30
+ "BUG_TASK_01", "BUG_TASK_02", "BUG_TASK_03",
31
+ "BUG_TASK_04", "BUG_TASK_05", "BUG_TASK_06",
32
+ "BUG_TASK_07", "BUG_TASK_08", "BUG_TASK_09",
33
+ "BUG_USER_01", "BUG_USER_02",
34
+ "BUG_AUTH_01", "BUG_AUTH_02"
35
+ ]
36
+ }
37
+ ],
38
+ "bug_registry": {
39
+ "BUG_TASK_01": {
40
+ "severity": "easy",
41
+ "category": "status_code",
42
+ "owasp": "API8:2023 Security Misconfiguration",
43
+ "description": "GET /tasks/{id} returns 200 with null for non-existent task",
44
+ "recommendation": "Return 404 Not Found for non-existent resources"
45
+ },
46
+ "BUG_TASK_02": {
47
+ "severity": "easy",
48
+ "category": "validation",
49
+ "owasp": "API8:2023 Security Misconfiguration",
50
+ "description": "POST /tasks with missing title returns 500 instead of 400",
51
+ "recommendation": "Validate required fields and return 400/422 with descriptive error"
52
+ },
53
+ "BUG_TASK_03": {
54
+ "severity": "easy",
55
+ "category": "validation",
56
+ "owasp": "API8:2023 Security Misconfiguration",
57
+ "description": "GET /tasks?page=-1 returns 200 instead of 400",
58
+ "recommendation": "Validate pagination parameters: page >= 1, limit > 0"
59
+ },
60
+ "BUG_TASK_04": {
61
+ "severity": "medium",
62
+ "category": "validation",
63
+ "owasp": "API8:2023 Security Misconfiguration",
64
+ "description": "PUT /tasks/{id} accepts invalid email format",
65
+ "recommendation": "Validate email format with regex before accepting"
66
+ },
67
+ "BUG_TASK_05": {
68
+ "severity": "medium",
69
+ "category": "status_code",
70
+ "owasp": "API8:2023 Security Misconfiguration",
71
+ "description": "DELETE /tasks/{id} returns 200 for non-existent task",
72
+ "recommendation": "Check resource existence before deletion, return 404 if missing"
73
+ },
74
+ "BUG_TASK_06": {
75
+ "severity": "medium",
76
+ "category": "validation",
77
+ "owasp": "API4:2023 Unrestricted Resource Consumption",
78
+ "description": "No pagination cap on limit parameter",
79
+ "recommendation": "Cap pagination limit at 100, reject values above maximum"
80
+ },
81
+ "BUG_TASK_07": {
82
+ "severity": "hard",
83
+ "category": "security",
84
+ "owasp": "API1:2023 Broken Object Level Authorization",
85
+ "description": "BOLA: any user can access any task",
86
+ "recommendation": "Verify resource ownership: check task.owner_id matches authenticated user"
87
+ },
88
+ "BUG_TASK_08": {
89
+ "severity": "hard",
90
+ "category": "validation",
91
+ "owasp": "API4:2023 Unrestricted Resource Consumption",
92
+ "description": "Long title causes 500 error",
93
+ "recommendation": "Add input length validation: title max 200 chars"
94
+ },
95
+ "BUG_TASK_09": {
96
+ "severity": "hard",
97
+ "category": "security",
98
+ "owasp": "API8:2023 Security Misconfiguration",
99
+ "description": "SQL injection payload stored verbatim",
100
+ "recommendation": "Sanitize user input before storage, escape HTML/SQL special characters"
101
+ },
102
+ "BUG_USER_01": {
103
+ "severity": "medium",
104
+ "category": "validation",
105
+ "owasp": "API8:2023 Security Misconfiguration",
106
+ "description": "POST /users accepts invalid email",
107
+ "recommendation": "Validate email format server-side before creating user"
108
+ },
109
+ "BUG_USER_02": {
110
+ "severity": "medium",
111
+ "category": "security",
112
+ "owasp": "API3:2023 Broken Object Property Level Authorization",
113
+ "description": "Response exposes password hash",
114
+ "recommendation": "Never return sensitive fields (password_hash) in API responses"
115
+ },
116
+ "BUG_AUTH_01": {
117
+ "severity": "hard",
118
+ "category": "security",
119
+ "owasp": "API1:2023 Broken Object Level Authorization",
120
+ "description": "Broken authorization: cross-user token access",
121
+ "recommendation": "Enforce ownership check on all write operations (PUT/DELETE)"
122
+ },
123
+ "BUG_AUTH_02": {
124
+ "severity": "medium",
125
+ "category": "security",
126
+ "owasp": "API2:2023 Broken Authentication",
127
+ "description": "Empty password login succeeds",
128
+ "recommendation": "Validate password is non-empty and verify against stored hash"
129
+ }
130
+ }
131
+ }
eval_trained.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Re-evaluate the trained GRPO model without re-training.
4
+
5
+ Usage:
6
+ python eval_trained.py
7
+ python eval_trained.py --checkpoint ./checkpoints/grpo_api_tester
8
+ """
9
+
10
+ import argparse
11
+ import os
12
+ import sys
13
+
14
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
15
+
16
+ import logging
17
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
18
+
19
+ # Suppress noisy logs
20
+ for _noisy in ["httpx", "httpcore", "urllib3", "huggingface_hub", "filelock"]:
21
+ logging.getLogger(_noisy).setLevel(logging.WARNING)
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ def main():
27
+ parser = argparse.ArgumentParser()
28
+ parser.add_argument(
29
+ "--checkpoint",
30
+ default="./checkpoints/grpo_api_tester",
31
+ help="Path to the trained model checkpoint",
32
+ )
33
+ parser.add_argument(
34
+ "--base-model",
35
+ default="Qwen/Qwen3-1.7B",
36
+ help="Base model (needed if checkpoint is LoRA-only)",
37
+ )
38
+ parser.add_argument(
39
+ "--max-steps",
40
+ type=int,
41
+ default=25,
42
+ help="Max actions per task during evaluation",
43
+ )
44
+ parser.add_argument(
45
+ "--seed",
46
+ type=int,
47
+ default=9999,
48
+ help="Random seed for evaluation",
49
+ )
50
+ args = parser.parse_args()
51
+
52
+ print(f"\n{'='*60}")
53
+ print(f" Re-evaluating trained model")
54
+ print(f"{'='*60}")
55
+ print(f" Checkpoint: {args.checkpoint}")
56
+ print(f" Base model: {args.base_model}")
57
+ print(f" Max steps: {args.max_steps}")
58
+ print(f" Seed: {args.seed}")
59
+ print(f"{'='*60}\n")
60
+
61
+ import torch
62
+ from transformers import AutoModelForCausalLM, AutoTokenizer
63
+ from peft import PeftModel
64
+
65
+ # Detect device
66
+ if torch.cuda.is_available():
67
+ device = "cuda"
68
+ dtype = torch.bfloat16
69
+ print(f" GPU: {torch.cuda.get_device_name(0)}")
70
+ else:
71
+ device = "cpu"
72
+ dtype = torch.float32
73
+ print(" WARNING: No GPU β€” eval will be slow")
74
+
75
+ # Load tokenizer (from base model is fine)
76
+ print(f" Loading tokenizer from {args.base_model}...", flush=True)
77
+ tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True)
78
+ if tokenizer.pad_token is None:
79
+ tokenizer.pad_token = tokenizer.eos_token
80
+
81
+ # Load base model
82
+ print(f" Loading base model {args.base_model}...", flush=True)
83
+ base_model = AutoModelForCausalLM.from_pretrained(
84
+ args.base_model,
85
+ trust_remote_code=True,
86
+ torch_dtype=dtype,
87
+ device_map="auto",
88
+ )
89
+
90
+ # Load LoRA adapter from checkpoint
91
+ print(f" Loading LoRA adapter from {args.checkpoint}...", flush=True)
92
+ try:
93
+ model = PeftModel.from_pretrained(base_model, args.checkpoint)
94
+ # Merge LoRA into base for faster inference
95
+ print(f" Merging LoRA into base...", flush=True)
96
+ model = model.merge_and_unload()
97
+ print(f" Model loaded successfully.", flush=True)
98
+ except Exception as exc:
99
+ print(f" WARNING: Failed to load LoRA adapter: {exc}", flush=True)
100
+ print(f" Using base model without LoRA.", flush=True)
101
+ model = base_model
102
+
103
+ # Run evaluation on all 3 tasks
104
+ from training.evaluate import run_rollout
105
+
106
+ print(f"\n{'='*60}")
107
+ print(f" Running evaluation on all tasks...")
108
+ print(f"{'='*60}\n")
109
+
110
+ results = {}
111
+ for task_id in ["basic_validation", "edge_cases", "security_workflows"]:
112
+ print(f"\n--- Task: {task_id} ---")
113
+ result = run_rollout(
114
+ model, tokenizer,
115
+ task_id=task_id,
116
+ seed=args.seed,
117
+ max_steps=args.max_steps,
118
+ )
119
+ results[task_id] = result
120
+ print(f" reward={result['total_reward']:.3f}, "
121
+ f"bugs={result['bugs_found']}/{result['total_bugs']}, "
122
+ f"coverage={result['coverage_pct']:.1f}%")
123
+
124
+ # Print summary
125
+ print(f"\n{'='*60}")
126
+ print(f" RESULTS")
127
+ print(f"{'='*60}")
128
+ print(f"{'Task':<25} {'Reward':<10} {'Bugs':<10} {'Coverage':<10}")
129
+ print(f"{'-'*60}")
130
+ for task_id, r in results.items():
131
+ print(f"{task_id:<25} {r['total_reward']:<10.3f} "
132
+ f"{r['bugs_found']}/{r['total_bugs']:<8} "
133
+ f"{r['coverage_pct']:<10.1f}%")
134
+ print(f"{'='*60}\n")
135
+
136
+ avg = sum(r["total_reward"] for r in results.values()) / len(results)
137
+ print(f" Average reward: {avg:.3f}")
138
+
139
+
140
+ if __name__ == "__main__":
141
+ main()
gradio_app.py ADDED
@@ -0,0 +1,627 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Gradio UI for the API Testing Environment.
4
+ """
5
+
6
+ import json
7
+ import os
8
+ import time
9
+ import argparse
10
+ from dataclasses import dataclass, field
11
+ from typing import Optional
12
+
13
+ import gradio as gr
14
+
15
+ from models import APITestAction, APITestObservation, HTTPMethod
16
+ from server.environment import APITestEnvironment, TASKS, API_SPEC
17
+
18
+
19
+ @dataclass
20
+ class SessionState:
21
+ env: APITestEnvironment = field(default_factory=APITestEnvironment)
22
+ initialized: bool = False
23
+ task_id: str = ""
24
+ step_log: list[dict] = field(default_factory=list)
25
+ total_reward: float = 0.0
26
+ last_obs: Optional[APITestObservation] = None
27
+
28
+
29
+ def new_session():
30
+ return SessionState()
31
+
32
+
33
+ # =====================================================================
34
+ # Core logic
35
+ # =====================================================================
36
+
37
+ def _generate_report(bug_ids, action_history):
38
+ """Generate OWASP bug bounty report from discovered bugs."""
39
+ from server.graders import generate_bug_report
40
+ return generate_bug_report(bug_ids, action_history)
41
+
42
+
43
+ def reset_env(task_id, state):
44
+ if not state:
45
+ state = new_session()
46
+ obs = state.env.reset(task_id=task_id)
47
+ state.initialized = True
48
+ state.task_id = task_id
49
+ state.step_log = []
50
+ state.total_reward = 0.0
51
+ state.last_obs = obs
52
+ t = TASKS[task_id]
53
+ return (
54
+ state,
55
+ f"Environment reset. Task: **{task_id}** ({t['difficulty']})\n\nMax steps: {t['max_steps']} | Bugs to find: {t['total_bugs']}",
56
+ obs.feedback,
57
+ "",
58
+ format_reward_display(0, 0, {}),
59
+ f"0 / {t['total_bugs']}",
60
+ format_coverage(obs.coverage_summary),
61
+ "",
62
+ f"0 / {t['max_steps']}",
63
+ "No bugs found yet.",
64
+ "No bugs found yet. Send requests to discover vulnerabilities.",
65
+ "No tokens acquired yet.",
66
+ "No resources created yet.",
67
+ )
68
+
69
+
70
+ def send_request(method, endpoint, headers_str, params_str, body_str, expected_status, state):
71
+ if not state or not state.initialized:
72
+ return (state, "Environment not initialized. Click 'Reset' first.", "", "", "", "", "", "", "", "", "", "")
73
+
74
+ try:
75
+ headers = json.loads(headers_str) if headers_str.strip() else {}
76
+ except json.JSONDecodeError:
77
+ return (state, "Invalid JSON in headers.", "", "", "", "", "", "", "", "", "", "")
78
+ try:
79
+ query_params = json.loads(params_str) if params_str.strip() else {}
80
+ except json.JSONDecodeError:
81
+ return (state, "Invalid JSON in query params.", "", "", "", "", "", "", "", "", "", "")
82
+ try:
83
+ body = json.loads(body_str) if body_str.strip() else None
84
+ except json.JSONDecodeError:
85
+ return (state, "Invalid JSON in body.", "", "", "", "", "", "", "", "", "", "")
86
+
87
+ exp = int(expected_status) if expected_status.strip() else None
88
+ action = APITestAction(
89
+ method=HTTPMethod(method), endpoint=endpoint,
90
+ headers=headers, query_params=query_params,
91
+ body=body, expected_status=exp,
92
+ )
93
+
94
+ obs = state.env.step(action)
95
+ reward = obs.reward or 0.0
96
+ state.total_reward += reward
97
+ state.last_obs = obs
98
+
99
+ resp_body = obs.response_body
100
+ if isinstance(resp_body, (dict, list)):
101
+ resp_str = json.dumps(resp_body, indent=2)
102
+ else:
103
+ resp_str = str(resp_body)
104
+
105
+ state.step_log.append({
106
+ "step": obs.steps_taken, "method": method, "endpoint": endpoint,
107
+ "status": obs.status_code, "reward": round(reward, 4), "bugs": obs.bugs_found_so_far,
108
+ })
109
+
110
+ breakdown = obs.metadata.get("reward_breakdown", {})
111
+ reward_detail = format_reward_display(reward, state.total_reward, breakdown)
112
+
113
+ t = TASKS[state.task_id]
114
+ es = state.env.state
115
+
116
+ status = ""
117
+ if obs.done:
118
+ status = (
119
+ f"\n\n**EPISODE COMPLETE**\n\n"
120
+ f"Final Score: {reward:.4f}\n"
121
+ f"Bugs: {obs.bugs_found_so_far}/{t['total_bugs']}\n"
122
+ f"Steps: {obs.steps_taken}/{obs.max_steps}"
123
+ )
124
+
125
+ return (
126
+ state,
127
+ obs.feedback + status,
128
+ f"**{obs.status_code}** β€” {obs.response_time_ms:.1f}ms\n\n```json\n{resp_str}\n```",
129
+ reward_detail,
130
+ f"{obs.bugs_found_so_far} / {t['total_bugs']}",
131
+ format_coverage(obs.coverage_summary),
132
+ format_log(state.step_log),
133
+ f"{obs.steps_taken} / {obs.max_steps}" + (" (DONE)" if obs.done else ""),
134
+ format_bug_list(es.bugs_found_ids),
135
+ _generate_report(es.bugs_found_ids, state.step_log),
136
+ format_auth_tokens(obs.auth_tokens),
137
+ format_resources(obs.known_resource_ids),
138
+ )
139
+
140
+
141
+ def apply_quick_action(action_name, _state):
142
+ quick_actions = {
143
+ "GET /tasks": ("GET", "/tasks", "{}", "{}", "", "200"),
144
+ "GET /users": ("GET", "/users", "{}", "{}", "", "200"),
145
+ "GET /tasks/1": ("GET", "/tasks/1", "{}", "{}", "", "200"),
146
+ "GET /tasks/999999 (bug hunt)": ("GET", "/tasks/999999", "{}", "{}", "", "404"),
147
+ "POST create task": ("POST", "/tasks", "{}", "{}", '{"title": "Test Task", "description": "Created via UI"}', "201"),
148
+ "POST missing title (bug hunt)": ("POST", "/tasks", "{}", "{}", '{"description": "no title"}', "400"),
149
+ "Login as alice": ("POST", "/auth/login", "{}", "{}", '{"username": "alice", "password": "pass"}', "200"),
150
+ "Login as bob": ("POST", "/auth/login", "{}", "{}", '{"username": "bob", "password": "pass"}', "200"),
151
+ "Login empty pwd (bug hunt)": ("POST", "/auth/login", "{}", "{}", '{"username": "alice", "password": ""}', "401"),
152
+ "Negative page (bug hunt)": ("GET", "/tasks", "{}", '{"page": -1, "limit": 10}', "", "400"),
153
+ "Huge limit (bug hunt)": ("GET", "/tasks", "{}", '{"limit": 999999}', "", "200"),
154
+ "Invalid email PUT (bug hunt)": ("PUT", "/tasks/1", "{}", "{}", '{"assignee_email": "not-an-email"}', "422"),
155
+ "DELETE non-existent (bug hunt)": ("DELETE", "/tasks/99999", "{}", "{}", "", "404"),
156
+ "Create user invalid email (bug)": ("POST", "/users", "{}", "{}", '{"username": "baduser", "email": "nope", "password": "x"}', "422"),
157
+ "SQL injection test": ("POST", "/tasks", "{}", "{}", '{"title": "test\'; DROP TABLE tasks;--"}', "201"),
158
+ "Long title crash (bug hunt)": ("POST", "/tasks", "{}", "{}", '{"title": "' + "A" * 6000 + '"}', "400"),
159
+ }
160
+ if action_name and action_name in quick_actions:
161
+ return quick_actions[action_name]
162
+ return [gr.update()] * 6
163
+
164
+
165
+ def run_baseline_agent(agent_type, state):
166
+ if not state or not state.initialized:
167
+ yield state, "Environment not initialized.", "", "", "", "", "", "", "", "", "", ""
168
+ return
169
+
170
+ from training.agents import RandomAgent, SequentialAgent, SmartAgent
171
+ agents = {"random": RandomAgent, "sequential": SequentialAgent, "smart": SmartAgent}
172
+ agent = agents[agent_type]()
173
+ t = TASKS[state.task_id]
174
+
175
+ obs = state.env.reset(task_id=state.task_id)
176
+ state.step_log = []
177
+ state.total_reward = 0.0
178
+ state.last_obs = obs
179
+
180
+ while not obs.done:
181
+ obs_dict = {
182
+ "status_code": obs.status_code, "response_body": obs.response_body,
183
+ "feedback": obs.feedback, "bugs_found_so_far": obs.bugs_found_so_far,
184
+ "coverage_summary": obs.coverage_summary, "known_resource_ids": obs.known_resource_ids,
185
+ "auth_tokens": obs.auth_tokens, "steps_taken": obs.steps_taken, "max_steps": obs.max_steps,
186
+ }
187
+ action = agent.act(obs_dict)
188
+ obs = state.env.step(action)
189
+ reward = obs.reward or 0.0
190
+ state.total_reward += reward
191
+ state.last_obs = obs
192
+
193
+ ms = action.method.value if hasattr(action.method, "value") else str(action.method)
194
+ state.step_log.append({
195
+ "step": obs.steps_taken, "method": ms, "endpoint": action.endpoint,
196
+ "status": obs.status_code, "reward": round(reward, 4), "bugs": obs.bugs_found_so_far,
197
+ })
198
+
199
+ resp_body = obs.response_body
200
+ if isinstance(resp_body, (dict, list)):
201
+ resp_str = json.dumps(resp_body, indent=2)
202
+ else:
203
+ resp_str = str(resp_body)
204
+
205
+ breakdown = obs.metadata.get("reward_breakdown", {})
206
+ reward_detail = format_reward_display(reward, state.total_reward, breakdown)
207
+
208
+ es = state.env.state
209
+ done_text = ""
210
+ if obs.done:
211
+ done_text = f"\n\n**EPISODE COMPLETE** β€” Final Score: {reward:.4f} | Bugs: {obs.bugs_found_so_far}/{t['total_bugs']}"
212
+
213
+ yield (
214
+ state,
215
+ f"[{agent_type}] {ms} {action.endpoint} -> {obs.status_code}{done_text}",
216
+ f"**{obs.status_code}**\n```json\n{resp_str[:500]}\n```",
217
+ reward_detail,
218
+ f"{obs.bugs_found_so_far} / {t['total_bugs']}",
219
+ format_coverage(obs.coverage_summary),
220
+ format_log(state.step_log),
221
+ f"{obs.steps_taken} / {obs.max_steps}" + (" (DONE)" if obs.done else ""),
222
+ format_bug_list(es.bugs_found_ids),
223
+ _generate_report(es.bugs_found_ids, state.step_log),
224
+ format_auth_tokens(obs.auth_tokens),
225
+ format_resources(obs.known_resource_ids),
226
+ )
227
+ time.sleep(0.3)
228
+
229
+
230
+ # =====================================================================
231
+ # Formatters
232
+ # =====================================================================
233
+
234
+ def format_reward_display(step_reward, cumulative, breakdown):
235
+ """Render reward metrics as styled HTML with explanations."""
236
+ components = [
237
+ ("Coverage", breakdown.get("coverage", 0),
238
+ "Reward for testing new endpoints and methods"),
239
+ ("Validity", breakdown.get("validity", 0),
240
+ "Reward for sending well-formed requests that return expected status codes"),
241
+ ("Bug", breakdown.get("bug_discovery", 0),
242
+ "Bonus for discovering a new bug in the API"),
243
+ ("Explore", breakdown.get("exploration", 0),
244
+ "Reward for trying new parameter combinations and edge cases"),
245
+ ("Penalty", breakdown.get("penalty", 0),
246
+ "Deduction for repeated or invalid requests"),
247
+ ]
248
+ bars = []
249
+ for label, value, tip in components:
250
+ val_color = "#16a34a" if value > 0 else "#dc2626" if value < 0 else "inherit"
251
+ bars.append(
252
+ f'<div style="display:flex;justify-content:space-between;align-items:center;'
253
+ f'padding:2px 0;font-size:0.82em;" title="{tip}">'
254
+ f'<span style="opacity:0.6;cursor:help;border-bottom:1px dotted currentColor;">'
255
+ f'{label}</span>'
256
+ f'<span style="color:{val_color};font-family:monospace;font-weight:600;">'
257
+ f'{value:+.3f}</span></div>'
258
+ )
259
+ cum_color = "#16a34a" if cumulative > 0 else "#dc2626" if cumulative < 0 else "inherit"
260
+ step_color = "#16a34a" if step_reward > 0 else "#dc2626" if step_reward < 0 else "inherit"
261
+ return (
262
+ f'<div style="display:flex;gap:16px;margin-bottom:8px;">'
263
+ f'<div style="flex:1;text-align:center;padding:6px;background:rgba(128,128,128,0.1);'
264
+ f'border-radius:8px;">'
265
+ f'<div style="font-size:0.72em;opacity:0.55;">STEP REWARD</div>'
266
+ f'<div style="font-size:1.3em;font-weight:700;color:{step_color};">'
267
+ f'{step_reward:+.4f}</div></div>'
268
+ f'<div style="flex:1;text-align:center;padding:6px;background:rgba(128,128,128,0.1);'
269
+ f'border-radius:8px;">'
270
+ f'<div style="font-size:0.72em;opacity:0.55;">CUMULATIVE</div>'
271
+ f'<div style="font-size:1.3em;font-weight:700;color:{cum_color};">'
272
+ f'{cumulative:.4f}</div></div></div>'
273
+ f'<div style="border:1px solid rgba(128,128,128,0.2);border-radius:8px;padding:6px 10px;">'
274
+ f'<div style="font-size:0.72em;opacity:0.5;margin-bottom:4px;">'
275
+ f'REWARD BREAKDOWN '
276
+ f'<span title="How the reward for the last step was calculated"'
277
+ f' style="cursor:help;">&#9432;</span></div>'
278
+ + "".join(bars)
279
+ + "</div>"
280
+ )
281
+
282
+
283
+ def format_coverage(summary):
284
+ if not summary:
285
+ return "No data"
286
+ pct = summary.get("coverage_pct", 0)
287
+ tested = summary.get("endpoints_tested", 0)
288
+ total = summary.get("total_endpoints", 0)
289
+ pairs = summary.get("method_endpoint_pairs", 0)
290
+ codes = summary.get("status_codes_seen", [])
291
+ color = "#dc2626" if pct < 30 else "#d97706" if pct < 70 else "#16a34a"
292
+ bar_html = (
293
+ f'<div style="display:flex;align-items:center;gap:8px;margin:4px 0;">'
294
+ f'<div style="flex:1;background:rgba(128,128,128,0.15);border-radius:6px;height:14px;overflow:hidden;">'
295
+ f'<div style="width:{pct:.1f}%;height:100%;background:{color};border-radius:6px;'
296
+ f'transition:width 0.3s ease;"></div></div>'
297
+ f'<span style="font-weight:700;min-width:48px;text-align:right;">{pct:.1f}%</span></div>'
298
+ )
299
+ code_pills = ""
300
+ for c in codes:
301
+ cc = "#16a34a" if 200 <= c < 300 else "#d97706" if 300 <= c < 400 else "#dc2626"
302
+ code_pills += (
303
+ f'<span style="background:{cc}18;color:{cc};padding:1px 7px;border-radius:10px;'
304
+ f'font-size:0.78em;font-weight:600;margin-right:4px;">{c}</span>'
305
+ )
306
+ return (
307
+ f"{bar_html}"
308
+ f'<div style="display:flex;gap:10px;margin:6px 0;font-size:0.82em;">'
309
+ f'<div style="flex:1;text-align:center;padding:4px;background:rgba(128,128,128,0.1);border-radius:6px;"'
310
+ f' title="How many unique API endpoints have been called">'
311
+ f'<div style="font-size:0.72em;opacity:0.5;">ENDPOINTS</div>'
312
+ f'<div style="font-weight:700;">{tested}/{total}</div></div>'
313
+ f'<div style="flex:1;text-align:center;padding:4px;background:rgba(128,128,128,0.1);border-radius:6px;"'
314
+ f' title="Unique combinations of HTTP method + endpoint path tested">'
315
+ f'<div style="font-size:0.72em;opacity:0.5;">METHOD+PATH</div>'
316
+ f'<div style="font-weight:700;">{pairs}</div></div></div>'
317
+ f'<div style="margin-top:4px;" title="HTTP status codes received from the API so far">'
318
+ f'<span style="font-size:0.72em;opacity:0.5;">STATUS CODES SEEN </span>'
319
+ f'{code_pills}</div>'
320
+ )
321
+
322
+
323
+ def format_log(log):
324
+ if not log:
325
+ return (
326
+ '<div style="opacity:0.55;font-size:0.85em;">'
327
+ "Each row shows an API request the agent made, the HTTP status it got back, "
328
+ "and the reward earned. Green = positive reward, red = penalty."
329
+ "</div>"
330
+ )
331
+ method_colors = {
332
+ "GET": "#2563eb", "POST": "#16a34a", "PUT": "#d97706",
333
+ "DELETE": "#dc2626", "PATCH": "#9333ea",
334
+ }
335
+ rows = []
336
+ for entry in log[-20:]:
337
+ m = entry["method"]
338
+ mcol = method_colors.get(m, "#6b7280")
339
+ r = entry["reward"]
340
+ rcol = "#16a34a" if r > 0 else "#dc2626" if r < 0 else "inherit"
341
+ bug_tag = (
342
+ '<span style="background:#92400e;color:#fef08a;padding:0 5px;border-radius:4px;'
343
+ 'font-size:0.7em;margin-left:4px;">BUG FOUND</span>'
344
+ ) if r > 0.2 else ""
345
+ status = entry["status"]
346
+ scol = "#16a34a" if 200 <= status < 300 else "#d97706" if 300 <= status < 400 else "#dc2626"
347
+ rows.append(
348
+ f'<div style="display:flex;align-items:center;gap:6px;padding:3px 0;'
349
+ f'border-bottom:1px solid rgba(128,128,128,0.1);font-size:0.82em;">'
350
+ f'<span style="opacity:0.45;min-width:20px;text-align:right;">{entry["step"]}</span>'
351
+ f'<span style="background:{mcol}18;color:{mcol};padding:1px 6px;border-radius:4px;'
352
+ f'font-weight:600;font-size:0.8em;min-width:52px;text-align:center;">{m}</span>'
353
+ f'<span style="flex:1;overflow:hidden;text-overflow:ellipsis;'
354
+ f'white-space:nowrap;">{entry["endpoint"]}</span>'
355
+ f'<span style="color:{scol};font-weight:600;min-width:28px;text-align:right;">{status}</span>'
356
+ f'<span style="color:{rcol};min-width:52px;text-align:right;font-family:monospace;'
357
+ f'font-size:0.85em;">{r:+.3f}</span>{bug_tag}</div>'
358
+ )
359
+ omitted = ""
360
+ if len(log) > 20:
361
+ omitted = (
362
+ f'<div style="opacity:0.45;font-size:0.78em;padding:4px 0;text-align:center;">'
363
+ f'... {len(log) - 20} earlier steps not shown</div>'
364
+ )
365
+ header = (
366
+ '<div style="opacity:0.55;font-size:0.78em;margin-bottom:6px;">'
367
+ "API requests made by the agent. Each row: step number, HTTP method, "
368
+ "endpoint, status code, and reward earned.</div>"
369
+ '<div style="display:flex;gap:6px;padding:2px 0 6px;border-bottom:1px solid rgba(128,128,128,0.2);'
370
+ 'font-size:0.75em;opacity:0.5;">'
371
+ '<span style="min-width:20px;text-align:right;">#</span>'
372
+ '<span style="min-width:52px;text-align:center;">Method</span>'
373
+ '<span style="flex:1;">Endpoint</span>'
374
+ '<span style="min-width:28px;text-align:right;">Status</span>'
375
+ '<span style="min-width:52px;text-align:right;">Reward</span></div>'
376
+ )
377
+ return header + omitted + "\n".join(rows)
378
+
379
+
380
+ def format_bug_list(bug_ids):
381
+ if not bug_ids:
382
+ return "No bugs found yet."
383
+ from server.bug_detector import BugDetector
384
+ detector = BugDetector("security_workflows")
385
+ severity_colors = {
386
+ "easy": "#16a34a",
387
+ "medium": "#d97706",
388
+ "hard": "#dc2626",
389
+ }
390
+ cards = []
391
+ for bid in sorted(bug_ids):
392
+ bug = detector.bugs.get(bid)
393
+ if bug:
394
+ fg = severity_colors.get(bug.severity, "#6b7280")
395
+ owasp_badge = f' | {bug.owasp.split(" ")[0]}' if bug.owasp else ""
396
+ cards.append(
397
+ f'<div style="border:1px solid {fg}40;border-radius:8px;padding:8px 10px;'
398
+ f'margin-bottom:6px;background:{fg}0d;">'
399
+ f'<div style="display:flex;justify-content:space-between;align-items:center;">'
400
+ f'<span style="font-weight:700;font-size:0.85em;">{bid}</span>'
401
+ f'<span style="background:{fg};color:#fff;padding:1px 8px;border-radius:10px;'
402
+ f'font-size:0.75em;font-weight:600;">{bug.severity.upper()}{owasp_badge}</span></div>'
403
+ f'<div style="margin-top:4px;font-size:0.85em;opacity:0.7;">'
404
+ f'{bug.description}</div>'
405
+ f'<div style="margin-top:2px;font-size:0.78em;opacity:0.5;font-style:italic;">'
406
+ f'{bug.owasp}</div></div>'
407
+ )
408
+ return "\n".join(cards)
409
+
410
+
411
+ def format_auth_tokens(tokens):
412
+ if not tokens:
413
+ return (
414
+ '<div style="opacity:0.5;font-size:0.85em;">'
415
+ "No tokens yet. Login via <code>POST /auth/login</code> to get auth tokens "
416
+ "for testing protected endpoints.</div>"
417
+ )
418
+ cards = []
419
+ for user, token in tokens.items():
420
+ cards.append(
421
+ f'<div style="display:flex;align-items:center;gap:8px;padding:4px 0;'
422
+ f'border-bottom:1px solid rgba(128,128,128,0.1);font-size:0.85em;">'
423
+ f'<span style="background:#2563eb18;color:#2563eb;padding:1px 8px;border-radius:10px;'
424
+ f'font-weight:600;font-size:0.8em;">{user}</span>'
425
+ f'<code style="opacity:0.55;font-size:0.82em;">{token[:20]}...</code></div>'
426
+ )
427
+ return (
428
+ '<div style="font-size:0.72em;opacity:0.5;margin-bottom:4px;"'
429
+ ' title="Auth tokens obtained by logging in. Use these in the Authorization header.">'
430
+ "AUTHENTICATED USERS</div>"
431
+ + "".join(cards)
432
+ )
433
+
434
+
435
+ def format_resources(ids):
436
+ if not ids:
437
+ return (
438
+ '<div style="opacity:0.5;font-size:0.85em;">'
439
+ "No resources created. Use POST endpoints to create tasks or users "
440
+ "and track their IDs here.</div>"
441
+ )
442
+ sections = []
443
+ type_colors = {"tasks": "#d97706", "users": "#2563eb"}
444
+ for rtype, id_list in ids.items():
445
+ color = type_colors.get(rtype, "#6b7280")
446
+ ids_str = ", ".join(str(i) for i in id_list) if isinstance(id_list, list) else str(id_list)
447
+ sections.append(
448
+ f'<div style="padding:4px 0;border-bottom:1px solid rgba(128,128,128,0.1);font-size:0.85em;">'
449
+ f'<span style="background:{color}18;color:{color};padding:1px 8px;border-radius:10px;'
450
+ f'font-weight:600;font-size:0.8em;text-transform:uppercase;">{rtype}</span>'
451
+ f'<span style="margin-left:8px;opacity:0.7;">IDs: {ids_str}</span></div>'
452
+ )
453
+ return (
454
+ '<div style="font-size:0.72em;opacity:0.5;margin-bottom:4px;"'
455
+ ' title="Resources created during this episode. Use these IDs in GET/PUT/DELETE requests.">'
456
+ "CREATED RESOURCES</div>"
457
+ + "".join(sections)
458
+ )
459
+
460
+
461
+ def format_endpoints():
462
+ lines = []
463
+ for ep in API_SPEC:
464
+ lines.append(f"**{ep['method']}** `{ep['path']}` β€” {ep.get('summary', '')}")
465
+ return "\n\n".join(lines)
466
+
467
+
468
+ # =====================================================================
469
+ # UI
470
+ # =====================================================================
471
+
472
+ def build_ui():
473
+ with gr.Blocks(title="API Testing Environment") as demo:
474
+ session = gr.State(value=new_session())
475
+
476
+ gr.Markdown(
477
+ "# API Testing Environment\n"
478
+ "An OpenEnv RL environment that trains AI agents to become automated **API security testers**. "
479
+ "A simulated API server with **13 hidden vulnerabilities** mapped to the **OWASP API Security Top 10** is provided. "
480
+ "Send HTTP requests, earn rewards for finding bugs and covering endpoints, and generate a **bug bounty report** at episode end. "
481
+ "Use **Manual Testing** to craft requests yourself, or run a **Baseline Agent** to watch an automated strategy."
482
+ )
483
+
484
+ with gr.Row():
485
+ # ── Left Panel ──
486
+ with gr.Column(scale=1):
487
+ gr.Markdown("### Environment Control")
488
+ task_dropdown = gr.Dropdown(choices=list(TASKS.keys()), value="basic_validation", label="Select Task")
489
+ reset_btn = gr.Button("Reset Environment", variant="primary", size="lg")
490
+ gr.Markdown(
491
+ '<span style="font-size:0.8em;opacity:0.55;">'
492
+ "Switch task or click Reset to start a fresh episode. "
493
+ "Resets all scores, bugs, and step count.</span>"
494
+ )
495
+ status_box = gr.Markdown("Initializing...")
496
+
497
+ gr.Markdown("---")
498
+ gr.Markdown("### Scoreboard")
499
+ gr.Markdown(
500
+ '<span style="font-size:0.78em;opacity:0.55;">'
501
+ "Tracks your testing progress. Steps are API calls you've made; "
502
+ "bugs are issues discovered in the API; reward measures how well "
503
+ "the agent is testing.</span>"
504
+ )
505
+ with gr.Row():
506
+ step_display = gr.Markdown("0 / 25", label="Steps")
507
+ bug_display = gr.Markdown("0 / 3", label="Bugs")
508
+ reward_display = gr.Markdown(format_reward_display(0, 0, {}), label="Reward")
509
+ coverage_display = gr.Markdown("No data", label="Coverage")
510
+
511
+ gr.Markdown("---")
512
+ gr.Markdown("### Session Context")
513
+ gr.Markdown(
514
+ '<span style="font-size:0.78em;opacity:0.55;">'
515
+ "Tokens and resources gathered during this episode. "
516
+ "Use tokens to test auth-protected endpoints and resource IDs for "
517
+ "GET/PUT/DELETE requests.</span>"
518
+ )
519
+ auth_display = gr.Markdown(format_auth_tokens({}))
520
+ resource_display = gr.Markdown(format_resources({}))
521
+
522
+ gr.Markdown("---")
523
+ with gr.Accordion("API Specification", open=False):
524
+ gr.Markdown(format_endpoints())
525
+
526
+ # ── Center Panel ──
527
+ with gr.Column(scale=2):
528
+ with gr.Tabs():
529
+ with gr.Tab("Manual Testing"):
530
+ gr.Markdown("### Craft Your Request")
531
+ with gr.Row():
532
+ method_input = gr.Dropdown(
533
+ choices=["GET", "POST", "PUT", "DELETE", "PATCH"],
534
+ value="GET", label="Method", scale=1,
535
+ )
536
+ endpoint_input = gr.Textbox(value="/tasks", label="Endpoint", placeholder="/tasks, /users/1, /auth/login", scale=3)
537
+ expected_input = gr.Textbox(value="200", label="Expected Status", placeholder="200", scale=1)
538
+
539
+ with gr.Row():
540
+ headers_input = gr.Textbox(value="{}", label="Headers (JSON)", placeholder='{"Authorization": "Bearer ..."}', lines=1)
541
+ params_input = gr.Textbox(value="{}", label="Query Params (JSON)", placeholder='{"page": 1, "limit": 10}', lines=1)
542
+
543
+ body_input = gr.Textbox(value="", label="Request Body (JSON)", placeholder='{"title": "My Task", "description": "..."}', lines=3)
544
+
545
+ send_btn = gr.Button("Send Request", variant="primary", size="lg")
546
+
547
+ gr.Markdown("### Quick Actions")
548
+ quick_actions = gr.Dropdown(
549
+ choices=[
550
+ "GET /tasks", "GET /users", "GET /tasks/1",
551
+ "GET /tasks/999999 (bug hunt)", "POST create task",
552
+ "POST missing title (bug hunt)", "Login as alice", "Login as bob",
553
+ "Login empty pwd (bug hunt)", "Negative page (bug hunt)",
554
+ "Huge limit (bug hunt)", "Invalid email PUT (bug hunt)",
555
+ "DELETE non-existent (bug hunt)", "Create user invalid email (bug)",
556
+ "SQL injection test", "Long title crash (bug hunt)",
557
+ ],
558
+ label="Quick Actions", value=None,
559
+ )
560
+ quick_btn = gr.Button("Load Quick Action", variant="secondary")
561
+
562
+ with gr.Tab("Run Baseline Agent"):
563
+ gr.Markdown("### Automated Agents\nWatch a baseline agent test the API step by step.")
564
+ agent_dropdown = gr.Dropdown(choices=["random", "sequential", "smart"], value="smart", label="Agent Type")
565
+ run_agent_btn = gr.Button("Run Agent", variant="primary", size="lg")
566
+
567
+ gr.Markdown("---")
568
+ gr.Markdown("### Response")
569
+ response_display = gr.Markdown("")
570
+
571
+ gr.Markdown("### Feedback")
572
+ feedback_display = gr.Markdown("")
573
+
574
+ # ── Right Panel ──
575
+ with gr.Column(scale=1):
576
+ with gr.Tabs():
577
+ with gr.Tab("Discovered Bugs"):
578
+ bug_list_display = gr.Markdown("No bugs found yet.")
579
+
580
+ with gr.Tab("Bug Report"):
581
+ gr.Markdown("*Auto-generated OWASP security report. Populates as bugs are found.*")
582
+ bug_report_display = gr.Markdown("No bugs found yet. Send requests to discover vulnerabilities.")
583
+
584
+ with gr.Tab("Activity Log"):
585
+ log_display = gr.Markdown("No steps yet.")
586
+
587
+ # ── Wiring ──
588
+ reset_outputs = [
589
+ session, status_box, feedback_display, response_display,
590
+ reward_display, bug_display, coverage_display, log_display,
591
+ step_display, bug_list_display, bug_report_display, auth_display, resource_display,
592
+ ]
593
+
594
+ step_outputs = [
595
+ session, feedback_display, response_display, reward_display,
596
+ bug_display, coverage_display, log_display, step_display,
597
+ bug_list_display, bug_report_display, auth_display, resource_display,
598
+ ]
599
+
600
+ reset_btn.click(fn=reset_env, inputs=[task_dropdown, session], outputs=reset_outputs)
601
+
602
+ send_btn.click(
603
+ fn=send_request,
604
+ inputs=[method_input, endpoint_input, headers_input, params_input, body_input, expected_input, session],
605
+ outputs=step_outputs,
606
+ )
607
+
608
+ quick_btn.click(
609
+ fn=apply_quick_action, inputs=[quick_actions, session],
610
+ outputs=[method_input, endpoint_input, headers_input, params_input, body_input, expected_input],
611
+ )
612
+
613
+ run_agent_btn.click(fn=run_baseline_agent, inputs=[agent_dropdown, session], outputs=step_outputs)
614
+
615
+ # Auto-reset on page load so users can start testing immediately
616
+ demo.load(fn=reset_env, inputs=[task_dropdown, session], outputs=reset_outputs)
617
+
618
+ return demo
619
+
620
+
621
+ if __name__ == "__main__":
622
+ parser = argparse.ArgumentParser()
623
+ parser.add_argument("--port", type=int, default=int(os.getenv("GRADIO_SERVER_PORT", "7860")))
624
+ parser.add_argument("--host", default="0.0.0.0")
625
+ parser.add_argument("--share", action="store_true")
626
+ args = parser.parse_args()
627
+ build_ui().launch(server_name=args.host, server_port=args.port, share=args.share)
inference.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ inference.py β€” OpenEnv API Testing Environment baseline inference script.
4
+
5
+ Runs an LLM agent against the API Testing Environment for all 3 tasks
6
+ (basic_validation -> edge_cases -> security_workflows) and emits the
7
+ mandatory [START]/[STEP]/[END] stdout format used by the OpenEnv judging
8
+ pipeline.
9
+
10
+ Required env vars (per OpenEnv submission spec):
11
+ API_BASE_URL The OpenAI-compatible LLM endpoint
12
+ MODEL_NAME The model identifier to use for inference
13
+ HF_TOKEN Bearer token for the LLM endpoint (or API_KEY)
14
+
15
+ Optional env vars:
16
+ IMAGE_NAME Docker image to spin up the env via from_docker_image()
17
+ LOCAL_IMAGE_NAME Alias for IMAGE_NAME
18
+ ENV_BASE_URL URL of an already-running env server (e.g. http://localhost:8000)
19
+ INFERENCE_TASKS Comma-separated subset of tasks to run (default: all 3)
20
+ INFERENCE_MAX_STEPS Override max steps per task
21
+ INFERENCE_TEMPERATURE Default 0.4
22
+ INFERENCE_MAX_TOKENS Default 4096 (plan completions need room for ~25 actions)
23
+
24
+ The script uses PLAN MODE: one LLM call per task produces a complete JSON
25
+ test plan, then the env executes each action sequentially. This matches the
26
+ GRPO training distribution and keeps total LLM cost to 3 calls per run, so
27
+ the script comfortably runs under 20 min on 2 vCPU / 8 GB RAM.
28
+
29
+ Usage:
30
+ # Local in-process (no Docker, fastest)
31
+ python inference.py
32
+
33
+ # Against a built docker image
34
+ IMAGE_NAME=api-testing-env:latest python inference.py
35
+
36
+ # Against an already running server
37
+ ENV_BASE_URL=http://localhost:8000 python inference.py
38
+
39
+ # Against a deployed HF Space
40
+ ENV_BASE_URL=https://your-user-api-testing-env.hf.space python inference.py
41
+ """
42
+
43
+ import json
44
+ import os
45
+ import sys
46
+ import time
47
+ import traceback
48
+ from typing import Any, Optional
49
+
50
+ # Make sibling modules importable when run from the repo root
51
+ _THIS_DIR = os.path.dirname(os.path.abspath(__file__))
52
+ if _THIS_DIR not in sys.path:
53
+ sys.path.insert(0, _THIS_DIR)
54
+
55
+ # Auto-load .env file if present (for local development)
56
+ # Judges set env vars directly so this is harmless in production
57
+ try:
58
+ from dotenv import load_dotenv
59
+ _env_path = os.path.join(_THIS_DIR, ".env")
60
+ if os.path.exists(_env_path):
61
+ load_dotenv(_env_path)
62
+ except ImportError:
63
+ pass # python-dotenv is optional
64
+
65
+ from openai import OpenAI
66
+
67
+ from models import APITestAction, HTTPMethod # noqa: E402
68
+ from training.prompts import ( # noqa: E402
69
+ PLAN_SYSTEM_PROMPT,
70
+ format_plan_prompt,
71
+ parse_test_plan,
72
+ )
73
+
74
+
75
+ # ---------------------------------------------------------------------------
76
+ # Config (env vars per OpenEnv spec)
77
+ # ---------------------------------------------------------------------------
78
+
79
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
80
+ # Default model: must be available on the HuggingFace Inference Router.
81
+ # Llama-3.3-70B-Instruct is reliable, follows JSON instructions well, and free.
82
+ # Override via: MODEL_NAME=other/model python inference.py
83
+ MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.3-70B-Instruct")
84
+ API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
85
+
86
+ if not API_KEY:
87
+ print(
88
+ "[ERROR] No HF_TOKEN or API_KEY found in environment.\n"
89
+ " Set one of:\n"
90
+ " export HF_TOKEN=hf_xxxxxxxxxxxxxxxxxxxx\n"
91
+ " Or create a .env file in this directory with:\n"
92
+ " HF_TOKEN=hf_xxxxxxxxxxxxxxxxxxxx\n"
93
+ " Get a token from: https://huggingface.co/settings/tokens\n"
94
+ " Make sure it has 'Make calls to Inference Providers' permission.",
95
+ file=sys.stderr,
96
+ )
97
+ sys.exit(1)
98
+
99
+ IMAGE_NAME = os.getenv("IMAGE_NAME") or os.getenv("LOCAL_IMAGE_NAME")
100
+ ENV_BASE_URL = os.getenv("ENV_BASE_URL")
101
+
102
+ BENCHMARK = "api_testing_env"
103
+ DEFAULT_TASKS = ["basic_validation", "edge_cases", "security_workflows"]
104
+ TASKS = [t.strip() for t in os.getenv("INFERENCE_TASKS", ",".join(DEFAULT_TASKS)).split(",") if t.strip()]
105
+
106
+ TEMPERATURE = float(os.getenv("INFERENCE_TEMPERATURE", "0.4"))
107
+ MAX_TOKENS = int(os.getenv("INFERENCE_MAX_TOKENS", "4096"))
108
+ _MAX_STEPS_OVERRIDE = os.getenv("INFERENCE_MAX_STEPS")
109
+ MAX_STEPS_OVERRIDE: Optional[int] = int(_MAX_STEPS_OVERRIDE) if _MAX_STEPS_OVERRIDE else None
110
+
111
+
112
+ # ---------------------------------------------------------------------------
113
+ # Strict stdout logging β€” these line formats are checked by the judge
114
+ # ---------------------------------------------------------------------------
115
+
116
+ def log_start(task: str, env: str, model: str) -> None:
117
+ print(f"[START] task={task} env={env} model={model}", flush=True)
118
+
119
+
120
+ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
121
+ print(
122
+ f"[STEP] step={step} action={action} reward={reward:.2f} "
123
+ f"done={str(done).lower()} error={error if error else 'null'}",
124
+ flush=True,
125
+ )
126
+
127
+
128
+ def log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None:
129
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
130
+ print(
131
+ f"[END] success={str(success).lower()} steps={steps} "
132
+ f"score={score:.3f} rewards={rewards_str}",
133
+ flush=True,
134
+ )
135
+
136
+
137
+ def _action_str(action: APITestAction) -> str:
138
+ """Compact human-readable action label for the [STEP] line."""
139
+ method = action.method.value if hasattr(action.method, "value") else str(action.method)
140
+ return f"{method}_{action.endpoint}"
141
+
142
+
143
+ # ---------------------------------------------------------------------------
144
+ # LLM call β€” plan mode (one completion per task)
145
+ # ---------------------------------------------------------------------------
146
+
147
+ def get_plan_from_llm(client: OpenAI, observation) -> str:
148
+ """Ask the LLM for a complete JSON test plan for this task.
149
+
150
+ Wraps the array in {"actions": [...]} so we can use OpenAI structured
151
+ output mode (`response_format={"type": "json_object"}`), which forces
152
+ the LLM to produce valid JSON. This is much more reliable than asking
153
+ for a raw JSON array.
154
+ """
155
+ user_prompt = format_plan_prompt(observation)
156
+
157
+ # Stronger system prompt for structured output mode
158
+ system_prompt = (
159
+ PLAN_SYSTEM_PROMPT
160
+ + "\n\nIMPORTANT: Output a JSON object with a single key 'actions' "
161
+ + "containing the array of actions:\n"
162
+ + '{"actions": [{"method": "GET", "endpoint": "/tasks", "headers": {}, '
163
+ + '"query_params": {}, "body": null, "expected_status": 200}, ...]}'
164
+ )
165
+
166
+ try:
167
+ completion = client.chat.completions.create(
168
+ model=MODEL_NAME,
169
+ messages=[
170
+ {"role": "system", "content": system_prompt},
171
+ {"role": "user", "content": user_prompt},
172
+ ],
173
+ temperature=TEMPERATURE,
174
+ max_tokens=MAX_TOKENS,
175
+ response_format={"type": "json_object"}, # forces valid JSON
176
+ stream=False,
177
+ )
178
+ text = (completion.choices[0].message.content or "").strip()
179
+ print(f"[DEBUG] LLM response length: {len(text)} chars", flush=True)
180
+ if len(text) > 0:
181
+ preview = text[:300].replace("\n", " ")
182
+ print(f"[DEBUG] LLM response preview: {preview}...", flush=True)
183
+ else:
184
+ print(f"[DEBUG] LLM returned EMPTY string", flush=True)
185
+ if hasattr(completion, "choices") and completion.choices:
186
+ finish_reason = getattr(completion.choices[0], "finish_reason", None)
187
+ print(f"[DEBUG] finish_reason: {finish_reason}", flush=True)
188
+ return text
189
+ except Exception as exc: # noqa: BLE001
190
+ print(f"[DEBUG] structured-output call failed ({type(exc).__name__}: {exc}), retrying without response_format...", flush=True)
191
+ # Some providers don't support response_format β€” fall back to plain text
192
+ try:
193
+ completion = client.chat.completions.create(
194
+ model=MODEL_NAME,
195
+ messages=[
196
+ {"role": "system", "content": PLAN_SYSTEM_PROMPT},
197
+ {"role": "user", "content": user_prompt},
198
+ ],
199
+ temperature=TEMPERATURE,
200
+ max_tokens=MAX_TOKENS,
201
+ stream=False,
202
+ )
203
+ text = (completion.choices[0].message.content or "").strip()
204
+ print(f"[DEBUG] fallback LLM response length: {len(text)} chars", flush=True)
205
+ return text
206
+ except Exception as exc2: # noqa: BLE001
207
+ print(f"[DEBUG] fallback LLM call failed: {type(exc2).__name__}: {exc2}", flush=True)
208
+ return ""
209
+
210
+
211
+ # ---------------------------------------------------------------------------
212
+ # Per-task scoring helper β€” keeps the score in [0, 1]
213
+ # ---------------------------------------------------------------------------
214
+
215
+ def compute_task_score(state, total_step_reward: float) -> float:
216
+ """Combine grader signals into a single normalized score in [0, 1].
217
+
218
+ The server already runs `TaskGrader.grade(...)` at episode end and adds
219
+ that score (already in [0, 1]) on top of the last step reward. We do
220
+ NOT trust the raw step rewards β€” those are sums of partial signals and
221
+ can exceed 1.0. Instead we derive the score from the published state:
222
+ score = 0.7 * (bugs_found / total_bugs) + 0.3 * (coverage_pct / 100)
223
+ which is bounded in [0, 1] and rewards both finding bugs and coverage.
224
+ """
225
+ bugs_found = getattr(state, "bugs_found", 0) or 0
226
+ total_bugs = getattr(state, "total_bugs", 0) or 0
227
+ coverage_pct = getattr(state, "coverage_pct", 0.0) or 0.0
228
+
229
+ bug_ratio = (bugs_found / total_bugs) if total_bugs > 0 else 0.0
230
+ coverage_ratio = max(0.0, min(1.0, coverage_pct / 100.0))
231
+
232
+ score = 0.70 * bug_ratio + 0.30 * coverage_ratio
233
+ return max(0.0, min(1.0, score))
234
+
235
+
236
+ # ---------------------------------------------------------------------------
237
+ # Environment connector β€” supports docker / remote / in-process
238
+ # ---------------------------------------------------------------------------
239
+
240
+ class _EnvHandle:
241
+ """Thin wrapper that exposes a uniform reset/step/state/close API.
242
+
243
+ Three modes, picked automatically:
244
+ 1. IMAGE_NAME set -> APITestEnv.from_docker_image(IMAGE_NAME)
245
+ 2. ENV_BASE_URL set -> APITestEnv(base_url=ENV_BASE_URL)
246
+ 3. neither set (default) -> APITestEnvironment() in-process
247
+ """
248
+
249
+ def __init__(self):
250
+ self._mode: str = ""
251
+ self._client = None # remote/docker client
252
+ self._env = None # in-process env
253
+
254
+ def open(self):
255
+ if IMAGE_NAME:
256
+ from client import APITestEnv
257
+ self._mode = "docker"
258
+ self._client = APITestEnv.from_docker_image(IMAGE_NAME)
259
+ elif ENV_BASE_URL:
260
+ from client import APITestEnv
261
+ self._mode = "remote"
262
+ self._client = APITestEnv(base_url=ENV_BASE_URL)
263
+ if hasattr(self._client, "connect"):
264
+ self._client.connect()
265
+ else:
266
+ from server.environment import APITestEnvironment
267
+ self._mode = "local"
268
+ self._env = APITestEnvironment()
269
+ return self
270
+
271
+ @property
272
+ def mode(self) -> str:
273
+ return self._mode
274
+
275
+ def reset(self, task_id: str, seed: int = 42):
276
+ if self._mode in ("docker", "remote"):
277
+ result = self._client.reset(task_id=task_id, seed=seed)
278
+ return result.observation, result
279
+ obs = self._env.reset(seed=seed, task_id=task_id)
280
+ return obs, None
281
+
282
+ def step(self, action: APITestAction):
283
+ if self._mode in ("docker", "remote"):
284
+ result = self._client.step(action)
285
+ return result.observation, result.reward or 0.0, result.done
286
+ obs = self._env.step(action)
287
+ return obs, (obs.reward or 0.0), obs.done
288
+
289
+ def state(self):
290
+ if self._mode in ("docker", "remote"):
291
+ return self._client.state()
292
+ return self._env.state
293
+
294
+ def close(self):
295
+ try:
296
+ if self._client is not None and hasattr(self._client, "close"):
297
+ self._client.close()
298
+ except Exception as exc: # noqa: BLE001
299
+ print(f"[DEBUG] env close error: {exc}", flush=True)
300
+
301
+
302
+ # ---------------------------------------------------------------------------
303
+ # One full episode (one task) -> emits [START] / [STEP]* / [END]
304
+ # ---------------------------------------------------------------------------
305
+
306
+ def run_task(env: _EnvHandle, client: OpenAI, task_id: str, seed: int = 42) -> dict:
307
+ rewards: list[float] = []
308
+ steps_taken = 0
309
+ last_error: Optional[str] = None
310
+ score = 0.0
311
+
312
+ log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
313
+
314
+ try:
315
+ obs, _ = env.reset(task_id=task_id, seed=seed)
316
+ max_steps = MAX_STEPS_OVERRIDE or getattr(obs, "max_steps", 25)
317
+
318
+ # 1) Ask the LLM for a full plan
319
+ plan_text = get_plan_from_llm(client, obs)
320
+ actions = parse_test_plan(plan_text) if plan_text else []
321
+
322
+ # Fallback: if parser failed but we have text, try a more lenient parse
323
+ if not actions and plan_text:
324
+ print(f"[DEBUG] {task_id}: parse_test_plan returned 0, trying lenient parse...", flush=True)
325
+ try:
326
+ import json as _json, re as _re
327
+ # Try to find any JSON array of objects in the text
328
+ cleaned = plan_text
329
+ if "</think>" in cleaned:
330
+ cleaned = cleaned.split("</think>", 1)[-1]
331
+ # Find first [ and last ]
332
+ start = cleaned.find("[")
333
+ end = cleaned.rfind("]")
334
+ if start >= 0 and end > start:
335
+ arr_str = cleaned[start:end+1]
336
+ raw = _json.loads(arr_str)
337
+ if isinstance(raw, list):
338
+ from training.prompts import _dict_to_action
339
+ for item in raw:
340
+ if isinstance(item, dict) and "method" in item:
341
+ a = _dict_to_action(item)
342
+ if a:
343
+ actions.append(a)
344
+ print(f"[DEBUG] {task_id}: lenient parse recovered {len(actions)} actions", flush=True)
345
+ except Exception as exc:
346
+ print(f"[DEBUG] {task_id}: lenient parse failed: {exc}", flush=True)
347
+ if not actions:
348
+ last_error = "no_plan_parsed"
349
+ print(f"[DEBUG] {task_id}: model produced 0 valid actions", flush=True)
350
+
351
+ actions = actions[:max_steps]
352
+
353
+ # 2) Execute each action and emit one [STEP] line per env.step()
354
+ done = False
355
+ for i, action in enumerate(actions, start=1):
356
+ if done:
357
+ break
358
+ try:
359
+ obs, reward, done = env.step(action)
360
+ rewards.append(float(reward))
361
+ steps_taken = i
362
+ log_step(step=i, action=_action_str(action), reward=reward, done=done, error=None)
363
+ except Exception as exc: # noqa: BLE001
364
+ last_error = f"{type(exc).__name__}: {exc}"
365
+ rewards.append(0.0)
366
+ steps_taken = i
367
+ log_step(step=i, action=_action_str(action), reward=0.0, done=False, error=last_error)
368
+
369
+ # 3) Score from final state
370
+ try:
371
+ final_state = env.state()
372
+ score = compute_task_score(final_state, sum(rewards))
373
+ except Exception as exc: # noqa: BLE001
374
+ last_error = last_error or f"state_error: {exc}"
375
+ score = 0.0
376
+
377
+ except Exception as exc: # noqa: BLE001
378
+ last_error = f"{type(exc).__name__}: {exc}"
379
+ traceback.print_exc()
380
+
381
+ success = score >= 0.20 # any meaningful progress counts as a successful episode
382
+ log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
383
+
384
+ return {
385
+ "task_id": task_id,
386
+ "success": success,
387
+ "steps": steps_taken,
388
+ "score": score,
389
+ "rewards": rewards,
390
+ "error": last_error,
391
+ }
392
+
393
+
394
+ # ---------------------------------------------------------------------------
395
+ # Main β€” runs all 3 tasks sequentially against ONE env handle
396
+ # ---------------------------------------------------------------------------
397
+
398
+ def main() -> None:
399
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
400
+
401
+ print(
402
+ f"[DEBUG] inference.py starting | model={MODEL_NAME} | "
403
+ f"base_url={API_BASE_URL} | tasks={TASKS}",
404
+ flush=True,
405
+ )
406
+
407
+ env = _EnvHandle().open()
408
+ print(f"[DEBUG] env mode={env.mode}", flush=True)
409
+
410
+ summary: list[dict] = []
411
+ t0 = time.time()
412
+ try:
413
+ for task_id in TASKS:
414
+ result = run_task(env, client, task_id=task_id, seed=42)
415
+ summary.append(result)
416
+ finally:
417
+ env.close()
418
+
419
+ elapsed = time.time() - t0
420
+ avg_score = sum(r["score"] for r in summary) / max(len(summary), 1)
421
+ print(
422
+ f"[DEBUG] inference.py finished in {elapsed:.1f}s | "
423
+ f"avg_score={avg_score:.3f}",
424
+ flush=True,
425
+ )
426
+ print("[DEBUG] per-task scores: " + json.dumps(
427
+ {r["task_id"]: round(r["score"], 3) for r in summary}
428
+ ), flush=True)
429
+
430
+
431
+ if __name__ == "__main__":
432
+ main()
models.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data models for the API Testing Environment.
3
+
4
+ Defines Action, Observation, State for API integration testing training.
5
+ An AI agent learns to test REST APIs intelligently β€” discovering endpoints,
6
+ crafting requests, validating responses, finding bugs, and handling edge cases.
7
+ """
8
+
9
+ from enum import Enum
10
+ from typing import Any, Optional
11
+
12
+ from pydantic import Field
13
+
14
+ from openenv.core.env_server.types import Action, Observation, State
15
+
16
+
17
+ class HTTPMethod(str, Enum):
18
+ GET = "GET"
19
+ POST = "POST"
20
+ PUT = "PUT"
21
+ DELETE = "DELETE"
22
+ PATCH = "PATCH"
23
+
24
+
25
+ class BugSeverity(str, Enum):
26
+ EASY = "easy"
27
+ MEDIUM = "medium"
28
+ HARD = "hard"
29
+
30
+
31
+ class APITestAction(Action):
32
+ """What the agent sends each step β€” an HTTP request to test the API."""
33
+
34
+ method: HTTPMethod = Field(..., description="HTTP method")
35
+ endpoint: str = Field(..., min_length=1, description="API endpoint path, e.g. /tasks, /users/1")
36
+ headers: dict[str, str] = Field(default_factory=dict, description="Request headers")
37
+ query_params: dict[str, Any] = Field(default_factory=dict, description="URL query parameters")
38
+ body: Optional[dict[str, Any]] = Field(default=None, description="Request JSON body")
39
+ expected_status: Optional[int] = Field(
40
+ default=None,
41
+ description="What the agent expects the status code to be (used for bug detection)",
42
+ )
43
+
44
+
45
+ class EndpointInfo(Action):
46
+ """Information about a single API endpoint from the spec."""
47
+
48
+ method: str = ""
49
+ path: str = ""
50
+ summary: str = ""
51
+ parameters: list[dict[str, Any]] = Field(default_factory=list)
52
+ request_body_schema: Optional[dict[str, Any]] = None
53
+ response_schema: Optional[dict[str, Any]] = None
54
+
55
+
56
+ class APITestObservation(Observation):
57
+ """What the agent sees after each step."""
58
+
59
+ # API spec info (provided on reset, updated each step)
60
+ available_endpoints: list[dict[str, Any]] = Field(
61
+ default_factory=list, description="Available API endpoints from the spec"
62
+ )
63
+
64
+ # Response from last request
65
+ status_code: int = Field(default=0, description="HTTP status code of the response")
66
+ response_body: Any = Field(default=None, description="Response body (JSON or text)")
67
+ response_headers: dict[str, str] = Field(default_factory=dict, description="Response headers")
68
+ response_time_ms: float = Field(default=0.0, description="Response time in milliseconds")
69
+
70
+ # Feedback
71
+ feedback: str = Field(default="", description="Human-readable feedback about the last action")
72
+ bugs_found_so_far: int = Field(default=0, description="Number of bugs found so far")
73
+ coverage_summary: dict[str, Any] = Field(
74
+ default_factory=dict,
75
+ description="Coverage stats: endpoints_tested, methods_used, status_codes_seen",
76
+ )
77
+
78
+ # Context from prior steps
79
+ known_resource_ids: dict[str, list[Any]] = Field(
80
+ default_factory=dict,
81
+ description="Resource IDs created by POST requests, keyed by resource type",
82
+ )
83
+ auth_tokens: dict[str, str] = Field(
84
+ default_factory=dict,
85
+ description="Available auth tokens for different users/roles",
86
+ )
87
+
88
+ # Task info
89
+ task_id: str = Field(default="", description="Current task identifier")
90
+ task_description: str = Field(default="", description="Description of the current task")
91
+ steps_taken: int = Field(default=0, description="Steps taken in this episode")
92
+ max_steps: int = Field(default=30, description="Maximum steps per episode")
93
+
94
+
95
+ class APITestState(State):
96
+ """Episode metadata β€” internal state exposed via state() endpoint."""
97
+
98
+ task_id: str = ""
99
+ task_description: str = ""
100
+ difficulty: str = "easy"
101
+ steps_taken: int = 0
102
+ max_steps: int = 30
103
+ bugs_found: int = 0
104
+ total_bugs: int = 0
105
+ bugs_found_ids: list[str] = Field(default_factory=list)
106
+ coverage_pct: float = 0.0
107
+ endpoints_tested: int = 0
108
+ total_endpoints: int = 0
109
+ current_score: float = 0.0
110
+ cumulative_reward: float = 0.0
openenv.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: api_testing_env
3
+ type: space
4
+ runtime: fastapi
5
+ app: server.app:app
6
+ port: 8000
openenv_api_testing.egg-info/PKG-INFO ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: openenv-api-testing
3
+ Version: 0.1.0
4
+ Summary: RL environment for intelligent API integration testing β€” train agents to find bugs in REST APIs
5
+ Requires-Python: >=3.10
6
+ Requires-Dist: openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git@v0.2.1
7
+ Requires-Dist: fastapi>=0.104.0
8
+ Requires-Dist: uvicorn>=0.24.0
9
+ Requires-Dist: httpx>=0.25.0
10
+ Requires-Dist: pydantic>=2.0.0
11
+ Provides-Extra: dev
12
+ Requires-Dist: pytest>=8.0.0; extra == "dev"
13
+ Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
14
+ Provides-Extra: train
15
+ Requires-Dist: trl[vllm]>=0.29.0; extra == "train"
16
+ Requires-Dist: torch>=2.8.0; extra == "train"
17
+ Requires-Dist: peft; extra == "train"
18
+ Requires-Dist: transformers; extra == "train"
19
+ Requires-Dist: datasets; extra == "train"
openenv_api_testing.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ README.md
2
+ pyproject.toml
3
+ ./__init__.py
4
+ ./baseline.py
5
+ ./client.py
6
+ ./models.py
7
+ openenv_api_testing.egg-info/PKG-INFO
8
+ openenv_api_testing.egg-info/SOURCES.txt
9
+ openenv_api_testing.egg-info/dependency_links.txt
10
+ openenv_api_testing.egg-info/entry_points.txt
11
+ openenv_api_testing.egg-info/requires.txt
12
+ openenv_api_testing.egg-info/top_level.txt
13
+ server/__init__.py
14
+ server/app.py
15
+ server/bug_detector.py
16
+ server/environment.py
17
+ server/graders.py
18
+ server/reward.py
19
+ server/buggy_api/__init__.py
20
+ server/buggy_api/database.py
21
+ server/buggy_api/main.py
22
+ server/buggy_api/models.py
23
+ server/buggy_api/routes/__init__.py
24
+ server/buggy_api/routes/auth.py
25
+ server/buggy_api/routes/tasks.py
26
+ server/buggy_api/routes/users.py
openenv_api_testing.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
openenv_api_testing.egg-info/entry_points.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [console_scripts]
2
+ server = api_testing_env.server.app:main
openenv_api_testing.egg-info/requires.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git@v0.2.1
2
+ fastapi>=0.104.0
3
+ uvicorn>=0.24.0
4
+ httpx>=0.25.0
5
+ pydantic>=2.0.0
6
+
7
+ [dev]
8
+ pytest>=8.0.0
9
+ pytest-cov>=4.0.0
10
+
11
+ [train]
12
+ trl[vllm]>=0.29.0
13
+ torch>=2.8.0
14
+ peft
15
+ transformers
16
+ datasets
openenv_api_testing.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ api_testing_env
pyproject.toml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=45", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "openenv-api-testing"
7
+ version = "0.1.0"
8
+ description = "RL environment for intelligent API integration testing β€” train agents to find bugs in REST APIs"
9
+ requires-python = ">=3.10"
10
+ dependencies = [
11
+ "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git@v0.2.1",
12
+ "fastapi>=0.104.0",
13
+ "uvicorn>=0.24.0",
14
+ "httpx>=0.25.0",
15
+ "pydantic>=2.0.0",
16
+ "openai>=1.40.0",
17
+ "gradio>=5.0.0",
18
+ ]
19
+
20
+ [project.optional-dependencies]
21
+ ui = [
22
+ "gradio>=5.0.0",
23
+ ]
24
+ dev = [
25
+ "pytest>=8.0.0",
26
+ "pytest-cov>=4.0.0",
27
+ ]
28
+ train = [
29
+ "trl>=0.15.0",
30
+ "torch>=2.1.0",
31
+ "peft>=0.7.0",
32
+ "transformers>=4.40.0",
33
+ "datasets>=2.16.0",
34
+ "wandb>=0.16.0",
35
+ "huggingface-hub>=0.20.0",
36
+ "matplotlib>=3.8.0",
37
+ ]
38
+
39
+ [project.scripts]
40
+ server = "api_testing_env.server.app:main"
41
+
42
+ [tool.uv]
43
+ package = false
44
+
45
+ [tool.setuptools]
46
+ include-package-data = true
47
+ packages = [
48
+ "api_testing_env",
49
+ "api_testing_env.server",
50
+ "api_testing_env.server.buggy_api",
51
+ "api_testing_env.server.buggy_api.routes",
52
+ "api_testing_env.training",
53
+ ]
54
+
55
+ [tool.setuptools.package-dir]
56
+ api_testing_env = "."
57
+ "api_testing_env.server" = "server"
58
+ "api_testing_env.server.buggy_api" = "server/buggy_api"
59
+ "api_testing_env.server.buggy_api.routes" = "server/buggy_api/routes"
60
+ "api_testing_env.training" = "training"
requirements.txt ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core dependencies
2
+ openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git@v0.2.1
3
+ fastapi>=0.104.0
4
+ uvicorn>=0.24.0
5
+ httpx>=0.25.0
6
+ pydantic>=2.0.0,<2.12
7
+
8
+ # Training dependencies
9
+ # NOTE: PyTorch is NOT listed here β€” it must be installed separately
10
+ # with the correct CUDA version. See setup.sh or run:
11
+ # pip install torch --index-url https://download.pytorch.org/whl/cu121
12
+ trl>=0.15.0
13
+ peft>=0.7.0
14
+ transformers>=4.40.0
15
+ datasets>=2.16.0
16
+
17
+ # Weights & Biases (optional but recommended)
18
+ wandb>=0.16.0
19
+
20
+ # HuggingFace Hub (for model push)
21
+ huggingface-hub>=0.20.0
22
+
23
+ # Plots and metrics
24
+ matplotlib>=3.8.0
25
+
26
+ # UI
27
+ gradio>=5.0.0
server/__init__.py ADDED
File without changes
server/app.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI application for the API Testing Environment.
3
+
4
+ Endpoints:
5
+ - POST /reset: Reset the environment
6
+ - POST /step: Execute an action
7
+ - GET /state: Get current environment state
8
+ - GET /schema: Get action/observation schemas
9
+ - WS /ws: WebSocket endpoint for persistent sessions
10
+ - GET / Info page
11
+
12
+ Usage:
13
+ uvicorn server.app:app --host 0.0.0.0 --port 8000
14
+ """
15
+
16
+ import os
17
+ import logging
18
+
19
+ try:
20
+ from openenv.core.env_server.http_server import create_app
21
+ from ..models import APITestAction, APITestObservation
22
+ from .environment import APITestEnvironment
23
+ except ImportError:
24
+ from openenv.core.env_server.http_server import create_app
25
+ from models import APITestAction, APITestObservation
26
+ from server.environment import APITestEnvironment
27
+
28
+ from fastapi.responses import RedirectResponse
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+ app = create_app(
33
+ APITestEnvironment,
34
+ APITestAction,
35
+ APITestObservation,
36
+ env_name="api_testing_env",
37
+ max_concurrent_envs=int(os.environ.get("MAX_ENVS", "1")),
38
+ )
39
+
40
+ # Track whether the Gradio UI is available so root can redirect to it
41
+ _GRADIO_MOUNTED = False
42
+
43
+
44
+ @app.get("/info")
45
+ async def info():
46
+ """JSON info about the environment (replaces the old `/` JSON endpoint)."""
47
+ return {
48
+ "name": "API Testing Environment",
49
+ "description": "An OpenEnv RL environment where an AI agent learns to test REST APIs intelligently",
50
+ "tasks": ["basic_validation", "edge_cases", "security_workflows"],
51
+ "ui": "/ui",
52
+ "docs": "/docs",
53
+ "schema": "/schema",
54
+ }
55
+
56
+
57
+ @app.get("/tasks")
58
+ async def list_tasks():
59
+ """List available tasks with descriptions."""
60
+ from .environment import TASKS
61
+ return {
62
+ task_id: {
63
+ "description": task["description"],
64
+ "difficulty": task["difficulty"],
65
+ "max_steps": task["max_steps"],
66
+ "total_bugs": task["total_bugs"],
67
+ }
68
+ for task_id, task in TASKS.items()
69
+ }
70
+
71
+
72
+ # ---------------------------------------------------------------------------
73
+ # Mount Gradio UI at /ui (only if gradio is installed and ENABLE_WEB_INTERFACE)
74
+ # ---------------------------------------------------------------------------
75
+ if os.environ.get("ENABLE_WEB_INTERFACE", "true").lower() in ("1", "true", "yes"):
76
+ try:
77
+ import gradio as gr # type: ignore
78
+ # Make the repo root importable so gradio_app's `from models import ...` works
79
+ import sys
80
+ _REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
81
+ if _REPO_ROOT not in sys.path:
82
+ sys.path.insert(0, _REPO_ROOT)
83
+ from gradio_app import build_ui # type: ignore
84
+
85
+ _gradio_ui = build_ui()
86
+ app = gr.mount_gradio_app(app, _gradio_ui, path="/ui")
87
+ _GRADIO_MOUNTED = True
88
+ logger.info("Gradio UI mounted at /ui")
89
+ except Exception as exc: # noqa: BLE001
90
+ logger.warning(f"Skipping Gradio mount ({type(exc).__name__}: {exc})")
91
+
92
+
93
+ # ---------------------------------------------------------------------------
94
+ # Root redirect: send visitors to the Gradio UI if mounted, else to JSON info
95
+ # ---------------------------------------------------------------------------
96
+ @app.get("/", include_in_schema=False)
97
+ async def root_redirect():
98
+ """Redirect / to the Gradio UI when available, otherwise to /info JSON."""
99
+ if _GRADIO_MOUNTED:
100
+ return RedirectResponse(url="/ui", status_code=307)
101
+ return RedirectResponse(url="/info", status_code=307)
102
+
103
+
104
+ def main(host: str = None, port: int = None):
105
+ """Entry point for `uv run server` and `python -m server.app`.
106
+
107
+ When invoked from the CLI without args, parses argv for --host / --port.
108
+ """
109
+ import uvicorn
110
+
111
+ if host is None or port is None:
112
+ import argparse
113
+ parser = argparse.ArgumentParser(description="API Testing Environment server")
114
+ parser.add_argument("--host", default="0.0.0.0")
115
+ parser.add_argument("--port", type=int, default=None)
116
+ args, _ = parser.parse_known_args()
117
+ host = host or args.host
118
+ port = port or args.port
119
+
120
+ if port is None:
121
+ port = int(os.environ.get("PORT", "8000"))
122
+
123
+ logging.basicConfig(
124
+ level=logging.INFO,
125
+ format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
126
+ )
127
+ logging.getLogger("httpx").setLevel(logging.WARNING)
128
+ logging.getLogger("httpcore").setLevel(logging.WARNING)
129
+ logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
130
+
131
+ uvicorn.run(app, host=host, port=port)
132
+
133
+
134
+ if __name__ == "__main__":
135
+ main()
server/bug_detector.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Bug detection logic β€” checks if the agent's action/response pair reveals a planted bug.
3
+
4
+ Each bug has:
5
+ - A unique ID
6
+ - A severity level (easy/medium/hard)
7
+ - A detection function that checks action + response
8
+ """
9
+
10
+ from dataclasses import dataclass
11
+ from typing import Any, Callable, Optional
12
+ import re
13
+
14
+
15
+ @dataclass
16
+ class Bug:
17
+ id: str
18
+ severity: str # "easy", "medium", "hard"
19
+ description: str
20
+ category: str # "status_code", "validation", "security", "data_integrity"
21
+ owasp: str = "" # OWASP API Security Top 10 (2023) category
22
+ recommendation: str = "" # Fix recommendation for bug bounty reports
23
+
24
+
25
+ @dataclass
26
+ class BugDetection:
27
+ bug: Bug
28
+ evidence: str # Human-readable explanation of how the bug was detected
29
+
30
+
31
+ class BugDetector:
32
+ """Detects planted bugs based on agent actions and API responses."""
33
+
34
+ def __init__(self, task_id: str):
35
+ self.task_id = task_id
36
+ self._build_bug_registry()
37
+
38
+ def _build_bug_registry(self):
39
+ """Define all bugs with their detection logic."""
40
+ self.bugs: dict[str, Bug] = {}
41
+ self.detectors: dict[str, Callable] = {}
42
+
43
+ # === EASY BUGS ===
44
+
45
+ self._register_bug(
46
+ Bug("BUG_TASK_01", "easy",
47
+ "GET /tasks/{id} returns 200 with null for non-existent task",
48
+ "status_code",
49
+ owasp="API8:2023 Security Misconfiguration",
50
+ recommendation="Return 404 Not Found for non-existent resources"),
51
+ self._detect_null_response_for_missing_task,
52
+ )
53
+ self._register_bug(
54
+ Bug("BUG_TASK_02", "easy",
55
+ "POST /tasks with missing title returns 500 instead of 400/422",
56
+ "validation",
57
+ owasp="API8:2023 Security Misconfiguration",
58
+ recommendation="Validate required fields and return 400/422 with descriptive error"),
59
+ self._detect_missing_field_500,
60
+ )
61
+ self._register_bug(
62
+ Bug("BUG_TASK_03", "easy",
63
+ "GET /tasks?page=-1 returns 200 instead of 400",
64
+ "validation",
65
+ owasp="API8:2023 Security Misconfiguration",
66
+ recommendation="Validate pagination parameters: page >= 1, limit > 0"),
67
+ self._detect_negative_page,
68
+ )
69
+
70
+ # === MEDIUM BUGS ===
71
+
72
+ self._register_bug(
73
+ Bug("BUG_TASK_04", "medium",
74
+ "PUT /tasks/{id} accepts invalid email format for assignee_email",
75
+ "validation",
76
+ owasp="API8:2023 Security Misconfiguration",
77
+ recommendation="Validate email format with regex before accepting"),
78
+ self._detect_invalid_email_accepted,
79
+ )
80
+ self._register_bug(
81
+ Bug("BUG_TASK_05", "medium",
82
+ "DELETE /tasks/{id} returns 200 for non-existent task",
83
+ "status_code",
84
+ owasp="API8:2023 Security Misconfiguration",
85
+ recommendation="Check resource existence before deletion, return 404 if missing"),
86
+ self._detect_delete_nonexistent_200,
87
+ )
88
+ self._register_bug(
89
+ Bug("BUG_TASK_06", "medium",
90
+ "GET /tasks?limit=999999 has no pagination cap",
91
+ "validation",
92
+ owasp="API4:2023 Unrestricted Resource Consumption",
93
+ recommendation="Cap pagination limit at 100, reject values above maximum"),
94
+ self._detect_no_pagination_cap,
95
+ )
96
+ self._register_bug(
97
+ Bug("BUG_USER_01", "medium",
98
+ "POST /users accepts invalid email format",
99
+ "validation",
100
+ owasp="API8:2023 Security Misconfiguration",
101
+ recommendation="Validate email format server-side before creating user"),
102
+ self._detect_user_invalid_email,
103
+ )
104
+ self._register_bug(
105
+ Bug("BUG_USER_02", "medium",
106
+ "POST /users response exposes password hash",
107
+ "security",
108
+ owasp="API3:2023 Broken Object Property Level Authorization",
109
+ recommendation="Never return sensitive fields (password_hash) in API responses"),
110
+ self._detect_password_hash_exposed,
111
+ )
112
+ self._register_bug(
113
+ Bug("BUG_AUTH_02", "medium",
114
+ "Login with empty password succeeds",
115
+ "security",
116
+ owasp="API2:2023 Broken Authentication",
117
+ recommendation="Validate password is non-empty and verify against stored hash"),
118
+ self._detect_empty_password_login,
119
+ )
120
+
121
+ # === HARD BUGS ===
122
+
123
+ self._register_bug(
124
+ Bug("BUG_TASK_07", "hard",
125
+ "BOLA: User A can access User B's tasks without authorization check",
126
+ "security",
127
+ owasp="API1:2023 Broken Object Level Authorization",
128
+ recommendation="Verify resource ownership: check task.owner_id matches authenticated user"),
129
+ self._detect_bola,
130
+ )
131
+ self._register_bug(
132
+ Bug("BUG_TASK_08", "hard",
133
+ "POST /tasks with very long title (>5000 chars) causes 500",
134
+ "validation",
135
+ owasp="API4:2023 Unrestricted Resource Consumption",
136
+ recommendation="Add input length validation: title max 200 chars"),
137
+ self._detect_long_input_crash,
138
+ )
139
+ self._register_bug(
140
+ Bug("BUG_TASK_09", "hard",
141
+ "SQL injection payload in title is stored verbatim (content injection)",
142
+ "security",
143
+ owasp="API8:2023 Security Misconfiguration",
144
+ recommendation="Sanitize user input before storage, escape HTML/SQL special characters"),
145
+ self._detect_content_injection,
146
+ )
147
+ self._register_bug(
148
+ Bug("BUG_AUTH_01", "hard",
149
+ "Auth tokens not user-scoped: User A's token can modify User B's tasks",
150
+ "security",
151
+ owasp="API1:2023 Broken Object Level Authorization",
152
+ recommendation="Enforce ownership check on all write operations (PUT/DELETE)"),
153
+ self._detect_broken_auth,
154
+ )
155
+
156
+ def _register_bug(self, bug: Bug, detector: Callable):
157
+ self.bugs[bug.id] = bug
158
+ self.detectors[bug.id] = detector
159
+
160
+ def get_bugs_for_task(self) -> list[Bug]:
161
+ """Return bugs relevant to the current task."""
162
+ if self.task_id == "basic_validation":
163
+ return [self.bugs[bid] for bid in ["BUG_TASK_01", "BUG_TASK_02", "BUG_TASK_03"]]
164
+ elif self.task_id == "edge_cases":
165
+ return [
166
+ self.bugs[bid]
167
+ for bid in [
168
+ "BUG_TASK_01", "BUG_TASK_02", "BUG_TASK_03",
169
+ "BUG_TASK_04", "BUG_TASK_05", "BUG_TASK_06",
170
+ "BUG_USER_01", "BUG_USER_02", "BUG_AUTH_02",
171
+ ]
172
+ ]
173
+ else: # security_workflows
174
+ return list(self.bugs.values())
175
+
176
+ def check(
177
+ self,
178
+ method: str,
179
+ endpoint: str,
180
+ headers: dict,
181
+ query_params: dict,
182
+ body: Optional[dict],
183
+ expected_status: Optional[int],
184
+ response_status: int,
185
+ response_body: Any,
186
+ action_history: list[dict],
187
+ found_bugs: set[str],
188
+ ) -> Optional[BugDetection]:
189
+ """Check if this action/response reveals a bug.
190
+
191
+ Returns the first new bug detected, or None.
192
+ """
193
+ ctx = {
194
+ "method": method.upper(),
195
+ "endpoint": endpoint,
196
+ "headers": headers,
197
+ "query_params": query_params,
198
+ "body": body,
199
+ "expected_status": expected_status,
200
+ "response_status": response_status,
201
+ "response_body": response_body,
202
+ "action_history": action_history,
203
+ }
204
+
205
+ for bug_id, detector in self.detectors.items():
206
+ if bug_id in found_bugs:
207
+ continue
208
+ # Only check bugs relevant to this task
209
+ task_bugs = {b.id for b in self.get_bugs_for_task()}
210
+ if bug_id not in task_bugs:
211
+ continue
212
+ result = detector(ctx)
213
+ if result:
214
+ return BugDetection(bug=self.bugs[bug_id], evidence=result)
215
+
216
+ return None
217
+
218
+ # === DETECTION FUNCTIONS ===
219
+
220
+ def _detect_null_response_for_missing_task(self, ctx: dict) -> Optional[str]:
221
+ if (
222
+ ctx["method"] == "GET"
223
+ and re.match(r"^/tasks/\d+$", ctx["endpoint"])
224
+ and ctx["response_status"] == 200
225
+ and ctx["response_body"] is None
226
+ ):
227
+ task_id = ctx["endpoint"].split("/")[-1]
228
+ return f"GET /tasks/{task_id} returned 200 with null body β€” should be 404"
229
+ return None
230
+
231
+ def _detect_missing_field_500(self, ctx: dict) -> Optional[str]:
232
+ if (
233
+ ctx["method"] == "POST"
234
+ and ctx["endpoint"] == "/tasks"
235
+ and ctx["response_status"] == 500
236
+ and ctx["body"] is not None
237
+ and "title" not in ctx["body"]
238
+ ):
239
+ return "POST /tasks with missing 'title' returned 500 β€” should be 400 or 422"
240
+ return None
241
+
242
+ def _detect_negative_page(self, ctx: dict) -> Optional[str]:
243
+ if (
244
+ ctx["method"] == "GET"
245
+ and ctx["endpoint"] == "/tasks"
246
+ and ctx["query_params"].get("page") is not None
247
+ ):
248
+ page = ctx["query_params"]["page"]
249
+ try:
250
+ page = int(page)
251
+ except (ValueError, TypeError):
252
+ return None
253
+ if page < 1 and ctx["response_status"] == 200:
254
+ return f"GET /tasks?page={page} returned 200 β€” should be 400 for invalid page"
255
+ return None
256
+
257
+ def _detect_invalid_email_accepted(self, ctx: dict) -> Optional[str]:
258
+ if (
259
+ ctx["method"] == "PUT"
260
+ and re.match(r"^/tasks/\d+$", ctx["endpoint"])
261
+ and ctx["body"]
262
+ and "assignee_email" in ctx["body"]
263
+ and ctx["response_status"] in (200, 201)
264
+ ):
265
+ email = ctx["body"]["assignee_email"]
266
+ if email and not re.match(r"^[^@]+@[^@]+\.[^@]+$", email):
267
+ return f"PUT accepted invalid email '{email}' without validation"
268
+ return None
269
+
270
+ def _detect_delete_nonexistent_200(self, ctx: dict) -> Optional[str]:
271
+ if (
272
+ ctx["method"] == "DELETE"
273
+ and re.match(r"^/tasks/\d+$", ctx["endpoint"])
274
+ and ctx["response_status"] == 200
275
+ ):
276
+ task_id = int(ctx["endpoint"].split("/")[-1])
277
+ # Check if this task was never created (ID > 1000 is a safe bet for non-existent)
278
+ if task_id > 100:
279
+ return f"DELETE /tasks/{task_id} returned 200 for non-existent task β€” should be 404"
280
+ return None
281
+
282
+ def _detect_no_pagination_cap(self, ctx: dict) -> Optional[str]:
283
+ if (
284
+ ctx["method"] == "GET"
285
+ and ctx["endpoint"] == "/tasks"
286
+ and ctx["response_status"] == 200
287
+ ):
288
+ limit = ctx["query_params"].get("limit")
289
+ if limit is not None:
290
+ try:
291
+ limit = int(limit)
292
+ except (ValueError, TypeError):
293
+ return None
294
+ if limit > 1000:
295
+ return f"GET /tasks?limit={limit} accepted without pagination cap β€” potential DoS"
296
+ return None
297
+
298
+ def _detect_user_invalid_email(self, ctx: dict) -> Optional[str]:
299
+ if (
300
+ ctx["method"] == "POST"
301
+ and ctx["endpoint"] == "/users"
302
+ and ctx["body"]
303
+ and "email" in ctx["body"]
304
+ and ctx["response_status"] == 201
305
+ ):
306
+ email = ctx["body"]["email"]
307
+ if email and not re.match(r"^[^@]+@[^@]+\.[^@]+$", email):
308
+ return f"POST /users accepted invalid email '{email}'"
309
+ return None
310
+
311
+ def _detect_password_hash_exposed(self, ctx: dict) -> Optional[str]:
312
+ if (
313
+ ctx["method"] == "POST"
314
+ and ctx["endpoint"] == "/users"
315
+ and ctx["response_status"] == 201
316
+ and isinstance(ctx["response_body"], dict)
317
+ ):
318
+ if "password_hash" in ctx["response_body"]:
319
+ return "POST /users response exposes password_hash field β€” security vulnerability"
320
+ return None
321
+
322
+ def _detect_empty_password_login(self, ctx: dict) -> Optional[str]:
323
+ if (
324
+ ctx["method"] == "POST"
325
+ and ctx["endpoint"] == "/auth/login"
326
+ and ctx["body"]
327
+ and ctx["response_status"] == 200
328
+ ):
329
+ password = ctx["body"].get("password", "NOTEMPTY")
330
+ if password == "" or password is None:
331
+ return "Login with empty password succeeded β€” authentication bypass"
332
+ return None
333
+
334
+ def _detect_bola(self, ctx: dict) -> Optional[str]:
335
+ """Detect if agent tested cross-user resource access."""
336
+ if (
337
+ ctx["method"] == "GET"
338
+ and re.match(r"^/tasks/\d+$", ctx["endpoint"])
339
+ and ctx["response_status"] == 200
340
+ and ctx["response_body"] is not None
341
+ and isinstance(ctx["response_body"], dict)
342
+ and ctx["headers"].get("Authorization")
343
+ ):
344
+ # Check if the agent logged in as a different user and accessed another's task
345
+ for prev in reversed(ctx["action_history"]):
346
+ if (
347
+ prev.get("method") == "POST"
348
+ and prev.get("endpoint") == "/auth/login"
349
+ and prev.get("response_status") == 200
350
+ and isinstance(prev.get("response_body"), dict)
351
+ ):
352
+ login_user_id = prev["response_body"].get("user_id")
353
+ task_owner_id = ctx["response_body"].get("owner_id")
354
+ if (
355
+ login_user_id is not None
356
+ and task_owner_id is not None
357
+ and login_user_id != task_owner_id
358
+ ):
359
+ return (
360
+ f"User {login_user_id} accessed task owned by user {task_owner_id} "
361
+ f"β€” BOLA/IDOR vulnerability (no authorization check)"
362
+ )
363
+ return None
364
+
365
+ def _detect_long_input_crash(self, ctx: dict) -> Optional[str]:
366
+ if (
367
+ ctx["method"] == "POST"
368
+ and ctx["endpoint"] == "/tasks"
369
+ and ctx["body"]
370
+ and ctx["response_status"] == 500
371
+ ):
372
+ title = ctx["body"].get("title", "")
373
+ if isinstance(title, str) and len(title) > 5000:
374
+ return f"POST /tasks with title length {len(title)} caused 500 β€” no input length validation"
375
+ return None
376
+
377
+ def _detect_content_injection(self, ctx: dict) -> Optional[str]:
378
+ if (
379
+ ctx["method"] == "POST"
380
+ and ctx["endpoint"] == "/tasks"
381
+ and ctx["body"]
382
+ and ctx["response_status"] == 201
383
+ and isinstance(ctx["response_body"], dict)
384
+ ):
385
+ title = ctx["body"].get("title", "")
386
+ injection_patterns = [
387
+ "DROP TABLE", "DELETE FROM", "<script>", "javascript:",
388
+ "'; --", "\" OR 1=1", "UNION SELECT",
389
+ ]
390
+ for pattern in injection_patterns:
391
+ if pattern.lower() in str(title).lower():
392
+ stored_title = ctx["response_body"].get("title", "")
393
+ if pattern.lower() in str(stored_title).lower():
394
+ return (
395
+ f"Injection payload '{pattern}' in title was stored verbatim "
396
+ f"β€” no input sanitization (content injection)"
397
+ )
398
+ return None
399
+
400
+ def _detect_broken_auth(self, ctx: dict) -> Optional[str]:
401
+ """Detect if agent successfully modified another user's task with their own token."""
402
+ if (
403
+ ctx["method"] in ("PUT", "DELETE")
404
+ and re.match(r"^/tasks/\d+$", ctx["endpoint"])
405
+ and ctx["response_status"] == 200
406
+ and ctx["headers"].get("Authorization")
407
+ ):
408
+ for prev in reversed(ctx["action_history"]):
409
+ if (
410
+ prev.get("method") == "POST"
411
+ and prev.get("endpoint") == "/auth/login"
412
+ and prev.get("response_status") == 200
413
+ and isinstance(prev.get("response_body"), dict)
414
+ ):
415
+ login_user_id = prev["response_body"].get("user_id")
416
+ # Check if the task belonged to a different user
417
+ task_id = int(ctx["endpoint"].split("/")[-1])
418
+ if isinstance(ctx["response_body"], dict):
419
+ task_owner = ctx["response_body"].get("owner_id")
420
+ if (
421
+ login_user_id is not None
422
+ and task_owner is not None
423
+ and login_user_id != task_owner
424
+ ):
425
+ return (
426
+ f"User {login_user_id}'s token modified task owned by user {task_owner} "
427
+ f"β€” broken authorization"
428
+ )
429
+ break
430
+ return None
server/buggy_api/__init__.py ADDED
File without changes
server/buggy_api/database.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ In-memory SQLite database for the buggy API.
3
+ Supports reset between episodes with DOMAIN RANDOMIZATION β€”
4
+ each seed produces different users, tasks, and data distributions
5
+ so that every training episode is unique.
6
+ """
7
+
8
+ import random
9
+ import sqlite3
10
+ import threading
11
+ from contextlib import contextmanager
12
+
13
+ # Name pools for randomized seed data
14
+ FIRST_NAMES = [
15
+ "alice", "bob", "charlie", "diana", "ethan", "fiona", "george", "hannah",
16
+ "ivan", "julia", "kevin", "luna", "mike", "nina", "oscar", "priya",
17
+ "quinn", "ravi", "sara", "tom", "uma", "victor", "wendy", "xander",
18
+ ]
19
+ DOMAINS = ["example.com", "company.org", "startup.io", "work.dev", "test.net"]
20
+ TASK_TITLES = [
21
+ "Setup CI/CD pipeline", "Write unit tests", "Fix login page CSS",
22
+ "Database migration", "API documentation", "Refactor auth module",
23
+ "Add rate limiting", "Setup monitoring", "Fix memory leak",
24
+ "Update dependencies", "Add logging middleware", "Create admin panel",
25
+ "Implement caching", "Fix CORS issues", "Add input validation",
26
+ "Setup Docker compose", "Write integration tests", "Fix date parsing bug",
27
+ "Add search functionality", "Implement pagination", "Setup SSL certs",
28
+ "Add webhook support", "Fix timezone handling", "Create backup script",
29
+ "Optimize database queries", "Add email notifications", "Fix file upload",
30
+ "Implement user roles", "Add audit logging", "Setup load balancer",
31
+ ]
32
+ TASK_DESCRIPTIONS = [
33
+ "Configure GitHub Actions for automated deployment",
34
+ "Add tests for the auth module endpoints",
35
+ "Button alignment issue on mobile devices",
36
+ "Migrate from SQLite to PostgreSQL",
37
+ "Document all REST endpoints with examples",
38
+ "Break down the monolithic auth into smaller services",
39
+ "Prevent API abuse with request throttling",
40
+ "Setup Grafana dashboards for key metrics",
41
+ "Memory usage grows unbounded after 1000 requests",
42
+ "Several packages have critical CVEs",
43
+ "Add structured JSON logging to all routes",
44
+ "Build an admin dashboard for user management",
45
+ "Add Redis caching layer for frequent queries",
46
+ "Frontend gets blocked by CORS policy",
47
+ "Sanitize user inputs to prevent injection",
48
+ ]
49
+ STATUSES = ["pending", "in_progress", "done"]
50
+ PRIORITIES = ["low", "medium", "high"]
51
+
52
+
53
+ class Database:
54
+ """Thread-safe in-memory SQLite database that can be reset between episodes.
55
+
56
+ When a seed is provided, the database is populated with deterministically
57
+ randomized data β€” different users, tasks, and distributions each time.
58
+ This prevents the agent from memorizing a single fixed dataset.
59
+ """
60
+
61
+ def __init__(self, seed: int | None = None):
62
+ self._lock = threading.Lock()
63
+ self._conn: sqlite3.Connection | None = None
64
+ self._seed = seed
65
+ self.initialize()
66
+
67
+ def initialize(self):
68
+ """Create a fresh database with schema and seed data."""
69
+ with self._lock:
70
+ if self._conn:
71
+ self._conn.close()
72
+ self._conn = sqlite3.connect(":memory:", check_same_thread=False)
73
+ self._conn.row_factory = sqlite3.Row
74
+ self._conn.execute("PRAGMA journal_mode=WAL")
75
+ self._create_schema()
76
+ self._seed_data()
77
+
78
+ def _create_schema(self):
79
+ cursor = self._conn.cursor()
80
+ cursor.executescript("""
81
+ CREATE TABLE IF NOT EXISTS users (
82
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
83
+ username TEXT UNIQUE NOT NULL,
84
+ email TEXT NOT NULL,
85
+ password_hash TEXT NOT NULL,
86
+ role TEXT DEFAULT 'user',
87
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
88
+ );
89
+
90
+ CREATE TABLE IF NOT EXISTS tasks (
91
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
92
+ title TEXT NOT NULL,
93
+ description TEXT DEFAULT '',
94
+ status TEXT DEFAULT 'pending',
95
+ priority TEXT DEFAULT 'medium',
96
+ assignee_email TEXT DEFAULT '',
97
+ owner_id INTEGER NOT NULL,
98
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
99
+ updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
100
+ FOREIGN KEY (owner_id) REFERENCES users(id)
101
+ );
102
+
103
+ CREATE TABLE IF NOT EXISTS auth_tokens (
104
+ token TEXT PRIMARY KEY,
105
+ user_id INTEGER NOT NULL,
106
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
107
+ expires_at TIMESTAMP,
108
+ FOREIGN KEY (user_id) REFERENCES users(id)
109
+ );
110
+ """)
111
+ self._conn.commit()
112
+
113
+ def _seed_data(self):
114
+ """Seed the database with randomized data based on the seed.
115
+
116
+ With seed=None, uses a fixed default dataset (for manual testing).
117
+ With a seed, generates random users/tasks so every episode differs.
118
+ """
119
+ rng = random.Random(self._seed)
120
+ cursor = self._conn.cursor()
121
+
122
+ if self._seed is None:
123
+ # Default fixed data for manual testing / Gradio UI
124
+ cursor.executescript("""
125
+ INSERT INTO users (username, email, password_hash, role) VALUES
126
+ ('alice', 'alice@example.com', 'hashed_password123', 'admin'),
127
+ ('bob', 'bob@example.com', 'hashed_password123', 'user'),
128
+ ('charlie', 'charlie@example.com', 'hashed_password123', 'user');
129
+
130
+ INSERT INTO tasks (title, description, status, priority, assignee_email, owner_id) VALUES
131
+ ('Setup CI/CD pipeline', 'Configure GitHub Actions', 'in_progress', 'high', 'alice@example.com', 1),
132
+ ('Write unit tests', 'Add tests for auth module', 'pending', 'medium', 'bob@example.com', 2),
133
+ ('Fix login page CSS', 'Button alignment issue', 'done', 'low', 'charlie@example.com', 3),
134
+ ('Database migration', 'Migrate to PostgreSQL', 'pending', 'high', 'alice@example.com', 1),
135
+ ('API documentation', 'Document all endpoints', 'in_progress', 'medium', 'bob@example.com', 2);
136
+ """)
137
+ else:
138
+ # Randomized data β€” different every episode
139
+ # Pick 3-5 users from the name pool
140
+ num_users = rng.randint(3, 5)
141
+ user_names = rng.sample(FIRST_NAMES, num_users)
142
+ domain = rng.choice(DOMAINS)
143
+
144
+ # First user is always admin, rest are regular users
145
+ for i, name in enumerate(user_names):
146
+ role = "admin" if i == 0 else "user"
147
+ email = f"{name}@{domain}"
148
+ cursor.execute(
149
+ "INSERT INTO users (username, email, password_hash, role) VALUES (?, ?, ?, ?)",
150
+ (name, email, f"hashed_password_{rng.randint(100, 999)}", role),
151
+ )
152
+
153
+ # Pick 4-8 tasks with random assignments
154
+ num_tasks = rng.randint(4, 8)
155
+ task_titles = rng.sample(TASK_TITLES, min(num_tasks, len(TASK_TITLES)))
156
+ task_descs = rng.sample(TASK_DESCRIPTIONS, min(num_tasks, len(TASK_DESCRIPTIONS)))
157
+
158
+ for i in range(num_tasks):
159
+ owner_id = rng.randint(1, num_users)
160
+ assignee_id = rng.randint(1, num_users)
161
+ assignee_email = f"{user_names[assignee_id - 1]}@{domain}"
162
+ cursor.execute(
163
+ "INSERT INTO tasks (title, description, status, priority, assignee_email, owner_id) VALUES (?, ?, ?, ?, ?, ?)",
164
+ (
165
+ task_titles[i % len(task_titles)],
166
+ task_descs[i % len(task_descs)] if i < len(task_descs) else "",
167
+ rng.choice(STATUSES),
168
+ rng.choice(PRIORITIES),
169
+ assignee_email,
170
+ owner_id,
171
+ ),
172
+ )
173
+
174
+ self._conn.commit()
175
+
176
+ @property
177
+ def user_names(self) -> list[str]:
178
+ """Get usernames in the database (for the agent's observation)."""
179
+ rows = self.execute("SELECT username FROM users ORDER BY id")
180
+ return [r["username"] for r in rows]
181
+
182
+ @contextmanager
183
+ def get_cursor(self):
184
+ with self._lock:
185
+ cursor = self._conn.cursor()
186
+ try:
187
+ yield cursor
188
+ self._conn.commit()
189
+ except Exception:
190
+ self._conn.rollback()
191
+ raise
192
+
193
+ def execute(self, query: str, params: tuple = ()) -> list[dict]:
194
+ with self.get_cursor() as cursor:
195
+ cursor.execute(query, params)
196
+ if cursor.description:
197
+ columns = [col[0] for col in cursor.description]
198
+ return [dict(zip(columns, row)) for row in cursor.fetchall()]
199
+ return []
200
+
201
+ def execute_insert(self, query: str, params: tuple = ()) -> int:
202
+ with self.get_cursor() as cursor:
203
+ cursor.execute(query, params)
204
+ return cursor.lastrowid
205
+
206
+ def execute_update(self, query: str, params: tuple = ()) -> int:
207
+ with self.get_cursor() as cursor:
208
+ cursor.execute(query, params)
209
+ return cursor.rowcount
server/buggy_api/main.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The deliberately buggy REST API β€” a task management system.
3
+
4
+ This API is the system-under-test. It has intentionally planted bugs at varying
5
+ difficulty levels that the AI agent must discover through intelligent testing.
6
+
7
+ The API runs in-process via Starlette's TestClient (no separate port needed).
8
+ """
9
+
10
+ import json
11
+ import logging
12
+
13
+ from fastapi import FastAPI, Request, Header
14
+ from fastapi.responses import JSONResponse
15
+ from typing import Optional
16
+
17
+ from .database import Database
18
+ from .routes import tasks as tasks_routes
19
+ from .routes import users as users_routes
20
+ from .routes import auth as auth_routes
21
+ from .models import TaskCreate
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ def create_buggy_api(db: Database) -> FastAPI:
27
+ """Create a fresh buggy API instance wired to the given database."""
28
+ api = FastAPI(
29
+ title="TaskTracker API",
30
+ description="A task management API (with bugs)",
31
+ version="1.0.0",
32
+ )
33
+
34
+ # Wire database into route modules
35
+ tasks_routes.set_db(db)
36
+ users_routes.set_db(db)
37
+ auth_routes.set_db(db)
38
+
39
+ # Include standard routes
40
+ api.include_router(tasks_routes.router)
41
+ api.include_router(users_routes.router)
42
+ api.include_router(auth_routes.router)
43
+
44
+ # BUG_TASK_02 + BUG_TASK_08: Raw POST /tasks handler that doesn't use Pydantic validation
45
+ # This allows missing fields and overly long inputs to cause 500 errors
46
+ @api.post("/tasks", status_code=201)
47
+ async def create_task_raw(
48
+ request: Request,
49
+ authorization: Optional[str] = Header(None),
50
+ ):
51
+ try:
52
+ body = await request.json()
53
+ except Exception:
54
+ # BUG_TASK_02: Returns 500 on malformed/empty body instead of 400
55
+ raise Exception("Failed to parse request body")
56
+
57
+ if not isinstance(body, dict):
58
+ raise Exception("Invalid body format")
59
+
60
+ title = body.get("title")
61
+
62
+ # BUG_TASK_02: No check for missing title β€” causes KeyError/500 below
63
+ if title is None:
64
+ # This SHOULD return 400, but we let it fall through to cause 500
65
+ # Simulate an internal error from missing required field
66
+ raise Exception("Internal error: title is required but was None")
67
+
68
+ # BUG_TASK_08: No length validation on title
69
+ if len(title) > 5000:
70
+ # Simulate a database error from overly long input
71
+ raise Exception(f"Database error: value too long for column 'title' (length={len(title)})")
72
+
73
+ task_data = TaskCreate(
74
+ title=title,
75
+ description=body.get("description", ""),
76
+ status=body.get("status", "pending"),
77
+ priority=body.get("priority", "medium"),
78
+ assignee_email=body.get("assignee_email", ""),
79
+ )
80
+ return tasks_routes.create_task_internal(task_data, authorization)
81
+
82
+ # Global error handler β€” returns 500 for unhandled exceptions
83
+ @api.exception_handler(Exception)
84
+ async def global_exception_handler(request: Request, exc: Exception):
85
+ logger.error(f"Unhandled error: {exc}")
86
+ return JSONResponse(
87
+ status_code=500,
88
+ content={"error": "Internal Server Error", "detail": str(exc)},
89
+ )
90
+
91
+ return api
server/buggy_api/models.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pydantic models for the buggy API request/response schemas."""
2
+
3
+ from pydantic import BaseModel, Field
4
+ from typing import Optional
5
+
6
+
7
+ class UserCreate(BaseModel):
8
+ username: str
9
+ email: str
10
+ password: str
11
+ role: str = "user"
12
+
13
+
14
+ class UserResponse(BaseModel):
15
+ id: int
16
+ username: str
17
+ email: str
18
+ role: str
19
+ created_at: str
20
+
21
+
22
+ class TaskCreate(BaseModel):
23
+ title: str
24
+ description: str = ""
25
+ status: str = "pending"
26
+ priority: str = "medium"
27
+ assignee_email: str = ""
28
+
29
+
30
+ class TaskUpdate(BaseModel):
31
+ title: Optional[str] = None
32
+ description: Optional[str] = None
33
+ status: Optional[str] = None
34
+ priority: Optional[str] = None
35
+ assignee_email: Optional[str] = None
36
+
37
+
38
+ class TaskResponse(BaseModel):
39
+ id: int
40
+ title: str
41
+ description: str
42
+ status: str
43
+ priority: str
44
+ assignee_email: str
45
+ owner_id: int
46
+ created_at: str
47
+ updated_at: str
48
+
49
+
50
+ class LoginRequest(BaseModel):
51
+ username: str
52
+ password: str
53
+
54
+
55
+ class LoginResponse(BaseModel):
56
+ token: str
57
+ user_id: int
58
+ username: str
59
+ role: str
60
+
61
+
62
+ class ErrorResponse(BaseModel):
63
+ error: str
64
+ detail: str = ""
server/buggy_api/routes/__init__.py ADDED
File without changes
server/buggy_api/routes/auth.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Authentication routes with planted bugs.
3
+
4
+ BUGS PLANTED:
5
+ - BUG_AUTH_01 (hard): Auth tokens are not user-scoped β€” any valid token works for any user's resources
6
+ - BUG_AUTH_02 (medium): Login with empty password succeeds (missing validation)
7
+ """
8
+
9
+ import uuid
10
+ from datetime import datetime, timedelta
11
+
12
+ from fastapi import APIRouter, Depends, Header, HTTPException
13
+ from typing import Optional
14
+
15
+ from ..database import Database
16
+ from ..models import LoginRequest, LoginResponse
17
+
18
+ router = APIRouter(prefix="/auth", tags=["auth"])
19
+
20
+ _db: Database | None = None
21
+
22
+
23
+ def set_db(db: Database):
24
+ global _db
25
+ _db = db
26
+
27
+
28
+ def get_db() -> Database:
29
+ return _db
30
+
31
+
32
+ def get_current_user(authorization: Optional[str] = Header(None)) -> dict | None:
33
+ """Extract user from auth token.
34
+
35
+ BUG_AUTH_01: Returns the token's user but doesn't enforce ownership anywhere.
36
+ The routes that use this don't check if the resource belongs to the user.
37
+ """
38
+ if not authorization:
39
+ return None
40
+ token = authorization.replace("Bearer ", "")
41
+ db = get_db()
42
+ rows = db.execute(
43
+ "SELECT u.id, u.username, u.role FROM auth_tokens t JOIN users u ON t.user_id = u.id WHERE t.token = ?",
44
+ (token,),
45
+ )
46
+ if not rows:
47
+ return None
48
+ return rows[0]
49
+
50
+
51
+ @router.post("/login", response_model=LoginResponse)
52
+ def login(req: LoginRequest):
53
+ db = get_db()
54
+
55
+ # BUG_AUTH_02: Empty password check is missing β€” empty password matches hash
56
+ # Should validate: if not req.password: raise HTTPException(400, ...)
57
+ rows = db.execute(
58
+ "SELECT id, username, role, password_hash FROM users WHERE username = ?",
59
+ (req.username,),
60
+ )
61
+ if not rows:
62
+ raise HTTPException(status_code=401, detail="Invalid credentials")
63
+
64
+ user = rows[0]
65
+ # BUG_AUTH_02 continued: Only checks username, not password properly
66
+ # In a real system we'd verify the password hash
67
+ # Here we just check if password is non-empty... but we don't!
68
+ # Any password (including empty string) works as long as username exists.
69
+
70
+ token = str(uuid.uuid4())
71
+ expires = datetime.utcnow() + timedelta(hours=24)
72
+ db.execute_insert(
73
+ "INSERT INTO auth_tokens (token, user_id, expires_at) VALUES (?, ?, ?)",
74
+ (token, user["id"], expires.isoformat()),
75
+ )
76
+
77
+ return LoginResponse(
78
+ token=token,
79
+ user_id=user["id"],
80
+ username=user["username"],
81
+ role=user["role"],
82
+ )
server/buggy_api/routes/tasks.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Task CRUD routes with planted bugs.
3
+
4
+ BUGS PLANTED:
5
+ - BUG_TASK_01 (easy): GET /tasks/{id} returns 200 with null body for non-existent task (should be 404)
6
+ - BUG_TASK_02 (easy): POST /tasks with missing required 'title' returns 500 instead of 400/422
7
+ - BUG_TASK_03 (easy): GET /tasks?page=-1 returns 200 instead of 400
8
+ - BUG_TASK_04 (medium): PUT /tasks/{id} doesn't validate assignee_email format
9
+ - BUG_TASK_05 (medium): DELETE /tasks/{id} returns 200 even for non-existent task (should be 404)
10
+ - BUG_TASK_06 (medium): GET /tasks?limit=999999 has no pagination cap (potential DoS)
11
+ - BUG_TASK_07 (hard): GET /tasks/{id} of another user's task returns data (BOLA/IDOR vulnerability)
12
+ - BUG_TASK_08 (hard): POST /tasks with very long title (>5000 chars) causes 500 (no input length validation)
13
+ - BUG_TASK_09 (hard): POST /tasks with SQL injection payload in title doesn't sanitize (uses parameterized
14
+ queries so no actual injection, but the input is stored verbatim β€” a content injection)
15
+ - BUG_TASK_10 (hard): No rate limiting β€” rapid sequential requests all succeed
16
+ """
17
+
18
+ from fastapi import APIRouter, HTTPException, Header, Query
19
+ from typing import Optional
20
+
21
+ from ..database import Database
22
+ from ..models import TaskCreate, TaskUpdate
23
+
24
+ router = APIRouter(prefix="/tasks", tags=["tasks"])
25
+
26
+ _db: Database | None = None
27
+
28
+ # Simple in-memory cache for BUG demonstration
29
+ _cache: dict[int, dict] = {}
30
+
31
+
32
+ def set_db(db: Database):
33
+ global _db, _cache
34
+ _db = db
35
+ _cache = {}
36
+
37
+
38
+ def get_db() -> Database:
39
+ return _db
40
+
41
+
42
+ @router.get("")
43
+ def list_tasks(
44
+ status: Optional[str] = Query(None, description="Filter by status"),
45
+ priority: Optional[str] = Query(None, description="Filter by priority"),
46
+ sort: Optional[str] = Query(None, description="Sort field"),
47
+ page: Optional[int] = Query(None, description="Page number"),
48
+ limit: Optional[int] = Query(None, description="Items per page"),
49
+ authorization: Optional[str] = Header(None),
50
+ ):
51
+ db = get_db()
52
+
53
+ # BUG_TASK_03: No validation for negative page numbers
54
+ # Should check: if page is not None and page < 1: raise HTTPException(400, ...)
55
+
56
+ # BUG_TASK_06: No cap on limit β€” agent can request limit=999999
57
+ # Should cap at e.g. 100
58
+
59
+ query = "SELECT * FROM tasks WHERE 1=1"
60
+ params = []
61
+
62
+ if status:
63
+ query += " AND status = ?"
64
+ params.append(status)
65
+ if priority:
66
+ query += " AND priority = ?"
67
+ params.append(priority)
68
+
69
+ if sort:
70
+ allowed_sorts = ["created_at", "updated_at", "title", "priority", "status"]
71
+ if sort in allowed_sorts:
72
+ query += f" ORDER BY {sort}"
73
+ else:
74
+ query += " ORDER BY created_at"
75
+ else:
76
+ query += " ORDER BY created_at DESC"
77
+
78
+ if limit is not None:
79
+ # BUG_TASK_06: No upper bound check on limit
80
+ query += " LIMIT ?"
81
+ params.append(limit)
82
+ else:
83
+ query += " LIMIT 20"
84
+
85
+ if page is not None and limit is not None:
86
+ # BUG_TASK_03: Allows negative offset β€” page=-1 with limit=10 gives offset=-10
87
+ offset = (page - 1) * limit
88
+ query += " OFFSET ?"
89
+ params.append(offset)
90
+
91
+ rows = db.execute(query, tuple(params))
92
+ return rows
93
+
94
+
95
+ @router.get("/{task_id}")
96
+ def get_task(
97
+ task_id: int,
98
+ authorization: Optional[str] = Header(None),
99
+ ):
100
+ db = get_db()
101
+
102
+ # Check cache first (used later for stale cache bug)
103
+ if task_id in _cache:
104
+ return _cache[task_id]
105
+
106
+ rows = db.execute("SELECT * FROM tasks WHERE id = ?", (task_id,))
107
+
108
+ # BUG_TASK_01: Returns 200 with null instead of 404
109
+ if not rows:
110
+ return None # Should be: raise HTTPException(status_code=404, detail="Task not found")
111
+
112
+ task = rows[0]
113
+
114
+ # BUG_TASK_07: No ownership check β€” any authenticated user can see any task
115
+ # Should check: if user and task["owner_id"] != user["id"]: raise HTTPException(403)
116
+
117
+ # Cache the result
118
+ _cache[task_id] = task
119
+ return task
120
+
121
+
122
+ @router.post("/create", status_code=201)
123
+ def create_task_internal(
124
+ task: TaskCreate,
125
+ authorization: Optional[str] = Header(None),
126
+ ):
127
+ """Internal create β€” used by the raw handler after parsing."""
128
+ db = get_db()
129
+
130
+ # BUG_TASK_08: No title length validation
131
+ # Should check: if len(task.title) > 200: raise HTTPException(400, ...)
132
+
133
+ # BUG_TASK_09: No content sanitization β€” SQL injection payloads stored verbatim
134
+ # While parameterized queries prevent actual SQL injection, the content
135
+ # is stored and returned as-is, which is a content injection / XSS vector
136
+
137
+ # Determine owner β€” default to user 1 if no auth
138
+ owner_id = 1
139
+ if authorization:
140
+ token = authorization.replace("Bearer ", "")
141
+ token_rows = db.execute(
142
+ "SELECT user_id FROM auth_tokens WHERE token = ?", (token,)
143
+ )
144
+ if token_rows:
145
+ owner_id = token_rows[0]["user_id"]
146
+
147
+ task_id = db.execute_insert(
148
+ "INSERT INTO tasks (title, description, status, priority, assignee_email, owner_id) VALUES (?, ?, ?, ?, ?, ?)",
149
+ (task.title, task.description, task.status, task.priority, task.assignee_email, owner_id),
150
+ )
151
+
152
+ rows = db.execute("SELECT * FROM tasks WHERE id = ?", (task_id,))
153
+ result = rows[0]
154
+ _cache[task_id] = result
155
+ return result
156
+
157
+
158
+ @router.put("/{task_id}")
159
+ def update_task(
160
+ task_id: int,
161
+ task: TaskUpdate,
162
+ authorization: Optional[str] = Header(None),
163
+ ):
164
+ db = get_db()
165
+
166
+ existing = db.execute("SELECT * FROM tasks WHERE id = ?", (task_id,))
167
+ if not existing:
168
+ raise HTTPException(status_code=404, detail="Task not found")
169
+
170
+ # BUG_TASK_04: No email format validation on assignee_email
171
+ # Should validate if task.assignee_email is provided
172
+
173
+ # BUG_TASK_07: No ownership check on update either
174
+ updates = []
175
+ params = []
176
+ for field_name in ["title", "description", "status", "priority", "assignee_email"]:
177
+ value = getattr(task, field_name, None)
178
+ if value is not None:
179
+ updates.append(f"{field_name} = ?")
180
+ params.append(value)
181
+
182
+ if updates:
183
+ updates.append("updated_at = CURRENT_TIMESTAMP")
184
+ params.append(task_id)
185
+ db.execute_update(
186
+ f"UPDATE tasks SET {', '.join(updates)} WHERE id = ?",
187
+ tuple(params),
188
+ )
189
+
190
+ rows = db.execute("SELECT * FROM tasks WHERE id = ?", (task_id,))
191
+ result = rows[0]
192
+ _cache[task_id] = result
193
+ return result
194
+
195
+
196
+ @router.delete("/{task_id}")
197
+ def delete_task(
198
+ task_id: int,
199
+ authorization: Optional[str] = Header(None),
200
+ ):
201
+ db = get_db()
202
+
203
+ # BUG_TASK_05: No existence check β€” returns 200 even for non-existent tasks
204
+ # Should check existence first and return 404
205
+ db.execute_update("DELETE FROM tasks WHERE id = ?", (task_id,))
206
+
207
+ # Note: cache is NOT cleared β€” this enables stale cache detection
208
+ # (BUG_TASK_01 variant: deleted task still returned from cache)
209
+
210
+ return {"message": "Task deleted", "id": task_id}
server/buggy_api/routes/users.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ User management routes with planted bugs.
3
+
4
+ BUGS PLANTED:
5
+ - BUG_USER_01 (medium): POST /users doesn't validate email format
6
+ - BUG_USER_02 (medium): GET /users exposes password hashes in response
7
+ """
8
+
9
+ from fastapi import APIRouter, HTTPException
10
+
11
+ from ..database import Database
12
+ from ..models import UserCreate
13
+
14
+ router = APIRouter(prefix="/users", tags=["users"])
15
+
16
+ _db: Database | None = None
17
+
18
+
19
+ def set_db(db: Database):
20
+ global _db
21
+ _db = db
22
+
23
+
24
+ def get_db() -> Database:
25
+ return _db
26
+
27
+
28
+ @router.get("")
29
+ def list_users():
30
+ db = get_db()
31
+ rows = db.execute("SELECT id, username, email, role, created_at FROM users")
32
+ return rows
33
+
34
+
35
+ @router.get("/{user_id}")
36
+ def get_user(user_id: int):
37
+ db = get_db()
38
+ rows = db.execute("SELECT id, username, email, role, created_at FROM users WHERE id = ?", (user_id,))
39
+ if not rows:
40
+ raise HTTPException(status_code=404, detail="User not found")
41
+ return rows[0]
42
+
43
+
44
+ @router.post("", status_code=201)
45
+ def create_user(user: UserCreate):
46
+ db = get_db()
47
+
48
+ # BUG_USER_01: No email format validation β€” accepts "not-an-email" or empty string
49
+ # Should validate email with regex or pydantic EmailStr
50
+
51
+ # Check username uniqueness
52
+ existing = db.execute("SELECT id FROM users WHERE username = ?", (user.username,))
53
+ if existing:
54
+ raise HTTPException(status_code=409, detail="Username already exists")
55
+
56
+ user_id = db.execute_insert(
57
+ "INSERT INTO users (username, email, password_hash, role) VALUES (?, ?, ?, ?)",
58
+ (user.username, user.email, f"hashed_{user.password}", user.role),
59
+ )
60
+
61
+ # BUG_USER_02: Response includes password_hash field
62
+ rows = db.execute("SELECT * FROM users WHERE id = ?", (user_id,))
63
+ return rows[0]
server/environment.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OpenEnv Environment for API Integration Testing.
3
+
4
+ The agent interacts with a deliberately buggy REST API, discovering endpoints,
5
+ crafting requests, and finding bugs. Rewards are multi-signal: coverage,
6
+ validity, bug discovery, and exploration.
7
+ """
8
+
9
+ import logging
10
+ import random
11
+ import time
12
+ import json
13
+ from typing import Any, Optional
14
+
15
+ from fastapi.testclient import TestClient
16
+ from openenv.core.env_server.interfaces import Environment
17
+
18
+ try:
19
+ from ..models import APITestAction, APITestObservation, APITestState
20
+ except ImportError:
21
+ from models import APITestAction, APITestObservation, APITestState
22
+
23
+ from .buggy_api.database import Database
24
+ from .buggy_api.main import create_buggy_api
25
+ from .bug_detector import BugDetector
26
+ from .reward import RewardComputer
27
+ from .graders import TaskGrader, generate_bug_report
28
+ from .graders import TaskGrader
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+ # Task definitions
33
+ TASKS = {
34
+ "basic_validation": {
35
+ "id": "basic_validation",
36
+ "description": (
37
+ "Test all CRUD endpoints with valid inputs and verify correct status codes. "
38
+ "Find basic bugs like wrong status codes and missing field handling. "
39
+ "Available endpoints: GET /tasks, POST /tasks, GET /tasks/{id}, PUT /tasks/{id}, "
40
+ "DELETE /tasks/{id}, GET /users, POST /users, POST /auth/login. "
41
+ "Try different methods on each endpoint and verify responses match the expected behavior."
42
+ ),
43
+ "difficulty": "easy",
44
+ "max_steps": 25,
45
+ "total_bugs": 3,
46
+ },
47
+ "edge_cases": {
48
+ "id": "edge_cases",
49
+ "description": (
50
+ "Test boundary conditions, invalid inputs, and error responses. "
51
+ "Send missing fields, wrong types, negative page numbers, huge limits. "
52
+ "Test with non-existent resource IDs (e.g., /tasks/999999). "
53
+ "Chain operations: create a resource, then read/update/delete it. "
54
+ "Find bugs in input validation, pagination, and error handling."
55
+ ),
56
+ "difficulty": "medium",
57
+ "max_steps": 35,
58
+ "total_bugs": 9,
59
+ },
60
+ "security_workflows": {
61
+ "id": "security_workflows",
62
+ "description": (
63
+ "Discover authorization flaws, injection vulnerabilities, and workflow bugs. "
64
+ "Login as different users (alice/password, bob/password, charlie/password) and "
65
+ "try accessing each other's resources. Test SQL injection patterns in input fields. "
66
+ "Execute multi-step workflows: create -> modify -> verify -> delete -> re-fetch. "
67
+ "Check if auth tokens properly scope access. Test with very long inputs."
68
+ ),
69
+ "difficulty": "hard",
70
+ "max_steps": 45,
71
+ "total_bugs": 13,
72
+ },
73
+ }
74
+
75
+ # OpenAPI-like spec for the agent
76
+ API_SPEC = [
77
+ {
78
+ "method": "GET",
79
+ "path": "/tasks",
80
+ "summary": "List all tasks. Supports filtering by status, priority; pagination with page & limit; sorting with sort.",
81
+ "parameters": [
82
+ {"name": "status", "in": "query", "type": "string", "enum": ["pending", "in_progress", "done"]},
83
+ {"name": "priority", "in": "query", "type": "string", "enum": ["low", "medium", "high"]},
84
+ {"name": "sort", "in": "query", "type": "string", "enum": ["created_at", "updated_at", "title"]},
85
+ {"name": "page", "in": "query", "type": "integer"},
86
+ {"name": "limit", "in": "query", "type": "integer"},
87
+ ],
88
+ },
89
+ {
90
+ "method": "POST",
91
+ "path": "/tasks",
92
+ "summary": "Create a new task. Requires 'title' field. Optional: description, status, priority, assignee_email.",
93
+ "request_body": {
94
+ "required": ["title"],
95
+ "properties": {
96
+ "title": {"type": "string"},
97
+ "description": {"type": "string"},
98
+ "status": {"type": "string", "enum": ["pending", "in_progress", "done"]},
99
+ "priority": {"type": "string", "enum": ["low", "medium", "high"]},
100
+ "assignee_email": {"type": "string", "format": "email"},
101
+ },
102
+ },
103
+ },
104
+ {
105
+ "method": "GET",
106
+ "path": "/tasks/{id}",
107
+ "summary": "Get a specific task by ID.",
108
+ "parameters": [{"name": "id", "in": "path", "type": "integer", "required": True}],
109
+ },
110
+ {
111
+ "method": "PUT",
112
+ "path": "/tasks/{id}",
113
+ "summary": "Update a task. All fields optional.",
114
+ "parameters": [{"name": "id", "in": "path", "type": "integer", "required": True}],
115
+ "request_body": {
116
+ "properties": {
117
+ "title": {"type": "string"},
118
+ "description": {"type": "string"},
119
+ "status": {"type": "string"},
120
+ "priority": {"type": "string"},
121
+ "assignee_email": {"type": "string", "format": "email"},
122
+ },
123
+ },
124
+ },
125
+ {
126
+ "method": "DELETE",
127
+ "path": "/tasks/{id}",
128
+ "summary": "Delete a task by ID.",
129
+ "parameters": [{"name": "id", "in": "path", "type": "integer", "required": True}],
130
+ },
131
+ {
132
+ "method": "GET",
133
+ "path": "/users",
134
+ "summary": "List all users.",
135
+ },
136
+ {
137
+ "method": "POST",
138
+ "path": "/users",
139
+ "summary": "Create a new user. Requires username, email, password.",
140
+ "request_body": {
141
+ "required": ["username", "email", "password"],
142
+ "properties": {
143
+ "username": {"type": "string"},
144
+ "email": {"type": "string", "format": "email"},
145
+ "password": {"type": "string"},
146
+ "role": {"type": "string", "enum": ["user", "admin"]},
147
+ },
148
+ },
149
+ },
150
+ {
151
+ "method": "GET",
152
+ "path": "/users/{id}",
153
+ "summary": "Get a specific user by ID.",
154
+ "parameters": [{"name": "id", "in": "path", "type": "integer", "required": True}],
155
+ },
156
+ {
157
+ "method": "POST",
158
+ "path": "/auth/login",
159
+ "summary": "Login and receive an auth token. Pre-seeded users: alice, bob, charlie (password: any string).",
160
+ "request_body": {
161
+ "required": ["username", "password"],
162
+ "properties": {
163
+ "username": {"type": "string"},
164
+ "password": {"type": "string"},
165
+ },
166
+ },
167
+ },
168
+ ]
169
+
170
+
171
+ class APITestEnvironment(Environment):
172
+ """OpenEnv environment for API integration testing.
173
+
174
+ The agent tests a deliberately buggy REST API by sending HTTP requests
175
+ and analyzing responses. It earns rewards for coverage, finding bugs,
176
+ and exploring edge cases.
177
+ """
178
+
179
+ SUPPORTS_CONCURRENT_SESSIONS = False
180
+
181
+ def __init__(self, **kwargs):
182
+ super().__init__(**kwargs)
183
+ self._db: Optional[Database] = None
184
+ self._api: Optional[TestClient] = None
185
+ self._bug_detector: Optional[BugDetector] = None
186
+ self._reward_computer: Optional[RewardComputer] = None
187
+ self._task: Optional[dict] = None
188
+ self._found_bugs: set[str] = set()
189
+ self._steps_taken: int = 0
190
+ self._cumulative_reward: float = 0.0
191
+ self._action_history: list[dict] = []
192
+ self._auth_tokens: dict[str, str] = {}
193
+ self._episode_id: str = ""
194
+
195
+ def reset(self, seed=None, episode_id=None, **kwargs) -> APITestObservation:
196
+ """Reset the environment for a new episode.
197
+
198
+ Args:
199
+ seed: Random seed for domain randomization. When provided, the
200
+ database is populated with different users, tasks, and data
201
+ so each training episode is unique. None = fixed default data.
202
+ episode_id: Optional episode identifier for tracking.
203
+
204
+ kwargs:
205
+ task_id: str - one of "basic_validation", "edge_cases", "security_workflows"
206
+ """
207
+ task_id = kwargs.get("task_id", "basic_validation")
208
+ if task_id not in TASKS:
209
+ task_id = "basic_validation"
210
+
211
+ self._task = TASKS[task_id]
212
+ self._seed = seed
213
+ self._episode_id = episode_id or f"ep_{int(time.time())}"
214
+
215
+ # Reset database with seed for domain randomization
216
+ # seed=None β†’ fixed data (manual testing / Gradio)
217
+ # seed=int β†’ randomized data (GRPO training)
218
+ self._db = Database(seed=seed)
219
+ buggy_app = create_buggy_api(self._db)
220
+ self._api = TestClient(buggy_app, raise_server_exceptions=False)
221
+
222
+ # Build dynamic task description that includes actual usernames
223
+ user_names = self._db.user_names
224
+ user_list = ", ".join(user_names)
225
+ dynamic_description = (
226
+ f"{self._task['description']} "
227
+ f"Users in the system: {user_list} (use any password to login)."
228
+ )
229
+
230
+ # Reset tracking
231
+ self._bug_detector = BugDetector(task_id)
232
+ self._reward_computer = RewardComputer()
233
+ self._found_bugs = set()
234
+ self._steps_taken = 0
235
+ self._cumulative_reward = 0.0
236
+ self._action_history = []
237
+ self._auth_tokens = {}
238
+
239
+ logger.info(f"Reset environment: task={task_id}, seed={seed}, episode={self._episode_id}")
240
+
241
+ return APITestObservation(
242
+ available_endpoints=API_SPEC,
243
+ status_code=0,
244
+ response_body=None,
245
+ response_headers={},
246
+ response_time_ms=0,
247
+ feedback=(
248
+ f"Environment reset. Task: {dynamic_description} "
249
+ f"You have {self._task['max_steps']} steps. Start testing the API!"
250
+ ),
251
+ bugs_found_so_far=0,
252
+ coverage_summary=self._reward_computer.coverage.summary(),
253
+ known_resource_ids=self._reward_computer.created_ids,
254
+ auth_tokens=self._auth_tokens,
255
+ task_id=task_id,
256
+ task_description=dynamic_description,
257
+ steps_taken=0,
258
+ max_steps=self._task["max_steps"],
259
+ done=False,
260
+ reward=0.0,
261
+ )
262
+
263
+ def step(self, action: APITestAction, timeout_s=None, **kwargs) -> APITestObservation:
264
+ """Execute an API test action and return observation + reward."""
265
+ self._steps_taken += 1
266
+
267
+ # Forward request to buggy API
268
+ method = action.method.value if hasattr(action.method, "value") else str(action.method)
269
+ endpoint = action.endpoint
270
+ headers = dict(action.headers) if action.headers else {}
271
+ query_params = dict(action.query_params) if action.query_params else {}
272
+ body = action.body
273
+
274
+ # Make the request
275
+ start_time = time.time()
276
+ try:
277
+ response = self._api.request(
278
+ method=method.upper(),
279
+ url=endpoint,
280
+ headers=headers,
281
+ params=query_params if query_params else None,
282
+ json=body,
283
+ )
284
+ elapsed_ms = (time.time() - start_time) * 1000
285
+
286
+ response_status = response.status_code
287
+ try:
288
+ response_body = response.json()
289
+ except Exception:
290
+ response_body = response.text
291
+ response_headers = dict(response.headers)
292
+ except Exception as e:
293
+ elapsed_ms = (time.time() - start_time) * 1000
294
+ response_status = 0
295
+ response_body = {"error": str(e)}
296
+ response_headers = {}
297
+
298
+ # Track auth tokens from login responses
299
+ if (
300
+ endpoint == "/auth/login"
301
+ and response_status == 200
302
+ and isinstance(response_body, dict)
303
+ and "token" in response_body
304
+ ):
305
+ username = body.get("username", "unknown") if body else "unknown"
306
+ self._auth_tokens[username] = response_body["token"]
307
+
308
+ # Check for bug detection
309
+ detection = self._bug_detector.check(
310
+ method=method,
311
+ endpoint=endpoint,
312
+ headers=headers,
313
+ query_params=query_params,
314
+ body=body,
315
+ expected_status=action.expected_status,
316
+ response_status=response_status,
317
+ response_body=response_body,
318
+ action_history=self._action_history,
319
+ found_bugs=self._found_bugs,
320
+ )
321
+
322
+ bug_severity = None
323
+ bug_id = None
324
+ if detection:
325
+ bug_severity = detection.bug.severity
326
+ bug_id = detection.bug.id
327
+ self._found_bugs.add(bug_id)
328
+
329
+ # Compute reward
330
+ reward_breakdown = self._reward_computer.compute(
331
+ method=method,
332
+ endpoint=endpoint,
333
+ headers=headers,
334
+ query_params=query_params,
335
+ body=body,
336
+ expected_status=action.expected_status,
337
+ response_status=response_status,
338
+ response_body=response_body,
339
+ bug_found=bug_severity,
340
+ bug_id=bug_id,
341
+ )
342
+ self._cumulative_reward += reward_breakdown.total
343
+
344
+ # Record action in history
345
+ self._action_history.append({
346
+ "method": method,
347
+ "endpoint": endpoint,
348
+ "headers": headers,
349
+ "query_params": query_params,
350
+ "body": body,
351
+ "response_status": response_status,
352
+ "response_body": response_body,
353
+ })
354
+
355
+ # Generate feedback
356
+ feedback_parts = [f"{method} {endpoint} -> {response_status}"]
357
+ if detection:
358
+ feedback_parts.append(f"BUG FOUND ({detection.bug.severity})! {detection.evidence}")
359
+ if reward_breakdown.coverage > 0:
360
+ feedback_parts.append(f"Coverage +{reward_breakdown.coverage:.2f}")
361
+ if reward_breakdown.penalty < 0:
362
+ feedback_parts.append("Repeated request penalty")
363
+
364
+ done = self._steps_taken >= self._task["max_steps"]
365
+
366
+ # Compute final grade if done
367
+ if done:
368
+ grade = TaskGrader.grade(
369
+ task_id=self._task["id"],
370
+ bugs_found=self._found_bugs,
371
+ coverage_pct=self._reward_computer.coverage.summary()["coverage_pct"],
372
+ endpoints_tested=len(self._reward_computer.coverage.endpoints_hit),
373
+ total_endpoints=self._reward_computer.coverage.total_endpoints,
374
+ method_endpoint_pairs=len(self._reward_computer.coverage.method_endpoint_pairs),
375
+ status_codes_seen=self._reward_computer.coverage.status_codes_seen,
376
+ action_history=self._action_history,
377
+ created_resources=self._reward_computer.created_ids,
378
+ )
379
+ # Generate bug bounty report
380
+ report = generate_bug_report(list(self._found_bugs), self._action_history)
381
+
382
+ feedback_parts.append(
383
+ f"\n=== EPISODE COMPLETE ===\n"
384
+ f"Final Score: {grade.score:.4f}\n"
385
+ f"Bugs Found: {len(self._found_bugs)}/{self._task['total_bugs']}\n"
386
+ f"Grade Breakdown: {json.dumps(grade.breakdown, indent=2)}\n"
387
+ f"Feedback: {grade.feedback}\n\n"
388
+ f"{report}"
389
+ )
390
+ # Add grade as bonus on top of step reward (not replacement)
391
+ final_reward = reward_breakdown.total + grade.score
392
+ else:
393
+ final_reward = reward_breakdown.total
394
+
395
+ return APITestObservation(
396
+ available_endpoints=API_SPEC,
397
+ status_code=response_status,
398
+ response_body=response_body,
399
+ response_headers={k: v for k, v in list(response_headers.items())[:20]},
400
+ response_time_ms=round(elapsed_ms, 2),
401
+ feedback=" | ".join(feedback_parts),
402
+ bugs_found_so_far=len(self._found_bugs),
403
+ coverage_summary=self._reward_computer.coverage.summary(),
404
+ known_resource_ids=self._reward_computer.created_ids,
405
+ auth_tokens=self._auth_tokens,
406
+ task_id=self._task["id"],
407
+ task_description=self._task["description"],
408
+ steps_taken=self._steps_taken,
409
+ max_steps=self._task["max_steps"],
410
+ done=done,
411
+ reward=final_reward,
412
+ metadata={"reward_breakdown": reward_breakdown.as_dict()},
413
+ )
414
+
415
+ @property
416
+ def state(self) -> APITestState:
417
+ """Return current episode state."""
418
+ if not self._task:
419
+ return APITestState()
420
+
421
+ coverage = self._reward_computer.coverage.summary() if self._reward_computer else {}
422
+ return APITestState(
423
+ episode_id=self._episode_id,
424
+ step_count=self._steps_taken,
425
+ task_id=self._task["id"],
426
+ task_description=self._task["description"],
427
+ difficulty=self._task["difficulty"],
428
+ steps_taken=self._steps_taken,
429
+ max_steps=self._task["max_steps"],
430
+ bugs_found=len(self._found_bugs),
431
+ total_bugs=self._task["total_bugs"],
432
+ bugs_found_ids=list(self._found_bugs),
433
+ coverage_pct=coverage.get("coverage_pct", 0.0),
434
+ endpoints_tested=coverage.get("endpoints_tested", 0),
435
+ total_endpoints=coverage.get("total_endpoints", 0),
436
+ current_score=0.0,
437
+ cumulative_reward=round(self._cumulative_reward, 4),
438
+ )
server/graders.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Task-specific grading logic and bug bounty report generation.
3
+
4
+ Each task has a grader that computes a final score (0.0 - 1.0)
5
+ based on what the agent accomplished during the episode.
6
+ """
7
+
8
+ from dataclasses import dataclass
9
+
10
+
11
+ @dataclass
12
+ class GradeResult:
13
+ score: float
14
+ breakdown: dict[str, float]
15
+ feedback: str
16
+ report: str = "" # Bug bounty report (markdown)
17
+
18
+
19
+ def generate_bug_report(bugs_found_ids: list[str], action_history: list[dict]) -> str:
20
+ """Generate a structured bug bounty report for discovered bugs."""
21
+ from .bug_detector import BugDetector
22
+ detector = BugDetector("security_workflows")
23
+
24
+ if not bugs_found_ids:
25
+ return "## API Security Assessment Report\n\nNo vulnerabilities discovered."
26
+
27
+ severity_order = {"hard": 0, "medium": 1, "easy": 2}
28
+ sorted_bugs = sorted(bugs_found_ids, key=lambda b: severity_order.get(detector.bugs.get(b, None) and detector.bugs[b].severity or "easy", 2))
29
+
30
+ sections = ["## API Security Assessment Report", ""]
31
+ sections.append(f"**Vulnerabilities Found:** {len(bugs_found_ids)}")
32
+
33
+ # Count by severity
34
+ counts = {"easy": 0, "medium": 0, "hard": 0}
35
+ for bid in bugs_found_ids:
36
+ bug = detector.bugs.get(bid)
37
+ if bug:
38
+ counts[bug.severity] = counts.get(bug.severity, 0) + 1
39
+ sections.append(f"**Critical/Hard:** {counts['hard']} | **Medium:** {counts['medium']} | **Low/Easy:** {counts['easy']}")
40
+ sections.append("")
41
+
42
+ for bid in sorted_bugs:
43
+ bug = detector.bugs.get(bid)
44
+ if not bug:
45
+ continue
46
+
47
+ sev_label = {"easy": "LOW", "medium": "MEDIUM", "hard": "HIGH"}.get(bug.severity, "INFO")
48
+ owasp = bug.owasp if bug.owasp else "Uncategorized"
49
+
50
+ sections.append(f"### {sev_label}: {bug.description}")
51
+ sections.append(f"- **ID:** {bid}")
52
+ sections.append(f"- **OWASP:** {owasp}")
53
+ sections.append(f"- **Category:** {bug.category}")
54
+ sections.append(f"- **Recommendation:** {bug.recommendation}" if bug.recommendation else "")
55
+
56
+ # Find the action that triggered this bug
57
+ for h in action_history:
58
+ if h.get("method") and h.get("endpoint"):
59
+ sections.append(f"- **Triggered by:** {h['method']} {h['endpoint']}")
60
+ break
61
+ sections.append("")
62
+
63
+ return "\n".join(sections)
64
+
65
+
66
+ class TaskGrader:
67
+ """Computes final scores for each task based on episode performance."""
68
+
69
+ @staticmethod
70
+ def grade(
71
+ task_id: str,
72
+ bugs_found: set[str],
73
+ coverage_pct: float,
74
+ endpoints_tested: int,
75
+ total_endpoints: int,
76
+ method_endpoint_pairs: int,
77
+ status_codes_seen: set[int],
78
+ action_history: list[dict],
79
+ created_resources: dict[str, list],
80
+ ) -> GradeResult:
81
+ if task_id == "basic_validation":
82
+ return TaskGrader._grade_basic(
83
+ bugs_found, coverage_pct, endpoints_tested, total_endpoints,
84
+ method_endpoint_pairs, status_codes_seen, action_history, created_resources,
85
+ )
86
+ elif task_id == "edge_cases":
87
+ return TaskGrader._grade_edge_cases(
88
+ bugs_found, coverage_pct, endpoints_tested, method_endpoint_pairs,
89
+ status_codes_seen, action_history, created_resources,
90
+ )
91
+ elif task_id == "security_workflows":
92
+ return TaskGrader._grade_security(
93
+ bugs_found, coverage_pct, action_history, created_resources,
94
+ )
95
+ return GradeResult(score=0.0, breakdown={}, feedback="Unknown task")
96
+
97
+ @staticmethod
98
+ def _grade_basic(
99
+ bugs_found, coverage_pct, endpoints_tested, total_endpoints,
100
+ method_endpoint_pairs, status_codes_seen, action_history, created_resources,
101
+ ) -> GradeResult:
102
+ breakdown = {}
103
+
104
+ # 0.25: Test all GET endpoints
105
+ get_endpoints = {
106
+ h.get("endpoint") for h in action_history
107
+ if h.get("method", "").upper() == "GET"
108
+ }
109
+ get_score = min(len(get_endpoints) / 4, 1.0) * 0.25
110
+ breakdown["get_coverage"] = round(get_score, 3)
111
+
112
+ # 0.20: Test POST with valid data
113
+ post_success = sum(
114
+ 1 for h in action_history
115
+ if h.get("method", "").upper() == "POST" and h.get("response_status") == 201
116
+ )
117
+ post_score = min(post_success / 2, 1.0) * 0.20
118
+ breakdown["post_testing"] = round(post_score, 3)
119
+
120
+ # 0.15: Test PUT/DELETE
121
+ put_delete = sum(
122
+ 1 for h in action_history
123
+ if h.get("method", "").upper() in ("PUT", "DELETE")
124
+ )
125
+ pd_score = min(put_delete / 2, 1.0) * 0.15
126
+ breakdown["put_delete"] = round(pd_score, 3)
127
+
128
+ # 0.20: Bug discovery (easy bugs: TASK_01, TASK_02, TASK_03)
129
+ easy_bugs = {"BUG_TASK_01", "BUG_TASK_02", "BUG_TASK_03"}
130
+ found_easy = len(bugs_found & easy_bugs)
131
+ bug_score = min(found_easy / 2, 1.0) * 0.20
132
+ breakdown["bugs_found"] = round(bug_score, 3)
133
+
134
+ # 0.20: Response schema validation (status codes variety)
135
+ schema_score = min(len(status_codes_seen) / 4, 1.0) * 0.20
136
+ breakdown["schema_validation"] = round(schema_score, 3)
137
+
138
+ score = sum(breakdown.values())
139
+ feedback_parts = []
140
+ if get_score > 0:
141
+ feedback_parts.append(f"GET coverage: {len(get_endpoints)} endpoints")
142
+ if post_success > 0:
143
+ feedback_parts.append(f"POST success: {post_success}")
144
+ if found_easy > 0:
145
+ feedback_parts.append(f"Bugs found: {found_easy}/{len(easy_bugs)}")
146
+
147
+ return GradeResult(
148
+ score=round(min(score, 1.0), 4),
149
+ breakdown=breakdown,
150
+ feedback="; ".join(feedback_parts) if feedback_parts else "No significant progress",
151
+ )
152
+
153
+ @staticmethod
154
+ def _grade_edge_cases(
155
+ bugs_found, coverage_pct, endpoints_tested, method_endpoint_pairs,
156
+ status_codes_seen, action_history, created_resources,
157
+ ) -> GradeResult:
158
+ breakdown = {}
159
+
160
+ # 0.15: Missing required fields testing
161
+ missing_field_tests = sum(
162
+ 1 for h in action_history
163
+ if h.get("method", "").upper() == "POST"
164
+ and h.get("body") is not None
165
+ and isinstance(h.get("body"), dict)
166
+ and not h["body"].get("title")
167
+ )
168
+ breakdown["missing_fields"] = round(min(missing_field_tests / 2, 1.0) * 0.15, 3)
169
+
170
+ # 0.15: Invalid data type testing
171
+ invalid_tests = sum(
172
+ 1 for h in action_history
173
+ if h.get("body") and isinstance(h.get("body"), dict)
174
+ and any(
175
+ isinstance(v, (list, bool)) or v == ""
176
+ for v in h["body"].values()
177
+ )
178
+ )
179
+ breakdown["invalid_types"] = round(min(invalid_tests / 2, 1.0) * 0.15, 3)
180
+
181
+ # 0.15: Boundary value testing (negative pages, huge limits, long strings)
182
+ boundary_tests = 0
183
+ for h in action_history:
184
+ qp = h.get("query_params", {})
185
+ if qp.get("page") is not None and int(str(qp.get("page", 1))) < 1:
186
+ boundary_tests += 1
187
+ if qp.get("limit") is not None and int(str(qp.get("limit", 10))) > 100:
188
+ boundary_tests += 1
189
+ breakdown["boundary_values"] = round(min(boundary_tests / 2, 1.0) * 0.15, 3)
190
+
191
+ # 0.15: Non-existent resource testing
192
+ nonexistent_tests = sum(
193
+ 1 for h in action_history
194
+ if h.get("method", "").upper() in ("GET", "DELETE", "PUT")
195
+ and "/999" in h.get("endpoint", "")
196
+ )
197
+ breakdown["nonexistent_resources"] = round(min(nonexistent_tests / 2, 1.0) * 0.15, 3)
198
+
199
+ # 0.20: Bug discovery (medium bugs)
200
+ medium_bugs = {
201
+ "BUG_TASK_04", "BUG_TASK_05", "BUG_TASK_06",
202
+ "BUG_USER_01", "BUG_USER_02", "BUG_AUTH_02",
203
+ }
204
+ all_relevant = medium_bugs | {"BUG_TASK_01", "BUG_TASK_02", "BUG_TASK_03"}
205
+ found_relevant = len(bugs_found & all_relevant)
206
+ breakdown["bugs_found"] = round(min(found_relevant / 3, 1.0) * 0.20, 3)
207
+
208
+ # 0.20: Dependency chaining (create β†’ read β†’ update β†’ delete)
209
+ chain_score = 0.0
210
+ if any(h.get("method") == "POST" and h.get("response_status") == 201 for h in action_history):
211
+ chain_score += 0.25
212
+ if created_resources.get("tasks"):
213
+ task_ids = created_resources["tasks"]
214
+ for tid in task_ids:
215
+ gets = [h for h in action_history if h.get("endpoint") == f"/tasks/{tid}" and h.get("method") == "GET"]
216
+ puts = [h for h in action_history if h.get("endpoint") == f"/tasks/{tid}" and h.get("method") == "PUT"]
217
+ deletes = [h for h in action_history if h.get("endpoint") == f"/tasks/{tid}" and h.get("method") == "DELETE"]
218
+ if gets:
219
+ chain_score += 0.25
220
+ if puts:
221
+ chain_score += 0.25
222
+ if deletes:
223
+ chain_score += 0.25
224
+ break # Only need one complete chain
225
+ breakdown["dependency_chaining"] = round(min(chain_score, 1.0) * 0.20, 3)
226
+
227
+ score = sum(breakdown.values())
228
+ return GradeResult(
229
+ score=round(min(score, 1.0), 4),
230
+ breakdown=breakdown,
231
+ feedback=f"Edge cases: {found_relevant} bugs found, chain score {chain_score:.0%}",
232
+ )
233
+
234
+ @staticmethod
235
+ def _grade_security(
236
+ bugs_found, coverage_pct, action_history, created_resources,
237
+ ) -> GradeResult:
238
+ breakdown = {}
239
+
240
+ # 0.20: Cross-user authorization testing
241
+ cross_user = False
242
+ login_users = set()
243
+ for h in action_history:
244
+ if h.get("endpoint") == "/auth/login" and h.get("response_status") == 200:
245
+ body = h.get("body", {})
246
+ if body:
247
+ login_users.add(body.get("username"))
248
+ cross_user = len(login_users) >= 2
249
+ breakdown["cross_user_auth"] = 0.20 if cross_user else 0.0
250
+
251
+ # 0.20: Injection pattern testing
252
+ injection_attempted = sum(
253
+ 1 for h in action_history
254
+ if h.get("body") and isinstance(h.get("body"), dict)
255
+ and any(
256
+ pattern.lower() in str(h["body"]).lower()
257
+ for pattern in ["DROP TABLE", "<script>", "OR 1=1", "UNION SELECT", "'; --"]
258
+ )
259
+ )
260
+ breakdown["injection_testing"] = round(min(injection_attempted / 2, 1.0) * 0.20, 3)
261
+
262
+ # 0.20: Multi-step state consistency
263
+ # Check if agent did: create β†’ delete β†’ re-fetch (stale cache test)
264
+ consistency_tests = 0
265
+ for i, h in enumerate(action_history):
266
+ if h.get("method") == "DELETE" and "/tasks/" in h.get("endpoint", ""):
267
+ # Check if agent re-fetched the same resource after deleting
268
+ deleted_endpoint = h["endpoint"]
269
+ for j in range(i + 1, len(action_history)):
270
+ if action_history[j].get("endpoint") == deleted_endpoint and action_history[j].get("method") == "GET":
271
+ consistency_tests += 1
272
+ break
273
+ breakdown["state_consistency"] = round(min(consistency_tests, 1.0) * 0.20, 3)
274
+
275
+ # 0.20: Security bug discovery
276
+ security_bugs = {"BUG_TASK_07", "BUG_AUTH_01", "BUG_TASK_08", "BUG_TASK_09"}
277
+ found_security = len(bugs_found & security_bugs)
278
+ breakdown["security_bugs"] = round(min(found_security / 2, 1.0) * 0.20, 3)
279
+
280
+ # 0.20: Complete workflow coverage
281
+ workflow_coverage = min(coverage_pct / 80, 1.0) # 80% coverage = full score
282
+ breakdown["workflow_coverage"] = round(workflow_coverage * 0.20, 3)
283
+
284
+ score = sum(breakdown.values())
285
+ return GradeResult(
286
+ score=round(min(score, 1.0), 4),
287
+ breakdown=breakdown,
288
+ feedback=f"Security: {found_security} security bugs, {len(login_users)} users tested, {injection_attempted} injection attempts",
289
+ )
server/reward.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multi-signal reward function for the API Testing Environment.
3
+
4
+ Rewards are decomposed into:
5
+ 1. Coverage reward β€” exploring new endpoints/methods/status codes
6
+ 2. Validity reward β€” well-formed requests and proper dependency chaining
7
+ 3. Bug discovery reward β€” the core goal, scaled by severity
8
+ 4. Exploration bonus β€” trying novel actions
9
+ 5. Penalties β€” for repeating exact requests or malformed input
10
+ """
11
+
12
+ from dataclasses import dataclass, field
13
+ from typing import Any, Optional
14
+ import re
15
+
16
+
17
+ @dataclass
18
+ class CoverageTracker:
19
+ """Tracks API coverage across the episode."""
20
+
21
+ endpoints_hit: set[str] = field(default_factory=set)
22
+ method_endpoint_pairs: set[tuple[str, str]] = field(default_factory=set)
23
+ status_codes_seen: set[int] = field(default_factory=set)
24
+ total_endpoints: int = 10 # known endpoint patterns
25
+
26
+ def record(self, method: str, endpoint: str, status_code: int) -> dict[str, bool]:
27
+ """Record a request and return what's new."""
28
+ normalized_endpoint = self._normalize_endpoint(endpoint)
29
+ pair = (method.upper(), normalized_endpoint)
30
+
31
+ is_new_endpoint = normalized_endpoint not in self.endpoints_hit
32
+ is_new_pair = pair not in self.method_endpoint_pairs
33
+ is_new_status = status_code not in self.status_codes_seen
34
+
35
+ self.endpoints_hit.add(normalized_endpoint)
36
+ self.method_endpoint_pairs.add(pair)
37
+ self.status_codes_seen.add(status_code)
38
+
39
+ return {
40
+ "new_endpoint": is_new_endpoint,
41
+ "new_method_endpoint": is_new_pair,
42
+ "new_status_code": is_new_status,
43
+ }
44
+
45
+ def _normalize_endpoint(self, endpoint: str) -> str:
46
+ """Normalize /tasks/42 to /tasks/{id}."""
47
+ normalized = re.sub(r"/(\d+)", "/{id}", endpoint)
48
+ return normalized.rstrip("/") or "/"
49
+
50
+ def summary(self) -> dict:
51
+ return {
52
+ "endpoints_tested": len(self.endpoints_hit),
53
+ "total_endpoints": self.total_endpoints,
54
+ "method_endpoint_pairs": len(self.method_endpoint_pairs),
55
+ "status_codes_seen": sorted(self.status_codes_seen),
56
+ "coverage_pct": round(len(self.endpoints_hit) / max(self.total_endpoints, 1) * 100, 1),
57
+ }
58
+
59
+
60
+ @dataclass
61
+ class RewardBreakdown:
62
+ coverage: float = 0.0
63
+ validity: float = 0.0
64
+ bug_discovery: float = 0.0
65
+ exploration: float = 0.0
66
+ penalty: float = 0.0
67
+ total: float = 0.0
68
+
69
+ def as_dict(self) -> dict:
70
+ return {
71
+ "coverage": round(self.coverage, 4),
72
+ "validity": round(self.validity, 4),
73
+ "bug_discovery": round(self.bug_discovery, 4),
74
+ "exploration": round(self.exploration, 4),
75
+ "penalty": round(self.penalty, 4),
76
+ "total": round(self.total, 4),
77
+ }
78
+
79
+
80
+ class RewardComputer:
81
+ """Computes multi-signal rewards for API testing actions."""
82
+
83
+ def __init__(self):
84
+ self.coverage = CoverageTracker()
85
+ self.action_history: list[dict] = []
86
+ self.found_bugs: set[str] = set()
87
+ self.created_ids: dict[str, list[Any]] = {} # resource type -> list of IDs
88
+
89
+ def reset(self):
90
+ self.coverage = CoverageTracker()
91
+ self.action_history = []
92
+ self.found_bugs = set()
93
+ self.created_ids = {}
94
+
95
+ def compute(
96
+ self,
97
+ method: str,
98
+ endpoint: str,
99
+ headers: dict,
100
+ query_params: dict,
101
+ body: Optional[dict],
102
+ expected_status: Optional[int],
103
+ response_status: int,
104
+ response_body: Any,
105
+ bug_found: Optional[str] = None, # bug severity if found
106
+ bug_id: Optional[str] = None,
107
+ ) -> RewardBreakdown:
108
+ """Compute reward for this step."""
109
+ breakdown = RewardBreakdown()
110
+
111
+ # 1. Coverage reward (0.0 - 0.3)
112
+ coverage_info = self.coverage.record(method, endpoint, response_status)
113
+ if coverage_info["new_endpoint"]:
114
+ breakdown.coverage += 0.10
115
+ if coverage_info["new_method_endpoint"]:
116
+ breakdown.coverage += 0.05
117
+ if coverage_info["new_status_code"]:
118
+ breakdown.coverage += 0.05
119
+
120
+ # 2. Validity reward (0.0 - 0.2)
121
+ if response_status < 500:
122
+ breakdown.validity += 0.03 # Non-crash request
123
+
124
+ if self._used_dependency(method, endpoint, body, headers):
125
+ breakdown.validity += 0.10 # Used a previously created resource ID or auth token
126
+
127
+ if expected_status is not None and expected_status == response_status:
128
+ breakdown.validity += 0.05 # Correctly predicted status code
129
+
130
+ # Track created resources
131
+ self._track_created_resources(method, endpoint, response_status, response_body)
132
+
133
+ # 3. Bug discovery reward (0.0 - 0.4)
134
+ if bug_found and bug_id:
135
+ if bug_id not in self.found_bugs:
136
+ self.found_bugs.add(bug_id)
137
+ if bug_found == "easy":
138
+ breakdown.bug_discovery += 0.10
139
+ elif bug_found == "medium":
140
+ breakdown.bug_discovery += 0.15
141
+ elif bug_found == "hard":
142
+ breakdown.bug_discovery += 0.25
143
+ # First discovery bonus
144
+ breakdown.bug_discovery += 0.05
145
+
146
+ # 4. Exploration bonus (0.0 - 0.1)
147
+ action_sig = self._action_signature(method, endpoint, query_params, body)
148
+ is_novel = all(
149
+ self._action_signature(
150
+ h.get("method", ""),
151
+ h.get("endpoint", ""),
152
+ h.get("query_params", {}),
153
+ h.get("body"),
154
+ )
155
+ != action_sig
156
+ for h in self.action_history
157
+ )
158
+ if is_novel:
159
+ breakdown.exploration += 0.05
160
+
161
+ # 5. Penalties
162
+ # Exact duplicate request
163
+ exact_match = any(
164
+ h.get("method") == method
165
+ and h.get("endpoint") == endpoint
166
+ and h.get("query_params") == query_params
167
+ and h.get("body") == body
168
+ and h.get("headers") == headers
169
+ for h in self.action_history
170
+ )
171
+ if exact_match:
172
+ breakdown.penalty -= 0.08
173
+
174
+ # Record this action in history
175
+ self.action_history.append({
176
+ "method": method,
177
+ "endpoint": endpoint,
178
+ "headers": headers,
179
+ "query_params": query_params,
180
+ "body": body,
181
+ "response_status": response_status,
182
+ "response_body": response_body,
183
+ })
184
+
185
+ # Total
186
+ breakdown.total = max(
187
+ breakdown.coverage + breakdown.validity + breakdown.bug_discovery + breakdown.exploration + breakdown.penalty,
188
+ -0.1, # Floor to prevent extreme negative rewards
189
+ )
190
+ breakdown.total = min(breakdown.total, 1.0)
191
+
192
+ return breakdown
193
+
194
+ def _used_dependency(self, method: str, endpoint: str, body: Optional[dict], headers: dict) -> bool:
195
+ """Check if this request uses a resource ID or token from a previous step."""
196
+ endpoint_str = str(endpoint)
197
+
198
+ # Check if endpoint contains a known resource ID
199
+ for resource_type, ids in self.created_ids.items():
200
+ for rid in ids:
201
+ if str(rid) in endpoint_str:
202
+ return True
203
+
204
+ # Check if using an auth token obtained from login
205
+ if headers.get("Authorization"):
206
+ for prev in self.action_history:
207
+ if (
208
+ prev.get("endpoint") == "/auth/login"
209
+ and prev.get("response_status") == 200
210
+ and isinstance(prev.get("response_body"), dict)
211
+ and "token" in prev["response_body"]
212
+ ):
213
+ token = prev["response_body"]["token"]
214
+ if token in headers["Authorization"]:
215
+ return True
216
+ return False
217
+
218
+ def _track_created_resources(
219
+ self, method: str, endpoint: str, status: int, body: Any
220
+ ):
221
+ """Track resource IDs from POST responses."""
222
+ if method.upper() == "POST" and status == 201 and isinstance(body, dict):
223
+ resource_id = body.get("id")
224
+ if resource_id is not None:
225
+ # Determine resource type from endpoint
226
+ resource_type = endpoint.strip("/").split("/")[0]
227
+ if resource_type not in self.created_ids:
228
+ self.created_ids[resource_type] = []
229
+ self.created_ids[resource_type].append(resource_id)
230
+
231
+ def _action_signature(
232
+ self, method: str, endpoint: str, query_params: dict, body: Optional[dict]
233
+ ) -> str:
234
+ """Create a signature for an action to check novelty."""
235
+ normalized = re.sub(r"/\d+", "/{id}", endpoint)
236
+ body_keys = sorted(body.keys()) if body else []
237
+ param_keys = sorted(query_params.keys()) if query_params else []
238
+ return f"{method}:{normalized}:{param_keys}:{body_keys}"
setup.sh ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # ============================================================
3
+ # API Testing Environment β€” One-command setup
4
+ # ============================================================
5
+ # Usage: bash setup.sh
6
+ #
7
+ # This script:
8
+ # 1. Creates a virtual environment
9
+ # 2. Detects your GPU and installs the correct PyTorch+CUDA
10
+ # 3. Installs all project dependencies
11
+ # 4. Verifies everything works
12
+ # ============================================================
13
+
14
+ set -e
15
+
16
+ echo ""
17
+ echo "============================================"
18
+ echo " API Testing Environment β€” Setup"
19
+ echo "============================================"
20
+ echo ""
21
+
22
+ # --- Step 1: Create venv ---
23
+ echo "[1/5] Setting up virtual environment..."
24
+ if [ ! -d ".venv" ]; then
25
+ python3 -m venv .venv
26
+ echo " Created .venv"
27
+ else
28
+ echo " .venv already exists"
29
+ fi
30
+ source .venv/bin/activate
31
+ pip install --upgrade pip setuptools wheel -q
32
+ echo " Python: $(python3 --version)"
33
+ echo " pip: $(pip --version | awk '{print $2}')"
34
+ echo ""
35
+
36
+ # --- Step 2: Install PyTorch with correct CUDA ---
37
+ echo "[2/5] Detecting GPU and installing PyTorch..."
38
+
39
+ install_pytorch() {
40
+ if command -v nvidia-smi &> /dev/null; then
41
+ DRIVER_VERSION=$(nvidia-smi --query-gpu=driver_version --format=csv,noheader 2>/dev/null | head -1)
42
+ DRIVER_MAJOR=$(echo "$DRIVER_VERSION" | cut -d. -f1)
43
+ GPU_NAME=$(nvidia-smi --query-gpu=name --format=csv,noheader 2>/dev/null | head -1)
44
+ GPU_MEM=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader 2>/dev/null | head -1)
45
+
46
+ echo " GPU: $GPU_NAME ($GPU_MEM)"
47
+ echo " NVIDIA driver: $DRIVER_VERSION"
48
+
49
+ if [ "$DRIVER_MAJOR" -ge 530 ]; then
50
+ echo " -> Installing PyTorch + CUDA 12.1"
51
+ pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121 -q
52
+ elif [ "$DRIVER_MAJOR" -ge 450 ]; then
53
+ echo " -> Installing PyTorch + CUDA 11.8 (older driver)"
54
+ pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118 -q
55
+ else
56
+ echo " WARNING: Driver too old ($DRIVER_VERSION). Install CPU PyTorch."
57
+ echo " Upgrade: https://www.nvidia.com/Download/index.aspx"
58
+ pip install torch torchvision -q
59
+ fi
60
+ else
61
+ echo " No NVIDIA GPU detected."
62
+ # Check for Apple Silicon
63
+ if python3 -c "import platform; exit(0 if platform.processor() == 'arm' else 1)" 2>/dev/null; then
64
+ echo " -> Apple Silicon detected, installing default PyTorch (MPS support)"
65
+ else
66
+ echo " -> Installing CPU-only PyTorch"
67
+ fi
68
+ pip install torch torchvision -q
69
+ fi
70
+ }
71
+
72
+ install_pytorch
73
+ echo ""
74
+
75
+ # --- Step 3: Install project dependencies ---
76
+ echo "[3/5] Installing project dependencies..."
77
+ pip install -r requirements.txt -q
78
+ echo " Done."
79
+ echo ""
80
+
81
+ # --- Step 4: Verify everything ---
82
+ echo "[4/5] Verifying installation..."
83
+ echo ""
84
+ python3 << 'PYEOF'
85
+ import sys
86
+
87
+ # Core
88
+ import fastapi, uvicorn, pydantic, httpx
89
+ print(f" fastapi: {fastapi.__version__}")
90
+
91
+ # ML
92
+ import torch
93
+ print(f" torch: {torch.__version__}")
94
+ cuda = torch.cuda.is_available()
95
+ mps = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
96
+ if cuda:
97
+ print(f" CUDA: {torch.version.cuda}")
98
+ print(f" GPU: {torch.cuda.get_device_name(0)}")
99
+ print(f" GPU memory: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB")
100
+ elif mps:
101
+ print(f" Device: Apple MPS")
102
+ else:
103
+ print(f" Device: CPU only (training will be slow!)")
104
+
105
+ import transformers, trl, peft, datasets
106
+ print(f" transformers: {transformers.__version__}")
107
+ print(f" trl: {trl.__version__}")
108
+ print(f" peft: {peft.__version__}")
109
+
110
+ # Optional
111
+ try:
112
+ import wandb
113
+ print(f" wandb: {wandb.__version__}")
114
+ except ImportError:
115
+ print(f" wandb: not installed (optional)")
116
+
117
+ try:
118
+ import gradio
119
+ print(f" gradio: {gradio.__version__}")
120
+ except ImportError:
121
+ print(f" gradio: not installed (optional)")
122
+
123
+ # OpenEnv
124
+ try:
125
+ import openenv
126
+ print(f" openenv: OK")
127
+ except ImportError:
128
+ print(f" openenv: MISSING β€” run: pip install -r requirements.txt")
129
+
130
+ # Environment test
131
+ print("")
132
+ sys.path.insert(0, ".")
133
+ from server.environment import APITestEnvironment
134
+ from models import APITestAction, HTTPMethod
135
+ env = APITestEnvironment()
136
+ obs = env.reset(seed=42, task_id="basic_validation")
137
+ obs = env.step(APITestAction(method=HTTPMethod.GET, endpoint="/tasks/999999", expected_status=404))
138
+ assert obs.bugs_found_so_far == 1, "Bug detection failed!"
139
+ print(f" Environment: OK (bug detection verified)")
140
+ PYEOF
141
+
142
+ echo ""
143
+
144
+ # --- Step 5: Done ---
145
+ echo "============================================"
146
+ echo " Setup complete!"
147
+ echo "============================================"
148
+ echo ""
149
+ echo " Activate: source .venv/bin/activate"
150
+ echo ""
151
+ echo " Gradio UI: python gradio_app.py"
152
+ echo " Baselines: python -m training.evaluate --task all --agent all"
153
+ echo " Training: python -m training.grpo --model-id Qwen/Qwen3-1.7B"
154
+ echo " Test mode: python -m training.grpo --test-mode"
155
+ echo ""
156
+ echo " For HF Hub: huggingface-cli login"
157
+ echo " For W&B: wandb login"
158
+ echo ""
train_grpo.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """GRPO training β€” see training/grpo.py for the full implementation."""
3
+ from training.grpo import main
4
+
5
+ if __name__ == "__main__":
6
+ main()
training/README.md ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Training Module
2
+
3
+ Everything related to training an AI agent to test APIs using GRPO (Group Relative Policy Optimization).
4
+
5
+ ---
6
+
7
+ ## Setup
8
+
9
+ ```bash
10
+ cd api_testing_env
11
+
12
+ # Option 1: Automated setup (creates venv, installs everything)
13
+ bash setup.sh
14
+
15
+ # Option 2: Manual setup
16
+ python3 -m venv .venv
17
+ source .venv/bin/activate
18
+ pip install -r requirements.txt
19
+
20
+ # Optional: login to HuggingFace Hub (for model push)
21
+ huggingface-cli login
22
+
23
+ # Optional: login to Weights & Biases (for logging)
24
+ wandb login
25
+ ```
26
+
27
+ ### Environment Variables
28
+
29
+ Create a `.env` file in `api_testing_env/` (or export in your shell):
30
+
31
+ ```bash
32
+ # .env
33
+
34
+ # HuggingFace Hub β€” required for --push-to-hub
35
+ # Get your token at: https://huggingface.co/settings/tokens
36
+ HF_TOKEN=hf_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
37
+
38
+ # Weights & Biases β€” required for --use-wandb
39
+ # Get your key at: https://wandb.ai/authorize
40
+ WANDB_API_KEY=xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
41
+
42
+ # Optional: set W&B defaults
43
+ WANDB_PROJECT=api-testing-grpo
44
+ WANDB_ENTITY=your-team-name
45
+ ```
46
+
47
+ **Three ways to provide these keys:**
48
+
49
+ | Method | Command |
50
+ |--------|---------|
51
+ | `.env` file | Create `.env` as shown above, then `source .env` before training |
52
+ | CLI login | `huggingface-cli login` and `wandb login` (stores keys in ~/.cache) |
53
+ | Inline export | `export HF_TOKEN=hf_xxx && export WANDB_API_KEY=xxx` |
54
+
55
+ > **Important:** Never commit `.env` to git. It's already in `.gitignore`.
56
+
57
+ ---
58
+
59
+ ## Quick Start
60
+
61
+ ```bash
62
+ cd api_testing_env
63
+ source .venv/bin/activate
64
+
65
+ # 1. See what training prompts look like (no GPU needed)
66
+ SHOW_PROMPTS=1 python -m training.grpo
67
+
68
+ # 2. Quick sanity check (CPU, ~2 minutes)
69
+ python -m training.grpo --test-mode
70
+
71
+ # 3. Real training (GPU required)
72
+ python -m training.grpo --model-id Qwen/Qwen3-1.7B --num-episodes 100
73
+
74
+ # 4. With HuggingFace Hub push
75
+ python -m training.grpo \
76
+ --push-to-hub --hf-repo-id your-username/api-tester-grpo
77
+
78
+ # 5. With Weights & Biases logging
79
+ python -m training.grpo \
80
+ --use-wandb --wandb-project api-testing-grpo
81
+
82
+ # 6. Full pipeline: training + HF push + W&B
83
+ python -m training.grpo \
84
+ --model-id Qwen/Qwen3-1.7B \
85
+ --num-episodes 100 \
86
+ --push-to-hub --hf-repo-id your-username/api-tester-grpo \
87
+ --use-wandb --wandb-project api-testing-grpo
88
+
89
+ # 7. Run baseline agents only (no GPU needed)
90
+ python -m training.evaluate --task all --agent all --url http://localhost:8000
91
+
92
+ # 8. Resume from checkpoint
93
+ python -m training.grpo --model-id ./checkpoints/step_50
94
+ ```
95
+
96
+ ---
97
+
98
+ ## How Training Works
99
+
100
+ There is **no external dataset**. The environment generates unique episodes on the fly.
101
+
102
+ ```
103
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
104
+ β”‚ GRPO Training Loop β”‚
105
+ β”‚ β”‚
106
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ 1. env.reset(seed=N) β”‚
107
+ β”‚ β”‚ β”‚ β†’ unique users, tasks, data β”‚
108
+ β”‚ Qwen β”‚ β”‚ β”‚
109
+ β”‚ 1.7B │──▢│ 2. LLM generates: {"method":"GET",...} β”‚
110
+ β”‚ + LoRA β”‚ β”‚ β”‚
111
+ β”‚ │◀──│ 3. env.step(action) β†’ reward β”‚
112
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ coverage + bugs + validity β”‚
113
+ β”‚ β”‚
114
+ β”‚ 4. GRPO: generate 4 attempts per prompt, β”‚
115
+ β”‚ keep best, update model weights β”‚
116
+ β”‚ β”‚
117
+ β”‚ 5. Repeat with next seed β”‚
118
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
119
+ ```
120
+
121
+ ### Why no dataset file?
122
+
123
+ Each `reset(seed=N)` creates a **unique database** with different users, tasks, and data:
124
+
125
+ | Seed | Users | Tasks |
126
+ |------|-------|-------|
127
+ | 42 | diana, alice, xander, ivan, hannah | 8 tasks |
128
+ | 99 | mike, george, tom, fiona | 6 tasks |
129
+ | 7 | priya, kevin, wendy | 4 tasks |
130
+
131
+ The agent can't memorize "login as alice" because alice might not exist. It must **read the observation and adapt** β€” that's the learning signal.
132
+
133
+ The bugs (13 planted flaws) are structural β€” same code flaws every episode β€” but the path to finding them changes because the data is different.
134
+
135
+ ---
136
+
137
+ ## Training Pipeline
138
+
139
+ The full training pipeline runs these steps automatically:
140
+
141
+ ```
142
+ 1. Run baseline agents (random, sequential, smart) across all tasks
143
+ ↓
144
+ 2. Load base model (Qwen 1.7B)
145
+ ↓
146
+ 3. Evaluate base model before training (establishes LLM baseline)
147
+ ↓
148
+ 4. GRPO training with LoRA
149
+ ↓
150
+ 5. Save model locally to --output-dir
151
+ ↓
152
+ 6. Push to HuggingFace Hub (if --push-to-hub)
153
+ ↓
154
+ 7. Evaluate trained model after GRPO
155
+ ↓
156
+ 8. Print comparison table (baselines vs base vs trained)
157
+ ↓
158
+ 9. Save metrics (JSON + markdown) to output-dir/metrics/
159
+ ↓
160
+ 10. Save comparison plots (PNG) to output-dir/metrics/plots/
161
+ ↓
162
+ 11. Finalize W&B run (if --use-wandb)
163
+ ```
164
+
165
+ ---
166
+
167
+ ## File Guide
168
+
169
+ | File | Purpose | When to modify |
170
+ |------|---------|----------------|
171
+ | `prompts.py` | System prompt, `format_observation()`, `parse_action()` | Change how the LLM sees tasks or formats actions |
172
+ | `rewards.py` | `format_reward_fn()`, `environment_reward_fn()` | Tune reward scaling or add new reward signals |
173
+ | `agents.py` | `RandomAgent`, `SequentialAgent`, `SmartAgent` | Add new baseline strategies |
174
+ | `grpo.py` | `build_training_prompts()`, `train_grpo()` | Change training hyperparameters or model |
175
+ | `evaluate.py` | `run_rollout()`, `run_baseline_local()`, remote runner | Change evaluation logic |
176
+
177
+ ### prompts.py
178
+
179
+ The bridge between the environment and the LLM.
180
+
181
+ **`SYSTEM_PROMPT`** β€” Instructions telling the LLM it's an API tester. Includes output format (JSON) and testing strategies.
182
+
183
+ **`format_observation(obs)`** β€” Converts an environment observation into text:
184
+ - First turn: full API spec + task description + available users
185
+ - Later turns: last response + feedback + progress stats + auth tokens
186
+
187
+ **`parse_action(text)`** β€” Extracts JSON from LLM output. Handles:
188
+ - Raw JSON: `{"method": "GET", "endpoint": "/tasks"}`
189
+ - Code blocks: `` ```json {...} ``` ``
190
+ - Extra text around JSON: `"I'll try: {...}"`
191
+
192
+ ### rewards.py
193
+
194
+ Two reward functions that GRPO uses to score each LLM completion:
195
+
196
+ **`format_reward_fn`** β€” Binary: +1.0 if valid JSON action, -1.0 if not. Teaches the model to always output parseable actions.
197
+
198
+ **`environment_reward_fn`** β€” Runs the action in the environment and returns the actual reward (coverage + bugs + validity), scaled by 5.0 to dominate over format reward.
199
+
200
+ ### agents.py
201
+
202
+ Three hand-coded baselines for comparison:
203
+
204
+ | Agent | Strategy | Expected Score |
205
+ |-------|----------|---------------|
206
+ | `RandomAgent` | Random method + random endpoint | ~0.10 |
207
+ | `SequentialAgent` | Fixed sequence: GET, POST, PUT, DELETE each endpoint | ~0.35 |
208
+ | `SmartAgent` | Multi-phase: discover β†’ auth β†’ CRUD β†’ bug hunt β†’ security | ~0.55 |
209
+
210
+ A GRPO-trained model should beat the SmartAgent.
211
+
212
+ ### grpo.py
213
+
214
+ The main training script.
215
+
216
+ **`build_training_prompts(num_episodes)`** β€” Creates N prompts by resetting the environment with seeds 0..N. Each prompt is a chat message with system prompt + initial observation.
217
+
218
+ **`run_baseline_evaluation(seed)`** β€” Runs all three baseline agents across all tasks before training starts.
219
+
220
+ **`train_grpo(args)`** β€” Full GRPO loop:
221
+ 1. Run baseline agents for comparison
222
+ 2. Load model + tokenizer (Qwen 1.7B default)
223
+ 3. Evaluate base model before training
224
+ 4. Apply LoRA (r=16, alpha=32, targets q_proj + v_proj)
225
+ 5. Generate prompts from environment
226
+ 6. Create per-prompt environment instances for reward eval
227
+ 7. Train with TRL's GRPOTrainer
228
+ 8. Save model locally + push to HF Hub
229
+ 9. Evaluate trained model + print comparison
230
+ 10. Save metrics (JSON, markdown) and plots (PNG)
231
+ 11. Finalize W&B run
232
+
233
+ **`save_metrics()`** β€” Saves `results.json` and `results.md` to `output-dir/metrics/`.
234
+
235
+ **`save_plots()`** β€” Generates three comparison bar charts (reward, bugs, coverage) saved as PNGs.
236
+
237
+ ### evaluate.py
238
+
239
+ **`run_rollout(model, tokenizer, task_id, seed)`** β€” Runs one full episode with a HuggingFace model. Multi-turn: LLM generates action β†’ env steps β†’ LLM sees result β†’ repeats.
240
+
241
+ **`run_baseline_local(agent_name, task_id, seed)`** β€” Runs baseline agents against the local environment (no server needed). Used by `grpo.py` to establish baselines before training.
242
+
243
+ **`run_episode(url, task_id, agent_cls)`** β€” Runs a baseline agent against a remote server via WebSocket.
244
+
245
+ ---
246
+
247
+ ## Training Hyperparameters
248
+
249
+ | Parameter | Default | Description |
250
+ |-----------|---------|-------------|
251
+ | `--model-id` | `Qwen/Qwen3-1.7B` | Base model (any HF causal LM) |
252
+ | `--num-episodes` | 50 | Training prompts (more = more diverse episodes) |
253
+ | `--num-generations` | 4 | GRPO rollouts per prompt (higher = better but slower) |
254
+ | `--max-completion-length` | 256 | Max tokens per LLM response |
255
+ | `--max-steps` | 200 | Total training optimizer steps |
256
+ | `--learning-rate` | 2e-5 | AdamW learning rate |
257
+ | `--batch-size` | 1 | Per-device batch size |
258
+ | `--output-dir` | `./checkpoints/grpo_api_tester` | Where to save model |
259
+ | `--push-to-hub` | off | Push trained model to HuggingFace Hub |
260
+ | `--hf-repo-id` | none | HF Hub repo (e.g., `user/api-tester-grpo`) |
261
+ | `--use-wandb` | off | Enable Weights & Biases logging |
262
+ | `--wandb-project` | `api-testing-grpo` | W&B project name |
263
+ | `--wandb-run-name` | auto | W&B run name |
264
+ | `--test-mode` | off | Quick 3-episode, 2-gen, 5-step test |
265
+
266
+ ### Hardware Requirements
267
+
268
+ | Setup | GPU | Time | Model |
269
+ |-------|-----|------|-------|
270
+ | Colab Free | T4 (16GB) | ~1-2 hours | Qwen 1.7B + 4-bit LoRA |
271
+ | Colab Pro | A100 (40GB) | ~30 min | Qwen 4B + LoRA |
272
+ | Local | Any 8GB+ | ~1-2 hours | Qwen 1.7B + 4-bit LoRA |
273
+ | CPU only | None | `--test-mode` only | Verifies pipeline works |
274
+
275
+ ---
276
+
277
+ ## Output Structure
278
+
279
+ After training, your output directory will look like:
280
+
281
+ ```
282
+ checkpoints/grpo_api_tester/
283
+ β”œβ”€β”€ adapter_config.json # LoRA adapter config
284
+ β”œβ”€β”€ adapter_model.safetensors # Trained LoRA weights
285
+ β”œβ”€β”€ tokenizer.json # Tokenizer files
286
+ β”œβ”€β”€ tokenizer_config.json
287
+ β”œβ”€β”€ special_tokens_map.json
288
+ └── metrics/
289
+ β”œβ”€β”€ results.json # Full results (baselines + base + trained)
290
+ β”œβ”€β”€ results.md # Markdown comparison table
291
+ └── plots/
292
+ β”œβ”€β”€ reward_comparison.png # Bar chart: reward across all agents
293
+ β”œβ”€β”€ bugs_comparison.png # Bar chart: bugs found
294
+ └── coverage_comparison.png # Bar chart: API coverage %
295
+ ```
296
+
297
+ ---
298
+
299
+ ## Weights & Biases Integration
300
+
301
+ When `--use-wandb` is enabled, the following is logged:
302
+
303
+ | Metric | Description |
304
+ |--------|-------------|
305
+ | `baseline/{agent}/{task}/reward` | Baseline agent scores |
306
+ | `base_model/{task}/reward` | Pre-training model scores |
307
+ | `trained_model/{task}/reward` | Post-training model scores |
308
+ | `delta/{task}/reward` | Improvement over base model |
309
+ | `plots/*` | Comparison charts as W&B images |
310
+ | TRL defaults | Loss, learning rate, reward mean/std |
311
+
312
+ ---
313
+
314
+ ## Expected Results
315
+
316
+ ### Before Training (base Qwen 1.7B, no fine-tuning)
317
+
318
+ The base model can output JSON sometimes, but has no API testing strategy:
319
+ ```
320
+ basic_validation: ~0.15 (random-level)
321
+ edge_cases: ~0.08
322
+ security_workflows: ~0.03
323
+ ```
324
+
325
+ ### After GRPO (50 episodes, 200 steps)
326
+
327
+ The model learns systematic testing patterns:
328
+ ```
329
+ basic_validation: ~0.55-0.65
330
+ edge_cases: ~0.35-0.45
331
+ security_workflows: ~0.25-0.35
332
+ ```
333
+
334
+ ### What the Model Learns
335
+
336
+ 1. **Output format** β€” Always produce valid JSON (format reward)
337
+ 2. **Coverage** β€” Test different endpoints, don't repeat the same request
338
+ 3. **Dependency chaining** β€” POST to create, then GET/PUT/DELETE the created resource
339
+ 4. **Bug patterns** β€” Try non-existent IDs, missing fields, invalid emails
340
+ 5. **Auth workflows** β€” Login first, use tokens in subsequent requests
341
+ 6. **Security testing** β€” Try cross-user access, injection payloads
342
+
343
+ ---
344
+
345
+ ## Extending the Training
346
+
347
+ ### Add a new reward signal
348
+
349
+ Edit `rewards.py`:
350
+
351
+ ```python
352
+ def efficiency_reward_fn(completions: list[str], **kwargs) -> list[float]:
353
+ """Reward for concise, focused actions (penalize wasted steps)."""
354
+ rewards = []
355
+ for text in completions:
356
+ action = parse_action(text)
357
+ if action and action.expected_status:
358
+ rewards.append(0.5) # Bonus for predicting expected status
359
+ else:
360
+ rewards.append(0.0)
361
+ return rewards
362
+ ```
363
+
364
+ Then add it to the combined reward in `grpo.py`.
365
+
366
+ ### Add a new baseline agent
367
+
368
+ Edit `agents.py`:
369
+
370
+ ```python
371
+ class CoverageAgent:
372
+ """Agent that prioritizes hitting every endpoint once."""
373
+ name = "coverage"
374
+
375
+ def __init__(self):
376
+ self.tested = set()
377
+ # ...
378
+ ```
379
+
380
+ Then add it to the `AGENTS` dict.
381
+
382
+ ### Use a different model
383
+
384
+ ```bash
385
+ # Qwen 2.5 (smaller, faster)
386
+ python -m training.grpo --model-id Qwen/Qwen2.5-1.5B
387
+
388
+ # Llama 3 (if you have access)
389
+ python -m training.grpo --model-id meta-llama/Llama-3.2-1B
390
+ ```
391
+
392
+ Any HuggingFace causal language model works β€” just make sure it supports chat templates.
training/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training module for the API Testing Environment.
3
+
4
+ Contains:
5
+ - prompts.py β€” System prompt, observation formatting, action parsing
6
+ - rewards.py β€” Reward functions for GRPO (format + environment)
7
+ - agents.py β€” Baseline agents (random, sequential, smart)
8
+ - grpo.py β€” GRPO training loop with TRL, HF Hub push, W&B logging
9
+ - evaluate.py β€” Evaluation / rollout runner (local + remote)
10
+ """
training/agents.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Baseline agents for the API Testing Environment.
3
+
4
+ Three agents of increasing sophistication:
5
+ 1. RandomAgent β€” Picks random endpoints/methods (lower bound)
6
+ 2. SequentialAgent β€” Systematically tests each endpoint in order
7
+ 3. SmartAgent β€” Chains requests and probes for known bug patterns
8
+ """
9
+
10
+ import random
11
+ import sys
12
+ import os
13
+
14
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
15
+ from models import APITestAction, HTTPMethod
16
+
17
+
18
+ class RandomAgent:
19
+ """Randomly picks endpoints and methods. Baseline for comparison."""
20
+
21
+ name = "random"
22
+
23
+ ENDPOINTS = ["/tasks", "/tasks/1", "/tasks/2", "/tasks/999", "/users", "/users/1", "/auth/login"]
24
+ METHODS = ["GET", "POST", "PUT", "DELETE"]
25
+
26
+ def act(self, observation: dict) -> APITestAction:
27
+ method = random.choice(self.METHODS)
28
+ endpoint = random.choice(self.ENDPOINTS)
29
+ body = None
30
+ headers = {}
31
+
32
+ if method == "POST" and endpoint == "/tasks":
33
+ body = {"title": f"Random task {random.randint(1, 100)}"}
34
+ elif method == "POST" and endpoint == "/auth/login":
35
+ body = {"username": random.choice(["alice", "bob"]), "password": "pass"}
36
+ elif method == "POST" and endpoint == "/users":
37
+ body = {"username": f"user{random.randint(100, 999)}", "email": "test@test.com", "password": "pass"}
38
+ elif method == "PUT":
39
+ endpoint = f"/tasks/{random.randint(1, 5)}"
40
+ body = {"title": "Updated"}
41
+
42
+ return APITestAction(
43
+ method=HTTPMethod(method) if method in ("GET", "POST", "PUT", "DELETE") else HTTPMethod.GET,
44
+ endpoint=endpoint,
45
+ headers=headers,
46
+ body=body,
47
+ )
48
+
49
+
50
+ class SequentialAgent:
51
+ """Systematically tests each endpoint with valid requests."""
52
+
53
+ name = "sequential"
54
+
55
+ def __init__(self):
56
+ self.step = 0
57
+
58
+ def act(self, observation: dict) -> APITestAction:
59
+ self.step += 1
60
+ actions = self._get_action_sequence()
61
+ idx = min(self.step - 1, len(actions) - 1)
62
+ return actions[idx]
63
+
64
+ def _get_action_sequence(self) -> list[APITestAction]:
65
+ return [
66
+ APITestAction(method=HTTPMethod.GET, endpoint="/tasks", expected_status=200),
67
+ APITestAction(method=HTTPMethod.GET, endpoint="/users", expected_status=200),
68
+ APITestAction(method=HTTPMethod.GET, endpoint="/tasks/1", expected_status=200),
69
+ APITestAction(method=HTTPMethod.GET, endpoint="/users/1", expected_status=200),
70
+ APITestAction(method=HTTPMethod.POST, endpoint="/auth/login",
71
+ body={"username": "alice", "password": "password123"}, expected_status=200),
72
+ APITestAction(method=HTTPMethod.POST, endpoint="/tasks",
73
+ body={"title": "Test Task", "description": "Created by baseline"}, expected_status=201),
74
+ APITestAction(method=HTTPMethod.POST, endpoint="/users",
75
+ body={"username": "testuser", "email": "test@example.com", "password": "test123"},
76
+ expected_status=201),
77
+ APITestAction(method=HTTPMethod.PUT, endpoint="/tasks/1",
78
+ body={"title": "Updated Task"}, expected_status=200),
79
+ APITestAction(method=HTTPMethod.DELETE, endpoint="/tasks/5", expected_status=200),
80
+ APITestAction(method=HTTPMethod.GET, endpoint="/tasks/999999", expected_status=404),
81
+ APITestAction(method=HTTPMethod.POST, endpoint="/tasks",
82
+ body={"description": "No title"}, expected_status=400),
83
+ APITestAction(method=HTTPMethod.GET, endpoint="/tasks",
84
+ query_params={"page": -1, "limit": 10}, expected_status=400),
85
+ APITestAction(method=HTTPMethod.GET, endpoint="/tasks",
86
+ query_params={"status": "done"}, expected_status=200),
87
+ APITestAction(method=HTTPMethod.GET, endpoint="/tasks",
88
+ query_params={"sort": "title"}, expected_status=200),
89
+ APITestAction(method=HTTPMethod.GET, endpoint="/tasks/2", expected_status=200),
90
+ ]
91
+
92
+
93
+ class SmartAgent:
94
+ """Heuristic agent that chains requests and probes for bugs."""
95
+
96
+ name = "smart"
97
+
98
+ def __init__(self):
99
+ self.step = 0
100
+ self.auth_tokens = {}
101
+ self.created_ids = []
102
+
103
+ def act(self, observation: dict) -> APITestAction:
104
+ self.step += 1
105
+
106
+ if isinstance(observation, dict):
107
+ self.auth_tokens = observation.get("auth_tokens", self.auth_tokens)
108
+ ids = observation.get("known_resource_ids", {})
109
+ for rtype, id_list in ids.items():
110
+ for rid in id_list:
111
+ if rid not in self.created_ids:
112
+ self.created_ids.append(rid)
113
+
114
+ actions = self._get_smart_sequence()
115
+ idx = min(self.step - 1, len(actions) - 1)
116
+ return actions[idx]
117
+
118
+ def _get_smart_sequence(self) -> list[APITestAction]:
119
+ alice_token = self.auth_tokens.get("alice", "")
120
+ bob_token = self.auth_tokens.get("bob", "")
121
+ alice_auth = {"Authorization": f"Bearer {alice_token}"} if alice_token else {}
122
+ bob_auth = {"Authorization": f"Bearer {bob_token}"} if bob_token else {}
123
+
124
+ return [
125
+ # Phase 1: Discovery
126
+ APITestAction(method=HTTPMethod.GET, endpoint="/tasks", expected_status=200),
127
+ APITestAction(method=HTTPMethod.GET, endpoint="/users", expected_status=200),
128
+ # Phase 2: Authentication
129
+ APITestAction(method=HTTPMethod.POST, endpoint="/auth/login",
130
+ body={"username": "alice", "password": "password123"}, expected_status=200),
131
+ APITestAction(method=HTTPMethod.POST, endpoint="/auth/login",
132
+ body={"username": "bob", "password": "password123"}, expected_status=200),
133
+ # Phase 3: CRUD with auth
134
+ APITestAction(method=HTTPMethod.POST, endpoint="/tasks",
135
+ body={"title": "Alice's task", "description": "Test"},
136
+ headers=alice_auth, expected_status=201),
137
+ APITestAction(method=HTTPMethod.GET, endpoint="/tasks/1", headers=alice_auth, expected_status=200),
138
+ # Phase 4: Easy bugs
139
+ APITestAction(method=HTTPMethod.GET, endpoint="/tasks/999999", expected_status=404),
140
+ APITestAction(method=HTTPMethod.POST, endpoint="/tasks",
141
+ body={"description": "no title"}, expected_status=400),
142
+ APITestAction(method=HTTPMethod.GET, endpoint="/tasks",
143
+ query_params={"page": -1, "limit": 10}, expected_status=400),
144
+ # Phase 5: Medium bugs
145
+ APITestAction(method=HTTPMethod.PUT, endpoint="/tasks/1",
146
+ body={"assignee_email": "not-an-email"}, expected_status=422),
147
+ APITestAction(method=HTTPMethod.DELETE, endpoint="/tasks/99999", expected_status=404),
148
+ APITestAction(method=HTTPMethod.GET, endpoint="/tasks",
149
+ query_params={"limit": 999999}, expected_status=200),
150
+ # Phase 6: User bugs
151
+ APITestAction(method=HTTPMethod.POST, endpoint="/users",
152
+ body={"username": "baduser", "email": "invalid-email", "password": "test"},
153
+ expected_status=422),
154
+ APITestAction(method=HTTPMethod.POST, endpoint="/auth/login",
155
+ body={"username": "alice", "password": ""}, expected_status=401),
156
+ # Phase 7: BOLA
157
+ APITestAction(method=HTTPMethod.GET, endpoint="/tasks/1",
158
+ headers=bob_auth, expected_status=403),
159
+ # Phase 8: Injection
160
+ APITestAction(method=HTTPMethod.POST, endpoint="/tasks",
161
+ body={"title": "test'; DROP TABLE tasks;--"}, expected_status=201),
162
+ APITestAction(method=HTTPMethod.POST, endpoint="/tasks",
163
+ body={"title": "A" * 6000}, expected_status=400),
164
+ # Phase 9: Cross-user modification
165
+ APITestAction(method=HTTPMethod.PUT, endpoint="/tasks/1",
166
+ body={"title": "Bob modified Alice's task"},
167
+ headers=bob_auth, expected_status=403),
168
+ # Phase 10: State consistency
169
+ APITestAction(method=HTTPMethod.POST, endpoint="/tasks",
170
+ body={"title": "Ephemeral task"}, expected_status=201),
171
+ APITestAction(method=HTTPMethod.DELETE, endpoint="/tasks/6", expected_status=200),
172
+ APITestAction(method=HTTPMethod.GET, endpoint="/tasks/6", expected_status=404),
173
+ # Phase 11: Coverage
174
+ APITestAction(method=HTTPMethod.GET, endpoint="/tasks",
175
+ query_params={"status": "done"}, expected_status=200),
176
+ APITestAction(method=HTTPMethod.GET, endpoint="/tasks",
177
+ query_params={"sort": "title"}, expected_status=200),
178
+ APITestAction(method=HTTPMethod.GET, endpoint="/users/2", expected_status=200),
179
+ # Phase 12: Password hash check
180
+ APITestAction(method=HTTPMethod.POST, endpoint="/users",
181
+ body={"username": "newuser2", "email": "valid@email.com", "password": "pass"},
182
+ expected_status=201),
183
+ ]
184
+
185
+
186
+ AGENTS = {
187
+ "random": RandomAgent,
188
+ "sequential": SequentialAgent,
189
+ "smart": SmartAgent,
190
+ }
training/evaluate.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Evaluation and rollout runner.
4
+
5
+ - run_rollout(): Run a single episode with a HuggingFace model
6
+ - run_baseline_local(): Run baseline agents against the local environment
7
+ - run_baseline(): Run baseline agents against a remote server
8
+ - main(): CLI for running baselines
9
+ """
10
+
11
+ import argparse
12
+ import asyncio
13
+ import logging
14
+ import random
15
+ import sys
16
+ import os
17
+
18
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
19
+
20
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
21
+ logger = logging.getLogger(__name__)
22
+
23
+ from models import APITestAction, HTTPMethod
24
+ from server.environment import APITestEnvironment
25
+ from .prompts import (
26
+ PLAN_SYSTEM_PROMPT, format_plan_prompt,
27
+ parse_action, parse_test_plan,
28
+ )
29
+ from .agents import AGENTS
30
+
31
+
32
+ def run_rollout(
33
+ model,
34
+ tokenizer,
35
+ task_id: str = "basic_validation",
36
+ seed: int = 42,
37
+ max_steps: int | None = None,
38
+ ) -> dict:
39
+ """Run a single episode with a HuggingFace model.
40
+
41
+ Uses PLAN mode: the model generates a full test plan (JSON array) in one shot,
42
+ then all actions are executed sequentially. This matches how training works.
43
+
44
+ Falls back to multi-turn mode if the model can't produce a valid plan.
45
+ """
46
+ import torch
47
+ import time as _time
48
+
49
+ # Force GPU if available
50
+ if torch.cuda.is_available():
51
+ device = torch.device("cuda")
52
+ # Move model to GPU if it's on CPU
53
+ if next(model.parameters()).device.type == "cpu":
54
+ logger.info(" Moving model to GPU...")
55
+ model = model.to(device)
56
+ else:
57
+ device = next(model.parameters()).device
58
+
59
+ env = APITestEnvironment()
60
+ obs = env.reset(seed=seed, task_id=task_id)
61
+ actual_max = max_steps or obs.max_steps
62
+ logger.info(f" Rollout: {task_id} | max_steps={actual_max} | device={device}")
63
+
64
+ # --- Try plan mode first (matches training) ---
65
+ plan_prompt = format_plan_prompt(obs)
66
+ messages = [
67
+ {"role": "system", "content": PLAN_SYSTEM_PROMPT},
68
+ {"role": "user", "content": plan_prompt},
69
+ ]
70
+
71
+ # Qwen3 thinking support
72
+ chat_kwargs = {}
73
+ if "qwen3" in str(getattr(model, "name_or_path", "") or "").lower():
74
+ chat_kwargs["enable_thinking"] = True
75
+
76
+ prompt_text = tokenizer.apply_chat_template(
77
+ messages, tokenize=False, add_generation_prompt=True, **chat_kwargs,
78
+ )
79
+ inputs = tokenizer(prompt_text, return_tensors="pt").to(device)
80
+
81
+ gen_start = _time.time()
82
+ print(f" Generating test plan...", end="", flush=True)
83
+ with torch.no_grad():
84
+ output = model.generate(
85
+ **inputs,
86
+ max_new_tokens=4096, # Match training max_completion_length
87
+ temperature=0.7,
88
+ do_sample=True,
89
+ pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
90
+ )
91
+ completion = tokenizer.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
92
+ gen_time = _time.time() - gen_start
93
+ print(f" done ({gen_time:.1f}s, {len(completion)} chars)")
94
+
95
+ # Parse the plan
96
+ actions = parse_test_plan(completion)
97
+ if actions:
98
+ logger.info(f" Plan generated: {len(actions)} actions")
99
+ else:
100
+ # Fallback: try single action parse
101
+ single = parse_action(completion)
102
+ if single:
103
+ actions = [single]
104
+ logger.info(" Plan parse failed, got 1 action from fallback")
105
+ else:
106
+ logger.warning(" Failed to parse any actions from model output")
107
+ # Print first 500 chars of completion for debugging
108
+ preview = completion[:500].replace("\n", " ")
109
+ logger.warning(f" Model output preview: {preview}...")
110
+ actions = []
111
+
112
+ # Limit to max_steps
113
+ actions = actions[:actual_max]
114
+
115
+ # Execute all actions
116
+ total_reward = 0.0
117
+ for i, action in enumerate(actions):
118
+ try:
119
+ obs = env.step(action)
120
+ total_reward += obs.reward or 0.0
121
+ method_str = action.method.value if hasattr(action.method, "value") else str(action.method)
122
+ print(f" Step {i+1}/{len(actions)}: {method_str} {action.endpoint} -> "
123
+ f"{obs.status_code} | reward={obs.reward:.3f} | bugs={obs.bugs_found_so_far}")
124
+ except Exception as e:
125
+ print(f" Step {i+1}/{len(actions)}: ERROR - {e}")
126
+
127
+ # If no actions were generated, show that
128
+ if not actions:
129
+ print(" (no valid actions generated)")
130
+
131
+ state = env.state
132
+ return {
133
+ "task_id": task_id,
134
+ "seed": seed,
135
+ "steps": len(actions),
136
+ "total_reward": round(total_reward, 4),
137
+ "bugs_found": state.bugs_found,
138
+ "total_bugs": state.total_bugs,
139
+ "coverage_pct": state.coverage_pct,
140
+ "bugs_found_ids": state.bugs_found_ids,
141
+ }
142
+
143
+
144
+ def run_baseline_local(
145
+ agent_name: str = "all",
146
+ task_id: str = "all",
147
+ seed: int = 42,
148
+ ) -> list[dict]:
149
+ """Run baseline agents against the local environment (no server needed).
150
+
151
+ Args:
152
+ agent_name: "random", "sequential", "smart", or "all"
153
+ task_id: task ID or "all"
154
+ seed: random seed
155
+
156
+ Returns:
157
+ List of result dicts with agent, task_id, total_reward, bugs_found, etc.
158
+ """
159
+ tasks = ["basic_validation", "edge_cases", "security_workflows"] if task_id == "all" else [task_id]
160
+ agents = list(AGENTS.items()) if agent_name == "all" else [(agent_name, AGENTS[agent_name])]
161
+
162
+ results = []
163
+ for tid in tasks:
164
+ for aname, agent_cls in agents:
165
+ random.seed(seed)
166
+ agent = agent_cls()
167
+ env = APITestEnvironment()
168
+ obs = env.reset(seed=seed, task_id=tid)
169
+
170
+ total_reward = 0.0
171
+ step = 0
172
+
173
+ while not obs.done and step < obs.max_steps:
174
+ obs_dict = {
175
+ "status_code": obs.status_code,
176
+ "response_body": obs.response_body,
177
+ "feedback": obs.feedback,
178
+ "bugs_found_so_far": obs.bugs_found_so_far,
179
+ "coverage_summary": obs.coverage_summary,
180
+ "known_resource_ids": obs.known_resource_ids,
181
+ "auth_tokens": obs.auth_tokens,
182
+ "steps_taken": obs.steps_taken,
183
+ "max_steps": obs.max_steps,
184
+ }
185
+
186
+ action = agent.act(obs_dict)
187
+ obs = env.step(action)
188
+ total_reward += obs.reward or 0.0
189
+ step += 1
190
+
191
+ state = env.state
192
+ result = {
193
+ "agent": aname,
194
+ "task_id": tid,
195
+ "seed": seed,
196
+ "steps": step,
197
+ "total_reward": round(total_reward, 4),
198
+ "bugs_found": state.bugs_found,
199
+ "total_bugs": state.total_bugs,
200
+ "coverage_pct": state.coverage_pct,
201
+ "bugs_found_ids": state.bugs_found_ids,
202
+ }
203
+ results.append(result)
204
+ logger.info(
205
+ f" [{aname}] {tid}: reward={result['total_reward']:.4f}, "
206
+ f"bugs={result['bugs_found']}/{result['total_bugs']}, "
207
+ f"coverage={result['coverage_pct']:.1f}%"
208
+ )
209
+
210
+ return results
211
+
212
+
213
+ # =====================================================================
214
+ # Remote baseline runner (against server via WebSocket client)
215
+ # =====================================================================
216
+
217
+ async def run_episode(url: str, task_id: str, agent_cls, seed: int = 42) -> dict:
218
+ """Run one baseline episode against a remote server."""
219
+ from client import APITestEnv
220
+
221
+ random.seed(seed)
222
+ agent = agent_cls()
223
+
224
+ async with APITestEnv(base_url=url) as env:
225
+ result = await env.reset(task_id=task_id)
226
+ obs = result.observation
227
+
228
+ logger.info(f"Starting {agent.name} agent on task '{task_id}'")
229
+
230
+ total_reward = 0.0
231
+ step = 0
232
+
233
+ while not result.done:
234
+ obs_dict = {
235
+ "status_code": obs.status_code,
236
+ "response_body": obs.response_body,
237
+ "feedback": obs.feedback,
238
+ "bugs_found_so_far": obs.bugs_found_so_far,
239
+ "coverage_summary": obs.coverage_summary,
240
+ "known_resource_ids": obs.known_resource_ids,
241
+ "auth_tokens": obs.auth_tokens,
242
+ "steps_taken": obs.steps_taken,
243
+ "max_steps": obs.max_steps,
244
+ }
245
+
246
+ action = agent.act(obs_dict)
247
+ result = await env.step(action)
248
+ obs = result.observation
249
+ total_reward += result.reward or 0
250
+
251
+ step += 1
252
+ method = action.method.value if hasattr(action.method, "value") else str(action.method)
253
+ logger.info(
254
+ f" Step {step}: {method} {action.endpoint} -> "
255
+ f"{obs.status_code} | reward={result.reward:.4f} | bugs={obs.bugs_found_so_far}"
256
+ )
257
+
258
+ state = await env.state()
259
+ return {
260
+ "task_id": task_id,
261
+ "agent": agent.name,
262
+ "total_reward": round(total_reward, 4),
263
+ "bugs_found": state.bugs_found,
264
+ "total_bugs": state.total_bugs,
265
+ "coverage_pct": state.coverage_pct,
266
+ "steps": step,
267
+ }
268
+
269
+
270
+ async def main_async(args):
271
+ tasks = ["basic_validation", "edge_cases", "security_workflows"] if args.task == "all" else [args.task]
272
+ agents = list(AGENTS.values()) if args.agent == "all" else [AGENTS[args.agent]]
273
+
274
+ results = []
275
+ for task_id in tasks:
276
+ for agent_cls in agents:
277
+ try:
278
+ result = await run_episode(args.url, task_id, agent_cls, seed=args.seed)
279
+ results.append(result)
280
+ logger.info(
281
+ f"\nRESULT: {result['agent']} on {result['task_id']}: "
282
+ f"reward={result['total_reward']}, bugs={result['bugs_found']}/{result['total_bugs']}, "
283
+ f"coverage={result['coverage_pct']:.1f}%"
284
+ )
285
+ except Exception as e:
286
+ logger.error(f"Error running {agent_cls.name} on {task_id}: {e}", exc_info=True)
287
+
288
+ if results:
289
+ print("\n" + "=" * 80)
290
+ print("BASELINE RESULTS SUMMARY")
291
+ print("=" * 80)
292
+ print(f"{'Agent':<15} {'Task':<25} {'Score':<10} {'Bugs':<10} {'Coverage':<10}")
293
+ print("-" * 80)
294
+ for r in results:
295
+ print(
296
+ f"{r['agent']:<15} {r['task_id']:<25} "
297
+ f"{r['total_reward']:<10.4f} "
298
+ f"{r['bugs_found']}/{r['total_bugs']:<8} "
299
+ f"{r['coverage_pct']:<10.1f}%"
300
+ )
301
+ print("=" * 80)
302
+
303
+ return results
304
+
305
+
306
+ def main():
307
+ parser = argparse.ArgumentParser(description="Baseline agents for API Testing Environment")
308
+ parser.add_argument("--url", default="http://localhost:8000", help="Environment server URL")
309
+ parser.add_argument("--task", default="all",
310
+ choices=["basic_validation", "edge_cases", "security_workflows", "all"])
311
+ parser.add_argument("--agent", default="all", choices=["random", "sequential", "smart", "all"])
312
+ parser.add_argument("--seed", type=int, default=42)
313
+ args = parser.parse_args()
314
+ asyncio.run(main_async(args))
315
+
316
+
317
+ if __name__ == "__main__":
318
+ main()
training/grpo.py ADDED
@@ -0,0 +1,783 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ GRPO Training Script for the API Testing Environment.
4
+
5
+ Trains a small LLM (Qwen 1.7B) to become an intelligent API tester
6
+ using Group Relative Policy Optimization (GRPO).
7
+
8
+ The environment IS the dataset β€” each reset(seed=N) creates a unique
9
+ episode with different users, tasks, and data. No external dataset needed.
10
+
11
+ Features:
12
+ - Auto-push trained model weights to HuggingFace Hub
13
+ - Weights & Biases logging for metrics, loss, rewards
14
+ - Baseline agent evaluation before GRPO (random, sequential, smart)
15
+ - Base model evaluation before GRPO for comparison
16
+ - Post-training evaluation with delta reporting
17
+ - Saves metrics, comparison tables, and plots to output dir
18
+
19
+ Usage:
20
+ # Quick test (CPU, 2 minutes)
21
+ python -m training.grpo --test-mode
22
+
23
+ # Real training (GPU required)
24
+ python -m training.grpo --model-id Qwen/Qwen3-1.7B --num-episodes 100
25
+
26
+ # With HF Hub push
27
+ python -m training.grpo --push-to-hub --hf-repo-id your-username/api-tester-grpo
28
+
29
+ # With Weights & Biases
30
+ python -m training.grpo --use-wandb --wandb-project api-testing-grpo
31
+
32
+ # See what prompts look like (no GPU needed)
33
+ SHOW_PROMPTS=1 python -m training.grpo
34
+
35
+ # Resume from checkpoint
36
+ python -m training.grpo --model-id ./checkpoints/step_50
37
+ """
38
+
39
+ import argparse
40
+ import json
41
+ import logging
42
+ import os
43
+ import sys
44
+ import time
45
+
46
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
47
+
48
+ # --- Suppress noisy HTTP/download logs ---
49
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
50
+ logger = logging.getLogger(__name__)
51
+ for _noisy in ["httpx", "httpcore", "urllib3", "huggingface_hub", "filelock",
52
+ "transformers.configuration_utils", "transformers.modeling_utils"]:
53
+ logging.getLogger(_noisy).setLevel(logging.WARNING)
54
+
55
+ # --- MONKEY PATCH FOR LLM-BLENDER ---
56
+ # llm-blender requires TRANSFORMERS_CACHE which was removed in transformers 4.42+
57
+ try:
58
+ import transformers.utils.hub
59
+ if not hasattr(transformers.utils.hub, "TRANSFORMERS_CACHE"):
60
+ transformers.utils.hub.TRANSFORMERS_CACHE = os.getenv("HF_HOME", os.path.expanduser("~/.cache/huggingface/hub"))
61
+ except ImportError:
62
+ pass
63
+ # ------------------------------------
64
+
65
+ from server.environment import APITestEnvironment
66
+ from .prompts import PLAN_SYSTEM_PROMPT, format_plan_prompt
67
+ from .rewards import format_reward_fn, plan_reward_fn, diversity_reward_fn
68
+ from .evaluate import run_rollout, run_baseline_local
69
+
70
+
71
+ def build_training_prompts(
72
+ num_episodes: int = 50,
73
+ task_ids: list[str] | None = None,
74
+ ) -> list[dict]:
75
+ """Generate training prompts for GRPO plan-based training.
76
+
77
+ Each prompt asks the model to output a COMPLETE TEST PLAN (JSON array of actions).
78
+ The reward function will execute the plan on a fresh environment and score it.
79
+ """
80
+ if task_ids is None:
81
+ task_ids = ["basic_validation", "edge_cases", "security_workflows"]
82
+
83
+ prompts = []
84
+ env = APITestEnvironment()
85
+
86
+ for i in range(num_episodes):
87
+ task_id = task_ids[i % len(task_ids)]
88
+ seed = i * 1000 + 42
89
+
90
+ obs = env.reset(seed=seed, task_id=task_id)
91
+ user_message = format_plan_prompt(obs)
92
+
93
+ prompt_messages = [
94
+ {"role": "system", "content": PLAN_SYSTEM_PROMPT},
95
+ {"role": "user", "content": user_message},
96
+ ]
97
+
98
+ prompts.append({
99
+ "prompt": prompt_messages,
100
+ "task_id": task_id,
101
+ "seed": seed,
102
+ })
103
+
104
+ logger.info(f"Generated {len(prompts)} training prompts across tasks: {task_ids}")
105
+ return prompts
106
+
107
+
108
+ def run_baseline_evaluation(seed: int = 9999) -> dict:
109
+ """Run all baseline agents and return results for comparison.
110
+
111
+ Returns:
112
+ dict with structure: {agent_name: {task_id: result_dict}}
113
+ """
114
+ logger.info("=" * 60)
115
+ logger.info("Running BASELINE AGENT evaluation...")
116
+ logger.info("=" * 60)
117
+
118
+ results = run_baseline_local(agent_name="all", task_id="all", seed=seed)
119
+
120
+ # Organize by agent -> task
121
+ organized = {}
122
+ for r in results:
123
+ agent = r["agent"]
124
+ if agent not in organized:
125
+ organized[agent] = {}
126
+ organized[agent][r["task_id"]] = r
127
+
128
+ # Print summary table
129
+ print("\n" + "=" * 90)
130
+ print("BASELINE AGENT RESULTS")
131
+ print("=" * 90)
132
+ print(f"{'Agent':<15} {'Task':<25} {'Reward':<10} {'Bugs':<12} {'Coverage':<10}")
133
+ print("-" * 90)
134
+ for agent_name in ["random", "sequential", "smart"]:
135
+ if agent_name not in organized:
136
+ continue
137
+ for task_id in ["basic_validation", "edge_cases", "security_workflows"]:
138
+ r = organized[agent_name].get(task_id, {})
139
+ print(
140
+ f"{agent_name:<15} {task_id:<25} "
141
+ f"{r.get('total_reward', 0):<10.4f} "
142
+ f"{r.get('bugs_found', 0)}/{r.get('total_bugs', 0):<10} "
143
+ f"{r.get('coverage_pct', 0):<10.1f}%"
144
+ )
145
+ print("-" * 90)
146
+ print("=" * 90 + "\n")
147
+
148
+ return organized
149
+
150
+
151
+ def save_metrics(
152
+ output_dir: str,
153
+ baseline_results: dict,
154
+ base_model_results: dict,
155
+ trained_model_results: dict,
156
+ training_args: dict,
157
+ training_time_s: float,
158
+ ):
159
+ """Save all metrics and comparison data to output_dir/metrics/."""
160
+ metrics_dir = os.path.join(output_dir, "metrics")
161
+ os.makedirs(metrics_dir, exist_ok=True)
162
+
163
+ # Full results JSON
164
+ all_results = {
165
+ "training_args": training_args,
166
+ "training_time_seconds": round(training_time_s, 1),
167
+ "baseline_agents": {},
168
+ "base_model": base_model_results,
169
+ "trained_model": trained_model_results,
170
+ }
171
+
172
+ # Flatten baseline results
173
+ for agent_name, tasks in baseline_results.items():
174
+ all_results["baseline_agents"][agent_name] = {}
175
+ for task_id, r in tasks.items():
176
+ all_results["baseline_agents"][agent_name][task_id] = {
177
+ "total_reward": r.get("total_reward", 0),
178
+ "bugs_found": r.get("bugs_found", 0),
179
+ "total_bugs": r.get("total_bugs", 0),
180
+ "coverage_pct": r.get("coverage_pct", 0),
181
+ }
182
+
183
+ with open(os.path.join(metrics_dir, "results.json"), "w") as f:
184
+ json.dump(all_results, f, indent=2)
185
+
186
+ # Comparison table as markdown
187
+ md_lines = ["# Training Results\n"]
188
+ md_lines.append(f"**Model**: {training_args.get('model_id', 'unknown')}")
189
+ md_lines.append(f"**Training time**: {training_time_s / 60:.1f} minutes")
190
+ md_lines.append(f"**Episodes**: {training_args.get('num_episodes', 0)}")
191
+ md_lines.append(f"**Max steps**: {training_args.get('max_steps', 0)}\n")
192
+
193
+ md_lines.append("## Comparison Table\n")
194
+ md_lines.append("| Agent/Model | Task | Reward | Bugs | Coverage |")
195
+ md_lines.append("|---|---|---|---|---|")
196
+
197
+ # Baselines
198
+ for agent_name in ["random", "sequential", "smart"]:
199
+ if agent_name not in baseline_results:
200
+ continue
201
+ for task_id in ["basic_validation", "edge_cases", "security_workflows"]:
202
+ r = baseline_results[agent_name].get(task_id, {})
203
+ md_lines.append(
204
+ f"| {agent_name} | {task_id} | "
205
+ f"{r.get('total_reward', 0):.4f} | "
206
+ f"{r.get('bugs_found', 0)}/{r.get('total_bugs', 0)} | "
207
+ f"{r.get('coverage_pct', 0):.1f}% |"
208
+ )
209
+
210
+ # Base model
211
+ for task_id in ["basic_validation", "edge_cases", "security_workflows"]:
212
+ r = base_model_results.get(task_id, {})
213
+ md_lines.append(
214
+ f"| **base model** | {task_id} | "
215
+ f"{r.get('total_reward', 0):.4f} | "
216
+ f"{r.get('bugs_found', 0)}/{r.get('total_bugs', 0)} | "
217
+ f"{r.get('coverage_pct', 0):.1f}% |"
218
+ )
219
+
220
+ # Trained model
221
+ for task_id in ["basic_validation", "edge_cases", "security_workflows"]:
222
+ r = trained_model_results.get(task_id, {})
223
+ base = base_model_results.get(task_id, {})
224
+ delta = r.get("total_reward", 0) - base.get("total_reward", 0)
225
+ md_lines.append(
226
+ f"| **GRPO trained** | {task_id} | "
227
+ f"{r.get('total_reward', 0):.4f} ({delta:+.4f}) | "
228
+ f"{r.get('bugs_found', 0)}/{r.get('total_bugs', 0)} | "
229
+ f"{r.get('coverage_pct', 0):.1f}% |"
230
+ )
231
+
232
+ md_lines.append("")
233
+ with open(os.path.join(metrics_dir, "results.md"), "w") as f:
234
+ f.write("\n".join(md_lines))
235
+
236
+ logger.info(f"Metrics saved to {metrics_dir}/")
237
+
238
+
239
+ def save_plots(output_dir: str, baseline_results: dict, base_model_results: dict, trained_model_results: dict):
240
+ """Generate and save comparison plots."""
241
+ try:
242
+ import matplotlib
243
+ matplotlib.use("Agg")
244
+ import matplotlib.pyplot as plt
245
+ import numpy as np
246
+ except ImportError:
247
+ logger.warning("matplotlib not installed β€” skipping plot generation. pip install matplotlib")
248
+ return
249
+
250
+ plots_dir = os.path.join(output_dir, "metrics", "plots")
251
+ os.makedirs(plots_dir, exist_ok=True)
252
+
253
+ tasks = ["basic_validation", "edge_cases", "security_workflows"]
254
+ task_labels = ["Basic", "Edge Cases", "Security"]
255
+
256
+ # --- Plot 1: Reward comparison bar chart ---
257
+ fig, ax = plt.subplots(figsize=(12, 6))
258
+ x = np.arange(len(tasks))
259
+ width = 0.15
260
+
261
+ agents_to_plot = []
262
+ for agent_name in ["random", "sequential", "smart"]:
263
+ if agent_name in baseline_results:
264
+ rewards = [baseline_results[agent_name].get(t, {}).get("total_reward", 0) for t in tasks]
265
+ agents_to_plot.append((agent_name, rewards))
266
+
267
+ base_rewards = [base_model_results.get(t, {}).get("total_reward", 0) for t in tasks]
268
+ agents_to_plot.append(("Base Model", base_rewards))
269
+
270
+ trained_rewards = [trained_model_results.get(t, {}).get("total_reward", 0) for t in tasks]
271
+ agents_to_plot.append(("GRPO Trained", trained_rewards))
272
+
273
+ colors = ["#95a5a6", "#3498db", "#e67e22", "#9b59b6", "#2ecc71"]
274
+ for i, (name, rewards) in enumerate(agents_to_plot):
275
+ offset = (i - len(agents_to_plot) / 2 + 0.5) * width
276
+ bars = ax.bar(x + offset, rewards, width, label=name, color=colors[i % len(colors)])
277
+ for bar, val in zip(bars, rewards):
278
+ if val > 0.01:
279
+ ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01,
280
+ f"{val:.2f}", ha="center", va="bottom", fontsize=7)
281
+
282
+ ax.set_xlabel("Task")
283
+ ax.set_ylabel("Total Reward")
284
+ ax.set_title("Reward Comparison: Baselines vs Base Model vs GRPO Trained")
285
+ ax.set_xticks(x)
286
+ ax.set_xticklabels(task_labels)
287
+ ax.legend()
288
+ ax.set_ylim(bottom=0)
289
+ plt.tight_layout()
290
+ fig.savefig(os.path.join(plots_dir, "reward_comparison.png"), dpi=150)
291
+ plt.close(fig)
292
+
293
+ # --- Plot 2: Bugs found comparison ---
294
+ fig, ax = plt.subplots(figsize=(12, 6))
295
+ for i, (name, _) in enumerate(agents_to_plot):
296
+ if name in baseline_results:
297
+ bugs = [baseline_results[name].get(t, {}).get("bugs_found", 0) for t in tasks]
298
+ elif name == "Base Model":
299
+ bugs = [base_model_results.get(t, {}).get("bugs_found", 0) for t in tasks]
300
+ else:
301
+ bugs = [trained_model_results.get(t, {}).get("bugs_found", 0) for t in tasks]
302
+ offset = (i - len(agents_to_plot) / 2 + 0.5) * width
303
+ ax.bar(x + offset, bugs, width, label=name, color=colors[i % len(colors)])
304
+
305
+ total_bugs = [base_model_results.get(t, {}).get("total_bugs", 0) or
306
+ trained_model_results.get(t, {}).get("total_bugs", 0) for t in tasks]
307
+ ax.plot(x, total_bugs, "k--", marker="D", label="Total Bugs", linewidth=1.5)
308
+
309
+ ax.set_xlabel("Task")
310
+ ax.set_ylabel("Bugs Found")
311
+ ax.set_title("Bug Discovery: Baselines vs Base Model vs GRPO Trained")
312
+ ax.set_xticks(x)
313
+ ax.set_xticklabels(task_labels)
314
+ ax.legend()
315
+ ax.set_ylim(bottom=0)
316
+ plt.tight_layout()
317
+ fig.savefig(os.path.join(plots_dir, "bugs_comparison.png"), dpi=150)
318
+ plt.close(fig)
319
+
320
+ # --- Plot 3: Coverage comparison ---
321
+ fig, ax = plt.subplots(figsize=(12, 6))
322
+ for i, (name, _) in enumerate(agents_to_plot):
323
+ if name in baseline_results:
324
+ cov = [baseline_results[name].get(t, {}).get("coverage_pct", 0) for t in tasks]
325
+ elif name == "Base Model":
326
+ cov = [base_model_results.get(t, {}).get("coverage_pct", 0) for t in tasks]
327
+ else:
328
+ cov = [trained_model_results.get(t, {}).get("coverage_pct", 0) for t in tasks]
329
+ offset = (i - len(agents_to_plot) / 2 + 0.5) * width
330
+ ax.bar(x + offset, cov, width, label=name, color=colors[i % len(colors)])
331
+
332
+ ax.set_xlabel("Task")
333
+ ax.set_ylabel("Coverage %")
334
+ ax.set_title("API Coverage: Baselines vs Base Model vs GRPO Trained")
335
+ ax.set_xticks(x)
336
+ ax.set_xticklabels(task_labels)
337
+ ax.legend()
338
+ ax.set_ylim(0, 105)
339
+ plt.tight_layout()
340
+ fig.savefig(os.path.join(plots_dir, "coverage_comparison.png"), dpi=150)
341
+ plt.close(fig)
342
+
343
+ logger.info(f"Plots saved to {plots_dir}/")
344
+
345
+
346
+ def train_grpo(args):
347
+ """Run GRPO training with TRL."""
348
+ try:
349
+ from datasets import Dataset
350
+ from peft import LoraConfig
351
+ from transformers import AutoModelForCausalLM, AutoTokenizer
352
+ from trl import GRPOConfig, GRPOTrainer
353
+
354
+ # --- MONKEY PATCH FOR TRL GRPOTrainer ---
355
+ # trl 0.15 lacks `dataset` argument in `_get_train_sampler` required by transformers 4.57+
356
+ import inspect
357
+ if hasattr(GRPOTrainer, "_get_train_sampler"):
358
+ sig = inspect.signature(GRPOTrainer._get_train_sampler)
359
+ if "dataset" not in sig.parameters:
360
+ _old_sampler = GRPOTrainer._get_train_sampler
361
+ def _new_sampler(self, dataset=None, **kwargs):
362
+ return _old_sampler(self)
363
+ GRPOTrainer._get_train_sampler = _new_sampler
364
+ # ----------------------------------------
365
+ except ImportError as e:
366
+ logger.error(
367
+ f"Missing dependency: {e}\n"
368
+ "Install with: pip install trl transformers peft datasets torch"
369
+ )
370
+ sys.exit(1)
371
+
372
+ # --- W&B setup ---
373
+ wandb_run = None
374
+ report_to = "none"
375
+ if args.use_wandb:
376
+ try:
377
+ import wandb
378
+ wandb_run = wandb.init(
379
+ project=args.wandb_project,
380
+ name=args.wandb_run_name or f"grpo-{args.model_id.split('/')[-1]}-{int(time.time())}",
381
+ config={
382
+ "model_id": args.model_id,
383
+ "num_episodes": args.num_episodes,
384
+ "num_generations": args.num_generations,
385
+ "max_steps": args.max_steps,
386
+ "learning_rate": args.learning_rate,
387
+ "batch_size": args.batch_size,
388
+ "max_completion_length": args.max_completion_length,
389
+ "lora_r": 16,
390
+ "lora_alpha": 32,
391
+ },
392
+ )
393
+ report_to = "wandb"
394
+ logger.info(f"W&B initialized: project={args.wandb_project}, run={wandb_run.name}")
395
+ except ImportError:
396
+ logger.warning("wandb not installed β€” skipping W&B logging. pip install wandb")
397
+ args.use_wandb = False
398
+
399
+ training_args_dict = {
400
+ "model_id": args.model_id,
401
+ "num_episodes": args.num_episodes,
402
+ "num_generations": args.num_generations,
403
+ "max_steps": args.max_steps,
404
+ "learning_rate": args.learning_rate,
405
+ "batch_size": args.batch_size,
406
+ "max_completion_length": args.max_completion_length,
407
+ "output_dir": args.output_dir,
408
+ "test_mode": args.test_mode,
409
+ }
410
+
411
+ # ================================================================
412
+ # PIPELINE OVERVIEW
413
+ # ================================================================
414
+ total_pipeline_steps = 11
415
+ def _step(n, msg):
416
+ bar = "β–ˆ" * n + "β–‘" * (total_pipeline_steps - n)
417
+ print(f"\n{'='*70}")
418
+ print(f" [{bar}] Step {n}/{total_pipeline_steps}: {msg}")
419
+ print(f"{'='*70}\n")
420
+
421
+ # --- Step 1: Run baseline agent evaluation ---
422
+ _step(1, "Running baseline agents (random, sequential, smart)")
423
+ baseline_results = run_baseline_evaluation(seed=9999)
424
+
425
+ if args.use_wandb and wandb_run:
426
+ import wandb
427
+ for agent_name, tasks in baseline_results.items():
428
+ for task_id, r in tasks.items():
429
+ wandb.log({
430
+ f"baseline/{agent_name}/{task_id}/reward": r["total_reward"],
431
+ f"baseline/{agent_name}/{task_id}/bugs": r["bugs_found"],
432
+ f"baseline/{agent_name}/{task_id}/coverage": r["coverage_pct"],
433
+ })
434
+
435
+ # --- Step 2: Load model and tokenizer ---
436
+ _step(2, f"Loading model: {args.model_id}")
437
+ print(" Downloading tokenizer...", flush=True)
438
+ tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True)
439
+ if tokenizer.pad_token is None:
440
+ tokenizer.pad_token = tokenizer.eos_token
441
+ print(" Tokenizer loaded.", flush=True)
442
+
443
+ import torch
444
+
445
+ # --- Force GPU detection ---
446
+ if torch.cuda.is_available():
447
+ device_map = "auto"
448
+ dtype = torch.bfloat16
449
+ gpu_name = torch.cuda.get_device_name(0)
450
+ gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
451
+ print(f" GPU: {gpu_name} ({gpu_mem:.1f} GB)", flush=True)
452
+ print(f" CUDA version: {torch.version.cuda}", flush=True)
453
+ elif torch.backends.mps.is_available():
454
+ device_map = "auto"
455
+ dtype = torch.float16
456
+ print(" Device: Apple MPS", flush=True)
457
+ else:
458
+ # Still try to use GPU β€” sometimes torch.cuda.is_available() is False
459
+ # because of driver issues but CUDA can still work
460
+ device_map = None
461
+ dtype = torch.float32
462
+ print(" !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!", flush=True)
463
+ print(" !! WARNING: No GPU detected β€” running on CPU !!", flush=True)
464
+ print(" !! Training will be EXTREMELY slow. !!", flush=True)
465
+ print(" !! Check: python -c 'import torch; print(torch.cuda.is_available())'", flush=True)
466
+ print(" !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!", flush=True)
467
+
468
+ print(" Downloading model weights...", flush=True)
469
+ model = AutoModelForCausalLM.from_pretrained(
470
+ args.model_id,
471
+ trust_remote_code=True,
472
+ torch_dtype=dtype,
473
+ device_map=device_map,
474
+ )
475
+
476
+ # Verify model is actually on GPU
477
+ actual_device = next(model.parameters()).device
478
+ param_count = sum(p.numel() for p in model.parameters()) / 1e6
479
+ print(f" Model loaded: {param_count:.0f}M parameters on {actual_device}", flush=True)
480
+
481
+ if torch.cuda.is_available() and actual_device.type != "cuda":
482
+ print(" Model not on GPU β€” forcing move to CUDA...", flush=True)
483
+ model = model.to("cuda")
484
+ print(f" Moved to: {next(model.parameters()).device}", flush=True)
485
+
486
+ # --- Step 3: Evaluate base model BEFORE training ---
487
+ _step(3, f"Evaluating BASE model (before GRPO, max {args.eval_max_steps} steps/task)")
488
+ base_results = {}
489
+ if not args.skip_eval:
490
+ for task_id in ["basic_validation", "edge_cases", "security_workflows"]:
491
+ result = run_rollout(model, tokenizer, task_id=task_id, seed=9999, max_steps=args.eval_max_steps)
492
+ base_results[task_id] = result
493
+ logger.info(
494
+ f" [BASE] {task_id}: reward={result['total_reward']:.3f}, "
495
+ f"bugs={result['bugs_found']}/{result['total_bugs']}, "
496
+ f"coverage={result['coverage_pct']:.1f}%"
497
+ )
498
+ if args.use_wandb and wandb_run:
499
+ import wandb
500
+ wandb.log({
501
+ f"base_model/{task_id}/reward": result["total_reward"],
502
+ f"base_model/{task_id}/bugs": result["bugs_found"],
503
+ f"base_model/{task_id}/coverage": result["coverage_pct"],
504
+ })
505
+ else:
506
+ logger.info("Skipping base model evaluation (--skip-eval)")
507
+ for task_id in ["basic_validation", "edge_cases", "security_workflows"]:
508
+ base_results[task_id] = {"total_reward": 0, "bugs_found": 0, "total_bugs": 0, "coverage_pct": 0}
509
+
510
+ # --- Step 4: LoRA config ---
511
+ _step(4, "Configuring LoRA adapters")
512
+ lora_config = LoraConfig(
513
+ r=16,
514
+ lora_alpha=32,
515
+ lora_dropout=0.05,
516
+ target_modules=["q_proj", "v_proj"],
517
+ task_type="CAUSAL_LM",
518
+ )
519
+ print(f" LoRA: r=16, alpha=32, targets=q_proj+v_proj", flush=True)
520
+
521
+ # --- Step 5: Generate training prompts ---
522
+ _step(5, f"Generating {args.num_episodes} training episodes")
523
+ raw_prompts = build_training_prompts(num_episodes=args.num_episodes)
524
+ print(f" {len(raw_prompts)} prompts across 3 tasks (each with unique seed)", flush=True)
525
+
526
+ # Qwen3 thinking mode: let the model reason before outputting JSON
527
+ # Requires higher max_completion_length (~2048) to fit <think>...</think> + JSON
528
+ chat_template_kwargs = {}
529
+ if "qwen3" in args.model_id.lower():
530
+ chat_template_kwargs["enable_thinking"] = True
531
+ logger.info("Qwen3 detected β€” thinking mode ENABLED (model will reason before acting)")
532
+
533
+ formatted_prompts = []
534
+ for p in raw_prompts:
535
+ text = tokenizer.apply_chat_template(
536
+ p["prompt"], tokenize=False, add_generation_prompt=True,
537
+ **chat_template_kwargs,
538
+ )
539
+ formatted_prompts.append({"prompt": text, "task_id": p["task_id"], "seed": p["seed"]})
540
+
541
+ dataset = Dataset.from_list(formatted_prompts)
542
+
543
+ # Store prompt metadata for the reward function to create fresh envs
544
+ prompts_meta = [{"seed": p["seed"], "task_id": p["task_id"]} for p in raw_prompts]
545
+
546
+ # Combined reward: format (valid JSON array?) + plan (execute all actions) + diversity (varied requests?)
547
+ # Each generation gets a FRESH environment β€” no shared state pollution
548
+ def combined_reward_fn(completions, **kwargs):
549
+ fmt = format_reward_fn(completions)
550
+ plan = plan_reward_fn(completions, prompts_meta=prompts_meta)
551
+ div = diversity_reward_fn(completions)
552
+ return [f + p + d for f, p, d in zip(fmt, plan, div)]
553
+
554
+ # --- Step 6: GRPO training ---
555
+ _step(6, f"GRPO training ({args.max_steps} steps, {args.num_generations} generations/prompt)")
556
+ config = GRPOConfig(
557
+ output_dir=args.output_dir,
558
+ num_generations=args.num_generations,
559
+ max_completion_length=args.max_completion_length,
560
+ learning_rate=args.learning_rate,
561
+ per_device_train_batch_size=args.batch_size,
562
+ num_train_epochs=1,
563
+ max_steps=args.max_steps,
564
+ logging_steps=5,
565
+ save_steps=50,
566
+ save_total_limit=3,
567
+ report_to=report_to,
568
+ temperature=0.8,
569
+ )
570
+
571
+ trainer = GRPOTrainer(
572
+ model=model,
573
+ args=config,
574
+ reward_funcs=[combined_reward_fn],
575
+ train_dataset=dataset,
576
+ peft_config=lora_config,
577
+ processing_class=tokenizer,
578
+ )
579
+
580
+ print(f" Config: lr={args.learning_rate}, batch={args.batch_size}, "
581
+ f"max_completion={args.max_completion_length}, temp=0.8", flush=True)
582
+ print(f" Rewards: format_reward + plan_reward + diversity_reward", flush=True)
583
+ print(f" Training begins... (progress bar below)\n", flush=True)
584
+
585
+ train_start = time.time()
586
+ trainer.train()
587
+ training_time = time.time() - train_start
588
+ print(f"\n Training completed in {training_time / 60:.1f} minutes", flush=True)
589
+
590
+ # --- Step 7: Save model locally ---
591
+ _step(7, f"Saving model to {args.output_dir}")
592
+ trainer.save_model(args.output_dir)
593
+ tokenizer.save_pretrained(args.output_dir)
594
+ print(f" Model + tokenizer saved.", flush=True)
595
+
596
+ # --- Step 8: Push to HuggingFace Hub ---
597
+ _step(8, "Pushing to HuggingFace Hub" if args.push_to_hub else "HF Hub push (skipped β€” use --push-to-hub)")
598
+ if args.push_to_hub:
599
+ hf_repo = args.hf_repo_id
600
+ if not hf_repo:
601
+ logger.error("--hf-repo-id is required when using --push-to-hub")
602
+ else:
603
+ try:
604
+ logger.info(f"Pushing model to HuggingFace Hub: {hf_repo}")
605
+ trainer.push_to_hub(repo_id=hf_repo, commit_message="GRPO trained API testing agent")
606
+ tokenizer.push_to_hub(repo_id=hf_repo, commit_message="GRPO trained API testing agent")
607
+ logger.info(f"Model pushed to https://huggingface.co/{hf_repo}")
608
+ except Exception as e:
609
+ logger.error(f"Failed to push to HF Hub: {e}")
610
+ logger.info("Make sure you're logged in: huggingface-cli login")
611
+
612
+ # --- Step 9: Evaluate AFTER training ---
613
+ _step(9, f"Evaluating TRAINED model (max {args.eval_max_steps} steps/task)")
614
+ trained_results = {}
615
+ if not args.skip_eval:
616
+ for task_id in ["basic_validation", "edge_cases", "security_workflows"]:
617
+ result = run_rollout(model, tokenizer, task_id=task_id, seed=9999, max_steps=args.eval_max_steps)
618
+ trained_results[task_id] = result
619
+ base = base_results[task_id]
620
+ reward_delta = result["total_reward"] - base.get("total_reward", 0)
621
+ bug_delta = result["bugs_found"] - base.get("bugs_found", 0)
622
+ cov_delta = result["coverage_pct"] - base.get("coverage_pct", 0)
623
+ logger.info(
624
+ f" [TRAINED] {task_id}: reward={result['total_reward']:.3f} ({reward_delta:+.3f}), "
625
+ f"bugs={result['bugs_found']}/{result['total_bugs']} ({bug_delta:+d}), "
626
+ f"coverage={result['coverage_pct']:.1f}% ({cov_delta:+.1f}%)"
627
+ )
628
+ if args.use_wandb and wandb_run:
629
+ import wandb
630
+ wandb.log({
631
+ f"trained_model/{task_id}/reward": result["total_reward"],
632
+ f"trained_model/{task_id}/bugs": result["bugs_found"],
633
+ f"trained_model/{task_id}/coverage": result["coverage_pct"],
634
+ f"delta/{task_id}/reward": reward_delta,
635
+ f"delta/{task_id}/bugs": bug_delta,
636
+ f"delta/{task_id}/coverage": cov_delta,
637
+ })
638
+ else:
639
+ logger.info("Skipping trained model evaluation (--skip-eval)")
640
+ for task_id in ["basic_validation", "edge_cases", "security_workflows"]:
641
+ trained_results[task_id] = {"total_reward": 0, "bugs_found": 0, "total_bugs": 0, "coverage_pct": 0}
642
+
643
+ # --- Step 10: Print final comparison table ---
644
+ _step(10, "Results comparison table")
645
+ print("=" * 95)
646
+ print("FINAL COMPARISON: All Agents & Models")
647
+ print("=" * 95)
648
+ print(f"{'Agent/Model':<18} {'Task':<25} {'Reward':<10} {'Bugs':<12} {'Coverage':<10}")
649
+ print("-" * 95)
650
+
651
+ for agent_name in ["random", "sequential", "smart"]:
652
+ if agent_name in baseline_results:
653
+ for task_id in ["basic_validation", "edge_cases", "security_workflows"]:
654
+ r = baseline_results[agent_name].get(task_id, {})
655
+ print(
656
+ f"{agent_name:<18} {task_id:<25} "
657
+ f"{r.get('total_reward', 0):<10.4f} "
658
+ f"{r.get('bugs_found', 0)}/{r.get('total_bugs', 0):<10} "
659
+ f"{r.get('coverage_pct', 0):<10.1f}%"
660
+ )
661
+ print("-" * 95)
662
+
663
+ for task_id in ["basic_validation", "edge_cases", "security_workflows"]:
664
+ r = base_results[task_id]
665
+ print(
666
+ f"{'Base Model':<18} {task_id:<25} "
667
+ f"{r['total_reward']:<10.4f} "
668
+ f"{r['bugs_found']}/{r['total_bugs']:<10} "
669
+ f"{r['coverage_pct']:<10.1f}%"
670
+ )
671
+ print("-" * 95)
672
+
673
+ for task_id in ["basic_validation", "edge_cases", "security_workflows"]:
674
+ r = trained_results[task_id]
675
+ base = base_results[task_id]
676
+ delta = r["total_reward"] - base["total_reward"]
677
+ print(
678
+ f"{'GRPO Trained':<18} {task_id:<25} "
679
+ f"{r['total_reward']:<10.4f} "
680
+ f"{r['bugs_found']}/{r['total_bugs']:<10} "
681
+ f"{r['coverage_pct']:<10.1f}% ({delta:+.4f})"
682
+ )
683
+ print("=" * 95)
684
+
685
+ # --- Step 11: Save metrics & plots ---
686
+ _step(11, "Saving metrics, plots, and finalizing")
687
+ save_metrics(
688
+ output_dir=args.output_dir,
689
+ baseline_results=baseline_results,
690
+ base_model_results=base_results,
691
+ trained_model_results=trained_results,
692
+ training_args=training_args_dict,
693
+ training_time_s=training_time,
694
+ )
695
+ save_plots(
696
+ output_dir=args.output_dir,
697
+ baseline_results=baseline_results,
698
+ base_model_results=base_results,
699
+ trained_model_results=trained_results,
700
+ )
701
+
702
+ # --- Finalize W&B ---
703
+ if args.use_wandb and wandb_run:
704
+ import wandb
705
+ # Log plots as artifacts
706
+ plots_dir = os.path.join(args.output_dir, "metrics", "plots")
707
+ if os.path.exists(plots_dir):
708
+ for fname in os.listdir(plots_dir):
709
+ if fname.endswith(".png"):
710
+ wandb.log({f"plots/{fname.replace('.png', '')}": wandb.Image(os.path.join(plots_dir, fname))})
711
+ wandb.finish()
712
+
713
+ # ================================================================
714
+ print(f"\n{'='*70}")
715
+ print(f" PIPELINE COMPLETE")
716
+ print(f" Training time: {training_time / 60:.1f} minutes")
717
+ print(f" Model saved to: {args.output_dir}")
718
+ print(f" Metrics: {args.output_dir}/metrics/")
719
+ print(f" Plots: {args.output_dir}/metrics/plots/")
720
+ if args.use_wandb:
721
+ print(f" W&B: https://wandb.ai/{args.wandb_project}")
722
+ if args.push_to_hub and args.hf_repo_id:
723
+ print(f" HF Hub: https://huggingface.co/{args.hf_repo_id}")
724
+ print(f"{'='*70}\n")
725
+
726
+
727
+ def main():
728
+ parser = argparse.ArgumentParser(description="GRPO Training for API Testing Agent")
729
+
730
+ # Model & training
731
+ parser.add_argument("--model-id", default="Qwen/Qwen3-1.7B", help="Base model to fine-tune")
732
+ parser.add_argument("--output-dir", default="./checkpoints/grpo_api_tester")
733
+ parser.add_argument("--num-episodes", type=int, default=50, help="Number of training episodes")
734
+ parser.add_argument("--num-generations", type=int, default=4, help="GRPO parallel rollouts per prompt")
735
+ parser.add_argument("--max-completion-length", type=int, default=4096,
736
+ help="Max tokens per generation. 4096 needed for Qwen3 thinking + JSON plan")
737
+ parser.add_argument("--max-steps", type=int, default=200, help="Max training steps")
738
+ parser.add_argument("--learning-rate", type=float, default=2e-5)
739
+ parser.add_argument("--batch-size", type=int, default=4)
740
+ parser.add_argument("--test-mode", action="store_true", help="Quick test with tiny config")
741
+
742
+ # HuggingFace Hub
743
+ parser.add_argument("--push-to-hub", action="store_true", help="Push trained model to HF Hub")
744
+ parser.add_argument("--hf-repo-id", type=str, default=None,
745
+ help="HF Hub repo ID (e.g., your-username/api-tester-grpo)")
746
+
747
+ # Evaluation
748
+ parser.add_argument("--skip-eval", action="store_true", help="Skip base/trained model evaluation")
749
+ parser.add_argument("--eval-max-steps", type=int, default=10,
750
+ help="Max steps per task during evaluation (default: 10, reduces eval time)")
751
+
752
+ # Weights & Biases
753
+ parser.add_argument("--use-wandb", action="store_true", help="Enable Weights & Biases logging")
754
+ parser.add_argument("--wandb-project", type=str, default="api-testing-grpo",
755
+ help="W&B project name")
756
+ parser.add_argument("--wandb-run-name", type=str, default=None,
757
+ help="W&B run name (auto-generated if not set)")
758
+
759
+ args = parser.parse_args()
760
+
761
+ if args.test_mode:
762
+ logger.info("=== TEST MODE β€” quick sanity check ===")
763
+ args.num_episodes = 3
764
+ args.num_generations = 4
765
+ args.batch_size = 2
766
+ args.max_steps = 10
767
+ args.max_completion_length = 2048
768
+
769
+ if os.environ.get("SHOW_PROMPTS"):
770
+ prompts = build_training_prompts(num_episodes=3)
771
+ for p in prompts:
772
+ print(f"\n{'='*60}")
773
+ print(f"Task: {p['task_id']} | Seed: {p['seed']}")
774
+ print(f"{'='*60}")
775
+ for msg in p["prompt"]:
776
+ print(f"[{msg['role']}]: {msg['content'][:300]}...")
777
+ return
778
+
779
+ train_grpo(args)
780
+
781
+
782
+ if __name__ == "__main__":
783
+ main()
training/prompts.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Prompt formatting and action parsing for LLM-based API testing agents.
3
+
4
+ - SYSTEM_PROMPT: Instructions for the LLM on how to test APIs
5
+ - format_observation(): Converts environment observations into LLM prompts
6
+ - parse_action(): Extracts a single JSON action from LLM text
7
+ - parse_test_plan(): Extracts a JSON array of actions (for GRPO training)
8
+ """
9
+
10
+ import json
11
+ import re
12
+ import sys
13
+ import os
14
+
15
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
16
+ from models import APITestAction, HTTPMethod
17
+
18
+
19
+ # =====================================================================
20
+ # System prompt for multi-turn evaluation (one action at a time)
21
+ # =====================================================================
22
+
23
+ SYSTEM_PROMPT = """\
24
+ You are an expert API security tester. You are testing a REST API for bugs.
25
+
26
+ You will receive:
27
+ - The API specification (available endpoints)
28
+ - Results from your previous requests
29
+ - Coverage and bug discovery progress
30
+
31
+ Your job: find as many bugs as possible by sending HTTP requests.
32
+
33
+ Think step by step about what to test next, then output your action as JSON.
34
+
35
+ RESPOND WITH EXACTLY ONE JSON ACTION per turn:
36
+ ```json
37
+ {
38
+ "method": "GET|POST|PUT|DELETE",
39
+ "endpoint": "/path",
40
+ "headers": {},
41
+ "query_params": {},
42
+ "body": null,
43
+ "expected_status": 200
44
+ }
45
+ ```
46
+
47
+ TESTING STRATEGIES:
48
+ - Test each endpoint with valid inputs first
49
+ - Try invalid inputs (missing fields, wrong types, boundary values)
50
+ - Test with non-existent resource IDs
51
+ - Login as different users and test cross-user access
52
+ - Try SQL injection patterns in text fields
53
+ - Test with very long inputs
54
+ - Chain operations: create -> read -> update -> delete
55
+ """
56
+
57
+
58
+ # =====================================================================
59
+ # System prompt for GRPO training (full test plan in one shot)
60
+ # =====================================================================
61
+
62
+ PLAN_SYSTEM_PROMPT = """\
63
+ You are an expert API security tester. You will receive an API specification and must output a COMPLETE TEST PLAN as a JSON array of HTTP requests to execute in order.
64
+
65
+ Your goal: find as many bugs as possible through systematic testing.
66
+
67
+ OUTPUT FORMAT β€” a JSON array of actions to execute sequentially:
68
+ ```json
69
+ [
70
+ {"method": "GET", "endpoint": "/tasks", "headers": {}, "query_params": {}, "body": null, "expected_status": 200},
71
+ {"method": "POST", "endpoint": "/auth/login", "headers": {}, "query_params": {}, "body": {"username": "alice", "password": "pass"}, "expected_status": 200},
72
+ ...more actions...
73
+ ]
74
+ ```
75
+
76
+ OUTPUT EXACTLY ONE JSON ARRAY. No other text.
77
+
78
+ TESTING STRATEGY β€” follow this order:
79
+ 1. DISCOVER: GET /tasks, GET /users to see what exists
80
+ 2. AUTHENTICATE: Login as two different users (POST /auth/login)
81
+ 3. CRUD: POST to create, GET to read, PUT to update, DELETE to remove
82
+ 4. MISSING FIELDS: POST /tasks without required "title" field
83
+ 5. NON-EXISTENT IDs: GET /tasks/999999 (expect 404 β€” if you get 200, that's a bug!)
84
+ 6. BOUNDARY: GET /tasks?page=-1&limit=10 (negative page), GET /tasks?limit=999999 (huge limit)
85
+ 7. INVALID DATA: PUT /tasks/1 with assignee_email="not-an-email"
86
+ 8. SECURITY: Login as user B, then try to GET/PUT/DELETE user A's resources (BOLA test)
87
+ 9. INJECTION: POST /tasks with title containing SQL injection like "'; DROP TABLE tasks;--"
88
+ 10. EMPTY AUTH: POST /auth/login with empty password (should fail but might not)
89
+ 11. DATA LEAKS: POST /users and check if response includes password_hash
90
+ 12. STATE: DELETE a task, then GET it again (should be 404)
91
+ 13. LONG INPUT: POST /tasks with a title of 6000+ characters
92
+
93
+ COMMON BUG PATTERNS TO TEST:
94
+ - API returns 200 with null body instead of 404 for missing resources
95
+ - API returns 500 instead of 400 for invalid input
96
+ - API accepts any password (even empty string) for login
97
+ - Users can access other users' resources (no authorization check)
98
+ - Response includes sensitive fields like password_hash
99
+ - No input length limits (very long strings crash the server)
100
+ - SQL/HTML injection payloads stored without sanitization
101
+ - DELETE returns 200 even for non-existent resources
102
+ - No pagination limit cap (limit=999999 accepted)
103
+
104
+ RULES:
105
+ - Output 15-25 actions
106
+ - Each action MUST have "method" and "endpoint"
107
+ - Vary your requests β€” never repeat the same action
108
+ - Use the usernames from the task description for login
109
+ """
110
+
111
+
112
+ def format_observation(obs) -> str:
113
+ """Convert an observation into a human-readable prompt for the LLM.
114
+ Used in multi-turn evaluation (one action at a time).
115
+ """
116
+ parts = []
117
+
118
+ if obs.steps_taken == 0:
119
+ parts.append(f"TASK: {obs.task_description}")
120
+ parts.append(f"\nSTEPS REMAINING: {obs.max_steps}")
121
+ parts.append("\nAVAILABLE ENDPOINTS:")
122
+ for ep in obs.available_endpoints:
123
+ line = f" {ep['method']} {ep['path']} β€” {ep.get('summary', '')}"
124
+ parts.append(line)
125
+ parts.append("\nBegin testing. Send your first request as JSON.")
126
+ else:
127
+ parts.append(f"STEP {obs.steps_taken}/{obs.max_steps}")
128
+ parts.append(f"RESPONSE: HTTP {obs.status_code}")
129
+
130
+ resp = obs.response_body
131
+ if isinstance(resp, (dict, list)):
132
+ resp_str = json.dumps(resp, indent=2)
133
+ if len(resp_str) > 500:
134
+ resp_str = resp_str[:500] + "\n... (truncated)"
135
+ else:
136
+ resp_str = str(resp)[:500]
137
+ parts.append(f"BODY:\n{resp_str}")
138
+
139
+ parts.append(f"\nFEEDBACK: {obs.feedback}")
140
+
141
+ coverage = obs.coverage_summary
142
+ parts.append(
143
+ f"\nPROGRESS: Bugs found: {obs.bugs_found_so_far} | "
144
+ f"Coverage: {coverage.get('coverage_pct', 0):.0f}% | "
145
+ f"Endpoints tested: {coverage.get('endpoints_tested', 0)}/{coverage.get('total_endpoints', 0)}"
146
+ )
147
+
148
+ if obs.auth_tokens:
149
+ parts.append(f"AUTH TOKENS: {list(obs.auth_tokens.keys())}")
150
+ if obs.known_resource_ids:
151
+ parts.append(f"CREATED RESOURCES: {dict(obs.known_resource_ids)}")
152
+
153
+ parts.append("\nSend your next request as JSON.")
154
+
155
+ return "\n".join(parts)
156
+
157
+
158
+ def format_plan_prompt(obs) -> str:
159
+ """Convert the initial observation into a prompt for generating a full test plan.
160
+ Used in GRPO training (model outputs a complete plan in one completion).
161
+ """
162
+ parts = []
163
+ parts.append(f"TASK: {obs.task_description}")
164
+ parts.append(f"\nYou have {obs.max_steps} actions to find as many bugs as possible.")
165
+ parts.append("\nAVAILABLE ENDPOINTS:")
166
+ for ep in obs.available_endpoints:
167
+ summary = ep.get("summary", "")
168
+ parts.append(f" {ep['method']} {ep['path']} β€” {summary}")
169
+
170
+ # Show request body schema if available
171
+ req_body = ep.get("request_body", {})
172
+ if req_body:
173
+ props = req_body.get("properties", {})
174
+ required = req_body.get("required", [])
175
+ if props:
176
+ fields = []
177
+ for fname, finfo in props.items():
178
+ req_mark = " (required)" if fname in required else ""
179
+ fields.append(f"{fname}: {finfo.get('type', 'any')}{req_mark}")
180
+ parts.append(f" Body: {', '.join(fields)}")
181
+
182
+ # Show parameters if available
183
+ params = ep.get("parameters", [])
184
+ if params:
185
+ param_strs = [f"{p['name']}: {p.get('type', 'any')}" for p in params]
186
+ parts.append(f" Params: {', '.join(param_strs)}")
187
+
188
+ parts.append("\nOutput your complete test plan as a JSON array of actions.")
189
+ return "\n".join(parts)
190
+
191
+
192
+ def parse_action(text: str) -> APITestAction | None:
193
+ """Parse a single JSON action from LLM output.
194
+ Used in multi-turn evaluation.
195
+ """
196
+ # Strip Qwen3 thinking blocks
197
+ if "</think>" in text:
198
+ text = text.split("</think>", 1)[-1]
199
+
200
+ json_match = re.search(r'\{[^{}]*"method"[^{}]*\}', text, re.DOTALL)
201
+ if not json_match:
202
+ json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', text, re.DOTALL)
203
+ if json_match:
204
+ json_str = json_match.group(1)
205
+ else:
206
+ return None
207
+ else:
208
+ json_str = json_match.group(0)
209
+
210
+ try:
211
+ data = json.loads(json_str)
212
+ except json.JSONDecodeError:
213
+ return None
214
+
215
+ return _dict_to_action(data)
216
+
217
+
218
+ def parse_test_plan(text: str) -> list[APITestAction]:
219
+ """Parse a JSON array of actions from LLM output.
220
+
221
+ Handles all of these formats:
222
+ 1. Raw JSON array: [{"method": ...}, ...]
223
+ 2. Wrapped object: {"actions": [...]} or {"plan": [...]} or {"test_plan": [...]}
224
+ 3. Markdown code block: ```json [...] ```
225
+ 4. Trailing commas, missing commas (best-effort repair)
226
+ 5. Brace-balanced extraction of individual action objects
227
+ """
228
+ if not text:
229
+ return []
230
+
231
+ # Strip Qwen3 thinking blocks
232
+ if "</think>" in text:
233
+ text = text.split("</think>", 1)[-1]
234
+
235
+ # Strip markdown code fences
236
+ text = re.sub(r'```(?:json)?\s*', '', text)
237
+ text = text.replace('```', '')
238
+
239
+ data = None
240
+
241
+ # Strategy 1: Try to parse the entire text as JSON
242
+ try:
243
+ data = json.loads(text.strip())
244
+ except json.JSONDecodeError:
245
+ pass
246
+
247
+ # Strategy 2: Find a top-level JSON ARRAY via brace matching
248
+ if data is None:
249
+ start = text.find('[')
250
+ if start >= 0:
251
+ depth = 0
252
+ for i in range(start, len(text)):
253
+ if text[i] == '[':
254
+ depth += 1
255
+ elif text[i] == ']':
256
+ depth -= 1
257
+ if depth == 0:
258
+ candidate = text[start:i+1]
259
+ try:
260
+ data = json.loads(candidate)
261
+ break
262
+ except json.JSONDecodeError:
263
+ cleaned = re.sub(r',(\s*[\]}])', r'\1', candidate)
264
+ try:
265
+ data = json.loads(cleaned)
266
+ break
267
+ except json.JSONDecodeError:
268
+ pass
269
+
270
+ # Strategy 2b: Find a top-level JSON OBJECT (might be {"actions": [...]})
271
+ if data is None:
272
+ start = text.find('{')
273
+ if start >= 0:
274
+ depth = 0
275
+ for i in range(start, len(text)):
276
+ if text[i] == '{':
277
+ depth += 1
278
+ elif text[i] == '}':
279
+ depth -= 1
280
+ if depth == 0:
281
+ candidate = text[start:i+1]
282
+ try:
283
+ parsed = json.loads(candidate)
284
+ # Only accept if it's a wrapper containing actions
285
+ if isinstance(parsed, dict) and any(
286
+ k in parsed for k in ("actions", "plan", "test_plan", "steps", "requests")
287
+ ):
288
+ data = parsed
289
+ break
290
+ except json.JSONDecodeError:
291
+ cleaned = re.sub(r',(\s*[\]}])', r'\1', candidate)
292
+ try:
293
+ parsed = json.loads(cleaned)
294
+ if isinstance(parsed, dict) and any(
295
+ k in parsed for k in ("actions", "plan", "test_plan", "steps", "requests")
296
+ ):
297
+ data = parsed
298
+ break
299
+ except json.JSONDecodeError:
300
+ pass
301
+
302
+ # Strategy 3: Extract individual {"method": ...} objects with brace balancing
303
+ if data is None:
304
+ objects = []
305
+ i = 0
306
+ while i < len(text):
307
+ if text[i] == '{':
308
+ depth = 1
309
+ start = i
310
+ i += 1
311
+ while i < len(text) and depth > 0:
312
+ if text[i] == '{':
313
+ depth += 1
314
+ elif text[i] == '}':
315
+ depth -= 1
316
+ i += 1
317
+ obj_str = text[start:i]
318
+ if '"method"' in obj_str:
319
+ try:
320
+ obj = json.loads(obj_str)
321
+ objects.append(obj)
322
+ except json.JSONDecodeError:
323
+ cleaned = re.sub(r',(\s*[\]}])', r'\1', obj_str)
324
+ try:
325
+ obj = json.loads(cleaned)
326
+ objects.append(obj)
327
+ except json.JSONDecodeError:
328
+ pass
329
+ else:
330
+ i += 1
331
+ if objects:
332
+ data = objects
333
+
334
+ if data is None:
335
+ return []
336
+
337
+ # Unwrap common container shapes: {"actions": [...]}, {"plan": [...]}, etc.
338
+ if isinstance(data, dict):
339
+ for key in ("actions", "plan", "test_plan", "steps", "requests"):
340
+ if key in data and isinstance(data[key], list):
341
+ data = data[key]
342
+ break
343
+ else:
344
+ # Single action object
345
+ data = [data]
346
+
347
+ if not isinstance(data, list):
348
+ data = [data]
349
+
350
+ actions = []
351
+ for item in data:
352
+ if isinstance(item, dict) and "method" in item:
353
+ action = _dict_to_action(item)
354
+ if action:
355
+ actions.append(action)
356
+
357
+ return actions
358
+
359
+
360
+ def _dict_to_action(data: dict) -> APITestAction | None:
361
+ """Convert a dict to an APITestAction."""
362
+ method = str(data.get("method", "GET")).upper()
363
+ if method not in ("GET", "POST", "PUT", "DELETE", "PATCH"):
364
+ method = "GET"
365
+
366
+ endpoint = data.get("endpoint", "/tasks")
367
+ if not isinstance(endpoint, str):
368
+ endpoint = str(endpoint)
369
+ if not endpoint.startswith("/"):
370
+ endpoint = "/" + endpoint
371
+
372
+ headers = data.get("headers") or {}
373
+ if not isinstance(headers, dict):
374
+ headers = {}
375
+
376
+ query_params = data.get("query_params") or {}
377
+ if not isinstance(query_params, dict):
378
+ query_params = {}
379
+
380
+ body = data.get("body")
381
+ if body is not None and not isinstance(body, dict):
382
+ body = None
383
+
384
+ expected = data.get("expected_status")
385
+ if expected is not None:
386
+ try:
387
+ expected = int(expected)
388
+ except (ValueError, TypeError):
389
+ expected = None
390
+
391
+ return APITestAction(
392
+ method=HTTPMethod(method),
393
+ endpoint=endpoint,
394
+ headers=headers,
395
+ query_params=query_params,
396
+ body=body,
397
+ expected_status=expected,
398
+ )
training/rewards.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Reward functions for GRPO training (v2 β€” plan-based).
3
+
4
+ The model outputs a FULL TEST PLAN (JSON array of actions).
5
+ Each reward function creates a FRESH environment, executes ALL actions,
6
+ and scores the result.
7
+
8
+ Three reward signals:
9
+ 1. format_reward β€” Valid JSON array with 3+ diverse actions? (+2 / -2)
10
+ 2. plan_reward β€” Execute plan, score on bugs + coverage + efficiency (0 to ~8)
11
+ 3. diversity_reward β€” Variety of methods, endpoints, and request patterns (+0 to +2)
12
+ """
13
+
14
+ import re
15
+ import sys
16
+ import os
17
+
18
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
19
+
20
+ from models import APITestAction, HTTPMethod
21
+ from server.environment import APITestEnvironment
22
+ from .prompts import parse_test_plan
23
+
24
+
25
+ def format_reward_fn(completions: list[str], **kwargs) -> list[float]:
26
+ """Reward for valid JSON test plan format.
27
+
28
+ +2.0 if output has 5+ diverse actions (a real plan)
29
+ +1.0 if output has 3-4 actions (minimal plan)
30
+ +0.0 if output has 1-2 actions (barely valid)
31
+ -2.0 if it can't be parsed at all
32
+
33
+ Also penalizes if all actions are identical.
34
+ """
35
+ rewards = []
36
+ for text in completions:
37
+ actions = parse_test_plan(text)
38
+ if not actions:
39
+ rewards.append(-2.0)
40
+ continue
41
+
42
+ n = len(actions)
43
+
44
+ # Check diversity β€” are the actions actually different?
45
+ unique_pairs = set()
46
+ for a in actions:
47
+ m = a.method.value if hasattr(a.method, "value") else str(a.method)
48
+ ep = re.sub(r'/\d+', '/{id}', a.endpoint)
49
+ unique_pairs.add((m, ep))
50
+
51
+ diversity_ratio = len(unique_pairs) / max(n, 1)
52
+
53
+ if n >= 5 and diversity_ratio >= 0.5:
54
+ rewards.append(2.0)
55
+ elif n >= 3:
56
+ rewards.append(1.0)
57
+ elif n >= 1:
58
+ rewards.append(0.0)
59
+ else:
60
+ rewards.append(-2.0)
61
+
62
+ # Penalty if all actions are the same
63
+ if len(unique_pairs) <= 1 and n > 1:
64
+ rewards[-1] = -1.0
65
+
66
+ return rewards
67
+
68
+
69
+ def plan_reward_fn(completions: list[str], **kwargs) -> list[float]:
70
+ """Execute the full test plan in a FRESH environment and return a balanced score.
71
+
72
+ Score components:
73
+ - Bug discovery: min(bugs_found, 5) * 1.0 (capped at 5.0 to not dominate)
74
+ - Coverage: (coverage_pct / 100) * 2.0 (up to 2.0)
75
+ - Efficiency: if bugs > 0: +0.5 per bug found in first 10 actions
76
+ - Crash penalty: -0.1 per action that caused a 500 error
77
+
78
+ Total range: roughly -2 to +8
79
+
80
+ Each completion gets its OWN fresh environment β€” no state pollution.
81
+ """
82
+ prompts_meta = kwargs.get("prompts_meta", [])
83
+ rewards = []
84
+
85
+ for i, text in enumerate(completions):
86
+ actions = parse_test_plan(text)
87
+ if not actions:
88
+ rewards.append(-1.0)
89
+ continue
90
+
91
+ # Get episode seed and task
92
+ meta = prompts_meta[i % len(prompts_meta)] if prompts_meta else {}
93
+ seed = meta.get("seed", 42)
94
+ task_id = meta.get("task_id", "basic_validation")
95
+
96
+ # Create a FRESH environment
97
+ env = APITestEnvironment()
98
+ env.reset(seed=seed, task_id=task_id)
99
+
100
+ # Execute all actions, track results
101
+ crashes = 0
102
+ step_rewards = []
103
+ for action in actions:
104
+ try:
105
+ obs = env.step(action)
106
+ step_rewards.append(obs.reward or 0.0)
107
+ if obs.status_code >= 500:
108
+ crashes += 1
109
+ except Exception:
110
+ step_rewards.append(0.0)
111
+ crashes += 1
112
+
113
+ state = env.state
114
+ coverage = state.coverage_pct
115
+
116
+ # Component 1: Bug discovery (capped to prevent domination)
117
+ bug_score = min(state.bugs_found, 5) * 1.0
118
+
119
+ # Component 2: Coverage (proportional, up to 2.0)
120
+ coverage_score = (coverage / 100) * 2.0
121
+
122
+ # Component 3: Efficiency β€” finding bugs early is better
123
+ early_bug_bonus = 0.0
124
+ early_steps = step_rewards[:10]
125
+ for r in early_steps:
126
+ if r > 0.2: # High reward step = likely found a bug
127
+ early_bug_bonus += 0.3
128
+
129
+ # Component 4: Crash penalty
130
+ crash_penalty = crashes * -0.1
131
+
132
+ # Component 5: Step reward sum (small weight β€” mainly for gradient signal)
133
+ step_sum = sum(step_rewards) * 0.2
134
+
135
+ total = bug_score + coverage_score + early_bug_bonus + crash_penalty + step_sum
136
+ rewards.append(round(total, 4))
137
+
138
+ return rewards
139
+
140
+
141
+ def diversity_reward_fn(completions: list[str], **kwargs) -> list[float]:
142
+ """Reward for diverse test plans β€” varied methods, endpoints, and strategies.
143
+
144
+ Components:
145
+ - Method variety: up to +0.5 (using GET/POST/PUT/DELETE)
146
+ - Endpoint variety: up to +0.5 (testing different endpoints)
147
+ - Strategy variety: up to +0.5 (auth + invalid input + boundary + injection patterns)
148
+ - Repetition penalty: up to -0.5
149
+ """
150
+ rewards = []
151
+ for text in completions:
152
+ actions = parse_test_plan(text)
153
+ if not actions:
154
+ rewards.append(0.0)
155
+ continue
156
+
157
+ methods = set()
158
+ endpoints = set()
159
+ unique_pairs = set()
160
+ has_auth = False
161
+ has_invalid_input = False
162
+ has_boundary = False
163
+ has_injection = False
164
+ has_nonexistent_id = False
165
+
166
+ for a in actions:
167
+ m = a.method.value if hasattr(a.method, "value") else str(a.method)
168
+ methods.add(m)
169
+ norm_ep = re.sub(r'/\d+', '/{id}', a.endpoint)
170
+ endpoints.add(norm_ep)
171
+ unique_pairs.add((m, norm_ep))
172
+
173
+ # Detect testing strategies
174
+ if a.endpoint == "/auth/login":
175
+ has_auth = True
176
+ if a.body and not a.body.get("title") and a.method.value == "POST":
177
+ has_invalid_input = True
178
+ qp = a.query_params or {}
179
+ if any(isinstance(v, (int, float)) and v < 0 for v in qp.values()):
180
+ has_boundary = True
181
+ if any(isinstance(v, (int, float)) and v > 10000 for v in qp.values()):
182
+ has_boundary = True
183
+ if a.body and any("DROP" in str(v).upper() or "script" in str(v).lower()
184
+ for v in (a.body or {}).values()):
185
+ has_injection = True
186
+ if re.search(r'/\d{4,}', a.endpoint):
187
+ has_nonexistent_id = True
188
+
189
+ # Method variety (max 4 methods = +0.5)
190
+ method_score = min(len(methods) / 4, 1.0) * 0.5
191
+
192
+ # Endpoint variety (max 7 endpoints = +0.5)
193
+ endpoint_score = min(len(endpoints) / 7, 1.0) * 0.5
194
+
195
+ # Strategy variety (each strategy = +0.1, max +0.5)
196
+ strategies = sum([has_auth, has_invalid_input, has_boundary, has_injection, has_nonexistent_id])
197
+ strategy_score = min(strategies * 0.1, 0.5)
198
+
199
+ # Repetition penalty
200
+ if len(actions) > 0:
201
+ repeat_count = len(actions) - len(unique_pairs)
202
+ repetition_penalty = min(repeat_count / len(actions), 1.0) * -0.5
203
+ else:
204
+ repetition_penalty = 0.0
205
+
206
+ total = method_score + endpoint_score + strategy_score + repetition_penalty
207
+ rewards.append(round(total, 3))
208
+
209
+ return rewards
uv.lock ADDED
The diff for this file is too large to render. See raw diff