# environment.py – FULLY CORRECTED RL Environment (TRUE Markov + Fixed Bugs) import sys import subprocess import tempfile import os import re from dataclasses import dataclass, field from typing import Tuple, Dict, Any, Optional, List from models import ( AnyAction, WriteComment, ProposeFix, Execute, Inspect, RunLinter, RunTests, QueryDocs, Skip, Done, AskQuestion, Observation, Reward, State ) from redteam import RedTeam from test_runner import TestRunner from author import PersonaAuthor from rltool import ToolBox from rubrics import ( ToolUsageRubric, TestDeltaRubric, LintDeltaRubric, TerminalSuccessRubric, ExplorationRubric, AntiHackingRubric, StepPenaltyRubric, ) # ====================================================================== # FULLY MARKOV OBSERVATION (NOTHING HIDDEN) # ====================================================================== @dataclass class EnhancedObservation: code_snippet: str last_tool_output: str current_test_score: float current_lint_score: float negotiation_score: float previous_test_score: float previous_lint_score: float author_confidence: float author_threshold: float step: int max_steps: int progress_ratio: float tests_run: bool linter_run: bool docs_queried: bool last_action_type: str action_history: List[str] done: bool bug_description: str comments_count: int # default fields must be at the very end author_response: str = "" # ====================================================================== # HELPER FUNCTIONS # ====================================================================== def execute_code(code: str, timeout_sec: int = 5) -> Tuple[bool, str, str]: if not code.strip(): return False, "", "Error: Empty code" with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False, encoding='utf-8') as f: f.write(code) tmp_path = f.name try: result = subprocess.run( [sys.executable, tmp_path], capture_output=True, text=True, timeout=timeout_sec ) success = (result.returncode == 0) return success, result.stdout, result.stderr except subprocess.TimeoutExpired: return False, "", f"Timeout after {timeout_sec}s" except Exception as e: return False, "", f"Execution error: {str(e)}" finally: try: os.unlink(tmp_path) except: pass # ====================================================================== # ENHANCED CODE REVIEW ENVIRONMENT # ====================================================================== @dataclass class CodeReviewEnv: task: str = "easy" max_steps: int = 10 step_penalty: float = 0.01 reward_profile: str = "full" # "full" or "core" # Curriculum learning auto_difficulty: bool = False success_threshold: float = 0.7 # Reward shaping parameters delta_weight: float = 0.3 tool_usage_bonus: float = 0.05 diversity_bonus: float = 0.03 _red_team: Optional[RedTeam] = field(init=False, default=None) _author: Optional[PersonaAuthor] = field(init=False, default=None) _current_code: str = field(init=False, default="") _current_bug_id: str = field(init=False, default="") _bug_description: str = field(init=False, default="") _oracle_fix: str = field(init=False, default="") _comments: list = field(init=False, default_factory=list) _test_results: Optional[str] = field(init=False, default=None) _lint_results: Optional[str] = field(init=False, default=None) _doc_results: Optional[str] = field(init=False, default=None) _step_count: int = field(init=False, default=0) _done: bool = field(init=False, default=False) # State tracking for dense rewards _previous_test_score: float = field(init=False, default=0.0) _previous_lint_score: float = field(init=False, default=0.0) _current_test_score: float = field(init=False, default=0.0) _current_lint_score: float = field(init=False, default=0.0) # Tool usage tracking _tests_run: bool = field(init=False, default=False) _linter_run: bool = field(init=False, default=False) _docs_queried: bool = field(init=False, default=False) # Action history _action_history: List[str] = field(init=False, default_factory=list) _last_action_type: str = field(init=False, default="none") _last_author_response: str = field(init=False, default="") # FIXED: Track CUMULATIVE episode reward _episode_total_reward: float = field(init=False, default=0.0) _episode_rewards: List[float] = field(init=False, default_factory=list) _difficulty_level: int = field(init=False, default=0) # Bug-id bridge: # RedTeam has fine-grained IDs, while TestRunner currently expects a # smaller canonical set. Keep this mapping here so both modules can evolve # independently without breaking evaluation. _BUG_ID_CANONICAL_MAP = { # Easy-family "simple_typo": "null_check", "default_value": "null_check", "empty_return": "null_check", "string_index": "off_by_one", # Medium-family "loop_skip": "off_by_one", "sign_error": "wrong_operator", "swap_args": "wrong_operator", "uninitialised_var": "null_check", # Hard-family "division_by_zero_empty": "division_by_zero", "division_by_zero_zero": "division_by_zero", "float_precision": "division_by_zero", "abs_usage": "division_by_zero", "round_error": "division_by_zero", } # =================================================================== def __post_init__(self): self.set_task(self.task) # =================================================================== def _build_rubrics(self): """ Build rubric stack from a named reward profile. - full: richer shaping for exploration/tool-use behavior - core: minimal stable signal for quick ablations/baselines """ core_rubrics = [ TestDeltaRubric(weight=self.delta_weight), LintDeltaRubric(weight=self.delta_weight), TerminalSuccessRubric(), StepPenaltyRubric(penalty=self.step_penalty), ] if self.reward_profile == "core": return core_rubrics if self.reward_profile == "full": return [ *core_rubrics[:-1], # step penalty appended at end for consistent ordering ToolUsageRubric(bonus=self.tool_usage_bonus), ExplorationRubric(penalty=-0.05, bonus=self.diversity_bonus * 0.7), AntiHackingRubric(), core_rubrics[-1], ] raise ValueError(f"Unknown reward_profile: {self.reward_profile}") # =================================================================== def set_task(self, task: str): if task not in ["easy", "medium", "hard", "harder", "hardest"]: raise ValueError(f"Unknown task: {task}") self.task = task # Use stochastic bug sampling across episodes; fixed seed here would # repeatedly select the same bug and weaken training diversity. self._red_team = RedTeam(task, seed=None) self._author = PersonaAuthor() self.rubrics = self._build_rubrics() task_to_level = { "easy": 0, "medium": 1, "hard": 2, "harder": 3, "hardest": 4 } self._difficulty_level = task_to_level[task] self._reset_internal() # =================================================================== def _reset_internal(self): self._step_count = 0 # ← FIXED self._comments = [] self._test_results = None self._lint_results = None self._doc_results = None self._done = False # Reset state tracking self._previous_test_score = 0.0 self._previous_lint_score = 0.0 self._current_test_score = 0.0 self._current_lint_score = 0.0 self._tests_run = False self._linter_run = False self._docs_queried = False self._action_history = [] self._last_action_type = "none" self._last_author_response = "" # FIXED: Reset episode cumulative reward self._episode_total_reward = 0.0 self._author.reset() # Base tasks if self.task == "easy": original = "def get_user(id):\n if id in users:\n return users[id]" elif self.task == "medium": original = "def process_items(items):\n for item in items:\n print(item)" elif self.task == "hard": original = "def average(data):\n if not data:\n return 0\n return sum(data) / len(data)" elif self.task == "harder": original = "counter = 0\ndef increment():\n global counter\n with lock:\n counter += 1" else: original = "def safe_work():\n with lock1:\n with lock2:\n do_work()" buggy_code, bug_id, desc, oracle = self._red_team.inject_bug(original) self._current_code = buggy_code self._current_bug_id = bug_id self._bug_description = desc self._oracle_fix = oracle self._comments.append(f"[RedTeam] {desc}") # =================================================================== def reset(self) -> EnhancedObservation: """Reset with optional curriculum adjustment.""" if self.auto_difficulty and len(self._episode_rewards) > 0: recent_performance = sum(self._episode_rewards[-5:]) / min(5, len(self._episode_rewards)) if recent_performance > self.success_threshold and self._difficulty_level < 4: self._difficulty_level += 1 print(f"[Curriculum] Increasing difficulty to level {self._difficulty_level}") elif recent_performance < 0.3 and self._difficulty_level > 0: self._difficulty_level -= 1 print(f"[Curriculum] Decreasing difficulty to level {self._difficulty_level}") level_to_task = {0: "easy", 1: "medium", 2: "hard", 3: "harder", 4: "hardest"} self.task = level_to_task[self._difficulty_level] # Keep curriculum stochastic for better coverage within each level. self._red_team = RedTeam(self.task, seed=None) self._reset_internal() return self._get_observation() # =================================================================== def _get_observation(self) -> EnhancedObservation: """Return COMPLETE Markov state.""" # Keep the author's message separate from tool output. # Using `_test_results` here can leak unrelated outputs (tests/linter/docs) # and gives the policy a noisy signal for dialogue actions. if self._last_action_type in ("comment", "question", "fix"): author_response = self._last_author_response else: author_response = "" return EnhancedObservation( code_snippet=self._current_code, last_tool_output=self._test_results or "", author_response=author_response, # ← now field exists current_test_score=self._current_test_score, current_lint_score=self._current_lint_score, negotiation_score=self._author.get_negotiation_score(), previous_test_score=self._previous_test_score, previous_lint_score=self._previous_lint_score, author_confidence=self._author._confidence, author_threshold=self._author.thresholds.get(self._author.personality, 0.5), step=self._step_count, max_steps=self.max_steps, # Guard against accidental `max_steps=0` configs. progress_ratio=(self._step_count / self.max_steps) if self.max_steps > 0 else 1.0, tests_run=self._tests_run, linter_run=self._linter_run, docs_queried=self._docs_queried, last_action_type=self._last_action_type, action_history=self._action_history[-5:], done=self._done, bug_description=self._bug_description, comments_count=len(self._comments), ) # =================================================================== def _get_action_type(self, action: AnyAction) -> str: """Extract action type as string.""" if isinstance(action, RunTests): return "run_tests" elif isinstance(action, RunLinter): return "run_linter" elif isinstance(action, QueryDocs): return "query_docs" elif isinstance(action, Execute): return "execute" elif isinstance(action, Inspect): return "inspect" elif isinstance(action, WriteComment): return "comment" elif isinstance(action, AskQuestion): return "question" elif isinstance(action, ProposeFix): return "fix" elif isinstance(action, Done): return "done" elif isinstance(action, Skip): return "skip" else: return "unknown" # =================================================================== def _get_test_runner_bug_id(self) -> str: """ Normalize RedTeam bug ids to the canonical ids understood by TestRunner. Falls back to the original id for known direct matches. """ return self._BUG_ID_CANONICAL_MAP.get(self._current_bug_id, self._current_bug_id) # =================================================================== def step(self, action: AnyAction) -> Tuple[EnhancedObservation, Reward, bool, Dict[str, Any]]: """ TRUE RL STEP with: - Complete Markov observations (no hidden state) - Dense intermediate rewards - Delta-based credit assignment (no double-counting) - Proper episode reward tracking """ if self._done: raise RuntimeError("Episode already finished") # Store previous metrics for delta computation self._previous_test_score = self._current_test_score self._previous_lint_score = self._current_lint_score # Snapshot tool-usage flags BEFORE action mutates them. # Rubrics use these to detect true "first-use" behavior. prev_tests_run = self._tests_run prev_linter_run = self._linter_run prev_docs_queried = self._docs_queried base_reward = 0.0 action_type = self._get_action_type(action) # Update action history self._action_history.append(action_type) self._last_action_type = action_type # ============================================================== # TOOL ACTIONS # ============================================================== if isinstance(action, Execute): success, stdout, stderr = execute_code(self._current_code) output = (stdout + stderr).strip() or "No output" self._test_results = f"[Execute] {'Success' if success else 'Failed'}\n{output[:300]}" base_reward = 0.001 if success else -0.05 elif isinstance(action, Inspect): self._test_results = f"[Inspect]\n{self._current_code[:500]}" base_reward = 0.001 elif isinstance(action, RunLinter): lint_output = ToolBox.run_linter(self._current_code) self._lint_results = lint_output[:500] self._test_results = f"[Linter]\n{self._lint_results}" self._current_lint_score = self._run_linter_score(self._current_code) self._linter_run = True base_reward = 0.002 elif isinstance(action, RunTests): runner = TestRunner(self._get_test_runner_bug_id()) score, output = runner.run_tests(self._current_code) self._current_test_score = score self._tests_run = True self._test_results = f"[Tests] Score: {score:.2f}\n{output[:300]}" base_reward = 0.002 if score > 0.8: base_reward += 0.005 elif isinstance(action, QueryDocs): # Normalize query to avoid rewarding empty/noisy requests. query_topic = (action.query_topic or "").strip() doc = ToolBox.query_docs(query_topic if query_topic else "general bug fixing") self._doc_results = doc self._test_results = f"[Docs]\n{doc[:400]}" self._docs_queried = True base_reward = 0.001 # ============================================================== # COMMUNICATION ACTIONS # ============================================================== elif isinstance(action, WriteComment): self._comments.append(f"Agent: {action.comment_text}") response = self._author.respond( agent_comment=action.comment_text, test_results=self._test_results, lint_results=self._lint_results, doc_results=self._doc_results, proposed_fix=None, original_code=self._current_code ) self._comments.append(f"Author: {response}") self._last_author_response = response self._test_results = f"[Comment] Author: {response[:200]}" base_reward = 0.001 elif isinstance(action, AskQuestion): self._comments.append(f"Agent: {action.question}") response = self._author.respond( agent_question=action.question, test_results=self._test_results, lint_results=self._lint_results, doc_results=self._doc_results, proposed_fix=None, original_code=self._current_code # ← FIXED ) self._comments.append(f"Author: {response}") self._last_author_response = response self._test_results = f"[Question] Author: {response[:200]}" base_reward = 0.002 # ============================================================== # FINAL FIX ACTION # ============================================================== elif isinstance(action, ProposeFix): if not action.fix_code: base_reward = -0.05 self._done = True else: # Save original code BEFORE overwriting (for author.respond) original_buggy = self._current_code self._current_code = action.fix_code runner = TestRunner(self._get_test_runner_bug_id()) test_score, test_output = runner.run_tests(self._current_code) lint_score = self._run_linter_score(self._current_code) negotiation_score = self._author.get_negotiation_score() self._current_test_score = test_score self._current_lint_score = lint_score # Author gating – determines if the episode ends, reward is separate threshold = self._author.thresholds.get(self._author.personality, 0.5) if self._author._confidence < threshold: if self._step_count < self.max_steps: self._done = False else: self._done = True else: self._done = True # Get author's verbal feedback (pushback/acceptance) author_feedback = self._author.respond( agent_comment=f"Proposed fix:\n{action.fix_code}", test_results=f"Score: {test_score:.2f}", lint_results=f"Score: {lint_score:.2f}", doc_results=self._doc_results, proposed_fix=action.fix_code, original_code=original_buggy # now correctly the buggy code, not the fix ) self._test_results = f"[Fix] Author: {author_feedback[:200]}" self._comments.append(f"Author: {author_feedback}") self._last_author_response = author_feedback base_reward = 0.001 # rubrics provide the real signal # ============================================================== # TERMINATION ACTIONS # ============================================================== elif isinstance(action, Skip): base_reward = -0.03 self._done = True elif isinstance(action, Done): if self._tests_run: base_reward = self._current_test_score * 0.5 - 0.2 else: base_reward = -0.04 self._done = True else: base_reward = -0.02 self._done = True # ============================================================== # STEP UPDATE (before rubric computation so info contains final step) # ============================================================== self._step_count += 1 if self._step_count >= self.max_steps: self._done = True # Get fresh observation (needed for rubrics that may read obs) obs = self._get_observation() # Prepare info dict (rubrics may need action_type and deltas) info = { "action_type": action_type, "test_score": self._current_test_score, "lint_score": self._current_lint_score, "test_delta": self._current_test_score - self._previous_test_score, "lint_delta": self._current_lint_score - self._previous_lint_score, "prev_tests_run": prev_tests_run, "prev_linter_run": prev_linter_run, "prev_docs_queried": prev_docs_queried, "docs_query_len": len((action.query_topic or "").strip()) if isinstance(action, QueryDocs) else 0, "base_reward": base_reward, } # ============================================================== # COMPUTE FINAL REWARD USING RUBRICS # ============================================================== rubric_score = sum(r(self, action, obs, None, self._done, info) for r in self.rubrics) final_reward = 0.4 * base_reward + rubric_score final_reward = max(-1.0, min(1.0, final_reward)) # safety clip # Track cumulative episode reward self._episode_total_reward += final_reward # Store episode total if done if self._done: self._episode_rewards.append(self._episode_total_reward) # Complete info info["final_reward"] = final_reward info["episode_total"] = self._episode_total_reward return obs, Reward(value=final_reward), self._done, info # =================================================================== def _run_linter_score(self, code: str) -> float: """Run pylint and return normalized score [0, 1].""" try: with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: f.write(code) tmp_path = f.name result = subprocess.run( [sys.executable, '-m', 'pylint', tmp_path, '--score=y', '--exit-zero'], capture_output=True, text=True, timeout=5 ) match = re.search(r"rated at (\d+\.\d+)/10", result.stdout) if match: return float(match.group(1)) / 10.0 return 0.0 except: return 0.0 finally: try: os.unlink(tmp_path) except: pass # =================================================================== def state(self) -> State: """Legacy compatibility.""" return State( pr_title="Code Review", pr_description=self._bug_description, code_snippet=self._current_code, comments=self._comments.copy(), test_results=self._test_results, step=self._step_count, done=self._done )