Spaces:
No application file
No application file
| # environment.py – Final integrated environment (multi-turn, gated, continuous scoring) | |
| import sys | |
| import subprocess | |
| import tempfile | |
| import os | |
| import re | |
| from dataclasses import dataclass, field | |
| from typing import Tuple, Dict, Any, Optional | |
| # Uncomment these imports – ensure the files are in the same directory | |
| from models import ( | |
| AnyAction, WriteComment, ProposeFix, Execute, Inspect, | |
| RunLinter, RunTests, QueryDocs, Skip, Done, AskQuestion, | |
| Observation, Reward, State | |
| ) | |
| from grader import RigorousGrader | |
| from redteam import RedTeam | |
| from test_runner import TestRunner | |
| from author import PersonaAuthor | |
| from rltool import ToolBox | |
| # ---------------------------------------------------------------------- | |
| # Helper: execute arbitrary Python code | |
| # ---------------------------------------------------------------------- | |
| def execute_code(code: str, timeout_sec: int = 5) -> Tuple[bool, str, str]: | |
| if not code.strip(): | |
| return False, "", "Error: Empty code" | |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False, encoding='utf-8') as f: | |
| f.write(code) | |
| tmp_path = f.name | |
| try: | |
| result = subprocess.run( | |
| [sys.executable, tmp_path], | |
| capture_output=True, | |
| text=True, | |
| timeout=timeout_sec | |
| ) | |
| success = (result.returncode == 0) | |
| return success, result.stdout, result.stderr | |
| except subprocess.TimeoutExpired: | |
| return False, "", f"Timeout after {timeout_sec}s" | |
| except Exception as e: | |
| return False, "", f"Execution error: {str(e)}" | |
| finally: | |
| try: | |
| os.unlink(tmp_path) | |
| except: | |
| pass | |
| # ---------------------------------------------------------------------- | |
| # Main Environment | |
| # ---------------------------------------------------------------------- | |
| class CodeReviewEnv: | |
| task: str = "easy" | |
| max_steps: int = 10 | |
| step_penalty: float = 0.02 | |
| _red_team: Optional[RedTeam] = field(init=False, default=None) | |
| _author: Optional[PersonaAuthor] = field(init=False, default=None) | |
| _current_code: str = field(init=False, default="") | |
| _current_bug_id: str = field(init=False, default="") | |
| _bug_description: str = field(init=False, default="") | |
| _oracle_fix: str = field(init=False, default="") | |
| _comments: list = field(init=False, default_factory=list) | |
| _test_results: Optional[str] = field(init=False, default=None) | |
| _lint_results: Optional[str] = field(init=False, default=None) | |
| _doc_results: Optional[str] = field(init=False, default=None) | |
| _step_count: int = field(init=False, default=0) | |
| _done: bool = field(init=False, default=False) | |
| # ------------------------------------------------------------------ | |
| def __post_init__(self): | |
| self.set_task(self.task) | |
| # ------------------------------------------------------------------ | |
| def set_task(self, task: str): | |
| if task not in ["easy", "medium", "hard", "harder", "hardest"]: | |
| raise ValueError(f"Unknown task: {task}") | |
| self.task = task | |
| self._red_team = RedTeam(task) | |
| self._author = PersonaAuthor() # uses default personality "defensive" | |
| self._reset_internal() | |
| # ------------------------------------------------------------------ | |
| def _reset_internal(self): | |
| self._step_count = 0 | |
| self._comments = [] | |
| self._test_results = None | |
| self._lint_results = None | |
| self._doc_results = None | |
| self._done = False | |
| self._author.reset() | |
| # --- Base tasks --- | |
| if self.task == "easy": | |
| original = "def get_user(id):\n if id in users:\n return users[id]" | |
| elif self.task == "medium": | |
| original = "def process_items(items):\n for item in items:\n print(item)" | |
| elif self.task == "hard": | |
| original = "def average(data):\n if not data:\n return 0\n return sum(data) / len(data)" | |
| elif self.task == "harder": | |
| original = "counter = 0\ndef increment():\n global counter\n with lock:\n counter += 1" | |
| else: | |
| original = "def safe_work():\n with lock1:\n with lock2:\n do_work()" | |
| # --- Inject bug --- | |
| buggy_code, bug_id, desc, oracle = self._red_team.inject_bug(original) | |
| self._current_code = buggy_code | |
| self._current_bug_id = bug_id | |
| self._bug_description = desc | |
| self._oracle_fix = oracle | |
| self._comments.append(f"[RedTeam] {desc}") | |
| # ------------------------------------------------------------------ | |
| def reset(self) -> Observation: | |
| self._reset_internal() | |
| return self._get_observation() | |
| # ------------------------------------------------------------------ | |
| def _get_observation(self) -> Observation: | |
| # Observation as defined in models.py (no conversation_history) | |
| return Observation( | |
| code_snippet=self._current_code, | |
| last_tool_output=self._test_results or "", | |
| step=self._step_count, | |
| done=self._done | |
| ) | |
| # ------------------------------------------------------------------ | |
| def step(self, action: AnyAction) -> Tuple[Observation, Reward, bool, Dict[str, Any]]: | |
| if self._done: | |
| raise RuntimeError("Episode already finished") | |
| reward_val = 0.0 | |
| info = {} | |
| # ================================================================ | |
| # TOOL ACTIONS | |
| # ================================================================ | |
| if isinstance(action, Execute): | |
| success, stdout, stderr = execute_code(self._current_code) | |
| self._test_results = (stdout + stderr).strip() or "No output" | |
| reward_val = -self.step_penalty | |
| elif isinstance(action, Inspect): | |
| self._test_results = self._current_code | |
| reward_val = -self.step_penalty | |
| elif isinstance(action, RunLinter): | |
| lint_output = ToolBox.run_linter(self._current_code) | |
| self._lint_results = lint_output[:500] | |
| self._test_results = self._lint_results | |
| reward_val = -self.step_penalty | |
| elif isinstance(action, RunTests): | |
| runner = TestRunner(self._current_bug_id) | |
| score, output = runner.run_tests(self._current_code) | |
| self._test_results = f"Test score: {score:.2f}\n{output[:500]}" | |
| reward_val = -self.step_penalty | |
| elif isinstance(action, QueryDocs): | |
| doc = ToolBox.query_docs(action.query_topic) | |
| self._doc_results = doc | |
| self._test_results = doc | |
| reward_val = -self.step_penalty | |
| # ================================================================ | |
| # COMMUNICATION (MULTI-TURN) | |
| # ================================================================ | |
| elif isinstance(action, WriteComment): | |
| self._comments.append(f"Agent: {action.comment_text}") | |
| response = self._author.respond( | |
| agent_comment=action.comment_text, | |
| test_results=self._test_results, | |
| lint_results=self._lint_results, | |
| doc_results=self._doc_results, | |
| proposed_fix=None, | |
| original_code=self._current_code | |
| ) | |
| self._comments.append(f"Author: {response}") | |
| reward_val = -self.step_penalty | |
| elif isinstance(action, AskQuestion): | |
| self._comments.append(f"Agent: {action.question}") | |
| response = self._author.respond( | |
| agent_question=action.question, | |
| test_results=self._test_results, | |
| lint_results=self._lint_results, | |
| doc_results=self._doc_results, | |
| proposed_fix=None, | |
| original_code=self._current_code | |
| ) | |
| self._comments.append(f"Author: {response}") | |
| reward_val = -self.step_penalty | |
| # ================================================================ | |
| # FINAL FIX | |
| # ================================================================ | |
| elif isinstance(action, ProposeFix): | |
| if not action.fix_code: | |
| reward_val = -0.5 | |
| self._done = True | |
| else: | |
| self._current_code = action.fix_code | |
| runner = TestRunner(self._current_bug_id) | |
| test_score, test_output = runner.run_tests(self._current_code) | |
| lint_score = self._run_linter_score(self._current_code) | |
| negotiation_score = self._author.get_negotiation_score() | |
| step_cost = self.step_penalty * self._step_count | |
| reward_val = ( | |
| 0.6 * test_score + | |
| 0.2 * lint_score + | |
| 0.2 * negotiation_score - | |
| step_cost | |
| ) | |
| # ------------------------- | |
| # Cross-signal penalties | |
| # ------------------------- | |
| if test_score > 0.8 and lint_score < 0.3: | |
| reward_val *= 0.8 | |
| if test_score < 0.3 and lint_score > 0.8: | |
| reward_val *= 0.7 | |
| if test_score > 0.8 and negotiation_score < 0.3: | |
| reward_val *= 0.75 | |
| # ------------------------- | |
| # Author gating (only if not already convinced) | |
| # ------------------------- | |
| threshold = self._author.thresholds.get(self._author.personality, 0.5) | |
| if self._author._confidence < threshold: | |
| reward_val = max(0.0, reward_val - 0.3) | |
| # Allow continuation if steps left | |
| if self._step_count < self.max_steps: | |
| self._done = False | |
| else: | |
| self._done = True | |
| else: | |
| self._done = True | |
| reward_val = max(0.0, min(1.0, reward_val)) | |
| self._test_results = f"Test score: {test_score:.2f}\n{test_output[:300]}" | |
| # ================================================================ | |
| # TERMINATION | |
| # ================================================================ | |
| elif isinstance(action, Skip): | |
| reward_val = -0.2 | |
| self._done = True | |
| elif isinstance(action, Done): | |
| reward_val = -0.5 | |
| self._done = True | |
| else: | |
| reward_val = -0.2 | |
| self._done = True | |
| # ================================================================ | |
| # STEP UPDATE | |
| # ================================================================ | |
| self._step_count += 1 | |
| if self._step_count >= self.max_steps: | |
| self._done = True | |
| obs = self._get_observation() | |
| return obs, Reward(value=reward_val), self._done, info | |
| # ------------------------------------------------------------------ | |
| def _run_linter_score(self, code: str) -> float: | |
| try: | |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: | |
| f.write(code) | |
| tmp_path = f.name | |
| result = subprocess.run( | |
| ['pylint', tmp_path, '--score=y', '--exit-zero'], | |
| capture_output=True, | |
| text=True, | |
| timeout=5 | |
| ) | |
| match = re.search(r"rated at (\d+\.\d+)/10", result.stdout) | |
| if match: | |
| return float(match.group(1)) / 10.0 | |
| return 0.0 | |
| except: | |
| return 0.0 | |
| finally: | |
| try: | |
| os.unlink(tmp_path) | |
| except: | |
| pass | |
| # ------------------------------------------------------------------ | |
| def state(self) -> State: | |
| return State( | |
| pr_title="Code Review", | |
| pr_description=self._bug_description, | |
| code_snippet=self._current_code, | |
| comments=self._comments.copy(), | |
| test_results=self._test_results, | |
| step=self._step_count, | |
| done=self._done | |
| ) |