Commit ·
1939cbc
1
Parent(s): abf2209
New Final
Browse files- README.md +63 -4
- environment/tasks.py +76 -0
- explore_env.ipynb +0 -0
- openenv.yaml +16 -0
- ppo_logs/README.md +15 -0
- ppo_logs/summary.txt +4 -0
- ppo_logs/train_metrics.csv +121 -0
- server/app.py +38 -1
- tests/test_env.py +8 -0
- tests/test_server_api.py +37 -0
- train.py +126 -0
- train_env.py +133 -0
README.md
CHANGED
|
@@ -12,6 +12,12 @@ pinned: false
|
|
| 12 |
|
| 13 |
This repository provides an OpenEnv-compatible environment for evaluating AI code-review agents.
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
The agent receives a code diff and surrounding file context, then performs a multi-step review:
|
| 16 |
|
| 17 |
1. Add issue comments with line numbers.
|
|
@@ -24,15 +30,20 @@ The environment scores the review quality using deterministic graders.
|
|
| 24 |
|
| 25 |
- Simulates pull-request review tasks across easy/medium/hard difficulty.
|
| 26 |
- Exposes OpenEnv-style lifecycle methods (`reset`, `step`, `state`).
|
|
|
|
| 27 |
- Grades issue detection, fix suggestions, and final decision quality.
|
| 28 |
- Supports local LLM providers via an OpenAI-compatible API (including Ollama).
|
|
|
|
| 29 |
|
| 30 |
## Project Structure
|
| 31 |
|
| 32 |
- `environment/`: environment implementation, task definitions, models, and grading logic.
|
| 33 |
- `inference.py`: baseline review agent loop.
|
|
|
|
|
|
|
| 34 |
- `openenv.yaml`: task registry and environment metadata.
|
| 35 |
- `tests/`: environment tests.
|
|
|
|
| 36 |
- `docker-compose.yml` / `Dockerfile`: containerized execution options.
|
| 37 |
|
| 38 |
## Prerequisites
|
|
@@ -154,9 +165,24 @@ Note: on macOS, `network_mode: host` can be unreliable. If `local-agent` cannot
|
|
| 154 |
- `memory_leak_medium_1`
|
| 155 |
- `performance_medium_2`
|
| 156 |
- `approve_medium_3`
|
|
|
|
|
|
|
| 157 |
- `security_hard_1`
|
| 158 |
- `race_condition_hard_2`
|
| 159 |
- `approve_hard_3`
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
## Output Format
|
| 162 |
|
|
@@ -221,6 +247,29 @@ python submit.py --skip-docker --max-steps 10
|
|
| 221 |
|
| 222 |
Note: `task_score` is normalized to [0,1]. `total_reward` is cumulative step reward and can exceed 1.0 by design.
|
| 223 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
## One-Command Benchmark Table
|
| 225 |
|
| 226 |
Generate per-task JSON outputs plus a markdown table for judge submission:
|
|
@@ -237,8 +286,18 @@ Artifacts:
|
|
| 237 |
|
| 238 |
## Failure Analysis Template
|
| 239 |
|
| 240 |
-
|
| 241 |
-
-
|
| 242 |
-
-
|
| 243 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
|
|
|
|
| 12 |
|
| 13 |
This repository provides an OpenEnv-compatible environment for evaluating AI code-review agents.
|
| 14 |
|
| 15 |
+
## Why This Environment
|
| 16 |
+
|
| 17 |
+
Code review is a strong RL task because success and failure are measurable: line-level issues can be deterministically graded, rewards can be shaped across review phases, and tasks can scale from easy to hard while staying realistic.
|
| 18 |
+
|
| 19 |
+
This project is designed for both evaluation and lightweight policy training loops, not only one-off scripted inference.
|
| 20 |
+
|
| 21 |
The agent receives a code diff and surrounding file context, then performs a multi-step review:
|
| 22 |
|
| 23 |
1. Add issue comments with line numbers.
|
|
|
|
| 30 |
|
| 31 |
- Simulates pull-request review tasks across easy/medium/hard difficulty.
|
| 32 |
- Exposes OpenEnv-style lifecycle methods (`reset`, `step`, `state`).
|
| 33 |
+
- Exposes integration endpoints (`tasks`, `score`, `health`) for tooling and dashboard checks.
|
| 34 |
- Grades issue detection, fix suggestions, and final decision quality.
|
| 35 |
- Supports local LLM providers via an OpenAI-compatible API (including Ollama).
|
| 36 |
+
- Includes a policy-training scaffold (`train.py`, `train_env.py`) and logged training metrics.
|
| 37 |
|
| 38 |
## Project Structure
|
| 39 |
|
| 40 |
- `environment/`: environment implementation, task definitions, models, and grading logic.
|
| 41 |
- `inference.py`: baseline review agent loop.
|
| 42 |
+
- `train.py`, `train_env.py`: lightweight PPO-style policy training loop over the environment.
|
| 43 |
+
- `ppo_logs/`: training metrics and summaries.
|
| 44 |
- `openenv.yaml`: task registry and environment metadata.
|
| 45 |
- `tests/`: environment tests.
|
| 46 |
+
- `explore_env.ipynb`: interactive environment walkthrough.
|
| 47 |
- `docker-compose.yml` / `Dockerfile`: containerized execution options.
|
| 48 |
|
| 49 |
## Prerequisites
|
|
|
|
| 165 |
- `memory_leak_medium_1`
|
| 166 |
- `performance_medium_2`
|
| 167 |
- `approve_medium_3`
|
| 168 |
+
- `type_safety_medium_4`
|
| 169 |
+
- `javascript_medium_5`
|
| 170 |
- `security_hard_1`
|
| 171 |
- `race_condition_hard_2`
|
| 172 |
- `approve_hard_3`
|
| 173 |
+
- `adversarial_hard_4`
|
| 174 |
+
- `concurrency_hard_5`
|
| 175 |
+
- `dependency_injection_hard_6`
|
| 176 |
+
|
| 177 |
+
## HTTP Endpoints
|
| 178 |
+
|
| 179 |
+
- `GET /`
|
| 180 |
+
- `GET /health`
|
| 181 |
+
- `GET /tasks`
|
| 182 |
+
- `GET|POST /reset`
|
| 183 |
+
- `POST /step`
|
| 184 |
+
- `GET /state`
|
| 185 |
+
- `GET /score`
|
| 186 |
|
| 187 |
## Output Format
|
| 188 |
|
|
|
|
| 247 |
|
| 248 |
Note: `task_score` is normalized to [0,1]. `total_reward` is cumulative step reward and can exceed 1.0 by design.
|
| 249 |
|
| 250 |
+
## Training Results (PPO-style Loop)
|
| 251 |
+
|
| 252 |
+
Run training:
|
| 253 |
+
|
| 254 |
+
```bash
|
| 255 |
+
source .venv/bin/activate
|
| 256 |
+
python train.py --episodes 120 --max-steps 5
|
| 257 |
+
```
|
| 258 |
+
|
| 259 |
+
Generated artifacts:
|
| 260 |
+
|
| 261 |
+
- `ppo_logs/train_metrics.csv`
|
| 262 |
+
- `ppo_logs/summary.txt`
|
| 263 |
+
|
| 264 |
+
Recent run summary:
|
| 265 |
+
|
| 266 |
+
- Episodes: `120`
|
| 267 |
+
- Average reward (first 10): `0.0100`
|
| 268 |
+
- Average reward (last 10): `0.5100`
|
| 269 |
+
- Improvement: `+0.5000`
|
| 270 |
+
|
| 271 |
+
This demonstrates measurable policy improvement under the training setup provided in this repository.
|
| 272 |
+
|
| 273 |
## One-Command Benchmark Table
|
| 274 |
|
| 275 |
Generate per-task JSON outputs plus a markdown table for judge submission:
|
|
|
|
| 286 |
|
| 287 |
## Failure Analysis Template
|
| 288 |
|
| 289 |
+
1. `javascript_medium_5` (Undefined access)
|
| 290 |
+
- Observation: task score reached `1.0`, but diagnostics show `precision=0.5`, `recall=1.0`, `f1=0.6667`, `false_positive_count=1`.
|
| 291 |
+
- Why: model used Python-centric heuristics and produced one extra issue comment on a JS snippet.
|
| 292 |
+
- Action: added JavaScript task category and retained false-positive penalties to expose over-flagging.
|
| 293 |
+
|
| 294 |
+
2. `memory_leak_medium_1` (historical baseline run)
|
| 295 |
+
- Observation: earlier run dropped below perfect score due to noisy comment strategy.
|
| 296 |
+
- Why: over-commenting triggered false positive penalties despite finding the core issue.
|
| 297 |
+
- Action: anti-loop repeated-comment penalty + adversarial no-issue tasks to discourage spam.
|
| 298 |
+
|
| 299 |
+
3. `adversarial_hard_4` (Safe SQL task)
|
| 300 |
+
- Observation: correct behavior is approve; naive SQL keyword matching causes false alarms.
|
| 301 |
+
- Why: keyword-only review policies confuse parameterized SQL with vulnerable string interpolation.
|
| 302 |
+
- Action: included explicit no-issue adversarial task in hard set and calibration tests to reward restraint.
|
| 303 |
|
environment/tasks.py
CHANGED
|
@@ -177,6 +177,44 @@ def run_user_query(db, limit):
|
|
| 177 |
"language": "python",
|
| 178 |
"line_count": 3,
|
| 179 |
"expected_issues": []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
}
|
| 181 |
]
|
| 182 |
|
|
@@ -297,6 +335,44 @@ def find_all_users(database):
|
|
| 297 |
"language": "python",
|
| 298 |
"line_count": 4,
|
| 299 |
"expected_issues": []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
}
|
| 301 |
]
|
| 302 |
|
|
|
|
| 177 |
"language": "python",
|
| 178 |
"line_count": 3,
|
| 179 |
"expected_issues": []
|
| 180 |
+
},
|
| 181 |
+
{
|
| 182 |
+
"task_id": "type_safety_medium_4",
|
| 183 |
+
"task_name": "Type Safety: Optional Arithmetic",
|
| 184 |
+
"difficulty": "medium",
|
| 185 |
+
"description": "Find the type safety issue where Optional[int] can be None during arithmetic",
|
| 186 |
+
"code_diff": """from typing import Optional\n\ndef increment(value: Optional[int]) -> int:\n return value + 1""",
|
| 187 |
+
"surrounding_code": """from typing import Optional\n\ndef increment(value: Optional[int]) -> int:\n return value + 1\n\ndef safe_increment(value: Optional[int]) -> int:\n return increment(value)""",
|
| 188 |
+
"file_path": "type_utils.py",
|
| 189 |
+
"language": "python",
|
| 190 |
+
"line_count": 4,
|
| 191 |
+
"expected_issues": [
|
| 192 |
+
{
|
| 193 |
+
"line": 4,
|
| 194 |
+
"type": "type_safety",
|
| 195 |
+
"severity": "medium",
|
| 196 |
+
"description": "Optional[int] may be None, causing runtime TypeError",
|
| 197 |
+
}
|
| 198 |
+
]
|
| 199 |
+
},
|
| 200 |
+
{
|
| 201 |
+
"task_id": "javascript_medium_5",
|
| 202 |
+
"task_name": "JavaScript: Undefined Access",
|
| 203 |
+
"difficulty": "medium",
|
| 204 |
+
"description": "Find the JavaScript bug where user can be undefined before property access",
|
| 205 |
+
"code_diff": """function getUserName(user) {\n return user.name.trim();\n}""",
|
| 206 |
+
"surrounding_code": """function getUserName(user) {\n return user.name.trim();\n}\n\nfunction formatUser(user) {\n return getUserName(user).toLowerCase();\n}""",
|
| 207 |
+
"file_path": "user.js",
|
| 208 |
+
"language": "javascript",
|
| 209 |
+
"line_count": 3,
|
| 210 |
+
"expected_issues": [
|
| 211 |
+
{
|
| 212 |
+
"line": 2,
|
| 213 |
+
"type": "null_access",
|
| 214 |
+
"severity": "medium",
|
| 215 |
+
"description": "user may be undefined and property access can throw",
|
| 216 |
+
}
|
| 217 |
+
]
|
| 218 |
}
|
| 219 |
]
|
| 220 |
|
|
|
|
| 335 |
"language": "python",
|
| 336 |
"line_count": 4,
|
| 337 |
"expected_issues": []
|
| 338 |
+
},
|
| 339 |
+
{
|
| 340 |
+
"task_id": "concurrency_hard_5",
|
| 341 |
+
"task_name": "Concurrency: Async Await Misuse",
|
| 342 |
+
"difficulty": "hard",
|
| 343 |
+
"description": "Find async misuse where created tasks are never awaited",
|
| 344 |
+
"code_diff": """import asyncio\n\nasync def process_all(items, worker):\n for item in items:\n asyncio.create_task(worker(item))\n return True""",
|
| 345 |
+
"surrounding_code": """import asyncio\n\nasync def process_all(items, worker):\n for item in items:\n asyncio.create_task(worker(item))\n return True\n\nasync def run(items, worker):\n return await process_all(items, worker)""",
|
| 346 |
+
"file_path": "async_processor.py",
|
| 347 |
+
"language": "python",
|
| 348 |
+
"line_count": 6,
|
| 349 |
+
"expected_issues": [
|
| 350 |
+
{
|
| 351 |
+
"line": 5,
|
| 352 |
+
"type": "async_misuse",
|
| 353 |
+
"severity": "high",
|
| 354 |
+
"description": "Tasks are created but never awaited or gathered",
|
| 355 |
+
}
|
| 356 |
+
]
|
| 357 |
+
},
|
| 358 |
+
{
|
| 359 |
+
"task_id": "dependency_injection_hard_6",
|
| 360 |
+
"task_name": "Dependency Injection: Tight Coupling",
|
| 361 |
+
"difficulty": "hard",
|
| 362 |
+
"description": "Find design issue where service constructs hardcoded dependency internally",
|
| 363 |
+
"code_diff": """class PaymentService:\n def __init__(self):\n self.gateway = StripeGateway()\n\n def charge(self, amount):\n return self.gateway.charge(amount)""",
|
| 364 |
+
"surrounding_code": """class PaymentService:\n def __init__(self):\n self.gateway = StripeGateway()\n\n def charge(self, amount):\n return self.gateway.charge(amount)\n\nclass StripeGateway:\n def charge(self, amount):\n return True""",
|
| 365 |
+
"file_path": "payment_service.py",
|
| 366 |
+
"language": "python",
|
| 367 |
+
"line_count": 6,
|
| 368 |
+
"expected_issues": [
|
| 369 |
+
{
|
| 370 |
+
"line": 3,
|
| 371 |
+
"type": "dependency_injection",
|
| 372 |
+
"severity": "medium",
|
| 373 |
+
"description": "Hardcoded dependency prevents testability and inversion of control",
|
| 374 |
+
}
|
| 375 |
+
]
|
| 376 |
}
|
| 377 |
]
|
| 378 |
|
explore_env.ipynb
ADDED
|
File without changes
|
openenv.yaml
CHANGED
|
@@ -46,6 +46,14 @@ tasks:
|
|
| 46 |
name: "Medium: Approve Safe Query Helper"
|
| 47 |
difficulty: medium
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
- id: security_hard_1
|
| 50 |
name: "Hard: SQL Injection Vulnerability"
|
| 51 |
difficulty: hard
|
|
@@ -62,6 +70,14 @@ tasks:
|
|
| 62 |
name: "Hard: Adversarial Safe SQL Builder"
|
| 63 |
difficulty: hard
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
observation_space:
|
| 66 |
type: dict
|
| 67 |
description: |
|
|
|
|
| 46 |
name: "Medium: Approve Safe Query Helper"
|
| 47 |
difficulty: medium
|
| 48 |
|
| 49 |
+
- id: type_safety_medium_4
|
| 50 |
+
name: "Medium: Type Safety Optional Arithmetic"
|
| 51 |
+
difficulty: medium
|
| 52 |
+
|
| 53 |
+
- id: javascript_medium_5
|
| 54 |
+
name: "Medium: JavaScript Undefined Access"
|
| 55 |
+
difficulty: medium
|
| 56 |
+
|
| 57 |
- id: security_hard_1
|
| 58 |
name: "Hard: SQL Injection Vulnerability"
|
| 59 |
difficulty: hard
|
|
|
|
| 70 |
name: "Hard: Adversarial Safe SQL Builder"
|
| 71 |
difficulty: hard
|
| 72 |
|
| 73 |
+
- id: concurrency_hard_5
|
| 74 |
+
name: "Hard: Async Await Misuse"
|
| 75 |
+
difficulty: hard
|
| 76 |
+
|
| 77 |
+
- id: dependency_injection_hard_6
|
| 78 |
+
name: "Hard: Tight Coupling in Service"
|
| 79 |
+
difficulty: hard
|
| 80 |
+
|
| 81 |
observation_space:
|
| 82 |
type: dict
|
| 83 |
description: |
|
ppo_logs/README.md
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# PPO Logs
|
| 2 |
+
|
| 3 |
+
This folder stores training artifacts produced by `train.py`.
|
| 4 |
+
|
| 5 |
+
Files:
|
| 6 |
+
|
| 7 |
+
- `train_metrics.csv`: per-episode reward, task_score, steps, and running baseline.
|
| 8 |
+
- `summary.txt`: compact training summary for README/judge evidence.
|
| 9 |
+
|
| 10 |
+
Example run:
|
| 11 |
+
|
| 12 |
+
```bash
|
| 13 |
+
source .venv/bin/activate
|
| 14 |
+
python train.py --episodes 120 --max-steps 5
|
| 15 |
+
```
|
ppo_logs/summary.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
episodes=120
|
| 2 |
+
avg_reward_first10=0.0100
|
| 3 |
+
avg_reward_last10=0.5100
|
| 4 |
+
improvement=0.5000
|
ppo_logs/train_metrics.csv
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
episode,reward,task_score,steps,baseline_reward
|
| 2 |
+
1,0.01,0.0,3,0.001
|
| 3 |
+
2,0.01,0.0,3,0.0019
|
| 4 |
+
3,0.01,0.0,3,0.0027
|
| 5 |
+
4,0.01,0.0,3,0.0034
|
| 6 |
+
5,0.01,0.0,3,0.0041
|
| 7 |
+
6,0.01,0.0,3,0.0047
|
| 8 |
+
7,0.01,0.0,3,0.0052
|
| 9 |
+
8,0.01,0.0,3,0.0057
|
| 10 |
+
9,0.01,0.0,3,0.0061
|
| 11 |
+
10,0.01,0.0,3,0.0065
|
| 12 |
+
11,0.01,0.0,3,0.0069
|
| 13 |
+
12,0.01,0.0,3,0.0072
|
| 14 |
+
13,0.01,0.0,3,0.0075
|
| 15 |
+
14,0.01,0.0,3,0.0077
|
| 16 |
+
15,0.01,0.0,3,0.0079
|
| 17 |
+
16,0.01,0.0,3,0.0081
|
| 18 |
+
17,0.01,0.0,3,0.0083
|
| 19 |
+
18,0.01,0.0,3,0.0085
|
| 20 |
+
19,0.01,0.0,3,0.0086
|
| 21 |
+
20,0.01,0.0,3,0.0088
|
| 22 |
+
21,1.31,1.0,3,0.1389
|
| 23 |
+
22,1.31,1.0,3,0.256
|
| 24 |
+
23,1.31,1.0,3,0.3614
|
| 25 |
+
24,1.31,1.0,3,0.4563
|
| 26 |
+
25,1.31,1.0,3,0.5416
|
| 27 |
+
26,1.31,1.0,3,0.6185
|
| 28 |
+
27,1.31,1.0,3,0.6876
|
| 29 |
+
28,1.31,1.0,3,0.7499
|
| 30 |
+
29,1.31,1.0,3,0.8059
|
| 31 |
+
30,0.51,0.4,3,0.7763
|
| 32 |
+
31,0.51,0.4,3,0.7497
|
| 33 |
+
32,0.51,0.4,3,0.7257
|
| 34 |
+
33,0.51,0.4,3,0.7041
|
| 35 |
+
34,0.51,0.4,3,0.6847
|
| 36 |
+
35,0.51,0.4,3,0.6672
|
| 37 |
+
36,0.51,0.4,3,0.6515
|
| 38 |
+
37,0.51,0.4,3,0.6374
|
| 39 |
+
38,0.51,0.4,3,0.6246
|
| 40 |
+
39,0.51,0.4,3,0.6132
|
| 41 |
+
40,0.51,0.4,3,0.6029
|
| 42 |
+
41,0.51,0.4,3,0.5936
|
| 43 |
+
42,0.51,0.4,3,0.5852
|
| 44 |
+
43,0.51,0.4,3,0.5777
|
| 45 |
+
44,0.51,0.4,3,0.5709
|
| 46 |
+
45,0.51,0.4,3,0.5648
|
| 47 |
+
46,0.51,0.4,3,0.5593
|
| 48 |
+
47,0.51,0.4,3,0.5544
|
| 49 |
+
48,0.51,0.4,3,0.55
|
| 50 |
+
49,0.51,0.4,3,0.546
|
| 51 |
+
50,0.51,0.4,3,0.5424
|
| 52 |
+
51,0.51,0.4,3,0.5391
|
| 53 |
+
52,0.51,0.4,3,0.5362
|
| 54 |
+
53,0.51,0.4,3,0.5336
|
| 55 |
+
54,0.51,0.4,3,0.5312
|
| 56 |
+
55,0.51,0.4,3,0.5291
|
| 57 |
+
56,0.51,0.4,3,0.5272
|
| 58 |
+
57,0.51,0.4,3,0.5255
|
| 59 |
+
58,0.51,0.4,3,0.5239
|
| 60 |
+
59,0.51,0.4,3,0.5225
|
| 61 |
+
60,0.51,0.4,3,0.5213
|
| 62 |
+
61,0.51,0.4,3,0.5202
|
| 63 |
+
62,0.51,0.4,3,0.5191
|
| 64 |
+
63,0.51,0.4,3,0.5182
|
| 65 |
+
64,0.51,0.4,3,0.5174
|
| 66 |
+
65,0.51,0.4,3,0.5167
|
| 67 |
+
66,0.51,0.4,3,0.516
|
| 68 |
+
67,0.51,0.4,3,0.5154
|
| 69 |
+
68,0.51,0.4,3,0.5149
|
| 70 |
+
69,0.51,0.4,3,0.5144
|
| 71 |
+
70,0.51,0.4,3,0.5139
|
| 72 |
+
71,0.51,0.4,3,0.5135
|
| 73 |
+
72,0.51,0.4,3,0.5132
|
| 74 |
+
73,0.51,0.4,3,0.5129
|
| 75 |
+
74,0.51,0.4,3,0.5126
|
| 76 |
+
75,0.51,0.4,3,0.5123
|
| 77 |
+
76,0.51,0.4,3,0.5121
|
| 78 |
+
77,0.51,0.4,3,0.5119
|
| 79 |
+
78,0.51,0.4,3,0.5117
|
| 80 |
+
79,0.51,0.4,3,0.5115
|
| 81 |
+
80,0.51,0.4,3,0.5114
|
| 82 |
+
81,0.51,0.4,3,0.5112
|
| 83 |
+
82,0.51,0.4,3,0.5111
|
| 84 |
+
83,0.51,0.4,3,0.511
|
| 85 |
+
84,0.51,0.4,3,0.5109
|
| 86 |
+
85,0.51,0.4,3,0.5108
|
| 87 |
+
86,0.51,0.4,3,0.5107
|
| 88 |
+
87,0.51,0.4,3,0.5107
|
| 89 |
+
88,0.51,0.4,3,0.5106
|
| 90 |
+
89,0.51,0.4,3,0.5105
|
| 91 |
+
90,0.51,0.4,3,0.5105
|
| 92 |
+
91,0.51,0.4,3,0.5104
|
| 93 |
+
92,0.51,0.4,3,0.5104
|
| 94 |
+
93,0.51,0.4,3,0.5103
|
| 95 |
+
94,0.51,0.4,3,0.5103
|
| 96 |
+
95,0.51,0.4,3,0.5103
|
| 97 |
+
96,0.51,0.4,3,0.5103
|
| 98 |
+
97,0.51,0.4,3,0.5102
|
| 99 |
+
98,0.51,0.4,3,0.5102
|
| 100 |
+
99,0.51,0.4,3,0.5102
|
| 101 |
+
100,0.51,0.4,3,0.5102
|
| 102 |
+
101,0.51,0.4,3,0.5102
|
| 103 |
+
102,0.51,0.4,3,0.5101
|
| 104 |
+
103,0.51,0.4,3,0.5101
|
| 105 |
+
104,0.51,0.4,3,0.5101
|
| 106 |
+
105,0.51,0.4,3,0.5101
|
| 107 |
+
106,0.51,0.4,3,0.5101
|
| 108 |
+
107,0.51,0.4,3,0.5101
|
| 109 |
+
108,0.51,0.4,3,0.5101
|
| 110 |
+
109,0.51,0.4,3,0.5101
|
| 111 |
+
110,0.51,0.4,3,0.5101
|
| 112 |
+
111,0.51,0.4,3,0.5101
|
| 113 |
+
112,0.51,0.4,3,0.51
|
| 114 |
+
113,0.51,0.4,3,0.51
|
| 115 |
+
114,0.51,0.4,3,0.51
|
| 116 |
+
115,0.51,0.4,3,0.51
|
| 117 |
+
116,0.51,0.4,3,0.51
|
| 118 |
+
117,0.51,0.4,3,0.51
|
| 119 |
+
118,0.51,0.4,3,0.51
|
| 120 |
+
119,0.51,0.4,3,0.51
|
| 121 |
+
120,0.51,0.4,3,0.51
|
server/app.py
CHANGED
|
@@ -15,6 +15,7 @@ if str(PROJECT_ROOT) not in sys.path:
|
|
| 15 |
sys.path.insert(0, str(PROJECT_ROOT))
|
| 16 |
|
| 17 |
from environment.env import CodeReviewEnv
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
app = Flask(__name__)
|
|
@@ -27,7 +28,7 @@ def root() -> Any:
|
|
| 27 |
return jsonify({
|
| 28 |
"status": "ok",
|
| 29 |
"service": "code-review-agent-env",
|
| 30 |
-
"endpoints": ["/health", "/reset", "/step", "/state"],
|
| 31 |
})
|
| 32 |
|
| 33 |
|
|
@@ -72,6 +73,42 @@ def state() -> Any:
|
|
| 72 |
return jsonify(current_state)
|
| 73 |
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
def main() -> None:
|
| 76 |
host = os.getenv("HOST", "0.0.0.0")
|
| 77 |
port = int(os.getenv("PORT", "7860"))
|
|
|
|
| 15 |
sys.path.insert(0, str(PROJECT_ROOT))
|
| 16 |
|
| 17 |
from environment.env import CodeReviewEnv
|
| 18 |
+
from environment.tasks import TaskDefinitions
|
| 19 |
|
| 20 |
|
| 21 |
app = Flask(__name__)
|
|
|
|
| 28 |
return jsonify({
|
| 29 |
"status": "ok",
|
| 30 |
"service": "code-review-agent-env",
|
| 31 |
+
"endpoints": ["/health", "/tasks", "/reset", "/step", "/state", "/score"],
|
| 32 |
})
|
| 33 |
|
| 34 |
|
|
|
|
| 73 |
return jsonify(current_state)
|
| 74 |
|
| 75 |
|
| 76 |
+
@app.get("/tasks")
|
| 77 |
+
def tasks() -> Any:
|
| 78 |
+
all_tasks = TaskDefinitions.get_all_tasks()
|
| 79 |
+
return jsonify(
|
| 80 |
+
{
|
| 81 |
+
"count": len(all_tasks),
|
| 82 |
+
"tasks": [
|
| 83 |
+
{
|
| 84 |
+
"task_id": t["task_id"],
|
| 85 |
+
"task_name": t["task_name"],
|
| 86 |
+
"difficulty": t["difficulty"],
|
| 87 |
+
"description": t["description"],
|
| 88 |
+
"language": t["language"],
|
| 89 |
+
}
|
| 90 |
+
for t in all_tasks
|
| 91 |
+
],
|
| 92 |
+
}
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
@app.get("/score")
|
| 97 |
+
def score() -> Any:
|
| 98 |
+
with _lock:
|
| 99 |
+
task_score = _env.get_task_score()
|
| 100 |
+
state = _env.state()
|
| 101 |
+
|
| 102 |
+
return jsonify(
|
| 103 |
+
{
|
| 104 |
+
"task_score": task_score,
|
| 105 |
+
"current_step": state.get("current_step", 0),
|
| 106 |
+
"is_complete": state.get("is_complete", False),
|
| 107 |
+
"task_id": (state.get("task_metadata") or {}).get("task_id"),
|
| 108 |
+
}
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
def main() -> None:
|
| 113 |
host = os.getenv("HOST", "0.0.0.0")
|
| 114 |
port = int(os.getenv("PORT", "7860"))
|
tests/test_env.py
CHANGED
|
@@ -6,6 +6,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
| 6 |
|
| 7 |
from environment.env import CodeReviewEnv
|
| 8 |
from environment.models import ReviewAction, ReviewActionType, Comment, Suggestion
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
class TestCodeReviewEnv(unittest.TestCase):
|
|
@@ -364,6 +365,13 @@ class TestCodeReviewEnv(unittest.TestCase):
|
|
| 364 |
self.assertEqual(obs["final_decision_made"], "approved")
|
| 365 |
self.assertEqual(info["task_score"], 1.0)
|
| 366 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
|
| 368 |
if __name__ == "__main__":
|
| 369 |
unittest.main()
|
|
|
|
| 6 |
|
| 7 |
from environment.env import CodeReviewEnv
|
| 8 |
from environment.models import ReviewAction, ReviewActionType, Comment, Suggestion
|
| 9 |
+
from environment.tasks import TaskDefinitions
|
| 10 |
|
| 11 |
|
| 12 |
class TestCodeReviewEnv(unittest.TestCase):
|
|
|
|
| 365 |
self.assertEqual(obs["final_decision_made"], "approved")
|
| 366 |
self.assertEqual(info["task_score"], 1.0)
|
| 367 |
|
| 368 |
+
def test_new_task_categories_registered(self):
|
| 369 |
+
task_ids = {t["task_id"] for t in TaskDefinitions.get_all_tasks()}
|
| 370 |
+
self.assertIn("type_safety_medium_4", task_ids)
|
| 371 |
+
self.assertIn("javascript_medium_5", task_ids)
|
| 372 |
+
self.assertIn("concurrency_hard_5", task_ids)
|
| 373 |
+
self.assertIn("dependency_injection_hard_6", task_ids)
|
| 374 |
+
|
| 375 |
|
| 376 |
if __name__ == "__main__":
|
| 377 |
unittest.main()
|
tests/test_server_api.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import unittest
|
| 2 |
+
|
| 3 |
+
from server.app import app
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class TestServerAPI(unittest.TestCase):
|
| 7 |
+
def setUp(self):
|
| 8 |
+
self.client = app.test_client()
|
| 9 |
+
|
| 10 |
+
def test_root_includes_new_endpoints(self):
|
| 11 |
+
response = self.client.get("/")
|
| 12 |
+
self.assertEqual(response.status_code, 200)
|
| 13 |
+
payload = response.get_json()
|
| 14 |
+
self.assertIn("/tasks", payload["endpoints"])
|
| 15 |
+
self.assertIn("/score", payload["endpoints"])
|
| 16 |
+
|
| 17 |
+
def test_tasks_endpoint(self):
|
| 18 |
+
response = self.client.get("/tasks")
|
| 19 |
+
self.assertEqual(response.status_code, 200)
|
| 20 |
+
payload = response.get_json()
|
| 21 |
+
self.assertIn("count", payload)
|
| 22 |
+
self.assertIn("tasks", payload)
|
| 23 |
+
self.assertGreaterEqual(payload["count"], 10)
|
| 24 |
+
|
| 25 |
+
def test_score_endpoint(self):
|
| 26 |
+
# Reset first so scoring context exists.
|
| 27 |
+
self.client.get("/reset")
|
| 28 |
+
response = self.client.get("/score")
|
| 29 |
+
self.assertEqual(response.status_code, 200)
|
| 30 |
+
payload = response.get_json()
|
| 31 |
+
self.assertIn("task_score", payload)
|
| 32 |
+
self.assertIn("current_step", payload)
|
| 33 |
+
self.assertIn("task_id", payload)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
if __name__ == "__main__":
|
| 37 |
+
unittest.main()
|
train.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import csv
|
| 6 |
+
import math
|
| 7 |
+
import random
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Dict, List
|
| 10 |
+
|
| 11 |
+
from train_env import TrainingEnv, default_action_catalog
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def softmax(xs: List[float]) -> List[float]:
|
| 15 |
+
m = max(xs)
|
| 16 |
+
exps = [math.exp(x - m) for x in xs]
|
| 17 |
+
s = sum(exps)
|
| 18 |
+
return [x / s for x in exps]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def sample_index(probs: List[float]) -> int:
|
| 22 |
+
r = random.random()
|
| 23 |
+
c = 0.0
|
| 24 |
+
for i, p in enumerate(probs):
|
| 25 |
+
c += p
|
| 26 |
+
if r <= c:
|
| 27 |
+
return i
|
| 28 |
+
return len(probs) - 1
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def main() -> int:
|
| 32 |
+
parser = argparse.ArgumentParser(description="Policy-gradient training loop for the code-review environment")
|
| 33 |
+
parser.add_argument("--episodes", type=int, default=120)
|
| 34 |
+
parser.add_argument("--lr", type=float, default=0.08)
|
| 35 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 36 |
+
parser.add_argument("--log-dir", type=Path, default=Path("ppo_logs"))
|
| 37 |
+
parser.add_argument("--max-steps", type=int, default=5)
|
| 38 |
+
args = parser.parse_args()
|
| 39 |
+
|
| 40 |
+
random.seed(args.seed)
|
| 41 |
+
args.log_dir.mkdir(parents=True, exist_ok=True)
|
| 42 |
+
|
| 43 |
+
env = TrainingEnv(max_steps=args.max_steps, seed=args.seed)
|
| 44 |
+
catalog = default_action_catalog()
|
| 45 |
+
|
| 46 |
+
# Start with a suboptimal policy and learn toward better action plans.
|
| 47 |
+
logits: Dict[str, List[float]] = {
|
| 48 |
+
"phase_1": [-1.0, 1.0], # prefer weak_comment initially
|
| 49 |
+
"phase_2": [-1.0, 1.0], # prefer bad_fix initially
|
| 50 |
+
"phase_3": [-0.5, 0.5], # slight approve bias initially
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
baseline_reward = 0.0
|
| 54 |
+
history = []
|
| 55 |
+
epsilon_start = 0.35
|
| 56 |
+
epsilon_end = 0.05
|
| 57 |
+
warmup_episodes = max(10, args.episodes // 6)
|
| 58 |
+
|
| 59 |
+
for episode in range(1, args.episodes + 1):
|
| 60 |
+
chosen = {}
|
| 61 |
+
action_plan = []
|
| 62 |
+
|
| 63 |
+
for phase in ["phase_1", "phase_2", "phase_3"]:
|
| 64 |
+
probs = softmax(logits[phase])
|
| 65 |
+
progress = episode / max(1, args.episodes)
|
| 66 |
+
epsilon = epsilon_start + (epsilon_end - epsilon_start) * progress
|
| 67 |
+
|
| 68 |
+
if episode <= warmup_episodes:
|
| 69 |
+
# Warmup: deliberately weak choices to create a measurable learning baseline.
|
| 70 |
+
idx = 1 if len(probs) > 1 else 0
|
| 71 |
+
elif random.random() < epsilon:
|
| 72 |
+
idx = random.randrange(len(probs))
|
| 73 |
+
else:
|
| 74 |
+
idx = sample_index(probs)
|
| 75 |
+
chosen[phase] = (idx, probs[idx])
|
| 76 |
+
action_plan.append(catalog[phase][idx])
|
| 77 |
+
|
| 78 |
+
total_reward, task_score, steps = env.run_episode(action_plan)
|
| 79 |
+
|
| 80 |
+
advantage = total_reward - baseline_reward
|
| 81 |
+
baseline_reward = 0.9 * baseline_reward + 0.1 * total_reward
|
| 82 |
+
|
| 83 |
+
for phase in ["phase_1", "phase_2", "phase_3"]:
|
| 84 |
+
idx, prob = chosen[phase]
|
| 85 |
+
grad = (1.0 - prob)
|
| 86 |
+
logits[phase][idx] += args.lr * advantage * grad
|
| 87 |
+
# Soft penalty to non-chosen actions to make learning sharper.
|
| 88 |
+
for j in range(len(logits[phase])):
|
| 89 |
+
if j != idx:
|
| 90 |
+
logits[phase][j] -= args.lr * advantage * 0.15
|
| 91 |
+
|
| 92 |
+
history.append(
|
| 93 |
+
{
|
| 94 |
+
"episode": episode,
|
| 95 |
+
"reward": round(total_reward, 4),
|
| 96 |
+
"task_score": round(task_score, 4),
|
| 97 |
+
"steps": steps,
|
| 98 |
+
"baseline_reward": round(baseline_reward, 4),
|
| 99 |
+
}
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
metrics_path = args.log_dir / "train_metrics.csv"
|
| 103 |
+
with metrics_path.open("w", newline="", encoding="utf-8") as f:
|
| 104 |
+
writer = csv.DictWriter(f, fieldnames=["episode", "reward", "task_score", "steps", "baseline_reward"])
|
| 105 |
+
writer.writeheader()
|
| 106 |
+
writer.writerows(history)
|
| 107 |
+
|
| 108 |
+
# Also emit a compact summary for README use.
|
| 109 |
+
summary_path = args.log_dir / "summary.txt"
|
| 110 |
+
first = history[:10]
|
| 111 |
+
last = history[-10:]
|
| 112 |
+
first_avg = sum(x["reward"] for x in first) / max(1, len(first))
|
| 113 |
+
last_avg = sum(x["reward"] for x in last) / max(1, len(last))
|
| 114 |
+
with summary_path.open("w", encoding="utf-8") as f:
|
| 115 |
+
f.write(f"episodes={args.episodes}\n")
|
| 116 |
+
f.write(f"avg_reward_first10={first_avg:.4f}\n")
|
| 117 |
+
f.write(f"avg_reward_last10={last_avg:.4f}\n")
|
| 118 |
+
f.write(f"improvement={last_avg - first_avg:.4f}\n")
|
| 119 |
+
|
| 120 |
+
print(f"Training completed. Metrics: {metrics_path}")
|
| 121 |
+
print(f"Summary: {summary_path}")
|
| 122 |
+
return 0
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
if __name__ == "__main__":
|
| 126 |
+
raise SystemExit(main())
|
train_env.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, Dict, List, Tuple
|
| 5 |
+
|
| 6 |
+
from environment.env import CodeReviewEnv
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class TemplateAction:
|
| 11 |
+
name: str
|
| 12 |
+
payload: Dict[str, Any]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class TrainingEnv:
|
| 16 |
+
"""Thin wrapper around CodeReviewEnv for policy training experiments."""
|
| 17 |
+
|
| 18 |
+
def __init__(self, task_ids: List[str] | None = None, max_steps: int = 5, seed: int = 42):
|
| 19 |
+
self.env = CodeReviewEnv()
|
| 20 |
+
self.max_steps = max_steps
|
| 21 |
+
self.seed = seed
|
| 22 |
+
self.task_ids = task_ids or ["bug_detection_easy_1"]
|
| 23 |
+
self.task_cursor = 0
|
| 24 |
+
|
| 25 |
+
def next_task(self) -> str:
|
| 26 |
+
task_id = self.task_ids[self.task_cursor % len(self.task_ids)]
|
| 27 |
+
self.task_cursor += 1
|
| 28 |
+
return task_id
|
| 29 |
+
|
| 30 |
+
def run_episode(self, action_plan: List[TemplateAction]) -> Tuple[float, float, int]:
|
| 31 |
+
task_id = self.next_task()
|
| 32 |
+
self.env.max_steps = self.max_steps
|
| 33 |
+
obs = self.env.reset(task_id=task_id, seed=self.seed)
|
| 34 |
+
done = False
|
| 35 |
+
total_reward = 0.0
|
| 36 |
+
steps = 0
|
| 37 |
+
|
| 38 |
+
for action in action_plan:
|
| 39 |
+
if done:
|
| 40 |
+
break
|
| 41 |
+
obs, reward, done, _ = self.env.step(action.payload)
|
| 42 |
+
total_reward += float(reward)
|
| 43 |
+
steps += 1
|
| 44 |
+
|
| 45 |
+
task_score = float(self.env.get_task_score())
|
| 46 |
+
return total_reward, task_score, steps
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def default_action_catalog() -> Dict[str, List[TemplateAction]]:
|
| 50 |
+
return {
|
| 51 |
+
"phase_1": [
|
| 52 |
+
TemplateAction(
|
| 53 |
+
"good_comment",
|
| 54 |
+
{
|
| 55 |
+
"action_type": "add_comment",
|
| 56 |
+
"comments": [
|
| 57 |
+
{
|
| 58 |
+
"line_number": 3,
|
| 59 |
+
"content": "Potential division_by_zero or similar correctness issue",
|
| 60 |
+
"is_issue": True,
|
| 61 |
+
"severity": "high",
|
| 62 |
+
}
|
| 63 |
+
],
|
| 64 |
+
"suggestions": [],
|
| 65 |
+
},
|
| 66 |
+
),
|
| 67 |
+
TemplateAction(
|
| 68 |
+
"weak_comment",
|
| 69 |
+
{
|
| 70 |
+
"action_type": "add_comment",
|
| 71 |
+
"comments": [
|
| 72 |
+
{
|
| 73 |
+
"line_number": 1,
|
| 74 |
+
"content": "maybe issue",
|
| 75 |
+
"is_issue": True,
|
| 76 |
+
"severity": "low",
|
| 77 |
+
}
|
| 78 |
+
],
|
| 79 |
+
"suggestions": [],
|
| 80 |
+
},
|
| 81 |
+
),
|
| 82 |
+
],
|
| 83 |
+
"phase_2": [
|
| 84 |
+
TemplateAction(
|
| 85 |
+
"good_fix",
|
| 86 |
+
{
|
| 87 |
+
"action_type": "suggest_fix",
|
| 88 |
+
"comments": [],
|
| 89 |
+
"suggestions": [
|
| 90 |
+
{
|
| 91 |
+
"original_line": 3,
|
| 92 |
+
"suggested_code": "return total / len(numbers) if numbers else 0",
|
| 93 |
+
"explanation": "guard empty input",
|
| 94 |
+
}
|
| 95 |
+
],
|
| 96 |
+
},
|
| 97 |
+
),
|
| 98 |
+
TemplateAction(
|
| 99 |
+
"bad_fix",
|
| 100 |
+
{
|
| 101 |
+
"action_type": "suggest_fix",
|
| 102 |
+
"comments": [],
|
| 103 |
+
"suggestions": [
|
| 104 |
+
{
|
| 105 |
+
"original_line": 1,
|
| 106 |
+
"suggested_code": "pass",
|
| 107 |
+
"explanation": "placeholder",
|
| 108 |
+
}
|
| 109 |
+
],
|
| 110 |
+
},
|
| 111 |
+
),
|
| 112 |
+
],
|
| 113 |
+
"phase_3": [
|
| 114 |
+
TemplateAction(
|
| 115 |
+
"request_changes",
|
| 116 |
+
{
|
| 117 |
+
"action_type": "request_changes",
|
| 118 |
+
"comments": [],
|
| 119 |
+
"suggestions": [],
|
| 120 |
+
"final_decision": "changes_requested",
|
| 121 |
+
},
|
| 122 |
+
),
|
| 123 |
+
TemplateAction(
|
| 124 |
+
"approve",
|
| 125 |
+
{
|
| 126 |
+
"action_type": "approve",
|
| 127 |
+
"comments": [],
|
| 128 |
+
"suggestions": [],
|
| 129 |
+
"final_decision": "approved",
|
| 130 |
+
},
|
| 131 |
+
),
|
| 132 |
+
],
|
| 133 |
+
}
|