100XZX001 commited on
Commit
1588266
·
verified ·
1 Parent(s): a9cad0e

Upload 16 files

Browse files
Files changed (15) hide show
  1. LICENSE +21 -0
  2. README.md +118 -103
  3. __init__.py +16 -16
  4. author.py +219 -219
  5. bugs.json +127 -127
  6. client.py +4 -4
  7. environment.py +613 -556
  8. grader.py +141 -147
  9. models.py +87 -114
  10. openenv.yaml +135 -135
  11. pyproject.toml +29 -29
  12. redteam.py +274 -274
  13. rubrics.py +136 -123
  14. test_runner.py +208 -181
  15. training.py +792 -708
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2026 YUTA
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,103 +1,118 @@
1
- ---
2
- title: Code Review Professional Workflow
3
- emoji: 🔥
4
- colorFrom: blue
5
- colorTo: purple
6
- sdk: docker
7
- app_port: 7860
8
- pinned: false
9
- ---
10
-
11
- # Code Review Professional Workflow
12
-
13
-
14
- "Multi‑turn code review environment for professional‑level bug fixing. "
15
- "The agent must inspect, test, lint, query documentation, and negotiate with "
16
- "a simulated (persona‑driven) author to get a fix accepted. "
17
- "Includes 25 bugs across 5 difficulty levels, AST‑based injection, "
18
- "a reward‑shaping system, and curriculum learning. "
19
- "Designed for RL training (PPO, DPO, or any policy‑gradient method)
20
-
21
- ## Quick Start
22
-
23
- ```python
24
- from environment import CodeReviewEnv
25
- env = CodeReviewEnv()
26
- obs = env.reset()
27
- print(obs.code_snippet)
28
- ```
29
-
30
- ## Environment Endpoints
31
-
32
- - `POST /reset` – reset environment (optional `task` parameter)
33
- - `POST /step` – take an action (JSON)
34
- - `GET /state` – get full environment state
35
- - `GET /health` health check
36
- - `GET /metadata` – environment metadata
37
- - `GET /schema` action/observation schemas
38
- - `POST /mcp` minimal MCP endpoint
39
-
40
- ## Tasks
41
- ## 🐛 Bug Taxonomy (25 bugs across 5 difficulty levels)
42
-
43
- The **RedTeam** randomly selects one bug from the current difficulty level at the start of every episode.
44
- Your agent must figure out what’s broken, gather evidence, and convince the simulated author – or it won’t stick.
45
-
46
- ### 🟢 Easy – Null‑Checks & Simple Logic Errors
47
-
48
- | # | Bug ID | What’s wrong | Injection method |
49
- |---|--------|--------------|------------------|
50
- | 1 | `null_check` | Missing `if key in dict:` guard → KeyError | AST: remove the if‑statement |
51
- | 2 | `simple_typo` | Misspelled variable `users` `usres` | AST: rename variable |
52
- | 3 | `string_index` | String index shifted by +1 | AST: change constant in index |
53
- | 4 | `default_value` | `dict.get(key)` used without a fallback | AST: replace `dict.get(key)` with `dict[key]` |
54
- | 5 | `empty_return` | Function returns `None` prematurely | AST: insert `return None` early |
55
-
56
- ### 🟡 Medium Off‑By‑One, Loop Logic & Simple Arithmetic
57
-
58
- | # | Bug ID | What’s wrong | Injection method |
59
- |---|--------|--------------|------------------|
60
- | 6 | `off_by_one` | `range(x)` becomes `range(1, x-1)` – skips first & last | AST: modify range arguments |
61
- | 7 | `loop_skip` | `range(len(arr))` becomes `range(len(arr)-1)` misses last element | AST: change range length |
62
- | 8 | `sign_error` | `sum += item` turned into `sum -= item` | AST: swap Add / Sub |
63
- | 9 | `swap_args` | Function arguments swapped | AST: swap first two arguments |
64
- |10 | `uninitialised_var` | Variable used before assignment in a loop | AST: remove the assignment statement |
65
-
66
- ### 🟠 Hard Division‑By‑Zero, Floating‑Point & Edge Cases
67
-
68
- | # | Bug ID | What’s wrong | Injection method |
69
- |---|--------|--------------|------------------|
70
- |11 | `division_by_zero_empty` | Empty‑list guard removed before averaging | AST: delete `if not data:` |
71
- |12 | `division_by_zero_zero` | Denominator check removed | AST: remove the zero‑check |
72
- |13 | `float_precision` | True division `/` replaced by integer division `//` | AST: change Div → FloorDiv |
73
- |14 | `abs_usage` | `abs()` call removed when comparing differences | AST: delete `abs()` wrapper |
74
- |15 | `round_error` | `round()` placed too early, causing precision drift | AST: inject `round()` prematurely |
75
-
76
- ### 🔴 HarderRace Conditions & Atomicity Bugs
77
-
78
- | # | Bug ID | What’s wrong | Injection method |
79
- |---|--------|--------------|------------------|
80
- |16 | `missing_lock` | Shared counter incremented without a lock | Template: remove `with lock:` |
81
- |17 | `double_lock` | Acquiring the same lock twice → deadlock risk | Template: add extra `lock.acquire()` |
82
- |18 | `global_nonatomic` | `count = count + 1` (read‑modify‑write) instead of `+=` | AST: modify assignment node |
83
- |19 | `thread_safe_list` | List append across threads without synchronisation | Template: remove lock from list operation |
84
- |20 | `volatile_read` | Shared flag read outside a lock → stale value | Template: remove synchronisation block |
85
-
86
- ### Hardest Deadlocks, Ordering & Complex Concurrency
87
-
88
- | # | Bug ID | What’s wrong | Injection method |
89
- |---|--------|--------------|------------------|
90
- |21 | `deadlock_order` | Locks acquired in opposite order in two threads | Template: swap lock order |
91
- |22 | `nested_lock_timeout` | `lock.acquire()` without a timeout → permanent hang | Template: remove timeout logic |
92
- |23 | `fork_join` | Thread started but not joined (`join()` missing) | AST: remove `thread.join()` |
93
- |24 | `mutex_release` | Lock released by a thread that never acquired it | Template: incorrect release logic |
94
- |25 | `race_on_init` | Shared resource initialised after threads have started | Template: move initialisation after `join()` |
95
- ## Deployment
96
-
97
- ```bash
98
- openenv push
99
- ```
100
-
101
- ## License
102
-
103
- MIT
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Code Review Professional Workflow
3
+ emoji: 🔥
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: docker
7
+ app_port: 7860
8
+ pinned: false
9
+ ---
10
+
11
+ # Code Review Professional Workflow
12
+
13
+ This project is a multi-turn RL environment where an agent plays the role of a senior code reviewer.
14
+ Instead of just patching code, the agent must gather evidence (`inspect`, `run_tests`, `run_linter`,
15
+ `query_docs`) and convince a simulated developer persona to accept the fix.
16
+
17
+ ### Why this environment is interesting
18
+
19
+ - It combines **technical correctness** (tests/lint) with **human acceptance** (negotiation).
20
+ - It includes **25 injected bug types** across 5 difficulty levels via `RedTeam`.
21
+ - It supports both a **full reward profile** (rich shaping) and a **core reward profile**
22
+ (minimal, baseline-friendly signal for ablations).
23
+
24
+ ## Quick Start
25
+
26
+ ```python
27
+ from environment import CodeReviewEnv
28
+ env = CodeReviewEnv(task="easy", reward_profile="full")
29
+ obs = env.reset()
30
+ print(obs.code_snippet)
31
+ ```
32
+
33
+ ## Demo Script (Non-Technical Friendly)
34
+
35
+ Use this 60-90 second flow in a demo:
36
+
37
+ 1. Reset on `easy` and show the buggy snippet.
38
+ 2. Take `inspect` and `run_tests` actions to show evidence gathering.
39
+ 3. Ask `query_docs` once to show retrieval-assisted reasoning.
40
+ 4. Propose a fix and show accepted/denied feedback from the author persona.
41
+ 5. Repeat once on `harder` to show increased challenge.
42
+
43
+ Message for audience: "The agent is learning not only to fix code, but to justify and communicate the fix."
44
+
45
+ ## Environment Endpoints
46
+
47
+ - `POST /reset` – reset environment (optional `task` parameter)
48
+ - `POST /step` take an action (JSON)
49
+ - `GET /state` – get full environment state
50
+ - `GET /health` health check
51
+ - `GET /metadata` environment metadata
52
+ - `GET /schema` action/observation schemas
53
+ - `POST /mcp` minimal MCP endpoint
54
+
55
+ ## Tasks
56
+ ## 🐛 Bug Taxonomy (25 bugs across 5 difficulty levels)
57
+
58
+ The **RedTeam** randomly selects one bug from the current difficulty level at the start of every episode.
59
+ Your agent must figure out what’s broken, gather evidence, and convince the simulated author – or it won’t stick.
60
+
61
+ ### 🟢 EasyNull‑Checks & Simple Logic Errors
62
+
63
+ | # | Bug ID | What’s wrong | Injection method |
64
+ |---|--------|--------------|------------------|
65
+ | 1 | `null_check` | Missing `if key in dict:` guard → KeyError | AST: remove the if‑statement |
66
+ | 2 | `simple_typo` | Misspelled variable `users` → `usres` | AST: rename variable |
67
+ | 3 | `string_index` | String index shifted by +1 | AST: change constant in index |
68
+ | 4 | `default_value` | `dict.get(key)` used without a fallback | AST: replace `dict.get(key)` with `dict[key]` |
69
+ | 5 | `empty_return` | Function returns `None` prematurely | AST: insert `return None` early |
70
+
71
+ ### 🟡 Medium Off‑By‑One, Loop Logic & Simple Arithmetic
72
+
73
+ | # | Bug ID | What’s wrong | Injection method |
74
+ |---|--------|--------------|------------------|
75
+ | 6 | `off_by_one` | `range(x)` becomes `range(1, x-1)` – skips first & last | AST: modify range arguments |
76
+ | 7 | `loop_skip` | `range(len(arr))` becomes `range(len(arr)-1)` misses last element | AST: change range length |
77
+ | 8 | `sign_error` | `sum += item` turned into `sum -= item` | AST: swap Add / Sub |
78
+ | 9 | `swap_args` | Function arguments swapped | AST: swap first two arguments |
79
+ |10 | `uninitialised_var` | Variable used before assignment in a loop | AST: remove the assignment statement |
80
+
81
+ ### 🟠 Hard Division‑By‑Zero, Floating‑Point & Edge Cases
82
+
83
+ | # | Bug ID | What’s wrong | Injection method |
84
+ |---|--------|--------------|------------------|
85
+ |11 | `division_by_zero_empty` | Empty‑list guard removed before averaging | AST: delete `if not data:` |
86
+ |12 | `division_by_zero_zero` | Denominator check removed | AST: remove the zero‑check |
87
+ |13 | `float_precision` | True division `/` replaced by integer division `//` | AST: change Div → FloorDiv |
88
+ |14 | `abs_usage` | `abs()` call removed when comparing differences | AST: delete `abs()` wrapper |
89
+ |15 | `round_error` | `round()` placed too early, causing precision drift | AST: inject `round()` prematurely |
90
+
91
+ ### 🔴 Harder Race Conditions & Atomicity Bugs
92
+
93
+ | # | Bug ID | What’s wrong | Injection method |
94
+ |---|--------|--------------|------------------|
95
+ |16 | `missing_lock` | Shared counter incremented without a lock | Template: remove `with lock:` |
96
+ |17 | `double_lock` | Acquiring the same lock twice → deadlock risk | Template: add extra `lock.acquire()` |
97
+ |18 | `global_nonatomic` | `count = count + 1` (read‑modify‑write) instead of `+=` | AST: modify assignment node |
98
+ |19 | `thread_safe_list` | List append across threads without synchronisation | Template: remove lock from list operation |
99
+ |20 | `volatile_read` | Shared flag read outside a lock → stale value | Template: remove synchronisation block |
100
+
101
+ ### ⚫ Hardest – Deadlocks, Ordering & Complex Concurrency
102
+
103
+ | # | Bug ID | What’s wrong | Injection method |
104
+ |---|--------|--------------|------------------|
105
+ |21 | `deadlock_order` | Locks acquired in opposite order in two threads | Template: swap lock order |
106
+ |22 | `nested_lock_timeout` | `lock.acquire()` without a timeout → permanent hang | Template: remove timeout logic |
107
+ |23 | `fork_join` | Thread started but not joined (`join()` missing) | AST: remove `thread.join()` |
108
+ |24 | `mutex_release` | Lock released by a thread that never acquired it | Template: incorrect release logic |
109
+ |25 | `race_on_init` | Shared resource initialised after threads have started | Template: move initialisation after `join()` |
110
+ ## Deployment
111
+
112
+ ```bash
113
+ openenv push
114
+ ```
115
+
116
+ ## License
117
+
118
+ MIT
__init__.py CHANGED
@@ -1,16 +1,16 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the BSD-style license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- """Criticrl Environment."""
8
-
9
- from .client import CriticrlEnv
10
- from .models import CriticrlAction, CriticrlObservation
11
-
12
- __all__ = [
13
- "CriticrlAction",
14
- "CriticrlObservation",
15
- "CriticrlEnv",
16
- ]
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Criticrl Environment."""
8
+
9
+ from .client import CriticrlEnv
10
+ from .models import CriticrlAction, CriticrlObservation
11
+
12
+ __all__ = [
13
+ "CriticrlAction",
14
+ "CriticrlObservation",
15
+ "CriticrlEnv",
16
+ ]
author.py CHANGED
@@ -1,219 +1,219 @@
1
- # cell 7 author.py – Final production version: stateful, evidence-driven, belief tracking
2
-
3
- import re
4
- import ast
5
- from dataclasses import dataclass, field
6
- from typing import List, Dict, Any, Optional
7
-
8
- @dataclass
9
- class PersonaAuthor:
10
- """
11
- Simulates a human developer with:
12
- - Continuous belief (confidence)
13
- - Evidence-based reasoning
14
- - Conversation memory
15
- - Code inspection awareness
16
- """
17
-
18
- personality: str = "defensive" # defensive | junior | collaborative
19
- max_persuasion_rounds: int = 5
20
-
21
- # Evidence weights
22
- weight_test_pass: float = 0.5
23
- weight_lint_clean: float = 0.2
24
- weight_doc_found: float = 0.15
25
- weight_explanation_quality: float = 0.15
26
-
27
- # Personality thresholds
28
- thresholds: Dict[str, float] = field(default_factory=lambda: {
29
- "defensive": 0.7,
30
- "junior": 0.3,
31
- "collaborative": 0.5,
32
- })
33
-
34
- # Internal state
35
- _confidence: float = 0.0
36
- _conversation: List[Dict[str, Any]] = field(default_factory=list)
37
- _pushback_count: int = 0
38
- _last_evidence_score: float = 0.0
39
- _stagnation_counter: int = 0
40
-
41
- # ------------------------------------------------------------------
42
- # Lifecycle
43
- # ------------------------------------------------------------------
44
- def __post_init__(self):
45
- self.reset()
46
-
47
- def reset(self):
48
- self._confidence = 0.0
49
- self._conversation.clear()
50
- self._pushback_count = 0
51
- self._last_evidence_score = 0.0
52
- self._stagnation_counter = 0
53
-
54
- # ------------------------------------------------------------------
55
- # Main interaction
56
- # ------------------------------------------------------------------
57
- # Added weight for code change magnitude
58
- weight_code_change: float = 0.1 # small change is better
59
-
60
- def respond(self,
61
- agent_comment: str = "",
62
- agent_question: str = "",
63
- test_results: Optional[str] = None,
64
- lint_results: Optional[str] = None,
65
- doc_results: Optional[str] = None,
66
- proposed_fix: Optional[str] = None,
67
- original_code: Optional[str] = None) -> str:
68
-
69
- # Store conversation
70
- self._conversation.append({
71
- "comment": agent_comment,
72
- "question": agent_question,
73
- "test": test_results,
74
- "lint": lint_results,
75
- "docs": doc_results
76
- })
77
-
78
- # Extract structured evidence
79
- evidence = self._extract_evidence(test_results, lint_results, doc_results)
80
-
81
- # Code inspection
82
- code_change = 0.0
83
- if proposed_fix and original_code:
84
- code_change = self._inspect_code(proposed_fix, original_code)
85
- evidence["code_change"] = code_change
86
-
87
- # Explanation score
88
- text = (agent_comment + " " + agent_question).lower()
89
- explanation_score = self._score_explanation(text)
90
-
91
- # Compute evidence score – now includes code change penalty (1 - change)
92
- evidence_score = (
93
- self.weight_test_pass * evidence.get("test_pass_ratio", 0.0) +
94
- self.weight_lint_clean * (1 - min(1.0, evidence.get("lint_errors", 0)/10)) +
95
- self.weight_doc_found * (1.0 if evidence.get("doc_found") else 0.0) +
96
- self.weight_explanation_quality * explanation_score +
97
- self.weight_code_change * (1.0 - code_change) # surgical fix rewarded
98
- )
99
-
100
- evidence_score = max(0.0, min(1.0, evidence_score))
101
-
102
- # Detect improvement
103
- delta = evidence_score - self._last_evidence_score
104
- self._last_evidence_score = evidence_score
105
-
106
- if delta > 0.05:
107
- self._stagnation_counter = 0
108
- else:
109
- self._stagnation_counter += 1
110
-
111
- # Update belief (momentum)
112
- lr = 0.3
113
- self._confidence = (1 - lr) * self._confidence + lr * evidence_score
114
-
115
- # Penalise stagnation
116
- if self._stagnation_counter >= 2:
117
- self._confidence *= 0.9
118
-
119
- # Decision
120
- threshold = self.thresholds.get(self.personality, 0.5)
121
-
122
- if self._confidence >= threshold or self._pushback_count >= self.max_persuasion_rounds:
123
- return "Alright, I'm convinced. Let's proceed with your fix."
124
-
125
- # Otherwise push back
126
- self._pushback_count += 1
127
- return self._generate_pushback(evidence, text)
128
-
129
- # ------------------------------------------------------------------
130
- # Evidence extraction
131
- # ------------------------------------------------------------------
132
- def _extract_evidence(self, test_results, lint_results, doc_results):
133
- evidence = {
134
- "test_pass_ratio": 0.0,
135
- "lint_errors": 0,
136
- "doc_found": False
137
- }
138
-
139
- # Parse test results
140
- if test_results:
141
- match = re.search(r'(\d+)\s*/\s*(\d+)', test_results)
142
- if match:
143
- p, t = int(match.group(1)), int(match.group(2))
144
- evidence["test_pass_ratio"] = p / t if t else 0.0
145
- elif "true" in test_results.lower():
146
- evidence["test_pass_ratio"] = 1.0
147
- elif "false" in test_results.lower():
148
- evidence["test_pass_ratio"] = 0.0
149
-
150
- # Lint errors
151
- if lint_results:
152
- evidence["lint_errors"] = len(re.findall(r'error', lint_results.lower()))
153
-
154
- # Docs
155
- if doc_results and "no relevant" not in doc_results.lower():
156
- evidence["doc_found"] = True
157
-
158
- return evidence
159
-
160
- # ------------------------------------------------------------------
161
- # Explanation scoring
162
- # ------------------------------------------------------------------
163
- def _score_explanation(self, text: str) -> float:
164
- score = 0.0
165
-
166
- if "because" in text or "therefore" in text:
167
- score += 0.3
168
- if "test" in text or "example" in text:
169
- score += 0.2
170
- if len(text.split()) > 30:
171
- score += 0.2
172
- if "error" in text or "fix" in text:
173
- score += 0.1
174
-
175
- return min(1.0, score)
176
-
177
- # ------------------------------------------------------------------
178
- # Code inspection
179
- # ------------------------------------------------------------------
180
- def _inspect_code(self, new_code: str, old_code: str) -> float:
181
- try:
182
- t1 = ast.parse(old_code)
183
- t2 = ast.parse(new_code)
184
-
185
- n1 = len(list(ast.walk(t1)))
186
- n2 = len(list(ast.walk(t2)))
187
-
188
- change = abs(n2 - n1) / max(n1, 1)
189
- return min(1.0, change)
190
- except:
191
- return 0.0
192
-
193
- # ------------------------------------------------------------------
194
- # Pushback generator
195
- # ------------------------------------------------------------------
196
- def _generate_pushback(self, evidence, text):
197
- if evidence["test_pass_ratio"] < 0.5:
198
- return "Tests are still failing. Show a passing case."
199
-
200
- if evidence["lint_errors"] > 0:
201
- return f"There are {evidence['lint_errors']} lint errors. Fix them."
202
-
203
- if not evidence["doc_found"]:
204
- return "Provide documentation or reference."
205
-
206
- if "because" not in text:
207
- return "Explain why this works."
208
-
209
- if len(text.split()) < 20:
210
- return "Too brief. Expand your reasoning."
211
-
212
- return "Not convinced yet. Give a concrete example."
213
-
214
- # ------------------------------------------------------------------
215
- # Score
216
- # ------------------------------------------------------------------
217
- def get_negotiation_score(self) -> float:
218
- penalty = 0.1 * min(3, self._pushback_count)
219
- return max(0.0, min(1.0, self._confidence - penalty))
 
1
+ # cell 7 author.py – Final production version: stateful, evidence-driven, belief tracking
2
+
3
+ import re
4
+ import ast
5
+ from dataclasses import dataclass, field
6
+ from typing import List, Dict, Any, Optional
7
+
8
+ @dataclass
9
+ class PersonaAuthor:
10
+ """
11
+ Simulates a human developer with:
12
+ - Continuous belief (confidence)
13
+ - Evidence-based reasoning
14
+ - Conversation memory
15
+ - Code inspection awareness
16
+ """
17
+
18
+ personality: str = "defensive" # defensive | junior | collaborative
19
+ max_persuasion_rounds: int = 5
20
+
21
+ # Evidence weights
22
+ weight_test_pass: float = 0.5
23
+ weight_lint_clean: float = 0.2
24
+ weight_doc_found: float = 0.15
25
+ weight_explanation_quality: float = 0.15
26
+
27
+ # Personality thresholds
28
+ thresholds: Dict[str, float] = field(default_factory=lambda: {
29
+ "defensive": 0.7,
30
+ "junior": 0.3,
31
+ "collaborative": 0.5,
32
+ })
33
+
34
+ # Internal state
35
+ _confidence: float = 0.0
36
+ _conversation: List[Dict[str, Any]] = field(default_factory=list)
37
+ _pushback_count: int = 0
38
+ _last_evidence_score: float = 0.0
39
+ _stagnation_counter: int = 0
40
+
41
+ # ------------------------------------------------------------------
42
+ # Lifecycle
43
+ # ------------------------------------------------------------------
44
+ def __post_init__(self):
45
+ self.reset()
46
+
47
+ def reset(self):
48
+ self._confidence = 0.0
49
+ self._conversation.clear()
50
+ self._pushback_count = 0
51
+ self._last_evidence_score = 0.0
52
+ self._stagnation_counter = 0
53
+
54
+ # ------------------------------------------------------------------
55
+ # Main interaction
56
+ # ------------------------------------------------------------------
57
+ # Added weight for code change magnitude
58
+ weight_code_change: float = 0.1 # small change is better
59
+
60
+ def respond(self,
61
+ agent_comment: str = "",
62
+ agent_question: str = "",
63
+ test_results: Optional[str] = None,
64
+ lint_results: Optional[str] = None,
65
+ doc_results: Optional[str] = None,
66
+ proposed_fix: Optional[str] = None,
67
+ original_code: Optional[str] = None) -> str:
68
+
69
+ # Store conversation
70
+ self._conversation.append({
71
+ "comment": agent_comment,
72
+ "question": agent_question,
73
+ "test": test_results,
74
+ "lint": lint_results,
75
+ "docs": doc_results
76
+ })
77
+
78
+ # Extract structured evidence
79
+ evidence = self._extract_evidence(test_results, lint_results, doc_results)
80
+
81
+ # Code inspection
82
+ code_change = 0.0
83
+ if proposed_fix and original_code:
84
+ code_change = self._inspect_code(proposed_fix, original_code)
85
+ evidence["code_change"] = code_change
86
+
87
+ # Explanation score
88
+ text = (agent_comment + " " + agent_question).lower()
89
+ explanation_score = self._score_explanation(text)
90
+
91
+ # Compute evidence score – now includes code change penalty (1 - change)
92
+ evidence_score = (
93
+ self.weight_test_pass * evidence.get("test_pass_ratio", 0.0) +
94
+ self.weight_lint_clean * (1 - min(1.0, evidence.get("lint_errors", 0)/10)) +
95
+ self.weight_doc_found * (1.0 if evidence.get("doc_found") else 0.0) +
96
+ self.weight_explanation_quality * explanation_score +
97
+ self.weight_code_change * (1.0 - code_change) # surgical fix rewarded
98
+ )
99
+
100
+ evidence_score = max(0.0, min(1.0, evidence_score))
101
+
102
+ # Detect improvement
103
+ delta = evidence_score - self._last_evidence_score
104
+ self._last_evidence_score = evidence_score
105
+
106
+ if delta > 0.05:
107
+ self._stagnation_counter = 0
108
+ else:
109
+ self._stagnation_counter += 1
110
+
111
+ # Update belief (momentum)
112
+ lr = 0.3
113
+ self._confidence = (1 - lr) * self._confidence + lr * evidence_score
114
+
115
+ # Penalise stagnation
116
+ if self._stagnation_counter >= 2:
117
+ self._confidence *= 0.9
118
+
119
+ # Decision
120
+ threshold = self.thresholds.get(self.personality, 0.5)
121
+
122
+ if self._confidence >= threshold or self._pushback_count >= self.max_persuasion_rounds:
123
+ return "Alright, I'm convinced. Let's proceed with your fix."
124
+
125
+ # Otherwise push back
126
+ self._pushback_count += 1
127
+ return self._generate_pushback(evidence, text)
128
+
129
+ # ------------------------------------------------------------------
130
+ # Evidence extraction
131
+ # ------------------------------------------------------------------
132
+ def _extract_evidence(self, test_results, lint_results, doc_results):
133
+ evidence = {
134
+ "test_pass_ratio": 0.0,
135
+ "lint_errors": 0,
136
+ "doc_found": False
137
+ }
138
+
139
+ # Parse test results
140
+ if test_results:
141
+ match = re.search(r'(\d+)\s*/\s*(\d+)', test_results)
142
+ if match:
143
+ p, t = int(match.group(1)), int(match.group(2))
144
+ evidence["test_pass_ratio"] = p / t if t else 0.0
145
+ elif "true" in test_results.lower():
146
+ evidence["test_pass_ratio"] = 1.0
147
+ elif "false" in test_results.lower():
148
+ evidence["test_pass_ratio"] = 0.0
149
+
150
+ # Lint errors
151
+ if lint_results:
152
+ evidence["lint_errors"] = len(re.findall(r'error', lint_results.lower()))
153
+
154
+ # Docs
155
+ if doc_results and "no relevant" not in doc_results.lower():
156
+ evidence["doc_found"] = True
157
+
158
+ return evidence
159
+
160
+ # ------------------------------------------------------------------
161
+ # Explanation scoring
162
+ # ------------------------------------------------------------------
163
+ def _score_explanation(self, text: str) -> float:
164
+ score = 0.0
165
+
166
+ if "because" in text or "therefore" in text:
167
+ score += 0.3
168
+ if "test" in text or "example" in text:
169
+ score += 0.2
170
+ if len(text.split()) > 30:
171
+ score += 0.2
172
+ if "error" in text or "fix" in text:
173
+ score += 0.1
174
+
175
+ return min(1.0, score)
176
+
177
+ # ------------------------------------------------------------------
178
+ # Code inspection
179
+ # ------------------------------------------------------------------
180
+ def _inspect_code(self, new_code: str, old_code: str) -> float:
181
+ try:
182
+ t1 = ast.parse(old_code)
183
+ t2 = ast.parse(new_code)
184
+
185
+ n1 = len(list(ast.walk(t1)))
186
+ n2 = len(list(ast.walk(t2)))
187
+
188
+ change = abs(n2 - n1) / max(n1, 1)
189
+ return min(1.0, change)
190
+ except:
191
+ return 0.0
192
+
193
+ # ------------------------------------------------------------------
194
+ # Pushback generator
195
+ # ------------------------------------------------------------------
196
+ def _generate_pushback(self, evidence, text):
197
+ if evidence["test_pass_ratio"] < 0.5:
198
+ return "Tests are still failing. Show a passing case."
199
+
200
+ if evidence["lint_errors"] > 0:
201
+ return f"There are {evidence['lint_errors']} lint errors. Fix them."
202
+
203
+ if not evidence["doc_found"]:
204
+ return "Provide documentation or reference."
205
+
206
+ if "because" not in text:
207
+ return "Explain why this works."
208
+
209
+ if len(text.split()) < 20:
210
+ return "Too brief. Expand your reasoning."
211
+
212
+ return "Not convinced yet. Give a concrete example."
213
+
214
+ # ------------------------------------------------------------------
215
+ # Score
216
+ # ------------------------------------------------------------------
217
+ def get_negotiation_score(self) -> float:
218
+ penalty = 0.1 * min(3, self._pushback_count)
219
+ return max(0.0, min(1.0, self._confidence - penalty))
bugs.json CHANGED
@@ -1,127 +1,127 @@
1
- {
2
- "easy": {
3
- "null_check": {
4
- "type": "ast",
5
- "bug_type": "null_check",
6
- "oracle_hint": "Add back the if-guard that was removed"
7
- },
8
- "simple_typo": {
9
- "type": "ast",
10
- "bug_type": "simple_typo",
11
- "oracle_hint": "Fix the misspelled variable name"
12
- },
13
- "string_index": {
14
- "type": "ast",
15
- "bug_type": "string_index",
16
- "oracle_hint": "Correct the index offset"
17
- },
18
- "default_value": {
19
- "type": "ast",
20
- "bug_type": "default_value",
21
- "oracle_hint": "Restore dict.get() with proper default"
22
- },
23
- "empty_return": {
24
- "type": "ast",
25
- "bug_type": "empty_return",
26
- "oracle_hint": "Remove the premature return None"
27
- }
28
- },
29
- "medium": {
30
- "off_by_one": {
31
- "type": "ast",
32
- "bug_type": "off_by_one"
33
- },
34
- "loop_skip": {
35
- "type": "ast",
36
- "bug_type": "loop_skip"
37
- },
38
- "sign_error": {
39
- "type": "ast",
40
- "bug_type": "sign_error"
41
- },
42
- "swap_args": {
43
- "type": "ast",
44
- "bug_type": "swap_args"
45
- },
46
- "uninitialised_var": {
47
- "type": "ast",
48
- "bug_type": "uninitialised_var"
49
- }
50
- },
51
- "hard": {
52
- "division_by_zero_empty": {
53
- "type": "ast",
54
- "bug_type": "division_by_zero_empty"
55
- },
56
- "division_by_zero_zero": {
57
- "type": "ast",
58
- "bug_type": "division_by_zero_zero"
59
- },
60
- "float_precision": {
61
- "type": "ast",
62
- "bug_type": "float_precision"
63
- },
64
- "abs_usage": {
65
- "type": "ast",
66
- "bug_type": "abs_usage"
67
- },
68
- "round_error": {
69
- "type": "ast",
70
- "bug_type": "round_error"
71
- }
72
- },
73
- "harder": {
74
- "missing_lock": {
75
- "type": "template",
76
- "buggy": "counter = 0\ndef increment():\n global counter\n counter += 1",
77
- "oracle": "counter = 0\nimport threading\nlock = threading.Lock()\ndef increment():\n global counter\n with lock:\n counter += 1"
78
- },
79
- "double_lock": {
80
- "type": "template",
81
- "buggy": "import threading\nlock = threading.Lock()\ndef do_work():\n lock.acquire()\n lock.acquire()\n print('working')\n lock.release()",
82
- "oracle": "import threading\nlock = threading.Lock()\ndef do_work():\n with lock:\n print('working')"
83
- },
84
- "global_nonatomic": {
85
- "type": "template",
86
- "buggy": "count = 0\ndef add():\n global count\n count = count + 1",
87
- "oracle": "count = 0\ndef add():\n global count\n count += 1"
88
- },
89
- "thread_safe_list": {
90
- "type": "template",
91
- "buggy": "import threading\nitems = []\ndef append_item(item):\n items.append(item)",
92
- "oracle": "import threading\nitems = []\nlock = threading.Lock()\ndef append_item(item):\n with lock:\n items.append(item)"
93
- },
94
- "volatile_read": {
95
- "type": "template",
96
- "buggy": "import threading\nstop = False\ndef worker():\n while not stop:\n pass",
97
- "oracle": "import threading\nstop = False\nlock = threading.Lock()\ndef worker():\n while True:\n with lock:\n if stop:\n break"
98
- }
99
- },
100
- "hardest": {
101
- "deadlock_order": {
102
- "type": "template",
103
- "buggy": "import threading\nlock1 = threading.Lock()\nlock2 = threading.Lock()\ndef thread1():\n with lock1:\n with lock2:\n pass\ndef thread2():\n with lock2:\n with lock1:\n pass",
104
- "oracle": "import threading\nlock1 = threading.Lock()\nlock2 = threading.Lock()\ndef thread1():\n with lock1:\n with lock2:\n pass\ndef thread2():\n with lock1:\n with lock2:\n pass"
105
- },
106
- "nested_lock_timeout": {
107
- "type": "template",
108
- "buggy": "import threading\nlock = threading.Lock()\ndef work():\n lock.acquire()\n # critical section\n lock.release()",
109
- "oracle": "import threading\nlock = threading.Lock()\ndef work():\n if lock.acquire(timeout=1):\n try:\n # critical section\n finally:\n lock.release()"
110
- },
111
- "fork_join": {
112
- "type": "template",
113
- "buggy": "import threading\ndef worker():\n pass\nt = threading.Thread(target=worker)\nt.start()",
114
- "oracle": "import threading\ndef worker():\n pass\nt = threading.Thread(target=worker)\nt.start()\nt.join()"
115
- },
116
- "mutex_release": {
117
- "type": "template",
118
- "buggy": "import threading\nlock = threading.Lock()\ndef thread_A():\n lock.acquire()\n lock.release()\ndef thread_B():\n lock.release()",
119
- "oracle": "import threading\nlock = threading.Lock()\ndef thread_A():\n with lock:\n pass\ndef thread_B():\n with lock:\n pass"
120
- },
121
- "race_on_init": {
122
- "type": "template",
123
- "buggy": "import threading\nitems = []\ndef init():\n global items\n items = [1,2,3]\nt = threading.Thread(target=init)\nt.start()\nprint(items)",
124
- "oracle": "import threading\nitems = []\ndef init():\n global items\n items = [1,2,3]\nt = threading.Thread(target=init)\nt.start()\nt.join()\nprint(items)"
125
- }
126
- }
127
- }
 
1
+ {
2
+ "easy": {
3
+ "null_check": {
4
+ "type": "ast",
5
+ "bug_type": "null_check",
6
+ "oracle_hint": "Add back the if-guard that was removed"
7
+ },
8
+ "simple_typo": {
9
+ "type": "ast",
10
+ "bug_type": "simple_typo",
11
+ "oracle_hint": "Fix the misspelled variable name"
12
+ },
13
+ "string_index": {
14
+ "type": "ast",
15
+ "bug_type": "string_index",
16
+ "oracle_hint": "Correct the index offset"
17
+ },
18
+ "default_value": {
19
+ "type": "ast",
20
+ "bug_type": "default_value",
21
+ "oracle_hint": "Restore dict.get() with proper default"
22
+ },
23
+ "empty_return": {
24
+ "type": "ast",
25
+ "bug_type": "empty_return",
26
+ "oracle_hint": "Remove the premature return None"
27
+ }
28
+ },
29
+ "medium": {
30
+ "off_by_one": {
31
+ "type": "ast",
32
+ "bug_type": "off_by_one"
33
+ },
34
+ "loop_skip": {
35
+ "type": "ast",
36
+ "bug_type": "loop_skip"
37
+ },
38
+ "sign_error": {
39
+ "type": "ast",
40
+ "bug_type": "sign_error"
41
+ },
42
+ "swap_args": {
43
+ "type": "ast",
44
+ "bug_type": "swap_args"
45
+ },
46
+ "uninitialised_var": {
47
+ "type": "ast",
48
+ "bug_type": "uninitialised_var"
49
+ }
50
+ },
51
+ "hard": {
52
+ "division_by_zero_empty": {
53
+ "type": "ast",
54
+ "bug_type": "division_by_zero_empty"
55
+ },
56
+ "division_by_zero_zero": {
57
+ "type": "ast",
58
+ "bug_type": "division_by_zero_zero"
59
+ },
60
+ "float_precision": {
61
+ "type": "ast",
62
+ "bug_type": "float_precision"
63
+ },
64
+ "abs_usage": {
65
+ "type": "ast",
66
+ "bug_type": "abs_usage"
67
+ },
68
+ "round_error": {
69
+ "type": "ast",
70
+ "bug_type": "round_error"
71
+ }
72
+ },
73
+ "harder": {
74
+ "missing_lock": {
75
+ "type": "template",
76
+ "buggy": "counter = 0\ndef increment():\n global counter\n counter += 1",
77
+ "oracle": "counter = 0\nimport threading\nlock = threading.Lock()\ndef increment():\n global counter\n with lock:\n counter += 1"
78
+ },
79
+ "double_lock": {
80
+ "type": "template",
81
+ "buggy": "import threading\nlock = threading.Lock()\ndef do_work():\n lock.acquire()\n lock.acquire()\n print('working')\n lock.release()",
82
+ "oracle": "import threading\nlock = threading.Lock()\ndef do_work():\n with lock:\n print('working')"
83
+ },
84
+ "global_nonatomic": {
85
+ "type": "template",
86
+ "buggy": "count = 0\ndef add():\n global count\n count = count + 1",
87
+ "oracle": "count = 0\ndef add():\n global count\n count += 1"
88
+ },
89
+ "thread_safe_list": {
90
+ "type": "template",
91
+ "buggy": "import threading\nitems = []\ndef append_item(item):\n items.append(item)",
92
+ "oracle": "import threading\nitems = []\nlock = threading.Lock()\ndef append_item(item):\n with lock:\n items.append(item)"
93
+ },
94
+ "volatile_read": {
95
+ "type": "template",
96
+ "buggy": "import threading\nstop = False\ndef worker():\n while not stop:\n pass",
97
+ "oracle": "import threading\nstop = False\nlock = threading.Lock()\ndef worker():\n while True:\n with lock:\n if stop:\n break"
98
+ }
99
+ },
100
+ "hardest": {
101
+ "deadlock_order": {
102
+ "type": "template",
103
+ "buggy": "import threading\nlock1 = threading.Lock()\nlock2 = threading.Lock()\ndef thread1():\n with lock1:\n with lock2:\n pass\ndef thread2():\n with lock2:\n with lock1:\n pass",
104
+ "oracle": "import threading\nlock1 = threading.Lock()\nlock2 = threading.Lock()\ndef thread1():\n with lock1:\n with lock2:\n pass\ndef thread2():\n with lock1:\n with lock2:\n pass"
105
+ },
106
+ "nested_lock_timeout": {
107
+ "type": "template",
108
+ "buggy": "import threading\nlock = threading.Lock()\ndef work():\n lock.acquire()\n # critical section\n lock.release()",
109
+ "oracle": "import threading\nlock = threading.Lock()\ndef work():\n if lock.acquire(timeout=1):\n try:\n # critical section\n finally:\n lock.release()"
110
+ },
111
+ "fork_join": {
112
+ "type": "template",
113
+ "buggy": "import threading\ndef worker():\n pass\nt = threading.Thread(target=worker)\nt.start()",
114
+ "oracle": "import threading\ndef worker():\n pass\nt = threading.Thread(target=worker)\nt.start()\nt.join()"
115
+ },
116
+ "mutex_release": {
117
+ "type": "template",
118
+ "buggy": "import threading\nlock = threading.Lock()\ndef thread_A():\n lock.acquire()\n lock.release()\ndef thread_B():\n lock.release()",
119
+ "oracle": "import threading\nlock = threading.Lock()\ndef thread_A():\n with lock:\n pass\ndef thread_B():\n with lock:\n pass"
120
+ },
121
+ "race_on_init": {
122
+ "type": "template",
123
+ "buggy": "import threading\nitems = []\ndef init():\n global items\n items = [1,2,3]\nt = threading.Thread(target=init)\nt.start()\nprint(items)",
124
+ "oracle": "import threading\nitems = []\ndef init():\n global items\n items = [1,2,3]\nt = threading.Thread(target=init)\nt.start()\nt.join()\nprint(items)"
125
+ }
126
+ }
127
+ }
client.py CHANGED
@@ -1,5 +1,5 @@
1
- # client.py – OpenEnv client entry point
2
- from environment import CodeReviewEnv
3
-
4
- # The OpenEnv framework will import this class as the environment.
5
  __all__ = ["CodeReviewEnv"]
 
1
+ # client.py – OpenEnv client entry point
2
+ from environment import CodeReviewEnv
3
+
4
+ # The OpenEnv framework will import this class as the environment.
5
  __all__ = ["CodeReviewEnv"]
environment.py CHANGED
@@ -1,556 +1,613 @@
1
- # environment.py – FULLY CORRECTED RL Environment (TRUE Markov + Fixed Bugs)
2
-
3
- import sys
4
- import subprocess
5
- import tempfile
6
- import os
7
- import re
8
- from dataclasses import dataclass, field
9
- from typing import Tuple, Dict, Any, Optional, List
10
-
11
- from models import (
12
- AnyAction, WriteComment, ProposeFix, Execute, Inspect,
13
- RunLinter, RunTests, QueryDocs, Skip, Done, AskQuestion,
14
- Observation, Reward, State
15
- )
16
- from redteam import RedTeam
17
- from test_runner import TestRunner
18
- from author import PersonaAuthor
19
- from rltool import ToolBox
20
- from rubrics import (
21
- ToolUsageRubric,
22
- TestDeltaRubric,
23
- LintDeltaRubric,
24
- TerminalSuccessRubric,
25
- ExplorationRubric,
26
- AntiHackingRubric,
27
- StepPenaltyRubric,
28
- )
29
-
30
- # ======================================================================
31
- # FULLY MARKOV OBSERVATION (NOTHING HIDDEN)
32
- # ======================================================================
33
- @dataclass
34
- class EnhancedObservation:
35
- code_snippet: str
36
- last_tool_output: str
37
-
38
- current_test_score: float
39
- current_lint_score: float
40
- negotiation_score: float
41
-
42
- previous_test_score: float
43
- previous_lint_score: float
44
-
45
- author_confidence: float
46
- author_threshold: float
47
-
48
- step: int
49
- max_steps: int
50
- progress_ratio: float
51
-
52
- tests_run: bool
53
- linter_run: bool
54
- docs_queried: bool
55
-
56
- last_action_type: str
57
- action_history: List[str]
58
-
59
- done: bool
60
-
61
- bug_description: str
62
- comments_count: int
63
-
64
- # default fields must be at the very end
65
- author_response: str = ""
66
-
67
-
68
- # ======================================================================
69
- # HELPER FUNCTIONS
70
- # ======================================================================
71
- def execute_code(code: str, timeout_sec: int = 5) -> Tuple[bool, str, str]:
72
- if not code.strip():
73
- return False, "", "Error: Empty code"
74
-
75
- with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False, encoding='utf-8') as f:
76
- f.write(code)
77
- tmp_path = f.name
78
-
79
- try:
80
- result = subprocess.run(
81
- [sys.executable, tmp_path],
82
- capture_output=True,
83
- text=True,
84
- timeout=timeout_sec
85
- )
86
- success = (result.returncode == 0)
87
- return success, result.stdout, result.stderr
88
- except subprocess.TimeoutExpired:
89
- return False, "", f"Timeout after {timeout_sec}s"
90
- except Exception as e:
91
- return False, "", f"Execution error: {str(e)}"
92
- finally:
93
- try:
94
- os.unlink(tmp_path)
95
- except:
96
- pass
97
-
98
-
99
- # ======================================================================
100
- # ENHANCED CODE REVIEW ENVIRONMENT
101
- # ======================================================================
102
- @dataclass
103
- class CodeReviewEnv:
104
- task: str = "easy"
105
- max_steps: int = 10
106
- step_penalty: float = 0.01
107
-
108
- # Curriculum learning
109
- auto_difficulty: bool = False
110
- success_threshold: float = 0.7
111
-
112
- # Reward shaping parameters
113
- delta_weight: float = 0.3
114
- tool_usage_bonus: float = 0.05
115
- diversity_bonus: float = 0.03
116
-
117
- _red_team: Optional[RedTeam] = field(init=False, default=None)
118
- _author: Optional[PersonaAuthor] = field(init=False, default=None)
119
-
120
- _current_code: str = field(init=False, default="")
121
- _current_bug_id: str = field(init=False, default="")
122
- _bug_description: str = field(init=False, default="")
123
- _oracle_fix: str = field(init=False, default="")
124
-
125
- _comments: list = field(init=False, default_factory=list)
126
- _test_results: Optional[str] = field(init=False, default=None)
127
- _lint_results: Optional[str] = field(init=False, default=None)
128
- _doc_results: Optional[str] = field(init=False, default=None)
129
-
130
- _step_count: int = field(init=False, default=0)
131
- _done: bool = field(init=False, default=False)
132
-
133
- # State tracking for dense rewards
134
- _previous_test_score: float = field(init=False, default=0.0)
135
- _previous_lint_score: float = field(init=False, default=0.0)
136
- _current_test_score: float = field(init=False, default=0.0)
137
- _current_lint_score: float = field(init=False, default=0.0)
138
-
139
- # Tool usage tracking
140
- _tests_run: bool = field(init=False, default=False)
141
- _linter_run: bool = field(init=False, default=False)
142
- _docs_queried: bool = field(init=False, default=False)
143
-
144
- # Action history
145
- _action_history: List[str] = field(init=False, default_factory=list)
146
- _last_action_type: str = field(init=False, default="none")
147
-
148
- # FIXED: Track CUMULATIVE episode reward
149
- _episode_total_reward: float = field(init=False, default=0.0)
150
- _episode_rewards: List[float] = field(init=False, default_factory=list)
151
- _difficulty_level: int = field(init=False, default=0)
152
-
153
- # ===================================================================
154
- def __post_init__(self):
155
- self.set_task(self.task)
156
-
157
- # ===================================================================
158
- def set_task(self, task: str):
159
- if task not in ["easy", "medium", "hard", "harder", "hardest"]:
160
- raise ValueError(f"Unknown task: {task}")
161
-
162
- self.task = task
163
- self._red_team = RedTeam(task)
164
- self._author = PersonaAuthor()
165
- self.rubrics = [
166
- TestDeltaRubric(weight=self.delta_weight),
167
- LintDeltaRubric(weight=self.delta_weight),
168
- ToolUsageRubric(bonus=self.tool_usage_bonus),
169
- TerminalSuccessRubric(),
170
- ExplorationRubric(penalty=-0.05, bonus=self.diversity_bonus * 0.7),
171
- AntiHackingRubric(),
172
- StepPenaltyRubric(penalty=self.step_penalty),
173
- ]
174
-
175
- task_to_level = {
176
- "easy": 0, "medium": 1, "hard": 2,
177
- "harder": 3, "hardest": 4
178
- }
179
- self._difficulty_level = task_to_level[task]
180
-
181
- self._reset_internal()
182
-
183
- # ===================================================================
184
- def _reset_internal(self):
185
- self._step_count = 0 # FIXED
186
- self._comments = []
187
- self._test_results = None
188
- self._lint_results = None
189
- self._doc_results = None
190
- self._done = False
191
-
192
- # Reset state tracking
193
- self._previous_test_score = 0.0
194
- self._previous_lint_score = 0.0
195
- self._current_test_score = 0.0
196
- self._current_lint_score = 0.0
197
-
198
- self._tests_run = False
199
- self._linter_run = False
200
- self._docs_queried = False
201
-
202
- self._action_history = []
203
- self._last_action_type = "none"
204
-
205
- # FIXED: Reset episode cumulative reward
206
- self._episode_total_reward = 0.0
207
-
208
- self._author.reset()
209
-
210
- # Base tasks
211
- if self.task == "easy":
212
- original = "def get_user(id):\n if id in users:\n return users[id]"
213
- elif self.task == "medium":
214
- original = "def process_items(items):\n for item in items:\n print(item)"
215
- elif self.task == "hard":
216
- original = "def average(data):\n if not data:\n return 0\n return sum(data) / len(data)"
217
- elif self.task == "harder":
218
- original = "counter = 0\ndef increment():\n global counter\n with lock:\n counter += 1"
219
- else:
220
- original = "def safe_work():\n with lock1:\n with lock2:\n do_work()"
221
-
222
- buggy_code, bug_id, desc, oracle = self._red_team.inject_bug(original)
223
- self._current_code = buggy_code
224
- self._current_bug_id = bug_id
225
- self._bug_description = desc
226
- self._oracle_fix = oracle
227
- self._comments.append(f"[RedTeam] {desc}")
228
-
229
- # ===================================================================
230
- def reset(self) -> EnhancedObservation:
231
- """Reset with optional curriculum adjustment."""
232
- if self.auto_difficulty and len(self._episode_rewards) > 0:
233
- recent_performance = sum(self._episode_rewards[-5:]) / min(5, len(self._episode_rewards))
234
-
235
- if recent_performance > self.success_threshold and self._difficulty_level < 4:
236
- self._difficulty_level += 1
237
- print(f"[Curriculum] Increasing difficulty to level {self._difficulty_level}")
238
- elif recent_performance < 0.3 and self._difficulty_level > 0:
239
- self._difficulty_level -= 1
240
- print(f"[Curriculum] Decreasing difficulty to level {self._difficulty_level}")
241
-
242
- level_to_task = {0: "easy", 1: "medium", 2: "hard", 3: "harder", 4: "hardest"}
243
- self.task = level_to_task[self._difficulty_level]
244
- self._red_team = RedTeam(self.task)
245
-
246
- self._reset_internal()
247
- return self._get_observation()
248
-
249
- # ===================================================================
250
- def _get_observation(self) -> EnhancedObservation:
251
- """Return COMPLETE Markov state."""
252
- # Compute author response: only after comment/question/fix does the author actually speak
253
- if self._last_action_type in ("write_comment", "ask_question", "propose_fix"):
254
- author_response = self._test_results or ""
255
- else:
256
- author_response = ""
257
-
258
- return EnhancedObservation(
259
- code_snippet=self._current_code,
260
- last_tool_output=self._test_results or "",
261
- author_response=author_response, # now field exists
262
-
263
- current_test_score=self._current_test_score,
264
- current_lint_score=self._current_lint_score,
265
- negotiation_score=self._author.get_negotiation_score(),
266
-
267
- previous_test_score=self._previous_test_score,
268
- previous_lint_score=self._previous_lint_score,
269
-
270
- author_confidence=self._author._confidence,
271
- author_threshold=self._author.thresholds.get(self._author.personality, 0.5),
272
-
273
- step=self._step_count,
274
- max_steps=self.max_steps,
275
- progress_ratio=self._step_count / self.max_steps,
276
-
277
- tests_run=self._tests_run,
278
- linter_run=self._linter_run,
279
- docs_queried=self._docs_queried,
280
-
281
- last_action_type=self._last_action_type,
282
- action_history=self._action_history[-5:],
283
-
284
- done=self._done,
285
-
286
- bug_description=self._bug_description,
287
- comments_count=len(self._comments),
288
- )
289
-
290
- # ===================================================================
291
- def _get_action_type(self, action: AnyAction) -> str:
292
- """Extract action type as string."""
293
- if isinstance(action, RunTests):
294
- return "run_tests"
295
- elif isinstance(action, RunLinter):
296
- return "run_linter"
297
- elif isinstance(action, QueryDocs):
298
- return "query_docs"
299
- elif isinstance(action, Execute):
300
- return "execute"
301
- elif isinstance(action, Inspect):
302
- return "inspect"
303
- elif isinstance(action, WriteComment):
304
- return "write_comment"
305
- elif isinstance(action, AskQuestion):
306
- return "ask_question"
307
- elif isinstance(action, ProposeFix):
308
- return "propose_fix"
309
- elif isinstance(action, Done):
310
- return "done"
311
- elif isinstance(action, Skip):
312
- return "skip"
313
- else:
314
- return "unknown"
315
-
316
- # ===================================================================
317
- def step(self, action: AnyAction) -> Tuple[EnhancedObservation, Reward, bool, Dict[str, Any]]:
318
- """
319
- TRUE RL STEP with:
320
- - Complete Markov observations (no hidden state)
321
- - Dense intermediate rewards
322
- - Delta-based credit assignment (no double-counting)
323
- - Proper episode reward tracking
324
- """
325
- if self._done:
326
- raise RuntimeError("Episode already finished")
327
-
328
- # Store previous metrics for delta computation
329
- self._previous_test_score = self._current_test_score
330
- self._previous_lint_score = self._current_lint_score
331
-
332
- base_reward = 0.0
333
- action_type = self._get_action_type(action)
334
-
335
- # Update action history
336
- self._action_history.append(action_type)
337
- self._last_action_type = action_type
338
-
339
- # ==============================================================
340
- # TOOL ACTIONS
341
- # ==============================================================
342
- if isinstance(action, Execute):
343
- success, stdout, stderr = execute_code(self._current_code)
344
- output = (stdout + stderr).strip() or "No output"
345
- self._test_results = f"[Execute] {'Success' if success else 'Failed'}\n{output[:300]}"
346
- base_reward = 0.001 if success else -0.05
347
-
348
- elif isinstance(action, Inspect):
349
- self._test_results = f"[Inspect]\n{self._current_code[:500]}"
350
- base_reward = 0.001
351
-
352
- elif isinstance(action, RunLinter):
353
- lint_output = ToolBox.run_linter(self._current_code)
354
- self._lint_results = lint_output[:500]
355
- self._test_results = f"[Linter]\n{self._lint_results}"
356
-
357
- self._current_lint_score = self._run_linter_score(self._current_code)
358
- self._linter_run = True
359
- base_reward = 0.002
360
-
361
- elif isinstance(action, RunTests):
362
- runner = TestRunner(self._current_bug_id)
363
- score, output = runner.run_tests(self._current_code)
364
-
365
- self._current_test_score = score
366
- self._tests_run = True
367
-
368
- self._test_results = f"[Tests] Score: {score:.2f}\n{output[:300]}"
369
- base_reward = 0.002
370
-
371
- if score > 0.8:
372
- base_reward += 0.005
373
-
374
- elif isinstance(action, QueryDocs):
375
- doc = ToolBox.query_docs(action.query_topic)
376
- self._doc_results = doc
377
- self._test_results = f"[Docs]\n{doc[:400]}"
378
- self._docs_queried = True
379
- base_reward = 0.001
380
-
381
- # ==============================================================
382
- # COMMUNICATION ACTIONS
383
- # ==============================================================
384
- elif isinstance(action, WriteComment):
385
- self._comments.append(f"Agent: {action.comment_text}")
386
-
387
- response = self._author.respond(
388
- agent_comment=action.comment_text,
389
- test_results=self._test_results,
390
- lint_results=self._lint_results,
391
- doc_results=self._doc_results,
392
- proposed_fix=None,
393
- original_code=self._current_code
394
- )
395
-
396
- self._comments.append(f"Author: {response}")
397
- self._test_results = f"[Comment] Author: {response[:200]}"
398
- base_reward = 0.001
399
-
400
- elif isinstance(action, AskQuestion):
401
- self._comments.append(f"Agent: {action.question}")
402
-
403
- response = self._author.respond(
404
- agent_question=action.question,
405
- test_results=self._test_results,
406
- lint_results=self._lint_results,
407
- doc_results=self._doc_results,
408
- proposed_fix=None,
409
- original_code=self._current_code # FIXED
410
- )
411
-
412
- self._comments.append(f"Author: {response}")
413
- self._test_results = f"[Question] Author: {response[:200]}"
414
- base_reward = 0.002
415
-
416
- # ==============================================================
417
- # FINAL FIX ACTION
418
- # ==============================================================
419
- elif isinstance(action, ProposeFix):
420
- if not action.fix_code:
421
- base_reward = -0.05
422
- self._done = True
423
- else:
424
- # Save original code BEFORE overwriting (for author.respond)
425
- original_buggy = self._current_code
426
- self._current_code = action.fix_code
427
-
428
- runner = TestRunner(self._current_bug_id)
429
- test_score, test_output = runner.run_tests(self._current_code)
430
- lint_score = self._run_linter_score(self._current_code)
431
- negotiation_score = self._author.get_negotiation_score()
432
-
433
- self._current_test_score = test_score
434
- self._current_lint_score = lint_score
435
-
436
- # Author gating – determines if the episode ends, reward is separate
437
- threshold = self._author.thresholds.get(self._author.personality, 0.5)
438
- if self._author._confidence < threshold:
439
- if self._step_count < self.max_steps:
440
- self._done = False
441
- else:
442
- self._done = True
443
- else:
444
- self._done = True
445
-
446
- # Get author's verbal feedback (pushback/acceptance)
447
- author_feedback = self._author.respond(
448
- agent_comment=f"Proposed fix:\n{action.fix_code}",
449
- test_results=f"Score: {test_score:.2f}",
450
- lint_results=f"Score: {lint_score:.2f}",
451
- doc_results=self._doc_results,
452
- proposed_fix=action.fix_code,
453
- original_code=original_buggy # now correctly the buggy code, not the fix
454
- )
455
- self._test_results = f"[Fix] Author: {author_feedback[:200]}"
456
- self._comments.append(f"Author: {author_feedback}")
457
-
458
- base_reward = 0.001 # rubrics provide the real signal
459
-
460
- # ==============================================================
461
- # TERMINATION ACTIONS
462
- # ==============================================================
463
- elif isinstance(action, Skip):
464
- base_reward = -0.03
465
- self._done = True
466
-
467
- elif isinstance(action, Done):
468
- if self._tests_run:
469
- base_reward = self._current_test_score * 0.5 - 0.2
470
- else:
471
- base_reward = -0.04
472
- self._done = True
473
-
474
- else:
475
- base_reward = -0.02
476
- self._done = True
477
-
478
- # ==============================================================
479
- # STEP UPDATE (before rubric computation so info contains final step)
480
- # ==============================================================
481
- self._step_count += 1
482
- if self._step_count >= self.max_steps:
483
- self._done = True
484
-
485
- # Get fresh observation (needed for rubrics that may read obs)
486
- obs = self._get_observation()
487
-
488
- # Prepare info dict (rubrics may need action_type and deltas)
489
- info = {
490
- "action_type": action_type,
491
- "test_score": self._current_test_score,
492
- "lint_score": self._current_lint_score,
493
- "test_delta": self._current_test_score - self._previous_test_score,
494
- "lint_delta": self._current_lint_score - self._previous_lint_score,
495
- "base_reward": base_reward,
496
- }
497
-
498
- # ==============================================================
499
- # COMPUTE FINAL REWARD USING RUBRICS
500
- # ==============================================================
501
- rubric_score = sum(r(self, action, obs, None, self._done, info) for r in self.rubrics)
502
- final_reward = 0.4 * base_reward + rubric_score
503
- final_reward = max(-1.0, min(1.0, final_reward)) # safety clip
504
-
505
- # Track cumulative episode reward
506
- self._episode_total_reward += final_reward
507
-
508
- # Store episode total if done
509
- if self._done:
510
- self._episode_rewards.append(self._episode_total_reward)
511
-
512
- # Complete info
513
- info["final_reward"] = final_reward
514
- info["episode_total"] = self._episode_total_reward
515
-
516
- return obs, Reward(value=final_reward), self._done, info
517
-
518
- # ===================================================================
519
- def _run_linter_score(self, code: str) -> float:
520
- """Run pylint and return normalized score [0, 1]."""
521
- try:
522
- with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
523
- f.write(code)
524
- tmp_path = f.name
525
-
526
- result = subprocess.run(
527
- ['pylint', tmp_path, '--score=y', '--exit-zero'],
528
- capture_output=True,
529
- text=True,
530
- timeout=5
531
- )
532
-
533
- match = re.search(r"rated at (\d+\.\d+)/10", result.stdout)
534
- if match:
535
- return float(match.group(1)) / 10.0
536
- return 0.0
537
- except:
538
- return 0.0
539
- finally:
540
- try:
541
- os.unlink(tmp_path)
542
- except:
543
- pass
544
-
545
- # ===================================================================
546
- def state(self) -> State:
547
- """Legacy compatibility."""
548
- return State(
549
- pr_title="Code Review",
550
- pr_description=self._bug_description,
551
- code_snippet=self._current_code,
552
- comments=self._comments.copy(),
553
- test_results=self._test_results,
554
- step=self._step_count,
555
- done=self._done
556
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # environment.py – FULLY CORRECTED RL Environment (TRUE Markov + Fixed Bugs)
2
+
3
+ import sys
4
+ import subprocess
5
+ import tempfile
6
+ import os
7
+ import re
8
+ from dataclasses import dataclass, field
9
+ from typing import Tuple, Dict, Any, Optional, List
10
+
11
+ from models import (
12
+ AnyAction, WriteComment, ProposeFix, Execute, Inspect,
13
+ RunLinter, RunTests, QueryDocs, Skip, Done, AskQuestion,
14
+ Observation, Reward, State
15
+ )
16
+ from redteam import RedTeam
17
+ from test_runner import TestRunner
18
+ from author import PersonaAuthor
19
+ from rltool import ToolBox
20
+ from rubrics import (
21
+ ToolUsageRubric,
22
+ TestDeltaRubric,
23
+ LintDeltaRubric,
24
+ TerminalSuccessRubric,
25
+ ExplorationRubric,
26
+ AntiHackingRubric,
27
+ StepPenaltyRubric,
28
+ )
29
+
30
+ # ======================================================================
31
+ # FULLY MARKOV OBSERVATION (NOTHING HIDDEN)
32
+ # ======================================================================
33
+ @dataclass
34
+ class EnhancedObservation:
35
+ code_snippet: str
36
+ last_tool_output: str
37
+
38
+ current_test_score: float
39
+ current_lint_score: float
40
+ negotiation_score: float
41
+
42
+ previous_test_score: float
43
+ previous_lint_score: float
44
+
45
+ author_confidence: float
46
+ author_threshold: float
47
+
48
+ step: int
49
+ max_steps: int
50
+ progress_ratio: float
51
+
52
+ tests_run: bool
53
+ linter_run: bool
54
+ docs_queried: bool
55
+
56
+ last_action_type: str
57
+ action_history: List[str]
58
+
59
+ done: bool
60
+
61
+ bug_description: str
62
+ comments_count: int
63
+
64
+ # default fields must be at the very end
65
+ author_response: str = ""
66
+
67
+ # ======================================================================
68
+ # HELPER FUNCTIONS
69
+ # ======================================================================
70
+ def execute_code(code: str, timeout_sec: int = 5) -> Tuple[bool, str, str]:
71
+ if not code.strip():
72
+ return False, "", "Error: Empty code"
73
+
74
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False, encoding='utf-8') as f:
75
+ f.write(code)
76
+ tmp_path = f.name
77
+
78
+ try:
79
+ result = subprocess.run(
80
+ [sys.executable, tmp_path],
81
+ capture_output=True,
82
+ text=True,
83
+ timeout=timeout_sec
84
+ )
85
+ success = (result.returncode == 0)
86
+ return success, result.stdout, result.stderr
87
+ except subprocess.TimeoutExpired:
88
+ return False, "", f"Timeout after {timeout_sec}s"
89
+ except Exception as e:
90
+ return False, "", f"Execution error: {str(e)}"
91
+ finally:
92
+ try:
93
+ os.unlink(tmp_path)
94
+ except:
95
+ pass
96
+
97
+
98
+ # ======================================================================
99
+ # ENHANCED CODE REVIEW ENVIRONMENT
100
+ # ======================================================================
101
+ @dataclass
102
+ class CodeReviewEnv:
103
+ task: str = "easy"
104
+ max_steps: int = 10
105
+ step_penalty: float = 0.01
106
+ reward_profile: str = "full" # "full" or "core"
107
+
108
+ # Curriculum learning
109
+ auto_difficulty: bool = False
110
+ success_threshold: float = 0.7
111
+
112
+ # Reward shaping parameters
113
+ delta_weight: float = 0.3
114
+ tool_usage_bonus: float = 0.05
115
+ diversity_bonus: float = 0.03
116
+
117
+ _red_team: Optional[RedTeam] = field(init=False, default=None)
118
+ _author: Optional[PersonaAuthor] = field(init=False, default=None)
119
+
120
+ _current_code: str = field(init=False, default="")
121
+ _current_bug_id: str = field(init=False, default="")
122
+ _bug_description: str = field(init=False, default="")
123
+ _oracle_fix: str = field(init=False, default="")
124
+
125
+ _comments: list = field(init=False, default_factory=list)
126
+ _test_results: Optional[str] = field(init=False, default=None)
127
+ _lint_results: Optional[str] = field(init=False, default=None)
128
+ _doc_results: Optional[str] = field(init=False, default=None)
129
+
130
+ _step_count: int = field(init=False, default=0)
131
+ _done: bool = field(init=False, default=False)
132
+
133
+ # State tracking for dense rewards
134
+ _previous_test_score: float = field(init=False, default=0.0)
135
+ _previous_lint_score: float = field(init=False, default=0.0)
136
+ _current_test_score: float = field(init=False, default=0.0)
137
+ _current_lint_score: float = field(init=False, default=0.0)
138
+
139
+ # Tool usage tracking
140
+ _tests_run: bool = field(init=False, default=False)
141
+ _linter_run: bool = field(init=False, default=False)
142
+ _docs_queried: bool = field(init=False, default=False)
143
+
144
+ # Action history
145
+ _action_history: List[str] = field(init=False, default_factory=list)
146
+ _last_action_type: str = field(init=False, default="none")
147
+ _last_author_response: str = field(init=False, default="")
148
+
149
+ # FIXED: Track CUMULATIVE episode reward
150
+ _episode_total_reward: float = field(init=False, default=0.0)
151
+ _episode_rewards: List[float] = field(init=False, default_factory=list)
152
+ _difficulty_level: int = field(init=False, default=0)
153
+
154
+ # Bug-id bridge:
155
+ # RedTeam has fine-grained IDs, while TestRunner currently expects a
156
+ # smaller canonical set. Keep this mapping here so both modules can evolve
157
+ # independently without breaking evaluation.
158
+ _BUG_ID_CANONICAL_MAP = {
159
+ "division_by_zero_empty": "division_by_zero",
160
+ "division_by_zero_zero": "division_by_zero",
161
+ "sign_error": "wrong_operator",
162
+ }
163
+
164
+ # ===================================================================
165
+ def __post_init__(self):
166
+ self.set_task(self.task)
167
+
168
+ # ===================================================================
169
+ def _build_rubrics(self):
170
+ """
171
+ Build rubric stack from a named reward profile.
172
+ - full: richer shaping for exploration/tool-use behavior
173
+ - core: minimal stable signal for quick ablations/baselines
174
+ """
175
+ core_rubrics = [
176
+ TestDeltaRubric(weight=self.delta_weight),
177
+ LintDeltaRubric(weight=self.delta_weight),
178
+ TerminalSuccessRubric(),
179
+ StepPenaltyRubric(penalty=self.step_penalty),
180
+ ]
181
+ if self.reward_profile == "core":
182
+ return core_rubrics
183
+ if self.reward_profile == "full":
184
+ return [
185
+ *core_rubrics[:-1], # step penalty appended at end for consistent ordering
186
+ ToolUsageRubric(bonus=self.tool_usage_bonus),
187
+ ExplorationRubric(penalty=-0.05, bonus=self.diversity_bonus * 0.7),
188
+ AntiHackingRubric(),
189
+ core_rubrics[-1],
190
+ ]
191
+ raise ValueError(f"Unknown reward_profile: {self.reward_profile}")
192
+
193
+ # ===================================================================
194
+ def set_task(self, task: str):
195
+ if task not in ["easy", "medium", "hard", "harder", "hardest"]:
196
+ raise ValueError(f"Unknown task: {task}")
197
+
198
+ self.task = task
199
+ # Use stochastic bug sampling across episodes; fixed seed here would
200
+ # repeatedly select the same bug and weaken training diversity.
201
+ self._red_team = RedTeam(task, seed=None)
202
+ self._author = PersonaAuthor()
203
+ self.rubrics = self._build_rubrics()
204
+
205
+ task_to_level = {
206
+ "easy": 0, "medium": 1, "hard": 2,
207
+ "harder": 3, "hardest": 4
208
+ }
209
+ self._difficulty_level = task_to_level[task]
210
+
211
+ self._reset_internal()
212
+
213
+ # ===================================================================
214
+ def _reset_internal(self):
215
+ self._step_count = 0 # ← FIXED
216
+ self._comments = []
217
+ self._test_results = None
218
+ self._lint_results = None
219
+ self._doc_results = None
220
+ self._done = False
221
+
222
+ # Reset state tracking
223
+ self._previous_test_score = 0.0
224
+ self._previous_lint_score = 0.0
225
+ self._current_test_score = 0.0
226
+ self._current_lint_score = 0.0
227
+
228
+ self._tests_run = False
229
+ self._linter_run = False
230
+ self._docs_queried = False
231
+
232
+ self._action_history = []
233
+ self._last_action_type = "none"
234
+ self._last_author_response = ""
235
+
236
+ # FIXED: Reset episode cumulative reward
237
+ self._episode_total_reward = 0.0
238
+
239
+ self._author.reset()
240
+
241
+ # Base tasks
242
+ if self.task == "easy":
243
+ original = "def get_user(id):\n if id in users:\n return users[id]"
244
+ elif self.task == "medium":
245
+ original = "def process_items(items):\n for item in items:\n print(item)"
246
+ elif self.task == "hard":
247
+ original = "def average(data):\n if not data:\n return 0\n return sum(data) / len(data)"
248
+ elif self.task == "harder":
249
+ original = "counter = 0\ndef increment():\n global counter\n with lock:\n counter += 1"
250
+ else:
251
+ original = "def safe_work():\n with lock1:\n with lock2:\n do_work()"
252
+
253
+ buggy_code, bug_id, desc, oracle = self._red_team.inject_bug(original)
254
+ self._current_code = buggy_code
255
+ self._current_bug_id = bug_id
256
+ self._bug_description = desc
257
+ self._oracle_fix = oracle
258
+ self._comments.append(f"[RedTeam] {desc}")
259
+
260
+ # ===================================================================
261
+ def reset(self) -> EnhancedObservation:
262
+ """Reset with optional curriculum adjustment."""
263
+ if self.auto_difficulty and len(self._episode_rewards) > 0:
264
+ recent_performance = sum(self._episode_rewards[-5:]) / min(5, len(self._episode_rewards))
265
+
266
+ if recent_performance > self.success_threshold and self._difficulty_level < 4:
267
+ self._difficulty_level += 1
268
+ print(f"[Curriculum] Increasing difficulty to level {self._difficulty_level}")
269
+ elif recent_performance < 0.3 and self._difficulty_level > 0:
270
+ self._difficulty_level -= 1
271
+ print(f"[Curriculum] Decreasing difficulty to level {self._difficulty_level}")
272
+
273
+ level_to_task = {0: "easy", 1: "medium", 2: "hard", 3: "harder", 4: "hardest"}
274
+ self.task = level_to_task[self._difficulty_level]
275
+ # Keep curriculum stochastic for better coverage within each level.
276
+ self._red_team = RedTeam(self.task, seed=None)
277
+
278
+ self._reset_internal()
279
+ return self._get_observation()
280
+
281
+ # ===================================================================
282
+ def _get_observation(self) -> EnhancedObservation:
283
+ """Return COMPLETE Markov state."""
284
+ # Keep the author's message separate from tool output.
285
+ # Using `_test_results` here can leak unrelated outputs (tests/linter/docs)
286
+ # and gives the policy a noisy signal for dialogue actions.
287
+ if self._last_action_type in ("write_comment", "ask_question", "propose_fix"):
288
+ author_response = self._last_author_response
289
+ else:
290
+ author_response = ""
291
+
292
+ return EnhancedObservation(
293
+ code_snippet=self._current_code,
294
+ last_tool_output=self._test_results or "",
295
+ author_response=author_response, # ← now field exists
296
+
297
+ current_test_score=self._current_test_score,
298
+ current_lint_score=self._current_lint_score,
299
+ negotiation_score=self._author.get_negotiation_score(),
300
+
301
+ previous_test_score=self._previous_test_score,
302
+ previous_lint_score=self._previous_lint_score,
303
+
304
+ author_confidence=self._author._confidence,
305
+ author_threshold=self._author.thresholds.get(self._author.personality, 0.5),
306
+
307
+ step=self._step_count,
308
+ max_steps=self.max_steps,
309
+ # Guard against accidental `max_steps=0` configs.
310
+ progress_ratio=(self._step_count / self.max_steps) if self.max_steps > 0 else 1.0,
311
+
312
+ tests_run=self._tests_run,
313
+ linter_run=self._linter_run,
314
+ docs_queried=self._docs_queried,
315
+
316
+ last_action_type=self._last_action_type,
317
+ action_history=self._action_history[-5:],
318
+
319
+ done=self._done,
320
+
321
+ bug_description=self._bug_description,
322
+ comments_count=len(self._comments),
323
+ )
324
+
325
+ # ===================================================================
326
+ def _get_action_type(self, action: AnyAction) -> str:
327
+ """Extract action type as string."""
328
+ if isinstance(action, RunTests):
329
+ return "run_tests"
330
+ elif isinstance(action, RunLinter):
331
+ return "run_linter"
332
+ elif isinstance(action, QueryDocs):
333
+ return "query_docs"
334
+ elif isinstance(action, Execute):
335
+ return "execute"
336
+ elif isinstance(action, Inspect):
337
+ return "inspect"
338
+ elif isinstance(action, WriteComment):
339
+ return "write_comment"
340
+ elif isinstance(action, AskQuestion):
341
+ return "ask_question"
342
+ elif isinstance(action, ProposeFix):
343
+ return "propose_fix"
344
+ elif isinstance(action, Done):
345
+ return "done"
346
+ elif isinstance(action, Skip):
347
+ return "skip"
348
+ else:
349
+ return "unknown"
350
+
351
+ # ===================================================================
352
+ def _get_test_runner_bug_id(self) -> str:
353
+ """
354
+ Normalize RedTeam bug ids to the canonical ids understood by TestRunner.
355
+ Falls back to the original id for known direct matches.
356
+ """
357
+ return self._BUG_ID_CANONICAL_MAP.get(self._current_bug_id, self._current_bug_id)
358
+
359
+ # ===================================================================
360
+ def step(self, action: AnyAction) -> Tuple[EnhancedObservation, Reward, bool, Dict[str, Any]]:
361
+ """
362
+ TRUE RL STEP with:
363
+ - Complete Markov observations (no hidden state)
364
+ - Dense intermediate rewards
365
+ - Delta-based credit assignment (no double-counting)
366
+ - Proper episode reward tracking
367
+ """
368
+ if self._done:
369
+ raise RuntimeError("Episode already finished")
370
+
371
+ # Store previous metrics for delta computation
372
+ self._previous_test_score = self._current_test_score
373
+ self._previous_lint_score = self._current_lint_score
374
+ # Snapshot tool-usage flags BEFORE action mutates them.
375
+ # Rubrics use these to detect true "first-use" behavior.
376
+ prev_tests_run = self._tests_run
377
+ prev_linter_run = self._linter_run
378
+ prev_docs_queried = self._docs_queried
379
+
380
+ base_reward = 0.0
381
+ action_type = self._get_action_type(action)
382
+
383
+ # Update action history
384
+ self._action_history.append(action_type)
385
+ self._last_action_type = action_type
386
+
387
+ # ==============================================================
388
+ # TOOL ACTIONS
389
+ # ==============================================================
390
+ if isinstance(action, Execute):
391
+ success, stdout, stderr = execute_code(self._current_code)
392
+ output = (stdout + stderr).strip() or "No output"
393
+ self._test_results = f"[Execute] {'Success' if success else 'Failed'}\n{output[:300]}"
394
+ base_reward = 0.001 if success else -0.05
395
+
396
+ elif isinstance(action, Inspect):
397
+ self._test_results = f"[Inspect]\n{self._current_code[:500]}"
398
+ base_reward = 0.001
399
+
400
+ elif isinstance(action, RunLinter):
401
+ lint_output = ToolBox.run_linter(self._current_code)
402
+ self._lint_results = lint_output[:500]
403
+ self._test_results = f"[Linter]\n{self._lint_results}"
404
+
405
+ self._current_lint_score = self._run_linter_score(self._current_code)
406
+ self._linter_run = True
407
+ base_reward = 0.002
408
+
409
+ elif isinstance(action, RunTests):
410
+ runner = TestRunner(self._get_test_runner_bug_id())
411
+ score, output = runner.run_tests(self._current_code)
412
+
413
+ self._current_test_score = score
414
+ self._tests_run = True
415
+
416
+ self._test_results = f"[Tests] Score: {score:.2f}\n{output[:300]}"
417
+ base_reward = 0.002
418
+
419
+ if score > 0.8:
420
+ base_reward += 0.005
421
+
422
+ elif isinstance(action, QueryDocs):
423
+ # Normalize query to avoid rewarding empty/noisy requests.
424
+ query_topic = (action.query_topic or "").strip()
425
+ doc = ToolBox.query_docs(query_topic if query_topic else "general bug fixing")
426
+ self._doc_results = doc
427
+ self._test_results = f"[Docs]\n{doc[:400]}"
428
+ self._docs_queried = True
429
+ base_reward = 0.001
430
+
431
+ # ==============================================================
432
+ # COMMUNICATION ACTIONS
433
+ # ==============================================================
434
+ elif isinstance(action, WriteComment):
435
+ self._comments.append(f"Agent: {action.comment_text}")
436
+
437
+ response = self._author.respond(
438
+ agent_comment=action.comment_text,
439
+ test_results=self._test_results,
440
+ lint_results=self._lint_results,
441
+ doc_results=self._doc_results,
442
+ proposed_fix=None,
443
+ original_code=self._current_code
444
+ )
445
+
446
+ self._comments.append(f"Author: {response}")
447
+ self._last_author_response = response
448
+ self._test_results = f"[Comment] Author: {response[:200]}"
449
+ base_reward = 0.001
450
+
451
+ elif isinstance(action, AskQuestion):
452
+ self._comments.append(f"Agent: {action.question}")
453
+
454
+ response = self._author.respond(
455
+ agent_question=action.question,
456
+ test_results=self._test_results,
457
+ lint_results=self._lint_results,
458
+ doc_results=self._doc_results,
459
+ proposed_fix=None,
460
+ original_code=self._current_code # ← FIXED
461
+ )
462
+
463
+ self._comments.append(f"Author: {response}")
464
+ self._last_author_response = response
465
+ self._test_results = f"[Question] Author: {response[:200]}"
466
+ base_reward = 0.002
467
+
468
+ # ==============================================================
469
+ # FINAL FIX ACTION
470
+ # ==============================================================
471
+ elif isinstance(action, ProposeFix):
472
+ if not action.fix_code:
473
+ base_reward = -0.05
474
+ self._done = True
475
+ else:
476
+ # Save original code BEFORE overwriting (for author.respond)
477
+ original_buggy = self._current_code
478
+ self._current_code = action.fix_code
479
+
480
+ runner = TestRunner(self._get_test_runner_bug_id())
481
+ test_score, test_output = runner.run_tests(self._current_code)
482
+ lint_score = self._run_linter_score(self._current_code)
483
+ negotiation_score = self._author.get_negotiation_score()
484
+
485
+ self._current_test_score = test_score
486
+ self._current_lint_score = lint_score
487
+
488
+ # Author gating determines if the episode ends, reward is separate
489
+ threshold = self._author.thresholds.get(self._author.personality, 0.5)
490
+ if self._author._confidence < threshold:
491
+ if self._step_count < self.max_steps:
492
+ self._done = False
493
+ else:
494
+ self._done = True
495
+ else:
496
+ self._done = True
497
+
498
+ # Get author's verbal feedback (pushback/acceptance)
499
+ author_feedback = self._author.respond(
500
+ agent_comment=f"Proposed fix:\n{action.fix_code}",
501
+ test_results=f"Score: {test_score:.2f}",
502
+ lint_results=f"Score: {lint_score:.2f}",
503
+ doc_results=self._doc_results,
504
+ proposed_fix=action.fix_code,
505
+ original_code=original_buggy # now correctly the buggy code, not the fix
506
+ )
507
+ self._test_results = f"[Fix] Author: {author_feedback[:200]}"
508
+ self._comments.append(f"Author: {author_feedback}")
509
+ self._last_author_response = author_feedback
510
+
511
+ base_reward = 0.001 # rubrics provide the real signal
512
+
513
+ # ==============================================================
514
+ # TERMINATION ACTIONS
515
+ # ==============================================================
516
+ elif isinstance(action, Skip):
517
+ base_reward = -0.03
518
+ self._done = True
519
+
520
+ elif isinstance(action, Done):
521
+ if self._tests_run:
522
+ base_reward = self._current_test_score * 0.5 - 0.2
523
+ else:
524
+ base_reward = -0.04
525
+ self._done = True
526
+
527
+ else:
528
+ base_reward = -0.02
529
+ self._done = True
530
+
531
+ # ==============================================================
532
+ # STEP UPDATE (before rubric computation so info contains final step)
533
+ # ==============================================================
534
+ self._step_count += 1
535
+ if self._step_count >= self.max_steps:
536
+ self._done = True
537
+
538
+ # Get fresh observation (needed for rubrics that may read obs)
539
+ obs = self._get_observation()
540
+
541
+ # Prepare info dict (rubrics may need action_type and deltas)
542
+ info = {
543
+ "action_type": action_type,
544
+ "test_score": self._current_test_score,
545
+ "lint_score": self._current_lint_score,
546
+ "test_delta": self._current_test_score - self._previous_test_score,
547
+ "lint_delta": self._current_lint_score - self._previous_lint_score,
548
+ "prev_tests_run": prev_tests_run,
549
+ "prev_linter_run": prev_linter_run,
550
+ "prev_docs_queried": prev_docs_queried,
551
+ "docs_query_len": len((action.query_topic or "").strip()) if isinstance(action, QueryDocs) else 0,
552
+ "base_reward": base_reward,
553
+ }
554
+
555
+ # ==============================================================
556
+ # COMPUTE FINAL REWARD USING RUBRICS
557
+ # ==============================================================
558
+ rubric_score = sum(r(self, action, obs, None, self._done, info) for r in self.rubrics)
559
+ final_reward = 0.4 * base_reward + rubric_score
560
+ final_reward = max(-1.0, min(1.0, final_reward)) # safety clip
561
+
562
+ # Track cumulative episode reward
563
+ self._episode_total_reward += final_reward
564
+
565
+ # Store episode total if done
566
+ if self._done:
567
+ self._episode_rewards.append(self._episode_total_reward)
568
+
569
+ # Complete info
570
+ info["final_reward"] = final_reward
571
+ info["episode_total"] = self._episode_total_reward
572
+
573
+ return obs, Reward(value=final_reward), self._done, info
574
+
575
+ # ===================================================================
576
+ def _run_linter_score(self, code: str) -> float:
577
+ """Run pylint and return normalized score [0, 1]."""
578
+ try:
579
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
580
+ f.write(code)
581
+ tmp_path = f.name
582
+
583
+ result = subprocess.run(
584
+ ['pylint', tmp_path, '--score=y', '--exit-zero'],
585
+ capture_output=True,
586
+ text=True,
587
+ timeout=5
588
+ )
589
+
590
+ match = re.search(r"rated at (\d+\.\d+)/10", result.stdout)
591
+ if match:
592
+ return float(match.group(1)) / 10.0
593
+ return 0.0
594
+ except:
595
+ return 0.0
596
+ finally:
597
+ try:
598
+ os.unlink(tmp_path)
599
+ except:
600
+ pass
601
+
602
+ # ===================================================================
603
+ def state(self) -> State:
604
+ """Legacy compatibility."""
605
+ return State(
606
+ pr_title="Code Review",
607
+ pr_description=self._bug_description,
608
+ code_snippet=self._current_code,
609
+ comments=self._comments.copy(),
610
+ test_results=self._test_results,
611
+ step=self._step_count,
612
+ done=self._done
613
+ )
grader.py CHANGED
@@ -1,148 +1,142 @@
1
- # grader.py – Production‑grade, continuous reward, exploit‑aware
2
- import ast
3
- import subprocess
4
- import tempfile
5
- import os
6
- import re
7
- import sys
8
- import json
9
- from dataclasses import dataclass
10
- from typing import Optional
11
-
12
- @dataclass
13
- class RigorousGrader:
14
- bug_id: str
15
- oracle_code: Optional[str] = None
16
-
17
- def grade_fix(self, proposed_fix: str) -> float:
18
- """Returns a smooth reward in [0,1]."""
19
- # Syntax check
20
- try:
21
- ast.parse(proposed_fix)
22
- except SyntaxError:
23
- return 0.0
24
-
25
- # Exploit detection (optional)
26
- if self._is_exploit(proposed_fix):
27
- return 0.0
28
-
29
- # Continuous test score
30
- test_score = self._run_continuous_tests(proposed_fix)
31
-
32
- # Lint score
33
- lint_score = self._get_lint_score(proposed_fix)
34
-
35
- # Oracle similarity
36
- oracle_score = self._ast_similarity(proposed_fix) if self.oracle_code else 0.0
37
-
38
- # Weighted combination
39
- final = (0.5 * test_score) + (0.3 * lint_score) + (0.2 * oracle_score)
40
- return max(0.0, min(1.0, final))
41
-
42
- def _run_continuous_tests(self, code: str) -> float:
43
- """Proportion of passed test cases."""
44
- test_cases = self._get_test_cases()
45
- if not test_cases:
46
- return 0.0
47
- passed = 0
48
- for test_input, expected in test_cases:
49
- if self._run_single_test(code, test_input, expected):
50
- passed += 1
51
- return passed / len(test_cases)
52
-
53
- def _get_test_cases(self) -> list:
54
- """Define multiple test cases per bug type."""
55
- if self.bug_id == "null_check":
56
- return [
57
- ({"users": {"alice": "Alice"}, "id": "bob"}, None),
58
- ({"users": {"alice": "Alice"}, "id": "alice"}, "Alice"),
59
- ]
60
- elif self.bug_id == "off_by_one":
61
- return [
62
- ([1, 2, 3, 4], 4),
63
- ([], 0),
64
- ]
65
- # Extend for other bugs …
66
- return []
67
-
68
- def _run_single_test(self, code: str, test_input, expected) -> bool:
69
- """Execute the agent's code with test_input and compare to expected."""
70
- harness = f"""
71
- import json
72
- {code}
73
- try:
74
- result = fix({json.dumps(test_input)})
75
- print(json.dumps({{"ok": True, "result": result}}))
76
- except Exception as e:
77
- print(json.dumps({{"ok": False, "error": str(e)}}))
78
- """
79
- with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
80
- f.write(harness)
81
- tmp_path = f.name
82
- try:
83
- result = subprocess.run(
84
- [sys.executable, tmp_path],
85
- capture_output=True, text=True, timeout=5
86
- )
87
- data = json.loads(result.stdout.strip())
88
- if data.get("ok") and data["result"] == expected:
89
- return True
90
- return False
91
- except (json.JSONDecodeError, subprocess.TimeoutExpired, Exception):
92
- return False
93
- finally:
94
- try:
95
- os.unlink(tmp_path)
96
- except:
97
- pass
98
-
99
- def _is_exploit(self, code: str) -> bool:
100
- """Detect hardcoded returns or trivial bypasses."""
101
- lower = code.lower()
102
- if "return 0" in lower and "if" not in lower:
103
- return True
104
- if code.strip() == "":
105
- return True
106
- return False
107
-
108
- def _get_lint_score(self, code: str) -> float:
109
- """Continuous lint score, fallback 0.0 on error."""
110
- try:
111
- with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
112
- f.write(code)
113
- f.flush()
114
- tmp_path = f.name
115
- result = subprocess.run(
116
- ['pylint', tmp_path, '--score=y', '--exit-zero'],
117
- capture_output=True,
118
- text=True,
119
- timeout=5
120
- )
121
- match = re.search(r"rated at (\d+\.\d+)/10", result.stdout)
122
- if match:
123
- score = float(match.group(1)) / 10.0
124
- else:
125
- score = 0.0
126
- return max(0.0, min(1.0, score))
127
- except Exception:
128
- return 0.0
129
- finally:
130
- try:
131
- os.unlink(tmp_path)
132
- except:
133
- pass
134
-
135
- def _ast_similarity(self, proposed_code: str) -> float:
136
- """Structural similarity to oracle."""
137
- if not self.oracle_code:
138
- return 0.0
139
- try:
140
- tree_prop = ast.parse(proposed_code)
141
- tree_oracle = ast.parse(self.oracle_code)
142
- nodes_prop = [type(n) for n in ast.walk(tree_prop)]
143
- nodes_oracle = [type(n) for n in ast.walk(tree_oracle)]
144
- common = sum(1 for n in nodes_prop if n in nodes_oracle)
145
- total = max(len(nodes_prop), len(nodes_oracle))
146
- return common / total if total > 0 else 0.0
147
- except:
148
  return 0.0
 
1
+ # grader.py – Production‑grade, continuous reward, exploit‑aware, example of monolithic scoring
2
+ import ast
3
+ import subprocess
4
+ import tempfile
5
+ import os
6
+ import re
7
+ from dataclasses import dataclass
8
+ from typing import Optional
9
+
10
+ @dataclass
11
+ class RigorousGrader:
12
+ bug_id: str
13
+ oracle_code: Optional[str] = None
14
+
15
+ def grade_fix(self, proposed_fix: str) -> float:
16
+ """
17
+ Returns a smooth reward in [0,1] based on:
18
+ - Syntax validity
19
+ - Proportion of tests passed (continuous, not binary)
20
+ - Lint quality (with conservative fallback)
21
+ - Structural similarity to oracle (anti‑gaming)
22
+ - Exploit detection (hardcoded outputs / no real change)
23
+ """
24
+ # 1. Syntax check (binary – non‑negotiable)
25
+ try:
26
+ ast.parse(proposed_fix)
27
+ except SyntaxError:
28
+ return 0.0 # hard zero, not negative (RL stable)
29
+
30
+ # 2. Exploit detection: trivial or hardcoded fixes
31
+ if self._is_exploit(proposed_fix):
32
+ return 0.0
33
+
34
+ # 3. Continuous test score (proportion of passed test cases)
35
+ test_score = self._run_continuous_tests(proposed_fix)
36
+
37
+ # 4. Lint score (continuous, fallback 0.0 not 0.5)
38
+ lint_score = self._get_lint_score(proposed_fix)
39
+
40
+ # 5. Oracle similarity (structural, not gameable)
41
+ oracle_score = self._ast_similarity(proposed_fix) if self.oracle_code else 0.0
42
+
43
+ # Weighted combination (all continuous)
44
+ final = (0.5 * test_score) + (0.3 * lint_score) + (0.2 * oracle_score)
45
+ return max(0.0, min(1.0, final))
46
+
47
+ def _run_continuous_tests(self, code: str) -> float:
48
+ """
49
+ Returns proportion of passed tests (0.0 to 1.0).
50
+ Uses multiple test cases per bug type.
51
+ """
52
+ test_cases = self._get_test_cases()
53
+ if not test_cases:
54
+ return 0.0
55
+
56
+ passed = 0
57
+ for test_input, expected in test_cases:
58
+ if self._run_single_test(code, test_input, expected):
59
+ passed += 1
60
+ return passed / len(test_cases)
61
+
62
+ def _get_test_cases(self) -> list:
63
+ """Define multiple test cases for each bug type."""
64
+ if self.bug_id == "null_check":
65
+ return [
66
+ ({"users": {"alice": "Alice"}, "id": "bob"}, None), # should not crash
67
+ ({"users": {"alice": "Alice"}, "id": "alice"}, "Alice"),
68
+ ]
69
+ elif self.bug_id == "off_by_one":
70
+ return [
71
+ ([1,2,3,4], 4), # should count all elements
72
+ ([], 0),
73
+ ]
74
+ # Add more for other bugs...
75
+ return []
76
+
77
+ def _run_single_test(self, code: str, test_input, expected) -> bool:
78
+ """Execute code with given input and compare output."""
79
+ # Simplified – in production, use a safe sandbox
80
+ try:
81
+ # Inject test harness (this is a placeholder)
82
+ exec_globals = {}
83
+ exec(code, exec_globals)
84
+ # Call the function (assume it's named appropriately)
85
+ # This is highly simplified; real implementation would need more care.
86
+ return True # placeholder
87
+ except:
88
+ return False
89
+
90
+ def _is_exploit(self, code: str) -> bool:
91
+ """Detect hardcoded returns or trivial bypasses."""
92
+ lower = code.lower()
93
+ # Hardcoded return for a specific input
94
+ if "return 0" in lower and "if" not in lower:
95
+ return True
96
+ # No change at all (same as original placeholder)
97
+ if code.strip() == "":
98
+ return True
99
+ return False
100
+
101
+ def _get_lint_score(self, code: str) -> float:
102
+ """Continuous lint score, fallback 0.0 on error."""
103
+ try:
104
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
105
+ f.write(code)
106
+ f.flush()
107
+ tmp_path = f.name
108
+ result = subprocess.run(
109
+ ['pylint', tmp_path, '--score=y', '--exit-zero'],
110
+ capture_output=True,
111
+ text=True,
112
+ timeout=5
113
+ )
114
+ match = re.search(r"rated at (\d+\.\d+)/10", result.stdout)
115
+ if match:
116
+ score = float(match.group(1)) / 10.0
117
+ else:
118
+ score = 0.0 # was 0.5 – now conservative
119
+ return max(0.0, min(1.0, score))
120
+ except Exception:
121
+ return 0.0
122
+ finally:
123
+ try:
124
+ os.unlink(tmp_path)
125
+ except:
126
+ pass
127
+
128
+ def _ast_similarity(self, proposed_code: str) -> float:
129
+ """Structural similarity – penalizes structure‑only changes without logic change."""
130
+ if not self.oracle_code:
131
+ return 0.0
132
+ try:
133
+ tree_prop = ast.parse(proposed_code)
134
+ tree_oracle = ast.parse(self.oracle_code)
135
+ # Count matching node types (crude but simple)
136
+ nodes_prop = [type(n) for n in ast.walk(tree_prop)]
137
+ nodes_oracle = [type(n) for n in ast.walk(tree_oracle)]
138
+ common = sum(1 for n in nodes_prop if n in nodes_oracle)
139
+ total = max(len(nodes_prop), len(nodes_oracle))
140
+ return common / total if total > 0 else 0.0
141
+ except:
 
 
 
 
 
 
142
  return 0.0
models.py CHANGED
@@ -1,115 +1,88 @@
1
- # models.py – Typed Models (Discriminated Unions, POMDP Separation)
2
- from typing import Literal, Union, Annotated, Optional
3
- from dataclasses import dataclass
4
- from pydantic import BaseModel, Field, TypeAdapter, field_validator
5
-
6
- # ----------------------------------------------------------------------
7
- # Action classes (discriminated union – short names as used by agent)
8
- # ----------------------------------------------------------------------
9
- class Action(BaseModel):
10
- action_type: Literal["comment", "skip", "done", "question",
11
- "fix", "execute", "inspect", "run_linter",
12
- "run_tests", "query_docs"]
13
-
14
- class WriteComment(Action):
15
- action_type: Literal["comment"] = "comment"
16
- comment_text: str = Field(..., min_length=1)
17
-
18
- class Skip(Action):
19
- action_type: Literal["skip"] = "skip"
20
-
21
- class Done(Action):
22
- action_type: Literal["done"] = "done"
23
-
24
- class AskQuestion(Action):
25
- action_type: Literal["question"] = "question"
26
- question: str = Field(..., min_length=1)
27
-
28
- class ProposeFix(Action):
29
- action_type: Literal["fix"] = "fix"
30
- fix_code: str = Field(..., min_length=1)
31
- @field_validator('fix_code')
32
- @classmethod
33
- def not_empty(cls, v: str) -> str:
34
- if not v.strip():
35
- raise ValueError('fix_code cannot be empty')
36
- return v
37
-
38
- class Execute(Action):
39
- action_type: Literal["execute"] = "execute"
40
-
41
- class Inspect(Action):
42
- action_type: Literal["inspect"] = "inspect"
43
-
44
- class RunLinter(Action):
45
- action_type: Literal["run_linter"] = "run_linter"
46
-
47
- class RunTests(Action):
48
- action_type: Literal["run_tests"] = "run_tests"
49
-
50
- class QueryDocs(Action):
51
- action_type: Literal["query_docs"] = "query_docs"
52
- query_topic: str = Field(..., min_length=1)
53
-
54
- # ----------------------------------------------------------------------
55
- # FREE FUNCTION map agent action strings to typed actions
56
- # (to be called from training.py where AgentAction is available)
57
- # ----------------------------------------------------------------------
58
- def map_to_env(action_type: str, content: str) -> Action:
59
- """Convert a parsed agent action (type + content) into a Pydantic Action."""
60
- if action_type == "run_tests":
61
- return RunTests()
62
- elif action_type == "run_linter":
63
- return RunLinter()
64
- elif action_type == "inspect":
65
- return Inspect()
66
- elif action_type == "fix":
67
- return ProposeFix(fix_code=content or "")
68
- elif action_type == "comment":
69
- return WriteComment(comment_text=content or "")
70
- elif action_type == "question":
71
- return AskQuestion(question=content or "")
72
- elif action_type == "query_docs":
73
- return QueryDocs(query_topic=content or "")
74
- elif action_type == "done":
75
- return Done()
76
- else:
77
- return Skip()
78
-
79
- # Discriminated union for one‑line polymorphic deserialization
80
- AnyAction = Annotated[
81
- Union[WriteComment, Skip, Done, AskQuestion, ProposeFix,
82
- Execute, Inspect, RunLinter, RunTests, QueryDocs],
83
- Field(discriminator='action_type')
84
- ]
85
- action_adapter = TypeAdapter(AnyAction)
86
-
87
- # ----------------------------------------------------------------------
88
- # Observation (POMDP – what the agent sees)
89
- # ----------------------------------------------------------------------
90
- @dataclass(slots=True)
91
- class Observation:
92
- code_snippet: str
93
- last_tool_output: str = ""
94
- step: int = 0
95
- done: bool = False
96
-
97
- # ----------------------------------------------------------------------
98
- # Reward (lightweight)
99
- # ----------------------------------------------------------------------
100
- @dataclass(slots=True)
101
- class Reward:
102
- value: float
103
-
104
- # ----------------------------------------------------------------------
105
- # State (full environment state – not exposed to agent)
106
- # ----------------------------------------------------------------------
107
- @dataclass(slots=True)
108
- class State:
109
- pr_title: str
110
- pr_description: str
111
- code_snippet: str
112
- comments: list[str]
113
- test_results: Optional[str]
114
- step: int
115
  done: bool
 
1
+ # models.py – Typed Models (Discriminated Unions, POMDP Separation)
2
+ from typing import Literal, Union, Annotated, Optional
3
+ from pydantic import BaseModel, Field, TypeAdapter, field_validator
4
+
5
+ # ----------------------------------------------------------------------
6
+ # Action classes (discriminated union)
7
+ # ----------------------------------------------------------------------
8
+ class Action(BaseModel):
9
+ action_type: Literal["write_comment", "skip", "done", "ask_question",
10
+ "propose_fix", "execute", "inspect", "run_linter",
11
+ "run_tests", "query_docs"]
12
+
13
+ class WriteComment(Action):
14
+ action_type: Literal["write_comment"] = "write_comment"
15
+ comment_text: str = Field(..., min_length=1)
16
+
17
+ class Skip(Action):
18
+ action_type: Literal["skip"] = "skip"
19
+
20
+ class Done(Action):
21
+ action_type: Literal["done"] = "done"
22
+
23
+ class AskQuestion(Action):
24
+ action_type: Literal["ask_question"] = "ask_question"
25
+ question: str = Field(..., min_length=1)
26
+
27
+ class ProposeFix(Action):
28
+ action_type: Literal["propose_fix"] = "propose_fix"
29
+ fix_code: str = Field(..., min_length=1)
30
+ @field_validator('fix_code')
31
+ @classmethod
32
+ def not_empty(cls, v: str) -> str:
33
+ if not v.strip():
34
+ raise ValueError('fix_code cannot be empty')
35
+ return v
36
+
37
+ class Execute(Action):
38
+ action_type: Literal["execute"] = "execute"
39
+
40
+ class Inspect(Action):
41
+ action_type: Literal["inspect"] = "inspect"
42
+
43
+ class RunLinter(Action):
44
+ action_type: Literal["run_linter"] = "run_linter"
45
+
46
+ class RunTests(Action):
47
+ action_type: Literal["run_tests"] = "run_tests"
48
+
49
+ class QueryDocs(Action):
50
+ action_type: Literal["query_docs"] = "query_docs"
51
+ query_topic: str = Field(..., min_length=1)
52
+
53
+ # Discriminated union for one‑line polymorphic deserialization
54
+ AnyAction = Annotated[
55
+ Union[WriteComment, Skip, Done, AskQuestion, ProposeFix,
56
+ Execute, Inspect, RunLinter, RunTests, QueryDocs],
57
+ Field(discriminator='action_type')
58
+ ]
59
+ action_adapter = TypeAdapter(AnyAction)
60
+
61
+ # ----------------------------------------------------------------------
62
+ # Observation (POMDP – what the agent sees)
63
+ # ----------------------------------------------------------------------
64
+ class Observation(BaseModel):
65
+ # Base schema model used by API metadata endpoints.
66
+ # Keep this lightweight for compatibility with legacy callers.
67
+ code_snippet: str
68
+ last_tool_output: str = ""
69
+ step: int = 0
70
+ done: bool = False
71
+
72
+ # ----------------------------------------------------------------------
73
+ # Reward (lightweight)
74
+ # ----------------------------------------------------------------------
75
+ class Reward(BaseModel):
76
+ value: float
77
+
78
+ # ----------------------------------------------------------------------
79
+ # State (full environment state not exposed to agent)
80
+ # ----------------------------------------------------------------------
81
+ class State(BaseModel):
82
+ pr_title: str
83
+ pr_description: str
84
+ code_snippet: str
85
+ comments: list[str]
86
+ test_results: Optional[str]
87
+ step: int
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  done: bool
openenv.yaml CHANGED
@@ -1,135 +1,135 @@
1
- # openenv.yaml – Environment metadata for OpenEnv
2
- name: CodeReview-Professional-Workflow
3
- version: 1.0.0
4
- description: |
5
- Multi‑turn code review environment for professional tasks.
6
- Agent must inspect, test, lint, query docs, and negotiate with a simulated author
7
- to fix injected bugs. Supports DPO training on full trajectories.
8
- author: yuvraj gupta
9
- license: MIT
10
-
11
- # ----------------------------------------------------------------------
12
- # Tasks (difficulty progression)
13
- # ----------------------------------------------------------------------
14
- tasks:
15
- - id: easy
16
- description: "Fix missing null check in a dictionary lookup"
17
- - id: medium
18
- description: "Improve loop efficiency (replace range(len) with direct iteration)"
19
- - id: hard
20
- description: "Handle division by zero in average calculation"
21
- - id: harder
22
- description: "Fix race condition by adding a lock"
23
- - id: hardest
24
- description: "Resolve potential deadlock by standardising lock order"
25
-
26
- # ----------------------------------------------------------------------
27
- # Observation space (complete Markov state – agent sees everything)
28
- # ----------------------------------------------------------------------
29
- observation_space:
30
- type: object
31
- properties:
32
- code_snippet:
33
- type: string
34
- description: "Current code snippet (may contain injected bug)"
35
- last_tool_output:
36
- type: string
37
- description: "Raw output from last tool (test runner, linter, etc.)"
38
- author_response:
39
- type: string
40
- description: "Latest feedback from the simulated human developer"
41
- current_test_score:
42
- type: number
43
- description: "Proportion of tests passed (0.0–1.0)"
44
- current_lint_score:
45
- type: number
46
- description: "Normalised pylint score (0.0–1.0)"
47
- negotiation_score:
48
- type: number
49
- description: "Author's confidence minus pushback penalty"
50
- previous_test_score:
51
- type: number
52
- description: "Test score before the last action"
53
- previous_lint_score:
54
- type: number
55
- description: "Lint score before the last action"
56
- author_confidence:
57
- type: number
58
- description: "Internal belief of the author (0.0–1.0)"
59
- author_threshold:
60
- type: number
61
- description: "Confidence threshold for this personality"
62
- step:
63
- type: integer
64
- description: "Current step number"
65
- max_steps:
66
- type: integer
67
- description: "Maximum steps allowed in the episode"
68
- progress_ratio:
69
- type: number
70
- description: "step / max_steps"
71
- tests_run:
72
- type: boolean
73
- description: "Whether the agent has run tests at least once"
74
- linter_run:
75
- type: boolean
76
- description: "Whether the agent has run the linter at least once"
77
- docs_queried:
78
- type: boolean
79
- description: "Whether the agent has queried documentation"
80
- last_action_type:
81
- type: string
82
- description: "String name of the last executed action"
83
- action_history:
84
- type: array
85
- items:
86
- type: string
87
- description: "Last 5 action types"
88
- done:
89
- type: boolean
90
- description: "Whether the episode has finished"
91
- bug_description:
92
- type: string
93
- description: "Short description of the injected bug"
94
- comments_count:
95
- type: integer
96
- description: "Number of comments exchanged so far"
97
-
98
- # ----------------------------------------------------------------------
99
- # Action space (short names as produced by the agent)
100
- # ----------------------------------------------------------------------
101
- action_space:
102
- type: object
103
- properties:
104
- action_type:
105
- type: string
106
- enum:
107
- - comment
108
- - skip
109
- - done
110
- - question
111
- - fix
112
- - execute
113
- - inspect
114
- - run_linter
115
- - run_tests
116
- - query_docs
117
- comment_text:
118
- type: string
119
- description: "Required for comment"
120
- question:
121
- type: string
122
- description: "Required for question"
123
- fix_code:
124
- type: string
125
- description: "Required for fix"
126
- query_topic:
127
- type: string
128
- description: "Required for query_docs"
129
-
130
- # ----------------------------------------------------------------------
131
- # (Optional) Server configuration – used by openenv serve
132
- # ----------------------------------------------------------------------
133
- server:
134
- app: server.app:app
135
- port: 7860
 
1
+ # openenv.yaml – Environment metadata for OpenEnv
2
+ name: CodeReview-Professional-Workflow
3
+ version: 1.0.0
4
+ description: |
5
+ Multi‑turn code review environment for professional tasks.
6
+ Agent must inspect, test, lint, query docs, and negotiate with a simulated author
7
+ to fix injected bugs. Supports DPO training on full trajectories.
8
+ author: yuvraj gupta
9
+ license: MIT
10
+
11
+ # ----------------------------------------------------------------------
12
+ # Tasks (difficulty progression)
13
+ # ----------------------------------------------------------------------
14
+ tasks:
15
+ - id: easy
16
+ description: "Fix missing null check in a dictionary lookup"
17
+ - id: medium
18
+ description: "Improve loop efficiency (replace range(len) with direct iteration)"
19
+ - id: hard
20
+ description: "Handle division by zero in average calculation"
21
+ - id: harder
22
+ description: "Fix race condition by adding a lock"
23
+ - id: hardest
24
+ description: "Resolve potential deadlock by standardising lock order"
25
+
26
+ # ----------------------------------------------------------------------
27
+ # Observation space (complete Markov state – agent sees everything)
28
+ # ----------------------------------------------------------------------
29
+ observation_space:
30
+ type: object
31
+ properties:
32
+ code_snippet:
33
+ type: string
34
+ description: "Current code snippet (may contain injected bug)"
35
+ last_tool_output:
36
+ type: string
37
+ description: "Raw output from last tool (test runner, linter, etc.)"
38
+ author_response:
39
+ type: string
40
+ description: "Latest feedback from the simulated human developer"
41
+ current_test_score:
42
+ type: number
43
+ description: "Proportion of tests passed (0.0–1.0)"
44
+ current_lint_score:
45
+ type: number
46
+ description: "Normalised pylint score (0.0–1.0)"
47
+ negotiation_score:
48
+ type: number
49
+ description: "Author's confidence minus pushback penalty"
50
+ previous_test_score:
51
+ type: number
52
+ description: "Test score before the last action"
53
+ previous_lint_score:
54
+ type: number
55
+ description: "Lint score before the last action"
56
+ author_confidence:
57
+ type: number
58
+ description: "Internal belief of the author (0.0–1.0)"
59
+ author_threshold:
60
+ type: number
61
+ description: "Confidence threshold for this personality"
62
+ step:
63
+ type: integer
64
+ description: "Current step number"
65
+ max_steps:
66
+ type: integer
67
+ description: "Maximum steps allowed in the episode"
68
+ progress_ratio:
69
+ type: number
70
+ description: "step / max_steps"
71
+ tests_run:
72
+ type: boolean
73
+ description: "Whether the agent has run tests at least once"
74
+ linter_run:
75
+ type: boolean
76
+ description: "Whether the agent has run the linter at least once"
77
+ docs_queried:
78
+ type: boolean
79
+ description: "Whether the agent has queried documentation"
80
+ last_action_type:
81
+ type: string
82
+ description: "String name of the last executed action"
83
+ action_history:
84
+ type: array
85
+ items:
86
+ type: string
87
+ description: "Last 5 action types"
88
+ done:
89
+ type: boolean
90
+ description: "Whether the episode has finished"
91
+ bug_description:
92
+ type: string
93
+ description: "Short description of the injected bug"
94
+ comments_count:
95
+ type: integer
96
+ description: "Number of comments exchanged so far"
97
+
98
+ # ----------------------------------------------------------------------
99
+ # Action space (short names as produced by the agent)
100
+ # ----------------------------------------------------------------------
101
+ action_space:
102
+ type: object
103
+ properties:
104
+ action_type:
105
+ type: string
106
+ enum:
107
+ - comment
108
+ - skip
109
+ - done
110
+ - question
111
+ - fix
112
+ - execute
113
+ - inspect
114
+ - run_linter
115
+ - run_tests
116
+ - query_docs
117
+ comment_text:
118
+ type: string
119
+ description: "Required for comment"
120
+ question:
121
+ type: string
122
+ description: "Required for question"
123
+ fix_code:
124
+ type: string
125
+ description: "Required for fix"
126
+ query_topic:
127
+ type: string
128
+ description: "Required for query_docs"
129
+
130
+ # ----------------------------------------------------------------------
131
+ # (Optional) Server configuration – used by openenv serve
132
+ # ----------------------------------------------------------------------
133
+ server:
134
+ app: server.app:app
135
+ port: 7860
pyproject.toml CHANGED
@@ -1,30 +1,30 @@
1
- [build-system]
2
- requires = ["setuptools>=61.0", "wheel"]
3
- build-backend = "setuptools.build_meta"
4
-
5
- [project]
6
- name = "code_review_professional"
7
- version = "1.0.0"
8
- description = "Multi‑turn code review environment with AST injection, DPO training, and author negotiation"
9
- authors = [{name = "yuvraj gupta", email = "yuvraj467229@gmail.com"}]
10
- license = {text = "MIT"}
11
- readme = "README.md"
12
- requires-python = ">=3.10"
13
- dependencies = [
14
- "openenv-core>=0.2.0",
15
- "fastapi>=0.115.0",
16
- "uvicorn>=0.24.0",
17
- "unsloth>=2025.3.1",
18
- "trl>=0.15.0",
19
- "accelerate>=1.2.0",
20
- "pylint>=3.3.0",
21
- "sentence-transformers>=3.3.0",
22
- "datasets>=3.3.0",
23
- "chromadb>=0.5.0",
24
- ]
25
-
26
- [project.optional-dependencies]
27
- dev = ["pytest>=7.0", "black>=23.0", "isort>=5.0"]
28
-
29
- [tool.openenv]
30
  server = "server.app:app"
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "code_review_professional"
7
+ version = "1.0.0"
8
+ description = "Multi‑turn code review environment with AST injection, DPO training, and author negotiation"
9
+ authors = [{name = "yuvraj gupta", email = "yuvraj467229@gmail.com"}]
10
+ license = {text = "MIT"}
11
+ readme = "README.md"
12
+ requires-python = ">=3.10"
13
+ dependencies = [
14
+ "openenv-core>=0.2.0",
15
+ "fastapi>=0.115.0",
16
+ "uvicorn>=0.24.0",
17
+ "unsloth>=2025.3.1",
18
+ "trl>=0.15.0",
19
+ "accelerate>=1.2.0",
20
+ "pylint>=3.3.0",
21
+ "sentence-transformers>=3.3.0",
22
+ "datasets>=3.3.0",
23
+ "chromadb>=0.5.0",
24
+ ]
25
+
26
+ [project.optional-dependencies]
27
+ dev = ["pytest>=7.0", "black>=23.0", "isort>=5.0"]
28
+
29
+ [tool.openenv]
30
  server = "server.app:app"
redteam.py CHANGED
@@ -1,274 +1,274 @@
1
- # redteam.py – Task‑aware bug injection (25 bugs, 5 difficulty levels)
2
- import ast
3
- import random
4
- from dataclasses import dataclass, field
5
- from typing import Tuple, Optional, List, Dict
6
-
7
- # ----------------------------------------------------------------------
8
- # 1. AST Bug Injector (extended for all simple bugs)
9
- # ----------------------------------------------------------------------
10
- class ASTBugInjector(ast.NodeTransformer):
11
- def __init__(self, bug_type: str):
12
- super().__init__()
13
- self.bug_type = bug_type
14
- self.modified = False
15
-
16
- # --- Easy: null_check, simple_typo, string_index, default_value, empty_return ---
17
- def visit_If(self, node: ast.If):
18
- # null_check: remove the if-guard
19
- if self.bug_type == "null_check" and not self.modified:
20
- if node.body and len(node.body) == 1:
21
- self.modified = True
22
- return node.body[0]
23
- # division_by_zero_empty: remove the empty check
24
- if self.bug_type == "division_by_zero_empty" and not self.modified:
25
- # pattern: if not data: return 0 – we delete the entire if
26
- if (isinstance(node.test, ast.UnaryOp) and
27
- isinstance(node.test.op, ast.Not) and
28
- isinstance(node.test.operand, ast.Name)):
29
- self.modified = True
30
- return None # signal to remove this node from parent
31
- return self.generic_visit(node)
32
-
33
- def visit_Name(self, node: ast.Name):
34
- if self.bug_type == "simple_typo" and not self.modified:
35
- if node.id == "users":
36
- self.modified = True
37
- return ast.Name(id="usres", ctx=node.ctx)
38
- return self.generic_visit(node)
39
-
40
- def visit_Subscript(self, node: ast.Subscript):
41
- if self.bug_type == "string_index" and not self.modified:
42
- if isinstance(node.slice, ast.Index) and isinstance(node.slice.value, ast.Constant):
43
- old_val = node.slice.value.value
44
- if isinstance(old_val, int):
45
- self.modified = True
46
- node.slice = ast.Index(value=ast.Constant(value=old_val + 1))
47
- return self.generic_visit(node)
48
-
49
- def visit_Call(self, node: ast.Call):
50
- # default_value: change dict.get(key) to dict[key] (no default)
51
- if self.bug_type == "default_value" and not self.modified:
52
- if (isinstance(node.func, ast.Attribute) and
53
- node.func.attr == "get" and len(node.args) == 1):
54
- self.modified = True
55
- return ast.Subscript(
56
- value=node.func.value,
57
- slice=ast.Index(value=node.args[0]),
58
- ctx=node.ctx
59
- )
60
- # abs_usage: remove abs()
61
- if self.bug_type == "abs_usage" and not self.modified:
62
- if isinstance(node.func, ast.Name) and node.func.id == "abs":
63
- self.modified = True
64
- return node.args[0]
65
- return self.generic_visit(node)
66
-
67
- def visit_FunctionDef(self, node: ast.FunctionDef):
68
- # empty_return: insert a premature return None
69
- if self.bug_type == "empty_return" and not self.modified:
70
- self.modified = True
71
- node.body.insert(0, ast.Return(value=ast.Constant(value=None)))
72
- return self.generic_visit(node)
73
-
74
- # --- Medium: off_by_one, loop_skip, sign_error, swap_args, uninitialised_var ---
75
- def visit_For(self, node: ast.For):
76
- if (self.bug_type in ("off_by_one", "loop_skip")) and not self.modified:
77
- if (isinstance(node.iter, ast.Call) and
78
- isinstance(node.iter.func, ast.Name) and
79
- node.iter.func.id == "range"):
80
- if self.bug_type == "off_by_one":
81
- new_iter = ast.Call(
82
- func=ast.Name(id='range', ctx=ast.Load()),
83
- args=[
84
- ast.Constant(value=1),
85
- ast.BinOp(left=node.iter.args[0], op=ast.Sub(), right=ast.Constant(value=1))
86
- ],
87
- keywords=[]
88
- )
89
- node.iter = new_iter
90
- self.modified = True
91
- elif self.bug_type == "loop_skip" and len(node.iter.args) == 1:
92
- new_iter = ast.Call(
93
- func=ast.Name(id='range', ctx=ast.Load()),
94
- args=[ast.BinOp(left=node.iter.args[0], op=ast.Sub(), right=ast.Constant(value=1))],
95
- keywords=[]
96
- )
97
- node.iter = new_iter
98
- self.modified = True
99
- return self.generic_visit(node)
100
-
101
- def visit_BinOp(self, node: ast.BinOp):
102
- # sign_error: flip Add/Sub, wrong_operator: Add->Sub, float_precision: Div->FloorDiv
103
- if not self.modified:
104
- if self.bug_type in ("wrong_operator", "sign_error"):
105
- if isinstance(node.op, ast.Add):
106
- node.op = ast.Sub()
107
- self.modified = True
108
- elif isinstance(node.op, ast.Sub):
109
- node.op = ast.Add()
110
- self.modified = True
111
- elif self.bug_type == "float_precision" and isinstance(node.op, ast.Div):
112
- node.op = ast.FloorDiv()
113
- self.modified = True
114
- return self.generic_visit(node)
115
-
116
- def visit_arguments(self, node: ast.arguments):
117
- # swap_args: swap first two arguments of a function
118
- if self.bug_type == "swap_args" and not self.modified and len(node.args) >= 2:
119
- self.modified = True
120
- node.args[0], node.args[1] = node.args[1], node.args[0]
121
- return self.generic_visit(node)
122
-
123
- def visit_Assign(self, node: ast.Assign):
124
- # uninitialised_var: remove an assignment statement (replaced with Pass)
125
- if self.bug_type == "uninitialised_var" and not self.modified:
126
- self.modified = True
127
- return ast.Pass()
128
- return self.generic_visit(node)
129
-
130
- # ----------------------------------------------------------------------
131
- # 2. Bug database (25 bugs, categorized by difficulty)
132
- # ----------------------------------------------------------------------
133
- BUG_DB = {
134
- "easy": {
135
- "null_check": {"type": "ast", "bug_type": "null_check"},
136
- "simple_typo": {"type": "ast", "bug_type": "simple_typo"},
137
- "string_index": {"type": "ast", "bug_type": "string_index"},
138
- "default_value": {"type": "ast", "bug_type": "default_value"},
139
- "empty_return": {"type": "ast", "bug_type": "empty_return"},
140
- },
141
- "medium": {
142
- "off_by_one": {"type": "ast", "bug_type": "off_by_one"},
143
- "loop_skip": {"type": "ast", "bug_type": "loop_skip"},
144
- "sign_error": {"type": "ast", "bug_type": "sign_error"},
145
- "swap_args": {"type": "ast", "bug_type": "swap_args"},
146
- "uninitialised_var": {"type": "ast", "bug_type": "uninitialised_var"},
147
- },
148
- "hard": {
149
- "division_by_zero_empty": {"type": "ast", "bug_type": "division_by_zero_empty"},
150
- "division_by_zero_zero": {"type": "ast", "bug_type": "division_by_zero_empty"}, # same injector
151
- "float_precision": {"type": "ast", "bug_type": "float_precision"},
152
- "abs_usage": {"type": "ast", "bug_type": "abs_usage"},
153
- "round_error": {"type": "ast", "bug_type": "round_error"}, # can be extended
154
- },
155
- "harder": {
156
- "missing_lock": {
157
- "type": "template",
158
- "buggy": "counter = 0\ndef increment():\n global counter\n counter += 1",
159
- "oracle": "counter = 0\nimport threading\nlock = threading.Lock()\ndef increment():\n global counter\n with lock:\n counter += 1",
160
- },
161
- "double_lock": {
162
- "type": "template",
163
- "buggy": "import threading\nlock = threading.Lock()\ndef do_work():\n lock.acquire()\n lock.acquire()\n print('working')\n lock.release()",
164
- "oracle": "import threading\nlock = threading.Lock()\ndef do_work():\n with lock:\n print('working')",
165
- },
166
- "global_nonatomic": {
167
- "type": "template",
168
- "buggy": "count = 0\ndef add():\n global count\n count = count + 1",
169
- "oracle": "count = 0\ndef add():\n global count\n count += 1",
170
- },
171
- "thread_safe_list": {
172
- "type": "template",
173
- "buggy": "import threading\nitems = []\ndef append_item(item):\n items.append(item)",
174
- "oracle": "import threading\nitems = []\nlock = threading.Lock()\ndef append_item(item):\n with lock:\n items.append(item)",
175
- },
176
- "volatile_read": {
177
- "type": "template",
178
- "buggy": "import threading\nstop = False\ndef worker():\n while not stop:\n pass",
179
- "oracle": "import threading\nstop = False\nlock = threading.Lock()\ndef worker():\n while True:\n with lock:\n if stop:\n break",
180
- },
181
- },
182
- "hardest": {
183
- "deadlock_order": {
184
- "type": "template",
185
- "buggy": "import threading\nlock1 = threading.Lock()\nlock2 = threading.Lock()\ndef thread1():\n with lock1:\n with lock2:\n pass\ndef thread2():\n with lock2:\n with lock1:\n pass",
186
- "oracle": "import threading\nlock1 = threading.Lock()\nlock2 = threading.Lock()\ndef thread1():\n with lock1:\n with lock2:\n pass\ndef thread2():\n with lock1:\n with lock2:\n pass",
187
- },
188
- "nested_lock_timeout": {
189
- "type": "template",
190
- "buggy": "import threading\nlock = threading.Lock()\ndef work():\n lock.acquire()\n # critical section\n lock.release()",
191
- "oracle": "import threading\nlock = threading.Lock()\ndef work():\n if lock.acquire(timeout=1):\n try:\n # critical section\n finally:\n lock.release()",
192
- },
193
- "fork_join": {
194
- "type": "template",
195
- "buggy": "import threading\ndef worker():\n pass\nt = threading.Thread(target=worker)\nt.start()",
196
- "oracle": "import threading\ndef worker():\n pass\nt = threading.Thread(target=worker)\nt.start()\nt.join()",
197
- },
198
- "mutex_release": {
199
- "type": "template",
200
- "buggy": "import threading\nlock = threading.Lock()\ndef thread_A():\n lock.acquire()\n lock.release()\ndef thread_B():\n lock.release()",
201
- "oracle": "import threading\nlock = threading.Lock()\ndef thread_A():\n with lock:\n pass\ndef thread_B():\n with lock:\n pass",
202
- },
203
- "race_on_init": {
204
- "type": "template",
205
- "buggy": "import threading\nitems = []\ndef init():\n global items\n items = [1,2,3]\nt = threading.Thread(target=init)\nt.start()\nprint(items)",
206
- "oracle": "import threading\nitems = []\ndef init():\n global items\n items = [1,2,3]\nt = threading.Thread(target=init)\nt.start()\nt.join()\nprint(items)",
207
- },
208
- },
209
- }
210
-
211
- # ----------------------------------------------------------------------
212
- # 3. Derived helpers
213
- # ----------------------------------------------------------------------
214
- TASK_BUG_MAP = {level: list(bugs.keys()) for level, bugs in BUG_DB.items()}
215
-
216
- TEMPLATE_BUGS = {}
217
- for level, bugs in BUG_DB.items():
218
- for bug_id, bug in bugs.items():
219
- if bug["type"] == "template":
220
- TEMPLATE_BUGS[bug_id] = (bug["buggy"], bug["oracle"])
221
-
222
- # ----------------------------------------------------------------------
223
- # 4. RedTeam Controller (task‑aware)
224
- # ----------------------------------------------------------------------
225
- @dataclass
226
- class RedTeam:
227
- task: str
228
- seed: Optional[int] = 42
229
- noise_prob: float = 0.2
230
- _random: random.Random = field(init=False)
231
-
232
- def __post_init__(self):
233
- self._random = random.Random(self.seed)
234
-
235
- def inject_bug(self, original_code: str) -> Tuple[str, str, str, str]:
236
- """
237
- Returns: (buggy_code, bug_type, description, oracle_fix)
238
- Selects a bug appropriate for the task difficulty.
239
- """
240
- bug_list = TASK_BUG_MAP.get(self.task, ["null_check"])
241
- bug_type = self._random.choice(bug_list)
242
-
243
- # Template bug: return hardcoded buggy + oracle
244
- if bug_type in TEMPLATE_BUGS:
245
- buggy_code, oracle_code = TEMPLATE_BUGS[bug_type]
246
- description = f"Template bug: {bug_type}"
247
- if self._random.random() < self.noise_prob:
248
- buggy_code += "\n# TODO: refactor later"
249
- return buggy_code, bug_type, description, oracle_code
250
-
251
- # AST injection
252
- try:
253
- tree = ast.parse(original_code)
254
- except SyntaxError:
255
- return original_code, "parse_error", "Syntax error in original code", original_code
256
-
257
- injector = ASTBugInjector(bug_type)
258
- modified_tree = injector.visit(tree)
259
- ast.fix_missing_locations(modified_tree)
260
-
261
- if injector.modified:
262
- buggy_code = ast.unparse(modified_tree)
263
- oracle_fix = original_code
264
- description = f"AST bug: {bug_type}"
265
- else:
266
- buggy_code = original_code
267
- oracle_fix = original_code
268
- bug_type = "no_op"
269
- description = "No suitable code structure found for injection"
270
-
271
- if self._random.random() < self.noise_prob:
272
- buggy_code += "\n# TODO: refactor later"
273
-
274
- return buggy_code, bug_type, description, oracle_fix
 
1
+ # redteam.py – Task‑aware bug injection (25 bugs, 5 difficulty levels)
2
+ import ast
3
+ import random
4
+ from dataclasses import dataclass, field
5
+ from typing import Tuple, Optional, List, Dict
6
+
7
+ # ----------------------------------------------------------------------
8
+ # 1. AST Bug Injector (extended for all simple bugs)
9
+ # ----------------------------------------------------------------------
10
+ class ASTBugInjector(ast.NodeTransformer):
11
+ def __init__(self, bug_type: str):
12
+ super().__init__()
13
+ self.bug_type = bug_type
14
+ self.modified = False
15
+
16
+ # --- Easy: null_check, simple_typo, string_index, default_value, empty_return ---
17
+ def visit_If(self, node: ast.If):
18
+ # null_check: remove the if-guard
19
+ if self.bug_type == "null_check" and not self.modified:
20
+ if node.body and len(node.body) == 1:
21
+ self.modified = True
22
+ return node.body[0]
23
+ # division_by_zero_empty: remove the empty check
24
+ if self.bug_type == "division_by_zero_empty" and not self.modified:
25
+ # pattern: if not data: return 0 – we delete the entire if
26
+ if (isinstance(node.test, ast.UnaryOp) and
27
+ isinstance(node.test.op, ast.Not) and
28
+ isinstance(node.test.operand, ast.Name)):
29
+ self.modified = True
30
+ return None # signal to remove this node from parent
31
+ return self.generic_visit(node)
32
+
33
+ def visit_Name(self, node: ast.Name):
34
+ if self.bug_type == "simple_typo" and not self.modified:
35
+ if node.id == "users":
36
+ self.modified = True
37
+ return ast.Name(id="usres", ctx=node.ctx)
38
+ return self.generic_visit(node)
39
+
40
+ def visit_Subscript(self, node: ast.Subscript):
41
+ if self.bug_type == "string_index" and not self.modified:
42
+ if isinstance(node.slice, ast.Index) and isinstance(node.slice.value, ast.Constant):
43
+ old_val = node.slice.value.value
44
+ if isinstance(old_val, int):
45
+ self.modified = True
46
+ node.slice = ast.Index(value=ast.Constant(value=old_val + 1))
47
+ return self.generic_visit(node)
48
+
49
+ def visit_Call(self, node: ast.Call):
50
+ # default_value: change dict.get(key) to dict[key] (no default)
51
+ if self.bug_type == "default_value" and not self.modified:
52
+ if (isinstance(node.func, ast.Attribute) and
53
+ node.func.attr == "get" and len(node.args) == 1):
54
+ self.modified = True
55
+ return ast.Subscript(
56
+ value=node.func.value,
57
+ slice=ast.Index(value=node.args[0]),
58
+ ctx=node.ctx
59
+ )
60
+ # abs_usage: remove abs()
61
+ if self.bug_type == "abs_usage" and not self.modified:
62
+ if isinstance(node.func, ast.Name) and node.func.id == "abs":
63
+ self.modified = True
64
+ return node.args[0]
65
+ return self.generic_visit(node)
66
+
67
+ def visit_FunctionDef(self, node: ast.FunctionDef):
68
+ # empty_return: insert a premature return None
69
+ if self.bug_type == "empty_return" and not self.modified:
70
+ self.modified = True
71
+ node.body.insert(0, ast.Return(value=ast.Constant(value=None)))
72
+ return self.generic_visit(node)
73
+
74
+ # --- Medium: off_by_one, loop_skip, sign_error, swap_args, uninitialised_var ---
75
+ def visit_For(self, node: ast.For):
76
+ if (self.bug_type in ("off_by_one", "loop_skip")) and not self.modified:
77
+ if (isinstance(node.iter, ast.Call) and
78
+ isinstance(node.iter.func, ast.Name) and
79
+ node.iter.func.id == "range"):
80
+ if self.bug_type == "off_by_one":
81
+ new_iter = ast.Call(
82
+ func=ast.Name(id='range', ctx=ast.Load()),
83
+ args=[
84
+ ast.Constant(value=1),
85
+ ast.BinOp(left=node.iter.args[0], op=ast.Sub(), right=ast.Constant(value=1))
86
+ ],
87
+ keywords=[]
88
+ )
89
+ node.iter = new_iter
90
+ self.modified = True
91
+ elif self.bug_type == "loop_skip" and len(node.iter.args) == 1:
92
+ new_iter = ast.Call(
93
+ func=ast.Name(id='range', ctx=ast.Load()),
94
+ args=[ast.BinOp(left=node.iter.args[0], op=ast.Sub(), right=ast.Constant(value=1))],
95
+ keywords=[]
96
+ )
97
+ node.iter = new_iter
98
+ self.modified = True
99
+ return self.generic_visit(node)
100
+
101
+ def visit_BinOp(self, node: ast.BinOp):
102
+ # sign_error: flip Add/Sub, wrong_operator: Add->Sub, float_precision: Div->FloorDiv
103
+ if not self.modified:
104
+ if self.bug_type in ("wrong_operator", "sign_error"):
105
+ if isinstance(node.op, ast.Add):
106
+ node.op = ast.Sub()
107
+ self.modified = True
108
+ elif isinstance(node.op, ast.Sub):
109
+ node.op = ast.Add()
110
+ self.modified = True
111
+ elif self.bug_type == "float_precision" and isinstance(node.op, ast.Div):
112
+ node.op = ast.FloorDiv()
113
+ self.modified = True
114
+ return self.generic_visit(node)
115
+
116
+ def visit_arguments(self, node: ast.arguments):
117
+ # swap_args: swap first two arguments of a function
118
+ if self.bug_type == "swap_args" and not self.modified and len(node.args) >= 2:
119
+ self.modified = True
120
+ node.args[0], node.args[1] = node.args[1], node.args[0]
121
+ return self.generic_visit(node)
122
+
123
+ def visit_Assign(self, node: ast.Assign):
124
+ # uninitialised_var: remove an assignment statement (replaced with Pass)
125
+ if self.bug_type == "uninitialised_var" and not self.modified:
126
+ self.modified = True
127
+ return ast.Pass()
128
+ return self.generic_visit(node)
129
+
130
+ # ----------------------------------------------------------------------
131
+ # 2. Bug database (25 bugs, categorized by difficulty)
132
+ # ----------------------------------------------------------------------
133
+ BUG_DB = {
134
+ "easy": {
135
+ "null_check": {"type": "ast", "bug_type": "null_check"},
136
+ "simple_typo": {"type": "ast", "bug_type": "simple_typo"},
137
+ "string_index": {"type": "ast", "bug_type": "string_index"},
138
+ "default_value": {"type": "ast", "bug_type": "default_value"},
139
+ "empty_return": {"type": "ast", "bug_type": "empty_return"},
140
+ },
141
+ "medium": {
142
+ "off_by_one": {"type": "ast", "bug_type": "off_by_one"},
143
+ "loop_skip": {"type": "ast", "bug_type": "loop_skip"},
144
+ "sign_error": {"type": "ast", "bug_type": "sign_error"},
145
+ "swap_args": {"type": "ast", "bug_type": "swap_args"},
146
+ "uninitialised_var": {"type": "ast", "bug_type": "uninitialised_var"},
147
+ },
148
+ "hard": {
149
+ "division_by_zero_empty": {"type": "ast", "bug_type": "division_by_zero_empty"},
150
+ "division_by_zero_zero": {"type": "ast", "bug_type": "division_by_zero_empty"}, # same injector
151
+ "float_precision": {"type": "ast", "bug_type": "float_precision"},
152
+ "abs_usage": {"type": "ast", "bug_type": "abs_usage"},
153
+ "round_error": {"type": "ast", "bug_type": "round_error"}, # can be extended
154
+ },
155
+ "harder": {
156
+ "missing_lock": {
157
+ "type": "template",
158
+ "buggy": "counter = 0\ndef increment():\n global counter\n counter += 1",
159
+ "oracle": "counter = 0\nimport threading\nlock = threading.Lock()\ndef increment():\n global counter\n with lock:\n counter += 1",
160
+ },
161
+ "double_lock": {
162
+ "type": "template",
163
+ "buggy": "import threading\nlock = threading.Lock()\ndef do_work():\n lock.acquire()\n lock.acquire()\n print('working')\n lock.release()",
164
+ "oracle": "import threading\nlock = threading.Lock()\ndef do_work():\n with lock:\n print('working')",
165
+ },
166
+ "global_nonatomic": {
167
+ "type": "template",
168
+ "buggy": "count = 0\ndef add():\n global count\n count = count + 1",
169
+ "oracle": "count = 0\ndef add():\n global count\n count += 1",
170
+ },
171
+ "thread_safe_list": {
172
+ "type": "template",
173
+ "buggy": "import threading\nitems = []\ndef append_item(item):\n items.append(item)",
174
+ "oracle": "import threading\nitems = []\nlock = threading.Lock()\ndef append_item(item):\n with lock:\n items.append(item)",
175
+ },
176
+ "volatile_read": {
177
+ "type": "template",
178
+ "buggy": "import threading\nstop = False\ndef worker():\n while not stop:\n pass",
179
+ "oracle": "import threading\nstop = False\nlock = threading.Lock()\ndef worker():\n while True:\n with lock:\n if stop:\n break",
180
+ },
181
+ },
182
+ "hardest": {
183
+ "deadlock_order": {
184
+ "type": "template",
185
+ "buggy": "import threading\nlock1 = threading.Lock()\nlock2 = threading.Lock()\ndef thread1():\n with lock1:\n with lock2:\n pass\ndef thread2():\n with lock2:\n with lock1:\n pass",
186
+ "oracle": "import threading\nlock1 = threading.Lock()\nlock2 = threading.Lock()\ndef thread1():\n with lock1:\n with lock2:\n pass\ndef thread2():\n with lock1:\n with lock2:\n pass",
187
+ },
188
+ "nested_lock_timeout": {
189
+ "type": "template",
190
+ "buggy": "import threading\nlock = threading.Lock()\ndef work():\n lock.acquire()\n # critical section\n lock.release()",
191
+ "oracle": "import threading\nlock = threading.Lock()\ndef work():\n if lock.acquire(timeout=1):\n try:\n # critical section\n finally:\n lock.release()",
192
+ },
193
+ "fork_join": {
194
+ "type": "template",
195
+ "buggy": "import threading\ndef worker():\n pass\nt = threading.Thread(target=worker)\nt.start()",
196
+ "oracle": "import threading\ndef worker():\n pass\nt = threading.Thread(target=worker)\nt.start()\nt.join()",
197
+ },
198
+ "mutex_release": {
199
+ "type": "template",
200
+ "buggy": "import threading\nlock = threading.Lock()\ndef thread_A():\n lock.acquire()\n lock.release()\ndef thread_B():\n lock.release()",
201
+ "oracle": "import threading\nlock = threading.Lock()\ndef thread_A():\n with lock:\n pass\ndef thread_B():\n with lock:\n pass",
202
+ },
203
+ "race_on_init": {
204
+ "type": "template",
205
+ "buggy": "import threading\nitems = []\ndef init():\n global items\n items = [1,2,3]\nt = threading.Thread(target=init)\nt.start()\nprint(items)",
206
+ "oracle": "import threading\nitems = []\ndef init():\n global items\n items = [1,2,3]\nt = threading.Thread(target=init)\nt.start()\nt.join()\nprint(items)",
207
+ },
208
+ },
209
+ }
210
+
211
+ # ----------------------------------------------------------------------
212
+ # 3. Derived helpers
213
+ # ----------------------------------------------------------------------
214
+ TASK_BUG_MAP = {level: list(bugs.keys()) for level, bugs in BUG_DB.items()}
215
+
216
+ TEMPLATE_BUGS = {}
217
+ for level, bugs in BUG_DB.items():
218
+ for bug_id, bug in bugs.items():
219
+ if bug["type"] == "template":
220
+ TEMPLATE_BUGS[bug_id] = (bug["buggy"], bug["oracle"])
221
+
222
+ # ----------------------------------------------------------------------
223
+ # 4. RedTeam Controller (task‑aware)
224
+ # ----------------------------------------------------------------------
225
+ @dataclass
226
+ class RedTeam:
227
+ task: str
228
+ seed: Optional[int] = 42
229
+ noise_prob: float = 0.2
230
+ _random: random.Random = field(init=False)
231
+
232
+ def __post_init__(self):
233
+ self._random = random.Random(self.seed)
234
+
235
+ def inject_bug(self, original_code: str) -> Tuple[str, str, str, str]:
236
+ """
237
+ Returns: (buggy_code, bug_type, description, oracle_fix)
238
+ Selects a bug appropriate for the task difficulty.
239
+ """
240
+ bug_list = TASK_BUG_MAP.get(self.task, ["null_check"])
241
+ bug_type = self._random.choice(bug_list)
242
+
243
+ # Template bug: return hardcoded buggy + oracle
244
+ if bug_type in TEMPLATE_BUGS:
245
+ buggy_code, oracle_code = TEMPLATE_BUGS[bug_type]
246
+ description = f"Template bug: {bug_type}"
247
+ if self._random.random() < self.noise_prob:
248
+ buggy_code += "\n# TODO: refactor later"
249
+ return buggy_code, bug_type, description, oracle_code
250
+
251
+ # AST injection
252
+ try:
253
+ tree = ast.parse(original_code)
254
+ except SyntaxError:
255
+ return original_code, "parse_error", "Syntax error in original code", original_code
256
+
257
+ injector = ASTBugInjector(bug_type)
258
+ modified_tree = injector.visit(tree)
259
+ ast.fix_missing_locations(modified_tree)
260
+
261
+ if injector.modified:
262
+ buggy_code = ast.unparse(modified_tree)
263
+ oracle_fix = original_code
264
+ description = f"AST bug: {bug_type}"
265
+ else:
266
+ buggy_code = original_code
267
+ oracle_fix = original_code
268
+ bug_type = "no_op"
269
+ description = "No suitable code structure found for injection"
270
+
271
+ if self._random.random() < self.noise_prob:
272
+ buggy_code += "\n# TODO: refactor later"
273
+
274
+ return buggy_code, bug_type, description, oracle_fix
rubrics.py CHANGED
@@ -1,123 +1,136 @@
1
- # rubrics.py – Self-contained Rubrics (no external OpenEnv dependency)
2
-
3
- class Rubric:
4
- """Minimal Rubric base – compatible with OpenEnv but self‑contained."""
5
- def __call__(self, env, action, obs, reward, done, info):
6
- return 0.0
7
-
8
-
9
- # --------------------------------------------------------------------------------
10
- # 1. TOOL‑USAGE BONUS
11
- # --------------------------------------------------------------------------------
12
- class ToolUsageRubric(Rubric):
13
- def __init__(self, bonus: float = 0.05):
14
- self.bonus = bonus
15
-
16
- def __call__(self, env, action, obs, reward, done, info):
17
- score = 0.0
18
- action_type = info.get("action_type", "")
19
-
20
- if action_type == "run_tests":
21
- if not env._tests_run:
22
- score += self.bonus
23
- score += 0.015
24
- elif action_type == "run_linter":
25
- if not env._linter_run:
26
- score += self.bonus
27
- score += 0.015
28
- elif action_type == "query_docs":
29
- if not env._docs_queried:
30
- score += self.bonus * 0.5
31
- elif action_type == "ask_question" and env._step_count <= 3:
32
- score += 0.02
33
- return score
34
-
35
-
36
- # --------------------------------------------------------------------------------
37
- # 2. DELTA‑BASED REWARDS
38
- # --------------------------------------------------------------------------------
39
- class TestDeltaRubric(Rubric):
40
- def __init__(self, weight: float = 0.3):
41
- self.weight = weight
42
-
43
- def __call__(self, env, action, obs, reward, done, info):
44
- delta = env._current_test_score - env._previous_test_score
45
- effective = self.weight
46
- if info.get("action_type") == "propose_fix":
47
- effective *= 0.4
48
- return effective * delta
49
-
50
-
51
- class LintDeltaRubric(Rubric):
52
- def __init__(self, weight: float = 0.3):
53
- self.weight = weight
54
-
55
- def __call__(self, env, action, obs, reward, done, info):
56
- delta = env._current_lint_score - env._previous_lint_score
57
- effective = self.weight * 0.5
58
- if info.get("action_type") == "propose_fix":
59
- effective *= 0.4
60
- return effective * delta
61
-
62
-
63
- # --------------------------------------------------------------------------------
64
- # 3. TERMINAL SUCCESS BONUS
65
- # --------------------------------------------------------------------------------
66
- class TerminalSuccessRubric(Rubric):
67
- def __call__(self, env, action, obs, reward, done, info):
68
- if info.get("action_type") != "propose_fix":
69
- return 0.0
70
- score = 0.0
71
- if env._current_test_score > 0.95:
72
- score += 0.4
73
- elif env._current_test_score > 0.85:
74
- score += 0.2
75
- return score
76
-
77
-
78
- # --------------------------------------------------------------------------------
79
- # 4. EXPLORATION & DIVERSITY
80
- # --------------------------------------------------------------------------------
81
- class ExplorationRubric(Rubric):
82
- def __init__(self, penalty: float = -0.05, bonus: float = 0.021):
83
- self.penalty = penalty
84
- self.bonus = bonus
85
-
86
- def __call__(self, env, action, obs, reward, done, info):
87
- if len(env._action_history) < 3:
88
- return 0.0
89
- recent = env._action_history[-3:]
90
- unique = len(set(recent))
91
- if unique == 1:
92
- return self.penalty
93
- elif unique == 3:
94
- return self.bonus
95
- return 0.0
96
-
97
-
98
- # --------------------------------------------------------------------------------
99
- # 5. ANTI‑HACKING & CONSISTENCY
100
- # --------------------------------------------------------------------------------
101
- class AntiHackingRubric(Rubric):
102
- def __call__(self, env, action, obs, reward, done, info):
103
- if info.get("action_type") != "propose_fix":
104
- return 0.0
105
- score = 0.0
106
- if not env._tests_run:
107
- score -= 0.25
108
- if env._step_count < 2:
109
- score -= 0.1
110
- if env._tests_run and env._linter_run:
111
- score += 0.02
112
- return score
113
-
114
-
115
- # --------------------------------------------------------------------------------
116
- # 6. STEP PENALTY
117
- # --------------------------------------------------------------------------------
118
- class StepPenaltyRubric(Rubric):
119
- def __init__(self, penalty: float = -0.01):
120
- self.penalty = penalty
121
-
122
- def __call__(self, env, action, obs, reward, done, info):
123
- return self.penalty
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rubrics.py – Self-contained Rubrics (no external OpenEnv dependency)
2
+
3
+ class Rubric:
4
+ """Minimal Rubric base – compatible with OpenEnv but self‑contained."""
5
+ def __call__(self, env, action, obs, reward, done, info):
6
+ return 0.0
7
+
8
+
9
+ # --------------------------------------------------------------------------------
10
+ # 1. TOOL‑USAGE BONUS
11
+ # --------------------------------------------------------------------------------
12
+ class ToolUsageRubric(Rubric):
13
+ def __init__(self, bonus: float = 0.05):
14
+ self.bonus = bonus
15
+
16
+ def __call__(self, env, action, obs, reward, done, info):
17
+ score = 0.0
18
+ action_type = info.get("action_type", "")
19
+ # Use pre-action flags from `info` so first-use bonuses are
20
+ # computed correctly even though env flags are mutated in-step.
21
+ prev_tests_run = info.get("prev_tests_run", env._tests_run)
22
+ prev_linter_run = info.get("prev_linter_run", env._linter_run)
23
+ prev_docs_queried = info.get("prev_docs_queried", env._docs_queried)
24
+
25
+ if action_type == "run_tests":
26
+ if not prev_tests_run:
27
+ score += self.bonus
28
+ score += 0.015
29
+ elif action_type == "run_linter":
30
+ if not prev_linter_run:
31
+ score += self.bonus
32
+ score += 0.015
33
+ elif action_type == "query_docs":
34
+ if not prev_docs_queried:
35
+ score += self.bonus * 0.5
36
+ # Encourage docs usage when it is likely useful:
37
+ # - early exploration phase
38
+ # - non-trivial query text
39
+ if env._step_count <= 4 and info.get("docs_query_len", 0) >= 8:
40
+ score += 0.01
41
+ # Discourage repeated docs calls after the first-use signal.
42
+ if prev_docs_queried:
43
+ score -= 0.01
44
+ elif action_type == "ask_question" and env._step_count <= 3:
45
+ score += 0.02
46
+ return score
47
+
48
+
49
+ # --------------------------------------------------------------------------------
50
+ # 2. DELTA‑BASED REWARDS
51
+ # --------------------------------------------------------------------------------
52
+ class TestDeltaRubric(Rubric):
53
+ def __init__(self, weight: float = 0.3):
54
+ self.weight = weight
55
+
56
+ def __call__(self, env, action, obs, reward, done, info):
57
+ delta = env._current_test_score - env._previous_test_score
58
+ effective = self.weight
59
+ if info.get("action_type") == "propose_fix":
60
+ effective *= 0.4
61
+ return effective * delta
62
+
63
+
64
+ class LintDeltaRubric(Rubric):
65
+ def __init__(self, weight: float = 0.3):
66
+ self.weight = weight
67
+
68
+ def __call__(self, env, action, obs, reward, done, info):
69
+ delta = env._current_lint_score - env._previous_lint_score
70
+ effective = self.weight * 0.5
71
+ if info.get("action_type") == "propose_fix":
72
+ effective *= 0.4
73
+ return effective * delta
74
+
75
+
76
+ # --------------------------------------------------------------------------------
77
+ # 3. TERMINAL SUCCESS BONUS
78
+ # --------------------------------------------------------------------------------
79
+ class TerminalSuccessRubric(Rubric):
80
+ def __call__(self, env, action, obs, reward, done, info):
81
+ if info.get("action_type") != "propose_fix":
82
+ return 0.0
83
+ score = 0.0
84
+ if env._current_test_score > 0.95:
85
+ score += 0.4
86
+ elif env._current_test_score > 0.85:
87
+ score += 0.2
88
+ return score
89
+
90
+
91
+ # --------------------------------------------------------------------------------
92
+ # 4. EXPLORATION & DIVERSITY
93
+ # --------------------------------------------------------------------------------
94
+ class ExplorationRubric(Rubric):
95
+ def __init__(self, penalty: float = -0.05, bonus: float = 0.021):
96
+ self.penalty = penalty
97
+ self.bonus = bonus
98
+
99
+ def __call__(self, env, action, obs, reward, done, info):
100
+ if len(env._action_history) < 3:
101
+ return 0.0
102
+ recent = env._action_history[-3:]
103
+ unique = len(set(recent))
104
+ if unique == 1:
105
+ return self.penalty
106
+ elif unique == 3:
107
+ return self.bonus
108
+ return 0.0
109
+
110
+
111
+ # --------------------------------------------------------------------------------
112
+ # 5. ANTI‑HACKING & CONSISTENCY
113
+ # --------------------------------------------------------------------------------
114
+ class AntiHackingRubric(Rubric):
115
+ def __call__(self, env, action, obs, reward, done, info):
116
+ if info.get("action_type") != "propose_fix":
117
+ return 0.0
118
+ score = 0.0
119
+ if not env._tests_run:
120
+ score -= 0.25
121
+ if env._step_count < 2:
122
+ score -= 0.1
123
+ if env._tests_run and env._linter_run:
124
+ score += 0.02
125
+ return score
126
+
127
+
128
+ # --------------------------------------------------------------------------------
129
+ # 6. STEP PENALTY
130
+ # --------------------------------------------------------------------------------
131
+ class StepPenaltyRubric(Rubric):
132
+ def __init__(self, penalty: float = -0.01):
133
+ self.penalty = penalty
134
+
135
+ def __call__(self, env, action, obs, reward, done, info):
136
+ return self.penalty
test_runner.py CHANGED
@@ -1,181 +1,208 @@
1
- # test_runner.py – Full production version with continuous scoring, dynamic function detection, and randomised tests
2
- import subprocess
3
- import tempfile
4
- import os
5
- import json
6
- import ast
7
- import random
8
- import sys
9
- from typing import Tuple, List, Any, Optional
10
- from dataclasses import dataclass
11
-
12
- @dataclass
13
- class TestRunner:
14
- bug_id: str
15
- timeout_sec: int = 5
16
- max_memory_mb: int = 256
17
- fuzz_rounds: int = 3 # number of random test cases per bug
18
-
19
- def run_tests(self, fix_code: str) -> Tuple[float, str]:
20
- """
21
- Returns (score, output_message) where score is proportion of passed tests (0.0–1.0).
22
- """
23
- # 1. Detect the function defined in the agent's code (dynamic)
24
- func_name = self._get_defined_function_name(fix_code)
25
- if not func_name:
26
- return 0.0, "No function definition found in agent code"
27
-
28
- # 2. Generate the test script (includes fixed + fuzzed test cases)
29
- test_script = self._generate_test_script(fix_code, func_name)
30
- with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False, encoding='utf-8') as f:
31
- f.write(test_script)
32
- tmp_path = f.name
33
-
34
- try:
35
- # Resource limiting (Linux only; fallback otherwise)
36
- try:
37
- import resource
38
- resource.setrlimit(resource.RLIMIT_AS, (self.max_memory_mb * 1024 * 1024, self.max_memory_mb * 1024 * 1024))
39
- except Exception:
40
- pass
41
-
42
- result = subprocess.run(
43
- [sys.executable, tmp_path],
44
- capture_output=True,
45
- text=True,
46
- timeout=self.timeout_sec,
47
- encoding='utf-8'
48
- )
49
- # Parse JSON output
50
- try:
51
- data = json.loads(result.stdout.strip())
52
- passed = data.get("passed", 0)
53
- total = data.get("total", 1)
54
- score = passed / total if total > 0 else 0.0
55
- return score, result.stdout.strip()
56
- except json.JSONDecodeError:
57
- # Fallback: look for "True" (legacy)
58
- if "True" in result.stdout:
59
- return 1.0, result.stdout
60
- return 0.0, result.stdout
61
- except subprocess.TimeoutExpired:
62
- return 0.0, "Test execution timed out"
63
- except Exception as e:
64
- return 0.0, f"Unexpected error: {str(e)}"
65
- finally:
66
- try:
67
- os.unlink(tmp_path)
68
- except:
69
- pass
70
-
71
- def _get_defined_function_name(self, code: str) -> Optional[str]:
72
- """Extract the target function name from the code.
73
- Looks for a function named 'fix' first; otherwise returns the first function found.
74
- """
75
- try:
76
- tree = ast.parse(code)
77
- first_func = None
78
- for node in ast.walk(tree):
79
- if isinstance(node, ast.FunctionDef):
80
- if node.name == "fix":
81
- return "fix"
82
- if first_func is None:
83
- first_func = node.name
84
- return first_func # fallback if no 'fix' function exists
85
- except SyntaxError:
86
- pass
87
- return None
88
-
89
- def _generate_test_script(self, fix_code: str, func_name: str) -> str:
90
- """Generate a test script that runs fixed + fuzzed test cases and outputs JSON."""
91
- test_cases = self._get_test_cases(func_name)
92
- fuzzed_cases = self._generate_fuzzed_cases(func_name)
93
- all_cases = test_cases + fuzzed_cases
94
-
95
- lines = []
96
- lines.append(fix_code)
97
- lines.append("")
98
- lines.append("import json")
99
- lines.append("")
100
- lines.append("def run_tests():")
101
- lines.append(f" test_cases = {json.dumps(all_cases)}")
102
- lines.append(" passed = 0")
103
- lines.append(" total = len(test_cases)")
104
- lines.append(" for args, expected in test_cases:")
105
- lines.append(f" try:")
106
- lines.append(f" result = {func_name}(*args) if isinstance(args, list) else {func_name}(args)")
107
- lines.append(f" if result == expected:")
108
- lines.append(f" passed += 1")
109
- lines.append(f" except Exception:")
110
- lines.append(f" pass")
111
- lines.append(" return {'passed': passed, 'total': total}")
112
- lines.append("")
113
- lines.append("if __name__ == '__main__':")
114
- lines.append(" result = run_tests()")
115
- lines.append(" print(json.dumps(result))")
116
- return "\n".join(lines)
117
-
118
- def _get_test_cases(self, func_name: str) -> List[Tuple[List[Any], Any]]:
119
- """
120
- Return a list of (arguments, expected_output) for the given bug_id.
121
- Uses the actual function name (dynamic) for consistency.
122
- """
123
- if self.bug_id == "null_check":
124
- return [
125
- ([{"users": {"alice": "Alice"}, "id": "bob"}], None), # missing key should not crash
126
- ([{"users": {"alice": "Alice"}, "id": "alice"}], "Alice"),
127
- ]
128
- elif self.bug_id == "off_by_one":
129
- return [
130
- ([[1,2,3,4]], 4),
131
- ([[]], 0),
132
- ]
133
- elif self.bug_id == "division_by_zero":
134
- return [
135
- ([[]], 0),
136
- ([[1,2,3]], 2.0),
137
- ]
138
- elif self.bug_id == "wrong_operator":
139
- return [
140
- ([5,3], 8),
141
- ([-1,1], 0),
142
- ]
143
- else:
144
- # For missing_lock, deadlock_order, etc., return empty list (will be handled gracefully)
145
- return []
146
-
147
- def _generate_fuzzed_cases(self, func_name: str) -> List[Tuple[List[Any], Any]]:
148
- """
149
- Generate random test cases to prevent memorisation.
150
- Only for bugs where meaningful fuzzing is possible.
151
- """
152
- cases = []
153
- if self.bug_id == "null_check":
154
- # Random users dictionary and random ids
155
- for _ in range(self.fuzz_rounds):
156
- users = {f"user_{i}": f"name_{i}" for i in range(random.randint(1, 5))}
157
- # Pick existing or missing key
158
- if random.random() > 0.5:
159
- key = random.choice(list(users.keys()))
160
- expected = users[key]
161
- else:
162
- key = "missing_" + str(random.randint(100, 999))
163
- expected = None
164
- cases.append(([{"users": users, "id": key}], expected))
165
- elif self.bug_id == "off_by_one":
166
- for _ in range(self.fuzz_rounds):
167
- length = random.randint(0, 10)
168
- arr = list(range(length))
169
- cases.append(([arr], length))
170
- elif self.bug_id == "division_by_zero":
171
- for _ in range(self.fuzz_rounds):
172
- length = random.randint(0, 10)
173
- data = [random.randint(-100, 100) for _ in range(length)]
174
- expected = sum(data)/length if length else 0
175
- cases.append(([data], expected))
176
- elif self.bug_id == "wrong_operator":
177
- for _ in range(self.fuzz_rounds):
178
- a = random.randint(-100, 100)
179
- b = random.randint(-100, 100)
180
- cases.append(([a, b], a + b))
181
- return cases
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # test_runner.py – Full production version with continuous scoring, dynamic function detection, and randomised tests
2
+ import subprocess
3
+ import tempfile
4
+ import os
5
+ import json
6
+ import ast
7
+ import random
8
+ import sys
9
+ from typing import Tuple, List, Any, Optional
10
+ from dataclasses import dataclass
11
+
12
+ # Bridge fine-grained RedTeam ids to canonical TestRunner families.
13
+ # This keeps evaluation stable even when bug generators use richer labels.
14
+ BUG_ID_CANONICAL_MAP = {
15
+ # Easy-family bugs on `get_user`-style behavior.
16
+ "simple_typo": "null_check",
17
+ "default_value": "null_check",
18
+ "empty_return": "null_check",
19
+
20
+ # Medium arithmetic/control-flow aliases.
21
+ "loop_skip": "off_by_one",
22
+ "sign_error": "wrong_operator",
23
+
24
+ # Hard numeric-safety aliases.
25
+ "division_by_zero_empty": "division_by_zero",
26
+ "division_by_zero_zero": "division_by_zero",
27
+ "float_precision": "division_by_zero",
28
+ "abs_usage": "division_by_zero",
29
+ "round_error": "division_by_zero",
30
+ }
31
+
32
+ @dataclass
33
+ class TestRunner:
34
+ bug_id: str
35
+ timeout_sec: int = 5
36
+ max_memory_mb: int = 256
37
+ fuzz_rounds: int = 3 # number of random test cases per bug
38
+
39
+ def run_tests(self, fix_code: str) -> Tuple[float, str]:
40
+ """
41
+ Returns (score, output_message) where score is proportion of passed tests (0.0–1.0).
42
+ """
43
+ # 1. Detect the function defined in the agent's code (dynamic)
44
+ func_name = self._get_defined_function_name(fix_code)
45
+ if not func_name:
46
+ return 0.0, "No function definition found in agent code"
47
+
48
+ # 2. Normalize bug id so broader RedTeam ids still hit meaningful tests.
49
+ canonical_bug_id = self._canonical_bug_id()
50
+
51
+ # 3. Generate the test script (includes fixed + fuzzed test cases)
52
+ test_script = self._generate_test_script(fix_code, func_name, canonical_bug_id)
53
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False, encoding='utf-8') as f:
54
+ f.write(test_script)
55
+ tmp_path = f.name
56
+
57
+ try:
58
+ # Resource limiting (Linux only; fallback otherwise)
59
+ try:
60
+ import resource
61
+ resource.setrlimit(resource.RLIMIT_AS, (self.max_memory_mb * 1024 * 1024, self.max_memory_mb * 1024 * 1024))
62
+ except Exception:
63
+ pass
64
+
65
+ result = subprocess.run(
66
+ [sys.executable, tmp_path],
67
+ capture_output=True,
68
+ text=True,
69
+ timeout=self.timeout_sec,
70
+ encoding='utf-8'
71
+ )
72
+ # Parse JSON output
73
+ try:
74
+ data = json.loads(result.stdout.strip())
75
+ passed = data.get("passed", 0)
76
+ total = data.get("total", 1)
77
+ score = passed / total if total > 0 else 0.0
78
+ return score, result.stdout.strip()
79
+ except json.JSONDecodeError:
80
+ # Fallback: look for "True" (legacy)
81
+ if "True" in result.stdout:
82
+ return 1.0, result.stdout
83
+ return 0.0, result.stdout
84
+ except subprocess.TimeoutExpired:
85
+ return 0.0, "Test execution timed out"
86
+ except Exception as e:
87
+ return 0.0, f"Unexpected error: {str(e)}"
88
+ finally:
89
+ try:
90
+ os.unlink(tmp_path)
91
+ except:
92
+ pass
93
+
94
+ def _get_defined_function_name(self, code: str) -> Optional[str]:
95
+ """Extract the target function name from the code.
96
+ Looks for a function named 'fix' first; otherwise returns the first function found.
97
+ """
98
+ try:
99
+ tree = ast.parse(code)
100
+ first_func = None
101
+ for node in ast.walk(tree):
102
+ if isinstance(node, ast.FunctionDef):
103
+ if node.name == "fix":
104
+ return "fix"
105
+ if first_func is None:
106
+ first_func = node.name
107
+ return first_func # fallback if no 'fix' function exists
108
+ except SyntaxError:
109
+ pass
110
+ return None
111
+
112
+ def _canonical_bug_id(self) -> str:
113
+ """Return canonical bug family used by this test harness."""
114
+ return BUG_ID_CANONICAL_MAP.get(self.bug_id, self.bug_id)
115
+
116
+ def _generate_test_script(self, fix_code: str, func_name: str, canonical_bug_id: str) -> str:
117
+ """Generate a test script that runs fixed + fuzzed test cases and outputs JSON."""
118
+ test_cases = self._get_test_cases(canonical_bug_id, func_name)
119
+ fuzzed_cases = self._generate_fuzzed_cases(canonical_bug_id, func_name)
120
+ all_cases = test_cases + fuzzed_cases
121
+
122
+ lines = []
123
+ lines.append(fix_code)
124
+ lines.append("")
125
+ lines.append("import json")
126
+ lines.append("")
127
+ lines.append("def run_tests():")
128
+ lines.append(f" test_cases = {json.dumps(all_cases)}")
129
+ lines.append(" passed = 0")
130
+ lines.append(" total = len(test_cases)")
131
+ lines.append(" for args, expected in test_cases:")
132
+ lines.append(f" try:")
133
+ lines.append(f" result = {func_name}(*args) if isinstance(args, list) else {func_name}(args)")
134
+ lines.append(f" if result == expected:")
135
+ lines.append(f" passed += 1")
136
+ lines.append(f" except Exception:")
137
+ lines.append(f" pass")
138
+ lines.append(" return {'passed': passed, 'total': total}")
139
+ lines.append("")
140
+ lines.append("if __name__ == '__main__':")
141
+ lines.append(" result = run_tests()")
142
+ lines.append(" print(json.dumps(result))")
143
+ return "\n".join(lines)
144
+
145
+ def _get_test_cases(self, canonical_bug_id: str, func_name: str) -> List[Tuple[List[Any], Any]]:
146
+ """
147
+ Return a list of (arguments, expected_output) for the given bug_id.
148
+ Uses the actual function name (dynamic) for consistency.
149
+ """
150
+ if canonical_bug_id == "null_check":
151
+ return [
152
+ ([{"users": {"alice": "Alice"}, "id": "bob"}], None), # missing key should not crash
153
+ ([{"users": {"alice": "Alice"}, "id": "alice"}], "Alice"),
154
+ ]
155
+ elif canonical_bug_id == "off_by_one":
156
+ return [
157
+ ([[1,2,3,4]], 4),
158
+ ([[]], 0),
159
+ ]
160
+ elif canonical_bug_id == "division_by_zero":
161
+ return [
162
+ ([[]], 0),
163
+ ([[1,2,3]], 2.0),
164
+ ]
165
+ elif canonical_bug_id == "wrong_operator":
166
+ return [
167
+ ([5,3], 8),
168
+ ([-1,1], 0),
169
+ ]
170
+ else:
171
+ # For missing_lock, deadlock_order, etc., return empty list (will be handled gracefully)
172
+ return []
173
+
174
+ def _generate_fuzzed_cases(self, canonical_bug_id: str, func_name: str) -> List[Tuple[List[Any], Any]]:
175
+ """
176
+ Generate random test cases to prevent memorisation.
177
+ Only for bugs where meaningful fuzzing is possible.
178
+ """
179
+ cases = []
180
+ if canonical_bug_id == "null_check":
181
+ # Random users dictionary and random ids
182
+ for _ in range(self.fuzz_rounds):
183
+ users = {f"user_{i}": f"name_{i}" for i in range(random.randint(1, 5))}
184
+ # Pick existing or missing key
185
+ if random.random() > 0.5:
186
+ key = random.choice(list(users.keys()))
187
+ expected = users[key]
188
+ else:
189
+ key = "missing_" + str(random.randint(100, 999))
190
+ expected = None
191
+ cases.append(([{"users": users, "id": key}], expected))
192
+ elif canonical_bug_id == "off_by_one":
193
+ for _ in range(self.fuzz_rounds):
194
+ length = random.randint(0, 10)
195
+ arr = list(range(length))
196
+ cases.append(([arr], length))
197
+ elif canonical_bug_id == "division_by_zero":
198
+ for _ in range(self.fuzz_rounds):
199
+ length = random.randint(0, 10)
200
+ data = [random.randint(-100, 100) for _ in range(length)]
201
+ expected = sum(data)/length if length else 0
202
+ cases.append(([data], expected))
203
+ elif canonical_bug_id == "wrong_operator":
204
+ for _ in range(self.fuzz_rounds):
205
+ a = random.randint(-100, 100)
206
+ b = random.randint(-100, 100)
207
+ cases.append(([a, b], a + b))
208
+ return cases
training.py CHANGED
@@ -1,708 +1,792 @@
1
- # training.py
2
- import json
3
- import torch
4
- import torch.nn.functional as F
5
- from torch.optim import AdamW
6
- from dataclasses import dataclass
7
- from typing import List, Dict, Tuple, Optional
8
- import numpy as np
9
- import re
10
- import random
11
-
12
- from unsloth import FastLanguageModel
13
- from transformers import TrainingArguments
14
- from trl import SFTTrainer
15
- from datasets import Dataset
16
-
17
- # Import your environment and actions (unchanged)
18
- from environment import CodeReviewEnv
19
- from models import (
20
- RunTests, RunLinter, Inspect,
21
- ProposeFix, WriteComment, AskQuestion,
22
- Done, Skip , QueryDocs
23
- )
24
-
25
- # ======================================================================
26
- # 1. ACTION PARSING (improved with fallback)
27
- # ======================================================================
28
- @dataclass
29
- class AgentAction:
30
- action_type: str
31
- content: Optional[str] = None
32
-
33
- def parse_action(output: str) -> AgentAction:
34
- """Robust JSON parsing with regex fallback and keyword detection."""
35
- # Try strict JSON first
36
- try:
37
- data = json.loads(output)
38
- return AgentAction(
39
- action_type=data.get("action_type", "").lower(),
40
- content=data.get("content")
41
- )
42
- except:
43
- pass
44
-
45
- # Try to extract JSON from markdown blocks
46
- json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', output, re.DOTALL)
47
- if json_match:
48
- try:
49
- data = json.loads(json_match.group(1))
50
- return AgentAction(
51
- action_type=data.get("action_type", "").lower(),
52
- content=data.get("content")
53
- )
54
- except:
55
- pass
56
-
57
- # Try to find "action_type" field with regex
58
- action_pattern = r'"action_type"\s*:\s*"(\w+)"'
59
- match = re.search(action_pattern, output)
60
- if match:
61
- return AgentAction(action_type=match.group(1).lower())
62
-
63
- # Keyword detection as last resort
64
- output_lower = output.lower()
65
- if "test" in output_lower:
66
- return AgentAction("run_tests")
67
- if "lint" in output_lower:
68
- return AgentAction("run_linter")
69
- if "inspect" in output_lower:
70
- return AgentAction("inspect")
71
-
72
- return AgentAction("invalid", output)
73
-
74
- def map_to_env(action: AgentAction):
75
- if action.action_type == "run_tests":
76
- return RunTests()
77
- elif action.action_type == "run_linter":
78
- return RunLinter()
79
- elif action.action_type == "inspect":
80
- return Inspect()
81
- elif action.action_type == "fix":
82
- return ProposeFix(fix_code=action.content or "")
83
- elif action.action_type == "comment":
84
- return WriteComment(comment_text=action.content or "")
85
- elif action.action_type == "question":
86
- return AskQuestion(question=action.content or "")
87
- elif action.action_type == "query_docs": # <-- new
88
- return QueryDocs(query_topic=action.content or "")
89
- elif action.action_type == "done":
90
- return Done()
91
- else:
92
- return Skip()
93
-
94
- # ======================================================================
95
- # 2. MODEL SETUP (stabilised LoRA)
96
- # ======================================================================
97
- def load_model():
98
- model, tokenizer = FastLanguageModel.from_pretrained(
99
- model_name="unsloth/gemma-2-2b-it-bnb-4bit",
100
- max_seq_length=2048,
101
- load_in_4bit=True,
102
- )
103
- # FIXED: Lower rank (16), dropout=0 for stability
104
- model = FastLanguageModel.get_peft_model(
105
- model,
106
- r=16, # was 64 → causes collapse
107
- target_modules=[
108
- "q_proj", "k_proj", "v_proj", "o_proj",
109
- "gate_proj", "up_proj", "down_proj"
110
- ],
111
- lora_alpha=32, # adjusted for r=16
112
- lora_dropout=0.0, # dropout can cause empty outputs
113
- )
114
- # Ensure tokenizer has correct chat template for Gemma-2
115
- if tokenizer.chat_template is None:
116
- tokenizer.chat_template = "{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}<start_of_turn>user\n{{ message['content'] }}<end_of_turn>\n<start_of_turn>model\n{% elif message['role'] == 'assistant' %}{{ message['content'] }}<end_of_turn>\n{% endif %}{% endfor %}"
117
- return model, tokenizer
118
-
119
- # ======================================================================
120
- # 3. MODEL SANITY CHECK (new ensures model can generate text)
121
- # ======================================================================
122
- def test_model_sanity(model, tokenizer) -> bool:
123
- print("\n" + "="*60)
124
- print("SANITY CHECK: Testing base model generation")
125
- print("="*60)
126
- test_prompt = "Hello, how are you?"
127
- messages = [{"role": "user", "content": test_prompt}]
128
- formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
129
- inputs = tokenizer(formatted, return_tensors="pt").to("cuda")
130
- with torch.no_grad():
131
- outputs = model.generate(
132
- **inputs,
133
- max_new_tokens=30,
134
- do_sample=True,
135
- temperature=0.7,
136
- min_new_tokens=1,
137
- eos_token_id=tokenizer.eos_token_id,
138
- pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
139
- )
140
- generated_ids = outputs[0][inputs['input_ids'].shape[1]:]
141
- response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
142
- print(f"Prompt: {test_prompt}")
143
- print(f"Response: {repr(response)}")
144
- if len(response) == 0:
145
- print("❌ Model produces empty output – cannot train.")
146
- return False
147
- print("✓ Model sanity check PASSED\n")
148
- return True
149
-
150
- # ======================================================================
151
- # 4. SUPERVISED WARM-UP (teaches JSON output)
152
- # ======================================================================
153
- def supervised_warmup(model, tokenizer, n_examples=500, epochs=2):
154
- print("\n" + "="*60)
155
- print("SUPERVISED WARM-UP: Teaching JSON format")
156
- print("="*60)
157
-
158
- examples = []
159
- action_templates = [
160
- '{"action_type": "run_tests"}',
161
- '{"action_type": "run_linter"}',
162
- '{"action_type": "inspect"}',
163
- '{"action_type": "fix", "content": "def corrected():\n pass"}',
164
- '{"action_type": "comment", "content": "This looks good."}',
165
- '{"action_type": "question", "content": "Why is this variable used?"}',
166
- '{"action_type": "done"}',
167
- ]
168
-
169
- for i in range(n_examples):
170
- code = f"def example_{i}():\n return {i % 10}"
171
- last_outputs = [
172
- "Tests passed: 2/3",
173
- "Linter found 1 error",
174
- "Inspection complete",
175
- "No previous action",
176
- ]
177
- last_output = random.choice(last_outputs)
178
- # Use same prompt structure as build_prompt
179
- prompt = f"""You are an AI code review agent. Your goal is to convince a simulated human developer to accept your proposed fix and name your proposed fix function fix.
180
-
181
- The developer has a **{author_personality}** personality and will only accept if you provide solid evidence:
182
- - Tests pass (high pass ratio)
183
- - Lint is clean (zero errors)
184
- - Documentation or references are provided
185
- - Your reasoning is clear, uses words like "because" or "therefore", and is detailed (over 30 words if needed)
186
-
187
- Workflow:
188
- 1. Use `inspect` to understand the code.
189
- 2. Use `run_tests` and `run_linter` to gather evidence.
190
- 3. Propose a fix (`fix`) and explain why it works (`comment` or `question`).
191
- 4. If the developer pushes back, read their response carefully and address their specific concern.
192
- 5. Once convinced, use `done` to finish.
193
-
194
- Code:
195
- {obs.code_snippet}
196
-
197
- Author says:
198
- {author_msg if author_msg else "(no response yet start with inspection)"}
199
-
200
- Last tool output:
201
- {tool_output if tool_output else "(none)"}
202
-
203
- Available actions:
204
- run_tests, run_linter, inspect, fix, comment, question, done
205
-
206
- Respond ONLY in JSON:
207
- {{"action_type": "...", "content": "..."}}"""
208
-
209
- action_json = random.choice(action_templates)
210
- messages = [
211
- {"role": "user", "content": prompt},
212
- {"role": "assistant", "content": action_json}
213
- ]
214
- full_text = tokenizer.apply_chat_template(messages, tokenize=False)
215
- examples.append({"text": full_text})
216
-
217
- dataset = Dataset.from_list(examples)
218
- trainer = SFTTrainer(
219
- model=model,
220
- tokenizer=tokenizer,
221
- train_dataset=dataset,
222
- dataset_text_field="text",
223
- max_seq_length=512,
224
- args=TrainingArguments(
225
- output_dir="warmup_output",
226
- num_train_epochs=epochs,
227
- per_device_train_batch_size=4,
228
- gradient_accumulation_steps=2,
229
- learning_rate=2e-5,
230
- logging_steps=50,
231
- save_strategy="no",
232
- fp16=True,
233
- ),
234
- )
235
- print(f"Training on {n_examples} examples for {epochs} epochs...")
236
- trainer.train()
237
- print("✓ Warm-up complete\n")
238
-
239
- # ======================================================================
240
- # 5. ACTION GENERATION WITH LOGPROB TRACKING (fixed)
241
- # ======================================================================
242
- def generate_action_with_logprob(
243
- prompt: str,
244
- model,
245
- tokenizer,
246
- temperature: float = 0.0, # changed: greedy by default for stability
247
- max_retries: int = 2
248
- ) -> Tuple[str, float]:
249
- """Generate action using correct chat template, with fallback."""
250
- messages = [{"role": "user", "content": prompt}]
251
- formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
252
- inputs = tokenizer(formatted, return_tensors="pt").to("cuda")
253
-
254
- for attempt in range(max_retries):
255
- with torch.no_grad():
256
- outputs = model.generate(
257
- **inputs,
258
- max_new_tokens=128,
259
- do_sample=(temperature > 0),
260
- temperature=max(temperature, 0.01) if temperature > 0 else 1.0,
261
- min_new_tokens=1,
262
- return_dict_in_generate=True,
263
- output_scores=True,
264
- )
265
-
266
- generated_ids = outputs.sequences[0][inputs['input_ids'].shape[1]:]
267
- action_text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
268
-
269
- # Compute logprob
270
- logprobs = []
271
- for idx, token_id in enumerate(generated_ids):
272
- if idx < len(outputs.scores):
273
- token_logits = outputs.scores[idx][0]
274
- token_logprob = F.log_softmax(token_logits, dim=-1)[token_id].item()
275
- logprobs.append(token_logprob)
276
- total_logprob = sum(logprobs) if logprobs else -100.0
277
-
278
- # If empty, use fallback
279
- if not action_text:
280
- fallback_actions = [
281
- '{"action_type": "run_tests"}',
282
- '{"action_type": "run_linter"}',
283
- '{"action_type": "inspect"}',
284
- '{"action_type": "skip"}',
285
- ]
286
- action_text = random.choice(fallback_actions)
287
- total_logprob = -50.0
288
- print(f"[WARN] Empty generation → using fallback: {action_text}")
289
- return action_text, total_logprob
290
-
291
- # Validate JSON
292
- try:
293
- json.loads(action_text)
294
- return action_text, total_logprob
295
- except:
296
- if attempt == max_retries - 1:
297
- return '{"action_type":"skip"}', -100.0
298
- continue
299
-
300
- return '{"action_type":"skip"}', -100.0
301
-
302
- # ======================================================================
303
- # 6. PROMPT BUILDER (unchanged – exactly as you wrote)
304
- # ======================================================================
305
- def build_prompt(obs, history_lines: List[str]) -> str:
306
- author_msg = getattr(obs, "author_response", "") or ""
307
- tool_output = getattr(obs, "last_tool_output", "") or ""
308
-
309
- # Personality hint (optional but helpful)
310
- author_personality = getattr(obs, "author_personality", "defensive") # e.g., from env
311
-
312
- prompt = f"""You are an AI code review agent. Your goal is to convince a simulated human developer to accept your proposed fix and name your proposed fix function fix.
313
-
314
- The developer has a **{author_personality}** personality and will only accept if you provide solid evidence:
315
- - Tests pass (high pass ratio)
316
- - Lint is clean (zero errors)
317
- - Documentation or references are provided
318
- - Your reasoning is clear, uses words like "because" or "therefore", and is detailed (over 30 words if needed)
319
-
320
- Workflow:
321
- 1. Use `inspect` to understand the code.
322
- 2. Use `run_tests` and `run_linter` to gather evidence.
323
- 3. Propose a fix (`fix`) and explain why it works (`comment` or `question`).
324
- 4. If the developer pushes back, read their response carefully and address their specific concern.
325
- 5. Once convinced, use `done` to finish.
326
-
327
- Code:
328
- {obs.code_snippet}
329
-
330
- Author says:
331
- {author_msg if author_msg else "(no response yet start with inspection)"}
332
-
333
- Last tool output:
334
- {tool_output if tool_output else "(none)"}
335
-
336
- Available actions:
337
- run_tests, run_linter, inspect, fix, comment, question, done
338
-
339
- Respond ONLY in JSON:
340
- {{"action_type": "...", "content": "..."}}"""
341
-
342
- if history_lines:
343
- history = "\n".join(history_lines[-6:])
344
- prompt += f"\n\nPrevious steps:\n{history}"
345
- return prompt
346
-
347
- # ======================================================================
348
- # 7. TRAJECTORY STORAGE (unchanged)
349
- # ======================================================================
350
- @dataclass
351
- class Trajectory:
352
- states: List[str]
353
- actions: List[str]
354
- rewards: List[float]
355
- logprobs: List[float]
356
- dones: List[bool]
357
-
358
- def __len__(self):
359
- return len(self.states)
360
-
361
- def to_dict(self):
362
- return {
363
- "states": self.states,
364
- "actions": self.actions,
365
- "rewards": self.rewards,
366
- "logprobs": self.logprobs,
367
- "dones": self.dones,
368
- }
369
-
370
- # ======================================================================
371
- # 8. ROLLOUT COLLECTION (uses fixed generate)
372
- # ======================================================================
373
- def collect_trajectory(
374
- env: CodeReviewEnv,
375
- model,
376
- tokenizer,
377
- max_steps: int = 10,
378
- temperature: float = 0.0 # changed to greedy
379
- ) -> Trajectory:
380
- obs = env.reset()
381
- history_lines = []
382
-
383
- states = []
384
- actions = []
385
- rewards = []
386
- logprobs = []
387
- dones = []
388
-
389
- for step in range(max_steps):
390
- prompt = build_prompt(obs, history_lines)
391
- states.append(prompt)
392
-
393
- action_text, logprob = generate_action_with_logprob(
394
- prompt, model, tokenizer, temperature
395
- )
396
- actions.append(action_text)
397
- logprobs.append(logprob)
398
-
399
- action = parse_action(action_text)
400
- env_action = map_to_env(action)
401
- next_obs, reward, done, _ = env.step(env_action)
402
-
403
- rewards.append(reward.value)
404
- dones.append(done)
405
-
406
- history_lines.append(f"Agent: {action_text}")
407
- history_lines.append(f"Env: {next_obs.last_tool_output}")
408
-
409
- obs = next_obs
410
- if done:
411
- break
412
-
413
- return Trajectory(states, actions, rewards, logprobs, dones)
414
-
415
- def collect_trajectories(
416
- env: CodeReviewEnv,
417
- model,
418
- tokenizer,
419
- n_trajectories: int,
420
- max_steps: int = 10
421
- ) -> List[Trajectory]:
422
- trajectories = []
423
- for i in range(n_trajectories):
424
- traj = collect_trajectory(env, model, tokenizer, max_steps)
425
- total_reward = sum(traj.rewards)
426
- print(f"Trajectory {i+1}/{n_trajectories}: "
427
- f"steps={len(traj)}, reward={total_reward:.3f}")
428
- trajectories.append(traj)
429
- return trajectories
430
-
431
- # ======================================================================
432
- # 9. ADVANTAGE ESTIMATION (unchanged)
433
- # ======================================================================
434
- def compute_returns_and_advantages(
435
- rewards: List[float],
436
- dones: List[bool],
437
- gamma: float = 0.99,
438
- standardize: bool = True
439
- ) -> Tuple[List[float], List[float]]:
440
- """
441
- Computes discounted returns and normalised advantages (no critic).
442
- Advantages = returns - mean(returns) (or zero baseline).
443
- """
444
- n = len(rewards)
445
- returns = [0.0] * n
446
- running_return = 0.0
447
- for t in reversed(range(n)):
448
- if dones[t]:
449
- running_return = 0.0
450
- running_return = rewards[t] + gamma * running_return
451
- returns[t] = running_return
452
-
453
- if standardize:
454
- advantages = np.array(returns) - np.mean(returns)
455
- adv_std = np.std(advantages) + 1e-8
456
- advantages = (advantages / adv_std).tolist()
457
- else:
458
- advantages = returns.copy()
459
-
460
- return advantages, returns
461
- # ======================================================================
462
- # 10. COMPUTE NEW LOGPROBS (unchanged)
463
- # ======================================================================
464
- def compute_logprob(prompt: str, action: str, model, tokenizer) -> float:
465
- messages = [{"role": "user", "content": prompt}]
466
- formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
467
- full_text = formatted + action
468
- inputs = tokenizer(full_text, return_tensors="pt").to("cuda")
469
-
470
- with torch.no_grad():
471
- outputs = model(**inputs)
472
- logits = outputs.logits
473
-
474
- action_ids = tokenizer.encode(action, add_special_tokens=False)
475
- prefix_ids = tokenizer.encode(formatted, add_special_tokens=False)
476
- action_start = len(prefix_ids)
477
-
478
- logprobs = []
479
- for idx, token_id in enumerate(action_ids):
480
- position = action_start + idx - 1
481
- if 0 <= position < logits.shape[1]:
482
- token_logits = logits[0, position]
483
- token_logprob = F.log_softmax(token_logits, dim=-1)[token_id].item()
484
- logprobs.append(token_logprob)
485
- return sum(logprobs) if logprobs else -100.0
486
-
487
- # ======================================================================
488
- # 11. PPO UPDATE (unchanged except uses compute_logprob correctly)
489
- # ======================================================================
490
- def ppo_update(
491
- trajectories: List[Trajectory],
492
- model,
493
- tokenizer,
494
- optimizer,
495
- n_epochs: int = 4,
496
- clip_epsilon: float = 0.2,
497
- entropy_coef: float = 0.01,
498
- gamma: float = 0.99,
499
- ) -> Dict[str, float]:
500
- model.train()
501
-
502
- all_states = []
503
- all_actions = []
504
- all_old_logprobs = []
505
- all_advantages = []
506
- all_returns = []
507
-
508
- for traj in trajectories:
509
- advantages, returns = compute_returns_and_advantages(
510
- traj.rewards, traj.dones, gamma=gamma, standardize=True
511
- )
512
- all_states.extend(traj.states)
513
- all_actions.extend(traj.actions)
514
- all_old_logprobs.extend(traj.logprobs)
515
- all_advantages.extend(advantages)
516
- all_returns.extend(returns)
517
-
518
- n_samples = len(all_states)
519
- total_loss = 0.0
520
- total_policy_loss = 0.0
521
- total_entropy = 0.0
522
- n_updates = 0
523
-
524
- for epoch in range(n_epochs):
525
- indices = np.random.permutation(n_samples)
526
- for i in indices:
527
- state = all_states[i]
528
- action = all_actions[i]
529
- old_logprob = all_old_logprobs[i]
530
- advantage = all_advantages[i]
531
-
532
- # Use the same chat template for PPO update
533
- messages = [{"role": "user", "content": state}]
534
- formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
535
- full_text = formatted + action
536
- inputs = tokenizer(full_text, return_tensors="pt").to("cuda")
537
-
538
- outputs = model(**inputs)
539
- logits = outputs.logits
540
-
541
- action_ids = tokenizer.encode(action, add_special_tokens=False)
542
- prefix_ids = tokenizer.encode(formatted, add_special_tokens=False)
543
- action_start = len(prefix_ids)
544
-
545
- logprobs = []
546
- entropy = 0.0
547
- for idx, token_id in enumerate(action_ids):
548
- position = action_start + idx - 1
549
- if 0 <= position < logits.shape[1]:
550
- token_logits = logits[0, position]
551
- log_probs = F.log_softmax(token_logits, dim=-1)
552
- token_logprob = log_probs[token_id]
553
- logprobs.append(token_logprob)
554
-
555
- probs = F.softmax(token_logits, dim=-1)
556
- entropy += -(probs * log_probs).sum()
557
-
558
- if not logprobs:
559
- continue
560
-
561
- new_logprob = sum(logprobs)
562
- avg_entropy = entropy / len(logprobs) if logprobs else 0.0
563
-
564
- ratio = torch.exp(new_logprob - old_logprob)
565
- surr1 = ratio * advantage
566
- surr2 = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * advantage
567
- policy_loss = -torch.min(surr1, surr2)
568
- loss = policy_loss - entropy_coef * avg_entropy
569
-
570
- optimizer.zero_grad()
571
- loss.backward()
572
- torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
573
- optimizer.step()
574
-
575
- total_loss += loss.item()
576
- total_policy_loss += policy_loss.item()
577
- total_entropy += avg_entropy.item()
578
- n_updates += 1
579
-
580
- return {
581
- "loss": total_loss / n_updates if n_updates > 0 else 0.0,
582
- "policy_loss": total_policy_loss / n_updates if n_updates > 0 else 0.0,
583
- "entropy": total_entropy / n_updates if n_updates > 0 else 0.0,
584
- }
585
- # ======================================================================
586
- # 12. EVALUATION (unchanged)
587
- # ======================================================================
588
- def evaluate_policy(
589
- env: CodeReviewEnv,
590
- model,
591
- tokenizer,
592
- n_episodes: int = 10,
593
- max_steps: int = 10
594
- ) -> Dict[str, float]:
595
- model.eval()
596
- total_rewards = []
597
- episode_lengths = []
598
- success_count = 0
599
-
600
- for _ in range(n_episodes):
601
- traj = collect_trajectory(env, model, tokenizer, max_steps, temperature=0.0)
602
- total_reward = sum(traj.rewards)
603
- total_rewards.append(total_reward)
604
- episode_lengths.append(len(traj))
605
- if total_reward > 0.5:
606
- success_count += 1
607
-
608
- return {
609
- "avg_reward": np.mean(total_rewards),
610
- "std_reward": np.std(total_rewards),
611
- "avg_length": np.mean(episode_lengths),
612
- "success_rate": success_count / n_episodes,
613
- }
614
-
615
- # ======================================================================
616
- # 13. MAIN TRAINING LOOP (added sanity check and warm-up)
617
- # ======================================================================
618
- def train_ppo(
619
- n_iterations: int = 50,
620
- trajectories_per_iter: int = 10,
621
- n_epochs: int = 4,
622
- max_steps: int = 10,
623
- learning_rate: float = 3e-5,
624
- clip_epsilon: float = 0.2,
625
- entropy_coef: float = 0.01,
626
- gamma: float = 0.99,
627
- eval_every: int = 5,
628
- ):
629
- print("Loading model...")
630
- model, tokenizer = load_model()
631
-
632
- # NEW: Sanity check before any training
633
- if not test_model_sanity(model, tokenizer):
634
- print("\n❌ Model sanity check failed – cannot proceed.")
635
- return
636
-
637
- # NEW: Supervised warm-up to teach JSON format
638
- supervised_warmup(model, tokenizer, n_examples=500, epochs=2)
639
-
640
- optimizer = AdamW(model.parameters(), lr=learning_rate)
641
- env = CodeReviewEnv()
642
-
643
- print(f"\n{'='*60}")
644
- print(f"Starting PPO Training")
645
- print(f"Iterations: {n_iterations}")
646
- print(f"Trajectories per iteration: {trajectories_per_iter}")
647
- print(f"PPO epochs: {n_epochs}")
648
- print(f"{'='*60}\n")
649
-
650
- for iteration in range(n_iterations):
651
- print(f"\n--- Iteration {iteration + 1}/{n_iterations} ---")
652
-
653
- print("Collecting trajectories...")
654
- trajectories = collect_trajectories(
655
- env, model, tokenizer, trajectories_per_iter, max_steps
656
- )
657
-
658
- avg_reward = np.mean([sum(t.rewards) for t in trajectories])
659
- avg_length = np.mean([len(t) for t in trajectories])
660
-
661
- print(f"Avg reward: {avg_reward:.3f}")
662
- print(f"Avg length: {avg_length:.1f}")
663
-
664
- print("Updating policy...")
665
- metrics = ppo_update(
666
- trajectories,
667
- model,
668
- tokenizer,
669
- optimizer,
670
- n_epochs=n_epochs,
671
- clip_epsilon=clip_epsilon,
672
- entropy_coef=entropy_coef,
673
- gamma=gamma,
674
- )
675
-
676
- print(f"Loss: {metrics['loss']:.4f}")
677
- print(f"Policy loss: {metrics['policy_loss']:.4f}")
678
- print(f"Entropy: {metrics['entropy']:.4f}")
679
-
680
- if (iteration + 1) % eval_every == 0:
681
- print("\nEvaluating policy...")
682
- eval_metrics = evaluate_policy(env, model, tokenizer, n_episodes=10)
683
- print(f"Eval avg reward: {eval_metrics['avg_reward']:.3f} ± {eval_metrics['std_reward']:.3f}")
684
- print(f"Eval success rate: {eval_metrics['success_rate']:.2%}")
685
- print(f"Eval avg length: {eval_metrics['avg_length']:.1f}")
686
-
687
- print("\n" + "="*60)
688
- print("Training complete. Saving model...")
689
- model.save_pretrained("ppo_final_model")
690
- tokenizer.save_pretrained("ppo_final_model")
691
- print("Model saved to ppo_final_model/")
692
- print("="*60)
693
-
694
- # ======================================================================
695
- # 14. ENTRY POINT (unchanged)
696
- # ======================================================================
697
- if __name__ == "__main__":
698
- train_ppo(
699
- n_iterations=50,
700
- trajectories_per_iter=10,
701
- n_epochs=4,
702
- max_steps=10,
703
- learning_rate=3e-5,
704
- clip_epsilon=0.2,
705
- entropy_coef=0.01,
706
- gamma=0.99,
707
- eval_every=5,
708
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # training.py
2
+ import json
3
+ import os
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch.optim import AdamW
7
+ from dataclasses import dataclass
8
+ from typing import List, Dict, Tuple, Optional
9
+ import numpy as np
10
+ import re
11
+ import random
12
+ import matplotlib.pyplot as plt
13
+
14
+ from unsloth import FastLanguageModel
15
+ from transformers import TrainingArguments
16
+ from trl import SFTTrainer
17
+ from datasets import Dataset
18
+
19
+ # Import your environment and actions (unchanged)
20
+ from environment import CodeReviewEnv
21
+ from redteam import BUG_DB
22
+ from models import (
23
+ RunTests, RunLinter, Inspect,
24
+ ProposeFix, WriteComment, AskQuestion,
25
+ Done, Skip , QueryDocs
26
+ )
27
+
28
+ # ======================================================================
29
+ # 1. ACTION PARSING (improved with fallback)
30
+ # ======================================================================
31
+ @dataclass
32
+ class AgentAction:
33
+ action_type: str
34
+ content: Optional[str] = None
35
+
36
+ def parse_action(output: str) -> AgentAction:
37
+ """Robust JSON parsing with regex fallback and keyword detection."""
38
+ # Try strict JSON first
39
+ try:
40
+ data = json.loads(output)
41
+ return AgentAction(
42
+ action_type=data.get("action_type", "").lower(),
43
+ content=data.get("content")
44
+ )
45
+ except:
46
+ pass
47
+
48
+ # Try to extract JSON from markdown blocks
49
+ json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', output, re.DOTALL)
50
+ if json_match:
51
+ try:
52
+ data = json.loads(json_match.group(1))
53
+ return AgentAction(
54
+ action_type=data.get("action_type", "").lower(),
55
+ content=data.get("content")
56
+ )
57
+ except:
58
+ pass
59
+
60
+ # Try to find "action_type" field with regex
61
+ action_pattern = r'"action_type"\s*:\s*"(\w+)"'
62
+ match = re.search(action_pattern, output)
63
+ if match:
64
+ return AgentAction(action_type=match.group(1).lower())
65
+
66
+ # Keyword detection as last resort
67
+ output_lower = output.lower()
68
+ if "test" in output_lower:
69
+ return AgentAction("run_tests")
70
+ if "lint" in output_lower:
71
+ return AgentAction("run_linter")
72
+ if "inspect" in output_lower:
73
+ return AgentAction("inspect")
74
+ if "doc" in output_lower or "documentation" in output_lower:
75
+ # Bridge natural language mentions to rltool-backed retrieval action.
76
+ return AgentAction("query_docs", "bug fix guidance")
77
+
78
+ return AgentAction("invalid", output)
79
+
80
+ def map_to_env(action: AgentAction):
81
+ if action.action_type == "run_tests":
82
+ return RunTests()
83
+ elif action.action_type == "run_linter":
84
+ return RunLinter()
85
+ elif action.action_type == "inspect":
86
+ return Inspect()
87
+ elif action.action_type == "fix":
88
+ return ProposeFix(fix_code=action.content or "")
89
+ elif action.action_type == "comment":
90
+ return WriteComment(comment_text=action.content or "")
91
+ elif action.action_type == "question":
92
+ return AskQuestion(question=action.content or "")
93
+ elif action.action_type == "query_docs": # <-- new
94
+ return QueryDocs(query_topic=action.content or "")
95
+ elif action.action_type == "done":
96
+ return Done()
97
+ else:
98
+ return Skip()
99
+
100
+ # ======================================================================
101
+ # 2. MODEL SETUP (stabilised LoRA)
102
+ # ======================================================================
103
+ def load_model():
104
+ model, tokenizer = FastLanguageModel.from_pretrained(
105
+ model_name="unsloth/gemma-2-2b-it-bnb-4bit",
106
+ max_seq_length=2048,
107
+ load_in_4bit=True,
108
+ )
109
+ # FIXED: Lower rank (16), dropout=0 for stability
110
+ model = FastLanguageModel.get_peft_model(
111
+ model,
112
+ r=16, # was 64 causes collapse
113
+ target_modules=[
114
+ "q_proj", "k_proj", "v_proj", "o_proj",
115
+ "gate_proj", "up_proj", "down_proj"
116
+ ],
117
+ lora_alpha=32, # adjusted for r=16
118
+ lora_dropout=0.0, # dropout can cause empty outputs
119
+ )
120
+ # Ensure tokenizer has correct chat template for Gemma-2
121
+ if tokenizer.chat_template is None:
122
+ tokenizer.chat_template = "{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}<start_of_turn>user\n{{ message['content'] }}<end_of_turn>\n<start_of_turn>model\n{% elif message['role'] == 'assistant' %}{{ message['content'] }}<end_of_turn>\n{% endif %}{% endfor %}"
123
+ return model, tokenizer
124
+
125
+ # ======================================================================
126
+ # 3. MODEL SANITY CHECK (new – ensures model can generate text)
127
+ # ======================================================================
128
+ def test_model_sanity(model, tokenizer) -> bool:
129
+ print("\n" + "="*60)
130
+ print("SANITY CHECK: Testing base model generation")
131
+ print("="*60)
132
+ test_prompt = "Hello, how are you?"
133
+ messages = [{"role": "user", "content": test_prompt}]
134
+ formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
135
+ inputs = tokenizer(formatted, return_tensors="pt").to("cuda")
136
+ with torch.no_grad():
137
+ outputs = model.generate(
138
+ **inputs,
139
+ max_new_tokens=30,
140
+ do_sample=True,
141
+ temperature=0.7,
142
+ min_new_tokens=1,
143
+ eos_token_id=tokenizer.eos_token_id,
144
+ pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
145
+ )
146
+ generated_ids = outputs[0][inputs['input_ids'].shape[1]:]
147
+ response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
148
+ print(f"Prompt: {test_prompt}")
149
+ print(f"Response: {repr(response)}")
150
+ if len(response) == 0:
151
+ print("❌ Model produces empty output cannot train.")
152
+ return False
153
+ print("✓ Model sanity check PASSED\n")
154
+ return True
155
+
156
+ # ======================================================================
157
+ # 4. SUPERVISED WARM-UP (teaches JSON output)
158
+ # ======================================================================
159
+ def supervised_warmup(model, tokenizer, n_examples=500, epochs=2):
160
+ print("\n" + "="*60)
161
+ print("SUPERVISED WARM-UP: Teaching JSON format")
162
+ print("="*60)
163
+
164
+ examples = []
165
+ action_templates = [
166
+ '{"action_type": "run_tests"}',
167
+ '{"action_type": "run_linter"}',
168
+ '{"action_type": "inspect"}',
169
+ '{"action_type": "query_docs", "content": "python keyerror handling"}',
170
+ '{"action_type": "fix", "content": "def corrected():\n pass"}',
171
+ '{"action_type": "comment", "content": "This looks good."}',
172
+ '{"action_type": "question", "content": "Why is this variable used?"}',
173
+ '{"action_type": "done"}',
174
+ ]
175
+
176
+ for i in range(n_examples):
177
+ code = f"def example_{i}():\n return {i % 10}"
178
+ last_outputs = [
179
+ "Tests passed: 2/3",
180
+ "Linter found 1 error",
181
+ "Inspection complete",
182
+ "No previous action",
183
+ ]
184
+ last_output = random.choice(last_outputs)
185
+ # Use same prompt structure as build_prompt
186
+ prompt = f"""You are an AI code review agent. Your goal is to convince a simulated human developer to accept your proposed fix and name your proposed fix function fix.
187
+
188
+ The developer has a **{author_personality}** personality and will only accept if you provide solid evidence:
189
+ - Tests pass (high pass ratio)
190
+ - Lint is clean (zero errors)
191
+ - Documentation or references are provided
192
+ - Your reasoning is clear, uses words like "because" or "therefore", and is detailed (over 30 words if needed)
193
+
194
+ Workflow:
195
+ 1. Use `inspect` to understand the code.
196
+ 2. Use `run_tests` and `run_linter` to gather evidence.
197
+ 3. Propose a fix (`fix`) and explain why it works (`comment` or `question`).
198
+ 4. If the developer pushes back, read their response carefully and address their specific concern.
199
+ 5. Once convinced, use `done` to finish.
200
+
201
+ Code:
202
+ {obs.code_snippet}
203
+
204
+ Author says:
205
+ {author_msg if author_msg else "(no response yet – start with inspection)"}
206
+
207
+ Last tool output:
208
+ {tool_output if tool_output else "(none)"}
209
+
210
+ Available actions:
211
+ run_tests, run_linter, inspect, query_docs, fix, comment, question, done
212
+
213
+ Respond ONLY in JSON:
214
+ {{"action_type": "...", "content": "..."}}"""
215
+
216
+ action_json = random.choice(action_templates)
217
+ messages = [
218
+ {"role": "user", "content": prompt},
219
+ {"role": "assistant", "content": action_json}
220
+ ]
221
+ full_text = tokenizer.apply_chat_template(messages, tokenize=False)
222
+ examples.append({"text": full_text})
223
+
224
+ dataset = Dataset.from_list(examples)
225
+ trainer = SFTTrainer(
226
+ model=model,
227
+ tokenizer=tokenizer,
228
+ train_dataset=dataset,
229
+ dataset_text_field="text",
230
+ max_seq_length=512,
231
+ args=TrainingArguments(
232
+ output_dir="warmup_output",
233
+ num_train_epochs=epochs,
234
+ per_device_train_batch_size=4,
235
+ gradient_accumulation_steps=2,
236
+ learning_rate=2e-5,
237
+ logging_steps=50,
238
+ save_strategy="no",
239
+ fp16=True,
240
+ ),
241
+ )
242
+ print(f"Training on {n_examples} examples for {epochs} epochs...")
243
+ trainer.train()
244
+ print("✓ Warm-up complete\n")
245
+
246
+ # ======================================================================
247
+ # 5. ACTION GENERATION WITH LOGPROB TRACKING (fixed)
248
+ # ======================================================================
249
+ def generate_action_with_logprob(
250
+ prompt: str,
251
+ model,
252
+ tokenizer,
253
+ temperature: float = 0.0, # changed: greedy by default for stability
254
+ max_retries: int = 2
255
+ ) -> Tuple[str, float]:
256
+ """Generate action using correct chat template, with fallback."""
257
+ messages = [{"role": "user", "content": prompt}]
258
+ formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
259
+ inputs = tokenizer(formatted, return_tensors="pt").to("cuda")
260
+
261
+ for attempt in range(max_retries):
262
+ with torch.no_grad():
263
+ outputs = model.generate(
264
+ **inputs,
265
+ max_new_tokens=128,
266
+ do_sample=(temperature > 0),
267
+ temperature=max(temperature, 0.01) if temperature > 0 else 1.0,
268
+ min_new_tokens=1,
269
+ return_dict_in_generate=True,
270
+ output_scores=True,
271
+ )
272
+
273
+ generated_ids = outputs.sequences[0][inputs['input_ids'].shape[1]:]
274
+ action_text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
275
+
276
+ # Compute logprob
277
+ logprobs = []
278
+ for idx, token_id in enumerate(generated_ids):
279
+ if idx < len(outputs.scores):
280
+ token_logits = outputs.scores[idx][0]
281
+ token_logprob = F.log_softmax(token_logits, dim=-1)[token_id].item()
282
+ logprobs.append(token_logprob)
283
+ total_logprob = sum(logprobs) if logprobs else -100.0
284
+
285
+ # If empty, use fallback
286
+ if not action_text:
287
+ fallback_actions = [
288
+ '{"action_type": "run_tests"}',
289
+ '{"action_type": "run_linter"}',
290
+ '{"action_type": "inspect"}',
291
+ '{"action_type": "skip"}',
292
+ ]
293
+ action_text = random.choice(fallback_actions)
294
+ total_logprob = -50.0
295
+ print(f"[WARN] Empty generation → using fallback: {action_text}")
296
+ return action_text, total_logprob
297
+
298
+ # Validate JSON
299
+ try:
300
+ json.loads(action_text)
301
+ return action_text, total_logprob
302
+ except:
303
+ if attempt == max_retries - 1:
304
+ return '{"action_type":"skip"}', -100.0
305
+ continue
306
+
307
+ return '{"action_type":"skip"}', -100.0
308
+
309
+ # ======================================================================
310
+ # 6. PROMPT BUILDER (unchanged exactly as you wrote)
311
+ # ======================================================================
312
+ def build_prompt(obs, history_lines: List[str]) -> str:
313
+ author_msg = getattr(obs, "author_response", "") or ""
314
+ tool_output = getattr(obs, "last_tool_output", "") or ""
315
+
316
+ # Personality hint (optional but helpful)
317
+ author_personality = getattr(obs, "author_personality", "defensive") # e.g., from env
318
+
319
+ prompt = f"""You are an AI code review agent. Your goal is to convince a simulated human developer to accept your proposed fix and name your proposed fix function fix.
320
+
321
+ The developer has a **{author_personality}** personality and will only accept if you provide solid evidence:
322
+ - Tests pass (high pass ratio)
323
+ - Lint is clean (zero errors)
324
+ - Documentation or references are provided
325
+ - Your reasoning is clear, uses words like "because" or "therefore", and is detailed (over 30 words if needed)
326
+
327
+ Workflow:
328
+ 1. Use `inspect` to understand the code.
329
+ 2. Use `run_tests` and `run_linter` to gather evidence.
330
+ 3. Use `query_docs` when you need references or language-specific guidance.
331
+ 4. Propose a fix (`fix`) and explain why it works (`comment` or `question`).
332
+ 5. If the developer pushes back, read their response carefully and address their specific concern.
333
+ 6. Once convinced, use `done` to finish.
334
+
335
+ Code:
336
+ {obs.code_snippet}
337
+
338
+ Author says:
339
+ {author_msg if author_msg else "(no response yet – start with inspection)"}
340
+
341
+ Last tool output:
342
+ {tool_output if tool_output else "(none)"}
343
+
344
+ Available actions:
345
+ run_tests, run_linter, inspect, query_docs, fix, comment, question, done
346
+
347
+ Respond ONLY in JSON:
348
+ {{"action_type": "...", "content": "..."}}"""
349
+
350
+ if history_lines:
351
+ history = "\n".join(history_lines[-6:])
352
+ prompt += f"\n\nPrevious steps:\n{history}"
353
+ return prompt
354
+
355
+ # ======================================================================
356
+ # 7. TRAJECTORY STORAGE (unchanged)
357
+ # ======================================================================
358
+ @dataclass
359
+ class Trajectory:
360
+ states: List[str]
361
+ actions: List[str]
362
+ rewards: List[float]
363
+ logprobs: List[float]
364
+ dones: List[bool]
365
+
366
+ def __len__(self):
367
+ return len(self.states)
368
+
369
+ def to_dict(self):
370
+ return {
371
+ "states": self.states,
372
+ "actions": self.actions,
373
+ "rewards": self.rewards,
374
+ "logprobs": self.logprobs,
375
+ "dones": self.dones,
376
+ }
377
+
378
+ # ======================================================================
379
+ # 8. ROLLOUT COLLECTION (uses fixed generate)
380
+ # ======================================================================
381
+ def collect_trajectory(
382
+ env: CodeReviewEnv,
383
+ model,
384
+ tokenizer,
385
+ max_steps: int = 10,
386
+ temperature: float = 0.0 # changed to greedy
387
+ ) -> Trajectory:
388
+ obs = env.reset()
389
+ history_lines = []
390
+
391
+ states = []
392
+ actions = []
393
+ rewards = []
394
+ logprobs = []
395
+ dones = []
396
+
397
+ for step in range(max_steps):
398
+ prompt = build_prompt(obs, history_lines)
399
+ states.append(prompt)
400
+
401
+ action_text, logprob = generate_action_with_logprob(
402
+ prompt, model, tokenizer, temperature
403
+ )
404
+ actions.append(action_text)
405
+ logprobs.append(logprob)
406
+
407
+ action = parse_action(action_text)
408
+ env_action = map_to_env(action)
409
+ next_obs, reward, done, _ = env.step(env_action)
410
+
411
+ rewards.append(reward.value)
412
+ dones.append(done)
413
+
414
+ history_lines.append(f"Agent: {action_text}")
415
+ history_lines.append(f"Env: {next_obs.last_tool_output}")
416
+
417
+ obs = next_obs
418
+ if done:
419
+ break
420
+
421
+ return Trajectory(states, actions, rewards, logprobs, dones)
422
+
423
+ def collect_trajectories(
424
+ env: CodeReviewEnv,
425
+ model,
426
+ tokenizer,
427
+ n_trajectories: int,
428
+ max_steps: int = 10,
429
+ task_levels: Optional[List[str]] = None,
430
+ task_weights: Optional[List[float]] = None,
431
+ ) -> List[Trajectory]:
432
+ # Link training to RedTeam's full bug distribution by sampling tasks
433
+ # per trajectory instead of training only on env default ("easy").
434
+ if task_levels is None:
435
+ task_levels = list(BUG_DB.keys())
436
+ if task_weights is not None and len(task_weights) != len(task_levels):
437
+ raise ValueError("task_weights must match task_levels length")
438
+ if task_weights is not None and sum(task_weights) <= 0:
439
+ raise ValueError("task_weights must have a positive total")
440
+
441
+ trajectories = []
442
+ for i in range(n_trajectories):
443
+ # Weighted sampling supports curriculum-style training schedules.
444
+ sampled_task = random.choices(task_levels, weights=task_weights, k=1)[0]
445
+ env.set_task(sampled_task)
446
+ traj = collect_trajectory(env, model, tokenizer, max_steps)
447
+ total_reward = sum(traj.rewards)
448
+ print(f"Trajectory {i+1}/{n_trajectories}: "
449
+ f"task={sampled_task}, steps={len(traj)}, reward={total_reward:.3f}")
450
+ trajectories.append(traj)
451
+ return trajectories
452
+
453
+ # ======================================================================
454
+ # 9. ADVANTAGE ESTIMATION (unchanged)
455
+ # ======================================================================
456
+ def compute_returns_and_advantages(
457
+ rewards: List[float],
458
+ dones: List[bool],
459
+ gamma: float = 0.99,
460
+ standardize: bool = True
461
+ ) -> Tuple[List[float], List[float]]:
462
+ """
463
+ Computes discounted returns and normalised advantages (no critic).
464
+ Advantages = returns - mean(returns) (or zero baseline).
465
+ """
466
+ n = len(rewards)
467
+ returns = [0.0] * n
468
+ running_return = 0.0
469
+ for t in reversed(range(n)):
470
+ if dones[t]:
471
+ running_return = 0.0
472
+ running_return = rewards[t] + gamma * running_return
473
+ returns[t] = running_return
474
+
475
+ if standardize:
476
+ advantages = np.array(returns) - np.mean(returns)
477
+ adv_std = np.std(advantages) + 1e-8
478
+ advantages = (advantages / adv_std).tolist()
479
+ else:
480
+ advantages = returns.copy()
481
+
482
+ return advantages, returns
483
+ # ======================================================================
484
+ # 10. COMPUTE NEW LOGPROBS (unchanged)
485
+ # ======================================================================
486
+ def compute_logprob(prompt: str, action: str, model, tokenizer) -> float:
487
+ messages = [{"role": "user", "content": prompt}]
488
+ formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
489
+ full_text = formatted + action
490
+ inputs = tokenizer(full_text, return_tensors="pt").to("cuda")
491
+
492
+ with torch.no_grad():
493
+ outputs = model(**inputs)
494
+ logits = outputs.logits
495
+
496
+ action_ids = tokenizer.encode(action, add_special_tokens=False)
497
+ prefix_ids = tokenizer.encode(formatted, add_special_tokens=False)
498
+ action_start = len(prefix_ids)
499
+
500
+ logprobs = []
501
+ for idx, token_id in enumerate(action_ids):
502
+ position = action_start + idx - 1
503
+ if 0 <= position < logits.shape[1]:
504
+ token_logits = logits[0, position]
505
+ token_logprob = F.log_softmax(token_logits, dim=-1)[token_id].item()
506
+ logprobs.append(token_logprob)
507
+ return sum(logprobs) if logprobs else -100.0
508
+
509
+ # ======================================================================
510
+ # 11. PPO UPDATE (unchanged except uses compute_logprob correctly)
511
+ # ======================================================================
512
+ def ppo_update(
513
+ trajectories: List[Trajectory],
514
+ model,
515
+ tokenizer,
516
+ optimizer,
517
+ n_epochs: int = 4,
518
+ clip_epsilon: float = 0.2,
519
+ entropy_coef: float = 0.01,
520
+ gamma: float = 0.99,
521
+ ) -> Dict[str, float]:
522
+ model.train()
523
+
524
+ all_states = []
525
+ all_actions = []
526
+ all_old_logprobs = []
527
+ all_advantages = []
528
+ all_returns = []
529
+
530
+ for traj in trajectories:
531
+ advantages, returns = compute_returns_and_advantages(
532
+ traj.rewards, traj.dones, gamma=gamma, standardize=True
533
+ )
534
+ all_states.extend(traj.states)
535
+ all_actions.extend(traj.actions)
536
+ all_old_logprobs.extend(traj.logprobs)
537
+ all_advantages.extend(advantages)
538
+ all_returns.extend(returns)
539
+
540
+ n_samples = len(all_states)
541
+ total_loss = 0.0
542
+ total_policy_loss = 0.0
543
+ total_entropy = 0.0
544
+ n_updates = 0
545
+
546
+ for epoch in range(n_epochs):
547
+ indices = np.random.permutation(n_samples)
548
+ for i in indices:
549
+ state = all_states[i]
550
+ action = all_actions[i]
551
+ old_logprob = all_old_logprobs[i]
552
+ advantage = all_advantages[i]
553
+
554
+ # Use the same chat template for PPO update
555
+ messages = [{"role": "user", "content": state}]
556
+ formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
557
+ full_text = formatted + action
558
+ inputs = tokenizer(full_text, return_tensors="pt").to("cuda")
559
+
560
+ outputs = model(**inputs)
561
+ logits = outputs.logits
562
+
563
+ action_ids = tokenizer.encode(action, add_special_tokens=False)
564
+ prefix_ids = tokenizer.encode(formatted, add_special_tokens=False)
565
+ action_start = len(prefix_ids)
566
+
567
+ logprobs = []
568
+ entropy = 0.0
569
+ for idx, token_id in enumerate(action_ids):
570
+ position = action_start + idx - 1
571
+ if 0 <= position < logits.shape[1]:
572
+ token_logits = logits[0, position]
573
+ log_probs = F.log_softmax(token_logits, dim=-1)
574
+ token_logprob = log_probs[token_id]
575
+ logprobs.append(token_logprob)
576
+
577
+ probs = F.softmax(token_logits, dim=-1)
578
+ entropy += -(probs * log_probs).sum()
579
+
580
+ if not logprobs:
581
+ continue
582
+
583
+ new_logprob = sum(logprobs)
584
+ avg_entropy = entropy / len(logprobs) if logprobs else 0.0
585
+
586
+ ratio = torch.exp(new_logprob - old_logprob)
587
+ surr1 = ratio * advantage
588
+ surr2 = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * advantage
589
+ policy_loss = -torch.min(surr1, surr2)
590
+ loss = policy_loss - entropy_coef * avg_entropy
591
+
592
+ optimizer.zero_grad()
593
+ loss.backward()
594
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
595
+ optimizer.step()
596
+
597
+ total_loss += loss.item()
598
+ total_policy_loss += policy_loss.item()
599
+ total_entropy += avg_entropy.item()
600
+ n_updates += 1
601
+
602
+ return {
603
+ "loss": total_loss / n_updates if n_updates > 0 else 0.0,
604
+ "policy_loss": total_policy_loss / n_updates if n_updates > 0 else 0.0,
605
+ "entropy": total_entropy / n_updates if n_updates > 0 else 0.0,
606
+ }
607
+ # ======================================================================
608
+ # 12. EVALUATION (unchanged)
609
+ # ======================================================================
610
+ def evaluate_policy(
611
+ env: CodeReviewEnv,
612
+ model,
613
+ tokenizer,
614
+ n_episodes: int = 10,
615
+ max_steps: int = 10
616
+ ) -> Dict[str, float]:
617
+ model.eval()
618
+ total_rewards = []
619
+ episode_lengths = []
620
+ success_count = 0
621
+
622
+ for _ in range(n_episodes):
623
+ traj = collect_trajectory(env, model, tokenizer, max_steps, temperature=0.0)
624
+ total_reward = sum(traj.rewards)
625
+ total_rewards.append(total_reward)
626
+ episode_lengths.append(len(traj))
627
+ if total_reward > 0.5:
628
+ success_count += 1
629
+
630
+ return {
631
+ "avg_reward": np.mean(total_rewards),
632
+ "std_reward": np.std(total_rewards),
633
+ "avg_length": np.mean(episode_lengths),
634
+ "success_rate": success_count / n_episodes,
635
+ }
636
+
637
+ # ======================================================================
638
+ # 13. MAIN TRAINING LOOP (added sanity check and warm-up)
639
+ # ======================================================================
640
+ def train_ppo(
641
+ n_iterations: int = 50,
642
+ trajectories_per_iter: int = 10,
643
+ n_epochs: int = 4,
644
+ max_steps: int = 10,
645
+ learning_rate: float = 3e-5,
646
+ clip_epsilon: float = 0.2,
647
+ entropy_coef: float = 0.01,
648
+ gamma: float = 0.99,
649
+ eval_every: int = 5,
650
+ task_levels: Optional[List[str]] = None,
651
+ curriculum_weighted_sampling: bool = True,
652
+ reward_profile: str = "full",
653
+ ):
654
+ print("Loading model...")
655
+ model, tokenizer = load_model()
656
+
657
+ # NEW: Sanity check before any training
658
+ if not test_model_sanity(model, tokenizer):
659
+ print("\n❌ Model sanity check failed cannot proceed.")
660
+ return
661
+
662
+ # NEW: Supervised warm-up to teach JSON format
663
+ supervised_warmup(model, tokenizer, n_examples=500, epochs=2)
664
+
665
+ optimizer = AdamW(model.parameters(), lr=learning_rate)
666
+ env = CodeReviewEnv(reward_profile=reward_profile)
667
+ if task_levels is None:
668
+ task_levels = list(BUG_DB.keys())
669
+
670
+ print(f"\n{'='*60}")
671
+ print(f"Starting PPO Training")
672
+ print(f"Iterations: {n_iterations}")
673
+ print(f"Trajectories per iteration: {trajectories_per_iter}")
674
+ print(f"PPO epochs: {n_epochs}")
675
+ print(f"Reward profile: {reward_profile}")
676
+ print(f"{'='*60}\n")
677
+ reward_history: List[float] = []
678
+ loss_history: List[float] = []
679
+
680
+ for iteration in range(n_iterations):
681
+ print(f"\n--- Iteration {iteration + 1}/{n_iterations} ---")
682
+ # Optional weighted curriculum:
683
+ # start with easier tasks and smoothly ramp difficulty over training.
684
+ if curriculum_weighted_sampling:
685
+ progress = (iteration + 1) / max(n_iterations, 1)
686
+ easy_w = max(0.15, 0.55 - 0.40 * progress)
687
+ medium_w = max(0.15, 0.25 - 0.10 * progress)
688
+ hard_w = 0.10 + 0.05 * progress
689
+ harder_w = 0.05 + 0.20 * progress
690
+ hardest_w = 0.05 + 0.25 * progress
691
+ task_weight_map = {
692
+ "easy": easy_w,
693
+ "medium": medium_w,
694
+ "hard": hard_w,
695
+ "harder": harder_w,
696
+ "hardest": hardest_w,
697
+ }
698
+ task_weights = [task_weight_map.get(level, 1.0) for level in task_levels]
699
+ else:
700
+ task_weights = None
701
+
702
+ print("Collecting trajectories...")
703
+ trajectories = collect_trajectories(
704
+ env,
705
+ model,
706
+ tokenizer,
707
+ trajectories_per_iter,
708
+ max_steps,
709
+ task_levels=task_levels,
710
+ task_weights=task_weights,
711
+ )
712
+
713
+ avg_reward = np.mean([sum(t.rewards) for t in trajectories])
714
+ avg_length = np.mean([len(t) for t in trajectories])
715
+ reward_history.append(float(avg_reward))
716
+
717
+ print(f"Avg reward: {avg_reward:.3f}")
718
+ print(f"Avg length: {avg_length:.1f}")
719
+
720
+ print("Updating policy...")
721
+ metrics = ppo_update(
722
+ trajectories,
723
+ model,
724
+ tokenizer,
725
+ optimizer,
726
+ n_epochs=n_epochs,
727
+ clip_epsilon=clip_epsilon,
728
+ entropy_coef=entropy_coef,
729
+ gamma=gamma,
730
+ )
731
+
732
+ print(f"Loss: {metrics['loss']:.4f}")
733
+ print(f"Policy loss: {metrics['policy_loss']:.4f}")
734
+ print(f"Entropy: {metrics['entropy']:.4f}")
735
+ loss_history.append(float(metrics["loss"]))
736
+
737
+ if (iteration + 1) % eval_every == 0:
738
+ print("\nEvaluating policy...")
739
+ eval_metrics = evaluate_policy(env, model, tokenizer, n_episodes=10)
740
+ print(f"Eval avg reward: {eval_metrics['avg_reward']:.3f} ± {eval_metrics['std_reward']:.3f}")
741
+ print(f"Eval success rate: {eval_metrics['success_rate']:.2%}")
742
+ print(f"Eval avg length: {eval_metrics['avg_length']:.1f}")
743
+
744
+ print("\n" + "="*60)
745
+ print("Training complete. Saving model...")
746
+ model.save_pretrained("ppo_final_model")
747
+ tokenizer.save_pretrained("ppo_final_model")
748
+ print("Model saved to ppo_final_model/")
749
+
750
+ # Save training curves for quick before/after comparisons.
751
+ # These are intentionally simple line plots to avoid extra dependencies.
752
+ if reward_history:
753
+ plt.figure(figsize=(8, 4))
754
+ plt.plot(range(1, len(reward_history) + 1), reward_history, marker="o")
755
+ plt.title("Average Reward per Iteration")
756
+ plt.xlabel("Iteration")
757
+ plt.ylabel("Average Reward")
758
+ plt.grid(alpha=0.3)
759
+ plt.tight_layout()
760
+ plt.savefig("reward_curve.png", dpi=150)
761
+ plt.close()
762
+
763
+ if loss_history:
764
+ plt.figure(figsize=(8, 4))
765
+ plt.plot(range(1, len(loss_history) + 1), loss_history, marker="o", color="tab:red")
766
+ plt.title("Training Loss per Iteration")
767
+ plt.xlabel("Iteration")
768
+ plt.ylabel("Loss")
769
+ plt.grid(alpha=0.3)
770
+ plt.tight_layout()
771
+ plt.savefig("loss_curve.png", dpi=150)
772
+ plt.close()
773
+
774
+ if os.path.exists("reward_curve.png") and os.path.exists("loss_curve.png"):
775
+ print("Saved reward_curve.png and loss_curve.png")
776
+ print("="*60)
777
+
778
+ # ======================================================================
779
+ # 14. ENTRY POINT (unchanged)
780
+ # ======================================================================
781
+ if __name__ == "__main__":
782
+ train_ppo(
783
+ n_iterations=50,
784
+ trajectories_per_iter=10,
785
+ n_epochs=4,
786
+ max_steps=10,
787
+ learning_rate=3e-5,
788
+ clip_epsilon=0.2,
789
+ entropy_coef=0.01,
790
+ gamma=0.99,
791
+ eval_every=5,
792
+ )