# environment.py – Final integrated environment (multi-turn, gated, continuous scoring) import sys import subprocess import tempfile import os import re from dataclasses import dataclass, field from typing import Tuple, Dict, Any, Optional # Uncomment these imports – ensure the files are in the same directory from models import ( AnyAction, WriteComment, ProposeFix, Execute, Inspect, RunLinter, RunTests, QueryDocs, Skip, Done, AskQuestion, Observation, Reward, State ) from grader import RigorousGrader from redteam import RedTeam from test_runner import TestRunner from author import PersonaAuthor from rltool import ToolBox # ---------------------------------------------------------------------- # Helper: execute arbitrary Python code # ---------------------------------------------------------------------- def execute_code(code: str, timeout_sec: int = 5) -> Tuple[bool, str, str]: if not code.strip(): return False, "", "Error: Empty code" with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False, encoding='utf-8') as f: f.write(code) tmp_path = f.name try: result = subprocess.run( [sys.executable, tmp_path], capture_output=True, text=True, timeout=timeout_sec ) success = (result.returncode == 0) return success, result.stdout, result.stderr except subprocess.TimeoutExpired: return False, "", f"Timeout after {timeout_sec}s" except Exception as e: return False, "", f"Execution error: {str(e)}" finally: try: os.unlink(tmp_path) except: pass # ---------------------------------------------------------------------- # Main Environment # ---------------------------------------------------------------------- @dataclass class CodeReviewEnv: task: str = "easy" max_steps: int = 10 step_penalty: float = 0.02 _red_team: Optional[RedTeam] = field(init=False, default=None) _author: Optional[PersonaAuthor] = field(init=False, default=None) _current_code: str = field(init=False, default="") _current_bug_id: str = field(init=False, default="") _bug_description: str = field(init=False, default="") _oracle_fix: str = field(init=False, default="") _comments: list = field(init=False, default_factory=list) _test_results: Optional[str] = field(init=False, default=None) _lint_results: Optional[str] = field(init=False, default=None) _doc_results: Optional[str] = field(init=False, default=None) _step_count: int = field(init=False, default=0) _done: bool = field(init=False, default=False) # ------------------------------------------------------------------ def __post_init__(self): self.set_task(self.task) # ------------------------------------------------------------------ def set_task(self, task: str): if task not in ["easy", "medium", "hard", "harder", "hardest"]: raise ValueError(f"Unknown task: {task}") self.task = task self._red_team = RedTeam(task) self._author = PersonaAuthor() # uses default personality "defensive" self._reset_internal() # ------------------------------------------------------------------ def _reset_internal(self): self._step_count = 0 self._comments = [] self._test_results = None self._lint_results = None self._doc_results = None self._done = False self._author.reset() # --- Base tasks --- if self.task == "easy": original = "def get_user(id):\n if id in users:\n return users[id]" elif self.task == "medium": original = "def process_items(items):\n for item in items:\n print(item)" elif self.task == "hard": original = "def average(data):\n if not data:\n return 0\n return sum(data) / len(data)" elif self.task == "harder": original = "counter = 0\ndef increment():\n global counter\n with lock:\n counter += 1" else: original = "def safe_work():\n with lock1:\n with lock2:\n do_work()" # --- Inject bug --- buggy_code, bug_id, desc, oracle = self._red_team.inject_bug(original) self._current_code = buggy_code self._current_bug_id = bug_id self._bug_description = desc self._oracle_fix = oracle self._comments.append(f"[RedTeam] {desc}") # ------------------------------------------------------------------ def reset(self) -> Observation: self._reset_internal() return self._get_observation() # ------------------------------------------------------------------ def _get_observation(self) -> Observation: # Observation as defined in models.py (no conversation_history) return Observation( code_snippet=self._current_code, last_tool_output=self._test_results or "", step=self._step_count, done=self._done ) # ------------------------------------------------------------------ def step(self, action: AnyAction) -> Tuple[Observation, Reward, bool, Dict[str, Any]]: if self._done: raise RuntimeError("Episode already finished") reward_val = 0.0 info = {} # ================================================================ # TOOL ACTIONS # ================================================================ if isinstance(action, Execute): success, stdout, stderr = execute_code(self._current_code) self._test_results = (stdout + stderr).strip() or "No output" reward_val = -self.step_penalty elif isinstance(action, Inspect): self._test_results = self._current_code reward_val = -self.step_penalty elif isinstance(action, RunLinter): lint_output = ToolBox.run_linter(self._current_code) self._lint_results = lint_output[:500] self._test_results = self._lint_results reward_val = -self.step_penalty elif isinstance(action, RunTests): runner = TestRunner(self._current_bug_id) score, output = runner.run_tests(self._current_code) self._test_results = f"Test score: {score:.2f}\n{output[:500]}" reward_val = -self.step_penalty elif isinstance(action, QueryDocs): doc = ToolBox.query_docs(action.query_topic) self._doc_results = doc self._test_results = doc reward_val = -self.step_penalty # ================================================================ # COMMUNICATION (MULTI-TURN) # ================================================================ elif isinstance(action, WriteComment): self._comments.append(f"Agent: {action.comment_text}") response = self._author.respond( agent_comment=action.comment_text, test_results=self._test_results, lint_results=self._lint_results, doc_results=self._doc_results, proposed_fix=None, original_code=self._current_code ) self._comments.append(f"Author: {response}") reward_val = -self.step_penalty elif isinstance(action, AskQuestion): self._comments.append(f"Agent: {action.question}") response = self._author.respond( agent_question=action.question, test_results=self._test_results, lint_results=self._lint_results, doc_results=self._doc_results, proposed_fix=None, original_code=self._current_code ) self._comments.append(f"Author: {response}") reward_val = -self.step_penalty # ================================================================ # FINAL FIX # ================================================================ elif isinstance(action, ProposeFix): if not action.fix_code: reward_val = -0.5 self._done = True else: self._current_code = action.fix_code runner = TestRunner(self._current_bug_id) test_score, test_output = runner.run_tests(self._current_code) lint_score = self._run_linter_score(self._current_code) negotiation_score = self._author.get_negotiation_score() step_cost = self.step_penalty * self._step_count reward_val = ( 0.6 * test_score + 0.2 * lint_score + 0.2 * negotiation_score - step_cost ) # ------------------------- # Cross-signal penalties # ------------------------- if test_score > 0.8 and lint_score < 0.3: reward_val *= 0.8 if test_score < 0.3 and lint_score > 0.8: reward_val *= 0.7 if test_score > 0.8 and negotiation_score < 0.3: reward_val *= 0.75 # ------------------------- # Author gating (only if not already convinced) # ------------------------- threshold = self._author.thresholds.get(self._author.personality, 0.5) if self._author._confidence < threshold: reward_val = max(0.0, reward_val - 0.3) # Allow continuation if steps left if self._step_count < self.max_steps: self._done = False else: self._done = True else: self._done = True reward_val = max(0.0, min(1.0, reward_val)) self._test_results = f"Test score: {test_score:.2f}\n{test_output[:300]}" # ================================================================ # TERMINATION # ================================================================ elif isinstance(action, Skip): reward_val = -0.2 self._done = True elif isinstance(action, Done): reward_val = -0.5 self._done = True else: reward_val = -0.2 self._done = True # ================================================================ # STEP UPDATE # ================================================================ self._step_count += 1 if self._step_count >= self.max_steps: self._done = True obs = self._get_observation() return obs, Reward(value=reward_val), self._done, info # ------------------------------------------------------------------ def _run_linter_score(self, code: str) -> float: try: with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: f.write(code) tmp_path = f.name result = subprocess.run( ['pylint', tmp_path, '--score=y', '--exit-zero'], capture_output=True, text=True, timeout=5 ) match = re.search(r"rated at (\d+\.\d+)/10", result.stdout) if match: return float(match.group(1)) / 10.0 return 0.0 except: return 0.0 finally: try: os.unlink(tmp_path) except: pass # ------------------------------------------------------------------ def state(self) -> State: return State( pr_title="Code Review", pr_description=self._bug_description, code_snippet=self._current_code, comments=self._comments.copy(), test_results=self._test_results, step=self._step_count, done=self._done )