codemaverick2 commited on
Commit
e48a1e4
·
1 Parent(s): 8b75c34

Add 7-task RL env with PBRS, CAMRL curriculum, VL norm, RC-GRPO inference

Browse files
Files changed (10) hide show
  1. README.md +109 -10
  2. inference.py +293 -51
  3. models.py +20 -0
  4. openenv.yaml +50 -3
  5. server/app.py +191 -14
  6. server/environment.py +414 -66
  7. server/graders.py +446 -6
  8. tasks/data.py +523 -0
  9. tests/test_environment.py +526 -0
  10. tests/test_graders.py +403 -1
README.md CHANGED
@@ -117,6 +117,71 @@ Comprehensive review of a Django e-commerce API across two files (`views.py`, `m
117
 
118
  **Max steps:** 30
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  ## Scoring
121
 
122
  ```
@@ -129,8 +194,36 @@ where:
129
  severity_accuracy = avg(1 − |flag_sev_rank − gt_sev_rank| × 0.34) for matched issues
130
 
131
  Matching tolerance: ±2 lines, same filename, compatible issue type
 
132
  ```
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  ## API Endpoints
135
 
136
  | Method | Endpoint | Description |
@@ -234,14 +327,18 @@ pytest tests/ -v
234
 
235
  ## Baseline Scores
236
 
237
- | Task | Keyword heuristic | GPT-4o-mini |
238
- |------|-------------------|-------------|
239
- | bug-detection | 1.00 | ~0.52 |
240
- | security-audit | 0.75 | ~0.59 |
241
- | comprehensive-review | 0.67 | ~0.17 |
242
- | **Overall** | **0.81** | **~0.43** |
 
 
 
 
243
 
244
- Keyword heuristic runs via `inference.py` with no API key. LLM scores use `API_BASE_URL` + `HF_TOKEN`.
245
 
246
  ## Project Structure
247
 
@@ -258,11 +355,13 @@ code-review-env/
258
  ├── client.py ← HTTP client
259
  ├── models.py ← ReviewAction, ReviewObservation, ReviewState, Issue
260
  ├── tasks/
261
- │ └── data.py ← 3 task definitions + ground truth
 
 
262
  ├── server/
263
  │ ├── app.py ← FastAPI application
264
- │ ├── environment.py ← Core environment logic
265
- │ └── graders.py ← F1 grading + keyword baseline
266
  └── tests/
267
  ├── test_environment.py
268
  └── test_graders.py
 
117
 
118
  **Max steps:** 30
119
 
120
+ ### Task 4: `async-review` — Medium-Hard
121
+
122
+ Review an async Python module (`async.py`) for concurrency bugs, resource leaks, and performance issues with `asyncio` and `aiohttp`.
123
+
124
+ | Line | Issue | Severity |
125
+ |------|-------|----------|
126
+ | 5 | Shared mutable cache dict without `asyncio.Lock` — race condition | High |
127
+ | 9 | `timeout=5` wrong type for aiohttp; requires `ClientTimeout(total=5)` | Medium |
128
+ | 22 | `ClientSession` created but never closed — resource leak | High |
129
+ | 24 | Sequential `await` in loop — use `asyncio.gather()` for concurrency | High |
130
+ | 37 | Off-by-one in retry condition: `attempt == retries` never true | High |
131
+ | 48 | Tasks awaited sequentially; `self.results` accumulates across calls | Medium |
132
+
133
+ **Max steps:** 20
134
+
135
+ ### Task 5: `data-pipeline` — Hard
136
+
137
+ Security and correctness audit of a SQLite data pipeline module (`pipeline.py`).
138
+
139
+ | Line | Issue | Severity |
140
+ |------|-------|----------|
141
+ | 20 | MD5 for password hashing — cryptographically broken | High |
142
+ | 27 | SQL injection via f-string in `INSERT` query | Critical |
143
+ | 35 | SQL injection via f-string in `LIKE` query | Critical |
144
+ | 41 | One transaction per row in `bulk_load` — severe performance issue | High |
145
+ | 46 | `float()` conversion without error handling — crashes on bad input | Medium |
146
+ | 52 | `export_records` leaks `password_hash` field in JSON output | High |
147
+ | 59 | SQL injection: `limit` interpolated into `LIMIT` clause | Critical |
148
+
149
+ **Max steps:** 25
150
+
151
+ ### Task 6: `api-security` — Hard
152
+
153
+ Security audit of a FastAPI REST API (`api.py`) with authentication, authorization, and injection vulnerabilities.
154
+
155
+ | Line | Issue | Severity |
156
+ |------|-------|----------|
157
+ | 12 | Hardcoded `SECRET_KEY` in source | High |
158
+ | 13 | Hardcoded `ADMIN_TOKEN` in source | High |
159
+ | 16 | MD5 for password hashing | High |
160
+ | 27 | JWT issued without `exp` expiry claim | Medium |
161
+ | 33 | IDOR — any user can fetch any other user's data | Critical |
162
+ | 38 | SQL injection via f-string in `SELECT` query | Critical |
163
+ | 47 | Command injection via `os.system()` with env-interpolated path | Critical |
164
+ | 53 | `pickle.loads()` on untrusted user bytes — RCE | Critical |
165
+
166
+ **Max steps:** 25
167
+
168
+ ### Task 7: `js-security` — Hard
169
+
170
+ Security audit of an Express.js REST API (`server.js`) in JavaScript/Node.js.
171
+
172
+ | Line | Issue | Severity |
173
+ |------|-------|----------|
174
+ | 11 | Hardcoded `JWT_SECRET` in source | High |
175
+ | 16 | SQL injection via template literal in `prepare()` | Critical |
176
+ | 18 | JWT issued without `expiresIn` — tokens valid forever | Medium |
177
+ | 25 | IDOR + SQL injection: unauthenticated user access + unparameterized query | Critical |
178
+ | 31 | XSS: user query param reflected directly in HTML response | High |
179
+ | 36 | Command injection via `execSync()` with user-supplied filename | Critical |
180
+ | 42 | Path traversal: `path.join` with user-supplied filename | High |
181
+ | 48 | `new Function()` with user template — arbitrary code execution | Critical |
182
+
183
+ **Max steps:** 25
184
+
185
  ## Scoring
186
 
187
  ```
 
194
  severity_accuracy = avg(1 − |flag_sev_rank − gt_sev_rank| × 0.34) for matched issues
195
 
196
  Matching tolerance: ±2 lines, same filename, compatible issue type
197
+ Near-miss (±3-5 lines): graduated partial credit via exponential decay
198
  ```
199
 
200
+ ## Reward Design
201
+
202
+ ### Per-step rewards
203
+
204
+ | Event | Reward |
205
+ |-------|--------|
206
+ | True positive (TP) | +0.10 base |
207
+ | TP + severity exact match | +0.02 bonus |
208
+ | TP + early (first 40% of steps) | +0.02 bonus |
209
+ | TP + high confidence (≥0.7) | +0.01 bonus |
210
+ | PBRS potential shaping (Φ(s')−Φ(s)) | +0.03–0.08 |
211
+ | Near-miss (±3-5 lines, exponential decay) | +0.020–0.055 |
212
+ | False positive | −0.05 |
213
+ | False positive flood (4th+ FP) | escalating −0.03 extra |
214
+ | High-confidence FP | −0.03 extra |
215
+ | Clear TP | −0.03 |
216
+ | Clear FP | +0.03 |
217
+ | Hint | −0.01 |
218
+ | Submit / auto-end | Final F1 score |
219
+
220
+ ### Reward shaping foundations
221
+
222
+ - **Potential-Based Reward Shaping** (Ng et al. 1999): Φ(s) = (tp/total_gt) × 0.5. Policy-invariant shaping that improves sample efficiency without changing the optimal policy.
223
+ - **Graduated near-miss** (exponential decay): reward = 0.10 × e^(−0.6 × (line_diff − 2)) for lines 3-5 off. Gives smooth gradient signal for line-number refinement.
224
+ - **Variable-Length Return Normalization** (VL Norm 2025): normalized_return = cumulative_reward / steps_used. Makes return comparable across tasks of different lengths.
225
+ - **Flood protection**: escalating FP penalty prevents reward hacking via flag-spamming.
226
+
227
  ## API Endpoints
228
 
229
  | Method | Endpoint | Description |
 
327
 
328
  ## Baseline Scores
329
 
330
+ | Task | Keyword heuristic |
331
+ |------|-------------------|
332
+ | bug-detection | 1.00 |
333
+ | security-audit | 0.75 |
334
+ | async-review | 0.71 |
335
+ | comprehensive-review | 0.66 |
336
+ | api-security | 0.83 |
337
+ | js-security | 0.70 |
338
+ | data-pipeline | 0.55 |
339
+ | **Overall (7 tasks)** | **0.74** |
340
 
341
+ Keyword heuristic runs via `inference.py` with no API key (uses `/baseline` endpoint). LLM scores use `API_BASE_URL` + `HF_TOKEN`.
342
 
343
  ## Project Structure
344
 
 
355
  ├── client.py ← HTTP client
356
  ├── models.py ← ReviewAction, ReviewObservation, ReviewState, Issue
357
  ├── tasks/
358
+ │ └── data.py ← 5 task definitions + ground truth
359
+ │ (bug-detection, security-audit, comprehensive-review,
360
+ │ async-review, data-pipeline)
361
  ├── server/
362
  │ ├── app.py ← FastAPI application
363
+ │ ├── environment.py ← Core environment logic (adaptive hints, rich rewards)
364
+ │ └── graders.py ← F1 grading + detailed grading + keyword baseline
365
  └── tests/
366
  ├── test_environment.py
367
  └── test_graders.py
inference.py CHANGED
@@ -4,7 +4,7 @@ Inference script for the Code Review Environment.
4
  Environment variables:
5
  API_BASE_URL — LLM API endpoint (e.g. https://openrouter.ai/api/v1)
6
  MODEL_NAME — Model identifier (e.g. openai/gpt-4o-mini)
7
- HF_TOKEN — API key for the LLM provider
8
  ENV_URL — Environment base URL (default: localhost:7860)
9
 
10
  Usage:
@@ -19,6 +19,7 @@ import os
19
  import sys
20
  import json
21
  import time
 
22
 
23
  import httpx
24
 
@@ -27,24 +28,76 @@ MODEL_NAME: str = os.environ.get("MODEL_NAME", "gpt-4o-mini")
27
  HF_TOKEN: str = os.environ.get("HF_TOKEN") or os.environ.get("OPENAI_API_KEY", "")
28
  ENV_URL: str = os.environ.get("ENV_URL", "http://localhost:7860").rstrip("/")
29
 
30
- TASK_IDS = ["bug-detection", "security-audit", "comprehensive-review"]
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  SYSTEM_PROMPT = """\
33
- You are an expert software engineer performing a thorough code review.
34
-
35
- Your job is to identify bugs, security vulnerabilities, and performance issues in code.
36
-
37
- For each issue you find, respond with a single JSON object:
38
- {"action_type": "flag_issue", "line_number": <int>, "filename": "<file>", "issue_type": "bug|security|performance|logic", "severity": "low|medium|high|critical", "description": "<explanation>", "fix_suggestion": "<fix>"}
39
-
40
- When done, respond with:
41
- {"action_type": "submit_review"}
42
-
43
- Rules:
44
- - Respond with raw JSON only — no markdown fences, no extra text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  - One action per response
46
- - Be precise with line numbers (count from line 1)
47
- - Only flag real issues, not style preferences
 
 
48
  """
49
 
50
 
@@ -59,13 +112,17 @@ def chat_completion(messages: list) -> str:
59
  kwargs["base_url"] = API_BASE_URL
60
 
61
  client = OpenAI(**kwargs)
62
- response = client.chat.completions.create(
63
- model=MODEL_NAME,
64
- messages=messages,
65
- temperature=0.0,
66
- max_tokens=400,
67
- )
68
- return response.choices[0].message.content.strip()
 
 
 
 
69
 
70
 
71
  def parse_action(text: str) -> dict:
@@ -100,45 +157,217 @@ def parse_action(text: str) -> dict:
100
 
101
  def run_keyword_fallback(base_url: str, task_id: str) -> dict:
102
  """Fallback: use the built-in /baseline endpoint (no LLM needed)."""
103
- with httpx.Client(timeout=30) as client:
104
- resp = client.post(f"{base_url}/baseline")
105
- resp.raise_for_status()
106
- results = resp.json()
107
- score = results["baseline_scores"].get(task_id, {}).get("score", 0.0)
108
- return {"task_id": task_id, "score": score, "steps": 0, "method": "keyword_heuristic"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
 
111
  def run_task(task_id: str, http_client: httpx.Client) -> dict:
112
- resp = http_client.post(f"{ENV_URL}/reset", json={"task_id": task_id}, timeout=30)
113
- resp.raise_for_status()
114
- obs = resp.json()
 
 
 
 
115
 
116
  code_display = "\n\n".join(
117
- f"=== {fname} ===\n{code}"
118
  for fname, code in obs.get("code_files", {}).items()
119
  )
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  messages = [
122
  {"role": "system", "content": SYSTEM_PROMPT},
123
  {
124
  "role": "user",
125
  "content": (
126
- f"Task: {obs.get('task_description', '')}\n\n"
127
- f"{code_display}\n\n"
128
- f"Review this code carefully. Flag every issue you find. "
129
- f"You have {obs.get('max_steps', 20)} steps total."
 
 
 
 
130
  ),
131
  },
132
  ]
133
 
134
  done = False
135
  step_count = 0
136
- max_steps = obs.get("max_steps", 20)
137
  final_score = 0.0
 
 
 
138
 
139
  while not done and step_count < max_steps:
140
- action_text = chat_completion(messages)
141
- action = parse_action(action_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
  try:
144
  step_resp = http_client.post(f"{ENV_URL}/step", json=action, timeout=30)
@@ -150,20 +379,33 @@ def run_task(task_id: str, http_client: httpx.Client) -> dict:
150
 
151
  done = obs.get("done", False)
152
  step_count += 1
153
- final_score = obs.get("current_score", 0.0)
154
- reward = obs.get("reward")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  messages.append({"role": "assistant", "content": action_text})
157
- messages.append({
158
- "role": "user",
159
- "content": (
160
- f"Feedback: {obs.get('feedback', '')} "
161
- f"(step {step_count}/{max_steps}, score: {obs.get('current_score', 0.0):.3f})"
162
- ),
163
- })
164
 
165
  atype = action.get("action_type", "")
166
- print(f" Step {step_count:2d}: {atype:20s} | reward={str(reward):8s} | score={obs.get('current_score', 0.0):.3f}")
167
 
168
  if atype == "submit_review":
169
  final_score = obs.get("reward", obs.get("current_score", 0.0)) or 0.0
@@ -205,7 +447,7 @@ def main():
205
  print(f"Running task: {task_id}")
206
  result = run_task(task_id, client)
207
  results[task_id] = result
208
- print(f" → score: {result['score']:.4f} ({result['steps']} steps)\n")
209
  else:
210
  print("HF_TOKEN / API_BASE_URL not set — using built-in keyword heuristic baseline.\n")
211
  for task_id in TASK_IDS:
 
4
  Environment variables:
5
  API_BASE_URL — LLM API endpoint (e.g. https://openrouter.ai/api/v1)
6
  MODEL_NAME — Model identifier (e.g. openai/gpt-4o-mini)
7
+ HF_TOKEN — API key for the LLM provider (also accepts OPENAI_API_KEY)
8
  ENV_URL — Environment base URL (default: localhost:7860)
9
 
10
  Usage:
 
19
  import sys
20
  import json
21
  import time
22
+ from typing import Optional
23
 
24
  import httpx
25
 
 
28
  HF_TOKEN: str = os.environ.get("HF_TOKEN") or os.environ.get("OPENAI_API_KEY", "")
29
  ENV_URL: str = os.environ.get("ENV_URL", "http://localhost:7860").rstrip("/")
30
 
31
+ # Curriculum ordering: easy → medium → medium-hard → hard
32
+ # Research (CAMRL, Curriculum RL): start with simpler tasks to build
33
+ # foundational skills, progress to harder multi-file and multi-language tasks.
34
+ TASK_IDS = [
35
+ "bug-detection", # easy: pure logic bugs, single file
36
+ "security-audit", # medium: OWASP Top-10, single file
37
+ "async-review", # medium-hard: async concurrency, subtle bugs
38
+ "data-pipeline", # hard: SQL injection + crypto + performance
39
+ "comprehensive-review", # hard: multi-file Django, mixed issue types
40
+ "api-security", # hard: FastAPI auth/authz/injection
41
+ "js-security", # hard: JavaScript (cross-language generalization)
42
+ ]
43
 
44
  SYSTEM_PROMPT = """\
45
+ You are an expert software engineer performing a thorough, methodical code review.
46
+
47
+ Your mission: identify ALL real bugs, security vulnerabilities, and performance issues.
48
+
49
+ ## REVIEW CHECKLIST work through EVERY category for EVERY function:
50
+
51
+ ### Security (check EVERY function for these)
52
+ - Hardcoded secrets / API keys / passwords / tokens
53
+ - SQL injection: f-strings/template literals/string concat in queries
54
+ - Command injection: shell=True, os.system(), execSync() with user input
55
+ - XSS: unsanitized user input in HTML templates / res.send()
56
+ - Path traversal: path.join/os.path.join with user-supplied paths
57
+ - IDOR: missing authorization — authenticated vs authorized
58
+ - Insecure deserialization: pickle.loads(), new Function(), eval() on user input
59
+ - Broken crypto: MD5/SHA1 for passwords; missing salt; weak PRNG
60
+ - JWT issues: missing expiry ('exp'), algorithm confusion, hardcoded secret
61
+ - Missing authentication on sensitive endpoints
62
+
63
+ ### Bugs & Logic Errors (check EVERY function for these)
64
+ - Off-by-one errors in ranges, slices, loop bounds, retry conditions
65
+ - Wrong initial values (counters starting at 0 instead of 1)
66
+ - Race conditions (shared mutable state without locks/atomicity)
67
+ - Missing transaction atomicity (partial writes to DB)
68
+ - Wrong type arguments (int where object required, e.g. aiohttp timeout)
69
+ - State that accumulates across calls (class fields not reset)
70
+
71
+ ### Performance (check EVERY loop and DB call)
72
+ - N+1 database queries (DB call inside a loop)
73
+ - Sequential async where gather() should be used
74
+ - One transaction per row in bulk operations
75
+ - Uncapped pagination (no max limit on per_page)
76
+
77
+ ### Resource Management
78
+ - Unclosed sessions/connections/file handles
79
+ - Missing context managers (async with, with)
80
+
81
+ ## RESPONSE FORMAT
82
+
83
+ For each issue you find, respond with ONE raw JSON object:
84
+ {"action_type": "flag_issue", "line_number": <int>, "filename": "<file>",
85
+ "issue_type": "bug|security|performance|logic",
86
+ "severity": "low|medium|high|critical",
87
+ "description": "<specific explanation>",
88
+ "fix_suggestion": "<concrete fix>",
89
+ "confidence": <0.0-1.0>}
90
+
91
+ When finished, respond with:
92
+ {"action_type": "submit_review"}
93
+
94
+ ## RULES
95
+ - Raw JSON only — no markdown fences, no extra text
96
  - One action per response
97
+ - Count lines carefully from line 1 (including blank lines and comments)
98
+ - Only flag REAL issues no style preferences, no hypothetical issues
99
+ - Be precise: "SQL injection at line 19 via f-string in SELECT query" not just "SQL injection"
100
+ - Flag the EXACT line where the problem code is (the f-string line, not the function def)
101
  """
102
 
103
 
 
112
  kwargs["base_url"] = API_BASE_URL
113
 
114
  client = OpenAI(**kwargs)
115
+ try:
116
+ response = client.chat.completions.create(
117
+ model=MODEL_NAME,
118
+ messages=messages,
119
+ temperature=0.0,
120
+ max_tokens=500,
121
+ )
122
+ return response.choices[0].message.content.strip()
123
+ except Exception as e:
124
+ print(f" LLM call failed: {e}")
125
+ raise
126
 
127
 
128
  def parse_action(text: str) -> dict:
 
157
 
158
  def run_keyword_fallback(base_url: str, task_id: str) -> dict:
159
  """Fallback: use the built-in /baseline endpoint (no LLM needed)."""
160
+ try:
161
+ with httpx.Client(timeout=30) as client:
162
+ resp = client.post(f"{base_url}/baseline")
163
+ resp.raise_for_status()
164
+ results = resp.json()
165
+ score = results["baseline_scores"].get(task_id, {}).get("score", 0.0)
166
+ return {"task_id": task_id, "score": score, "steps": 0, "method": "keyword_heuristic"}
167
+ except Exception as e:
168
+ print(f" Keyword fallback failed: {e}")
169
+ return {"task_id": task_id, "score": 0.0, "steps": 0, "method": "error"}
170
+
171
+
172
+ def _build_progress_feedback(obs: dict) -> str:
173
+ """Build a rich feedback string from observation progress data."""
174
+ progress = obs.get("progress") or {}
175
+ flagged_summary = obs.get("flagged_summary") or {}
176
+
177
+ parts = []
178
+ if progress:
179
+ f1 = progress.get("f1", 0)
180
+ precision = progress.get("precision", 0)
181
+ recall = progress.get("recall", 0)
182
+ tp = int(progress.get("true_positives", 0))
183
+ total_gt = int(progress.get("total_ground_truth", 0))
184
+ steps_rem = int(progress.get("steps_remaining", 0))
185
+ unfound = progress.get("unfound_issue_types", [])
186
+
187
+ parts.append(
188
+ f"Score progress: {tp}/{total_gt} issues confirmed | "
189
+ f"F1={f1:.2f} Precision={precision:.2f} Recall={recall:.2f} | "
190
+ f"{steps_rem} steps remaining"
191
+ )
192
+ if unfound:
193
+ parts.append(
194
+ f"IMPORTANT — still need to find: {unfound}. "
195
+ f"Search specifically for those issue types."
196
+ )
197
+
198
+ if flagged_summary:
199
+ incorrect = flagged_summary.get("incorrect", 0)
200
+ near = flagged_summary.get("near_misses", 0)
201
+ if incorrect > 0:
202
+ parts.append(
203
+ f"WARNING: {incorrect} false positive(s) hurting precision. "
204
+ f"Consider using clear_flag to remove uncertain flags."
205
+ )
206
+ if near > 0:
207
+ parts.append(
208
+ f"NOTE: {near} near-miss(es) — you're close on line numbers, "
209
+ f"but slightly off. Re-check exact line and try reflagging."
210
+ )
211
+
212
+ return "\n".join(parts) if parts else ""
213
+
214
+
215
+ def _should_submit(obs: dict, step_count: int, max_steps: int) -> bool:
216
+ """
217
+ Smart submission: submit when recall is high or steps are nearly exhausted.
218
+ Avoids wasting steps after all real issues are found.
219
+ """
220
+ progress = obs.get("progress", {})
221
+ recall = progress.get("recall", 0.0)
222
+ tp = int(progress.get("true_positives", 0))
223
+ total_gt = int(progress.get("total_ground_truth", 0))
224
+ steps_rem = int(progress.get("steps_remaining", 0))
225
+ unfound = progress.get("unfound_issue_types", [])
226
+ fp = int(progress.get("false_positives", 0))
227
+
228
+ # All issues found
229
+ if total_gt > 0 and tp >= total_gt:
230
+ return True
231
+
232
+ # No unfound categories and high recall
233
+ if not unfound and recall >= 0.85:
234
+ return True
235
+
236
+ # High recall overall (≥80%) and precision is decent (not too many FPs)
237
+ if recall >= 0.80 and (fp <= 2 or tp / max(tp + fp, 1) >= 0.6):
238
+ return True
239
+
240
+ # Very few steps left and we've done a reasonable scan
241
+ if steps_rem <= 2 and step_count >= 5:
242
+ return True
243
+
244
+ return False
245
+
246
+
247
+ def _should_clear_flag(obs: dict, last_reward: float, last_action: dict) -> Optional[dict]:
248
+ """
249
+ Recovery strategy: if the last flag was a false positive with high penalty,
250
+ suggest clearing it to recover partial reward and improve precision.
251
+
252
+ Returns a clear_flag action dict if we should recover, else None.
253
+ """
254
+ if last_reward is None or last_reward >= 0:
255
+ return None
256
+ if last_action.get("action_type") != "flag_issue":
257
+ return None
258
+
259
+ # Only clear if it was a clear FP (no near-miss indicator in feedback)
260
+ # and we've got too many false positives
261
+ progress = obs.get("progress", {})
262
+ fp = int(progress.get("false_positives", 0))
263
+ tp = int(progress.get("true_positives", 0))
264
+
265
+ # If FP > TP and last reward was notably negative, clear the bad flag
266
+ if fp > tp and last_reward <= -0.05:
267
+ return {
268
+ "action_type": "clear_flag",
269
+ "line_number": last_action.get("line_number"),
270
+ "filename": last_action.get("filename"),
271
+ }
272
+
273
+ return None
274
 
275
 
276
  def run_task(task_id: str, http_client: httpx.Client) -> dict:
277
+ try:
278
+ resp = http_client.post(f"{ENV_URL}/reset", json={"task_id": task_id}, timeout=30)
279
+ resp.raise_for_status()
280
+ obs = resp.json()
281
+ except Exception as e:
282
+ print(f" Reset failed: {e} — falling back to keyword heuristic")
283
+ return run_keyword_fallback(ENV_URL, task_id)
284
 
285
  code_display = "\n\n".join(
286
+ f"=== {fname} (starting at line 1) ===\n{code}"
287
  for fname, code in obs.get("code_files", {}).items()
288
  )
289
 
290
+ # Include function map hint if available
291
+ code_metadata = obs.get("code_metadata") or {}
292
+ function_ranges = code_metadata.get("function_ranges", [])
293
+ fn_map_hint = ""
294
+ if function_ranges:
295
+ fn_lines = [f" {fr['name']}() in {fr['file']} (lines {fr['start']}-{fr['end']})"
296
+ for fr in function_ranges]
297
+ fn_map_hint = "\n\nFunction map:\n" + "\n".join(fn_lines)
298
+
299
+ task_desc = obs.get("task_description", "")
300
+ max_steps = obs.get("max_steps", 20)
301
+ issue_categories = code_metadata.get("issue_categories", [])
302
+ n_gt = len(obs.get("code_files", {})) # rough complexity hint
303
+ category_hint = ""
304
+ if issue_categories:
305
+ category_hint = f"\nIssue categories to look for: {sorted(set(issue_categories))}"
306
+
307
+ # RC-GRPO style reward conditioning (2025): tell the agent what quality level
308
+ # it should aim for, so it calibrates confidence appropriately.
309
+ state_features = code_metadata.get("state_features", [])
310
+ complexity_label = "medium"
311
+ if state_features and len(state_features) >= 4:
312
+ complexity_score = state_features[3]
313
+ complexity_label = "high" if complexity_score >= 1.0 else "medium" if complexity_score >= 0.5 else "low"
314
+
315
+ reward_conditioning = (
316
+ f"[TARGET: high-quality review, score ≥ 0.85. "
317
+ f"Code complexity: {complexity_label}. "
318
+ f"Be thorough — missing issues costs more than a single FP.]"
319
+ )
320
+
321
  messages = [
322
  {"role": "system", "content": SYSTEM_PROMPT},
323
  {
324
  "role": "user",
325
  "content": (
326
+ f"{reward_conditioning}\n\n"
327
+ f"Task: {task_desc}\n\n"
328
+ f"{code_display}"
329
+ f"{fn_map_hint}"
330
+ f"{category_hint}\n\n"
331
+ f"You have {max_steps} steps total. "
332
+ f"Work through the checklist systematically, function by function. "
333
+ f"Flag each issue one at a time as a raw JSON object."
334
  ),
335
  },
336
  ]
337
 
338
  done = False
339
  step_count = 0
 
340
  final_score = 0.0
341
+ last_action: dict = {}
342
+ last_reward: Optional[float] = None
343
+ consecutive_fp = 0
344
 
345
  while not done and step_count < max_steps:
346
+ # --- Auto clear_flag recovery: undo recent FP if hurting precision ---
347
+ recovery_action = _should_clear_flag(obs, last_reward, last_action)
348
+ if recovery_action and step_count < max_steps - 1:
349
+ action = recovery_action
350
+ action_text = json.dumps(action)
351
+ print(f" Auto-recovery: clearing FP at {action.get('filename')}:{action.get('line_number')}")
352
+ else:
353
+ # --- Normal LLM action ---
354
+ try:
355
+ action_text = chat_completion(messages)
356
+ except Exception as e:
357
+ print(f" LLM unavailable ({e}) — submitting and falling back to keyword heuristic")
358
+ try:
359
+ http_client.post(f"{ENV_URL}/step", json={"action_type": "submit_review"}, timeout=30)
360
+ except Exception:
361
+ pass
362
+ return run_keyword_fallback(ENV_URL, task_id)
363
+
364
+ action = parse_action(action_text)
365
+
366
+ # Smart submission: inject submit if progress shows we're done
367
+ if action.get("action_type") != "submit_review" and _should_submit(obs, step_count, max_steps):
368
+ print(f" Smart submit at step {step_count + 1} (recall target met)")
369
+ action = {"action_type": "submit_review"}
370
+ action_text = json.dumps(action)
371
 
372
  try:
373
  step_resp = http_client.post(f"{ENV_URL}/step", json=action, timeout=30)
 
379
 
380
  done = obs.get("done", False)
381
  step_count += 1
382
+ last_reward = obs.get("reward")
383
+ # Use terminal reward (final grade) when done, else intermediate score
384
+ if done:
385
+ final_score = last_reward or obs.get("current_score", 0.0)
386
+ else:
387
+ final_score = obs.get("current_score", 0.0)
388
+ last_action = action
389
+
390
+ # Track consecutive FPs for logging
391
+ if last_reward is not None and last_reward < 0 and action.get("action_type") == "flag_issue":
392
+ consecutive_fp += 1
393
+ else:
394
+ consecutive_fp = 0
395
+
396
+ # Build rich feedback for next LLM turn
397
+ progress_feedback = _build_progress_feedback(obs)
398
+ env_feedback = obs.get("feedback", "")
399
+ combined_feedback = env_feedback
400
+ if progress_feedback:
401
+ combined_feedback += f"\n{progress_feedback}"
402
 
403
  messages.append({"role": "assistant", "content": action_text})
404
+ if combined_feedback:
405
+ messages.append({"role": "user", "content": combined_feedback})
 
 
 
 
 
406
 
407
  atype = action.get("action_type", "")
408
+ print(f" Step {step_count:2d}: {atype:20s} | reward={str(last_reward):8s} | score={obs.get('current_score', 0.0):.3f}")
409
 
410
  if atype == "submit_review":
411
  final_score = obs.get("reward", obs.get("current_score", 0.0)) or 0.0
 
447
  print(f"Running task: {task_id}")
448
  result = run_task(task_id, client)
449
  results[task_id] = result
450
+ print(f" → score: {result['score']:.4f} ({result['steps']} steps, method={result['method']})\n")
451
  else:
452
  print("HF_TOKEN / API_BASE_URL not set — using built-in keyword heuristic baseline.\n")
453
  for task_id in TASK_IDS:
models.py CHANGED
@@ -80,6 +80,8 @@ class ReviewAction(_BaseAction):
80
  severity: Optional[str] = None
81
  description: str = ""
82
  fix_suggestion: Optional[str] = None
 
 
83
  metadata: Dict[str, Any] = field(default_factory=dict)
84
 
85
  def to_dict(self) -> dict:
@@ -91,6 +93,8 @@ class ReviewAction(_BaseAction):
91
  "severity": self.severity,
92
  "description": self.description,
93
  "fix_suggestion": self.fix_suggestion,
 
 
94
  }
95
 
96
  @classmethod
@@ -103,6 +107,8 @@ class ReviewAction(_BaseAction):
103
  severity=d.get("severity"),
104
  description=str(d.get("description", "")),
105
  fix_suggestion=d.get("fix_suggestion"),
 
 
106
  )
107
 
108
 
@@ -125,6 +131,11 @@ class ReviewObservation(_BaseObservation):
125
  done: bool = False
126
  reward: Optional[float] = None
127
  metadata: Dict[str, Any] = field(default_factory=dict)
 
 
 
 
 
128
 
129
  def to_dict(self) -> dict:
130
  return {
@@ -141,6 +152,10 @@ class ReviewObservation(_BaseObservation):
141
  "done": self.done,
142
  "reward": self.reward,
143
  "metadata": self.metadata,
 
 
 
 
144
  }
145
 
146
  @classmethod
@@ -158,6 +173,11 @@ class ReviewObservation(_BaseObservation):
158
  current_score=d.get("current_score", 0.0),
159
  done=d.get("done", False),
160
  reward=d.get("reward"),
 
 
 
 
 
161
  )
162
 
163
 
 
80
  severity: Optional[str] = None
81
  description: str = ""
82
  fix_suggestion: Optional[str] = None
83
+ confidence: Optional[float] = None # agent's confidence 0.0–1.0
84
+ related_lines: Optional[List[int]] = None # multi-line issues
85
  metadata: Dict[str, Any] = field(default_factory=dict)
86
 
87
  def to_dict(self) -> dict:
 
93
  "severity": self.severity,
94
  "description": self.description,
95
  "fix_suggestion": self.fix_suggestion,
96
+ "confidence": self.confidence,
97
+ "related_lines": self.related_lines,
98
  }
99
 
100
  @classmethod
 
107
  severity=d.get("severity"),
108
  description=str(d.get("description", "")),
109
  fix_suggestion=d.get("fix_suggestion"),
110
+ confidence=d.get("confidence"),
111
+ related_lines=d.get("related_lines"),
112
  )
113
 
114
 
 
131
  done: bool = False
132
  reward: Optional[float] = None
133
  metadata: Dict[str, Any] = field(default_factory=dict)
134
+ # New fields
135
+ reward_breakdown: Dict[str, float] = field(default_factory=dict)
136
+ progress: Dict[str, float] = field(default_factory=dict)
137
+ flagged_summary: Dict[str, Any] = field(default_factory=dict)
138
+ code_metadata: Dict[str, Any] = field(default_factory=dict)
139
 
140
  def to_dict(self) -> dict:
141
  return {
 
152
  "done": self.done,
153
  "reward": self.reward,
154
  "metadata": self.metadata,
155
+ "reward_breakdown": self.reward_breakdown,
156
+ "progress": self.progress,
157
+ "flagged_summary": self.flagged_summary,
158
+ "code_metadata": self.code_metadata,
159
  }
160
 
161
  @classmethod
 
173
  current_score=d.get("current_score", 0.0),
174
  done=d.get("done", False),
175
  reward=d.get("reward"),
176
+ metadata=d.get("metadata", {}),
177
+ reward_breakdown=d.get("reward_breakdown", {}),
178
+ progress=d.get("progress", {}),
179
+ flagged_summary=d.get("flagged_summary", {}),
180
+ code_metadata=d.get("code_metadata", {}),
181
  )
182
 
183
 
openenv.yaml CHANGED
@@ -1,11 +1,58 @@
1
  spec_version: 1
2
  name: code_review_env
3
- version: "1.0.0"
4
  description: >
5
- A code review and security audit environment for training AI agents.
6
  The agent identifies bugs, security vulnerabilities, and performance issues
7
- across three tasks of increasing difficulty (easy → medium → hard).
 
 
8
  type: space
9
  runtime: fastapi
10
  app: server.app:app
 
11
  port: 7860
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  spec_version: 1
2
  name: code_review_env
3
+ version: "2.0.0"
4
  description: >
5
+ A code review and security audit RL environment for training AI agents.
6
  The agent identifies bugs, security vulnerabilities, and performance issues
7
+ across 7 tasks of increasing difficulty (easy → medium → medium-hard → hard).
8
+ Features: PBRS reward shaping, graduated near-miss rewards, flood protection,
9
+ CAMRL curriculum selector, VL return normalization, and cross-language tasks.
10
  type: space
11
  runtime: fastapi
12
  app: server.app:app
13
+ entry_point: server
14
  port: 7860
15
+
16
+ tasks:
17
+ - id: bug-detection
18
+ difficulty: easy
19
+ language: python
20
+ num_issues: 3
21
+ max_steps: 15
22
+ - id: security-audit
23
+ difficulty: medium
24
+ language: python
25
+ num_issues: 7
26
+ max_steps: 20
27
+ - id: async-review
28
+ difficulty: medium-hard
29
+ language: python
30
+ num_issues: 6
31
+ max_steps: 20
32
+ - id: data-pipeline
33
+ difficulty: hard
34
+ language: python
35
+ num_issues: 7
36
+ max_steps: 25
37
+ - id: comprehensive-review
38
+ difficulty: hard
39
+ language: python
40
+ num_issues: 9
41
+ max_steps: 30
42
+ - id: api-security
43
+ difficulty: hard
44
+ language: python
45
+ num_issues: 8
46
+ max_steps: 25
47
+ - id: js-security
48
+ difficulty: hard
49
+ language: javascript
50
+ num_issues: 8
51
+ max_steps: 25
52
+
53
+ reward_design:
54
+ terminal: "0.70 * F1 + 0.30 * severity_accuracy"
55
+ shaping: "PBRS (Ng et al. 1999): phi(s) = (tp/total_gt) * 0.5"
56
+ near_miss: "exponential decay: 0.10 * exp(-0.6 * (line_diff - 2))"
57
+ flood_protection: "escalating FP penalty after 3rd false positive"
58
+ normalization: "VL Norm (2025): normalized_return = cumulative / steps_used"
server/app.py CHANGED
@@ -21,6 +21,7 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
21
  import json
22
  import asyncio
23
  import dataclasses
 
24
  from typing import Optional, List, Dict, Any
25
 
26
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
@@ -29,7 +30,10 @@ from pydantic import BaseModel
29
 
30
  from models import ReviewAction, Issue
31
  from server.environment import CodeReviewEnvironment
32
- from server.graders import grade_episode, run_keyword_baseline
 
 
 
33
  from tasks.data import ALL_TASKS, TASK_IDS
34
 
35
 
@@ -45,6 +49,7 @@ def _serialize(obj) -> dict:
45
 
46
 
47
  _env_instance = CodeReviewEnvironment()
 
48
 
49
 
50
  def _make_app() -> FastAPI:
@@ -245,27 +250,25 @@ async def run_grader(request: GraderRequest):
245
 
246
  flagged = [Issue.from_dict(i) for i in request.flagged_issues]
247
  ground_truth = [Issue.from_dict(gt) for gt in task["ground_truth_issues"]]
248
- score = grade_episode(flagged, ground_truth)
249
-
250
- tp = sum(
251
- 1 for f in flagged
252
- if any(
253
- True for gt in ground_truth
254
- if abs(f.line_number - gt.line_number) <= 2
255
- and f.filename == gt.filename
256
- )
257
- )
258
 
259
  return {
260
  "task_id": request.task_id,
261
  "difficulty": task["difficulty"],
262
- "score": score,
263
  "max_score": 1.0,
 
 
 
 
264
  "details": {
265
  "total_flagged": len(flagged),
266
- "true_positives": tp,
267
- "false_positives": len(flagged) - tp,
 
 
268
  "total_ground_truth": len(ground_truth),
 
269
  },
270
  }
271
 
@@ -296,6 +299,180 @@ async def run_baseline():
296
  }
297
 
298
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
  def main():
300
  import uvicorn
301
  port = int(os.environ.get("PORT", 7860))
 
21
  import json
22
  import asyncio
23
  import dataclasses
24
+ import random
25
  from typing import Optional, List, Dict, Any
26
 
27
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
 
30
 
31
  from models import ReviewAction, Issue
32
  from server.environment import CodeReviewEnvironment
33
+ from server.graders import (
34
+ grade_episode, grade_episode_detailed, run_keyword_baseline,
35
+ compute_code_state_features, RewardNormalizer,
36
+ )
37
  from tasks.data import ALL_TASKS, TASK_IDS
38
 
39
 
 
49
 
50
 
51
  _env_instance = CodeReviewEnvironment()
52
+ _reward_normalizer = RewardNormalizer(window_size=100)
53
 
54
 
55
  def _make_app() -> FastAPI:
 
250
 
251
  flagged = [Issue.from_dict(i) for i in request.flagged_issues]
252
  ground_truth = [Issue.from_dict(gt) for gt in task["ground_truth_issues"]]
253
+ detailed = grade_episode_detailed(flagged, ground_truth)
 
 
 
 
 
 
 
 
 
254
 
255
  return {
256
  "task_id": request.task_id,
257
  "difficulty": task["difficulty"],
258
+ "score": detailed["score"],
259
  "max_score": 1.0,
260
+ "f1": detailed["f1"],
261
+ "precision": detailed["precision"],
262
+ "recall": detailed["recall"],
263
+ "severity_accuracy": detailed["severity_accuracy"],
264
  "details": {
265
  "total_flagged": len(flagged),
266
+ "true_positives": detailed["true_positives"],
267
+ "false_positives": detailed["false_positives"],
268
+ "false_negatives": detailed["false_negatives"],
269
+ "near_misses": detailed["near_misses"],
270
  "total_ground_truth": len(ground_truth),
271
+ "per_file": detailed["per_file"],
272
  },
273
  }
274
 
 
299
  }
300
 
301
 
302
+ class CurriculumRequest(BaseModel):
303
+ agent_performance: Optional[Dict[str, Any]] = None
304
+ easy_threshold: float = 0.30
305
+ hard_threshold: float = 0.70
306
+
307
+
308
+ @app.post("/curriculum")
309
+ async def curriculum_task_selector(request: CurriculumRequest):
310
+ """
311
+ CAMRL-style curriculum task selector (Curriculum-based Asymmetric Multi-Task RL, TPAMI 2023).
312
+
313
+ Given agent performance metrics per task, returns the recommended next task_id
314
+ based on curriculum phase:
315
+ - easy phase (avg_success < 0.30): focus on task with fewest issues
316
+ - medium phase (0.30-0.70): mix easy/hard (70% easy, 30% hard)
317
+ - hard phase (avg_success > 0.70): focus on least-solved hard tasks
318
+
319
+ Body:
320
+ agent_performance: {task_id: {success_rate: 0.5, episodes: 10, avg_score: 0.4}}
321
+ easy_threshold: float (default 0.3)
322
+ hard_threshold: float (default 0.7)
323
+ """
324
+ perf = request.agent_performance or {}
325
+ easy_thresh = request.easy_threshold
326
+ hard_thresh = request.hard_threshold
327
+
328
+ # Build difficulty estimate per task: (1 - success_rate) × complexity
329
+ task_difficulty: Dict[str, float] = {}
330
+ for task_id, task in ALL_TASKS.items():
331
+ n_issues = len(task["ground_truth_issues"])
332
+ complexity = min(1.0, n_issues / 10.0)
333
+ task_perf = perf.get(task_id, {})
334
+ success_rate = float(task_perf.get("success_rate", task_perf.get("avg_score", 0.0)))
335
+ task_difficulty[task_id] = round((1.0 - success_rate) * complexity, 4)
336
+
337
+ # Determine curriculum phase
338
+ if perf:
339
+ all_success = [float(p.get("success_rate", p.get("avg_score", 0.0))) for p in perf.values()]
340
+ avg_success = sum(all_success) / len(all_success)
341
+ else:
342
+ avg_success = 0.0
343
+
344
+ if avg_success < easy_thresh:
345
+ phase = "easy"
346
+ # Focus on task with lowest ground truth issue count (most approachable)
347
+ recommended = min(ALL_TASKS.keys(), key=lambda t: len(ALL_TASKS[t]["ground_truth_issues"]))
348
+ elif avg_success > hard_thresh:
349
+ phase = "hard"
350
+ # Focus on hardest unsolved task (highest difficulty score)
351
+ recommended = max(task_difficulty, key=task_difficulty.get)
352
+ else:
353
+ phase = "medium"
354
+ # Mix: pick a task proportional to difficulty (harder = more likely)
355
+ import random
356
+ weights = list(task_difficulty.values())
357
+ total_w = sum(weights) or 1.0
358
+ probs = [w / total_w for w in weights]
359
+ recommended = random.choices(list(task_difficulty.keys()), weights=probs, k=1)[0]
360
+
361
+ return {
362
+ "recommended_task_id": recommended,
363
+ "curriculum_phase": phase,
364
+ "avg_success_rate": round(avg_success, 4),
365
+ "task_difficulty_scores": task_difficulty,
366
+ "thresholds": {"easy": easy_thresh, "hard": hard_thresh},
367
+ "method": "CAMRL",
368
+ }
369
+
370
+
371
+ @app.get("/reward_normalizer")
372
+ async def get_reward_normalizer_stats():
373
+ """
374
+ Return current RewardNormalizer statistics for the running environment.
375
+ Useful for monitoring VL Norm across training runs.
376
+ """
377
+ return _reward_normalizer.to_dict()
378
+
379
+
380
+ @app.post("/record_episode")
381
+ async def record_episode(body: Dict[str, Any]):
382
+ """
383
+ Record a completed episode's return and length for VL Norm statistics.
384
+ Body: {"episode_return": 0.72, "episode_length": 12}
385
+ """
386
+ episode_return = float(body.get("episode_return", 0.0))
387
+ episode_length = int(body.get("episode_length", 1))
388
+ _reward_normalizer.update(episode_return, episode_length)
389
+ normalized = _reward_normalizer.normalize(episode_return, episode_length)
390
+ return {
391
+ "normalized_return": normalized,
392
+ "stats": _reward_normalizer.to_dict(),
393
+ }
394
+
395
+
396
+ class TRLRolloutRequest(BaseModel):
397
+ task_id: Optional[str] = None
398
+ seed: Optional[int] = None
399
+ actions: List[Dict[str, Any]] # Pre-generated action sequence from LLM
400
+
401
+
402
+ @app.post("/trl_rollout")
403
+ async def trl_rollout(request: TRLRolloutRequest):
404
+ """
405
+ Run a full episode from a pre-generated action sequence.
406
+
407
+ Designed for TRL GRPOTrainer custom rollout_fn integration:
408
+ - Takes a sequence of LLM-generated actions
409
+ - Runs them through the environment
410
+ - Returns trajectory dict with per-step rewards and final score
411
+
412
+ This enables offline rollout: LLM generates all actions first,
413
+ then this endpoint evaluates them, matching TRL's batch-rollout pattern.
414
+
415
+ Body:
416
+ task_id: str (optional, random if not set)
417
+ seed: int (optional)
418
+ actions: [{action_type, line_number, filename, ...}, ...]
419
+
420
+ Returns:
421
+ trajectory: [{step, action, reward, feedback, done}]
422
+ episode_return: float (sum of step rewards)
423
+ final_score: float (terminal grade)
424
+ normalized_return: float (episode_return / num_steps)
425
+ state_features: [float] (12-dim feature vector at end of episode)
426
+ """
427
+ rollout_env = CodeReviewEnvironment()
428
+ obs = rollout_env.reset(task_id=request.task_id, seed=request.seed)
429
+
430
+ trajectory = []
431
+ episode_return = 0.0
432
+ final_score = 0.0
433
+
434
+ for step_idx, action_dict in enumerate(request.actions):
435
+ action = ReviewAction.from_dict(action_dict)
436
+ obs_step = rollout_env.step(action)
437
+ step_data = _serialize(obs_step)
438
+
439
+ reward = step_data.get("reward") or 0.0
440
+ episode_return += reward
441
+
442
+ trajectory.append({
443
+ "step": step_idx + 1,
444
+ "action": action_dict,
445
+ "reward": reward,
446
+ "reward_breakdown": step_data.get("reward_breakdown", {}),
447
+ "feedback": step_data.get("feedback", ""),
448
+ "current_score": step_data.get("current_score", 0.0),
449
+ "done": step_data.get("done", False),
450
+ })
451
+
452
+ if step_data.get("done", False):
453
+ final_score = step_data.get("reward", step_data.get("current_score", 0.0)) or 0.0
454
+ break
455
+
456
+ n_steps = max(len(trajectory), 1)
457
+ # Record in global normalizer for VL Norm statistics
458
+ _reward_normalizer.update(episode_return, n_steps)
459
+ normalized = _reward_normalizer.normalize(episode_return, n_steps)
460
+
461
+ # Get final state features
462
+ final_progress = rollout_env._compute_progress(rollout_env._task["max_steps"] if rollout_env._task else 20)
463
+
464
+ return {
465
+ "task_id": request.task_id,
466
+ "trajectory": trajectory,
467
+ "episode_return": round(episode_return, 4),
468
+ "final_score": round(final_score, 4),
469
+ "normalized_return": normalized,
470
+ "num_steps": n_steps,
471
+ "state_features": final_progress.get("state_features", []),
472
+ "final_progress": {k: v for k, v in final_progress.items() if k != "state_features"},
473
+ }
474
+
475
+
476
  def main():
477
  import uvicorn
478
  port = int(os.environ.get("PORT", 7860))
server/environment.py CHANGED
@@ -9,11 +9,15 @@ import sys
9
  import os
10
  sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
11
 
12
- from typing import Optional, List
13
 
14
  from models import Issue, ReviewAction, ReviewObservation, ReviewState
15
  from tasks.data import ALL_TASKS, TASK_IDS
16
- from server.graders import grade_episode, compute_live_score, match_issue
 
 
 
 
17
 
18
  try:
19
  from openenv.core.env_server import Environment as _BaseEnv
@@ -25,21 +29,44 @@ except ImportError:
25
  pass
26
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  class CodeReviewEnvironment(_BaseEnv):
29
  """
30
- A code review and security audit environment.
31
 
32
  The agent receives code files and must identify bugs, security
33
  vulnerabilities, and performance issues by flagging them with
34
  exact line numbers, types, and severity ratings.
35
 
36
- Episode flow:
37
- 1. reset(task_id) agent sees code files and task description
38
- 2. step(flag_issue) flag a problem; get per-step reward
39
- 3. step(clear_flag) remove an incorrectly flagged issue
40
- 4. step(request_hint) get a hint (costs -0.01 reward)
41
- 5. step(submit_review) episode ends, final grade is returned
42
- (or auto-ends when max_steps is reached)
 
 
 
 
43
  """
44
 
45
  SUPPORTS_CONCURRENT_SESSIONS = False
@@ -49,6 +76,10 @@ class CodeReviewEnvironment(_BaseEnv):
49
  self._task: Optional[dict] = None
50
  self._ground_truth: List[Issue] = []
51
  self._hint_index: int = 0
 
 
 
 
52
 
53
  def reset(
54
  self,
@@ -70,6 +101,9 @@ class CodeReviewEnvironment(_BaseEnv):
70
  for gt in self._task["ground_truth_issues"]
71
  ]
72
  self._hint_index = 0
 
 
 
73
 
74
  self._state = ReviewState(
75
  task_id=task_id,
@@ -81,6 +115,16 @@ class CodeReviewEnvironment(_BaseEnv):
81
  submitted=False,
82
  )
83
 
 
 
 
 
 
 
 
 
 
 
84
  return ReviewObservation(
85
  task_id=task_id,
86
  task_description=self._task["description"],
@@ -93,11 +137,16 @@ class CodeReviewEnvironment(_BaseEnv):
93
  feedback=(
94
  f"New episode started. Task: {self._task['difficulty'].upper()}. "
95
  f"Review the code carefully and flag all issues you find. "
96
- f"Use 'submit_review' when done."
 
97
  ),
98
  current_score=0.0,
99
  done=False,
100
  reward=None,
 
 
 
 
101
  )
102
 
103
  def step(
@@ -133,26 +182,43 @@ class CodeReviewEnvironment(_BaseEnv):
133
  action = ReviewAction.from_dict(action)
134
 
135
  self._state.step_count += 1
136
- reward, feedback = self._process_action(action)
 
 
 
 
137
 
138
  max_steps = self._task["max_steps"]
139
  auto_end = self._state.step_count >= max_steps and not self._state.submitted
140
  done = self._state.submitted or auto_end
141
 
142
  if auto_end and not self._state.submitted:
143
- # Grade what was submitted so far
144
  final = grade_episode(self._state.flagged_issues, self._ground_truth)
145
  self._state.current_score = final
146
- reward = final * 0.5 # partial credit for auto-end
 
147
  feedback += (
148
- f" Max steps reached. Auto-graded: {final:.3f}. "
149
- f"Submit earlier for best score."
150
  )
151
  self._state.submitted = True
152
 
153
  live = compute_live_score(self._state.flagged_issues, self._ground_truth)
154
  self._state.current_score = live
155
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  return ReviewObservation(
157
  task_id=self._state.task_id,
158
  task_description="",
@@ -166,12 +232,130 @@ class CodeReviewEnvironment(_BaseEnv):
166
  current_score=live,
167
  done=done,
168
  reward=reward,
 
 
 
 
 
 
 
 
169
  )
170
 
171
  @property
172
  def state(self) -> ReviewState:
173
  return self._state
174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  def _process_action(self, action: ReviewAction):
176
  atype = (action.action_type or "").strip().lower()
177
 
@@ -187,25 +371,26 @@ class CodeReviewEnvironment(_BaseEnv):
187
  return 0.0, (
188
  f"Unknown action_type '{action.action_type}'. "
189
  "Use: flag_issue | clear_flag | request_hint | submit_review"
190
- )
191
 
192
  def _handle_flag(self, action: ReviewAction):
193
  if action.line_number is None:
194
- return -0.02, "flag_issue requires 'line_number'."
195
  if not action.filename:
196
- return -0.02, "flag_issue requires 'filename'."
197
  if action.issue_type not in ("bug", "security", "performance", "logic", None):
198
  action.issue_type = "bug"
199
  if action.severity not in ("low", "medium", "high", "critical", None):
200
  action.severity = "medium"
201
 
 
202
  for existing in self._state.flagged_issues:
203
  if (existing.line_number == action.line_number
204
  and existing.filename == action.filename):
205
  return 0.0, (
206
  f"Line {action.line_number} in {action.filename} already flagged. "
207
- "Use clear_flag first if you want to change the finding."
208
- )
209
 
210
  new_issue = Issue(
211
  line_number=action.line_number,
@@ -216,95 +401,258 @@ class CodeReviewEnvironment(_BaseEnv):
216
  fix_suggestion=action.fix_suggestion,
217
  )
218
 
219
- is_tp = any(
220
- match_issue(new_issue, gt)
221
- for gt in self._ground_truth
222
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
  self._state.flagged_issues.append(new_issue)
225
 
226
- if is_tp:
227
- reward = 0.10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  feedback = (
229
- f"Good catch! Issue flagged at {action.filename}:{action.line_number}. "
230
- f"[+0.10 reward — correct finding]"
231
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  else:
233
- reward = -0.05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  feedback = (
235
- f"Issue flagged at {action.filename}:{action.line_number}. "
236
- f"[-0.05 reward no matching ground-truth issue nearby]"
237
  )
238
 
239
- return reward, feedback
240
 
241
  def _handle_clear(self, action: ReviewAction):
242
  if action.line_number is None or not action.filename:
243
- return -0.02, "clear_flag requires 'line_number' and 'filename'."
244
-
245
- before = len(self._state.flagged_issues)
246
- removed = None
247
- self._state.flagged_issues = [
248
- f for f in self._state.flagged_issues
249
- if not (f.line_number == action.line_number
250
- and f.filename == action.filename)
251
- ]
252
 
253
- if len(self._state.flagged_issues) == before:
 
 
 
 
 
 
 
 
254
  return 0.0, (
255
  f"No flagged issue found at {action.filename}:{action.line_number}."
256
- )
257
 
258
- removed_issue = Issue(
259
- line_number=action.line_number,
260
- filename=action.filename,
261
- issue_type="bug",
262
- severity="medium",
263
- )
264
  was_tp = any(match_issue(removed_issue, gt) for gt in self._ground_truth)
265
 
266
  if was_tp:
267
- reward = -0.03
 
 
 
 
 
 
268
  feedback = (
269
  f"Removed a correct finding at {action.filename}:{action.line_number}. "
270
- f"[-0.03 reward]"
271
  )
272
  else:
273
- reward = 0.03
 
 
 
274
  feedback = (
275
  f"Removed a false positive at {action.filename}:{action.line_number}. "
276
- f"[+0.03 reward — good correction]"
277
  )
278
 
279
- return reward, feedback
280
 
281
  def _handle_hint(self):
282
  hints = self._task.get("hints", [])
 
 
 
 
 
283
  if self._hint_index >= len(hints):
284
- return -0.01, "No more hints available for this task."
285
 
286
  hint = hints[self._hint_index]
287
  self._hint_index += 1
288
  remaining = len(hints) - self._hint_index
289
- return -0.01, f"Hint {self._hint_index}/{len(hints)}: {hint} ({remaining} hints left)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
 
291
  def _handle_submit(self):
292
  self._state.submitted = True
293
  final_score = grade_episode(self._state.flagged_issues, self._ground_truth)
294
  self._state.current_score = final_score
295
 
296
- tp_count = sum(
297
- 1 for f in self._state.flagged_issues
298
- if any(match_issue(f, gt) for gt in self._ground_truth)
299
- )
300
  total_gt = len(self._ground_truth)
301
  total_flagged = len(self._state.flagged_issues)
 
 
 
 
302
 
303
  feedback = (
304
  f"Review submitted! Final score: {final_score:.3f}. "
305
- f"Found {tp_count}/{total_gt} real issues. "
306
- f"Total flags: {total_flagged} "
307
- f"({'perfect' if total_flagged == tp_count else f'{total_flagged - tp_count} false positives'})."
308
  )
309
-
310
- return final_score, feedback
 
 
 
 
 
 
 
 
 
 
 
 
9
  import os
10
  sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
11
 
12
+ from typing import Optional, List, Dict, Any, Set
13
 
14
  from models import Issue, ReviewAction, ReviewObservation, ReviewState
15
  from tasks.data import ALL_TASKS, TASK_IDS
16
+ from server.graders import (
17
+ grade_episode, compute_live_score, match_issue, match_quality,
18
+ compute_code_metadata, grade_episode_detailed,
19
+ graduated_near_reward, compute_potential, compute_code_state_features,
20
+ )
21
 
22
  try:
23
  from openenv.core.env_server import Environment as _BaseEnv
 
29
  pass
30
 
31
 
32
+ # Reward constants
33
+ _BASE_TP_REWARD = 0.10
34
+ _NEAR_MISS_REWARD = 0.03
35
+ _BASE_FP_PENALTY = -0.05
36
+ _SEVERITY_EXACT_BONUS = 0.02 # when severity exactly matches GT
37
+ _TEMPORAL_BONUS = 0.02 # early correct flag (first 40% of steps)
38
+ _CONFIDENCE_TP_BONUS = 0.01 # high-confidence TP
39
+ _CONFIDENCE_FP_EXTRA = -0.03 # high-confidence FP (penalty multiplier)
40
+ _HINT_COST = -0.01
41
+ _REMOVE_TP_PENALTY = -0.03
42
+ _REMOVE_FP_REWARD = 0.03
43
+ _VALIDATION_PENALTY = -0.02
44
+ # Flood protection: escalating FP penalty
45
+ _FP_FLOOD_THRESHOLD = 3 # FPs before escalation kicks in
46
+ _FP_FLOOD_MULTIPLIER = 1.5 # each extra FP beyond threshold costs 1.5x more
47
+
48
+ _SEV_RANK = {"low": 0, "medium": 1, "high": 2, "critical": 3}
49
+
50
+
51
  class CodeReviewEnvironment(_BaseEnv):
52
  """
53
+ A code review and security audit RL environment.
54
 
55
  The agent receives code files and must identify bugs, security
56
  vulnerabilities, and performance issues by flagging them with
57
  exact line numbers, types, and severity ratings.
58
 
59
+ Reward design:
60
+ - True positive flag: +0.10 base, +0.02 severity exact match,
61
+ +0.02 early (first 40% steps), +0.01 high-confidence TP
62
+ - Near-miss (±3-5 lines): +0.03 partial credit
63
+ - False positive: -0.05 base, escalating penalty after 3rd FP,
64
+ extra -0.03 for high-confidence FP
65
+ - Clear false positive: +0.03
66
+ - Clear true positive: -0.03
67
+ - Hint: -0.01
68
+ - Submit: final F1+severity score (0.0–1.0)
69
+ - Auto-end (max_steps): full grade score (no penalty)
70
  """
71
 
72
  SUPPORTS_CONCURRENT_SESSIONS = False
 
76
  self._task: Optional[dict] = None
77
  self._ground_truth: List[Issue] = []
78
  self._hint_index: int = 0
79
+ self._code_metadata: Dict[str, Any] = {}
80
+ self._fp_count: int = 0 # total false positives this episode
81
+ self._matched_gt_indices: Set[int] = set() # GT indices already matched
82
+ self._episode_rewards: List[float] = [] # for VL return normalization
83
 
84
  def reset(
85
  self,
 
101
  for gt in self._task["ground_truth_issues"]
102
  ]
103
  self._hint_index = 0
104
+ self._fp_count = 0
105
+ self._matched_gt_indices = set()
106
+ self._episode_rewards = []
107
 
108
  self._state = ReviewState(
109
  task_id=task_id,
 
115
  submitted=False,
116
  )
117
 
118
+ issue_categories = list({gt.issue_type for gt in self._ground_truth})
119
+ self._code_metadata = compute_code_metadata(
120
+ self._task["code_files"],
121
+ issue_categories=issue_categories,
122
+ )
123
+ # Pre-compute initial state features (progress=empty at reset)
124
+ self._code_metadata["state_features"] = compute_code_state_features(
125
+ self._code_metadata, progress={}
126
+ )
127
+
128
  return ReviewObservation(
129
  task_id=task_id,
130
  task_description=self._task["description"],
 
137
  feedback=(
138
  f"New episode started. Task: {self._task['difficulty'].upper()}. "
139
  f"Review the code carefully and flag all issues you find. "
140
+ f"Use 'submit_review' when done. "
141
+ f"Issue categories present: {sorted(set(issue_categories))}."
142
  ),
143
  current_score=0.0,
144
  done=False,
145
  reward=None,
146
+ reward_breakdown={},
147
+ progress={},
148
+ flagged_summary={},
149
+ code_metadata=self._code_metadata,
150
  )
151
 
152
  def step(
 
182
  action = ReviewAction.from_dict(action)
183
 
184
  self._state.step_count += 1
185
+ reward, feedback, reward_breakdown = self._process_action(action)
186
+
187
+ # Track episode rewards for VL return normalization
188
+ if reward is not None:
189
+ self._episode_rewards.append(float(reward))
190
 
191
  max_steps = self._task["max_steps"]
192
  auto_end = self._state.step_count >= max_steps and not self._state.submitted
193
  done = self._state.submitted or auto_end
194
 
195
  if auto_end and not self._state.submitted:
196
+ # Auto-end: grade in full (no penalty for hitting step limit)
197
  final = grade_episode(self._state.flagged_issues, self._ground_truth)
198
  self._state.current_score = final
199
+ reward = final # full score, no 0.5x penalty
200
+ reward_breakdown = {"auto_end_grade": final, "total": final}
201
  feedback += (
202
+ f" Step budget exhausted — auto-graded: {final:.3f}. "
203
+ f"Submit earlier next time for slightly cleaner feedback."
204
  )
205
  self._state.submitted = True
206
 
207
  live = compute_live_score(self._state.flagged_issues, self._ground_truth)
208
  self._state.current_score = live
209
 
210
+ progress = self._compute_progress(max_steps)
211
+ flagged_summary = self._compute_flagged_summary()
212
+
213
+ # PRM-style dense signal: expected reward-to-go
214
+ # Based on Process Reward Models research: give agent an estimate of
215
+ # how much reward is still available, so it can plan remaining steps.
216
+ tp_found = len(self._matched_gt_indices)
217
+ total_gt = len(self._ground_truth)
218
+ issues_remaining = total_gt - tp_found
219
+ # Expected: each remaining TP gives ~0.12 (base + avg severity bonus)
220
+ expected_reward_to_go = round(issues_remaining * 0.12, 3)
221
+
222
  return ReviewObservation(
223
  task_id=self._state.task_id,
224
  task_description="",
 
232
  current_score=live,
233
  done=done,
234
  reward=reward,
235
+ reward_breakdown=reward_breakdown,
236
+ progress=progress,
237
+ flagged_summary=flagged_summary,
238
+ code_metadata={}, # Only populated on reset
239
+ metadata={
240
+ "issues_remaining": issues_remaining,
241
+ "expected_reward_to_go": expected_reward_to_go,
242
+ },
243
  )
244
 
245
  @property
246
  def state(self) -> ReviewState:
247
  return self._state
248
 
249
+ # ------------------------------------------------------------------
250
+ # Progress and summary helpers
251
+ # ------------------------------------------------------------------
252
+
253
+ def _compute_progress(self, max_steps: int) -> Dict[str, Any]:
254
+ """Compute live precision/recall/f1, step stats, and unfound issue types."""
255
+ flagged = self._state.flagged_issues
256
+ gt = self._ground_truth
257
+
258
+ tp = 0
259
+ fp = 0
260
+ matched: Set[int] = set()
261
+ found_types: Set[str] = set()
262
+
263
+ for flag in flagged:
264
+ hit = False
265
+ for i, g in enumerate(gt):
266
+ if i not in matched and match_issue(flag, g):
267
+ tp += 1
268
+ matched.add(i)
269
+ found_types.add(g.issue_type)
270
+ hit = True
271
+ break
272
+ if not hit:
273
+ fp += 1
274
+
275
+ fn = len(gt) - len(matched)
276
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
277
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
278
+ f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0
279
+
280
+ all_types = {g.issue_type for g in gt}
281
+ unfound_types = sorted(all_types - found_types)
282
+
283
+ steps_used = self._state.step_count
284
+ steps_remaining = max(0, max_steps - steps_used)
285
+
286
+ # Variable-Length Return Normalization (VL Norm 2025):
287
+ # normalized_return = cumulative_reward / max(steps_used, 1)
288
+ # This makes return comparable across episodes of different length,
289
+ # which is key for multi-task RL where tasks have different max_steps.
290
+ cumulative_reward = sum(self._episode_rewards)
291
+ normalized_return = round(cumulative_reward / max(steps_used, 1), 4)
292
+
293
+ progress = {
294
+ "precision": round(precision, 4),
295
+ "recall": round(recall, 4),
296
+ "f1": round(f1, 4),
297
+ "true_positives": float(tp),
298
+ "false_positives": float(fp),
299
+ "total_ground_truth": float(len(gt)),
300
+ "steps_used": float(steps_used),
301
+ "steps_remaining": float(steps_remaining),
302
+ "unfound_issue_types": unfound_types,
303
+ "normalized_return": normalized_return,
304
+ "cumulative_reward": round(cumulative_reward, 4),
305
+ }
306
+
307
+ # 12-dim state feature vector for RL policy/value networks (code2vec/PBRS literature)
308
+ progress["state_features"] = compute_code_state_features(
309
+ self._code_metadata, progress=progress
310
+ )
311
+
312
+ return progress
313
+
314
+ def _compute_flagged_summary(self) -> Dict[str, Any]:
315
+ """Compute correct/incorrect/near_miss counts."""
316
+ flagged = self._state.flagged_issues
317
+ gt = self._ground_truth
318
+
319
+ correct = 0
320
+ near_misses = 0
321
+ incorrect = 0
322
+ matched_gt: Set[int] = set()
323
+
324
+ for flag in flagged:
325
+ matched = False
326
+ for i, g in enumerate(gt):
327
+ if i in matched_gt:
328
+ continue
329
+ if match_issue(flag, g):
330
+ correct += 1
331
+ matched_gt.add(i)
332
+ matched = True
333
+ break
334
+
335
+ if not matched:
336
+ is_near = False
337
+ for i, g in enumerate(gt):
338
+ if i in matched_gt:
339
+ continue
340
+ if match_quality(flag, g) == "near":
341
+ is_near = True
342
+ break
343
+ if is_near:
344
+ near_misses += 1
345
+ else:
346
+ incorrect += 1
347
+
348
+ return {
349
+ "total_flagged": len(flagged),
350
+ "correct": correct,
351
+ "incorrect": incorrect,
352
+ "near_misses": near_misses,
353
+ }
354
+
355
+ # ------------------------------------------------------------------
356
+ # Action handlers
357
+ # ------------------------------------------------------------------
358
+
359
  def _process_action(self, action: ReviewAction):
360
  atype = (action.action_type or "").strip().lower()
361
 
 
371
  return 0.0, (
372
  f"Unknown action_type '{action.action_type}'. "
373
  "Use: flag_issue | clear_flag | request_hint | submit_review"
374
+ ), {}
375
 
376
  def _handle_flag(self, action: ReviewAction):
377
  if action.line_number is None:
378
+ return _VALIDATION_PENALTY, "flag_issue requires 'line_number'.", {"validation_penalty": _VALIDATION_PENALTY}
379
  if not action.filename:
380
+ return _VALIDATION_PENALTY, "flag_issue requires 'filename'.", {"validation_penalty": _VALIDATION_PENALTY}
381
  if action.issue_type not in ("bug", "security", "performance", "logic", None):
382
  action.issue_type = "bug"
383
  if action.severity not in ("low", "medium", "high", "critical", None):
384
  action.severity = "medium"
385
 
386
+ # Duplicate check
387
  for existing in self._state.flagged_issues:
388
  if (existing.line_number == action.line_number
389
  and existing.filename == action.filename):
390
  return 0.0, (
391
  f"Line {action.line_number} in {action.filename} already flagged. "
392
+ "Use clear_flag first to change it."
393
+ ), {"duplicate": 0.0}
394
 
395
  new_issue = Issue(
396
  line_number=action.line_number,
 
401
  fix_suggestion=action.fix_suggestion,
402
  )
403
 
404
+ # Classify: TP, near-miss (with line distance), or FP
405
+ is_tp = False
406
+ is_near = False
407
+ near_line_diff = 0
408
+ matched_gt_issue: Optional[Issue] = None
409
+ matched_gt_idx: Optional[int] = None
410
+
411
+ for i, gt in enumerate(self._ground_truth):
412
+ q = match_quality(new_issue, gt)
413
+ if q == "exact" and i not in self._matched_gt_indices:
414
+ is_tp = True
415
+ matched_gt_issue = gt
416
+ matched_gt_idx = i
417
+ break
418
+ elif q == "near" and not is_near:
419
+ is_near = True
420
+ near_line_diff = abs(new_issue.line_number - gt.line_number)
421
 
422
  self._state.flagged_issues.append(new_issue)
423
 
424
+ # PBRS: compute potential before and after this flag
425
+ tp_before = len(self._matched_gt_indices)
426
+ total_gt = len(self._ground_truth)
427
+
428
+ reward_breakdown: Dict[str, float] = {}
429
+
430
+ if is_tp and matched_gt_issue is not None and matched_gt_idx is not None:
431
+ self._matched_gt_indices.add(matched_gt_idx)
432
+ tp_after = len(self._matched_gt_indices)
433
+
434
+ base_reward = _BASE_TP_REWARD
435
+ reward_breakdown["base_tp"] = base_reward
436
+
437
+ # Severity exact match bonus
438
+ severity_bonus = 0.0
439
+ if new_issue.severity == matched_gt_issue.severity:
440
+ severity_bonus = _SEVERITY_EXACT_BONUS
441
+ reward_breakdown["severity_exact"] = severity_bonus
442
+
443
+ # Temporal bonus: TP caught in first 40% of max_steps
444
+ max_steps = self._task["max_steps"]
445
+ early_threshold = max(1, int(max_steps * 0.4))
446
+ temporal_bonus = 0.0
447
+ if self._state.step_count <= early_threshold:
448
+ temporal_bonus = _TEMPORAL_BONUS
449
+ reward_breakdown["temporal_bonus"] = temporal_bonus
450
+
451
+ # Confidence calibration: high confidence TP → small bonus
452
+ confidence_bonus = 0.0
453
+ if action.confidence is not None and action.confidence >= 0.7:
454
+ confidence_bonus = _CONFIDENCE_TP_BONUS
455
+ reward_breakdown["confidence_bonus"] = confidence_bonus
456
+
457
+ # PBRS: Φ(s') - Φ(s) (potential-based shaping, policy-invariant)
458
+ phi_before = compute_potential(tp_before, total_gt)
459
+ phi_after = compute_potential(tp_after, total_gt)
460
+ pbrs_bonus = round(phi_after - phi_before, 4)
461
+ reward_breakdown["pbrs_shaping"] = pbrs_bonus
462
+
463
+ reward = base_reward + severity_bonus + temporal_bonus + confidence_bonus + pbrs_bonus
464
+ reward_breakdown["total"] = round(reward, 4)
465
+
466
+ sev_note = f", severity +{severity_bonus:.2f}" if severity_bonus else ""
467
+ temp_note = f", early +{temporal_bonus:.2f}" if temporal_bonus else ""
468
+ conf_note = f", conf +{confidence_bonus:.2f}" if confidence_bonus else ""
469
+ pbrs_note = f", progress +{pbrs_bonus:.2f}" if pbrs_bonus > 0 else ""
470
  feedback = (
471
+ f"Correct! Issue at {action.filename}:{action.line_number} confirmed. "
472
+ f"[+{reward:.2f}{sev_note}{temp_note}{conf_note}{pbrs_note}]"
473
  )
474
+
475
+ elif is_near:
476
+ # Graduated near-miss: smooth exponential decay by line distance
477
+ near_reward = graduated_near_reward(near_line_diff)
478
+ reward_breakdown["near_miss"] = near_reward
479
+ reward_breakdown["line_diff"] = float(near_line_diff)
480
+ reward_breakdown["total"] = near_reward
481
+ feedback = (
482
+ f"Close! Near a real issue at {action.filename}:{action.line_number}. "
483
+ f"[+{near_reward:.3f} — {near_line_diff} lines off, adjust line number]"
484
+ )
485
+ reward = near_reward
486
+
487
  else:
488
+ # False positive — with flood protection
489
+ self._fp_count += 1
490
+
491
+ base_penalty = _BASE_FP_PENALTY
492
+ reward_breakdown["base_fp"] = base_penalty
493
+
494
+ # Escalating penalty after FP_FLOOD_THRESHOLD FPs
495
+ flood_penalty = 0.0
496
+ if self._fp_count > _FP_FLOOD_THRESHOLD:
497
+ extra = self._fp_count - _FP_FLOOD_THRESHOLD
498
+ flood_penalty = round(-0.02 * extra * _FP_FLOOD_MULTIPLIER, 3)
499
+ reward_breakdown["flood_penalty"] = flood_penalty
500
+
501
+ # High-confidence FP: extra penalty
502
+ confidence_penalty = 0.0
503
+ if action.confidence is not None and action.confidence >= 0.7:
504
+ confidence_penalty = _CONFIDENCE_FP_EXTRA
505
+ reward_breakdown["confidence_penalty"] = confidence_penalty
506
+
507
+ reward = base_penalty + flood_penalty + confidence_penalty
508
+ reward_breakdown["total"] = round(reward, 4)
509
+
510
+ flood_note = f", over-flagging -{abs(flood_penalty):.2f}" if flood_penalty else ""
511
+ conf_note = f", high-confidence penalty {confidence_penalty:.2f}" if confidence_penalty else ""
512
  feedback = (
513
+ f"No match at {action.filename}:{action.line_number}. "
514
+ f"[{reward:.2f}false positive{flood_note}{conf_note}]"
515
  )
516
 
517
+ return reward, feedback, reward_breakdown
518
 
519
  def _handle_clear(self, action: ReviewAction):
520
  if action.line_number is None or not action.filename:
521
+ return _VALIDATION_PENALTY, "clear_flag requires 'line_number' and 'filename'.", {"validation_penalty": _VALIDATION_PENALTY}
 
 
 
 
 
 
 
 
522
 
523
+ removed_issue = None
524
+ new_list = []
525
+ for f in self._state.flagged_issues:
526
+ if f.line_number == action.line_number and f.filename == action.filename:
527
+ removed_issue = f
528
+ else:
529
+ new_list.append(f)
530
+
531
+ if removed_issue is None:
532
  return 0.0, (
533
  f"No flagged issue found at {action.filename}:{action.line_number}."
534
+ ), {"no_op": 0.0}
535
 
536
+ self._state.flagged_issues = new_list
537
+
538
+ # Check if removed issue was TP
 
 
 
539
  was_tp = any(match_issue(removed_issue, gt) for gt in self._ground_truth)
540
 
541
  if was_tp:
542
+ # Un-track it from matched set
543
+ for i, gt in enumerate(self._ground_truth):
544
+ if match_issue(removed_issue, gt):
545
+ self._matched_gt_indices.discard(i)
546
+ break
547
+ reward = _REMOVE_TP_PENALTY
548
+ reward_breakdown = {"removed_tp": reward, "total": reward}
549
  feedback = (
550
  f"Removed a correct finding at {action.filename}:{action.line_number}. "
551
+ f"[{reward:.2f}]"
552
  )
553
  else:
554
+ # Removing a FP — decrement counter
555
+ self._fp_count = max(0, self._fp_count - 1)
556
+ reward = _REMOVE_FP_REWARD
557
+ reward_breakdown = {"removed_fp": reward, "total": reward}
558
  feedback = (
559
  f"Removed a false positive at {action.filename}:{action.line_number}. "
560
+ f"[+{reward:.2f} — good correction]"
561
  )
562
 
563
+ return reward, feedback, reward_breakdown
564
 
565
  def _handle_hint(self):
566
  hints = self._task.get("hints", [])
567
+
568
+ adaptive_hint = self._get_adaptive_hint()
569
+ if adaptive_hint:
570
+ return _HINT_COST, f"Hint: {adaptive_hint} ({_HINT_COST} reward)", {"hint_cost": _HINT_COST}
571
+
572
  if self._hint_index >= len(hints):
573
+ return _HINT_COST, "No more hints available for this task.", {"hint_cost": _HINT_COST}
574
 
575
  hint = hints[self._hint_index]
576
  self._hint_index += 1
577
  remaining = len(hints) - self._hint_index
578
+ return _HINT_COST, f"Hint {self._hint_index}/{len(hints)}: {hint} ({remaining} hints left)", {"hint_cost": _HINT_COST}
579
+
580
+ def _get_adaptive_hint(self) -> Optional[str]:
581
+ """Generate a context-aware hint based on current episode state."""
582
+ flagged = self._state.flagged_issues
583
+ gt = self._ground_truth
584
+
585
+ if not gt:
586
+ return None
587
+
588
+ tp_count = len(self._matched_gt_indices)
589
+ fp_count = len(flagged) - tp_count - sum(
590
+ 1 for f in flagged
591
+ if any(match_quality(f, g) == "near" for g in gt)
592
+ )
593
+
594
+ issue_categories = self._code_metadata.get("issue_categories", [])
595
+
596
+ # Many false positives: over-flagging
597
+ if fp_count > tp_count and fp_count >= 2:
598
+ return (
599
+ "You are over-flagging. Focus only on confident, concrete findings. "
600
+ "Consider using clear_flag to remove uncertain flags."
601
+ )
602
+
603
+ # No correct flags at all yet
604
+ if len(flagged) > 0 and tp_count == 0:
605
+ if issue_categories:
606
+ cats = ", ".join(sorted(set(issue_categories)))
607
+ return (
608
+ f"Focus on [{cats}] issues. "
609
+ "None of your current flags match real issues. Re-examine carefully."
610
+ )
611
+
612
+ # Found some but missed whole categories
613
+ if tp_count > 0 and issue_categories:
614
+ found_types: Set[str] = set()
615
+ for i in self._matched_gt_indices:
616
+ found_types.add(gt[i].issue_type)
617
+ missed = sorted(set(issue_categories) - found_types)
618
+ if missed:
619
+ missed_str = ", ".join(missed)
620
+ return (
621
+ f"Good progress! You've found some issues but haven't flagged any "
622
+ f"[{missed_str}] issues yet — look again for those specifically."
623
+ )
624
+
625
+ return None # Fall through to static hints
626
 
627
  def _handle_submit(self):
628
  self._state.submitted = True
629
  final_score = grade_episode(self._state.flagged_issues, self._ground_truth)
630
  self._state.current_score = final_score
631
 
632
+ tp_count = len(self._matched_gt_indices)
 
 
 
633
  total_gt = len(self._ground_truth)
634
  total_flagged = len(self._state.flagged_issues)
635
+ fp_count = total_flagged - tp_count
636
+
637
+ # Breakdown for detailed feedback
638
+ detailed = grade_episode_detailed(self._state.flagged_issues, self._ground_truth)
639
 
640
  feedback = (
641
  f"Review submitted! Final score: {final_score:.3f}. "
642
+ f"Found {tp_count}/{total_gt} issues. "
643
+ f"Precision: {detailed['precision']:.2f}, Recall: {detailed['recall']:.2f}, "
644
+ f"F1: {detailed['f1']:.2f}. "
645
  )
646
+ if fp_count > 0:
647
+ feedback += f"{fp_count} false positive(s). "
648
+ if detailed["false_negatives"] > 0:
649
+ fn = detailed["false_negatives"]
650
+ feedback += f"{fn} issue(s) missed."
651
+
652
+ reward_breakdown = {
653
+ "final_f1": detailed["f1"],
654
+ "severity_accuracy": detailed["severity_accuracy"],
655
+ "final_score": final_score,
656
+ "total": final_score,
657
+ }
658
+ return final_score, feedback, reward_breakdown
server/graders.py CHANGED
@@ -1,10 +1,21 @@
1
  """
2
  Grading logic for the Code Review Environment.
 
 
 
 
 
 
 
 
 
3
  """
4
  from __future__ import annotations
5
 
 
 
6
  import re
7
- from typing import List, Tuple, Set
8
 
9
  import sys
10
  import os
@@ -21,8 +32,18 @@ _TYPE_COMPAT = {
21
  "performance": {"performance"},
22
  }
23
 
 
 
 
 
 
 
 
 
24
 
25
- def match_issue(flagged: Issue, gt: Issue, line_tolerance: int = 2) -> bool:
 
 
26
  if flagged.filename != gt.filename:
27
  return False
28
  if abs(flagged.line_number - gt.line_number) > line_tolerance:
@@ -33,6 +54,274 @@ def match_issue(flagged: Issue, gt: Issue, line_tolerance: int = 2) -> bool:
33
  return True
34
 
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def grade_episode(
37
  flagged: List[Issue],
38
  ground_truth: List[Issue],
@@ -79,6 +368,105 @@ def grade_episode(
79
  return round(min(1.0, max(0.0, final)), 4)
80
 
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  def compute_live_score(flagged: List[Issue], ground_truth: List[Issue]) -> float:
83
  """F1-only score for per-step feedback (no severity bonus)."""
84
  if not ground_truth:
@@ -107,6 +495,7 @@ def compute_live_score(flagged: List[Issue], ground_truth: List[Issue]) -> float
107
 
108
 
109
  _PATTERNS = [
 
110
  (r"range\(len\(\w+\)\s*\+\s*1\)", None, "bug", "high",
111
  "Off-by-one error: range(len(x) + 1) iterates one past the end"),
112
  (r"left,\s*right\s*=\s*0,\s*len\(", None, "bug", "medium",
@@ -114,30 +503,81 @@ _PATTERNS = [
114
  (r"counts\[word\]\s*=\s*0\b", None, "bug", "low",
115
  "Counter initialized to 0 instead of 1"),
116
 
 
117
  (r'SECRET_KEY\s*=\s*["\']', None, "security", "high",
118
  "Hardcoded SECRET_KEY in source code"),
 
 
119
  (r'PASSWORD\s*=\s*["\']', None, "security", "high",
120
  "Hardcoded password in source code"),
 
 
121
  (r"f['\"].*SELECT.*\{", None, "security", "critical",
122
  "SQL injection via f-string query construction"),
 
 
123
  (r"f['\"].*DELETE.*\{", None, "security", "critical",
124
  "SQL injection via f-string DELETE query"),
 
 
 
 
125
  (r"render_template_string\(f['\"]", None, "security", "high",
126
  "XSS: unsanitized user input in render_template_string"),
127
  (r"shell\s*=\s*True", None, "security", "critical",
128
  "Command injection risk: shell=True with user input"),
129
- (r"hashlib\.md5\(", None, "security", "medium",
130
- "MD5 is cryptographically broken, use SHA-256 or HMAC-SHA256"),
 
 
 
 
 
 
 
 
131
  (r"expected\s*==\s*\w+_hash", None, "security", "medium",
132
  "Timing attack: use hmac.compare_digest() for constant-time comparison"),
 
 
 
 
 
 
 
 
133
  (r"password\s*=\s*models\.CharField", None, "security", "critical",
134
  "Plaintext password storage in database"),
135
- (r"os\.path\.join\(['\"]\/", None, "security", "high",
136
- "Path traversal: os.path.join with absolute prefix doesn't prevent traversal"),
137
 
 
 
 
 
 
 
 
 
 
 
 
138
  (r"\.objects\.get\(id=item\.", None, "performance", "high",
139
  "N+1 query: database lookup inside a loop"),
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  (r"FloatField\(\)", None, "bug", "medium",
142
  "FloatField for monetary values causes precision errors, use DecimalField"),
143
  (r"BinaryField\(\)", None, "security", "high",
 
1
  """
2
  Grading logic for the Code Review Environment.
3
+
4
+ Reward design is grounded in:
5
+ - Potential-Based Reward Shaping (PBRS): Ng et al. 1999
6
+ R_shaped(s,a,s') = R(s,a,s') + γ·Φ(s') - Φ(s)
7
+ where Φ(s) = (tp_found / total_gt) · POTENTIAL_SCALE
8
+ - Graduated line-proximity rewards: exponential decay over line distance
9
+ reward = BASE_TP · exp(-DECAY · max(0, line_diff - EXACT_TOLERANCE))
10
+ for 0 < line_diff ≤ NEAR_TOLERANCE
11
+ - F1-based terminal scoring: 0.70·F1 + 0.30·severity_accuracy
12
  """
13
  from __future__ import annotations
14
 
15
+ import ast
16
+ import math
17
  import re
18
+ from typing import List, Tuple, Set, Dict, Optional
19
 
20
  import sys
21
  import os
 
32
  "performance": {"performance"},
33
  }
34
 
35
+ # Tolerances
36
+ NEAR_TOLERANCE = 5
37
+ EXACT_TOLERANCE = 2
38
+
39
+ # Graduated reward constants (PBRS + smooth near-miss)
40
+ BASE_TP_REWARD = 0.10
41
+ NEAR_DECAY = 0.6 # exponential decay per line beyond EXACT_TOLERANCE
42
+ POTENTIAL_SCALE = 0.5 # Φ(s) = (tp/total_gt) * POTENTIAL_SCALE
43
 
44
+
45
+ def match_issue(flagged: Issue, gt: Issue, line_tolerance: int = EXACT_TOLERANCE, near_tolerance: int = NEAR_TOLERANCE) -> bool:
46
+ """Return True if flagged matches gt within line_tolerance lines and same type."""
47
  if flagged.filename != gt.filename:
48
  return False
49
  if abs(flagged.line_number - gt.line_number) > line_tolerance:
 
54
  return True
55
 
56
 
57
+ def match_quality(flagged: Issue, gt: Issue) -> str:
58
+ """
59
+ Return quality of match between flagged and gt:
60
+ "exact" — within ±2 lines and right issue type
61
+ "near" — within ±3-5 lines and same file (regardless of type)
62
+ "none" — no meaningful match
63
+ """
64
+ if flagged.filename != gt.filename:
65
+ return "none"
66
+
67
+ line_diff = abs(flagged.line_number - gt.line_number)
68
+
69
+ if line_diff <= EXACT_TOLERANCE:
70
+ compat = _TYPE_COMPAT.get(gt.issue_type, {gt.issue_type})
71
+ if flagged.issue_type in compat:
72
+ return "exact"
73
+
74
+ if line_diff <= NEAR_TOLERANCE:
75
+ return "near"
76
+
77
+ return "none"
78
+
79
+
80
+ def graduated_near_reward(line_diff: int) -> float:
81
+ """
82
+ Graduated reward for near-miss flags using exponential decay.
83
+
84
+ Implements continuous reward shaping based on proximity:
85
+ line_diff = 0-2 → 0.10 (full TP, handled separately)
86
+ line_diff = 3 → 0.10 * exp(-0.6*1) ≈ 0.055
87
+ line_diff = 4 → 0.10 * exp(-0.6*2) ≈ 0.033
88
+ line_diff = 5 → 0.10 * exp(-0.6*3) ≈ 0.020
89
+
90
+ This gives smooth gradient signal rather than a hard 0.03 step function,
91
+ encouraging the agent to refine line numbers progressively.
92
+ """
93
+ if line_diff <= EXACT_TOLERANCE:
94
+ return BASE_TP_REWARD
95
+ extra = line_diff - EXACT_TOLERANCE
96
+ return round(BASE_TP_REWARD * math.exp(-NEAR_DECAY * extra), 4)
97
+
98
+
99
+ def compute_potential(tp_count: int, total_gt: int) -> float:
100
+ """
101
+ Potential function Φ(s) for Potential-Based Reward Shaping (PBRS).
102
+
103
+ Φ(s) = (tp_found / total_gt) * POTENTIAL_SCALE
104
+
105
+ The shaped reward R_shaped = r + Φ(s') - Φ(s) ensures policy invariance
106
+ (Ng et al. 1999): the optimal policy under shaped rewards is the same as
107
+ under the original rewards, but with better intermediate gradient signal.
108
+
109
+ Here we compute just Φ(s); the caller computes Φ(s') - Φ(s).
110
+ """
111
+ if total_gt <= 0:
112
+ return 0.0
113
+ return (tp_count / total_gt) * POTENTIAL_SCALE
114
+
115
+
116
+ def compute_function_map(code: str) -> Dict[int, str]:
117
+ """
118
+ Map each line number to the name of its enclosing function (or class method).
119
+ Lines outside any function map to "module". Non-parseable code returns empty dict.
120
+ """
121
+ result: Dict[int, str] = {}
122
+ try:
123
+ tree = ast.parse(code)
124
+ for node in ast.walk(tree):
125
+ if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
126
+ end = getattr(node, "end_lineno", node.lineno)
127
+ for lineno in range(node.lineno, end + 1):
128
+ result[lineno] = node.name
129
+ except SyntaxError:
130
+ pass
131
+ return result
132
+
133
+
134
+ def compute_code_metadata(code_files: Dict[str, str], issue_categories: Optional[List[str]] = None) -> Dict:
135
+ """
136
+ Extract code structure metadata using Python's ast module.
137
+
138
+ Returns:
139
+ total_lines, num_functions, function_names, num_classes, class_names,
140
+ imports, complexity_estimate, issue_categories, function_ranges
141
+ """
142
+ total_lines = 0
143
+ num_functions = 0
144
+ function_names: List[str] = []
145
+ num_classes = 0
146
+ class_names: List[str] = []
147
+ imports: List[str] = []
148
+ branch_count = 0
149
+ function_ranges: List[Dict] = [] # [{name, file, start, end}]
150
+
151
+ for filename, code in code_files.items():
152
+ lines = code.splitlines()
153
+ total_lines += len(lines)
154
+ try:
155
+ tree = ast.parse(code)
156
+ for node in ast.walk(tree):
157
+ if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
158
+ num_functions += 1
159
+ function_names.append(node.name)
160
+ end = getattr(node, "end_lineno", node.lineno)
161
+ function_ranges.append({
162
+ "name": node.name,
163
+ "file": filename,
164
+ "start": node.lineno,
165
+ "end": end,
166
+ })
167
+ elif isinstance(node, ast.ClassDef):
168
+ num_classes += 1
169
+ class_names.append(node.name)
170
+ elif isinstance(node, ast.Import):
171
+ for alias in node.names:
172
+ imports.append(alias.name.split(".")[0])
173
+ elif isinstance(node, ast.ImportFrom):
174
+ if node.module:
175
+ imports.append(node.module.split(".")[0])
176
+ elif isinstance(node, (ast.If, ast.For, ast.While, ast.Try,
177
+ ast.ExceptHandler, ast.With)):
178
+ branch_count += 1
179
+ except SyntaxError:
180
+ # If ast can't parse (e.g. non-Python file), just count lines
181
+ pass
182
+
183
+ # Deduplicate imports
184
+ imports = list(dict.fromkeys(imports))
185
+
186
+ # Complexity estimate
187
+ if branch_count <= 5:
188
+ complexity_estimate = "low"
189
+ elif branch_count <= 15:
190
+ complexity_estimate = "medium"
191
+ else:
192
+ complexity_estimate = "high"
193
+
194
+ return {
195
+ "total_lines": total_lines,
196
+ "num_functions": num_functions,
197
+ "function_names": function_names,
198
+ "num_classes": num_classes,
199
+ "class_names": class_names,
200
+ "imports": imports,
201
+ "complexity_estimate": complexity_estimate,
202
+ "issue_categories": list(set(issue_categories)) if issue_categories else [],
203
+ "function_ranges": function_ranges,
204
+ }
205
+
206
+
207
+ def compute_code_state_features(
208
+ code_metadata: Dict,
209
+ progress: Optional[Dict] = None,
210
+ ) -> List[float]:
211
+ """
212
+ Compute a normalized 12-dimensional feature vector for RL training.
213
+
214
+ Based on state representation research (code2vec, GraphCodeBERT, 2023-2024),
215
+ combining AST-derived structural features with episode progress metrics.
216
+ This vector is suitable as input to a policy network or value estimator.
217
+
218
+ Dimensions:
219
+ 0: total_lines / 200 — code size (normalized)
220
+ 1: num_functions / 20 — function count
221
+ 2: num_classes / 10 — class count
222
+ 3: complexity_score — 0=low, 0.5=medium, 1.0=high
223
+ 4: has_bug_issues — 1 if "bug" in issue_categories
224
+ 5: has_security_issues — 1 if "security" in issue_categories
225
+ 6: has_performance_issues — 1 if "performance" in issue_categories
226
+ 7: has_logic_issues — 1 if "logic" in issue_categories
227
+ 8: progress_recall — tp / total_gt (0 if no progress yet)
228
+ 9: progress_precision — precision so far
229
+ 10: steps_used_frac — steps_used / max_steps
230
+ 11: fp_pressure — false_positives / max(total_flagged, 1)
231
+ """
232
+ if progress is None:
233
+ progress = {}
234
+
235
+ complexity_map = {"low": 0.0, "medium": 0.5, "high": 1.0}
236
+ cats = set(code_metadata.get("issue_categories", []))
237
+
238
+ total_gt = float(progress.get("total_ground_truth", 1.0)) or 1.0
239
+ tp = float(progress.get("true_positives", 0.0))
240
+ fp = float(progress.get("false_positives", 0.0))
241
+ total_flagged = tp + fp
242
+ steps_used = float(progress.get("steps_used", 0.0))
243
+ steps_rem = float(progress.get("steps_remaining", 1.0))
244
+ max_steps = steps_used + steps_rem or 1.0
245
+
246
+ features = [
247
+ min(1.0, code_metadata.get("total_lines", 0) / 200.0),
248
+ min(1.0, code_metadata.get("num_functions", 0) / 20.0),
249
+ min(1.0, code_metadata.get("num_classes", 0) / 10.0),
250
+ complexity_map.get(code_metadata.get("complexity_estimate", "low"), 0.0),
251
+ 1.0 if "bug" in cats else 0.0,
252
+ 1.0 if "security" in cats else 0.0,
253
+ 1.0 if "performance" in cats else 0.0,
254
+ 1.0 if "logic" in cats else 0.0,
255
+ min(1.0, tp / total_gt),
256
+ min(1.0, tp / total_flagged) if total_flagged > 0 else 0.0,
257
+ min(1.0, steps_used / max_steps),
258
+ min(1.0, fp / total_flagged) if total_flagged > 0 else 0.0,
259
+ ]
260
+ return [round(f, 4) for f in features]
261
+
262
+
263
+ class RewardNormalizer:
264
+ """
265
+ Variable-Length Return Normalizer for multi-task RL training.
266
+
267
+ Based on VL Norm (2025) and Return-based Scaling (2021):
268
+ Normalizes episode returns accounting for variable episode lengths,
269
+ preventing long episodes from dominating gradient computation.
270
+
271
+ Usage:
272
+ normalizer = RewardNormalizer(window_size=100)
273
+ # After each episode:
274
+ normalizer.update(episode_return, episode_length)
275
+ normalized_r = normalizer.normalize(episode_return, episode_length)
276
+ """
277
+
278
+ def __init__(self, window_size: int = 100, eps: float = 1e-8) -> None:
279
+ self.window_size = window_size
280
+ self.eps = eps
281
+ self._returns: List[float] = []
282
+ self._lengths: List[int] = []
283
+ self.mean: float = 0.0
284
+ self.std: float = 1.0
285
+
286
+ def update(self, episode_return: float, episode_length: int) -> None:
287
+ """Record a completed episode for running statistics."""
288
+ self._returns.append(episode_return)
289
+ self._lengths.append(max(1, episode_length))
290
+ if len(self._returns) > self.window_size:
291
+ self._returns.pop(0)
292
+ self._lengths.pop(0)
293
+ self._recompute()
294
+
295
+ def _recompute(self) -> None:
296
+ if len(self._returns) < 2:
297
+ return
298
+ returns = [r for r in self._returns]
299
+ lengths = [l for l in self._lengths]
300
+ mean_len = sum(lengths) / len(lengths)
301
+ # Length-adjusted std: longer episodes have proportionally less weight
302
+ self.mean = sum(returns) / len(returns)
303
+ raw_std = (sum((r - self.mean) ** 2 for r in returns) / len(returns)) ** 0.5
304
+ length_factors = [(l / mean_len) ** 0.5 for l in lengths]
305
+ avg_lf = sum(length_factors) / len(length_factors)
306
+ self.std = max(self.eps, raw_std * avg_lf)
307
+
308
+ def normalize(self, episode_return: float, episode_length: int) -> float:
309
+ """Return the length-adjusted normalized return."""
310
+ if len(self._returns) < 2:
311
+ return episode_return
312
+ mean_len = sum(self._lengths) / len(self._lengths)
313
+ length_factor = (max(1, episode_length) / mean_len) ** 0.5
314
+ return round((episode_return - self.mean) / (self.std * length_factor + self.eps), 4)
315
+
316
+ def to_dict(self) -> Dict:
317
+ return {
318
+ "mean": round(self.mean, 4),
319
+ "std": round(self.std, 4),
320
+ "n_episodes": len(self._returns),
321
+ "window_size": self.window_size,
322
+ }
323
+
324
+
325
  def grade_episode(
326
  flagged: List[Issue],
327
  ground_truth: List[Issue],
 
368
  return round(min(1.0, max(0.0, final)), 4)
369
 
370
 
371
+ def grade_episode_detailed(
372
+ flagged: List[Issue],
373
+ ground_truth: List[Issue],
374
+ line_tolerance: int = 2,
375
+ ) -> Dict:
376
+ """
377
+ Full breakdown of grading results.
378
+
379
+ Returns:
380
+ score, f1, precision, recall, severity_accuracy,
381
+ true_positives, false_positives, false_negatives,
382
+ near_misses, per_file
383
+ """
384
+ if not ground_truth:
385
+ score = 1.0 if not flagged else 0.0
386
+ return {
387
+ "score": score,
388
+ "f1": score,
389
+ "precision": score,
390
+ "recall": score,
391
+ "severity_accuracy": score,
392
+ "true_positives": 0,
393
+ "false_positives": len(flagged),
394
+ "false_negatives": 0,
395
+ "near_misses": 0,
396
+ "per_file": {},
397
+ }
398
+
399
+ tp = 0
400
+ fp = 0
401
+ near_misses = 0
402
+ matched_gt_indices: Set[int] = set()
403
+ severity_scores: List[float] = []
404
+ per_file: Dict[str, Dict] = {}
405
+
406
+ for flag in flagged:
407
+ fname = flag.filename
408
+ if fname not in per_file:
409
+ per_file[fname] = {"tp": 0, "fp": 0, "near_miss": 0}
410
+
411
+ matched = False
412
+ for i, gt in enumerate(ground_truth):
413
+ if i in matched_gt_indices:
414
+ continue
415
+ if match_issue(flag, gt, line_tolerance):
416
+ tp += 1
417
+ matched_gt_indices.add(i)
418
+ matched = True
419
+ per_file[fname]["tp"] += 1
420
+ flag_rank = _SEV_RANK.get(flag.severity, 1)
421
+ gt_rank = _SEV_RANK.get(gt.severity, 1)
422
+ distance = abs(flag_rank - gt_rank)
423
+ severity_scores.append(max(0.0, 1.0 - distance * 0.34))
424
+ break
425
+
426
+ if not matched:
427
+ # Check for near miss (3-5 lines off, same file)
428
+ is_near = False
429
+ for i, gt in enumerate(ground_truth):
430
+ if i in matched_gt_indices:
431
+ continue
432
+ q = match_quality(flag, gt)
433
+ if q == "near":
434
+ is_near = True
435
+ break
436
+ if is_near:
437
+ near_misses += 1
438
+ per_file[fname]["near_miss"] += 1
439
+ else:
440
+ fp += 1
441
+ per_file[fname]["fp"] += 1
442
+
443
+ fn = len(ground_truth) - len(matched_gt_indices)
444
+
445
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
446
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
447
+ f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0
448
+
449
+ if severity_scores:
450
+ severity_accuracy = sum(severity_scores) / len(ground_truth)
451
+ else:
452
+ severity_accuracy = 0.0
453
+
454
+ score = round(min(1.0, max(0.0, 0.70 * f1 + 0.30 * severity_accuracy)), 4)
455
+
456
+ return {
457
+ "score": score,
458
+ "f1": round(f1, 4),
459
+ "precision": round(precision, 4),
460
+ "recall": round(recall, 4),
461
+ "severity_accuracy": round(severity_accuracy, 4),
462
+ "true_positives": tp,
463
+ "false_positives": fp,
464
+ "false_negatives": fn,
465
+ "near_misses": near_misses,
466
+ "per_file": per_file,
467
+ }
468
+
469
+
470
  def compute_live_score(flagged: List[Issue], ground_truth: List[Issue]) -> float:
471
  """F1-only score for per-step feedback (no severity bonus)."""
472
  if not ground_truth:
 
495
 
496
 
497
  _PATTERNS = [
498
+ # --- Bug patterns ---
499
  (r"range\(len\(\w+\)\s*\+\s*1\)", None, "bug", "high",
500
  "Off-by-one error: range(len(x) + 1) iterates one past the end"),
501
  (r"left,\s*right\s*=\s*0,\s*len\(", None, "bug", "medium",
 
503
  (r"counts\[word\]\s*=\s*0\b", None, "bug", "low",
504
  "Counter initialized to 0 instead of 1"),
505
 
506
+ # --- Hardcoded secrets ---
507
  (r'SECRET_KEY\s*=\s*["\']', None, "security", "high",
508
  "Hardcoded SECRET_KEY in source code"),
509
+ (r'ADMIN_TOKEN\s*=\s*["\']', None, "security", "high",
510
+ "Hardcoded ADMIN_TOKEN in source code"),
511
  (r'PASSWORD\s*=\s*["\']', None, "security", "high",
512
  "Hardcoded password in source code"),
513
+
514
+ # --- Injection attacks ---
515
  (r"f['\"].*SELECT.*\{", None, "security", "critical",
516
  "SQL injection via f-string query construction"),
517
+ (r"f['\"].*INSERT.*\{", None, "security", "critical",
518
+ "SQL injection via f-string INSERT query"),
519
  (r"f['\"].*DELETE.*\{", None, "security", "critical",
520
  "SQL injection via f-string DELETE query"),
521
+ (r"f['\"].*LIKE.*%\{", None, "security", "critical",
522
+ "SQL injection via f-string LIKE clause"),
523
+ (r"LIMIT\s*\{", None, "security", "critical",
524
+ "SQL injection: LIMIT clause uses unparameterized variable"),
525
  (r"render_template_string\(f['\"]", None, "security", "high",
526
  "XSS: unsanitized user input in render_template_string"),
527
  (r"shell\s*=\s*True", None, "security", "critical",
528
  "Command injection risk: shell=True with user input"),
529
+ (r"os\.system\(", None, "security", "critical",
530
+ "Command injection risk: os.system() executes shell commands"),
531
+ (r"os\.path\.join\(['\"]\/", None, "security", "high",
532
+ "Path traversal: os.path.join with absolute prefix doesn't prevent traversal"),
533
+
534
+ # --- Broken cryptography ---
535
+ (r"hashlib\.md5\(", None, "security", "high",
536
+ "MD5 is cryptographically broken for security use; use SHA-256 or bcrypt"),
537
+ (r"hashlib\.sha1\(", None, "security", "medium",
538
+ "SHA-1 is deprecated for security use; use SHA-256 or better"),
539
  (r"expected\s*==\s*\w+_hash", None, "security", "medium",
540
  "Timing attack: use hmac.compare_digest() for constant-time comparison"),
541
+
542
+ # --- Dangerous deserialization ---
543
+ (r"pickle\.loads\(", None, "security", "critical",
544
+ "Unsafe deserialization: pickle.loads() on untrusted data allows remote code execution"),
545
+ (r"yaml\.load\(", None, "security", "high",
546
+ "Unsafe YAML deserialization: use yaml.safe_load() instead"),
547
+
548
+ # --- Auth / access control ---
549
  (r"password\s*=\s*models\.CharField", None, "security", "critical",
550
  "Plaintext password storage in database"),
 
 
551
 
552
+ # --- Async / concurrency bugs ---
553
+ (r"aiohttp\.ClientSession\(\)", None, "bug", "high",
554
+ "ClientSession created outside 'async with' — may not be closed (resource leak)"),
555
+ (r"timeout\s*=\s*\d+\b", None, "bug", "medium",
556
+ "aiohttp timeout should be aiohttp.ClientTimeout(total=N), not a bare integer"),
557
+ (r"attempt\s*==\s*retries\b", None, "bug", "high",
558
+ "Off-by-one: range(retries) yields 0..retries-1, so attempt==retries is never true"),
559
+ (r"for\s+\w+\s+in\s+\w+_ids\s*:", None, "performance", "high",
560
+ "Sequential loop over IDs — consider asyncio.gather() for concurrent fetching"),
561
+
562
+ # --- Performance ---
563
  (r"\.objects\.get\(id=item\.", None, "performance", "high",
564
  "N+1 query: database lookup inside a loop"),
565
 
566
+ # --- JavaScript-specific patterns ---
567
+ (r"new\s+Function\(", None, "security", "critical",
568
+ "Unsafe dynamic code execution: new Function() with user input is equivalent to eval()"),
569
+ (r"\beval\(", None, "security", "critical",
570
+ "eval() with user-supplied input allows arbitrary code execution"),
571
+ (r"execSync\(", None, "security", "critical",
572
+ "Command injection risk: execSync() with user-supplied data"),
573
+ (r"jwt\.sign\(.*\{(?!.*expiresIn)", None, "security", "medium",
574
+ "JWT issued without expiry (expiresIn) — tokens are valid forever"),
575
+ (r"JWT_SECRET\s*=\s*['\"]", None, "security", "high",
576
+ "Hardcoded JWT secret in source code"),
577
+ (r"res\.send\(`.*\$\{", None, "security", "high",
578
+ "XSS: template literal with user input sent directly in response"),
579
+
580
+ # --- Data model bugs ---
581
  (r"FloatField\(\)", None, "bug", "medium",
582
  "FloatField for monetary values causes precision errors, use DecimalField"),
583
  (r"BinaryField\(\)", None, "security", "high",
tasks/data.py CHANGED
@@ -418,10 +418,533 @@ TASK_COMPREHENSIVE: Dict[str, Any] = {
418
  }
419
 
420
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
  ALL_TASKS: Dict[str, Dict[str, Any]] = {
422
  TASK_BUG_DETECTION["task_id"]: TASK_BUG_DETECTION,
423
  TASK_SECURITY_AUDIT["task_id"]: TASK_SECURITY_AUDIT,
424
  TASK_COMPREHENSIVE["task_id"]: TASK_COMPREHENSIVE,
 
 
 
 
425
  }
426
 
427
  TASK_IDS: List[str] = list(ALL_TASKS.keys())
 
418
  }
419
 
420
 
421
+ _ASYNC_CODE = """\
422
+ import asyncio
423
+ import aiohttp
424
+ from typing import List, Optional
425
+
426
+ _cache: dict = {}
427
+
428
+
429
+ async def fetch_json(url: str, session: aiohttp.ClientSession) -> dict:
430
+ async with session.get(url, timeout=5) as resp:
431
+ return await resp.json()
432
+
433
+
434
+ async def get_user(user_id: int, session: aiohttp.ClientSession) -> dict:
435
+ if user_id in _cache:
436
+ return _cache[user_id]
437
+ data = await fetch_json(f"https://api.example.com/users/{user_id}", session)
438
+ _cache[user_id] = data
439
+ return data
440
+
441
+
442
+ async def process_users(user_ids: List[int]) -> List[dict]:
443
+ session = aiohttp.ClientSession()
444
+ results = []
445
+ for uid in user_ids:
446
+ result = await get_user(uid, session)
447
+ results.append(result)
448
+ return results
449
+
450
+
451
+ async def run_with_retry(url: str, retries: int = 3) -> Optional[str]:
452
+ for attempt in range(retries):
453
+ try:
454
+ async with aiohttp.ClientSession() as session:
455
+ async with session.get(url) as resp:
456
+ return await resp.text()
457
+ except Exception:
458
+ if attempt == retries:
459
+ raise
460
+ return None
461
+
462
+
463
+ class TaskRunner:
464
+ def __init__(self, concurrency: int = 5):
465
+ self.concurrency = concurrency
466
+ self.results = []
467
+
468
+ async def run_all(self, tasks: List) -> List:
469
+ for task in tasks:
470
+ result = await task
471
+ self.results.append(result)
472
+ return self.results
473
+ """
474
+
475
+ TASK_ASYNC_REVIEW: Dict[str, Any] = {
476
+ "task_id": "async-review",
477
+ "difficulty": "medium-hard",
478
+ "description": (
479
+ "Review this async Python module for concurrency bugs, resource leaks,\n"
480
+ "and performance issues with asyncio and aiohttp.\n"
481
+ "The code has subtle async-specific bugs that would cause failures or\n"
482
+ "degraded performance in production. Identify all issues with exact\n"
483
+ "line numbers, types, and severity.\n\n"
484
+ "File to review: async.py"
485
+ ),
486
+ "language": "python",
487
+ "code_files": {
488
+ "async.py": _ASYNC_CODE,
489
+ },
490
+ "ground_truth_issues": [
491
+ _issue(
492
+ 5, "async.py", "bug", "high",
493
+ "Shared mutable dict without asyncio.Lock; concurrent coroutines can read "
494
+ "stale data or overwrite each other's writes. Use async with _lock: around "
495
+ "cache check and write.",
496
+ "Add _lock = asyncio.Lock() and use: async with _lock: around cache check and write."
497
+ ),
498
+ _issue(
499
+ 9, "async.py", "bug", "medium",
500
+ "timeout=5 is wrong type for aiohttp; requires aiohttp.ClientTimeout(total=5). "
501
+ "Passing an int raises TypeError at runtime.",
502
+ "Use: timeout=aiohttp.ClientTimeout(total=5)"
503
+ ),
504
+ _issue(
505
+ 22, "async.py", "bug", "high",
506
+ "ClientSession created but never closed, causing resource leak. "
507
+ "Use: async with aiohttp.ClientSession() as session: and pass it in.",
508
+ "Replace with: async with aiohttp.ClientSession() as session:"
509
+ ),
510
+ _issue(
511
+ 24, "async.py", "performance", "high",
512
+ "Sequential for loop with await serializes all requests. "
513
+ "Use asyncio.gather(*[get_user(uid, session) for uid in user_ids]) "
514
+ "for true concurrency.",
515
+ "Replace loop with: results = await asyncio.gather(*[get_user(uid, session) for uid in user_ids])"
516
+ ),
517
+ _issue(
518
+ 37, "async.py", "bug", "high",
519
+ "Off-by-one: range(retries) yields 0..retries-1, so attempt==retries is never true. "
520
+ "Exception is never re-raised. Fix: attempt == retries - 1.",
521
+ "Change: if attempt == retries - 1: raise"
522
+ ),
523
+ _issue(
524
+ 48, "async.py", "performance", "medium",
525
+ "Tasks awaited sequentially instead of concurrently. "
526
+ "Use asyncio.gather(*tasks). Also self.results accumulates across multiple run_all calls.",
527
+ "Replace loop with: self.results.extend(await asyncio.gather(*tasks))"
528
+ ),
529
+ ],
530
+ "max_steps": 20,
531
+ "hints": [
532
+ "Check all places where ClientSession is created — are they properly closed?",
533
+ "Look for sequential awaits inside loops where gather() would be more appropriate.",
534
+ "The retry function has an off-by-one error in its condition.",
535
+ ],
536
+ }
537
+
538
+
539
+ _PIPELINE_CODE = """\
540
+ import csv
541
+ import json
542
+ import hashlib
543
+ import sqlite3
544
+ from typing import List, Dict, Optional
545
+
546
+
547
+ def init_db(path: str) -> sqlite3.Connection:
548
+ conn = sqlite3.connect(path)
549
+ conn.execute(
550
+ "CREATE TABLE IF NOT EXISTS records "
551
+ "(id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT NOT NULL, "
552
+ "email TEXT NOT NULL, password_hash TEXT, score REAL DEFAULT 0)"
553
+ )
554
+ conn.commit()
555
+ return conn
556
+
557
+
558
+ def hash_password(password: str) -> str:
559
+ return hashlib.md5(password.encode()).hexdigest()
560
+
561
+
562
+ def insert_record(conn: sqlite3.Connection, username: str,
563
+ email: str, password: str, score: float) -> None:
564
+ pwd = hash_password(password)
565
+ conn.execute(
566
+ f"INSERT INTO records (username, email, password_hash, score) "
567
+ f"VALUES ('{username}', '{email}', '{pwd}', {score})"
568
+ )
569
+ conn.commit()
570
+
571
+
572
+ def search_records(conn: sqlite3.Connection, query: str) -> List[Dict]:
573
+ cursor = conn.execute(
574
+ f"SELECT id, username, email, score FROM records WHERE username LIKE '%{query}%'"
575
+ )
576
+ cols = [d[0] for d in cursor.description]
577
+ return [dict(zip(cols, row)) for row in cursor.fetchall()]
578
+
579
+
580
+ def bulk_load(conn: sqlite3.Connection, filepath: str) -> int:
581
+ count = 0
582
+ with open(filepath, newline='') as f:
583
+ for row in csv.DictReader(f):
584
+ insert_record(conn, row['username'], row['email'],
585
+ row.get('password', ''), float(row.get('score', 0)))
586
+ count += 1
587
+ return count
588
+
589
+
590
+ def export_records(conn: sqlite3.Connection, out_path: str) -> None:
591
+ rows = search_records(conn, '')
592
+ with open(out_path, 'w') as f:
593
+ json.dump(rows, f, indent=2)
594
+
595
+
596
+ def get_top_scores(conn: sqlite3.Connection, limit: int) -> List[Dict]:
597
+ cursor = conn.execute(
598
+ f"SELECT username, score FROM records ORDER BY score DESC LIMIT {limit}"
599
+ )
600
+ return [{'username': r[0], 'score': r[1]} for r in cursor.fetchall()]
601
+ """
602
+
603
+ TASK_DATA_PIPELINE: Dict[str, Any] = {
604
+ "task_id": "data-pipeline",
605
+ "difficulty": "hard",
606
+ "description": (
607
+ "Perform a security and correctness review of this data pipeline module.\n"
608
+ "The module handles user records in SQLite. It contains multiple critical\n"
609
+ "security vulnerabilities, a performance issue, and an error handling gap.\n"
610
+ "Find ALL issues across the file.\n\n"
611
+ "File to review: pipeline.py"
612
+ ),
613
+ "language": "python",
614
+ "code_files": {
615
+ "pipeline.py": _PIPELINE_CODE,
616
+ },
617
+ "ground_truth_issues": [
618
+ _issue(
619
+ 20, "pipeline.py", "security", "high",
620
+ "MD5 is cryptographically broken for password hashing. "
621
+ "Use bcrypt, argon2, or hashlib.pbkdf2_hmac instead.",
622
+ "Use: hashlib.pbkdf2_hmac('sha256', password.encode(), salt, 100000)"
623
+ ),
624
+ _issue(
625
+ 27, "pipeline.py", "security", "critical",
626
+ "SQL injection: username, email, and pwd interpolated directly into query string. "
627
+ "Use parameterized queries: conn.execute('INSERT INTO records ... VALUES (?,?,?,?)', "
628
+ "(username, email, pwd, score))",
629
+ "Use: conn.execute('INSERT INTO records (username, email, password_hash, score) VALUES (?,?,?,?)', (username, email, pwd, score))"
630
+ ),
631
+ _issue(
632
+ 35, "pipeline.py", "security", "critical",
633
+ "SQL injection in LIKE clause: user-supplied query interpolated directly. "
634
+ "Use: conn.execute('... WHERE username LIKE ?', (f'%{query}%',))",
635
+ "Use: conn.execute('SELECT ... WHERE username LIKE ?', (f'%{query}%',))"
636
+ ),
637
+ _issue(
638
+ 41, "pipeline.py", "performance", "high",
639
+ "bulk_load commits one transaction per row via insert_record. "
640
+ "Wrap entire loop in with conn: for a single transaction — 10-100x faster for large imports.",
641
+ "Wrap loop body with: with conn: conn.executemany(...)"
642
+ ),
643
+ _issue(
644
+ 46, "pipeline.py", "bug", "medium",
645
+ "float() conversion has no error handling. A single malformed score field "
646
+ "crashes the entire import. Wrap in try/except ValueError.",
647
+ "Use: float(row.get('score', 0) or 0) inside try/except ValueError"
648
+ ),
649
+ _issue(
650
+ 52, "pipeline.py", "security", "high",
651
+ "export_records calls search_records(conn, '') which returns all records including "
652
+ "password_hash field. Strip sensitive fields before export.",
653
+ "Filter out password_hash: rows = [{k: v for k, v in r.items() if k != 'password_hash'} for r in rows]"
654
+ ),
655
+ _issue(
656
+ 59, "pipeline.py", "security", "critical",
657
+ "SQL injection: limit value interpolated into query. Although limit is an int here, "
658
+ "use parameterized query: conn.execute('... LIMIT ?', (limit,))",
659
+ "Use: conn.execute('SELECT username, score FROM records ORDER BY score DESC LIMIT ?', (limit,))"
660
+ ),
661
+ ],
662
+ "max_steps": 25,
663
+ "hints": [
664
+ "Look for every place user-supplied values touch a SQL query string — are they parameterized?",
665
+ "The bulk_load function has both a performance issue and an error handling gap.",
666
+ "Check what fields export_records includes in its output — are any sensitive?",
667
+ ],
668
+ }
669
+
670
+
671
+ _API_SECURITY_CODE = """\
672
+ from fastapi import FastAPI, Depends, HTTPException, Header
673
+ from fastapi.security import HTTPBasic, HTTPBasicCredentials
674
+ import jwt
675
+ import hashlib
676
+ import pickle
677
+ import os
678
+ import sqlite3
679
+
680
+ app = FastAPI()
681
+ security = HTTPBasic()
682
+
683
+ SECRET_KEY = "dev-secret-do-not-use-in-prod"
684
+ ADMIN_TOKEN = "admin-hardcoded-token-123"
685
+
686
+ users_db = {
687
+ "admin": hashlib.md5(b"password123").hexdigest(),
688
+ "user": hashlib.md5(b"user123").hexdigest(),
689
+ }
690
+
691
+
692
+ @app.post("/login")
693
+ def login(credentials: HTTPBasicCredentials = Depends(security)):
694
+ username = credentials.username
695
+ stored = users_db.get(username, "")
696
+ if stored != hashlib.md5(credentials.password.encode()).hexdigest():
697
+ raise HTTPException(status_code=401, detail="Invalid credentials")
698
+ token = jwt.encode({"user": username, "admin": username == "admin"},
699
+ SECRET_KEY, algorithm="HS256")
700
+ return {"token": token}
701
+
702
+
703
+ @app.get("/users/{user_id}")
704
+ def get_user(user_id: str, authorization: str = Header(None)):
705
+ if not authorization:
706
+ raise HTTPException(status_code=401, detail="Missing token")
707
+ payload = jwt.decode(authorization, SECRET_KEY, algorithms=["HS256"])
708
+ conn = sqlite3.connect("app.db")
709
+ cursor = conn.execute(f"SELECT * FROM users WHERE id = '{user_id}'")
710
+ return {"user": cursor.fetchone()}
711
+
712
+
713
+ @app.post("/admin/export")
714
+ def admin_export(authorization: str = Header(None)):
715
+ if authorization != ADMIN_TOKEN:
716
+ raise HTTPException(status_code=403, detail="Forbidden")
717
+ path = os.environ.get("EXPORT_PATH", "/tmp/export")
718
+ os.system(f"mysqldump mydb > {path}/dump.sql")
719
+ return {"status": "export complete", "path": path}
720
+
721
+
722
+ @app.post("/import")
723
+ def import_data(payload: bytes):
724
+ data = pickle.loads(payload)
725
+ return {"records": len(data)}
726
+
727
+
728
+ @app.get("/search")
729
+ def search_users(q: str, limit: int = 100):
730
+ conn = sqlite3.connect("app.db")
731
+ rows = conn.execute(
732
+ f"SELECT id, name, email FROM users WHERE name LIKE '%{q}%' LIMIT {limit}"
733
+ ).fetchall()
734
+ return {"results": rows}
735
+ """
736
+
737
+ TASK_API_SECURITY: Dict[str, Any] = {
738
+ "task_id": "api-security",
739
+ "difficulty": "hard",
740
+ "description": (
741
+ "Perform a security audit on this FastAPI REST API.\n"
742
+ "The service handles user authentication and data operations.\n"
743
+ "It contains multiple critical security flaws across authentication,\n"
744
+ "authorization, injection attacks, and cryptography.\n"
745
+ "Find ALL issues with exact line numbers and severity ratings.\n\n"
746
+ "File to review: api.py"
747
+ ),
748
+ "language": "python",
749
+ "code_files": {
750
+ "api.py": _API_SECURITY_CODE,
751
+ },
752
+ "ground_truth_issues": [
753
+ _issue(
754
+ 12, "api.py", "security", "high",
755
+ "Hardcoded SECRET_KEY in source code. Any developer with repo access can forge "
756
+ "JWT tokens and impersonate any user.",
757
+ "Use: SECRET_KEY = os.environ.get('SECRET_KEY') and rotate it as a secret."
758
+ ),
759
+ _issue(
760
+ 13, "api.py", "security", "high",
761
+ "Hardcoded ADMIN_TOKEN in source code. Static tokens in code are trivially "
762
+ "leaked via version control, logs, or error messages.",
763
+ "Use: ADMIN_TOKEN = os.environ.get('ADMIN_TOKEN') and generate it securely."
764
+ ),
765
+ _issue(
766
+ 16, "api.py", "security", "high",
767
+ "MD5 used for password hashing. MD5 is cryptographically broken; precomputed "
768
+ "rainbow tables can reverse any MD5 hash in seconds.",
769
+ "Use bcrypt, argon2, or hashlib.pbkdf2_hmac with a random salt."
770
+ ),
771
+ _issue(
772
+ 27, "api.py", "security", "medium",
773
+ "JWT token issued without an expiry claim ('exp'). Tokens are valid forever; "
774
+ "a stolen token can never be invalidated without rotating the secret.",
775
+ "Add: {'exp': datetime.utcnow() + timedelta(hours=1)} to the JWT payload."
776
+ ),
777
+ _issue(
778
+ 33, "api.py", "security", "critical",
779
+ "Missing authorization check: any authenticated user can fetch any user_id. "
780
+ "This is an Insecure Direct Object Reference (IDOR) — user A can read user B's data.",
781
+ "Check: if payload.get('user') != user_id and not payload.get('admin'): raise 403."
782
+ ),
783
+ _issue(
784
+ 38, "api.py", "security", "critical",
785
+ "SQL injection: user_id is interpolated directly into the query string. "
786
+ "An attacker can supply user_id = \"' OR '1'='1\" to dump the users table.",
787
+ "Use parameterized query: conn.execute('SELECT * FROM users WHERE id = ?', (user_id,))"
788
+ ),
789
+ _issue(
790
+ 47, "api.py", "security", "critical",
791
+ "Command injection: EXPORT_PATH from environment is interpolated into an "
792
+ "os.system() shell command. A misconfigured env var like '/tmp; rm -rf /' "
793
+ "executes arbitrary commands as the server process.",
794
+ "Use subprocess.run(['mysqldump', 'mydb'], stdout=open(path, 'w'), shell=False)."
795
+ ),
796
+ _issue(
797
+ 53, "api.py", "security", "critical",
798
+ "Unsafe deserialization: pickle.loads() on untrusted user-supplied bytes allows "
799
+ "remote code execution. Any client can craft a pickle payload that runs arbitrary code.",
800
+ "Use json.loads() or a schema-validated format. Never unpickle untrusted data."
801
+ ),
802
+ ],
803
+ "max_steps": 25,
804
+ "hints": [
805
+ "Check every hardcoded string assigned to variables like SECRET_KEY, TOKEN, PASSWORD.",
806
+ "Look at every endpoint: which ones verify the caller's identity vs just authentication?",
807
+ "Find all places user-supplied data touches: SQL queries, shell commands, deserialization.",
808
+ ],
809
+ }
810
+
811
+
812
+ _JS_CODE = """\
813
+ const express = require('express');
814
+ const jwt = require('jsonwebtoken');
815
+ const { execSync } = require('child_process');
816
+ const path = require('path');
817
+ const fs = require('fs');
818
+ const sqlite3 = require('better-sqlite3');
819
+
820
+ const app = express();
821
+ app.use(express.json());
822
+
823
+ const JWT_SECRET = 'super-secret-key-hardcoded';
824
+ const db = new sqlite3('./data.db');
825
+
826
+ app.post('/login', (req, res) => {
827
+ const { username, password } = req.body;
828
+ const user = db.prepare(`SELECT * FROM users WHERE username = '${username}' AND password = '${password}'`).get();
829
+ if (!user) return res.status(401).json({ error: 'Invalid credentials' });
830
+ const token = jwt.sign({ id: user.id, role: user.role }, JWT_SECRET);
831
+ res.json({ token });
832
+ });
833
+
834
+ app.get('/user/:id', (req, res) => {
835
+ const token = req.headers.authorization;
836
+ const payload = jwt.verify(token, JWT_SECRET);
837
+ const user = db.prepare(`SELECT * FROM users WHERE id = ${req.params.id}`).get();
838
+ res.json(user);
839
+ });
840
+
841
+ app.get('/search', (req, res) => {
842
+ const q = req.query.q;
843
+ res.send(`<h1>Results for: ${q}</h1>`);
844
+ });
845
+
846
+ app.post('/run-report', (req, res) => {
847
+ const { filename } = req.body;
848
+ const output = execSync(`node reports/${filename}`);
849
+ res.send(output.toString());
850
+ });
851
+
852
+ app.get('/files', (req, res) => {
853
+ const name = req.query.name;
854
+ const filePath = path.join(__dirname, 'uploads', name);
855
+ res.send(fs.readFileSync(filePath, 'utf8'));
856
+ });
857
+
858
+ app.post('/template', (req, res) => {
859
+ const { template, data } = req.body;
860
+ const fn = new Function('data', `return \\`${template}\\``);
861
+ res.json({ result: fn(data) });
862
+ });
863
+
864
+ app.listen(3000);
865
+ """
866
+
867
+ TASK_JS_SECURITY: Dict[str, Any] = {
868
+ "task_id": "js-security",
869
+ "difficulty": "hard",
870
+ "description": (
871
+ "Perform a security audit on this Express.js REST API.\n"
872
+ "The service handles authentication and user data operations in Node.js.\n"
873
+ "It contains critical security vulnerabilities common in JavaScript backends.\n"
874
+ "Identify ALL issues with exact line numbers, types, and severity.\n\n"
875
+ "File to review: server.js"
876
+ ),
877
+ "language": "javascript",
878
+ "code_files": {
879
+ "server.js": _JS_CODE,
880
+ },
881
+ "ground_truth_issues": [
882
+ _issue(
883
+ 11, "server.js", "security", "high",
884
+ "Hardcoded JWT secret 'super-secret-key-hardcoded' in source. "
885
+ "Anyone with code access can forge tokens for any user.",
886
+ "Use: const JWT_SECRET = process.env.JWT_SECRET and rotate it as an env secret."
887
+ ),
888
+ _issue(
889
+ 16, "server.js", "security", "critical",
890
+ "SQL injection: username and password are interpolated directly into a template "
891
+ "literal inside prepare(). An attacker can bypass authentication with username = ' OR '1'='1'--.",
892
+ "Use parameterized queries: db.prepare('SELECT * FROM users WHERE username = ? AND password = ?').get(username, password)"
893
+ ),
894
+ _issue(
895
+ 18, "server.js", "security", "medium",
896
+ "JWT issued without expiry ('expiresIn' option missing). Tokens are valid forever; "
897
+ "a stolen token can never be invalidated without rotating the secret.",
898
+ "Add: jwt.sign({ id: user.id, role: user.role }, JWT_SECRET, { expiresIn: '1h' })"
899
+ ),
900
+ _issue(
901
+ 25, "server.js", "security", "critical",
902
+ "Missing authorization + SQL injection: any authenticated user can fetch any "
903
+ "user by changing req.params.id (IDOR). Also id is interpolated directly into SQL.",
904
+ "Check payload.id === req.params.id (or admin role). Use parameterized: db.prepare('SELECT * FROM users WHERE id = ?').get(req.params.id)"
905
+ ),
906
+ _issue(
907
+ 31, "server.js", "security", "high",
908
+ "Cross-site scripting (XSS): user-supplied query parameter q is reflected "
909
+ "directly into HTML response without escaping.",
910
+ "Use a templating engine with auto-escaping, or: res.send(`<h1>Results for: ${escapeHtml(q)}</h1>`)"
911
+ ),
912
+ _issue(
913
+ 36, "server.js", "security", "critical",
914
+ "Command injection: user-supplied filename is passed directly to execSync() "
915
+ "in a shell command. An attacker can supply 'x; rm -rf /' as filename.",
916
+ "Validate filename against a strict allowlist. Use execFileSync(['node', 'reports/' + sanitizedName]) with shell:false."
917
+ ),
918
+ _issue(
919
+ 42, "server.js", "security", "high",
920
+ "Path traversal: user-supplied 'name' is joined to uploads directory with path.join. "
921
+ "An attacker can supply '../../../etc/passwd' to read arbitrary files.",
922
+ "Use: path.resolve(__dirname, 'uploads', path.basename(name)) and validate the result starts with the uploads dir."
923
+ ),
924
+ _issue(
925
+ 48, "server.js", "security", "critical",
926
+ "Unsafe dynamic code execution: new Function() with user-supplied template string "
927
+ "is equivalent to eval(). Any client can execute arbitrary JavaScript on the server.",
928
+ "Never use new Function() or eval() with user input. Use a safe template engine like Handlebars or Mustache."
929
+ ),
930
+ ],
931
+ "max_steps": 25,
932
+ "hints": [
933
+ "Check every place user input (req.body, req.params, req.query) touches a database query, shell command, or HTML response.",
934
+ "Look for hardcoded secrets at the top of the file.",
935
+ "The /template and /run-report endpoints have particularly dangerous patterns.",
936
+ ],
937
+ }
938
+
939
+
940
  ALL_TASKS: Dict[str, Dict[str, Any]] = {
941
  TASK_BUG_DETECTION["task_id"]: TASK_BUG_DETECTION,
942
  TASK_SECURITY_AUDIT["task_id"]: TASK_SECURITY_AUDIT,
943
  TASK_COMPREHENSIVE["task_id"]: TASK_COMPREHENSIVE,
944
+ TASK_ASYNC_REVIEW["task_id"]: TASK_ASYNC_REVIEW,
945
+ TASK_DATA_PIPELINE["task_id"]: TASK_DATA_PIPELINE,
946
+ TASK_API_SECURITY["task_id"]: TASK_API_SECURITY,
947
+ TASK_JS_SECURITY["task_id"]: TASK_JS_SECURITY,
948
  }
949
 
950
  TASK_IDS: List[str] = list(ALL_TASKS.keys())
tests/test_environment.py CHANGED
@@ -41,6 +41,18 @@ def env_hard(env):
41
  return env
42
 
43
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  # ---------------------------------------------------------------------------
45
  # reset() tests
46
  # ---------------------------------------------------------------------------
@@ -106,6 +118,40 @@ class TestReset:
106
  assert obs.flagged_issues == []
107
  assert obs.step_count == 0
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  # ---------------------------------------------------------------------------
111
  # step() — flag_issue tests
@@ -167,6 +213,148 @@ class TestFlagIssue:
167
  obs = env_bug.state
168
  assert len(obs.flagged_issues) == 3
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
  # ---------------------------------------------------------------------------
172
  # step() — clear_flag tests
@@ -312,3 +500,341 @@ class TestMaxSteps:
312
  break
313
 
314
  assert obs.done is True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  return env
42
 
43
 
44
+ @pytest.fixture
45
+ def env_async(env):
46
+ env.reset(task_id="async-review")
47
+ return env
48
+
49
+
50
+ @pytest.fixture
51
+ def env_pipeline(env):
52
+ env.reset(task_id="data-pipeline")
53
+ return env
54
+
55
+
56
  # ---------------------------------------------------------------------------
57
  # reset() tests
58
  # ---------------------------------------------------------------------------
 
118
  assert obs.flagged_issues == []
119
  assert obs.step_count == 0
120
 
121
+ def test_reset_has_code_metadata(self, env):
122
+ """Reset observation should include code_metadata."""
123
+ obs = env.reset(task_id="bug-detection")
124
+ assert isinstance(obs.code_metadata, dict)
125
+ assert "total_lines" in obs.code_metadata
126
+ assert "num_functions" in obs.code_metadata
127
+ assert "complexity_estimate" in obs.code_metadata
128
+
129
+ def test_reset_code_metadata_has_issue_categories(self, env):
130
+ """code_metadata should list the issue categories present in ground truth."""
131
+ obs = env.reset(task_id="bug-detection")
132
+ assert "issue_categories" in obs.code_metadata
133
+ # bug-detection has only bug type issues
134
+ assert "bug" in obs.code_metadata["issue_categories"]
135
+
136
+ def test_reset_has_empty_progress(self, env):
137
+ """Reset observation progress may be empty or absent (populated on step)."""
138
+ obs = env.reset(task_id="bug-detection")
139
+ assert isinstance(obs.progress, dict)
140
+
141
+ def test_reset_has_empty_reward_breakdown(self, env):
142
+ obs = env.reset(task_id="bug-detection")
143
+ assert isinstance(obs.reward_breakdown, dict)
144
+
145
+ def test_reset_async_task(self, env):
146
+ obs = env.reset(task_id="async-review")
147
+ assert obs.task_id == "async-review"
148
+ assert "async.py" in obs.code_files
149
+
150
+ def test_reset_pipeline_task(self, env):
151
+ obs = env.reset(task_id="data-pipeline")
152
+ assert obs.task_id == "data-pipeline"
153
+ assert "pipeline.py" in obs.code_files
154
+
155
 
156
  # ---------------------------------------------------------------------------
157
  # step() — flag_issue tests
 
213
  obs = env_bug.state
214
  assert len(obs.flagged_issues) == 3
215
 
216
+ def test_flag_has_reward_breakdown(self, env_bug):
217
+ """Every step should have a reward_breakdown dict."""
218
+ obs = env_bug.step(ReviewAction(
219
+ action_type="flag_issue", line_number=6, filename="utils.py",
220
+ issue_type="bug", severity="high", description="test"
221
+ ))
222
+ assert isinstance(obs.reward_breakdown, dict)
223
+ assert len(obs.reward_breakdown) > 0
224
+
225
+ def test_flag_has_progress(self, env_bug):
226
+ """Every step should have a progress dict with required keys."""
227
+ obs = env_bug.step(ReviewAction(
228
+ action_type="flag_issue", line_number=6, filename="utils.py",
229
+ issue_type="bug", severity="high", description="test"
230
+ ))
231
+ assert isinstance(obs.progress, dict)
232
+ for key in ("precision", "recall", "f1", "true_positives", "steps_remaining"):
233
+ assert key in obs.progress, f"Missing key: {key}"
234
+
235
+ def test_flag_has_flagged_summary(self, env_bug):
236
+ """Every step should have a flagged_summary dict."""
237
+ obs = env_bug.step(ReviewAction(
238
+ action_type="flag_issue", line_number=6, filename="utils.py",
239
+ issue_type="bug", severity="high", description="test"
240
+ ))
241
+ assert isinstance(obs.flagged_summary, dict)
242
+ assert "total_flagged" in obs.flagged_summary
243
+ assert "correct" in obs.flagged_summary
244
+ assert "incorrect" in obs.flagged_summary
245
+ assert "near_misses" in obs.flagged_summary
246
+
247
+
248
+ # ---------------------------------------------------------------------------
249
+ # Near-miss tests
250
+ # ---------------------------------------------------------------------------
251
+
252
+ class TestNearMiss:
253
+ def test_near_miss_gives_partial_credit(self, env_bug):
254
+ """A flag within 3-5 lines of a GT issue should give +0.03 not -0.05."""
255
+ # GT issue is at line 6 (off-by-one), so line 10 is 4 away = near miss
256
+ obs = env_bug.step(ReviewAction(
257
+ action_type="flag_issue", line_number=10, filename="utils.py",
258
+ issue_type="bug", severity="high", description="near miss test"
259
+ ))
260
+ # Near miss gives +0.03
261
+ assert obs.reward is not None and obs.reward > 0, (
262
+ f"Expected near-miss +0.03 but got {obs.reward}"
263
+ )
264
+ assert obs.reward == pytest.approx(0.03, abs=0.01)
265
+
266
+ def test_near_miss_counted_in_summary(self, env_bug):
267
+ """Near-miss flags should appear in flagged_summary.near_misses."""
268
+ # Line 10 is 4 lines from GT at line 6 → near miss
269
+ obs = env_bug.step(ReviewAction(
270
+ action_type="flag_issue", line_number=10, filename="utils.py",
271
+ issue_type="bug", severity="high", description="near miss"
272
+ ))
273
+ assert obs.flagged_summary.get("near_misses", 0) >= 1
274
+
275
+ def test_true_positive_not_counted_as_near_miss(self, env_bug):
276
+ """An exact TP should not be counted as a near miss."""
277
+ obs = env_bug.step(ReviewAction(
278
+ action_type="flag_issue", line_number=6, filename="utils.py",
279
+ issue_type="bug", severity="high", description="exact match"
280
+ ))
281
+ assert obs.flagged_summary.get("correct", 0) >= 1
282
+ assert obs.flagged_summary.get("near_misses", 0) == 0
283
+
284
+
285
+ # ---------------------------------------------------------------------------
286
+ # Confidence field tests
287
+ # ---------------------------------------------------------------------------
288
+
289
+ class TestConfidenceField:
290
+ def test_action_with_confidence(self, env_bug):
291
+ """ReviewAction should accept a confidence field."""
292
+ action = ReviewAction(
293
+ action_type="flag_issue", line_number=6, filename="utils.py",
294
+ issue_type="bug", severity="high", description="test",
295
+ confidence=0.9
296
+ )
297
+ assert action.confidence == 0.9
298
+
299
+ def test_high_confidence_tp_gets_bonus(self, env_bug):
300
+ """High confidence + TP should give more than base 0.10."""
301
+ obs = env_bug.step(ReviewAction(
302
+ action_type="flag_issue", line_number=6, filename="utils.py",
303
+ issue_type="bug", severity="high", description="test",
304
+ confidence=0.9
305
+ ))
306
+ assert obs.reward is not None and obs.reward > 0.10
307
+
308
+ def test_high_confidence_fp_gets_extra_penalty(self, env_bug):
309
+ """High confidence + FP should give more penalty than -0.05."""
310
+ obs = env_bug.step(ReviewAction(
311
+ action_type="flag_issue", line_number=100, filename="utils.py",
312
+ issue_type="bug", severity="low", description="wrong",
313
+ confidence=0.9
314
+ ))
315
+ assert obs.reward is not None and obs.reward < -0.05
316
+
317
+ def test_low_confidence_tp_base_reward_only(self, env_bug):
318
+ """Low confidence + TP should give exactly base 0.10 (no bonus)."""
319
+ obs = env_bug.step(ReviewAction(
320
+ action_type="flag_issue", line_number=6, filename="utils.py",
321
+ issue_type="bug", severity="high", description="test",
322
+ confidence=0.5
323
+ ))
324
+ assert obs.reward is not None
325
+ # Should be 0.10 base + possible temporal bonus but no confidence bonus
326
+ assert obs.reward >= 0.10
327
+
328
+ def test_no_confidence_field_is_none(self):
329
+ """ReviewAction without confidence defaults to None."""
330
+ action = ReviewAction(
331
+ action_type="flag_issue", line_number=6, filename="utils.py",
332
+ )
333
+ assert action.confidence is None
334
+
335
+ def test_confidence_in_action_to_dict(self):
336
+ """confidence should round-trip through to_dict/from_dict."""
337
+ action = ReviewAction(
338
+ action_type="flag_issue", line_number=6, filename="utils.py",
339
+ confidence=0.75
340
+ )
341
+ d = action.to_dict()
342
+ assert d["confidence"] == 0.75
343
+ action2 = ReviewAction.from_dict(d)
344
+ assert action2.confidence == 0.75
345
+
346
+ def test_related_lines_field(self):
347
+ """ReviewAction should accept a related_lines field."""
348
+ action = ReviewAction(
349
+ action_type="flag_issue", line_number=6, filename="utils.py",
350
+ related_lines=[6, 7, 8]
351
+ )
352
+ assert action.related_lines == [6, 7, 8]
353
+ d = action.to_dict()
354
+ assert d["related_lines"] == [6, 7, 8]
355
+ action2 = ReviewAction.from_dict(d)
356
+ assert action2.related_lines == [6, 7, 8]
357
+
358
 
359
  # ---------------------------------------------------------------------------
360
  # step() — clear_flag tests
 
500
  break
501
 
502
  assert obs.done is True
503
+
504
+
505
+ # ---------------------------------------------------------------------------
506
+ # New task tests
507
+ # ---------------------------------------------------------------------------
508
+
509
+ class TestNewTasks:
510
+ def test_async_review_task_exists(self, env):
511
+ obs = env.reset(task_id="async-review")
512
+ assert obs.task_id == "async-review"
513
+ assert obs.done is False
514
+
515
+ def test_async_review_has_correct_issue_count(self):
516
+ from tasks.data import ALL_TASKS
517
+ task = ALL_TASKS["async-review"]
518
+ assert len(task["ground_truth_issues"]) == 6
519
+
520
+ def test_async_review_has_async_py(self, env):
521
+ obs = env.reset(task_id="async-review")
522
+ assert "async.py" in obs.code_files
523
+ code = obs.code_files["async.py"]
524
+ assert "asyncio" in code
525
+ assert "aiohttp" in code
526
+
527
+ def test_async_review_max_steps(self):
528
+ from tasks.data import ALL_TASKS
529
+ task = ALL_TASKS["async-review"]
530
+ assert task["max_steps"] == 20
531
+
532
+ def test_data_pipeline_task_exists(self, env):
533
+ obs = env.reset(task_id="data-pipeline")
534
+ assert obs.task_id == "data-pipeline"
535
+ assert obs.done is False
536
+
537
+ def test_data_pipeline_has_correct_issue_count(self):
538
+ from tasks.data import ALL_TASKS
539
+ task = ALL_TASKS["data-pipeline"]
540
+ assert len(task["ground_truth_issues"]) == 7
541
+
542
+ def test_data_pipeline_has_pipeline_py(self, env):
543
+ obs = env.reset(task_id="data-pipeline")
544
+ assert "pipeline.py" in obs.code_files
545
+ code = obs.code_files["pipeline.py"]
546
+ assert "sqlite3" in code
547
+ assert "hashlib" in code
548
+
549
+ def test_data_pipeline_max_steps(self):
550
+ from tasks.data import ALL_TASKS
551
+ task = ALL_TASKS["data-pipeline"]
552
+ assert task["max_steps"] == 25
553
+
554
+ def test_task_count(self):
555
+ from tasks.data import TASK_IDS
556
+ assert len(TASK_IDS) >= 6
557
+
558
+ def test_async_review_correct_tp_reward(self, env_async):
559
+ """Flagging a known issue in async-review should give positive reward."""
560
+ obs = env_async.step(ReviewAction(
561
+ action_type="flag_issue", line_number=22, filename="async.py",
562
+ issue_type="bug", severity="high",
563
+ description="ClientSession not closed"
564
+ ))
565
+ assert obs.reward is not None and obs.reward > 0
566
+
567
+ def test_data_pipeline_correct_tp_reward(self, env_pipeline):
568
+ """Flagging a known SQL injection in pipeline.py should give positive reward."""
569
+ obs = env_pipeline.step(ReviewAction(
570
+ action_type="flag_issue", line_number=27, filename="pipeline.py",
571
+ issue_type="security", severity="critical",
572
+ description="SQL injection"
573
+ ))
574
+ assert obs.reward is not None and obs.reward > 0
575
+
576
+ def test_all_tasks_have_hints(self):
577
+ from tasks.data import ALL_TASKS
578
+ for task_id, task in ALL_TASKS.items():
579
+ assert "hints" in task, f"Task {task_id} missing hints"
580
+ assert len(task["hints"]) >= 3, f"Task {task_id} has fewer than 3 hints"
581
+
582
+
583
+ # ---------------------------------------------------------------------------
584
+ # Observation serialization
585
+ # ---------------------------------------------------------------------------
586
+
587
+ class TestObservationSerialization:
588
+ def test_reset_obs_to_dict_has_new_fields(self, env):
589
+ """to_dict() should include all new fields."""
590
+ obs = env.reset(task_id="bug-detection")
591
+ d = obs.to_dict()
592
+ assert "reward_breakdown" in d
593
+ assert "progress" in d
594
+ assert "flagged_summary" in d
595
+ assert "code_metadata" in d
596
+
597
+ def test_obs_from_dict_handles_missing_new_fields(self):
598
+ """from_dict() should handle missing new fields gracefully."""
599
+ d = {
600
+ "task_id": "bug-detection",
601
+ "task_description": "test",
602
+ "code_files": {},
603
+ "language": "python",
604
+ "flagged_issues": [],
605
+ "step_count": 0,
606
+ "max_steps": 15,
607
+ "hints_remaining": 3,
608
+ "feedback": "",
609
+ "current_score": 0.0,
610
+ "done": False,
611
+ "reward": None,
612
+ # No reward_breakdown, progress, flagged_summary, code_metadata
613
+ }
614
+ obs = ReviewObservation.from_dict(d)
615
+ assert obs.reward_breakdown == {}
616
+ assert obs.progress == {}
617
+ assert obs.flagged_summary == {}
618
+ assert obs.code_metadata == {}
619
+
620
+ def test_step_obs_to_dict_round_trip(self, env_bug):
621
+ obs = env_bug.step(ReviewAction(
622
+ action_type="flag_issue", line_number=6, filename="utils.py",
623
+ issue_type="bug", severity="high", description="test"
624
+ ))
625
+ d = obs.to_dict()
626
+ obs2 = ReviewObservation.from_dict(d)
627
+ assert obs2.task_id == obs.task_id
628
+ assert obs2.step_count == obs.step_count
629
+ assert isinstance(obs2.reward_breakdown, dict)
630
+ assert isinstance(obs2.progress, dict)
631
+ assert isinstance(obs2.flagged_summary, dict)
632
+
633
+
634
+ # ---------------------------------------------------------------------------
635
+ # Severity exact match bonus
636
+ # ---------------------------------------------------------------------------
637
+
638
+ class TestSeverityBonus:
639
+ def test_severity_match_gives_extra_reward(self, env_bug):
640
+ """Exact severity match should give more than a severity mismatch."""
641
+ # GT at line 6 is "high"
642
+ obs_match = env_bug.step(ReviewAction(
643
+ action_type="flag_issue", line_number=6, filename="utils.py",
644
+ issue_type="bug", severity="high", description="exact severity"
645
+ ))
646
+ env_bug.reset(task_id="bug-detection")
647
+ obs_wrong = env_bug.step(ReviewAction(
648
+ action_type="flag_issue", line_number=6, filename="utils.py",
649
+ issue_type="bug", severity="low", description="wrong severity"
650
+ ))
651
+ assert obs_match.reward > obs_wrong.reward
652
+
653
+ def test_severity_bonus_in_reward_breakdown(self, env_bug):
654
+ """reward_breakdown should include 'severity_exact' key on correct severity."""
655
+ obs = env_bug.step(ReviewAction(
656
+ action_type="flag_issue", line_number=6, filename="utils.py",
657
+ issue_type="bug", severity="high", description="correct severity"
658
+ ))
659
+ assert "severity_exact" in obs.reward_breakdown
660
+
661
+ def test_severity_mismatch_no_severity_bonus(self, env_bug):
662
+ """Wrong severity should not include 'severity_exact' key."""
663
+ obs = env_bug.step(ReviewAction(
664
+ action_type="flag_issue", line_number=6, filename="utils.py",
665
+ issue_type="bug", severity="low", description="wrong severity"
666
+ ))
667
+ assert "severity_exact" not in obs.reward_breakdown
668
+
669
+
670
+ # ---------------------------------------------------------------------------
671
+ # Flood protection (escalating FP penalty)
672
+ # ---------------------------------------------------------------------------
673
+
674
+ class TestFloodProtection:
675
+ def test_many_fps_escalate_penalty(self, env_bug):
676
+ """After 3 false positives, each subsequent FP should have larger penalty."""
677
+ rewards = []
678
+ for line in [101, 102, 103, 104, 105]:
679
+ obs = env_bug.step(ReviewAction(
680
+ action_type="flag_issue", line_number=line, filename="utils.py",
681
+ issue_type="bug", severity="low", description="fp"
682
+ ))
683
+ if obs.reward is not None and obs.reward < 0:
684
+ rewards.append(obs.reward)
685
+
686
+ # The 4th and 5th FPs should have larger absolute penalty
687
+ if len(rewards) >= 4:
688
+ assert abs(rewards[-1]) >= abs(rewards[0]), (
689
+ f"Expected escalating penalty but got {rewards}"
690
+ )
691
+
692
+ def test_fp_below_threshold_normal_penalty(self, env_bug):
693
+ """First FP should get standard -0.05 penalty."""
694
+ obs = env_bug.step(ReviewAction(
695
+ action_type="flag_issue", line_number=200, filename="utils.py",
696
+ issue_type="bug", severity="low", description="first fp"
697
+ ))
698
+ assert obs.reward is not None
699
+ assert obs.reward == pytest.approx(-0.05, abs=0.01)
700
+
701
+ def test_clearing_fp_reduces_penalty_track(self, env_bug):
702
+ """Clearing a FP should give positive reward."""
703
+ env_bug.step(ReviewAction(
704
+ action_type="flag_issue", line_number=200, filename="utils.py",
705
+ issue_type="bug", severity="low", description="fp"
706
+ ))
707
+ obs = env_bug.step(ReviewAction(
708
+ action_type="clear_flag", line_number=200, filename="utils.py",
709
+ ))
710
+ assert obs.reward is not None and obs.reward > 0
711
+
712
+
713
+ # ---------------------------------------------------------------------------
714
+ # Unfound issue types in progress
715
+ # ---------------------------------------------------------------------------
716
+
717
+ class TestUnfoundIssueTypes:
718
+ def test_unfound_types_present_at_start(self, env_bug):
719
+ """Before flagging anything, all GT issue types should be in unfound_issue_types."""
720
+ obs = env_bug.step(ReviewAction(action_type="request_hint"))
721
+ unfound = obs.progress.get("unfound_issue_types", [])
722
+ assert "bug" in unfound
723
+
724
+ def test_unfound_types_shrinks_when_issue_found(self, env_bug):
725
+ """Finding a bug should remove 'bug' from unfound_issue_types."""
726
+ obs_before = env_bug.step(ReviewAction(action_type="request_hint"))
727
+ unfound_before = set(obs_before.progress.get("unfound_issue_types", []))
728
+
729
+ env_bug.step(ReviewAction(
730
+ action_type="flag_issue", line_number=6, filename="utils.py",
731
+ issue_type="bug", severity="high", description="found a bug"
732
+ ))
733
+ obs_after = env_bug.step(ReviewAction(action_type="request_hint"))
734
+ unfound_after = set(obs_after.progress.get("unfound_issue_types", []))
735
+
736
+ # bug should now be gone from unfound
737
+ assert "bug" not in unfound_after or len(unfound_after) < len(unfound_before)
738
+
739
+ def test_unfound_types_is_list(self, env_bug):
740
+ obs = env_bug.step(ReviewAction(action_type="request_hint"))
741
+ assert isinstance(obs.progress.get("unfound_issue_types", []), list)
742
+
743
+
744
+ # ---------------------------------------------------------------------------
745
+ # API security task
746
+ # ---------------------------------------------------------------------------
747
+
748
+ class TestApiSecurityTask:
749
+ def test_api_security_task_exists(self, env):
750
+ obs = env.reset(task_id="api-security")
751
+ assert obs.task_id == "api-security"
752
+ assert obs.done is False
753
+
754
+ def test_api_security_has_api_py(self, env):
755
+ obs = env.reset(task_id="api-security")
756
+ assert "api.py" in obs.code_files
757
+
758
+ def test_api_security_has_8_issues(self):
759
+ from tasks.data import ALL_TASKS
760
+ task = ALL_TASKS["api-security"]
761
+ assert len(task["ground_truth_issues"]) == 8
762
+
763
+ def test_api_security_has_critical_issues(self):
764
+ from tasks.data import ALL_TASKS
765
+ task = ALL_TASKS["api-security"]
766
+ severities = {i["severity"] for i in task["ground_truth_issues"]}
767
+ assert "critical" in severities
768
+
769
+ def test_api_security_tp_reward(self, env):
770
+ env.reset(task_id="api-security")
771
+ obs = env.step(ReviewAction(
772
+ action_type="flag_issue", line_number=38, filename="api.py",
773
+ issue_type="security", severity="critical",
774
+ description="SQL injection via f-string"
775
+ ))
776
+ assert obs.reward is not None and obs.reward > 0
777
+
778
+ def test_api_security_keyword_baseline_finds_issues(self):
779
+ from tasks.data import ALL_TASKS
780
+ from server.graders import run_keyword_baseline
781
+ task = ALL_TASKS["api-security"]
782
+ findings = run_keyword_baseline(task)
783
+ assert len(findings) >= 2
784
+
785
+ def test_api_security_difficulty_hard(self):
786
+ from tasks.data import ALL_TASKS
787
+ task = ALL_TASKS["api-security"]
788
+ assert task["difficulty"] == "hard"
789
+
790
+
791
+ # ---------------------------------------------------------------------------
792
+ # Auto-end gives full score (not 0.5x)
793
+ # ---------------------------------------------------------------------------
794
+
795
+ class TestAutoEndFullScore:
796
+ def test_auto_end_uses_full_grade(self, env_bug):
797
+ """Auto-end should give full grade_episode score, not a penalized value."""
798
+ # Flag all 3 correct bugs first
799
+ for line, sev in [(6, "high"), (13, "medium"), (33, "low")]:
800
+ env_bug.step(ReviewAction(
801
+ action_type="flag_issue", line_number=line, filename="utils.py",
802
+ issue_type="bug", severity=sev, description=f"bug at {line}"
803
+ ))
804
+ # Exhaust remaining steps with hints
805
+ max_steps = 15
806
+ for _ in range(max_steps - 3 - 1):
807
+ obs = env_bug.step(ReviewAction(action_type="request_hint"))
808
+ if obs.done:
809
+ break
810
+
811
+ obs = env_bug.step(ReviewAction(action_type="request_hint"))
812
+ if obs.done and obs.reward_breakdown.get("auto_end_grade") is not None:
813
+ # If auto-ended, score should be >= 0.7 since all 3 bugs found
814
+ assert obs.reward >= 0.7, f"Auto-end gave {obs.reward} instead of full grade"
815
+
816
+
817
+ # ---------------------------------------------------------------------------
818
+ # Function ranges in code_metadata
819
+ # ---------------------------------------------------------------------------
820
+
821
+ class TestFunctionRanges:
822
+ def test_reset_has_function_ranges(self, env):
823
+ obs = env.reset(task_id="bug-detection")
824
+ assert "function_ranges" in obs.code_metadata
825
+
826
+ def test_function_ranges_is_list(self, env):
827
+ obs = env.reset(task_id="bug-detection")
828
+ assert isinstance(obs.code_metadata["function_ranges"], list)
829
+
830
+ def test_function_ranges_have_required_fields(self, env):
831
+ obs = env.reset(task_id="bug-detection")
832
+ for fr in obs.code_metadata["function_ranges"]:
833
+ assert "name" in fr
834
+ assert "file" in fr
835
+ assert "start" in fr
836
+ assert "end" in fr
837
+
838
+ def test_function_ranges_nonempty_for_python(self, env):
839
+ obs = env.reset(task_id="bug-detection")
840
+ assert len(obs.code_metadata["function_ranges"]) > 0
tests/test_graders.py CHANGED
@@ -7,7 +7,11 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
7
 
8
  import pytest
9
  from models import Issue
10
- from server.graders import grade_episode, match_issue, run_keyword_baseline
 
 
 
 
11
  from tasks.data import ALL_TASKS, TASK_IDS
12
 
13
 
@@ -56,6 +60,231 @@ class TestMatchIssue:
56
  gt = _issue(6, "utils.py", "bug", "high")
57
  assert match_issue(f, gt) is False
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  # ---------------------------------------------------------------------------
61
  # grade_episode()
@@ -177,6 +406,23 @@ class TestKeywordBaseline:
177
  if task_id == "security-audit":
178
  assert score > 0.0, f"Heuristic found nothing in {task_id}"
179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
  # ---------------------------------------------------------------------------
182
  # Ground truth sanity checks
@@ -213,3 +459,159 @@ class TestGroundTruth:
213
  files = {i["filename"] for i in task["ground_truth_issues"]}
214
  assert "views.py" in files
215
  assert "models.py" in files
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  import pytest
9
  from models import Issue
10
+ from server.graders import (
11
+ grade_episode, match_issue, run_keyword_baseline,
12
+ match_quality, compute_code_metadata, grade_episode_detailed,
13
+ NEAR_TOLERANCE,
14
+ )
15
  from tasks.data import ALL_TASKS, TASK_IDS
16
 
17
 
 
60
  gt = _issue(6, "utils.py", "bug", "high")
61
  assert match_issue(f, gt) is False
62
 
63
+ def test_near_tolerance_param_accepted(self):
64
+ """match_issue should accept near_tolerance param without error."""
65
+ f = _issue(6, "utils.py", "bug", "high")
66
+ gt = _issue(6, "utils.py", "bug", "high")
67
+ result = match_issue(f, gt, line_tolerance=2, near_tolerance=5)
68
+ assert result is True
69
+
70
+
71
+ # ---------------------------------------------------------------------------
72
+ # match_quality()
73
+ # ---------------------------------------------------------------------------
74
+
75
+ class TestMatchQuality:
76
+ def test_exact_match_within_2_lines(self):
77
+ f = _issue(7, "utils.py", "bug", "high")
78
+ gt = _issue(6, "utils.py", "bug", "high")
79
+ assert match_quality(f, gt) == "exact"
80
+
81
+ def test_near_match_3_to_5_lines(self):
82
+ # 4 lines away from GT at 6 → near
83
+ f = _issue(10, "utils.py", "bug", "high")
84
+ gt = _issue(6, "utils.py", "bug", "high")
85
+ assert match_quality(f, gt) == "near"
86
+
87
+ def test_near_match_exactly_3_lines(self):
88
+ f = _issue(9, "utils.py", "bug", "high")
89
+ gt = _issue(6, "utils.py", "bug", "high")
90
+ assert match_quality(f, gt) == "near"
91
+
92
+ def test_near_match_exactly_5_lines(self):
93
+ f = _issue(11, "utils.py", "bug", "high")
94
+ gt = _issue(6, "utils.py", "bug", "high")
95
+ assert match_quality(f, gt) == "near"
96
+
97
+ def test_no_match_beyond_5_lines(self):
98
+ f = _issue(12, "utils.py", "bug", "high")
99
+ gt = _issue(6, "utils.py", "bug", "high")
100
+ assert match_quality(f, gt) == "none"
101
+
102
+ def test_no_match_wrong_file(self):
103
+ f = _issue(6, "other.py", "bug", "high")
104
+ gt = _issue(6, "utils.py", "bug", "high")
105
+ assert match_quality(f, gt) == "none"
106
+
107
+ def test_near_ignores_type_difference(self):
108
+ """Near match checks same file + line range, ignores type."""
109
+ f = _issue(10, "utils.py", "performance", "high")
110
+ gt = _issue(6, "utils.py", "bug", "high")
111
+ # 4 lines away → near
112
+ assert match_quality(f, gt) == "near"
113
+
114
+ def test_near_tolerance_constant(self):
115
+ assert NEAR_TOLERANCE == 5
116
+
117
+
118
+ # ---------------------------------------------------------------------------
119
+ # compute_code_metadata()
120
+ # ---------------------------------------------------------------------------
121
+
122
+ class TestComputeCodeMetadata:
123
+ def test_returns_dict(self):
124
+ code = {"test.py": "def foo(): pass\n"}
125
+ result = compute_code_metadata(code)
126
+ assert isinstance(result, dict)
127
+
128
+ def test_total_lines(self):
129
+ code = {"test.py": "line1\nline2\nline3\n"}
130
+ result = compute_code_metadata(code)
131
+ assert result["total_lines"] == 3
132
+
133
+ def test_num_functions(self):
134
+ code = {"test.py": "def foo():\n pass\n\ndef bar():\n pass\n"}
135
+ result = compute_code_metadata(code)
136
+ assert result["num_functions"] == 2
137
+
138
+ def test_function_names(self):
139
+ code = {"test.py": "def foo():\n pass\n\ndef bar():\n pass\n"}
140
+ result = compute_code_metadata(code)
141
+ assert "foo" in result["function_names"]
142
+ assert "bar" in result["function_names"]
143
+
144
+ def test_num_classes(self):
145
+ code = {"test.py": "class Foo:\n pass\n\nclass Bar:\n pass\n"}
146
+ result = compute_code_metadata(code)
147
+ assert result["num_classes"] == 2
148
+
149
+ def test_class_names(self):
150
+ code = {"test.py": "class Foo:\n pass\n"}
151
+ result = compute_code_metadata(code)
152
+ assert "Foo" in result["class_names"]
153
+
154
+ def test_imports(self):
155
+ code = {"test.py": "import os\nimport sys\nfrom typing import List\n"}
156
+ result = compute_code_metadata(code)
157
+ assert "os" in result["imports"]
158
+ assert "sys" in result["imports"]
159
+ assert "typing" in result["imports"]
160
+
161
+ def test_complexity_low(self):
162
+ code = {"test.py": "def foo():\n return 1\n"}
163
+ result = compute_code_metadata(code)
164
+ assert result["complexity_estimate"] == "low"
165
+
166
+ def test_complexity_medium(self):
167
+ # 6-15 branches — each if is top-level so indent is fine
168
+ lines = ["def foo(x):"]
169
+ for i in range(8):
170
+ lines.append(f" if x > {i}:")
171
+ lines.append(" pass")
172
+ code = {"test.py": "\n".join(lines) + "\n"}
173
+ result = compute_code_metadata(code)
174
+ assert result["complexity_estimate"] in ("medium", "high")
175
+
176
+ def test_complexity_high(self):
177
+ # 16+ branches
178
+ lines = ["def foo(x):"]
179
+ for i in range(20):
180
+ lines.append(f" if x > {i}:")
181
+ lines.append(" pass")
182
+ code = {"test.py": "\n".join(lines) + "\n"}
183
+ result = compute_code_metadata(code)
184
+ assert result["complexity_estimate"] == "high"
185
+
186
+ def test_issue_categories_passed_through(self):
187
+ code = {"test.py": "x = 1\n"}
188
+ result = compute_code_metadata(code, issue_categories=["bug", "security", "bug"])
189
+ # Should deduplicate
190
+ cats = result["issue_categories"]
191
+ assert "bug" in cats
192
+ assert "security" in cats
193
+
194
+ def test_syntax_error_no_crash(self):
195
+ """Non-parseable code should not raise."""
196
+ code = {"bad.py": "this is not valid python !!!\n def broken("}
197
+ result = compute_code_metadata(code)
198
+ assert "total_lines" in result
199
+ assert result["total_lines"] >= 1
200
+
201
+ def test_multi_file(self):
202
+ code = {
203
+ "a.py": "def foo():\n pass\n",
204
+ "b.py": "def bar():\n pass\n",
205
+ }
206
+ result = compute_code_metadata(code)
207
+ assert result["num_functions"] == 2
208
+ assert result["total_lines"] == 4
209
+
210
+ def test_utils_task_metadata(self):
211
+ from tasks.data import ALL_TASKS
212
+ task = ALL_TASKS["bug-detection"]
213
+ result = compute_code_metadata(task["code_files"])
214
+ assert result["total_lines"] > 0
215
+ assert result["num_functions"] >= 4 # utils.py has 4 functions
216
+
217
+
218
+ # ---------------------------------------------------------------------------
219
+ # grade_episode_detailed()
220
+ # ---------------------------------------------------------------------------
221
+
222
+ class TestGradeEpisodeDetailed:
223
+ def test_returns_dict(self):
224
+ gt = [_issue(6, "utils.py", "bug", "high")]
225
+ result = grade_episode_detailed(gt, gt)
226
+ assert isinstance(result, dict)
227
+
228
+ def test_required_keys(self):
229
+ gt = [_issue(6, "utils.py", "bug", "high")]
230
+ result = grade_episode_detailed(gt, gt)
231
+ for key in ("score", "f1", "precision", "recall", "severity_accuracy",
232
+ "true_positives", "false_positives", "false_negatives",
233
+ "near_misses", "per_file"):
234
+ assert key in result, f"Missing key: {key}"
235
+
236
+ def test_perfect_match(self):
237
+ gt = [_issue(6, "utils.py", "bug", "high")]
238
+ result = grade_episode_detailed(gt, gt)
239
+ assert result["true_positives"] == 1
240
+ assert result["false_positives"] == 0
241
+ assert result["false_negatives"] == 0
242
+
243
+ def test_false_positive_counted(self):
244
+ gt = [_issue(6, "utils.py", "bug", "high")]
245
+ flagged = [_issue(6, "utils.py", "bug", "high"),
246
+ _issue(100, "utils.py", "bug", "low")]
247
+ result = grade_episode_detailed(flagged, gt)
248
+ assert result["false_positives"] >= 1
249
+
250
+ def test_near_miss_counted(self):
251
+ gt = [_issue(6, "utils.py", "bug", "high")]
252
+ # 4 lines away = near miss
253
+ flagged = [_issue(10, "utils.py", "bug", "high")]
254
+ result = grade_episode_detailed(flagged, gt)
255
+ assert result["near_misses"] >= 1
256
+
257
+ def test_per_file_breakdown(self):
258
+ gt = [
259
+ _issue(6, "utils.py", "bug", "high"),
260
+ _issue(10, "other.py", "security", "critical"),
261
+ ]
262
+ flagged = [_issue(6, "utils.py", "bug", "high")]
263
+ result = grade_episode_detailed(flagged, gt)
264
+ assert "utils.py" in result["per_file"]
265
+
266
+ def test_score_matches_grade_episode(self):
267
+ """Detailed score should match grade_episode for simple cases."""
268
+ gt = [
269
+ _issue(6, "utils.py", "bug", "high"),
270
+ _issue(13, "utils.py", "bug", "medium"),
271
+ ]
272
+ flagged = [_issue(6, "utils.py", "bug", "high")]
273
+ simple_score = grade_episode(flagged, gt)
274
+ detailed = grade_episode_detailed(flagged, gt)
275
+ # Scores may differ slightly (near_miss handling), but should be close
276
+ assert abs(detailed["score"] - simple_score) <= 0.15
277
+
278
+ def test_empty_ground_truth_perfect(self):
279
+ result = grade_episode_detailed([], [])
280
+ assert result["score"] == 1.0
281
+
282
+ def test_empty_flagged_zero(self):
283
+ gt = [_issue(6, "utils.py")]
284
+ result = grade_episode_detailed([], gt)
285
+ assert result["score"] == 0.0
286
+ assert result["false_negatives"] == 1
287
+
288
 
289
  # ---------------------------------------------------------------------------
290
  # grade_episode()
 
406
  if task_id == "security-audit":
407
  assert score > 0.0, f"Heuristic found nothing in {task_id}"
408
 
409
+ def test_baseline_finds_md5_in_pipeline(self):
410
+ """Keyword baseline should find the MD5 issue in data-pipeline."""
411
+ from tasks.data import ALL_TASKS
412
+ task = ALL_TASKS["data-pipeline"]
413
+ findings = run_keyword_baseline(task)
414
+ md5_finds = [f for f in findings if "md5" in f.description.lower() or "MD5" in f.description]
415
+ assert len(md5_finds) >= 1
416
+
417
+ def test_baseline_finds_sql_injection_in_pipeline(self):
418
+ """Keyword baseline should find SQL injection via f-string in pipeline.py."""
419
+ from tasks.data import ALL_TASKS
420
+ task = ALL_TASKS["data-pipeline"]
421
+ findings = run_keyword_baseline(task)
422
+ sql_finds = [f for f in findings if f.issue_type == "security"
423
+ and "sql" in f.description.lower()]
424
+ assert len(sql_finds) >= 1
425
+
426
 
427
  # ---------------------------------------------------------------------------
428
  # Ground truth sanity checks
 
459
  files = {i["filename"] for i in task["ground_truth_issues"]}
460
  assert "views.py" in files
461
  assert "models.py" in files
462
+
463
+ def test_async_review_has_6_issues(self):
464
+ task = ALL_TASKS["async-review"]
465
+ assert len(task["ground_truth_issues"]) == 6
466
+
467
+ def test_data_pipeline_has_7_issues(self):
468
+ task = ALL_TASKS["data-pipeline"]
469
+ assert len(task["ground_truth_issues"]) == 7
470
+
471
+ def test_async_review_issues_in_async_py(self):
472
+ task = ALL_TASKS["async-review"]
473
+ for issue in task["ground_truth_issues"]:
474
+ assert issue["filename"] == "async.py"
475
+
476
+ def test_data_pipeline_issues_in_pipeline_py(self):
477
+ task = ALL_TASKS["data-pipeline"]
478
+ for issue in task["ground_truth_issues"]:
479
+ assert issue["filename"] == "pipeline.py"
480
+
481
+ def test_data_pipeline_has_security_and_performance(self):
482
+ task = ALL_TASKS["data-pipeline"]
483
+ types = {i["issue_type"] for i in task["ground_truth_issues"]}
484
+ assert "security" in types
485
+ assert "performance" in types
486
+
487
+ def test_async_review_has_bug_and_performance(self):
488
+ task = ALL_TASKS["async-review"]
489
+ types = {i["issue_type"] for i in task["ground_truth_issues"]}
490
+ assert "bug" in types
491
+ assert "performance" in types
492
+
493
+ def test_all_tasks_count(self):
494
+ assert len(ALL_TASKS) >= 6
495
+
496
+ def test_async_review_line_numbers_are_valid(self):
497
+ """GT issue line numbers should be within the code file."""
498
+ from tasks.data import TASK_ASYNC_REVIEW
499
+ code = TASK_ASYNC_REVIEW["code_files"]["async.py"]
500
+ total_lines = len(code.splitlines())
501
+ for issue in TASK_ASYNC_REVIEW["ground_truth_issues"]:
502
+ assert 1 <= issue["line_number"] <= total_lines, (
503
+ f"Line {issue['line_number']} out of range (file has {total_lines} lines)"
504
+ )
505
+
506
+ def test_pipeline_line_numbers_are_valid(self):
507
+ """GT issue line numbers should be within the code file."""
508
+ from tasks.data import TASK_DATA_PIPELINE
509
+ code = TASK_DATA_PIPELINE["code_files"]["pipeline.py"]
510
+ total_lines = len(code.splitlines())
511
+ for issue in TASK_DATA_PIPELINE["ground_truth_issues"]:
512
+ assert 1 <= issue["line_number"] <= total_lines, (
513
+ f"Line {issue['line_number']} out of range (file has {total_lines} lines)"
514
+ )
515
+
516
+ def test_api_security_has_8_issues(self):
517
+ from tasks.data import ALL_TASKS
518
+ task = ALL_TASKS["api-security"]
519
+ assert len(task["ground_truth_issues"]) == 8
520
+
521
+ def test_api_security_line_numbers_are_valid(self):
522
+ from tasks.data import ALL_TASKS
523
+ task = ALL_TASKS["api-security"]
524
+ code = task["code_files"]["api.py"]
525
+ total_lines = len(code.splitlines())
526
+ for issue in task["ground_truth_issues"]:
527
+ assert 1 <= issue["line_number"] <= total_lines, (
528
+ f"Line {issue['line_number']} out of range (file has {total_lines} lines)"
529
+ )
530
+
531
+ def test_api_security_has_security_issues(self):
532
+ from tasks.data import ALL_TASKS
533
+ task = ALL_TASKS["api-security"]
534
+ types = {i["issue_type"] for i in task["ground_truth_issues"]}
535
+ assert "security" in types
536
+
537
+
538
+ # ---------------------------------------------------------------------------
539
+ # compute_function_map and function_ranges in metadata
540
+ # ---------------------------------------------------------------------------
541
+
542
+ class TestFunctionRangesMetadata:
543
+ def test_function_ranges_in_metadata(self):
544
+ code = {"test.py": "def foo():\n return 1\n\ndef bar(x):\n return x\n"}
545
+ result = compute_code_metadata(code)
546
+ assert "function_ranges" in result
547
+ assert len(result["function_ranges"]) == 2
548
+
549
+ def test_function_ranges_have_correct_fields(self):
550
+ code = {"test.py": "def foo():\n return 1\n"}
551
+ result = compute_code_metadata(code)
552
+ fr = result["function_ranges"][0]
553
+ assert fr["name"] == "foo"
554
+ assert fr["file"] == "test.py"
555
+ assert "start" in fr
556
+ assert "end" in fr
557
+ assert fr["start"] <= fr["end"]
558
+
559
+ def test_function_ranges_empty_for_no_functions(self):
560
+ code = {"test.py": "x = 1\ny = 2\n"}
561
+ result = compute_code_metadata(code)
562
+ assert result["function_ranges"] == []
563
+
564
+ def test_function_ranges_multifile(self):
565
+ code = {
566
+ "a.py": "def foo():\n pass\n",
567
+ "b.py": "def bar():\n pass\n\ndef baz():\n pass\n",
568
+ }
569
+ result = compute_code_metadata(code)
570
+ names = {fr["name"] for fr in result["function_ranges"]}
571
+ assert names == {"foo", "bar", "baz"}
572
+
573
+ def test_function_ranges_correct_line_numbers(self):
574
+ code = {"test.py": "x = 1\n\ndef foo():\n return 1\n"}
575
+ result = compute_code_metadata(code)
576
+ assert len(result["function_ranges"]) == 1
577
+ assert result["function_ranges"][0]["start"] == 3 # line 3
578
+
579
+
580
+ # ---------------------------------------------------------------------------
581
+ # New keyword patterns
582
+ # ---------------------------------------------------------------------------
583
+
584
+ class TestNewKeywordPatterns:
585
+ def test_baseline_finds_hardcoded_admin_token(self):
586
+ from server.graders import run_keyword_baseline
587
+ from tasks.data import ALL_TASKS
588
+ task = ALL_TASKS["api-security"]
589
+ findings = run_keyword_baseline(task)
590
+ token_finds = [f for f in findings if "ADMIN_TOKEN" in f.description or "token" in f.description.lower()]
591
+ assert len(token_finds) >= 1
592
+
593
+ def test_baseline_finds_pickle_loads(self):
594
+ from server.graders import run_keyword_baseline
595
+ from tasks.data import ALL_TASKS
596
+ task = ALL_TASKS["api-security"]
597
+ findings = run_keyword_baseline(task)
598
+ pickle_finds = [f for f in findings if "pickle" in f.description.lower()]
599
+ assert len(pickle_finds) >= 1
600
+
601
+ def test_baseline_finds_os_system(self):
602
+ from server.graders import run_keyword_baseline
603
+ from tasks.data import ALL_TASKS
604
+ task = ALL_TASKS["api-security"]
605
+ findings = run_keyword_baseline(task)
606
+ sys_finds = [f for f in findings if "os.system" in f.description.lower() or "command" in f.description.lower()]
607
+ assert len(sys_finds) >= 1
608
+
609
+ def test_baseline_api_security_score_nonzero(self):
610
+ from server.graders import run_keyword_baseline, grade_episode
611
+ from models import Issue
612
+ from tasks.data import ALL_TASKS
613
+ task = ALL_TASKS["api-security"]
614
+ findings = run_keyword_baseline(task)
615
+ gt = [Issue.from_dict(i) for i in task["ground_truth_issues"]]
616
+ score = grade_episode(findings, gt)
617
+ assert score > 0.0, "Keyword baseline should find at least 1 issue in api-security"