100XZX001 commited on
Commit
4036b6f
·
verified ·
1 Parent(s): 0a1f93a

Upload 19 files

Browse files
Files changed (19) hide show
  1. Dockerfile +23 -0
  2. LICENSE +21 -0
  3. README.md +118 -14
  4. __init__.py +11 -0
  5. app.py +104 -0
  6. author.py +219 -0
  7. bugs.json +127 -0
  8. client.py +5 -0
  9. environment.py +613 -0
  10. grader.py +142 -0
  11. models.py +88 -0
  12. openenv.yaml +135 -0
  13. pyproject.toml +30 -0
  14. redteam.py +274 -0
  15. requirements.txt +10 -0
  16. rltool.py +127 -0
  17. rubrics.py +136 -0
  18. test_runner.py +208 -0
  19. training.py +792 -0
Dockerfile ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dockerfile – OpenEnv server with FastAPI and all dependencies
2
+ FROM python:3.10-slim
3
+
4
+ # Install system dependencies required for chromadb and sentence-transformers
5
+ RUN apt-get update && apt-get install -y --no-install-recommends \
6
+ build-essential \
7
+ && rm -rf /var/lib/apt/lists/*
8
+
9
+ WORKDIR /app
10
+
11
+ # Copy requirements and install Python dependencies
12
+ COPY requirements.txt .
13
+ RUN pip install --no-cache-dir -r requirements.txt
14
+
15
+ # Copy the rest of the application
16
+ COPY . .
17
+
18
+ # Expose the port used by the FastAPI server
19
+ EXPOSE 7860
20
+
21
+ # Run the server using uvicorn
22
+ # Note: 'server.app:app' assumes the FastAPI app is in server/app.py
23
+ CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2026 YUTA
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,14 +1,118 @@
1
- ---
2
- title: Code Review Training
3
- emoji: 🌖
4
- colorFrom: gray
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 6.13.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: Training pipeline for CodeReview‑Professional‑Workflow. Uses
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Code Review Professional Workflow
3
+ emoji: 🔥
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: docker
7
+ app_port: 7860
8
+ pinned: false
9
+ ---
10
+
11
+ # Code Review Professional Workflow
12
+
13
+ This project is a multi-turn RL environment where an agent plays the role of a senior code reviewer.
14
+ Instead of just patching code, the agent must gather evidence (`inspect`, `run_tests`, `run_linter`,
15
+ `query_docs`) and convince a simulated developer persona to accept the fix.
16
+
17
+ ### Why this environment is interesting
18
+
19
+ - It combines **technical correctness** (tests/lint) with **human acceptance** (negotiation).
20
+ - It includes **25 injected bug types** across 5 difficulty levels via `RedTeam`.
21
+ - It supports both a **full reward profile** (rich shaping) and a **core reward profile**
22
+ (minimal, baseline-friendly signal for ablations).
23
+
24
+ ## Quick Start
25
+
26
+ ```python
27
+ from environment import CodeReviewEnv
28
+ env = CodeReviewEnv(task="easy", reward_profile="full")
29
+ obs = env.reset()
30
+ print(obs.code_snippet)
31
+ ```
32
+
33
+ ## Demo Script (Non-Technical Friendly)
34
+
35
+ Use this 60-90 second flow in a demo:
36
+
37
+ 1. Reset on `easy` and show the buggy snippet.
38
+ 2. Take `inspect` and `run_tests` actions to show evidence gathering.
39
+ 3. Ask `query_docs` once to show retrieval-assisted reasoning.
40
+ 4. Propose a fix and show accepted/denied feedback from the author persona.
41
+ 5. Repeat once on `harder` to show increased challenge.
42
+
43
+ Message for audience: "The agent is learning not only to fix code, but to justify and communicate the fix."
44
+
45
+ ## Environment Endpoints
46
+
47
+ - `POST /reset` – reset environment (optional `task` parameter)
48
+ - `POST /step` – take an action (JSON)
49
+ - `GET /state` – get full environment state
50
+ - `GET /health` – health check
51
+ - `GET /metadata` – environment metadata
52
+ - `GET /schema` – action/observation schemas
53
+ - `POST /mcp` – minimal MCP endpoint
54
+
55
+ ## Tasks
56
+ ## 🐛 Bug Taxonomy (25 bugs across 5 difficulty levels)
57
+
58
+ The **RedTeam** randomly selects one bug from the current difficulty level at the start of every episode.
59
+ Your agent must figure out what’s broken, gather evidence, and convince the simulated author – or it won’t stick.
60
+
61
+ ### 🟢 Easy – Null‑Checks & Simple Logic Errors
62
+
63
+ | # | Bug ID | What’s wrong | Injection method |
64
+ |---|--------|--------------|------------------|
65
+ | 1 | `null_check` | Missing `if key in dict:` guard → KeyError | AST: remove the if‑statement |
66
+ | 2 | `simple_typo` | Misspelled variable `users` → `usres` | AST: rename variable |
67
+ | 3 | `string_index` | String index shifted by +1 | AST: change constant in index |
68
+ | 4 | `default_value` | `dict.get(key)` used without a fallback | AST: replace `dict.get(key)` with `dict[key]` |
69
+ | 5 | `empty_return` | Function returns `None` prematurely | AST: insert `return None` early |
70
+
71
+ ### 🟡 Medium – Off‑By‑One, Loop Logic & Simple Arithmetic
72
+
73
+ | # | Bug ID | What’s wrong | Injection method |
74
+ |---|--------|--------------|------------------|
75
+ | 6 | `off_by_one` | `range(x)` becomes `range(1, x-1)` – skips first & last | AST: modify range arguments |
76
+ | 7 | `loop_skip` | `range(len(arr))` becomes `range(len(arr)-1)` – misses last element | AST: change range length |
77
+ | 8 | `sign_error` | `sum += item` turned into `sum -= item` | AST: swap Add / Sub |
78
+ | 9 | `swap_args` | Function arguments swapped | AST: swap first two arguments |
79
+ |10 | `uninitialised_var` | Variable used before assignment in a loop | AST: remove the assignment statement |
80
+
81
+ ### 🟠 Hard – Division‑By‑Zero, Floating‑Point & Edge Cases
82
+
83
+ | # | Bug ID | What’s wrong | Injection method |
84
+ |---|--------|--------------|------------------|
85
+ |11 | `division_by_zero_empty` | Empty‑list guard removed before averaging | AST: delete `if not data:` |
86
+ |12 | `division_by_zero_zero` | Denominator check removed | AST: remove the zero‑check |
87
+ |13 | `float_precision` | True division `/` replaced by integer division `//` | AST: change Div → FloorDiv |
88
+ |14 | `abs_usage` | `abs()` call removed when comparing differences | AST: delete `abs()` wrapper |
89
+ |15 | `round_error` | `round()` placed too early, causing precision drift | AST: inject `round()` prematurely |
90
+
91
+ ### 🔴 Harder – Race Conditions & Atomicity Bugs
92
+
93
+ | # | Bug ID | What’s wrong | Injection method |
94
+ |---|--------|--------------|------------------|
95
+ |16 | `missing_lock` | Shared counter incremented without a lock | Template: remove `with lock:` |
96
+ |17 | `double_lock` | Acquiring the same lock twice → deadlock risk | Template: add extra `lock.acquire()` |
97
+ |18 | `global_nonatomic` | `count = count + 1` (read‑modify‑write) instead of `+=` | AST: modify assignment node |
98
+ |19 | `thread_safe_list` | List append across threads without synchronisation | Template: remove lock from list operation |
99
+ |20 | `volatile_read` | Shared flag read outside a lock → stale value | Template: remove synchronisation block |
100
+
101
+ ### ⚫ Hardest – Deadlocks, Ordering & Complex Concurrency
102
+
103
+ | # | Bug ID | What’s wrong | Injection method |
104
+ |---|--------|--------------|------------------|
105
+ |21 | `deadlock_order` | Locks acquired in opposite order in two threads | Template: swap lock order |
106
+ |22 | `nested_lock_timeout` | `lock.acquire()` without a timeout → permanent hang | Template: remove timeout logic |
107
+ |23 | `fork_join` | Thread started but not joined (`join()` missing) | AST: remove `thread.join()` |
108
+ |24 | `mutex_release` | Lock released by a thread that never acquired it | Template: incorrect release logic |
109
+ |25 | `race_on_init` | Shared resource initialised after threads have started | Template: move initialisation after `join()` |
110
+ ## Deployment
111
+
112
+ ```bash
113
+ openenv push
114
+ ```
115
+
116
+ ## License
117
+
118
+ MIT
__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Criticrl environment server components."""
8
+
9
+ from .CriticRL__environment import CriticrlEnvironment
10
+
11
+ __all__ = ["CriticrlEnvironment"]
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # server/app.py – OpenEnv HTTP server
2
+ import sys
3
+ import os
4
+ sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
5
+
6
+ from fastapi import FastAPI, HTTPException
7
+ from environment import CodeReviewEnv
8
+ from models import AnyAction, Observation, Reward, State, action_adapter
9
+
10
+ app = FastAPI(title="Code Review Environment", version="1.0.0")
11
+ env = CodeReviewEnv()
12
+
13
+ # ----------------------------------------------------------------------
14
+ # Health & metadata endpoints
15
+ # ----------------------------------------------------------------------
16
+ @app.get("/")
17
+ def root():
18
+ print("[ROOT] Health check hit")
19
+ return {"status": "crazy good"}
20
+
21
+ @app.get("/health")
22
+ def health():
23
+ print("[HEALTH] Service is healthy")
24
+ return {"status": "healthy"}
25
+
26
+ @app.get("/metadata")
27
+ def metadata():
28
+ print("[METADATA] Requested")
29
+ return {
30
+ "name": "Code Review Professional Workflow",
31
+ "description": (
32
+ "Multi‑turn code review environment for professional‑level bug fixing. "
33
+ "The agent must inspect, test, lint, query documentation, and negotiate with "
34
+ "a simulated (persona‑driven) author to get a fix accepted. "
35
+ "Includes 25 bugs across 5 difficulty levels, AST‑based injection, "
36
+ "a reward‑shaping system (full/core profiles), and curriculum learning. "
37
+ "Designed for RL training (PPO, DPO, or any policy‑gradient method)."
38
+ )
39
+ }
40
+
41
+ @app.get("/schema")
42
+ def schema():
43
+ print("[SCHEMA] Requested")
44
+ return {
45
+ "action": AnyAction.model_json_schema(),
46
+ "observation": Observation.model_json_schema(),
47
+ "state": State.model_json_schema()
48
+ }
49
+
50
+ @app.post("/mcp")
51
+ def mcp():
52
+ print("[MCP] Ping received")
53
+ return {"jsonrpc": "2.0", "result": None}
54
+
55
+ # ----------------------------------------------------------------------
56
+ # Environment endpoints
57
+ # ----------------------------------------------------------------------
58
+ @app.post("/reset")
59
+ def reset(task: str = "easy"):
60
+ try:
61
+ print(f"[RESET] Starting new episode | task={task}")
62
+
63
+ env.set_task(task)
64
+ obs = env.reset()
65
+
66
+ print(f"[RESET DONE] step={env._step_count}")
67
+
68
+ return obs.__dict__
69
+ except Exception as e:
70
+ print(f"[RESET ERROR] {e}")
71
+ raise HTTPException(status_code=400, detail=str(e))
72
+
73
+ @app.post("/step")
74
+ def step(action: dict):
75
+ try:
76
+ print(f"[STEP INPUT] {action}")
77
+
78
+ parsed_action = action_adapter.validate_python(action)
79
+ obs, reward, done, info = env.step(parsed_action)
80
+
81
+ print(f"[STEP OUTPUT] reward={reward.value:.4f} | done={done}")
82
+
83
+ return {
84
+ "observation": obs.__dict__,
85
+ "reward": reward.value,
86
+ "done": done,
87
+ "info": info
88
+ }
89
+ except Exception as e:
90
+ print(f"[STEP ERROR] {e}")
91
+ raise HTTPException(status_code=400, detail=str(e))
92
+
93
+ @app.get("/state")
94
+ def state():
95
+ print("[STATE] Requested")
96
+ return env._get_observation().__dict__
97
+
98
+ # ----------------------------------------------------------------------
99
+ # Main entry point (for local testing)
100
+ # ----------------------------------------------------------------------
101
+ if __name__ == "__main__":
102
+ import uvicorn
103
+ print("[SERVER START] Running on http://0.0.0.0:7860")
104
+ uvicorn.run(app, host="0.0.0.0", port=7860)
author.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cell 7 author.py – Final production version: stateful, evidence-driven, belief tracking
2
+
3
+ import re
4
+ import ast
5
+ from dataclasses import dataclass, field
6
+ from typing import List, Dict, Any, Optional
7
+
8
+ @dataclass
9
+ class PersonaAuthor:
10
+ """
11
+ Simulates a human developer with:
12
+ - Continuous belief (confidence)
13
+ - Evidence-based reasoning
14
+ - Conversation memory
15
+ - Code inspection awareness
16
+ """
17
+
18
+ personality: str = "defensive" # defensive | junior | collaborative
19
+ max_persuasion_rounds: int = 5
20
+
21
+ # Evidence weights
22
+ weight_test_pass: float = 0.5
23
+ weight_lint_clean: float = 0.2
24
+ weight_doc_found: float = 0.15
25
+ weight_explanation_quality: float = 0.15
26
+
27
+ # Personality thresholds
28
+ thresholds: Dict[str, float] = field(default_factory=lambda: {
29
+ "defensive": 0.7,
30
+ "junior": 0.3,
31
+ "collaborative": 0.5,
32
+ })
33
+
34
+ # Internal state
35
+ _confidence: float = 0.0
36
+ _conversation: List[Dict[str, Any]] = field(default_factory=list)
37
+ _pushback_count: int = 0
38
+ _last_evidence_score: float = 0.0
39
+ _stagnation_counter: int = 0
40
+
41
+ # ------------------------------------------------------------------
42
+ # Lifecycle
43
+ # ------------------------------------------------------------------
44
+ def __post_init__(self):
45
+ self.reset()
46
+
47
+ def reset(self):
48
+ self._confidence = 0.0
49
+ self._conversation.clear()
50
+ self._pushback_count = 0
51
+ self._last_evidence_score = 0.0
52
+ self._stagnation_counter = 0
53
+
54
+ # ------------------------------------------------------------------
55
+ # Main interaction
56
+ # ------------------------------------------------------------------
57
+ # Added weight for code change magnitude
58
+ weight_code_change: float = 0.1 # small change is better
59
+
60
+ def respond(self,
61
+ agent_comment: str = "",
62
+ agent_question: str = "",
63
+ test_results: Optional[str] = None,
64
+ lint_results: Optional[str] = None,
65
+ doc_results: Optional[str] = None,
66
+ proposed_fix: Optional[str] = None,
67
+ original_code: Optional[str] = None) -> str:
68
+
69
+ # Store conversation
70
+ self._conversation.append({
71
+ "comment": agent_comment,
72
+ "question": agent_question,
73
+ "test": test_results,
74
+ "lint": lint_results,
75
+ "docs": doc_results
76
+ })
77
+
78
+ # Extract structured evidence
79
+ evidence = self._extract_evidence(test_results, lint_results, doc_results)
80
+
81
+ # Code inspection
82
+ code_change = 0.0
83
+ if proposed_fix and original_code:
84
+ code_change = self._inspect_code(proposed_fix, original_code)
85
+ evidence["code_change"] = code_change
86
+
87
+ # Explanation score
88
+ text = (agent_comment + " " + agent_question).lower()
89
+ explanation_score = self._score_explanation(text)
90
+
91
+ # Compute evidence score – now includes code change penalty (1 - change)
92
+ evidence_score = (
93
+ self.weight_test_pass * evidence.get("test_pass_ratio", 0.0) +
94
+ self.weight_lint_clean * (1 - min(1.0, evidence.get("lint_errors", 0)/10)) +
95
+ self.weight_doc_found * (1.0 if evidence.get("doc_found") else 0.0) +
96
+ self.weight_explanation_quality * explanation_score +
97
+ self.weight_code_change * (1.0 - code_change) # surgical fix rewarded
98
+ )
99
+
100
+ evidence_score = max(0.0, min(1.0, evidence_score))
101
+
102
+ # Detect improvement
103
+ delta = evidence_score - self._last_evidence_score
104
+ self._last_evidence_score = evidence_score
105
+
106
+ if delta > 0.05:
107
+ self._stagnation_counter = 0
108
+ else:
109
+ self._stagnation_counter += 1
110
+
111
+ # Update belief (momentum)
112
+ lr = 0.3
113
+ self._confidence = (1 - lr) * self._confidence + lr * evidence_score
114
+
115
+ # Penalise stagnation
116
+ if self._stagnation_counter >= 2:
117
+ self._confidence *= 0.9
118
+
119
+ # Decision
120
+ threshold = self.thresholds.get(self.personality, 0.5)
121
+
122
+ if self._confidence >= threshold or self._pushback_count >= self.max_persuasion_rounds:
123
+ return "Alright, I'm convinced. Let's proceed with your fix."
124
+
125
+ # Otherwise push back
126
+ self._pushback_count += 1
127
+ return self._generate_pushback(evidence, text)
128
+
129
+ # ------------------------------------------------------------------
130
+ # Evidence extraction
131
+ # ------------------------------------------------------------------
132
+ def _extract_evidence(self, test_results, lint_results, doc_results):
133
+ evidence = {
134
+ "test_pass_ratio": 0.0,
135
+ "lint_errors": 0,
136
+ "doc_found": False
137
+ }
138
+
139
+ # Parse test results
140
+ if test_results:
141
+ match = re.search(r'(\d+)\s*/\s*(\d+)', test_results)
142
+ if match:
143
+ p, t = int(match.group(1)), int(match.group(2))
144
+ evidence["test_pass_ratio"] = p / t if t else 0.0
145
+ elif "true" in test_results.lower():
146
+ evidence["test_pass_ratio"] = 1.0
147
+ elif "false" in test_results.lower():
148
+ evidence["test_pass_ratio"] = 0.0
149
+
150
+ # Lint errors
151
+ if lint_results:
152
+ evidence["lint_errors"] = len(re.findall(r'error', lint_results.lower()))
153
+
154
+ # Docs
155
+ if doc_results and "no relevant" not in doc_results.lower():
156
+ evidence["doc_found"] = True
157
+
158
+ return evidence
159
+
160
+ # ------------------------------------------------------------------
161
+ # Explanation scoring
162
+ # ------------------------------------------------------------------
163
+ def _score_explanation(self, text: str) -> float:
164
+ score = 0.0
165
+
166
+ if "because" in text or "therefore" in text:
167
+ score += 0.3
168
+ if "test" in text or "example" in text:
169
+ score += 0.2
170
+ if len(text.split()) > 30:
171
+ score += 0.2
172
+ if "error" in text or "fix" in text:
173
+ score += 0.1
174
+
175
+ return min(1.0, score)
176
+
177
+ # ------------------------------------------------------------------
178
+ # Code inspection
179
+ # ------------------------------------------------------------------
180
+ def _inspect_code(self, new_code: str, old_code: str) -> float:
181
+ try:
182
+ t1 = ast.parse(old_code)
183
+ t2 = ast.parse(new_code)
184
+
185
+ n1 = len(list(ast.walk(t1)))
186
+ n2 = len(list(ast.walk(t2)))
187
+
188
+ change = abs(n2 - n1) / max(n1, 1)
189
+ return min(1.0, change)
190
+ except:
191
+ return 0.0
192
+
193
+ # ------------------------------------------------------------------
194
+ # Pushback generator
195
+ # ------------------------------------------------------------------
196
+ def _generate_pushback(self, evidence, text):
197
+ if evidence["test_pass_ratio"] < 0.5:
198
+ return "Tests are still failing. Show a passing case."
199
+
200
+ if evidence["lint_errors"] > 0:
201
+ return f"There are {evidence['lint_errors']} lint errors. Fix them."
202
+
203
+ if not evidence["doc_found"]:
204
+ return "Provide documentation or reference."
205
+
206
+ if "because" not in text:
207
+ return "Explain why this works."
208
+
209
+ if len(text.split()) < 20:
210
+ return "Too brief. Expand your reasoning."
211
+
212
+ return "Not convinced yet. Give a concrete example."
213
+
214
+ # ------------------------------------------------------------------
215
+ # Score
216
+ # ------------------------------------------------------------------
217
+ def get_negotiation_score(self) -> float:
218
+ penalty = 0.1 * min(3, self._pushback_count)
219
+ return max(0.0, min(1.0, self._confidence - penalty))
bugs.json ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "easy": {
3
+ "null_check": {
4
+ "type": "ast",
5
+ "bug_type": "null_check",
6
+ "oracle_hint": "Add back the if-guard that was removed"
7
+ },
8
+ "simple_typo": {
9
+ "type": "ast",
10
+ "bug_type": "simple_typo",
11
+ "oracle_hint": "Fix the misspelled variable name"
12
+ },
13
+ "string_index": {
14
+ "type": "ast",
15
+ "bug_type": "string_index",
16
+ "oracle_hint": "Correct the index offset"
17
+ },
18
+ "default_value": {
19
+ "type": "ast",
20
+ "bug_type": "default_value",
21
+ "oracle_hint": "Restore dict.get() with proper default"
22
+ },
23
+ "empty_return": {
24
+ "type": "ast",
25
+ "bug_type": "empty_return",
26
+ "oracle_hint": "Remove the premature return None"
27
+ }
28
+ },
29
+ "medium": {
30
+ "off_by_one": {
31
+ "type": "ast",
32
+ "bug_type": "off_by_one"
33
+ },
34
+ "loop_skip": {
35
+ "type": "ast",
36
+ "bug_type": "loop_skip"
37
+ },
38
+ "sign_error": {
39
+ "type": "ast",
40
+ "bug_type": "sign_error"
41
+ },
42
+ "swap_args": {
43
+ "type": "ast",
44
+ "bug_type": "swap_args"
45
+ },
46
+ "uninitialised_var": {
47
+ "type": "ast",
48
+ "bug_type": "uninitialised_var"
49
+ }
50
+ },
51
+ "hard": {
52
+ "division_by_zero_empty": {
53
+ "type": "ast",
54
+ "bug_type": "division_by_zero_empty"
55
+ },
56
+ "division_by_zero_zero": {
57
+ "type": "ast",
58
+ "bug_type": "division_by_zero_zero"
59
+ },
60
+ "float_precision": {
61
+ "type": "ast",
62
+ "bug_type": "float_precision"
63
+ },
64
+ "abs_usage": {
65
+ "type": "ast",
66
+ "bug_type": "abs_usage"
67
+ },
68
+ "round_error": {
69
+ "type": "ast",
70
+ "bug_type": "round_error"
71
+ }
72
+ },
73
+ "harder": {
74
+ "missing_lock": {
75
+ "type": "template",
76
+ "buggy": "counter = 0\ndef increment():\n global counter\n counter += 1",
77
+ "oracle": "counter = 0\nimport threading\nlock = threading.Lock()\ndef increment():\n global counter\n with lock:\n counter += 1"
78
+ },
79
+ "double_lock": {
80
+ "type": "template",
81
+ "buggy": "import threading\nlock = threading.Lock()\ndef do_work():\n lock.acquire()\n lock.acquire()\n print('working')\n lock.release()",
82
+ "oracle": "import threading\nlock = threading.Lock()\ndef do_work():\n with lock:\n print('working')"
83
+ },
84
+ "global_nonatomic": {
85
+ "type": "template",
86
+ "buggy": "count = 0\ndef add():\n global count\n count = count + 1",
87
+ "oracle": "count = 0\ndef add():\n global count\n count += 1"
88
+ },
89
+ "thread_safe_list": {
90
+ "type": "template",
91
+ "buggy": "import threading\nitems = []\ndef append_item(item):\n items.append(item)",
92
+ "oracle": "import threading\nitems = []\nlock = threading.Lock()\ndef append_item(item):\n with lock:\n items.append(item)"
93
+ },
94
+ "volatile_read": {
95
+ "type": "template",
96
+ "buggy": "import threading\nstop = False\ndef worker():\n while not stop:\n pass",
97
+ "oracle": "import threading\nstop = False\nlock = threading.Lock()\ndef worker():\n while True:\n with lock:\n if stop:\n break"
98
+ }
99
+ },
100
+ "hardest": {
101
+ "deadlock_order": {
102
+ "type": "template",
103
+ "buggy": "import threading\nlock1 = threading.Lock()\nlock2 = threading.Lock()\ndef thread1():\n with lock1:\n with lock2:\n pass\ndef thread2():\n with lock2:\n with lock1:\n pass",
104
+ "oracle": "import threading\nlock1 = threading.Lock()\nlock2 = threading.Lock()\ndef thread1():\n with lock1:\n with lock2:\n pass\ndef thread2():\n with lock1:\n with lock2:\n pass"
105
+ },
106
+ "nested_lock_timeout": {
107
+ "type": "template",
108
+ "buggy": "import threading\nlock = threading.Lock()\ndef work():\n lock.acquire()\n # critical section\n lock.release()",
109
+ "oracle": "import threading\nlock = threading.Lock()\ndef work():\n if lock.acquire(timeout=1):\n try:\n # critical section\n finally:\n lock.release()"
110
+ },
111
+ "fork_join": {
112
+ "type": "template",
113
+ "buggy": "import threading\ndef worker():\n pass\nt = threading.Thread(target=worker)\nt.start()",
114
+ "oracle": "import threading\ndef worker():\n pass\nt = threading.Thread(target=worker)\nt.start()\nt.join()"
115
+ },
116
+ "mutex_release": {
117
+ "type": "template",
118
+ "buggy": "import threading\nlock = threading.Lock()\ndef thread_A():\n lock.acquire()\n lock.release()\ndef thread_B():\n lock.release()",
119
+ "oracle": "import threading\nlock = threading.Lock()\ndef thread_A():\n with lock:\n pass\ndef thread_B():\n with lock:\n pass"
120
+ },
121
+ "race_on_init": {
122
+ "type": "template",
123
+ "buggy": "import threading\nitems = []\ndef init():\n global items\n items = [1,2,3]\nt = threading.Thread(target=init)\nt.start()\nprint(items)",
124
+ "oracle": "import threading\nitems = []\ndef init():\n global items\n items = [1,2,3]\nt = threading.Thread(target=init)\nt.start()\nt.join()\nprint(items)"
125
+ }
126
+ }
127
+ }
client.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # client.py – OpenEnv client entry point
2
+ from environment import CodeReviewEnv
3
+
4
+ # The OpenEnv framework will import this class as the environment.
5
+ __all__ = ["CodeReviewEnv"]
environment.py ADDED
@@ -0,0 +1,613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # environment.py – FULLY CORRECTED RL Environment (TRUE Markov + Fixed Bugs)
2
+
3
+ import sys
4
+ import subprocess
5
+ import tempfile
6
+ import os
7
+ import re
8
+ from dataclasses import dataclass, field
9
+ from typing import Tuple, Dict, Any, Optional, List
10
+
11
+ from models import (
12
+ AnyAction, WriteComment, ProposeFix, Execute, Inspect,
13
+ RunLinter, RunTests, QueryDocs, Skip, Done, AskQuestion,
14
+ Observation, Reward, State
15
+ )
16
+ from redteam import RedTeam
17
+ from test_runner import TestRunner
18
+ from author import PersonaAuthor
19
+ from rltool import ToolBox
20
+ from rubrics import (
21
+ ToolUsageRubric,
22
+ TestDeltaRubric,
23
+ LintDeltaRubric,
24
+ TerminalSuccessRubric,
25
+ ExplorationRubric,
26
+ AntiHackingRubric,
27
+ StepPenaltyRubric,
28
+ )
29
+
30
+ # ======================================================================
31
+ # FULLY MARKOV OBSERVATION (NOTHING HIDDEN)
32
+ # ======================================================================
33
+ @dataclass
34
+ class EnhancedObservation:
35
+ code_snippet: str
36
+ last_tool_output: str
37
+
38
+ current_test_score: float
39
+ current_lint_score: float
40
+ negotiation_score: float
41
+
42
+ previous_test_score: float
43
+ previous_lint_score: float
44
+
45
+ author_confidence: float
46
+ author_threshold: float
47
+
48
+ step: int
49
+ max_steps: int
50
+ progress_ratio: float
51
+
52
+ tests_run: bool
53
+ linter_run: bool
54
+ docs_queried: bool
55
+
56
+ last_action_type: str
57
+ action_history: List[str]
58
+
59
+ done: bool
60
+
61
+ bug_description: str
62
+ comments_count: int
63
+
64
+ # default fields must be at the very end
65
+ author_response: str = ""
66
+
67
+ # ======================================================================
68
+ # HELPER FUNCTIONS
69
+ # ======================================================================
70
+ def execute_code(code: str, timeout_sec: int = 5) -> Tuple[bool, str, str]:
71
+ if not code.strip():
72
+ return False, "", "Error: Empty code"
73
+
74
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False, encoding='utf-8') as f:
75
+ f.write(code)
76
+ tmp_path = f.name
77
+
78
+ try:
79
+ result = subprocess.run(
80
+ [sys.executable, tmp_path],
81
+ capture_output=True,
82
+ text=True,
83
+ timeout=timeout_sec
84
+ )
85
+ success = (result.returncode == 0)
86
+ return success, result.stdout, result.stderr
87
+ except subprocess.TimeoutExpired:
88
+ return False, "", f"Timeout after {timeout_sec}s"
89
+ except Exception as e:
90
+ return False, "", f"Execution error: {str(e)}"
91
+ finally:
92
+ try:
93
+ os.unlink(tmp_path)
94
+ except:
95
+ pass
96
+
97
+
98
+ # ======================================================================
99
+ # ENHANCED CODE REVIEW ENVIRONMENT
100
+ # ======================================================================
101
+ @dataclass
102
+ class CodeReviewEnv:
103
+ task: str = "easy"
104
+ max_steps: int = 10
105
+ step_penalty: float = 0.01
106
+ reward_profile: str = "full" # "full" or "core"
107
+
108
+ # Curriculum learning
109
+ auto_difficulty: bool = False
110
+ success_threshold: float = 0.7
111
+
112
+ # Reward shaping parameters
113
+ delta_weight: float = 0.3
114
+ tool_usage_bonus: float = 0.05
115
+ diversity_bonus: float = 0.03
116
+
117
+ _red_team: Optional[RedTeam] = field(init=False, default=None)
118
+ _author: Optional[PersonaAuthor] = field(init=False, default=None)
119
+
120
+ _current_code: str = field(init=False, default="")
121
+ _current_bug_id: str = field(init=False, default="")
122
+ _bug_description: str = field(init=False, default="")
123
+ _oracle_fix: str = field(init=False, default="")
124
+
125
+ _comments: list = field(init=False, default_factory=list)
126
+ _test_results: Optional[str] = field(init=False, default=None)
127
+ _lint_results: Optional[str] = field(init=False, default=None)
128
+ _doc_results: Optional[str] = field(init=False, default=None)
129
+
130
+ _step_count: int = field(init=False, default=0)
131
+ _done: bool = field(init=False, default=False)
132
+
133
+ # State tracking for dense rewards
134
+ _previous_test_score: float = field(init=False, default=0.0)
135
+ _previous_lint_score: float = field(init=False, default=0.0)
136
+ _current_test_score: float = field(init=False, default=0.0)
137
+ _current_lint_score: float = field(init=False, default=0.0)
138
+
139
+ # Tool usage tracking
140
+ _tests_run: bool = field(init=False, default=False)
141
+ _linter_run: bool = field(init=False, default=False)
142
+ _docs_queried: bool = field(init=False, default=False)
143
+
144
+ # Action history
145
+ _action_history: List[str] = field(init=False, default_factory=list)
146
+ _last_action_type: str = field(init=False, default="none")
147
+ _last_author_response: str = field(init=False, default="")
148
+
149
+ # FIXED: Track CUMULATIVE episode reward
150
+ _episode_total_reward: float = field(init=False, default=0.0)
151
+ _episode_rewards: List[float] = field(init=False, default_factory=list)
152
+ _difficulty_level: int = field(init=False, default=0)
153
+
154
+ # Bug-id bridge:
155
+ # RedTeam has fine-grained IDs, while TestRunner currently expects a
156
+ # smaller canonical set. Keep this mapping here so both modules can evolve
157
+ # independently without breaking evaluation.
158
+ _BUG_ID_CANONICAL_MAP = {
159
+ "division_by_zero_empty": "division_by_zero",
160
+ "division_by_zero_zero": "division_by_zero",
161
+ "sign_error": "wrong_operator",
162
+ }
163
+
164
+ # ===================================================================
165
+ def __post_init__(self):
166
+ self.set_task(self.task)
167
+
168
+ # ===================================================================
169
+ def _build_rubrics(self):
170
+ """
171
+ Build rubric stack from a named reward profile.
172
+ - full: richer shaping for exploration/tool-use behavior
173
+ - core: minimal stable signal for quick ablations/baselines
174
+ """
175
+ core_rubrics = [
176
+ TestDeltaRubric(weight=self.delta_weight),
177
+ LintDeltaRubric(weight=self.delta_weight),
178
+ TerminalSuccessRubric(),
179
+ StepPenaltyRubric(penalty=self.step_penalty),
180
+ ]
181
+ if self.reward_profile == "core":
182
+ return core_rubrics
183
+ if self.reward_profile == "full":
184
+ return [
185
+ *core_rubrics[:-1], # step penalty appended at end for consistent ordering
186
+ ToolUsageRubric(bonus=self.tool_usage_bonus),
187
+ ExplorationRubric(penalty=-0.05, bonus=self.diversity_bonus * 0.7),
188
+ AntiHackingRubric(),
189
+ core_rubrics[-1],
190
+ ]
191
+ raise ValueError(f"Unknown reward_profile: {self.reward_profile}")
192
+
193
+ # ===================================================================
194
+ def set_task(self, task: str):
195
+ if task not in ["easy", "medium", "hard", "harder", "hardest"]:
196
+ raise ValueError(f"Unknown task: {task}")
197
+
198
+ self.task = task
199
+ # Use stochastic bug sampling across episodes; fixed seed here would
200
+ # repeatedly select the same bug and weaken training diversity.
201
+ self._red_team = RedTeam(task, seed=None)
202
+ self._author = PersonaAuthor()
203
+ self.rubrics = self._build_rubrics()
204
+
205
+ task_to_level = {
206
+ "easy": 0, "medium": 1, "hard": 2,
207
+ "harder": 3, "hardest": 4
208
+ }
209
+ self._difficulty_level = task_to_level[task]
210
+
211
+ self._reset_internal()
212
+
213
+ # ===================================================================
214
+ def _reset_internal(self):
215
+ self._step_count = 0 # ← FIXED
216
+ self._comments = []
217
+ self._test_results = None
218
+ self._lint_results = None
219
+ self._doc_results = None
220
+ self._done = False
221
+
222
+ # Reset state tracking
223
+ self._previous_test_score = 0.0
224
+ self._previous_lint_score = 0.0
225
+ self._current_test_score = 0.0
226
+ self._current_lint_score = 0.0
227
+
228
+ self._tests_run = False
229
+ self._linter_run = False
230
+ self._docs_queried = False
231
+
232
+ self._action_history = []
233
+ self._last_action_type = "none"
234
+ self._last_author_response = ""
235
+
236
+ # FIXED: Reset episode cumulative reward
237
+ self._episode_total_reward = 0.0
238
+
239
+ self._author.reset()
240
+
241
+ # Base tasks
242
+ if self.task == "easy":
243
+ original = "def get_user(id):\n if id in users:\n return users[id]"
244
+ elif self.task == "medium":
245
+ original = "def process_items(items):\n for item in items:\n print(item)"
246
+ elif self.task == "hard":
247
+ original = "def average(data):\n if not data:\n return 0\n return sum(data) / len(data)"
248
+ elif self.task == "harder":
249
+ original = "counter = 0\ndef increment():\n global counter\n with lock:\n counter += 1"
250
+ else:
251
+ original = "def safe_work():\n with lock1:\n with lock2:\n do_work()"
252
+
253
+ buggy_code, bug_id, desc, oracle = self._red_team.inject_bug(original)
254
+ self._current_code = buggy_code
255
+ self._current_bug_id = bug_id
256
+ self._bug_description = desc
257
+ self._oracle_fix = oracle
258
+ self._comments.append(f"[RedTeam] {desc}")
259
+
260
+ # ===================================================================
261
+ def reset(self) -> EnhancedObservation:
262
+ """Reset with optional curriculum adjustment."""
263
+ if self.auto_difficulty and len(self._episode_rewards) > 0:
264
+ recent_performance = sum(self._episode_rewards[-5:]) / min(5, len(self._episode_rewards))
265
+
266
+ if recent_performance > self.success_threshold and self._difficulty_level < 4:
267
+ self._difficulty_level += 1
268
+ print(f"[Curriculum] Increasing difficulty to level {self._difficulty_level}")
269
+ elif recent_performance < 0.3 and self._difficulty_level > 0:
270
+ self._difficulty_level -= 1
271
+ print(f"[Curriculum] Decreasing difficulty to level {self._difficulty_level}")
272
+
273
+ level_to_task = {0: "easy", 1: "medium", 2: "hard", 3: "harder", 4: "hardest"}
274
+ self.task = level_to_task[self._difficulty_level]
275
+ # Keep curriculum stochastic for better coverage within each level.
276
+ self._red_team = RedTeam(self.task, seed=None)
277
+
278
+ self._reset_internal()
279
+ return self._get_observation()
280
+
281
+ # ===================================================================
282
+ def _get_observation(self) -> EnhancedObservation:
283
+ """Return COMPLETE Markov state."""
284
+ # Keep the author's message separate from tool output.
285
+ # Using `_test_results` here can leak unrelated outputs (tests/linter/docs)
286
+ # and gives the policy a noisy signal for dialogue actions.
287
+ if self._last_action_type in ("write_comment", "ask_question", "propose_fix"):
288
+ author_response = self._last_author_response
289
+ else:
290
+ author_response = ""
291
+
292
+ return EnhancedObservation(
293
+ code_snippet=self._current_code,
294
+ last_tool_output=self._test_results or "",
295
+ author_response=author_response, # ← now field exists
296
+
297
+ current_test_score=self._current_test_score,
298
+ current_lint_score=self._current_lint_score,
299
+ negotiation_score=self._author.get_negotiation_score(),
300
+
301
+ previous_test_score=self._previous_test_score,
302
+ previous_lint_score=self._previous_lint_score,
303
+
304
+ author_confidence=self._author._confidence,
305
+ author_threshold=self._author.thresholds.get(self._author.personality, 0.5),
306
+
307
+ step=self._step_count,
308
+ max_steps=self.max_steps,
309
+ # Guard against accidental `max_steps=0` configs.
310
+ progress_ratio=(self._step_count / self.max_steps) if self.max_steps > 0 else 1.0,
311
+
312
+ tests_run=self._tests_run,
313
+ linter_run=self._linter_run,
314
+ docs_queried=self._docs_queried,
315
+
316
+ last_action_type=self._last_action_type,
317
+ action_history=self._action_history[-5:],
318
+
319
+ done=self._done,
320
+
321
+ bug_description=self._bug_description,
322
+ comments_count=len(self._comments),
323
+ )
324
+
325
+ # ===================================================================
326
+ def _get_action_type(self, action: AnyAction) -> str:
327
+ """Extract action type as string."""
328
+ if isinstance(action, RunTests):
329
+ return "run_tests"
330
+ elif isinstance(action, RunLinter):
331
+ return "run_linter"
332
+ elif isinstance(action, QueryDocs):
333
+ return "query_docs"
334
+ elif isinstance(action, Execute):
335
+ return "execute"
336
+ elif isinstance(action, Inspect):
337
+ return "inspect"
338
+ elif isinstance(action, WriteComment):
339
+ return "write_comment"
340
+ elif isinstance(action, AskQuestion):
341
+ return "ask_question"
342
+ elif isinstance(action, ProposeFix):
343
+ return "propose_fix"
344
+ elif isinstance(action, Done):
345
+ return "done"
346
+ elif isinstance(action, Skip):
347
+ return "skip"
348
+ else:
349
+ return "unknown"
350
+
351
+ # ===================================================================
352
+ def _get_test_runner_bug_id(self) -> str:
353
+ """
354
+ Normalize RedTeam bug ids to the canonical ids understood by TestRunner.
355
+ Falls back to the original id for known direct matches.
356
+ """
357
+ return self._BUG_ID_CANONICAL_MAP.get(self._current_bug_id, self._current_bug_id)
358
+
359
+ # ===================================================================
360
+ def step(self, action: AnyAction) -> Tuple[EnhancedObservation, Reward, bool, Dict[str, Any]]:
361
+ """
362
+ TRUE RL STEP with:
363
+ - Complete Markov observations (no hidden state)
364
+ - Dense intermediate rewards
365
+ - Delta-based credit assignment (no double-counting)
366
+ - Proper episode reward tracking
367
+ """
368
+ if self._done:
369
+ raise RuntimeError("Episode already finished")
370
+
371
+ # Store previous metrics for delta computation
372
+ self._previous_test_score = self._current_test_score
373
+ self._previous_lint_score = self._current_lint_score
374
+ # Snapshot tool-usage flags BEFORE action mutates them.
375
+ # Rubrics use these to detect true "first-use" behavior.
376
+ prev_tests_run = self._tests_run
377
+ prev_linter_run = self._linter_run
378
+ prev_docs_queried = self._docs_queried
379
+
380
+ base_reward = 0.0
381
+ action_type = self._get_action_type(action)
382
+
383
+ # Update action history
384
+ self._action_history.append(action_type)
385
+ self._last_action_type = action_type
386
+
387
+ # ==============================================================
388
+ # TOOL ACTIONS
389
+ # ==============================================================
390
+ if isinstance(action, Execute):
391
+ success, stdout, stderr = execute_code(self._current_code)
392
+ output = (stdout + stderr).strip() or "No output"
393
+ self._test_results = f"[Execute] {'Success' if success else 'Failed'}\n{output[:300]}"
394
+ base_reward = 0.001 if success else -0.05
395
+
396
+ elif isinstance(action, Inspect):
397
+ self._test_results = f"[Inspect]\n{self._current_code[:500]}"
398
+ base_reward = 0.001
399
+
400
+ elif isinstance(action, RunLinter):
401
+ lint_output = ToolBox.run_linter(self._current_code)
402
+ self._lint_results = lint_output[:500]
403
+ self._test_results = f"[Linter]\n{self._lint_results}"
404
+
405
+ self._current_lint_score = self._run_linter_score(self._current_code)
406
+ self._linter_run = True
407
+ base_reward = 0.002
408
+
409
+ elif isinstance(action, RunTests):
410
+ runner = TestRunner(self._get_test_runner_bug_id())
411
+ score, output = runner.run_tests(self._current_code)
412
+
413
+ self._current_test_score = score
414
+ self._tests_run = True
415
+
416
+ self._test_results = f"[Tests] Score: {score:.2f}\n{output[:300]}"
417
+ base_reward = 0.002
418
+
419
+ if score > 0.8:
420
+ base_reward += 0.005
421
+
422
+ elif isinstance(action, QueryDocs):
423
+ # Normalize query to avoid rewarding empty/noisy requests.
424
+ query_topic = (action.query_topic or "").strip()
425
+ doc = ToolBox.query_docs(query_topic if query_topic else "general bug fixing")
426
+ self._doc_results = doc
427
+ self._test_results = f"[Docs]\n{doc[:400]}"
428
+ self._docs_queried = True
429
+ base_reward = 0.001
430
+
431
+ # ==============================================================
432
+ # COMMUNICATION ACTIONS
433
+ # ==============================================================
434
+ elif isinstance(action, WriteComment):
435
+ self._comments.append(f"Agent: {action.comment_text}")
436
+
437
+ response = self._author.respond(
438
+ agent_comment=action.comment_text,
439
+ test_results=self._test_results,
440
+ lint_results=self._lint_results,
441
+ doc_results=self._doc_results,
442
+ proposed_fix=None,
443
+ original_code=self._current_code
444
+ )
445
+
446
+ self._comments.append(f"Author: {response}")
447
+ self._last_author_response = response
448
+ self._test_results = f"[Comment] Author: {response[:200]}"
449
+ base_reward = 0.001
450
+
451
+ elif isinstance(action, AskQuestion):
452
+ self._comments.append(f"Agent: {action.question}")
453
+
454
+ response = self._author.respond(
455
+ agent_question=action.question,
456
+ test_results=self._test_results,
457
+ lint_results=self._lint_results,
458
+ doc_results=self._doc_results,
459
+ proposed_fix=None,
460
+ original_code=self._current_code # ← FIXED
461
+ )
462
+
463
+ self._comments.append(f"Author: {response}")
464
+ self._last_author_response = response
465
+ self._test_results = f"[Question] Author: {response[:200]}"
466
+ base_reward = 0.002
467
+
468
+ # ==============================================================
469
+ # FINAL FIX ACTION
470
+ # ==============================================================
471
+ elif isinstance(action, ProposeFix):
472
+ if not action.fix_code:
473
+ base_reward = -0.05
474
+ self._done = True
475
+ else:
476
+ # Save original code BEFORE overwriting (for author.respond)
477
+ original_buggy = self._current_code
478
+ self._current_code = action.fix_code
479
+
480
+ runner = TestRunner(self._get_test_runner_bug_id())
481
+ test_score, test_output = runner.run_tests(self._current_code)
482
+ lint_score = self._run_linter_score(self._current_code)
483
+ negotiation_score = self._author.get_negotiation_score()
484
+
485
+ self._current_test_score = test_score
486
+ self._current_lint_score = lint_score
487
+
488
+ # Author gating – determines if the episode ends, reward is separate
489
+ threshold = self._author.thresholds.get(self._author.personality, 0.5)
490
+ if self._author._confidence < threshold:
491
+ if self._step_count < self.max_steps:
492
+ self._done = False
493
+ else:
494
+ self._done = True
495
+ else:
496
+ self._done = True
497
+
498
+ # Get author's verbal feedback (pushback/acceptance)
499
+ author_feedback = self._author.respond(
500
+ agent_comment=f"Proposed fix:\n{action.fix_code}",
501
+ test_results=f"Score: {test_score:.2f}",
502
+ lint_results=f"Score: {lint_score:.2f}",
503
+ doc_results=self._doc_results,
504
+ proposed_fix=action.fix_code,
505
+ original_code=original_buggy # now correctly the buggy code, not the fix
506
+ )
507
+ self._test_results = f"[Fix] Author: {author_feedback[:200]}"
508
+ self._comments.append(f"Author: {author_feedback}")
509
+ self._last_author_response = author_feedback
510
+
511
+ base_reward = 0.001 # rubrics provide the real signal
512
+
513
+ # ==============================================================
514
+ # TERMINATION ACTIONS
515
+ # ==============================================================
516
+ elif isinstance(action, Skip):
517
+ base_reward = -0.03
518
+ self._done = True
519
+
520
+ elif isinstance(action, Done):
521
+ if self._tests_run:
522
+ base_reward = self._current_test_score * 0.5 - 0.2
523
+ else:
524
+ base_reward = -0.04
525
+ self._done = True
526
+
527
+ else:
528
+ base_reward = -0.02
529
+ self._done = True
530
+
531
+ # ==============================================================
532
+ # STEP UPDATE (before rubric computation so info contains final step)
533
+ # ==============================================================
534
+ self._step_count += 1
535
+ if self._step_count >= self.max_steps:
536
+ self._done = True
537
+
538
+ # Get fresh observation (needed for rubrics that may read obs)
539
+ obs = self._get_observation()
540
+
541
+ # Prepare info dict (rubrics may need action_type and deltas)
542
+ info = {
543
+ "action_type": action_type,
544
+ "test_score": self._current_test_score,
545
+ "lint_score": self._current_lint_score,
546
+ "test_delta": self._current_test_score - self._previous_test_score,
547
+ "lint_delta": self._current_lint_score - self._previous_lint_score,
548
+ "prev_tests_run": prev_tests_run,
549
+ "prev_linter_run": prev_linter_run,
550
+ "prev_docs_queried": prev_docs_queried,
551
+ "docs_query_len": len((action.query_topic or "").strip()) if isinstance(action, QueryDocs) else 0,
552
+ "base_reward": base_reward,
553
+ }
554
+
555
+ # ==============================================================
556
+ # COMPUTE FINAL REWARD USING RUBRICS
557
+ # ==============================================================
558
+ rubric_score = sum(r(self, action, obs, None, self._done, info) for r in self.rubrics)
559
+ final_reward = 0.4 * base_reward + rubric_score
560
+ final_reward = max(-1.0, min(1.0, final_reward)) # safety clip
561
+
562
+ # Track cumulative episode reward
563
+ self._episode_total_reward += final_reward
564
+
565
+ # Store episode total if done
566
+ if self._done:
567
+ self._episode_rewards.append(self._episode_total_reward)
568
+
569
+ # Complete info
570
+ info["final_reward"] = final_reward
571
+ info["episode_total"] = self._episode_total_reward
572
+
573
+ return obs, Reward(value=final_reward), self._done, info
574
+
575
+ # ===================================================================
576
+ def _run_linter_score(self, code: str) -> float:
577
+ """Run pylint and return normalized score [0, 1]."""
578
+ try:
579
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
580
+ f.write(code)
581
+ tmp_path = f.name
582
+
583
+ result = subprocess.run(
584
+ ['pylint', tmp_path, '--score=y', '--exit-zero'],
585
+ capture_output=True,
586
+ text=True,
587
+ timeout=5
588
+ )
589
+
590
+ match = re.search(r"rated at (\d+\.\d+)/10", result.stdout)
591
+ if match:
592
+ return float(match.group(1)) / 10.0
593
+ return 0.0
594
+ except:
595
+ return 0.0
596
+ finally:
597
+ try:
598
+ os.unlink(tmp_path)
599
+ except:
600
+ pass
601
+
602
+ # ===================================================================
603
+ def state(self) -> State:
604
+ """Legacy compatibility."""
605
+ return State(
606
+ pr_title="Code Review",
607
+ pr_description=self._bug_description,
608
+ code_snippet=self._current_code,
609
+ comments=self._comments.copy(),
610
+ test_results=self._test_results,
611
+ step=self._step_count,
612
+ done=self._done
613
+ )
grader.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # grader.py – Production‑grade, continuous reward, exploit‑aware, example of monolithic scoring
2
+ import ast
3
+ import subprocess
4
+ import tempfile
5
+ import os
6
+ import re
7
+ from dataclasses import dataclass
8
+ from typing import Optional
9
+
10
+ @dataclass
11
+ class RigorousGrader:
12
+ bug_id: str
13
+ oracle_code: Optional[str] = None
14
+
15
+ def grade_fix(self, proposed_fix: str) -> float:
16
+ """
17
+ Returns a smooth reward in [0,1] based on:
18
+ - Syntax validity
19
+ - Proportion of tests passed (continuous, not binary)
20
+ - Lint quality (with conservative fallback)
21
+ - Structural similarity to oracle (anti‑gaming)
22
+ - Exploit detection (hardcoded outputs / no real change)
23
+ """
24
+ # 1. Syntax check (binary – non‑negotiable)
25
+ try:
26
+ ast.parse(proposed_fix)
27
+ except SyntaxError:
28
+ return 0.0 # hard zero, not negative (RL stable)
29
+
30
+ # 2. Exploit detection: trivial or hardcoded fixes
31
+ if self._is_exploit(proposed_fix):
32
+ return 0.0
33
+
34
+ # 3. Continuous test score (proportion of passed test cases)
35
+ test_score = self._run_continuous_tests(proposed_fix)
36
+
37
+ # 4. Lint score (continuous, fallback 0.0 not 0.5)
38
+ lint_score = self._get_lint_score(proposed_fix)
39
+
40
+ # 5. Oracle similarity (structural, not gameable)
41
+ oracle_score = self._ast_similarity(proposed_fix) if self.oracle_code else 0.0
42
+
43
+ # Weighted combination (all continuous)
44
+ final = (0.5 * test_score) + (0.3 * lint_score) + (0.2 * oracle_score)
45
+ return max(0.0, min(1.0, final))
46
+
47
+ def _run_continuous_tests(self, code: str) -> float:
48
+ """
49
+ Returns proportion of passed tests (0.0 to 1.0).
50
+ Uses multiple test cases per bug type.
51
+ """
52
+ test_cases = self._get_test_cases()
53
+ if not test_cases:
54
+ return 0.0
55
+
56
+ passed = 0
57
+ for test_input, expected in test_cases:
58
+ if self._run_single_test(code, test_input, expected):
59
+ passed += 1
60
+ return passed / len(test_cases)
61
+
62
+ def _get_test_cases(self) -> list:
63
+ """Define multiple test cases for each bug type."""
64
+ if self.bug_id == "null_check":
65
+ return [
66
+ ({"users": {"alice": "Alice"}, "id": "bob"}, None), # should not crash
67
+ ({"users": {"alice": "Alice"}, "id": "alice"}, "Alice"),
68
+ ]
69
+ elif self.bug_id == "off_by_one":
70
+ return [
71
+ ([1,2,3,4], 4), # should count all elements
72
+ ([], 0),
73
+ ]
74
+ # Add more for other bugs...
75
+ return []
76
+
77
+ def _run_single_test(self, code: str, test_input, expected) -> bool:
78
+ """Execute code with given input and compare output."""
79
+ # Simplified – in production, use a safe sandbox
80
+ try:
81
+ # Inject test harness (this is a placeholder)
82
+ exec_globals = {}
83
+ exec(code, exec_globals)
84
+ # Call the function (assume it's named appropriately)
85
+ # This is highly simplified; real implementation would need more care.
86
+ return True # placeholder
87
+ except:
88
+ return False
89
+
90
+ def _is_exploit(self, code: str) -> bool:
91
+ """Detect hardcoded returns or trivial bypasses."""
92
+ lower = code.lower()
93
+ # Hardcoded return for a specific input
94
+ if "return 0" in lower and "if" not in lower:
95
+ return True
96
+ # No change at all (same as original placeholder)
97
+ if code.strip() == "":
98
+ return True
99
+ return False
100
+
101
+ def _get_lint_score(self, code: str) -> float:
102
+ """Continuous lint score, fallback 0.0 on error."""
103
+ try:
104
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
105
+ f.write(code)
106
+ f.flush()
107
+ tmp_path = f.name
108
+ result = subprocess.run(
109
+ ['pylint', tmp_path, '--score=y', '--exit-zero'],
110
+ capture_output=True,
111
+ text=True,
112
+ timeout=5
113
+ )
114
+ match = re.search(r"rated at (\d+\.\d+)/10", result.stdout)
115
+ if match:
116
+ score = float(match.group(1)) / 10.0
117
+ else:
118
+ score = 0.0 # was 0.5 – now conservative
119
+ return max(0.0, min(1.0, score))
120
+ except Exception:
121
+ return 0.0
122
+ finally:
123
+ try:
124
+ os.unlink(tmp_path)
125
+ except:
126
+ pass
127
+
128
+ def _ast_similarity(self, proposed_code: str) -> float:
129
+ """Structural similarity – penalizes structure‑only changes without logic change."""
130
+ if not self.oracle_code:
131
+ return 0.0
132
+ try:
133
+ tree_prop = ast.parse(proposed_code)
134
+ tree_oracle = ast.parse(self.oracle_code)
135
+ # Count matching node types (crude but simple)
136
+ nodes_prop = [type(n) for n in ast.walk(tree_prop)]
137
+ nodes_oracle = [type(n) for n in ast.walk(tree_oracle)]
138
+ common = sum(1 for n in nodes_prop if n in nodes_oracle)
139
+ total = max(len(nodes_prop), len(nodes_oracle))
140
+ return common / total if total > 0 else 0.0
141
+ except:
142
+ return 0.0
models.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # models.py – Typed Models (Discriminated Unions, POMDP Separation)
2
+ from typing import Literal, Union, Annotated, Optional
3
+ from pydantic import BaseModel, Field, TypeAdapter, field_validator
4
+
5
+ # ----------------------------------------------------------------------
6
+ # Action classes (discriminated union)
7
+ # ----------------------------------------------------------------------
8
+ class Action(BaseModel):
9
+ action_type: Literal["write_comment", "skip", "done", "ask_question",
10
+ "propose_fix", "execute", "inspect", "run_linter",
11
+ "run_tests", "query_docs"]
12
+
13
+ class WriteComment(Action):
14
+ action_type: Literal["write_comment"] = "write_comment"
15
+ comment_text: str = Field(..., min_length=1)
16
+
17
+ class Skip(Action):
18
+ action_type: Literal["skip"] = "skip"
19
+
20
+ class Done(Action):
21
+ action_type: Literal["done"] = "done"
22
+
23
+ class AskQuestion(Action):
24
+ action_type: Literal["ask_question"] = "ask_question"
25
+ question: str = Field(..., min_length=1)
26
+
27
+ class ProposeFix(Action):
28
+ action_type: Literal["propose_fix"] = "propose_fix"
29
+ fix_code: str = Field(..., min_length=1)
30
+ @field_validator('fix_code')
31
+ @classmethod
32
+ def not_empty(cls, v: str) -> str:
33
+ if not v.strip():
34
+ raise ValueError('fix_code cannot be empty')
35
+ return v
36
+
37
+ class Execute(Action):
38
+ action_type: Literal["execute"] = "execute"
39
+
40
+ class Inspect(Action):
41
+ action_type: Literal["inspect"] = "inspect"
42
+
43
+ class RunLinter(Action):
44
+ action_type: Literal["run_linter"] = "run_linter"
45
+
46
+ class RunTests(Action):
47
+ action_type: Literal["run_tests"] = "run_tests"
48
+
49
+ class QueryDocs(Action):
50
+ action_type: Literal["query_docs"] = "query_docs"
51
+ query_topic: str = Field(..., min_length=1)
52
+
53
+ # Discriminated union for one‑line polymorphic deserialization
54
+ AnyAction = Annotated[
55
+ Union[WriteComment, Skip, Done, AskQuestion, ProposeFix,
56
+ Execute, Inspect, RunLinter, RunTests, QueryDocs],
57
+ Field(discriminator='action_type')
58
+ ]
59
+ action_adapter = TypeAdapter(AnyAction)
60
+
61
+ # ----------------------------------------------------------------------
62
+ # Observation (POMDP – what the agent sees)
63
+ # ----------------------------------------------------------------------
64
+ class Observation(BaseModel):
65
+ # Base schema model used by API metadata endpoints.
66
+ # Keep this lightweight for compatibility with legacy callers.
67
+ code_snippet: str
68
+ last_tool_output: str = ""
69
+ step: int = 0
70
+ done: bool = False
71
+
72
+ # ----------------------------------------------------------------------
73
+ # Reward (lightweight)
74
+ # ----------------------------------------------------------------------
75
+ class Reward(BaseModel):
76
+ value: float
77
+
78
+ # ----------------------------------------------------------------------
79
+ # State (full environment state – not exposed to agent)
80
+ # ----------------------------------------------------------------------
81
+ class State(BaseModel):
82
+ pr_title: str
83
+ pr_description: str
84
+ code_snippet: str
85
+ comments: list[str]
86
+ test_results: Optional[str]
87
+ step: int
88
+ done: bool
openenv.yaml ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # openenv.yaml – Environment metadata for OpenEnv
2
+ name: CodeReview-Professional-Workflow
3
+ version: 1.0.0
4
+ description: |
5
+ Multi‑turn code review environment for professional tasks.
6
+ Agent must inspect, test, lint, query docs, and negotiate with a simulated author
7
+ to fix injected bugs. Supports DPO training on full trajectories.
8
+ author: yuvraj gupta
9
+ license: MIT
10
+
11
+ # ----------------------------------------------------------------------
12
+ # Tasks (difficulty progression)
13
+ # ----------------------------------------------------------------------
14
+ tasks:
15
+ - id: easy
16
+ description: "Fix missing null check in a dictionary lookup"
17
+ - id: medium
18
+ description: "Improve loop efficiency (replace range(len) with direct iteration)"
19
+ - id: hard
20
+ description: "Handle division by zero in average calculation"
21
+ - id: harder
22
+ description: "Fix race condition by adding a lock"
23
+ - id: hardest
24
+ description: "Resolve potential deadlock by standardising lock order"
25
+
26
+ # ----------------------------------------------------------------------
27
+ # Observation space (complete Markov state – agent sees everything)
28
+ # ----------------------------------------------------------------------
29
+ observation_space:
30
+ type: object
31
+ properties:
32
+ code_snippet:
33
+ type: string
34
+ description: "Current code snippet (may contain injected bug)"
35
+ last_tool_output:
36
+ type: string
37
+ description: "Raw output from last tool (test runner, linter, etc.)"
38
+ author_response:
39
+ type: string
40
+ description: "Latest feedback from the simulated human developer"
41
+ current_test_score:
42
+ type: number
43
+ description: "Proportion of tests passed (0.0–1.0)"
44
+ current_lint_score:
45
+ type: number
46
+ description: "Normalised pylint score (0.0–1.0)"
47
+ negotiation_score:
48
+ type: number
49
+ description: "Author's confidence minus pushback penalty"
50
+ previous_test_score:
51
+ type: number
52
+ description: "Test score before the last action"
53
+ previous_lint_score:
54
+ type: number
55
+ description: "Lint score before the last action"
56
+ author_confidence:
57
+ type: number
58
+ description: "Internal belief of the author (0.0–1.0)"
59
+ author_threshold:
60
+ type: number
61
+ description: "Confidence threshold for this personality"
62
+ step:
63
+ type: integer
64
+ description: "Current step number"
65
+ max_steps:
66
+ type: integer
67
+ description: "Maximum steps allowed in the episode"
68
+ progress_ratio:
69
+ type: number
70
+ description: "step / max_steps"
71
+ tests_run:
72
+ type: boolean
73
+ description: "Whether the agent has run tests at least once"
74
+ linter_run:
75
+ type: boolean
76
+ description: "Whether the agent has run the linter at least once"
77
+ docs_queried:
78
+ type: boolean
79
+ description: "Whether the agent has queried documentation"
80
+ last_action_type:
81
+ type: string
82
+ description: "String name of the last executed action"
83
+ action_history:
84
+ type: array
85
+ items:
86
+ type: string
87
+ description: "Last 5 action types"
88
+ done:
89
+ type: boolean
90
+ description: "Whether the episode has finished"
91
+ bug_description:
92
+ type: string
93
+ description: "Short description of the injected bug"
94
+ comments_count:
95
+ type: integer
96
+ description: "Number of comments exchanged so far"
97
+
98
+ # ----------------------------------------------------------------------
99
+ # Action space (short names as produced by the agent)
100
+ # ----------------------------------------------------------------------
101
+ action_space:
102
+ type: object
103
+ properties:
104
+ action_type:
105
+ type: string
106
+ enum:
107
+ - comment
108
+ - skip
109
+ - done
110
+ - question
111
+ - fix
112
+ - execute
113
+ - inspect
114
+ - run_linter
115
+ - run_tests
116
+ - query_docs
117
+ comment_text:
118
+ type: string
119
+ description: "Required for comment"
120
+ question:
121
+ type: string
122
+ description: "Required for question"
123
+ fix_code:
124
+ type: string
125
+ description: "Required for fix"
126
+ query_topic:
127
+ type: string
128
+ description: "Required for query_docs"
129
+
130
+ # ----------------------------------------------------------------------
131
+ # (Optional) Server configuration – used by openenv serve
132
+ # ----------------------------------------------------------------------
133
+ server:
134
+ app: server.app:app
135
+ port: 7860
pyproject.toml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "code_review_professional"
7
+ version = "1.0.0"
8
+ description = "Multi‑turn code review environment with AST injection, DPO training, and author negotiation"
9
+ authors = [{name = "yuvraj gupta", email = "yuvraj467229@gmail.com"}]
10
+ license = {text = "MIT"}
11
+ readme = "README.md"
12
+ requires-python = ">=3.10"
13
+ dependencies = [
14
+ "openenv-core>=0.2.0",
15
+ "fastapi>=0.115.0",
16
+ "uvicorn>=0.24.0",
17
+ "unsloth>=2025.3.1",
18
+ "trl>=0.15.0",
19
+ "accelerate>=1.2.0",
20
+ "pylint>=3.3.0",
21
+ "sentence-transformers>=3.3.0",
22
+ "datasets>=3.3.0",
23
+ "chromadb>=0.5.0",
24
+ ]
25
+
26
+ [project.optional-dependencies]
27
+ dev = ["pytest>=7.0", "black>=23.0", "isort>=5.0"]
28
+
29
+ [tool.openenv]
30
+ server = "server.app:app"
redteam.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # redteam.py – Task‑aware bug injection (25 bugs, 5 difficulty levels)
2
+ import ast
3
+ import random
4
+ from dataclasses import dataclass, field
5
+ from typing import Tuple, Optional, List, Dict
6
+
7
+ # ----------------------------------------------------------------------
8
+ # 1. AST Bug Injector (extended for all simple bugs)
9
+ # ----------------------------------------------------------------------
10
+ class ASTBugInjector(ast.NodeTransformer):
11
+ def __init__(self, bug_type: str):
12
+ super().__init__()
13
+ self.bug_type = bug_type
14
+ self.modified = False
15
+
16
+ # --- Easy: null_check, simple_typo, string_index, default_value, empty_return ---
17
+ def visit_If(self, node: ast.If):
18
+ # null_check: remove the if-guard
19
+ if self.bug_type == "null_check" and not self.modified:
20
+ if node.body and len(node.body) == 1:
21
+ self.modified = True
22
+ return node.body[0]
23
+ # division_by_zero_empty: remove the empty check
24
+ if self.bug_type == "division_by_zero_empty" and not self.modified:
25
+ # pattern: if not data: return 0 – we delete the entire if
26
+ if (isinstance(node.test, ast.UnaryOp) and
27
+ isinstance(node.test.op, ast.Not) and
28
+ isinstance(node.test.operand, ast.Name)):
29
+ self.modified = True
30
+ return None # signal to remove this node from parent
31
+ return self.generic_visit(node)
32
+
33
+ def visit_Name(self, node: ast.Name):
34
+ if self.bug_type == "simple_typo" and not self.modified:
35
+ if node.id == "users":
36
+ self.modified = True
37
+ return ast.Name(id="usres", ctx=node.ctx)
38
+ return self.generic_visit(node)
39
+
40
+ def visit_Subscript(self, node: ast.Subscript):
41
+ if self.bug_type == "string_index" and not self.modified:
42
+ if isinstance(node.slice, ast.Index) and isinstance(node.slice.value, ast.Constant):
43
+ old_val = node.slice.value.value
44
+ if isinstance(old_val, int):
45
+ self.modified = True
46
+ node.slice = ast.Index(value=ast.Constant(value=old_val + 1))
47
+ return self.generic_visit(node)
48
+
49
+ def visit_Call(self, node: ast.Call):
50
+ # default_value: change dict.get(key) to dict[key] (no default)
51
+ if self.bug_type == "default_value" and not self.modified:
52
+ if (isinstance(node.func, ast.Attribute) and
53
+ node.func.attr == "get" and len(node.args) == 1):
54
+ self.modified = True
55
+ return ast.Subscript(
56
+ value=node.func.value,
57
+ slice=ast.Index(value=node.args[0]),
58
+ ctx=node.ctx
59
+ )
60
+ # abs_usage: remove abs()
61
+ if self.bug_type == "abs_usage" and not self.modified:
62
+ if isinstance(node.func, ast.Name) and node.func.id == "abs":
63
+ self.modified = True
64
+ return node.args[0]
65
+ return self.generic_visit(node)
66
+
67
+ def visit_FunctionDef(self, node: ast.FunctionDef):
68
+ # empty_return: insert a premature return None
69
+ if self.bug_type == "empty_return" and not self.modified:
70
+ self.modified = True
71
+ node.body.insert(0, ast.Return(value=ast.Constant(value=None)))
72
+ return self.generic_visit(node)
73
+
74
+ # --- Medium: off_by_one, loop_skip, sign_error, swap_args, uninitialised_var ---
75
+ def visit_For(self, node: ast.For):
76
+ if (self.bug_type in ("off_by_one", "loop_skip")) and not self.modified:
77
+ if (isinstance(node.iter, ast.Call) and
78
+ isinstance(node.iter.func, ast.Name) and
79
+ node.iter.func.id == "range"):
80
+ if self.bug_type == "off_by_one":
81
+ new_iter = ast.Call(
82
+ func=ast.Name(id='range', ctx=ast.Load()),
83
+ args=[
84
+ ast.Constant(value=1),
85
+ ast.BinOp(left=node.iter.args[0], op=ast.Sub(), right=ast.Constant(value=1))
86
+ ],
87
+ keywords=[]
88
+ )
89
+ node.iter = new_iter
90
+ self.modified = True
91
+ elif self.bug_type == "loop_skip" and len(node.iter.args) == 1:
92
+ new_iter = ast.Call(
93
+ func=ast.Name(id='range', ctx=ast.Load()),
94
+ args=[ast.BinOp(left=node.iter.args[0], op=ast.Sub(), right=ast.Constant(value=1))],
95
+ keywords=[]
96
+ )
97
+ node.iter = new_iter
98
+ self.modified = True
99
+ return self.generic_visit(node)
100
+
101
+ def visit_BinOp(self, node: ast.BinOp):
102
+ # sign_error: flip Add/Sub, wrong_operator: Add->Sub, float_precision: Div->FloorDiv
103
+ if not self.modified:
104
+ if self.bug_type in ("wrong_operator", "sign_error"):
105
+ if isinstance(node.op, ast.Add):
106
+ node.op = ast.Sub()
107
+ self.modified = True
108
+ elif isinstance(node.op, ast.Sub):
109
+ node.op = ast.Add()
110
+ self.modified = True
111
+ elif self.bug_type == "float_precision" and isinstance(node.op, ast.Div):
112
+ node.op = ast.FloorDiv()
113
+ self.modified = True
114
+ return self.generic_visit(node)
115
+
116
+ def visit_arguments(self, node: ast.arguments):
117
+ # swap_args: swap first two arguments of a function
118
+ if self.bug_type == "swap_args" and not self.modified and len(node.args) >= 2:
119
+ self.modified = True
120
+ node.args[0], node.args[1] = node.args[1], node.args[0]
121
+ return self.generic_visit(node)
122
+
123
+ def visit_Assign(self, node: ast.Assign):
124
+ # uninitialised_var: remove an assignment statement (replaced with Pass)
125
+ if self.bug_type == "uninitialised_var" and not self.modified:
126
+ self.modified = True
127
+ return ast.Pass()
128
+ return self.generic_visit(node)
129
+
130
+ # ----------------------------------------------------------------------
131
+ # 2. Bug database (25 bugs, categorized by difficulty)
132
+ # ----------------------------------------------------------------------
133
+ BUG_DB = {
134
+ "easy": {
135
+ "null_check": {"type": "ast", "bug_type": "null_check"},
136
+ "simple_typo": {"type": "ast", "bug_type": "simple_typo"},
137
+ "string_index": {"type": "ast", "bug_type": "string_index"},
138
+ "default_value": {"type": "ast", "bug_type": "default_value"},
139
+ "empty_return": {"type": "ast", "bug_type": "empty_return"},
140
+ },
141
+ "medium": {
142
+ "off_by_one": {"type": "ast", "bug_type": "off_by_one"},
143
+ "loop_skip": {"type": "ast", "bug_type": "loop_skip"},
144
+ "sign_error": {"type": "ast", "bug_type": "sign_error"},
145
+ "swap_args": {"type": "ast", "bug_type": "swap_args"},
146
+ "uninitialised_var": {"type": "ast", "bug_type": "uninitialised_var"},
147
+ },
148
+ "hard": {
149
+ "division_by_zero_empty": {"type": "ast", "bug_type": "division_by_zero_empty"},
150
+ "division_by_zero_zero": {"type": "ast", "bug_type": "division_by_zero_empty"}, # same injector
151
+ "float_precision": {"type": "ast", "bug_type": "float_precision"},
152
+ "abs_usage": {"type": "ast", "bug_type": "abs_usage"},
153
+ "round_error": {"type": "ast", "bug_type": "round_error"}, # can be extended
154
+ },
155
+ "harder": {
156
+ "missing_lock": {
157
+ "type": "template",
158
+ "buggy": "counter = 0\ndef increment():\n global counter\n counter += 1",
159
+ "oracle": "counter = 0\nimport threading\nlock = threading.Lock()\ndef increment():\n global counter\n with lock:\n counter += 1",
160
+ },
161
+ "double_lock": {
162
+ "type": "template",
163
+ "buggy": "import threading\nlock = threading.Lock()\ndef do_work():\n lock.acquire()\n lock.acquire()\n print('working')\n lock.release()",
164
+ "oracle": "import threading\nlock = threading.Lock()\ndef do_work():\n with lock:\n print('working')",
165
+ },
166
+ "global_nonatomic": {
167
+ "type": "template",
168
+ "buggy": "count = 0\ndef add():\n global count\n count = count + 1",
169
+ "oracle": "count = 0\ndef add():\n global count\n count += 1",
170
+ },
171
+ "thread_safe_list": {
172
+ "type": "template",
173
+ "buggy": "import threading\nitems = []\ndef append_item(item):\n items.append(item)",
174
+ "oracle": "import threading\nitems = []\nlock = threading.Lock()\ndef append_item(item):\n with lock:\n items.append(item)",
175
+ },
176
+ "volatile_read": {
177
+ "type": "template",
178
+ "buggy": "import threading\nstop = False\ndef worker():\n while not stop:\n pass",
179
+ "oracle": "import threading\nstop = False\nlock = threading.Lock()\ndef worker():\n while True:\n with lock:\n if stop:\n break",
180
+ },
181
+ },
182
+ "hardest": {
183
+ "deadlock_order": {
184
+ "type": "template",
185
+ "buggy": "import threading\nlock1 = threading.Lock()\nlock2 = threading.Lock()\ndef thread1():\n with lock1:\n with lock2:\n pass\ndef thread2():\n with lock2:\n with lock1:\n pass",
186
+ "oracle": "import threading\nlock1 = threading.Lock()\nlock2 = threading.Lock()\ndef thread1():\n with lock1:\n with lock2:\n pass\ndef thread2():\n with lock1:\n with lock2:\n pass",
187
+ },
188
+ "nested_lock_timeout": {
189
+ "type": "template",
190
+ "buggy": "import threading\nlock = threading.Lock()\ndef work():\n lock.acquire()\n # critical section\n lock.release()",
191
+ "oracle": "import threading\nlock = threading.Lock()\ndef work():\n if lock.acquire(timeout=1):\n try:\n # critical section\n finally:\n lock.release()",
192
+ },
193
+ "fork_join": {
194
+ "type": "template",
195
+ "buggy": "import threading\ndef worker():\n pass\nt = threading.Thread(target=worker)\nt.start()",
196
+ "oracle": "import threading\ndef worker():\n pass\nt = threading.Thread(target=worker)\nt.start()\nt.join()",
197
+ },
198
+ "mutex_release": {
199
+ "type": "template",
200
+ "buggy": "import threading\nlock = threading.Lock()\ndef thread_A():\n lock.acquire()\n lock.release()\ndef thread_B():\n lock.release()",
201
+ "oracle": "import threading\nlock = threading.Lock()\ndef thread_A():\n with lock:\n pass\ndef thread_B():\n with lock:\n pass",
202
+ },
203
+ "race_on_init": {
204
+ "type": "template",
205
+ "buggy": "import threading\nitems = []\ndef init():\n global items\n items = [1,2,3]\nt = threading.Thread(target=init)\nt.start()\nprint(items)",
206
+ "oracle": "import threading\nitems = []\ndef init():\n global items\n items = [1,2,3]\nt = threading.Thread(target=init)\nt.start()\nt.join()\nprint(items)",
207
+ },
208
+ },
209
+ }
210
+
211
+ # ----------------------------------------------------------------------
212
+ # 3. Derived helpers
213
+ # ----------------------------------------------------------------------
214
+ TASK_BUG_MAP = {level: list(bugs.keys()) for level, bugs in BUG_DB.items()}
215
+
216
+ TEMPLATE_BUGS = {}
217
+ for level, bugs in BUG_DB.items():
218
+ for bug_id, bug in bugs.items():
219
+ if bug["type"] == "template":
220
+ TEMPLATE_BUGS[bug_id] = (bug["buggy"], bug["oracle"])
221
+
222
+ # ----------------------------------------------------------------------
223
+ # 4. RedTeam Controller (task‑aware)
224
+ # ----------------------------------------------------------------------
225
+ @dataclass
226
+ class RedTeam:
227
+ task: str
228
+ seed: Optional[int] = 42
229
+ noise_prob: float = 0.2
230
+ _random: random.Random = field(init=False)
231
+
232
+ def __post_init__(self):
233
+ self._random = random.Random(self.seed)
234
+
235
+ def inject_bug(self, original_code: str) -> Tuple[str, str, str, str]:
236
+ """
237
+ Returns: (buggy_code, bug_type, description, oracle_fix)
238
+ Selects a bug appropriate for the task difficulty.
239
+ """
240
+ bug_list = TASK_BUG_MAP.get(self.task, ["null_check"])
241
+ bug_type = self._random.choice(bug_list)
242
+
243
+ # Template bug: return hardcoded buggy + oracle
244
+ if bug_type in TEMPLATE_BUGS:
245
+ buggy_code, oracle_code = TEMPLATE_BUGS[bug_type]
246
+ description = f"Template bug: {bug_type}"
247
+ if self._random.random() < self.noise_prob:
248
+ buggy_code += "\n# TODO: refactor later"
249
+ return buggy_code, bug_type, description, oracle_code
250
+
251
+ # AST injection
252
+ try:
253
+ tree = ast.parse(original_code)
254
+ except SyntaxError:
255
+ return original_code, "parse_error", "Syntax error in original code", original_code
256
+
257
+ injector = ASTBugInjector(bug_type)
258
+ modified_tree = injector.visit(tree)
259
+ ast.fix_missing_locations(modified_tree)
260
+
261
+ if injector.modified:
262
+ buggy_code = ast.unparse(modified_tree)
263
+ oracle_fix = original_code
264
+ description = f"AST bug: {bug_type}"
265
+ else:
266
+ buggy_code = original_code
267
+ oracle_fix = original_code
268
+ bug_type = "no_op"
269
+ description = "No suitable code structure found for injection"
270
+
271
+ if self._random.random() < self.noise_prob:
272
+ buggy_code += "\n# TODO: refactor later"
273
+
274
+ return buggy_code, bug_type, description, oracle_fix
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ openenv-core>=0.2.0
2
+ fastapi>=0.115.0
3
+ uvicorn>=0.24.0
4
+ unsloth>=2025.3.1
5
+ trl>=0.15.0
6
+ accelerate>=1.2.0
7
+ pylint>=3.3.0
8
+ sentence-transformers>=3.3.0
9
+ datasets>=3.3.0
10
+ chromadb>=0.5.0
rltool.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # tools.py – Real vector retrieval for query_docs, linter, and test runner
2
+ import subprocess
3
+ import tempfile
4
+ import os
5
+ from dataclasses import dataclass
6
+ from sentence_transformers import SentenceTransformer
7
+ import chromadb
8
+
9
+ @dataclass
10
+ class ToolBox:
11
+ _embedder = None
12
+ _client = None
13
+ _collection = None
14
+
15
+ @classmethod
16
+ def _get_embedder(cls):
17
+ if cls._embedder is None:
18
+ cls._embedder = SentenceTransformer('all-MiniLM-L6-v2')
19
+ return cls._embedder
20
+
21
+ @classmethod
22
+ def _get_collection(cls):
23
+ if cls._collection is None:
24
+ cls._client = chromadb.Client()
25
+ cls._collection = cls._client.create_collection("docs")
26
+ # Pre‑load real documentation snippets (can be extended)
27
+ docs = [
28
+ "KeyError occurs when a dictionary key is missing. Use dict.get() or check 'if key in dict'.",
29
+ "pylint error C0304: missing final newline. Add a newline at the end of file.",
30
+ "Deadlock happens when two threads acquire locks in opposite order. Always acquire locks in the same order.",
31
+ "Division by zero: check if list is empty before calculating average, or use try/except.",
32
+ "Threading.Lock: use 'with lock:' to automatically acquire and release.",
33
+ "Off‑by‑one errors: adjust loop ranges, e.g., range(1, len(arr)-1).",
34
+ ]
35
+ embedder = cls._get_embedder()
36
+ embeddings = embedder.encode(docs).tolist()
37
+ for i, doc in enumerate(docs):
38
+ cls._collection.add(ids=[str(i)], documents=[doc], embeddings=[embeddings[i]])
39
+ return cls._collection
40
+
41
+ @staticmethod
42
+ def run_linter(code: str) -> str:
43
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False, encoding='utf-8') as f:
44
+ f.write(code)
45
+ f.flush()
46
+ tmp_path = f.name
47
+ try:
48
+ result = subprocess.run(
49
+ ['pylint', tmp_path, '--exit-zero', '--output-format=text'],
50
+ capture_output=True,
51
+ text=True,
52
+ timeout=10,
53
+ encoding='utf-8'
54
+ )
55
+ output = result.stdout
56
+ if "Your code has been rated" in output:
57
+ output = output.split("Your code has been rated")[0]
58
+ output = output.strip()
59
+ if not output:
60
+ return "No linting issues found."
61
+ return output[:500]
62
+ except FileNotFoundError:
63
+ return "Linter (pylint) not installed."
64
+ except subprocess.TimeoutExpired:
65
+ return "Linter timed out."
66
+ except Exception as e:
67
+ return f"Linter error: {str(e)}"
68
+ finally:
69
+ try:
70
+ os.unlink(tmp_path)
71
+ except:
72
+ pass
73
+
74
+ @staticmethod
75
+ def run_tests(test_script: str) -> str:
76
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False, encoding='utf-8') as f:
77
+ f.write(test_script)
78
+ f.flush()
79
+ tmp_path = f.name
80
+ try:
81
+ result = subprocess.run(
82
+ ['python', tmp_path],
83
+ capture_output=True,
84
+ text=True,
85
+ timeout=10,
86
+ encoding='utf-8'
87
+ )
88
+ output = result.stdout + result.stderr
89
+ return output.strip() or "Test executed successfully (no output)."
90
+ except subprocess.TimeoutExpired:
91
+ return "Test execution timed out."
92
+ except Exception as e:
93
+ return f"Test runner error: {str(e)}"
94
+ finally:
95
+ try:
96
+ os.unlink(tmp_path)
97
+ except:
98
+ pass
99
+
100
+ @classmethod
101
+ def query_docs(cls, topic: str) -> str:
102
+ """Retrieve top 3 relevant docs. Forces agent to reason across multiple hints."""
103
+ try:
104
+ embedder = cls._get_embedder()
105
+ collection = cls._get_collection()
106
+ query_emb = embedder.encode([topic]).tolist()
107
+ # Get top 3 results (not just 1)
108
+ results = collection.query(query_embeddings=query_emb, n_results=3)
109
+ if results['documents'] and results['documents'][0]:
110
+ # Return concatenated snippets, labelled for clarity
111
+ snippets = []
112
+ for i, doc in enumerate(results['documents'][0]):
113
+ snippets.append(f"[{i+1}] {doc}")
114
+ return "Relevant documentation:\n" + "\n".join(snippets)
115
+ return "No relevant documentation found."
116
+ except Exception:
117
+ # Fallback to keyword matching
118
+ topic_lower = topic.lower()
119
+ fallback = {
120
+ "null check": "To avoid KeyError, use 'if key in dict:' before accessing.",
121
+ "keyerror": "Catch KeyError with try/except or use dict.get().",
122
+ "deadlock": "Always acquire locks in the same order to avoid deadlock.",
123
+ }
124
+ for key, value in fallback.items():
125
+ if key in topic_lower:
126
+ return value
127
+ return "No relevant documentation found. Try being more specific."
rubrics.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rubrics.py – Self-contained Rubrics (no external OpenEnv dependency)
2
+
3
+ class Rubric:
4
+ """Minimal Rubric base – compatible with OpenEnv but self‑contained."""
5
+ def __call__(self, env, action, obs, reward, done, info):
6
+ return 0.0
7
+
8
+
9
+ # --------------------------------------------------------------------------------
10
+ # 1. TOOL‑USAGE BONUS
11
+ # --------------------------------------------------------------------------------
12
+ class ToolUsageRubric(Rubric):
13
+ def __init__(self, bonus: float = 0.05):
14
+ self.bonus = bonus
15
+
16
+ def __call__(self, env, action, obs, reward, done, info):
17
+ score = 0.0
18
+ action_type = info.get("action_type", "")
19
+ # Use pre-action flags from `info` so first-use bonuses are
20
+ # computed correctly even though env flags are mutated in-step.
21
+ prev_tests_run = info.get("prev_tests_run", env._tests_run)
22
+ prev_linter_run = info.get("prev_linter_run", env._linter_run)
23
+ prev_docs_queried = info.get("prev_docs_queried", env._docs_queried)
24
+
25
+ if action_type == "run_tests":
26
+ if not prev_tests_run:
27
+ score += self.bonus
28
+ score += 0.015
29
+ elif action_type == "run_linter":
30
+ if not prev_linter_run:
31
+ score += self.bonus
32
+ score += 0.015
33
+ elif action_type == "query_docs":
34
+ if not prev_docs_queried:
35
+ score += self.bonus * 0.5
36
+ # Encourage docs usage when it is likely useful:
37
+ # - early exploration phase
38
+ # - non-trivial query text
39
+ if env._step_count <= 4 and info.get("docs_query_len", 0) >= 8:
40
+ score += 0.01
41
+ # Discourage repeated docs calls after the first-use signal.
42
+ if prev_docs_queried:
43
+ score -= 0.01
44
+ elif action_type == "ask_question" and env._step_count <= 3:
45
+ score += 0.02
46
+ return score
47
+
48
+
49
+ # --------------------------------------------------------------------------------
50
+ # 2. DELTA‑BASED REWARDS
51
+ # --------------------------------------------------------------------------------
52
+ class TestDeltaRubric(Rubric):
53
+ def __init__(self, weight: float = 0.3):
54
+ self.weight = weight
55
+
56
+ def __call__(self, env, action, obs, reward, done, info):
57
+ delta = env._current_test_score - env._previous_test_score
58
+ effective = self.weight
59
+ if info.get("action_type") == "propose_fix":
60
+ effective *= 0.4
61
+ return effective * delta
62
+
63
+
64
+ class LintDeltaRubric(Rubric):
65
+ def __init__(self, weight: float = 0.3):
66
+ self.weight = weight
67
+
68
+ def __call__(self, env, action, obs, reward, done, info):
69
+ delta = env._current_lint_score - env._previous_lint_score
70
+ effective = self.weight * 0.5
71
+ if info.get("action_type") == "propose_fix":
72
+ effective *= 0.4
73
+ return effective * delta
74
+
75
+
76
+ # --------------------------------------------------------------------------------
77
+ # 3. TERMINAL SUCCESS BONUS
78
+ # --------------------------------------------------------------------------------
79
+ class TerminalSuccessRubric(Rubric):
80
+ def __call__(self, env, action, obs, reward, done, info):
81
+ if info.get("action_type") != "propose_fix":
82
+ return 0.0
83
+ score = 0.0
84
+ if env._current_test_score > 0.95:
85
+ score += 0.4
86
+ elif env._current_test_score > 0.85:
87
+ score += 0.2
88
+ return score
89
+
90
+
91
+ # --------------------------------------------------------------------------------
92
+ # 4. EXPLORATION & DIVERSITY
93
+ # --------------------------------------------------------------------------------
94
+ class ExplorationRubric(Rubric):
95
+ def __init__(self, penalty: float = -0.05, bonus: float = 0.021):
96
+ self.penalty = penalty
97
+ self.bonus = bonus
98
+
99
+ def __call__(self, env, action, obs, reward, done, info):
100
+ if len(env._action_history) < 3:
101
+ return 0.0
102
+ recent = env._action_history[-3:]
103
+ unique = len(set(recent))
104
+ if unique == 1:
105
+ return self.penalty
106
+ elif unique == 3:
107
+ return self.bonus
108
+ return 0.0
109
+
110
+
111
+ # --------------------------------------------------------------------------------
112
+ # 5. ANTI‑HACKING & CONSISTENCY
113
+ # --------------------------------------------------------------------------------
114
+ class AntiHackingRubric(Rubric):
115
+ def __call__(self, env, action, obs, reward, done, info):
116
+ if info.get("action_type") != "propose_fix":
117
+ return 0.0
118
+ score = 0.0
119
+ if not env._tests_run:
120
+ score -= 0.25
121
+ if env._step_count < 2:
122
+ score -= 0.1
123
+ if env._tests_run and env._linter_run:
124
+ score += 0.02
125
+ return score
126
+
127
+
128
+ # --------------------------------------------------------------------------------
129
+ # 6. STEP PENALTY
130
+ # --------------------------------------------------------------------------------
131
+ class StepPenaltyRubric(Rubric):
132
+ def __init__(self, penalty: float = -0.01):
133
+ self.penalty = penalty
134
+
135
+ def __call__(self, env, action, obs, reward, done, info):
136
+ return self.penalty
test_runner.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # test_runner.py – Full production version with continuous scoring, dynamic function detection, and randomised tests
2
+ import subprocess
3
+ import tempfile
4
+ import os
5
+ import json
6
+ import ast
7
+ import random
8
+ import sys
9
+ from typing import Tuple, List, Any, Optional
10
+ from dataclasses import dataclass
11
+
12
+ # Bridge fine-grained RedTeam ids to canonical TestRunner families.
13
+ # This keeps evaluation stable even when bug generators use richer labels.
14
+ BUG_ID_CANONICAL_MAP = {
15
+ # Easy-family bugs on `get_user`-style behavior.
16
+ "simple_typo": "null_check",
17
+ "default_value": "null_check",
18
+ "empty_return": "null_check",
19
+
20
+ # Medium arithmetic/control-flow aliases.
21
+ "loop_skip": "off_by_one",
22
+ "sign_error": "wrong_operator",
23
+
24
+ # Hard numeric-safety aliases.
25
+ "division_by_zero_empty": "division_by_zero",
26
+ "division_by_zero_zero": "division_by_zero",
27
+ "float_precision": "division_by_zero",
28
+ "abs_usage": "division_by_zero",
29
+ "round_error": "division_by_zero",
30
+ }
31
+
32
+ @dataclass
33
+ class TestRunner:
34
+ bug_id: str
35
+ timeout_sec: int = 5
36
+ max_memory_mb: int = 256
37
+ fuzz_rounds: int = 3 # number of random test cases per bug
38
+
39
+ def run_tests(self, fix_code: str) -> Tuple[float, str]:
40
+ """
41
+ Returns (score, output_message) where score is proportion of passed tests (0.0–1.0).
42
+ """
43
+ # 1. Detect the function defined in the agent's code (dynamic)
44
+ func_name = self._get_defined_function_name(fix_code)
45
+ if not func_name:
46
+ return 0.0, "No function definition found in agent code"
47
+
48
+ # 2. Normalize bug id so broader RedTeam ids still hit meaningful tests.
49
+ canonical_bug_id = self._canonical_bug_id()
50
+
51
+ # 3. Generate the test script (includes fixed + fuzzed test cases)
52
+ test_script = self._generate_test_script(fix_code, func_name, canonical_bug_id)
53
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False, encoding='utf-8') as f:
54
+ f.write(test_script)
55
+ tmp_path = f.name
56
+
57
+ try:
58
+ # Resource limiting (Linux only; fallback otherwise)
59
+ try:
60
+ import resource
61
+ resource.setrlimit(resource.RLIMIT_AS, (self.max_memory_mb * 1024 * 1024, self.max_memory_mb * 1024 * 1024))
62
+ except Exception:
63
+ pass
64
+
65
+ result = subprocess.run(
66
+ [sys.executable, tmp_path],
67
+ capture_output=True,
68
+ text=True,
69
+ timeout=self.timeout_sec,
70
+ encoding='utf-8'
71
+ )
72
+ # Parse JSON output
73
+ try:
74
+ data = json.loads(result.stdout.strip())
75
+ passed = data.get("passed", 0)
76
+ total = data.get("total", 1)
77
+ score = passed / total if total > 0 else 0.0
78
+ return score, result.stdout.strip()
79
+ except json.JSONDecodeError:
80
+ # Fallback: look for "True" (legacy)
81
+ if "True" in result.stdout:
82
+ return 1.0, result.stdout
83
+ return 0.0, result.stdout
84
+ except subprocess.TimeoutExpired:
85
+ return 0.0, "Test execution timed out"
86
+ except Exception as e:
87
+ return 0.0, f"Unexpected error: {str(e)}"
88
+ finally:
89
+ try:
90
+ os.unlink(tmp_path)
91
+ except:
92
+ pass
93
+
94
+ def _get_defined_function_name(self, code: str) -> Optional[str]:
95
+ """Extract the target function name from the code.
96
+ Looks for a function named 'fix' first; otherwise returns the first function found.
97
+ """
98
+ try:
99
+ tree = ast.parse(code)
100
+ first_func = None
101
+ for node in ast.walk(tree):
102
+ if isinstance(node, ast.FunctionDef):
103
+ if node.name == "fix":
104
+ return "fix"
105
+ if first_func is None:
106
+ first_func = node.name
107
+ return first_func # fallback if no 'fix' function exists
108
+ except SyntaxError:
109
+ pass
110
+ return None
111
+
112
+ def _canonical_bug_id(self) -> str:
113
+ """Return canonical bug family used by this test harness."""
114
+ return BUG_ID_CANONICAL_MAP.get(self.bug_id, self.bug_id)
115
+
116
+ def _generate_test_script(self, fix_code: str, func_name: str, canonical_bug_id: str) -> str:
117
+ """Generate a test script that runs fixed + fuzzed test cases and outputs JSON."""
118
+ test_cases = self._get_test_cases(canonical_bug_id, func_name)
119
+ fuzzed_cases = self._generate_fuzzed_cases(canonical_bug_id, func_name)
120
+ all_cases = test_cases + fuzzed_cases
121
+
122
+ lines = []
123
+ lines.append(fix_code)
124
+ lines.append("")
125
+ lines.append("import json")
126
+ lines.append("")
127
+ lines.append("def run_tests():")
128
+ lines.append(f" test_cases = {json.dumps(all_cases)}")
129
+ lines.append(" passed = 0")
130
+ lines.append(" total = len(test_cases)")
131
+ lines.append(" for args, expected in test_cases:")
132
+ lines.append(f" try:")
133
+ lines.append(f" result = {func_name}(*args) if isinstance(args, list) else {func_name}(args)")
134
+ lines.append(f" if result == expected:")
135
+ lines.append(f" passed += 1")
136
+ lines.append(f" except Exception:")
137
+ lines.append(f" pass")
138
+ lines.append(" return {'passed': passed, 'total': total}")
139
+ lines.append("")
140
+ lines.append("if __name__ == '__main__':")
141
+ lines.append(" result = run_tests()")
142
+ lines.append(" print(json.dumps(result))")
143
+ return "\n".join(lines)
144
+
145
+ def _get_test_cases(self, canonical_bug_id: str, func_name: str) -> List[Tuple[List[Any], Any]]:
146
+ """
147
+ Return a list of (arguments, expected_output) for the given bug_id.
148
+ Uses the actual function name (dynamic) for consistency.
149
+ """
150
+ if canonical_bug_id == "null_check":
151
+ return [
152
+ ([{"users": {"alice": "Alice"}, "id": "bob"}], None), # missing key should not crash
153
+ ([{"users": {"alice": "Alice"}, "id": "alice"}], "Alice"),
154
+ ]
155
+ elif canonical_bug_id == "off_by_one":
156
+ return [
157
+ ([[1,2,3,4]], 4),
158
+ ([[]], 0),
159
+ ]
160
+ elif canonical_bug_id == "division_by_zero":
161
+ return [
162
+ ([[]], 0),
163
+ ([[1,2,3]], 2.0),
164
+ ]
165
+ elif canonical_bug_id == "wrong_operator":
166
+ return [
167
+ ([5,3], 8),
168
+ ([-1,1], 0),
169
+ ]
170
+ else:
171
+ # For missing_lock, deadlock_order, etc., return empty list (will be handled gracefully)
172
+ return []
173
+
174
+ def _generate_fuzzed_cases(self, canonical_bug_id: str, func_name: str) -> List[Tuple[List[Any], Any]]:
175
+ """
176
+ Generate random test cases to prevent memorisation.
177
+ Only for bugs where meaningful fuzzing is possible.
178
+ """
179
+ cases = []
180
+ if canonical_bug_id == "null_check":
181
+ # Random users dictionary and random ids
182
+ for _ in range(self.fuzz_rounds):
183
+ users = {f"user_{i}": f"name_{i}" for i in range(random.randint(1, 5))}
184
+ # Pick existing or missing key
185
+ if random.random() > 0.5:
186
+ key = random.choice(list(users.keys()))
187
+ expected = users[key]
188
+ else:
189
+ key = "missing_" + str(random.randint(100, 999))
190
+ expected = None
191
+ cases.append(([{"users": users, "id": key}], expected))
192
+ elif canonical_bug_id == "off_by_one":
193
+ for _ in range(self.fuzz_rounds):
194
+ length = random.randint(0, 10)
195
+ arr = list(range(length))
196
+ cases.append(([arr], length))
197
+ elif canonical_bug_id == "division_by_zero":
198
+ for _ in range(self.fuzz_rounds):
199
+ length = random.randint(0, 10)
200
+ data = [random.randint(-100, 100) for _ in range(length)]
201
+ expected = sum(data)/length if length else 0
202
+ cases.append(([data], expected))
203
+ elif canonical_bug_id == "wrong_operator":
204
+ for _ in range(self.fuzz_rounds):
205
+ a = random.randint(-100, 100)
206
+ b = random.randint(-100, 100)
207
+ cases.append(([a, b], a + b))
208
+ return cases
training.py ADDED
@@ -0,0 +1,792 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # training.py
2
+ import json
3
+ import os
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch.optim import AdamW
7
+ from dataclasses import dataclass
8
+ from typing import List, Dict, Tuple, Optional
9
+ import numpy as np
10
+ import re
11
+ import random
12
+ import matplotlib.pyplot as plt
13
+
14
+ from unsloth import FastLanguageModel
15
+ from transformers import TrainingArguments
16
+ from trl import SFTTrainer
17
+ from datasets import Dataset
18
+
19
+ # Import your environment and actions (unchanged)
20
+ from environment import CodeReviewEnv
21
+ from redteam import BUG_DB
22
+ from models import (
23
+ RunTests, RunLinter, Inspect,
24
+ ProposeFix, WriteComment, AskQuestion,
25
+ Done, Skip , QueryDocs
26
+ )
27
+
28
+ # ======================================================================
29
+ # 1. ACTION PARSING (improved with fallback)
30
+ # ======================================================================
31
+ @dataclass
32
+ class AgentAction:
33
+ action_type: str
34
+ content: Optional[str] = None
35
+
36
+ def parse_action(output: str) -> AgentAction:
37
+ """Robust JSON parsing with regex fallback and keyword detection."""
38
+ # Try strict JSON first
39
+ try:
40
+ data = json.loads(output)
41
+ return AgentAction(
42
+ action_type=data.get("action_type", "").lower(),
43
+ content=data.get("content")
44
+ )
45
+ except:
46
+ pass
47
+
48
+ # Try to extract JSON from markdown blocks
49
+ json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', output, re.DOTALL)
50
+ if json_match:
51
+ try:
52
+ data = json.loads(json_match.group(1))
53
+ return AgentAction(
54
+ action_type=data.get("action_type", "").lower(),
55
+ content=data.get("content")
56
+ )
57
+ except:
58
+ pass
59
+
60
+ # Try to find "action_type" field with regex
61
+ action_pattern = r'"action_type"\s*:\s*"(\w+)"'
62
+ match = re.search(action_pattern, output)
63
+ if match:
64
+ return AgentAction(action_type=match.group(1).lower())
65
+
66
+ # Keyword detection as last resort
67
+ output_lower = output.lower()
68
+ if "test" in output_lower:
69
+ return AgentAction("run_tests")
70
+ if "lint" in output_lower:
71
+ return AgentAction("run_linter")
72
+ if "inspect" in output_lower:
73
+ return AgentAction("inspect")
74
+ if "doc" in output_lower or "documentation" in output_lower:
75
+ # Bridge natural language mentions to rltool-backed retrieval action.
76
+ return AgentAction("query_docs", "bug fix guidance")
77
+
78
+ return AgentAction("invalid", output)
79
+
80
+ def map_to_env(action: AgentAction):
81
+ if action.action_type == "run_tests":
82
+ return RunTests()
83
+ elif action.action_type == "run_linter":
84
+ return RunLinter()
85
+ elif action.action_type == "inspect":
86
+ return Inspect()
87
+ elif action.action_type == "fix":
88
+ return ProposeFix(fix_code=action.content or "")
89
+ elif action.action_type == "comment":
90
+ return WriteComment(comment_text=action.content or "")
91
+ elif action.action_type == "question":
92
+ return AskQuestion(question=action.content or "")
93
+ elif action.action_type == "query_docs": # <-- new
94
+ return QueryDocs(query_topic=action.content or "")
95
+ elif action.action_type == "done":
96
+ return Done()
97
+ else:
98
+ return Skip()
99
+
100
+ # ======================================================================
101
+ # 2. MODEL SETUP (stabilised LoRA)
102
+ # ======================================================================
103
+ def load_model():
104
+ model, tokenizer = FastLanguageModel.from_pretrained(
105
+ model_name="unsloth/gemma-2-2b-it-bnb-4bit",
106
+ max_seq_length=2048,
107
+ load_in_4bit=True,
108
+ )
109
+ # FIXED: Lower rank (16), dropout=0 for stability
110
+ model = FastLanguageModel.get_peft_model(
111
+ model,
112
+ r=16, # was 64 → causes collapse
113
+ target_modules=[
114
+ "q_proj", "k_proj", "v_proj", "o_proj",
115
+ "gate_proj", "up_proj", "down_proj"
116
+ ],
117
+ lora_alpha=32, # adjusted for r=16
118
+ lora_dropout=0.0, # dropout can cause empty outputs
119
+ )
120
+ # Ensure tokenizer has correct chat template for Gemma-2
121
+ if tokenizer.chat_template is None:
122
+ tokenizer.chat_template = "{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}<start_of_turn>user\n{{ message['content'] }}<end_of_turn>\n<start_of_turn>model\n{% elif message['role'] == 'assistant' %}{{ message['content'] }}<end_of_turn>\n{% endif %}{% endfor %}"
123
+ return model, tokenizer
124
+
125
+ # ======================================================================
126
+ # 3. MODEL SANITY CHECK (new – ensures model can generate text)
127
+ # ======================================================================
128
+ def test_model_sanity(model, tokenizer) -> bool:
129
+ print("\n" + "="*60)
130
+ print("SANITY CHECK: Testing base model generation")
131
+ print("="*60)
132
+ test_prompt = "Hello, how are you?"
133
+ messages = [{"role": "user", "content": test_prompt}]
134
+ formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
135
+ inputs = tokenizer(formatted, return_tensors="pt").to("cuda")
136
+ with torch.no_grad():
137
+ outputs = model.generate(
138
+ **inputs,
139
+ max_new_tokens=30,
140
+ do_sample=True,
141
+ temperature=0.7,
142
+ min_new_tokens=1,
143
+ eos_token_id=tokenizer.eos_token_id,
144
+ pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
145
+ )
146
+ generated_ids = outputs[0][inputs['input_ids'].shape[1]:]
147
+ response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
148
+ print(f"Prompt: {test_prompt}")
149
+ print(f"Response: {repr(response)}")
150
+ if len(response) == 0:
151
+ print("❌ Model produces empty output – cannot train.")
152
+ return False
153
+ print("✓ Model sanity check PASSED\n")
154
+ return True
155
+
156
+ # ======================================================================
157
+ # 4. SUPERVISED WARM-UP (teaches JSON output)
158
+ # ======================================================================
159
+ def supervised_warmup(model, tokenizer, n_examples=500, epochs=2):
160
+ print("\n" + "="*60)
161
+ print("SUPERVISED WARM-UP: Teaching JSON format")
162
+ print("="*60)
163
+
164
+ examples = []
165
+ action_templates = [
166
+ '{"action_type": "run_tests"}',
167
+ '{"action_type": "run_linter"}',
168
+ '{"action_type": "inspect"}',
169
+ '{"action_type": "query_docs", "content": "python keyerror handling"}',
170
+ '{"action_type": "fix", "content": "def corrected():\n pass"}',
171
+ '{"action_type": "comment", "content": "This looks good."}',
172
+ '{"action_type": "question", "content": "Why is this variable used?"}',
173
+ '{"action_type": "done"}',
174
+ ]
175
+
176
+ for i in range(n_examples):
177
+ code = f"def example_{i}():\n return {i % 10}"
178
+ last_outputs = [
179
+ "Tests passed: 2/3",
180
+ "Linter found 1 error",
181
+ "Inspection complete",
182
+ "No previous action",
183
+ ]
184
+ last_output = random.choice(last_outputs)
185
+ # Use same prompt structure as build_prompt
186
+ prompt = f"""You are an AI code review agent. Your goal is to convince a simulated human developer to accept your proposed fix and name your proposed fix function fix.
187
+
188
+ The developer has a **{author_personality}** personality and will only accept if you provide solid evidence:
189
+ - Tests pass (high pass ratio)
190
+ - Lint is clean (zero errors)
191
+ - Documentation or references are provided
192
+ - Your reasoning is clear, uses words like "because" or "therefore", and is detailed (over 30 words if needed)
193
+
194
+ Workflow:
195
+ 1. Use `inspect` to understand the code.
196
+ 2. Use `run_tests` and `run_linter` to gather evidence.
197
+ 3. Propose a fix (`fix`) and explain why it works (`comment` or `question`).
198
+ 4. If the developer pushes back, read their response carefully and address their specific concern.
199
+ 5. Once convinced, use `done` to finish.
200
+
201
+ Code:
202
+ {obs.code_snippet}
203
+
204
+ Author says:
205
+ {author_msg if author_msg else "(no response yet – start with inspection)"}
206
+
207
+ Last tool output:
208
+ {tool_output if tool_output else "(none)"}
209
+
210
+ Available actions:
211
+ run_tests, run_linter, inspect, query_docs, fix, comment, question, done
212
+
213
+ Respond ONLY in JSON:
214
+ {{"action_type": "...", "content": "..."}}"""
215
+
216
+ action_json = random.choice(action_templates)
217
+ messages = [
218
+ {"role": "user", "content": prompt},
219
+ {"role": "assistant", "content": action_json}
220
+ ]
221
+ full_text = tokenizer.apply_chat_template(messages, tokenize=False)
222
+ examples.append({"text": full_text})
223
+
224
+ dataset = Dataset.from_list(examples)
225
+ trainer = SFTTrainer(
226
+ model=model,
227
+ tokenizer=tokenizer,
228
+ train_dataset=dataset,
229
+ dataset_text_field="text",
230
+ max_seq_length=512,
231
+ args=TrainingArguments(
232
+ output_dir="warmup_output",
233
+ num_train_epochs=epochs,
234
+ per_device_train_batch_size=4,
235
+ gradient_accumulation_steps=2,
236
+ learning_rate=2e-5,
237
+ logging_steps=50,
238
+ save_strategy="no",
239
+ fp16=True,
240
+ ),
241
+ )
242
+ print(f"Training on {n_examples} examples for {epochs} epochs...")
243
+ trainer.train()
244
+ print("✓ Warm-up complete\n")
245
+
246
+ # ======================================================================
247
+ # 5. ACTION GENERATION WITH LOGPROB TRACKING (fixed)
248
+ # ======================================================================
249
+ def generate_action_with_logprob(
250
+ prompt: str,
251
+ model,
252
+ tokenizer,
253
+ temperature: float = 0.0, # changed: greedy by default for stability
254
+ max_retries: int = 2
255
+ ) -> Tuple[str, float]:
256
+ """Generate action using correct chat template, with fallback."""
257
+ messages = [{"role": "user", "content": prompt}]
258
+ formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
259
+ inputs = tokenizer(formatted, return_tensors="pt").to("cuda")
260
+
261
+ for attempt in range(max_retries):
262
+ with torch.no_grad():
263
+ outputs = model.generate(
264
+ **inputs,
265
+ max_new_tokens=128,
266
+ do_sample=(temperature > 0),
267
+ temperature=max(temperature, 0.01) if temperature > 0 else 1.0,
268
+ min_new_tokens=1,
269
+ return_dict_in_generate=True,
270
+ output_scores=True,
271
+ )
272
+
273
+ generated_ids = outputs.sequences[0][inputs['input_ids'].shape[1]:]
274
+ action_text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
275
+
276
+ # Compute logprob
277
+ logprobs = []
278
+ for idx, token_id in enumerate(generated_ids):
279
+ if idx < len(outputs.scores):
280
+ token_logits = outputs.scores[idx][0]
281
+ token_logprob = F.log_softmax(token_logits, dim=-1)[token_id].item()
282
+ logprobs.append(token_logprob)
283
+ total_logprob = sum(logprobs) if logprobs else -100.0
284
+
285
+ # If empty, use fallback
286
+ if not action_text:
287
+ fallback_actions = [
288
+ '{"action_type": "run_tests"}',
289
+ '{"action_type": "run_linter"}',
290
+ '{"action_type": "inspect"}',
291
+ '{"action_type": "skip"}',
292
+ ]
293
+ action_text = random.choice(fallback_actions)
294
+ total_logprob = -50.0
295
+ print(f"[WARN] Empty generation → using fallback: {action_text}")
296
+ return action_text, total_logprob
297
+
298
+ # Validate JSON
299
+ try:
300
+ json.loads(action_text)
301
+ return action_text, total_logprob
302
+ except:
303
+ if attempt == max_retries - 1:
304
+ return '{"action_type":"skip"}', -100.0
305
+ continue
306
+
307
+ return '{"action_type":"skip"}', -100.0
308
+
309
+ # ======================================================================
310
+ # 6. PROMPT BUILDER (unchanged – exactly as you wrote)
311
+ # ======================================================================
312
+ def build_prompt(obs, history_lines: List[str]) -> str:
313
+ author_msg = getattr(obs, "author_response", "") or ""
314
+ tool_output = getattr(obs, "last_tool_output", "") or ""
315
+
316
+ # Personality hint (optional but helpful)
317
+ author_personality = getattr(obs, "author_personality", "defensive") # e.g., from env
318
+
319
+ prompt = f"""You are an AI code review agent. Your goal is to convince a simulated human developer to accept your proposed fix and name your proposed fix function fix.
320
+
321
+ The developer has a **{author_personality}** personality and will only accept if you provide solid evidence:
322
+ - Tests pass (high pass ratio)
323
+ - Lint is clean (zero errors)
324
+ - Documentation or references are provided
325
+ - Your reasoning is clear, uses words like "because" or "therefore", and is detailed (over 30 words if needed)
326
+
327
+ Workflow:
328
+ 1. Use `inspect` to understand the code.
329
+ 2. Use `run_tests` and `run_linter` to gather evidence.
330
+ 3. Use `query_docs` when you need references or language-specific guidance.
331
+ 4. Propose a fix (`fix`) and explain why it works (`comment` or `question`).
332
+ 5. If the developer pushes back, read their response carefully and address their specific concern.
333
+ 6. Once convinced, use `done` to finish.
334
+
335
+ Code:
336
+ {obs.code_snippet}
337
+
338
+ Author says:
339
+ {author_msg if author_msg else "(no response yet – start with inspection)"}
340
+
341
+ Last tool output:
342
+ {tool_output if tool_output else "(none)"}
343
+
344
+ Available actions:
345
+ run_tests, run_linter, inspect, query_docs, fix, comment, question, done
346
+
347
+ Respond ONLY in JSON:
348
+ {{"action_type": "...", "content": "..."}}"""
349
+
350
+ if history_lines:
351
+ history = "\n".join(history_lines[-6:])
352
+ prompt += f"\n\nPrevious steps:\n{history}"
353
+ return prompt
354
+
355
+ # ======================================================================
356
+ # 7. TRAJECTORY STORAGE (unchanged)
357
+ # ======================================================================
358
+ @dataclass
359
+ class Trajectory:
360
+ states: List[str]
361
+ actions: List[str]
362
+ rewards: List[float]
363
+ logprobs: List[float]
364
+ dones: List[bool]
365
+
366
+ def __len__(self):
367
+ return len(self.states)
368
+
369
+ def to_dict(self):
370
+ return {
371
+ "states": self.states,
372
+ "actions": self.actions,
373
+ "rewards": self.rewards,
374
+ "logprobs": self.logprobs,
375
+ "dones": self.dones,
376
+ }
377
+
378
+ # ======================================================================
379
+ # 8. ROLLOUT COLLECTION (uses fixed generate)
380
+ # ======================================================================
381
+ def collect_trajectory(
382
+ env: CodeReviewEnv,
383
+ model,
384
+ tokenizer,
385
+ max_steps: int = 10,
386
+ temperature: float = 0.0 # changed to greedy
387
+ ) -> Trajectory:
388
+ obs = env.reset()
389
+ history_lines = []
390
+
391
+ states = []
392
+ actions = []
393
+ rewards = []
394
+ logprobs = []
395
+ dones = []
396
+
397
+ for step in range(max_steps):
398
+ prompt = build_prompt(obs, history_lines)
399
+ states.append(prompt)
400
+
401
+ action_text, logprob = generate_action_with_logprob(
402
+ prompt, model, tokenizer, temperature
403
+ )
404
+ actions.append(action_text)
405
+ logprobs.append(logprob)
406
+
407
+ action = parse_action(action_text)
408
+ env_action = map_to_env(action)
409
+ next_obs, reward, done, _ = env.step(env_action)
410
+
411
+ rewards.append(reward.value)
412
+ dones.append(done)
413
+
414
+ history_lines.append(f"Agent: {action_text}")
415
+ history_lines.append(f"Env: {next_obs.last_tool_output}")
416
+
417
+ obs = next_obs
418
+ if done:
419
+ break
420
+
421
+ return Trajectory(states, actions, rewards, logprobs, dones)
422
+
423
+ def collect_trajectories(
424
+ env: CodeReviewEnv,
425
+ model,
426
+ tokenizer,
427
+ n_trajectories: int,
428
+ max_steps: int = 10,
429
+ task_levels: Optional[List[str]] = None,
430
+ task_weights: Optional[List[float]] = None,
431
+ ) -> List[Trajectory]:
432
+ # Link training to RedTeam's full bug distribution by sampling tasks
433
+ # per trajectory instead of training only on env default ("easy").
434
+ if task_levels is None:
435
+ task_levels = list(BUG_DB.keys())
436
+ if task_weights is not None and len(task_weights) != len(task_levels):
437
+ raise ValueError("task_weights must match task_levels length")
438
+ if task_weights is not None and sum(task_weights) <= 0:
439
+ raise ValueError("task_weights must have a positive total")
440
+
441
+ trajectories = []
442
+ for i in range(n_trajectories):
443
+ # Weighted sampling supports curriculum-style training schedules.
444
+ sampled_task = random.choices(task_levels, weights=task_weights, k=1)[0]
445
+ env.set_task(sampled_task)
446
+ traj = collect_trajectory(env, model, tokenizer, max_steps)
447
+ total_reward = sum(traj.rewards)
448
+ print(f"Trajectory {i+1}/{n_trajectories}: "
449
+ f"task={sampled_task}, steps={len(traj)}, reward={total_reward:.3f}")
450
+ trajectories.append(traj)
451
+ return trajectories
452
+
453
+ # ======================================================================
454
+ # 9. ADVANTAGE ESTIMATION (unchanged)
455
+ # ======================================================================
456
+ def compute_returns_and_advantages(
457
+ rewards: List[float],
458
+ dones: List[bool],
459
+ gamma: float = 0.99,
460
+ standardize: bool = True
461
+ ) -> Tuple[List[float], List[float]]:
462
+ """
463
+ Computes discounted returns and normalised advantages (no critic).
464
+ Advantages = returns - mean(returns) (or zero baseline).
465
+ """
466
+ n = len(rewards)
467
+ returns = [0.0] * n
468
+ running_return = 0.0
469
+ for t in reversed(range(n)):
470
+ if dones[t]:
471
+ running_return = 0.0
472
+ running_return = rewards[t] + gamma * running_return
473
+ returns[t] = running_return
474
+
475
+ if standardize:
476
+ advantages = np.array(returns) - np.mean(returns)
477
+ adv_std = np.std(advantages) + 1e-8
478
+ advantages = (advantages / adv_std).tolist()
479
+ else:
480
+ advantages = returns.copy()
481
+
482
+ return advantages, returns
483
+ # ======================================================================
484
+ # 10. COMPUTE NEW LOGPROBS (unchanged)
485
+ # ======================================================================
486
+ def compute_logprob(prompt: str, action: str, model, tokenizer) -> float:
487
+ messages = [{"role": "user", "content": prompt}]
488
+ formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
489
+ full_text = formatted + action
490
+ inputs = tokenizer(full_text, return_tensors="pt").to("cuda")
491
+
492
+ with torch.no_grad():
493
+ outputs = model(**inputs)
494
+ logits = outputs.logits
495
+
496
+ action_ids = tokenizer.encode(action, add_special_tokens=False)
497
+ prefix_ids = tokenizer.encode(formatted, add_special_tokens=False)
498
+ action_start = len(prefix_ids)
499
+
500
+ logprobs = []
501
+ for idx, token_id in enumerate(action_ids):
502
+ position = action_start + idx - 1
503
+ if 0 <= position < logits.shape[1]:
504
+ token_logits = logits[0, position]
505
+ token_logprob = F.log_softmax(token_logits, dim=-1)[token_id].item()
506
+ logprobs.append(token_logprob)
507
+ return sum(logprobs) if logprobs else -100.0
508
+
509
+ # ======================================================================
510
+ # 11. PPO UPDATE (unchanged except uses compute_logprob correctly)
511
+ # ======================================================================
512
+ def ppo_update(
513
+ trajectories: List[Trajectory],
514
+ model,
515
+ tokenizer,
516
+ optimizer,
517
+ n_epochs: int = 4,
518
+ clip_epsilon: float = 0.2,
519
+ entropy_coef: float = 0.01,
520
+ gamma: float = 0.99,
521
+ ) -> Dict[str, float]:
522
+ model.train()
523
+
524
+ all_states = []
525
+ all_actions = []
526
+ all_old_logprobs = []
527
+ all_advantages = []
528
+ all_returns = []
529
+
530
+ for traj in trajectories:
531
+ advantages, returns = compute_returns_and_advantages(
532
+ traj.rewards, traj.dones, gamma=gamma, standardize=True
533
+ )
534
+ all_states.extend(traj.states)
535
+ all_actions.extend(traj.actions)
536
+ all_old_logprobs.extend(traj.logprobs)
537
+ all_advantages.extend(advantages)
538
+ all_returns.extend(returns)
539
+
540
+ n_samples = len(all_states)
541
+ total_loss = 0.0
542
+ total_policy_loss = 0.0
543
+ total_entropy = 0.0
544
+ n_updates = 0
545
+
546
+ for epoch in range(n_epochs):
547
+ indices = np.random.permutation(n_samples)
548
+ for i in indices:
549
+ state = all_states[i]
550
+ action = all_actions[i]
551
+ old_logprob = all_old_logprobs[i]
552
+ advantage = all_advantages[i]
553
+
554
+ # Use the same chat template for PPO update
555
+ messages = [{"role": "user", "content": state}]
556
+ formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
557
+ full_text = formatted + action
558
+ inputs = tokenizer(full_text, return_tensors="pt").to("cuda")
559
+
560
+ outputs = model(**inputs)
561
+ logits = outputs.logits
562
+
563
+ action_ids = tokenizer.encode(action, add_special_tokens=False)
564
+ prefix_ids = tokenizer.encode(formatted, add_special_tokens=False)
565
+ action_start = len(prefix_ids)
566
+
567
+ logprobs = []
568
+ entropy = 0.0
569
+ for idx, token_id in enumerate(action_ids):
570
+ position = action_start + idx - 1
571
+ if 0 <= position < logits.shape[1]:
572
+ token_logits = logits[0, position]
573
+ log_probs = F.log_softmax(token_logits, dim=-1)
574
+ token_logprob = log_probs[token_id]
575
+ logprobs.append(token_logprob)
576
+
577
+ probs = F.softmax(token_logits, dim=-1)
578
+ entropy += -(probs * log_probs).sum()
579
+
580
+ if not logprobs:
581
+ continue
582
+
583
+ new_logprob = sum(logprobs)
584
+ avg_entropy = entropy / len(logprobs) if logprobs else 0.0
585
+
586
+ ratio = torch.exp(new_logprob - old_logprob)
587
+ surr1 = ratio * advantage
588
+ surr2 = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * advantage
589
+ policy_loss = -torch.min(surr1, surr2)
590
+ loss = policy_loss - entropy_coef * avg_entropy
591
+
592
+ optimizer.zero_grad()
593
+ loss.backward()
594
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
595
+ optimizer.step()
596
+
597
+ total_loss += loss.item()
598
+ total_policy_loss += policy_loss.item()
599
+ total_entropy += avg_entropy.item()
600
+ n_updates += 1
601
+
602
+ return {
603
+ "loss": total_loss / n_updates if n_updates > 0 else 0.0,
604
+ "policy_loss": total_policy_loss / n_updates if n_updates > 0 else 0.0,
605
+ "entropy": total_entropy / n_updates if n_updates > 0 else 0.0,
606
+ }
607
+ # ======================================================================
608
+ # 12. EVALUATION (unchanged)
609
+ # ======================================================================
610
+ def evaluate_policy(
611
+ env: CodeReviewEnv,
612
+ model,
613
+ tokenizer,
614
+ n_episodes: int = 10,
615
+ max_steps: int = 10
616
+ ) -> Dict[str, float]:
617
+ model.eval()
618
+ total_rewards = []
619
+ episode_lengths = []
620
+ success_count = 0
621
+
622
+ for _ in range(n_episodes):
623
+ traj = collect_trajectory(env, model, tokenizer, max_steps, temperature=0.0)
624
+ total_reward = sum(traj.rewards)
625
+ total_rewards.append(total_reward)
626
+ episode_lengths.append(len(traj))
627
+ if total_reward > 0.5:
628
+ success_count += 1
629
+
630
+ return {
631
+ "avg_reward": np.mean(total_rewards),
632
+ "std_reward": np.std(total_rewards),
633
+ "avg_length": np.mean(episode_lengths),
634
+ "success_rate": success_count / n_episodes,
635
+ }
636
+
637
+ # ======================================================================
638
+ # 13. MAIN TRAINING LOOP (added sanity check and warm-up)
639
+ # ======================================================================
640
+ def train_ppo(
641
+ n_iterations: int = 50,
642
+ trajectories_per_iter: int = 10,
643
+ n_epochs: int = 4,
644
+ max_steps: int = 10,
645
+ learning_rate: float = 3e-5,
646
+ clip_epsilon: float = 0.2,
647
+ entropy_coef: float = 0.01,
648
+ gamma: float = 0.99,
649
+ eval_every: int = 5,
650
+ task_levels: Optional[List[str]] = None,
651
+ curriculum_weighted_sampling: bool = True,
652
+ reward_profile: str = "full",
653
+ ):
654
+ print("Loading model...")
655
+ model, tokenizer = load_model()
656
+
657
+ # NEW: Sanity check before any training
658
+ if not test_model_sanity(model, tokenizer):
659
+ print("\n❌ Model sanity check failed – cannot proceed.")
660
+ return
661
+
662
+ # NEW: Supervised warm-up to teach JSON format
663
+ supervised_warmup(model, tokenizer, n_examples=500, epochs=2)
664
+
665
+ optimizer = AdamW(model.parameters(), lr=learning_rate)
666
+ env = CodeReviewEnv(reward_profile=reward_profile)
667
+ if task_levels is None:
668
+ task_levels = list(BUG_DB.keys())
669
+
670
+ print(f"\n{'='*60}")
671
+ print(f"Starting PPO Training")
672
+ print(f"Iterations: {n_iterations}")
673
+ print(f"Trajectories per iteration: {trajectories_per_iter}")
674
+ print(f"PPO epochs: {n_epochs}")
675
+ print(f"Reward profile: {reward_profile}")
676
+ print(f"{'='*60}\n")
677
+ reward_history: List[float] = []
678
+ loss_history: List[float] = []
679
+
680
+ for iteration in range(n_iterations):
681
+ print(f"\n--- Iteration {iteration + 1}/{n_iterations} ---")
682
+ # Optional weighted curriculum:
683
+ # start with easier tasks and smoothly ramp difficulty over training.
684
+ if curriculum_weighted_sampling:
685
+ progress = (iteration + 1) / max(n_iterations, 1)
686
+ easy_w = max(0.15, 0.55 - 0.40 * progress)
687
+ medium_w = max(0.15, 0.25 - 0.10 * progress)
688
+ hard_w = 0.10 + 0.05 * progress
689
+ harder_w = 0.05 + 0.20 * progress
690
+ hardest_w = 0.05 + 0.25 * progress
691
+ task_weight_map = {
692
+ "easy": easy_w,
693
+ "medium": medium_w,
694
+ "hard": hard_w,
695
+ "harder": harder_w,
696
+ "hardest": hardest_w,
697
+ }
698
+ task_weights = [task_weight_map.get(level, 1.0) for level in task_levels]
699
+ else:
700
+ task_weights = None
701
+
702
+ print("Collecting trajectories...")
703
+ trajectories = collect_trajectories(
704
+ env,
705
+ model,
706
+ tokenizer,
707
+ trajectories_per_iter,
708
+ max_steps,
709
+ task_levels=task_levels,
710
+ task_weights=task_weights,
711
+ )
712
+
713
+ avg_reward = np.mean([sum(t.rewards) for t in trajectories])
714
+ avg_length = np.mean([len(t) for t in trajectories])
715
+ reward_history.append(float(avg_reward))
716
+
717
+ print(f"Avg reward: {avg_reward:.3f}")
718
+ print(f"Avg length: {avg_length:.1f}")
719
+
720
+ print("Updating policy...")
721
+ metrics = ppo_update(
722
+ trajectories,
723
+ model,
724
+ tokenizer,
725
+ optimizer,
726
+ n_epochs=n_epochs,
727
+ clip_epsilon=clip_epsilon,
728
+ entropy_coef=entropy_coef,
729
+ gamma=gamma,
730
+ )
731
+
732
+ print(f"Loss: {metrics['loss']:.4f}")
733
+ print(f"Policy loss: {metrics['policy_loss']:.4f}")
734
+ print(f"Entropy: {metrics['entropy']:.4f}")
735
+ loss_history.append(float(metrics["loss"]))
736
+
737
+ if (iteration + 1) % eval_every == 0:
738
+ print("\nEvaluating policy...")
739
+ eval_metrics = evaluate_policy(env, model, tokenizer, n_episodes=10)
740
+ print(f"Eval avg reward: {eval_metrics['avg_reward']:.3f} ± {eval_metrics['std_reward']:.3f}")
741
+ print(f"Eval success rate: {eval_metrics['success_rate']:.2%}")
742
+ print(f"Eval avg length: {eval_metrics['avg_length']:.1f}")
743
+
744
+ print("\n" + "="*60)
745
+ print("Training complete. Saving model...")
746
+ model.save_pretrained("ppo_final_model")
747
+ tokenizer.save_pretrained("ppo_final_model")
748
+ print("Model saved to ppo_final_model/")
749
+
750
+ # Save training curves for quick before/after comparisons.
751
+ # These are intentionally simple line plots to avoid extra dependencies.
752
+ if reward_history:
753
+ plt.figure(figsize=(8, 4))
754
+ plt.plot(range(1, len(reward_history) + 1), reward_history, marker="o")
755
+ plt.title("Average Reward per Iteration")
756
+ plt.xlabel("Iteration")
757
+ plt.ylabel("Average Reward")
758
+ plt.grid(alpha=0.3)
759
+ plt.tight_layout()
760
+ plt.savefig("reward_curve.png", dpi=150)
761
+ plt.close()
762
+
763
+ if loss_history:
764
+ plt.figure(figsize=(8, 4))
765
+ plt.plot(range(1, len(loss_history) + 1), loss_history, marker="o", color="tab:red")
766
+ plt.title("Training Loss per Iteration")
767
+ plt.xlabel("Iteration")
768
+ plt.ylabel("Loss")
769
+ plt.grid(alpha=0.3)
770
+ plt.tight_layout()
771
+ plt.savefig("loss_curve.png", dpi=150)
772
+ plt.close()
773
+
774
+ if os.path.exists("reward_curve.png") and os.path.exists("loss_curve.png"):
775
+ print("Saved reward_curve.png and loss_curve.png")
776
+ print("="*60)
777
+
778
+ # ======================================================================
779
+ # 14. ENTRY POINT (unchanged)
780
+ # ======================================================================
781
+ if __name__ == "__main__":
782
+ train_ppo(
783
+ n_iterations=50,
784
+ trajectories_per_iter=10,
785
+ n_epochs=4,
786
+ max_steps=10,
787
+ learning_rate=3e-5,
788
+ clip_epsilon=0.2,
789
+ entropy_coef=0.01,
790
+ gamma=0.99,
791
+ eval_every=5,
792
+ )