Spaces:
Sleeping
Sleeping
codemaverick2 commited on
Commit ·
e48a1e4
1
Parent(s): 8b75c34
Add 7-task RL env with PBRS, CAMRL curriculum, VL norm, RC-GRPO inference
Browse files- README.md +109 -10
- inference.py +293 -51
- models.py +20 -0
- openenv.yaml +50 -3
- server/app.py +191 -14
- server/environment.py +414 -66
- server/graders.py +446 -6
- tasks/data.py +523 -0
- tests/test_environment.py +526 -0
- tests/test_graders.py +403 -1
README.md
CHANGED
|
@@ -117,6 +117,71 @@ Comprehensive review of a Django e-commerce API across two files (`views.py`, `m
|
|
| 117 |
|
| 118 |
**Max steps:** 30
|
| 119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
## Scoring
|
| 121 |
|
| 122 |
```
|
|
@@ -129,8 +194,36 @@ where:
|
|
| 129 |
severity_accuracy = avg(1 − |flag_sev_rank − gt_sev_rank| × 0.34) for matched issues
|
| 130 |
|
| 131 |
Matching tolerance: ±2 lines, same filename, compatible issue type
|
|
|
|
| 132 |
```
|
| 133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
## API Endpoints
|
| 135 |
|
| 136 |
| Method | Endpoint | Description |
|
|
@@ -234,14 +327,18 @@ pytest tests/ -v
|
|
| 234 |
|
| 235 |
## Baseline Scores
|
| 236 |
|
| 237 |
-
| Task | Keyword heuristic |
|
| 238 |
-
|------|-------------------|
|
| 239 |
-
| bug-detection | 1.00 |
|
| 240 |
-
| security-audit | 0.75 |
|
| 241 |
-
|
|
| 242 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
|
| 244 |
-
Keyword heuristic runs via `inference.py` with no API key. LLM scores use `API_BASE_URL` + `HF_TOKEN`.
|
| 245 |
|
| 246 |
## Project Structure
|
| 247 |
|
|
@@ -258,11 +355,13 @@ code-review-env/
|
|
| 258 |
├── client.py ← HTTP client
|
| 259 |
├── models.py ← ReviewAction, ReviewObservation, ReviewState, Issue
|
| 260 |
├── tasks/
|
| 261 |
-
│ └── data.py ←
|
|
|
|
|
|
|
| 262 |
├── server/
|
| 263 |
│ ├── app.py ← FastAPI application
|
| 264 |
-
│ ├── environment.py ← Core environment logic
|
| 265 |
-
│ └── graders.py ← F1 grading + keyword baseline
|
| 266 |
└── tests/
|
| 267 |
├── test_environment.py
|
| 268 |
└── test_graders.py
|
|
|
|
| 117 |
|
| 118 |
**Max steps:** 30
|
| 119 |
|
| 120 |
+
### Task 4: `async-review` — Medium-Hard
|
| 121 |
+
|
| 122 |
+
Review an async Python module (`async.py`) for concurrency bugs, resource leaks, and performance issues with `asyncio` and `aiohttp`.
|
| 123 |
+
|
| 124 |
+
| Line | Issue | Severity |
|
| 125 |
+
|------|-------|----------|
|
| 126 |
+
| 5 | Shared mutable cache dict without `asyncio.Lock` — race condition | High |
|
| 127 |
+
| 9 | `timeout=5` wrong type for aiohttp; requires `ClientTimeout(total=5)` | Medium |
|
| 128 |
+
| 22 | `ClientSession` created but never closed — resource leak | High |
|
| 129 |
+
| 24 | Sequential `await` in loop — use `asyncio.gather()` for concurrency | High |
|
| 130 |
+
| 37 | Off-by-one in retry condition: `attempt == retries` never true | High |
|
| 131 |
+
| 48 | Tasks awaited sequentially; `self.results` accumulates across calls | Medium |
|
| 132 |
+
|
| 133 |
+
**Max steps:** 20
|
| 134 |
+
|
| 135 |
+
### Task 5: `data-pipeline` — Hard
|
| 136 |
+
|
| 137 |
+
Security and correctness audit of a SQLite data pipeline module (`pipeline.py`).
|
| 138 |
+
|
| 139 |
+
| Line | Issue | Severity |
|
| 140 |
+
|------|-------|----------|
|
| 141 |
+
| 20 | MD5 for password hashing — cryptographically broken | High |
|
| 142 |
+
| 27 | SQL injection via f-string in `INSERT` query | Critical |
|
| 143 |
+
| 35 | SQL injection via f-string in `LIKE` query | Critical |
|
| 144 |
+
| 41 | One transaction per row in `bulk_load` — severe performance issue | High |
|
| 145 |
+
| 46 | `float()` conversion without error handling — crashes on bad input | Medium |
|
| 146 |
+
| 52 | `export_records` leaks `password_hash` field in JSON output | High |
|
| 147 |
+
| 59 | SQL injection: `limit` interpolated into `LIMIT` clause | Critical |
|
| 148 |
+
|
| 149 |
+
**Max steps:** 25
|
| 150 |
+
|
| 151 |
+
### Task 6: `api-security` — Hard
|
| 152 |
+
|
| 153 |
+
Security audit of a FastAPI REST API (`api.py`) with authentication, authorization, and injection vulnerabilities.
|
| 154 |
+
|
| 155 |
+
| Line | Issue | Severity |
|
| 156 |
+
|------|-------|----------|
|
| 157 |
+
| 12 | Hardcoded `SECRET_KEY` in source | High |
|
| 158 |
+
| 13 | Hardcoded `ADMIN_TOKEN` in source | High |
|
| 159 |
+
| 16 | MD5 for password hashing | High |
|
| 160 |
+
| 27 | JWT issued without `exp` expiry claim | Medium |
|
| 161 |
+
| 33 | IDOR — any user can fetch any other user's data | Critical |
|
| 162 |
+
| 38 | SQL injection via f-string in `SELECT` query | Critical |
|
| 163 |
+
| 47 | Command injection via `os.system()` with env-interpolated path | Critical |
|
| 164 |
+
| 53 | `pickle.loads()` on untrusted user bytes — RCE | Critical |
|
| 165 |
+
|
| 166 |
+
**Max steps:** 25
|
| 167 |
+
|
| 168 |
+
### Task 7: `js-security` — Hard
|
| 169 |
+
|
| 170 |
+
Security audit of an Express.js REST API (`server.js`) in JavaScript/Node.js.
|
| 171 |
+
|
| 172 |
+
| Line | Issue | Severity |
|
| 173 |
+
|------|-------|----------|
|
| 174 |
+
| 11 | Hardcoded `JWT_SECRET` in source | High |
|
| 175 |
+
| 16 | SQL injection via template literal in `prepare()` | Critical |
|
| 176 |
+
| 18 | JWT issued without `expiresIn` — tokens valid forever | Medium |
|
| 177 |
+
| 25 | IDOR + SQL injection: unauthenticated user access + unparameterized query | Critical |
|
| 178 |
+
| 31 | XSS: user query param reflected directly in HTML response | High |
|
| 179 |
+
| 36 | Command injection via `execSync()` with user-supplied filename | Critical |
|
| 180 |
+
| 42 | Path traversal: `path.join` with user-supplied filename | High |
|
| 181 |
+
| 48 | `new Function()` with user template — arbitrary code execution | Critical |
|
| 182 |
+
|
| 183 |
+
**Max steps:** 25
|
| 184 |
+
|
| 185 |
## Scoring
|
| 186 |
|
| 187 |
```
|
|
|
|
| 194 |
severity_accuracy = avg(1 − |flag_sev_rank − gt_sev_rank| × 0.34) for matched issues
|
| 195 |
|
| 196 |
Matching tolerance: ±2 lines, same filename, compatible issue type
|
| 197 |
+
Near-miss (±3-5 lines): graduated partial credit via exponential decay
|
| 198 |
```
|
| 199 |
|
| 200 |
+
## Reward Design
|
| 201 |
+
|
| 202 |
+
### Per-step rewards
|
| 203 |
+
|
| 204 |
+
| Event | Reward |
|
| 205 |
+
|-------|--------|
|
| 206 |
+
| True positive (TP) | +0.10 base |
|
| 207 |
+
| TP + severity exact match | +0.02 bonus |
|
| 208 |
+
| TP + early (first 40% of steps) | +0.02 bonus |
|
| 209 |
+
| TP + high confidence (≥0.7) | +0.01 bonus |
|
| 210 |
+
| PBRS potential shaping (Φ(s')−Φ(s)) | +0.03–0.08 |
|
| 211 |
+
| Near-miss (±3-5 lines, exponential decay) | +0.020–0.055 |
|
| 212 |
+
| False positive | −0.05 |
|
| 213 |
+
| False positive flood (4th+ FP) | escalating −0.03 extra |
|
| 214 |
+
| High-confidence FP | −0.03 extra |
|
| 215 |
+
| Clear TP | −0.03 |
|
| 216 |
+
| Clear FP | +0.03 |
|
| 217 |
+
| Hint | −0.01 |
|
| 218 |
+
| Submit / auto-end | Final F1 score |
|
| 219 |
+
|
| 220 |
+
### Reward shaping foundations
|
| 221 |
+
|
| 222 |
+
- **Potential-Based Reward Shaping** (Ng et al. 1999): Φ(s) = (tp/total_gt) × 0.5. Policy-invariant shaping that improves sample efficiency without changing the optimal policy.
|
| 223 |
+
- **Graduated near-miss** (exponential decay): reward = 0.10 × e^(−0.6 × (line_diff − 2)) for lines 3-5 off. Gives smooth gradient signal for line-number refinement.
|
| 224 |
+
- **Variable-Length Return Normalization** (VL Norm 2025): normalized_return = cumulative_reward / steps_used. Makes return comparable across tasks of different lengths.
|
| 225 |
+
- **Flood protection**: escalating FP penalty prevents reward hacking via flag-spamming.
|
| 226 |
+
|
| 227 |
## API Endpoints
|
| 228 |
|
| 229 |
| Method | Endpoint | Description |
|
|
|
|
| 327 |
|
| 328 |
## Baseline Scores
|
| 329 |
|
| 330 |
+
| Task | Keyword heuristic |
|
| 331 |
+
|------|-------------------|
|
| 332 |
+
| bug-detection | 1.00 |
|
| 333 |
+
| security-audit | 0.75 |
|
| 334 |
+
| async-review | 0.71 |
|
| 335 |
+
| comprehensive-review | 0.66 |
|
| 336 |
+
| api-security | 0.83 |
|
| 337 |
+
| js-security | 0.70 |
|
| 338 |
+
| data-pipeline | 0.55 |
|
| 339 |
+
| **Overall (7 tasks)** | **0.74** |
|
| 340 |
|
| 341 |
+
Keyword heuristic runs via `inference.py` with no API key (uses `/baseline` endpoint). LLM scores use `API_BASE_URL` + `HF_TOKEN`.
|
| 342 |
|
| 343 |
## Project Structure
|
| 344 |
|
|
|
|
| 355 |
├── client.py ← HTTP client
|
| 356 |
├── models.py ← ReviewAction, ReviewObservation, ReviewState, Issue
|
| 357 |
├── tasks/
|
| 358 |
+
│ └── data.py ← 5 task definitions + ground truth
|
| 359 |
+
│ (bug-detection, security-audit, comprehensive-review,
|
| 360 |
+
│ async-review, data-pipeline)
|
| 361 |
├── server/
|
| 362 |
│ ├── app.py ← FastAPI application
|
| 363 |
+
│ ├── environment.py ← Core environment logic (adaptive hints, rich rewards)
|
| 364 |
+
│ └── graders.py ← F1 grading + detailed grading + keyword baseline
|
| 365 |
└── tests/
|
| 366 |
├── test_environment.py
|
| 367 |
└── test_graders.py
|
inference.py
CHANGED
|
@@ -4,7 +4,7 @@ Inference script for the Code Review Environment.
|
|
| 4 |
Environment variables:
|
| 5 |
API_BASE_URL — LLM API endpoint (e.g. https://openrouter.ai/api/v1)
|
| 6 |
MODEL_NAME — Model identifier (e.g. openai/gpt-4o-mini)
|
| 7 |
-
HF_TOKEN — API key for the LLM provider
|
| 8 |
ENV_URL — Environment base URL (default: localhost:7860)
|
| 9 |
|
| 10 |
Usage:
|
|
@@ -19,6 +19,7 @@ import os
|
|
| 19 |
import sys
|
| 20 |
import json
|
| 21 |
import time
|
|
|
|
| 22 |
|
| 23 |
import httpx
|
| 24 |
|
|
@@ -27,24 +28,76 @@ MODEL_NAME: str = os.environ.get("MODEL_NAME", "gpt-4o-mini")
|
|
| 27 |
HF_TOKEN: str = os.environ.get("HF_TOKEN") or os.environ.get("OPENAI_API_KEY", "")
|
| 28 |
ENV_URL: str = os.environ.get("ENV_URL", "http://localhost:7860").rstrip("/")
|
| 29 |
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
SYSTEM_PROMPT = """\
|
| 33 |
-
You are an expert software engineer performing a thorough code review.
|
| 34 |
-
|
| 35 |
-
Your
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
- One action per response
|
| 46 |
-
-
|
| 47 |
-
- Only flag
|
|
|
|
|
|
|
| 48 |
"""
|
| 49 |
|
| 50 |
|
|
@@ -59,13 +112,17 @@ def chat_completion(messages: list) -> str:
|
|
| 59 |
kwargs["base_url"] = API_BASE_URL
|
| 60 |
|
| 61 |
client = OpenAI(**kwargs)
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
|
| 71 |
def parse_action(text: str) -> dict:
|
|
@@ -100,45 +157,217 @@ def parse_action(text: str) -> dict:
|
|
| 100 |
|
| 101 |
def run_keyword_fallback(base_url: str, task_id: str) -> dict:
|
| 102 |
"""Fallback: use the built-in /baseline endpoint (no LLM needed)."""
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
|
| 111 |
def run_task(task_id: str, http_client: httpx.Client) -> dict:
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
code_display = "\n\n".join(
|
| 117 |
-
f"=== {fname} ===\n{code}"
|
| 118 |
for fname, code in obs.get("code_files", {}).items()
|
| 119 |
)
|
| 120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
messages = [
|
| 122 |
{"role": "system", "content": SYSTEM_PROMPT},
|
| 123 |
{
|
| 124 |
"role": "user",
|
| 125 |
"content": (
|
| 126 |
-
f"
|
| 127 |
-
f"{
|
| 128 |
-
f"
|
| 129 |
-
f"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
),
|
| 131 |
},
|
| 132 |
]
|
| 133 |
|
| 134 |
done = False
|
| 135 |
step_count = 0
|
| 136 |
-
max_steps = obs.get("max_steps", 20)
|
| 137 |
final_score = 0.0
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
while not done and step_count < max_steps:
|
| 140 |
-
|
| 141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
try:
|
| 144 |
step_resp = http_client.post(f"{ENV_URL}/step", json=action, timeout=30)
|
|
@@ -150,20 +379,33 @@ def run_task(task_id: str, http_client: httpx.Client) -> dict:
|
|
| 150 |
|
| 151 |
done = obs.get("done", False)
|
| 152 |
step_count += 1
|
| 153 |
-
|
| 154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
|
| 156 |
messages.append({"role": "assistant", "content": action_text})
|
| 157 |
-
|
| 158 |
-
"role": "user",
|
| 159 |
-
"content": (
|
| 160 |
-
f"Feedback: {obs.get('feedback', '')} "
|
| 161 |
-
f"(step {step_count}/{max_steps}, score: {obs.get('current_score', 0.0):.3f})"
|
| 162 |
-
),
|
| 163 |
-
})
|
| 164 |
|
| 165 |
atype = action.get("action_type", "")
|
| 166 |
-
print(f" Step {step_count:2d}: {atype:20s} | reward={str(
|
| 167 |
|
| 168 |
if atype == "submit_review":
|
| 169 |
final_score = obs.get("reward", obs.get("current_score", 0.0)) or 0.0
|
|
@@ -205,7 +447,7 @@ def main():
|
|
| 205 |
print(f"Running task: {task_id}")
|
| 206 |
result = run_task(task_id, client)
|
| 207 |
results[task_id] = result
|
| 208 |
-
print(f" → score: {result['score']:.4f} ({result['steps']} steps)\n")
|
| 209 |
else:
|
| 210 |
print("HF_TOKEN / API_BASE_URL not set — using built-in keyword heuristic baseline.\n")
|
| 211 |
for task_id in TASK_IDS:
|
|
|
|
| 4 |
Environment variables:
|
| 5 |
API_BASE_URL — LLM API endpoint (e.g. https://openrouter.ai/api/v1)
|
| 6 |
MODEL_NAME — Model identifier (e.g. openai/gpt-4o-mini)
|
| 7 |
+
HF_TOKEN — API key for the LLM provider (also accepts OPENAI_API_KEY)
|
| 8 |
ENV_URL — Environment base URL (default: localhost:7860)
|
| 9 |
|
| 10 |
Usage:
|
|
|
|
| 19 |
import sys
|
| 20 |
import json
|
| 21 |
import time
|
| 22 |
+
from typing import Optional
|
| 23 |
|
| 24 |
import httpx
|
| 25 |
|
|
|
|
| 28 |
HF_TOKEN: str = os.environ.get("HF_TOKEN") or os.environ.get("OPENAI_API_KEY", "")
|
| 29 |
ENV_URL: str = os.environ.get("ENV_URL", "http://localhost:7860").rstrip("/")
|
| 30 |
|
| 31 |
+
# Curriculum ordering: easy → medium → medium-hard → hard
|
| 32 |
+
# Research (CAMRL, Curriculum RL): start with simpler tasks to build
|
| 33 |
+
# foundational skills, progress to harder multi-file and multi-language tasks.
|
| 34 |
+
TASK_IDS = [
|
| 35 |
+
"bug-detection", # easy: pure logic bugs, single file
|
| 36 |
+
"security-audit", # medium: OWASP Top-10, single file
|
| 37 |
+
"async-review", # medium-hard: async concurrency, subtle bugs
|
| 38 |
+
"data-pipeline", # hard: SQL injection + crypto + performance
|
| 39 |
+
"comprehensive-review", # hard: multi-file Django, mixed issue types
|
| 40 |
+
"api-security", # hard: FastAPI auth/authz/injection
|
| 41 |
+
"js-security", # hard: JavaScript (cross-language generalization)
|
| 42 |
+
]
|
| 43 |
|
| 44 |
SYSTEM_PROMPT = """\
|
| 45 |
+
You are an expert software engineer performing a thorough, methodical code review.
|
| 46 |
+
|
| 47 |
+
Your mission: identify ALL real bugs, security vulnerabilities, and performance issues.
|
| 48 |
+
|
| 49 |
+
## REVIEW CHECKLIST — work through EVERY category for EVERY function:
|
| 50 |
+
|
| 51 |
+
### Security (check EVERY function for these)
|
| 52 |
+
- Hardcoded secrets / API keys / passwords / tokens
|
| 53 |
+
- SQL injection: f-strings/template literals/string concat in queries
|
| 54 |
+
- Command injection: shell=True, os.system(), execSync() with user input
|
| 55 |
+
- XSS: unsanitized user input in HTML templates / res.send()
|
| 56 |
+
- Path traversal: path.join/os.path.join with user-supplied paths
|
| 57 |
+
- IDOR: missing authorization — authenticated vs authorized
|
| 58 |
+
- Insecure deserialization: pickle.loads(), new Function(), eval() on user input
|
| 59 |
+
- Broken crypto: MD5/SHA1 for passwords; missing salt; weak PRNG
|
| 60 |
+
- JWT issues: missing expiry ('exp'), algorithm confusion, hardcoded secret
|
| 61 |
+
- Missing authentication on sensitive endpoints
|
| 62 |
+
|
| 63 |
+
### Bugs & Logic Errors (check EVERY function for these)
|
| 64 |
+
- Off-by-one errors in ranges, slices, loop bounds, retry conditions
|
| 65 |
+
- Wrong initial values (counters starting at 0 instead of 1)
|
| 66 |
+
- Race conditions (shared mutable state without locks/atomicity)
|
| 67 |
+
- Missing transaction atomicity (partial writes to DB)
|
| 68 |
+
- Wrong type arguments (int where object required, e.g. aiohttp timeout)
|
| 69 |
+
- State that accumulates across calls (class fields not reset)
|
| 70 |
+
|
| 71 |
+
### Performance (check EVERY loop and DB call)
|
| 72 |
+
- N+1 database queries (DB call inside a loop)
|
| 73 |
+
- Sequential async where gather() should be used
|
| 74 |
+
- One transaction per row in bulk operations
|
| 75 |
+
- Uncapped pagination (no max limit on per_page)
|
| 76 |
+
|
| 77 |
+
### Resource Management
|
| 78 |
+
- Unclosed sessions/connections/file handles
|
| 79 |
+
- Missing context managers (async with, with)
|
| 80 |
+
|
| 81 |
+
## RESPONSE FORMAT
|
| 82 |
+
|
| 83 |
+
For each issue you find, respond with ONE raw JSON object:
|
| 84 |
+
{"action_type": "flag_issue", "line_number": <int>, "filename": "<file>",
|
| 85 |
+
"issue_type": "bug|security|performance|logic",
|
| 86 |
+
"severity": "low|medium|high|critical",
|
| 87 |
+
"description": "<specific explanation>",
|
| 88 |
+
"fix_suggestion": "<concrete fix>",
|
| 89 |
+
"confidence": <0.0-1.0>}
|
| 90 |
+
|
| 91 |
+
When finished, respond with:
|
| 92 |
+
{"action_type": "submit_review"}
|
| 93 |
+
|
| 94 |
+
## RULES
|
| 95 |
+
- Raw JSON only — no markdown fences, no extra text
|
| 96 |
- One action per response
|
| 97 |
+
- Count lines carefully from line 1 (including blank lines and comments)
|
| 98 |
+
- Only flag REAL issues — no style preferences, no hypothetical issues
|
| 99 |
+
- Be precise: "SQL injection at line 19 via f-string in SELECT query" not just "SQL injection"
|
| 100 |
+
- Flag the EXACT line where the problem code is (the f-string line, not the function def)
|
| 101 |
"""
|
| 102 |
|
| 103 |
|
|
|
|
| 112 |
kwargs["base_url"] = API_BASE_URL
|
| 113 |
|
| 114 |
client = OpenAI(**kwargs)
|
| 115 |
+
try:
|
| 116 |
+
response = client.chat.completions.create(
|
| 117 |
+
model=MODEL_NAME,
|
| 118 |
+
messages=messages,
|
| 119 |
+
temperature=0.0,
|
| 120 |
+
max_tokens=500,
|
| 121 |
+
)
|
| 122 |
+
return response.choices[0].message.content.strip()
|
| 123 |
+
except Exception as e:
|
| 124 |
+
print(f" LLM call failed: {e}")
|
| 125 |
+
raise
|
| 126 |
|
| 127 |
|
| 128 |
def parse_action(text: str) -> dict:
|
|
|
|
| 157 |
|
| 158 |
def run_keyword_fallback(base_url: str, task_id: str) -> dict:
|
| 159 |
"""Fallback: use the built-in /baseline endpoint (no LLM needed)."""
|
| 160 |
+
try:
|
| 161 |
+
with httpx.Client(timeout=30) as client:
|
| 162 |
+
resp = client.post(f"{base_url}/baseline")
|
| 163 |
+
resp.raise_for_status()
|
| 164 |
+
results = resp.json()
|
| 165 |
+
score = results["baseline_scores"].get(task_id, {}).get("score", 0.0)
|
| 166 |
+
return {"task_id": task_id, "score": score, "steps": 0, "method": "keyword_heuristic"}
|
| 167 |
+
except Exception as e:
|
| 168 |
+
print(f" Keyword fallback failed: {e}")
|
| 169 |
+
return {"task_id": task_id, "score": 0.0, "steps": 0, "method": "error"}
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def _build_progress_feedback(obs: dict) -> str:
|
| 173 |
+
"""Build a rich feedback string from observation progress data."""
|
| 174 |
+
progress = obs.get("progress") or {}
|
| 175 |
+
flagged_summary = obs.get("flagged_summary") or {}
|
| 176 |
+
|
| 177 |
+
parts = []
|
| 178 |
+
if progress:
|
| 179 |
+
f1 = progress.get("f1", 0)
|
| 180 |
+
precision = progress.get("precision", 0)
|
| 181 |
+
recall = progress.get("recall", 0)
|
| 182 |
+
tp = int(progress.get("true_positives", 0))
|
| 183 |
+
total_gt = int(progress.get("total_ground_truth", 0))
|
| 184 |
+
steps_rem = int(progress.get("steps_remaining", 0))
|
| 185 |
+
unfound = progress.get("unfound_issue_types", [])
|
| 186 |
+
|
| 187 |
+
parts.append(
|
| 188 |
+
f"Score progress: {tp}/{total_gt} issues confirmed | "
|
| 189 |
+
f"F1={f1:.2f} Precision={precision:.2f} Recall={recall:.2f} | "
|
| 190 |
+
f"{steps_rem} steps remaining"
|
| 191 |
+
)
|
| 192 |
+
if unfound:
|
| 193 |
+
parts.append(
|
| 194 |
+
f"IMPORTANT — still need to find: {unfound}. "
|
| 195 |
+
f"Search specifically for those issue types."
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
if flagged_summary:
|
| 199 |
+
incorrect = flagged_summary.get("incorrect", 0)
|
| 200 |
+
near = flagged_summary.get("near_misses", 0)
|
| 201 |
+
if incorrect > 0:
|
| 202 |
+
parts.append(
|
| 203 |
+
f"WARNING: {incorrect} false positive(s) hurting precision. "
|
| 204 |
+
f"Consider using clear_flag to remove uncertain flags."
|
| 205 |
+
)
|
| 206 |
+
if near > 0:
|
| 207 |
+
parts.append(
|
| 208 |
+
f"NOTE: {near} near-miss(es) — you're close on line numbers, "
|
| 209 |
+
f"but slightly off. Re-check exact line and try reflagging."
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
return "\n".join(parts) if parts else ""
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def _should_submit(obs: dict, step_count: int, max_steps: int) -> bool:
|
| 216 |
+
"""
|
| 217 |
+
Smart submission: submit when recall is high or steps are nearly exhausted.
|
| 218 |
+
Avoids wasting steps after all real issues are found.
|
| 219 |
+
"""
|
| 220 |
+
progress = obs.get("progress", {})
|
| 221 |
+
recall = progress.get("recall", 0.0)
|
| 222 |
+
tp = int(progress.get("true_positives", 0))
|
| 223 |
+
total_gt = int(progress.get("total_ground_truth", 0))
|
| 224 |
+
steps_rem = int(progress.get("steps_remaining", 0))
|
| 225 |
+
unfound = progress.get("unfound_issue_types", [])
|
| 226 |
+
fp = int(progress.get("false_positives", 0))
|
| 227 |
+
|
| 228 |
+
# All issues found
|
| 229 |
+
if total_gt > 0 and tp >= total_gt:
|
| 230 |
+
return True
|
| 231 |
+
|
| 232 |
+
# No unfound categories and high recall
|
| 233 |
+
if not unfound and recall >= 0.85:
|
| 234 |
+
return True
|
| 235 |
+
|
| 236 |
+
# High recall overall (≥80%) and precision is decent (not too many FPs)
|
| 237 |
+
if recall >= 0.80 and (fp <= 2 or tp / max(tp + fp, 1) >= 0.6):
|
| 238 |
+
return True
|
| 239 |
+
|
| 240 |
+
# Very few steps left and we've done a reasonable scan
|
| 241 |
+
if steps_rem <= 2 and step_count >= 5:
|
| 242 |
+
return True
|
| 243 |
+
|
| 244 |
+
return False
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def _should_clear_flag(obs: dict, last_reward: float, last_action: dict) -> Optional[dict]:
|
| 248 |
+
"""
|
| 249 |
+
Recovery strategy: if the last flag was a false positive with high penalty,
|
| 250 |
+
suggest clearing it to recover partial reward and improve precision.
|
| 251 |
+
|
| 252 |
+
Returns a clear_flag action dict if we should recover, else None.
|
| 253 |
+
"""
|
| 254 |
+
if last_reward is None or last_reward >= 0:
|
| 255 |
+
return None
|
| 256 |
+
if last_action.get("action_type") != "flag_issue":
|
| 257 |
+
return None
|
| 258 |
+
|
| 259 |
+
# Only clear if it was a clear FP (no near-miss indicator in feedback)
|
| 260 |
+
# and we've got too many false positives
|
| 261 |
+
progress = obs.get("progress", {})
|
| 262 |
+
fp = int(progress.get("false_positives", 0))
|
| 263 |
+
tp = int(progress.get("true_positives", 0))
|
| 264 |
+
|
| 265 |
+
# If FP > TP and last reward was notably negative, clear the bad flag
|
| 266 |
+
if fp > tp and last_reward <= -0.05:
|
| 267 |
+
return {
|
| 268 |
+
"action_type": "clear_flag",
|
| 269 |
+
"line_number": last_action.get("line_number"),
|
| 270 |
+
"filename": last_action.get("filename"),
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
return None
|
| 274 |
|
| 275 |
|
| 276 |
def run_task(task_id: str, http_client: httpx.Client) -> dict:
|
| 277 |
+
try:
|
| 278 |
+
resp = http_client.post(f"{ENV_URL}/reset", json={"task_id": task_id}, timeout=30)
|
| 279 |
+
resp.raise_for_status()
|
| 280 |
+
obs = resp.json()
|
| 281 |
+
except Exception as e:
|
| 282 |
+
print(f" Reset failed: {e} — falling back to keyword heuristic")
|
| 283 |
+
return run_keyword_fallback(ENV_URL, task_id)
|
| 284 |
|
| 285 |
code_display = "\n\n".join(
|
| 286 |
+
f"=== {fname} (starting at line 1) ===\n{code}"
|
| 287 |
for fname, code in obs.get("code_files", {}).items()
|
| 288 |
)
|
| 289 |
|
| 290 |
+
# Include function map hint if available
|
| 291 |
+
code_metadata = obs.get("code_metadata") or {}
|
| 292 |
+
function_ranges = code_metadata.get("function_ranges", [])
|
| 293 |
+
fn_map_hint = ""
|
| 294 |
+
if function_ranges:
|
| 295 |
+
fn_lines = [f" {fr['name']}() in {fr['file']} (lines {fr['start']}-{fr['end']})"
|
| 296 |
+
for fr in function_ranges]
|
| 297 |
+
fn_map_hint = "\n\nFunction map:\n" + "\n".join(fn_lines)
|
| 298 |
+
|
| 299 |
+
task_desc = obs.get("task_description", "")
|
| 300 |
+
max_steps = obs.get("max_steps", 20)
|
| 301 |
+
issue_categories = code_metadata.get("issue_categories", [])
|
| 302 |
+
n_gt = len(obs.get("code_files", {})) # rough complexity hint
|
| 303 |
+
category_hint = ""
|
| 304 |
+
if issue_categories:
|
| 305 |
+
category_hint = f"\nIssue categories to look for: {sorted(set(issue_categories))}"
|
| 306 |
+
|
| 307 |
+
# RC-GRPO style reward conditioning (2025): tell the agent what quality level
|
| 308 |
+
# it should aim for, so it calibrates confidence appropriately.
|
| 309 |
+
state_features = code_metadata.get("state_features", [])
|
| 310 |
+
complexity_label = "medium"
|
| 311 |
+
if state_features and len(state_features) >= 4:
|
| 312 |
+
complexity_score = state_features[3]
|
| 313 |
+
complexity_label = "high" if complexity_score >= 1.0 else "medium" if complexity_score >= 0.5 else "low"
|
| 314 |
+
|
| 315 |
+
reward_conditioning = (
|
| 316 |
+
f"[TARGET: high-quality review, score ≥ 0.85. "
|
| 317 |
+
f"Code complexity: {complexity_label}. "
|
| 318 |
+
f"Be thorough — missing issues costs more than a single FP.]"
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
messages = [
|
| 322 |
{"role": "system", "content": SYSTEM_PROMPT},
|
| 323 |
{
|
| 324 |
"role": "user",
|
| 325 |
"content": (
|
| 326 |
+
f"{reward_conditioning}\n\n"
|
| 327 |
+
f"Task: {task_desc}\n\n"
|
| 328 |
+
f"{code_display}"
|
| 329 |
+
f"{fn_map_hint}"
|
| 330 |
+
f"{category_hint}\n\n"
|
| 331 |
+
f"You have {max_steps} steps total. "
|
| 332 |
+
f"Work through the checklist systematically, function by function. "
|
| 333 |
+
f"Flag each issue one at a time as a raw JSON object."
|
| 334 |
),
|
| 335 |
},
|
| 336 |
]
|
| 337 |
|
| 338 |
done = False
|
| 339 |
step_count = 0
|
|
|
|
| 340 |
final_score = 0.0
|
| 341 |
+
last_action: dict = {}
|
| 342 |
+
last_reward: Optional[float] = None
|
| 343 |
+
consecutive_fp = 0
|
| 344 |
|
| 345 |
while not done and step_count < max_steps:
|
| 346 |
+
# --- Auto clear_flag recovery: undo recent FP if hurting precision ---
|
| 347 |
+
recovery_action = _should_clear_flag(obs, last_reward, last_action)
|
| 348 |
+
if recovery_action and step_count < max_steps - 1:
|
| 349 |
+
action = recovery_action
|
| 350 |
+
action_text = json.dumps(action)
|
| 351 |
+
print(f" Auto-recovery: clearing FP at {action.get('filename')}:{action.get('line_number')}")
|
| 352 |
+
else:
|
| 353 |
+
# --- Normal LLM action ---
|
| 354 |
+
try:
|
| 355 |
+
action_text = chat_completion(messages)
|
| 356 |
+
except Exception as e:
|
| 357 |
+
print(f" LLM unavailable ({e}) — submitting and falling back to keyword heuristic")
|
| 358 |
+
try:
|
| 359 |
+
http_client.post(f"{ENV_URL}/step", json={"action_type": "submit_review"}, timeout=30)
|
| 360 |
+
except Exception:
|
| 361 |
+
pass
|
| 362 |
+
return run_keyword_fallback(ENV_URL, task_id)
|
| 363 |
+
|
| 364 |
+
action = parse_action(action_text)
|
| 365 |
+
|
| 366 |
+
# Smart submission: inject submit if progress shows we're done
|
| 367 |
+
if action.get("action_type") != "submit_review" and _should_submit(obs, step_count, max_steps):
|
| 368 |
+
print(f" Smart submit at step {step_count + 1} (recall target met)")
|
| 369 |
+
action = {"action_type": "submit_review"}
|
| 370 |
+
action_text = json.dumps(action)
|
| 371 |
|
| 372 |
try:
|
| 373 |
step_resp = http_client.post(f"{ENV_URL}/step", json=action, timeout=30)
|
|
|
|
| 379 |
|
| 380 |
done = obs.get("done", False)
|
| 381 |
step_count += 1
|
| 382 |
+
last_reward = obs.get("reward")
|
| 383 |
+
# Use terminal reward (final grade) when done, else intermediate score
|
| 384 |
+
if done:
|
| 385 |
+
final_score = last_reward or obs.get("current_score", 0.0)
|
| 386 |
+
else:
|
| 387 |
+
final_score = obs.get("current_score", 0.0)
|
| 388 |
+
last_action = action
|
| 389 |
+
|
| 390 |
+
# Track consecutive FPs for logging
|
| 391 |
+
if last_reward is not None and last_reward < 0 and action.get("action_type") == "flag_issue":
|
| 392 |
+
consecutive_fp += 1
|
| 393 |
+
else:
|
| 394 |
+
consecutive_fp = 0
|
| 395 |
+
|
| 396 |
+
# Build rich feedback for next LLM turn
|
| 397 |
+
progress_feedback = _build_progress_feedback(obs)
|
| 398 |
+
env_feedback = obs.get("feedback", "")
|
| 399 |
+
combined_feedback = env_feedback
|
| 400 |
+
if progress_feedback:
|
| 401 |
+
combined_feedback += f"\n{progress_feedback}"
|
| 402 |
|
| 403 |
messages.append({"role": "assistant", "content": action_text})
|
| 404 |
+
if combined_feedback:
|
| 405 |
+
messages.append({"role": "user", "content": combined_feedback})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
|
| 407 |
atype = action.get("action_type", "")
|
| 408 |
+
print(f" Step {step_count:2d}: {atype:20s} | reward={str(last_reward):8s} | score={obs.get('current_score', 0.0):.3f}")
|
| 409 |
|
| 410 |
if atype == "submit_review":
|
| 411 |
final_score = obs.get("reward", obs.get("current_score", 0.0)) or 0.0
|
|
|
|
| 447 |
print(f"Running task: {task_id}")
|
| 448 |
result = run_task(task_id, client)
|
| 449 |
results[task_id] = result
|
| 450 |
+
print(f" → score: {result['score']:.4f} ({result['steps']} steps, method={result['method']})\n")
|
| 451 |
else:
|
| 452 |
print("HF_TOKEN / API_BASE_URL not set — using built-in keyword heuristic baseline.\n")
|
| 453 |
for task_id in TASK_IDS:
|
models.py
CHANGED
|
@@ -80,6 +80,8 @@ class ReviewAction(_BaseAction):
|
|
| 80 |
severity: Optional[str] = None
|
| 81 |
description: str = ""
|
| 82 |
fix_suggestion: Optional[str] = None
|
|
|
|
|
|
|
| 83 |
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 84 |
|
| 85 |
def to_dict(self) -> dict:
|
|
@@ -91,6 +93,8 @@ class ReviewAction(_BaseAction):
|
|
| 91 |
"severity": self.severity,
|
| 92 |
"description": self.description,
|
| 93 |
"fix_suggestion": self.fix_suggestion,
|
|
|
|
|
|
|
| 94 |
}
|
| 95 |
|
| 96 |
@classmethod
|
|
@@ -103,6 +107,8 @@ class ReviewAction(_BaseAction):
|
|
| 103 |
severity=d.get("severity"),
|
| 104 |
description=str(d.get("description", "")),
|
| 105 |
fix_suggestion=d.get("fix_suggestion"),
|
|
|
|
|
|
|
| 106 |
)
|
| 107 |
|
| 108 |
|
|
@@ -125,6 +131,11 @@ class ReviewObservation(_BaseObservation):
|
|
| 125 |
done: bool = False
|
| 126 |
reward: Optional[float] = None
|
| 127 |
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
def to_dict(self) -> dict:
|
| 130 |
return {
|
|
@@ -141,6 +152,10 @@ class ReviewObservation(_BaseObservation):
|
|
| 141 |
"done": self.done,
|
| 142 |
"reward": self.reward,
|
| 143 |
"metadata": self.metadata,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
}
|
| 145 |
|
| 146 |
@classmethod
|
|
@@ -158,6 +173,11 @@ class ReviewObservation(_BaseObservation):
|
|
| 158 |
current_score=d.get("current_score", 0.0),
|
| 159 |
done=d.get("done", False),
|
| 160 |
reward=d.get("reward"),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
)
|
| 162 |
|
| 163 |
|
|
|
|
| 80 |
severity: Optional[str] = None
|
| 81 |
description: str = ""
|
| 82 |
fix_suggestion: Optional[str] = None
|
| 83 |
+
confidence: Optional[float] = None # agent's confidence 0.0–1.0
|
| 84 |
+
related_lines: Optional[List[int]] = None # multi-line issues
|
| 85 |
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 86 |
|
| 87 |
def to_dict(self) -> dict:
|
|
|
|
| 93 |
"severity": self.severity,
|
| 94 |
"description": self.description,
|
| 95 |
"fix_suggestion": self.fix_suggestion,
|
| 96 |
+
"confidence": self.confidence,
|
| 97 |
+
"related_lines": self.related_lines,
|
| 98 |
}
|
| 99 |
|
| 100 |
@classmethod
|
|
|
|
| 107 |
severity=d.get("severity"),
|
| 108 |
description=str(d.get("description", "")),
|
| 109 |
fix_suggestion=d.get("fix_suggestion"),
|
| 110 |
+
confidence=d.get("confidence"),
|
| 111 |
+
related_lines=d.get("related_lines"),
|
| 112 |
)
|
| 113 |
|
| 114 |
|
|
|
|
| 131 |
done: bool = False
|
| 132 |
reward: Optional[float] = None
|
| 133 |
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 134 |
+
# New fields
|
| 135 |
+
reward_breakdown: Dict[str, float] = field(default_factory=dict)
|
| 136 |
+
progress: Dict[str, float] = field(default_factory=dict)
|
| 137 |
+
flagged_summary: Dict[str, Any] = field(default_factory=dict)
|
| 138 |
+
code_metadata: Dict[str, Any] = field(default_factory=dict)
|
| 139 |
|
| 140 |
def to_dict(self) -> dict:
|
| 141 |
return {
|
|
|
|
| 152 |
"done": self.done,
|
| 153 |
"reward": self.reward,
|
| 154 |
"metadata": self.metadata,
|
| 155 |
+
"reward_breakdown": self.reward_breakdown,
|
| 156 |
+
"progress": self.progress,
|
| 157 |
+
"flagged_summary": self.flagged_summary,
|
| 158 |
+
"code_metadata": self.code_metadata,
|
| 159 |
}
|
| 160 |
|
| 161 |
@classmethod
|
|
|
|
| 173 |
current_score=d.get("current_score", 0.0),
|
| 174 |
done=d.get("done", False),
|
| 175 |
reward=d.get("reward"),
|
| 176 |
+
metadata=d.get("metadata", {}),
|
| 177 |
+
reward_breakdown=d.get("reward_breakdown", {}),
|
| 178 |
+
progress=d.get("progress", {}),
|
| 179 |
+
flagged_summary=d.get("flagged_summary", {}),
|
| 180 |
+
code_metadata=d.get("code_metadata", {}),
|
| 181 |
)
|
| 182 |
|
| 183 |
|
openenv.yaml
CHANGED
|
@@ -1,11 +1,58 @@
|
|
| 1 |
spec_version: 1
|
| 2 |
name: code_review_env
|
| 3 |
-
version: "
|
| 4 |
description: >
|
| 5 |
-
A code review and security audit environment for training AI agents.
|
| 6 |
The agent identifies bugs, security vulnerabilities, and performance issues
|
| 7 |
-
across
|
|
|
|
|
|
|
| 8 |
type: space
|
| 9 |
runtime: fastapi
|
| 10 |
app: server.app:app
|
|
|
|
| 11 |
port: 7860
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
spec_version: 1
|
| 2 |
name: code_review_env
|
| 3 |
+
version: "2.0.0"
|
| 4 |
description: >
|
| 5 |
+
A code review and security audit RL environment for training AI agents.
|
| 6 |
The agent identifies bugs, security vulnerabilities, and performance issues
|
| 7 |
+
across 7 tasks of increasing difficulty (easy → medium → medium-hard → hard).
|
| 8 |
+
Features: PBRS reward shaping, graduated near-miss rewards, flood protection,
|
| 9 |
+
CAMRL curriculum selector, VL return normalization, and cross-language tasks.
|
| 10 |
type: space
|
| 11 |
runtime: fastapi
|
| 12 |
app: server.app:app
|
| 13 |
+
entry_point: server
|
| 14 |
port: 7860
|
| 15 |
+
|
| 16 |
+
tasks:
|
| 17 |
+
- id: bug-detection
|
| 18 |
+
difficulty: easy
|
| 19 |
+
language: python
|
| 20 |
+
num_issues: 3
|
| 21 |
+
max_steps: 15
|
| 22 |
+
- id: security-audit
|
| 23 |
+
difficulty: medium
|
| 24 |
+
language: python
|
| 25 |
+
num_issues: 7
|
| 26 |
+
max_steps: 20
|
| 27 |
+
- id: async-review
|
| 28 |
+
difficulty: medium-hard
|
| 29 |
+
language: python
|
| 30 |
+
num_issues: 6
|
| 31 |
+
max_steps: 20
|
| 32 |
+
- id: data-pipeline
|
| 33 |
+
difficulty: hard
|
| 34 |
+
language: python
|
| 35 |
+
num_issues: 7
|
| 36 |
+
max_steps: 25
|
| 37 |
+
- id: comprehensive-review
|
| 38 |
+
difficulty: hard
|
| 39 |
+
language: python
|
| 40 |
+
num_issues: 9
|
| 41 |
+
max_steps: 30
|
| 42 |
+
- id: api-security
|
| 43 |
+
difficulty: hard
|
| 44 |
+
language: python
|
| 45 |
+
num_issues: 8
|
| 46 |
+
max_steps: 25
|
| 47 |
+
- id: js-security
|
| 48 |
+
difficulty: hard
|
| 49 |
+
language: javascript
|
| 50 |
+
num_issues: 8
|
| 51 |
+
max_steps: 25
|
| 52 |
+
|
| 53 |
+
reward_design:
|
| 54 |
+
terminal: "0.70 * F1 + 0.30 * severity_accuracy"
|
| 55 |
+
shaping: "PBRS (Ng et al. 1999): phi(s) = (tp/total_gt) * 0.5"
|
| 56 |
+
near_miss: "exponential decay: 0.10 * exp(-0.6 * (line_diff - 2))"
|
| 57 |
+
flood_protection: "escalating FP penalty after 3rd false positive"
|
| 58 |
+
normalization: "VL Norm (2025): normalized_return = cumulative / steps_used"
|
server/app.py
CHANGED
|
@@ -21,6 +21,7 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
| 21 |
import json
|
| 22 |
import asyncio
|
| 23 |
import dataclasses
|
|
|
|
| 24 |
from typing import Optional, List, Dict, Any
|
| 25 |
|
| 26 |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
|
|
@@ -29,7 +30,10 @@ from pydantic import BaseModel
|
|
| 29 |
|
| 30 |
from models import ReviewAction, Issue
|
| 31 |
from server.environment import CodeReviewEnvironment
|
| 32 |
-
from server.graders import
|
|
|
|
|
|
|
|
|
|
| 33 |
from tasks.data import ALL_TASKS, TASK_IDS
|
| 34 |
|
| 35 |
|
|
@@ -45,6 +49,7 @@ def _serialize(obj) -> dict:
|
|
| 45 |
|
| 46 |
|
| 47 |
_env_instance = CodeReviewEnvironment()
|
|
|
|
| 48 |
|
| 49 |
|
| 50 |
def _make_app() -> FastAPI:
|
|
@@ -245,27 +250,25 @@ async def run_grader(request: GraderRequest):
|
|
| 245 |
|
| 246 |
flagged = [Issue.from_dict(i) for i in request.flagged_issues]
|
| 247 |
ground_truth = [Issue.from_dict(gt) for gt in task["ground_truth_issues"]]
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
tp = sum(
|
| 251 |
-
1 for f in flagged
|
| 252 |
-
if any(
|
| 253 |
-
True for gt in ground_truth
|
| 254 |
-
if abs(f.line_number - gt.line_number) <= 2
|
| 255 |
-
and f.filename == gt.filename
|
| 256 |
-
)
|
| 257 |
-
)
|
| 258 |
|
| 259 |
return {
|
| 260 |
"task_id": request.task_id,
|
| 261 |
"difficulty": task["difficulty"],
|
| 262 |
-
"score": score,
|
| 263 |
"max_score": 1.0,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
"details": {
|
| 265 |
"total_flagged": len(flagged),
|
| 266 |
-
"true_positives":
|
| 267 |
-
"false_positives":
|
|
|
|
|
|
|
| 268 |
"total_ground_truth": len(ground_truth),
|
|
|
|
| 269 |
},
|
| 270 |
}
|
| 271 |
|
|
@@ -296,6 +299,180 @@ async def run_baseline():
|
|
| 296 |
}
|
| 297 |
|
| 298 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
def main():
|
| 300 |
import uvicorn
|
| 301 |
port = int(os.environ.get("PORT", 7860))
|
|
|
|
| 21 |
import json
|
| 22 |
import asyncio
|
| 23 |
import dataclasses
|
| 24 |
+
import random
|
| 25 |
from typing import Optional, List, Dict, Any
|
| 26 |
|
| 27 |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
|
|
|
|
| 30 |
|
| 31 |
from models import ReviewAction, Issue
|
| 32 |
from server.environment import CodeReviewEnvironment
|
| 33 |
+
from server.graders import (
|
| 34 |
+
grade_episode, grade_episode_detailed, run_keyword_baseline,
|
| 35 |
+
compute_code_state_features, RewardNormalizer,
|
| 36 |
+
)
|
| 37 |
from tasks.data import ALL_TASKS, TASK_IDS
|
| 38 |
|
| 39 |
|
|
|
|
| 49 |
|
| 50 |
|
| 51 |
_env_instance = CodeReviewEnvironment()
|
| 52 |
+
_reward_normalizer = RewardNormalizer(window_size=100)
|
| 53 |
|
| 54 |
|
| 55 |
def _make_app() -> FastAPI:
|
|
|
|
| 250 |
|
| 251 |
flagged = [Issue.from_dict(i) for i in request.flagged_issues]
|
| 252 |
ground_truth = [Issue.from_dict(gt) for gt in task["ground_truth_issues"]]
|
| 253 |
+
detailed = grade_episode_detailed(flagged, ground_truth)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
|
| 255 |
return {
|
| 256 |
"task_id": request.task_id,
|
| 257 |
"difficulty": task["difficulty"],
|
| 258 |
+
"score": detailed["score"],
|
| 259 |
"max_score": 1.0,
|
| 260 |
+
"f1": detailed["f1"],
|
| 261 |
+
"precision": detailed["precision"],
|
| 262 |
+
"recall": detailed["recall"],
|
| 263 |
+
"severity_accuracy": detailed["severity_accuracy"],
|
| 264 |
"details": {
|
| 265 |
"total_flagged": len(flagged),
|
| 266 |
+
"true_positives": detailed["true_positives"],
|
| 267 |
+
"false_positives": detailed["false_positives"],
|
| 268 |
+
"false_negatives": detailed["false_negatives"],
|
| 269 |
+
"near_misses": detailed["near_misses"],
|
| 270 |
"total_ground_truth": len(ground_truth),
|
| 271 |
+
"per_file": detailed["per_file"],
|
| 272 |
},
|
| 273 |
}
|
| 274 |
|
|
|
|
| 299 |
}
|
| 300 |
|
| 301 |
|
| 302 |
+
class CurriculumRequest(BaseModel):
|
| 303 |
+
agent_performance: Optional[Dict[str, Any]] = None
|
| 304 |
+
easy_threshold: float = 0.30
|
| 305 |
+
hard_threshold: float = 0.70
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
@app.post("/curriculum")
|
| 309 |
+
async def curriculum_task_selector(request: CurriculumRequest):
|
| 310 |
+
"""
|
| 311 |
+
CAMRL-style curriculum task selector (Curriculum-based Asymmetric Multi-Task RL, TPAMI 2023).
|
| 312 |
+
|
| 313 |
+
Given agent performance metrics per task, returns the recommended next task_id
|
| 314 |
+
based on curriculum phase:
|
| 315 |
+
- easy phase (avg_success < 0.30): focus on task with fewest issues
|
| 316 |
+
- medium phase (0.30-0.70): mix easy/hard (70% easy, 30% hard)
|
| 317 |
+
- hard phase (avg_success > 0.70): focus on least-solved hard tasks
|
| 318 |
+
|
| 319 |
+
Body:
|
| 320 |
+
agent_performance: {task_id: {success_rate: 0.5, episodes: 10, avg_score: 0.4}}
|
| 321 |
+
easy_threshold: float (default 0.3)
|
| 322 |
+
hard_threshold: float (default 0.7)
|
| 323 |
+
"""
|
| 324 |
+
perf = request.agent_performance or {}
|
| 325 |
+
easy_thresh = request.easy_threshold
|
| 326 |
+
hard_thresh = request.hard_threshold
|
| 327 |
+
|
| 328 |
+
# Build difficulty estimate per task: (1 - success_rate) × complexity
|
| 329 |
+
task_difficulty: Dict[str, float] = {}
|
| 330 |
+
for task_id, task in ALL_TASKS.items():
|
| 331 |
+
n_issues = len(task["ground_truth_issues"])
|
| 332 |
+
complexity = min(1.0, n_issues / 10.0)
|
| 333 |
+
task_perf = perf.get(task_id, {})
|
| 334 |
+
success_rate = float(task_perf.get("success_rate", task_perf.get("avg_score", 0.0)))
|
| 335 |
+
task_difficulty[task_id] = round((1.0 - success_rate) * complexity, 4)
|
| 336 |
+
|
| 337 |
+
# Determine curriculum phase
|
| 338 |
+
if perf:
|
| 339 |
+
all_success = [float(p.get("success_rate", p.get("avg_score", 0.0))) for p in perf.values()]
|
| 340 |
+
avg_success = sum(all_success) / len(all_success)
|
| 341 |
+
else:
|
| 342 |
+
avg_success = 0.0
|
| 343 |
+
|
| 344 |
+
if avg_success < easy_thresh:
|
| 345 |
+
phase = "easy"
|
| 346 |
+
# Focus on task with lowest ground truth issue count (most approachable)
|
| 347 |
+
recommended = min(ALL_TASKS.keys(), key=lambda t: len(ALL_TASKS[t]["ground_truth_issues"]))
|
| 348 |
+
elif avg_success > hard_thresh:
|
| 349 |
+
phase = "hard"
|
| 350 |
+
# Focus on hardest unsolved task (highest difficulty score)
|
| 351 |
+
recommended = max(task_difficulty, key=task_difficulty.get)
|
| 352 |
+
else:
|
| 353 |
+
phase = "medium"
|
| 354 |
+
# Mix: pick a task proportional to difficulty (harder = more likely)
|
| 355 |
+
import random
|
| 356 |
+
weights = list(task_difficulty.values())
|
| 357 |
+
total_w = sum(weights) or 1.0
|
| 358 |
+
probs = [w / total_w for w in weights]
|
| 359 |
+
recommended = random.choices(list(task_difficulty.keys()), weights=probs, k=1)[0]
|
| 360 |
+
|
| 361 |
+
return {
|
| 362 |
+
"recommended_task_id": recommended,
|
| 363 |
+
"curriculum_phase": phase,
|
| 364 |
+
"avg_success_rate": round(avg_success, 4),
|
| 365 |
+
"task_difficulty_scores": task_difficulty,
|
| 366 |
+
"thresholds": {"easy": easy_thresh, "hard": hard_thresh},
|
| 367 |
+
"method": "CAMRL",
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
@app.get("/reward_normalizer")
|
| 372 |
+
async def get_reward_normalizer_stats():
|
| 373 |
+
"""
|
| 374 |
+
Return current RewardNormalizer statistics for the running environment.
|
| 375 |
+
Useful for monitoring VL Norm across training runs.
|
| 376 |
+
"""
|
| 377 |
+
return _reward_normalizer.to_dict()
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
@app.post("/record_episode")
|
| 381 |
+
async def record_episode(body: Dict[str, Any]):
|
| 382 |
+
"""
|
| 383 |
+
Record a completed episode's return and length for VL Norm statistics.
|
| 384 |
+
Body: {"episode_return": 0.72, "episode_length": 12}
|
| 385 |
+
"""
|
| 386 |
+
episode_return = float(body.get("episode_return", 0.0))
|
| 387 |
+
episode_length = int(body.get("episode_length", 1))
|
| 388 |
+
_reward_normalizer.update(episode_return, episode_length)
|
| 389 |
+
normalized = _reward_normalizer.normalize(episode_return, episode_length)
|
| 390 |
+
return {
|
| 391 |
+
"normalized_return": normalized,
|
| 392 |
+
"stats": _reward_normalizer.to_dict(),
|
| 393 |
+
}
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
class TRLRolloutRequest(BaseModel):
|
| 397 |
+
task_id: Optional[str] = None
|
| 398 |
+
seed: Optional[int] = None
|
| 399 |
+
actions: List[Dict[str, Any]] # Pre-generated action sequence from LLM
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
@app.post("/trl_rollout")
|
| 403 |
+
async def trl_rollout(request: TRLRolloutRequest):
|
| 404 |
+
"""
|
| 405 |
+
Run a full episode from a pre-generated action sequence.
|
| 406 |
+
|
| 407 |
+
Designed for TRL GRPOTrainer custom rollout_fn integration:
|
| 408 |
+
- Takes a sequence of LLM-generated actions
|
| 409 |
+
- Runs them through the environment
|
| 410 |
+
- Returns trajectory dict with per-step rewards and final score
|
| 411 |
+
|
| 412 |
+
This enables offline rollout: LLM generates all actions first,
|
| 413 |
+
then this endpoint evaluates them, matching TRL's batch-rollout pattern.
|
| 414 |
+
|
| 415 |
+
Body:
|
| 416 |
+
task_id: str (optional, random if not set)
|
| 417 |
+
seed: int (optional)
|
| 418 |
+
actions: [{action_type, line_number, filename, ...}, ...]
|
| 419 |
+
|
| 420 |
+
Returns:
|
| 421 |
+
trajectory: [{step, action, reward, feedback, done}]
|
| 422 |
+
episode_return: float (sum of step rewards)
|
| 423 |
+
final_score: float (terminal grade)
|
| 424 |
+
normalized_return: float (episode_return / num_steps)
|
| 425 |
+
state_features: [float] (12-dim feature vector at end of episode)
|
| 426 |
+
"""
|
| 427 |
+
rollout_env = CodeReviewEnvironment()
|
| 428 |
+
obs = rollout_env.reset(task_id=request.task_id, seed=request.seed)
|
| 429 |
+
|
| 430 |
+
trajectory = []
|
| 431 |
+
episode_return = 0.0
|
| 432 |
+
final_score = 0.0
|
| 433 |
+
|
| 434 |
+
for step_idx, action_dict in enumerate(request.actions):
|
| 435 |
+
action = ReviewAction.from_dict(action_dict)
|
| 436 |
+
obs_step = rollout_env.step(action)
|
| 437 |
+
step_data = _serialize(obs_step)
|
| 438 |
+
|
| 439 |
+
reward = step_data.get("reward") or 0.0
|
| 440 |
+
episode_return += reward
|
| 441 |
+
|
| 442 |
+
trajectory.append({
|
| 443 |
+
"step": step_idx + 1,
|
| 444 |
+
"action": action_dict,
|
| 445 |
+
"reward": reward,
|
| 446 |
+
"reward_breakdown": step_data.get("reward_breakdown", {}),
|
| 447 |
+
"feedback": step_data.get("feedback", ""),
|
| 448 |
+
"current_score": step_data.get("current_score", 0.0),
|
| 449 |
+
"done": step_data.get("done", False),
|
| 450 |
+
})
|
| 451 |
+
|
| 452 |
+
if step_data.get("done", False):
|
| 453 |
+
final_score = step_data.get("reward", step_data.get("current_score", 0.0)) or 0.0
|
| 454 |
+
break
|
| 455 |
+
|
| 456 |
+
n_steps = max(len(trajectory), 1)
|
| 457 |
+
# Record in global normalizer for VL Norm statistics
|
| 458 |
+
_reward_normalizer.update(episode_return, n_steps)
|
| 459 |
+
normalized = _reward_normalizer.normalize(episode_return, n_steps)
|
| 460 |
+
|
| 461 |
+
# Get final state features
|
| 462 |
+
final_progress = rollout_env._compute_progress(rollout_env._task["max_steps"] if rollout_env._task else 20)
|
| 463 |
+
|
| 464 |
+
return {
|
| 465 |
+
"task_id": request.task_id,
|
| 466 |
+
"trajectory": trajectory,
|
| 467 |
+
"episode_return": round(episode_return, 4),
|
| 468 |
+
"final_score": round(final_score, 4),
|
| 469 |
+
"normalized_return": normalized,
|
| 470 |
+
"num_steps": n_steps,
|
| 471 |
+
"state_features": final_progress.get("state_features", []),
|
| 472 |
+
"final_progress": {k: v for k, v in final_progress.items() if k != "state_features"},
|
| 473 |
+
}
|
| 474 |
+
|
| 475 |
+
|
| 476 |
def main():
|
| 477 |
import uvicorn
|
| 478 |
port = int(os.environ.get("PORT", 7860))
|
server/environment.py
CHANGED
|
@@ -9,11 +9,15 @@ import sys
|
|
| 9 |
import os
|
| 10 |
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 11 |
|
| 12 |
-
from typing import Optional, List
|
| 13 |
|
| 14 |
from models import Issue, ReviewAction, ReviewObservation, ReviewState
|
| 15 |
from tasks.data import ALL_TASKS, TASK_IDS
|
| 16 |
-
from server.graders import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
try:
|
| 19 |
from openenv.core.env_server import Environment as _BaseEnv
|
|
@@ -25,21 +29,44 @@ except ImportError:
|
|
| 25 |
pass
|
| 26 |
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
class CodeReviewEnvironment(_BaseEnv):
|
| 29 |
"""
|
| 30 |
-
A code review and security audit environment.
|
| 31 |
|
| 32 |
The agent receives code files and must identify bugs, security
|
| 33 |
vulnerabilities, and performance issues by flagging them with
|
| 34 |
exact line numbers, types, and severity ratings.
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
"""
|
| 44 |
|
| 45 |
SUPPORTS_CONCURRENT_SESSIONS = False
|
|
@@ -49,6 +76,10 @@ class CodeReviewEnvironment(_BaseEnv):
|
|
| 49 |
self._task: Optional[dict] = None
|
| 50 |
self._ground_truth: List[Issue] = []
|
| 51 |
self._hint_index: int = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
def reset(
|
| 54 |
self,
|
|
@@ -70,6 +101,9 @@ class CodeReviewEnvironment(_BaseEnv):
|
|
| 70 |
for gt in self._task["ground_truth_issues"]
|
| 71 |
]
|
| 72 |
self._hint_index = 0
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
self._state = ReviewState(
|
| 75 |
task_id=task_id,
|
|
@@ -81,6 +115,16 @@ class CodeReviewEnvironment(_BaseEnv):
|
|
| 81 |
submitted=False,
|
| 82 |
)
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
return ReviewObservation(
|
| 85 |
task_id=task_id,
|
| 86 |
task_description=self._task["description"],
|
|
@@ -93,11 +137,16 @@ class CodeReviewEnvironment(_BaseEnv):
|
|
| 93 |
feedback=(
|
| 94 |
f"New episode started. Task: {self._task['difficulty'].upper()}. "
|
| 95 |
f"Review the code carefully and flag all issues you find. "
|
| 96 |
-
f"Use 'submit_review' when done."
|
|
|
|
| 97 |
),
|
| 98 |
current_score=0.0,
|
| 99 |
done=False,
|
| 100 |
reward=None,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
)
|
| 102 |
|
| 103 |
def step(
|
|
@@ -133,26 +182,43 @@ class CodeReviewEnvironment(_BaseEnv):
|
|
| 133 |
action = ReviewAction.from_dict(action)
|
| 134 |
|
| 135 |
self._state.step_count += 1
|
| 136 |
-
reward, feedback = self._process_action(action)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
max_steps = self._task["max_steps"]
|
| 139 |
auto_end = self._state.step_count >= max_steps and not self._state.submitted
|
| 140 |
done = self._state.submitted or auto_end
|
| 141 |
|
| 142 |
if auto_end and not self._state.submitted:
|
| 143 |
-
#
|
| 144 |
final = grade_episode(self._state.flagged_issues, self._ground_truth)
|
| 145 |
self._state.current_score = final
|
| 146 |
-
reward = final
|
|
|
|
| 147 |
feedback += (
|
| 148 |
-
f"
|
| 149 |
-
f"Submit earlier for
|
| 150 |
)
|
| 151 |
self._state.submitted = True
|
| 152 |
|
| 153 |
live = compute_live_score(self._state.flagged_issues, self._ground_truth)
|
| 154 |
self._state.current_score = live
|
| 155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
return ReviewObservation(
|
| 157 |
task_id=self._state.task_id,
|
| 158 |
task_description="",
|
|
@@ -166,12 +232,130 @@ class CodeReviewEnvironment(_BaseEnv):
|
|
| 166 |
current_score=live,
|
| 167 |
done=done,
|
| 168 |
reward=reward,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
)
|
| 170 |
|
| 171 |
@property
|
| 172 |
def state(self) -> ReviewState:
|
| 173 |
return self._state
|
| 174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
def _process_action(self, action: ReviewAction):
|
| 176 |
atype = (action.action_type or "").strip().lower()
|
| 177 |
|
|
@@ -187,25 +371,26 @@ class CodeReviewEnvironment(_BaseEnv):
|
|
| 187 |
return 0.0, (
|
| 188 |
f"Unknown action_type '{action.action_type}'. "
|
| 189 |
"Use: flag_issue | clear_flag | request_hint | submit_review"
|
| 190 |
-
)
|
| 191 |
|
| 192 |
def _handle_flag(self, action: ReviewAction):
|
| 193 |
if action.line_number is None:
|
| 194 |
-
return
|
| 195 |
if not action.filename:
|
| 196 |
-
return
|
| 197 |
if action.issue_type not in ("bug", "security", "performance", "logic", None):
|
| 198 |
action.issue_type = "bug"
|
| 199 |
if action.severity not in ("low", "medium", "high", "critical", None):
|
| 200 |
action.severity = "medium"
|
| 201 |
|
|
|
|
| 202 |
for existing in self._state.flagged_issues:
|
| 203 |
if (existing.line_number == action.line_number
|
| 204 |
and existing.filename == action.filename):
|
| 205 |
return 0.0, (
|
| 206 |
f"Line {action.line_number} in {action.filename} already flagged. "
|
| 207 |
-
"Use clear_flag first
|
| 208 |
-
)
|
| 209 |
|
| 210 |
new_issue = Issue(
|
| 211 |
line_number=action.line_number,
|
|
@@ -216,95 +401,258 @@ class CodeReviewEnvironment(_BaseEnv):
|
|
| 216 |
fix_suggestion=action.fix_suggestion,
|
| 217 |
)
|
| 218 |
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
|
| 224 |
self._state.flagged_issues.append(new_issue)
|
| 225 |
|
| 226 |
-
|
| 227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
feedback = (
|
| 229 |
-
f"
|
| 230 |
-
f"[+
|
| 231 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
else:
|
| 233 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
feedback = (
|
| 235 |
-
f"
|
| 236 |
-
f"[
|
| 237 |
)
|
| 238 |
|
| 239 |
-
return reward, feedback
|
| 240 |
|
| 241 |
def _handle_clear(self, action: ReviewAction):
|
| 242 |
if action.line_number is None or not action.filename:
|
| 243 |
-
return
|
| 244 |
-
|
| 245 |
-
before = len(self._state.flagged_issues)
|
| 246 |
-
removed = None
|
| 247 |
-
self._state.flagged_issues = [
|
| 248 |
-
f for f in self._state.flagged_issues
|
| 249 |
-
if not (f.line_number == action.line_number
|
| 250 |
-
and f.filename == action.filename)
|
| 251 |
-
]
|
| 252 |
|
| 253 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
return 0.0, (
|
| 255 |
f"No flagged issue found at {action.filename}:{action.line_number}."
|
| 256 |
-
)
|
| 257 |
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
issue_type="bug",
|
| 262 |
-
severity="medium",
|
| 263 |
-
)
|
| 264 |
was_tp = any(match_issue(removed_issue, gt) for gt in self._ground_truth)
|
| 265 |
|
| 266 |
if was_tp:
|
| 267 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
feedback = (
|
| 269 |
f"Removed a correct finding at {action.filename}:{action.line_number}. "
|
| 270 |
-
f"[
|
| 271 |
)
|
| 272 |
else:
|
| 273 |
-
|
|
|
|
|
|
|
|
|
|
| 274 |
feedback = (
|
| 275 |
f"Removed a false positive at {action.filename}:{action.line_number}. "
|
| 276 |
-
f"[+
|
| 277 |
)
|
| 278 |
|
| 279 |
-
return reward, feedback
|
| 280 |
|
| 281 |
def _handle_hint(self):
|
| 282 |
hints = self._task.get("hints", [])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
if self._hint_index >= len(hints):
|
| 284 |
-
return
|
| 285 |
|
| 286 |
hint = hints[self._hint_index]
|
| 287 |
self._hint_index += 1
|
| 288 |
remaining = len(hints) - self._hint_index
|
| 289 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
|
| 291 |
def _handle_submit(self):
|
| 292 |
self._state.submitted = True
|
| 293 |
final_score = grade_episode(self._state.flagged_issues, self._ground_truth)
|
| 294 |
self._state.current_score = final_score
|
| 295 |
|
| 296 |
-
tp_count =
|
| 297 |
-
1 for f in self._state.flagged_issues
|
| 298 |
-
if any(match_issue(f, gt) for gt in self._ground_truth)
|
| 299 |
-
)
|
| 300 |
total_gt = len(self._ground_truth)
|
| 301 |
total_flagged = len(self._state.flagged_issues)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
|
| 303 |
feedback = (
|
| 304 |
f"Review submitted! Final score: {final_score:.3f}. "
|
| 305 |
-
f"Found {tp_count}/{total_gt}
|
| 306 |
-
f"
|
| 307 |
-
f"
|
| 308 |
)
|
| 309 |
-
|
| 310 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
import os
|
| 10 |
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 11 |
|
| 12 |
+
from typing import Optional, List, Dict, Any, Set
|
| 13 |
|
| 14 |
from models import Issue, ReviewAction, ReviewObservation, ReviewState
|
| 15 |
from tasks.data import ALL_TASKS, TASK_IDS
|
| 16 |
+
from server.graders import (
|
| 17 |
+
grade_episode, compute_live_score, match_issue, match_quality,
|
| 18 |
+
compute_code_metadata, grade_episode_detailed,
|
| 19 |
+
graduated_near_reward, compute_potential, compute_code_state_features,
|
| 20 |
+
)
|
| 21 |
|
| 22 |
try:
|
| 23 |
from openenv.core.env_server import Environment as _BaseEnv
|
|
|
|
| 29 |
pass
|
| 30 |
|
| 31 |
|
| 32 |
+
# Reward constants
|
| 33 |
+
_BASE_TP_REWARD = 0.10
|
| 34 |
+
_NEAR_MISS_REWARD = 0.03
|
| 35 |
+
_BASE_FP_PENALTY = -0.05
|
| 36 |
+
_SEVERITY_EXACT_BONUS = 0.02 # when severity exactly matches GT
|
| 37 |
+
_TEMPORAL_BONUS = 0.02 # early correct flag (first 40% of steps)
|
| 38 |
+
_CONFIDENCE_TP_BONUS = 0.01 # high-confidence TP
|
| 39 |
+
_CONFIDENCE_FP_EXTRA = -0.03 # high-confidence FP (penalty multiplier)
|
| 40 |
+
_HINT_COST = -0.01
|
| 41 |
+
_REMOVE_TP_PENALTY = -0.03
|
| 42 |
+
_REMOVE_FP_REWARD = 0.03
|
| 43 |
+
_VALIDATION_PENALTY = -0.02
|
| 44 |
+
# Flood protection: escalating FP penalty
|
| 45 |
+
_FP_FLOOD_THRESHOLD = 3 # FPs before escalation kicks in
|
| 46 |
+
_FP_FLOOD_MULTIPLIER = 1.5 # each extra FP beyond threshold costs 1.5x more
|
| 47 |
+
|
| 48 |
+
_SEV_RANK = {"low": 0, "medium": 1, "high": 2, "critical": 3}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
class CodeReviewEnvironment(_BaseEnv):
|
| 52 |
"""
|
| 53 |
+
A code review and security audit RL environment.
|
| 54 |
|
| 55 |
The agent receives code files and must identify bugs, security
|
| 56 |
vulnerabilities, and performance issues by flagging them with
|
| 57 |
exact line numbers, types, and severity ratings.
|
| 58 |
|
| 59 |
+
Reward design:
|
| 60 |
+
- True positive flag: +0.10 base, +0.02 severity exact match,
|
| 61 |
+
+0.02 early (first 40% steps), +0.01 high-confidence TP
|
| 62 |
+
- Near-miss (±3-5 lines): +0.03 partial credit
|
| 63 |
+
- False positive: -0.05 base, escalating penalty after 3rd FP,
|
| 64 |
+
extra -0.03 for high-confidence FP
|
| 65 |
+
- Clear false positive: +0.03
|
| 66 |
+
- Clear true positive: -0.03
|
| 67 |
+
- Hint: -0.01
|
| 68 |
+
- Submit: final F1+severity score (0.0–1.0)
|
| 69 |
+
- Auto-end (max_steps): full grade score (no penalty)
|
| 70 |
"""
|
| 71 |
|
| 72 |
SUPPORTS_CONCURRENT_SESSIONS = False
|
|
|
|
| 76 |
self._task: Optional[dict] = None
|
| 77 |
self._ground_truth: List[Issue] = []
|
| 78 |
self._hint_index: int = 0
|
| 79 |
+
self._code_metadata: Dict[str, Any] = {}
|
| 80 |
+
self._fp_count: int = 0 # total false positives this episode
|
| 81 |
+
self._matched_gt_indices: Set[int] = set() # GT indices already matched
|
| 82 |
+
self._episode_rewards: List[float] = [] # for VL return normalization
|
| 83 |
|
| 84 |
def reset(
|
| 85 |
self,
|
|
|
|
| 101 |
for gt in self._task["ground_truth_issues"]
|
| 102 |
]
|
| 103 |
self._hint_index = 0
|
| 104 |
+
self._fp_count = 0
|
| 105 |
+
self._matched_gt_indices = set()
|
| 106 |
+
self._episode_rewards = []
|
| 107 |
|
| 108 |
self._state = ReviewState(
|
| 109 |
task_id=task_id,
|
|
|
|
| 115 |
submitted=False,
|
| 116 |
)
|
| 117 |
|
| 118 |
+
issue_categories = list({gt.issue_type for gt in self._ground_truth})
|
| 119 |
+
self._code_metadata = compute_code_metadata(
|
| 120 |
+
self._task["code_files"],
|
| 121 |
+
issue_categories=issue_categories,
|
| 122 |
+
)
|
| 123 |
+
# Pre-compute initial state features (progress=empty at reset)
|
| 124 |
+
self._code_metadata["state_features"] = compute_code_state_features(
|
| 125 |
+
self._code_metadata, progress={}
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
return ReviewObservation(
|
| 129 |
task_id=task_id,
|
| 130 |
task_description=self._task["description"],
|
|
|
|
| 137 |
feedback=(
|
| 138 |
f"New episode started. Task: {self._task['difficulty'].upper()}. "
|
| 139 |
f"Review the code carefully and flag all issues you find. "
|
| 140 |
+
f"Use 'submit_review' when done. "
|
| 141 |
+
f"Issue categories present: {sorted(set(issue_categories))}."
|
| 142 |
),
|
| 143 |
current_score=0.0,
|
| 144 |
done=False,
|
| 145 |
reward=None,
|
| 146 |
+
reward_breakdown={},
|
| 147 |
+
progress={},
|
| 148 |
+
flagged_summary={},
|
| 149 |
+
code_metadata=self._code_metadata,
|
| 150 |
)
|
| 151 |
|
| 152 |
def step(
|
|
|
|
| 182 |
action = ReviewAction.from_dict(action)
|
| 183 |
|
| 184 |
self._state.step_count += 1
|
| 185 |
+
reward, feedback, reward_breakdown = self._process_action(action)
|
| 186 |
+
|
| 187 |
+
# Track episode rewards for VL return normalization
|
| 188 |
+
if reward is not None:
|
| 189 |
+
self._episode_rewards.append(float(reward))
|
| 190 |
|
| 191 |
max_steps = self._task["max_steps"]
|
| 192 |
auto_end = self._state.step_count >= max_steps and not self._state.submitted
|
| 193 |
done = self._state.submitted or auto_end
|
| 194 |
|
| 195 |
if auto_end and not self._state.submitted:
|
| 196 |
+
# Auto-end: grade in full (no penalty for hitting step limit)
|
| 197 |
final = grade_episode(self._state.flagged_issues, self._ground_truth)
|
| 198 |
self._state.current_score = final
|
| 199 |
+
reward = final # full score, no 0.5x penalty
|
| 200 |
+
reward_breakdown = {"auto_end_grade": final, "total": final}
|
| 201 |
feedback += (
|
| 202 |
+
f" Step budget exhausted — auto-graded: {final:.3f}. "
|
| 203 |
+
f"Submit earlier next time for slightly cleaner feedback."
|
| 204 |
)
|
| 205 |
self._state.submitted = True
|
| 206 |
|
| 207 |
live = compute_live_score(self._state.flagged_issues, self._ground_truth)
|
| 208 |
self._state.current_score = live
|
| 209 |
|
| 210 |
+
progress = self._compute_progress(max_steps)
|
| 211 |
+
flagged_summary = self._compute_flagged_summary()
|
| 212 |
+
|
| 213 |
+
# PRM-style dense signal: expected reward-to-go
|
| 214 |
+
# Based on Process Reward Models research: give agent an estimate of
|
| 215 |
+
# how much reward is still available, so it can plan remaining steps.
|
| 216 |
+
tp_found = len(self._matched_gt_indices)
|
| 217 |
+
total_gt = len(self._ground_truth)
|
| 218 |
+
issues_remaining = total_gt - tp_found
|
| 219 |
+
# Expected: each remaining TP gives ~0.12 (base + avg severity bonus)
|
| 220 |
+
expected_reward_to_go = round(issues_remaining * 0.12, 3)
|
| 221 |
+
|
| 222 |
return ReviewObservation(
|
| 223 |
task_id=self._state.task_id,
|
| 224 |
task_description="",
|
|
|
|
| 232 |
current_score=live,
|
| 233 |
done=done,
|
| 234 |
reward=reward,
|
| 235 |
+
reward_breakdown=reward_breakdown,
|
| 236 |
+
progress=progress,
|
| 237 |
+
flagged_summary=flagged_summary,
|
| 238 |
+
code_metadata={}, # Only populated on reset
|
| 239 |
+
metadata={
|
| 240 |
+
"issues_remaining": issues_remaining,
|
| 241 |
+
"expected_reward_to_go": expected_reward_to_go,
|
| 242 |
+
},
|
| 243 |
)
|
| 244 |
|
| 245 |
@property
|
| 246 |
def state(self) -> ReviewState:
|
| 247 |
return self._state
|
| 248 |
|
| 249 |
+
# ------------------------------------------------------------------
|
| 250 |
+
# Progress and summary helpers
|
| 251 |
+
# ------------------------------------------------------------------
|
| 252 |
+
|
| 253 |
+
def _compute_progress(self, max_steps: int) -> Dict[str, Any]:
|
| 254 |
+
"""Compute live precision/recall/f1, step stats, and unfound issue types."""
|
| 255 |
+
flagged = self._state.flagged_issues
|
| 256 |
+
gt = self._ground_truth
|
| 257 |
+
|
| 258 |
+
tp = 0
|
| 259 |
+
fp = 0
|
| 260 |
+
matched: Set[int] = set()
|
| 261 |
+
found_types: Set[str] = set()
|
| 262 |
+
|
| 263 |
+
for flag in flagged:
|
| 264 |
+
hit = False
|
| 265 |
+
for i, g in enumerate(gt):
|
| 266 |
+
if i not in matched and match_issue(flag, g):
|
| 267 |
+
tp += 1
|
| 268 |
+
matched.add(i)
|
| 269 |
+
found_types.add(g.issue_type)
|
| 270 |
+
hit = True
|
| 271 |
+
break
|
| 272 |
+
if not hit:
|
| 273 |
+
fp += 1
|
| 274 |
+
|
| 275 |
+
fn = len(gt) - len(matched)
|
| 276 |
+
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
|
| 277 |
+
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
|
| 278 |
+
f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0
|
| 279 |
+
|
| 280 |
+
all_types = {g.issue_type for g in gt}
|
| 281 |
+
unfound_types = sorted(all_types - found_types)
|
| 282 |
+
|
| 283 |
+
steps_used = self._state.step_count
|
| 284 |
+
steps_remaining = max(0, max_steps - steps_used)
|
| 285 |
+
|
| 286 |
+
# Variable-Length Return Normalization (VL Norm 2025):
|
| 287 |
+
# normalized_return = cumulative_reward / max(steps_used, 1)
|
| 288 |
+
# This makes return comparable across episodes of different length,
|
| 289 |
+
# which is key for multi-task RL where tasks have different max_steps.
|
| 290 |
+
cumulative_reward = sum(self._episode_rewards)
|
| 291 |
+
normalized_return = round(cumulative_reward / max(steps_used, 1), 4)
|
| 292 |
+
|
| 293 |
+
progress = {
|
| 294 |
+
"precision": round(precision, 4),
|
| 295 |
+
"recall": round(recall, 4),
|
| 296 |
+
"f1": round(f1, 4),
|
| 297 |
+
"true_positives": float(tp),
|
| 298 |
+
"false_positives": float(fp),
|
| 299 |
+
"total_ground_truth": float(len(gt)),
|
| 300 |
+
"steps_used": float(steps_used),
|
| 301 |
+
"steps_remaining": float(steps_remaining),
|
| 302 |
+
"unfound_issue_types": unfound_types,
|
| 303 |
+
"normalized_return": normalized_return,
|
| 304 |
+
"cumulative_reward": round(cumulative_reward, 4),
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
# 12-dim state feature vector for RL policy/value networks (code2vec/PBRS literature)
|
| 308 |
+
progress["state_features"] = compute_code_state_features(
|
| 309 |
+
self._code_metadata, progress=progress
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
return progress
|
| 313 |
+
|
| 314 |
+
def _compute_flagged_summary(self) -> Dict[str, Any]:
|
| 315 |
+
"""Compute correct/incorrect/near_miss counts."""
|
| 316 |
+
flagged = self._state.flagged_issues
|
| 317 |
+
gt = self._ground_truth
|
| 318 |
+
|
| 319 |
+
correct = 0
|
| 320 |
+
near_misses = 0
|
| 321 |
+
incorrect = 0
|
| 322 |
+
matched_gt: Set[int] = set()
|
| 323 |
+
|
| 324 |
+
for flag in flagged:
|
| 325 |
+
matched = False
|
| 326 |
+
for i, g in enumerate(gt):
|
| 327 |
+
if i in matched_gt:
|
| 328 |
+
continue
|
| 329 |
+
if match_issue(flag, g):
|
| 330 |
+
correct += 1
|
| 331 |
+
matched_gt.add(i)
|
| 332 |
+
matched = True
|
| 333 |
+
break
|
| 334 |
+
|
| 335 |
+
if not matched:
|
| 336 |
+
is_near = False
|
| 337 |
+
for i, g in enumerate(gt):
|
| 338 |
+
if i in matched_gt:
|
| 339 |
+
continue
|
| 340 |
+
if match_quality(flag, g) == "near":
|
| 341 |
+
is_near = True
|
| 342 |
+
break
|
| 343 |
+
if is_near:
|
| 344 |
+
near_misses += 1
|
| 345 |
+
else:
|
| 346 |
+
incorrect += 1
|
| 347 |
+
|
| 348 |
+
return {
|
| 349 |
+
"total_flagged": len(flagged),
|
| 350 |
+
"correct": correct,
|
| 351 |
+
"incorrect": incorrect,
|
| 352 |
+
"near_misses": near_misses,
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
# ------------------------------------------------------------------
|
| 356 |
+
# Action handlers
|
| 357 |
+
# ------------------------------------------------------------------
|
| 358 |
+
|
| 359 |
def _process_action(self, action: ReviewAction):
|
| 360 |
atype = (action.action_type or "").strip().lower()
|
| 361 |
|
|
|
|
| 371 |
return 0.0, (
|
| 372 |
f"Unknown action_type '{action.action_type}'. "
|
| 373 |
"Use: flag_issue | clear_flag | request_hint | submit_review"
|
| 374 |
+
), {}
|
| 375 |
|
| 376 |
def _handle_flag(self, action: ReviewAction):
|
| 377 |
if action.line_number is None:
|
| 378 |
+
return _VALIDATION_PENALTY, "flag_issue requires 'line_number'.", {"validation_penalty": _VALIDATION_PENALTY}
|
| 379 |
if not action.filename:
|
| 380 |
+
return _VALIDATION_PENALTY, "flag_issue requires 'filename'.", {"validation_penalty": _VALIDATION_PENALTY}
|
| 381 |
if action.issue_type not in ("bug", "security", "performance", "logic", None):
|
| 382 |
action.issue_type = "bug"
|
| 383 |
if action.severity not in ("low", "medium", "high", "critical", None):
|
| 384 |
action.severity = "medium"
|
| 385 |
|
| 386 |
+
# Duplicate check
|
| 387 |
for existing in self._state.flagged_issues:
|
| 388 |
if (existing.line_number == action.line_number
|
| 389 |
and existing.filename == action.filename):
|
| 390 |
return 0.0, (
|
| 391 |
f"Line {action.line_number} in {action.filename} already flagged. "
|
| 392 |
+
"Use clear_flag first to change it."
|
| 393 |
+
), {"duplicate": 0.0}
|
| 394 |
|
| 395 |
new_issue = Issue(
|
| 396 |
line_number=action.line_number,
|
|
|
|
| 401 |
fix_suggestion=action.fix_suggestion,
|
| 402 |
)
|
| 403 |
|
| 404 |
+
# Classify: TP, near-miss (with line distance), or FP
|
| 405 |
+
is_tp = False
|
| 406 |
+
is_near = False
|
| 407 |
+
near_line_diff = 0
|
| 408 |
+
matched_gt_issue: Optional[Issue] = None
|
| 409 |
+
matched_gt_idx: Optional[int] = None
|
| 410 |
+
|
| 411 |
+
for i, gt in enumerate(self._ground_truth):
|
| 412 |
+
q = match_quality(new_issue, gt)
|
| 413 |
+
if q == "exact" and i not in self._matched_gt_indices:
|
| 414 |
+
is_tp = True
|
| 415 |
+
matched_gt_issue = gt
|
| 416 |
+
matched_gt_idx = i
|
| 417 |
+
break
|
| 418 |
+
elif q == "near" and not is_near:
|
| 419 |
+
is_near = True
|
| 420 |
+
near_line_diff = abs(new_issue.line_number - gt.line_number)
|
| 421 |
|
| 422 |
self._state.flagged_issues.append(new_issue)
|
| 423 |
|
| 424 |
+
# PBRS: compute potential before and after this flag
|
| 425 |
+
tp_before = len(self._matched_gt_indices)
|
| 426 |
+
total_gt = len(self._ground_truth)
|
| 427 |
+
|
| 428 |
+
reward_breakdown: Dict[str, float] = {}
|
| 429 |
+
|
| 430 |
+
if is_tp and matched_gt_issue is not None and matched_gt_idx is not None:
|
| 431 |
+
self._matched_gt_indices.add(matched_gt_idx)
|
| 432 |
+
tp_after = len(self._matched_gt_indices)
|
| 433 |
+
|
| 434 |
+
base_reward = _BASE_TP_REWARD
|
| 435 |
+
reward_breakdown["base_tp"] = base_reward
|
| 436 |
+
|
| 437 |
+
# Severity exact match bonus
|
| 438 |
+
severity_bonus = 0.0
|
| 439 |
+
if new_issue.severity == matched_gt_issue.severity:
|
| 440 |
+
severity_bonus = _SEVERITY_EXACT_BONUS
|
| 441 |
+
reward_breakdown["severity_exact"] = severity_bonus
|
| 442 |
+
|
| 443 |
+
# Temporal bonus: TP caught in first 40% of max_steps
|
| 444 |
+
max_steps = self._task["max_steps"]
|
| 445 |
+
early_threshold = max(1, int(max_steps * 0.4))
|
| 446 |
+
temporal_bonus = 0.0
|
| 447 |
+
if self._state.step_count <= early_threshold:
|
| 448 |
+
temporal_bonus = _TEMPORAL_BONUS
|
| 449 |
+
reward_breakdown["temporal_bonus"] = temporal_bonus
|
| 450 |
+
|
| 451 |
+
# Confidence calibration: high confidence TP → small bonus
|
| 452 |
+
confidence_bonus = 0.0
|
| 453 |
+
if action.confidence is not None and action.confidence >= 0.7:
|
| 454 |
+
confidence_bonus = _CONFIDENCE_TP_BONUS
|
| 455 |
+
reward_breakdown["confidence_bonus"] = confidence_bonus
|
| 456 |
+
|
| 457 |
+
# PBRS: Φ(s') - Φ(s) (potential-based shaping, policy-invariant)
|
| 458 |
+
phi_before = compute_potential(tp_before, total_gt)
|
| 459 |
+
phi_after = compute_potential(tp_after, total_gt)
|
| 460 |
+
pbrs_bonus = round(phi_after - phi_before, 4)
|
| 461 |
+
reward_breakdown["pbrs_shaping"] = pbrs_bonus
|
| 462 |
+
|
| 463 |
+
reward = base_reward + severity_bonus + temporal_bonus + confidence_bonus + pbrs_bonus
|
| 464 |
+
reward_breakdown["total"] = round(reward, 4)
|
| 465 |
+
|
| 466 |
+
sev_note = f", severity +{severity_bonus:.2f}" if severity_bonus else ""
|
| 467 |
+
temp_note = f", early +{temporal_bonus:.2f}" if temporal_bonus else ""
|
| 468 |
+
conf_note = f", conf +{confidence_bonus:.2f}" if confidence_bonus else ""
|
| 469 |
+
pbrs_note = f", progress +{pbrs_bonus:.2f}" if pbrs_bonus > 0 else ""
|
| 470 |
feedback = (
|
| 471 |
+
f"Correct! Issue at {action.filename}:{action.line_number} confirmed. "
|
| 472 |
+
f"[+{reward:.2f}{sev_note}{temp_note}{conf_note}{pbrs_note}]"
|
| 473 |
)
|
| 474 |
+
|
| 475 |
+
elif is_near:
|
| 476 |
+
# Graduated near-miss: smooth exponential decay by line distance
|
| 477 |
+
near_reward = graduated_near_reward(near_line_diff)
|
| 478 |
+
reward_breakdown["near_miss"] = near_reward
|
| 479 |
+
reward_breakdown["line_diff"] = float(near_line_diff)
|
| 480 |
+
reward_breakdown["total"] = near_reward
|
| 481 |
+
feedback = (
|
| 482 |
+
f"Close! Near a real issue at {action.filename}:{action.line_number}. "
|
| 483 |
+
f"[+{near_reward:.3f} — {near_line_diff} lines off, adjust line number]"
|
| 484 |
+
)
|
| 485 |
+
reward = near_reward
|
| 486 |
+
|
| 487 |
else:
|
| 488 |
+
# False positive — with flood protection
|
| 489 |
+
self._fp_count += 1
|
| 490 |
+
|
| 491 |
+
base_penalty = _BASE_FP_PENALTY
|
| 492 |
+
reward_breakdown["base_fp"] = base_penalty
|
| 493 |
+
|
| 494 |
+
# Escalating penalty after FP_FLOOD_THRESHOLD FPs
|
| 495 |
+
flood_penalty = 0.0
|
| 496 |
+
if self._fp_count > _FP_FLOOD_THRESHOLD:
|
| 497 |
+
extra = self._fp_count - _FP_FLOOD_THRESHOLD
|
| 498 |
+
flood_penalty = round(-0.02 * extra * _FP_FLOOD_MULTIPLIER, 3)
|
| 499 |
+
reward_breakdown["flood_penalty"] = flood_penalty
|
| 500 |
+
|
| 501 |
+
# High-confidence FP: extra penalty
|
| 502 |
+
confidence_penalty = 0.0
|
| 503 |
+
if action.confidence is not None and action.confidence >= 0.7:
|
| 504 |
+
confidence_penalty = _CONFIDENCE_FP_EXTRA
|
| 505 |
+
reward_breakdown["confidence_penalty"] = confidence_penalty
|
| 506 |
+
|
| 507 |
+
reward = base_penalty + flood_penalty + confidence_penalty
|
| 508 |
+
reward_breakdown["total"] = round(reward, 4)
|
| 509 |
+
|
| 510 |
+
flood_note = f", over-flagging -{abs(flood_penalty):.2f}" if flood_penalty else ""
|
| 511 |
+
conf_note = f", high-confidence penalty {confidence_penalty:.2f}" if confidence_penalty else ""
|
| 512 |
feedback = (
|
| 513 |
+
f"No match at {action.filename}:{action.line_number}. "
|
| 514 |
+
f"[{reward:.2f} — false positive{flood_note}{conf_note}]"
|
| 515 |
)
|
| 516 |
|
| 517 |
+
return reward, feedback, reward_breakdown
|
| 518 |
|
| 519 |
def _handle_clear(self, action: ReviewAction):
|
| 520 |
if action.line_number is None or not action.filename:
|
| 521 |
+
return _VALIDATION_PENALTY, "clear_flag requires 'line_number' and 'filename'.", {"validation_penalty": _VALIDATION_PENALTY}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 522 |
|
| 523 |
+
removed_issue = None
|
| 524 |
+
new_list = []
|
| 525 |
+
for f in self._state.flagged_issues:
|
| 526 |
+
if f.line_number == action.line_number and f.filename == action.filename:
|
| 527 |
+
removed_issue = f
|
| 528 |
+
else:
|
| 529 |
+
new_list.append(f)
|
| 530 |
+
|
| 531 |
+
if removed_issue is None:
|
| 532 |
return 0.0, (
|
| 533 |
f"No flagged issue found at {action.filename}:{action.line_number}."
|
| 534 |
+
), {"no_op": 0.0}
|
| 535 |
|
| 536 |
+
self._state.flagged_issues = new_list
|
| 537 |
+
|
| 538 |
+
# Check if removed issue was TP
|
|
|
|
|
|
|
|
|
|
| 539 |
was_tp = any(match_issue(removed_issue, gt) for gt in self._ground_truth)
|
| 540 |
|
| 541 |
if was_tp:
|
| 542 |
+
# Un-track it from matched set
|
| 543 |
+
for i, gt in enumerate(self._ground_truth):
|
| 544 |
+
if match_issue(removed_issue, gt):
|
| 545 |
+
self._matched_gt_indices.discard(i)
|
| 546 |
+
break
|
| 547 |
+
reward = _REMOVE_TP_PENALTY
|
| 548 |
+
reward_breakdown = {"removed_tp": reward, "total": reward}
|
| 549 |
feedback = (
|
| 550 |
f"Removed a correct finding at {action.filename}:{action.line_number}. "
|
| 551 |
+
f"[{reward:.2f}]"
|
| 552 |
)
|
| 553 |
else:
|
| 554 |
+
# Removing a FP — decrement counter
|
| 555 |
+
self._fp_count = max(0, self._fp_count - 1)
|
| 556 |
+
reward = _REMOVE_FP_REWARD
|
| 557 |
+
reward_breakdown = {"removed_fp": reward, "total": reward}
|
| 558 |
feedback = (
|
| 559 |
f"Removed a false positive at {action.filename}:{action.line_number}. "
|
| 560 |
+
f"[+{reward:.2f} — good correction]"
|
| 561 |
)
|
| 562 |
|
| 563 |
+
return reward, feedback, reward_breakdown
|
| 564 |
|
| 565 |
def _handle_hint(self):
|
| 566 |
hints = self._task.get("hints", [])
|
| 567 |
+
|
| 568 |
+
adaptive_hint = self._get_adaptive_hint()
|
| 569 |
+
if adaptive_hint:
|
| 570 |
+
return _HINT_COST, f"Hint: {adaptive_hint} ({_HINT_COST} reward)", {"hint_cost": _HINT_COST}
|
| 571 |
+
|
| 572 |
if self._hint_index >= len(hints):
|
| 573 |
+
return _HINT_COST, "No more hints available for this task.", {"hint_cost": _HINT_COST}
|
| 574 |
|
| 575 |
hint = hints[self._hint_index]
|
| 576 |
self._hint_index += 1
|
| 577 |
remaining = len(hints) - self._hint_index
|
| 578 |
+
return _HINT_COST, f"Hint {self._hint_index}/{len(hints)}: {hint} ({remaining} hints left)", {"hint_cost": _HINT_COST}
|
| 579 |
+
|
| 580 |
+
def _get_adaptive_hint(self) -> Optional[str]:
|
| 581 |
+
"""Generate a context-aware hint based on current episode state."""
|
| 582 |
+
flagged = self._state.flagged_issues
|
| 583 |
+
gt = self._ground_truth
|
| 584 |
+
|
| 585 |
+
if not gt:
|
| 586 |
+
return None
|
| 587 |
+
|
| 588 |
+
tp_count = len(self._matched_gt_indices)
|
| 589 |
+
fp_count = len(flagged) - tp_count - sum(
|
| 590 |
+
1 for f in flagged
|
| 591 |
+
if any(match_quality(f, g) == "near" for g in gt)
|
| 592 |
+
)
|
| 593 |
+
|
| 594 |
+
issue_categories = self._code_metadata.get("issue_categories", [])
|
| 595 |
+
|
| 596 |
+
# Many false positives: over-flagging
|
| 597 |
+
if fp_count > tp_count and fp_count >= 2:
|
| 598 |
+
return (
|
| 599 |
+
"You are over-flagging. Focus only on confident, concrete findings. "
|
| 600 |
+
"Consider using clear_flag to remove uncertain flags."
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
# No correct flags at all yet
|
| 604 |
+
if len(flagged) > 0 and tp_count == 0:
|
| 605 |
+
if issue_categories:
|
| 606 |
+
cats = ", ".join(sorted(set(issue_categories)))
|
| 607 |
+
return (
|
| 608 |
+
f"Focus on [{cats}] issues. "
|
| 609 |
+
"None of your current flags match real issues. Re-examine carefully."
|
| 610 |
+
)
|
| 611 |
+
|
| 612 |
+
# Found some but missed whole categories
|
| 613 |
+
if tp_count > 0 and issue_categories:
|
| 614 |
+
found_types: Set[str] = set()
|
| 615 |
+
for i in self._matched_gt_indices:
|
| 616 |
+
found_types.add(gt[i].issue_type)
|
| 617 |
+
missed = sorted(set(issue_categories) - found_types)
|
| 618 |
+
if missed:
|
| 619 |
+
missed_str = ", ".join(missed)
|
| 620 |
+
return (
|
| 621 |
+
f"Good progress! You've found some issues but haven't flagged any "
|
| 622 |
+
f"[{missed_str}] issues yet — look again for those specifically."
|
| 623 |
+
)
|
| 624 |
+
|
| 625 |
+
return None # Fall through to static hints
|
| 626 |
|
| 627 |
def _handle_submit(self):
|
| 628 |
self._state.submitted = True
|
| 629 |
final_score = grade_episode(self._state.flagged_issues, self._ground_truth)
|
| 630 |
self._state.current_score = final_score
|
| 631 |
|
| 632 |
+
tp_count = len(self._matched_gt_indices)
|
|
|
|
|
|
|
|
|
|
| 633 |
total_gt = len(self._ground_truth)
|
| 634 |
total_flagged = len(self._state.flagged_issues)
|
| 635 |
+
fp_count = total_flagged - tp_count
|
| 636 |
+
|
| 637 |
+
# Breakdown for detailed feedback
|
| 638 |
+
detailed = grade_episode_detailed(self._state.flagged_issues, self._ground_truth)
|
| 639 |
|
| 640 |
feedback = (
|
| 641 |
f"Review submitted! Final score: {final_score:.3f}. "
|
| 642 |
+
f"Found {tp_count}/{total_gt} issues. "
|
| 643 |
+
f"Precision: {detailed['precision']:.2f}, Recall: {detailed['recall']:.2f}, "
|
| 644 |
+
f"F1: {detailed['f1']:.2f}. "
|
| 645 |
)
|
| 646 |
+
if fp_count > 0:
|
| 647 |
+
feedback += f"{fp_count} false positive(s). "
|
| 648 |
+
if detailed["false_negatives"] > 0:
|
| 649 |
+
fn = detailed["false_negatives"]
|
| 650 |
+
feedback += f"{fn} issue(s) missed."
|
| 651 |
+
|
| 652 |
+
reward_breakdown = {
|
| 653 |
+
"final_f1": detailed["f1"],
|
| 654 |
+
"severity_accuracy": detailed["severity_accuracy"],
|
| 655 |
+
"final_score": final_score,
|
| 656 |
+
"total": final_score,
|
| 657 |
+
}
|
| 658 |
+
return final_score, feedback, reward_breakdown
|
server/graders.py
CHANGED
|
@@ -1,10 +1,21 @@
|
|
| 1 |
"""
|
| 2 |
Grading logic for the Code Review Environment.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
"""
|
| 4 |
from __future__ import annotations
|
| 5 |
|
|
|
|
|
|
|
| 6 |
import re
|
| 7 |
-
from typing import List, Tuple, Set
|
| 8 |
|
| 9 |
import sys
|
| 10 |
import os
|
|
@@ -21,8 +32,18 @@ _TYPE_COMPAT = {
|
|
| 21 |
"performance": {"performance"},
|
| 22 |
}
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
-
|
|
|
|
|
|
|
| 26 |
if flagged.filename != gt.filename:
|
| 27 |
return False
|
| 28 |
if abs(flagged.line_number - gt.line_number) > line_tolerance:
|
|
@@ -33,6 +54,274 @@ def match_issue(flagged: Issue, gt: Issue, line_tolerance: int = 2) -> bool:
|
|
| 33 |
return True
|
| 34 |
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
def grade_episode(
|
| 37 |
flagged: List[Issue],
|
| 38 |
ground_truth: List[Issue],
|
|
@@ -79,6 +368,105 @@ def grade_episode(
|
|
| 79 |
return round(min(1.0, max(0.0, final)), 4)
|
| 80 |
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
def compute_live_score(flagged: List[Issue], ground_truth: List[Issue]) -> float:
|
| 83 |
"""F1-only score for per-step feedback (no severity bonus)."""
|
| 84 |
if not ground_truth:
|
|
@@ -107,6 +495,7 @@ def compute_live_score(flagged: List[Issue], ground_truth: List[Issue]) -> float
|
|
| 107 |
|
| 108 |
|
| 109 |
_PATTERNS = [
|
|
|
|
| 110 |
(r"range\(len\(\w+\)\s*\+\s*1\)", None, "bug", "high",
|
| 111 |
"Off-by-one error: range(len(x) + 1) iterates one past the end"),
|
| 112 |
(r"left,\s*right\s*=\s*0,\s*len\(", None, "bug", "medium",
|
|
@@ -114,30 +503,81 @@ _PATTERNS = [
|
|
| 114 |
(r"counts\[word\]\s*=\s*0\b", None, "bug", "low",
|
| 115 |
"Counter initialized to 0 instead of 1"),
|
| 116 |
|
|
|
|
| 117 |
(r'SECRET_KEY\s*=\s*["\']', None, "security", "high",
|
| 118 |
"Hardcoded SECRET_KEY in source code"),
|
|
|
|
|
|
|
| 119 |
(r'PASSWORD\s*=\s*["\']', None, "security", "high",
|
| 120 |
"Hardcoded password in source code"),
|
|
|
|
|
|
|
| 121 |
(r"f['\"].*SELECT.*\{", None, "security", "critical",
|
| 122 |
"SQL injection via f-string query construction"),
|
|
|
|
|
|
|
| 123 |
(r"f['\"].*DELETE.*\{", None, "security", "critical",
|
| 124 |
"SQL injection via f-string DELETE query"),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
(r"render_template_string\(f['\"]", None, "security", "high",
|
| 126 |
"XSS: unsanitized user input in render_template_string"),
|
| 127 |
(r"shell\s*=\s*True", None, "security", "critical",
|
| 128 |
"Command injection risk: shell=True with user input"),
|
| 129 |
-
(r"
|
| 130 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
(r"expected\s*==\s*\w+_hash", None, "security", "medium",
|
| 132 |
"Timing attack: use hmac.compare_digest() for constant-time comparison"),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
(r"password\s*=\s*models\.CharField", None, "security", "critical",
|
| 134 |
"Plaintext password storage in database"),
|
| 135 |
-
(r"os\.path\.join\(['\"]\/", None, "security", "high",
|
| 136 |
-
"Path traversal: os.path.join with absolute prefix doesn't prevent traversal"),
|
| 137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
(r"\.objects\.get\(id=item\.", None, "performance", "high",
|
| 139 |
"N+1 query: database lookup inside a loop"),
|
| 140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
(r"FloatField\(\)", None, "bug", "medium",
|
| 142 |
"FloatField for monetary values causes precision errors, use DecimalField"),
|
| 143 |
(r"BinaryField\(\)", None, "security", "high",
|
|
|
|
| 1 |
"""
|
| 2 |
Grading logic for the Code Review Environment.
|
| 3 |
+
|
| 4 |
+
Reward design is grounded in:
|
| 5 |
+
- Potential-Based Reward Shaping (PBRS): Ng et al. 1999
|
| 6 |
+
R_shaped(s,a,s') = R(s,a,s') + γ·Φ(s') - Φ(s)
|
| 7 |
+
where Φ(s) = (tp_found / total_gt) · POTENTIAL_SCALE
|
| 8 |
+
- Graduated line-proximity rewards: exponential decay over line distance
|
| 9 |
+
reward = BASE_TP · exp(-DECAY · max(0, line_diff - EXACT_TOLERANCE))
|
| 10 |
+
for 0 < line_diff ≤ NEAR_TOLERANCE
|
| 11 |
+
- F1-based terminal scoring: 0.70·F1 + 0.30·severity_accuracy
|
| 12 |
"""
|
| 13 |
from __future__ import annotations
|
| 14 |
|
| 15 |
+
import ast
|
| 16 |
+
import math
|
| 17 |
import re
|
| 18 |
+
from typing import List, Tuple, Set, Dict, Optional
|
| 19 |
|
| 20 |
import sys
|
| 21 |
import os
|
|
|
|
| 32 |
"performance": {"performance"},
|
| 33 |
}
|
| 34 |
|
| 35 |
+
# Tolerances
|
| 36 |
+
NEAR_TOLERANCE = 5
|
| 37 |
+
EXACT_TOLERANCE = 2
|
| 38 |
+
|
| 39 |
+
# Graduated reward constants (PBRS + smooth near-miss)
|
| 40 |
+
BASE_TP_REWARD = 0.10
|
| 41 |
+
NEAR_DECAY = 0.6 # exponential decay per line beyond EXACT_TOLERANCE
|
| 42 |
+
POTENTIAL_SCALE = 0.5 # Φ(s) = (tp/total_gt) * POTENTIAL_SCALE
|
| 43 |
|
| 44 |
+
|
| 45 |
+
def match_issue(flagged: Issue, gt: Issue, line_tolerance: int = EXACT_TOLERANCE, near_tolerance: int = NEAR_TOLERANCE) -> bool:
|
| 46 |
+
"""Return True if flagged matches gt within line_tolerance lines and same type."""
|
| 47 |
if flagged.filename != gt.filename:
|
| 48 |
return False
|
| 49 |
if abs(flagged.line_number - gt.line_number) > line_tolerance:
|
|
|
|
| 54 |
return True
|
| 55 |
|
| 56 |
|
| 57 |
+
def match_quality(flagged: Issue, gt: Issue) -> str:
|
| 58 |
+
"""
|
| 59 |
+
Return quality of match between flagged and gt:
|
| 60 |
+
"exact" — within ±2 lines and right issue type
|
| 61 |
+
"near" — within ±3-5 lines and same file (regardless of type)
|
| 62 |
+
"none" — no meaningful match
|
| 63 |
+
"""
|
| 64 |
+
if flagged.filename != gt.filename:
|
| 65 |
+
return "none"
|
| 66 |
+
|
| 67 |
+
line_diff = abs(flagged.line_number - gt.line_number)
|
| 68 |
+
|
| 69 |
+
if line_diff <= EXACT_TOLERANCE:
|
| 70 |
+
compat = _TYPE_COMPAT.get(gt.issue_type, {gt.issue_type})
|
| 71 |
+
if flagged.issue_type in compat:
|
| 72 |
+
return "exact"
|
| 73 |
+
|
| 74 |
+
if line_diff <= NEAR_TOLERANCE:
|
| 75 |
+
return "near"
|
| 76 |
+
|
| 77 |
+
return "none"
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def graduated_near_reward(line_diff: int) -> float:
|
| 81 |
+
"""
|
| 82 |
+
Graduated reward for near-miss flags using exponential decay.
|
| 83 |
+
|
| 84 |
+
Implements continuous reward shaping based on proximity:
|
| 85 |
+
line_diff = 0-2 → 0.10 (full TP, handled separately)
|
| 86 |
+
line_diff = 3 → 0.10 * exp(-0.6*1) ≈ 0.055
|
| 87 |
+
line_diff = 4 → 0.10 * exp(-0.6*2) ≈ 0.033
|
| 88 |
+
line_diff = 5 → 0.10 * exp(-0.6*3) ≈ 0.020
|
| 89 |
+
|
| 90 |
+
This gives smooth gradient signal rather than a hard 0.03 step function,
|
| 91 |
+
encouraging the agent to refine line numbers progressively.
|
| 92 |
+
"""
|
| 93 |
+
if line_diff <= EXACT_TOLERANCE:
|
| 94 |
+
return BASE_TP_REWARD
|
| 95 |
+
extra = line_diff - EXACT_TOLERANCE
|
| 96 |
+
return round(BASE_TP_REWARD * math.exp(-NEAR_DECAY * extra), 4)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def compute_potential(tp_count: int, total_gt: int) -> float:
|
| 100 |
+
"""
|
| 101 |
+
Potential function Φ(s) for Potential-Based Reward Shaping (PBRS).
|
| 102 |
+
|
| 103 |
+
Φ(s) = (tp_found / total_gt) * POTENTIAL_SCALE
|
| 104 |
+
|
| 105 |
+
The shaped reward R_shaped = r + Φ(s') - Φ(s) ensures policy invariance
|
| 106 |
+
(Ng et al. 1999): the optimal policy under shaped rewards is the same as
|
| 107 |
+
under the original rewards, but with better intermediate gradient signal.
|
| 108 |
+
|
| 109 |
+
Here we compute just Φ(s); the caller computes Φ(s') - Φ(s).
|
| 110 |
+
"""
|
| 111 |
+
if total_gt <= 0:
|
| 112 |
+
return 0.0
|
| 113 |
+
return (tp_count / total_gt) * POTENTIAL_SCALE
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def compute_function_map(code: str) -> Dict[int, str]:
|
| 117 |
+
"""
|
| 118 |
+
Map each line number to the name of its enclosing function (or class method).
|
| 119 |
+
Lines outside any function map to "module". Non-parseable code returns empty dict.
|
| 120 |
+
"""
|
| 121 |
+
result: Dict[int, str] = {}
|
| 122 |
+
try:
|
| 123 |
+
tree = ast.parse(code)
|
| 124 |
+
for node in ast.walk(tree):
|
| 125 |
+
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
| 126 |
+
end = getattr(node, "end_lineno", node.lineno)
|
| 127 |
+
for lineno in range(node.lineno, end + 1):
|
| 128 |
+
result[lineno] = node.name
|
| 129 |
+
except SyntaxError:
|
| 130 |
+
pass
|
| 131 |
+
return result
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def compute_code_metadata(code_files: Dict[str, str], issue_categories: Optional[List[str]] = None) -> Dict:
|
| 135 |
+
"""
|
| 136 |
+
Extract code structure metadata using Python's ast module.
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
total_lines, num_functions, function_names, num_classes, class_names,
|
| 140 |
+
imports, complexity_estimate, issue_categories, function_ranges
|
| 141 |
+
"""
|
| 142 |
+
total_lines = 0
|
| 143 |
+
num_functions = 0
|
| 144 |
+
function_names: List[str] = []
|
| 145 |
+
num_classes = 0
|
| 146 |
+
class_names: List[str] = []
|
| 147 |
+
imports: List[str] = []
|
| 148 |
+
branch_count = 0
|
| 149 |
+
function_ranges: List[Dict] = [] # [{name, file, start, end}]
|
| 150 |
+
|
| 151 |
+
for filename, code in code_files.items():
|
| 152 |
+
lines = code.splitlines()
|
| 153 |
+
total_lines += len(lines)
|
| 154 |
+
try:
|
| 155 |
+
tree = ast.parse(code)
|
| 156 |
+
for node in ast.walk(tree):
|
| 157 |
+
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
| 158 |
+
num_functions += 1
|
| 159 |
+
function_names.append(node.name)
|
| 160 |
+
end = getattr(node, "end_lineno", node.lineno)
|
| 161 |
+
function_ranges.append({
|
| 162 |
+
"name": node.name,
|
| 163 |
+
"file": filename,
|
| 164 |
+
"start": node.lineno,
|
| 165 |
+
"end": end,
|
| 166 |
+
})
|
| 167 |
+
elif isinstance(node, ast.ClassDef):
|
| 168 |
+
num_classes += 1
|
| 169 |
+
class_names.append(node.name)
|
| 170 |
+
elif isinstance(node, ast.Import):
|
| 171 |
+
for alias in node.names:
|
| 172 |
+
imports.append(alias.name.split(".")[0])
|
| 173 |
+
elif isinstance(node, ast.ImportFrom):
|
| 174 |
+
if node.module:
|
| 175 |
+
imports.append(node.module.split(".")[0])
|
| 176 |
+
elif isinstance(node, (ast.If, ast.For, ast.While, ast.Try,
|
| 177 |
+
ast.ExceptHandler, ast.With)):
|
| 178 |
+
branch_count += 1
|
| 179 |
+
except SyntaxError:
|
| 180 |
+
# If ast can't parse (e.g. non-Python file), just count lines
|
| 181 |
+
pass
|
| 182 |
+
|
| 183 |
+
# Deduplicate imports
|
| 184 |
+
imports = list(dict.fromkeys(imports))
|
| 185 |
+
|
| 186 |
+
# Complexity estimate
|
| 187 |
+
if branch_count <= 5:
|
| 188 |
+
complexity_estimate = "low"
|
| 189 |
+
elif branch_count <= 15:
|
| 190 |
+
complexity_estimate = "medium"
|
| 191 |
+
else:
|
| 192 |
+
complexity_estimate = "high"
|
| 193 |
+
|
| 194 |
+
return {
|
| 195 |
+
"total_lines": total_lines,
|
| 196 |
+
"num_functions": num_functions,
|
| 197 |
+
"function_names": function_names,
|
| 198 |
+
"num_classes": num_classes,
|
| 199 |
+
"class_names": class_names,
|
| 200 |
+
"imports": imports,
|
| 201 |
+
"complexity_estimate": complexity_estimate,
|
| 202 |
+
"issue_categories": list(set(issue_categories)) if issue_categories else [],
|
| 203 |
+
"function_ranges": function_ranges,
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def compute_code_state_features(
|
| 208 |
+
code_metadata: Dict,
|
| 209 |
+
progress: Optional[Dict] = None,
|
| 210 |
+
) -> List[float]:
|
| 211 |
+
"""
|
| 212 |
+
Compute a normalized 12-dimensional feature vector for RL training.
|
| 213 |
+
|
| 214 |
+
Based on state representation research (code2vec, GraphCodeBERT, 2023-2024),
|
| 215 |
+
combining AST-derived structural features with episode progress metrics.
|
| 216 |
+
This vector is suitable as input to a policy network or value estimator.
|
| 217 |
+
|
| 218 |
+
Dimensions:
|
| 219 |
+
0: total_lines / 200 — code size (normalized)
|
| 220 |
+
1: num_functions / 20 — function count
|
| 221 |
+
2: num_classes / 10 — class count
|
| 222 |
+
3: complexity_score — 0=low, 0.5=medium, 1.0=high
|
| 223 |
+
4: has_bug_issues — 1 if "bug" in issue_categories
|
| 224 |
+
5: has_security_issues — 1 if "security" in issue_categories
|
| 225 |
+
6: has_performance_issues — 1 if "performance" in issue_categories
|
| 226 |
+
7: has_logic_issues — 1 if "logic" in issue_categories
|
| 227 |
+
8: progress_recall — tp / total_gt (0 if no progress yet)
|
| 228 |
+
9: progress_precision — precision so far
|
| 229 |
+
10: steps_used_frac — steps_used / max_steps
|
| 230 |
+
11: fp_pressure — false_positives / max(total_flagged, 1)
|
| 231 |
+
"""
|
| 232 |
+
if progress is None:
|
| 233 |
+
progress = {}
|
| 234 |
+
|
| 235 |
+
complexity_map = {"low": 0.0, "medium": 0.5, "high": 1.0}
|
| 236 |
+
cats = set(code_metadata.get("issue_categories", []))
|
| 237 |
+
|
| 238 |
+
total_gt = float(progress.get("total_ground_truth", 1.0)) or 1.0
|
| 239 |
+
tp = float(progress.get("true_positives", 0.0))
|
| 240 |
+
fp = float(progress.get("false_positives", 0.0))
|
| 241 |
+
total_flagged = tp + fp
|
| 242 |
+
steps_used = float(progress.get("steps_used", 0.0))
|
| 243 |
+
steps_rem = float(progress.get("steps_remaining", 1.0))
|
| 244 |
+
max_steps = steps_used + steps_rem or 1.0
|
| 245 |
+
|
| 246 |
+
features = [
|
| 247 |
+
min(1.0, code_metadata.get("total_lines", 0) / 200.0),
|
| 248 |
+
min(1.0, code_metadata.get("num_functions", 0) / 20.0),
|
| 249 |
+
min(1.0, code_metadata.get("num_classes", 0) / 10.0),
|
| 250 |
+
complexity_map.get(code_metadata.get("complexity_estimate", "low"), 0.0),
|
| 251 |
+
1.0 if "bug" in cats else 0.0,
|
| 252 |
+
1.0 if "security" in cats else 0.0,
|
| 253 |
+
1.0 if "performance" in cats else 0.0,
|
| 254 |
+
1.0 if "logic" in cats else 0.0,
|
| 255 |
+
min(1.0, tp / total_gt),
|
| 256 |
+
min(1.0, tp / total_flagged) if total_flagged > 0 else 0.0,
|
| 257 |
+
min(1.0, steps_used / max_steps),
|
| 258 |
+
min(1.0, fp / total_flagged) if total_flagged > 0 else 0.0,
|
| 259 |
+
]
|
| 260 |
+
return [round(f, 4) for f in features]
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class RewardNormalizer:
|
| 264 |
+
"""
|
| 265 |
+
Variable-Length Return Normalizer for multi-task RL training.
|
| 266 |
+
|
| 267 |
+
Based on VL Norm (2025) and Return-based Scaling (2021):
|
| 268 |
+
Normalizes episode returns accounting for variable episode lengths,
|
| 269 |
+
preventing long episodes from dominating gradient computation.
|
| 270 |
+
|
| 271 |
+
Usage:
|
| 272 |
+
normalizer = RewardNormalizer(window_size=100)
|
| 273 |
+
# After each episode:
|
| 274 |
+
normalizer.update(episode_return, episode_length)
|
| 275 |
+
normalized_r = normalizer.normalize(episode_return, episode_length)
|
| 276 |
+
"""
|
| 277 |
+
|
| 278 |
+
def __init__(self, window_size: int = 100, eps: float = 1e-8) -> None:
|
| 279 |
+
self.window_size = window_size
|
| 280 |
+
self.eps = eps
|
| 281 |
+
self._returns: List[float] = []
|
| 282 |
+
self._lengths: List[int] = []
|
| 283 |
+
self.mean: float = 0.0
|
| 284 |
+
self.std: float = 1.0
|
| 285 |
+
|
| 286 |
+
def update(self, episode_return: float, episode_length: int) -> None:
|
| 287 |
+
"""Record a completed episode for running statistics."""
|
| 288 |
+
self._returns.append(episode_return)
|
| 289 |
+
self._lengths.append(max(1, episode_length))
|
| 290 |
+
if len(self._returns) > self.window_size:
|
| 291 |
+
self._returns.pop(0)
|
| 292 |
+
self._lengths.pop(0)
|
| 293 |
+
self._recompute()
|
| 294 |
+
|
| 295 |
+
def _recompute(self) -> None:
|
| 296 |
+
if len(self._returns) < 2:
|
| 297 |
+
return
|
| 298 |
+
returns = [r for r in self._returns]
|
| 299 |
+
lengths = [l for l in self._lengths]
|
| 300 |
+
mean_len = sum(lengths) / len(lengths)
|
| 301 |
+
# Length-adjusted std: longer episodes have proportionally less weight
|
| 302 |
+
self.mean = sum(returns) / len(returns)
|
| 303 |
+
raw_std = (sum((r - self.mean) ** 2 for r in returns) / len(returns)) ** 0.5
|
| 304 |
+
length_factors = [(l / mean_len) ** 0.5 for l in lengths]
|
| 305 |
+
avg_lf = sum(length_factors) / len(length_factors)
|
| 306 |
+
self.std = max(self.eps, raw_std * avg_lf)
|
| 307 |
+
|
| 308 |
+
def normalize(self, episode_return: float, episode_length: int) -> float:
|
| 309 |
+
"""Return the length-adjusted normalized return."""
|
| 310 |
+
if len(self._returns) < 2:
|
| 311 |
+
return episode_return
|
| 312 |
+
mean_len = sum(self._lengths) / len(self._lengths)
|
| 313 |
+
length_factor = (max(1, episode_length) / mean_len) ** 0.5
|
| 314 |
+
return round((episode_return - self.mean) / (self.std * length_factor + self.eps), 4)
|
| 315 |
+
|
| 316 |
+
def to_dict(self) -> Dict:
|
| 317 |
+
return {
|
| 318 |
+
"mean": round(self.mean, 4),
|
| 319 |
+
"std": round(self.std, 4),
|
| 320 |
+
"n_episodes": len(self._returns),
|
| 321 |
+
"window_size": self.window_size,
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
|
| 325 |
def grade_episode(
|
| 326 |
flagged: List[Issue],
|
| 327 |
ground_truth: List[Issue],
|
|
|
|
| 368 |
return round(min(1.0, max(0.0, final)), 4)
|
| 369 |
|
| 370 |
|
| 371 |
+
def grade_episode_detailed(
|
| 372 |
+
flagged: List[Issue],
|
| 373 |
+
ground_truth: List[Issue],
|
| 374 |
+
line_tolerance: int = 2,
|
| 375 |
+
) -> Dict:
|
| 376 |
+
"""
|
| 377 |
+
Full breakdown of grading results.
|
| 378 |
+
|
| 379 |
+
Returns:
|
| 380 |
+
score, f1, precision, recall, severity_accuracy,
|
| 381 |
+
true_positives, false_positives, false_negatives,
|
| 382 |
+
near_misses, per_file
|
| 383 |
+
"""
|
| 384 |
+
if not ground_truth:
|
| 385 |
+
score = 1.0 if not flagged else 0.0
|
| 386 |
+
return {
|
| 387 |
+
"score": score,
|
| 388 |
+
"f1": score,
|
| 389 |
+
"precision": score,
|
| 390 |
+
"recall": score,
|
| 391 |
+
"severity_accuracy": score,
|
| 392 |
+
"true_positives": 0,
|
| 393 |
+
"false_positives": len(flagged),
|
| 394 |
+
"false_negatives": 0,
|
| 395 |
+
"near_misses": 0,
|
| 396 |
+
"per_file": {},
|
| 397 |
+
}
|
| 398 |
+
|
| 399 |
+
tp = 0
|
| 400 |
+
fp = 0
|
| 401 |
+
near_misses = 0
|
| 402 |
+
matched_gt_indices: Set[int] = set()
|
| 403 |
+
severity_scores: List[float] = []
|
| 404 |
+
per_file: Dict[str, Dict] = {}
|
| 405 |
+
|
| 406 |
+
for flag in flagged:
|
| 407 |
+
fname = flag.filename
|
| 408 |
+
if fname not in per_file:
|
| 409 |
+
per_file[fname] = {"tp": 0, "fp": 0, "near_miss": 0}
|
| 410 |
+
|
| 411 |
+
matched = False
|
| 412 |
+
for i, gt in enumerate(ground_truth):
|
| 413 |
+
if i in matched_gt_indices:
|
| 414 |
+
continue
|
| 415 |
+
if match_issue(flag, gt, line_tolerance):
|
| 416 |
+
tp += 1
|
| 417 |
+
matched_gt_indices.add(i)
|
| 418 |
+
matched = True
|
| 419 |
+
per_file[fname]["tp"] += 1
|
| 420 |
+
flag_rank = _SEV_RANK.get(flag.severity, 1)
|
| 421 |
+
gt_rank = _SEV_RANK.get(gt.severity, 1)
|
| 422 |
+
distance = abs(flag_rank - gt_rank)
|
| 423 |
+
severity_scores.append(max(0.0, 1.0 - distance * 0.34))
|
| 424 |
+
break
|
| 425 |
+
|
| 426 |
+
if not matched:
|
| 427 |
+
# Check for near miss (3-5 lines off, same file)
|
| 428 |
+
is_near = False
|
| 429 |
+
for i, gt in enumerate(ground_truth):
|
| 430 |
+
if i in matched_gt_indices:
|
| 431 |
+
continue
|
| 432 |
+
q = match_quality(flag, gt)
|
| 433 |
+
if q == "near":
|
| 434 |
+
is_near = True
|
| 435 |
+
break
|
| 436 |
+
if is_near:
|
| 437 |
+
near_misses += 1
|
| 438 |
+
per_file[fname]["near_miss"] += 1
|
| 439 |
+
else:
|
| 440 |
+
fp += 1
|
| 441 |
+
per_file[fname]["fp"] += 1
|
| 442 |
+
|
| 443 |
+
fn = len(ground_truth) - len(matched_gt_indices)
|
| 444 |
+
|
| 445 |
+
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
|
| 446 |
+
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
|
| 447 |
+
f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0
|
| 448 |
+
|
| 449 |
+
if severity_scores:
|
| 450 |
+
severity_accuracy = sum(severity_scores) / len(ground_truth)
|
| 451 |
+
else:
|
| 452 |
+
severity_accuracy = 0.0
|
| 453 |
+
|
| 454 |
+
score = round(min(1.0, max(0.0, 0.70 * f1 + 0.30 * severity_accuracy)), 4)
|
| 455 |
+
|
| 456 |
+
return {
|
| 457 |
+
"score": score,
|
| 458 |
+
"f1": round(f1, 4),
|
| 459 |
+
"precision": round(precision, 4),
|
| 460 |
+
"recall": round(recall, 4),
|
| 461 |
+
"severity_accuracy": round(severity_accuracy, 4),
|
| 462 |
+
"true_positives": tp,
|
| 463 |
+
"false_positives": fp,
|
| 464 |
+
"false_negatives": fn,
|
| 465 |
+
"near_misses": near_misses,
|
| 466 |
+
"per_file": per_file,
|
| 467 |
+
}
|
| 468 |
+
|
| 469 |
+
|
| 470 |
def compute_live_score(flagged: List[Issue], ground_truth: List[Issue]) -> float:
|
| 471 |
"""F1-only score for per-step feedback (no severity bonus)."""
|
| 472 |
if not ground_truth:
|
|
|
|
| 495 |
|
| 496 |
|
| 497 |
_PATTERNS = [
|
| 498 |
+
# --- Bug patterns ---
|
| 499 |
(r"range\(len\(\w+\)\s*\+\s*1\)", None, "bug", "high",
|
| 500 |
"Off-by-one error: range(len(x) + 1) iterates one past the end"),
|
| 501 |
(r"left,\s*right\s*=\s*0,\s*len\(", None, "bug", "medium",
|
|
|
|
| 503 |
(r"counts\[word\]\s*=\s*0\b", None, "bug", "low",
|
| 504 |
"Counter initialized to 0 instead of 1"),
|
| 505 |
|
| 506 |
+
# --- Hardcoded secrets ---
|
| 507 |
(r'SECRET_KEY\s*=\s*["\']', None, "security", "high",
|
| 508 |
"Hardcoded SECRET_KEY in source code"),
|
| 509 |
+
(r'ADMIN_TOKEN\s*=\s*["\']', None, "security", "high",
|
| 510 |
+
"Hardcoded ADMIN_TOKEN in source code"),
|
| 511 |
(r'PASSWORD\s*=\s*["\']', None, "security", "high",
|
| 512 |
"Hardcoded password in source code"),
|
| 513 |
+
|
| 514 |
+
# --- Injection attacks ---
|
| 515 |
(r"f['\"].*SELECT.*\{", None, "security", "critical",
|
| 516 |
"SQL injection via f-string query construction"),
|
| 517 |
+
(r"f['\"].*INSERT.*\{", None, "security", "critical",
|
| 518 |
+
"SQL injection via f-string INSERT query"),
|
| 519 |
(r"f['\"].*DELETE.*\{", None, "security", "critical",
|
| 520 |
"SQL injection via f-string DELETE query"),
|
| 521 |
+
(r"f['\"].*LIKE.*%\{", None, "security", "critical",
|
| 522 |
+
"SQL injection via f-string LIKE clause"),
|
| 523 |
+
(r"LIMIT\s*\{", None, "security", "critical",
|
| 524 |
+
"SQL injection: LIMIT clause uses unparameterized variable"),
|
| 525 |
(r"render_template_string\(f['\"]", None, "security", "high",
|
| 526 |
"XSS: unsanitized user input in render_template_string"),
|
| 527 |
(r"shell\s*=\s*True", None, "security", "critical",
|
| 528 |
"Command injection risk: shell=True with user input"),
|
| 529 |
+
(r"os\.system\(", None, "security", "critical",
|
| 530 |
+
"Command injection risk: os.system() executes shell commands"),
|
| 531 |
+
(r"os\.path\.join\(['\"]\/", None, "security", "high",
|
| 532 |
+
"Path traversal: os.path.join with absolute prefix doesn't prevent traversal"),
|
| 533 |
+
|
| 534 |
+
# --- Broken cryptography ---
|
| 535 |
+
(r"hashlib\.md5\(", None, "security", "high",
|
| 536 |
+
"MD5 is cryptographically broken for security use; use SHA-256 or bcrypt"),
|
| 537 |
+
(r"hashlib\.sha1\(", None, "security", "medium",
|
| 538 |
+
"SHA-1 is deprecated for security use; use SHA-256 or better"),
|
| 539 |
(r"expected\s*==\s*\w+_hash", None, "security", "medium",
|
| 540 |
"Timing attack: use hmac.compare_digest() for constant-time comparison"),
|
| 541 |
+
|
| 542 |
+
# --- Dangerous deserialization ---
|
| 543 |
+
(r"pickle\.loads\(", None, "security", "critical",
|
| 544 |
+
"Unsafe deserialization: pickle.loads() on untrusted data allows remote code execution"),
|
| 545 |
+
(r"yaml\.load\(", None, "security", "high",
|
| 546 |
+
"Unsafe YAML deserialization: use yaml.safe_load() instead"),
|
| 547 |
+
|
| 548 |
+
# --- Auth / access control ---
|
| 549 |
(r"password\s*=\s*models\.CharField", None, "security", "critical",
|
| 550 |
"Plaintext password storage in database"),
|
|
|
|
|
|
|
| 551 |
|
| 552 |
+
# --- Async / concurrency bugs ---
|
| 553 |
+
(r"aiohttp\.ClientSession\(\)", None, "bug", "high",
|
| 554 |
+
"ClientSession created outside 'async with' — may not be closed (resource leak)"),
|
| 555 |
+
(r"timeout\s*=\s*\d+\b", None, "bug", "medium",
|
| 556 |
+
"aiohttp timeout should be aiohttp.ClientTimeout(total=N), not a bare integer"),
|
| 557 |
+
(r"attempt\s*==\s*retries\b", None, "bug", "high",
|
| 558 |
+
"Off-by-one: range(retries) yields 0..retries-1, so attempt==retries is never true"),
|
| 559 |
+
(r"for\s+\w+\s+in\s+\w+_ids\s*:", None, "performance", "high",
|
| 560 |
+
"Sequential loop over IDs — consider asyncio.gather() for concurrent fetching"),
|
| 561 |
+
|
| 562 |
+
# --- Performance ---
|
| 563 |
(r"\.objects\.get\(id=item\.", None, "performance", "high",
|
| 564 |
"N+1 query: database lookup inside a loop"),
|
| 565 |
|
| 566 |
+
# --- JavaScript-specific patterns ---
|
| 567 |
+
(r"new\s+Function\(", None, "security", "critical",
|
| 568 |
+
"Unsafe dynamic code execution: new Function() with user input is equivalent to eval()"),
|
| 569 |
+
(r"\beval\(", None, "security", "critical",
|
| 570 |
+
"eval() with user-supplied input allows arbitrary code execution"),
|
| 571 |
+
(r"execSync\(", None, "security", "critical",
|
| 572 |
+
"Command injection risk: execSync() with user-supplied data"),
|
| 573 |
+
(r"jwt\.sign\(.*\{(?!.*expiresIn)", None, "security", "medium",
|
| 574 |
+
"JWT issued without expiry (expiresIn) — tokens are valid forever"),
|
| 575 |
+
(r"JWT_SECRET\s*=\s*['\"]", None, "security", "high",
|
| 576 |
+
"Hardcoded JWT secret in source code"),
|
| 577 |
+
(r"res\.send\(`.*\$\{", None, "security", "high",
|
| 578 |
+
"XSS: template literal with user input sent directly in response"),
|
| 579 |
+
|
| 580 |
+
# --- Data model bugs ---
|
| 581 |
(r"FloatField\(\)", None, "bug", "medium",
|
| 582 |
"FloatField for monetary values causes precision errors, use DecimalField"),
|
| 583 |
(r"BinaryField\(\)", None, "security", "high",
|
tasks/data.py
CHANGED
|
@@ -418,10 +418,533 @@ TASK_COMPREHENSIVE: Dict[str, Any] = {
|
|
| 418 |
}
|
| 419 |
|
| 420 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 421 |
ALL_TASKS: Dict[str, Dict[str, Any]] = {
|
| 422 |
TASK_BUG_DETECTION["task_id"]: TASK_BUG_DETECTION,
|
| 423 |
TASK_SECURITY_AUDIT["task_id"]: TASK_SECURITY_AUDIT,
|
| 424 |
TASK_COMPREHENSIVE["task_id"]: TASK_COMPREHENSIVE,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 425 |
}
|
| 426 |
|
| 427 |
TASK_IDS: List[str] = list(ALL_TASKS.keys())
|
|
|
|
| 418 |
}
|
| 419 |
|
| 420 |
|
| 421 |
+
_ASYNC_CODE = """\
|
| 422 |
+
import asyncio
|
| 423 |
+
import aiohttp
|
| 424 |
+
from typing import List, Optional
|
| 425 |
+
|
| 426 |
+
_cache: dict = {}
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
async def fetch_json(url: str, session: aiohttp.ClientSession) -> dict:
|
| 430 |
+
async with session.get(url, timeout=5) as resp:
|
| 431 |
+
return await resp.json()
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
async def get_user(user_id: int, session: aiohttp.ClientSession) -> dict:
|
| 435 |
+
if user_id in _cache:
|
| 436 |
+
return _cache[user_id]
|
| 437 |
+
data = await fetch_json(f"https://api.example.com/users/{user_id}", session)
|
| 438 |
+
_cache[user_id] = data
|
| 439 |
+
return data
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
async def process_users(user_ids: List[int]) -> List[dict]:
|
| 443 |
+
session = aiohttp.ClientSession()
|
| 444 |
+
results = []
|
| 445 |
+
for uid in user_ids:
|
| 446 |
+
result = await get_user(uid, session)
|
| 447 |
+
results.append(result)
|
| 448 |
+
return results
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
async def run_with_retry(url: str, retries: int = 3) -> Optional[str]:
|
| 452 |
+
for attempt in range(retries):
|
| 453 |
+
try:
|
| 454 |
+
async with aiohttp.ClientSession() as session:
|
| 455 |
+
async with session.get(url) as resp:
|
| 456 |
+
return await resp.text()
|
| 457 |
+
except Exception:
|
| 458 |
+
if attempt == retries:
|
| 459 |
+
raise
|
| 460 |
+
return None
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
class TaskRunner:
|
| 464 |
+
def __init__(self, concurrency: int = 5):
|
| 465 |
+
self.concurrency = concurrency
|
| 466 |
+
self.results = []
|
| 467 |
+
|
| 468 |
+
async def run_all(self, tasks: List) -> List:
|
| 469 |
+
for task in tasks:
|
| 470 |
+
result = await task
|
| 471 |
+
self.results.append(result)
|
| 472 |
+
return self.results
|
| 473 |
+
"""
|
| 474 |
+
|
| 475 |
+
TASK_ASYNC_REVIEW: Dict[str, Any] = {
|
| 476 |
+
"task_id": "async-review",
|
| 477 |
+
"difficulty": "medium-hard",
|
| 478 |
+
"description": (
|
| 479 |
+
"Review this async Python module for concurrency bugs, resource leaks,\n"
|
| 480 |
+
"and performance issues with asyncio and aiohttp.\n"
|
| 481 |
+
"The code has subtle async-specific bugs that would cause failures or\n"
|
| 482 |
+
"degraded performance in production. Identify all issues with exact\n"
|
| 483 |
+
"line numbers, types, and severity.\n\n"
|
| 484 |
+
"File to review: async.py"
|
| 485 |
+
),
|
| 486 |
+
"language": "python",
|
| 487 |
+
"code_files": {
|
| 488 |
+
"async.py": _ASYNC_CODE,
|
| 489 |
+
},
|
| 490 |
+
"ground_truth_issues": [
|
| 491 |
+
_issue(
|
| 492 |
+
5, "async.py", "bug", "high",
|
| 493 |
+
"Shared mutable dict without asyncio.Lock; concurrent coroutines can read "
|
| 494 |
+
"stale data or overwrite each other's writes. Use async with _lock: around "
|
| 495 |
+
"cache check and write.",
|
| 496 |
+
"Add _lock = asyncio.Lock() and use: async with _lock: around cache check and write."
|
| 497 |
+
),
|
| 498 |
+
_issue(
|
| 499 |
+
9, "async.py", "bug", "medium",
|
| 500 |
+
"timeout=5 is wrong type for aiohttp; requires aiohttp.ClientTimeout(total=5). "
|
| 501 |
+
"Passing an int raises TypeError at runtime.",
|
| 502 |
+
"Use: timeout=aiohttp.ClientTimeout(total=5)"
|
| 503 |
+
),
|
| 504 |
+
_issue(
|
| 505 |
+
22, "async.py", "bug", "high",
|
| 506 |
+
"ClientSession created but never closed, causing resource leak. "
|
| 507 |
+
"Use: async with aiohttp.ClientSession() as session: and pass it in.",
|
| 508 |
+
"Replace with: async with aiohttp.ClientSession() as session:"
|
| 509 |
+
),
|
| 510 |
+
_issue(
|
| 511 |
+
24, "async.py", "performance", "high",
|
| 512 |
+
"Sequential for loop with await serializes all requests. "
|
| 513 |
+
"Use asyncio.gather(*[get_user(uid, session) for uid in user_ids]) "
|
| 514 |
+
"for true concurrency.",
|
| 515 |
+
"Replace loop with: results = await asyncio.gather(*[get_user(uid, session) for uid in user_ids])"
|
| 516 |
+
),
|
| 517 |
+
_issue(
|
| 518 |
+
37, "async.py", "bug", "high",
|
| 519 |
+
"Off-by-one: range(retries) yields 0..retries-1, so attempt==retries is never true. "
|
| 520 |
+
"Exception is never re-raised. Fix: attempt == retries - 1.",
|
| 521 |
+
"Change: if attempt == retries - 1: raise"
|
| 522 |
+
),
|
| 523 |
+
_issue(
|
| 524 |
+
48, "async.py", "performance", "medium",
|
| 525 |
+
"Tasks awaited sequentially instead of concurrently. "
|
| 526 |
+
"Use asyncio.gather(*tasks). Also self.results accumulates across multiple run_all calls.",
|
| 527 |
+
"Replace loop with: self.results.extend(await asyncio.gather(*tasks))"
|
| 528 |
+
),
|
| 529 |
+
],
|
| 530 |
+
"max_steps": 20,
|
| 531 |
+
"hints": [
|
| 532 |
+
"Check all places where ClientSession is created — are they properly closed?",
|
| 533 |
+
"Look for sequential awaits inside loops where gather() would be more appropriate.",
|
| 534 |
+
"The retry function has an off-by-one error in its condition.",
|
| 535 |
+
],
|
| 536 |
+
}
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
_PIPELINE_CODE = """\
|
| 540 |
+
import csv
|
| 541 |
+
import json
|
| 542 |
+
import hashlib
|
| 543 |
+
import sqlite3
|
| 544 |
+
from typing import List, Dict, Optional
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
def init_db(path: str) -> sqlite3.Connection:
|
| 548 |
+
conn = sqlite3.connect(path)
|
| 549 |
+
conn.execute(
|
| 550 |
+
"CREATE TABLE IF NOT EXISTS records "
|
| 551 |
+
"(id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT NOT NULL, "
|
| 552 |
+
"email TEXT NOT NULL, password_hash TEXT, score REAL DEFAULT 0)"
|
| 553 |
+
)
|
| 554 |
+
conn.commit()
|
| 555 |
+
return conn
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
def hash_password(password: str) -> str:
|
| 559 |
+
return hashlib.md5(password.encode()).hexdigest()
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
def insert_record(conn: sqlite3.Connection, username: str,
|
| 563 |
+
email: str, password: str, score: float) -> None:
|
| 564 |
+
pwd = hash_password(password)
|
| 565 |
+
conn.execute(
|
| 566 |
+
f"INSERT INTO records (username, email, password_hash, score) "
|
| 567 |
+
f"VALUES ('{username}', '{email}', '{pwd}', {score})"
|
| 568 |
+
)
|
| 569 |
+
conn.commit()
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
def search_records(conn: sqlite3.Connection, query: str) -> List[Dict]:
|
| 573 |
+
cursor = conn.execute(
|
| 574 |
+
f"SELECT id, username, email, score FROM records WHERE username LIKE '%{query}%'"
|
| 575 |
+
)
|
| 576 |
+
cols = [d[0] for d in cursor.description]
|
| 577 |
+
return [dict(zip(cols, row)) for row in cursor.fetchall()]
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
def bulk_load(conn: sqlite3.Connection, filepath: str) -> int:
|
| 581 |
+
count = 0
|
| 582 |
+
with open(filepath, newline='') as f:
|
| 583 |
+
for row in csv.DictReader(f):
|
| 584 |
+
insert_record(conn, row['username'], row['email'],
|
| 585 |
+
row.get('password', ''), float(row.get('score', 0)))
|
| 586 |
+
count += 1
|
| 587 |
+
return count
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
def export_records(conn: sqlite3.Connection, out_path: str) -> None:
|
| 591 |
+
rows = search_records(conn, '')
|
| 592 |
+
with open(out_path, 'w') as f:
|
| 593 |
+
json.dump(rows, f, indent=2)
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
def get_top_scores(conn: sqlite3.Connection, limit: int) -> List[Dict]:
|
| 597 |
+
cursor = conn.execute(
|
| 598 |
+
f"SELECT username, score FROM records ORDER BY score DESC LIMIT {limit}"
|
| 599 |
+
)
|
| 600 |
+
return [{'username': r[0], 'score': r[1]} for r in cursor.fetchall()]
|
| 601 |
+
"""
|
| 602 |
+
|
| 603 |
+
TASK_DATA_PIPELINE: Dict[str, Any] = {
|
| 604 |
+
"task_id": "data-pipeline",
|
| 605 |
+
"difficulty": "hard",
|
| 606 |
+
"description": (
|
| 607 |
+
"Perform a security and correctness review of this data pipeline module.\n"
|
| 608 |
+
"The module handles user records in SQLite. It contains multiple critical\n"
|
| 609 |
+
"security vulnerabilities, a performance issue, and an error handling gap.\n"
|
| 610 |
+
"Find ALL issues across the file.\n\n"
|
| 611 |
+
"File to review: pipeline.py"
|
| 612 |
+
),
|
| 613 |
+
"language": "python",
|
| 614 |
+
"code_files": {
|
| 615 |
+
"pipeline.py": _PIPELINE_CODE,
|
| 616 |
+
},
|
| 617 |
+
"ground_truth_issues": [
|
| 618 |
+
_issue(
|
| 619 |
+
20, "pipeline.py", "security", "high",
|
| 620 |
+
"MD5 is cryptographically broken for password hashing. "
|
| 621 |
+
"Use bcrypt, argon2, or hashlib.pbkdf2_hmac instead.",
|
| 622 |
+
"Use: hashlib.pbkdf2_hmac('sha256', password.encode(), salt, 100000)"
|
| 623 |
+
),
|
| 624 |
+
_issue(
|
| 625 |
+
27, "pipeline.py", "security", "critical",
|
| 626 |
+
"SQL injection: username, email, and pwd interpolated directly into query string. "
|
| 627 |
+
"Use parameterized queries: conn.execute('INSERT INTO records ... VALUES (?,?,?,?)', "
|
| 628 |
+
"(username, email, pwd, score))",
|
| 629 |
+
"Use: conn.execute('INSERT INTO records (username, email, password_hash, score) VALUES (?,?,?,?)', (username, email, pwd, score))"
|
| 630 |
+
),
|
| 631 |
+
_issue(
|
| 632 |
+
35, "pipeline.py", "security", "critical",
|
| 633 |
+
"SQL injection in LIKE clause: user-supplied query interpolated directly. "
|
| 634 |
+
"Use: conn.execute('... WHERE username LIKE ?', (f'%{query}%',))",
|
| 635 |
+
"Use: conn.execute('SELECT ... WHERE username LIKE ?', (f'%{query}%',))"
|
| 636 |
+
),
|
| 637 |
+
_issue(
|
| 638 |
+
41, "pipeline.py", "performance", "high",
|
| 639 |
+
"bulk_load commits one transaction per row via insert_record. "
|
| 640 |
+
"Wrap entire loop in with conn: for a single transaction — 10-100x faster for large imports.",
|
| 641 |
+
"Wrap loop body with: with conn: conn.executemany(...)"
|
| 642 |
+
),
|
| 643 |
+
_issue(
|
| 644 |
+
46, "pipeline.py", "bug", "medium",
|
| 645 |
+
"float() conversion has no error handling. A single malformed score field "
|
| 646 |
+
"crashes the entire import. Wrap in try/except ValueError.",
|
| 647 |
+
"Use: float(row.get('score', 0) or 0) inside try/except ValueError"
|
| 648 |
+
),
|
| 649 |
+
_issue(
|
| 650 |
+
52, "pipeline.py", "security", "high",
|
| 651 |
+
"export_records calls search_records(conn, '') which returns all records including "
|
| 652 |
+
"password_hash field. Strip sensitive fields before export.",
|
| 653 |
+
"Filter out password_hash: rows = [{k: v for k, v in r.items() if k != 'password_hash'} for r in rows]"
|
| 654 |
+
),
|
| 655 |
+
_issue(
|
| 656 |
+
59, "pipeline.py", "security", "critical",
|
| 657 |
+
"SQL injection: limit value interpolated into query. Although limit is an int here, "
|
| 658 |
+
"use parameterized query: conn.execute('... LIMIT ?', (limit,))",
|
| 659 |
+
"Use: conn.execute('SELECT username, score FROM records ORDER BY score DESC LIMIT ?', (limit,))"
|
| 660 |
+
),
|
| 661 |
+
],
|
| 662 |
+
"max_steps": 25,
|
| 663 |
+
"hints": [
|
| 664 |
+
"Look for every place user-supplied values touch a SQL query string — are they parameterized?",
|
| 665 |
+
"The bulk_load function has both a performance issue and an error handling gap.",
|
| 666 |
+
"Check what fields export_records includes in its output — are any sensitive?",
|
| 667 |
+
],
|
| 668 |
+
}
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
_API_SECURITY_CODE = """\
|
| 672 |
+
from fastapi import FastAPI, Depends, HTTPException, Header
|
| 673 |
+
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
| 674 |
+
import jwt
|
| 675 |
+
import hashlib
|
| 676 |
+
import pickle
|
| 677 |
+
import os
|
| 678 |
+
import sqlite3
|
| 679 |
+
|
| 680 |
+
app = FastAPI()
|
| 681 |
+
security = HTTPBasic()
|
| 682 |
+
|
| 683 |
+
SECRET_KEY = "dev-secret-do-not-use-in-prod"
|
| 684 |
+
ADMIN_TOKEN = "admin-hardcoded-token-123"
|
| 685 |
+
|
| 686 |
+
users_db = {
|
| 687 |
+
"admin": hashlib.md5(b"password123").hexdigest(),
|
| 688 |
+
"user": hashlib.md5(b"user123").hexdigest(),
|
| 689 |
+
}
|
| 690 |
+
|
| 691 |
+
|
| 692 |
+
@app.post("/login")
|
| 693 |
+
def login(credentials: HTTPBasicCredentials = Depends(security)):
|
| 694 |
+
username = credentials.username
|
| 695 |
+
stored = users_db.get(username, "")
|
| 696 |
+
if stored != hashlib.md5(credentials.password.encode()).hexdigest():
|
| 697 |
+
raise HTTPException(status_code=401, detail="Invalid credentials")
|
| 698 |
+
token = jwt.encode({"user": username, "admin": username == "admin"},
|
| 699 |
+
SECRET_KEY, algorithm="HS256")
|
| 700 |
+
return {"token": token}
|
| 701 |
+
|
| 702 |
+
|
| 703 |
+
@app.get("/users/{user_id}")
|
| 704 |
+
def get_user(user_id: str, authorization: str = Header(None)):
|
| 705 |
+
if not authorization:
|
| 706 |
+
raise HTTPException(status_code=401, detail="Missing token")
|
| 707 |
+
payload = jwt.decode(authorization, SECRET_KEY, algorithms=["HS256"])
|
| 708 |
+
conn = sqlite3.connect("app.db")
|
| 709 |
+
cursor = conn.execute(f"SELECT * FROM users WHERE id = '{user_id}'")
|
| 710 |
+
return {"user": cursor.fetchone()}
|
| 711 |
+
|
| 712 |
+
|
| 713 |
+
@app.post("/admin/export")
|
| 714 |
+
def admin_export(authorization: str = Header(None)):
|
| 715 |
+
if authorization != ADMIN_TOKEN:
|
| 716 |
+
raise HTTPException(status_code=403, detail="Forbidden")
|
| 717 |
+
path = os.environ.get("EXPORT_PATH", "/tmp/export")
|
| 718 |
+
os.system(f"mysqldump mydb > {path}/dump.sql")
|
| 719 |
+
return {"status": "export complete", "path": path}
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
@app.post("/import")
|
| 723 |
+
def import_data(payload: bytes):
|
| 724 |
+
data = pickle.loads(payload)
|
| 725 |
+
return {"records": len(data)}
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
@app.get("/search")
|
| 729 |
+
def search_users(q: str, limit: int = 100):
|
| 730 |
+
conn = sqlite3.connect("app.db")
|
| 731 |
+
rows = conn.execute(
|
| 732 |
+
f"SELECT id, name, email FROM users WHERE name LIKE '%{q}%' LIMIT {limit}"
|
| 733 |
+
).fetchall()
|
| 734 |
+
return {"results": rows}
|
| 735 |
+
"""
|
| 736 |
+
|
| 737 |
+
TASK_API_SECURITY: Dict[str, Any] = {
|
| 738 |
+
"task_id": "api-security",
|
| 739 |
+
"difficulty": "hard",
|
| 740 |
+
"description": (
|
| 741 |
+
"Perform a security audit on this FastAPI REST API.\n"
|
| 742 |
+
"The service handles user authentication and data operations.\n"
|
| 743 |
+
"It contains multiple critical security flaws across authentication,\n"
|
| 744 |
+
"authorization, injection attacks, and cryptography.\n"
|
| 745 |
+
"Find ALL issues with exact line numbers and severity ratings.\n\n"
|
| 746 |
+
"File to review: api.py"
|
| 747 |
+
),
|
| 748 |
+
"language": "python",
|
| 749 |
+
"code_files": {
|
| 750 |
+
"api.py": _API_SECURITY_CODE,
|
| 751 |
+
},
|
| 752 |
+
"ground_truth_issues": [
|
| 753 |
+
_issue(
|
| 754 |
+
12, "api.py", "security", "high",
|
| 755 |
+
"Hardcoded SECRET_KEY in source code. Any developer with repo access can forge "
|
| 756 |
+
"JWT tokens and impersonate any user.",
|
| 757 |
+
"Use: SECRET_KEY = os.environ.get('SECRET_KEY') and rotate it as a secret."
|
| 758 |
+
),
|
| 759 |
+
_issue(
|
| 760 |
+
13, "api.py", "security", "high",
|
| 761 |
+
"Hardcoded ADMIN_TOKEN in source code. Static tokens in code are trivially "
|
| 762 |
+
"leaked via version control, logs, or error messages.",
|
| 763 |
+
"Use: ADMIN_TOKEN = os.environ.get('ADMIN_TOKEN') and generate it securely."
|
| 764 |
+
),
|
| 765 |
+
_issue(
|
| 766 |
+
16, "api.py", "security", "high",
|
| 767 |
+
"MD5 used for password hashing. MD5 is cryptographically broken; precomputed "
|
| 768 |
+
"rainbow tables can reverse any MD5 hash in seconds.",
|
| 769 |
+
"Use bcrypt, argon2, or hashlib.pbkdf2_hmac with a random salt."
|
| 770 |
+
),
|
| 771 |
+
_issue(
|
| 772 |
+
27, "api.py", "security", "medium",
|
| 773 |
+
"JWT token issued without an expiry claim ('exp'). Tokens are valid forever; "
|
| 774 |
+
"a stolen token can never be invalidated without rotating the secret.",
|
| 775 |
+
"Add: {'exp': datetime.utcnow() + timedelta(hours=1)} to the JWT payload."
|
| 776 |
+
),
|
| 777 |
+
_issue(
|
| 778 |
+
33, "api.py", "security", "critical",
|
| 779 |
+
"Missing authorization check: any authenticated user can fetch any user_id. "
|
| 780 |
+
"This is an Insecure Direct Object Reference (IDOR) — user A can read user B's data.",
|
| 781 |
+
"Check: if payload.get('user') != user_id and not payload.get('admin'): raise 403."
|
| 782 |
+
),
|
| 783 |
+
_issue(
|
| 784 |
+
38, "api.py", "security", "critical",
|
| 785 |
+
"SQL injection: user_id is interpolated directly into the query string. "
|
| 786 |
+
"An attacker can supply user_id = \"' OR '1'='1\" to dump the users table.",
|
| 787 |
+
"Use parameterized query: conn.execute('SELECT * FROM users WHERE id = ?', (user_id,))"
|
| 788 |
+
),
|
| 789 |
+
_issue(
|
| 790 |
+
47, "api.py", "security", "critical",
|
| 791 |
+
"Command injection: EXPORT_PATH from environment is interpolated into an "
|
| 792 |
+
"os.system() shell command. A misconfigured env var like '/tmp; rm -rf /' "
|
| 793 |
+
"executes arbitrary commands as the server process.",
|
| 794 |
+
"Use subprocess.run(['mysqldump', 'mydb'], stdout=open(path, 'w'), shell=False)."
|
| 795 |
+
),
|
| 796 |
+
_issue(
|
| 797 |
+
53, "api.py", "security", "critical",
|
| 798 |
+
"Unsafe deserialization: pickle.loads() on untrusted user-supplied bytes allows "
|
| 799 |
+
"remote code execution. Any client can craft a pickle payload that runs arbitrary code.",
|
| 800 |
+
"Use json.loads() or a schema-validated format. Never unpickle untrusted data."
|
| 801 |
+
),
|
| 802 |
+
],
|
| 803 |
+
"max_steps": 25,
|
| 804 |
+
"hints": [
|
| 805 |
+
"Check every hardcoded string assigned to variables like SECRET_KEY, TOKEN, PASSWORD.",
|
| 806 |
+
"Look at every endpoint: which ones verify the caller's identity vs just authentication?",
|
| 807 |
+
"Find all places user-supplied data touches: SQL queries, shell commands, deserialization.",
|
| 808 |
+
],
|
| 809 |
+
}
|
| 810 |
+
|
| 811 |
+
|
| 812 |
+
_JS_CODE = """\
|
| 813 |
+
const express = require('express');
|
| 814 |
+
const jwt = require('jsonwebtoken');
|
| 815 |
+
const { execSync } = require('child_process');
|
| 816 |
+
const path = require('path');
|
| 817 |
+
const fs = require('fs');
|
| 818 |
+
const sqlite3 = require('better-sqlite3');
|
| 819 |
+
|
| 820 |
+
const app = express();
|
| 821 |
+
app.use(express.json());
|
| 822 |
+
|
| 823 |
+
const JWT_SECRET = 'super-secret-key-hardcoded';
|
| 824 |
+
const db = new sqlite3('./data.db');
|
| 825 |
+
|
| 826 |
+
app.post('/login', (req, res) => {
|
| 827 |
+
const { username, password } = req.body;
|
| 828 |
+
const user = db.prepare(`SELECT * FROM users WHERE username = '${username}' AND password = '${password}'`).get();
|
| 829 |
+
if (!user) return res.status(401).json({ error: 'Invalid credentials' });
|
| 830 |
+
const token = jwt.sign({ id: user.id, role: user.role }, JWT_SECRET);
|
| 831 |
+
res.json({ token });
|
| 832 |
+
});
|
| 833 |
+
|
| 834 |
+
app.get('/user/:id', (req, res) => {
|
| 835 |
+
const token = req.headers.authorization;
|
| 836 |
+
const payload = jwt.verify(token, JWT_SECRET);
|
| 837 |
+
const user = db.prepare(`SELECT * FROM users WHERE id = ${req.params.id}`).get();
|
| 838 |
+
res.json(user);
|
| 839 |
+
});
|
| 840 |
+
|
| 841 |
+
app.get('/search', (req, res) => {
|
| 842 |
+
const q = req.query.q;
|
| 843 |
+
res.send(`<h1>Results for: ${q}</h1>`);
|
| 844 |
+
});
|
| 845 |
+
|
| 846 |
+
app.post('/run-report', (req, res) => {
|
| 847 |
+
const { filename } = req.body;
|
| 848 |
+
const output = execSync(`node reports/${filename}`);
|
| 849 |
+
res.send(output.toString());
|
| 850 |
+
});
|
| 851 |
+
|
| 852 |
+
app.get('/files', (req, res) => {
|
| 853 |
+
const name = req.query.name;
|
| 854 |
+
const filePath = path.join(__dirname, 'uploads', name);
|
| 855 |
+
res.send(fs.readFileSync(filePath, 'utf8'));
|
| 856 |
+
});
|
| 857 |
+
|
| 858 |
+
app.post('/template', (req, res) => {
|
| 859 |
+
const { template, data } = req.body;
|
| 860 |
+
const fn = new Function('data', `return \\`${template}\\``);
|
| 861 |
+
res.json({ result: fn(data) });
|
| 862 |
+
});
|
| 863 |
+
|
| 864 |
+
app.listen(3000);
|
| 865 |
+
"""
|
| 866 |
+
|
| 867 |
+
TASK_JS_SECURITY: Dict[str, Any] = {
|
| 868 |
+
"task_id": "js-security",
|
| 869 |
+
"difficulty": "hard",
|
| 870 |
+
"description": (
|
| 871 |
+
"Perform a security audit on this Express.js REST API.\n"
|
| 872 |
+
"The service handles authentication and user data operations in Node.js.\n"
|
| 873 |
+
"It contains critical security vulnerabilities common in JavaScript backends.\n"
|
| 874 |
+
"Identify ALL issues with exact line numbers, types, and severity.\n\n"
|
| 875 |
+
"File to review: server.js"
|
| 876 |
+
),
|
| 877 |
+
"language": "javascript",
|
| 878 |
+
"code_files": {
|
| 879 |
+
"server.js": _JS_CODE,
|
| 880 |
+
},
|
| 881 |
+
"ground_truth_issues": [
|
| 882 |
+
_issue(
|
| 883 |
+
11, "server.js", "security", "high",
|
| 884 |
+
"Hardcoded JWT secret 'super-secret-key-hardcoded' in source. "
|
| 885 |
+
"Anyone with code access can forge tokens for any user.",
|
| 886 |
+
"Use: const JWT_SECRET = process.env.JWT_SECRET and rotate it as an env secret."
|
| 887 |
+
),
|
| 888 |
+
_issue(
|
| 889 |
+
16, "server.js", "security", "critical",
|
| 890 |
+
"SQL injection: username and password are interpolated directly into a template "
|
| 891 |
+
"literal inside prepare(). An attacker can bypass authentication with username = ' OR '1'='1'--.",
|
| 892 |
+
"Use parameterized queries: db.prepare('SELECT * FROM users WHERE username = ? AND password = ?').get(username, password)"
|
| 893 |
+
),
|
| 894 |
+
_issue(
|
| 895 |
+
18, "server.js", "security", "medium",
|
| 896 |
+
"JWT issued without expiry ('expiresIn' option missing). Tokens are valid forever; "
|
| 897 |
+
"a stolen token can never be invalidated without rotating the secret.",
|
| 898 |
+
"Add: jwt.sign({ id: user.id, role: user.role }, JWT_SECRET, { expiresIn: '1h' })"
|
| 899 |
+
),
|
| 900 |
+
_issue(
|
| 901 |
+
25, "server.js", "security", "critical",
|
| 902 |
+
"Missing authorization + SQL injection: any authenticated user can fetch any "
|
| 903 |
+
"user by changing req.params.id (IDOR). Also id is interpolated directly into SQL.",
|
| 904 |
+
"Check payload.id === req.params.id (or admin role). Use parameterized: db.prepare('SELECT * FROM users WHERE id = ?').get(req.params.id)"
|
| 905 |
+
),
|
| 906 |
+
_issue(
|
| 907 |
+
31, "server.js", "security", "high",
|
| 908 |
+
"Cross-site scripting (XSS): user-supplied query parameter q is reflected "
|
| 909 |
+
"directly into HTML response without escaping.",
|
| 910 |
+
"Use a templating engine with auto-escaping, or: res.send(`<h1>Results for: ${escapeHtml(q)}</h1>`)"
|
| 911 |
+
),
|
| 912 |
+
_issue(
|
| 913 |
+
36, "server.js", "security", "critical",
|
| 914 |
+
"Command injection: user-supplied filename is passed directly to execSync() "
|
| 915 |
+
"in a shell command. An attacker can supply 'x; rm -rf /' as filename.",
|
| 916 |
+
"Validate filename against a strict allowlist. Use execFileSync(['node', 'reports/' + sanitizedName]) with shell:false."
|
| 917 |
+
),
|
| 918 |
+
_issue(
|
| 919 |
+
42, "server.js", "security", "high",
|
| 920 |
+
"Path traversal: user-supplied 'name' is joined to uploads directory with path.join. "
|
| 921 |
+
"An attacker can supply '../../../etc/passwd' to read arbitrary files.",
|
| 922 |
+
"Use: path.resolve(__dirname, 'uploads', path.basename(name)) and validate the result starts with the uploads dir."
|
| 923 |
+
),
|
| 924 |
+
_issue(
|
| 925 |
+
48, "server.js", "security", "critical",
|
| 926 |
+
"Unsafe dynamic code execution: new Function() with user-supplied template string "
|
| 927 |
+
"is equivalent to eval(). Any client can execute arbitrary JavaScript on the server.",
|
| 928 |
+
"Never use new Function() or eval() with user input. Use a safe template engine like Handlebars or Mustache."
|
| 929 |
+
),
|
| 930 |
+
],
|
| 931 |
+
"max_steps": 25,
|
| 932 |
+
"hints": [
|
| 933 |
+
"Check every place user input (req.body, req.params, req.query) touches a database query, shell command, or HTML response.",
|
| 934 |
+
"Look for hardcoded secrets at the top of the file.",
|
| 935 |
+
"The /template and /run-report endpoints have particularly dangerous patterns.",
|
| 936 |
+
],
|
| 937 |
+
}
|
| 938 |
+
|
| 939 |
+
|
| 940 |
ALL_TASKS: Dict[str, Dict[str, Any]] = {
|
| 941 |
TASK_BUG_DETECTION["task_id"]: TASK_BUG_DETECTION,
|
| 942 |
TASK_SECURITY_AUDIT["task_id"]: TASK_SECURITY_AUDIT,
|
| 943 |
TASK_COMPREHENSIVE["task_id"]: TASK_COMPREHENSIVE,
|
| 944 |
+
TASK_ASYNC_REVIEW["task_id"]: TASK_ASYNC_REVIEW,
|
| 945 |
+
TASK_DATA_PIPELINE["task_id"]: TASK_DATA_PIPELINE,
|
| 946 |
+
TASK_API_SECURITY["task_id"]: TASK_API_SECURITY,
|
| 947 |
+
TASK_JS_SECURITY["task_id"]: TASK_JS_SECURITY,
|
| 948 |
}
|
| 949 |
|
| 950 |
TASK_IDS: List[str] = list(ALL_TASKS.keys())
|
tests/test_environment.py
CHANGED
|
@@ -41,6 +41,18 @@ def env_hard(env):
|
|
| 41 |
return env
|
| 42 |
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
# ---------------------------------------------------------------------------
|
| 45 |
# reset() tests
|
| 46 |
# ---------------------------------------------------------------------------
|
|
@@ -106,6 +118,40 @@ class TestReset:
|
|
| 106 |
assert obs.flagged_issues == []
|
| 107 |
assert obs.step_count == 0
|
| 108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
# ---------------------------------------------------------------------------
|
| 111 |
# step() — flag_issue tests
|
|
@@ -167,6 +213,148 @@ class TestFlagIssue:
|
|
| 167 |
obs = env_bug.state
|
| 168 |
assert len(obs.flagged_issues) == 3
|
| 169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
|
| 171 |
# ---------------------------------------------------------------------------
|
| 172 |
# step() — clear_flag tests
|
|
@@ -312,3 +500,341 @@ class TestMaxSteps:
|
|
| 312 |
break
|
| 313 |
|
| 314 |
assert obs.done is True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
return env
|
| 42 |
|
| 43 |
|
| 44 |
+
@pytest.fixture
|
| 45 |
+
def env_async(env):
|
| 46 |
+
env.reset(task_id="async-review")
|
| 47 |
+
return env
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@pytest.fixture
|
| 51 |
+
def env_pipeline(env):
|
| 52 |
+
env.reset(task_id="data-pipeline")
|
| 53 |
+
return env
|
| 54 |
+
|
| 55 |
+
|
| 56 |
# ---------------------------------------------------------------------------
|
| 57 |
# reset() tests
|
| 58 |
# ---------------------------------------------------------------------------
|
|
|
|
| 118 |
assert obs.flagged_issues == []
|
| 119 |
assert obs.step_count == 0
|
| 120 |
|
| 121 |
+
def test_reset_has_code_metadata(self, env):
|
| 122 |
+
"""Reset observation should include code_metadata."""
|
| 123 |
+
obs = env.reset(task_id="bug-detection")
|
| 124 |
+
assert isinstance(obs.code_metadata, dict)
|
| 125 |
+
assert "total_lines" in obs.code_metadata
|
| 126 |
+
assert "num_functions" in obs.code_metadata
|
| 127 |
+
assert "complexity_estimate" in obs.code_metadata
|
| 128 |
+
|
| 129 |
+
def test_reset_code_metadata_has_issue_categories(self, env):
|
| 130 |
+
"""code_metadata should list the issue categories present in ground truth."""
|
| 131 |
+
obs = env.reset(task_id="bug-detection")
|
| 132 |
+
assert "issue_categories" in obs.code_metadata
|
| 133 |
+
# bug-detection has only bug type issues
|
| 134 |
+
assert "bug" in obs.code_metadata["issue_categories"]
|
| 135 |
+
|
| 136 |
+
def test_reset_has_empty_progress(self, env):
|
| 137 |
+
"""Reset observation progress may be empty or absent (populated on step)."""
|
| 138 |
+
obs = env.reset(task_id="bug-detection")
|
| 139 |
+
assert isinstance(obs.progress, dict)
|
| 140 |
+
|
| 141 |
+
def test_reset_has_empty_reward_breakdown(self, env):
|
| 142 |
+
obs = env.reset(task_id="bug-detection")
|
| 143 |
+
assert isinstance(obs.reward_breakdown, dict)
|
| 144 |
+
|
| 145 |
+
def test_reset_async_task(self, env):
|
| 146 |
+
obs = env.reset(task_id="async-review")
|
| 147 |
+
assert obs.task_id == "async-review"
|
| 148 |
+
assert "async.py" in obs.code_files
|
| 149 |
+
|
| 150 |
+
def test_reset_pipeline_task(self, env):
|
| 151 |
+
obs = env.reset(task_id="data-pipeline")
|
| 152 |
+
assert obs.task_id == "data-pipeline"
|
| 153 |
+
assert "pipeline.py" in obs.code_files
|
| 154 |
+
|
| 155 |
|
| 156 |
# ---------------------------------------------------------------------------
|
| 157 |
# step() — flag_issue tests
|
|
|
|
| 213 |
obs = env_bug.state
|
| 214 |
assert len(obs.flagged_issues) == 3
|
| 215 |
|
| 216 |
+
def test_flag_has_reward_breakdown(self, env_bug):
|
| 217 |
+
"""Every step should have a reward_breakdown dict."""
|
| 218 |
+
obs = env_bug.step(ReviewAction(
|
| 219 |
+
action_type="flag_issue", line_number=6, filename="utils.py",
|
| 220 |
+
issue_type="bug", severity="high", description="test"
|
| 221 |
+
))
|
| 222 |
+
assert isinstance(obs.reward_breakdown, dict)
|
| 223 |
+
assert len(obs.reward_breakdown) > 0
|
| 224 |
+
|
| 225 |
+
def test_flag_has_progress(self, env_bug):
|
| 226 |
+
"""Every step should have a progress dict with required keys."""
|
| 227 |
+
obs = env_bug.step(ReviewAction(
|
| 228 |
+
action_type="flag_issue", line_number=6, filename="utils.py",
|
| 229 |
+
issue_type="bug", severity="high", description="test"
|
| 230 |
+
))
|
| 231 |
+
assert isinstance(obs.progress, dict)
|
| 232 |
+
for key in ("precision", "recall", "f1", "true_positives", "steps_remaining"):
|
| 233 |
+
assert key in obs.progress, f"Missing key: {key}"
|
| 234 |
+
|
| 235 |
+
def test_flag_has_flagged_summary(self, env_bug):
|
| 236 |
+
"""Every step should have a flagged_summary dict."""
|
| 237 |
+
obs = env_bug.step(ReviewAction(
|
| 238 |
+
action_type="flag_issue", line_number=6, filename="utils.py",
|
| 239 |
+
issue_type="bug", severity="high", description="test"
|
| 240 |
+
))
|
| 241 |
+
assert isinstance(obs.flagged_summary, dict)
|
| 242 |
+
assert "total_flagged" in obs.flagged_summary
|
| 243 |
+
assert "correct" in obs.flagged_summary
|
| 244 |
+
assert "incorrect" in obs.flagged_summary
|
| 245 |
+
assert "near_misses" in obs.flagged_summary
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
# ---------------------------------------------------------------------------
|
| 249 |
+
# Near-miss tests
|
| 250 |
+
# ---------------------------------------------------------------------------
|
| 251 |
+
|
| 252 |
+
class TestNearMiss:
|
| 253 |
+
def test_near_miss_gives_partial_credit(self, env_bug):
|
| 254 |
+
"""A flag within 3-5 lines of a GT issue should give +0.03 not -0.05."""
|
| 255 |
+
# GT issue is at line 6 (off-by-one), so line 10 is 4 away = near miss
|
| 256 |
+
obs = env_bug.step(ReviewAction(
|
| 257 |
+
action_type="flag_issue", line_number=10, filename="utils.py",
|
| 258 |
+
issue_type="bug", severity="high", description="near miss test"
|
| 259 |
+
))
|
| 260 |
+
# Near miss gives +0.03
|
| 261 |
+
assert obs.reward is not None and obs.reward > 0, (
|
| 262 |
+
f"Expected near-miss +0.03 but got {obs.reward}"
|
| 263 |
+
)
|
| 264 |
+
assert obs.reward == pytest.approx(0.03, abs=0.01)
|
| 265 |
+
|
| 266 |
+
def test_near_miss_counted_in_summary(self, env_bug):
|
| 267 |
+
"""Near-miss flags should appear in flagged_summary.near_misses."""
|
| 268 |
+
# Line 10 is 4 lines from GT at line 6 → near miss
|
| 269 |
+
obs = env_bug.step(ReviewAction(
|
| 270 |
+
action_type="flag_issue", line_number=10, filename="utils.py",
|
| 271 |
+
issue_type="bug", severity="high", description="near miss"
|
| 272 |
+
))
|
| 273 |
+
assert obs.flagged_summary.get("near_misses", 0) >= 1
|
| 274 |
+
|
| 275 |
+
def test_true_positive_not_counted_as_near_miss(self, env_bug):
|
| 276 |
+
"""An exact TP should not be counted as a near miss."""
|
| 277 |
+
obs = env_bug.step(ReviewAction(
|
| 278 |
+
action_type="flag_issue", line_number=6, filename="utils.py",
|
| 279 |
+
issue_type="bug", severity="high", description="exact match"
|
| 280 |
+
))
|
| 281 |
+
assert obs.flagged_summary.get("correct", 0) >= 1
|
| 282 |
+
assert obs.flagged_summary.get("near_misses", 0) == 0
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
# ---------------------------------------------------------------------------
|
| 286 |
+
# Confidence field tests
|
| 287 |
+
# ---------------------------------------------------------------------------
|
| 288 |
+
|
| 289 |
+
class TestConfidenceField:
|
| 290 |
+
def test_action_with_confidence(self, env_bug):
|
| 291 |
+
"""ReviewAction should accept a confidence field."""
|
| 292 |
+
action = ReviewAction(
|
| 293 |
+
action_type="flag_issue", line_number=6, filename="utils.py",
|
| 294 |
+
issue_type="bug", severity="high", description="test",
|
| 295 |
+
confidence=0.9
|
| 296 |
+
)
|
| 297 |
+
assert action.confidence == 0.9
|
| 298 |
+
|
| 299 |
+
def test_high_confidence_tp_gets_bonus(self, env_bug):
|
| 300 |
+
"""High confidence + TP should give more than base 0.10."""
|
| 301 |
+
obs = env_bug.step(ReviewAction(
|
| 302 |
+
action_type="flag_issue", line_number=6, filename="utils.py",
|
| 303 |
+
issue_type="bug", severity="high", description="test",
|
| 304 |
+
confidence=0.9
|
| 305 |
+
))
|
| 306 |
+
assert obs.reward is not None and obs.reward > 0.10
|
| 307 |
+
|
| 308 |
+
def test_high_confidence_fp_gets_extra_penalty(self, env_bug):
|
| 309 |
+
"""High confidence + FP should give more penalty than -0.05."""
|
| 310 |
+
obs = env_bug.step(ReviewAction(
|
| 311 |
+
action_type="flag_issue", line_number=100, filename="utils.py",
|
| 312 |
+
issue_type="bug", severity="low", description="wrong",
|
| 313 |
+
confidence=0.9
|
| 314 |
+
))
|
| 315 |
+
assert obs.reward is not None and obs.reward < -0.05
|
| 316 |
+
|
| 317 |
+
def test_low_confidence_tp_base_reward_only(self, env_bug):
|
| 318 |
+
"""Low confidence + TP should give exactly base 0.10 (no bonus)."""
|
| 319 |
+
obs = env_bug.step(ReviewAction(
|
| 320 |
+
action_type="flag_issue", line_number=6, filename="utils.py",
|
| 321 |
+
issue_type="bug", severity="high", description="test",
|
| 322 |
+
confidence=0.5
|
| 323 |
+
))
|
| 324 |
+
assert obs.reward is not None
|
| 325 |
+
# Should be 0.10 base + possible temporal bonus but no confidence bonus
|
| 326 |
+
assert obs.reward >= 0.10
|
| 327 |
+
|
| 328 |
+
def test_no_confidence_field_is_none(self):
|
| 329 |
+
"""ReviewAction without confidence defaults to None."""
|
| 330 |
+
action = ReviewAction(
|
| 331 |
+
action_type="flag_issue", line_number=6, filename="utils.py",
|
| 332 |
+
)
|
| 333 |
+
assert action.confidence is None
|
| 334 |
+
|
| 335 |
+
def test_confidence_in_action_to_dict(self):
|
| 336 |
+
"""confidence should round-trip through to_dict/from_dict."""
|
| 337 |
+
action = ReviewAction(
|
| 338 |
+
action_type="flag_issue", line_number=6, filename="utils.py",
|
| 339 |
+
confidence=0.75
|
| 340 |
+
)
|
| 341 |
+
d = action.to_dict()
|
| 342 |
+
assert d["confidence"] == 0.75
|
| 343 |
+
action2 = ReviewAction.from_dict(d)
|
| 344 |
+
assert action2.confidence == 0.75
|
| 345 |
+
|
| 346 |
+
def test_related_lines_field(self):
|
| 347 |
+
"""ReviewAction should accept a related_lines field."""
|
| 348 |
+
action = ReviewAction(
|
| 349 |
+
action_type="flag_issue", line_number=6, filename="utils.py",
|
| 350 |
+
related_lines=[6, 7, 8]
|
| 351 |
+
)
|
| 352 |
+
assert action.related_lines == [6, 7, 8]
|
| 353 |
+
d = action.to_dict()
|
| 354 |
+
assert d["related_lines"] == [6, 7, 8]
|
| 355 |
+
action2 = ReviewAction.from_dict(d)
|
| 356 |
+
assert action2.related_lines == [6, 7, 8]
|
| 357 |
+
|
| 358 |
|
| 359 |
# ---------------------------------------------------------------------------
|
| 360 |
# step() — clear_flag tests
|
|
|
|
| 500 |
break
|
| 501 |
|
| 502 |
assert obs.done is True
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
# ---------------------------------------------------------------------------
|
| 506 |
+
# New task tests
|
| 507 |
+
# ---------------------------------------------------------------------------
|
| 508 |
+
|
| 509 |
+
class TestNewTasks:
|
| 510 |
+
def test_async_review_task_exists(self, env):
|
| 511 |
+
obs = env.reset(task_id="async-review")
|
| 512 |
+
assert obs.task_id == "async-review"
|
| 513 |
+
assert obs.done is False
|
| 514 |
+
|
| 515 |
+
def test_async_review_has_correct_issue_count(self):
|
| 516 |
+
from tasks.data import ALL_TASKS
|
| 517 |
+
task = ALL_TASKS["async-review"]
|
| 518 |
+
assert len(task["ground_truth_issues"]) == 6
|
| 519 |
+
|
| 520 |
+
def test_async_review_has_async_py(self, env):
|
| 521 |
+
obs = env.reset(task_id="async-review")
|
| 522 |
+
assert "async.py" in obs.code_files
|
| 523 |
+
code = obs.code_files["async.py"]
|
| 524 |
+
assert "asyncio" in code
|
| 525 |
+
assert "aiohttp" in code
|
| 526 |
+
|
| 527 |
+
def test_async_review_max_steps(self):
|
| 528 |
+
from tasks.data import ALL_TASKS
|
| 529 |
+
task = ALL_TASKS["async-review"]
|
| 530 |
+
assert task["max_steps"] == 20
|
| 531 |
+
|
| 532 |
+
def test_data_pipeline_task_exists(self, env):
|
| 533 |
+
obs = env.reset(task_id="data-pipeline")
|
| 534 |
+
assert obs.task_id == "data-pipeline"
|
| 535 |
+
assert obs.done is False
|
| 536 |
+
|
| 537 |
+
def test_data_pipeline_has_correct_issue_count(self):
|
| 538 |
+
from tasks.data import ALL_TASKS
|
| 539 |
+
task = ALL_TASKS["data-pipeline"]
|
| 540 |
+
assert len(task["ground_truth_issues"]) == 7
|
| 541 |
+
|
| 542 |
+
def test_data_pipeline_has_pipeline_py(self, env):
|
| 543 |
+
obs = env.reset(task_id="data-pipeline")
|
| 544 |
+
assert "pipeline.py" in obs.code_files
|
| 545 |
+
code = obs.code_files["pipeline.py"]
|
| 546 |
+
assert "sqlite3" in code
|
| 547 |
+
assert "hashlib" in code
|
| 548 |
+
|
| 549 |
+
def test_data_pipeline_max_steps(self):
|
| 550 |
+
from tasks.data import ALL_TASKS
|
| 551 |
+
task = ALL_TASKS["data-pipeline"]
|
| 552 |
+
assert task["max_steps"] == 25
|
| 553 |
+
|
| 554 |
+
def test_task_count(self):
|
| 555 |
+
from tasks.data import TASK_IDS
|
| 556 |
+
assert len(TASK_IDS) >= 6
|
| 557 |
+
|
| 558 |
+
def test_async_review_correct_tp_reward(self, env_async):
|
| 559 |
+
"""Flagging a known issue in async-review should give positive reward."""
|
| 560 |
+
obs = env_async.step(ReviewAction(
|
| 561 |
+
action_type="flag_issue", line_number=22, filename="async.py",
|
| 562 |
+
issue_type="bug", severity="high",
|
| 563 |
+
description="ClientSession not closed"
|
| 564 |
+
))
|
| 565 |
+
assert obs.reward is not None and obs.reward > 0
|
| 566 |
+
|
| 567 |
+
def test_data_pipeline_correct_tp_reward(self, env_pipeline):
|
| 568 |
+
"""Flagging a known SQL injection in pipeline.py should give positive reward."""
|
| 569 |
+
obs = env_pipeline.step(ReviewAction(
|
| 570 |
+
action_type="flag_issue", line_number=27, filename="pipeline.py",
|
| 571 |
+
issue_type="security", severity="critical",
|
| 572 |
+
description="SQL injection"
|
| 573 |
+
))
|
| 574 |
+
assert obs.reward is not None and obs.reward > 0
|
| 575 |
+
|
| 576 |
+
def test_all_tasks_have_hints(self):
|
| 577 |
+
from tasks.data import ALL_TASKS
|
| 578 |
+
for task_id, task in ALL_TASKS.items():
|
| 579 |
+
assert "hints" in task, f"Task {task_id} missing hints"
|
| 580 |
+
assert len(task["hints"]) >= 3, f"Task {task_id} has fewer than 3 hints"
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
# ---------------------------------------------------------------------------
|
| 584 |
+
# Observation serialization
|
| 585 |
+
# ---------------------------------------------------------------------------
|
| 586 |
+
|
| 587 |
+
class TestObservationSerialization:
|
| 588 |
+
def test_reset_obs_to_dict_has_new_fields(self, env):
|
| 589 |
+
"""to_dict() should include all new fields."""
|
| 590 |
+
obs = env.reset(task_id="bug-detection")
|
| 591 |
+
d = obs.to_dict()
|
| 592 |
+
assert "reward_breakdown" in d
|
| 593 |
+
assert "progress" in d
|
| 594 |
+
assert "flagged_summary" in d
|
| 595 |
+
assert "code_metadata" in d
|
| 596 |
+
|
| 597 |
+
def test_obs_from_dict_handles_missing_new_fields(self):
|
| 598 |
+
"""from_dict() should handle missing new fields gracefully."""
|
| 599 |
+
d = {
|
| 600 |
+
"task_id": "bug-detection",
|
| 601 |
+
"task_description": "test",
|
| 602 |
+
"code_files": {},
|
| 603 |
+
"language": "python",
|
| 604 |
+
"flagged_issues": [],
|
| 605 |
+
"step_count": 0,
|
| 606 |
+
"max_steps": 15,
|
| 607 |
+
"hints_remaining": 3,
|
| 608 |
+
"feedback": "",
|
| 609 |
+
"current_score": 0.0,
|
| 610 |
+
"done": False,
|
| 611 |
+
"reward": None,
|
| 612 |
+
# No reward_breakdown, progress, flagged_summary, code_metadata
|
| 613 |
+
}
|
| 614 |
+
obs = ReviewObservation.from_dict(d)
|
| 615 |
+
assert obs.reward_breakdown == {}
|
| 616 |
+
assert obs.progress == {}
|
| 617 |
+
assert obs.flagged_summary == {}
|
| 618 |
+
assert obs.code_metadata == {}
|
| 619 |
+
|
| 620 |
+
def test_step_obs_to_dict_round_trip(self, env_bug):
|
| 621 |
+
obs = env_bug.step(ReviewAction(
|
| 622 |
+
action_type="flag_issue", line_number=6, filename="utils.py",
|
| 623 |
+
issue_type="bug", severity="high", description="test"
|
| 624 |
+
))
|
| 625 |
+
d = obs.to_dict()
|
| 626 |
+
obs2 = ReviewObservation.from_dict(d)
|
| 627 |
+
assert obs2.task_id == obs.task_id
|
| 628 |
+
assert obs2.step_count == obs.step_count
|
| 629 |
+
assert isinstance(obs2.reward_breakdown, dict)
|
| 630 |
+
assert isinstance(obs2.progress, dict)
|
| 631 |
+
assert isinstance(obs2.flagged_summary, dict)
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
# ---------------------------------------------------------------------------
|
| 635 |
+
# Severity exact match bonus
|
| 636 |
+
# ---------------------------------------------------------------------------
|
| 637 |
+
|
| 638 |
+
class TestSeverityBonus:
|
| 639 |
+
def test_severity_match_gives_extra_reward(self, env_bug):
|
| 640 |
+
"""Exact severity match should give more than a severity mismatch."""
|
| 641 |
+
# GT at line 6 is "high"
|
| 642 |
+
obs_match = env_bug.step(ReviewAction(
|
| 643 |
+
action_type="flag_issue", line_number=6, filename="utils.py",
|
| 644 |
+
issue_type="bug", severity="high", description="exact severity"
|
| 645 |
+
))
|
| 646 |
+
env_bug.reset(task_id="bug-detection")
|
| 647 |
+
obs_wrong = env_bug.step(ReviewAction(
|
| 648 |
+
action_type="flag_issue", line_number=6, filename="utils.py",
|
| 649 |
+
issue_type="bug", severity="low", description="wrong severity"
|
| 650 |
+
))
|
| 651 |
+
assert obs_match.reward > obs_wrong.reward
|
| 652 |
+
|
| 653 |
+
def test_severity_bonus_in_reward_breakdown(self, env_bug):
|
| 654 |
+
"""reward_breakdown should include 'severity_exact' key on correct severity."""
|
| 655 |
+
obs = env_bug.step(ReviewAction(
|
| 656 |
+
action_type="flag_issue", line_number=6, filename="utils.py",
|
| 657 |
+
issue_type="bug", severity="high", description="correct severity"
|
| 658 |
+
))
|
| 659 |
+
assert "severity_exact" in obs.reward_breakdown
|
| 660 |
+
|
| 661 |
+
def test_severity_mismatch_no_severity_bonus(self, env_bug):
|
| 662 |
+
"""Wrong severity should not include 'severity_exact' key."""
|
| 663 |
+
obs = env_bug.step(ReviewAction(
|
| 664 |
+
action_type="flag_issue", line_number=6, filename="utils.py",
|
| 665 |
+
issue_type="bug", severity="low", description="wrong severity"
|
| 666 |
+
))
|
| 667 |
+
assert "severity_exact" not in obs.reward_breakdown
|
| 668 |
+
|
| 669 |
+
|
| 670 |
+
# ---------------------------------------------------------------------------
|
| 671 |
+
# Flood protection (escalating FP penalty)
|
| 672 |
+
# ---------------------------------------------------------------------------
|
| 673 |
+
|
| 674 |
+
class TestFloodProtection:
|
| 675 |
+
def test_many_fps_escalate_penalty(self, env_bug):
|
| 676 |
+
"""After 3 false positives, each subsequent FP should have larger penalty."""
|
| 677 |
+
rewards = []
|
| 678 |
+
for line in [101, 102, 103, 104, 105]:
|
| 679 |
+
obs = env_bug.step(ReviewAction(
|
| 680 |
+
action_type="flag_issue", line_number=line, filename="utils.py",
|
| 681 |
+
issue_type="bug", severity="low", description="fp"
|
| 682 |
+
))
|
| 683 |
+
if obs.reward is not None and obs.reward < 0:
|
| 684 |
+
rewards.append(obs.reward)
|
| 685 |
+
|
| 686 |
+
# The 4th and 5th FPs should have larger absolute penalty
|
| 687 |
+
if len(rewards) >= 4:
|
| 688 |
+
assert abs(rewards[-1]) >= abs(rewards[0]), (
|
| 689 |
+
f"Expected escalating penalty but got {rewards}"
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
def test_fp_below_threshold_normal_penalty(self, env_bug):
|
| 693 |
+
"""First FP should get standard -0.05 penalty."""
|
| 694 |
+
obs = env_bug.step(ReviewAction(
|
| 695 |
+
action_type="flag_issue", line_number=200, filename="utils.py",
|
| 696 |
+
issue_type="bug", severity="low", description="first fp"
|
| 697 |
+
))
|
| 698 |
+
assert obs.reward is not None
|
| 699 |
+
assert obs.reward == pytest.approx(-0.05, abs=0.01)
|
| 700 |
+
|
| 701 |
+
def test_clearing_fp_reduces_penalty_track(self, env_bug):
|
| 702 |
+
"""Clearing a FP should give positive reward."""
|
| 703 |
+
env_bug.step(ReviewAction(
|
| 704 |
+
action_type="flag_issue", line_number=200, filename="utils.py",
|
| 705 |
+
issue_type="bug", severity="low", description="fp"
|
| 706 |
+
))
|
| 707 |
+
obs = env_bug.step(ReviewAction(
|
| 708 |
+
action_type="clear_flag", line_number=200, filename="utils.py",
|
| 709 |
+
))
|
| 710 |
+
assert obs.reward is not None and obs.reward > 0
|
| 711 |
+
|
| 712 |
+
|
| 713 |
+
# ---------------------------------------------------------------------------
|
| 714 |
+
# Unfound issue types in progress
|
| 715 |
+
# ---------------------------------------------------------------------------
|
| 716 |
+
|
| 717 |
+
class TestUnfoundIssueTypes:
|
| 718 |
+
def test_unfound_types_present_at_start(self, env_bug):
|
| 719 |
+
"""Before flagging anything, all GT issue types should be in unfound_issue_types."""
|
| 720 |
+
obs = env_bug.step(ReviewAction(action_type="request_hint"))
|
| 721 |
+
unfound = obs.progress.get("unfound_issue_types", [])
|
| 722 |
+
assert "bug" in unfound
|
| 723 |
+
|
| 724 |
+
def test_unfound_types_shrinks_when_issue_found(self, env_bug):
|
| 725 |
+
"""Finding a bug should remove 'bug' from unfound_issue_types."""
|
| 726 |
+
obs_before = env_bug.step(ReviewAction(action_type="request_hint"))
|
| 727 |
+
unfound_before = set(obs_before.progress.get("unfound_issue_types", []))
|
| 728 |
+
|
| 729 |
+
env_bug.step(ReviewAction(
|
| 730 |
+
action_type="flag_issue", line_number=6, filename="utils.py",
|
| 731 |
+
issue_type="bug", severity="high", description="found a bug"
|
| 732 |
+
))
|
| 733 |
+
obs_after = env_bug.step(ReviewAction(action_type="request_hint"))
|
| 734 |
+
unfound_after = set(obs_after.progress.get("unfound_issue_types", []))
|
| 735 |
+
|
| 736 |
+
# bug should now be gone from unfound
|
| 737 |
+
assert "bug" not in unfound_after or len(unfound_after) < len(unfound_before)
|
| 738 |
+
|
| 739 |
+
def test_unfound_types_is_list(self, env_bug):
|
| 740 |
+
obs = env_bug.step(ReviewAction(action_type="request_hint"))
|
| 741 |
+
assert isinstance(obs.progress.get("unfound_issue_types", []), list)
|
| 742 |
+
|
| 743 |
+
|
| 744 |
+
# ---------------------------------------------------------------------------
|
| 745 |
+
# API security task
|
| 746 |
+
# ---------------------------------------------------------------------------
|
| 747 |
+
|
| 748 |
+
class TestApiSecurityTask:
|
| 749 |
+
def test_api_security_task_exists(self, env):
|
| 750 |
+
obs = env.reset(task_id="api-security")
|
| 751 |
+
assert obs.task_id == "api-security"
|
| 752 |
+
assert obs.done is False
|
| 753 |
+
|
| 754 |
+
def test_api_security_has_api_py(self, env):
|
| 755 |
+
obs = env.reset(task_id="api-security")
|
| 756 |
+
assert "api.py" in obs.code_files
|
| 757 |
+
|
| 758 |
+
def test_api_security_has_8_issues(self):
|
| 759 |
+
from tasks.data import ALL_TASKS
|
| 760 |
+
task = ALL_TASKS["api-security"]
|
| 761 |
+
assert len(task["ground_truth_issues"]) == 8
|
| 762 |
+
|
| 763 |
+
def test_api_security_has_critical_issues(self):
|
| 764 |
+
from tasks.data import ALL_TASKS
|
| 765 |
+
task = ALL_TASKS["api-security"]
|
| 766 |
+
severities = {i["severity"] for i in task["ground_truth_issues"]}
|
| 767 |
+
assert "critical" in severities
|
| 768 |
+
|
| 769 |
+
def test_api_security_tp_reward(self, env):
|
| 770 |
+
env.reset(task_id="api-security")
|
| 771 |
+
obs = env.step(ReviewAction(
|
| 772 |
+
action_type="flag_issue", line_number=38, filename="api.py",
|
| 773 |
+
issue_type="security", severity="critical",
|
| 774 |
+
description="SQL injection via f-string"
|
| 775 |
+
))
|
| 776 |
+
assert obs.reward is not None and obs.reward > 0
|
| 777 |
+
|
| 778 |
+
def test_api_security_keyword_baseline_finds_issues(self):
|
| 779 |
+
from tasks.data import ALL_TASKS
|
| 780 |
+
from server.graders import run_keyword_baseline
|
| 781 |
+
task = ALL_TASKS["api-security"]
|
| 782 |
+
findings = run_keyword_baseline(task)
|
| 783 |
+
assert len(findings) >= 2
|
| 784 |
+
|
| 785 |
+
def test_api_security_difficulty_hard(self):
|
| 786 |
+
from tasks.data import ALL_TASKS
|
| 787 |
+
task = ALL_TASKS["api-security"]
|
| 788 |
+
assert task["difficulty"] == "hard"
|
| 789 |
+
|
| 790 |
+
|
| 791 |
+
# ---------------------------------------------------------------------------
|
| 792 |
+
# Auto-end gives full score (not 0.5x)
|
| 793 |
+
# ---------------------------------------------------------------------------
|
| 794 |
+
|
| 795 |
+
class TestAutoEndFullScore:
|
| 796 |
+
def test_auto_end_uses_full_grade(self, env_bug):
|
| 797 |
+
"""Auto-end should give full grade_episode score, not a penalized value."""
|
| 798 |
+
# Flag all 3 correct bugs first
|
| 799 |
+
for line, sev in [(6, "high"), (13, "medium"), (33, "low")]:
|
| 800 |
+
env_bug.step(ReviewAction(
|
| 801 |
+
action_type="flag_issue", line_number=line, filename="utils.py",
|
| 802 |
+
issue_type="bug", severity=sev, description=f"bug at {line}"
|
| 803 |
+
))
|
| 804 |
+
# Exhaust remaining steps with hints
|
| 805 |
+
max_steps = 15
|
| 806 |
+
for _ in range(max_steps - 3 - 1):
|
| 807 |
+
obs = env_bug.step(ReviewAction(action_type="request_hint"))
|
| 808 |
+
if obs.done:
|
| 809 |
+
break
|
| 810 |
+
|
| 811 |
+
obs = env_bug.step(ReviewAction(action_type="request_hint"))
|
| 812 |
+
if obs.done and obs.reward_breakdown.get("auto_end_grade") is not None:
|
| 813 |
+
# If auto-ended, score should be >= 0.7 since all 3 bugs found
|
| 814 |
+
assert obs.reward >= 0.7, f"Auto-end gave {obs.reward} instead of full grade"
|
| 815 |
+
|
| 816 |
+
|
| 817 |
+
# ---------------------------------------------------------------------------
|
| 818 |
+
# Function ranges in code_metadata
|
| 819 |
+
# ---------------------------------------------------------------------------
|
| 820 |
+
|
| 821 |
+
class TestFunctionRanges:
|
| 822 |
+
def test_reset_has_function_ranges(self, env):
|
| 823 |
+
obs = env.reset(task_id="bug-detection")
|
| 824 |
+
assert "function_ranges" in obs.code_metadata
|
| 825 |
+
|
| 826 |
+
def test_function_ranges_is_list(self, env):
|
| 827 |
+
obs = env.reset(task_id="bug-detection")
|
| 828 |
+
assert isinstance(obs.code_metadata["function_ranges"], list)
|
| 829 |
+
|
| 830 |
+
def test_function_ranges_have_required_fields(self, env):
|
| 831 |
+
obs = env.reset(task_id="bug-detection")
|
| 832 |
+
for fr in obs.code_metadata["function_ranges"]:
|
| 833 |
+
assert "name" in fr
|
| 834 |
+
assert "file" in fr
|
| 835 |
+
assert "start" in fr
|
| 836 |
+
assert "end" in fr
|
| 837 |
+
|
| 838 |
+
def test_function_ranges_nonempty_for_python(self, env):
|
| 839 |
+
obs = env.reset(task_id="bug-detection")
|
| 840 |
+
assert len(obs.code_metadata["function_ranges"]) > 0
|
tests/test_graders.py
CHANGED
|
@@ -7,7 +7,11 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
| 7 |
|
| 8 |
import pytest
|
| 9 |
from models import Issue
|
| 10 |
-
from server.graders import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
from tasks.data import ALL_TASKS, TASK_IDS
|
| 12 |
|
| 13 |
|
|
@@ -56,6 +60,231 @@ class TestMatchIssue:
|
|
| 56 |
gt = _issue(6, "utils.py", "bug", "high")
|
| 57 |
assert match_issue(f, gt) is False
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
# ---------------------------------------------------------------------------
|
| 61 |
# grade_episode()
|
|
@@ -177,6 +406,23 @@ class TestKeywordBaseline:
|
|
| 177 |
if task_id == "security-audit":
|
| 178 |
assert score > 0.0, f"Heuristic found nothing in {task_id}"
|
| 179 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
|
| 181 |
# ---------------------------------------------------------------------------
|
| 182 |
# Ground truth sanity checks
|
|
@@ -213,3 +459,159 @@ class TestGroundTruth:
|
|
| 213 |
files = {i["filename"] for i in task["ground_truth_issues"]}
|
| 214 |
assert "views.py" in files
|
| 215 |
assert "models.py" in files
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
import pytest
|
| 9 |
from models import Issue
|
| 10 |
+
from server.graders import (
|
| 11 |
+
grade_episode, match_issue, run_keyword_baseline,
|
| 12 |
+
match_quality, compute_code_metadata, grade_episode_detailed,
|
| 13 |
+
NEAR_TOLERANCE,
|
| 14 |
+
)
|
| 15 |
from tasks.data import ALL_TASKS, TASK_IDS
|
| 16 |
|
| 17 |
|
|
|
|
| 60 |
gt = _issue(6, "utils.py", "bug", "high")
|
| 61 |
assert match_issue(f, gt) is False
|
| 62 |
|
| 63 |
+
def test_near_tolerance_param_accepted(self):
|
| 64 |
+
"""match_issue should accept near_tolerance param without error."""
|
| 65 |
+
f = _issue(6, "utils.py", "bug", "high")
|
| 66 |
+
gt = _issue(6, "utils.py", "bug", "high")
|
| 67 |
+
result = match_issue(f, gt, line_tolerance=2, near_tolerance=5)
|
| 68 |
+
assert result is True
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# ---------------------------------------------------------------------------
|
| 72 |
+
# match_quality()
|
| 73 |
+
# ---------------------------------------------------------------------------
|
| 74 |
+
|
| 75 |
+
class TestMatchQuality:
|
| 76 |
+
def test_exact_match_within_2_lines(self):
|
| 77 |
+
f = _issue(7, "utils.py", "bug", "high")
|
| 78 |
+
gt = _issue(6, "utils.py", "bug", "high")
|
| 79 |
+
assert match_quality(f, gt) == "exact"
|
| 80 |
+
|
| 81 |
+
def test_near_match_3_to_5_lines(self):
|
| 82 |
+
# 4 lines away from GT at 6 → near
|
| 83 |
+
f = _issue(10, "utils.py", "bug", "high")
|
| 84 |
+
gt = _issue(6, "utils.py", "bug", "high")
|
| 85 |
+
assert match_quality(f, gt) == "near"
|
| 86 |
+
|
| 87 |
+
def test_near_match_exactly_3_lines(self):
|
| 88 |
+
f = _issue(9, "utils.py", "bug", "high")
|
| 89 |
+
gt = _issue(6, "utils.py", "bug", "high")
|
| 90 |
+
assert match_quality(f, gt) == "near"
|
| 91 |
+
|
| 92 |
+
def test_near_match_exactly_5_lines(self):
|
| 93 |
+
f = _issue(11, "utils.py", "bug", "high")
|
| 94 |
+
gt = _issue(6, "utils.py", "bug", "high")
|
| 95 |
+
assert match_quality(f, gt) == "near"
|
| 96 |
+
|
| 97 |
+
def test_no_match_beyond_5_lines(self):
|
| 98 |
+
f = _issue(12, "utils.py", "bug", "high")
|
| 99 |
+
gt = _issue(6, "utils.py", "bug", "high")
|
| 100 |
+
assert match_quality(f, gt) == "none"
|
| 101 |
+
|
| 102 |
+
def test_no_match_wrong_file(self):
|
| 103 |
+
f = _issue(6, "other.py", "bug", "high")
|
| 104 |
+
gt = _issue(6, "utils.py", "bug", "high")
|
| 105 |
+
assert match_quality(f, gt) == "none"
|
| 106 |
+
|
| 107 |
+
def test_near_ignores_type_difference(self):
|
| 108 |
+
"""Near match checks same file + line range, ignores type."""
|
| 109 |
+
f = _issue(10, "utils.py", "performance", "high")
|
| 110 |
+
gt = _issue(6, "utils.py", "bug", "high")
|
| 111 |
+
# 4 lines away → near
|
| 112 |
+
assert match_quality(f, gt) == "near"
|
| 113 |
+
|
| 114 |
+
def test_near_tolerance_constant(self):
|
| 115 |
+
assert NEAR_TOLERANCE == 5
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# ---------------------------------------------------------------------------
|
| 119 |
+
# compute_code_metadata()
|
| 120 |
+
# ---------------------------------------------------------------------------
|
| 121 |
+
|
| 122 |
+
class TestComputeCodeMetadata:
|
| 123 |
+
def test_returns_dict(self):
|
| 124 |
+
code = {"test.py": "def foo(): pass\n"}
|
| 125 |
+
result = compute_code_metadata(code)
|
| 126 |
+
assert isinstance(result, dict)
|
| 127 |
+
|
| 128 |
+
def test_total_lines(self):
|
| 129 |
+
code = {"test.py": "line1\nline2\nline3\n"}
|
| 130 |
+
result = compute_code_metadata(code)
|
| 131 |
+
assert result["total_lines"] == 3
|
| 132 |
+
|
| 133 |
+
def test_num_functions(self):
|
| 134 |
+
code = {"test.py": "def foo():\n pass\n\ndef bar():\n pass\n"}
|
| 135 |
+
result = compute_code_metadata(code)
|
| 136 |
+
assert result["num_functions"] == 2
|
| 137 |
+
|
| 138 |
+
def test_function_names(self):
|
| 139 |
+
code = {"test.py": "def foo():\n pass\n\ndef bar():\n pass\n"}
|
| 140 |
+
result = compute_code_metadata(code)
|
| 141 |
+
assert "foo" in result["function_names"]
|
| 142 |
+
assert "bar" in result["function_names"]
|
| 143 |
+
|
| 144 |
+
def test_num_classes(self):
|
| 145 |
+
code = {"test.py": "class Foo:\n pass\n\nclass Bar:\n pass\n"}
|
| 146 |
+
result = compute_code_metadata(code)
|
| 147 |
+
assert result["num_classes"] == 2
|
| 148 |
+
|
| 149 |
+
def test_class_names(self):
|
| 150 |
+
code = {"test.py": "class Foo:\n pass\n"}
|
| 151 |
+
result = compute_code_metadata(code)
|
| 152 |
+
assert "Foo" in result["class_names"]
|
| 153 |
+
|
| 154 |
+
def test_imports(self):
|
| 155 |
+
code = {"test.py": "import os\nimport sys\nfrom typing import List\n"}
|
| 156 |
+
result = compute_code_metadata(code)
|
| 157 |
+
assert "os" in result["imports"]
|
| 158 |
+
assert "sys" in result["imports"]
|
| 159 |
+
assert "typing" in result["imports"]
|
| 160 |
+
|
| 161 |
+
def test_complexity_low(self):
|
| 162 |
+
code = {"test.py": "def foo():\n return 1\n"}
|
| 163 |
+
result = compute_code_metadata(code)
|
| 164 |
+
assert result["complexity_estimate"] == "low"
|
| 165 |
+
|
| 166 |
+
def test_complexity_medium(self):
|
| 167 |
+
# 6-15 branches — each if is top-level so indent is fine
|
| 168 |
+
lines = ["def foo(x):"]
|
| 169 |
+
for i in range(8):
|
| 170 |
+
lines.append(f" if x > {i}:")
|
| 171 |
+
lines.append(" pass")
|
| 172 |
+
code = {"test.py": "\n".join(lines) + "\n"}
|
| 173 |
+
result = compute_code_metadata(code)
|
| 174 |
+
assert result["complexity_estimate"] in ("medium", "high")
|
| 175 |
+
|
| 176 |
+
def test_complexity_high(self):
|
| 177 |
+
# 16+ branches
|
| 178 |
+
lines = ["def foo(x):"]
|
| 179 |
+
for i in range(20):
|
| 180 |
+
lines.append(f" if x > {i}:")
|
| 181 |
+
lines.append(" pass")
|
| 182 |
+
code = {"test.py": "\n".join(lines) + "\n"}
|
| 183 |
+
result = compute_code_metadata(code)
|
| 184 |
+
assert result["complexity_estimate"] == "high"
|
| 185 |
+
|
| 186 |
+
def test_issue_categories_passed_through(self):
|
| 187 |
+
code = {"test.py": "x = 1\n"}
|
| 188 |
+
result = compute_code_metadata(code, issue_categories=["bug", "security", "bug"])
|
| 189 |
+
# Should deduplicate
|
| 190 |
+
cats = result["issue_categories"]
|
| 191 |
+
assert "bug" in cats
|
| 192 |
+
assert "security" in cats
|
| 193 |
+
|
| 194 |
+
def test_syntax_error_no_crash(self):
|
| 195 |
+
"""Non-parseable code should not raise."""
|
| 196 |
+
code = {"bad.py": "this is not valid python !!!\n def broken("}
|
| 197 |
+
result = compute_code_metadata(code)
|
| 198 |
+
assert "total_lines" in result
|
| 199 |
+
assert result["total_lines"] >= 1
|
| 200 |
+
|
| 201 |
+
def test_multi_file(self):
|
| 202 |
+
code = {
|
| 203 |
+
"a.py": "def foo():\n pass\n",
|
| 204 |
+
"b.py": "def bar():\n pass\n",
|
| 205 |
+
}
|
| 206 |
+
result = compute_code_metadata(code)
|
| 207 |
+
assert result["num_functions"] == 2
|
| 208 |
+
assert result["total_lines"] == 4
|
| 209 |
+
|
| 210 |
+
def test_utils_task_metadata(self):
|
| 211 |
+
from tasks.data import ALL_TASKS
|
| 212 |
+
task = ALL_TASKS["bug-detection"]
|
| 213 |
+
result = compute_code_metadata(task["code_files"])
|
| 214 |
+
assert result["total_lines"] > 0
|
| 215 |
+
assert result["num_functions"] >= 4 # utils.py has 4 functions
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
# ---------------------------------------------------------------------------
|
| 219 |
+
# grade_episode_detailed()
|
| 220 |
+
# ---------------------------------------------------------------------------
|
| 221 |
+
|
| 222 |
+
class TestGradeEpisodeDetailed:
|
| 223 |
+
def test_returns_dict(self):
|
| 224 |
+
gt = [_issue(6, "utils.py", "bug", "high")]
|
| 225 |
+
result = grade_episode_detailed(gt, gt)
|
| 226 |
+
assert isinstance(result, dict)
|
| 227 |
+
|
| 228 |
+
def test_required_keys(self):
|
| 229 |
+
gt = [_issue(6, "utils.py", "bug", "high")]
|
| 230 |
+
result = grade_episode_detailed(gt, gt)
|
| 231 |
+
for key in ("score", "f1", "precision", "recall", "severity_accuracy",
|
| 232 |
+
"true_positives", "false_positives", "false_negatives",
|
| 233 |
+
"near_misses", "per_file"):
|
| 234 |
+
assert key in result, f"Missing key: {key}"
|
| 235 |
+
|
| 236 |
+
def test_perfect_match(self):
|
| 237 |
+
gt = [_issue(6, "utils.py", "bug", "high")]
|
| 238 |
+
result = grade_episode_detailed(gt, gt)
|
| 239 |
+
assert result["true_positives"] == 1
|
| 240 |
+
assert result["false_positives"] == 0
|
| 241 |
+
assert result["false_negatives"] == 0
|
| 242 |
+
|
| 243 |
+
def test_false_positive_counted(self):
|
| 244 |
+
gt = [_issue(6, "utils.py", "bug", "high")]
|
| 245 |
+
flagged = [_issue(6, "utils.py", "bug", "high"),
|
| 246 |
+
_issue(100, "utils.py", "bug", "low")]
|
| 247 |
+
result = grade_episode_detailed(flagged, gt)
|
| 248 |
+
assert result["false_positives"] >= 1
|
| 249 |
+
|
| 250 |
+
def test_near_miss_counted(self):
|
| 251 |
+
gt = [_issue(6, "utils.py", "bug", "high")]
|
| 252 |
+
# 4 lines away = near miss
|
| 253 |
+
flagged = [_issue(10, "utils.py", "bug", "high")]
|
| 254 |
+
result = grade_episode_detailed(flagged, gt)
|
| 255 |
+
assert result["near_misses"] >= 1
|
| 256 |
+
|
| 257 |
+
def test_per_file_breakdown(self):
|
| 258 |
+
gt = [
|
| 259 |
+
_issue(6, "utils.py", "bug", "high"),
|
| 260 |
+
_issue(10, "other.py", "security", "critical"),
|
| 261 |
+
]
|
| 262 |
+
flagged = [_issue(6, "utils.py", "bug", "high")]
|
| 263 |
+
result = grade_episode_detailed(flagged, gt)
|
| 264 |
+
assert "utils.py" in result["per_file"]
|
| 265 |
+
|
| 266 |
+
def test_score_matches_grade_episode(self):
|
| 267 |
+
"""Detailed score should match grade_episode for simple cases."""
|
| 268 |
+
gt = [
|
| 269 |
+
_issue(6, "utils.py", "bug", "high"),
|
| 270 |
+
_issue(13, "utils.py", "bug", "medium"),
|
| 271 |
+
]
|
| 272 |
+
flagged = [_issue(6, "utils.py", "bug", "high")]
|
| 273 |
+
simple_score = grade_episode(flagged, gt)
|
| 274 |
+
detailed = grade_episode_detailed(flagged, gt)
|
| 275 |
+
# Scores may differ slightly (near_miss handling), but should be close
|
| 276 |
+
assert abs(detailed["score"] - simple_score) <= 0.15
|
| 277 |
+
|
| 278 |
+
def test_empty_ground_truth_perfect(self):
|
| 279 |
+
result = grade_episode_detailed([], [])
|
| 280 |
+
assert result["score"] == 1.0
|
| 281 |
+
|
| 282 |
+
def test_empty_flagged_zero(self):
|
| 283 |
+
gt = [_issue(6, "utils.py")]
|
| 284 |
+
result = grade_episode_detailed([], gt)
|
| 285 |
+
assert result["score"] == 0.0
|
| 286 |
+
assert result["false_negatives"] == 1
|
| 287 |
+
|
| 288 |
|
| 289 |
# ---------------------------------------------------------------------------
|
| 290 |
# grade_episode()
|
|
|
|
| 406 |
if task_id == "security-audit":
|
| 407 |
assert score > 0.0, f"Heuristic found nothing in {task_id}"
|
| 408 |
|
| 409 |
+
def test_baseline_finds_md5_in_pipeline(self):
|
| 410 |
+
"""Keyword baseline should find the MD5 issue in data-pipeline."""
|
| 411 |
+
from tasks.data import ALL_TASKS
|
| 412 |
+
task = ALL_TASKS["data-pipeline"]
|
| 413 |
+
findings = run_keyword_baseline(task)
|
| 414 |
+
md5_finds = [f for f in findings if "md5" in f.description.lower() or "MD5" in f.description]
|
| 415 |
+
assert len(md5_finds) >= 1
|
| 416 |
+
|
| 417 |
+
def test_baseline_finds_sql_injection_in_pipeline(self):
|
| 418 |
+
"""Keyword baseline should find SQL injection via f-string in pipeline.py."""
|
| 419 |
+
from tasks.data import ALL_TASKS
|
| 420 |
+
task = ALL_TASKS["data-pipeline"]
|
| 421 |
+
findings = run_keyword_baseline(task)
|
| 422 |
+
sql_finds = [f for f in findings if f.issue_type == "security"
|
| 423 |
+
and "sql" in f.description.lower()]
|
| 424 |
+
assert len(sql_finds) >= 1
|
| 425 |
+
|
| 426 |
|
| 427 |
# ---------------------------------------------------------------------------
|
| 428 |
# Ground truth sanity checks
|
|
|
|
| 459 |
files = {i["filename"] for i in task["ground_truth_issues"]}
|
| 460 |
assert "views.py" in files
|
| 461 |
assert "models.py" in files
|
| 462 |
+
|
| 463 |
+
def test_async_review_has_6_issues(self):
|
| 464 |
+
task = ALL_TASKS["async-review"]
|
| 465 |
+
assert len(task["ground_truth_issues"]) == 6
|
| 466 |
+
|
| 467 |
+
def test_data_pipeline_has_7_issues(self):
|
| 468 |
+
task = ALL_TASKS["data-pipeline"]
|
| 469 |
+
assert len(task["ground_truth_issues"]) == 7
|
| 470 |
+
|
| 471 |
+
def test_async_review_issues_in_async_py(self):
|
| 472 |
+
task = ALL_TASKS["async-review"]
|
| 473 |
+
for issue in task["ground_truth_issues"]:
|
| 474 |
+
assert issue["filename"] == "async.py"
|
| 475 |
+
|
| 476 |
+
def test_data_pipeline_issues_in_pipeline_py(self):
|
| 477 |
+
task = ALL_TASKS["data-pipeline"]
|
| 478 |
+
for issue in task["ground_truth_issues"]:
|
| 479 |
+
assert issue["filename"] == "pipeline.py"
|
| 480 |
+
|
| 481 |
+
def test_data_pipeline_has_security_and_performance(self):
|
| 482 |
+
task = ALL_TASKS["data-pipeline"]
|
| 483 |
+
types = {i["issue_type"] for i in task["ground_truth_issues"]}
|
| 484 |
+
assert "security" in types
|
| 485 |
+
assert "performance" in types
|
| 486 |
+
|
| 487 |
+
def test_async_review_has_bug_and_performance(self):
|
| 488 |
+
task = ALL_TASKS["async-review"]
|
| 489 |
+
types = {i["issue_type"] for i in task["ground_truth_issues"]}
|
| 490 |
+
assert "bug" in types
|
| 491 |
+
assert "performance" in types
|
| 492 |
+
|
| 493 |
+
def test_all_tasks_count(self):
|
| 494 |
+
assert len(ALL_TASKS) >= 6
|
| 495 |
+
|
| 496 |
+
def test_async_review_line_numbers_are_valid(self):
|
| 497 |
+
"""GT issue line numbers should be within the code file."""
|
| 498 |
+
from tasks.data import TASK_ASYNC_REVIEW
|
| 499 |
+
code = TASK_ASYNC_REVIEW["code_files"]["async.py"]
|
| 500 |
+
total_lines = len(code.splitlines())
|
| 501 |
+
for issue in TASK_ASYNC_REVIEW["ground_truth_issues"]:
|
| 502 |
+
assert 1 <= issue["line_number"] <= total_lines, (
|
| 503 |
+
f"Line {issue['line_number']} out of range (file has {total_lines} lines)"
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
def test_pipeline_line_numbers_are_valid(self):
|
| 507 |
+
"""GT issue line numbers should be within the code file."""
|
| 508 |
+
from tasks.data import TASK_DATA_PIPELINE
|
| 509 |
+
code = TASK_DATA_PIPELINE["code_files"]["pipeline.py"]
|
| 510 |
+
total_lines = len(code.splitlines())
|
| 511 |
+
for issue in TASK_DATA_PIPELINE["ground_truth_issues"]:
|
| 512 |
+
assert 1 <= issue["line_number"] <= total_lines, (
|
| 513 |
+
f"Line {issue['line_number']} out of range (file has {total_lines} lines)"
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
def test_api_security_has_8_issues(self):
|
| 517 |
+
from tasks.data import ALL_TASKS
|
| 518 |
+
task = ALL_TASKS["api-security"]
|
| 519 |
+
assert len(task["ground_truth_issues"]) == 8
|
| 520 |
+
|
| 521 |
+
def test_api_security_line_numbers_are_valid(self):
|
| 522 |
+
from tasks.data import ALL_TASKS
|
| 523 |
+
task = ALL_TASKS["api-security"]
|
| 524 |
+
code = task["code_files"]["api.py"]
|
| 525 |
+
total_lines = len(code.splitlines())
|
| 526 |
+
for issue in task["ground_truth_issues"]:
|
| 527 |
+
assert 1 <= issue["line_number"] <= total_lines, (
|
| 528 |
+
f"Line {issue['line_number']} out of range (file has {total_lines} lines)"
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
def test_api_security_has_security_issues(self):
|
| 532 |
+
from tasks.data import ALL_TASKS
|
| 533 |
+
task = ALL_TASKS["api-security"]
|
| 534 |
+
types = {i["issue_type"] for i in task["ground_truth_issues"]}
|
| 535 |
+
assert "security" in types
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
# ---------------------------------------------------------------------------
|
| 539 |
+
# compute_function_map and function_ranges in metadata
|
| 540 |
+
# ---------------------------------------------------------------------------
|
| 541 |
+
|
| 542 |
+
class TestFunctionRangesMetadata:
|
| 543 |
+
def test_function_ranges_in_metadata(self):
|
| 544 |
+
code = {"test.py": "def foo():\n return 1\n\ndef bar(x):\n return x\n"}
|
| 545 |
+
result = compute_code_metadata(code)
|
| 546 |
+
assert "function_ranges" in result
|
| 547 |
+
assert len(result["function_ranges"]) == 2
|
| 548 |
+
|
| 549 |
+
def test_function_ranges_have_correct_fields(self):
|
| 550 |
+
code = {"test.py": "def foo():\n return 1\n"}
|
| 551 |
+
result = compute_code_metadata(code)
|
| 552 |
+
fr = result["function_ranges"][0]
|
| 553 |
+
assert fr["name"] == "foo"
|
| 554 |
+
assert fr["file"] == "test.py"
|
| 555 |
+
assert "start" in fr
|
| 556 |
+
assert "end" in fr
|
| 557 |
+
assert fr["start"] <= fr["end"]
|
| 558 |
+
|
| 559 |
+
def test_function_ranges_empty_for_no_functions(self):
|
| 560 |
+
code = {"test.py": "x = 1\ny = 2\n"}
|
| 561 |
+
result = compute_code_metadata(code)
|
| 562 |
+
assert result["function_ranges"] == []
|
| 563 |
+
|
| 564 |
+
def test_function_ranges_multifile(self):
|
| 565 |
+
code = {
|
| 566 |
+
"a.py": "def foo():\n pass\n",
|
| 567 |
+
"b.py": "def bar():\n pass\n\ndef baz():\n pass\n",
|
| 568 |
+
}
|
| 569 |
+
result = compute_code_metadata(code)
|
| 570 |
+
names = {fr["name"] for fr in result["function_ranges"]}
|
| 571 |
+
assert names == {"foo", "bar", "baz"}
|
| 572 |
+
|
| 573 |
+
def test_function_ranges_correct_line_numbers(self):
|
| 574 |
+
code = {"test.py": "x = 1\n\ndef foo():\n return 1\n"}
|
| 575 |
+
result = compute_code_metadata(code)
|
| 576 |
+
assert len(result["function_ranges"]) == 1
|
| 577 |
+
assert result["function_ranges"][0]["start"] == 3 # line 3
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
# ---------------------------------------------------------------------------
|
| 581 |
+
# New keyword patterns
|
| 582 |
+
# ---------------------------------------------------------------------------
|
| 583 |
+
|
| 584 |
+
class TestNewKeywordPatterns:
|
| 585 |
+
def test_baseline_finds_hardcoded_admin_token(self):
|
| 586 |
+
from server.graders import run_keyword_baseline
|
| 587 |
+
from tasks.data import ALL_TASKS
|
| 588 |
+
task = ALL_TASKS["api-security"]
|
| 589 |
+
findings = run_keyword_baseline(task)
|
| 590 |
+
token_finds = [f for f in findings if "ADMIN_TOKEN" in f.description or "token" in f.description.lower()]
|
| 591 |
+
assert len(token_finds) >= 1
|
| 592 |
+
|
| 593 |
+
def test_baseline_finds_pickle_loads(self):
|
| 594 |
+
from server.graders import run_keyword_baseline
|
| 595 |
+
from tasks.data import ALL_TASKS
|
| 596 |
+
task = ALL_TASKS["api-security"]
|
| 597 |
+
findings = run_keyword_baseline(task)
|
| 598 |
+
pickle_finds = [f for f in findings if "pickle" in f.description.lower()]
|
| 599 |
+
assert len(pickle_finds) >= 1
|
| 600 |
+
|
| 601 |
+
def test_baseline_finds_os_system(self):
|
| 602 |
+
from server.graders import run_keyword_baseline
|
| 603 |
+
from tasks.data import ALL_TASKS
|
| 604 |
+
task = ALL_TASKS["api-security"]
|
| 605 |
+
findings = run_keyword_baseline(task)
|
| 606 |
+
sys_finds = [f for f in findings if "os.system" in f.description.lower() or "command" in f.description.lower()]
|
| 607 |
+
assert len(sys_finds) >= 1
|
| 608 |
+
|
| 609 |
+
def test_baseline_api_security_score_nonzero(self):
|
| 610 |
+
from server.graders import run_keyword_baseline, grade_episode
|
| 611 |
+
from models import Issue
|
| 612 |
+
from tasks.data import ALL_TASKS
|
| 613 |
+
task = ALL_TASKS["api-security"]
|
| 614 |
+
findings = run_keyword_baseline(task)
|
| 615 |
+
gt = [Issue.from_dict(i) for i in task["ground_truth_issues"]]
|
| 616 |
+
score = grade_episode(findings, gt)
|
| 617 |
+
assert score > 0.0, "Keyword baseline should find at least 1 issue in api-security"
|