100XZX001 commited on
Commit
73f8ffa
·
verified ·
1 Parent(s): 354bd58

Update environment.py

Browse files
Files changed (1) hide show
  1. environment.py +419 -115
environment.py CHANGED
@@ -1,4 +1,4 @@
1
- # environment.py – Final integrated environment (multi-turn, gated, continuous scoring)
2
 
3
  import sys
4
  import subprocess
@@ -6,9 +6,9 @@ import tempfile
6
  import os
7
  import re
8
  from dataclasses import dataclass, field
9
- from typing import Tuple, Dict, Any, Optional
 
10
 
11
- # Uncomment these imports – ensure the files are in the same directory
12
  from models import (
13
  AnyAction, WriteComment, ProposeFix, Execute, Inspect,
14
  RunLinter, RunTests, QueryDocs, Skip, Done, AskQuestion,
@@ -20,9 +20,57 @@ from test_runner import TestRunner
20
  from author import PersonaAuthor
21
  from rltool import ToolBox
22
 
23
- # ----------------------------------------------------------------------
24
- # Helper: execute arbitrary Python code
25
- # ----------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def execute_code(code: str, timeout_sec: int = 5) -> Tuple[bool, str, str]:
27
  if not code.strip():
28
  return False, "", "Error: Empty code"
@@ -40,13 +88,10 @@ def execute_code(code: str, timeout_sec: int = 5) -> Tuple[bool, str, str]:
40
  )
41
  success = (result.returncode == 0)
42
  return success, result.stdout, result.stderr
43
-
44
  except subprocess.TimeoutExpired:
45
  return False, "", f"Timeout after {timeout_sec}s"
46
-
47
  except Exception as e:
48
  return False, "", f"Execution error: {str(e)}"
49
-
50
  finally:
51
  try:
52
  os.unlink(tmp_path)
@@ -54,15 +99,24 @@ def execute_code(code: str, timeout_sec: int = 5) -> Tuple[bool, str, str]:
54
  pass
55
 
56
 
57
- # ----------------------------------------------------------------------
58
- # Main Environment
59
- # ----------------------------------------------------------------------
60
  @dataclass
61
  class CodeReviewEnv:
62
  task: str = "easy"
63
  max_steps: int = 10
64
- step_penalty: float = 0.02
65
-
 
 
 
 
 
 
 
 
 
66
  _red_team: Optional[RedTeam] = field(init=False, default=None)
67
  _author: Optional[PersonaAuthor] = field(init=False, default=None)
68
 
@@ -78,23 +132,49 @@ class CodeReviewEnv:
78
 
79
  _step_count: int = field(init=False, default=0)
80
  _done: bool = field(init=False, default=False)
81
-
82
- # ------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  def __post_init__(self):
84
  self.set_task(self.task)
85
 
86
- # ------------------------------------------------------------------
87
  def set_task(self, task: str):
88
  if task not in ["easy", "medium", "hard", "harder", "hardest"]:
89
  raise ValueError(f"Unknown task: {task}")
90
 
91
  self.task = task
92
  self._red_team = RedTeam(task)
93
- self._author = PersonaAuthor() # uses default personality "defensive"
94
-
 
 
 
 
 
 
95
  self._reset_internal()
96
 
97
- # ------------------------------------------------------------------
98
  def _reset_internal(self):
99
  self._step_count = 0
100
  self._comments = []
@@ -102,96 +182,301 @@ class CodeReviewEnv:
102
  self._lint_results = None
103
  self._doc_results = None
104
  self._done = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  self._author.reset()
107
 
108
- # --- Base tasks ---
109
  if self.task == "easy":
110
  original = "def get_user(id):\n if id in users:\n return users[id]"
111
-
112
  elif self.task == "medium":
113
  original = "def process_items(items):\n for item in items:\n print(item)"
114
-
115
  elif self.task == "hard":
116
  original = "def average(data):\n if not data:\n return 0\n return sum(data) / len(data)"
117
-
118
  elif self.task == "harder":
119
  original = "counter = 0\ndef increment():\n global counter\n with lock:\n counter += 1"
120
-
121
  else:
122
  original = "def safe_work():\n with lock1:\n with lock2:\n do_work()"
123
 
124
- # --- Inject bug ---
125
  buggy_code, bug_id, desc, oracle = self._red_team.inject_bug(original)
126
-
127
  self._current_code = buggy_code
128
  self._current_bug_id = bug_id
129
  self._bug_description = desc
130
  self._oracle_fix = oracle
131
-
132
  self._comments.append(f"[RedTeam] {desc}")
133
 
134
- # ------------------------------------------------------------------
135
- def reset(self) -> Observation:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  self._reset_internal()
137
  return self._get_observation()
138
 
139
- # ------------------------------------------------------------------
140
- def _get_observation(self) -> Observation:
141
- # Observation as defined in models.py (no conversation_history)
142
- return Observation(
 
 
 
143
  code_snippet=self._current_code,
144
  last_tool_output=self._test_results or "",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  step=self._step_count,
146
- done=self._done
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  )
148
 
149
- # ------------------------------------------------------------------
150
- def step(self, action: AnyAction) -> Tuple[Observation, Reward, bool, Dict[str, Any]]:
151
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  if self._done:
153
  raise RuntimeError("Episode already finished")
154
 
155
- reward_val = 0.0
 
 
 
 
156
  info = {}
 
 
 
 
 
157
 
158
- # ================================================================
159
  # TOOL ACTIONS
160
- # ================================================================
161
  if isinstance(action, Execute):
162
  success, stdout, stderr = execute_code(self._current_code)
163
- self._test_results = (stdout + stderr).strip() or "No output"
164
- reward_val = -self.step_penalty
 
165
 
166
  elif isinstance(action, Inspect):
167
- self._test_results = self._current_code
168
- reward_val = -self.step_penalty
169
 
170
  elif isinstance(action, RunLinter):
171
  lint_output = ToolBox.run_linter(self._current_code)
172
  self._lint_results = lint_output[:500]
173
- self._test_results = self._lint_results
174
- reward_val = -self.step_penalty
 
 
 
175
 
176
  elif isinstance(action, RunTests):
177
  runner = TestRunner(self._current_bug_id)
178
  score, output = runner.run_tests(self._current_code)
179
-
180
- self._test_results = f"Test score: {score:.2f}\n{output[:500]}"
181
- reward_val = -self.step_penalty
 
 
 
 
 
 
182
 
183
  elif isinstance(action, QueryDocs):
184
  doc = ToolBox.query_docs(action.query_topic)
185
  self._doc_results = doc
186
- self._test_results = doc
187
- reward_val = -self.step_penalty
 
188
 
189
- # ================================================================
190
- # COMMUNICATION (MULTI-TURN)
191
- # ================================================================
192
  elif isinstance(action, WriteComment):
193
  self._comments.append(f"Agent: {action.comment_text}")
194
-
195
  response = self._author.respond(
196
  agent_comment=action.comment_text,
197
  test_results=self._test_results,
@@ -200,13 +485,14 @@ class CodeReviewEnv:
200
  proposed_fix=None,
201
  original_code=self._current_code
202
  )
203
-
204
  self._comments.append(f"Author: {response}")
205
- reward_val = -self.step_penalty
 
206
 
207
  elif isinstance(action, AskQuestion):
208
  self._comments.append(f"Agent: {action.question}")
209
-
210
  response = self._author.respond(
211
  agent_question=action.question,
212
  test_results=self._test_results,
@@ -215,96 +501,117 @@ class CodeReviewEnv:
215
  proposed_fix=None,
216
  original_code=self._current_code
217
  )
218
-
219
  self._comments.append(f"Author: {response}")
220
- reward_val = -self.step_penalty
 
221
 
222
- # ================================================================
223
- # FINAL FIX
224
- # ================================================================
225
  elif isinstance(action, ProposeFix):
226
-
227
  if not action.fix_code:
228
- reward_val = -0.5
229
  self._done = True
230
-
231
  else:
232
  self._current_code = action.fix_code
233
-
234
  runner = TestRunner(self._current_bug_id)
235
  test_score, test_output = runner.run_tests(self._current_code)
236
-
237
  lint_score = self._run_linter_score(self._current_code)
238
  negotiation_score = self._author.get_negotiation_score()
239
-
240
- step_cost = self.step_penalty * self._step_count
241
-
242
- reward_val = (
243
- 0.6 * test_score +
244
- 0.2 * lint_score +
245
- 0.2 * negotiation_score -
246
- step_cost
 
 
247
  )
248
-
249
- # -------------------------
250
- # Cross-signal penalties
251
- # -------------------------
 
252
  if test_score > 0.8 and lint_score < 0.3:
253
- reward_val *= 0.8
254
-
255
  if test_score < 0.3 and lint_score > 0.8:
256
- reward_val *= 0.7
257
-
258
  if test_score > 0.8 and negotiation_score < 0.3:
259
- reward_val *= 0.75
260
-
261
- # -------------------------
262
- # Author gating (only if not already convinced)
263
- # -------------------------
264
  threshold = self._author.thresholds.get(self._author.personality, 0.5)
265
  if self._author._confidence < threshold:
266
- reward_val = max(0.0, reward_val - 0.3)
267
- # Allow continuation if steps left
268
  if self._step_count < self.max_steps:
269
  self._done = False
270
  else:
271
  self._done = True
272
  else:
273
  self._done = True
 
 
 
274
 
275
- reward_val = max(0.0, min(1.0, reward_val))
276
-
277
- self._test_results = f"Test score: {test_score:.2f}\n{test_output[:300]}"
278
-
279
- # ================================================================
280
- # TERMINATION
281
- # ================================================================
282
  elif isinstance(action, Skip):
283
- reward_val = -0.2
284
  self._done = True
285
 
286
  elif isinstance(action, Done):
287
- reward_val = -0.5
 
 
 
288
  self._done = True
289
 
290
  else:
291
- reward_val = -0.2
292
  self._done = True
293
 
294
- # ================================================================
 
 
 
 
 
 
 
 
295
  # STEP UPDATE
296
- # ================================================================
297
  self._step_count += 1
298
-
299
  if self._step_count >= self.max_steps:
300
  self._done = True
301
-
 
 
 
 
302
  obs = self._get_observation()
303
-
304
- return obs, Reward(value=reward_val), self._done, info
305
-
306
- # ------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
307
  def _run_linter_score(self, code: str) -> float:
 
308
  try:
309
  with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
310
  f.write(code)
@@ -318,23 +625,20 @@ class CodeReviewEnv:
318
  )
319
 
320
  match = re.search(r"rated at (\d+\.\d+)/10", result.stdout)
321
-
322
  if match:
323
  return float(match.group(1)) / 10.0
324
-
325
  return 0.0
326
-
327
  except:
328
  return 0.0
329
-
330
  finally:
331
  try:
332
  os.unlink(tmp_path)
333
  except:
334
  pass
335
 
336
- # ------------------------------------------------------------------
337
  def state(self) -> State:
 
338
  return State(
339
  pr_title="Code Review",
340
  pr_description=self._bug_description,
 
1
+ # environment.py – FULLY CORRECTED RL Environment (TRUE Markov + Fixed Bugs)
2
 
3
  import sys
4
  import subprocess
 
6
  import os
7
  import re
8
  from dataclasses import dataclass, field
9
+ from typing import Tuple, Dict, Any, Optional, List
10
+ from collections import Counter
11
 
 
12
  from models import (
13
  AnyAction, WriteComment, ProposeFix, Execute, Inspect,
14
  RunLinter, RunTests, QueryDocs, Skip, Done, AskQuestion,
 
20
  from author import PersonaAuthor
21
  from rltool import ToolBox
22
 
23
+ # ======================================================================
24
+ # FULLY MARKOV OBSERVATION (NOTHING HIDDEN)
25
+ # ======================================================================
26
+ @dataclass
27
+ class EnhancedObservation:
28
+ """
29
+ Complete Markov state - agent has ALL information needed for optimal decisions.
30
+ Reward function depends ONLY on (state, action), not hidden variables.
31
+ """
32
+ # Code state
33
+ code_snippet: str
34
+ last_tool_output: str
35
+
36
+ # Current metrics
37
+ current_test_score: float
38
+ current_lint_score: float
39
+ negotiation_score: float
40
+
41
+ # CRITICAL: Previous metrics (for understanding deltas)
42
+ previous_test_score: float
43
+ previous_lint_score: float
44
+
45
+ # CRITICAL: Author internal state (affects reward gating)
46
+ author_confidence: float
47
+ author_threshold: float # When author accepts
48
+
49
+ # Progress tracking
50
+ step: int
51
+ max_steps: int
52
+ progress_ratio: float
53
+
54
+ # Tool usage flags
55
+ tests_run: bool
56
+ linter_run: bool
57
+ docs_queried: bool
58
+
59
+ # Action history (with outcomes)
60
+ last_action_type: str
61
+ action_history: List[str] # Last 5 actions
62
+
63
+ # Terminal flag
64
+ done: bool
65
+
66
+ # Additional context
67
+ bug_description: str
68
+ comments_count: int
69
+
70
+
71
+ # ======================================================================
72
+ # HELPER FUNCTIONS
73
+ # ======================================================================
74
  def execute_code(code: str, timeout_sec: int = 5) -> Tuple[bool, str, str]:
75
  if not code.strip():
76
  return False, "", "Error: Empty code"
 
88
  )
89
  success = (result.returncode == 0)
90
  return success, result.stdout, result.stderr
 
91
  except subprocess.TimeoutExpired:
92
  return False, "", f"Timeout after {timeout_sec}s"
 
93
  except Exception as e:
94
  return False, "", f"Execution error: {str(e)}"
 
95
  finally:
96
  try:
97
  os.unlink(tmp_path)
 
99
  pass
100
 
101
 
102
+ # ======================================================================
103
+ # ENHANCED CODE REVIEW ENVIRONMENT
104
+ # ======================================================================
105
  @dataclass
106
  class CodeReviewEnv:
107
  task: str = "easy"
108
  max_steps: int = 10
109
+ step_penalty: float = 0.01
110
+
111
+ # Curriculum learning
112
+ auto_difficulty: bool = False
113
+ success_threshold: float = 0.7
114
+
115
+ # Reward shaping parameters
116
+ delta_weight: float = 0.3
117
+ tool_usage_bonus: float = 0.05
118
+ diversity_bonus: float = 0.03
119
+
120
  _red_team: Optional[RedTeam] = field(init=False, default=None)
121
  _author: Optional[PersonaAuthor] = field(init=False, default=None)
122
 
 
132
 
133
  _step_count: int = field(init=False, default=0)
134
  _done: bool = field(init=False, default=False)
135
+
136
+ # State tracking for dense rewards
137
+ _previous_test_score: float = field(init=False, default=0.0)
138
+ _previous_lint_score: float = field(init=False, default=0.0)
139
+ _current_test_score: float = field(init=False, default=0.0)
140
+ _current_lint_score: float = field(init=False, default=0.0)
141
+
142
+ # Tool usage tracking
143
+ _tests_run: bool = field(init=False, default=False)
144
+ _linter_run: bool = field(init=False, default=False)
145
+ _docs_queried: bool = field(init=False, default=False)
146
+
147
+ # Action history
148
+ _action_history: List[str] = field(init=False, default_factory=list)
149
+ _last_action_type: str = field(init=False, default="none")
150
+
151
+ # FIXED: Track CUMULATIVE episode reward
152
+ _episode_total_reward: float = field(init=False, default=0.0)
153
+ _episode_rewards: List[float] = field(init=False, default_factory=list)
154
+ _difficulty_level: int = field(init=False, default=0)
155
+
156
+ # ===================================================================
157
  def __post_init__(self):
158
  self.set_task(self.task)
159
 
160
+ # ===================================================================
161
  def set_task(self, task: str):
162
  if task not in ["easy", "medium", "hard", "harder", "hardest"]:
163
  raise ValueError(f"Unknown task: {task}")
164
 
165
  self.task = task
166
  self._red_team = RedTeam(task)
167
+ self._author = PersonaAuthor()
168
+
169
+ task_to_level = {
170
+ "easy": 0, "medium": 1, "hard": 2,
171
+ "harder": 3, "hardest": 4
172
+ }
173
+ self._difficulty_level = task_to_level[task]
174
+
175
  self._reset_internal()
176
 
177
+ # ===================================================================
178
  def _reset_internal(self):
179
  self._step_count = 0
180
  self._comments = []
 
182
  self._lint_results = None
183
  self._doc_results = None
184
  self._done = False
185
+
186
+ # Reset state tracking
187
+ self._previous_test_score = 0.0
188
+ self._previous_lint_score = 0.0
189
+ self._current_test_score = 0.0
190
+ self._current_lint_score = 0.0
191
+
192
+ self._tests_run = False
193
+ self._linter_run = False
194
+ self._docs_queried = False
195
+
196
+ self._action_history = []
197
+ self._last_action_type = "none"
198
+
199
+ # FIXED: Reset episode cumulative reward
200
+ self._episode_total_reward = 0.0
201
 
202
  self._author.reset()
203
 
204
+ # Base tasks
205
  if self.task == "easy":
206
  original = "def get_user(id):\n if id in users:\n return users[id]"
 
207
  elif self.task == "medium":
208
  original = "def process_items(items):\n for item in items:\n print(item)"
 
209
  elif self.task == "hard":
210
  original = "def average(data):\n if not data:\n return 0\n return sum(data) / len(data)"
 
211
  elif self.task == "harder":
212
  original = "counter = 0\ndef increment():\n global counter\n with lock:\n counter += 1"
 
213
  else:
214
  original = "def safe_work():\n with lock1:\n with lock2:\n do_work()"
215
 
 
216
  buggy_code, bug_id, desc, oracle = self._red_team.inject_bug(original)
 
217
  self._current_code = buggy_code
218
  self._current_bug_id = bug_id
219
  self._bug_description = desc
220
  self._oracle_fix = oracle
 
221
  self._comments.append(f"[RedTeam] {desc}")
222
 
223
+ # ===================================================================
224
+ def reset(self) -> EnhancedObservation:
225
+ """Reset with optional curriculum adjustment."""
226
+ if self.auto_difficulty and len(self._episode_rewards) > 0:
227
+ recent_performance = sum(self._episode_rewards[-5:]) / min(5, len(self._episode_rewards))
228
+
229
+ if recent_performance > self.success_threshold and self._difficulty_level < 4:
230
+ self._difficulty_level += 1
231
+ print(f"[Curriculum] Increasing difficulty to level {self._difficulty_level}")
232
+ elif recent_performance < 0.3 and self._difficulty_level > 0:
233
+ self._difficulty_level -= 1
234
+ print(f"[Curriculum] Decreasing difficulty to level {self._difficulty_level}")
235
+
236
+ level_to_task = {0: "easy", 1: "medium", 2: "hard", 3: "harder", 4: "hardest"}
237
+ self.task = level_to_task[self._difficulty_level]
238
+ self._red_team = RedTeam(self.task)
239
+
240
  self._reset_internal()
241
  return self._get_observation()
242
 
243
+ # ===================================================================
244
+ def _get_observation(self) -> EnhancedObservation:
245
+ """
246
+ Return COMPLETE Markov state.
247
+ NOTHING is hidden - reward depends ONLY on (state, action).
248
+ """
249
+ return EnhancedObservation(
250
  code_snippet=self._current_code,
251
  last_tool_output=self._test_results or "",
252
+
253
+ # Current metrics
254
+ current_test_score=self._current_test_score,
255
+ current_lint_score=self._current_lint_score,
256
+ negotiation_score=self._author.get_negotiation_score(),
257
+
258
+ # EXPOSED: Previous metrics (for delta understanding)
259
+ previous_test_score=self._previous_test_score,
260
+ previous_lint_score=self._previous_lint_score,
261
+
262
+ # EXPOSED: Author internal state (affects gating)
263
+ author_confidence=self._author._confidence,
264
+ author_threshold=self._author.thresholds.get(self._author.personality, 0.5),
265
+
266
+ # Progress
267
  step=self._step_count,
268
+ max_steps=self.max_steps,
269
+ progress_ratio=self._step_count / self.max_steps,
270
+
271
+ # Tool usage
272
+ tests_run=self._tests_run,
273
+ linter_run=self._linter_run,
274
+ docs_queried=self._docs_queried,
275
+
276
+ # Action history
277
+ last_action_type=self._last_action_type,
278
+ action_history=self._action_history[-5:],
279
+
280
+ # Terminal
281
+ done=self._done,
282
+
283
+ # Context
284
+ bug_description=self._bug_description,
285
+ comments_count=len(self._comments),
286
  )
287
 
288
+ # ===================================================================
289
+ def _compute_dense_reward(
290
+ self,
291
+ action: AnyAction,
292
+ base_reward: float,
293
+ action_type: str
294
+ ) -> float:
295
+ """
296
+ Compute dense reward with:
297
+ 1. Delta-based improvement rewards
298
+ 2. Tool usage bonuses
299
+ 3. Exploration incentives
300
+ 4. Anti-hacking penalties
301
+
302
+ FIXED: Reduced delta weight for ProposeFix to avoid double-counting
303
+ """
304
+ reward = base_reward
305
+
306
+ # FIXED: Reduce delta impact for ProposeFix (already includes test_score in base)
307
+ effective_delta_weight = self.delta_weight
308
+ if action_type == "propose_fix":
309
+ effective_delta_weight *= 0.5 # Prevent double-counting
310
+
311
+ # ============================================================
312
+ # 1. DELTA-BASED REWARDS (credit assignment)
313
+ # ============================================================
314
+ test_delta = self._current_test_score - self._previous_test_score
315
+ lint_delta = self._current_lint_score - self._previous_lint_score
316
+
317
+ if test_delta > 0:
318
+ reward += effective_delta_weight * test_delta
319
+ elif test_delta < 0:
320
+ reward += effective_delta_weight * test_delta * 0.5
321
+
322
+ if lint_delta > 0:
323
+ reward += effective_delta_weight * 0.5 * lint_delta
324
+
325
+ # ============================================================
326
+ # 2. TOOL USAGE BONUSES
327
+ # ============================================================
328
+ if action_type == "run_tests":
329
+ if not self._tests_run:
330
+ reward += self.tool_usage_bonus
331
+ reward += 0.02
332
+
333
+ elif action_type == "run_linter":
334
+ if not self._linter_run:
335
+ reward += self.tool_usage_bonus
336
+ reward += 0.02
337
+
338
+ elif action_type == "query_docs":
339
+ if not self._docs_queried:
340
+ reward += self.tool_usage_bonus * 0.5
341
+
342
+ elif action_type == "ask_question":
343
+ if 1 <= self._step_count <= 5:
344
+ reward += 0.03
345
+
346
+ # ============================================================
347
+ # 3. EXPLORATION INCENTIVES
348
+ # ============================================================
349
+ if len(self._action_history) >= 3:
350
+ recent_actions = self._action_history[-3:]
351
+ action_counts = Counter(recent_actions)
352
+ most_common_count = action_counts.most_common(1)[0][1]
353
+
354
+ if most_common_count >= 3:
355
+ reward -= 0.05 # Repetition penalty
356
+ elif len(set(recent_actions)) == 3:
357
+ reward += self.diversity_bonus # Diversity bonus
358
+
359
+ # ============================================================
360
+ # 4. ANTI-HACKING PENALTIES
361
+ # ============================================================
362
+ if action_type == "propose_fix":
363
+ if not self._tests_run:
364
+ reward -= 0.2
365
+ if self._step_count < 2:
366
+ reward -= 0.15
367
+ if self._tests_run and self._linter_run:
368
+ reward += 0.1
369
+
370
+ # ============================================================
371
+ # 5. STEP PENALTY
372
+ # ============================================================
373
+ reward -= self.step_penalty
374
+
375
+ # ============================================================
376
+ # 6. NORMALIZE TO [-1, 1]
377
+ # ============================================================
378
+ reward = max(-1.0, min(1.0, reward))
379
+
380
+ return reward
381
+
382
+ # ===================================================================
383
+ def _get_action_type(self, action: AnyAction) -> str:
384
+ """Extract action type as string."""
385
+ if isinstance(action, RunTests):
386
+ return "run_tests"
387
+ elif isinstance(action, RunLinter):
388
+ return "run_linter"
389
+ elif isinstance(action, QueryDocs):
390
+ return "query_docs"
391
+ elif isinstance(action, Execute):
392
+ return "execute"
393
+ elif isinstance(action, Inspect):
394
+ return "inspect"
395
+ elif isinstance(action, WriteComment):
396
+ return "write_comment"
397
+ elif isinstance(action, AskQuestion):
398
+ return "ask_question"
399
+ elif isinstance(action, ProposeFix):
400
+ return "propose_fix"
401
+ elif isinstance(action, Done):
402
+ return "done"
403
+ elif isinstance(action, Skip):
404
+ return "skip"
405
+ else:
406
+ return "unknown"
407
+
408
+ # ===================================================================
409
+ def step(self, action: AnyAction) -> Tuple[EnhancedObservation, Reward, bool, Dict[str, Any]]:
410
+ """
411
+ TRUE RL STEP with:
412
+ - Complete Markov observations (no hidden state)
413
+ - Dense intermediate rewards
414
+ - Delta-based credit assignment (no double-counting)
415
+ - Proper episode reward tracking
416
+ """
417
  if self._done:
418
  raise RuntimeError("Episode already finished")
419
 
420
+ # Store previous metrics for delta computation
421
+ self._previous_test_score = self._current_test_score
422
+ self._previous_lint_score = self._current_lint_score
423
+
424
+ base_reward = 0.0
425
  info = {}
426
+ action_type = self._get_action_type(action)
427
+
428
+ # Update action history
429
+ self._action_history.append(action_type)
430
+ self._last_action_type = action_type
431
 
432
+ # ==============================================================
433
  # TOOL ACTIONS
434
+ # ==============================================================
435
  if isinstance(action, Execute):
436
  success, stdout, stderr = execute_code(self._current_code)
437
+ output = (stdout + stderr).strip() or "No output"
438
+ self._test_results = f"[Execute] {'Success' if success else 'Failed'}\n{output[:300]}"
439
+ base_reward = 0.01 if success else -0.05
440
 
441
  elif isinstance(action, Inspect):
442
+ self._test_results = f"[Inspect]\n{self._current_code[:500]}"
443
+ base_reward = 0.01
444
 
445
  elif isinstance(action, RunLinter):
446
  lint_output = ToolBox.run_linter(self._current_code)
447
  self._lint_results = lint_output[:500]
448
+ self._test_results = f"[Linter]\n{self._lint_results}"
449
+
450
+ self._current_lint_score = self._run_linter_score(self._current_code)
451
+ self._linter_run = True
452
+ base_reward = 0.02
453
 
454
  elif isinstance(action, RunTests):
455
  runner = TestRunner(self._current_bug_id)
456
  score, output = runner.run_tests(self._current_code)
457
+
458
+ self._current_test_score = score
459
+ self._tests_run = True
460
+
461
+ self._test_results = f"[Tests] Score: {score:.2f}\n{output[:300]}"
462
+ base_reward = 0.02
463
+
464
+ if score > 0.8:
465
+ base_reward += 0.05
466
 
467
  elif isinstance(action, QueryDocs):
468
  doc = ToolBox.query_docs(action.query_topic)
469
  self._doc_results = doc
470
+ self._test_results = f"[Docs]\n{doc[:400]}"
471
+ self._docs_queried = True
472
+ base_reward = 0.01
473
 
474
+ # ==============================================================
475
+ # COMMUNICATION ACTIONS
476
+ # ==============================================================
477
  elif isinstance(action, WriteComment):
478
  self._comments.append(f"Agent: {action.comment_text}")
479
+
480
  response = self._author.respond(
481
  agent_comment=action.comment_text,
482
  test_results=self._test_results,
 
485
  proposed_fix=None,
486
  original_code=self._current_code
487
  )
488
+
489
  self._comments.append(f"Author: {response}")
490
+ self._test_results = f"[Comment] Author: {response[:200]}"
491
+ base_reward = 0.01
492
 
493
  elif isinstance(action, AskQuestion):
494
  self._comments.append(f"Agent: {action.question}")
495
+
496
  response = self._author.respond(
497
  agent_question=action.question,
498
  test_results=self._test_results,
 
501
  proposed_fix=None,
502
  original_code=self._current_code
503
  )
504
+
505
  self._comments.append(f"Author: {response}")
506
+ self._test_results = f"[Question] Author: {response[:200]}"
507
+ base_reward = 0.02
508
 
509
+ # ==============================================================
510
+ # FINAL FIX ACTION
511
+ # ==============================================================
512
  elif isinstance(action, ProposeFix):
 
513
  if not action.fix_code:
514
+ base_reward = -0.5
515
  self._done = True
 
516
  else:
517
  self._current_code = action.fix_code
518
+
519
  runner = TestRunner(self._current_bug_id)
520
  test_score, test_output = runner.run_tests(self._current_code)
 
521
  lint_score = self._run_linter_score(self._current_code)
522
  negotiation_score = self._author.get_negotiation_score()
523
+
524
+ # Update current scores
525
+ self._current_test_score = test_score
526
+ self._current_lint_score = lint_score
527
+
528
+ # Component reward (scaled down to allow delta distribution)
529
+ component_reward = (
530
+ 0.4 * test_score +
531
+ 0.15 * lint_score +
532
+ 0.15 * negotiation_score
533
  )
534
+
535
+ efficiency = 1.0 - (self._step_count / self.max_steps)
536
+ component_reward += 0.1 * efficiency
537
+
538
+ # Cross-signal consistency
539
  if test_score > 0.8 and lint_score < 0.3:
540
+ component_reward *= 0.85
 
541
  if test_score < 0.3 and lint_score > 0.8:
542
+ component_reward *= 0.75
 
543
  if test_score > 0.8 and negotiation_score < 0.3:
544
+ component_reward *= 0.8
545
+
546
+ # Author gating
 
 
547
  threshold = self._author.thresholds.get(self._author.personality, 0.5)
548
  if self._author._confidence < threshold:
549
+ component_reward = max(0.0, component_reward - 0.2)
 
550
  if self._step_count < self.max_steps:
551
  self._done = False
552
  else:
553
  self._done = True
554
  else:
555
  self._done = True
556
+
557
+ base_reward = component_reward
558
+ self._test_results = f"[Fix] Test: {test_score:.2f}, Lint: {lint_score:.2f}\n{test_output[:200]}"
559
 
560
+ # ==============================================================
561
+ # TERMINATION ACTIONS
562
+ # ==============================================================
 
 
 
 
563
  elif isinstance(action, Skip):
564
+ base_reward = -0.3
565
  self._done = True
566
 
567
  elif isinstance(action, Done):
568
+ if self._tests_run:
569
+ base_reward = self._current_test_score * 0.5 - 0.2
570
+ else:
571
+ base_reward = -0.4
572
  self._done = True
573
 
574
  else:
575
+ base_reward = -0.2
576
  self._done = True
577
 
578
+ # ==============================================================
579
+ # COMPUTE FINAL DENSE REWARD (with action_type for fix detection)
580
+ # ==============================================================
581
+ final_reward = self._compute_dense_reward(action, base_reward, action_type)
582
+
583
+ # FIXED: Track CUMULATIVE episode reward
584
+ self._episode_total_reward += final_reward
585
+
586
+ # ==============================================================
587
  # STEP UPDATE
588
+ # ==============================================================
589
  self._step_count += 1
590
+
591
  if self._step_count >= self.max_steps:
592
  self._done = True
593
+
594
+ # FIXED: Store TOTAL episode reward, not just last step
595
+ if self._done:
596
+ self._episode_rewards.append(self._episode_total_reward)
597
+
598
  obs = self._get_observation()
599
+
600
+ info = {
601
+ "test_score": self._current_test_score,
602
+ "lint_score": self._current_lint_score,
603
+ "test_delta": self._current_test_score - self._previous_test_score,
604
+ "lint_delta": self._current_lint_score - self._previous_lint_score,
605
+ "base_reward": base_reward,
606
+ "final_reward": final_reward,
607
+ "episode_total": self._episode_total_reward,
608
+ }
609
+
610
+ return obs, Reward(value=final_reward), self._done, info
611
+
612
+ # ===================================================================
613
  def _run_linter_score(self, code: str) -> float:
614
+ """Run pylint and return normalized score [0, 1]."""
615
  try:
616
  with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
617
  f.write(code)
 
625
  )
626
 
627
  match = re.search(r"rated at (\d+\.\d+)/10", result.stdout)
 
628
  if match:
629
  return float(match.group(1)) / 10.0
 
630
  return 0.0
 
631
  except:
632
  return 0.0
 
633
  finally:
634
  try:
635
  os.unlink(tmp_path)
636
  except:
637
  pass
638
 
639
+ # ===================================================================
640
  def state(self) -> State:
641
+ """Legacy compatibility."""
642
  return State(
643
  pr_title="Code Review",
644
  pr_description=self._bug_description,