ashishbaberwal commited on
Commit
1939cbc
·
1 Parent(s): abf2209

New Final

Browse files
README.md CHANGED
@@ -12,6 +12,12 @@ pinned: false
12
 
13
  This repository provides an OpenEnv-compatible environment for evaluating AI code-review agents.
14
 
 
 
 
 
 
 
15
  The agent receives a code diff and surrounding file context, then performs a multi-step review:
16
 
17
  1. Add issue comments with line numbers.
@@ -24,15 +30,20 @@ The environment scores the review quality using deterministic graders.
24
 
25
  - Simulates pull-request review tasks across easy/medium/hard difficulty.
26
  - Exposes OpenEnv-style lifecycle methods (`reset`, `step`, `state`).
 
27
  - Grades issue detection, fix suggestions, and final decision quality.
28
  - Supports local LLM providers via an OpenAI-compatible API (including Ollama).
 
29
 
30
  ## Project Structure
31
 
32
  - `environment/`: environment implementation, task definitions, models, and grading logic.
33
  - `inference.py`: baseline review agent loop.
 
 
34
  - `openenv.yaml`: task registry and environment metadata.
35
  - `tests/`: environment tests.
 
36
  - `docker-compose.yml` / `Dockerfile`: containerized execution options.
37
 
38
  ## Prerequisites
@@ -154,9 +165,24 @@ Note: on macOS, `network_mode: host` can be unreliable. If `local-agent` cannot
154
  - `memory_leak_medium_1`
155
  - `performance_medium_2`
156
  - `approve_medium_3`
 
 
157
  - `security_hard_1`
158
  - `race_condition_hard_2`
159
  - `approve_hard_3`
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
  ## Output Format
162
 
@@ -221,6 +247,29 @@ python submit.py --skip-docker --max-steps 10
221
 
222
  Note: `task_score` is normalized to [0,1]. `total_reward` is cumulative step reward and can exceed 1.0 by design.
223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  ## One-Command Benchmark Table
225
 
226
  Generate per-task JSON outputs plus a markdown table for judge submission:
@@ -237,8 +286,18 @@ Artifacts:
237
 
238
  ## Failure Analysis Template
239
 
240
- - Missed issue type:
241
- - Why it was missed (model behavior or prompt failure):
242
- - Grader diagnostics (precision/recall/F1/FP):
243
- - Fix applied (prompt/rubric/task change):
 
 
 
 
 
 
 
 
 
 
244
 
 
12
 
13
  This repository provides an OpenEnv-compatible environment for evaluating AI code-review agents.
14
 
15
+ ## Why This Environment
16
+
17
+ Code review is a strong RL task because success and failure are measurable: line-level issues can be deterministically graded, rewards can be shaped across review phases, and tasks can scale from easy to hard while staying realistic.
18
+
19
+ This project is designed for both evaluation and lightweight policy training loops, not only one-off scripted inference.
20
+
21
  The agent receives a code diff and surrounding file context, then performs a multi-step review:
22
 
23
  1. Add issue comments with line numbers.
 
30
 
31
  - Simulates pull-request review tasks across easy/medium/hard difficulty.
32
  - Exposes OpenEnv-style lifecycle methods (`reset`, `step`, `state`).
33
+ - Exposes integration endpoints (`tasks`, `score`, `health`) for tooling and dashboard checks.
34
  - Grades issue detection, fix suggestions, and final decision quality.
35
  - Supports local LLM providers via an OpenAI-compatible API (including Ollama).
36
+ - Includes a policy-training scaffold (`train.py`, `train_env.py`) and logged training metrics.
37
 
38
  ## Project Structure
39
 
40
  - `environment/`: environment implementation, task definitions, models, and grading logic.
41
  - `inference.py`: baseline review agent loop.
42
+ - `train.py`, `train_env.py`: lightweight PPO-style policy training loop over the environment.
43
+ - `ppo_logs/`: training metrics and summaries.
44
  - `openenv.yaml`: task registry and environment metadata.
45
  - `tests/`: environment tests.
46
+ - `explore_env.ipynb`: interactive environment walkthrough.
47
  - `docker-compose.yml` / `Dockerfile`: containerized execution options.
48
 
49
  ## Prerequisites
 
165
  - `memory_leak_medium_1`
166
  - `performance_medium_2`
167
  - `approve_medium_3`
168
+ - `type_safety_medium_4`
169
+ - `javascript_medium_5`
170
  - `security_hard_1`
171
  - `race_condition_hard_2`
172
  - `approve_hard_3`
173
+ - `adversarial_hard_4`
174
+ - `concurrency_hard_5`
175
+ - `dependency_injection_hard_6`
176
+
177
+ ## HTTP Endpoints
178
+
179
+ - `GET /`
180
+ - `GET /health`
181
+ - `GET /tasks`
182
+ - `GET|POST /reset`
183
+ - `POST /step`
184
+ - `GET /state`
185
+ - `GET /score`
186
 
187
  ## Output Format
188
 
 
247
 
248
  Note: `task_score` is normalized to [0,1]. `total_reward` is cumulative step reward and can exceed 1.0 by design.
249
 
250
+ ## Training Results (PPO-style Loop)
251
+
252
+ Run training:
253
+
254
+ ```bash
255
+ source .venv/bin/activate
256
+ python train.py --episodes 120 --max-steps 5
257
+ ```
258
+
259
+ Generated artifacts:
260
+
261
+ - `ppo_logs/train_metrics.csv`
262
+ - `ppo_logs/summary.txt`
263
+
264
+ Recent run summary:
265
+
266
+ - Episodes: `120`
267
+ - Average reward (first 10): `0.0100`
268
+ - Average reward (last 10): `0.5100`
269
+ - Improvement: `+0.5000`
270
+
271
+ This demonstrates measurable policy improvement under the training setup provided in this repository.
272
+
273
  ## One-Command Benchmark Table
274
 
275
  Generate per-task JSON outputs plus a markdown table for judge submission:
 
286
 
287
  ## Failure Analysis Template
288
 
289
+ 1. `javascript_medium_5` (Undefined access)
290
+ - Observation: task score reached `1.0`, but diagnostics show `precision=0.5`, `recall=1.0`, `f1=0.6667`, `false_positive_count=1`.
291
+ - Why: model used Python-centric heuristics and produced one extra issue comment on a JS snippet.
292
+ - Action: added JavaScript task category and retained false-positive penalties to expose over-flagging.
293
+
294
+ 2. `memory_leak_medium_1` (historical baseline run)
295
+ - Observation: earlier run dropped below perfect score due to noisy comment strategy.
296
+ - Why: over-commenting triggered false positive penalties despite finding the core issue.
297
+ - Action: anti-loop repeated-comment penalty + adversarial no-issue tasks to discourage spam.
298
+
299
+ 3. `adversarial_hard_4` (Safe SQL task)
300
+ - Observation: correct behavior is approve; naive SQL keyword matching causes false alarms.
301
+ - Why: keyword-only review policies confuse parameterized SQL with vulnerable string interpolation.
302
+ - Action: included explicit no-issue adversarial task in hard set and calibration tests to reward restraint.
303
 
environment/tasks.py CHANGED
@@ -177,6 +177,44 @@ def run_user_query(db, limit):
177
  "language": "python",
178
  "line_count": 3,
179
  "expected_issues": []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  }
181
  ]
182
 
@@ -297,6 +335,44 @@ def find_all_users(database):
297
  "language": "python",
298
  "line_count": 4,
299
  "expected_issues": []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  }
301
  ]
302
 
 
177
  "language": "python",
178
  "line_count": 3,
179
  "expected_issues": []
180
+ },
181
+ {
182
+ "task_id": "type_safety_medium_4",
183
+ "task_name": "Type Safety: Optional Arithmetic",
184
+ "difficulty": "medium",
185
+ "description": "Find the type safety issue where Optional[int] can be None during arithmetic",
186
+ "code_diff": """from typing import Optional\n\ndef increment(value: Optional[int]) -> int:\n return value + 1""",
187
+ "surrounding_code": """from typing import Optional\n\ndef increment(value: Optional[int]) -> int:\n return value + 1\n\ndef safe_increment(value: Optional[int]) -> int:\n return increment(value)""",
188
+ "file_path": "type_utils.py",
189
+ "language": "python",
190
+ "line_count": 4,
191
+ "expected_issues": [
192
+ {
193
+ "line": 4,
194
+ "type": "type_safety",
195
+ "severity": "medium",
196
+ "description": "Optional[int] may be None, causing runtime TypeError",
197
+ }
198
+ ]
199
+ },
200
+ {
201
+ "task_id": "javascript_medium_5",
202
+ "task_name": "JavaScript: Undefined Access",
203
+ "difficulty": "medium",
204
+ "description": "Find the JavaScript bug where user can be undefined before property access",
205
+ "code_diff": """function getUserName(user) {\n return user.name.trim();\n}""",
206
+ "surrounding_code": """function getUserName(user) {\n return user.name.trim();\n}\n\nfunction formatUser(user) {\n return getUserName(user).toLowerCase();\n}""",
207
+ "file_path": "user.js",
208
+ "language": "javascript",
209
+ "line_count": 3,
210
+ "expected_issues": [
211
+ {
212
+ "line": 2,
213
+ "type": "null_access",
214
+ "severity": "medium",
215
+ "description": "user may be undefined and property access can throw",
216
+ }
217
+ ]
218
  }
219
  ]
220
 
 
335
  "language": "python",
336
  "line_count": 4,
337
  "expected_issues": []
338
+ },
339
+ {
340
+ "task_id": "concurrency_hard_5",
341
+ "task_name": "Concurrency: Async Await Misuse",
342
+ "difficulty": "hard",
343
+ "description": "Find async misuse where created tasks are never awaited",
344
+ "code_diff": """import asyncio\n\nasync def process_all(items, worker):\n for item in items:\n asyncio.create_task(worker(item))\n return True""",
345
+ "surrounding_code": """import asyncio\n\nasync def process_all(items, worker):\n for item in items:\n asyncio.create_task(worker(item))\n return True\n\nasync def run(items, worker):\n return await process_all(items, worker)""",
346
+ "file_path": "async_processor.py",
347
+ "language": "python",
348
+ "line_count": 6,
349
+ "expected_issues": [
350
+ {
351
+ "line": 5,
352
+ "type": "async_misuse",
353
+ "severity": "high",
354
+ "description": "Tasks are created but never awaited or gathered",
355
+ }
356
+ ]
357
+ },
358
+ {
359
+ "task_id": "dependency_injection_hard_6",
360
+ "task_name": "Dependency Injection: Tight Coupling",
361
+ "difficulty": "hard",
362
+ "description": "Find design issue where service constructs hardcoded dependency internally",
363
+ "code_diff": """class PaymentService:\n def __init__(self):\n self.gateway = StripeGateway()\n\n def charge(self, amount):\n return self.gateway.charge(amount)""",
364
+ "surrounding_code": """class PaymentService:\n def __init__(self):\n self.gateway = StripeGateway()\n\n def charge(self, amount):\n return self.gateway.charge(amount)\n\nclass StripeGateway:\n def charge(self, amount):\n return True""",
365
+ "file_path": "payment_service.py",
366
+ "language": "python",
367
+ "line_count": 6,
368
+ "expected_issues": [
369
+ {
370
+ "line": 3,
371
+ "type": "dependency_injection",
372
+ "severity": "medium",
373
+ "description": "Hardcoded dependency prevents testability and inversion of control",
374
+ }
375
+ ]
376
  }
377
  ]
378
 
explore_env.ipynb ADDED
File without changes
openenv.yaml CHANGED
@@ -46,6 +46,14 @@ tasks:
46
  name: "Medium: Approve Safe Query Helper"
47
  difficulty: medium
48
 
 
 
 
 
 
 
 
 
49
  - id: security_hard_1
50
  name: "Hard: SQL Injection Vulnerability"
51
  difficulty: hard
@@ -62,6 +70,14 @@ tasks:
62
  name: "Hard: Adversarial Safe SQL Builder"
63
  difficulty: hard
64
 
 
 
 
 
 
 
 
 
65
  observation_space:
66
  type: dict
67
  description: |
 
46
  name: "Medium: Approve Safe Query Helper"
47
  difficulty: medium
48
 
49
+ - id: type_safety_medium_4
50
+ name: "Medium: Type Safety Optional Arithmetic"
51
+ difficulty: medium
52
+
53
+ - id: javascript_medium_5
54
+ name: "Medium: JavaScript Undefined Access"
55
+ difficulty: medium
56
+
57
  - id: security_hard_1
58
  name: "Hard: SQL Injection Vulnerability"
59
  difficulty: hard
 
70
  name: "Hard: Adversarial Safe SQL Builder"
71
  difficulty: hard
72
 
73
+ - id: concurrency_hard_5
74
+ name: "Hard: Async Await Misuse"
75
+ difficulty: hard
76
+
77
+ - id: dependency_injection_hard_6
78
+ name: "Hard: Tight Coupling in Service"
79
+ difficulty: hard
80
+
81
  observation_space:
82
  type: dict
83
  description: |
ppo_logs/README.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PPO Logs
2
+
3
+ This folder stores training artifacts produced by `train.py`.
4
+
5
+ Files:
6
+
7
+ - `train_metrics.csv`: per-episode reward, task_score, steps, and running baseline.
8
+ - `summary.txt`: compact training summary for README/judge evidence.
9
+
10
+ Example run:
11
+
12
+ ```bash
13
+ source .venv/bin/activate
14
+ python train.py --episodes 120 --max-steps 5
15
+ ```
ppo_logs/summary.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ episodes=120
2
+ avg_reward_first10=0.0100
3
+ avg_reward_last10=0.5100
4
+ improvement=0.5000
ppo_logs/train_metrics.csv ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ episode,reward,task_score,steps,baseline_reward
2
+ 1,0.01,0.0,3,0.001
3
+ 2,0.01,0.0,3,0.0019
4
+ 3,0.01,0.0,3,0.0027
5
+ 4,0.01,0.0,3,0.0034
6
+ 5,0.01,0.0,3,0.0041
7
+ 6,0.01,0.0,3,0.0047
8
+ 7,0.01,0.0,3,0.0052
9
+ 8,0.01,0.0,3,0.0057
10
+ 9,0.01,0.0,3,0.0061
11
+ 10,0.01,0.0,3,0.0065
12
+ 11,0.01,0.0,3,0.0069
13
+ 12,0.01,0.0,3,0.0072
14
+ 13,0.01,0.0,3,0.0075
15
+ 14,0.01,0.0,3,0.0077
16
+ 15,0.01,0.0,3,0.0079
17
+ 16,0.01,0.0,3,0.0081
18
+ 17,0.01,0.0,3,0.0083
19
+ 18,0.01,0.0,3,0.0085
20
+ 19,0.01,0.0,3,0.0086
21
+ 20,0.01,0.0,3,0.0088
22
+ 21,1.31,1.0,3,0.1389
23
+ 22,1.31,1.0,3,0.256
24
+ 23,1.31,1.0,3,0.3614
25
+ 24,1.31,1.0,3,0.4563
26
+ 25,1.31,1.0,3,0.5416
27
+ 26,1.31,1.0,3,0.6185
28
+ 27,1.31,1.0,3,0.6876
29
+ 28,1.31,1.0,3,0.7499
30
+ 29,1.31,1.0,3,0.8059
31
+ 30,0.51,0.4,3,0.7763
32
+ 31,0.51,0.4,3,0.7497
33
+ 32,0.51,0.4,3,0.7257
34
+ 33,0.51,0.4,3,0.7041
35
+ 34,0.51,0.4,3,0.6847
36
+ 35,0.51,0.4,3,0.6672
37
+ 36,0.51,0.4,3,0.6515
38
+ 37,0.51,0.4,3,0.6374
39
+ 38,0.51,0.4,3,0.6246
40
+ 39,0.51,0.4,3,0.6132
41
+ 40,0.51,0.4,3,0.6029
42
+ 41,0.51,0.4,3,0.5936
43
+ 42,0.51,0.4,3,0.5852
44
+ 43,0.51,0.4,3,0.5777
45
+ 44,0.51,0.4,3,0.5709
46
+ 45,0.51,0.4,3,0.5648
47
+ 46,0.51,0.4,3,0.5593
48
+ 47,0.51,0.4,3,0.5544
49
+ 48,0.51,0.4,3,0.55
50
+ 49,0.51,0.4,3,0.546
51
+ 50,0.51,0.4,3,0.5424
52
+ 51,0.51,0.4,3,0.5391
53
+ 52,0.51,0.4,3,0.5362
54
+ 53,0.51,0.4,3,0.5336
55
+ 54,0.51,0.4,3,0.5312
56
+ 55,0.51,0.4,3,0.5291
57
+ 56,0.51,0.4,3,0.5272
58
+ 57,0.51,0.4,3,0.5255
59
+ 58,0.51,0.4,3,0.5239
60
+ 59,0.51,0.4,3,0.5225
61
+ 60,0.51,0.4,3,0.5213
62
+ 61,0.51,0.4,3,0.5202
63
+ 62,0.51,0.4,3,0.5191
64
+ 63,0.51,0.4,3,0.5182
65
+ 64,0.51,0.4,3,0.5174
66
+ 65,0.51,0.4,3,0.5167
67
+ 66,0.51,0.4,3,0.516
68
+ 67,0.51,0.4,3,0.5154
69
+ 68,0.51,0.4,3,0.5149
70
+ 69,0.51,0.4,3,0.5144
71
+ 70,0.51,0.4,3,0.5139
72
+ 71,0.51,0.4,3,0.5135
73
+ 72,0.51,0.4,3,0.5132
74
+ 73,0.51,0.4,3,0.5129
75
+ 74,0.51,0.4,3,0.5126
76
+ 75,0.51,0.4,3,0.5123
77
+ 76,0.51,0.4,3,0.5121
78
+ 77,0.51,0.4,3,0.5119
79
+ 78,0.51,0.4,3,0.5117
80
+ 79,0.51,0.4,3,0.5115
81
+ 80,0.51,0.4,3,0.5114
82
+ 81,0.51,0.4,3,0.5112
83
+ 82,0.51,0.4,3,0.5111
84
+ 83,0.51,0.4,3,0.511
85
+ 84,0.51,0.4,3,0.5109
86
+ 85,0.51,0.4,3,0.5108
87
+ 86,0.51,0.4,3,0.5107
88
+ 87,0.51,0.4,3,0.5107
89
+ 88,0.51,0.4,3,0.5106
90
+ 89,0.51,0.4,3,0.5105
91
+ 90,0.51,0.4,3,0.5105
92
+ 91,0.51,0.4,3,0.5104
93
+ 92,0.51,0.4,3,0.5104
94
+ 93,0.51,0.4,3,0.5103
95
+ 94,0.51,0.4,3,0.5103
96
+ 95,0.51,0.4,3,0.5103
97
+ 96,0.51,0.4,3,0.5103
98
+ 97,0.51,0.4,3,0.5102
99
+ 98,0.51,0.4,3,0.5102
100
+ 99,0.51,0.4,3,0.5102
101
+ 100,0.51,0.4,3,0.5102
102
+ 101,0.51,0.4,3,0.5102
103
+ 102,0.51,0.4,3,0.5101
104
+ 103,0.51,0.4,3,0.5101
105
+ 104,0.51,0.4,3,0.5101
106
+ 105,0.51,0.4,3,0.5101
107
+ 106,0.51,0.4,3,0.5101
108
+ 107,0.51,0.4,3,0.5101
109
+ 108,0.51,0.4,3,0.5101
110
+ 109,0.51,0.4,3,0.5101
111
+ 110,0.51,0.4,3,0.5101
112
+ 111,0.51,0.4,3,0.5101
113
+ 112,0.51,0.4,3,0.51
114
+ 113,0.51,0.4,3,0.51
115
+ 114,0.51,0.4,3,0.51
116
+ 115,0.51,0.4,3,0.51
117
+ 116,0.51,0.4,3,0.51
118
+ 117,0.51,0.4,3,0.51
119
+ 118,0.51,0.4,3,0.51
120
+ 119,0.51,0.4,3,0.51
121
+ 120,0.51,0.4,3,0.51
server/app.py CHANGED
@@ -15,6 +15,7 @@ if str(PROJECT_ROOT) not in sys.path:
15
  sys.path.insert(0, str(PROJECT_ROOT))
16
 
17
  from environment.env import CodeReviewEnv
 
18
 
19
 
20
  app = Flask(__name__)
@@ -27,7 +28,7 @@ def root() -> Any:
27
  return jsonify({
28
  "status": "ok",
29
  "service": "code-review-agent-env",
30
- "endpoints": ["/health", "/reset", "/step", "/state"],
31
  })
32
 
33
 
@@ -72,6 +73,42 @@ def state() -> Any:
72
  return jsonify(current_state)
73
 
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  def main() -> None:
76
  host = os.getenv("HOST", "0.0.0.0")
77
  port = int(os.getenv("PORT", "7860"))
 
15
  sys.path.insert(0, str(PROJECT_ROOT))
16
 
17
  from environment.env import CodeReviewEnv
18
+ from environment.tasks import TaskDefinitions
19
 
20
 
21
  app = Flask(__name__)
 
28
  return jsonify({
29
  "status": "ok",
30
  "service": "code-review-agent-env",
31
+ "endpoints": ["/health", "/tasks", "/reset", "/step", "/state", "/score"],
32
  })
33
 
34
 
 
73
  return jsonify(current_state)
74
 
75
 
76
+ @app.get("/tasks")
77
+ def tasks() -> Any:
78
+ all_tasks = TaskDefinitions.get_all_tasks()
79
+ return jsonify(
80
+ {
81
+ "count": len(all_tasks),
82
+ "tasks": [
83
+ {
84
+ "task_id": t["task_id"],
85
+ "task_name": t["task_name"],
86
+ "difficulty": t["difficulty"],
87
+ "description": t["description"],
88
+ "language": t["language"],
89
+ }
90
+ for t in all_tasks
91
+ ],
92
+ }
93
+ )
94
+
95
+
96
+ @app.get("/score")
97
+ def score() -> Any:
98
+ with _lock:
99
+ task_score = _env.get_task_score()
100
+ state = _env.state()
101
+
102
+ return jsonify(
103
+ {
104
+ "task_score": task_score,
105
+ "current_step": state.get("current_step", 0),
106
+ "is_complete": state.get("is_complete", False),
107
+ "task_id": (state.get("task_metadata") or {}).get("task_id"),
108
+ }
109
+ )
110
+
111
+
112
  def main() -> None:
113
  host = os.getenv("HOST", "0.0.0.0")
114
  port = int(os.getenv("PORT", "7860"))
tests/test_env.py CHANGED
@@ -6,6 +6,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
6
 
7
  from environment.env import CodeReviewEnv
8
  from environment.models import ReviewAction, ReviewActionType, Comment, Suggestion
 
9
 
10
 
11
  class TestCodeReviewEnv(unittest.TestCase):
@@ -364,6 +365,13 @@ class TestCodeReviewEnv(unittest.TestCase):
364
  self.assertEqual(obs["final_decision_made"], "approved")
365
  self.assertEqual(info["task_score"], 1.0)
366
 
 
 
 
 
 
 
 
367
 
368
  if __name__ == "__main__":
369
  unittest.main()
 
6
 
7
  from environment.env import CodeReviewEnv
8
  from environment.models import ReviewAction, ReviewActionType, Comment, Suggestion
9
+ from environment.tasks import TaskDefinitions
10
 
11
 
12
  class TestCodeReviewEnv(unittest.TestCase):
 
365
  self.assertEqual(obs["final_decision_made"], "approved")
366
  self.assertEqual(info["task_score"], 1.0)
367
 
368
+ def test_new_task_categories_registered(self):
369
+ task_ids = {t["task_id"] for t in TaskDefinitions.get_all_tasks()}
370
+ self.assertIn("type_safety_medium_4", task_ids)
371
+ self.assertIn("javascript_medium_5", task_ids)
372
+ self.assertIn("concurrency_hard_5", task_ids)
373
+ self.assertIn("dependency_injection_hard_6", task_ids)
374
+
375
 
376
  if __name__ == "__main__":
377
  unittest.main()
tests/test_server_api.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+
3
+ from server.app import app
4
+
5
+
6
+ class TestServerAPI(unittest.TestCase):
7
+ def setUp(self):
8
+ self.client = app.test_client()
9
+
10
+ def test_root_includes_new_endpoints(self):
11
+ response = self.client.get("/")
12
+ self.assertEqual(response.status_code, 200)
13
+ payload = response.get_json()
14
+ self.assertIn("/tasks", payload["endpoints"])
15
+ self.assertIn("/score", payload["endpoints"])
16
+
17
+ def test_tasks_endpoint(self):
18
+ response = self.client.get("/tasks")
19
+ self.assertEqual(response.status_code, 200)
20
+ payload = response.get_json()
21
+ self.assertIn("count", payload)
22
+ self.assertIn("tasks", payload)
23
+ self.assertGreaterEqual(payload["count"], 10)
24
+
25
+ def test_score_endpoint(self):
26
+ # Reset first so scoring context exists.
27
+ self.client.get("/reset")
28
+ response = self.client.get("/score")
29
+ self.assertEqual(response.status_code, 200)
30
+ payload = response.get_json()
31
+ self.assertIn("task_score", payload)
32
+ self.assertIn("current_step", payload)
33
+ self.assertIn("task_id", payload)
34
+
35
+
36
+ if __name__ == "__main__":
37
+ unittest.main()
train.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import csv
6
+ import math
7
+ import random
8
+ from pathlib import Path
9
+ from typing import Dict, List
10
+
11
+ from train_env import TrainingEnv, default_action_catalog
12
+
13
+
14
+ def softmax(xs: List[float]) -> List[float]:
15
+ m = max(xs)
16
+ exps = [math.exp(x - m) for x in xs]
17
+ s = sum(exps)
18
+ return [x / s for x in exps]
19
+
20
+
21
+ def sample_index(probs: List[float]) -> int:
22
+ r = random.random()
23
+ c = 0.0
24
+ for i, p in enumerate(probs):
25
+ c += p
26
+ if r <= c:
27
+ return i
28
+ return len(probs) - 1
29
+
30
+
31
+ def main() -> int:
32
+ parser = argparse.ArgumentParser(description="Policy-gradient training loop for the code-review environment")
33
+ parser.add_argument("--episodes", type=int, default=120)
34
+ parser.add_argument("--lr", type=float, default=0.08)
35
+ parser.add_argument("--seed", type=int, default=42)
36
+ parser.add_argument("--log-dir", type=Path, default=Path("ppo_logs"))
37
+ parser.add_argument("--max-steps", type=int, default=5)
38
+ args = parser.parse_args()
39
+
40
+ random.seed(args.seed)
41
+ args.log_dir.mkdir(parents=True, exist_ok=True)
42
+
43
+ env = TrainingEnv(max_steps=args.max_steps, seed=args.seed)
44
+ catalog = default_action_catalog()
45
+
46
+ # Start with a suboptimal policy and learn toward better action plans.
47
+ logits: Dict[str, List[float]] = {
48
+ "phase_1": [-1.0, 1.0], # prefer weak_comment initially
49
+ "phase_2": [-1.0, 1.0], # prefer bad_fix initially
50
+ "phase_3": [-0.5, 0.5], # slight approve bias initially
51
+ }
52
+
53
+ baseline_reward = 0.0
54
+ history = []
55
+ epsilon_start = 0.35
56
+ epsilon_end = 0.05
57
+ warmup_episodes = max(10, args.episodes // 6)
58
+
59
+ for episode in range(1, args.episodes + 1):
60
+ chosen = {}
61
+ action_plan = []
62
+
63
+ for phase in ["phase_1", "phase_2", "phase_3"]:
64
+ probs = softmax(logits[phase])
65
+ progress = episode / max(1, args.episodes)
66
+ epsilon = epsilon_start + (epsilon_end - epsilon_start) * progress
67
+
68
+ if episode <= warmup_episodes:
69
+ # Warmup: deliberately weak choices to create a measurable learning baseline.
70
+ idx = 1 if len(probs) > 1 else 0
71
+ elif random.random() < epsilon:
72
+ idx = random.randrange(len(probs))
73
+ else:
74
+ idx = sample_index(probs)
75
+ chosen[phase] = (idx, probs[idx])
76
+ action_plan.append(catalog[phase][idx])
77
+
78
+ total_reward, task_score, steps = env.run_episode(action_plan)
79
+
80
+ advantage = total_reward - baseline_reward
81
+ baseline_reward = 0.9 * baseline_reward + 0.1 * total_reward
82
+
83
+ for phase in ["phase_1", "phase_2", "phase_3"]:
84
+ idx, prob = chosen[phase]
85
+ grad = (1.0 - prob)
86
+ logits[phase][idx] += args.lr * advantage * grad
87
+ # Soft penalty to non-chosen actions to make learning sharper.
88
+ for j in range(len(logits[phase])):
89
+ if j != idx:
90
+ logits[phase][j] -= args.lr * advantage * 0.15
91
+
92
+ history.append(
93
+ {
94
+ "episode": episode,
95
+ "reward": round(total_reward, 4),
96
+ "task_score": round(task_score, 4),
97
+ "steps": steps,
98
+ "baseline_reward": round(baseline_reward, 4),
99
+ }
100
+ )
101
+
102
+ metrics_path = args.log_dir / "train_metrics.csv"
103
+ with metrics_path.open("w", newline="", encoding="utf-8") as f:
104
+ writer = csv.DictWriter(f, fieldnames=["episode", "reward", "task_score", "steps", "baseline_reward"])
105
+ writer.writeheader()
106
+ writer.writerows(history)
107
+
108
+ # Also emit a compact summary for README use.
109
+ summary_path = args.log_dir / "summary.txt"
110
+ first = history[:10]
111
+ last = history[-10:]
112
+ first_avg = sum(x["reward"] for x in first) / max(1, len(first))
113
+ last_avg = sum(x["reward"] for x in last) / max(1, len(last))
114
+ with summary_path.open("w", encoding="utf-8") as f:
115
+ f.write(f"episodes={args.episodes}\n")
116
+ f.write(f"avg_reward_first10={first_avg:.4f}\n")
117
+ f.write(f"avg_reward_last10={last_avg:.4f}\n")
118
+ f.write(f"improvement={last_avg - first_avg:.4f}\n")
119
+
120
+ print(f"Training completed. Metrics: {metrics_path}")
121
+ print(f"Summary: {summary_path}")
122
+ return 0
123
+
124
+
125
+ if __name__ == "__main__":
126
+ raise SystemExit(main())
train_env.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, List, Tuple
5
+
6
+ from environment.env import CodeReviewEnv
7
+
8
+
9
+ @dataclass
10
+ class TemplateAction:
11
+ name: str
12
+ payload: Dict[str, Any]
13
+
14
+
15
+ class TrainingEnv:
16
+ """Thin wrapper around CodeReviewEnv for policy training experiments."""
17
+
18
+ def __init__(self, task_ids: List[str] | None = None, max_steps: int = 5, seed: int = 42):
19
+ self.env = CodeReviewEnv()
20
+ self.max_steps = max_steps
21
+ self.seed = seed
22
+ self.task_ids = task_ids or ["bug_detection_easy_1"]
23
+ self.task_cursor = 0
24
+
25
+ def next_task(self) -> str:
26
+ task_id = self.task_ids[self.task_cursor % len(self.task_ids)]
27
+ self.task_cursor += 1
28
+ return task_id
29
+
30
+ def run_episode(self, action_plan: List[TemplateAction]) -> Tuple[float, float, int]:
31
+ task_id = self.next_task()
32
+ self.env.max_steps = self.max_steps
33
+ obs = self.env.reset(task_id=task_id, seed=self.seed)
34
+ done = False
35
+ total_reward = 0.0
36
+ steps = 0
37
+
38
+ for action in action_plan:
39
+ if done:
40
+ break
41
+ obs, reward, done, _ = self.env.step(action.payload)
42
+ total_reward += float(reward)
43
+ steps += 1
44
+
45
+ task_score = float(self.env.get_task_score())
46
+ return total_reward, task_score, steps
47
+
48
+
49
+ def default_action_catalog() -> Dict[str, List[TemplateAction]]:
50
+ return {
51
+ "phase_1": [
52
+ TemplateAction(
53
+ "good_comment",
54
+ {
55
+ "action_type": "add_comment",
56
+ "comments": [
57
+ {
58
+ "line_number": 3,
59
+ "content": "Potential division_by_zero or similar correctness issue",
60
+ "is_issue": True,
61
+ "severity": "high",
62
+ }
63
+ ],
64
+ "suggestions": [],
65
+ },
66
+ ),
67
+ TemplateAction(
68
+ "weak_comment",
69
+ {
70
+ "action_type": "add_comment",
71
+ "comments": [
72
+ {
73
+ "line_number": 1,
74
+ "content": "maybe issue",
75
+ "is_issue": True,
76
+ "severity": "low",
77
+ }
78
+ ],
79
+ "suggestions": [],
80
+ },
81
+ ),
82
+ ],
83
+ "phase_2": [
84
+ TemplateAction(
85
+ "good_fix",
86
+ {
87
+ "action_type": "suggest_fix",
88
+ "comments": [],
89
+ "suggestions": [
90
+ {
91
+ "original_line": 3,
92
+ "suggested_code": "return total / len(numbers) if numbers else 0",
93
+ "explanation": "guard empty input",
94
+ }
95
+ ],
96
+ },
97
+ ),
98
+ TemplateAction(
99
+ "bad_fix",
100
+ {
101
+ "action_type": "suggest_fix",
102
+ "comments": [],
103
+ "suggestions": [
104
+ {
105
+ "original_line": 1,
106
+ "suggested_code": "pass",
107
+ "explanation": "placeholder",
108
+ }
109
+ ],
110
+ },
111
+ ),
112
+ ],
113
+ "phase_3": [
114
+ TemplateAction(
115
+ "request_changes",
116
+ {
117
+ "action_type": "request_changes",
118
+ "comments": [],
119
+ "suggestions": [],
120
+ "final_decision": "changes_requested",
121
+ },
122
+ ),
123
+ TemplateAction(
124
+ "approve",
125
+ {
126
+ "action_type": "approve",
127
+ "comments": [],
128
+ "suggestions": [],
129
+ "final_decision": "approved",
130
+ },
131
+ ),
132
+ ],
133
+ }