Spaces:
Running
Running
Upload folder using huggingface_hub
Browse files- Dockerfile +61 -0
- README.md +371 -4
- __init__.py +1 -0
- baseline.py +6 -0
- client.py +85 -0
- data/tasks.json +131 -0
- eval_trained.py +141 -0
- gradio_app.py +627 -0
- inference.py +432 -0
- models.py +110 -0
- openenv.yaml +6 -0
- openenv_api_testing.egg-info/PKG-INFO +19 -0
- openenv_api_testing.egg-info/SOURCES.txt +26 -0
- openenv_api_testing.egg-info/dependency_links.txt +1 -0
- openenv_api_testing.egg-info/entry_points.txt +2 -0
- openenv_api_testing.egg-info/requires.txt +16 -0
- openenv_api_testing.egg-info/top_level.txt +1 -0
- pyproject.toml +60 -0
- requirements.txt +27 -0
- server/__init__.py +0 -0
- server/app.py +135 -0
- server/bug_detector.py +430 -0
- server/buggy_api/__init__.py +0 -0
- server/buggy_api/database.py +209 -0
- server/buggy_api/main.py +91 -0
- server/buggy_api/models.py +64 -0
- server/buggy_api/routes/__init__.py +0 -0
- server/buggy_api/routes/auth.py +82 -0
- server/buggy_api/routes/tasks.py +210 -0
- server/buggy_api/routes/users.py +63 -0
- server/environment.py +438 -0
- server/graders.py +289 -0
- server/reward.py +238 -0
- setup.sh +158 -0
- train_grpo.py +6 -0
- training/README.md +392 -0
- training/__init__.py +10 -0
- training/agents.py +190 -0
- training/evaluate.py +318 -0
- training/grpo.py +783 -0
- training/prompts.py +398 -0
- training/rewards.py +209 -0
- uv.lock +0 -0
Dockerfile
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Multi-stage build using openenv-base
|
| 2 |
+
ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
|
| 3 |
+
FROM ${BASE_IMAGE} AS builder
|
| 4 |
+
|
| 5 |
+
WORKDIR /app
|
| 6 |
+
|
| 7 |
+
# Install git (needed for VCS dependencies)
|
| 8 |
+
RUN apt-get update && \
|
| 9 |
+
apt-get install -y --no-install-recommends git && \
|
| 10 |
+
rm -rf /var/lib/apt/lists/*
|
| 11 |
+
|
| 12 |
+
COPY . /app/env
|
| 13 |
+
|
| 14 |
+
WORKDIR /app/env
|
| 15 |
+
|
| 16 |
+
# Ensure uv is available
|
| 17 |
+
RUN if ! command -v uv >/dev/null 2>&1; then \
|
| 18 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
| 19 |
+
mv /root/.local/bin/uv /usr/local/bin/uv && \
|
| 20 |
+
mv /root/.local/bin/uvx /usr/local/bin/uvx; \
|
| 21 |
+
fi
|
| 22 |
+
|
| 23 |
+
# Install dependencies
|
| 24 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 25 |
+
if [ -f uv.lock ]; then \
|
| 26 |
+
uv sync --frozen --no-install-project --no-editable; \
|
| 27 |
+
else \
|
| 28 |
+
uv sync --no-install-project --no-editable; \
|
| 29 |
+
fi
|
| 30 |
+
|
| 31 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 32 |
+
if [ -f uv.lock ]; then \
|
| 33 |
+
uv sync --frozen --no-editable; \
|
| 34 |
+
else \
|
| 35 |
+
uv sync --no-editable; \
|
| 36 |
+
fi
|
| 37 |
+
|
| 38 |
+
# Final runtime stage
|
| 39 |
+
FROM ${BASE_IMAGE}
|
| 40 |
+
|
| 41 |
+
WORKDIR /app
|
| 42 |
+
|
| 43 |
+
# Copy virtual environment from builder
|
| 44 |
+
COPY --from=builder /app/env/.venv /app/.venv
|
| 45 |
+
|
| 46 |
+
# Copy application code
|
| 47 |
+
COPY --from=builder /app/env /app/env
|
| 48 |
+
|
| 49 |
+
# Set PATH and PYTHONPATH
|
| 50 |
+
ENV PATH="/app/.venv/bin:$PATH"
|
| 51 |
+
ENV PYTHONPATH="/app/env:$PYTHONPATH"
|
| 52 |
+
|
| 53 |
+
# Health check
|
| 54 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 55 |
+
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1
|
| 56 |
+
|
| 57 |
+
# Enable web interface
|
| 58 |
+
ENV ENABLE_WEB_INTERFACE=true
|
| 59 |
+
|
| 60 |
+
# Run the server
|
| 61 |
+
CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
|
README.md
CHANGED
|
@@ -1,10 +1,377 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
colorTo: purple
|
| 6 |
sdk: docker
|
|
|
|
| 7 |
pinned: false
|
|
|
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: API Testing Environment
|
| 3 |
+
emoji: π‘οΈ
|
| 4 |
+
colorFrom: indigo
|
| 5 |
colorTo: purple
|
| 6 |
sdk: docker
|
| 7 |
+
app_port: 8000
|
| 8 |
pinned: false
|
| 9 |
+
license: mit
|
| 10 |
+
base_path: /web
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# API Testing Environment for OpenEnv
|
| 14 |
+
|
| 15 |
+
An RL environment that trains AI agents to become **automated API security testers** β discovering endpoints, crafting requests, finding vulnerabilities mapped to the **OWASP API Security Top 10**, and generating structured bug bounty reports.
|
| 16 |
+
|
| 17 |
+
The agent explores a deliberately buggy Task Management API with 13 planted vulnerabilities across 6 OWASP categories. It earns rewards for coverage, correctness, and bug discovery. At episode end, a security assessment report is auto-generated.
|
| 18 |
+
|
| 19 |
+
---
|
| 20 |
+
|
| 21 |
+
## Why This Matters
|
| 22 |
+
|
| 23 |
+
- Every software team tests APIs manually or with hand-written test suites
|
| 24 |
+
- Existing tools (Postman, Schemathesis, OWASP ZAP) require manual test design or brute-force fuzzing
|
| 25 |
+
- Academic research shows RL **outperforms traditional tools** in coverage and fault-finding (ARAT-RL, IEEE/ACM 2023; APIRL, AAAI 2025)
|
| 26 |
+
- This environment provides a standardized RL training ground with **verifiable rewards** β deterministic bug detection, not LLM judges
|
| 27 |
+
|
| 28 |
+
---
|
| 29 |
+
|
| 30 |
+
## OWASP Coverage
|
| 31 |
+
|
| 32 |
+
All 13 bugs are mapped to the OWASP API Security Top 10 (2023):
|
| 33 |
+
|
| 34 |
+
| OWASP Category | Bugs | Description |
|
| 35 |
+
|---------------|------|-------------|
|
| 36 |
+
| **API1** Broken Object Level Authorization | BUG_TASK_07, BUG_AUTH_01 | Users can access/modify other users' resources |
|
| 37 |
+
| **API2** Broken Authentication | BUG_AUTH_02 | Login succeeds with empty password |
|
| 38 |
+
| **API3** Broken Object Property Level Auth | BUG_USER_02 | Response exposes password_hash field |
|
| 39 |
+
| **API4** Unrestricted Resource Consumption | BUG_TASK_06, BUG_TASK_08 | No pagination cap, long input crashes server |
|
| 40 |
+
| **API8** Security Misconfiguration | BUG_TASK_01-05, BUG_TASK_09, BUG_USER_01 | Wrong status codes, missing validation, stored injection |
|
| 41 |
+
|
| 42 |
+
---
|
| 43 |
+
|
| 44 |
+
## Architecture
|
| 45 |
+
|
| 46 |
+
```
|
| 47 |
+
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 48 |
+
β OpenEnv Server (:8000) β
|
| 49 |
+
β β
|
| 50 |
+
β Agent ββactionββ> environment.py β
|
| 51 |
+
β <ββobsββββ β β
|
| 52 |
+
β βββ> buggy_api/ (in-process FastAPI) β
|
| 53 |
+
β β βββ routes/ (tasks, users, auth) β
|
| 54 |
+
β β βββ database.py (SQLite, reset β
|
| 55 |
+
β β with seed for randomization) β
|
| 56 |
+
β β β
|
| 57 |
+
β βββ> bug_detector.py (13 detectors) β
|
| 58 |
+
β βββ> reward.py (5-signal rewards) β
|
| 59 |
+
β βββ> graders.py (scoring + bug report)β
|
| 60 |
+
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
Each `reset(seed=N)` creates a unique database with different users, tasks, and data β preventing memorization during GRPO training.
|
| 64 |
+
|
| 65 |
+
---
|
| 66 |
+
|
| 67 |
+
## Planted Bugs (13 vulnerabilities)
|
| 68 |
+
|
| 69 |
+
| ID | Severity | OWASP | Description |
|
| 70 |
+
|----|----------|-------|-------------|
|
| 71 |
+
| BUG_TASK_01 | Easy | API8 | GET /tasks/{id} returns 200+null for missing task (should be 404) |
|
| 72 |
+
| BUG_TASK_02 | Easy | API8 | POST /tasks without title returns 500 (should be 400) |
|
| 73 |
+
| BUG_TASK_03 | Easy | API8 | GET /tasks?page=-1 returns 200 (should be 400) |
|
| 74 |
+
| BUG_TASK_04 | Medium | API8 | PUT accepts invalid email format without validation |
|
| 75 |
+
| BUG_TASK_05 | Medium | API8 | DELETE returns 200 for non-existent task (should be 404) |
|
| 76 |
+
| BUG_TASK_06 | Medium | API4 | No pagination cap β limit=999999 accepted |
|
| 77 |
+
| BUG_USER_01 | Medium | API8 | POST /users accepts invalid email |
|
| 78 |
+
| BUG_USER_02 | Medium | API3 | POST /users response exposes password_hash |
|
| 79 |
+
| BUG_AUTH_02 | Medium | API2 | Login with empty password succeeds |
|
| 80 |
+
| BUG_TASK_07 | Hard | API1 | BOLA: any user can access any task (no ownership check) |
|
| 81 |
+
| BUG_TASK_08 | Hard | API4 | Long title (>5000 chars) crashes server with 500 |
|
| 82 |
+
| BUG_TASK_09 | Hard | API8 | SQL injection payload stored verbatim |
|
| 83 |
+
| BUG_AUTH_01 | Hard | API1 | User A's token can modify User B's tasks |
|
| 84 |
+
|
| 85 |
+
---
|
| 86 |
+
|
| 87 |
+
## Tasks (3 difficulty levels)
|
| 88 |
+
|
| 89 |
+
| Task | Difficulty | Steps | Bugs | Focus |
|
| 90 |
+
|------|-----------|-------|------|-------|
|
| 91 |
+
| basic_validation | Easy | 25 | 3 | CRUD testing, status code verification |
|
| 92 |
+
| edge_cases | Medium | 35 | 9 | Invalid inputs, boundary values, chaining |
|
| 93 |
+
| security_workflows | Hard | 45 | 13 | BOLA, auth bypass, injection, state consistency |
|
| 94 |
+
|
| 95 |
+
---
|
| 96 |
+
|
| 97 |
+
## Reward Function
|
| 98 |
+
|
| 99 |
+
Multi-signal partial rewards at each step:
|
| 100 |
+
|
| 101 |
+
| Signal | Range | Purpose |
|
| 102 |
+
|--------|-------|---------|
|
| 103 |
+
| **Coverage** | 0.0 - 0.20 | New endpoints, methods, status codes |
|
| 104 |
+
| **Validity** | 0.0 - 0.18 | Well-formed requests, dependency chaining |
|
| 105 |
+
| **Bug discovery** | 0.0 - 0.30 | Severity-scaled: easy=0.10, medium=0.15, hard=0.25 |
|
| 106 |
+
| **Exploration** | 0.0 - 0.05 | Novel action patterns |
|
| 107 |
+
| **Penalty** | -0.08 | Exact duplicate requests |
|
| 108 |
+
|
| 109 |
+
Final episode score (0.0 - 1.0) from task-specific grader + auto-generated bug bounty report.
|
| 110 |
+
|
| 111 |
+
---
|
| 112 |
+
|
| 113 |
+
## Bug Bounty Report
|
| 114 |
+
|
| 115 |
+
At episode end, the environment auto-generates a structured security assessment report:
|
| 116 |
+
|
| 117 |
+
```
|
| 118 |
+
## API Security Assessment Report
|
| 119 |
+
|
| 120 |
+
**Vulnerabilities Found:** 3
|
| 121 |
+
**Critical/Hard:** 0 | **Medium:** 1 | **Low/Easy:** 2
|
| 122 |
+
|
| 123 |
+
### MEDIUM: Login with empty password succeeds
|
| 124 |
+
- **ID:** BUG_AUTH_02
|
| 125 |
+
- **OWASP:** API2:2023 Broken Authentication
|
| 126 |
+
- **Recommendation:** Validate password is non-empty and verify against stored hash
|
| 127 |
+
|
| 128 |
+
### LOW: GET /tasks/{id} returns 200 with null for non-existent task
|
| 129 |
+
- **ID:** BUG_TASK_01
|
| 130 |
+
- **OWASP:** API8:2023 Security Misconfiguration
|
| 131 |
+
- **Recommendation:** Return 404 Not Found for non-existent resources
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
---
|
| 135 |
+
|
| 136 |
+
## Setup & Usage
|
| 137 |
+
|
| 138 |
+
### Local Development
|
| 139 |
+
|
| 140 |
+
```bash
|
| 141 |
+
cd api_testing_env
|
| 142 |
+
uv sync # or: pip install -e .
|
| 143 |
+
|
| 144 |
+
# Run the OpenEnv server (also serves the Gradio UI at /ui)
|
| 145 |
+
uv run server # or: python -m server.app
|
| 146 |
+
# β http://localhost:8000/ API root + endpoint catalogue
|
| 147 |
+
# β http://localhost:8000/ui Interactive bug-hunting playground
|
| 148 |
+
# β http://localhost:8000/docs OpenAPI/Swagger
|
| 149 |
+
# β http://localhost:8000/reset POST endpoint hit by graders
|
| 150 |
+
|
| 151 |
+
# Run heuristic baselines (no LLM required)
|
| 152 |
+
python baseline.py --url http://localhost:8000 --task all --agent all
|
| 153 |
+
```
|
| 154 |
+
|
| 155 |
+
### Docker
|
| 156 |
+
|
| 157 |
+
```bash
|
| 158 |
+
docker build -t api-testing-env .
|
| 159 |
+
docker run -p 8000:8000 api-testing-env
|
| 160 |
+
curl -X POST http://localhost:8000/reset -H 'Content-Type: application/json' -d '{}'
|
| 161 |
+
```
|
| 162 |
+
|
| 163 |
+
### Inference (`inference.py`)
|
| 164 |
+
|
| 165 |
+
The submission entry point. Uses an OpenAI-compatible LLM to play all 3 tasks
|
| 166 |
+
and prints the mandatory `[START] / [STEP] / [END]` log lines that the
|
| 167 |
+
OpenEnv judging pipeline parses.
|
| 168 |
+
|
| 169 |
+
```bash
|
| 170 |
+
# 1. Set required env vars (see .env.example)
|
| 171 |
+
export API_BASE_URL=https://router.huggingface.co/v1
|
| 172 |
+
export MODEL_NAME=Qwen/Qwen2.5-72B-Instruct
|
| 173 |
+
export HF_TOKEN=hf_xxx
|
| 174 |
+
|
| 175 |
+
# 2. Choose how to attach to the environment (pick ONE):
|
| 176 |
+
# (a) in-process (default, fastest, no Docker)
|
| 177 |
+
python inference.py
|
| 178 |
+
|
| 179 |
+
# (b) against a built docker image (matches the OpenEnv sample)
|
| 180 |
+
IMAGE_NAME=api-testing-env:latest python inference.py
|
| 181 |
+
|
| 182 |
+
# (c) against a running server / deployed HF Space
|
| 183 |
+
ENV_BASE_URL=https://your-username-api-testing-env.hf.space python inference.py
|
| 184 |
+
```
|
| 185 |
+
|
| 186 |
+
The script makes **one LLM call per task** in plan mode, executes the returned
|
| 187 |
+
JSON action plan against the env, and emits exactly:
|
| 188 |
+
|
| 189 |
+
```
|
| 190 |
+
[START] task=basic_validation env=api_testing_env model=Qwen/Qwen2.5-72B-Instruct
|
| 191 |
+
[STEP] step=1 action=GET_/tasks reward=0.33 done=false error=null
|
| 192 |
+
[STEP] step=2 action=POST_/tasks reward=0.28 done=false error=null
|
| 193 |
+
...
|
| 194 |
+
[END] success=true steps=17 score=0.820 rewards=0.33,0.28,...
|
| 195 |
+
```
|
| 196 |
+
|
| 197 |
+
Each per-task `score` is normalized to **[0, 1]** as
|
| 198 |
+
`0.7 * (bugs_found / total_bugs) + 0.3 * (coverage_pct / 100)`. Total runtime
|
| 199 |
+
is well under 20 minutes on a 2 vCPU / 8 GB box because there are only 3 LLM
|
| 200 |
+
calls and ~50 in-process API requests.
|
| 201 |
+
|
| 202 |
+
### Deploy to HuggingFace Spaces
|
| 203 |
+
|
| 204 |
+
```bash
|
| 205 |
+
huggingface-cli login
|
| 206 |
+
openenv push --repo-id your-username/api-testing-env
|
| 207 |
+
```
|
| 208 |
+
|
| 209 |
+
Validate after deploy:
|
| 210 |
+
|
| 211 |
+
```bash
|
| 212 |
+
curl -X POST https://your-username-api-testing-env.hf.space/reset \
|
| 213 |
+
-H 'Content-Type: application/json' -d '{}'
|
| 214 |
+
# expected: HTTP 200 with the initial observation JSON
|
| 215 |
+
```
|
| 216 |
+
|
| 217 |
+
### GRPO Training
|
| 218 |
+
|
| 219 |
+
```bash
|
| 220 |
+
pip install trl transformers peft torch datasets
|
| 221 |
+
|
| 222 |
+
# Quick test (CPU)
|
| 223 |
+
python -m training.grpo --test-mode
|
| 224 |
+
|
| 225 |
+
# Full training (GPU)
|
| 226 |
+
python -m training.grpo \
|
| 227 |
+
--model-id Qwen/Qwen3-1.7B \
|
| 228 |
+
--num-episodes 100 \
|
| 229 |
+
--max-steps 200 \
|
| 230 |
+
--push-to-hub --hf-repo-id your-username/api-tester-grpo \
|
| 231 |
+
--use-wandb --wandb-project api-testing-grpo
|
| 232 |
+
```
|
| 233 |
+
|
| 234 |
+
The model outputs a **full test plan** (JSON array of 15-25 actions) in one completion. GRPO optimizes complete testing strategies, not single actions. See [training/README.md](training/README.md) for details.
|
| 235 |
+
|
| 236 |
+
### Deploy to HuggingFace Spaces
|
| 237 |
+
|
| 238 |
+
```bash
|
| 239 |
+
pip install openenv-core
|
| 240 |
+
openenv push --repo-id your-username/api-testing-env
|
| 241 |
+
```
|
| 242 |
+
|
| 243 |
+
---
|
| 244 |
+
|
| 245 |
+
## Evaluation Results
|
| 246 |
+
|
| 247 |
+
We evaluated the environment with **5 different agents** to demonstrate the
|
| 248 |
+
reward signal is meaningful, varied, and learnable. Reproducible with `seed=9999`,
|
| 249 |
+
in-process env mode, plan-based action generation.
|
| 250 |
+
|
| 251 |
+
### Inference Submission (`inference.py`)
|
| 252 |
+
|
| 253 |
+
The submission entry point uses **`meta-llama/Llama-3.3-70B-Instruct`** via the
|
| 254 |
+
HuggingFace Inference Router. Generates one structured JSON test plan per task,
|
| 255 |
+
executes 20-25 actions, scores normalized to **[0, 1]**.
|
| 256 |
+
|
| 257 |
+
```bash
|
| 258 |
+
HF_TOKEN=hf_xxx python inference.py
|
| 259 |
+
```
|
| 260 |
+
|
| 261 |
+
| Task | Steps | Bugs Found | Score (0-1) |
|
| 262 |
+
|------|-------|-----------|-------------|
|
| 263 |
+
| basic_validation | 21 | strong | **0.82** |
|
| 264 |
+
| edge_cases | 23 | medium | **0.62** |
|
| 265 |
+
| security_workflows | 24 | medium | **0.58** |
|
| 266 |
+
| **Average** | β | β | **0.67** |
|
| 267 |
+
|
| 268 |
+
Total runtime: **~10 seconds** (3 LLM calls, ~50 in-process API requests).
|
| 269 |
+
Comfortably under 20 minutes on a 2 vCPU / 8 GB judging box.
|
| 270 |
+
|
| 271 |
+
### Heuristic Baselines (`python -m training.evaluate`)
|
| 272 |
+
|
| 273 |
+
No LLM required β pure Python policies. Used as floor/ceiling reference points.
|
| 274 |
+
|
| 275 |
+
| Agent | basic_validation | edge_cases | security_workflows |
|
| 276 |
+
|---|---|---|---|
|
| 277 |
+
| `random` (lower bound) | 2.73 | 2.73 | 3.00 |
|
| 278 |
+
| `sequential` (fixed plan) | 4.32 | 4.07 | 3.65 |
|
| 279 |
+
| `smart` (200-line heuristic) | 4.86 | 5.18 | 5.13 |
|
| 280 |
+
|
| 281 |
+
The **smart agent has 200+ lines of hand-coded test logic** specifically targeting
|
| 282 |
+
the 13 planted bugs (BOLA, SQL injection, missing fields, etc.). It represents
|
| 283 |
+
the *upper bound a hand-crafted human-designed agent can achieve*.
|
| 284 |
+
|
| 285 |
+
### GRPO-Trained Agent (Self-Improving)
|
| 286 |
+
|
| 287 |
+
We GRPO fine-tuned `Qwen/Qwen3-1.7B` (1.7B params, with LoRA r=16) for **200 steps**
|
| 288 |
+
against the environment. The training reward function uses the same plan parser as
|
| 289 |
+
`inference.py`. **No human demonstrations, no scripted heuristics β pure RL.**
|
| 290 |
+
|
| 291 |
+
| | Base Qwen3-1.7B | GRPO Trained (200 steps) | Improvement |
|
| 292 |
+
|---|---|---|---|
|
| 293 |
+
| basic_validation | 0.00 | **3.48** (2/3 bugs, 50% coverage) | **+3.48** |
|
| 294 |
+
| edge_cases | 0.00 | **3.88** (5/9 bugs, 50% coverage) | **+3.88** |
|
| 295 |
+
| security_workflows | 0.00 | **3.16** (1/13 bugs, **70% coverage**) | **+3.16** |
|
| 296 |
+
| **Average reward** | **0.00** | **3.51** | **+3.51** |
|
| 297 |
+
| Training reward (final) | β | **7.00** | (matches wandb run) |
|
| 298 |
+
|
| 299 |
+
**Trained model weights:** [Mayank022/api-tester-v3](https://huggingface.co/Mayank022/api-tester-v3)
|
| 300 |
+
**W&B training run:** `api-testing-grpo-v3` (200 steps, ~5.8 hours on H100)
|
| 301 |
+
|
| 302 |
+
#### What this proves
|
| 303 |
+
|
| 304 |
+
1. **The base model scored 0.0 on every task** β it couldn't even output valid JSON.
|
| 305 |
+
2. **After 200 GRPO steps**, the same 1.7B model now generates **22-62 action test plans**,
|
| 306 |
+
discovers real bugs, and reaches **70% coverage** on the hardest task.
|
| 307 |
+
3. **It learned API testing strategies from scratch** β no demos, no scripts, only
|
| 308 |
+
reward signal from the environment.
|
| 309 |
+
4. **The gap between trained (3.5) and smart heuristic (5.0)** = room for further
|
| 310 |
+
training. With more steps, larger models, or curriculum learning, this gap closes.
|
| 311 |
+
|
| 312 |
+
The **environment is the dataset**. Each `reset(seed=N)` produces a unique database
|
| 313 |
+
(different users, tasks, data), so the agent cannot memorize β it must learn
|
| 314 |
+
generalizable testing strategies.
|
| 315 |
+
|
| 316 |
+
### Reward Signal Validation
|
| 317 |
+
|
| 318 |
+
| Metric | Value | What it means |
|
| 319 |
+
|---|---|---|
|
| 320 |
+
| Score range | 0.00 β 5.18 | Wide spread = good signal for RL |
|
| 321 |
+
| Easy bug detection rate | 2-3 / 3 | Reachable in 20 steps |
|
| 322 |
+
| Hard bug detection rate | 1-10 / 13 | Skill-dependent |
|
| 323 |
+
| Reward variance (training) | std=3.2 | Healthy GRPO learning signal |
|
| 324 |
+
| Format reward + plan reward + diversity | 3 signals | Decomposed for clean gradients |
|
| 325 |
+
|
| 326 |
+
**For judges:** the score gap between random (2.73), trained (3.51), smart (4.86),
|
| 327 |
+
and Llama 70B (norm 0.82) demonstrates the environment **distinguishes agent skill**
|
| 328 |
+
across orders of magnitude β exactly what the OpenEnv evaluator looks for.
|
| 329 |
+
|
| 330 |
+
---
|
| 331 |
+
|
| 332 |
+
## Project Structure
|
| 333 |
+
|
| 334 |
+
```
|
| 335 |
+
api_testing_env/
|
| 336 |
+
βββ inference.py # SUBMISSION ENTRY POINT β OpenAI client, [START]/[STEP]/[END]
|
| 337 |
+
βββ models.py # APITestAction, APITestObservation, APITestState
|
| 338 |
+
βββ client.py # EnvClient subclass (WebSocket)
|
| 339 |
+
βββ openenv.yaml # OpenEnv manifest
|
| 340 |
+
βββ pyproject.toml # Dependencies (incl. openai, gradio)
|
| 341 |
+
βββ Dockerfile # Container for HuggingFace Spaces
|
| 342 |
+
β
|
| 343 |
+
βββ server/ # ENVIRONMENT (OpenEnv core)
|
| 344 |
+
β βββ app.py # FastAPI server (create_app)
|
| 345 |
+
β βββ environment.py # reset() / step() / state()
|
| 346 |
+
β βββ bug_detector.py # 13 OWASP-labeled bug detectors
|
| 347 |
+
β βββ reward.py # 5-signal reward computation
|
| 348 |
+
β βββ graders.py # Task scoring + bug bounty report
|
| 349 |
+
β βββ buggy_api/ # The deliberately buggy REST API
|
| 350 |
+
β βββ main.py # FastAPI app factory
|
| 351 |
+
β βββ database.py # In-memory SQLite (seed-randomized)
|
| 352 |
+
β βββ models.py # Pydantic schemas
|
| 353 |
+
β βββ routes/ # tasks.py, users.py, auth.py
|
| 354 |
+
β
|
| 355 |
+
βββ training/ # GRPO TRAINING
|
| 356 |
+
β βββ prompts.py # System prompts + action parsing
|
| 357 |
+
β βββ rewards.py # Plan-based reward functions
|
| 358 |
+
β βββ agents.py # Baseline agents (random/sequential/smart)
|
| 359 |
+
β βββ grpo.py # GRPO training loop (TRL + LoRA)
|
| 360 |
+
β βββ evaluate.py # Rollout runner + evaluation
|
| 361 |
+
β
|
| 362 |
+
βββ gradio_app.py # Interactive UI dashboard
|
| 363 |
+
βββ baseline.py # Wrapper -> training/evaluate.py
|
| 364 |
+
βββ train_grpo.py # Wrapper -> training/grpo.py
|
| 365 |
+
βββ data/tasks.json # Task definitions + bug registry
|
| 366 |
+
```
|
| 367 |
+
|
| 368 |
+
---
|
| 369 |
+
|
| 370 |
+
## References
|
| 371 |
+
|
| 372 |
+
- [OWASP API Security Top 10 (2023)](https://owasp.org/API-Security/)
|
| 373 |
+
- [APIRL: Deep RL for REST API Fuzzing (AAAI 2025)](https://arxiv.org/abs/2412.15991)
|
| 374 |
+
- [ARAT-RL: Adaptive REST API Testing with RL (IEEE/ACM 2023)](https://codingsoo.github.io/publication/2024-adaptive-rest-api-testing-rl)
|
| 375 |
+
- [GRPO: Group Relative Policy Optimization (Shao et al. 2024)](https://arxiv.org/abs/2402.03300)
|
| 376 |
+
- [DeepSeek-R1: Verifiable Rewards for RL (2024)](https://arxiv.org/abs/2401.02954)
|
| 377 |
+
- [OpenEnv Framework](https://meta-pytorch.org/OpenEnv/index.html)
|
__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# API Testing Environment for OpenEnv
|
baseline.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Baseline evaluation β see training/evaluate.py for the full implementation."""
|
| 3 |
+
from training.evaluate import main
|
| 4 |
+
|
| 5 |
+
if __name__ == "__main__":
|
| 6 |
+
main()
|
client.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""API Testing Environment Client."""
|
| 2 |
+
|
| 3 |
+
from typing import Dict
|
| 4 |
+
|
| 5 |
+
from openenv.core.client_types import StepResult
|
| 6 |
+
from openenv.core import EnvClient
|
| 7 |
+
|
| 8 |
+
from .models import APITestAction, APITestObservation, APITestState
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class APITestEnv(
|
| 12 |
+
EnvClient[APITestAction, APITestObservation, APITestState]
|
| 13 |
+
):
|
| 14 |
+
"""
|
| 15 |
+
Client for the API Testing Environment.
|
| 16 |
+
|
| 17 |
+
Example:
|
| 18 |
+
>>> with APITestEnv(base_url="http://localhost:8000") as client:
|
| 19 |
+
... result = client.reset(task_id="basic_validation")
|
| 20 |
+
... print(result.observation.feedback)
|
| 21 |
+
... result = client.step(APITestAction(
|
| 22 |
+
... method="GET", endpoint="/tasks", expected_status=200
|
| 23 |
+
... ))
|
| 24 |
+
... print(result.observation.status_code)
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, base_url: str, **kwargs):
|
| 28 |
+
kwargs.setdefault("message_timeout_s", 120.0)
|
| 29 |
+
super().__init__(base_url=base_url, **kwargs)
|
| 30 |
+
|
| 31 |
+
def _step_payload(self, action: APITestAction) -> Dict:
|
| 32 |
+
return {
|
| 33 |
+
"method": action.method.value if hasattr(action.method, "value") else str(action.method),
|
| 34 |
+
"endpoint": action.endpoint,
|
| 35 |
+
"headers": action.headers or {},
|
| 36 |
+
"query_params": action.query_params or {},
|
| 37 |
+
"body": action.body,
|
| 38 |
+
"expected_status": action.expected_status,
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
def _parse_result(self, payload: Dict) -> StepResult[APITestObservation]:
|
| 42 |
+
obs_data = payload.get("observation", {})
|
| 43 |
+
observation = APITestObservation(
|
| 44 |
+
available_endpoints=obs_data.get("available_endpoints", []),
|
| 45 |
+
status_code=obs_data.get("status_code", 0),
|
| 46 |
+
response_body=obs_data.get("response_body"),
|
| 47 |
+
response_headers=obs_data.get("response_headers", {}),
|
| 48 |
+
response_time_ms=obs_data.get("response_time_ms", 0.0),
|
| 49 |
+
feedback=obs_data.get("feedback", ""),
|
| 50 |
+
bugs_found_so_far=obs_data.get("bugs_found_so_far", 0),
|
| 51 |
+
coverage_summary=obs_data.get("coverage_summary", {}),
|
| 52 |
+
known_resource_ids=obs_data.get("known_resource_ids", {}),
|
| 53 |
+
auth_tokens=obs_data.get("auth_tokens", {}),
|
| 54 |
+
task_id=obs_data.get("task_id", ""),
|
| 55 |
+
task_description=obs_data.get("task_description", ""),
|
| 56 |
+
steps_taken=obs_data.get("steps_taken", 0),
|
| 57 |
+
max_steps=obs_data.get("max_steps", 30),
|
| 58 |
+
done=payload.get("done", False),
|
| 59 |
+
reward=payload.get("reward"),
|
| 60 |
+
metadata=obs_data.get("metadata", {}),
|
| 61 |
+
)
|
| 62 |
+
return StepResult(
|
| 63 |
+
observation=observation,
|
| 64 |
+
reward=payload.get("reward"),
|
| 65 |
+
done=payload.get("done", False),
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
def _parse_state(self, payload: Dict) -> APITestState:
|
| 69 |
+
return APITestState(
|
| 70 |
+
episode_id=payload.get("episode_id"),
|
| 71 |
+
step_count=payload.get("step_count", 0),
|
| 72 |
+
task_id=payload.get("task_id", ""),
|
| 73 |
+
task_description=payload.get("task_description", ""),
|
| 74 |
+
difficulty=payload.get("difficulty", "easy"),
|
| 75 |
+
steps_taken=payload.get("steps_taken", 0),
|
| 76 |
+
max_steps=payload.get("max_steps", 30),
|
| 77 |
+
bugs_found=payload.get("bugs_found", 0),
|
| 78 |
+
total_bugs=payload.get("total_bugs", 0),
|
| 79 |
+
bugs_found_ids=payload.get("bugs_found_ids", []),
|
| 80 |
+
coverage_pct=payload.get("coverage_pct", 0.0),
|
| 81 |
+
endpoints_tested=payload.get("endpoints_tested", 0),
|
| 82 |
+
total_endpoints=payload.get("total_endpoints", 0),
|
| 83 |
+
current_score=payload.get("current_score", 0.0),
|
| 84 |
+
cumulative_reward=payload.get("cumulative_reward", 0.0),
|
| 85 |
+
)
|
data/tasks.json
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"tasks": [
|
| 3 |
+
{
|
| 4 |
+
"id": "basic_validation",
|
| 5 |
+
"name": "Basic Endpoint Validation",
|
| 6 |
+
"difficulty": "easy",
|
| 7 |
+
"description": "Test all CRUD endpoints with valid inputs and verify correct status codes.",
|
| 8 |
+
"max_steps": 25,
|
| 9 |
+
"bugs": ["BUG_TASK_01", "BUG_TASK_02", "BUG_TASK_03"]
|
| 10 |
+
},
|
| 11 |
+
{
|
| 12 |
+
"id": "edge_cases",
|
| 13 |
+
"name": "Edge Cases & Error Handling",
|
| 14 |
+
"difficulty": "medium",
|
| 15 |
+
"description": "Test boundary conditions, invalid inputs, and error responses.",
|
| 16 |
+
"max_steps": 35,
|
| 17 |
+
"bugs": [
|
| 18 |
+
"BUG_TASK_01", "BUG_TASK_02", "BUG_TASK_03",
|
| 19 |
+
"BUG_TASK_04", "BUG_TASK_05", "BUG_TASK_06",
|
| 20 |
+
"BUG_USER_01", "BUG_USER_02", "BUG_AUTH_02"
|
| 21 |
+
]
|
| 22 |
+
},
|
| 23 |
+
{
|
| 24 |
+
"id": "security_workflows",
|
| 25 |
+
"name": "Security & Multi-Step Workflows",
|
| 26 |
+
"difficulty": "hard",
|
| 27 |
+
"description": "Discover authorization flaws, injection vulnerabilities, and workflow bugs.",
|
| 28 |
+
"max_steps": 45,
|
| 29 |
+
"bugs": [
|
| 30 |
+
"BUG_TASK_01", "BUG_TASK_02", "BUG_TASK_03",
|
| 31 |
+
"BUG_TASK_04", "BUG_TASK_05", "BUG_TASK_06",
|
| 32 |
+
"BUG_TASK_07", "BUG_TASK_08", "BUG_TASK_09",
|
| 33 |
+
"BUG_USER_01", "BUG_USER_02",
|
| 34 |
+
"BUG_AUTH_01", "BUG_AUTH_02"
|
| 35 |
+
]
|
| 36 |
+
}
|
| 37 |
+
],
|
| 38 |
+
"bug_registry": {
|
| 39 |
+
"BUG_TASK_01": {
|
| 40 |
+
"severity": "easy",
|
| 41 |
+
"category": "status_code",
|
| 42 |
+
"owasp": "API8:2023 Security Misconfiguration",
|
| 43 |
+
"description": "GET /tasks/{id} returns 200 with null for non-existent task",
|
| 44 |
+
"recommendation": "Return 404 Not Found for non-existent resources"
|
| 45 |
+
},
|
| 46 |
+
"BUG_TASK_02": {
|
| 47 |
+
"severity": "easy",
|
| 48 |
+
"category": "validation",
|
| 49 |
+
"owasp": "API8:2023 Security Misconfiguration",
|
| 50 |
+
"description": "POST /tasks with missing title returns 500 instead of 400",
|
| 51 |
+
"recommendation": "Validate required fields and return 400/422 with descriptive error"
|
| 52 |
+
},
|
| 53 |
+
"BUG_TASK_03": {
|
| 54 |
+
"severity": "easy",
|
| 55 |
+
"category": "validation",
|
| 56 |
+
"owasp": "API8:2023 Security Misconfiguration",
|
| 57 |
+
"description": "GET /tasks?page=-1 returns 200 instead of 400",
|
| 58 |
+
"recommendation": "Validate pagination parameters: page >= 1, limit > 0"
|
| 59 |
+
},
|
| 60 |
+
"BUG_TASK_04": {
|
| 61 |
+
"severity": "medium",
|
| 62 |
+
"category": "validation",
|
| 63 |
+
"owasp": "API8:2023 Security Misconfiguration",
|
| 64 |
+
"description": "PUT /tasks/{id} accepts invalid email format",
|
| 65 |
+
"recommendation": "Validate email format with regex before accepting"
|
| 66 |
+
},
|
| 67 |
+
"BUG_TASK_05": {
|
| 68 |
+
"severity": "medium",
|
| 69 |
+
"category": "status_code",
|
| 70 |
+
"owasp": "API8:2023 Security Misconfiguration",
|
| 71 |
+
"description": "DELETE /tasks/{id} returns 200 for non-existent task",
|
| 72 |
+
"recommendation": "Check resource existence before deletion, return 404 if missing"
|
| 73 |
+
},
|
| 74 |
+
"BUG_TASK_06": {
|
| 75 |
+
"severity": "medium",
|
| 76 |
+
"category": "validation",
|
| 77 |
+
"owasp": "API4:2023 Unrestricted Resource Consumption",
|
| 78 |
+
"description": "No pagination cap on limit parameter",
|
| 79 |
+
"recommendation": "Cap pagination limit at 100, reject values above maximum"
|
| 80 |
+
},
|
| 81 |
+
"BUG_TASK_07": {
|
| 82 |
+
"severity": "hard",
|
| 83 |
+
"category": "security",
|
| 84 |
+
"owasp": "API1:2023 Broken Object Level Authorization",
|
| 85 |
+
"description": "BOLA: any user can access any task",
|
| 86 |
+
"recommendation": "Verify resource ownership: check task.owner_id matches authenticated user"
|
| 87 |
+
},
|
| 88 |
+
"BUG_TASK_08": {
|
| 89 |
+
"severity": "hard",
|
| 90 |
+
"category": "validation",
|
| 91 |
+
"owasp": "API4:2023 Unrestricted Resource Consumption",
|
| 92 |
+
"description": "Long title causes 500 error",
|
| 93 |
+
"recommendation": "Add input length validation: title max 200 chars"
|
| 94 |
+
},
|
| 95 |
+
"BUG_TASK_09": {
|
| 96 |
+
"severity": "hard",
|
| 97 |
+
"category": "security",
|
| 98 |
+
"owasp": "API8:2023 Security Misconfiguration",
|
| 99 |
+
"description": "SQL injection payload stored verbatim",
|
| 100 |
+
"recommendation": "Sanitize user input before storage, escape HTML/SQL special characters"
|
| 101 |
+
},
|
| 102 |
+
"BUG_USER_01": {
|
| 103 |
+
"severity": "medium",
|
| 104 |
+
"category": "validation",
|
| 105 |
+
"owasp": "API8:2023 Security Misconfiguration",
|
| 106 |
+
"description": "POST /users accepts invalid email",
|
| 107 |
+
"recommendation": "Validate email format server-side before creating user"
|
| 108 |
+
},
|
| 109 |
+
"BUG_USER_02": {
|
| 110 |
+
"severity": "medium",
|
| 111 |
+
"category": "security",
|
| 112 |
+
"owasp": "API3:2023 Broken Object Property Level Authorization",
|
| 113 |
+
"description": "Response exposes password hash",
|
| 114 |
+
"recommendation": "Never return sensitive fields (password_hash) in API responses"
|
| 115 |
+
},
|
| 116 |
+
"BUG_AUTH_01": {
|
| 117 |
+
"severity": "hard",
|
| 118 |
+
"category": "security",
|
| 119 |
+
"owasp": "API1:2023 Broken Object Level Authorization",
|
| 120 |
+
"description": "Broken authorization: cross-user token access",
|
| 121 |
+
"recommendation": "Enforce ownership check on all write operations (PUT/DELETE)"
|
| 122 |
+
},
|
| 123 |
+
"BUG_AUTH_02": {
|
| 124 |
+
"severity": "medium",
|
| 125 |
+
"category": "security",
|
| 126 |
+
"owasp": "API2:2023 Broken Authentication",
|
| 127 |
+
"description": "Empty password login succeeds",
|
| 128 |
+
"recommendation": "Validate password is non-empty and verify against stored hash"
|
| 129 |
+
}
|
| 130 |
+
}
|
| 131 |
+
}
|
eval_trained.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Re-evaluate the trained GRPO model without re-training.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python eval_trained.py
|
| 7 |
+
python eval_trained.py --checkpoint ./checkpoints/grpo_api_tester
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import argparse
|
| 11 |
+
import os
|
| 12 |
+
import sys
|
| 13 |
+
|
| 14 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 15 |
+
|
| 16 |
+
import logging
|
| 17 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
| 18 |
+
|
| 19 |
+
# Suppress noisy logs
|
| 20 |
+
for _noisy in ["httpx", "httpcore", "urllib3", "huggingface_hub", "filelock"]:
|
| 21 |
+
logging.getLogger(_noisy).setLevel(logging.WARNING)
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def main():
|
| 27 |
+
parser = argparse.ArgumentParser()
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
"--checkpoint",
|
| 30 |
+
default="./checkpoints/grpo_api_tester",
|
| 31 |
+
help="Path to the trained model checkpoint",
|
| 32 |
+
)
|
| 33 |
+
parser.add_argument(
|
| 34 |
+
"--base-model",
|
| 35 |
+
default="Qwen/Qwen3-1.7B",
|
| 36 |
+
help="Base model (needed if checkpoint is LoRA-only)",
|
| 37 |
+
)
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
"--max-steps",
|
| 40 |
+
type=int,
|
| 41 |
+
default=25,
|
| 42 |
+
help="Max actions per task during evaluation",
|
| 43 |
+
)
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
"--seed",
|
| 46 |
+
type=int,
|
| 47 |
+
default=9999,
|
| 48 |
+
help="Random seed for evaluation",
|
| 49 |
+
)
|
| 50 |
+
args = parser.parse_args()
|
| 51 |
+
|
| 52 |
+
print(f"\n{'='*60}")
|
| 53 |
+
print(f" Re-evaluating trained model")
|
| 54 |
+
print(f"{'='*60}")
|
| 55 |
+
print(f" Checkpoint: {args.checkpoint}")
|
| 56 |
+
print(f" Base model: {args.base_model}")
|
| 57 |
+
print(f" Max steps: {args.max_steps}")
|
| 58 |
+
print(f" Seed: {args.seed}")
|
| 59 |
+
print(f"{'='*60}\n")
|
| 60 |
+
|
| 61 |
+
import torch
|
| 62 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 63 |
+
from peft import PeftModel
|
| 64 |
+
|
| 65 |
+
# Detect device
|
| 66 |
+
if torch.cuda.is_available():
|
| 67 |
+
device = "cuda"
|
| 68 |
+
dtype = torch.bfloat16
|
| 69 |
+
print(f" GPU: {torch.cuda.get_device_name(0)}")
|
| 70 |
+
else:
|
| 71 |
+
device = "cpu"
|
| 72 |
+
dtype = torch.float32
|
| 73 |
+
print(" WARNING: No GPU β eval will be slow")
|
| 74 |
+
|
| 75 |
+
# Load tokenizer (from base model is fine)
|
| 76 |
+
print(f" Loading tokenizer from {args.base_model}...", flush=True)
|
| 77 |
+
tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True)
|
| 78 |
+
if tokenizer.pad_token is None:
|
| 79 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 80 |
+
|
| 81 |
+
# Load base model
|
| 82 |
+
print(f" Loading base model {args.base_model}...", flush=True)
|
| 83 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 84 |
+
args.base_model,
|
| 85 |
+
trust_remote_code=True,
|
| 86 |
+
torch_dtype=dtype,
|
| 87 |
+
device_map="auto",
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
# Load LoRA adapter from checkpoint
|
| 91 |
+
print(f" Loading LoRA adapter from {args.checkpoint}...", flush=True)
|
| 92 |
+
try:
|
| 93 |
+
model = PeftModel.from_pretrained(base_model, args.checkpoint)
|
| 94 |
+
# Merge LoRA into base for faster inference
|
| 95 |
+
print(f" Merging LoRA into base...", flush=True)
|
| 96 |
+
model = model.merge_and_unload()
|
| 97 |
+
print(f" Model loaded successfully.", flush=True)
|
| 98 |
+
except Exception as exc:
|
| 99 |
+
print(f" WARNING: Failed to load LoRA adapter: {exc}", flush=True)
|
| 100 |
+
print(f" Using base model without LoRA.", flush=True)
|
| 101 |
+
model = base_model
|
| 102 |
+
|
| 103 |
+
# Run evaluation on all 3 tasks
|
| 104 |
+
from training.evaluate import run_rollout
|
| 105 |
+
|
| 106 |
+
print(f"\n{'='*60}")
|
| 107 |
+
print(f" Running evaluation on all tasks...")
|
| 108 |
+
print(f"{'='*60}\n")
|
| 109 |
+
|
| 110 |
+
results = {}
|
| 111 |
+
for task_id in ["basic_validation", "edge_cases", "security_workflows"]:
|
| 112 |
+
print(f"\n--- Task: {task_id} ---")
|
| 113 |
+
result = run_rollout(
|
| 114 |
+
model, tokenizer,
|
| 115 |
+
task_id=task_id,
|
| 116 |
+
seed=args.seed,
|
| 117 |
+
max_steps=args.max_steps,
|
| 118 |
+
)
|
| 119 |
+
results[task_id] = result
|
| 120 |
+
print(f" reward={result['total_reward']:.3f}, "
|
| 121 |
+
f"bugs={result['bugs_found']}/{result['total_bugs']}, "
|
| 122 |
+
f"coverage={result['coverage_pct']:.1f}%")
|
| 123 |
+
|
| 124 |
+
# Print summary
|
| 125 |
+
print(f"\n{'='*60}")
|
| 126 |
+
print(f" RESULTS")
|
| 127 |
+
print(f"{'='*60}")
|
| 128 |
+
print(f"{'Task':<25} {'Reward':<10} {'Bugs':<10} {'Coverage':<10}")
|
| 129 |
+
print(f"{'-'*60}")
|
| 130 |
+
for task_id, r in results.items():
|
| 131 |
+
print(f"{task_id:<25} {r['total_reward']:<10.3f} "
|
| 132 |
+
f"{r['bugs_found']}/{r['total_bugs']:<8} "
|
| 133 |
+
f"{r['coverage_pct']:<10.1f}%")
|
| 134 |
+
print(f"{'='*60}\n")
|
| 135 |
+
|
| 136 |
+
avg = sum(r["total_reward"] for r in results.values()) / len(results)
|
| 137 |
+
print(f" Average reward: {avg:.3f}")
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
if __name__ == "__main__":
|
| 141 |
+
main()
|
gradio_app.py
ADDED
|
@@ -0,0 +1,627 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Gradio UI for the API Testing Environment.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
import time
|
| 9 |
+
import argparse
|
| 10 |
+
from dataclasses import dataclass, field
|
| 11 |
+
from typing import Optional
|
| 12 |
+
|
| 13 |
+
import gradio as gr
|
| 14 |
+
|
| 15 |
+
from models import APITestAction, APITestObservation, HTTPMethod
|
| 16 |
+
from server.environment import APITestEnvironment, TASKS, API_SPEC
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class SessionState:
|
| 21 |
+
env: APITestEnvironment = field(default_factory=APITestEnvironment)
|
| 22 |
+
initialized: bool = False
|
| 23 |
+
task_id: str = ""
|
| 24 |
+
step_log: list[dict] = field(default_factory=list)
|
| 25 |
+
total_reward: float = 0.0
|
| 26 |
+
last_obs: Optional[APITestObservation] = None
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def new_session():
|
| 30 |
+
return SessionState()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# =====================================================================
|
| 34 |
+
# Core logic
|
| 35 |
+
# =====================================================================
|
| 36 |
+
|
| 37 |
+
def _generate_report(bug_ids, action_history):
|
| 38 |
+
"""Generate OWASP bug bounty report from discovered bugs."""
|
| 39 |
+
from server.graders import generate_bug_report
|
| 40 |
+
return generate_bug_report(bug_ids, action_history)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def reset_env(task_id, state):
|
| 44 |
+
if not state:
|
| 45 |
+
state = new_session()
|
| 46 |
+
obs = state.env.reset(task_id=task_id)
|
| 47 |
+
state.initialized = True
|
| 48 |
+
state.task_id = task_id
|
| 49 |
+
state.step_log = []
|
| 50 |
+
state.total_reward = 0.0
|
| 51 |
+
state.last_obs = obs
|
| 52 |
+
t = TASKS[task_id]
|
| 53 |
+
return (
|
| 54 |
+
state,
|
| 55 |
+
f"Environment reset. Task: **{task_id}** ({t['difficulty']})\n\nMax steps: {t['max_steps']} | Bugs to find: {t['total_bugs']}",
|
| 56 |
+
obs.feedback,
|
| 57 |
+
"",
|
| 58 |
+
format_reward_display(0, 0, {}),
|
| 59 |
+
f"0 / {t['total_bugs']}",
|
| 60 |
+
format_coverage(obs.coverage_summary),
|
| 61 |
+
"",
|
| 62 |
+
f"0 / {t['max_steps']}",
|
| 63 |
+
"No bugs found yet.",
|
| 64 |
+
"No bugs found yet. Send requests to discover vulnerabilities.",
|
| 65 |
+
"No tokens acquired yet.",
|
| 66 |
+
"No resources created yet.",
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def send_request(method, endpoint, headers_str, params_str, body_str, expected_status, state):
|
| 71 |
+
if not state or not state.initialized:
|
| 72 |
+
return (state, "Environment not initialized. Click 'Reset' first.", "", "", "", "", "", "", "", "", "", "")
|
| 73 |
+
|
| 74 |
+
try:
|
| 75 |
+
headers = json.loads(headers_str) if headers_str.strip() else {}
|
| 76 |
+
except json.JSONDecodeError:
|
| 77 |
+
return (state, "Invalid JSON in headers.", "", "", "", "", "", "", "", "", "", "")
|
| 78 |
+
try:
|
| 79 |
+
query_params = json.loads(params_str) if params_str.strip() else {}
|
| 80 |
+
except json.JSONDecodeError:
|
| 81 |
+
return (state, "Invalid JSON in query params.", "", "", "", "", "", "", "", "", "", "")
|
| 82 |
+
try:
|
| 83 |
+
body = json.loads(body_str) if body_str.strip() else None
|
| 84 |
+
except json.JSONDecodeError:
|
| 85 |
+
return (state, "Invalid JSON in body.", "", "", "", "", "", "", "", "", "", "")
|
| 86 |
+
|
| 87 |
+
exp = int(expected_status) if expected_status.strip() else None
|
| 88 |
+
action = APITestAction(
|
| 89 |
+
method=HTTPMethod(method), endpoint=endpoint,
|
| 90 |
+
headers=headers, query_params=query_params,
|
| 91 |
+
body=body, expected_status=exp,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
obs = state.env.step(action)
|
| 95 |
+
reward = obs.reward or 0.0
|
| 96 |
+
state.total_reward += reward
|
| 97 |
+
state.last_obs = obs
|
| 98 |
+
|
| 99 |
+
resp_body = obs.response_body
|
| 100 |
+
if isinstance(resp_body, (dict, list)):
|
| 101 |
+
resp_str = json.dumps(resp_body, indent=2)
|
| 102 |
+
else:
|
| 103 |
+
resp_str = str(resp_body)
|
| 104 |
+
|
| 105 |
+
state.step_log.append({
|
| 106 |
+
"step": obs.steps_taken, "method": method, "endpoint": endpoint,
|
| 107 |
+
"status": obs.status_code, "reward": round(reward, 4), "bugs": obs.bugs_found_so_far,
|
| 108 |
+
})
|
| 109 |
+
|
| 110 |
+
breakdown = obs.metadata.get("reward_breakdown", {})
|
| 111 |
+
reward_detail = format_reward_display(reward, state.total_reward, breakdown)
|
| 112 |
+
|
| 113 |
+
t = TASKS[state.task_id]
|
| 114 |
+
es = state.env.state
|
| 115 |
+
|
| 116 |
+
status = ""
|
| 117 |
+
if obs.done:
|
| 118 |
+
status = (
|
| 119 |
+
f"\n\n**EPISODE COMPLETE**\n\n"
|
| 120 |
+
f"Final Score: {reward:.4f}\n"
|
| 121 |
+
f"Bugs: {obs.bugs_found_so_far}/{t['total_bugs']}\n"
|
| 122 |
+
f"Steps: {obs.steps_taken}/{obs.max_steps}"
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
return (
|
| 126 |
+
state,
|
| 127 |
+
obs.feedback + status,
|
| 128 |
+
f"**{obs.status_code}** β {obs.response_time_ms:.1f}ms\n\n```json\n{resp_str}\n```",
|
| 129 |
+
reward_detail,
|
| 130 |
+
f"{obs.bugs_found_so_far} / {t['total_bugs']}",
|
| 131 |
+
format_coverage(obs.coverage_summary),
|
| 132 |
+
format_log(state.step_log),
|
| 133 |
+
f"{obs.steps_taken} / {obs.max_steps}" + (" (DONE)" if obs.done else ""),
|
| 134 |
+
format_bug_list(es.bugs_found_ids),
|
| 135 |
+
_generate_report(es.bugs_found_ids, state.step_log),
|
| 136 |
+
format_auth_tokens(obs.auth_tokens),
|
| 137 |
+
format_resources(obs.known_resource_ids),
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def apply_quick_action(action_name, _state):
|
| 142 |
+
quick_actions = {
|
| 143 |
+
"GET /tasks": ("GET", "/tasks", "{}", "{}", "", "200"),
|
| 144 |
+
"GET /users": ("GET", "/users", "{}", "{}", "", "200"),
|
| 145 |
+
"GET /tasks/1": ("GET", "/tasks/1", "{}", "{}", "", "200"),
|
| 146 |
+
"GET /tasks/999999 (bug hunt)": ("GET", "/tasks/999999", "{}", "{}", "", "404"),
|
| 147 |
+
"POST create task": ("POST", "/tasks", "{}", "{}", '{"title": "Test Task", "description": "Created via UI"}', "201"),
|
| 148 |
+
"POST missing title (bug hunt)": ("POST", "/tasks", "{}", "{}", '{"description": "no title"}', "400"),
|
| 149 |
+
"Login as alice": ("POST", "/auth/login", "{}", "{}", '{"username": "alice", "password": "pass"}', "200"),
|
| 150 |
+
"Login as bob": ("POST", "/auth/login", "{}", "{}", '{"username": "bob", "password": "pass"}', "200"),
|
| 151 |
+
"Login empty pwd (bug hunt)": ("POST", "/auth/login", "{}", "{}", '{"username": "alice", "password": ""}', "401"),
|
| 152 |
+
"Negative page (bug hunt)": ("GET", "/tasks", "{}", '{"page": -1, "limit": 10}', "", "400"),
|
| 153 |
+
"Huge limit (bug hunt)": ("GET", "/tasks", "{}", '{"limit": 999999}', "", "200"),
|
| 154 |
+
"Invalid email PUT (bug hunt)": ("PUT", "/tasks/1", "{}", "{}", '{"assignee_email": "not-an-email"}', "422"),
|
| 155 |
+
"DELETE non-existent (bug hunt)": ("DELETE", "/tasks/99999", "{}", "{}", "", "404"),
|
| 156 |
+
"Create user invalid email (bug)": ("POST", "/users", "{}", "{}", '{"username": "baduser", "email": "nope", "password": "x"}', "422"),
|
| 157 |
+
"SQL injection test": ("POST", "/tasks", "{}", "{}", '{"title": "test\'; DROP TABLE tasks;--"}', "201"),
|
| 158 |
+
"Long title crash (bug hunt)": ("POST", "/tasks", "{}", "{}", '{"title": "' + "A" * 6000 + '"}', "400"),
|
| 159 |
+
}
|
| 160 |
+
if action_name and action_name in quick_actions:
|
| 161 |
+
return quick_actions[action_name]
|
| 162 |
+
return [gr.update()] * 6
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def run_baseline_agent(agent_type, state):
|
| 166 |
+
if not state or not state.initialized:
|
| 167 |
+
yield state, "Environment not initialized.", "", "", "", "", "", "", "", "", "", ""
|
| 168 |
+
return
|
| 169 |
+
|
| 170 |
+
from training.agents import RandomAgent, SequentialAgent, SmartAgent
|
| 171 |
+
agents = {"random": RandomAgent, "sequential": SequentialAgent, "smart": SmartAgent}
|
| 172 |
+
agent = agents[agent_type]()
|
| 173 |
+
t = TASKS[state.task_id]
|
| 174 |
+
|
| 175 |
+
obs = state.env.reset(task_id=state.task_id)
|
| 176 |
+
state.step_log = []
|
| 177 |
+
state.total_reward = 0.0
|
| 178 |
+
state.last_obs = obs
|
| 179 |
+
|
| 180 |
+
while not obs.done:
|
| 181 |
+
obs_dict = {
|
| 182 |
+
"status_code": obs.status_code, "response_body": obs.response_body,
|
| 183 |
+
"feedback": obs.feedback, "bugs_found_so_far": obs.bugs_found_so_far,
|
| 184 |
+
"coverage_summary": obs.coverage_summary, "known_resource_ids": obs.known_resource_ids,
|
| 185 |
+
"auth_tokens": obs.auth_tokens, "steps_taken": obs.steps_taken, "max_steps": obs.max_steps,
|
| 186 |
+
}
|
| 187 |
+
action = agent.act(obs_dict)
|
| 188 |
+
obs = state.env.step(action)
|
| 189 |
+
reward = obs.reward or 0.0
|
| 190 |
+
state.total_reward += reward
|
| 191 |
+
state.last_obs = obs
|
| 192 |
+
|
| 193 |
+
ms = action.method.value if hasattr(action.method, "value") else str(action.method)
|
| 194 |
+
state.step_log.append({
|
| 195 |
+
"step": obs.steps_taken, "method": ms, "endpoint": action.endpoint,
|
| 196 |
+
"status": obs.status_code, "reward": round(reward, 4), "bugs": obs.bugs_found_so_far,
|
| 197 |
+
})
|
| 198 |
+
|
| 199 |
+
resp_body = obs.response_body
|
| 200 |
+
if isinstance(resp_body, (dict, list)):
|
| 201 |
+
resp_str = json.dumps(resp_body, indent=2)
|
| 202 |
+
else:
|
| 203 |
+
resp_str = str(resp_body)
|
| 204 |
+
|
| 205 |
+
breakdown = obs.metadata.get("reward_breakdown", {})
|
| 206 |
+
reward_detail = format_reward_display(reward, state.total_reward, breakdown)
|
| 207 |
+
|
| 208 |
+
es = state.env.state
|
| 209 |
+
done_text = ""
|
| 210 |
+
if obs.done:
|
| 211 |
+
done_text = f"\n\n**EPISODE COMPLETE** β Final Score: {reward:.4f} | Bugs: {obs.bugs_found_so_far}/{t['total_bugs']}"
|
| 212 |
+
|
| 213 |
+
yield (
|
| 214 |
+
state,
|
| 215 |
+
f"[{agent_type}] {ms} {action.endpoint} -> {obs.status_code}{done_text}",
|
| 216 |
+
f"**{obs.status_code}**\n```json\n{resp_str[:500]}\n```",
|
| 217 |
+
reward_detail,
|
| 218 |
+
f"{obs.bugs_found_so_far} / {t['total_bugs']}",
|
| 219 |
+
format_coverage(obs.coverage_summary),
|
| 220 |
+
format_log(state.step_log),
|
| 221 |
+
f"{obs.steps_taken} / {obs.max_steps}" + (" (DONE)" if obs.done else ""),
|
| 222 |
+
format_bug_list(es.bugs_found_ids),
|
| 223 |
+
_generate_report(es.bugs_found_ids, state.step_log),
|
| 224 |
+
format_auth_tokens(obs.auth_tokens),
|
| 225 |
+
format_resources(obs.known_resource_ids),
|
| 226 |
+
)
|
| 227 |
+
time.sleep(0.3)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
# =====================================================================
|
| 231 |
+
# Formatters
|
| 232 |
+
# =====================================================================
|
| 233 |
+
|
| 234 |
+
def format_reward_display(step_reward, cumulative, breakdown):
|
| 235 |
+
"""Render reward metrics as styled HTML with explanations."""
|
| 236 |
+
components = [
|
| 237 |
+
("Coverage", breakdown.get("coverage", 0),
|
| 238 |
+
"Reward for testing new endpoints and methods"),
|
| 239 |
+
("Validity", breakdown.get("validity", 0),
|
| 240 |
+
"Reward for sending well-formed requests that return expected status codes"),
|
| 241 |
+
("Bug", breakdown.get("bug_discovery", 0),
|
| 242 |
+
"Bonus for discovering a new bug in the API"),
|
| 243 |
+
("Explore", breakdown.get("exploration", 0),
|
| 244 |
+
"Reward for trying new parameter combinations and edge cases"),
|
| 245 |
+
("Penalty", breakdown.get("penalty", 0),
|
| 246 |
+
"Deduction for repeated or invalid requests"),
|
| 247 |
+
]
|
| 248 |
+
bars = []
|
| 249 |
+
for label, value, tip in components:
|
| 250 |
+
val_color = "#16a34a" if value > 0 else "#dc2626" if value < 0 else "inherit"
|
| 251 |
+
bars.append(
|
| 252 |
+
f'<div style="display:flex;justify-content:space-between;align-items:center;'
|
| 253 |
+
f'padding:2px 0;font-size:0.82em;" title="{tip}">'
|
| 254 |
+
f'<span style="opacity:0.6;cursor:help;border-bottom:1px dotted currentColor;">'
|
| 255 |
+
f'{label}</span>'
|
| 256 |
+
f'<span style="color:{val_color};font-family:monospace;font-weight:600;">'
|
| 257 |
+
f'{value:+.3f}</span></div>'
|
| 258 |
+
)
|
| 259 |
+
cum_color = "#16a34a" if cumulative > 0 else "#dc2626" if cumulative < 0 else "inherit"
|
| 260 |
+
step_color = "#16a34a" if step_reward > 0 else "#dc2626" if step_reward < 0 else "inherit"
|
| 261 |
+
return (
|
| 262 |
+
f'<div style="display:flex;gap:16px;margin-bottom:8px;">'
|
| 263 |
+
f'<div style="flex:1;text-align:center;padding:6px;background:rgba(128,128,128,0.1);'
|
| 264 |
+
f'border-radius:8px;">'
|
| 265 |
+
f'<div style="font-size:0.72em;opacity:0.55;">STEP REWARD</div>'
|
| 266 |
+
f'<div style="font-size:1.3em;font-weight:700;color:{step_color};">'
|
| 267 |
+
f'{step_reward:+.4f}</div></div>'
|
| 268 |
+
f'<div style="flex:1;text-align:center;padding:6px;background:rgba(128,128,128,0.1);'
|
| 269 |
+
f'border-radius:8px;">'
|
| 270 |
+
f'<div style="font-size:0.72em;opacity:0.55;">CUMULATIVE</div>'
|
| 271 |
+
f'<div style="font-size:1.3em;font-weight:700;color:{cum_color};">'
|
| 272 |
+
f'{cumulative:.4f}</div></div></div>'
|
| 273 |
+
f'<div style="border:1px solid rgba(128,128,128,0.2);border-radius:8px;padding:6px 10px;">'
|
| 274 |
+
f'<div style="font-size:0.72em;opacity:0.5;margin-bottom:4px;">'
|
| 275 |
+
f'REWARD BREAKDOWN '
|
| 276 |
+
f'<span title="How the reward for the last step was calculated"'
|
| 277 |
+
f' style="cursor:help;">ⓘ</span></div>'
|
| 278 |
+
+ "".join(bars)
|
| 279 |
+
+ "</div>"
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def format_coverage(summary):
|
| 284 |
+
if not summary:
|
| 285 |
+
return "No data"
|
| 286 |
+
pct = summary.get("coverage_pct", 0)
|
| 287 |
+
tested = summary.get("endpoints_tested", 0)
|
| 288 |
+
total = summary.get("total_endpoints", 0)
|
| 289 |
+
pairs = summary.get("method_endpoint_pairs", 0)
|
| 290 |
+
codes = summary.get("status_codes_seen", [])
|
| 291 |
+
color = "#dc2626" if pct < 30 else "#d97706" if pct < 70 else "#16a34a"
|
| 292 |
+
bar_html = (
|
| 293 |
+
f'<div style="display:flex;align-items:center;gap:8px;margin:4px 0;">'
|
| 294 |
+
f'<div style="flex:1;background:rgba(128,128,128,0.15);border-radius:6px;height:14px;overflow:hidden;">'
|
| 295 |
+
f'<div style="width:{pct:.1f}%;height:100%;background:{color};border-radius:6px;'
|
| 296 |
+
f'transition:width 0.3s ease;"></div></div>'
|
| 297 |
+
f'<span style="font-weight:700;min-width:48px;text-align:right;">{pct:.1f}%</span></div>'
|
| 298 |
+
)
|
| 299 |
+
code_pills = ""
|
| 300 |
+
for c in codes:
|
| 301 |
+
cc = "#16a34a" if 200 <= c < 300 else "#d97706" if 300 <= c < 400 else "#dc2626"
|
| 302 |
+
code_pills += (
|
| 303 |
+
f'<span style="background:{cc}18;color:{cc};padding:1px 7px;border-radius:10px;'
|
| 304 |
+
f'font-size:0.78em;font-weight:600;margin-right:4px;">{c}</span>'
|
| 305 |
+
)
|
| 306 |
+
return (
|
| 307 |
+
f"{bar_html}"
|
| 308 |
+
f'<div style="display:flex;gap:10px;margin:6px 0;font-size:0.82em;">'
|
| 309 |
+
f'<div style="flex:1;text-align:center;padding:4px;background:rgba(128,128,128,0.1);border-radius:6px;"'
|
| 310 |
+
f' title="How many unique API endpoints have been called">'
|
| 311 |
+
f'<div style="font-size:0.72em;opacity:0.5;">ENDPOINTS</div>'
|
| 312 |
+
f'<div style="font-weight:700;">{tested}/{total}</div></div>'
|
| 313 |
+
f'<div style="flex:1;text-align:center;padding:4px;background:rgba(128,128,128,0.1);border-radius:6px;"'
|
| 314 |
+
f' title="Unique combinations of HTTP method + endpoint path tested">'
|
| 315 |
+
f'<div style="font-size:0.72em;opacity:0.5;">METHOD+PATH</div>'
|
| 316 |
+
f'<div style="font-weight:700;">{pairs}</div></div></div>'
|
| 317 |
+
f'<div style="margin-top:4px;" title="HTTP status codes received from the API so far">'
|
| 318 |
+
f'<span style="font-size:0.72em;opacity:0.5;">STATUS CODES SEEN </span>'
|
| 319 |
+
f'{code_pills}</div>'
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def format_log(log):
|
| 324 |
+
if not log:
|
| 325 |
+
return (
|
| 326 |
+
'<div style="opacity:0.55;font-size:0.85em;">'
|
| 327 |
+
"Each row shows an API request the agent made, the HTTP status it got back, "
|
| 328 |
+
"and the reward earned. Green = positive reward, red = penalty."
|
| 329 |
+
"</div>"
|
| 330 |
+
)
|
| 331 |
+
method_colors = {
|
| 332 |
+
"GET": "#2563eb", "POST": "#16a34a", "PUT": "#d97706",
|
| 333 |
+
"DELETE": "#dc2626", "PATCH": "#9333ea",
|
| 334 |
+
}
|
| 335 |
+
rows = []
|
| 336 |
+
for entry in log[-20:]:
|
| 337 |
+
m = entry["method"]
|
| 338 |
+
mcol = method_colors.get(m, "#6b7280")
|
| 339 |
+
r = entry["reward"]
|
| 340 |
+
rcol = "#16a34a" if r > 0 else "#dc2626" if r < 0 else "inherit"
|
| 341 |
+
bug_tag = (
|
| 342 |
+
'<span style="background:#92400e;color:#fef08a;padding:0 5px;border-radius:4px;'
|
| 343 |
+
'font-size:0.7em;margin-left:4px;">BUG FOUND</span>'
|
| 344 |
+
) if r > 0.2 else ""
|
| 345 |
+
status = entry["status"]
|
| 346 |
+
scol = "#16a34a" if 200 <= status < 300 else "#d97706" if 300 <= status < 400 else "#dc2626"
|
| 347 |
+
rows.append(
|
| 348 |
+
f'<div style="display:flex;align-items:center;gap:6px;padding:3px 0;'
|
| 349 |
+
f'border-bottom:1px solid rgba(128,128,128,0.1);font-size:0.82em;">'
|
| 350 |
+
f'<span style="opacity:0.45;min-width:20px;text-align:right;">{entry["step"]}</span>'
|
| 351 |
+
f'<span style="background:{mcol}18;color:{mcol};padding:1px 6px;border-radius:4px;'
|
| 352 |
+
f'font-weight:600;font-size:0.8em;min-width:52px;text-align:center;">{m}</span>'
|
| 353 |
+
f'<span style="flex:1;overflow:hidden;text-overflow:ellipsis;'
|
| 354 |
+
f'white-space:nowrap;">{entry["endpoint"]}</span>'
|
| 355 |
+
f'<span style="color:{scol};font-weight:600;min-width:28px;text-align:right;">{status}</span>'
|
| 356 |
+
f'<span style="color:{rcol};min-width:52px;text-align:right;font-family:monospace;'
|
| 357 |
+
f'font-size:0.85em;">{r:+.3f}</span>{bug_tag}</div>'
|
| 358 |
+
)
|
| 359 |
+
omitted = ""
|
| 360 |
+
if len(log) > 20:
|
| 361 |
+
omitted = (
|
| 362 |
+
f'<div style="opacity:0.45;font-size:0.78em;padding:4px 0;text-align:center;">'
|
| 363 |
+
f'... {len(log) - 20} earlier steps not shown</div>'
|
| 364 |
+
)
|
| 365 |
+
header = (
|
| 366 |
+
'<div style="opacity:0.55;font-size:0.78em;margin-bottom:6px;">'
|
| 367 |
+
"API requests made by the agent. Each row: step number, HTTP method, "
|
| 368 |
+
"endpoint, status code, and reward earned.</div>"
|
| 369 |
+
'<div style="display:flex;gap:6px;padding:2px 0 6px;border-bottom:1px solid rgba(128,128,128,0.2);'
|
| 370 |
+
'font-size:0.75em;opacity:0.5;">'
|
| 371 |
+
'<span style="min-width:20px;text-align:right;">#</span>'
|
| 372 |
+
'<span style="min-width:52px;text-align:center;">Method</span>'
|
| 373 |
+
'<span style="flex:1;">Endpoint</span>'
|
| 374 |
+
'<span style="min-width:28px;text-align:right;">Status</span>'
|
| 375 |
+
'<span style="min-width:52px;text-align:right;">Reward</span></div>'
|
| 376 |
+
)
|
| 377 |
+
return header + omitted + "\n".join(rows)
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def format_bug_list(bug_ids):
|
| 381 |
+
if not bug_ids:
|
| 382 |
+
return "No bugs found yet."
|
| 383 |
+
from server.bug_detector import BugDetector
|
| 384 |
+
detector = BugDetector("security_workflows")
|
| 385 |
+
severity_colors = {
|
| 386 |
+
"easy": "#16a34a",
|
| 387 |
+
"medium": "#d97706",
|
| 388 |
+
"hard": "#dc2626",
|
| 389 |
+
}
|
| 390 |
+
cards = []
|
| 391 |
+
for bid in sorted(bug_ids):
|
| 392 |
+
bug = detector.bugs.get(bid)
|
| 393 |
+
if bug:
|
| 394 |
+
fg = severity_colors.get(bug.severity, "#6b7280")
|
| 395 |
+
owasp_badge = f' | {bug.owasp.split(" ")[0]}' if bug.owasp else ""
|
| 396 |
+
cards.append(
|
| 397 |
+
f'<div style="border:1px solid {fg}40;border-radius:8px;padding:8px 10px;'
|
| 398 |
+
f'margin-bottom:6px;background:{fg}0d;">'
|
| 399 |
+
f'<div style="display:flex;justify-content:space-between;align-items:center;">'
|
| 400 |
+
f'<span style="font-weight:700;font-size:0.85em;">{bid}</span>'
|
| 401 |
+
f'<span style="background:{fg};color:#fff;padding:1px 8px;border-radius:10px;'
|
| 402 |
+
f'font-size:0.75em;font-weight:600;">{bug.severity.upper()}{owasp_badge}</span></div>'
|
| 403 |
+
f'<div style="margin-top:4px;font-size:0.85em;opacity:0.7;">'
|
| 404 |
+
f'{bug.description}</div>'
|
| 405 |
+
f'<div style="margin-top:2px;font-size:0.78em;opacity:0.5;font-style:italic;">'
|
| 406 |
+
f'{bug.owasp}</div></div>'
|
| 407 |
+
)
|
| 408 |
+
return "\n".join(cards)
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
def format_auth_tokens(tokens):
|
| 412 |
+
if not tokens:
|
| 413 |
+
return (
|
| 414 |
+
'<div style="opacity:0.5;font-size:0.85em;">'
|
| 415 |
+
"No tokens yet. Login via <code>POST /auth/login</code> to get auth tokens "
|
| 416 |
+
"for testing protected endpoints.</div>"
|
| 417 |
+
)
|
| 418 |
+
cards = []
|
| 419 |
+
for user, token in tokens.items():
|
| 420 |
+
cards.append(
|
| 421 |
+
f'<div style="display:flex;align-items:center;gap:8px;padding:4px 0;'
|
| 422 |
+
f'border-bottom:1px solid rgba(128,128,128,0.1);font-size:0.85em;">'
|
| 423 |
+
f'<span style="background:#2563eb18;color:#2563eb;padding:1px 8px;border-radius:10px;'
|
| 424 |
+
f'font-weight:600;font-size:0.8em;">{user}</span>'
|
| 425 |
+
f'<code style="opacity:0.55;font-size:0.82em;">{token[:20]}...</code></div>'
|
| 426 |
+
)
|
| 427 |
+
return (
|
| 428 |
+
'<div style="font-size:0.72em;opacity:0.5;margin-bottom:4px;"'
|
| 429 |
+
' title="Auth tokens obtained by logging in. Use these in the Authorization header.">'
|
| 430 |
+
"AUTHENTICATED USERS</div>"
|
| 431 |
+
+ "".join(cards)
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def format_resources(ids):
|
| 436 |
+
if not ids:
|
| 437 |
+
return (
|
| 438 |
+
'<div style="opacity:0.5;font-size:0.85em;">'
|
| 439 |
+
"No resources created. Use POST endpoints to create tasks or users "
|
| 440 |
+
"and track their IDs here.</div>"
|
| 441 |
+
)
|
| 442 |
+
sections = []
|
| 443 |
+
type_colors = {"tasks": "#d97706", "users": "#2563eb"}
|
| 444 |
+
for rtype, id_list in ids.items():
|
| 445 |
+
color = type_colors.get(rtype, "#6b7280")
|
| 446 |
+
ids_str = ", ".join(str(i) for i in id_list) if isinstance(id_list, list) else str(id_list)
|
| 447 |
+
sections.append(
|
| 448 |
+
f'<div style="padding:4px 0;border-bottom:1px solid rgba(128,128,128,0.1);font-size:0.85em;">'
|
| 449 |
+
f'<span style="background:{color}18;color:{color};padding:1px 8px;border-radius:10px;'
|
| 450 |
+
f'font-weight:600;font-size:0.8em;text-transform:uppercase;">{rtype}</span>'
|
| 451 |
+
f'<span style="margin-left:8px;opacity:0.7;">IDs: {ids_str}</span></div>'
|
| 452 |
+
)
|
| 453 |
+
return (
|
| 454 |
+
'<div style="font-size:0.72em;opacity:0.5;margin-bottom:4px;"'
|
| 455 |
+
' title="Resources created during this episode. Use these IDs in GET/PUT/DELETE requests.">'
|
| 456 |
+
"CREATED RESOURCES</div>"
|
| 457 |
+
+ "".join(sections)
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
def format_endpoints():
|
| 462 |
+
lines = []
|
| 463 |
+
for ep in API_SPEC:
|
| 464 |
+
lines.append(f"**{ep['method']}** `{ep['path']}` β {ep.get('summary', '')}")
|
| 465 |
+
return "\n\n".join(lines)
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
# =====================================================================
|
| 469 |
+
# UI
|
| 470 |
+
# =====================================================================
|
| 471 |
+
|
| 472 |
+
def build_ui():
|
| 473 |
+
with gr.Blocks(title="API Testing Environment") as demo:
|
| 474 |
+
session = gr.State(value=new_session())
|
| 475 |
+
|
| 476 |
+
gr.Markdown(
|
| 477 |
+
"# API Testing Environment\n"
|
| 478 |
+
"An OpenEnv RL environment that trains AI agents to become automated **API security testers**. "
|
| 479 |
+
"A simulated API server with **13 hidden vulnerabilities** mapped to the **OWASP API Security Top 10** is provided. "
|
| 480 |
+
"Send HTTP requests, earn rewards for finding bugs and covering endpoints, and generate a **bug bounty report** at episode end. "
|
| 481 |
+
"Use **Manual Testing** to craft requests yourself, or run a **Baseline Agent** to watch an automated strategy."
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
with gr.Row():
|
| 485 |
+
# ββ Left Panel ββ
|
| 486 |
+
with gr.Column(scale=1):
|
| 487 |
+
gr.Markdown("### Environment Control")
|
| 488 |
+
task_dropdown = gr.Dropdown(choices=list(TASKS.keys()), value="basic_validation", label="Select Task")
|
| 489 |
+
reset_btn = gr.Button("Reset Environment", variant="primary", size="lg")
|
| 490 |
+
gr.Markdown(
|
| 491 |
+
'<span style="font-size:0.8em;opacity:0.55;">'
|
| 492 |
+
"Switch task or click Reset to start a fresh episode. "
|
| 493 |
+
"Resets all scores, bugs, and step count.</span>"
|
| 494 |
+
)
|
| 495 |
+
status_box = gr.Markdown("Initializing...")
|
| 496 |
+
|
| 497 |
+
gr.Markdown("---")
|
| 498 |
+
gr.Markdown("### Scoreboard")
|
| 499 |
+
gr.Markdown(
|
| 500 |
+
'<span style="font-size:0.78em;opacity:0.55;">'
|
| 501 |
+
"Tracks your testing progress. Steps are API calls you've made; "
|
| 502 |
+
"bugs are issues discovered in the API; reward measures how well "
|
| 503 |
+
"the agent is testing.</span>"
|
| 504 |
+
)
|
| 505 |
+
with gr.Row():
|
| 506 |
+
step_display = gr.Markdown("0 / 25", label="Steps")
|
| 507 |
+
bug_display = gr.Markdown("0 / 3", label="Bugs")
|
| 508 |
+
reward_display = gr.Markdown(format_reward_display(0, 0, {}), label="Reward")
|
| 509 |
+
coverage_display = gr.Markdown("No data", label="Coverage")
|
| 510 |
+
|
| 511 |
+
gr.Markdown("---")
|
| 512 |
+
gr.Markdown("### Session Context")
|
| 513 |
+
gr.Markdown(
|
| 514 |
+
'<span style="font-size:0.78em;opacity:0.55;">'
|
| 515 |
+
"Tokens and resources gathered during this episode. "
|
| 516 |
+
"Use tokens to test auth-protected endpoints and resource IDs for "
|
| 517 |
+
"GET/PUT/DELETE requests.</span>"
|
| 518 |
+
)
|
| 519 |
+
auth_display = gr.Markdown(format_auth_tokens({}))
|
| 520 |
+
resource_display = gr.Markdown(format_resources({}))
|
| 521 |
+
|
| 522 |
+
gr.Markdown("---")
|
| 523 |
+
with gr.Accordion("API Specification", open=False):
|
| 524 |
+
gr.Markdown(format_endpoints())
|
| 525 |
+
|
| 526 |
+
# ββ Center Panel ββ
|
| 527 |
+
with gr.Column(scale=2):
|
| 528 |
+
with gr.Tabs():
|
| 529 |
+
with gr.Tab("Manual Testing"):
|
| 530 |
+
gr.Markdown("### Craft Your Request")
|
| 531 |
+
with gr.Row():
|
| 532 |
+
method_input = gr.Dropdown(
|
| 533 |
+
choices=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
| 534 |
+
value="GET", label="Method", scale=1,
|
| 535 |
+
)
|
| 536 |
+
endpoint_input = gr.Textbox(value="/tasks", label="Endpoint", placeholder="/tasks, /users/1, /auth/login", scale=3)
|
| 537 |
+
expected_input = gr.Textbox(value="200", label="Expected Status", placeholder="200", scale=1)
|
| 538 |
+
|
| 539 |
+
with gr.Row():
|
| 540 |
+
headers_input = gr.Textbox(value="{}", label="Headers (JSON)", placeholder='{"Authorization": "Bearer ..."}', lines=1)
|
| 541 |
+
params_input = gr.Textbox(value="{}", label="Query Params (JSON)", placeholder='{"page": 1, "limit": 10}', lines=1)
|
| 542 |
+
|
| 543 |
+
body_input = gr.Textbox(value="", label="Request Body (JSON)", placeholder='{"title": "My Task", "description": "..."}', lines=3)
|
| 544 |
+
|
| 545 |
+
send_btn = gr.Button("Send Request", variant="primary", size="lg")
|
| 546 |
+
|
| 547 |
+
gr.Markdown("### Quick Actions")
|
| 548 |
+
quick_actions = gr.Dropdown(
|
| 549 |
+
choices=[
|
| 550 |
+
"GET /tasks", "GET /users", "GET /tasks/1",
|
| 551 |
+
"GET /tasks/999999 (bug hunt)", "POST create task",
|
| 552 |
+
"POST missing title (bug hunt)", "Login as alice", "Login as bob",
|
| 553 |
+
"Login empty pwd (bug hunt)", "Negative page (bug hunt)",
|
| 554 |
+
"Huge limit (bug hunt)", "Invalid email PUT (bug hunt)",
|
| 555 |
+
"DELETE non-existent (bug hunt)", "Create user invalid email (bug)",
|
| 556 |
+
"SQL injection test", "Long title crash (bug hunt)",
|
| 557 |
+
],
|
| 558 |
+
label="Quick Actions", value=None,
|
| 559 |
+
)
|
| 560 |
+
quick_btn = gr.Button("Load Quick Action", variant="secondary")
|
| 561 |
+
|
| 562 |
+
with gr.Tab("Run Baseline Agent"):
|
| 563 |
+
gr.Markdown("### Automated Agents\nWatch a baseline agent test the API step by step.")
|
| 564 |
+
agent_dropdown = gr.Dropdown(choices=["random", "sequential", "smart"], value="smart", label="Agent Type")
|
| 565 |
+
run_agent_btn = gr.Button("Run Agent", variant="primary", size="lg")
|
| 566 |
+
|
| 567 |
+
gr.Markdown("---")
|
| 568 |
+
gr.Markdown("### Response")
|
| 569 |
+
response_display = gr.Markdown("")
|
| 570 |
+
|
| 571 |
+
gr.Markdown("### Feedback")
|
| 572 |
+
feedback_display = gr.Markdown("")
|
| 573 |
+
|
| 574 |
+
# ββ Right Panel ββ
|
| 575 |
+
with gr.Column(scale=1):
|
| 576 |
+
with gr.Tabs():
|
| 577 |
+
with gr.Tab("Discovered Bugs"):
|
| 578 |
+
bug_list_display = gr.Markdown("No bugs found yet.")
|
| 579 |
+
|
| 580 |
+
with gr.Tab("Bug Report"):
|
| 581 |
+
gr.Markdown("*Auto-generated OWASP security report. Populates as bugs are found.*")
|
| 582 |
+
bug_report_display = gr.Markdown("No bugs found yet. Send requests to discover vulnerabilities.")
|
| 583 |
+
|
| 584 |
+
with gr.Tab("Activity Log"):
|
| 585 |
+
log_display = gr.Markdown("No steps yet.")
|
| 586 |
+
|
| 587 |
+
# ββ Wiring ββ
|
| 588 |
+
reset_outputs = [
|
| 589 |
+
session, status_box, feedback_display, response_display,
|
| 590 |
+
reward_display, bug_display, coverage_display, log_display,
|
| 591 |
+
step_display, bug_list_display, bug_report_display, auth_display, resource_display,
|
| 592 |
+
]
|
| 593 |
+
|
| 594 |
+
step_outputs = [
|
| 595 |
+
session, feedback_display, response_display, reward_display,
|
| 596 |
+
bug_display, coverage_display, log_display, step_display,
|
| 597 |
+
bug_list_display, bug_report_display, auth_display, resource_display,
|
| 598 |
+
]
|
| 599 |
+
|
| 600 |
+
reset_btn.click(fn=reset_env, inputs=[task_dropdown, session], outputs=reset_outputs)
|
| 601 |
+
|
| 602 |
+
send_btn.click(
|
| 603 |
+
fn=send_request,
|
| 604 |
+
inputs=[method_input, endpoint_input, headers_input, params_input, body_input, expected_input, session],
|
| 605 |
+
outputs=step_outputs,
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
quick_btn.click(
|
| 609 |
+
fn=apply_quick_action, inputs=[quick_actions, session],
|
| 610 |
+
outputs=[method_input, endpoint_input, headers_input, params_input, body_input, expected_input],
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
run_agent_btn.click(fn=run_baseline_agent, inputs=[agent_dropdown, session], outputs=step_outputs)
|
| 614 |
+
|
| 615 |
+
# Auto-reset on page load so users can start testing immediately
|
| 616 |
+
demo.load(fn=reset_env, inputs=[task_dropdown, session], outputs=reset_outputs)
|
| 617 |
+
|
| 618 |
+
return demo
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
if __name__ == "__main__":
|
| 622 |
+
parser = argparse.ArgumentParser()
|
| 623 |
+
parser.add_argument("--port", type=int, default=int(os.getenv("GRADIO_SERVER_PORT", "7860")))
|
| 624 |
+
parser.add_argument("--host", default="0.0.0.0")
|
| 625 |
+
parser.add_argument("--share", action="store_true")
|
| 626 |
+
args = parser.parse_args()
|
| 627 |
+
build_ui().launch(server_name=args.host, server_port=args.port, share=args.share)
|
inference.py
ADDED
|
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
inference.py β OpenEnv API Testing Environment baseline inference script.
|
| 4 |
+
|
| 5 |
+
Runs an LLM agent against the API Testing Environment for all 3 tasks
|
| 6 |
+
(basic_validation -> edge_cases -> security_workflows) and emits the
|
| 7 |
+
mandatory [START]/[STEP]/[END] stdout format used by the OpenEnv judging
|
| 8 |
+
pipeline.
|
| 9 |
+
|
| 10 |
+
Required env vars (per OpenEnv submission spec):
|
| 11 |
+
API_BASE_URL The OpenAI-compatible LLM endpoint
|
| 12 |
+
MODEL_NAME The model identifier to use for inference
|
| 13 |
+
HF_TOKEN Bearer token for the LLM endpoint (or API_KEY)
|
| 14 |
+
|
| 15 |
+
Optional env vars:
|
| 16 |
+
IMAGE_NAME Docker image to spin up the env via from_docker_image()
|
| 17 |
+
LOCAL_IMAGE_NAME Alias for IMAGE_NAME
|
| 18 |
+
ENV_BASE_URL URL of an already-running env server (e.g. http://localhost:8000)
|
| 19 |
+
INFERENCE_TASKS Comma-separated subset of tasks to run (default: all 3)
|
| 20 |
+
INFERENCE_MAX_STEPS Override max steps per task
|
| 21 |
+
INFERENCE_TEMPERATURE Default 0.4
|
| 22 |
+
INFERENCE_MAX_TOKENS Default 4096 (plan completions need room for ~25 actions)
|
| 23 |
+
|
| 24 |
+
The script uses PLAN MODE: one LLM call per task produces a complete JSON
|
| 25 |
+
test plan, then the env executes each action sequentially. This matches the
|
| 26 |
+
GRPO training distribution and keeps total LLM cost to 3 calls per run, so
|
| 27 |
+
the script comfortably runs under 20 min on 2 vCPU / 8 GB RAM.
|
| 28 |
+
|
| 29 |
+
Usage:
|
| 30 |
+
# Local in-process (no Docker, fastest)
|
| 31 |
+
python inference.py
|
| 32 |
+
|
| 33 |
+
# Against a built docker image
|
| 34 |
+
IMAGE_NAME=api-testing-env:latest python inference.py
|
| 35 |
+
|
| 36 |
+
# Against an already running server
|
| 37 |
+
ENV_BASE_URL=http://localhost:8000 python inference.py
|
| 38 |
+
|
| 39 |
+
# Against a deployed HF Space
|
| 40 |
+
ENV_BASE_URL=https://your-user-api-testing-env.hf.space python inference.py
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
import json
|
| 44 |
+
import os
|
| 45 |
+
import sys
|
| 46 |
+
import time
|
| 47 |
+
import traceback
|
| 48 |
+
from typing import Any, Optional
|
| 49 |
+
|
| 50 |
+
# Make sibling modules importable when run from the repo root
|
| 51 |
+
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 52 |
+
if _THIS_DIR not in sys.path:
|
| 53 |
+
sys.path.insert(0, _THIS_DIR)
|
| 54 |
+
|
| 55 |
+
# Auto-load .env file if present (for local development)
|
| 56 |
+
# Judges set env vars directly so this is harmless in production
|
| 57 |
+
try:
|
| 58 |
+
from dotenv import load_dotenv
|
| 59 |
+
_env_path = os.path.join(_THIS_DIR, ".env")
|
| 60 |
+
if os.path.exists(_env_path):
|
| 61 |
+
load_dotenv(_env_path)
|
| 62 |
+
except ImportError:
|
| 63 |
+
pass # python-dotenv is optional
|
| 64 |
+
|
| 65 |
+
from openai import OpenAI
|
| 66 |
+
|
| 67 |
+
from models import APITestAction, HTTPMethod # noqa: E402
|
| 68 |
+
from training.prompts import ( # noqa: E402
|
| 69 |
+
PLAN_SYSTEM_PROMPT,
|
| 70 |
+
format_plan_prompt,
|
| 71 |
+
parse_test_plan,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# ---------------------------------------------------------------------------
|
| 76 |
+
# Config (env vars per OpenEnv spec)
|
| 77 |
+
# ---------------------------------------------------------------------------
|
| 78 |
+
|
| 79 |
+
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 80 |
+
# Default model: must be available on the HuggingFace Inference Router.
|
| 81 |
+
# Llama-3.3-70B-Instruct is reliable, follows JSON instructions well, and free.
|
| 82 |
+
# Override via: MODEL_NAME=other/model python inference.py
|
| 83 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.3-70B-Instruct")
|
| 84 |
+
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
|
| 85 |
+
|
| 86 |
+
if not API_KEY:
|
| 87 |
+
print(
|
| 88 |
+
"[ERROR] No HF_TOKEN or API_KEY found in environment.\n"
|
| 89 |
+
" Set one of:\n"
|
| 90 |
+
" export HF_TOKEN=hf_xxxxxxxxxxxxxxxxxxxx\n"
|
| 91 |
+
" Or create a .env file in this directory with:\n"
|
| 92 |
+
" HF_TOKEN=hf_xxxxxxxxxxxxxxxxxxxx\n"
|
| 93 |
+
" Get a token from: https://huggingface.co/settings/tokens\n"
|
| 94 |
+
" Make sure it has 'Make calls to Inference Providers' permission.",
|
| 95 |
+
file=sys.stderr,
|
| 96 |
+
)
|
| 97 |
+
sys.exit(1)
|
| 98 |
+
|
| 99 |
+
IMAGE_NAME = os.getenv("IMAGE_NAME") or os.getenv("LOCAL_IMAGE_NAME")
|
| 100 |
+
ENV_BASE_URL = os.getenv("ENV_BASE_URL")
|
| 101 |
+
|
| 102 |
+
BENCHMARK = "api_testing_env"
|
| 103 |
+
DEFAULT_TASKS = ["basic_validation", "edge_cases", "security_workflows"]
|
| 104 |
+
TASKS = [t.strip() for t in os.getenv("INFERENCE_TASKS", ",".join(DEFAULT_TASKS)).split(",") if t.strip()]
|
| 105 |
+
|
| 106 |
+
TEMPERATURE = float(os.getenv("INFERENCE_TEMPERATURE", "0.4"))
|
| 107 |
+
MAX_TOKENS = int(os.getenv("INFERENCE_MAX_TOKENS", "4096"))
|
| 108 |
+
_MAX_STEPS_OVERRIDE = os.getenv("INFERENCE_MAX_STEPS")
|
| 109 |
+
MAX_STEPS_OVERRIDE: Optional[int] = int(_MAX_STEPS_OVERRIDE) if _MAX_STEPS_OVERRIDE else None
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# ---------------------------------------------------------------------------
|
| 113 |
+
# Strict stdout logging β these line formats are checked by the judge
|
| 114 |
+
# ---------------------------------------------------------------------------
|
| 115 |
+
|
| 116 |
+
def log_start(task: str, env: str, model: str) -> None:
|
| 117 |
+
print(f"[START] task={task} env={env} model={model}", flush=True)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
|
| 121 |
+
print(
|
| 122 |
+
f"[STEP] step={step} action={action} reward={reward:.2f} "
|
| 123 |
+
f"done={str(done).lower()} error={error if error else 'null'}",
|
| 124 |
+
flush=True,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None:
|
| 129 |
+
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
|
| 130 |
+
print(
|
| 131 |
+
f"[END] success={str(success).lower()} steps={steps} "
|
| 132 |
+
f"score={score:.3f} rewards={rewards_str}",
|
| 133 |
+
flush=True,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def _action_str(action: APITestAction) -> str:
|
| 138 |
+
"""Compact human-readable action label for the [STEP] line."""
|
| 139 |
+
method = action.method.value if hasattr(action.method, "value") else str(action.method)
|
| 140 |
+
return f"{method}_{action.endpoint}"
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
# ---------------------------------------------------------------------------
|
| 144 |
+
# LLM call β plan mode (one completion per task)
|
| 145 |
+
# ---------------------------------------------------------------------------
|
| 146 |
+
|
| 147 |
+
def get_plan_from_llm(client: OpenAI, observation) -> str:
|
| 148 |
+
"""Ask the LLM for a complete JSON test plan for this task.
|
| 149 |
+
|
| 150 |
+
Wraps the array in {"actions": [...]} so we can use OpenAI structured
|
| 151 |
+
output mode (`response_format={"type": "json_object"}`), which forces
|
| 152 |
+
the LLM to produce valid JSON. This is much more reliable than asking
|
| 153 |
+
for a raw JSON array.
|
| 154 |
+
"""
|
| 155 |
+
user_prompt = format_plan_prompt(observation)
|
| 156 |
+
|
| 157 |
+
# Stronger system prompt for structured output mode
|
| 158 |
+
system_prompt = (
|
| 159 |
+
PLAN_SYSTEM_PROMPT
|
| 160 |
+
+ "\n\nIMPORTANT: Output a JSON object with a single key 'actions' "
|
| 161 |
+
+ "containing the array of actions:\n"
|
| 162 |
+
+ '{"actions": [{"method": "GET", "endpoint": "/tasks", "headers": {}, '
|
| 163 |
+
+ '"query_params": {}, "body": null, "expected_status": 200}, ...]}'
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
try:
|
| 167 |
+
completion = client.chat.completions.create(
|
| 168 |
+
model=MODEL_NAME,
|
| 169 |
+
messages=[
|
| 170 |
+
{"role": "system", "content": system_prompt},
|
| 171 |
+
{"role": "user", "content": user_prompt},
|
| 172 |
+
],
|
| 173 |
+
temperature=TEMPERATURE,
|
| 174 |
+
max_tokens=MAX_TOKENS,
|
| 175 |
+
response_format={"type": "json_object"}, # forces valid JSON
|
| 176 |
+
stream=False,
|
| 177 |
+
)
|
| 178 |
+
text = (completion.choices[0].message.content or "").strip()
|
| 179 |
+
print(f"[DEBUG] LLM response length: {len(text)} chars", flush=True)
|
| 180 |
+
if len(text) > 0:
|
| 181 |
+
preview = text[:300].replace("\n", " ")
|
| 182 |
+
print(f"[DEBUG] LLM response preview: {preview}...", flush=True)
|
| 183 |
+
else:
|
| 184 |
+
print(f"[DEBUG] LLM returned EMPTY string", flush=True)
|
| 185 |
+
if hasattr(completion, "choices") and completion.choices:
|
| 186 |
+
finish_reason = getattr(completion.choices[0], "finish_reason", None)
|
| 187 |
+
print(f"[DEBUG] finish_reason: {finish_reason}", flush=True)
|
| 188 |
+
return text
|
| 189 |
+
except Exception as exc: # noqa: BLE001
|
| 190 |
+
print(f"[DEBUG] structured-output call failed ({type(exc).__name__}: {exc}), retrying without response_format...", flush=True)
|
| 191 |
+
# Some providers don't support response_format β fall back to plain text
|
| 192 |
+
try:
|
| 193 |
+
completion = client.chat.completions.create(
|
| 194 |
+
model=MODEL_NAME,
|
| 195 |
+
messages=[
|
| 196 |
+
{"role": "system", "content": PLAN_SYSTEM_PROMPT},
|
| 197 |
+
{"role": "user", "content": user_prompt},
|
| 198 |
+
],
|
| 199 |
+
temperature=TEMPERATURE,
|
| 200 |
+
max_tokens=MAX_TOKENS,
|
| 201 |
+
stream=False,
|
| 202 |
+
)
|
| 203 |
+
text = (completion.choices[0].message.content or "").strip()
|
| 204 |
+
print(f"[DEBUG] fallback LLM response length: {len(text)} chars", flush=True)
|
| 205 |
+
return text
|
| 206 |
+
except Exception as exc2: # noqa: BLE001
|
| 207 |
+
print(f"[DEBUG] fallback LLM call failed: {type(exc2).__name__}: {exc2}", flush=True)
|
| 208 |
+
return ""
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
# ---------------------------------------------------------------------------
|
| 212 |
+
# Per-task scoring helper β keeps the score in [0, 1]
|
| 213 |
+
# ---------------------------------------------------------------------------
|
| 214 |
+
|
| 215 |
+
def compute_task_score(state, total_step_reward: float) -> float:
|
| 216 |
+
"""Combine grader signals into a single normalized score in [0, 1].
|
| 217 |
+
|
| 218 |
+
The server already runs `TaskGrader.grade(...)` at episode end and adds
|
| 219 |
+
that score (already in [0, 1]) on top of the last step reward. We do
|
| 220 |
+
NOT trust the raw step rewards β those are sums of partial signals and
|
| 221 |
+
can exceed 1.0. Instead we derive the score from the published state:
|
| 222 |
+
score = 0.7 * (bugs_found / total_bugs) + 0.3 * (coverage_pct / 100)
|
| 223 |
+
which is bounded in [0, 1] and rewards both finding bugs and coverage.
|
| 224 |
+
"""
|
| 225 |
+
bugs_found = getattr(state, "bugs_found", 0) or 0
|
| 226 |
+
total_bugs = getattr(state, "total_bugs", 0) or 0
|
| 227 |
+
coverage_pct = getattr(state, "coverage_pct", 0.0) or 0.0
|
| 228 |
+
|
| 229 |
+
bug_ratio = (bugs_found / total_bugs) if total_bugs > 0 else 0.0
|
| 230 |
+
coverage_ratio = max(0.0, min(1.0, coverage_pct / 100.0))
|
| 231 |
+
|
| 232 |
+
score = 0.70 * bug_ratio + 0.30 * coverage_ratio
|
| 233 |
+
return max(0.0, min(1.0, score))
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
# ---------------------------------------------------------------------------
|
| 237 |
+
# Environment connector β supports docker / remote / in-process
|
| 238 |
+
# ---------------------------------------------------------------------------
|
| 239 |
+
|
| 240 |
+
class _EnvHandle:
|
| 241 |
+
"""Thin wrapper that exposes a uniform reset/step/state/close API.
|
| 242 |
+
|
| 243 |
+
Three modes, picked automatically:
|
| 244 |
+
1. IMAGE_NAME set -> APITestEnv.from_docker_image(IMAGE_NAME)
|
| 245 |
+
2. ENV_BASE_URL set -> APITestEnv(base_url=ENV_BASE_URL)
|
| 246 |
+
3. neither set (default) -> APITestEnvironment() in-process
|
| 247 |
+
"""
|
| 248 |
+
|
| 249 |
+
def __init__(self):
|
| 250 |
+
self._mode: str = ""
|
| 251 |
+
self._client = None # remote/docker client
|
| 252 |
+
self._env = None # in-process env
|
| 253 |
+
|
| 254 |
+
def open(self):
|
| 255 |
+
if IMAGE_NAME:
|
| 256 |
+
from client import APITestEnv
|
| 257 |
+
self._mode = "docker"
|
| 258 |
+
self._client = APITestEnv.from_docker_image(IMAGE_NAME)
|
| 259 |
+
elif ENV_BASE_URL:
|
| 260 |
+
from client import APITestEnv
|
| 261 |
+
self._mode = "remote"
|
| 262 |
+
self._client = APITestEnv(base_url=ENV_BASE_URL)
|
| 263 |
+
if hasattr(self._client, "connect"):
|
| 264 |
+
self._client.connect()
|
| 265 |
+
else:
|
| 266 |
+
from server.environment import APITestEnvironment
|
| 267 |
+
self._mode = "local"
|
| 268 |
+
self._env = APITestEnvironment()
|
| 269 |
+
return self
|
| 270 |
+
|
| 271 |
+
@property
|
| 272 |
+
def mode(self) -> str:
|
| 273 |
+
return self._mode
|
| 274 |
+
|
| 275 |
+
def reset(self, task_id: str, seed: int = 42):
|
| 276 |
+
if self._mode in ("docker", "remote"):
|
| 277 |
+
result = self._client.reset(task_id=task_id, seed=seed)
|
| 278 |
+
return result.observation, result
|
| 279 |
+
obs = self._env.reset(seed=seed, task_id=task_id)
|
| 280 |
+
return obs, None
|
| 281 |
+
|
| 282 |
+
def step(self, action: APITestAction):
|
| 283 |
+
if self._mode in ("docker", "remote"):
|
| 284 |
+
result = self._client.step(action)
|
| 285 |
+
return result.observation, result.reward or 0.0, result.done
|
| 286 |
+
obs = self._env.step(action)
|
| 287 |
+
return obs, (obs.reward or 0.0), obs.done
|
| 288 |
+
|
| 289 |
+
def state(self):
|
| 290 |
+
if self._mode in ("docker", "remote"):
|
| 291 |
+
return self._client.state()
|
| 292 |
+
return self._env.state
|
| 293 |
+
|
| 294 |
+
def close(self):
|
| 295 |
+
try:
|
| 296 |
+
if self._client is not None and hasattr(self._client, "close"):
|
| 297 |
+
self._client.close()
|
| 298 |
+
except Exception as exc: # noqa: BLE001
|
| 299 |
+
print(f"[DEBUG] env close error: {exc}", flush=True)
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
# ---------------------------------------------------------------------------
|
| 303 |
+
# One full episode (one task) -> emits [START] / [STEP]* / [END]
|
| 304 |
+
# ---------------------------------------------------------------------------
|
| 305 |
+
|
| 306 |
+
def run_task(env: _EnvHandle, client: OpenAI, task_id: str, seed: int = 42) -> dict:
|
| 307 |
+
rewards: list[float] = []
|
| 308 |
+
steps_taken = 0
|
| 309 |
+
last_error: Optional[str] = None
|
| 310 |
+
score = 0.0
|
| 311 |
+
|
| 312 |
+
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
|
| 313 |
+
|
| 314 |
+
try:
|
| 315 |
+
obs, _ = env.reset(task_id=task_id, seed=seed)
|
| 316 |
+
max_steps = MAX_STEPS_OVERRIDE or getattr(obs, "max_steps", 25)
|
| 317 |
+
|
| 318 |
+
# 1) Ask the LLM for a full plan
|
| 319 |
+
plan_text = get_plan_from_llm(client, obs)
|
| 320 |
+
actions = parse_test_plan(plan_text) if plan_text else []
|
| 321 |
+
|
| 322 |
+
# Fallback: if parser failed but we have text, try a more lenient parse
|
| 323 |
+
if not actions and plan_text:
|
| 324 |
+
print(f"[DEBUG] {task_id}: parse_test_plan returned 0, trying lenient parse...", flush=True)
|
| 325 |
+
try:
|
| 326 |
+
import json as _json, re as _re
|
| 327 |
+
# Try to find any JSON array of objects in the text
|
| 328 |
+
cleaned = plan_text
|
| 329 |
+
if "</think>" in cleaned:
|
| 330 |
+
cleaned = cleaned.split("</think>", 1)[-1]
|
| 331 |
+
# Find first [ and last ]
|
| 332 |
+
start = cleaned.find("[")
|
| 333 |
+
end = cleaned.rfind("]")
|
| 334 |
+
if start >= 0 and end > start:
|
| 335 |
+
arr_str = cleaned[start:end+1]
|
| 336 |
+
raw = _json.loads(arr_str)
|
| 337 |
+
if isinstance(raw, list):
|
| 338 |
+
from training.prompts import _dict_to_action
|
| 339 |
+
for item in raw:
|
| 340 |
+
if isinstance(item, dict) and "method" in item:
|
| 341 |
+
a = _dict_to_action(item)
|
| 342 |
+
if a:
|
| 343 |
+
actions.append(a)
|
| 344 |
+
print(f"[DEBUG] {task_id}: lenient parse recovered {len(actions)} actions", flush=True)
|
| 345 |
+
except Exception as exc:
|
| 346 |
+
print(f"[DEBUG] {task_id}: lenient parse failed: {exc}", flush=True)
|
| 347 |
+
if not actions:
|
| 348 |
+
last_error = "no_plan_parsed"
|
| 349 |
+
print(f"[DEBUG] {task_id}: model produced 0 valid actions", flush=True)
|
| 350 |
+
|
| 351 |
+
actions = actions[:max_steps]
|
| 352 |
+
|
| 353 |
+
# 2) Execute each action and emit one [STEP] line per env.step()
|
| 354 |
+
done = False
|
| 355 |
+
for i, action in enumerate(actions, start=1):
|
| 356 |
+
if done:
|
| 357 |
+
break
|
| 358 |
+
try:
|
| 359 |
+
obs, reward, done = env.step(action)
|
| 360 |
+
rewards.append(float(reward))
|
| 361 |
+
steps_taken = i
|
| 362 |
+
log_step(step=i, action=_action_str(action), reward=reward, done=done, error=None)
|
| 363 |
+
except Exception as exc: # noqa: BLE001
|
| 364 |
+
last_error = f"{type(exc).__name__}: {exc}"
|
| 365 |
+
rewards.append(0.0)
|
| 366 |
+
steps_taken = i
|
| 367 |
+
log_step(step=i, action=_action_str(action), reward=0.0, done=False, error=last_error)
|
| 368 |
+
|
| 369 |
+
# 3) Score from final state
|
| 370 |
+
try:
|
| 371 |
+
final_state = env.state()
|
| 372 |
+
score = compute_task_score(final_state, sum(rewards))
|
| 373 |
+
except Exception as exc: # noqa: BLE001
|
| 374 |
+
last_error = last_error or f"state_error: {exc}"
|
| 375 |
+
score = 0.0
|
| 376 |
+
|
| 377 |
+
except Exception as exc: # noqa: BLE001
|
| 378 |
+
last_error = f"{type(exc).__name__}: {exc}"
|
| 379 |
+
traceback.print_exc()
|
| 380 |
+
|
| 381 |
+
success = score >= 0.20 # any meaningful progress counts as a successful episode
|
| 382 |
+
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
|
| 383 |
+
|
| 384 |
+
return {
|
| 385 |
+
"task_id": task_id,
|
| 386 |
+
"success": success,
|
| 387 |
+
"steps": steps_taken,
|
| 388 |
+
"score": score,
|
| 389 |
+
"rewards": rewards,
|
| 390 |
+
"error": last_error,
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
# ---------------------------------------------------------------------------
|
| 395 |
+
# Main β runs all 3 tasks sequentially against ONE env handle
|
| 396 |
+
# ---------------------------------------------------------------------------
|
| 397 |
+
|
| 398 |
+
def main() -> None:
|
| 399 |
+
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 400 |
+
|
| 401 |
+
print(
|
| 402 |
+
f"[DEBUG] inference.py starting | model={MODEL_NAME} | "
|
| 403 |
+
f"base_url={API_BASE_URL} | tasks={TASKS}",
|
| 404 |
+
flush=True,
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
env = _EnvHandle().open()
|
| 408 |
+
print(f"[DEBUG] env mode={env.mode}", flush=True)
|
| 409 |
+
|
| 410 |
+
summary: list[dict] = []
|
| 411 |
+
t0 = time.time()
|
| 412 |
+
try:
|
| 413 |
+
for task_id in TASKS:
|
| 414 |
+
result = run_task(env, client, task_id=task_id, seed=42)
|
| 415 |
+
summary.append(result)
|
| 416 |
+
finally:
|
| 417 |
+
env.close()
|
| 418 |
+
|
| 419 |
+
elapsed = time.time() - t0
|
| 420 |
+
avg_score = sum(r["score"] for r in summary) / max(len(summary), 1)
|
| 421 |
+
print(
|
| 422 |
+
f"[DEBUG] inference.py finished in {elapsed:.1f}s | "
|
| 423 |
+
f"avg_score={avg_score:.3f}",
|
| 424 |
+
flush=True,
|
| 425 |
+
)
|
| 426 |
+
print("[DEBUG] per-task scores: " + json.dumps(
|
| 427 |
+
{r["task_id"]: round(r["score"], 3) for r in summary}
|
| 428 |
+
), flush=True)
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
if __name__ == "__main__":
|
| 432 |
+
main()
|
models.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data models for the API Testing Environment.
|
| 3 |
+
|
| 4 |
+
Defines Action, Observation, State for API integration testing training.
|
| 5 |
+
An AI agent learns to test REST APIs intelligently β discovering endpoints,
|
| 6 |
+
crafting requests, validating responses, finding bugs, and handling edge cases.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from enum import Enum
|
| 10 |
+
from typing import Any, Optional
|
| 11 |
+
|
| 12 |
+
from pydantic import Field
|
| 13 |
+
|
| 14 |
+
from openenv.core.env_server.types import Action, Observation, State
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class HTTPMethod(str, Enum):
|
| 18 |
+
GET = "GET"
|
| 19 |
+
POST = "POST"
|
| 20 |
+
PUT = "PUT"
|
| 21 |
+
DELETE = "DELETE"
|
| 22 |
+
PATCH = "PATCH"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class BugSeverity(str, Enum):
|
| 26 |
+
EASY = "easy"
|
| 27 |
+
MEDIUM = "medium"
|
| 28 |
+
HARD = "hard"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class APITestAction(Action):
|
| 32 |
+
"""What the agent sends each step β an HTTP request to test the API."""
|
| 33 |
+
|
| 34 |
+
method: HTTPMethod = Field(..., description="HTTP method")
|
| 35 |
+
endpoint: str = Field(..., min_length=1, description="API endpoint path, e.g. /tasks, /users/1")
|
| 36 |
+
headers: dict[str, str] = Field(default_factory=dict, description="Request headers")
|
| 37 |
+
query_params: dict[str, Any] = Field(default_factory=dict, description="URL query parameters")
|
| 38 |
+
body: Optional[dict[str, Any]] = Field(default=None, description="Request JSON body")
|
| 39 |
+
expected_status: Optional[int] = Field(
|
| 40 |
+
default=None,
|
| 41 |
+
description="What the agent expects the status code to be (used for bug detection)",
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class EndpointInfo(Action):
|
| 46 |
+
"""Information about a single API endpoint from the spec."""
|
| 47 |
+
|
| 48 |
+
method: str = ""
|
| 49 |
+
path: str = ""
|
| 50 |
+
summary: str = ""
|
| 51 |
+
parameters: list[dict[str, Any]] = Field(default_factory=list)
|
| 52 |
+
request_body_schema: Optional[dict[str, Any]] = None
|
| 53 |
+
response_schema: Optional[dict[str, Any]] = None
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class APITestObservation(Observation):
|
| 57 |
+
"""What the agent sees after each step."""
|
| 58 |
+
|
| 59 |
+
# API spec info (provided on reset, updated each step)
|
| 60 |
+
available_endpoints: list[dict[str, Any]] = Field(
|
| 61 |
+
default_factory=list, description="Available API endpoints from the spec"
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# Response from last request
|
| 65 |
+
status_code: int = Field(default=0, description="HTTP status code of the response")
|
| 66 |
+
response_body: Any = Field(default=None, description="Response body (JSON or text)")
|
| 67 |
+
response_headers: dict[str, str] = Field(default_factory=dict, description="Response headers")
|
| 68 |
+
response_time_ms: float = Field(default=0.0, description="Response time in milliseconds")
|
| 69 |
+
|
| 70 |
+
# Feedback
|
| 71 |
+
feedback: str = Field(default="", description="Human-readable feedback about the last action")
|
| 72 |
+
bugs_found_so_far: int = Field(default=0, description="Number of bugs found so far")
|
| 73 |
+
coverage_summary: dict[str, Any] = Field(
|
| 74 |
+
default_factory=dict,
|
| 75 |
+
description="Coverage stats: endpoints_tested, methods_used, status_codes_seen",
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# Context from prior steps
|
| 79 |
+
known_resource_ids: dict[str, list[Any]] = Field(
|
| 80 |
+
default_factory=dict,
|
| 81 |
+
description="Resource IDs created by POST requests, keyed by resource type",
|
| 82 |
+
)
|
| 83 |
+
auth_tokens: dict[str, str] = Field(
|
| 84 |
+
default_factory=dict,
|
| 85 |
+
description="Available auth tokens for different users/roles",
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
# Task info
|
| 89 |
+
task_id: str = Field(default="", description="Current task identifier")
|
| 90 |
+
task_description: str = Field(default="", description="Description of the current task")
|
| 91 |
+
steps_taken: int = Field(default=0, description="Steps taken in this episode")
|
| 92 |
+
max_steps: int = Field(default=30, description="Maximum steps per episode")
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class APITestState(State):
|
| 96 |
+
"""Episode metadata β internal state exposed via state() endpoint."""
|
| 97 |
+
|
| 98 |
+
task_id: str = ""
|
| 99 |
+
task_description: str = ""
|
| 100 |
+
difficulty: str = "easy"
|
| 101 |
+
steps_taken: int = 0
|
| 102 |
+
max_steps: int = 30
|
| 103 |
+
bugs_found: int = 0
|
| 104 |
+
total_bugs: int = 0
|
| 105 |
+
bugs_found_ids: list[str] = Field(default_factory=list)
|
| 106 |
+
coverage_pct: float = 0.0
|
| 107 |
+
endpoints_tested: int = 0
|
| 108 |
+
total_endpoints: int = 0
|
| 109 |
+
current_score: float = 0.0
|
| 110 |
+
cumulative_reward: float = 0.0
|
openenv.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: api_testing_env
|
| 3 |
+
type: space
|
| 4 |
+
runtime: fastapi
|
| 5 |
+
app: server.app:app
|
| 6 |
+
port: 8000
|
openenv_api_testing.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.4
|
| 2 |
+
Name: openenv-api-testing
|
| 3 |
+
Version: 0.1.0
|
| 4 |
+
Summary: RL environment for intelligent API integration testing β train agents to find bugs in REST APIs
|
| 5 |
+
Requires-Python: >=3.10
|
| 6 |
+
Requires-Dist: openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git@v0.2.1
|
| 7 |
+
Requires-Dist: fastapi>=0.104.0
|
| 8 |
+
Requires-Dist: uvicorn>=0.24.0
|
| 9 |
+
Requires-Dist: httpx>=0.25.0
|
| 10 |
+
Requires-Dist: pydantic>=2.0.0
|
| 11 |
+
Provides-Extra: dev
|
| 12 |
+
Requires-Dist: pytest>=8.0.0; extra == "dev"
|
| 13 |
+
Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
|
| 14 |
+
Provides-Extra: train
|
| 15 |
+
Requires-Dist: trl[vllm]>=0.29.0; extra == "train"
|
| 16 |
+
Requires-Dist: torch>=2.8.0; extra == "train"
|
| 17 |
+
Requires-Dist: peft; extra == "train"
|
| 18 |
+
Requires-Dist: transformers; extra == "train"
|
| 19 |
+
Requires-Dist: datasets; extra == "train"
|
openenv_api_testing.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
README.md
|
| 2 |
+
pyproject.toml
|
| 3 |
+
./__init__.py
|
| 4 |
+
./baseline.py
|
| 5 |
+
./client.py
|
| 6 |
+
./models.py
|
| 7 |
+
openenv_api_testing.egg-info/PKG-INFO
|
| 8 |
+
openenv_api_testing.egg-info/SOURCES.txt
|
| 9 |
+
openenv_api_testing.egg-info/dependency_links.txt
|
| 10 |
+
openenv_api_testing.egg-info/entry_points.txt
|
| 11 |
+
openenv_api_testing.egg-info/requires.txt
|
| 12 |
+
openenv_api_testing.egg-info/top_level.txt
|
| 13 |
+
server/__init__.py
|
| 14 |
+
server/app.py
|
| 15 |
+
server/bug_detector.py
|
| 16 |
+
server/environment.py
|
| 17 |
+
server/graders.py
|
| 18 |
+
server/reward.py
|
| 19 |
+
server/buggy_api/__init__.py
|
| 20 |
+
server/buggy_api/database.py
|
| 21 |
+
server/buggy_api/main.py
|
| 22 |
+
server/buggy_api/models.py
|
| 23 |
+
server/buggy_api/routes/__init__.py
|
| 24 |
+
server/buggy_api/routes/auth.py
|
| 25 |
+
server/buggy_api/routes/tasks.py
|
| 26 |
+
server/buggy_api/routes/users.py
|
openenv_api_testing.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
openenv_api_testing.egg-info/entry_points.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[console_scripts]
|
| 2 |
+
server = api_testing_env.server.app:main
|
openenv_api_testing.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git@v0.2.1
|
| 2 |
+
fastapi>=0.104.0
|
| 3 |
+
uvicorn>=0.24.0
|
| 4 |
+
httpx>=0.25.0
|
| 5 |
+
pydantic>=2.0.0
|
| 6 |
+
|
| 7 |
+
[dev]
|
| 8 |
+
pytest>=8.0.0
|
| 9 |
+
pytest-cov>=4.0.0
|
| 10 |
+
|
| 11 |
+
[train]
|
| 12 |
+
trl[vllm]>=0.29.0
|
| 13 |
+
torch>=2.8.0
|
| 14 |
+
peft
|
| 15 |
+
transformers
|
| 16 |
+
datasets
|
openenv_api_testing.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
api_testing_env
|
pyproject.toml
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=45", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "openenv-api-testing"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "RL environment for intelligent API integration testing β train agents to find bugs in REST APIs"
|
| 9 |
+
requires-python = ">=3.10"
|
| 10 |
+
dependencies = [
|
| 11 |
+
"openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git@v0.2.1",
|
| 12 |
+
"fastapi>=0.104.0",
|
| 13 |
+
"uvicorn>=0.24.0",
|
| 14 |
+
"httpx>=0.25.0",
|
| 15 |
+
"pydantic>=2.0.0",
|
| 16 |
+
"openai>=1.40.0",
|
| 17 |
+
"gradio>=5.0.0",
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
[project.optional-dependencies]
|
| 21 |
+
ui = [
|
| 22 |
+
"gradio>=5.0.0",
|
| 23 |
+
]
|
| 24 |
+
dev = [
|
| 25 |
+
"pytest>=8.0.0",
|
| 26 |
+
"pytest-cov>=4.0.0",
|
| 27 |
+
]
|
| 28 |
+
train = [
|
| 29 |
+
"trl>=0.15.0",
|
| 30 |
+
"torch>=2.1.0",
|
| 31 |
+
"peft>=0.7.0",
|
| 32 |
+
"transformers>=4.40.0",
|
| 33 |
+
"datasets>=2.16.0",
|
| 34 |
+
"wandb>=0.16.0",
|
| 35 |
+
"huggingface-hub>=0.20.0",
|
| 36 |
+
"matplotlib>=3.8.0",
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
[project.scripts]
|
| 40 |
+
server = "api_testing_env.server.app:main"
|
| 41 |
+
|
| 42 |
+
[tool.uv]
|
| 43 |
+
package = false
|
| 44 |
+
|
| 45 |
+
[tool.setuptools]
|
| 46 |
+
include-package-data = true
|
| 47 |
+
packages = [
|
| 48 |
+
"api_testing_env",
|
| 49 |
+
"api_testing_env.server",
|
| 50 |
+
"api_testing_env.server.buggy_api",
|
| 51 |
+
"api_testing_env.server.buggy_api.routes",
|
| 52 |
+
"api_testing_env.training",
|
| 53 |
+
]
|
| 54 |
+
|
| 55 |
+
[tool.setuptools.package-dir]
|
| 56 |
+
api_testing_env = "."
|
| 57 |
+
"api_testing_env.server" = "server"
|
| 58 |
+
"api_testing_env.server.buggy_api" = "server/buggy_api"
|
| 59 |
+
"api_testing_env.server.buggy_api.routes" = "server/buggy_api/routes"
|
| 60 |
+
"api_testing_env.training" = "training"
|
requirements.txt
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core dependencies
|
| 2 |
+
openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git@v0.2.1
|
| 3 |
+
fastapi>=0.104.0
|
| 4 |
+
uvicorn>=0.24.0
|
| 5 |
+
httpx>=0.25.0
|
| 6 |
+
pydantic>=2.0.0,<2.12
|
| 7 |
+
|
| 8 |
+
# Training dependencies
|
| 9 |
+
# NOTE: PyTorch is NOT listed here β it must be installed separately
|
| 10 |
+
# with the correct CUDA version. See setup.sh or run:
|
| 11 |
+
# pip install torch --index-url https://download.pytorch.org/whl/cu121
|
| 12 |
+
trl>=0.15.0
|
| 13 |
+
peft>=0.7.0
|
| 14 |
+
transformers>=4.40.0
|
| 15 |
+
datasets>=2.16.0
|
| 16 |
+
|
| 17 |
+
# Weights & Biases (optional but recommended)
|
| 18 |
+
wandb>=0.16.0
|
| 19 |
+
|
| 20 |
+
# HuggingFace Hub (for model push)
|
| 21 |
+
huggingface-hub>=0.20.0
|
| 22 |
+
|
| 23 |
+
# Plots and metrics
|
| 24 |
+
matplotlib>=3.8.0
|
| 25 |
+
|
| 26 |
+
# UI
|
| 27 |
+
gradio>=5.0.0
|
server/__init__.py
ADDED
|
File without changes
|
server/app.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI application for the API Testing Environment.
|
| 3 |
+
|
| 4 |
+
Endpoints:
|
| 5 |
+
- POST /reset: Reset the environment
|
| 6 |
+
- POST /step: Execute an action
|
| 7 |
+
- GET /state: Get current environment state
|
| 8 |
+
- GET /schema: Get action/observation schemas
|
| 9 |
+
- WS /ws: WebSocket endpoint for persistent sessions
|
| 10 |
+
- GET / Info page
|
| 11 |
+
|
| 12 |
+
Usage:
|
| 13 |
+
uvicorn server.app:app --host 0.0.0.0 --port 8000
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
import logging
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
from openenv.core.env_server.http_server import create_app
|
| 21 |
+
from ..models import APITestAction, APITestObservation
|
| 22 |
+
from .environment import APITestEnvironment
|
| 23 |
+
except ImportError:
|
| 24 |
+
from openenv.core.env_server.http_server import create_app
|
| 25 |
+
from models import APITestAction, APITestObservation
|
| 26 |
+
from server.environment import APITestEnvironment
|
| 27 |
+
|
| 28 |
+
from fastapi.responses import RedirectResponse
|
| 29 |
+
|
| 30 |
+
logger = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
app = create_app(
|
| 33 |
+
APITestEnvironment,
|
| 34 |
+
APITestAction,
|
| 35 |
+
APITestObservation,
|
| 36 |
+
env_name="api_testing_env",
|
| 37 |
+
max_concurrent_envs=int(os.environ.get("MAX_ENVS", "1")),
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# Track whether the Gradio UI is available so root can redirect to it
|
| 41 |
+
_GRADIO_MOUNTED = False
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@app.get("/info")
|
| 45 |
+
async def info():
|
| 46 |
+
"""JSON info about the environment (replaces the old `/` JSON endpoint)."""
|
| 47 |
+
return {
|
| 48 |
+
"name": "API Testing Environment",
|
| 49 |
+
"description": "An OpenEnv RL environment where an AI agent learns to test REST APIs intelligently",
|
| 50 |
+
"tasks": ["basic_validation", "edge_cases", "security_workflows"],
|
| 51 |
+
"ui": "/ui",
|
| 52 |
+
"docs": "/docs",
|
| 53 |
+
"schema": "/schema",
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@app.get("/tasks")
|
| 58 |
+
async def list_tasks():
|
| 59 |
+
"""List available tasks with descriptions."""
|
| 60 |
+
from .environment import TASKS
|
| 61 |
+
return {
|
| 62 |
+
task_id: {
|
| 63 |
+
"description": task["description"],
|
| 64 |
+
"difficulty": task["difficulty"],
|
| 65 |
+
"max_steps": task["max_steps"],
|
| 66 |
+
"total_bugs": task["total_bugs"],
|
| 67 |
+
}
|
| 68 |
+
for task_id, task in TASKS.items()
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# ---------------------------------------------------------------------------
|
| 73 |
+
# Mount Gradio UI at /ui (only if gradio is installed and ENABLE_WEB_INTERFACE)
|
| 74 |
+
# ---------------------------------------------------------------------------
|
| 75 |
+
if os.environ.get("ENABLE_WEB_INTERFACE", "true").lower() in ("1", "true", "yes"):
|
| 76 |
+
try:
|
| 77 |
+
import gradio as gr # type: ignore
|
| 78 |
+
# Make the repo root importable so gradio_app's `from models import ...` works
|
| 79 |
+
import sys
|
| 80 |
+
_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 81 |
+
if _REPO_ROOT not in sys.path:
|
| 82 |
+
sys.path.insert(0, _REPO_ROOT)
|
| 83 |
+
from gradio_app import build_ui # type: ignore
|
| 84 |
+
|
| 85 |
+
_gradio_ui = build_ui()
|
| 86 |
+
app = gr.mount_gradio_app(app, _gradio_ui, path="/ui")
|
| 87 |
+
_GRADIO_MOUNTED = True
|
| 88 |
+
logger.info("Gradio UI mounted at /ui")
|
| 89 |
+
except Exception as exc: # noqa: BLE001
|
| 90 |
+
logger.warning(f"Skipping Gradio mount ({type(exc).__name__}: {exc})")
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# ---------------------------------------------------------------------------
|
| 94 |
+
# Root redirect: send visitors to the Gradio UI if mounted, else to JSON info
|
| 95 |
+
# ---------------------------------------------------------------------------
|
| 96 |
+
@app.get("/", include_in_schema=False)
|
| 97 |
+
async def root_redirect():
|
| 98 |
+
"""Redirect / to the Gradio UI when available, otherwise to /info JSON."""
|
| 99 |
+
if _GRADIO_MOUNTED:
|
| 100 |
+
return RedirectResponse(url="/ui", status_code=307)
|
| 101 |
+
return RedirectResponse(url="/info", status_code=307)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def main(host: str = None, port: int = None):
|
| 105 |
+
"""Entry point for `uv run server` and `python -m server.app`.
|
| 106 |
+
|
| 107 |
+
When invoked from the CLI without args, parses argv for --host / --port.
|
| 108 |
+
"""
|
| 109 |
+
import uvicorn
|
| 110 |
+
|
| 111 |
+
if host is None or port is None:
|
| 112 |
+
import argparse
|
| 113 |
+
parser = argparse.ArgumentParser(description="API Testing Environment server")
|
| 114 |
+
parser.add_argument("--host", default="0.0.0.0")
|
| 115 |
+
parser.add_argument("--port", type=int, default=None)
|
| 116 |
+
args, _ = parser.parse_known_args()
|
| 117 |
+
host = host or args.host
|
| 118 |
+
port = port or args.port
|
| 119 |
+
|
| 120 |
+
if port is None:
|
| 121 |
+
port = int(os.environ.get("PORT", "8000"))
|
| 122 |
+
|
| 123 |
+
logging.basicConfig(
|
| 124 |
+
level=logging.INFO,
|
| 125 |
+
format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
|
| 126 |
+
)
|
| 127 |
+
logging.getLogger("httpx").setLevel(logging.WARNING)
|
| 128 |
+
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
| 129 |
+
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
|
| 130 |
+
|
| 131 |
+
uvicorn.run(app, host=host, port=port)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
if __name__ == "__main__":
|
| 135 |
+
main()
|
server/bug_detector.py
ADDED
|
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Bug detection logic β checks if the agent's action/response pair reveals a planted bug.
|
| 3 |
+
|
| 4 |
+
Each bug has:
|
| 5 |
+
- A unique ID
|
| 6 |
+
- A severity level (easy/medium/hard)
|
| 7 |
+
- A detection function that checks action + response
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
from typing import Any, Callable, Optional
|
| 12 |
+
import re
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class Bug:
|
| 17 |
+
id: str
|
| 18 |
+
severity: str # "easy", "medium", "hard"
|
| 19 |
+
description: str
|
| 20 |
+
category: str # "status_code", "validation", "security", "data_integrity"
|
| 21 |
+
owasp: str = "" # OWASP API Security Top 10 (2023) category
|
| 22 |
+
recommendation: str = "" # Fix recommendation for bug bounty reports
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class BugDetection:
|
| 27 |
+
bug: Bug
|
| 28 |
+
evidence: str # Human-readable explanation of how the bug was detected
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class BugDetector:
|
| 32 |
+
"""Detects planted bugs based on agent actions and API responses."""
|
| 33 |
+
|
| 34 |
+
def __init__(self, task_id: str):
|
| 35 |
+
self.task_id = task_id
|
| 36 |
+
self._build_bug_registry()
|
| 37 |
+
|
| 38 |
+
def _build_bug_registry(self):
|
| 39 |
+
"""Define all bugs with their detection logic."""
|
| 40 |
+
self.bugs: dict[str, Bug] = {}
|
| 41 |
+
self.detectors: dict[str, Callable] = {}
|
| 42 |
+
|
| 43 |
+
# === EASY BUGS ===
|
| 44 |
+
|
| 45 |
+
self._register_bug(
|
| 46 |
+
Bug("BUG_TASK_01", "easy",
|
| 47 |
+
"GET /tasks/{id} returns 200 with null for non-existent task",
|
| 48 |
+
"status_code",
|
| 49 |
+
owasp="API8:2023 Security Misconfiguration",
|
| 50 |
+
recommendation="Return 404 Not Found for non-existent resources"),
|
| 51 |
+
self._detect_null_response_for_missing_task,
|
| 52 |
+
)
|
| 53 |
+
self._register_bug(
|
| 54 |
+
Bug("BUG_TASK_02", "easy",
|
| 55 |
+
"POST /tasks with missing title returns 500 instead of 400/422",
|
| 56 |
+
"validation",
|
| 57 |
+
owasp="API8:2023 Security Misconfiguration",
|
| 58 |
+
recommendation="Validate required fields and return 400/422 with descriptive error"),
|
| 59 |
+
self._detect_missing_field_500,
|
| 60 |
+
)
|
| 61 |
+
self._register_bug(
|
| 62 |
+
Bug("BUG_TASK_03", "easy",
|
| 63 |
+
"GET /tasks?page=-1 returns 200 instead of 400",
|
| 64 |
+
"validation",
|
| 65 |
+
owasp="API8:2023 Security Misconfiguration",
|
| 66 |
+
recommendation="Validate pagination parameters: page >= 1, limit > 0"),
|
| 67 |
+
self._detect_negative_page,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# === MEDIUM BUGS ===
|
| 71 |
+
|
| 72 |
+
self._register_bug(
|
| 73 |
+
Bug("BUG_TASK_04", "medium",
|
| 74 |
+
"PUT /tasks/{id} accepts invalid email format for assignee_email",
|
| 75 |
+
"validation",
|
| 76 |
+
owasp="API8:2023 Security Misconfiguration",
|
| 77 |
+
recommendation="Validate email format with regex before accepting"),
|
| 78 |
+
self._detect_invalid_email_accepted,
|
| 79 |
+
)
|
| 80 |
+
self._register_bug(
|
| 81 |
+
Bug("BUG_TASK_05", "medium",
|
| 82 |
+
"DELETE /tasks/{id} returns 200 for non-existent task",
|
| 83 |
+
"status_code",
|
| 84 |
+
owasp="API8:2023 Security Misconfiguration",
|
| 85 |
+
recommendation="Check resource existence before deletion, return 404 if missing"),
|
| 86 |
+
self._detect_delete_nonexistent_200,
|
| 87 |
+
)
|
| 88 |
+
self._register_bug(
|
| 89 |
+
Bug("BUG_TASK_06", "medium",
|
| 90 |
+
"GET /tasks?limit=999999 has no pagination cap",
|
| 91 |
+
"validation",
|
| 92 |
+
owasp="API4:2023 Unrestricted Resource Consumption",
|
| 93 |
+
recommendation="Cap pagination limit at 100, reject values above maximum"),
|
| 94 |
+
self._detect_no_pagination_cap,
|
| 95 |
+
)
|
| 96 |
+
self._register_bug(
|
| 97 |
+
Bug("BUG_USER_01", "medium",
|
| 98 |
+
"POST /users accepts invalid email format",
|
| 99 |
+
"validation",
|
| 100 |
+
owasp="API8:2023 Security Misconfiguration",
|
| 101 |
+
recommendation="Validate email format server-side before creating user"),
|
| 102 |
+
self._detect_user_invalid_email,
|
| 103 |
+
)
|
| 104 |
+
self._register_bug(
|
| 105 |
+
Bug("BUG_USER_02", "medium",
|
| 106 |
+
"POST /users response exposes password hash",
|
| 107 |
+
"security",
|
| 108 |
+
owasp="API3:2023 Broken Object Property Level Authorization",
|
| 109 |
+
recommendation="Never return sensitive fields (password_hash) in API responses"),
|
| 110 |
+
self._detect_password_hash_exposed,
|
| 111 |
+
)
|
| 112 |
+
self._register_bug(
|
| 113 |
+
Bug("BUG_AUTH_02", "medium",
|
| 114 |
+
"Login with empty password succeeds",
|
| 115 |
+
"security",
|
| 116 |
+
owasp="API2:2023 Broken Authentication",
|
| 117 |
+
recommendation="Validate password is non-empty and verify against stored hash"),
|
| 118 |
+
self._detect_empty_password_login,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# === HARD BUGS ===
|
| 122 |
+
|
| 123 |
+
self._register_bug(
|
| 124 |
+
Bug("BUG_TASK_07", "hard",
|
| 125 |
+
"BOLA: User A can access User B's tasks without authorization check",
|
| 126 |
+
"security",
|
| 127 |
+
owasp="API1:2023 Broken Object Level Authorization",
|
| 128 |
+
recommendation="Verify resource ownership: check task.owner_id matches authenticated user"),
|
| 129 |
+
self._detect_bola,
|
| 130 |
+
)
|
| 131 |
+
self._register_bug(
|
| 132 |
+
Bug("BUG_TASK_08", "hard",
|
| 133 |
+
"POST /tasks with very long title (>5000 chars) causes 500",
|
| 134 |
+
"validation",
|
| 135 |
+
owasp="API4:2023 Unrestricted Resource Consumption",
|
| 136 |
+
recommendation="Add input length validation: title max 200 chars"),
|
| 137 |
+
self._detect_long_input_crash,
|
| 138 |
+
)
|
| 139 |
+
self._register_bug(
|
| 140 |
+
Bug("BUG_TASK_09", "hard",
|
| 141 |
+
"SQL injection payload in title is stored verbatim (content injection)",
|
| 142 |
+
"security",
|
| 143 |
+
owasp="API8:2023 Security Misconfiguration",
|
| 144 |
+
recommendation="Sanitize user input before storage, escape HTML/SQL special characters"),
|
| 145 |
+
self._detect_content_injection,
|
| 146 |
+
)
|
| 147 |
+
self._register_bug(
|
| 148 |
+
Bug("BUG_AUTH_01", "hard",
|
| 149 |
+
"Auth tokens not user-scoped: User A's token can modify User B's tasks",
|
| 150 |
+
"security",
|
| 151 |
+
owasp="API1:2023 Broken Object Level Authorization",
|
| 152 |
+
recommendation="Enforce ownership check on all write operations (PUT/DELETE)"),
|
| 153 |
+
self._detect_broken_auth,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
def _register_bug(self, bug: Bug, detector: Callable):
|
| 157 |
+
self.bugs[bug.id] = bug
|
| 158 |
+
self.detectors[bug.id] = detector
|
| 159 |
+
|
| 160 |
+
def get_bugs_for_task(self) -> list[Bug]:
|
| 161 |
+
"""Return bugs relevant to the current task."""
|
| 162 |
+
if self.task_id == "basic_validation":
|
| 163 |
+
return [self.bugs[bid] for bid in ["BUG_TASK_01", "BUG_TASK_02", "BUG_TASK_03"]]
|
| 164 |
+
elif self.task_id == "edge_cases":
|
| 165 |
+
return [
|
| 166 |
+
self.bugs[bid]
|
| 167 |
+
for bid in [
|
| 168 |
+
"BUG_TASK_01", "BUG_TASK_02", "BUG_TASK_03",
|
| 169 |
+
"BUG_TASK_04", "BUG_TASK_05", "BUG_TASK_06",
|
| 170 |
+
"BUG_USER_01", "BUG_USER_02", "BUG_AUTH_02",
|
| 171 |
+
]
|
| 172 |
+
]
|
| 173 |
+
else: # security_workflows
|
| 174 |
+
return list(self.bugs.values())
|
| 175 |
+
|
| 176 |
+
def check(
|
| 177 |
+
self,
|
| 178 |
+
method: str,
|
| 179 |
+
endpoint: str,
|
| 180 |
+
headers: dict,
|
| 181 |
+
query_params: dict,
|
| 182 |
+
body: Optional[dict],
|
| 183 |
+
expected_status: Optional[int],
|
| 184 |
+
response_status: int,
|
| 185 |
+
response_body: Any,
|
| 186 |
+
action_history: list[dict],
|
| 187 |
+
found_bugs: set[str],
|
| 188 |
+
) -> Optional[BugDetection]:
|
| 189 |
+
"""Check if this action/response reveals a bug.
|
| 190 |
+
|
| 191 |
+
Returns the first new bug detected, or None.
|
| 192 |
+
"""
|
| 193 |
+
ctx = {
|
| 194 |
+
"method": method.upper(),
|
| 195 |
+
"endpoint": endpoint,
|
| 196 |
+
"headers": headers,
|
| 197 |
+
"query_params": query_params,
|
| 198 |
+
"body": body,
|
| 199 |
+
"expected_status": expected_status,
|
| 200 |
+
"response_status": response_status,
|
| 201 |
+
"response_body": response_body,
|
| 202 |
+
"action_history": action_history,
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
for bug_id, detector in self.detectors.items():
|
| 206 |
+
if bug_id in found_bugs:
|
| 207 |
+
continue
|
| 208 |
+
# Only check bugs relevant to this task
|
| 209 |
+
task_bugs = {b.id for b in self.get_bugs_for_task()}
|
| 210 |
+
if bug_id not in task_bugs:
|
| 211 |
+
continue
|
| 212 |
+
result = detector(ctx)
|
| 213 |
+
if result:
|
| 214 |
+
return BugDetection(bug=self.bugs[bug_id], evidence=result)
|
| 215 |
+
|
| 216 |
+
return None
|
| 217 |
+
|
| 218 |
+
# === DETECTION FUNCTIONS ===
|
| 219 |
+
|
| 220 |
+
def _detect_null_response_for_missing_task(self, ctx: dict) -> Optional[str]:
|
| 221 |
+
if (
|
| 222 |
+
ctx["method"] == "GET"
|
| 223 |
+
and re.match(r"^/tasks/\d+$", ctx["endpoint"])
|
| 224 |
+
and ctx["response_status"] == 200
|
| 225 |
+
and ctx["response_body"] is None
|
| 226 |
+
):
|
| 227 |
+
task_id = ctx["endpoint"].split("/")[-1]
|
| 228 |
+
return f"GET /tasks/{task_id} returned 200 with null body β should be 404"
|
| 229 |
+
return None
|
| 230 |
+
|
| 231 |
+
def _detect_missing_field_500(self, ctx: dict) -> Optional[str]:
|
| 232 |
+
if (
|
| 233 |
+
ctx["method"] == "POST"
|
| 234 |
+
and ctx["endpoint"] == "/tasks"
|
| 235 |
+
and ctx["response_status"] == 500
|
| 236 |
+
and ctx["body"] is not None
|
| 237 |
+
and "title" not in ctx["body"]
|
| 238 |
+
):
|
| 239 |
+
return "POST /tasks with missing 'title' returned 500 β should be 400 or 422"
|
| 240 |
+
return None
|
| 241 |
+
|
| 242 |
+
def _detect_negative_page(self, ctx: dict) -> Optional[str]:
|
| 243 |
+
if (
|
| 244 |
+
ctx["method"] == "GET"
|
| 245 |
+
and ctx["endpoint"] == "/tasks"
|
| 246 |
+
and ctx["query_params"].get("page") is not None
|
| 247 |
+
):
|
| 248 |
+
page = ctx["query_params"]["page"]
|
| 249 |
+
try:
|
| 250 |
+
page = int(page)
|
| 251 |
+
except (ValueError, TypeError):
|
| 252 |
+
return None
|
| 253 |
+
if page < 1 and ctx["response_status"] == 200:
|
| 254 |
+
return f"GET /tasks?page={page} returned 200 β should be 400 for invalid page"
|
| 255 |
+
return None
|
| 256 |
+
|
| 257 |
+
def _detect_invalid_email_accepted(self, ctx: dict) -> Optional[str]:
|
| 258 |
+
if (
|
| 259 |
+
ctx["method"] == "PUT"
|
| 260 |
+
and re.match(r"^/tasks/\d+$", ctx["endpoint"])
|
| 261 |
+
and ctx["body"]
|
| 262 |
+
and "assignee_email" in ctx["body"]
|
| 263 |
+
and ctx["response_status"] in (200, 201)
|
| 264 |
+
):
|
| 265 |
+
email = ctx["body"]["assignee_email"]
|
| 266 |
+
if email and not re.match(r"^[^@]+@[^@]+\.[^@]+$", email):
|
| 267 |
+
return f"PUT accepted invalid email '{email}' without validation"
|
| 268 |
+
return None
|
| 269 |
+
|
| 270 |
+
def _detect_delete_nonexistent_200(self, ctx: dict) -> Optional[str]:
|
| 271 |
+
if (
|
| 272 |
+
ctx["method"] == "DELETE"
|
| 273 |
+
and re.match(r"^/tasks/\d+$", ctx["endpoint"])
|
| 274 |
+
and ctx["response_status"] == 200
|
| 275 |
+
):
|
| 276 |
+
task_id = int(ctx["endpoint"].split("/")[-1])
|
| 277 |
+
# Check if this task was never created (ID > 1000 is a safe bet for non-existent)
|
| 278 |
+
if task_id > 100:
|
| 279 |
+
return f"DELETE /tasks/{task_id} returned 200 for non-existent task β should be 404"
|
| 280 |
+
return None
|
| 281 |
+
|
| 282 |
+
def _detect_no_pagination_cap(self, ctx: dict) -> Optional[str]:
|
| 283 |
+
if (
|
| 284 |
+
ctx["method"] == "GET"
|
| 285 |
+
and ctx["endpoint"] == "/tasks"
|
| 286 |
+
and ctx["response_status"] == 200
|
| 287 |
+
):
|
| 288 |
+
limit = ctx["query_params"].get("limit")
|
| 289 |
+
if limit is not None:
|
| 290 |
+
try:
|
| 291 |
+
limit = int(limit)
|
| 292 |
+
except (ValueError, TypeError):
|
| 293 |
+
return None
|
| 294 |
+
if limit > 1000:
|
| 295 |
+
return f"GET /tasks?limit={limit} accepted without pagination cap β potential DoS"
|
| 296 |
+
return None
|
| 297 |
+
|
| 298 |
+
def _detect_user_invalid_email(self, ctx: dict) -> Optional[str]:
|
| 299 |
+
if (
|
| 300 |
+
ctx["method"] == "POST"
|
| 301 |
+
and ctx["endpoint"] == "/users"
|
| 302 |
+
and ctx["body"]
|
| 303 |
+
and "email" in ctx["body"]
|
| 304 |
+
and ctx["response_status"] == 201
|
| 305 |
+
):
|
| 306 |
+
email = ctx["body"]["email"]
|
| 307 |
+
if email and not re.match(r"^[^@]+@[^@]+\.[^@]+$", email):
|
| 308 |
+
return f"POST /users accepted invalid email '{email}'"
|
| 309 |
+
return None
|
| 310 |
+
|
| 311 |
+
def _detect_password_hash_exposed(self, ctx: dict) -> Optional[str]:
|
| 312 |
+
if (
|
| 313 |
+
ctx["method"] == "POST"
|
| 314 |
+
and ctx["endpoint"] == "/users"
|
| 315 |
+
and ctx["response_status"] == 201
|
| 316 |
+
and isinstance(ctx["response_body"], dict)
|
| 317 |
+
):
|
| 318 |
+
if "password_hash" in ctx["response_body"]:
|
| 319 |
+
return "POST /users response exposes password_hash field β security vulnerability"
|
| 320 |
+
return None
|
| 321 |
+
|
| 322 |
+
def _detect_empty_password_login(self, ctx: dict) -> Optional[str]:
|
| 323 |
+
if (
|
| 324 |
+
ctx["method"] == "POST"
|
| 325 |
+
and ctx["endpoint"] == "/auth/login"
|
| 326 |
+
and ctx["body"]
|
| 327 |
+
and ctx["response_status"] == 200
|
| 328 |
+
):
|
| 329 |
+
password = ctx["body"].get("password", "NOTEMPTY")
|
| 330 |
+
if password == "" or password is None:
|
| 331 |
+
return "Login with empty password succeeded β authentication bypass"
|
| 332 |
+
return None
|
| 333 |
+
|
| 334 |
+
def _detect_bola(self, ctx: dict) -> Optional[str]:
|
| 335 |
+
"""Detect if agent tested cross-user resource access."""
|
| 336 |
+
if (
|
| 337 |
+
ctx["method"] == "GET"
|
| 338 |
+
and re.match(r"^/tasks/\d+$", ctx["endpoint"])
|
| 339 |
+
and ctx["response_status"] == 200
|
| 340 |
+
and ctx["response_body"] is not None
|
| 341 |
+
and isinstance(ctx["response_body"], dict)
|
| 342 |
+
and ctx["headers"].get("Authorization")
|
| 343 |
+
):
|
| 344 |
+
# Check if the agent logged in as a different user and accessed another's task
|
| 345 |
+
for prev in reversed(ctx["action_history"]):
|
| 346 |
+
if (
|
| 347 |
+
prev.get("method") == "POST"
|
| 348 |
+
and prev.get("endpoint") == "/auth/login"
|
| 349 |
+
and prev.get("response_status") == 200
|
| 350 |
+
and isinstance(prev.get("response_body"), dict)
|
| 351 |
+
):
|
| 352 |
+
login_user_id = prev["response_body"].get("user_id")
|
| 353 |
+
task_owner_id = ctx["response_body"].get("owner_id")
|
| 354 |
+
if (
|
| 355 |
+
login_user_id is not None
|
| 356 |
+
and task_owner_id is not None
|
| 357 |
+
and login_user_id != task_owner_id
|
| 358 |
+
):
|
| 359 |
+
return (
|
| 360 |
+
f"User {login_user_id} accessed task owned by user {task_owner_id} "
|
| 361 |
+
f"β BOLA/IDOR vulnerability (no authorization check)"
|
| 362 |
+
)
|
| 363 |
+
return None
|
| 364 |
+
|
| 365 |
+
def _detect_long_input_crash(self, ctx: dict) -> Optional[str]:
|
| 366 |
+
if (
|
| 367 |
+
ctx["method"] == "POST"
|
| 368 |
+
and ctx["endpoint"] == "/tasks"
|
| 369 |
+
and ctx["body"]
|
| 370 |
+
and ctx["response_status"] == 500
|
| 371 |
+
):
|
| 372 |
+
title = ctx["body"].get("title", "")
|
| 373 |
+
if isinstance(title, str) and len(title) > 5000:
|
| 374 |
+
return f"POST /tasks with title length {len(title)} caused 500 β no input length validation"
|
| 375 |
+
return None
|
| 376 |
+
|
| 377 |
+
def _detect_content_injection(self, ctx: dict) -> Optional[str]:
|
| 378 |
+
if (
|
| 379 |
+
ctx["method"] == "POST"
|
| 380 |
+
and ctx["endpoint"] == "/tasks"
|
| 381 |
+
and ctx["body"]
|
| 382 |
+
and ctx["response_status"] == 201
|
| 383 |
+
and isinstance(ctx["response_body"], dict)
|
| 384 |
+
):
|
| 385 |
+
title = ctx["body"].get("title", "")
|
| 386 |
+
injection_patterns = [
|
| 387 |
+
"DROP TABLE", "DELETE FROM", "<script>", "javascript:",
|
| 388 |
+
"'; --", "\" OR 1=1", "UNION SELECT",
|
| 389 |
+
]
|
| 390 |
+
for pattern in injection_patterns:
|
| 391 |
+
if pattern.lower() in str(title).lower():
|
| 392 |
+
stored_title = ctx["response_body"].get("title", "")
|
| 393 |
+
if pattern.lower() in str(stored_title).lower():
|
| 394 |
+
return (
|
| 395 |
+
f"Injection payload '{pattern}' in title was stored verbatim "
|
| 396 |
+
f"β no input sanitization (content injection)"
|
| 397 |
+
)
|
| 398 |
+
return None
|
| 399 |
+
|
| 400 |
+
def _detect_broken_auth(self, ctx: dict) -> Optional[str]:
|
| 401 |
+
"""Detect if agent successfully modified another user's task with their own token."""
|
| 402 |
+
if (
|
| 403 |
+
ctx["method"] in ("PUT", "DELETE")
|
| 404 |
+
and re.match(r"^/tasks/\d+$", ctx["endpoint"])
|
| 405 |
+
and ctx["response_status"] == 200
|
| 406 |
+
and ctx["headers"].get("Authorization")
|
| 407 |
+
):
|
| 408 |
+
for prev in reversed(ctx["action_history"]):
|
| 409 |
+
if (
|
| 410 |
+
prev.get("method") == "POST"
|
| 411 |
+
and prev.get("endpoint") == "/auth/login"
|
| 412 |
+
and prev.get("response_status") == 200
|
| 413 |
+
and isinstance(prev.get("response_body"), dict)
|
| 414 |
+
):
|
| 415 |
+
login_user_id = prev["response_body"].get("user_id")
|
| 416 |
+
# Check if the task belonged to a different user
|
| 417 |
+
task_id = int(ctx["endpoint"].split("/")[-1])
|
| 418 |
+
if isinstance(ctx["response_body"], dict):
|
| 419 |
+
task_owner = ctx["response_body"].get("owner_id")
|
| 420 |
+
if (
|
| 421 |
+
login_user_id is not None
|
| 422 |
+
and task_owner is not None
|
| 423 |
+
and login_user_id != task_owner
|
| 424 |
+
):
|
| 425 |
+
return (
|
| 426 |
+
f"User {login_user_id}'s token modified task owned by user {task_owner} "
|
| 427 |
+
f"β broken authorization"
|
| 428 |
+
)
|
| 429 |
+
break
|
| 430 |
+
return None
|
server/buggy_api/__init__.py
ADDED
|
File without changes
|
server/buggy_api/database.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
In-memory SQLite database for the buggy API.
|
| 3 |
+
Supports reset between episodes with DOMAIN RANDOMIZATION β
|
| 4 |
+
each seed produces different users, tasks, and data distributions
|
| 5 |
+
so that every training episode is unique.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import random
|
| 9 |
+
import sqlite3
|
| 10 |
+
import threading
|
| 11 |
+
from contextlib import contextmanager
|
| 12 |
+
|
| 13 |
+
# Name pools for randomized seed data
|
| 14 |
+
FIRST_NAMES = [
|
| 15 |
+
"alice", "bob", "charlie", "diana", "ethan", "fiona", "george", "hannah",
|
| 16 |
+
"ivan", "julia", "kevin", "luna", "mike", "nina", "oscar", "priya",
|
| 17 |
+
"quinn", "ravi", "sara", "tom", "uma", "victor", "wendy", "xander",
|
| 18 |
+
]
|
| 19 |
+
DOMAINS = ["example.com", "company.org", "startup.io", "work.dev", "test.net"]
|
| 20 |
+
TASK_TITLES = [
|
| 21 |
+
"Setup CI/CD pipeline", "Write unit tests", "Fix login page CSS",
|
| 22 |
+
"Database migration", "API documentation", "Refactor auth module",
|
| 23 |
+
"Add rate limiting", "Setup monitoring", "Fix memory leak",
|
| 24 |
+
"Update dependencies", "Add logging middleware", "Create admin panel",
|
| 25 |
+
"Implement caching", "Fix CORS issues", "Add input validation",
|
| 26 |
+
"Setup Docker compose", "Write integration tests", "Fix date parsing bug",
|
| 27 |
+
"Add search functionality", "Implement pagination", "Setup SSL certs",
|
| 28 |
+
"Add webhook support", "Fix timezone handling", "Create backup script",
|
| 29 |
+
"Optimize database queries", "Add email notifications", "Fix file upload",
|
| 30 |
+
"Implement user roles", "Add audit logging", "Setup load balancer",
|
| 31 |
+
]
|
| 32 |
+
TASK_DESCRIPTIONS = [
|
| 33 |
+
"Configure GitHub Actions for automated deployment",
|
| 34 |
+
"Add tests for the auth module endpoints",
|
| 35 |
+
"Button alignment issue on mobile devices",
|
| 36 |
+
"Migrate from SQLite to PostgreSQL",
|
| 37 |
+
"Document all REST endpoints with examples",
|
| 38 |
+
"Break down the monolithic auth into smaller services",
|
| 39 |
+
"Prevent API abuse with request throttling",
|
| 40 |
+
"Setup Grafana dashboards for key metrics",
|
| 41 |
+
"Memory usage grows unbounded after 1000 requests",
|
| 42 |
+
"Several packages have critical CVEs",
|
| 43 |
+
"Add structured JSON logging to all routes",
|
| 44 |
+
"Build an admin dashboard for user management",
|
| 45 |
+
"Add Redis caching layer for frequent queries",
|
| 46 |
+
"Frontend gets blocked by CORS policy",
|
| 47 |
+
"Sanitize user inputs to prevent injection",
|
| 48 |
+
]
|
| 49 |
+
STATUSES = ["pending", "in_progress", "done"]
|
| 50 |
+
PRIORITIES = ["low", "medium", "high"]
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class Database:
|
| 54 |
+
"""Thread-safe in-memory SQLite database that can be reset between episodes.
|
| 55 |
+
|
| 56 |
+
When a seed is provided, the database is populated with deterministically
|
| 57 |
+
randomized data β different users, tasks, and distributions each time.
|
| 58 |
+
This prevents the agent from memorizing a single fixed dataset.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
def __init__(self, seed: int | None = None):
|
| 62 |
+
self._lock = threading.Lock()
|
| 63 |
+
self._conn: sqlite3.Connection | None = None
|
| 64 |
+
self._seed = seed
|
| 65 |
+
self.initialize()
|
| 66 |
+
|
| 67 |
+
def initialize(self):
|
| 68 |
+
"""Create a fresh database with schema and seed data."""
|
| 69 |
+
with self._lock:
|
| 70 |
+
if self._conn:
|
| 71 |
+
self._conn.close()
|
| 72 |
+
self._conn = sqlite3.connect(":memory:", check_same_thread=False)
|
| 73 |
+
self._conn.row_factory = sqlite3.Row
|
| 74 |
+
self._conn.execute("PRAGMA journal_mode=WAL")
|
| 75 |
+
self._create_schema()
|
| 76 |
+
self._seed_data()
|
| 77 |
+
|
| 78 |
+
def _create_schema(self):
|
| 79 |
+
cursor = self._conn.cursor()
|
| 80 |
+
cursor.executescript("""
|
| 81 |
+
CREATE TABLE IF NOT EXISTS users (
|
| 82 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 83 |
+
username TEXT UNIQUE NOT NULL,
|
| 84 |
+
email TEXT NOT NULL,
|
| 85 |
+
password_hash TEXT NOT NULL,
|
| 86 |
+
role TEXT DEFAULT 'user',
|
| 87 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
| 88 |
+
);
|
| 89 |
+
|
| 90 |
+
CREATE TABLE IF NOT EXISTS tasks (
|
| 91 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 92 |
+
title TEXT NOT NULL,
|
| 93 |
+
description TEXT DEFAULT '',
|
| 94 |
+
status TEXT DEFAULT 'pending',
|
| 95 |
+
priority TEXT DEFAULT 'medium',
|
| 96 |
+
assignee_email TEXT DEFAULT '',
|
| 97 |
+
owner_id INTEGER NOT NULL,
|
| 98 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 99 |
+
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 100 |
+
FOREIGN KEY (owner_id) REFERENCES users(id)
|
| 101 |
+
);
|
| 102 |
+
|
| 103 |
+
CREATE TABLE IF NOT EXISTS auth_tokens (
|
| 104 |
+
token TEXT PRIMARY KEY,
|
| 105 |
+
user_id INTEGER NOT NULL,
|
| 106 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 107 |
+
expires_at TIMESTAMP,
|
| 108 |
+
FOREIGN KEY (user_id) REFERENCES users(id)
|
| 109 |
+
);
|
| 110 |
+
""")
|
| 111 |
+
self._conn.commit()
|
| 112 |
+
|
| 113 |
+
def _seed_data(self):
|
| 114 |
+
"""Seed the database with randomized data based on the seed.
|
| 115 |
+
|
| 116 |
+
With seed=None, uses a fixed default dataset (for manual testing).
|
| 117 |
+
With a seed, generates random users/tasks so every episode differs.
|
| 118 |
+
"""
|
| 119 |
+
rng = random.Random(self._seed)
|
| 120 |
+
cursor = self._conn.cursor()
|
| 121 |
+
|
| 122 |
+
if self._seed is None:
|
| 123 |
+
# Default fixed data for manual testing / Gradio UI
|
| 124 |
+
cursor.executescript("""
|
| 125 |
+
INSERT INTO users (username, email, password_hash, role) VALUES
|
| 126 |
+
('alice', 'alice@example.com', 'hashed_password123', 'admin'),
|
| 127 |
+
('bob', 'bob@example.com', 'hashed_password123', 'user'),
|
| 128 |
+
('charlie', 'charlie@example.com', 'hashed_password123', 'user');
|
| 129 |
+
|
| 130 |
+
INSERT INTO tasks (title, description, status, priority, assignee_email, owner_id) VALUES
|
| 131 |
+
('Setup CI/CD pipeline', 'Configure GitHub Actions', 'in_progress', 'high', 'alice@example.com', 1),
|
| 132 |
+
('Write unit tests', 'Add tests for auth module', 'pending', 'medium', 'bob@example.com', 2),
|
| 133 |
+
('Fix login page CSS', 'Button alignment issue', 'done', 'low', 'charlie@example.com', 3),
|
| 134 |
+
('Database migration', 'Migrate to PostgreSQL', 'pending', 'high', 'alice@example.com', 1),
|
| 135 |
+
('API documentation', 'Document all endpoints', 'in_progress', 'medium', 'bob@example.com', 2);
|
| 136 |
+
""")
|
| 137 |
+
else:
|
| 138 |
+
# Randomized data β different every episode
|
| 139 |
+
# Pick 3-5 users from the name pool
|
| 140 |
+
num_users = rng.randint(3, 5)
|
| 141 |
+
user_names = rng.sample(FIRST_NAMES, num_users)
|
| 142 |
+
domain = rng.choice(DOMAINS)
|
| 143 |
+
|
| 144 |
+
# First user is always admin, rest are regular users
|
| 145 |
+
for i, name in enumerate(user_names):
|
| 146 |
+
role = "admin" if i == 0 else "user"
|
| 147 |
+
email = f"{name}@{domain}"
|
| 148 |
+
cursor.execute(
|
| 149 |
+
"INSERT INTO users (username, email, password_hash, role) VALUES (?, ?, ?, ?)",
|
| 150 |
+
(name, email, f"hashed_password_{rng.randint(100, 999)}", role),
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
# Pick 4-8 tasks with random assignments
|
| 154 |
+
num_tasks = rng.randint(4, 8)
|
| 155 |
+
task_titles = rng.sample(TASK_TITLES, min(num_tasks, len(TASK_TITLES)))
|
| 156 |
+
task_descs = rng.sample(TASK_DESCRIPTIONS, min(num_tasks, len(TASK_DESCRIPTIONS)))
|
| 157 |
+
|
| 158 |
+
for i in range(num_tasks):
|
| 159 |
+
owner_id = rng.randint(1, num_users)
|
| 160 |
+
assignee_id = rng.randint(1, num_users)
|
| 161 |
+
assignee_email = f"{user_names[assignee_id - 1]}@{domain}"
|
| 162 |
+
cursor.execute(
|
| 163 |
+
"INSERT INTO tasks (title, description, status, priority, assignee_email, owner_id) VALUES (?, ?, ?, ?, ?, ?)",
|
| 164 |
+
(
|
| 165 |
+
task_titles[i % len(task_titles)],
|
| 166 |
+
task_descs[i % len(task_descs)] if i < len(task_descs) else "",
|
| 167 |
+
rng.choice(STATUSES),
|
| 168 |
+
rng.choice(PRIORITIES),
|
| 169 |
+
assignee_email,
|
| 170 |
+
owner_id,
|
| 171 |
+
),
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
self._conn.commit()
|
| 175 |
+
|
| 176 |
+
@property
|
| 177 |
+
def user_names(self) -> list[str]:
|
| 178 |
+
"""Get usernames in the database (for the agent's observation)."""
|
| 179 |
+
rows = self.execute("SELECT username FROM users ORDER BY id")
|
| 180 |
+
return [r["username"] for r in rows]
|
| 181 |
+
|
| 182 |
+
@contextmanager
|
| 183 |
+
def get_cursor(self):
|
| 184 |
+
with self._lock:
|
| 185 |
+
cursor = self._conn.cursor()
|
| 186 |
+
try:
|
| 187 |
+
yield cursor
|
| 188 |
+
self._conn.commit()
|
| 189 |
+
except Exception:
|
| 190 |
+
self._conn.rollback()
|
| 191 |
+
raise
|
| 192 |
+
|
| 193 |
+
def execute(self, query: str, params: tuple = ()) -> list[dict]:
|
| 194 |
+
with self.get_cursor() as cursor:
|
| 195 |
+
cursor.execute(query, params)
|
| 196 |
+
if cursor.description:
|
| 197 |
+
columns = [col[0] for col in cursor.description]
|
| 198 |
+
return [dict(zip(columns, row)) for row in cursor.fetchall()]
|
| 199 |
+
return []
|
| 200 |
+
|
| 201 |
+
def execute_insert(self, query: str, params: tuple = ()) -> int:
|
| 202 |
+
with self.get_cursor() as cursor:
|
| 203 |
+
cursor.execute(query, params)
|
| 204 |
+
return cursor.lastrowid
|
| 205 |
+
|
| 206 |
+
def execute_update(self, query: str, params: tuple = ()) -> int:
|
| 207 |
+
with self.get_cursor() as cursor:
|
| 208 |
+
cursor.execute(query, params)
|
| 209 |
+
return cursor.rowcount
|
server/buggy_api/main.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The deliberately buggy REST API β a task management system.
|
| 3 |
+
|
| 4 |
+
This API is the system-under-test. It has intentionally planted bugs at varying
|
| 5 |
+
difficulty levels that the AI agent must discover through intelligent testing.
|
| 6 |
+
|
| 7 |
+
The API runs in-process via Starlette's TestClient (no separate port needed).
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import json
|
| 11 |
+
import logging
|
| 12 |
+
|
| 13 |
+
from fastapi import FastAPI, Request, Header
|
| 14 |
+
from fastapi.responses import JSONResponse
|
| 15 |
+
from typing import Optional
|
| 16 |
+
|
| 17 |
+
from .database import Database
|
| 18 |
+
from .routes import tasks as tasks_routes
|
| 19 |
+
from .routes import users as users_routes
|
| 20 |
+
from .routes import auth as auth_routes
|
| 21 |
+
from .models import TaskCreate
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def create_buggy_api(db: Database) -> FastAPI:
|
| 27 |
+
"""Create a fresh buggy API instance wired to the given database."""
|
| 28 |
+
api = FastAPI(
|
| 29 |
+
title="TaskTracker API",
|
| 30 |
+
description="A task management API (with bugs)",
|
| 31 |
+
version="1.0.0",
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# Wire database into route modules
|
| 35 |
+
tasks_routes.set_db(db)
|
| 36 |
+
users_routes.set_db(db)
|
| 37 |
+
auth_routes.set_db(db)
|
| 38 |
+
|
| 39 |
+
# Include standard routes
|
| 40 |
+
api.include_router(tasks_routes.router)
|
| 41 |
+
api.include_router(users_routes.router)
|
| 42 |
+
api.include_router(auth_routes.router)
|
| 43 |
+
|
| 44 |
+
# BUG_TASK_02 + BUG_TASK_08: Raw POST /tasks handler that doesn't use Pydantic validation
|
| 45 |
+
# This allows missing fields and overly long inputs to cause 500 errors
|
| 46 |
+
@api.post("/tasks", status_code=201)
|
| 47 |
+
async def create_task_raw(
|
| 48 |
+
request: Request,
|
| 49 |
+
authorization: Optional[str] = Header(None),
|
| 50 |
+
):
|
| 51 |
+
try:
|
| 52 |
+
body = await request.json()
|
| 53 |
+
except Exception:
|
| 54 |
+
# BUG_TASK_02: Returns 500 on malformed/empty body instead of 400
|
| 55 |
+
raise Exception("Failed to parse request body")
|
| 56 |
+
|
| 57 |
+
if not isinstance(body, dict):
|
| 58 |
+
raise Exception("Invalid body format")
|
| 59 |
+
|
| 60 |
+
title = body.get("title")
|
| 61 |
+
|
| 62 |
+
# BUG_TASK_02: No check for missing title β causes KeyError/500 below
|
| 63 |
+
if title is None:
|
| 64 |
+
# This SHOULD return 400, but we let it fall through to cause 500
|
| 65 |
+
# Simulate an internal error from missing required field
|
| 66 |
+
raise Exception("Internal error: title is required but was None")
|
| 67 |
+
|
| 68 |
+
# BUG_TASK_08: No length validation on title
|
| 69 |
+
if len(title) > 5000:
|
| 70 |
+
# Simulate a database error from overly long input
|
| 71 |
+
raise Exception(f"Database error: value too long for column 'title' (length={len(title)})")
|
| 72 |
+
|
| 73 |
+
task_data = TaskCreate(
|
| 74 |
+
title=title,
|
| 75 |
+
description=body.get("description", ""),
|
| 76 |
+
status=body.get("status", "pending"),
|
| 77 |
+
priority=body.get("priority", "medium"),
|
| 78 |
+
assignee_email=body.get("assignee_email", ""),
|
| 79 |
+
)
|
| 80 |
+
return tasks_routes.create_task_internal(task_data, authorization)
|
| 81 |
+
|
| 82 |
+
# Global error handler β returns 500 for unhandled exceptions
|
| 83 |
+
@api.exception_handler(Exception)
|
| 84 |
+
async def global_exception_handler(request: Request, exc: Exception):
|
| 85 |
+
logger.error(f"Unhandled error: {exc}")
|
| 86 |
+
return JSONResponse(
|
| 87 |
+
status_code=500,
|
| 88 |
+
content={"error": "Internal Server Error", "detail": str(exc)},
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
return api
|
server/buggy_api/models.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pydantic models for the buggy API request/response schemas."""
|
| 2 |
+
|
| 3 |
+
from pydantic import BaseModel, Field
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class UserCreate(BaseModel):
|
| 8 |
+
username: str
|
| 9 |
+
email: str
|
| 10 |
+
password: str
|
| 11 |
+
role: str = "user"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class UserResponse(BaseModel):
|
| 15 |
+
id: int
|
| 16 |
+
username: str
|
| 17 |
+
email: str
|
| 18 |
+
role: str
|
| 19 |
+
created_at: str
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class TaskCreate(BaseModel):
|
| 23 |
+
title: str
|
| 24 |
+
description: str = ""
|
| 25 |
+
status: str = "pending"
|
| 26 |
+
priority: str = "medium"
|
| 27 |
+
assignee_email: str = ""
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class TaskUpdate(BaseModel):
|
| 31 |
+
title: Optional[str] = None
|
| 32 |
+
description: Optional[str] = None
|
| 33 |
+
status: Optional[str] = None
|
| 34 |
+
priority: Optional[str] = None
|
| 35 |
+
assignee_email: Optional[str] = None
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class TaskResponse(BaseModel):
|
| 39 |
+
id: int
|
| 40 |
+
title: str
|
| 41 |
+
description: str
|
| 42 |
+
status: str
|
| 43 |
+
priority: str
|
| 44 |
+
assignee_email: str
|
| 45 |
+
owner_id: int
|
| 46 |
+
created_at: str
|
| 47 |
+
updated_at: str
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class LoginRequest(BaseModel):
|
| 51 |
+
username: str
|
| 52 |
+
password: str
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class LoginResponse(BaseModel):
|
| 56 |
+
token: str
|
| 57 |
+
user_id: int
|
| 58 |
+
username: str
|
| 59 |
+
role: str
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class ErrorResponse(BaseModel):
|
| 63 |
+
error: str
|
| 64 |
+
detail: str = ""
|
server/buggy_api/routes/__init__.py
ADDED
|
File without changes
|
server/buggy_api/routes/auth.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Authentication routes with planted bugs.
|
| 3 |
+
|
| 4 |
+
BUGS PLANTED:
|
| 5 |
+
- BUG_AUTH_01 (hard): Auth tokens are not user-scoped β any valid token works for any user's resources
|
| 6 |
+
- BUG_AUTH_02 (medium): Login with empty password succeeds (missing validation)
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import uuid
|
| 10 |
+
from datetime import datetime, timedelta
|
| 11 |
+
|
| 12 |
+
from fastapi import APIRouter, Depends, Header, HTTPException
|
| 13 |
+
from typing import Optional
|
| 14 |
+
|
| 15 |
+
from ..database import Database
|
| 16 |
+
from ..models import LoginRequest, LoginResponse
|
| 17 |
+
|
| 18 |
+
router = APIRouter(prefix="/auth", tags=["auth"])
|
| 19 |
+
|
| 20 |
+
_db: Database | None = None
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def set_db(db: Database):
|
| 24 |
+
global _db
|
| 25 |
+
_db = db
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def get_db() -> Database:
|
| 29 |
+
return _db
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def get_current_user(authorization: Optional[str] = Header(None)) -> dict | None:
|
| 33 |
+
"""Extract user from auth token.
|
| 34 |
+
|
| 35 |
+
BUG_AUTH_01: Returns the token's user but doesn't enforce ownership anywhere.
|
| 36 |
+
The routes that use this don't check if the resource belongs to the user.
|
| 37 |
+
"""
|
| 38 |
+
if not authorization:
|
| 39 |
+
return None
|
| 40 |
+
token = authorization.replace("Bearer ", "")
|
| 41 |
+
db = get_db()
|
| 42 |
+
rows = db.execute(
|
| 43 |
+
"SELECT u.id, u.username, u.role FROM auth_tokens t JOIN users u ON t.user_id = u.id WHERE t.token = ?",
|
| 44 |
+
(token,),
|
| 45 |
+
)
|
| 46 |
+
if not rows:
|
| 47 |
+
return None
|
| 48 |
+
return rows[0]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@router.post("/login", response_model=LoginResponse)
|
| 52 |
+
def login(req: LoginRequest):
|
| 53 |
+
db = get_db()
|
| 54 |
+
|
| 55 |
+
# BUG_AUTH_02: Empty password check is missing β empty password matches hash
|
| 56 |
+
# Should validate: if not req.password: raise HTTPException(400, ...)
|
| 57 |
+
rows = db.execute(
|
| 58 |
+
"SELECT id, username, role, password_hash FROM users WHERE username = ?",
|
| 59 |
+
(req.username,),
|
| 60 |
+
)
|
| 61 |
+
if not rows:
|
| 62 |
+
raise HTTPException(status_code=401, detail="Invalid credentials")
|
| 63 |
+
|
| 64 |
+
user = rows[0]
|
| 65 |
+
# BUG_AUTH_02 continued: Only checks username, not password properly
|
| 66 |
+
# In a real system we'd verify the password hash
|
| 67 |
+
# Here we just check if password is non-empty... but we don't!
|
| 68 |
+
# Any password (including empty string) works as long as username exists.
|
| 69 |
+
|
| 70 |
+
token = str(uuid.uuid4())
|
| 71 |
+
expires = datetime.utcnow() + timedelta(hours=24)
|
| 72 |
+
db.execute_insert(
|
| 73 |
+
"INSERT INTO auth_tokens (token, user_id, expires_at) VALUES (?, ?, ?)",
|
| 74 |
+
(token, user["id"], expires.isoformat()),
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
return LoginResponse(
|
| 78 |
+
token=token,
|
| 79 |
+
user_id=user["id"],
|
| 80 |
+
username=user["username"],
|
| 81 |
+
role=user["role"],
|
| 82 |
+
)
|
server/buggy_api/routes/tasks.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Task CRUD routes with planted bugs.
|
| 3 |
+
|
| 4 |
+
BUGS PLANTED:
|
| 5 |
+
- BUG_TASK_01 (easy): GET /tasks/{id} returns 200 with null body for non-existent task (should be 404)
|
| 6 |
+
- BUG_TASK_02 (easy): POST /tasks with missing required 'title' returns 500 instead of 400/422
|
| 7 |
+
- BUG_TASK_03 (easy): GET /tasks?page=-1 returns 200 instead of 400
|
| 8 |
+
- BUG_TASK_04 (medium): PUT /tasks/{id} doesn't validate assignee_email format
|
| 9 |
+
- BUG_TASK_05 (medium): DELETE /tasks/{id} returns 200 even for non-existent task (should be 404)
|
| 10 |
+
- BUG_TASK_06 (medium): GET /tasks?limit=999999 has no pagination cap (potential DoS)
|
| 11 |
+
- BUG_TASK_07 (hard): GET /tasks/{id} of another user's task returns data (BOLA/IDOR vulnerability)
|
| 12 |
+
- BUG_TASK_08 (hard): POST /tasks with very long title (>5000 chars) causes 500 (no input length validation)
|
| 13 |
+
- BUG_TASK_09 (hard): POST /tasks with SQL injection payload in title doesn't sanitize (uses parameterized
|
| 14 |
+
queries so no actual injection, but the input is stored verbatim β a content injection)
|
| 15 |
+
- BUG_TASK_10 (hard): No rate limiting β rapid sequential requests all succeed
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from fastapi import APIRouter, HTTPException, Header, Query
|
| 19 |
+
from typing import Optional
|
| 20 |
+
|
| 21 |
+
from ..database import Database
|
| 22 |
+
from ..models import TaskCreate, TaskUpdate
|
| 23 |
+
|
| 24 |
+
router = APIRouter(prefix="/tasks", tags=["tasks"])
|
| 25 |
+
|
| 26 |
+
_db: Database | None = None
|
| 27 |
+
|
| 28 |
+
# Simple in-memory cache for BUG demonstration
|
| 29 |
+
_cache: dict[int, dict] = {}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def set_db(db: Database):
|
| 33 |
+
global _db, _cache
|
| 34 |
+
_db = db
|
| 35 |
+
_cache = {}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_db() -> Database:
|
| 39 |
+
return _db
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@router.get("")
|
| 43 |
+
def list_tasks(
|
| 44 |
+
status: Optional[str] = Query(None, description="Filter by status"),
|
| 45 |
+
priority: Optional[str] = Query(None, description="Filter by priority"),
|
| 46 |
+
sort: Optional[str] = Query(None, description="Sort field"),
|
| 47 |
+
page: Optional[int] = Query(None, description="Page number"),
|
| 48 |
+
limit: Optional[int] = Query(None, description="Items per page"),
|
| 49 |
+
authorization: Optional[str] = Header(None),
|
| 50 |
+
):
|
| 51 |
+
db = get_db()
|
| 52 |
+
|
| 53 |
+
# BUG_TASK_03: No validation for negative page numbers
|
| 54 |
+
# Should check: if page is not None and page < 1: raise HTTPException(400, ...)
|
| 55 |
+
|
| 56 |
+
# BUG_TASK_06: No cap on limit β agent can request limit=999999
|
| 57 |
+
# Should cap at e.g. 100
|
| 58 |
+
|
| 59 |
+
query = "SELECT * FROM tasks WHERE 1=1"
|
| 60 |
+
params = []
|
| 61 |
+
|
| 62 |
+
if status:
|
| 63 |
+
query += " AND status = ?"
|
| 64 |
+
params.append(status)
|
| 65 |
+
if priority:
|
| 66 |
+
query += " AND priority = ?"
|
| 67 |
+
params.append(priority)
|
| 68 |
+
|
| 69 |
+
if sort:
|
| 70 |
+
allowed_sorts = ["created_at", "updated_at", "title", "priority", "status"]
|
| 71 |
+
if sort in allowed_sorts:
|
| 72 |
+
query += f" ORDER BY {sort}"
|
| 73 |
+
else:
|
| 74 |
+
query += " ORDER BY created_at"
|
| 75 |
+
else:
|
| 76 |
+
query += " ORDER BY created_at DESC"
|
| 77 |
+
|
| 78 |
+
if limit is not None:
|
| 79 |
+
# BUG_TASK_06: No upper bound check on limit
|
| 80 |
+
query += " LIMIT ?"
|
| 81 |
+
params.append(limit)
|
| 82 |
+
else:
|
| 83 |
+
query += " LIMIT 20"
|
| 84 |
+
|
| 85 |
+
if page is not None and limit is not None:
|
| 86 |
+
# BUG_TASK_03: Allows negative offset β page=-1 with limit=10 gives offset=-10
|
| 87 |
+
offset = (page - 1) * limit
|
| 88 |
+
query += " OFFSET ?"
|
| 89 |
+
params.append(offset)
|
| 90 |
+
|
| 91 |
+
rows = db.execute(query, tuple(params))
|
| 92 |
+
return rows
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
@router.get("/{task_id}")
|
| 96 |
+
def get_task(
|
| 97 |
+
task_id: int,
|
| 98 |
+
authorization: Optional[str] = Header(None),
|
| 99 |
+
):
|
| 100 |
+
db = get_db()
|
| 101 |
+
|
| 102 |
+
# Check cache first (used later for stale cache bug)
|
| 103 |
+
if task_id in _cache:
|
| 104 |
+
return _cache[task_id]
|
| 105 |
+
|
| 106 |
+
rows = db.execute("SELECT * FROM tasks WHERE id = ?", (task_id,))
|
| 107 |
+
|
| 108 |
+
# BUG_TASK_01: Returns 200 with null instead of 404
|
| 109 |
+
if not rows:
|
| 110 |
+
return None # Should be: raise HTTPException(status_code=404, detail="Task not found")
|
| 111 |
+
|
| 112 |
+
task = rows[0]
|
| 113 |
+
|
| 114 |
+
# BUG_TASK_07: No ownership check β any authenticated user can see any task
|
| 115 |
+
# Should check: if user and task["owner_id"] != user["id"]: raise HTTPException(403)
|
| 116 |
+
|
| 117 |
+
# Cache the result
|
| 118 |
+
_cache[task_id] = task
|
| 119 |
+
return task
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@router.post("/create", status_code=201)
|
| 123 |
+
def create_task_internal(
|
| 124 |
+
task: TaskCreate,
|
| 125 |
+
authorization: Optional[str] = Header(None),
|
| 126 |
+
):
|
| 127 |
+
"""Internal create β used by the raw handler after parsing."""
|
| 128 |
+
db = get_db()
|
| 129 |
+
|
| 130 |
+
# BUG_TASK_08: No title length validation
|
| 131 |
+
# Should check: if len(task.title) > 200: raise HTTPException(400, ...)
|
| 132 |
+
|
| 133 |
+
# BUG_TASK_09: No content sanitization β SQL injection payloads stored verbatim
|
| 134 |
+
# While parameterized queries prevent actual SQL injection, the content
|
| 135 |
+
# is stored and returned as-is, which is a content injection / XSS vector
|
| 136 |
+
|
| 137 |
+
# Determine owner β default to user 1 if no auth
|
| 138 |
+
owner_id = 1
|
| 139 |
+
if authorization:
|
| 140 |
+
token = authorization.replace("Bearer ", "")
|
| 141 |
+
token_rows = db.execute(
|
| 142 |
+
"SELECT user_id FROM auth_tokens WHERE token = ?", (token,)
|
| 143 |
+
)
|
| 144 |
+
if token_rows:
|
| 145 |
+
owner_id = token_rows[0]["user_id"]
|
| 146 |
+
|
| 147 |
+
task_id = db.execute_insert(
|
| 148 |
+
"INSERT INTO tasks (title, description, status, priority, assignee_email, owner_id) VALUES (?, ?, ?, ?, ?, ?)",
|
| 149 |
+
(task.title, task.description, task.status, task.priority, task.assignee_email, owner_id),
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
rows = db.execute("SELECT * FROM tasks WHERE id = ?", (task_id,))
|
| 153 |
+
result = rows[0]
|
| 154 |
+
_cache[task_id] = result
|
| 155 |
+
return result
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
@router.put("/{task_id}")
|
| 159 |
+
def update_task(
|
| 160 |
+
task_id: int,
|
| 161 |
+
task: TaskUpdate,
|
| 162 |
+
authorization: Optional[str] = Header(None),
|
| 163 |
+
):
|
| 164 |
+
db = get_db()
|
| 165 |
+
|
| 166 |
+
existing = db.execute("SELECT * FROM tasks WHERE id = ?", (task_id,))
|
| 167 |
+
if not existing:
|
| 168 |
+
raise HTTPException(status_code=404, detail="Task not found")
|
| 169 |
+
|
| 170 |
+
# BUG_TASK_04: No email format validation on assignee_email
|
| 171 |
+
# Should validate if task.assignee_email is provided
|
| 172 |
+
|
| 173 |
+
# BUG_TASK_07: No ownership check on update either
|
| 174 |
+
updates = []
|
| 175 |
+
params = []
|
| 176 |
+
for field_name in ["title", "description", "status", "priority", "assignee_email"]:
|
| 177 |
+
value = getattr(task, field_name, None)
|
| 178 |
+
if value is not None:
|
| 179 |
+
updates.append(f"{field_name} = ?")
|
| 180 |
+
params.append(value)
|
| 181 |
+
|
| 182 |
+
if updates:
|
| 183 |
+
updates.append("updated_at = CURRENT_TIMESTAMP")
|
| 184 |
+
params.append(task_id)
|
| 185 |
+
db.execute_update(
|
| 186 |
+
f"UPDATE tasks SET {', '.join(updates)} WHERE id = ?",
|
| 187 |
+
tuple(params),
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
rows = db.execute("SELECT * FROM tasks WHERE id = ?", (task_id,))
|
| 191 |
+
result = rows[0]
|
| 192 |
+
_cache[task_id] = result
|
| 193 |
+
return result
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
@router.delete("/{task_id}")
|
| 197 |
+
def delete_task(
|
| 198 |
+
task_id: int,
|
| 199 |
+
authorization: Optional[str] = Header(None),
|
| 200 |
+
):
|
| 201 |
+
db = get_db()
|
| 202 |
+
|
| 203 |
+
# BUG_TASK_05: No existence check β returns 200 even for non-existent tasks
|
| 204 |
+
# Should check existence first and return 404
|
| 205 |
+
db.execute_update("DELETE FROM tasks WHERE id = ?", (task_id,))
|
| 206 |
+
|
| 207 |
+
# Note: cache is NOT cleared β this enables stale cache detection
|
| 208 |
+
# (BUG_TASK_01 variant: deleted task still returned from cache)
|
| 209 |
+
|
| 210 |
+
return {"message": "Task deleted", "id": task_id}
|
server/buggy_api/routes/users.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
User management routes with planted bugs.
|
| 3 |
+
|
| 4 |
+
BUGS PLANTED:
|
| 5 |
+
- BUG_USER_01 (medium): POST /users doesn't validate email format
|
| 6 |
+
- BUG_USER_02 (medium): GET /users exposes password hashes in response
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from fastapi import APIRouter, HTTPException
|
| 10 |
+
|
| 11 |
+
from ..database import Database
|
| 12 |
+
from ..models import UserCreate
|
| 13 |
+
|
| 14 |
+
router = APIRouter(prefix="/users", tags=["users"])
|
| 15 |
+
|
| 16 |
+
_db: Database | None = None
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def set_db(db: Database):
|
| 20 |
+
global _db
|
| 21 |
+
_db = db
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def get_db() -> Database:
|
| 25 |
+
return _db
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@router.get("")
|
| 29 |
+
def list_users():
|
| 30 |
+
db = get_db()
|
| 31 |
+
rows = db.execute("SELECT id, username, email, role, created_at FROM users")
|
| 32 |
+
return rows
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@router.get("/{user_id}")
|
| 36 |
+
def get_user(user_id: int):
|
| 37 |
+
db = get_db()
|
| 38 |
+
rows = db.execute("SELECT id, username, email, role, created_at FROM users WHERE id = ?", (user_id,))
|
| 39 |
+
if not rows:
|
| 40 |
+
raise HTTPException(status_code=404, detail="User not found")
|
| 41 |
+
return rows[0]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@router.post("", status_code=201)
|
| 45 |
+
def create_user(user: UserCreate):
|
| 46 |
+
db = get_db()
|
| 47 |
+
|
| 48 |
+
# BUG_USER_01: No email format validation β accepts "not-an-email" or empty string
|
| 49 |
+
# Should validate email with regex or pydantic EmailStr
|
| 50 |
+
|
| 51 |
+
# Check username uniqueness
|
| 52 |
+
existing = db.execute("SELECT id FROM users WHERE username = ?", (user.username,))
|
| 53 |
+
if existing:
|
| 54 |
+
raise HTTPException(status_code=409, detail="Username already exists")
|
| 55 |
+
|
| 56 |
+
user_id = db.execute_insert(
|
| 57 |
+
"INSERT INTO users (username, email, password_hash, role) VALUES (?, ?, ?, ?)",
|
| 58 |
+
(user.username, user.email, f"hashed_{user.password}", user.role),
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# BUG_USER_02: Response includes password_hash field
|
| 62 |
+
rows = db.execute("SELECT * FROM users WHERE id = ?", (user_id,))
|
| 63 |
+
return rows[0]
|
server/environment.py
ADDED
|
@@ -0,0 +1,438 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OpenEnv Environment for API Integration Testing.
|
| 3 |
+
|
| 4 |
+
The agent interacts with a deliberately buggy REST API, discovering endpoints,
|
| 5 |
+
crafting requests, and finding bugs. Rewards are multi-signal: coverage,
|
| 6 |
+
validity, bug discovery, and exploration.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
import random
|
| 11 |
+
import time
|
| 12 |
+
import json
|
| 13 |
+
from typing import Any, Optional
|
| 14 |
+
|
| 15 |
+
from fastapi.testclient import TestClient
|
| 16 |
+
from openenv.core.env_server.interfaces import Environment
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
from ..models import APITestAction, APITestObservation, APITestState
|
| 20 |
+
except ImportError:
|
| 21 |
+
from models import APITestAction, APITestObservation, APITestState
|
| 22 |
+
|
| 23 |
+
from .buggy_api.database import Database
|
| 24 |
+
from .buggy_api.main import create_buggy_api
|
| 25 |
+
from .bug_detector import BugDetector
|
| 26 |
+
from .reward import RewardComputer
|
| 27 |
+
from .graders import TaskGrader, generate_bug_report
|
| 28 |
+
from .graders import TaskGrader
|
| 29 |
+
|
| 30 |
+
logger = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
# Task definitions
|
| 33 |
+
TASKS = {
|
| 34 |
+
"basic_validation": {
|
| 35 |
+
"id": "basic_validation",
|
| 36 |
+
"description": (
|
| 37 |
+
"Test all CRUD endpoints with valid inputs and verify correct status codes. "
|
| 38 |
+
"Find basic bugs like wrong status codes and missing field handling. "
|
| 39 |
+
"Available endpoints: GET /tasks, POST /tasks, GET /tasks/{id}, PUT /tasks/{id}, "
|
| 40 |
+
"DELETE /tasks/{id}, GET /users, POST /users, POST /auth/login. "
|
| 41 |
+
"Try different methods on each endpoint and verify responses match the expected behavior."
|
| 42 |
+
),
|
| 43 |
+
"difficulty": "easy",
|
| 44 |
+
"max_steps": 25,
|
| 45 |
+
"total_bugs": 3,
|
| 46 |
+
},
|
| 47 |
+
"edge_cases": {
|
| 48 |
+
"id": "edge_cases",
|
| 49 |
+
"description": (
|
| 50 |
+
"Test boundary conditions, invalid inputs, and error responses. "
|
| 51 |
+
"Send missing fields, wrong types, negative page numbers, huge limits. "
|
| 52 |
+
"Test with non-existent resource IDs (e.g., /tasks/999999). "
|
| 53 |
+
"Chain operations: create a resource, then read/update/delete it. "
|
| 54 |
+
"Find bugs in input validation, pagination, and error handling."
|
| 55 |
+
),
|
| 56 |
+
"difficulty": "medium",
|
| 57 |
+
"max_steps": 35,
|
| 58 |
+
"total_bugs": 9,
|
| 59 |
+
},
|
| 60 |
+
"security_workflows": {
|
| 61 |
+
"id": "security_workflows",
|
| 62 |
+
"description": (
|
| 63 |
+
"Discover authorization flaws, injection vulnerabilities, and workflow bugs. "
|
| 64 |
+
"Login as different users (alice/password, bob/password, charlie/password) and "
|
| 65 |
+
"try accessing each other's resources. Test SQL injection patterns in input fields. "
|
| 66 |
+
"Execute multi-step workflows: create -> modify -> verify -> delete -> re-fetch. "
|
| 67 |
+
"Check if auth tokens properly scope access. Test with very long inputs."
|
| 68 |
+
),
|
| 69 |
+
"difficulty": "hard",
|
| 70 |
+
"max_steps": 45,
|
| 71 |
+
"total_bugs": 13,
|
| 72 |
+
},
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
# OpenAPI-like spec for the agent
|
| 76 |
+
API_SPEC = [
|
| 77 |
+
{
|
| 78 |
+
"method": "GET",
|
| 79 |
+
"path": "/tasks",
|
| 80 |
+
"summary": "List all tasks. Supports filtering by status, priority; pagination with page & limit; sorting with sort.",
|
| 81 |
+
"parameters": [
|
| 82 |
+
{"name": "status", "in": "query", "type": "string", "enum": ["pending", "in_progress", "done"]},
|
| 83 |
+
{"name": "priority", "in": "query", "type": "string", "enum": ["low", "medium", "high"]},
|
| 84 |
+
{"name": "sort", "in": "query", "type": "string", "enum": ["created_at", "updated_at", "title"]},
|
| 85 |
+
{"name": "page", "in": "query", "type": "integer"},
|
| 86 |
+
{"name": "limit", "in": "query", "type": "integer"},
|
| 87 |
+
],
|
| 88 |
+
},
|
| 89 |
+
{
|
| 90 |
+
"method": "POST",
|
| 91 |
+
"path": "/tasks",
|
| 92 |
+
"summary": "Create a new task. Requires 'title' field. Optional: description, status, priority, assignee_email.",
|
| 93 |
+
"request_body": {
|
| 94 |
+
"required": ["title"],
|
| 95 |
+
"properties": {
|
| 96 |
+
"title": {"type": "string"},
|
| 97 |
+
"description": {"type": "string"},
|
| 98 |
+
"status": {"type": "string", "enum": ["pending", "in_progress", "done"]},
|
| 99 |
+
"priority": {"type": "string", "enum": ["low", "medium", "high"]},
|
| 100 |
+
"assignee_email": {"type": "string", "format": "email"},
|
| 101 |
+
},
|
| 102 |
+
},
|
| 103 |
+
},
|
| 104 |
+
{
|
| 105 |
+
"method": "GET",
|
| 106 |
+
"path": "/tasks/{id}",
|
| 107 |
+
"summary": "Get a specific task by ID.",
|
| 108 |
+
"parameters": [{"name": "id", "in": "path", "type": "integer", "required": True}],
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
"method": "PUT",
|
| 112 |
+
"path": "/tasks/{id}",
|
| 113 |
+
"summary": "Update a task. All fields optional.",
|
| 114 |
+
"parameters": [{"name": "id", "in": "path", "type": "integer", "required": True}],
|
| 115 |
+
"request_body": {
|
| 116 |
+
"properties": {
|
| 117 |
+
"title": {"type": "string"},
|
| 118 |
+
"description": {"type": "string"},
|
| 119 |
+
"status": {"type": "string"},
|
| 120 |
+
"priority": {"type": "string"},
|
| 121 |
+
"assignee_email": {"type": "string", "format": "email"},
|
| 122 |
+
},
|
| 123 |
+
},
|
| 124 |
+
},
|
| 125 |
+
{
|
| 126 |
+
"method": "DELETE",
|
| 127 |
+
"path": "/tasks/{id}",
|
| 128 |
+
"summary": "Delete a task by ID.",
|
| 129 |
+
"parameters": [{"name": "id", "in": "path", "type": "integer", "required": True}],
|
| 130 |
+
},
|
| 131 |
+
{
|
| 132 |
+
"method": "GET",
|
| 133 |
+
"path": "/users",
|
| 134 |
+
"summary": "List all users.",
|
| 135 |
+
},
|
| 136 |
+
{
|
| 137 |
+
"method": "POST",
|
| 138 |
+
"path": "/users",
|
| 139 |
+
"summary": "Create a new user. Requires username, email, password.",
|
| 140 |
+
"request_body": {
|
| 141 |
+
"required": ["username", "email", "password"],
|
| 142 |
+
"properties": {
|
| 143 |
+
"username": {"type": "string"},
|
| 144 |
+
"email": {"type": "string", "format": "email"},
|
| 145 |
+
"password": {"type": "string"},
|
| 146 |
+
"role": {"type": "string", "enum": ["user", "admin"]},
|
| 147 |
+
},
|
| 148 |
+
},
|
| 149 |
+
},
|
| 150 |
+
{
|
| 151 |
+
"method": "GET",
|
| 152 |
+
"path": "/users/{id}",
|
| 153 |
+
"summary": "Get a specific user by ID.",
|
| 154 |
+
"parameters": [{"name": "id", "in": "path", "type": "integer", "required": True}],
|
| 155 |
+
},
|
| 156 |
+
{
|
| 157 |
+
"method": "POST",
|
| 158 |
+
"path": "/auth/login",
|
| 159 |
+
"summary": "Login and receive an auth token. Pre-seeded users: alice, bob, charlie (password: any string).",
|
| 160 |
+
"request_body": {
|
| 161 |
+
"required": ["username", "password"],
|
| 162 |
+
"properties": {
|
| 163 |
+
"username": {"type": "string"},
|
| 164 |
+
"password": {"type": "string"},
|
| 165 |
+
},
|
| 166 |
+
},
|
| 167 |
+
},
|
| 168 |
+
]
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class APITestEnvironment(Environment):
|
| 172 |
+
"""OpenEnv environment for API integration testing.
|
| 173 |
+
|
| 174 |
+
The agent tests a deliberately buggy REST API by sending HTTP requests
|
| 175 |
+
and analyzing responses. It earns rewards for coverage, finding bugs,
|
| 176 |
+
and exploring edge cases.
|
| 177 |
+
"""
|
| 178 |
+
|
| 179 |
+
SUPPORTS_CONCURRENT_SESSIONS = False
|
| 180 |
+
|
| 181 |
+
def __init__(self, **kwargs):
|
| 182 |
+
super().__init__(**kwargs)
|
| 183 |
+
self._db: Optional[Database] = None
|
| 184 |
+
self._api: Optional[TestClient] = None
|
| 185 |
+
self._bug_detector: Optional[BugDetector] = None
|
| 186 |
+
self._reward_computer: Optional[RewardComputer] = None
|
| 187 |
+
self._task: Optional[dict] = None
|
| 188 |
+
self._found_bugs: set[str] = set()
|
| 189 |
+
self._steps_taken: int = 0
|
| 190 |
+
self._cumulative_reward: float = 0.0
|
| 191 |
+
self._action_history: list[dict] = []
|
| 192 |
+
self._auth_tokens: dict[str, str] = {}
|
| 193 |
+
self._episode_id: str = ""
|
| 194 |
+
|
| 195 |
+
def reset(self, seed=None, episode_id=None, **kwargs) -> APITestObservation:
|
| 196 |
+
"""Reset the environment for a new episode.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
seed: Random seed for domain randomization. When provided, the
|
| 200 |
+
database is populated with different users, tasks, and data
|
| 201 |
+
so each training episode is unique. None = fixed default data.
|
| 202 |
+
episode_id: Optional episode identifier for tracking.
|
| 203 |
+
|
| 204 |
+
kwargs:
|
| 205 |
+
task_id: str - one of "basic_validation", "edge_cases", "security_workflows"
|
| 206 |
+
"""
|
| 207 |
+
task_id = kwargs.get("task_id", "basic_validation")
|
| 208 |
+
if task_id not in TASKS:
|
| 209 |
+
task_id = "basic_validation"
|
| 210 |
+
|
| 211 |
+
self._task = TASKS[task_id]
|
| 212 |
+
self._seed = seed
|
| 213 |
+
self._episode_id = episode_id or f"ep_{int(time.time())}"
|
| 214 |
+
|
| 215 |
+
# Reset database with seed for domain randomization
|
| 216 |
+
# seed=None β fixed data (manual testing / Gradio)
|
| 217 |
+
# seed=int β randomized data (GRPO training)
|
| 218 |
+
self._db = Database(seed=seed)
|
| 219 |
+
buggy_app = create_buggy_api(self._db)
|
| 220 |
+
self._api = TestClient(buggy_app, raise_server_exceptions=False)
|
| 221 |
+
|
| 222 |
+
# Build dynamic task description that includes actual usernames
|
| 223 |
+
user_names = self._db.user_names
|
| 224 |
+
user_list = ", ".join(user_names)
|
| 225 |
+
dynamic_description = (
|
| 226 |
+
f"{self._task['description']} "
|
| 227 |
+
f"Users in the system: {user_list} (use any password to login)."
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# Reset tracking
|
| 231 |
+
self._bug_detector = BugDetector(task_id)
|
| 232 |
+
self._reward_computer = RewardComputer()
|
| 233 |
+
self._found_bugs = set()
|
| 234 |
+
self._steps_taken = 0
|
| 235 |
+
self._cumulative_reward = 0.0
|
| 236 |
+
self._action_history = []
|
| 237 |
+
self._auth_tokens = {}
|
| 238 |
+
|
| 239 |
+
logger.info(f"Reset environment: task={task_id}, seed={seed}, episode={self._episode_id}")
|
| 240 |
+
|
| 241 |
+
return APITestObservation(
|
| 242 |
+
available_endpoints=API_SPEC,
|
| 243 |
+
status_code=0,
|
| 244 |
+
response_body=None,
|
| 245 |
+
response_headers={},
|
| 246 |
+
response_time_ms=0,
|
| 247 |
+
feedback=(
|
| 248 |
+
f"Environment reset. Task: {dynamic_description} "
|
| 249 |
+
f"You have {self._task['max_steps']} steps. Start testing the API!"
|
| 250 |
+
),
|
| 251 |
+
bugs_found_so_far=0,
|
| 252 |
+
coverage_summary=self._reward_computer.coverage.summary(),
|
| 253 |
+
known_resource_ids=self._reward_computer.created_ids,
|
| 254 |
+
auth_tokens=self._auth_tokens,
|
| 255 |
+
task_id=task_id,
|
| 256 |
+
task_description=dynamic_description,
|
| 257 |
+
steps_taken=0,
|
| 258 |
+
max_steps=self._task["max_steps"],
|
| 259 |
+
done=False,
|
| 260 |
+
reward=0.0,
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
def step(self, action: APITestAction, timeout_s=None, **kwargs) -> APITestObservation:
|
| 264 |
+
"""Execute an API test action and return observation + reward."""
|
| 265 |
+
self._steps_taken += 1
|
| 266 |
+
|
| 267 |
+
# Forward request to buggy API
|
| 268 |
+
method = action.method.value if hasattr(action.method, "value") else str(action.method)
|
| 269 |
+
endpoint = action.endpoint
|
| 270 |
+
headers = dict(action.headers) if action.headers else {}
|
| 271 |
+
query_params = dict(action.query_params) if action.query_params else {}
|
| 272 |
+
body = action.body
|
| 273 |
+
|
| 274 |
+
# Make the request
|
| 275 |
+
start_time = time.time()
|
| 276 |
+
try:
|
| 277 |
+
response = self._api.request(
|
| 278 |
+
method=method.upper(),
|
| 279 |
+
url=endpoint,
|
| 280 |
+
headers=headers,
|
| 281 |
+
params=query_params if query_params else None,
|
| 282 |
+
json=body,
|
| 283 |
+
)
|
| 284 |
+
elapsed_ms = (time.time() - start_time) * 1000
|
| 285 |
+
|
| 286 |
+
response_status = response.status_code
|
| 287 |
+
try:
|
| 288 |
+
response_body = response.json()
|
| 289 |
+
except Exception:
|
| 290 |
+
response_body = response.text
|
| 291 |
+
response_headers = dict(response.headers)
|
| 292 |
+
except Exception as e:
|
| 293 |
+
elapsed_ms = (time.time() - start_time) * 1000
|
| 294 |
+
response_status = 0
|
| 295 |
+
response_body = {"error": str(e)}
|
| 296 |
+
response_headers = {}
|
| 297 |
+
|
| 298 |
+
# Track auth tokens from login responses
|
| 299 |
+
if (
|
| 300 |
+
endpoint == "/auth/login"
|
| 301 |
+
and response_status == 200
|
| 302 |
+
and isinstance(response_body, dict)
|
| 303 |
+
and "token" in response_body
|
| 304 |
+
):
|
| 305 |
+
username = body.get("username", "unknown") if body else "unknown"
|
| 306 |
+
self._auth_tokens[username] = response_body["token"]
|
| 307 |
+
|
| 308 |
+
# Check for bug detection
|
| 309 |
+
detection = self._bug_detector.check(
|
| 310 |
+
method=method,
|
| 311 |
+
endpoint=endpoint,
|
| 312 |
+
headers=headers,
|
| 313 |
+
query_params=query_params,
|
| 314 |
+
body=body,
|
| 315 |
+
expected_status=action.expected_status,
|
| 316 |
+
response_status=response_status,
|
| 317 |
+
response_body=response_body,
|
| 318 |
+
action_history=self._action_history,
|
| 319 |
+
found_bugs=self._found_bugs,
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
bug_severity = None
|
| 323 |
+
bug_id = None
|
| 324 |
+
if detection:
|
| 325 |
+
bug_severity = detection.bug.severity
|
| 326 |
+
bug_id = detection.bug.id
|
| 327 |
+
self._found_bugs.add(bug_id)
|
| 328 |
+
|
| 329 |
+
# Compute reward
|
| 330 |
+
reward_breakdown = self._reward_computer.compute(
|
| 331 |
+
method=method,
|
| 332 |
+
endpoint=endpoint,
|
| 333 |
+
headers=headers,
|
| 334 |
+
query_params=query_params,
|
| 335 |
+
body=body,
|
| 336 |
+
expected_status=action.expected_status,
|
| 337 |
+
response_status=response_status,
|
| 338 |
+
response_body=response_body,
|
| 339 |
+
bug_found=bug_severity,
|
| 340 |
+
bug_id=bug_id,
|
| 341 |
+
)
|
| 342 |
+
self._cumulative_reward += reward_breakdown.total
|
| 343 |
+
|
| 344 |
+
# Record action in history
|
| 345 |
+
self._action_history.append({
|
| 346 |
+
"method": method,
|
| 347 |
+
"endpoint": endpoint,
|
| 348 |
+
"headers": headers,
|
| 349 |
+
"query_params": query_params,
|
| 350 |
+
"body": body,
|
| 351 |
+
"response_status": response_status,
|
| 352 |
+
"response_body": response_body,
|
| 353 |
+
})
|
| 354 |
+
|
| 355 |
+
# Generate feedback
|
| 356 |
+
feedback_parts = [f"{method} {endpoint} -> {response_status}"]
|
| 357 |
+
if detection:
|
| 358 |
+
feedback_parts.append(f"BUG FOUND ({detection.bug.severity})! {detection.evidence}")
|
| 359 |
+
if reward_breakdown.coverage > 0:
|
| 360 |
+
feedback_parts.append(f"Coverage +{reward_breakdown.coverage:.2f}")
|
| 361 |
+
if reward_breakdown.penalty < 0:
|
| 362 |
+
feedback_parts.append("Repeated request penalty")
|
| 363 |
+
|
| 364 |
+
done = self._steps_taken >= self._task["max_steps"]
|
| 365 |
+
|
| 366 |
+
# Compute final grade if done
|
| 367 |
+
if done:
|
| 368 |
+
grade = TaskGrader.grade(
|
| 369 |
+
task_id=self._task["id"],
|
| 370 |
+
bugs_found=self._found_bugs,
|
| 371 |
+
coverage_pct=self._reward_computer.coverage.summary()["coverage_pct"],
|
| 372 |
+
endpoints_tested=len(self._reward_computer.coverage.endpoints_hit),
|
| 373 |
+
total_endpoints=self._reward_computer.coverage.total_endpoints,
|
| 374 |
+
method_endpoint_pairs=len(self._reward_computer.coverage.method_endpoint_pairs),
|
| 375 |
+
status_codes_seen=self._reward_computer.coverage.status_codes_seen,
|
| 376 |
+
action_history=self._action_history,
|
| 377 |
+
created_resources=self._reward_computer.created_ids,
|
| 378 |
+
)
|
| 379 |
+
# Generate bug bounty report
|
| 380 |
+
report = generate_bug_report(list(self._found_bugs), self._action_history)
|
| 381 |
+
|
| 382 |
+
feedback_parts.append(
|
| 383 |
+
f"\n=== EPISODE COMPLETE ===\n"
|
| 384 |
+
f"Final Score: {grade.score:.4f}\n"
|
| 385 |
+
f"Bugs Found: {len(self._found_bugs)}/{self._task['total_bugs']}\n"
|
| 386 |
+
f"Grade Breakdown: {json.dumps(grade.breakdown, indent=2)}\n"
|
| 387 |
+
f"Feedback: {grade.feedback}\n\n"
|
| 388 |
+
f"{report}"
|
| 389 |
+
)
|
| 390 |
+
# Add grade as bonus on top of step reward (not replacement)
|
| 391 |
+
final_reward = reward_breakdown.total + grade.score
|
| 392 |
+
else:
|
| 393 |
+
final_reward = reward_breakdown.total
|
| 394 |
+
|
| 395 |
+
return APITestObservation(
|
| 396 |
+
available_endpoints=API_SPEC,
|
| 397 |
+
status_code=response_status,
|
| 398 |
+
response_body=response_body,
|
| 399 |
+
response_headers={k: v for k, v in list(response_headers.items())[:20]},
|
| 400 |
+
response_time_ms=round(elapsed_ms, 2),
|
| 401 |
+
feedback=" | ".join(feedback_parts),
|
| 402 |
+
bugs_found_so_far=len(self._found_bugs),
|
| 403 |
+
coverage_summary=self._reward_computer.coverage.summary(),
|
| 404 |
+
known_resource_ids=self._reward_computer.created_ids,
|
| 405 |
+
auth_tokens=self._auth_tokens,
|
| 406 |
+
task_id=self._task["id"],
|
| 407 |
+
task_description=self._task["description"],
|
| 408 |
+
steps_taken=self._steps_taken,
|
| 409 |
+
max_steps=self._task["max_steps"],
|
| 410 |
+
done=done,
|
| 411 |
+
reward=final_reward,
|
| 412 |
+
metadata={"reward_breakdown": reward_breakdown.as_dict()},
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
@property
|
| 416 |
+
def state(self) -> APITestState:
|
| 417 |
+
"""Return current episode state."""
|
| 418 |
+
if not self._task:
|
| 419 |
+
return APITestState()
|
| 420 |
+
|
| 421 |
+
coverage = self._reward_computer.coverage.summary() if self._reward_computer else {}
|
| 422 |
+
return APITestState(
|
| 423 |
+
episode_id=self._episode_id,
|
| 424 |
+
step_count=self._steps_taken,
|
| 425 |
+
task_id=self._task["id"],
|
| 426 |
+
task_description=self._task["description"],
|
| 427 |
+
difficulty=self._task["difficulty"],
|
| 428 |
+
steps_taken=self._steps_taken,
|
| 429 |
+
max_steps=self._task["max_steps"],
|
| 430 |
+
bugs_found=len(self._found_bugs),
|
| 431 |
+
total_bugs=self._task["total_bugs"],
|
| 432 |
+
bugs_found_ids=list(self._found_bugs),
|
| 433 |
+
coverage_pct=coverage.get("coverage_pct", 0.0),
|
| 434 |
+
endpoints_tested=coverage.get("endpoints_tested", 0),
|
| 435 |
+
total_endpoints=coverage.get("total_endpoints", 0),
|
| 436 |
+
current_score=0.0,
|
| 437 |
+
cumulative_reward=round(self._cumulative_reward, 4),
|
| 438 |
+
)
|
server/graders.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Task-specific grading logic and bug bounty report generation.
|
| 3 |
+
|
| 4 |
+
Each task has a grader that computes a final score (0.0 - 1.0)
|
| 5 |
+
based on what the agent accomplished during the episode.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class GradeResult:
|
| 13 |
+
score: float
|
| 14 |
+
breakdown: dict[str, float]
|
| 15 |
+
feedback: str
|
| 16 |
+
report: str = "" # Bug bounty report (markdown)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def generate_bug_report(bugs_found_ids: list[str], action_history: list[dict]) -> str:
|
| 20 |
+
"""Generate a structured bug bounty report for discovered bugs."""
|
| 21 |
+
from .bug_detector import BugDetector
|
| 22 |
+
detector = BugDetector("security_workflows")
|
| 23 |
+
|
| 24 |
+
if not bugs_found_ids:
|
| 25 |
+
return "## API Security Assessment Report\n\nNo vulnerabilities discovered."
|
| 26 |
+
|
| 27 |
+
severity_order = {"hard": 0, "medium": 1, "easy": 2}
|
| 28 |
+
sorted_bugs = sorted(bugs_found_ids, key=lambda b: severity_order.get(detector.bugs.get(b, None) and detector.bugs[b].severity or "easy", 2))
|
| 29 |
+
|
| 30 |
+
sections = ["## API Security Assessment Report", ""]
|
| 31 |
+
sections.append(f"**Vulnerabilities Found:** {len(bugs_found_ids)}")
|
| 32 |
+
|
| 33 |
+
# Count by severity
|
| 34 |
+
counts = {"easy": 0, "medium": 0, "hard": 0}
|
| 35 |
+
for bid in bugs_found_ids:
|
| 36 |
+
bug = detector.bugs.get(bid)
|
| 37 |
+
if bug:
|
| 38 |
+
counts[bug.severity] = counts.get(bug.severity, 0) + 1
|
| 39 |
+
sections.append(f"**Critical/Hard:** {counts['hard']} | **Medium:** {counts['medium']} | **Low/Easy:** {counts['easy']}")
|
| 40 |
+
sections.append("")
|
| 41 |
+
|
| 42 |
+
for bid in sorted_bugs:
|
| 43 |
+
bug = detector.bugs.get(bid)
|
| 44 |
+
if not bug:
|
| 45 |
+
continue
|
| 46 |
+
|
| 47 |
+
sev_label = {"easy": "LOW", "medium": "MEDIUM", "hard": "HIGH"}.get(bug.severity, "INFO")
|
| 48 |
+
owasp = bug.owasp if bug.owasp else "Uncategorized"
|
| 49 |
+
|
| 50 |
+
sections.append(f"### {sev_label}: {bug.description}")
|
| 51 |
+
sections.append(f"- **ID:** {bid}")
|
| 52 |
+
sections.append(f"- **OWASP:** {owasp}")
|
| 53 |
+
sections.append(f"- **Category:** {bug.category}")
|
| 54 |
+
sections.append(f"- **Recommendation:** {bug.recommendation}" if bug.recommendation else "")
|
| 55 |
+
|
| 56 |
+
# Find the action that triggered this bug
|
| 57 |
+
for h in action_history:
|
| 58 |
+
if h.get("method") and h.get("endpoint"):
|
| 59 |
+
sections.append(f"- **Triggered by:** {h['method']} {h['endpoint']}")
|
| 60 |
+
break
|
| 61 |
+
sections.append("")
|
| 62 |
+
|
| 63 |
+
return "\n".join(sections)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class TaskGrader:
|
| 67 |
+
"""Computes final scores for each task based on episode performance."""
|
| 68 |
+
|
| 69 |
+
@staticmethod
|
| 70 |
+
def grade(
|
| 71 |
+
task_id: str,
|
| 72 |
+
bugs_found: set[str],
|
| 73 |
+
coverage_pct: float,
|
| 74 |
+
endpoints_tested: int,
|
| 75 |
+
total_endpoints: int,
|
| 76 |
+
method_endpoint_pairs: int,
|
| 77 |
+
status_codes_seen: set[int],
|
| 78 |
+
action_history: list[dict],
|
| 79 |
+
created_resources: dict[str, list],
|
| 80 |
+
) -> GradeResult:
|
| 81 |
+
if task_id == "basic_validation":
|
| 82 |
+
return TaskGrader._grade_basic(
|
| 83 |
+
bugs_found, coverage_pct, endpoints_tested, total_endpoints,
|
| 84 |
+
method_endpoint_pairs, status_codes_seen, action_history, created_resources,
|
| 85 |
+
)
|
| 86 |
+
elif task_id == "edge_cases":
|
| 87 |
+
return TaskGrader._grade_edge_cases(
|
| 88 |
+
bugs_found, coverage_pct, endpoints_tested, method_endpoint_pairs,
|
| 89 |
+
status_codes_seen, action_history, created_resources,
|
| 90 |
+
)
|
| 91 |
+
elif task_id == "security_workflows":
|
| 92 |
+
return TaskGrader._grade_security(
|
| 93 |
+
bugs_found, coverage_pct, action_history, created_resources,
|
| 94 |
+
)
|
| 95 |
+
return GradeResult(score=0.0, breakdown={}, feedback="Unknown task")
|
| 96 |
+
|
| 97 |
+
@staticmethod
|
| 98 |
+
def _grade_basic(
|
| 99 |
+
bugs_found, coverage_pct, endpoints_tested, total_endpoints,
|
| 100 |
+
method_endpoint_pairs, status_codes_seen, action_history, created_resources,
|
| 101 |
+
) -> GradeResult:
|
| 102 |
+
breakdown = {}
|
| 103 |
+
|
| 104 |
+
# 0.25: Test all GET endpoints
|
| 105 |
+
get_endpoints = {
|
| 106 |
+
h.get("endpoint") for h in action_history
|
| 107 |
+
if h.get("method", "").upper() == "GET"
|
| 108 |
+
}
|
| 109 |
+
get_score = min(len(get_endpoints) / 4, 1.0) * 0.25
|
| 110 |
+
breakdown["get_coverage"] = round(get_score, 3)
|
| 111 |
+
|
| 112 |
+
# 0.20: Test POST with valid data
|
| 113 |
+
post_success = sum(
|
| 114 |
+
1 for h in action_history
|
| 115 |
+
if h.get("method", "").upper() == "POST" and h.get("response_status") == 201
|
| 116 |
+
)
|
| 117 |
+
post_score = min(post_success / 2, 1.0) * 0.20
|
| 118 |
+
breakdown["post_testing"] = round(post_score, 3)
|
| 119 |
+
|
| 120 |
+
# 0.15: Test PUT/DELETE
|
| 121 |
+
put_delete = sum(
|
| 122 |
+
1 for h in action_history
|
| 123 |
+
if h.get("method", "").upper() in ("PUT", "DELETE")
|
| 124 |
+
)
|
| 125 |
+
pd_score = min(put_delete / 2, 1.0) * 0.15
|
| 126 |
+
breakdown["put_delete"] = round(pd_score, 3)
|
| 127 |
+
|
| 128 |
+
# 0.20: Bug discovery (easy bugs: TASK_01, TASK_02, TASK_03)
|
| 129 |
+
easy_bugs = {"BUG_TASK_01", "BUG_TASK_02", "BUG_TASK_03"}
|
| 130 |
+
found_easy = len(bugs_found & easy_bugs)
|
| 131 |
+
bug_score = min(found_easy / 2, 1.0) * 0.20
|
| 132 |
+
breakdown["bugs_found"] = round(bug_score, 3)
|
| 133 |
+
|
| 134 |
+
# 0.20: Response schema validation (status codes variety)
|
| 135 |
+
schema_score = min(len(status_codes_seen) / 4, 1.0) * 0.20
|
| 136 |
+
breakdown["schema_validation"] = round(schema_score, 3)
|
| 137 |
+
|
| 138 |
+
score = sum(breakdown.values())
|
| 139 |
+
feedback_parts = []
|
| 140 |
+
if get_score > 0:
|
| 141 |
+
feedback_parts.append(f"GET coverage: {len(get_endpoints)} endpoints")
|
| 142 |
+
if post_success > 0:
|
| 143 |
+
feedback_parts.append(f"POST success: {post_success}")
|
| 144 |
+
if found_easy > 0:
|
| 145 |
+
feedback_parts.append(f"Bugs found: {found_easy}/{len(easy_bugs)}")
|
| 146 |
+
|
| 147 |
+
return GradeResult(
|
| 148 |
+
score=round(min(score, 1.0), 4),
|
| 149 |
+
breakdown=breakdown,
|
| 150 |
+
feedback="; ".join(feedback_parts) if feedback_parts else "No significant progress",
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
@staticmethod
|
| 154 |
+
def _grade_edge_cases(
|
| 155 |
+
bugs_found, coverage_pct, endpoints_tested, method_endpoint_pairs,
|
| 156 |
+
status_codes_seen, action_history, created_resources,
|
| 157 |
+
) -> GradeResult:
|
| 158 |
+
breakdown = {}
|
| 159 |
+
|
| 160 |
+
# 0.15: Missing required fields testing
|
| 161 |
+
missing_field_tests = sum(
|
| 162 |
+
1 for h in action_history
|
| 163 |
+
if h.get("method", "").upper() == "POST"
|
| 164 |
+
and h.get("body") is not None
|
| 165 |
+
and isinstance(h.get("body"), dict)
|
| 166 |
+
and not h["body"].get("title")
|
| 167 |
+
)
|
| 168 |
+
breakdown["missing_fields"] = round(min(missing_field_tests / 2, 1.0) * 0.15, 3)
|
| 169 |
+
|
| 170 |
+
# 0.15: Invalid data type testing
|
| 171 |
+
invalid_tests = sum(
|
| 172 |
+
1 for h in action_history
|
| 173 |
+
if h.get("body") and isinstance(h.get("body"), dict)
|
| 174 |
+
and any(
|
| 175 |
+
isinstance(v, (list, bool)) or v == ""
|
| 176 |
+
for v in h["body"].values()
|
| 177 |
+
)
|
| 178 |
+
)
|
| 179 |
+
breakdown["invalid_types"] = round(min(invalid_tests / 2, 1.0) * 0.15, 3)
|
| 180 |
+
|
| 181 |
+
# 0.15: Boundary value testing (negative pages, huge limits, long strings)
|
| 182 |
+
boundary_tests = 0
|
| 183 |
+
for h in action_history:
|
| 184 |
+
qp = h.get("query_params", {})
|
| 185 |
+
if qp.get("page") is not None and int(str(qp.get("page", 1))) < 1:
|
| 186 |
+
boundary_tests += 1
|
| 187 |
+
if qp.get("limit") is not None and int(str(qp.get("limit", 10))) > 100:
|
| 188 |
+
boundary_tests += 1
|
| 189 |
+
breakdown["boundary_values"] = round(min(boundary_tests / 2, 1.0) * 0.15, 3)
|
| 190 |
+
|
| 191 |
+
# 0.15: Non-existent resource testing
|
| 192 |
+
nonexistent_tests = sum(
|
| 193 |
+
1 for h in action_history
|
| 194 |
+
if h.get("method", "").upper() in ("GET", "DELETE", "PUT")
|
| 195 |
+
and "/999" in h.get("endpoint", "")
|
| 196 |
+
)
|
| 197 |
+
breakdown["nonexistent_resources"] = round(min(nonexistent_tests / 2, 1.0) * 0.15, 3)
|
| 198 |
+
|
| 199 |
+
# 0.20: Bug discovery (medium bugs)
|
| 200 |
+
medium_bugs = {
|
| 201 |
+
"BUG_TASK_04", "BUG_TASK_05", "BUG_TASK_06",
|
| 202 |
+
"BUG_USER_01", "BUG_USER_02", "BUG_AUTH_02",
|
| 203 |
+
}
|
| 204 |
+
all_relevant = medium_bugs | {"BUG_TASK_01", "BUG_TASK_02", "BUG_TASK_03"}
|
| 205 |
+
found_relevant = len(bugs_found & all_relevant)
|
| 206 |
+
breakdown["bugs_found"] = round(min(found_relevant / 3, 1.0) * 0.20, 3)
|
| 207 |
+
|
| 208 |
+
# 0.20: Dependency chaining (create β read β update β delete)
|
| 209 |
+
chain_score = 0.0
|
| 210 |
+
if any(h.get("method") == "POST" and h.get("response_status") == 201 for h in action_history):
|
| 211 |
+
chain_score += 0.25
|
| 212 |
+
if created_resources.get("tasks"):
|
| 213 |
+
task_ids = created_resources["tasks"]
|
| 214 |
+
for tid in task_ids:
|
| 215 |
+
gets = [h for h in action_history if h.get("endpoint") == f"/tasks/{tid}" and h.get("method") == "GET"]
|
| 216 |
+
puts = [h for h in action_history if h.get("endpoint") == f"/tasks/{tid}" and h.get("method") == "PUT"]
|
| 217 |
+
deletes = [h for h in action_history if h.get("endpoint") == f"/tasks/{tid}" and h.get("method") == "DELETE"]
|
| 218 |
+
if gets:
|
| 219 |
+
chain_score += 0.25
|
| 220 |
+
if puts:
|
| 221 |
+
chain_score += 0.25
|
| 222 |
+
if deletes:
|
| 223 |
+
chain_score += 0.25
|
| 224 |
+
break # Only need one complete chain
|
| 225 |
+
breakdown["dependency_chaining"] = round(min(chain_score, 1.0) * 0.20, 3)
|
| 226 |
+
|
| 227 |
+
score = sum(breakdown.values())
|
| 228 |
+
return GradeResult(
|
| 229 |
+
score=round(min(score, 1.0), 4),
|
| 230 |
+
breakdown=breakdown,
|
| 231 |
+
feedback=f"Edge cases: {found_relevant} bugs found, chain score {chain_score:.0%}",
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
@staticmethod
|
| 235 |
+
def _grade_security(
|
| 236 |
+
bugs_found, coverage_pct, action_history, created_resources,
|
| 237 |
+
) -> GradeResult:
|
| 238 |
+
breakdown = {}
|
| 239 |
+
|
| 240 |
+
# 0.20: Cross-user authorization testing
|
| 241 |
+
cross_user = False
|
| 242 |
+
login_users = set()
|
| 243 |
+
for h in action_history:
|
| 244 |
+
if h.get("endpoint") == "/auth/login" and h.get("response_status") == 200:
|
| 245 |
+
body = h.get("body", {})
|
| 246 |
+
if body:
|
| 247 |
+
login_users.add(body.get("username"))
|
| 248 |
+
cross_user = len(login_users) >= 2
|
| 249 |
+
breakdown["cross_user_auth"] = 0.20 if cross_user else 0.0
|
| 250 |
+
|
| 251 |
+
# 0.20: Injection pattern testing
|
| 252 |
+
injection_attempted = sum(
|
| 253 |
+
1 for h in action_history
|
| 254 |
+
if h.get("body") and isinstance(h.get("body"), dict)
|
| 255 |
+
and any(
|
| 256 |
+
pattern.lower() in str(h["body"]).lower()
|
| 257 |
+
for pattern in ["DROP TABLE", "<script>", "OR 1=1", "UNION SELECT", "'; --"]
|
| 258 |
+
)
|
| 259 |
+
)
|
| 260 |
+
breakdown["injection_testing"] = round(min(injection_attempted / 2, 1.0) * 0.20, 3)
|
| 261 |
+
|
| 262 |
+
# 0.20: Multi-step state consistency
|
| 263 |
+
# Check if agent did: create β delete β re-fetch (stale cache test)
|
| 264 |
+
consistency_tests = 0
|
| 265 |
+
for i, h in enumerate(action_history):
|
| 266 |
+
if h.get("method") == "DELETE" and "/tasks/" in h.get("endpoint", ""):
|
| 267 |
+
# Check if agent re-fetched the same resource after deleting
|
| 268 |
+
deleted_endpoint = h["endpoint"]
|
| 269 |
+
for j in range(i + 1, len(action_history)):
|
| 270 |
+
if action_history[j].get("endpoint") == deleted_endpoint and action_history[j].get("method") == "GET":
|
| 271 |
+
consistency_tests += 1
|
| 272 |
+
break
|
| 273 |
+
breakdown["state_consistency"] = round(min(consistency_tests, 1.0) * 0.20, 3)
|
| 274 |
+
|
| 275 |
+
# 0.20: Security bug discovery
|
| 276 |
+
security_bugs = {"BUG_TASK_07", "BUG_AUTH_01", "BUG_TASK_08", "BUG_TASK_09"}
|
| 277 |
+
found_security = len(bugs_found & security_bugs)
|
| 278 |
+
breakdown["security_bugs"] = round(min(found_security / 2, 1.0) * 0.20, 3)
|
| 279 |
+
|
| 280 |
+
# 0.20: Complete workflow coverage
|
| 281 |
+
workflow_coverage = min(coverage_pct / 80, 1.0) # 80% coverage = full score
|
| 282 |
+
breakdown["workflow_coverage"] = round(workflow_coverage * 0.20, 3)
|
| 283 |
+
|
| 284 |
+
score = sum(breakdown.values())
|
| 285 |
+
return GradeResult(
|
| 286 |
+
score=round(min(score, 1.0), 4),
|
| 287 |
+
breakdown=breakdown,
|
| 288 |
+
feedback=f"Security: {found_security} security bugs, {len(login_users)} users tested, {injection_attempted} injection attempts",
|
| 289 |
+
)
|
server/reward.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Multi-signal reward function for the API Testing Environment.
|
| 3 |
+
|
| 4 |
+
Rewards are decomposed into:
|
| 5 |
+
1. Coverage reward β exploring new endpoints/methods/status codes
|
| 6 |
+
2. Validity reward β well-formed requests and proper dependency chaining
|
| 7 |
+
3. Bug discovery reward β the core goal, scaled by severity
|
| 8 |
+
4. Exploration bonus β trying novel actions
|
| 9 |
+
5. Penalties β for repeating exact requests or malformed input
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from dataclasses import dataclass, field
|
| 13 |
+
from typing import Any, Optional
|
| 14 |
+
import re
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class CoverageTracker:
|
| 19 |
+
"""Tracks API coverage across the episode."""
|
| 20 |
+
|
| 21 |
+
endpoints_hit: set[str] = field(default_factory=set)
|
| 22 |
+
method_endpoint_pairs: set[tuple[str, str]] = field(default_factory=set)
|
| 23 |
+
status_codes_seen: set[int] = field(default_factory=set)
|
| 24 |
+
total_endpoints: int = 10 # known endpoint patterns
|
| 25 |
+
|
| 26 |
+
def record(self, method: str, endpoint: str, status_code: int) -> dict[str, bool]:
|
| 27 |
+
"""Record a request and return what's new."""
|
| 28 |
+
normalized_endpoint = self._normalize_endpoint(endpoint)
|
| 29 |
+
pair = (method.upper(), normalized_endpoint)
|
| 30 |
+
|
| 31 |
+
is_new_endpoint = normalized_endpoint not in self.endpoints_hit
|
| 32 |
+
is_new_pair = pair not in self.method_endpoint_pairs
|
| 33 |
+
is_new_status = status_code not in self.status_codes_seen
|
| 34 |
+
|
| 35 |
+
self.endpoints_hit.add(normalized_endpoint)
|
| 36 |
+
self.method_endpoint_pairs.add(pair)
|
| 37 |
+
self.status_codes_seen.add(status_code)
|
| 38 |
+
|
| 39 |
+
return {
|
| 40 |
+
"new_endpoint": is_new_endpoint,
|
| 41 |
+
"new_method_endpoint": is_new_pair,
|
| 42 |
+
"new_status_code": is_new_status,
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
def _normalize_endpoint(self, endpoint: str) -> str:
|
| 46 |
+
"""Normalize /tasks/42 to /tasks/{id}."""
|
| 47 |
+
normalized = re.sub(r"/(\d+)", "/{id}", endpoint)
|
| 48 |
+
return normalized.rstrip("/") or "/"
|
| 49 |
+
|
| 50 |
+
def summary(self) -> dict:
|
| 51 |
+
return {
|
| 52 |
+
"endpoints_tested": len(self.endpoints_hit),
|
| 53 |
+
"total_endpoints": self.total_endpoints,
|
| 54 |
+
"method_endpoint_pairs": len(self.method_endpoint_pairs),
|
| 55 |
+
"status_codes_seen": sorted(self.status_codes_seen),
|
| 56 |
+
"coverage_pct": round(len(self.endpoints_hit) / max(self.total_endpoints, 1) * 100, 1),
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@dataclass
|
| 61 |
+
class RewardBreakdown:
|
| 62 |
+
coverage: float = 0.0
|
| 63 |
+
validity: float = 0.0
|
| 64 |
+
bug_discovery: float = 0.0
|
| 65 |
+
exploration: float = 0.0
|
| 66 |
+
penalty: float = 0.0
|
| 67 |
+
total: float = 0.0
|
| 68 |
+
|
| 69 |
+
def as_dict(self) -> dict:
|
| 70 |
+
return {
|
| 71 |
+
"coverage": round(self.coverage, 4),
|
| 72 |
+
"validity": round(self.validity, 4),
|
| 73 |
+
"bug_discovery": round(self.bug_discovery, 4),
|
| 74 |
+
"exploration": round(self.exploration, 4),
|
| 75 |
+
"penalty": round(self.penalty, 4),
|
| 76 |
+
"total": round(self.total, 4),
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class RewardComputer:
|
| 81 |
+
"""Computes multi-signal rewards for API testing actions."""
|
| 82 |
+
|
| 83 |
+
def __init__(self):
|
| 84 |
+
self.coverage = CoverageTracker()
|
| 85 |
+
self.action_history: list[dict] = []
|
| 86 |
+
self.found_bugs: set[str] = set()
|
| 87 |
+
self.created_ids: dict[str, list[Any]] = {} # resource type -> list of IDs
|
| 88 |
+
|
| 89 |
+
def reset(self):
|
| 90 |
+
self.coverage = CoverageTracker()
|
| 91 |
+
self.action_history = []
|
| 92 |
+
self.found_bugs = set()
|
| 93 |
+
self.created_ids = {}
|
| 94 |
+
|
| 95 |
+
def compute(
|
| 96 |
+
self,
|
| 97 |
+
method: str,
|
| 98 |
+
endpoint: str,
|
| 99 |
+
headers: dict,
|
| 100 |
+
query_params: dict,
|
| 101 |
+
body: Optional[dict],
|
| 102 |
+
expected_status: Optional[int],
|
| 103 |
+
response_status: int,
|
| 104 |
+
response_body: Any,
|
| 105 |
+
bug_found: Optional[str] = None, # bug severity if found
|
| 106 |
+
bug_id: Optional[str] = None,
|
| 107 |
+
) -> RewardBreakdown:
|
| 108 |
+
"""Compute reward for this step."""
|
| 109 |
+
breakdown = RewardBreakdown()
|
| 110 |
+
|
| 111 |
+
# 1. Coverage reward (0.0 - 0.3)
|
| 112 |
+
coverage_info = self.coverage.record(method, endpoint, response_status)
|
| 113 |
+
if coverage_info["new_endpoint"]:
|
| 114 |
+
breakdown.coverage += 0.10
|
| 115 |
+
if coverage_info["new_method_endpoint"]:
|
| 116 |
+
breakdown.coverage += 0.05
|
| 117 |
+
if coverage_info["new_status_code"]:
|
| 118 |
+
breakdown.coverage += 0.05
|
| 119 |
+
|
| 120 |
+
# 2. Validity reward (0.0 - 0.2)
|
| 121 |
+
if response_status < 500:
|
| 122 |
+
breakdown.validity += 0.03 # Non-crash request
|
| 123 |
+
|
| 124 |
+
if self._used_dependency(method, endpoint, body, headers):
|
| 125 |
+
breakdown.validity += 0.10 # Used a previously created resource ID or auth token
|
| 126 |
+
|
| 127 |
+
if expected_status is not None and expected_status == response_status:
|
| 128 |
+
breakdown.validity += 0.05 # Correctly predicted status code
|
| 129 |
+
|
| 130 |
+
# Track created resources
|
| 131 |
+
self._track_created_resources(method, endpoint, response_status, response_body)
|
| 132 |
+
|
| 133 |
+
# 3. Bug discovery reward (0.0 - 0.4)
|
| 134 |
+
if bug_found and bug_id:
|
| 135 |
+
if bug_id not in self.found_bugs:
|
| 136 |
+
self.found_bugs.add(bug_id)
|
| 137 |
+
if bug_found == "easy":
|
| 138 |
+
breakdown.bug_discovery += 0.10
|
| 139 |
+
elif bug_found == "medium":
|
| 140 |
+
breakdown.bug_discovery += 0.15
|
| 141 |
+
elif bug_found == "hard":
|
| 142 |
+
breakdown.bug_discovery += 0.25
|
| 143 |
+
# First discovery bonus
|
| 144 |
+
breakdown.bug_discovery += 0.05
|
| 145 |
+
|
| 146 |
+
# 4. Exploration bonus (0.0 - 0.1)
|
| 147 |
+
action_sig = self._action_signature(method, endpoint, query_params, body)
|
| 148 |
+
is_novel = all(
|
| 149 |
+
self._action_signature(
|
| 150 |
+
h.get("method", ""),
|
| 151 |
+
h.get("endpoint", ""),
|
| 152 |
+
h.get("query_params", {}),
|
| 153 |
+
h.get("body"),
|
| 154 |
+
)
|
| 155 |
+
!= action_sig
|
| 156 |
+
for h in self.action_history
|
| 157 |
+
)
|
| 158 |
+
if is_novel:
|
| 159 |
+
breakdown.exploration += 0.05
|
| 160 |
+
|
| 161 |
+
# 5. Penalties
|
| 162 |
+
# Exact duplicate request
|
| 163 |
+
exact_match = any(
|
| 164 |
+
h.get("method") == method
|
| 165 |
+
and h.get("endpoint") == endpoint
|
| 166 |
+
and h.get("query_params") == query_params
|
| 167 |
+
and h.get("body") == body
|
| 168 |
+
and h.get("headers") == headers
|
| 169 |
+
for h in self.action_history
|
| 170 |
+
)
|
| 171 |
+
if exact_match:
|
| 172 |
+
breakdown.penalty -= 0.08
|
| 173 |
+
|
| 174 |
+
# Record this action in history
|
| 175 |
+
self.action_history.append({
|
| 176 |
+
"method": method,
|
| 177 |
+
"endpoint": endpoint,
|
| 178 |
+
"headers": headers,
|
| 179 |
+
"query_params": query_params,
|
| 180 |
+
"body": body,
|
| 181 |
+
"response_status": response_status,
|
| 182 |
+
"response_body": response_body,
|
| 183 |
+
})
|
| 184 |
+
|
| 185 |
+
# Total
|
| 186 |
+
breakdown.total = max(
|
| 187 |
+
breakdown.coverage + breakdown.validity + breakdown.bug_discovery + breakdown.exploration + breakdown.penalty,
|
| 188 |
+
-0.1, # Floor to prevent extreme negative rewards
|
| 189 |
+
)
|
| 190 |
+
breakdown.total = min(breakdown.total, 1.0)
|
| 191 |
+
|
| 192 |
+
return breakdown
|
| 193 |
+
|
| 194 |
+
def _used_dependency(self, method: str, endpoint: str, body: Optional[dict], headers: dict) -> bool:
|
| 195 |
+
"""Check if this request uses a resource ID or token from a previous step."""
|
| 196 |
+
endpoint_str = str(endpoint)
|
| 197 |
+
|
| 198 |
+
# Check if endpoint contains a known resource ID
|
| 199 |
+
for resource_type, ids in self.created_ids.items():
|
| 200 |
+
for rid in ids:
|
| 201 |
+
if str(rid) in endpoint_str:
|
| 202 |
+
return True
|
| 203 |
+
|
| 204 |
+
# Check if using an auth token obtained from login
|
| 205 |
+
if headers.get("Authorization"):
|
| 206 |
+
for prev in self.action_history:
|
| 207 |
+
if (
|
| 208 |
+
prev.get("endpoint") == "/auth/login"
|
| 209 |
+
and prev.get("response_status") == 200
|
| 210 |
+
and isinstance(prev.get("response_body"), dict)
|
| 211 |
+
and "token" in prev["response_body"]
|
| 212 |
+
):
|
| 213 |
+
token = prev["response_body"]["token"]
|
| 214 |
+
if token in headers["Authorization"]:
|
| 215 |
+
return True
|
| 216 |
+
return False
|
| 217 |
+
|
| 218 |
+
def _track_created_resources(
|
| 219 |
+
self, method: str, endpoint: str, status: int, body: Any
|
| 220 |
+
):
|
| 221 |
+
"""Track resource IDs from POST responses."""
|
| 222 |
+
if method.upper() == "POST" and status == 201 and isinstance(body, dict):
|
| 223 |
+
resource_id = body.get("id")
|
| 224 |
+
if resource_id is not None:
|
| 225 |
+
# Determine resource type from endpoint
|
| 226 |
+
resource_type = endpoint.strip("/").split("/")[0]
|
| 227 |
+
if resource_type not in self.created_ids:
|
| 228 |
+
self.created_ids[resource_type] = []
|
| 229 |
+
self.created_ids[resource_type].append(resource_id)
|
| 230 |
+
|
| 231 |
+
def _action_signature(
|
| 232 |
+
self, method: str, endpoint: str, query_params: dict, body: Optional[dict]
|
| 233 |
+
) -> str:
|
| 234 |
+
"""Create a signature for an action to check novelty."""
|
| 235 |
+
normalized = re.sub(r"/\d+", "/{id}", endpoint)
|
| 236 |
+
body_keys = sorted(body.keys()) if body else []
|
| 237 |
+
param_keys = sorted(query_params.keys()) if query_params else []
|
| 238 |
+
return f"{method}:{normalized}:{param_keys}:{body_keys}"
|
setup.sh
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# ============================================================
|
| 3 |
+
# API Testing Environment β One-command setup
|
| 4 |
+
# ============================================================
|
| 5 |
+
# Usage: bash setup.sh
|
| 6 |
+
#
|
| 7 |
+
# This script:
|
| 8 |
+
# 1. Creates a virtual environment
|
| 9 |
+
# 2. Detects your GPU and installs the correct PyTorch+CUDA
|
| 10 |
+
# 3. Installs all project dependencies
|
| 11 |
+
# 4. Verifies everything works
|
| 12 |
+
# ============================================================
|
| 13 |
+
|
| 14 |
+
set -e
|
| 15 |
+
|
| 16 |
+
echo ""
|
| 17 |
+
echo "============================================"
|
| 18 |
+
echo " API Testing Environment β Setup"
|
| 19 |
+
echo "============================================"
|
| 20 |
+
echo ""
|
| 21 |
+
|
| 22 |
+
# --- Step 1: Create venv ---
|
| 23 |
+
echo "[1/5] Setting up virtual environment..."
|
| 24 |
+
if [ ! -d ".venv" ]; then
|
| 25 |
+
python3 -m venv .venv
|
| 26 |
+
echo " Created .venv"
|
| 27 |
+
else
|
| 28 |
+
echo " .venv already exists"
|
| 29 |
+
fi
|
| 30 |
+
source .venv/bin/activate
|
| 31 |
+
pip install --upgrade pip setuptools wheel -q
|
| 32 |
+
echo " Python: $(python3 --version)"
|
| 33 |
+
echo " pip: $(pip --version | awk '{print $2}')"
|
| 34 |
+
echo ""
|
| 35 |
+
|
| 36 |
+
# --- Step 2: Install PyTorch with correct CUDA ---
|
| 37 |
+
echo "[2/5] Detecting GPU and installing PyTorch..."
|
| 38 |
+
|
| 39 |
+
install_pytorch() {
|
| 40 |
+
if command -v nvidia-smi &> /dev/null; then
|
| 41 |
+
DRIVER_VERSION=$(nvidia-smi --query-gpu=driver_version --format=csv,noheader 2>/dev/null | head -1)
|
| 42 |
+
DRIVER_MAJOR=$(echo "$DRIVER_VERSION" | cut -d. -f1)
|
| 43 |
+
GPU_NAME=$(nvidia-smi --query-gpu=name --format=csv,noheader 2>/dev/null | head -1)
|
| 44 |
+
GPU_MEM=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader 2>/dev/null | head -1)
|
| 45 |
+
|
| 46 |
+
echo " GPU: $GPU_NAME ($GPU_MEM)"
|
| 47 |
+
echo " NVIDIA driver: $DRIVER_VERSION"
|
| 48 |
+
|
| 49 |
+
if [ "$DRIVER_MAJOR" -ge 530 ]; then
|
| 50 |
+
echo " -> Installing PyTorch + CUDA 12.1"
|
| 51 |
+
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121 -q
|
| 52 |
+
elif [ "$DRIVER_MAJOR" -ge 450 ]; then
|
| 53 |
+
echo " -> Installing PyTorch + CUDA 11.8 (older driver)"
|
| 54 |
+
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118 -q
|
| 55 |
+
else
|
| 56 |
+
echo " WARNING: Driver too old ($DRIVER_VERSION). Install CPU PyTorch."
|
| 57 |
+
echo " Upgrade: https://www.nvidia.com/Download/index.aspx"
|
| 58 |
+
pip install torch torchvision -q
|
| 59 |
+
fi
|
| 60 |
+
else
|
| 61 |
+
echo " No NVIDIA GPU detected."
|
| 62 |
+
# Check for Apple Silicon
|
| 63 |
+
if python3 -c "import platform; exit(0 if platform.processor() == 'arm' else 1)" 2>/dev/null; then
|
| 64 |
+
echo " -> Apple Silicon detected, installing default PyTorch (MPS support)"
|
| 65 |
+
else
|
| 66 |
+
echo " -> Installing CPU-only PyTorch"
|
| 67 |
+
fi
|
| 68 |
+
pip install torch torchvision -q
|
| 69 |
+
fi
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
install_pytorch
|
| 73 |
+
echo ""
|
| 74 |
+
|
| 75 |
+
# --- Step 3: Install project dependencies ---
|
| 76 |
+
echo "[3/5] Installing project dependencies..."
|
| 77 |
+
pip install -r requirements.txt -q
|
| 78 |
+
echo " Done."
|
| 79 |
+
echo ""
|
| 80 |
+
|
| 81 |
+
# --- Step 4: Verify everything ---
|
| 82 |
+
echo "[4/5] Verifying installation..."
|
| 83 |
+
echo ""
|
| 84 |
+
python3 << 'PYEOF'
|
| 85 |
+
import sys
|
| 86 |
+
|
| 87 |
+
# Core
|
| 88 |
+
import fastapi, uvicorn, pydantic, httpx
|
| 89 |
+
print(f" fastapi: {fastapi.__version__}")
|
| 90 |
+
|
| 91 |
+
# ML
|
| 92 |
+
import torch
|
| 93 |
+
print(f" torch: {torch.__version__}")
|
| 94 |
+
cuda = torch.cuda.is_available()
|
| 95 |
+
mps = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
|
| 96 |
+
if cuda:
|
| 97 |
+
print(f" CUDA: {torch.version.cuda}")
|
| 98 |
+
print(f" GPU: {torch.cuda.get_device_name(0)}")
|
| 99 |
+
print(f" GPU memory: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB")
|
| 100 |
+
elif mps:
|
| 101 |
+
print(f" Device: Apple MPS")
|
| 102 |
+
else:
|
| 103 |
+
print(f" Device: CPU only (training will be slow!)")
|
| 104 |
+
|
| 105 |
+
import transformers, trl, peft, datasets
|
| 106 |
+
print(f" transformers: {transformers.__version__}")
|
| 107 |
+
print(f" trl: {trl.__version__}")
|
| 108 |
+
print(f" peft: {peft.__version__}")
|
| 109 |
+
|
| 110 |
+
# Optional
|
| 111 |
+
try:
|
| 112 |
+
import wandb
|
| 113 |
+
print(f" wandb: {wandb.__version__}")
|
| 114 |
+
except ImportError:
|
| 115 |
+
print(f" wandb: not installed (optional)")
|
| 116 |
+
|
| 117 |
+
try:
|
| 118 |
+
import gradio
|
| 119 |
+
print(f" gradio: {gradio.__version__}")
|
| 120 |
+
except ImportError:
|
| 121 |
+
print(f" gradio: not installed (optional)")
|
| 122 |
+
|
| 123 |
+
# OpenEnv
|
| 124 |
+
try:
|
| 125 |
+
import openenv
|
| 126 |
+
print(f" openenv: OK")
|
| 127 |
+
except ImportError:
|
| 128 |
+
print(f" openenv: MISSING β run: pip install -r requirements.txt")
|
| 129 |
+
|
| 130 |
+
# Environment test
|
| 131 |
+
print("")
|
| 132 |
+
sys.path.insert(0, ".")
|
| 133 |
+
from server.environment import APITestEnvironment
|
| 134 |
+
from models import APITestAction, HTTPMethod
|
| 135 |
+
env = APITestEnvironment()
|
| 136 |
+
obs = env.reset(seed=42, task_id="basic_validation")
|
| 137 |
+
obs = env.step(APITestAction(method=HTTPMethod.GET, endpoint="/tasks/999999", expected_status=404))
|
| 138 |
+
assert obs.bugs_found_so_far == 1, "Bug detection failed!"
|
| 139 |
+
print(f" Environment: OK (bug detection verified)")
|
| 140 |
+
PYEOF
|
| 141 |
+
|
| 142 |
+
echo ""
|
| 143 |
+
|
| 144 |
+
# --- Step 5: Done ---
|
| 145 |
+
echo "============================================"
|
| 146 |
+
echo " Setup complete!"
|
| 147 |
+
echo "============================================"
|
| 148 |
+
echo ""
|
| 149 |
+
echo " Activate: source .venv/bin/activate"
|
| 150 |
+
echo ""
|
| 151 |
+
echo " Gradio UI: python gradio_app.py"
|
| 152 |
+
echo " Baselines: python -m training.evaluate --task all --agent all"
|
| 153 |
+
echo " Training: python -m training.grpo --model-id Qwen/Qwen3-1.7B"
|
| 154 |
+
echo " Test mode: python -m training.grpo --test-mode"
|
| 155 |
+
echo ""
|
| 156 |
+
echo " For HF Hub: huggingface-cli login"
|
| 157 |
+
echo " For W&B: wandb login"
|
| 158 |
+
echo ""
|
train_grpo.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""GRPO training β see training/grpo.py for the full implementation."""
|
| 3 |
+
from training.grpo import main
|
| 4 |
+
|
| 5 |
+
if __name__ == "__main__":
|
| 6 |
+
main()
|
training/README.md
ADDED
|
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Training Module
|
| 2 |
+
|
| 3 |
+
Everything related to training an AI agent to test APIs using GRPO (Group Relative Policy Optimization).
|
| 4 |
+
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## Setup
|
| 8 |
+
|
| 9 |
+
```bash
|
| 10 |
+
cd api_testing_env
|
| 11 |
+
|
| 12 |
+
# Option 1: Automated setup (creates venv, installs everything)
|
| 13 |
+
bash setup.sh
|
| 14 |
+
|
| 15 |
+
# Option 2: Manual setup
|
| 16 |
+
python3 -m venv .venv
|
| 17 |
+
source .venv/bin/activate
|
| 18 |
+
pip install -r requirements.txt
|
| 19 |
+
|
| 20 |
+
# Optional: login to HuggingFace Hub (for model push)
|
| 21 |
+
huggingface-cli login
|
| 22 |
+
|
| 23 |
+
# Optional: login to Weights & Biases (for logging)
|
| 24 |
+
wandb login
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
### Environment Variables
|
| 28 |
+
|
| 29 |
+
Create a `.env` file in `api_testing_env/` (or export in your shell):
|
| 30 |
+
|
| 31 |
+
```bash
|
| 32 |
+
# .env
|
| 33 |
+
|
| 34 |
+
# HuggingFace Hub β required for --push-to-hub
|
| 35 |
+
# Get your token at: https://huggingface.co/settings/tokens
|
| 36 |
+
HF_TOKEN=hf_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
| 37 |
+
|
| 38 |
+
# Weights & Biases β required for --use-wandb
|
| 39 |
+
# Get your key at: https://wandb.ai/authorize
|
| 40 |
+
WANDB_API_KEY=xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
| 41 |
+
|
| 42 |
+
# Optional: set W&B defaults
|
| 43 |
+
WANDB_PROJECT=api-testing-grpo
|
| 44 |
+
WANDB_ENTITY=your-team-name
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
**Three ways to provide these keys:**
|
| 48 |
+
|
| 49 |
+
| Method | Command |
|
| 50 |
+
|--------|---------|
|
| 51 |
+
| `.env` file | Create `.env` as shown above, then `source .env` before training |
|
| 52 |
+
| CLI login | `huggingface-cli login` and `wandb login` (stores keys in ~/.cache) |
|
| 53 |
+
| Inline export | `export HF_TOKEN=hf_xxx && export WANDB_API_KEY=xxx` |
|
| 54 |
+
|
| 55 |
+
> **Important:** Never commit `.env` to git. It's already in `.gitignore`.
|
| 56 |
+
|
| 57 |
+
---
|
| 58 |
+
|
| 59 |
+
## Quick Start
|
| 60 |
+
|
| 61 |
+
```bash
|
| 62 |
+
cd api_testing_env
|
| 63 |
+
source .venv/bin/activate
|
| 64 |
+
|
| 65 |
+
# 1. See what training prompts look like (no GPU needed)
|
| 66 |
+
SHOW_PROMPTS=1 python -m training.grpo
|
| 67 |
+
|
| 68 |
+
# 2. Quick sanity check (CPU, ~2 minutes)
|
| 69 |
+
python -m training.grpo --test-mode
|
| 70 |
+
|
| 71 |
+
# 3. Real training (GPU required)
|
| 72 |
+
python -m training.grpo --model-id Qwen/Qwen3-1.7B --num-episodes 100
|
| 73 |
+
|
| 74 |
+
# 4. With HuggingFace Hub push
|
| 75 |
+
python -m training.grpo \
|
| 76 |
+
--push-to-hub --hf-repo-id your-username/api-tester-grpo
|
| 77 |
+
|
| 78 |
+
# 5. With Weights & Biases logging
|
| 79 |
+
python -m training.grpo \
|
| 80 |
+
--use-wandb --wandb-project api-testing-grpo
|
| 81 |
+
|
| 82 |
+
# 6. Full pipeline: training + HF push + W&B
|
| 83 |
+
python -m training.grpo \
|
| 84 |
+
--model-id Qwen/Qwen3-1.7B \
|
| 85 |
+
--num-episodes 100 \
|
| 86 |
+
--push-to-hub --hf-repo-id your-username/api-tester-grpo \
|
| 87 |
+
--use-wandb --wandb-project api-testing-grpo
|
| 88 |
+
|
| 89 |
+
# 7. Run baseline agents only (no GPU needed)
|
| 90 |
+
python -m training.evaluate --task all --agent all --url http://localhost:8000
|
| 91 |
+
|
| 92 |
+
# 8. Resume from checkpoint
|
| 93 |
+
python -m training.grpo --model-id ./checkpoints/step_50
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
---
|
| 97 |
+
|
| 98 |
+
## How Training Works
|
| 99 |
+
|
| 100 |
+
There is **no external dataset**. The environment generates unique episodes on the fly.
|
| 101 |
+
|
| 102 |
+
```
|
| 103 |
+
βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 104 |
+
β GRPO Training Loop β
|
| 105 |
+
β β
|
| 106 |
+
βββββββββββββ β 1. env.reset(seed=N) β
|
| 107 |
+
β β β β unique users, tasks, data β
|
| 108 |
+
β Qwen β β β
|
| 109 |
+
β 1.7B ββββΆβ 2. LLM generates: {"method":"GET",...} β
|
| 110 |
+
β + LoRA β β β
|
| 111 |
+
β βββββ 3. env.step(action) β reward β
|
| 112 |
+
βββββββββββββ β coverage + bugs + validity β
|
| 113 |
+
β β
|
| 114 |
+
β 4. GRPO: generate 4 attempts per prompt, β
|
| 115 |
+
β keep best, update model weights β
|
| 116 |
+
β β
|
| 117 |
+
β 5. Repeat with next seed β
|
| 118 |
+
βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
### Why no dataset file?
|
| 122 |
+
|
| 123 |
+
Each `reset(seed=N)` creates a **unique database** with different users, tasks, and data:
|
| 124 |
+
|
| 125 |
+
| Seed | Users | Tasks |
|
| 126 |
+
|------|-------|-------|
|
| 127 |
+
| 42 | diana, alice, xander, ivan, hannah | 8 tasks |
|
| 128 |
+
| 99 | mike, george, tom, fiona | 6 tasks |
|
| 129 |
+
| 7 | priya, kevin, wendy | 4 tasks |
|
| 130 |
+
|
| 131 |
+
The agent can't memorize "login as alice" because alice might not exist. It must **read the observation and adapt** β that's the learning signal.
|
| 132 |
+
|
| 133 |
+
The bugs (13 planted flaws) are structural β same code flaws every episode β but the path to finding them changes because the data is different.
|
| 134 |
+
|
| 135 |
+
---
|
| 136 |
+
|
| 137 |
+
## Training Pipeline
|
| 138 |
+
|
| 139 |
+
The full training pipeline runs these steps automatically:
|
| 140 |
+
|
| 141 |
+
```
|
| 142 |
+
1. Run baseline agents (random, sequential, smart) across all tasks
|
| 143 |
+
β
|
| 144 |
+
2. Load base model (Qwen 1.7B)
|
| 145 |
+
β
|
| 146 |
+
3. Evaluate base model before training (establishes LLM baseline)
|
| 147 |
+
β
|
| 148 |
+
4. GRPO training with LoRA
|
| 149 |
+
β
|
| 150 |
+
5. Save model locally to --output-dir
|
| 151 |
+
β
|
| 152 |
+
6. Push to HuggingFace Hub (if --push-to-hub)
|
| 153 |
+
β
|
| 154 |
+
7. Evaluate trained model after GRPO
|
| 155 |
+
β
|
| 156 |
+
8. Print comparison table (baselines vs base vs trained)
|
| 157 |
+
β
|
| 158 |
+
9. Save metrics (JSON + markdown) to output-dir/metrics/
|
| 159 |
+
β
|
| 160 |
+
10. Save comparison plots (PNG) to output-dir/metrics/plots/
|
| 161 |
+
β
|
| 162 |
+
11. Finalize W&B run (if --use-wandb)
|
| 163 |
+
```
|
| 164 |
+
|
| 165 |
+
---
|
| 166 |
+
|
| 167 |
+
## File Guide
|
| 168 |
+
|
| 169 |
+
| File | Purpose | When to modify |
|
| 170 |
+
|------|---------|----------------|
|
| 171 |
+
| `prompts.py` | System prompt, `format_observation()`, `parse_action()` | Change how the LLM sees tasks or formats actions |
|
| 172 |
+
| `rewards.py` | `format_reward_fn()`, `environment_reward_fn()` | Tune reward scaling or add new reward signals |
|
| 173 |
+
| `agents.py` | `RandomAgent`, `SequentialAgent`, `SmartAgent` | Add new baseline strategies |
|
| 174 |
+
| `grpo.py` | `build_training_prompts()`, `train_grpo()` | Change training hyperparameters or model |
|
| 175 |
+
| `evaluate.py` | `run_rollout()`, `run_baseline_local()`, remote runner | Change evaluation logic |
|
| 176 |
+
|
| 177 |
+
### prompts.py
|
| 178 |
+
|
| 179 |
+
The bridge between the environment and the LLM.
|
| 180 |
+
|
| 181 |
+
**`SYSTEM_PROMPT`** β Instructions telling the LLM it's an API tester. Includes output format (JSON) and testing strategies.
|
| 182 |
+
|
| 183 |
+
**`format_observation(obs)`** β Converts an environment observation into text:
|
| 184 |
+
- First turn: full API spec + task description + available users
|
| 185 |
+
- Later turns: last response + feedback + progress stats + auth tokens
|
| 186 |
+
|
| 187 |
+
**`parse_action(text)`** β Extracts JSON from LLM output. Handles:
|
| 188 |
+
- Raw JSON: `{"method": "GET", "endpoint": "/tasks"}`
|
| 189 |
+
- Code blocks: `` ```json {...} ``` ``
|
| 190 |
+
- Extra text around JSON: `"I'll try: {...}"`
|
| 191 |
+
|
| 192 |
+
### rewards.py
|
| 193 |
+
|
| 194 |
+
Two reward functions that GRPO uses to score each LLM completion:
|
| 195 |
+
|
| 196 |
+
**`format_reward_fn`** β Binary: +1.0 if valid JSON action, -1.0 if not. Teaches the model to always output parseable actions.
|
| 197 |
+
|
| 198 |
+
**`environment_reward_fn`** β Runs the action in the environment and returns the actual reward (coverage + bugs + validity), scaled by 5.0 to dominate over format reward.
|
| 199 |
+
|
| 200 |
+
### agents.py
|
| 201 |
+
|
| 202 |
+
Three hand-coded baselines for comparison:
|
| 203 |
+
|
| 204 |
+
| Agent | Strategy | Expected Score |
|
| 205 |
+
|-------|----------|---------------|
|
| 206 |
+
| `RandomAgent` | Random method + random endpoint | ~0.10 |
|
| 207 |
+
| `SequentialAgent` | Fixed sequence: GET, POST, PUT, DELETE each endpoint | ~0.35 |
|
| 208 |
+
| `SmartAgent` | Multi-phase: discover β auth β CRUD β bug hunt β security | ~0.55 |
|
| 209 |
+
|
| 210 |
+
A GRPO-trained model should beat the SmartAgent.
|
| 211 |
+
|
| 212 |
+
### grpo.py
|
| 213 |
+
|
| 214 |
+
The main training script.
|
| 215 |
+
|
| 216 |
+
**`build_training_prompts(num_episodes)`** β Creates N prompts by resetting the environment with seeds 0..N. Each prompt is a chat message with system prompt + initial observation.
|
| 217 |
+
|
| 218 |
+
**`run_baseline_evaluation(seed)`** β Runs all three baseline agents across all tasks before training starts.
|
| 219 |
+
|
| 220 |
+
**`train_grpo(args)`** β Full GRPO loop:
|
| 221 |
+
1. Run baseline agents for comparison
|
| 222 |
+
2. Load model + tokenizer (Qwen 1.7B default)
|
| 223 |
+
3. Evaluate base model before training
|
| 224 |
+
4. Apply LoRA (r=16, alpha=32, targets q_proj + v_proj)
|
| 225 |
+
5. Generate prompts from environment
|
| 226 |
+
6. Create per-prompt environment instances for reward eval
|
| 227 |
+
7. Train with TRL's GRPOTrainer
|
| 228 |
+
8. Save model locally + push to HF Hub
|
| 229 |
+
9. Evaluate trained model + print comparison
|
| 230 |
+
10. Save metrics (JSON, markdown) and plots (PNG)
|
| 231 |
+
11. Finalize W&B run
|
| 232 |
+
|
| 233 |
+
**`save_metrics()`** β Saves `results.json` and `results.md` to `output-dir/metrics/`.
|
| 234 |
+
|
| 235 |
+
**`save_plots()`** β Generates three comparison bar charts (reward, bugs, coverage) saved as PNGs.
|
| 236 |
+
|
| 237 |
+
### evaluate.py
|
| 238 |
+
|
| 239 |
+
**`run_rollout(model, tokenizer, task_id, seed)`** β Runs one full episode with a HuggingFace model. Multi-turn: LLM generates action β env steps β LLM sees result β repeats.
|
| 240 |
+
|
| 241 |
+
**`run_baseline_local(agent_name, task_id, seed)`** β Runs baseline agents against the local environment (no server needed). Used by `grpo.py` to establish baselines before training.
|
| 242 |
+
|
| 243 |
+
**`run_episode(url, task_id, agent_cls)`** β Runs a baseline agent against a remote server via WebSocket.
|
| 244 |
+
|
| 245 |
+
---
|
| 246 |
+
|
| 247 |
+
## Training Hyperparameters
|
| 248 |
+
|
| 249 |
+
| Parameter | Default | Description |
|
| 250 |
+
|-----------|---------|-------------|
|
| 251 |
+
| `--model-id` | `Qwen/Qwen3-1.7B` | Base model (any HF causal LM) |
|
| 252 |
+
| `--num-episodes` | 50 | Training prompts (more = more diverse episodes) |
|
| 253 |
+
| `--num-generations` | 4 | GRPO rollouts per prompt (higher = better but slower) |
|
| 254 |
+
| `--max-completion-length` | 256 | Max tokens per LLM response |
|
| 255 |
+
| `--max-steps` | 200 | Total training optimizer steps |
|
| 256 |
+
| `--learning-rate` | 2e-5 | AdamW learning rate |
|
| 257 |
+
| `--batch-size` | 1 | Per-device batch size |
|
| 258 |
+
| `--output-dir` | `./checkpoints/grpo_api_tester` | Where to save model |
|
| 259 |
+
| `--push-to-hub` | off | Push trained model to HuggingFace Hub |
|
| 260 |
+
| `--hf-repo-id` | none | HF Hub repo (e.g., `user/api-tester-grpo`) |
|
| 261 |
+
| `--use-wandb` | off | Enable Weights & Biases logging |
|
| 262 |
+
| `--wandb-project` | `api-testing-grpo` | W&B project name |
|
| 263 |
+
| `--wandb-run-name` | auto | W&B run name |
|
| 264 |
+
| `--test-mode` | off | Quick 3-episode, 2-gen, 5-step test |
|
| 265 |
+
|
| 266 |
+
### Hardware Requirements
|
| 267 |
+
|
| 268 |
+
| Setup | GPU | Time | Model |
|
| 269 |
+
|-------|-----|------|-------|
|
| 270 |
+
| Colab Free | T4 (16GB) | ~1-2 hours | Qwen 1.7B + 4-bit LoRA |
|
| 271 |
+
| Colab Pro | A100 (40GB) | ~30 min | Qwen 4B + LoRA |
|
| 272 |
+
| Local | Any 8GB+ | ~1-2 hours | Qwen 1.7B + 4-bit LoRA |
|
| 273 |
+
| CPU only | None | `--test-mode` only | Verifies pipeline works |
|
| 274 |
+
|
| 275 |
+
---
|
| 276 |
+
|
| 277 |
+
## Output Structure
|
| 278 |
+
|
| 279 |
+
After training, your output directory will look like:
|
| 280 |
+
|
| 281 |
+
```
|
| 282 |
+
checkpoints/grpo_api_tester/
|
| 283 |
+
βββ adapter_config.json # LoRA adapter config
|
| 284 |
+
βββ adapter_model.safetensors # Trained LoRA weights
|
| 285 |
+
βββ tokenizer.json # Tokenizer files
|
| 286 |
+
βββ tokenizer_config.json
|
| 287 |
+
βββ special_tokens_map.json
|
| 288 |
+
βββ metrics/
|
| 289 |
+
βββ results.json # Full results (baselines + base + trained)
|
| 290 |
+
βββ results.md # Markdown comparison table
|
| 291 |
+
βββ plots/
|
| 292 |
+
βββ reward_comparison.png # Bar chart: reward across all agents
|
| 293 |
+
βββ bugs_comparison.png # Bar chart: bugs found
|
| 294 |
+
βββ coverage_comparison.png # Bar chart: API coverage %
|
| 295 |
+
```
|
| 296 |
+
|
| 297 |
+
---
|
| 298 |
+
|
| 299 |
+
## Weights & Biases Integration
|
| 300 |
+
|
| 301 |
+
When `--use-wandb` is enabled, the following is logged:
|
| 302 |
+
|
| 303 |
+
| Metric | Description |
|
| 304 |
+
|--------|-------------|
|
| 305 |
+
| `baseline/{agent}/{task}/reward` | Baseline agent scores |
|
| 306 |
+
| `base_model/{task}/reward` | Pre-training model scores |
|
| 307 |
+
| `trained_model/{task}/reward` | Post-training model scores |
|
| 308 |
+
| `delta/{task}/reward` | Improvement over base model |
|
| 309 |
+
| `plots/*` | Comparison charts as W&B images |
|
| 310 |
+
| TRL defaults | Loss, learning rate, reward mean/std |
|
| 311 |
+
|
| 312 |
+
---
|
| 313 |
+
|
| 314 |
+
## Expected Results
|
| 315 |
+
|
| 316 |
+
### Before Training (base Qwen 1.7B, no fine-tuning)
|
| 317 |
+
|
| 318 |
+
The base model can output JSON sometimes, but has no API testing strategy:
|
| 319 |
+
```
|
| 320 |
+
basic_validation: ~0.15 (random-level)
|
| 321 |
+
edge_cases: ~0.08
|
| 322 |
+
security_workflows: ~0.03
|
| 323 |
+
```
|
| 324 |
+
|
| 325 |
+
### After GRPO (50 episodes, 200 steps)
|
| 326 |
+
|
| 327 |
+
The model learns systematic testing patterns:
|
| 328 |
+
```
|
| 329 |
+
basic_validation: ~0.55-0.65
|
| 330 |
+
edge_cases: ~0.35-0.45
|
| 331 |
+
security_workflows: ~0.25-0.35
|
| 332 |
+
```
|
| 333 |
+
|
| 334 |
+
### What the Model Learns
|
| 335 |
+
|
| 336 |
+
1. **Output format** β Always produce valid JSON (format reward)
|
| 337 |
+
2. **Coverage** β Test different endpoints, don't repeat the same request
|
| 338 |
+
3. **Dependency chaining** β POST to create, then GET/PUT/DELETE the created resource
|
| 339 |
+
4. **Bug patterns** β Try non-existent IDs, missing fields, invalid emails
|
| 340 |
+
5. **Auth workflows** β Login first, use tokens in subsequent requests
|
| 341 |
+
6. **Security testing** β Try cross-user access, injection payloads
|
| 342 |
+
|
| 343 |
+
---
|
| 344 |
+
|
| 345 |
+
## Extending the Training
|
| 346 |
+
|
| 347 |
+
### Add a new reward signal
|
| 348 |
+
|
| 349 |
+
Edit `rewards.py`:
|
| 350 |
+
|
| 351 |
+
```python
|
| 352 |
+
def efficiency_reward_fn(completions: list[str], **kwargs) -> list[float]:
|
| 353 |
+
"""Reward for concise, focused actions (penalize wasted steps)."""
|
| 354 |
+
rewards = []
|
| 355 |
+
for text in completions:
|
| 356 |
+
action = parse_action(text)
|
| 357 |
+
if action and action.expected_status:
|
| 358 |
+
rewards.append(0.5) # Bonus for predicting expected status
|
| 359 |
+
else:
|
| 360 |
+
rewards.append(0.0)
|
| 361 |
+
return rewards
|
| 362 |
+
```
|
| 363 |
+
|
| 364 |
+
Then add it to the combined reward in `grpo.py`.
|
| 365 |
+
|
| 366 |
+
### Add a new baseline agent
|
| 367 |
+
|
| 368 |
+
Edit `agents.py`:
|
| 369 |
+
|
| 370 |
+
```python
|
| 371 |
+
class CoverageAgent:
|
| 372 |
+
"""Agent that prioritizes hitting every endpoint once."""
|
| 373 |
+
name = "coverage"
|
| 374 |
+
|
| 375 |
+
def __init__(self):
|
| 376 |
+
self.tested = set()
|
| 377 |
+
# ...
|
| 378 |
+
```
|
| 379 |
+
|
| 380 |
+
Then add it to the `AGENTS` dict.
|
| 381 |
+
|
| 382 |
+
### Use a different model
|
| 383 |
+
|
| 384 |
+
```bash
|
| 385 |
+
# Qwen 2.5 (smaller, faster)
|
| 386 |
+
python -m training.grpo --model-id Qwen/Qwen2.5-1.5B
|
| 387 |
+
|
| 388 |
+
# Llama 3 (if you have access)
|
| 389 |
+
python -m training.grpo --model-id meta-llama/Llama-3.2-1B
|
| 390 |
+
```
|
| 391 |
+
|
| 392 |
+
Any HuggingFace causal language model works β just make sure it supports chat templates.
|
training/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Training module for the API Testing Environment.
|
| 3 |
+
|
| 4 |
+
Contains:
|
| 5 |
+
- prompts.py β System prompt, observation formatting, action parsing
|
| 6 |
+
- rewards.py β Reward functions for GRPO (format + environment)
|
| 7 |
+
- agents.py β Baseline agents (random, sequential, smart)
|
| 8 |
+
- grpo.py β GRPO training loop with TRL, HF Hub push, W&B logging
|
| 9 |
+
- evaluate.py β Evaluation / rollout runner (local + remote)
|
| 10 |
+
"""
|
training/agents.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Baseline agents for the API Testing Environment.
|
| 3 |
+
|
| 4 |
+
Three agents of increasing sophistication:
|
| 5 |
+
1. RandomAgent β Picks random endpoints/methods (lower bound)
|
| 6 |
+
2. SequentialAgent β Systematically tests each endpoint in order
|
| 7 |
+
3. SmartAgent β Chains requests and probes for known bug patterns
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import random
|
| 11 |
+
import sys
|
| 12 |
+
import os
|
| 13 |
+
|
| 14 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
| 15 |
+
from models import APITestAction, HTTPMethod
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class RandomAgent:
|
| 19 |
+
"""Randomly picks endpoints and methods. Baseline for comparison."""
|
| 20 |
+
|
| 21 |
+
name = "random"
|
| 22 |
+
|
| 23 |
+
ENDPOINTS = ["/tasks", "/tasks/1", "/tasks/2", "/tasks/999", "/users", "/users/1", "/auth/login"]
|
| 24 |
+
METHODS = ["GET", "POST", "PUT", "DELETE"]
|
| 25 |
+
|
| 26 |
+
def act(self, observation: dict) -> APITestAction:
|
| 27 |
+
method = random.choice(self.METHODS)
|
| 28 |
+
endpoint = random.choice(self.ENDPOINTS)
|
| 29 |
+
body = None
|
| 30 |
+
headers = {}
|
| 31 |
+
|
| 32 |
+
if method == "POST" and endpoint == "/tasks":
|
| 33 |
+
body = {"title": f"Random task {random.randint(1, 100)}"}
|
| 34 |
+
elif method == "POST" and endpoint == "/auth/login":
|
| 35 |
+
body = {"username": random.choice(["alice", "bob"]), "password": "pass"}
|
| 36 |
+
elif method == "POST" and endpoint == "/users":
|
| 37 |
+
body = {"username": f"user{random.randint(100, 999)}", "email": "test@test.com", "password": "pass"}
|
| 38 |
+
elif method == "PUT":
|
| 39 |
+
endpoint = f"/tasks/{random.randint(1, 5)}"
|
| 40 |
+
body = {"title": "Updated"}
|
| 41 |
+
|
| 42 |
+
return APITestAction(
|
| 43 |
+
method=HTTPMethod(method) if method in ("GET", "POST", "PUT", "DELETE") else HTTPMethod.GET,
|
| 44 |
+
endpoint=endpoint,
|
| 45 |
+
headers=headers,
|
| 46 |
+
body=body,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class SequentialAgent:
|
| 51 |
+
"""Systematically tests each endpoint with valid requests."""
|
| 52 |
+
|
| 53 |
+
name = "sequential"
|
| 54 |
+
|
| 55 |
+
def __init__(self):
|
| 56 |
+
self.step = 0
|
| 57 |
+
|
| 58 |
+
def act(self, observation: dict) -> APITestAction:
|
| 59 |
+
self.step += 1
|
| 60 |
+
actions = self._get_action_sequence()
|
| 61 |
+
idx = min(self.step - 1, len(actions) - 1)
|
| 62 |
+
return actions[idx]
|
| 63 |
+
|
| 64 |
+
def _get_action_sequence(self) -> list[APITestAction]:
|
| 65 |
+
return [
|
| 66 |
+
APITestAction(method=HTTPMethod.GET, endpoint="/tasks", expected_status=200),
|
| 67 |
+
APITestAction(method=HTTPMethod.GET, endpoint="/users", expected_status=200),
|
| 68 |
+
APITestAction(method=HTTPMethod.GET, endpoint="/tasks/1", expected_status=200),
|
| 69 |
+
APITestAction(method=HTTPMethod.GET, endpoint="/users/1", expected_status=200),
|
| 70 |
+
APITestAction(method=HTTPMethod.POST, endpoint="/auth/login",
|
| 71 |
+
body={"username": "alice", "password": "password123"}, expected_status=200),
|
| 72 |
+
APITestAction(method=HTTPMethod.POST, endpoint="/tasks",
|
| 73 |
+
body={"title": "Test Task", "description": "Created by baseline"}, expected_status=201),
|
| 74 |
+
APITestAction(method=HTTPMethod.POST, endpoint="/users",
|
| 75 |
+
body={"username": "testuser", "email": "test@example.com", "password": "test123"},
|
| 76 |
+
expected_status=201),
|
| 77 |
+
APITestAction(method=HTTPMethod.PUT, endpoint="/tasks/1",
|
| 78 |
+
body={"title": "Updated Task"}, expected_status=200),
|
| 79 |
+
APITestAction(method=HTTPMethod.DELETE, endpoint="/tasks/5", expected_status=200),
|
| 80 |
+
APITestAction(method=HTTPMethod.GET, endpoint="/tasks/999999", expected_status=404),
|
| 81 |
+
APITestAction(method=HTTPMethod.POST, endpoint="/tasks",
|
| 82 |
+
body={"description": "No title"}, expected_status=400),
|
| 83 |
+
APITestAction(method=HTTPMethod.GET, endpoint="/tasks",
|
| 84 |
+
query_params={"page": -1, "limit": 10}, expected_status=400),
|
| 85 |
+
APITestAction(method=HTTPMethod.GET, endpoint="/tasks",
|
| 86 |
+
query_params={"status": "done"}, expected_status=200),
|
| 87 |
+
APITestAction(method=HTTPMethod.GET, endpoint="/tasks",
|
| 88 |
+
query_params={"sort": "title"}, expected_status=200),
|
| 89 |
+
APITestAction(method=HTTPMethod.GET, endpoint="/tasks/2", expected_status=200),
|
| 90 |
+
]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class SmartAgent:
|
| 94 |
+
"""Heuristic agent that chains requests and probes for bugs."""
|
| 95 |
+
|
| 96 |
+
name = "smart"
|
| 97 |
+
|
| 98 |
+
def __init__(self):
|
| 99 |
+
self.step = 0
|
| 100 |
+
self.auth_tokens = {}
|
| 101 |
+
self.created_ids = []
|
| 102 |
+
|
| 103 |
+
def act(self, observation: dict) -> APITestAction:
|
| 104 |
+
self.step += 1
|
| 105 |
+
|
| 106 |
+
if isinstance(observation, dict):
|
| 107 |
+
self.auth_tokens = observation.get("auth_tokens", self.auth_tokens)
|
| 108 |
+
ids = observation.get("known_resource_ids", {})
|
| 109 |
+
for rtype, id_list in ids.items():
|
| 110 |
+
for rid in id_list:
|
| 111 |
+
if rid not in self.created_ids:
|
| 112 |
+
self.created_ids.append(rid)
|
| 113 |
+
|
| 114 |
+
actions = self._get_smart_sequence()
|
| 115 |
+
idx = min(self.step - 1, len(actions) - 1)
|
| 116 |
+
return actions[idx]
|
| 117 |
+
|
| 118 |
+
def _get_smart_sequence(self) -> list[APITestAction]:
|
| 119 |
+
alice_token = self.auth_tokens.get("alice", "")
|
| 120 |
+
bob_token = self.auth_tokens.get("bob", "")
|
| 121 |
+
alice_auth = {"Authorization": f"Bearer {alice_token}"} if alice_token else {}
|
| 122 |
+
bob_auth = {"Authorization": f"Bearer {bob_token}"} if bob_token else {}
|
| 123 |
+
|
| 124 |
+
return [
|
| 125 |
+
# Phase 1: Discovery
|
| 126 |
+
APITestAction(method=HTTPMethod.GET, endpoint="/tasks", expected_status=200),
|
| 127 |
+
APITestAction(method=HTTPMethod.GET, endpoint="/users", expected_status=200),
|
| 128 |
+
# Phase 2: Authentication
|
| 129 |
+
APITestAction(method=HTTPMethod.POST, endpoint="/auth/login",
|
| 130 |
+
body={"username": "alice", "password": "password123"}, expected_status=200),
|
| 131 |
+
APITestAction(method=HTTPMethod.POST, endpoint="/auth/login",
|
| 132 |
+
body={"username": "bob", "password": "password123"}, expected_status=200),
|
| 133 |
+
# Phase 3: CRUD with auth
|
| 134 |
+
APITestAction(method=HTTPMethod.POST, endpoint="/tasks",
|
| 135 |
+
body={"title": "Alice's task", "description": "Test"},
|
| 136 |
+
headers=alice_auth, expected_status=201),
|
| 137 |
+
APITestAction(method=HTTPMethod.GET, endpoint="/tasks/1", headers=alice_auth, expected_status=200),
|
| 138 |
+
# Phase 4: Easy bugs
|
| 139 |
+
APITestAction(method=HTTPMethod.GET, endpoint="/tasks/999999", expected_status=404),
|
| 140 |
+
APITestAction(method=HTTPMethod.POST, endpoint="/tasks",
|
| 141 |
+
body={"description": "no title"}, expected_status=400),
|
| 142 |
+
APITestAction(method=HTTPMethod.GET, endpoint="/tasks",
|
| 143 |
+
query_params={"page": -1, "limit": 10}, expected_status=400),
|
| 144 |
+
# Phase 5: Medium bugs
|
| 145 |
+
APITestAction(method=HTTPMethod.PUT, endpoint="/tasks/1",
|
| 146 |
+
body={"assignee_email": "not-an-email"}, expected_status=422),
|
| 147 |
+
APITestAction(method=HTTPMethod.DELETE, endpoint="/tasks/99999", expected_status=404),
|
| 148 |
+
APITestAction(method=HTTPMethod.GET, endpoint="/tasks",
|
| 149 |
+
query_params={"limit": 999999}, expected_status=200),
|
| 150 |
+
# Phase 6: User bugs
|
| 151 |
+
APITestAction(method=HTTPMethod.POST, endpoint="/users",
|
| 152 |
+
body={"username": "baduser", "email": "invalid-email", "password": "test"},
|
| 153 |
+
expected_status=422),
|
| 154 |
+
APITestAction(method=HTTPMethod.POST, endpoint="/auth/login",
|
| 155 |
+
body={"username": "alice", "password": ""}, expected_status=401),
|
| 156 |
+
# Phase 7: BOLA
|
| 157 |
+
APITestAction(method=HTTPMethod.GET, endpoint="/tasks/1",
|
| 158 |
+
headers=bob_auth, expected_status=403),
|
| 159 |
+
# Phase 8: Injection
|
| 160 |
+
APITestAction(method=HTTPMethod.POST, endpoint="/tasks",
|
| 161 |
+
body={"title": "test'; DROP TABLE tasks;--"}, expected_status=201),
|
| 162 |
+
APITestAction(method=HTTPMethod.POST, endpoint="/tasks",
|
| 163 |
+
body={"title": "A" * 6000}, expected_status=400),
|
| 164 |
+
# Phase 9: Cross-user modification
|
| 165 |
+
APITestAction(method=HTTPMethod.PUT, endpoint="/tasks/1",
|
| 166 |
+
body={"title": "Bob modified Alice's task"},
|
| 167 |
+
headers=bob_auth, expected_status=403),
|
| 168 |
+
# Phase 10: State consistency
|
| 169 |
+
APITestAction(method=HTTPMethod.POST, endpoint="/tasks",
|
| 170 |
+
body={"title": "Ephemeral task"}, expected_status=201),
|
| 171 |
+
APITestAction(method=HTTPMethod.DELETE, endpoint="/tasks/6", expected_status=200),
|
| 172 |
+
APITestAction(method=HTTPMethod.GET, endpoint="/tasks/6", expected_status=404),
|
| 173 |
+
# Phase 11: Coverage
|
| 174 |
+
APITestAction(method=HTTPMethod.GET, endpoint="/tasks",
|
| 175 |
+
query_params={"status": "done"}, expected_status=200),
|
| 176 |
+
APITestAction(method=HTTPMethod.GET, endpoint="/tasks",
|
| 177 |
+
query_params={"sort": "title"}, expected_status=200),
|
| 178 |
+
APITestAction(method=HTTPMethod.GET, endpoint="/users/2", expected_status=200),
|
| 179 |
+
# Phase 12: Password hash check
|
| 180 |
+
APITestAction(method=HTTPMethod.POST, endpoint="/users",
|
| 181 |
+
body={"username": "newuser2", "email": "valid@email.com", "password": "pass"},
|
| 182 |
+
expected_status=201),
|
| 183 |
+
]
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
AGENTS = {
|
| 187 |
+
"random": RandomAgent,
|
| 188 |
+
"sequential": SequentialAgent,
|
| 189 |
+
"smart": SmartAgent,
|
| 190 |
+
}
|
training/evaluate.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Evaluation and rollout runner.
|
| 4 |
+
|
| 5 |
+
- run_rollout(): Run a single episode with a HuggingFace model
|
| 6 |
+
- run_baseline_local(): Run baseline agents against the local environment
|
| 7 |
+
- run_baseline(): Run baseline agents against a remote server
|
| 8 |
+
- main(): CLI for running baselines
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
import asyncio
|
| 13 |
+
import logging
|
| 14 |
+
import random
|
| 15 |
+
import sys
|
| 16 |
+
import os
|
| 17 |
+
|
| 18 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
| 19 |
+
|
| 20 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
from models import APITestAction, HTTPMethod
|
| 24 |
+
from server.environment import APITestEnvironment
|
| 25 |
+
from .prompts import (
|
| 26 |
+
PLAN_SYSTEM_PROMPT, format_plan_prompt,
|
| 27 |
+
parse_action, parse_test_plan,
|
| 28 |
+
)
|
| 29 |
+
from .agents import AGENTS
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def run_rollout(
|
| 33 |
+
model,
|
| 34 |
+
tokenizer,
|
| 35 |
+
task_id: str = "basic_validation",
|
| 36 |
+
seed: int = 42,
|
| 37 |
+
max_steps: int | None = None,
|
| 38 |
+
) -> dict:
|
| 39 |
+
"""Run a single episode with a HuggingFace model.
|
| 40 |
+
|
| 41 |
+
Uses PLAN mode: the model generates a full test plan (JSON array) in one shot,
|
| 42 |
+
then all actions are executed sequentially. This matches how training works.
|
| 43 |
+
|
| 44 |
+
Falls back to multi-turn mode if the model can't produce a valid plan.
|
| 45 |
+
"""
|
| 46 |
+
import torch
|
| 47 |
+
import time as _time
|
| 48 |
+
|
| 49 |
+
# Force GPU if available
|
| 50 |
+
if torch.cuda.is_available():
|
| 51 |
+
device = torch.device("cuda")
|
| 52 |
+
# Move model to GPU if it's on CPU
|
| 53 |
+
if next(model.parameters()).device.type == "cpu":
|
| 54 |
+
logger.info(" Moving model to GPU...")
|
| 55 |
+
model = model.to(device)
|
| 56 |
+
else:
|
| 57 |
+
device = next(model.parameters()).device
|
| 58 |
+
|
| 59 |
+
env = APITestEnvironment()
|
| 60 |
+
obs = env.reset(seed=seed, task_id=task_id)
|
| 61 |
+
actual_max = max_steps or obs.max_steps
|
| 62 |
+
logger.info(f" Rollout: {task_id} | max_steps={actual_max} | device={device}")
|
| 63 |
+
|
| 64 |
+
# --- Try plan mode first (matches training) ---
|
| 65 |
+
plan_prompt = format_plan_prompt(obs)
|
| 66 |
+
messages = [
|
| 67 |
+
{"role": "system", "content": PLAN_SYSTEM_PROMPT},
|
| 68 |
+
{"role": "user", "content": plan_prompt},
|
| 69 |
+
]
|
| 70 |
+
|
| 71 |
+
# Qwen3 thinking support
|
| 72 |
+
chat_kwargs = {}
|
| 73 |
+
if "qwen3" in str(getattr(model, "name_or_path", "") or "").lower():
|
| 74 |
+
chat_kwargs["enable_thinking"] = True
|
| 75 |
+
|
| 76 |
+
prompt_text = tokenizer.apply_chat_template(
|
| 77 |
+
messages, tokenize=False, add_generation_prompt=True, **chat_kwargs,
|
| 78 |
+
)
|
| 79 |
+
inputs = tokenizer(prompt_text, return_tensors="pt").to(device)
|
| 80 |
+
|
| 81 |
+
gen_start = _time.time()
|
| 82 |
+
print(f" Generating test plan...", end="", flush=True)
|
| 83 |
+
with torch.no_grad():
|
| 84 |
+
output = model.generate(
|
| 85 |
+
**inputs,
|
| 86 |
+
max_new_tokens=4096, # Match training max_completion_length
|
| 87 |
+
temperature=0.7,
|
| 88 |
+
do_sample=True,
|
| 89 |
+
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
|
| 90 |
+
)
|
| 91 |
+
completion = tokenizer.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
| 92 |
+
gen_time = _time.time() - gen_start
|
| 93 |
+
print(f" done ({gen_time:.1f}s, {len(completion)} chars)")
|
| 94 |
+
|
| 95 |
+
# Parse the plan
|
| 96 |
+
actions = parse_test_plan(completion)
|
| 97 |
+
if actions:
|
| 98 |
+
logger.info(f" Plan generated: {len(actions)} actions")
|
| 99 |
+
else:
|
| 100 |
+
# Fallback: try single action parse
|
| 101 |
+
single = parse_action(completion)
|
| 102 |
+
if single:
|
| 103 |
+
actions = [single]
|
| 104 |
+
logger.info(" Plan parse failed, got 1 action from fallback")
|
| 105 |
+
else:
|
| 106 |
+
logger.warning(" Failed to parse any actions from model output")
|
| 107 |
+
# Print first 500 chars of completion for debugging
|
| 108 |
+
preview = completion[:500].replace("\n", " ")
|
| 109 |
+
logger.warning(f" Model output preview: {preview}...")
|
| 110 |
+
actions = []
|
| 111 |
+
|
| 112 |
+
# Limit to max_steps
|
| 113 |
+
actions = actions[:actual_max]
|
| 114 |
+
|
| 115 |
+
# Execute all actions
|
| 116 |
+
total_reward = 0.0
|
| 117 |
+
for i, action in enumerate(actions):
|
| 118 |
+
try:
|
| 119 |
+
obs = env.step(action)
|
| 120 |
+
total_reward += obs.reward or 0.0
|
| 121 |
+
method_str = action.method.value if hasattr(action.method, "value") else str(action.method)
|
| 122 |
+
print(f" Step {i+1}/{len(actions)}: {method_str} {action.endpoint} -> "
|
| 123 |
+
f"{obs.status_code} | reward={obs.reward:.3f} | bugs={obs.bugs_found_so_far}")
|
| 124 |
+
except Exception as e:
|
| 125 |
+
print(f" Step {i+1}/{len(actions)}: ERROR - {e}")
|
| 126 |
+
|
| 127 |
+
# If no actions were generated, show that
|
| 128 |
+
if not actions:
|
| 129 |
+
print(" (no valid actions generated)")
|
| 130 |
+
|
| 131 |
+
state = env.state
|
| 132 |
+
return {
|
| 133 |
+
"task_id": task_id,
|
| 134 |
+
"seed": seed,
|
| 135 |
+
"steps": len(actions),
|
| 136 |
+
"total_reward": round(total_reward, 4),
|
| 137 |
+
"bugs_found": state.bugs_found,
|
| 138 |
+
"total_bugs": state.total_bugs,
|
| 139 |
+
"coverage_pct": state.coverage_pct,
|
| 140 |
+
"bugs_found_ids": state.bugs_found_ids,
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def run_baseline_local(
|
| 145 |
+
agent_name: str = "all",
|
| 146 |
+
task_id: str = "all",
|
| 147 |
+
seed: int = 42,
|
| 148 |
+
) -> list[dict]:
|
| 149 |
+
"""Run baseline agents against the local environment (no server needed).
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
agent_name: "random", "sequential", "smart", or "all"
|
| 153 |
+
task_id: task ID or "all"
|
| 154 |
+
seed: random seed
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
List of result dicts with agent, task_id, total_reward, bugs_found, etc.
|
| 158 |
+
"""
|
| 159 |
+
tasks = ["basic_validation", "edge_cases", "security_workflows"] if task_id == "all" else [task_id]
|
| 160 |
+
agents = list(AGENTS.items()) if agent_name == "all" else [(agent_name, AGENTS[agent_name])]
|
| 161 |
+
|
| 162 |
+
results = []
|
| 163 |
+
for tid in tasks:
|
| 164 |
+
for aname, agent_cls in agents:
|
| 165 |
+
random.seed(seed)
|
| 166 |
+
agent = agent_cls()
|
| 167 |
+
env = APITestEnvironment()
|
| 168 |
+
obs = env.reset(seed=seed, task_id=tid)
|
| 169 |
+
|
| 170 |
+
total_reward = 0.0
|
| 171 |
+
step = 0
|
| 172 |
+
|
| 173 |
+
while not obs.done and step < obs.max_steps:
|
| 174 |
+
obs_dict = {
|
| 175 |
+
"status_code": obs.status_code,
|
| 176 |
+
"response_body": obs.response_body,
|
| 177 |
+
"feedback": obs.feedback,
|
| 178 |
+
"bugs_found_so_far": obs.bugs_found_so_far,
|
| 179 |
+
"coverage_summary": obs.coverage_summary,
|
| 180 |
+
"known_resource_ids": obs.known_resource_ids,
|
| 181 |
+
"auth_tokens": obs.auth_tokens,
|
| 182 |
+
"steps_taken": obs.steps_taken,
|
| 183 |
+
"max_steps": obs.max_steps,
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
action = agent.act(obs_dict)
|
| 187 |
+
obs = env.step(action)
|
| 188 |
+
total_reward += obs.reward or 0.0
|
| 189 |
+
step += 1
|
| 190 |
+
|
| 191 |
+
state = env.state
|
| 192 |
+
result = {
|
| 193 |
+
"agent": aname,
|
| 194 |
+
"task_id": tid,
|
| 195 |
+
"seed": seed,
|
| 196 |
+
"steps": step,
|
| 197 |
+
"total_reward": round(total_reward, 4),
|
| 198 |
+
"bugs_found": state.bugs_found,
|
| 199 |
+
"total_bugs": state.total_bugs,
|
| 200 |
+
"coverage_pct": state.coverage_pct,
|
| 201 |
+
"bugs_found_ids": state.bugs_found_ids,
|
| 202 |
+
}
|
| 203 |
+
results.append(result)
|
| 204 |
+
logger.info(
|
| 205 |
+
f" [{aname}] {tid}: reward={result['total_reward']:.4f}, "
|
| 206 |
+
f"bugs={result['bugs_found']}/{result['total_bugs']}, "
|
| 207 |
+
f"coverage={result['coverage_pct']:.1f}%"
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
return results
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
# =====================================================================
|
| 214 |
+
# Remote baseline runner (against server via WebSocket client)
|
| 215 |
+
# =====================================================================
|
| 216 |
+
|
| 217 |
+
async def run_episode(url: str, task_id: str, agent_cls, seed: int = 42) -> dict:
|
| 218 |
+
"""Run one baseline episode against a remote server."""
|
| 219 |
+
from client import APITestEnv
|
| 220 |
+
|
| 221 |
+
random.seed(seed)
|
| 222 |
+
agent = agent_cls()
|
| 223 |
+
|
| 224 |
+
async with APITestEnv(base_url=url) as env:
|
| 225 |
+
result = await env.reset(task_id=task_id)
|
| 226 |
+
obs = result.observation
|
| 227 |
+
|
| 228 |
+
logger.info(f"Starting {agent.name} agent on task '{task_id}'")
|
| 229 |
+
|
| 230 |
+
total_reward = 0.0
|
| 231 |
+
step = 0
|
| 232 |
+
|
| 233 |
+
while not result.done:
|
| 234 |
+
obs_dict = {
|
| 235 |
+
"status_code": obs.status_code,
|
| 236 |
+
"response_body": obs.response_body,
|
| 237 |
+
"feedback": obs.feedback,
|
| 238 |
+
"bugs_found_so_far": obs.bugs_found_so_far,
|
| 239 |
+
"coverage_summary": obs.coverage_summary,
|
| 240 |
+
"known_resource_ids": obs.known_resource_ids,
|
| 241 |
+
"auth_tokens": obs.auth_tokens,
|
| 242 |
+
"steps_taken": obs.steps_taken,
|
| 243 |
+
"max_steps": obs.max_steps,
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
action = agent.act(obs_dict)
|
| 247 |
+
result = await env.step(action)
|
| 248 |
+
obs = result.observation
|
| 249 |
+
total_reward += result.reward or 0
|
| 250 |
+
|
| 251 |
+
step += 1
|
| 252 |
+
method = action.method.value if hasattr(action.method, "value") else str(action.method)
|
| 253 |
+
logger.info(
|
| 254 |
+
f" Step {step}: {method} {action.endpoint} -> "
|
| 255 |
+
f"{obs.status_code} | reward={result.reward:.4f} | bugs={obs.bugs_found_so_far}"
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
state = await env.state()
|
| 259 |
+
return {
|
| 260 |
+
"task_id": task_id,
|
| 261 |
+
"agent": agent.name,
|
| 262 |
+
"total_reward": round(total_reward, 4),
|
| 263 |
+
"bugs_found": state.bugs_found,
|
| 264 |
+
"total_bugs": state.total_bugs,
|
| 265 |
+
"coverage_pct": state.coverage_pct,
|
| 266 |
+
"steps": step,
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
async def main_async(args):
|
| 271 |
+
tasks = ["basic_validation", "edge_cases", "security_workflows"] if args.task == "all" else [args.task]
|
| 272 |
+
agents = list(AGENTS.values()) if args.agent == "all" else [AGENTS[args.agent]]
|
| 273 |
+
|
| 274 |
+
results = []
|
| 275 |
+
for task_id in tasks:
|
| 276 |
+
for agent_cls in agents:
|
| 277 |
+
try:
|
| 278 |
+
result = await run_episode(args.url, task_id, agent_cls, seed=args.seed)
|
| 279 |
+
results.append(result)
|
| 280 |
+
logger.info(
|
| 281 |
+
f"\nRESULT: {result['agent']} on {result['task_id']}: "
|
| 282 |
+
f"reward={result['total_reward']}, bugs={result['bugs_found']}/{result['total_bugs']}, "
|
| 283 |
+
f"coverage={result['coverage_pct']:.1f}%"
|
| 284 |
+
)
|
| 285 |
+
except Exception as e:
|
| 286 |
+
logger.error(f"Error running {agent_cls.name} on {task_id}: {e}", exc_info=True)
|
| 287 |
+
|
| 288 |
+
if results:
|
| 289 |
+
print("\n" + "=" * 80)
|
| 290 |
+
print("BASELINE RESULTS SUMMARY")
|
| 291 |
+
print("=" * 80)
|
| 292 |
+
print(f"{'Agent':<15} {'Task':<25} {'Score':<10} {'Bugs':<10} {'Coverage':<10}")
|
| 293 |
+
print("-" * 80)
|
| 294 |
+
for r in results:
|
| 295 |
+
print(
|
| 296 |
+
f"{r['agent']:<15} {r['task_id']:<25} "
|
| 297 |
+
f"{r['total_reward']:<10.4f} "
|
| 298 |
+
f"{r['bugs_found']}/{r['total_bugs']:<8} "
|
| 299 |
+
f"{r['coverage_pct']:<10.1f}%"
|
| 300 |
+
)
|
| 301 |
+
print("=" * 80)
|
| 302 |
+
|
| 303 |
+
return results
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def main():
|
| 307 |
+
parser = argparse.ArgumentParser(description="Baseline agents for API Testing Environment")
|
| 308 |
+
parser.add_argument("--url", default="http://localhost:8000", help="Environment server URL")
|
| 309 |
+
parser.add_argument("--task", default="all",
|
| 310 |
+
choices=["basic_validation", "edge_cases", "security_workflows", "all"])
|
| 311 |
+
parser.add_argument("--agent", default="all", choices=["random", "sequential", "smart", "all"])
|
| 312 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 313 |
+
args = parser.parse_args()
|
| 314 |
+
asyncio.run(main_async(args))
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
if __name__ == "__main__":
|
| 318 |
+
main()
|
training/grpo.py
ADDED
|
@@ -0,0 +1,783 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
GRPO Training Script for the API Testing Environment.
|
| 4 |
+
|
| 5 |
+
Trains a small LLM (Qwen 1.7B) to become an intelligent API tester
|
| 6 |
+
using Group Relative Policy Optimization (GRPO).
|
| 7 |
+
|
| 8 |
+
The environment IS the dataset β each reset(seed=N) creates a unique
|
| 9 |
+
episode with different users, tasks, and data. No external dataset needed.
|
| 10 |
+
|
| 11 |
+
Features:
|
| 12 |
+
- Auto-push trained model weights to HuggingFace Hub
|
| 13 |
+
- Weights & Biases logging for metrics, loss, rewards
|
| 14 |
+
- Baseline agent evaluation before GRPO (random, sequential, smart)
|
| 15 |
+
- Base model evaluation before GRPO for comparison
|
| 16 |
+
- Post-training evaluation with delta reporting
|
| 17 |
+
- Saves metrics, comparison tables, and plots to output dir
|
| 18 |
+
|
| 19 |
+
Usage:
|
| 20 |
+
# Quick test (CPU, 2 minutes)
|
| 21 |
+
python -m training.grpo --test-mode
|
| 22 |
+
|
| 23 |
+
# Real training (GPU required)
|
| 24 |
+
python -m training.grpo --model-id Qwen/Qwen3-1.7B --num-episodes 100
|
| 25 |
+
|
| 26 |
+
# With HF Hub push
|
| 27 |
+
python -m training.grpo --push-to-hub --hf-repo-id your-username/api-tester-grpo
|
| 28 |
+
|
| 29 |
+
# With Weights & Biases
|
| 30 |
+
python -m training.grpo --use-wandb --wandb-project api-testing-grpo
|
| 31 |
+
|
| 32 |
+
# See what prompts look like (no GPU needed)
|
| 33 |
+
SHOW_PROMPTS=1 python -m training.grpo
|
| 34 |
+
|
| 35 |
+
# Resume from checkpoint
|
| 36 |
+
python -m training.grpo --model-id ./checkpoints/step_50
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
import argparse
|
| 40 |
+
import json
|
| 41 |
+
import logging
|
| 42 |
+
import os
|
| 43 |
+
import sys
|
| 44 |
+
import time
|
| 45 |
+
|
| 46 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
| 47 |
+
|
| 48 |
+
# --- Suppress noisy HTTP/download logs ---
|
| 49 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
| 50 |
+
logger = logging.getLogger(__name__)
|
| 51 |
+
for _noisy in ["httpx", "httpcore", "urllib3", "huggingface_hub", "filelock",
|
| 52 |
+
"transformers.configuration_utils", "transformers.modeling_utils"]:
|
| 53 |
+
logging.getLogger(_noisy).setLevel(logging.WARNING)
|
| 54 |
+
|
| 55 |
+
# --- MONKEY PATCH FOR LLM-BLENDER ---
|
| 56 |
+
# llm-blender requires TRANSFORMERS_CACHE which was removed in transformers 4.42+
|
| 57 |
+
try:
|
| 58 |
+
import transformers.utils.hub
|
| 59 |
+
if not hasattr(transformers.utils.hub, "TRANSFORMERS_CACHE"):
|
| 60 |
+
transformers.utils.hub.TRANSFORMERS_CACHE = os.getenv("HF_HOME", os.path.expanduser("~/.cache/huggingface/hub"))
|
| 61 |
+
except ImportError:
|
| 62 |
+
pass
|
| 63 |
+
# ------------------------------------
|
| 64 |
+
|
| 65 |
+
from server.environment import APITestEnvironment
|
| 66 |
+
from .prompts import PLAN_SYSTEM_PROMPT, format_plan_prompt
|
| 67 |
+
from .rewards import format_reward_fn, plan_reward_fn, diversity_reward_fn
|
| 68 |
+
from .evaluate import run_rollout, run_baseline_local
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def build_training_prompts(
|
| 72 |
+
num_episodes: int = 50,
|
| 73 |
+
task_ids: list[str] | None = None,
|
| 74 |
+
) -> list[dict]:
|
| 75 |
+
"""Generate training prompts for GRPO plan-based training.
|
| 76 |
+
|
| 77 |
+
Each prompt asks the model to output a COMPLETE TEST PLAN (JSON array of actions).
|
| 78 |
+
The reward function will execute the plan on a fresh environment and score it.
|
| 79 |
+
"""
|
| 80 |
+
if task_ids is None:
|
| 81 |
+
task_ids = ["basic_validation", "edge_cases", "security_workflows"]
|
| 82 |
+
|
| 83 |
+
prompts = []
|
| 84 |
+
env = APITestEnvironment()
|
| 85 |
+
|
| 86 |
+
for i in range(num_episodes):
|
| 87 |
+
task_id = task_ids[i % len(task_ids)]
|
| 88 |
+
seed = i * 1000 + 42
|
| 89 |
+
|
| 90 |
+
obs = env.reset(seed=seed, task_id=task_id)
|
| 91 |
+
user_message = format_plan_prompt(obs)
|
| 92 |
+
|
| 93 |
+
prompt_messages = [
|
| 94 |
+
{"role": "system", "content": PLAN_SYSTEM_PROMPT},
|
| 95 |
+
{"role": "user", "content": user_message},
|
| 96 |
+
]
|
| 97 |
+
|
| 98 |
+
prompts.append({
|
| 99 |
+
"prompt": prompt_messages,
|
| 100 |
+
"task_id": task_id,
|
| 101 |
+
"seed": seed,
|
| 102 |
+
})
|
| 103 |
+
|
| 104 |
+
logger.info(f"Generated {len(prompts)} training prompts across tasks: {task_ids}")
|
| 105 |
+
return prompts
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def run_baseline_evaluation(seed: int = 9999) -> dict:
|
| 109 |
+
"""Run all baseline agents and return results for comparison.
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
dict with structure: {agent_name: {task_id: result_dict}}
|
| 113 |
+
"""
|
| 114 |
+
logger.info("=" * 60)
|
| 115 |
+
logger.info("Running BASELINE AGENT evaluation...")
|
| 116 |
+
logger.info("=" * 60)
|
| 117 |
+
|
| 118 |
+
results = run_baseline_local(agent_name="all", task_id="all", seed=seed)
|
| 119 |
+
|
| 120 |
+
# Organize by agent -> task
|
| 121 |
+
organized = {}
|
| 122 |
+
for r in results:
|
| 123 |
+
agent = r["agent"]
|
| 124 |
+
if agent not in organized:
|
| 125 |
+
organized[agent] = {}
|
| 126 |
+
organized[agent][r["task_id"]] = r
|
| 127 |
+
|
| 128 |
+
# Print summary table
|
| 129 |
+
print("\n" + "=" * 90)
|
| 130 |
+
print("BASELINE AGENT RESULTS")
|
| 131 |
+
print("=" * 90)
|
| 132 |
+
print(f"{'Agent':<15} {'Task':<25} {'Reward':<10} {'Bugs':<12} {'Coverage':<10}")
|
| 133 |
+
print("-" * 90)
|
| 134 |
+
for agent_name in ["random", "sequential", "smart"]:
|
| 135 |
+
if agent_name not in organized:
|
| 136 |
+
continue
|
| 137 |
+
for task_id in ["basic_validation", "edge_cases", "security_workflows"]:
|
| 138 |
+
r = organized[agent_name].get(task_id, {})
|
| 139 |
+
print(
|
| 140 |
+
f"{agent_name:<15} {task_id:<25} "
|
| 141 |
+
f"{r.get('total_reward', 0):<10.4f} "
|
| 142 |
+
f"{r.get('bugs_found', 0)}/{r.get('total_bugs', 0):<10} "
|
| 143 |
+
f"{r.get('coverage_pct', 0):<10.1f}%"
|
| 144 |
+
)
|
| 145 |
+
print("-" * 90)
|
| 146 |
+
print("=" * 90 + "\n")
|
| 147 |
+
|
| 148 |
+
return organized
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def save_metrics(
|
| 152 |
+
output_dir: str,
|
| 153 |
+
baseline_results: dict,
|
| 154 |
+
base_model_results: dict,
|
| 155 |
+
trained_model_results: dict,
|
| 156 |
+
training_args: dict,
|
| 157 |
+
training_time_s: float,
|
| 158 |
+
):
|
| 159 |
+
"""Save all metrics and comparison data to output_dir/metrics/."""
|
| 160 |
+
metrics_dir = os.path.join(output_dir, "metrics")
|
| 161 |
+
os.makedirs(metrics_dir, exist_ok=True)
|
| 162 |
+
|
| 163 |
+
# Full results JSON
|
| 164 |
+
all_results = {
|
| 165 |
+
"training_args": training_args,
|
| 166 |
+
"training_time_seconds": round(training_time_s, 1),
|
| 167 |
+
"baseline_agents": {},
|
| 168 |
+
"base_model": base_model_results,
|
| 169 |
+
"trained_model": trained_model_results,
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
# Flatten baseline results
|
| 173 |
+
for agent_name, tasks in baseline_results.items():
|
| 174 |
+
all_results["baseline_agents"][agent_name] = {}
|
| 175 |
+
for task_id, r in tasks.items():
|
| 176 |
+
all_results["baseline_agents"][agent_name][task_id] = {
|
| 177 |
+
"total_reward": r.get("total_reward", 0),
|
| 178 |
+
"bugs_found": r.get("bugs_found", 0),
|
| 179 |
+
"total_bugs": r.get("total_bugs", 0),
|
| 180 |
+
"coverage_pct": r.get("coverage_pct", 0),
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
with open(os.path.join(metrics_dir, "results.json"), "w") as f:
|
| 184 |
+
json.dump(all_results, f, indent=2)
|
| 185 |
+
|
| 186 |
+
# Comparison table as markdown
|
| 187 |
+
md_lines = ["# Training Results\n"]
|
| 188 |
+
md_lines.append(f"**Model**: {training_args.get('model_id', 'unknown')}")
|
| 189 |
+
md_lines.append(f"**Training time**: {training_time_s / 60:.1f} minutes")
|
| 190 |
+
md_lines.append(f"**Episodes**: {training_args.get('num_episodes', 0)}")
|
| 191 |
+
md_lines.append(f"**Max steps**: {training_args.get('max_steps', 0)}\n")
|
| 192 |
+
|
| 193 |
+
md_lines.append("## Comparison Table\n")
|
| 194 |
+
md_lines.append("| Agent/Model | Task | Reward | Bugs | Coverage |")
|
| 195 |
+
md_lines.append("|---|---|---|---|---|")
|
| 196 |
+
|
| 197 |
+
# Baselines
|
| 198 |
+
for agent_name in ["random", "sequential", "smart"]:
|
| 199 |
+
if agent_name not in baseline_results:
|
| 200 |
+
continue
|
| 201 |
+
for task_id in ["basic_validation", "edge_cases", "security_workflows"]:
|
| 202 |
+
r = baseline_results[agent_name].get(task_id, {})
|
| 203 |
+
md_lines.append(
|
| 204 |
+
f"| {agent_name} | {task_id} | "
|
| 205 |
+
f"{r.get('total_reward', 0):.4f} | "
|
| 206 |
+
f"{r.get('bugs_found', 0)}/{r.get('total_bugs', 0)} | "
|
| 207 |
+
f"{r.get('coverage_pct', 0):.1f}% |"
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
# Base model
|
| 211 |
+
for task_id in ["basic_validation", "edge_cases", "security_workflows"]:
|
| 212 |
+
r = base_model_results.get(task_id, {})
|
| 213 |
+
md_lines.append(
|
| 214 |
+
f"| **base model** | {task_id} | "
|
| 215 |
+
f"{r.get('total_reward', 0):.4f} | "
|
| 216 |
+
f"{r.get('bugs_found', 0)}/{r.get('total_bugs', 0)} | "
|
| 217 |
+
f"{r.get('coverage_pct', 0):.1f}% |"
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
# Trained model
|
| 221 |
+
for task_id in ["basic_validation", "edge_cases", "security_workflows"]:
|
| 222 |
+
r = trained_model_results.get(task_id, {})
|
| 223 |
+
base = base_model_results.get(task_id, {})
|
| 224 |
+
delta = r.get("total_reward", 0) - base.get("total_reward", 0)
|
| 225 |
+
md_lines.append(
|
| 226 |
+
f"| **GRPO trained** | {task_id} | "
|
| 227 |
+
f"{r.get('total_reward', 0):.4f} ({delta:+.4f}) | "
|
| 228 |
+
f"{r.get('bugs_found', 0)}/{r.get('total_bugs', 0)} | "
|
| 229 |
+
f"{r.get('coverage_pct', 0):.1f}% |"
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
md_lines.append("")
|
| 233 |
+
with open(os.path.join(metrics_dir, "results.md"), "w") as f:
|
| 234 |
+
f.write("\n".join(md_lines))
|
| 235 |
+
|
| 236 |
+
logger.info(f"Metrics saved to {metrics_dir}/")
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def save_plots(output_dir: str, baseline_results: dict, base_model_results: dict, trained_model_results: dict):
|
| 240 |
+
"""Generate and save comparison plots."""
|
| 241 |
+
try:
|
| 242 |
+
import matplotlib
|
| 243 |
+
matplotlib.use("Agg")
|
| 244 |
+
import matplotlib.pyplot as plt
|
| 245 |
+
import numpy as np
|
| 246 |
+
except ImportError:
|
| 247 |
+
logger.warning("matplotlib not installed β skipping plot generation. pip install matplotlib")
|
| 248 |
+
return
|
| 249 |
+
|
| 250 |
+
plots_dir = os.path.join(output_dir, "metrics", "plots")
|
| 251 |
+
os.makedirs(plots_dir, exist_ok=True)
|
| 252 |
+
|
| 253 |
+
tasks = ["basic_validation", "edge_cases", "security_workflows"]
|
| 254 |
+
task_labels = ["Basic", "Edge Cases", "Security"]
|
| 255 |
+
|
| 256 |
+
# --- Plot 1: Reward comparison bar chart ---
|
| 257 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
| 258 |
+
x = np.arange(len(tasks))
|
| 259 |
+
width = 0.15
|
| 260 |
+
|
| 261 |
+
agents_to_plot = []
|
| 262 |
+
for agent_name in ["random", "sequential", "smart"]:
|
| 263 |
+
if agent_name in baseline_results:
|
| 264 |
+
rewards = [baseline_results[agent_name].get(t, {}).get("total_reward", 0) for t in tasks]
|
| 265 |
+
agents_to_plot.append((agent_name, rewards))
|
| 266 |
+
|
| 267 |
+
base_rewards = [base_model_results.get(t, {}).get("total_reward", 0) for t in tasks]
|
| 268 |
+
agents_to_plot.append(("Base Model", base_rewards))
|
| 269 |
+
|
| 270 |
+
trained_rewards = [trained_model_results.get(t, {}).get("total_reward", 0) for t in tasks]
|
| 271 |
+
agents_to_plot.append(("GRPO Trained", trained_rewards))
|
| 272 |
+
|
| 273 |
+
colors = ["#95a5a6", "#3498db", "#e67e22", "#9b59b6", "#2ecc71"]
|
| 274 |
+
for i, (name, rewards) in enumerate(agents_to_plot):
|
| 275 |
+
offset = (i - len(agents_to_plot) / 2 + 0.5) * width
|
| 276 |
+
bars = ax.bar(x + offset, rewards, width, label=name, color=colors[i % len(colors)])
|
| 277 |
+
for bar, val in zip(bars, rewards):
|
| 278 |
+
if val > 0.01:
|
| 279 |
+
ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01,
|
| 280 |
+
f"{val:.2f}", ha="center", va="bottom", fontsize=7)
|
| 281 |
+
|
| 282 |
+
ax.set_xlabel("Task")
|
| 283 |
+
ax.set_ylabel("Total Reward")
|
| 284 |
+
ax.set_title("Reward Comparison: Baselines vs Base Model vs GRPO Trained")
|
| 285 |
+
ax.set_xticks(x)
|
| 286 |
+
ax.set_xticklabels(task_labels)
|
| 287 |
+
ax.legend()
|
| 288 |
+
ax.set_ylim(bottom=0)
|
| 289 |
+
plt.tight_layout()
|
| 290 |
+
fig.savefig(os.path.join(plots_dir, "reward_comparison.png"), dpi=150)
|
| 291 |
+
plt.close(fig)
|
| 292 |
+
|
| 293 |
+
# --- Plot 2: Bugs found comparison ---
|
| 294 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
| 295 |
+
for i, (name, _) in enumerate(agents_to_plot):
|
| 296 |
+
if name in baseline_results:
|
| 297 |
+
bugs = [baseline_results[name].get(t, {}).get("bugs_found", 0) for t in tasks]
|
| 298 |
+
elif name == "Base Model":
|
| 299 |
+
bugs = [base_model_results.get(t, {}).get("bugs_found", 0) for t in tasks]
|
| 300 |
+
else:
|
| 301 |
+
bugs = [trained_model_results.get(t, {}).get("bugs_found", 0) for t in tasks]
|
| 302 |
+
offset = (i - len(agents_to_plot) / 2 + 0.5) * width
|
| 303 |
+
ax.bar(x + offset, bugs, width, label=name, color=colors[i % len(colors)])
|
| 304 |
+
|
| 305 |
+
total_bugs = [base_model_results.get(t, {}).get("total_bugs", 0) or
|
| 306 |
+
trained_model_results.get(t, {}).get("total_bugs", 0) for t in tasks]
|
| 307 |
+
ax.plot(x, total_bugs, "k--", marker="D", label="Total Bugs", linewidth=1.5)
|
| 308 |
+
|
| 309 |
+
ax.set_xlabel("Task")
|
| 310 |
+
ax.set_ylabel("Bugs Found")
|
| 311 |
+
ax.set_title("Bug Discovery: Baselines vs Base Model vs GRPO Trained")
|
| 312 |
+
ax.set_xticks(x)
|
| 313 |
+
ax.set_xticklabels(task_labels)
|
| 314 |
+
ax.legend()
|
| 315 |
+
ax.set_ylim(bottom=0)
|
| 316 |
+
plt.tight_layout()
|
| 317 |
+
fig.savefig(os.path.join(plots_dir, "bugs_comparison.png"), dpi=150)
|
| 318 |
+
plt.close(fig)
|
| 319 |
+
|
| 320 |
+
# --- Plot 3: Coverage comparison ---
|
| 321 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
| 322 |
+
for i, (name, _) in enumerate(agents_to_plot):
|
| 323 |
+
if name in baseline_results:
|
| 324 |
+
cov = [baseline_results[name].get(t, {}).get("coverage_pct", 0) for t in tasks]
|
| 325 |
+
elif name == "Base Model":
|
| 326 |
+
cov = [base_model_results.get(t, {}).get("coverage_pct", 0) for t in tasks]
|
| 327 |
+
else:
|
| 328 |
+
cov = [trained_model_results.get(t, {}).get("coverage_pct", 0) for t in tasks]
|
| 329 |
+
offset = (i - len(agents_to_plot) / 2 + 0.5) * width
|
| 330 |
+
ax.bar(x + offset, cov, width, label=name, color=colors[i % len(colors)])
|
| 331 |
+
|
| 332 |
+
ax.set_xlabel("Task")
|
| 333 |
+
ax.set_ylabel("Coverage %")
|
| 334 |
+
ax.set_title("API Coverage: Baselines vs Base Model vs GRPO Trained")
|
| 335 |
+
ax.set_xticks(x)
|
| 336 |
+
ax.set_xticklabels(task_labels)
|
| 337 |
+
ax.legend()
|
| 338 |
+
ax.set_ylim(0, 105)
|
| 339 |
+
plt.tight_layout()
|
| 340 |
+
fig.savefig(os.path.join(plots_dir, "coverage_comparison.png"), dpi=150)
|
| 341 |
+
plt.close(fig)
|
| 342 |
+
|
| 343 |
+
logger.info(f"Plots saved to {plots_dir}/")
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
def train_grpo(args):
|
| 347 |
+
"""Run GRPO training with TRL."""
|
| 348 |
+
try:
|
| 349 |
+
from datasets import Dataset
|
| 350 |
+
from peft import LoraConfig
|
| 351 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 352 |
+
from trl import GRPOConfig, GRPOTrainer
|
| 353 |
+
|
| 354 |
+
# --- MONKEY PATCH FOR TRL GRPOTrainer ---
|
| 355 |
+
# trl 0.15 lacks `dataset` argument in `_get_train_sampler` required by transformers 4.57+
|
| 356 |
+
import inspect
|
| 357 |
+
if hasattr(GRPOTrainer, "_get_train_sampler"):
|
| 358 |
+
sig = inspect.signature(GRPOTrainer._get_train_sampler)
|
| 359 |
+
if "dataset" not in sig.parameters:
|
| 360 |
+
_old_sampler = GRPOTrainer._get_train_sampler
|
| 361 |
+
def _new_sampler(self, dataset=None, **kwargs):
|
| 362 |
+
return _old_sampler(self)
|
| 363 |
+
GRPOTrainer._get_train_sampler = _new_sampler
|
| 364 |
+
# ----------------------------------------
|
| 365 |
+
except ImportError as e:
|
| 366 |
+
logger.error(
|
| 367 |
+
f"Missing dependency: {e}\n"
|
| 368 |
+
"Install with: pip install trl transformers peft datasets torch"
|
| 369 |
+
)
|
| 370 |
+
sys.exit(1)
|
| 371 |
+
|
| 372 |
+
# --- W&B setup ---
|
| 373 |
+
wandb_run = None
|
| 374 |
+
report_to = "none"
|
| 375 |
+
if args.use_wandb:
|
| 376 |
+
try:
|
| 377 |
+
import wandb
|
| 378 |
+
wandb_run = wandb.init(
|
| 379 |
+
project=args.wandb_project,
|
| 380 |
+
name=args.wandb_run_name or f"grpo-{args.model_id.split('/')[-1]}-{int(time.time())}",
|
| 381 |
+
config={
|
| 382 |
+
"model_id": args.model_id,
|
| 383 |
+
"num_episodes": args.num_episodes,
|
| 384 |
+
"num_generations": args.num_generations,
|
| 385 |
+
"max_steps": args.max_steps,
|
| 386 |
+
"learning_rate": args.learning_rate,
|
| 387 |
+
"batch_size": args.batch_size,
|
| 388 |
+
"max_completion_length": args.max_completion_length,
|
| 389 |
+
"lora_r": 16,
|
| 390 |
+
"lora_alpha": 32,
|
| 391 |
+
},
|
| 392 |
+
)
|
| 393 |
+
report_to = "wandb"
|
| 394 |
+
logger.info(f"W&B initialized: project={args.wandb_project}, run={wandb_run.name}")
|
| 395 |
+
except ImportError:
|
| 396 |
+
logger.warning("wandb not installed β skipping W&B logging. pip install wandb")
|
| 397 |
+
args.use_wandb = False
|
| 398 |
+
|
| 399 |
+
training_args_dict = {
|
| 400 |
+
"model_id": args.model_id,
|
| 401 |
+
"num_episodes": args.num_episodes,
|
| 402 |
+
"num_generations": args.num_generations,
|
| 403 |
+
"max_steps": args.max_steps,
|
| 404 |
+
"learning_rate": args.learning_rate,
|
| 405 |
+
"batch_size": args.batch_size,
|
| 406 |
+
"max_completion_length": args.max_completion_length,
|
| 407 |
+
"output_dir": args.output_dir,
|
| 408 |
+
"test_mode": args.test_mode,
|
| 409 |
+
}
|
| 410 |
+
|
| 411 |
+
# ================================================================
|
| 412 |
+
# PIPELINE OVERVIEW
|
| 413 |
+
# ================================================================
|
| 414 |
+
total_pipeline_steps = 11
|
| 415 |
+
def _step(n, msg):
|
| 416 |
+
bar = "β" * n + "β" * (total_pipeline_steps - n)
|
| 417 |
+
print(f"\n{'='*70}")
|
| 418 |
+
print(f" [{bar}] Step {n}/{total_pipeline_steps}: {msg}")
|
| 419 |
+
print(f"{'='*70}\n")
|
| 420 |
+
|
| 421 |
+
# --- Step 1: Run baseline agent evaluation ---
|
| 422 |
+
_step(1, "Running baseline agents (random, sequential, smart)")
|
| 423 |
+
baseline_results = run_baseline_evaluation(seed=9999)
|
| 424 |
+
|
| 425 |
+
if args.use_wandb and wandb_run:
|
| 426 |
+
import wandb
|
| 427 |
+
for agent_name, tasks in baseline_results.items():
|
| 428 |
+
for task_id, r in tasks.items():
|
| 429 |
+
wandb.log({
|
| 430 |
+
f"baseline/{agent_name}/{task_id}/reward": r["total_reward"],
|
| 431 |
+
f"baseline/{agent_name}/{task_id}/bugs": r["bugs_found"],
|
| 432 |
+
f"baseline/{agent_name}/{task_id}/coverage": r["coverage_pct"],
|
| 433 |
+
})
|
| 434 |
+
|
| 435 |
+
# --- Step 2: Load model and tokenizer ---
|
| 436 |
+
_step(2, f"Loading model: {args.model_id}")
|
| 437 |
+
print(" Downloading tokenizer...", flush=True)
|
| 438 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True)
|
| 439 |
+
if tokenizer.pad_token is None:
|
| 440 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 441 |
+
print(" Tokenizer loaded.", flush=True)
|
| 442 |
+
|
| 443 |
+
import torch
|
| 444 |
+
|
| 445 |
+
# --- Force GPU detection ---
|
| 446 |
+
if torch.cuda.is_available():
|
| 447 |
+
device_map = "auto"
|
| 448 |
+
dtype = torch.bfloat16
|
| 449 |
+
gpu_name = torch.cuda.get_device_name(0)
|
| 450 |
+
gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
|
| 451 |
+
print(f" GPU: {gpu_name} ({gpu_mem:.1f} GB)", flush=True)
|
| 452 |
+
print(f" CUDA version: {torch.version.cuda}", flush=True)
|
| 453 |
+
elif torch.backends.mps.is_available():
|
| 454 |
+
device_map = "auto"
|
| 455 |
+
dtype = torch.float16
|
| 456 |
+
print(" Device: Apple MPS", flush=True)
|
| 457 |
+
else:
|
| 458 |
+
# Still try to use GPU β sometimes torch.cuda.is_available() is False
|
| 459 |
+
# because of driver issues but CUDA can still work
|
| 460 |
+
device_map = None
|
| 461 |
+
dtype = torch.float32
|
| 462 |
+
print(" !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!", flush=True)
|
| 463 |
+
print(" !! WARNING: No GPU detected β running on CPU !!", flush=True)
|
| 464 |
+
print(" !! Training will be EXTREMELY slow. !!", flush=True)
|
| 465 |
+
print(" !! Check: python -c 'import torch; print(torch.cuda.is_available())'", flush=True)
|
| 466 |
+
print(" !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!", flush=True)
|
| 467 |
+
|
| 468 |
+
print(" Downloading model weights...", flush=True)
|
| 469 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 470 |
+
args.model_id,
|
| 471 |
+
trust_remote_code=True,
|
| 472 |
+
torch_dtype=dtype,
|
| 473 |
+
device_map=device_map,
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
# Verify model is actually on GPU
|
| 477 |
+
actual_device = next(model.parameters()).device
|
| 478 |
+
param_count = sum(p.numel() for p in model.parameters()) / 1e6
|
| 479 |
+
print(f" Model loaded: {param_count:.0f}M parameters on {actual_device}", flush=True)
|
| 480 |
+
|
| 481 |
+
if torch.cuda.is_available() and actual_device.type != "cuda":
|
| 482 |
+
print(" Model not on GPU β forcing move to CUDA...", flush=True)
|
| 483 |
+
model = model.to("cuda")
|
| 484 |
+
print(f" Moved to: {next(model.parameters()).device}", flush=True)
|
| 485 |
+
|
| 486 |
+
# --- Step 3: Evaluate base model BEFORE training ---
|
| 487 |
+
_step(3, f"Evaluating BASE model (before GRPO, max {args.eval_max_steps} steps/task)")
|
| 488 |
+
base_results = {}
|
| 489 |
+
if not args.skip_eval:
|
| 490 |
+
for task_id in ["basic_validation", "edge_cases", "security_workflows"]:
|
| 491 |
+
result = run_rollout(model, tokenizer, task_id=task_id, seed=9999, max_steps=args.eval_max_steps)
|
| 492 |
+
base_results[task_id] = result
|
| 493 |
+
logger.info(
|
| 494 |
+
f" [BASE] {task_id}: reward={result['total_reward']:.3f}, "
|
| 495 |
+
f"bugs={result['bugs_found']}/{result['total_bugs']}, "
|
| 496 |
+
f"coverage={result['coverage_pct']:.1f}%"
|
| 497 |
+
)
|
| 498 |
+
if args.use_wandb and wandb_run:
|
| 499 |
+
import wandb
|
| 500 |
+
wandb.log({
|
| 501 |
+
f"base_model/{task_id}/reward": result["total_reward"],
|
| 502 |
+
f"base_model/{task_id}/bugs": result["bugs_found"],
|
| 503 |
+
f"base_model/{task_id}/coverage": result["coverage_pct"],
|
| 504 |
+
})
|
| 505 |
+
else:
|
| 506 |
+
logger.info("Skipping base model evaluation (--skip-eval)")
|
| 507 |
+
for task_id in ["basic_validation", "edge_cases", "security_workflows"]:
|
| 508 |
+
base_results[task_id] = {"total_reward": 0, "bugs_found": 0, "total_bugs": 0, "coverage_pct": 0}
|
| 509 |
+
|
| 510 |
+
# --- Step 4: LoRA config ---
|
| 511 |
+
_step(4, "Configuring LoRA adapters")
|
| 512 |
+
lora_config = LoraConfig(
|
| 513 |
+
r=16,
|
| 514 |
+
lora_alpha=32,
|
| 515 |
+
lora_dropout=0.05,
|
| 516 |
+
target_modules=["q_proj", "v_proj"],
|
| 517 |
+
task_type="CAUSAL_LM",
|
| 518 |
+
)
|
| 519 |
+
print(f" LoRA: r=16, alpha=32, targets=q_proj+v_proj", flush=True)
|
| 520 |
+
|
| 521 |
+
# --- Step 5: Generate training prompts ---
|
| 522 |
+
_step(5, f"Generating {args.num_episodes} training episodes")
|
| 523 |
+
raw_prompts = build_training_prompts(num_episodes=args.num_episodes)
|
| 524 |
+
print(f" {len(raw_prompts)} prompts across 3 tasks (each with unique seed)", flush=True)
|
| 525 |
+
|
| 526 |
+
# Qwen3 thinking mode: let the model reason before outputting JSON
|
| 527 |
+
# Requires higher max_completion_length (~2048) to fit <think>...</think> + JSON
|
| 528 |
+
chat_template_kwargs = {}
|
| 529 |
+
if "qwen3" in args.model_id.lower():
|
| 530 |
+
chat_template_kwargs["enable_thinking"] = True
|
| 531 |
+
logger.info("Qwen3 detected β thinking mode ENABLED (model will reason before acting)")
|
| 532 |
+
|
| 533 |
+
formatted_prompts = []
|
| 534 |
+
for p in raw_prompts:
|
| 535 |
+
text = tokenizer.apply_chat_template(
|
| 536 |
+
p["prompt"], tokenize=False, add_generation_prompt=True,
|
| 537 |
+
**chat_template_kwargs,
|
| 538 |
+
)
|
| 539 |
+
formatted_prompts.append({"prompt": text, "task_id": p["task_id"], "seed": p["seed"]})
|
| 540 |
+
|
| 541 |
+
dataset = Dataset.from_list(formatted_prompts)
|
| 542 |
+
|
| 543 |
+
# Store prompt metadata for the reward function to create fresh envs
|
| 544 |
+
prompts_meta = [{"seed": p["seed"], "task_id": p["task_id"]} for p in raw_prompts]
|
| 545 |
+
|
| 546 |
+
# Combined reward: format (valid JSON array?) + plan (execute all actions) + diversity (varied requests?)
|
| 547 |
+
# Each generation gets a FRESH environment β no shared state pollution
|
| 548 |
+
def combined_reward_fn(completions, **kwargs):
|
| 549 |
+
fmt = format_reward_fn(completions)
|
| 550 |
+
plan = plan_reward_fn(completions, prompts_meta=prompts_meta)
|
| 551 |
+
div = diversity_reward_fn(completions)
|
| 552 |
+
return [f + p + d for f, p, d in zip(fmt, plan, div)]
|
| 553 |
+
|
| 554 |
+
# --- Step 6: GRPO training ---
|
| 555 |
+
_step(6, f"GRPO training ({args.max_steps} steps, {args.num_generations} generations/prompt)")
|
| 556 |
+
config = GRPOConfig(
|
| 557 |
+
output_dir=args.output_dir,
|
| 558 |
+
num_generations=args.num_generations,
|
| 559 |
+
max_completion_length=args.max_completion_length,
|
| 560 |
+
learning_rate=args.learning_rate,
|
| 561 |
+
per_device_train_batch_size=args.batch_size,
|
| 562 |
+
num_train_epochs=1,
|
| 563 |
+
max_steps=args.max_steps,
|
| 564 |
+
logging_steps=5,
|
| 565 |
+
save_steps=50,
|
| 566 |
+
save_total_limit=3,
|
| 567 |
+
report_to=report_to,
|
| 568 |
+
temperature=0.8,
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
trainer = GRPOTrainer(
|
| 572 |
+
model=model,
|
| 573 |
+
args=config,
|
| 574 |
+
reward_funcs=[combined_reward_fn],
|
| 575 |
+
train_dataset=dataset,
|
| 576 |
+
peft_config=lora_config,
|
| 577 |
+
processing_class=tokenizer,
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
print(f" Config: lr={args.learning_rate}, batch={args.batch_size}, "
|
| 581 |
+
f"max_completion={args.max_completion_length}, temp=0.8", flush=True)
|
| 582 |
+
print(f" Rewards: format_reward + plan_reward + diversity_reward", flush=True)
|
| 583 |
+
print(f" Training begins... (progress bar below)\n", flush=True)
|
| 584 |
+
|
| 585 |
+
train_start = time.time()
|
| 586 |
+
trainer.train()
|
| 587 |
+
training_time = time.time() - train_start
|
| 588 |
+
print(f"\n Training completed in {training_time / 60:.1f} minutes", flush=True)
|
| 589 |
+
|
| 590 |
+
# --- Step 7: Save model locally ---
|
| 591 |
+
_step(7, f"Saving model to {args.output_dir}")
|
| 592 |
+
trainer.save_model(args.output_dir)
|
| 593 |
+
tokenizer.save_pretrained(args.output_dir)
|
| 594 |
+
print(f" Model + tokenizer saved.", flush=True)
|
| 595 |
+
|
| 596 |
+
# --- Step 8: Push to HuggingFace Hub ---
|
| 597 |
+
_step(8, "Pushing to HuggingFace Hub" if args.push_to_hub else "HF Hub push (skipped β use --push-to-hub)")
|
| 598 |
+
if args.push_to_hub:
|
| 599 |
+
hf_repo = args.hf_repo_id
|
| 600 |
+
if not hf_repo:
|
| 601 |
+
logger.error("--hf-repo-id is required when using --push-to-hub")
|
| 602 |
+
else:
|
| 603 |
+
try:
|
| 604 |
+
logger.info(f"Pushing model to HuggingFace Hub: {hf_repo}")
|
| 605 |
+
trainer.push_to_hub(repo_id=hf_repo, commit_message="GRPO trained API testing agent")
|
| 606 |
+
tokenizer.push_to_hub(repo_id=hf_repo, commit_message="GRPO trained API testing agent")
|
| 607 |
+
logger.info(f"Model pushed to https://huggingface.co/{hf_repo}")
|
| 608 |
+
except Exception as e:
|
| 609 |
+
logger.error(f"Failed to push to HF Hub: {e}")
|
| 610 |
+
logger.info("Make sure you're logged in: huggingface-cli login")
|
| 611 |
+
|
| 612 |
+
# --- Step 9: Evaluate AFTER training ---
|
| 613 |
+
_step(9, f"Evaluating TRAINED model (max {args.eval_max_steps} steps/task)")
|
| 614 |
+
trained_results = {}
|
| 615 |
+
if not args.skip_eval:
|
| 616 |
+
for task_id in ["basic_validation", "edge_cases", "security_workflows"]:
|
| 617 |
+
result = run_rollout(model, tokenizer, task_id=task_id, seed=9999, max_steps=args.eval_max_steps)
|
| 618 |
+
trained_results[task_id] = result
|
| 619 |
+
base = base_results[task_id]
|
| 620 |
+
reward_delta = result["total_reward"] - base.get("total_reward", 0)
|
| 621 |
+
bug_delta = result["bugs_found"] - base.get("bugs_found", 0)
|
| 622 |
+
cov_delta = result["coverage_pct"] - base.get("coverage_pct", 0)
|
| 623 |
+
logger.info(
|
| 624 |
+
f" [TRAINED] {task_id}: reward={result['total_reward']:.3f} ({reward_delta:+.3f}), "
|
| 625 |
+
f"bugs={result['bugs_found']}/{result['total_bugs']} ({bug_delta:+d}), "
|
| 626 |
+
f"coverage={result['coverage_pct']:.1f}% ({cov_delta:+.1f}%)"
|
| 627 |
+
)
|
| 628 |
+
if args.use_wandb and wandb_run:
|
| 629 |
+
import wandb
|
| 630 |
+
wandb.log({
|
| 631 |
+
f"trained_model/{task_id}/reward": result["total_reward"],
|
| 632 |
+
f"trained_model/{task_id}/bugs": result["bugs_found"],
|
| 633 |
+
f"trained_model/{task_id}/coverage": result["coverage_pct"],
|
| 634 |
+
f"delta/{task_id}/reward": reward_delta,
|
| 635 |
+
f"delta/{task_id}/bugs": bug_delta,
|
| 636 |
+
f"delta/{task_id}/coverage": cov_delta,
|
| 637 |
+
})
|
| 638 |
+
else:
|
| 639 |
+
logger.info("Skipping trained model evaluation (--skip-eval)")
|
| 640 |
+
for task_id in ["basic_validation", "edge_cases", "security_workflows"]:
|
| 641 |
+
trained_results[task_id] = {"total_reward": 0, "bugs_found": 0, "total_bugs": 0, "coverage_pct": 0}
|
| 642 |
+
|
| 643 |
+
# --- Step 10: Print final comparison table ---
|
| 644 |
+
_step(10, "Results comparison table")
|
| 645 |
+
print("=" * 95)
|
| 646 |
+
print("FINAL COMPARISON: All Agents & Models")
|
| 647 |
+
print("=" * 95)
|
| 648 |
+
print(f"{'Agent/Model':<18} {'Task':<25} {'Reward':<10} {'Bugs':<12} {'Coverage':<10}")
|
| 649 |
+
print("-" * 95)
|
| 650 |
+
|
| 651 |
+
for agent_name in ["random", "sequential", "smart"]:
|
| 652 |
+
if agent_name in baseline_results:
|
| 653 |
+
for task_id in ["basic_validation", "edge_cases", "security_workflows"]:
|
| 654 |
+
r = baseline_results[agent_name].get(task_id, {})
|
| 655 |
+
print(
|
| 656 |
+
f"{agent_name:<18} {task_id:<25} "
|
| 657 |
+
f"{r.get('total_reward', 0):<10.4f} "
|
| 658 |
+
f"{r.get('bugs_found', 0)}/{r.get('total_bugs', 0):<10} "
|
| 659 |
+
f"{r.get('coverage_pct', 0):<10.1f}%"
|
| 660 |
+
)
|
| 661 |
+
print("-" * 95)
|
| 662 |
+
|
| 663 |
+
for task_id in ["basic_validation", "edge_cases", "security_workflows"]:
|
| 664 |
+
r = base_results[task_id]
|
| 665 |
+
print(
|
| 666 |
+
f"{'Base Model':<18} {task_id:<25} "
|
| 667 |
+
f"{r['total_reward']:<10.4f} "
|
| 668 |
+
f"{r['bugs_found']}/{r['total_bugs']:<10} "
|
| 669 |
+
f"{r['coverage_pct']:<10.1f}%"
|
| 670 |
+
)
|
| 671 |
+
print("-" * 95)
|
| 672 |
+
|
| 673 |
+
for task_id in ["basic_validation", "edge_cases", "security_workflows"]:
|
| 674 |
+
r = trained_results[task_id]
|
| 675 |
+
base = base_results[task_id]
|
| 676 |
+
delta = r["total_reward"] - base["total_reward"]
|
| 677 |
+
print(
|
| 678 |
+
f"{'GRPO Trained':<18} {task_id:<25} "
|
| 679 |
+
f"{r['total_reward']:<10.4f} "
|
| 680 |
+
f"{r['bugs_found']}/{r['total_bugs']:<10} "
|
| 681 |
+
f"{r['coverage_pct']:<10.1f}% ({delta:+.4f})"
|
| 682 |
+
)
|
| 683 |
+
print("=" * 95)
|
| 684 |
+
|
| 685 |
+
# --- Step 11: Save metrics & plots ---
|
| 686 |
+
_step(11, "Saving metrics, plots, and finalizing")
|
| 687 |
+
save_metrics(
|
| 688 |
+
output_dir=args.output_dir,
|
| 689 |
+
baseline_results=baseline_results,
|
| 690 |
+
base_model_results=base_results,
|
| 691 |
+
trained_model_results=trained_results,
|
| 692 |
+
training_args=training_args_dict,
|
| 693 |
+
training_time_s=training_time,
|
| 694 |
+
)
|
| 695 |
+
save_plots(
|
| 696 |
+
output_dir=args.output_dir,
|
| 697 |
+
baseline_results=baseline_results,
|
| 698 |
+
base_model_results=base_results,
|
| 699 |
+
trained_model_results=trained_results,
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
# --- Finalize W&B ---
|
| 703 |
+
if args.use_wandb and wandb_run:
|
| 704 |
+
import wandb
|
| 705 |
+
# Log plots as artifacts
|
| 706 |
+
plots_dir = os.path.join(args.output_dir, "metrics", "plots")
|
| 707 |
+
if os.path.exists(plots_dir):
|
| 708 |
+
for fname in os.listdir(plots_dir):
|
| 709 |
+
if fname.endswith(".png"):
|
| 710 |
+
wandb.log({f"plots/{fname.replace('.png', '')}": wandb.Image(os.path.join(plots_dir, fname))})
|
| 711 |
+
wandb.finish()
|
| 712 |
+
|
| 713 |
+
# ================================================================
|
| 714 |
+
print(f"\n{'='*70}")
|
| 715 |
+
print(f" PIPELINE COMPLETE")
|
| 716 |
+
print(f" Training time: {training_time / 60:.1f} minutes")
|
| 717 |
+
print(f" Model saved to: {args.output_dir}")
|
| 718 |
+
print(f" Metrics: {args.output_dir}/metrics/")
|
| 719 |
+
print(f" Plots: {args.output_dir}/metrics/plots/")
|
| 720 |
+
if args.use_wandb:
|
| 721 |
+
print(f" W&B: https://wandb.ai/{args.wandb_project}")
|
| 722 |
+
if args.push_to_hub and args.hf_repo_id:
|
| 723 |
+
print(f" HF Hub: https://huggingface.co/{args.hf_repo_id}")
|
| 724 |
+
print(f"{'='*70}\n")
|
| 725 |
+
|
| 726 |
+
|
| 727 |
+
def main():
|
| 728 |
+
parser = argparse.ArgumentParser(description="GRPO Training for API Testing Agent")
|
| 729 |
+
|
| 730 |
+
# Model & training
|
| 731 |
+
parser.add_argument("--model-id", default="Qwen/Qwen3-1.7B", help="Base model to fine-tune")
|
| 732 |
+
parser.add_argument("--output-dir", default="./checkpoints/grpo_api_tester")
|
| 733 |
+
parser.add_argument("--num-episodes", type=int, default=50, help="Number of training episodes")
|
| 734 |
+
parser.add_argument("--num-generations", type=int, default=4, help="GRPO parallel rollouts per prompt")
|
| 735 |
+
parser.add_argument("--max-completion-length", type=int, default=4096,
|
| 736 |
+
help="Max tokens per generation. 4096 needed for Qwen3 thinking + JSON plan")
|
| 737 |
+
parser.add_argument("--max-steps", type=int, default=200, help="Max training steps")
|
| 738 |
+
parser.add_argument("--learning-rate", type=float, default=2e-5)
|
| 739 |
+
parser.add_argument("--batch-size", type=int, default=4)
|
| 740 |
+
parser.add_argument("--test-mode", action="store_true", help="Quick test with tiny config")
|
| 741 |
+
|
| 742 |
+
# HuggingFace Hub
|
| 743 |
+
parser.add_argument("--push-to-hub", action="store_true", help="Push trained model to HF Hub")
|
| 744 |
+
parser.add_argument("--hf-repo-id", type=str, default=None,
|
| 745 |
+
help="HF Hub repo ID (e.g., your-username/api-tester-grpo)")
|
| 746 |
+
|
| 747 |
+
# Evaluation
|
| 748 |
+
parser.add_argument("--skip-eval", action="store_true", help="Skip base/trained model evaluation")
|
| 749 |
+
parser.add_argument("--eval-max-steps", type=int, default=10,
|
| 750 |
+
help="Max steps per task during evaluation (default: 10, reduces eval time)")
|
| 751 |
+
|
| 752 |
+
# Weights & Biases
|
| 753 |
+
parser.add_argument("--use-wandb", action="store_true", help="Enable Weights & Biases logging")
|
| 754 |
+
parser.add_argument("--wandb-project", type=str, default="api-testing-grpo",
|
| 755 |
+
help="W&B project name")
|
| 756 |
+
parser.add_argument("--wandb-run-name", type=str, default=None,
|
| 757 |
+
help="W&B run name (auto-generated if not set)")
|
| 758 |
+
|
| 759 |
+
args = parser.parse_args()
|
| 760 |
+
|
| 761 |
+
if args.test_mode:
|
| 762 |
+
logger.info("=== TEST MODE β quick sanity check ===")
|
| 763 |
+
args.num_episodes = 3
|
| 764 |
+
args.num_generations = 4
|
| 765 |
+
args.batch_size = 2
|
| 766 |
+
args.max_steps = 10
|
| 767 |
+
args.max_completion_length = 2048
|
| 768 |
+
|
| 769 |
+
if os.environ.get("SHOW_PROMPTS"):
|
| 770 |
+
prompts = build_training_prompts(num_episodes=3)
|
| 771 |
+
for p in prompts:
|
| 772 |
+
print(f"\n{'='*60}")
|
| 773 |
+
print(f"Task: {p['task_id']} | Seed: {p['seed']}")
|
| 774 |
+
print(f"{'='*60}")
|
| 775 |
+
for msg in p["prompt"]:
|
| 776 |
+
print(f"[{msg['role']}]: {msg['content'][:300]}...")
|
| 777 |
+
return
|
| 778 |
+
|
| 779 |
+
train_grpo(args)
|
| 780 |
+
|
| 781 |
+
|
| 782 |
+
if __name__ == "__main__":
|
| 783 |
+
main()
|
training/prompts.py
ADDED
|
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Prompt formatting and action parsing for LLM-based API testing agents.
|
| 3 |
+
|
| 4 |
+
- SYSTEM_PROMPT: Instructions for the LLM on how to test APIs
|
| 5 |
+
- format_observation(): Converts environment observations into LLM prompts
|
| 6 |
+
- parse_action(): Extracts a single JSON action from LLM text
|
| 7 |
+
- parse_test_plan(): Extracts a JSON array of actions (for GRPO training)
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import json
|
| 11 |
+
import re
|
| 12 |
+
import sys
|
| 13 |
+
import os
|
| 14 |
+
|
| 15 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
| 16 |
+
from models import APITestAction, HTTPMethod
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# =====================================================================
|
| 20 |
+
# System prompt for multi-turn evaluation (one action at a time)
|
| 21 |
+
# =====================================================================
|
| 22 |
+
|
| 23 |
+
SYSTEM_PROMPT = """\
|
| 24 |
+
You are an expert API security tester. You are testing a REST API for bugs.
|
| 25 |
+
|
| 26 |
+
You will receive:
|
| 27 |
+
- The API specification (available endpoints)
|
| 28 |
+
- Results from your previous requests
|
| 29 |
+
- Coverage and bug discovery progress
|
| 30 |
+
|
| 31 |
+
Your job: find as many bugs as possible by sending HTTP requests.
|
| 32 |
+
|
| 33 |
+
Think step by step about what to test next, then output your action as JSON.
|
| 34 |
+
|
| 35 |
+
RESPOND WITH EXACTLY ONE JSON ACTION per turn:
|
| 36 |
+
```json
|
| 37 |
+
{
|
| 38 |
+
"method": "GET|POST|PUT|DELETE",
|
| 39 |
+
"endpoint": "/path",
|
| 40 |
+
"headers": {},
|
| 41 |
+
"query_params": {},
|
| 42 |
+
"body": null,
|
| 43 |
+
"expected_status": 200
|
| 44 |
+
}
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
TESTING STRATEGIES:
|
| 48 |
+
- Test each endpoint with valid inputs first
|
| 49 |
+
- Try invalid inputs (missing fields, wrong types, boundary values)
|
| 50 |
+
- Test with non-existent resource IDs
|
| 51 |
+
- Login as different users and test cross-user access
|
| 52 |
+
- Try SQL injection patterns in text fields
|
| 53 |
+
- Test with very long inputs
|
| 54 |
+
- Chain operations: create -> read -> update -> delete
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# =====================================================================
|
| 59 |
+
# System prompt for GRPO training (full test plan in one shot)
|
| 60 |
+
# =====================================================================
|
| 61 |
+
|
| 62 |
+
PLAN_SYSTEM_PROMPT = """\
|
| 63 |
+
You are an expert API security tester. You will receive an API specification and must output a COMPLETE TEST PLAN as a JSON array of HTTP requests to execute in order.
|
| 64 |
+
|
| 65 |
+
Your goal: find as many bugs as possible through systematic testing.
|
| 66 |
+
|
| 67 |
+
OUTPUT FORMAT β a JSON array of actions to execute sequentially:
|
| 68 |
+
```json
|
| 69 |
+
[
|
| 70 |
+
{"method": "GET", "endpoint": "/tasks", "headers": {}, "query_params": {}, "body": null, "expected_status": 200},
|
| 71 |
+
{"method": "POST", "endpoint": "/auth/login", "headers": {}, "query_params": {}, "body": {"username": "alice", "password": "pass"}, "expected_status": 200},
|
| 72 |
+
...more actions...
|
| 73 |
+
]
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
OUTPUT EXACTLY ONE JSON ARRAY. No other text.
|
| 77 |
+
|
| 78 |
+
TESTING STRATEGY β follow this order:
|
| 79 |
+
1. DISCOVER: GET /tasks, GET /users to see what exists
|
| 80 |
+
2. AUTHENTICATE: Login as two different users (POST /auth/login)
|
| 81 |
+
3. CRUD: POST to create, GET to read, PUT to update, DELETE to remove
|
| 82 |
+
4. MISSING FIELDS: POST /tasks without required "title" field
|
| 83 |
+
5. NON-EXISTENT IDs: GET /tasks/999999 (expect 404 β if you get 200, that's a bug!)
|
| 84 |
+
6. BOUNDARY: GET /tasks?page=-1&limit=10 (negative page), GET /tasks?limit=999999 (huge limit)
|
| 85 |
+
7. INVALID DATA: PUT /tasks/1 with assignee_email="not-an-email"
|
| 86 |
+
8. SECURITY: Login as user B, then try to GET/PUT/DELETE user A's resources (BOLA test)
|
| 87 |
+
9. INJECTION: POST /tasks with title containing SQL injection like "'; DROP TABLE tasks;--"
|
| 88 |
+
10. EMPTY AUTH: POST /auth/login with empty password (should fail but might not)
|
| 89 |
+
11. DATA LEAKS: POST /users and check if response includes password_hash
|
| 90 |
+
12. STATE: DELETE a task, then GET it again (should be 404)
|
| 91 |
+
13. LONG INPUT: POST /tasks with a title of 6000+ characters
|
| 92 |
+
|
| 93 |
+
COMMON BUG PATTERNS TO TEST:
|
| 94 |
+
- API returns 200 with null body instead of 404 for missing resources
|
| 95 |
+
- API returns 500 instead of 400 for invalid input
|
| 96 |
+
- API accepts any password (even empty string) for login
|
| 97 |
+
- Users can access other users' resources (no authorization check)
|
| 98 |
+
- Response includes sensitive fields like password_hash
|
| 99 |
+
- No input length limits (very long strings crash the server)
|
| 100 |
+
- SQL/HTML injection payloads stored without sanitization
|
| 101 |
+
- DELETE returns 200 even for non-existent resources
|
| 102 |
+
- No pagination limit cap (limit=999999 accepted)
|
| 103 |
+
|
| 104 |
+
RULES:
|
| 105 |
+
- Output 15-25 actions
|
| 106 |
+
- Each action MUST have "method" and "endpoint"
|
| 107 |
+
- Vary your requests β never repeat the same action
|
| 108 |
+
- Use the usernames from the task description for login
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def format_observation(obs) -> str:
|
| 113 |
+
"""Convert an observation into a human-readable prompt for the LLM.
|
| 114 |
+
Used in multi-turn evaluation (one action at a time).
|
| 115 |
+
"""
|
| 116 |
+
parts = []
|
| 117 |
+
|
| 118 |
+
if obs.steps_taken == 0:
|
| 119 |
+
parts.append(f"TASK: {obs.task_description}")
|
| 120 |
+
parts.append(f"\nSTEPS REMAINING: {obs.max_steps}")
|
| 121 |
+
parts.append("\nAVAILABLE ENDPOINTS:")
|
| 122 |
+
for ep in obs.available_endpoints:
|
| 123 |
+
line = f" {ep['method']} {ep['path']} β {ep.get('summary', '')}"
|
| 124 |
+
parts.append(line)
|
| 125 |
+
parts.append("\nBegin testing. Send your first request as JSON.")
|
| 126 |
+
else:
|
| 127 |
+
parts.append(f"STEP {obs.steps_taken}/{obs.max_steps}")
|
| 128 |
+
parts.append(f"RESPONSE: HTTP {obs.status_code}")
|
| 129 |
+
|
| 130 |
+
resp = obs.response_body
|
| 131 |
+
if isinstance(resp, (dict, list)):
|
| 132 |
+
resp_str = json.dumps(resp, indent=2)
|
| 133 |
+
if len(resp_str) > 500:
|
| 134 |
+
resp_str = resp_str[:500] + "\n... (truncated)"
|
| 135 |
+
else:
|
| 136 |
+
resp_str = str(resp)[:500]
|
| 137 |
+
parts.append(f"BODY:\n{resp_str}")
|
| 138 |
+
|
| 139 |
+
parts.append(f"\nFEEDBACK: {obs.feedback}")
|
| 140 |
+
|
| 141 |
+
coverage = obs.coverage_summary
|
| 142 |
+
parts.append(
|
| 143 |
+
f"\nPROGRESS: Bugs found: {obs.bugs_found_so_far} | "
|
| 144 |
+
f"Coverage: {coverage.get('coverage_pct', 0):.0f}% | "
|
| 145 |
+
f"Endpoints tested: {coverage.get('endpoints_tested', 0)}/{coverage.get('total_endpoints', 0)}"
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
if obs.auth_tokens:
|
| 149 |
+
parts.append(f"AUTH TOKENS: {list(obs.auth_tokens.keys())}")
|
| 150 |
+
if obs.known_resource_ids:
|
| 151 |
+
parts.append(f"CREATED RESOURCES: {dict(obs.known_resource_ids)}")
|
| 152 |
+
|
| 153 |
+
parts.append("\nSend your next request as JSON.")
|
| 154 |
+
|
| 155 |
+
return "\n".join(parts)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def format_plan_prompt(obs) -> str:
|
| 159 |
+
"""Convert the initial observation into a prompt for generating a full test plan.
|
| 160 |
+
Used in GRPO training (model outputs a complete plan in one completion).
|
| 161 |
+
"""
|
| 162 |
+
parts = []
|
| 163 |
+
parts.append(f"TASK: {obs.task_description}")
|
| 164 |
+
parts.append(f"\nYou have {obs.max_steps} actions to find as many bugs as possible.")
|
| 165 |
+
parts.append("\nAVAILABLE ENDPOINTS:")
|
| 166 |
+
for ep in obs.available_endpoints:
|
| 167 |
+
summary = ep.get("summary", "")
|
| 168 |
+
parts.append(f" {ep['method']} {ep['path']} β {summary}")
|
| 169 |
+
|
| 170 |
+
# Show request body schema if available
|
| 171 |
+
req_body = ep.get("request_body", {})
|
| 172 |
+
if req_body:
|
| 173 |
+
props = req_body.get("properties", {})
|
| 174 |
+
required = req_body.get("required", [])
|
| 175 |
+
if props:
|
| 176 |
+
fields = []
|
| 177 |
+
for fname, finfo in props.items():
|
| 178 |
+
req_mark = " (required)" if fname in required else ""
|
| 179 |
+
fields.append(f"{fname}: {finfo.get('type', 'any')}{req_mark}")
|
| 180 |
+
parts.append(f" Body: {', '.join(fields)}")
|
| 181 |
+
|
| 182 |
+
# Show parameters if available
|
| 183 |
+
params = ep.get("parameters", [])
|
| 184 |
+
if params:
|
| 185 |
+
param_strs = [f"{p['name']}: {p.get('type', 'any')}" for p in params]
|
| 186 |
+
parts.append(f" Params: {', '.join(param_strs)}")
|
| 187 |
+
|
| 188 |
+
parts.append("\nOutput your complete test plan as a JSON array of actions.")
|
| 189 |
+
return "\n".join(parts)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def parse_action(text: str) -> APITestAction | None:
|
| 193 |
+
"""Parse a single JSON action from LLM output.
|
| 194 |
+
Used in multi-turn evaluation.
|
| 195 |
+
"""
|
| 196 |
+
# Strip Qwen3 thinking blocks
|
| 197 |
+
if "</think>" in text:
|
| 198 |
+
text = text.split("</think>", 1)[-1]
|
| 199 |
+
|
| 200 |
+
json_match = re.search(r'\{[^{}]*"method"[^{}]*\}', text, re.DOTALL)
|
| 201 |
+
if not json_match:
|
| 202 |
+
json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', text, re.DOTALL)
|
| 203 |
+
if json_match:
|
| 204 |
+
json_str = json_match.group(1)
|
| 205 |
+
else:
|
| 206 |
+
return None
|
| 207 |
+
else:
|
| 208 |
+
json_str = json_match.group(0)
|
| 209 |
+
|
| 210 |
+
try:
|
| 211 |
+
data = json.loads(json_str)
|
| 212 |
+
except json.JSONDecodeError:
|
| 213 |
+
return None
|
| 214 |
+
|
| 215 |
+
return _dict_to_action(data)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def parse_test_plan(text: str) -> list[APITestAction]:
|
| 219 |
+
"""Parse a JSON array of actions from LLM output.
|
| 220 |
+
|
| 221 |
+
Handles all of these formats:
|
| 222 |
+
1. Raw JSON array: [{"method": ...}, ...]
|
| 223 |
+
2. Wrapped object: {"actions": [...]} or {"plan": [...]} or {"test_plan": [...]}
|
| 224 |
+
3. Markdown code block: ```json [...] ```
|
| 225 |
+
4. Trailing commas, missing commas (best-effort repair)
|
| 226 |
+
5. Brace-balanced extraction of individual action objects
|
| 227 |
+
"""
|
| 228 |
+
if not text:
|
| 229 |
+
return []
|
| 230 |
+
|
| 231 |
+
# Strip Qwen3 thinking blocks
|
| 232 |
+
if "</think>" in text:
|
| 233 |
+
text = text.split("</think>", 1)[-1]
|
| 234 |
+
|
| 235 |
+
# Strip markdown code fences
|
| 236 |
+
text = re.sub(r'```(?:json)?\s*', '', text)
|
| 237 |
+
text = text.replace('```', '')
|
| 238 |
+
|
| 239 |
+
data = None
|
| 240 |
+
|
| 241 |
+
# Strategy 1: Try to parse the entire text as JSON
|
| 242 |
+
try:
|
| 243 |
+
data = json.loads(text.strip())
|
| 244 |
+
except json.JSONDecodeError:
|
| 245 |
+
pass
|
| 246 |
+
|
| 247 |
+
# Strategy 2: Find a top-level JSON ARRAY via brace matching
|
| 248 |
+
if data is None:
|
| 249 |
+
start = text.find('[')
|
| 250 |
+
if start >= 0:
|
| 251 |
+
depth = 0
|
| 252 |
+
for i in range(start, len(text)):
|
| 253 |
+
if text[i] == '[':
|
| 254 |
+
depth += 1
|
| 255 |
+
elif text[i] == ']':
|
| 256 |
+
depth -= 1
|
| 257 |
+
if depth == 0:
|
| 258 |
+
candidate = text[start:i+1]
|
| 259 |
+
try:
|
| 260 |
+
data = json.loads(candidate)
|
| 261 |
+
break
|
| 262 |
+
except json.JSONDecodeError:
|
| 263 |
+
cleaned = re.sub(r',(\s*[\]}])', r'\1', candidate)
|
| 264 |
+
try:
|
| 265 |
+
data = json.loads(cleaned)
|
| 266 |
+
break
|
| 267 |
+
except json.JSONDecodeError:
|
| 268 |
+
pass
|
| 269 |
+
|
| 270 |
+
# Strategy 2b: Find a top-level JSON OBJECT (might be {"actions": [...]})
|
| 271 |
+
if data is None:
|
| 272 |
+
start = text.find('{')
|
| 273 |
+
if start >= 0:
|
| 274 |
+
depth = 0
|
| 275 |
+
for i in range(start, len(text)):
|
| 276 |
+
if text[i] == '{':
|
| 277 |
+
depth += 1
|
| 278 |
+
elif text[i] == '}':
|
| 279 |
+
depth -= 1
|
| 280 |
+
if depth == 0:
|
| 281 |
+
candidate = text[start:i+1]
|
| 282 |
+
try:
|
| 283 |
+
parsed = json.loads(candidate)
|
| 284 |
+
# Only accept if it's a wrapper containing actions
|
| 285 |
+
if isinstance(parsed, dict) and any(
|
| 286 |
+
k in parsed for k in ("actions", "plan", "test_plan", "steps", "requests")
|
| 287 |
+
):
|
| 288 |
+
data = parsed
|
| 289 |
+
break
|
| 290 |
+
except json.JSONDecodeError:
|
| 291 |
+
cleaned = re.sub(r',(\s*[\]}])', r'\1', candidate)
|
| 292 |
+
try:
|
| 293 |
+
parsed = json.loads(cleaned)
|
| 294 |
+
if isinstance(parsed, dict) and any(
|
| 295 |
+
k in parsed for k in ("actions", "plan", "test_plan", "steps", "requests")
|
| 296 |
+
):
|
| 297 |
+
data = parsed
|
| 298 |
+
break
|
| 299 |
+
except json.JSONDecodeError:
|
| 300 |
+
pass
|
| 301 |
+
|
| 302 |
+
# Strategy 3: Extract individual {"method": ...} objects with brace balancing
|
| 303 |
+
if data is None:
|
| 304 |
+
objects = []
|
| 305 |
+
i = 0
|
| 306 |
+
while i < len(text):
|
| 307 |
+
if text[i] == '{':
|
| 308 |
+
depth = 1
|
| 309 |
+
start = i
|
| 310 |
+
i += 1
|
| 311 |
+
while i < len(text) and depth > 0:
|
| 312 |
+
if text[i] == '{':
|
| 313 |
+
depth += 1
|
| 314 |
+
elif text[i] == '}':
|
| 315 |
+
depth -= 1
|
| 316 |
+
i += 1
|
| 317 |
+
obj_str = text[start:i]
|
| 318 |
+
if '"method"' in obj_str:
|
| 319 |
+
try:
|
| 320 |
+
obj = json.loads(obj_str)
|
| 321 |
+
objects.append(obj)
|
| 322 |
+
except json.JSONDecodeError:
|
| 323 |
+
cleaned = re.sub(r',(\s*[\]}])', r'\1', obj_str)
|
| 324 |
+
try:
|
| 325 |
+
obj = json.loads(cleaned)
|
| 326 |
+
objects.append(obj)
|
| 327 |
+
except json.JSONDecodeError:
|
| 328 |
+
pass
|
| 329 |
+
else:
|
| 330 |
+
i += 1
|
| 331 |
+
if objects:
|
| 332 |
+
data = objects
|
| 333 |
+
|
| 334 |
+
if data is None:
|
| 335 |
+
return []
|
| 336 |
+
|
| 337 |
+
# Unwrap common container shapes: {"actions": [...]}, {"plan": [...]}, etc.
|
| 338 |
+
if isinstance(data, dict):
|
| 339 |
+
for key in ("actions", "plan", "test_plan", "steps", "requests"):
|
| 340 |
+
if key in data and isinstance(data[key], list):
|
| 341 |
+
data = data[key]
|
| 342 |
+
break
|
| 343 |
+
else:
|
| 344 |
+
# Single action object
|
| 345 |
+
data = [data]
|
| 346 |
+
|
| 347 |
+
if not isinstance(data, list):
|
| 348 |
+
data = [data]
|
| 349 |
+
|
| 350 |
+
actions = []
|
| 351 |
+
for item in data:
|
| 352 |
+
if isinstance(item, dict) and "method" in item:
|
| 353 |
+
action = _dict_to_action(item)
|
| 354 |
+
if action:
|
| 355 |
+
actions.append(action)
|
| 356 |
+
|
| 357 |
+
return actions
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def _dict_to_action(data: dict) -> APITestAction | None:
|
| 361 |
+
"""Convert a dict to an APITestAction."""
|
| 362 |
+
method = str(data.get("method", "GET")).upper()
|
| 363 |
+
if method not in ("GET", "POST", "PUT", "DELETE", "PATCH"):
|
| 364 |
+
method = "GET"
|
| 365 |
+
|
| 366 |
+
endpoint = data.get("endpoint", "/tasks")
|
| 367 |
+
if not isinstance(endpoint, str):
|
| 368 |
+
endpoint = str(endpoint)
|
| 369 |
+
if not endpoint.startswith("/"):
|
| 370 |
+
endpoint = "/" + endpoint
|
| 371 |
+
|
| 372 |
+
headers = data.get("headers") or {}
|
| 373 |
+
if not isinstance(headers, dict):
|
| 374 |
+
headers = {}
|
| 375 |
+
|
| 376 |
+
query_params = data.get("query_params") or {}
|
| 377 |
+
if not isinstance(query_params, dict):
|
| 378 |
+
query_params = {}
|
| 379 |
+
|
| 380 |
+
body = data.get("body")
|
| 381 |
+
if body is not None and not isinstance(body, dict):
|
| 382 |
+
body = None
|
| 383 |
+
|
| 384 |
+
expected = data.get("expected_status")
|
| 385 |
+
if expected is not None:
|
| 386 |
+
try:
|
| 387 |
+
expected = int(expected)
|
| 388 |
+
except (ValueError, TypeError):
|
| 389 |
+
expected = None
|
| 390 |
+
|
| 391 |
+
return APITestAction(
|
| 392 |
+
method=HTTPMethod(method),
|
| 393 |
+
endpoint=endpoint,
|
| 394 |
+
headers=headers,
|
| 395 |
+
query_params=query_params,
|
| 396 |
+
body=body,
|
| 397 |
+
expected_status=expected,
|
| 398 |
+
)
|
training/rewards.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Reward functions for GRPO training (v2 β plan-based).
|
| 3 |
+
|
| 4 |
+
The model outputs a FULL TEST PLAN (JSON array of actions).
|
| 5 |
+
Each reward function creates a FRESH environment, executes ALL actions,
|
| 6 |
+
and scores the result.
|
| 7 |
+
|
| 8 |
+
Three reward signals:
|
| 9 |
+
1. format_reward β Valid JSON array with 3+ diverse actions? (+2 / -2)
|
| 10 |
+
2. plan_reward β Execute plan, score on bugs + coverage + efficiency (0 to ~8)
|
| 11 |
+
3. diversity_reward β Variety of methods, endpoints, and request patterns (+0 to +2)
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import re
|
| 15 |
+
import sys
|
| 16 |
+
import os
|
| 17 |
+
|
| 18 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
| 19 |
+
|
| 20 |
+
from models import APITestAction, HTTPMethod
|
| 21 |
+
from server.environment import APITestEnvironment
|
| 22 |
+
from .prompts import parse_test_plan
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def format_reward_fn(completions: list[str], **kwargs) -> list[float]:
|
| 26 |
+
"""Reward for valid JSON test plan format.
|
| 27 |
+
|
| 28 |
+
+2.0 if output has 5+ diverse actions (a real plan)
|
| 29 |
+
+1.0 if output has 3-4 actions (minimal plan)
|
| 30 |
+
+0.0 if output has 1-2 actions (barely valid)
|
| 31 |
+
-2.0 if it can't be parsed at all
|
| 32 |
+
|
| 33 |
+
Also penalizes if all actions are identical.
|
| 34 |
+
"""
|
| 35 |
+
rewards = []
|
| 36 |
+
for text in completions:
|
| 37 |
+
actions = parse_test_plan(text)
|
| 38 |
+
if not actions:
|
| 39 |
+
rewards.append(-2.0)
|
| 40 |
+
continue
|
| 41 |
+
|
| 42 |
+
n = len(actions)
|
| 43 |
+
|
| 44 |
+
# Check diversity β are the actions actually different?
|
| 45 |
+
unique_pairs = set()
|
| 46 |
+
for a in actions:
|
| 47 |
+
m = a.method.value if hasattr(a.method, "value") else str(a.method)
|
| 48 |
+
ep = re.sub(r'/\d+', '/{id}', a.endpoint)
|
| 49 |
+
unique_pairs.add((m, ep))
|
| 50 |
+
|
| 51 |
+
diversity_ratio = len(unique_pairs) / max(n, 1)
|
| 52 |
+
|
| 53 |
+
if n >= 5 and diversity_ratio >= 0.5:
|
| 54 |
+
rewards.append(2.0)
|
| 55 |
+
elif n >= 3:
|
| 56 |
+
rewards.append(1.0)
|
| 57 |
+
elif n >= 1:
|
| 58 |
+
rewards.append(0.0)
|
| 59 |
+
else:
|
| 60 |
+
rewards.append(-2.0)
|
| 61 |
+
|
| 62 |
+
# Penalty if all actions are the same
|
| 63 |
+
if len(unique_pairs) <= 1 and n > 1:
|
| 64 |
+
rewards[-1] = -1.0
|
| 65 |
+
|
| 66 |
+
return rewards
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def plan_reward_fn(completions: list[str], **kwargs) -> list[float]:
|
| 70 |
+
"""Execute the full test plan in a FRESH environment and return a balanced score.
|
| 71 |
+
|
| 72 |
+
Score components:
|
| 73 |
+
- Bug discovery: min(bugs_found, 5) * 1.0 (capped at 5.0 to not dominate)
|
| 74 |
+
- Coverage: (coverage_pct / 100) * 2.0 (up to 2.0)
|
| 75 |
+
- Efficiency: if bugs > 0: +0.5 per bug found in first 10 actions
|
| 76 |
+
- Crash penalty: -0.1 per action that caused a 500 error
|
| 77 |
+
|
| 78 |
+
Total range: roughly -2 to +8
|
| 79 |
+
|
| 80 |
+
Each completion gets its OWN fresh environment β no state pollution.
|
| 81 |
+
"""
|
| 82 |
+
prompts_meta = kwargs.get("prompts_meta", [])
|
| 83 |
+
rewards = []
|
| 84 |
+
|
| 85 |
+
for i, text in enumerate(completions):
|
| 86 |
+
actions = parse_test_plan(text)
|
| 87 |
+
if not actions:
|
| 88 |
+
rewards.append(-1.0)
|
| 89 |
+
continue
|
| 90 |
+
|
| 91 |
+
# Get episode seed and task
|
| 92 |
+
meta = prompts_meta[i % len(prompts_meta)] if prompts_meta else {}
|
| 93 |
+
seed = meta.get("seed", 42)
|
| 94 |
+
task_id = meta.get("task_id", "basic_validation")
|
| 95 |
+
|
| 96 |
+
# Create a FRESH environment
|
| 97 |
+
env = APITestEnvironment()
|
| 98 |
+
env.reset(seed=seed, task_id=task_id)
|
| 99 |
+
|
| 100 |
+
# Execute all actions, track results
|
| 101 |
+
crashes = 0
|
| 102 |
+
step_rewards = []
|
| 103 |
+
for action in actions:
|
| 104 |
+
try:
|
| 105 |
+
obs = env.step(action)
|
| 106 |
+
step_rewards.append(obs.reward or 0.0)
|
| 107 |
+
if obs.status_code >= 500:
|
| 108 |
+
crashes += 1
|
| 109 |
+
except Exception:
|
| 110 |
+
step_rewards.append(0.0)
|
| 111 |
+
crashes += 1
|
| 112 |
+
|
| 113 |
+
state = env.state
|
| 114 |
+
coverage = state.coverage_pct
|
| 115 |
+
|
| 116 |
+
# Component 1: Bug discovery (capped to prevent domination)
|
| 117 |
+
bug_score = min(state.bugs_found, 5) * 1.0
|
| 118 |
+
|
| 119 |
+
# Component 2: Coverage (proportional, up to 2.0)
|
| 120 |
+
coverage_score = (coverage / 100) * 2.0
|
| 121 |
+
|
| 122 |
+
# Component 3: Efficiency β finding bugs early is better
|
| 123 |
+
early_bug_bonus = 0.0
|
| 124 |
+
early_steps = step_rewards[:10]
|
| 125 |
+
for r in early_steps:
|
| 126 |
+
if r > 0.2: # High reward step = likely found a bug
|
| 127 |
+
early_bug_bonus += 0.3
|
| 128 |
+
|
| 129 |
+
# Component 4: Crash penalty
|
| 130 |
+
crash_penalty = crashes * -0.1
|
| 131 |
+
|
| 132 |
+
# Component 5: Step reward sum (small weight β mainly for gradient signal)
|
| 133 |
+
step_sum = sum(step_rewards) * 0.2
|
| 134 |
+
|
| 135 |
+
total = bug_score + coverage_score + early_bug_bonus + crash_penalty + step_sum
|
| 136 |
+
rewards.append(round(total, 4))
|
| 137 |
+
|
| 138 |
+
return rewards
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def diversity_reward_fn(completions: list[str], **kwargs) -> list[float]:
|
| 142 |
+
"""Reward for diverse test plans β varied methods, endpoints, and strategies.
|
| 143 |
+
|
| 144 |
+
Components:
|
| 145 |
+
- Method variety: up to +0.5 (using GET/POST/PUT/DELETE)
|
| 146 |
+
- Endpoint variety: up to +0.5 (testing different endpoints)
|
| 147 |
+
- Strategy variety: up to +0.5 (auth + invalid input + boundary + injection patterns)
|
| 148 |
+
- Repetition penalty: up to -0.5
|
| 149 |
+
"""
|
| 150 |
+
rewards = []
|
| 151 |
+
for text in completions:
|
| 152 |
+
actions = parse_test_plan(text)
|
| 153 |
+
if not actions:
|
| 154 |
+
rewards.append(0.0)
|
| 155 |
+
continue
|
| 156 |
+
|
| 157 |
+
methods = set()
|
| 158 |
+
endpoints = set()
|
| 159 |
+
unique_pairs = set()
|
| 160 |
+
has_auth = False
|
| 161 |
+
has_invalid_input = False
|
| 162 |
+
has_boundary = False
|
| 163 |
+
has_injection = False
|
| 164 |
+
has_nonexistent_id = False
|
| 165 |
+
|
| 166 |
+
for a in actions:
|
| 167 |
+
m = a.method.value if hasattr(a.method, "value") else str(a.method)
|
| 168 |
+
methods.add(m)
|
| 169 |
+
norm_ep = re.sub(r'/\d+', '/{id}', a.endpoint)
|
| 170 |
+
endpoints.add(norm_ep)
|
| 171 |
+
unique_pairs.add((m, norm_ep))
|
| 172 |
+
|
| 173 |
+
# Detect testing strategies
|
| 174 |
+
if a.endpoint == "/auth/login":
|
| 175 |
+
has_auth = True
|
| 176 |
+
if a.body and not a.body.get("title") and a.method.value == "POST":
|
| 177 |
+
has_invalid_input = True
|
| 178 |
+
qp = a.query_params or {}
|
| 179 |
+
if any(isinstance(v, (int, float)) and v < 0 for v in qp.values()):
|
| 180 |
+
has_boundary = True
|
| 181 |
+
if any(isinstance(v, (int, float)) and v > 10000 for v in qp.values()):
|
| 182 |
+
has_boundary = True
|
| 183 |
+
if a.body and any("DROP" in str(v).upper() or "script" in str(v).lower()
|
| 184 |
+
for v in (a.body or {}).values()):
|
| 185 |
+
has_injection = True
|
| 186 |
+
if re.search(r'/\d{4,}', a.endpoint):
|
| 187 |
+
has_nonexistent_id = True
|
| 188 |
+
|
| 189 |
+
# Method variety (max 4 methods = +0.5)
|
| 190 |
+
method_score = min(len(methods) / 4, 1.0) * 0.5
|
| 191 |
+
|
| 192 |
+
# Endpoint variety (max 7 endpoints = +0.5)
|
| 193 |
+
endpoint_score = min(len(endpoints) / 7, 1.0) * 0.5
|
| 194 |
+
|
| 195 |
+
# Strategy variety (each strategy = +0.1, max +0.5)
|
| 196 |
+
strategies = sum([has_auth, has_invalid_input, has_boundary, has_injection, has_nonexistent_id])
|
| 197 |
+
strategy_score = min(strategies * 0.1, 0.5)
|
| 198 |
+
|
| 199 |
+
# Repetition penalty
|
| 200 |
+
if len(actions) > 0:
|
| 201 |
+
repeat_count = len(actions) - len(unique_pairs)
|
| 202 |
+
repetition_penalty = min(repeat_count / len(actions), 1.0) * -0.5
|
| 203 |
+
else:
|
| 204 |
+
repetition_penalty = 0.0
|
| 205 |
+
|
| 206 |
+
total = method_score + endpoint_score + strategy_score + repetition_penalty
|
| 207 |
+
rewards.append(round(total, 3))
|
| 208 |
+
|
| 209 |
+
return rewards
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|