|
|
|
|
| import sys
|
| import subprocess
|
| import tempfile
|
| import os
|
| import re
|
| from dataclasses import dataclass, field
|
| from typing import Tuple, Dict, Any, Optional, List
|
|
|
| from models import (
|
| AnyAction, WriteComment, ProposeFix, Execute, Inspect,
|
| RunLinter, RunTests, QueryDocs, Skip, Done, AskQuestion,
|
| Observation, Reward, State
|
| )
|
| from redteam import RedTeam
|
| from test_runner import TestRunner
|
| from author import PersonaAuthor
|
| from rltool import ToolBox
|
| from rubrics import (
|
| ToolUsageRubric,
|
| TestDeltaRubric,
|
| LintDeltaRubric,
|
| TerminalSuccessRubric,
|
| ExplorationRubric,
|
| AntiHackingRubric,
|
| StepPenaltyRubric,
|
| )
|
|
|
|
|
|
|
|
|
| @dataclass
|
| class EnhancedObservation:
|
| code_snippet: str
|
| last_tool_output: str
|
|
|
| current_test_score: float
|
| current_lint_score: float
|
| negotiation_score: float
|
|
|
| previous_test_score: float
|
| previous_lint_score: float
|
|
|
| author_confidence: float
|
| author_threshold: float
|
|
|
| step: int
|
| max_steps: int
|
| progress_ratio: float
|
|
|
| tests_run: bool
|
| linter_run: bool
|
| docs_queried: bool
|
|
|
| last_action_type: str
|
| action_history: List[str]
|
|
|
| done: bool
|
|
|
| bug_description: str
|
| comments_count: int
|
|
|
|
|
| author_response: str = ""
|
|
|
|
|
|
|
|
|
| def execute_code(code: str, timeout_sec: int = 5) -> Tuple[bool, str, str]:
|
| if not code.strip():
|
| return False, "", "Error: Empty code"
|
|
|
| with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False, encoding='utf-8') as f:
|
| f.write(code)
|
| tmp_path = f.name
|
|
|
| try:
|
| result = subprocess.run(
|
| [sys.executable, tmp_path],
|
| capture_output=True,
|
| text=True,
|
| timeout=timeout_sec
|
| )
|
| success = (result.returncode == 0)
|
| return success, result.stdout, result.stderr
|
| except subprocess.TimeoutExpired:
|
| return False, "", f"Timeout after {timeout_sec}s"
|
| except Exception as e:
|
| return False, "", f"Execution error: {str(e)}"
|
| finally:
|
| try:
|
| os.unlink(tmp_path)
|
| except:
|
| pass
|
|
|
|
|
|
|
|
|
|
|
| @dataclass
|
| class CodeReviewEnv:
|
| task: str = "easy"
|
| max_steps: int = 10
|
| step_penalty: float = 0.01
|
| reward_profile: str = "full"
|
|
|
|
|
| auto_difficulty: bool = False
|
| success_threshold: float = 0.7
|
|
|
|
|
| delta_weight: float = 0.3
|
| tool_usage_bonus: float = 0.05
|
| diversity_bonus: float = 0.03
|
|
|
| _red_team: Optional[RedTeam] = field(init=False, default=None)
|
| _author: Optional[PersonaAuthor] = field(init=False, default=None)
|
|
|
| _current_code: str = field(init=False, default="")
|
| _current_bug_id: str = field(init=False, default="")
|
| _bug_description: str = field(init=False, default="")
|
| _oracle_fix: str = field(init=False, default="")
|
|
|
| _comments: list = field(init=False, default_factory=list)
|
| _test_results: Optional[str] = field(init=False, default=None)
|
| _lint_results: Optional[str] = field(init=False, default=None)
|
| _doc_results: Optional[str] = field(init=False, default=None)
|
|
|
| _step_count: int = field(init=False, default=0)
|
| _done: bool = field(init=False, default=False)
|
|
|
|
|
| _previous_test_score: float = field(init=False, default=0.0)
|
| _previous_lint_score: float = field(init=False, default=0.0)
|
| _current_test_score: float = field(init=False, default=0.0)
|
| _current_lint_score: float = field(init=False, default=0.0)
|
|
|
|
|
| _tests_run: bool = field(init=False, default=False)
|
| _linter_run: bool = field(init=False, default=False)
|
| _docs_queried: bool = field(init=False, default=False)
|
|
|
|
|
| _action_history: List[str] = field(init=False, default_factory=list)
|
| _last_action_type: str = field(init=False, default="none")
|
| _last_author_response: str = field(init=False, default="")
|
|
|
|
|
| _episode_total_reward: float = field(init=False, default=0.0)
|
| _episode_rewards: List[float] = field(init=False, default_factory=list)
|
| _difficulty_level: int = field(init=False, default=0)
|
|
|
|
|
|
|
|
|
|
|
| _BUG_ID_CANONICAL_MAP = {
|
|
|
| "simple_typo": "null_check",
|
| "default_value": "null_check",
|
| "empty_return": "null_check",
|
| "string_index": "off_by_one",
|
|
|
|
|
| "loop_skip": "off_by_one",
|
| "sign_error": "wrong_operator",
|
| "swap_args": "wrong_operator",
|
| "uninitialised_var": "null_check",
|
|
|
|
|
| "division_by_zero_empty": "division_by_zero",
|
| "division_by_zero_zero": "division_by_zero",
|
| "float_precision": "division_by_zero",
|
| "abs_usage": "division_by_zero",
|
| "round_error": "division_by_zero",
|
| }
|
|
|
|
|
| def __post_init__(self):
|
| self.set_task(self.task)
|
|
|
|
|
| def _build_rubrics(self):
|
| """
|
| Build rubric stack from a named reward profile.
|
| - full: richer shaping for exploration/tool-use behavior
|
| - core: minimal stable signal for quick ablations/baselines
|
| """
|
| core_rubrics = [
|
| TestDeltaRubric(weight=self.delta_weight),
|
| LintDeltaRubric(weight=self.delta_weight),
|
| TerminalSuccessRubric(),
|
| StepPenaltyRubric(penalty=self.step_penalty),
|
| ]
|
| if self.reward_profile == "core":
|
| return core_rubrics
|
| if self.reward_profile == "full":
|
| return [
|
| *core_rubrics[:-1],
|
| ToolUsageRubric(bonus=self.tool_usage_bonus),
|
| ExplorationRubric(penalty=-0.05, bonus=self.diversity_bonus * 0.7),
|
| AntiHackingRubric(),
|
| core_rubrics[-1],
|
| ]
|
| raise ValueError(f"Unknown reward_profile: {self.reward_profile}")
|
|
|
|
|
| def set_task(self, task: str):
|
| if task not in ["easy", "medium", "hard", "harder", "hardest"]:
|
| raise ValueError(f"Unknown task: {task}")
|
|
|
| self.task = task
|
|
|
|
|
| self._red_team = RedTeam(task, seed=None)
|
| self._author = PersonaAuthor()
|
| self.rubrics = self._build_rubrics()
|
|
|
| task_to_level = {
|
| "easy": 0, "medium": 1, "hard": 2,
|
| "harder": 3, "hardest": 4
|
| }
|
| self._difficulty_level = task_to_level[task]
|
|
|
| self._reset_internal()
|
|
|
|
|
| def _reset_internal(self):
|
| self._step_count = 0
|
| self._comments = []
|
| self._test_results = None
|
| self._lint_results = None
|
| self._doc_results = None
|
| self._done = False
|
|
|
|
|
| self._previous_test_score = 0.0
|
| self._previous_lint_score = 0.0
|
| self._current_test_score = 0.0
|
| self._current_lint_score = 0.0
|
|
|
| self._tests_run = False
|
| self._linter_run = False
|
| self._docs_queried = False
|
|
|
| self._action_history = []
|
| self._last_action_type = "none"
|
| self._last_author_response = ""
|
|
|
|
|
| self._episode_total_reward = 0.0
|
|
|
| self._author.reset()
|
|
|
|
|
| if self.task == "easy":
|
| original = "def get_user(id):\n if id in users:\n return users[id]"
|
| elif self.task == "medium":
|
| original = "def process_items(items):\n for item in items:\n print(item)"
|
| elif self.task == "hard":
|
| original = "def average(data):\n if not data:\n return 0\n return sum(data) / len(data)"
|
| elif self.task == "harder":
|
| original = "counter = 0\ndef increment():\n global counter\n with lock:\n counter += 1"
|
| else:
|
| original = "def safe_work():\n with lock1:\n with lock2:\n do_work()"
|
|
|
| buggy_code, bug_id, desc, oracle = self._red_team.inject_bug(original)
|
| self._current_code = buggy_code
|
| self._current_bug_id = bug_id
|
| self._bug_description = desc
|
| self._oracle_fix = oracle
|
| self._comments.append(f"[RedTeam] {desc}")
|
|
|
|
|
| def reset(self) -> EnhancedObservation:
|
| """Reset with optional curriculum adjustment."""
|
| if self.auto_difficulty and len(self._episode_rewards) > 0:
|
| recent_performance = sum(self._episode_rewards[-5:]) / min(5, len(self._episode_rewards))
|
|
|
| if recent_performance > self.success_threshold and self._difficulty_level < 4:
|
| self._difficulty_level += 1
|
| print(f"[Curriculum] Increasing difficulty to level {self._difficulty_level}")
|
| elif recent_performance < 0.3 and self._difficulty_level > 0:
|
| self._difficulty_level -= 1
|
| print(f"[Curriculum] Decreasing difficulty to level {self._difficulty_level}")
|
|
|
| level_to_task = {0: "easy", 1: "medium", 2: "hard", 3: "harder", 4: "hardest"}
|
| self.task = level_to_task[self._difficulty_level]
|
|
|
| self._red_team = RedTeam(self.task, seed=None)
|
|
|
| self._reset_internal()
|
| return self._get_observation()
|
|
|
|
|
| def _get_observation(self) -> EnhancedObservation:
|
| """Return COMPLETE Markov state."""
|
|
|
|
|
|
|
| if self._last_action_type in ("comment", "question", "fix"):
|
| author_response = self._last_author_response
|
| else:
|
| author_response = ""
|
|
|
| return EnhancedObservation(
|
| code_snippet=self._current_code,
|
| last_tool_output=self._test_results or "",
|
| author_response=author_response,
|
|
|
| current_test_score=self._current_test_score,
|
| current_lint_score=self._current_lint_score,
|
| negotiation_score=self._author.get_negotiation_score(),
|
|
|
| previous_test_score=self._previous_test_score,
|
| previous_lint_score=self._previous_lint_score,
|
|
|
| author_confidence=self._author._confidence,
|
| author_threshold=self._author.thresholds.get(self._author.personality, 0.5),
|
|
|
| step=self._step_count,
|
| max_steps=self.max_steps,
|
|
|
| progress_ratio=(self._step_count / self.max_steps) if self.max_steps > 0 else 1.0,
|
|
|
| tests_run=self._tests_run,
|
| linter_run=self._linter_run,
|
| docs_queried=self._docs_queried,
|
|
|
| last_action_type=self._last_action_type,
|
| action_history=self._action_history[-5:],
|
|
|
| done=self._done,
|
|
|
| bug_description=self._bug_description,
|
| comments_count=len(self._comments),
|
| )
|
|
|
|
|
| def _get_action_type(self, action: AnyAction) -> str:
|
| """Extract action type as string."""
|
| if isinstance(action, RunTests):
|
| return "run_tests"
|
| elif isinstance(action, RunLinter):
|
| return "run_linter"
|
| elif isinstance(action, QueryDocs):
|
| return "query_docs"
|
| elif isinstance(action, Execute):
|
| return "execute"
|
| elif isinstance(action, Inspect):
|
| return "inspect"
|
| elif isinstance(action, WriteComment):
|
| return "comment"
|
| elif isinstance(action, AskQuestion):
|
| return "question"
|
| elif isinstance(action, ProposeFix):
|
| return "fix"
|
| elif isinstance(action, Done):
|
| return "done"
|
| elif isinstance(action, Skip):
|
| return "skip"
|
| else:
|
| return "unknown"
|
|
|
|
|
| def _get_test_runner_bug_id(self) -> str:
|
| """
|
| Normalize RedTeam bug ids to the canonical ids understood by TestRunner.
|
| Falls back to the original id for known direct matches.
|
| """
|
| return self._BUG_ID_CANONICAL_MAP.get(self._current_bug_id, self._current_bug_id)
|
|
|
|
|
| def step(self, action: AnyAction) -> Tuple[EnhancedObservation, Reward, bool, Dict[str, Any]]:
|
| """
|
| TRUE RL STEP with:
|
| - Complete Markov observations (no hidden state)
|
| - Dense intermediate rewards
|
| - Delta-based credit assignment (no double-counting)
|
| - Proper episode reward tracking
|
| """
|
| if self._done:
|
| raise RuntimeError("Episode already finished")
|
|
|
|
|
| self._previous_test_score = self._current_test_score
|
| self._previous_lint_score = self._current_lint_score
|
|
|
|
|
| prev_tests_run = self._tests_run
|
| prev_linter_run = self._linter_run
|
| prev_docs_queried = self._docs_queried
|
|
|
| base_reward = 0.0
|
| action_type = self._get_action_type(action)
|
|
|
|
|
| self._action_history.append(action_type)
|
| self._last_action_type = action_type
|
|
|
|
|
|
|
|
|
| if isinstance(action, Execute):
|
| success, stdout, stderr = execute_code(self._current_code)
|
| output = (stdout + stderr).strip() or "No output"
|
| self._test_results = f"[Execute] {'Success' if success else 'Failed'}\n{output[:300]}"
|
| base_reward = 0.001 if success else -0.05
|
|
|
| elif isinstance(action, Inspect):
|
| self._test_results = f"[Inspect]\n{self._current_code[:500]}"
|
| base_reward = 0.001
|
|
|
| elif isinstance(action, RunLinter):
|
| lint_output = ToolBox.run_linter(self._current_code)
|
| self._lint_results = lint_output[:500]
|
| self._test_results = f"[Linter]\n{self._lint_results}"
|
|
|
| self._current_lint_score = self._run_linter_score(self._current_code)
|
| self._linter_run = True
|
| base_reward = 0.002
|
|
|
| elif isinstance(action, RunTests):
|
| runner = TestRunner(self._get_test_runner_bug_id())
|
| score, output = runner.run_tests(self._current_code)
|
|
|
| self._current_test_score = score
|
| self._tests_run = True
|
|
|
| self._test_results = f"[Tests] Score: {score:.2f}\n{output[:300]}"
|
| base_reward = 0.002
|
|
|
| if score > 0.8:
|
| base_reward += 0.005
|
|
|
| elif isinstance(action, QueryDocs):
|
|
|
| query_topic = (action.query_topic or "").strip()
|
| doc = ToolBox.query_docs(query_topic if query_topic else "general bug fixing")
|
| self._doc_results = doc
|
| self._test_results = f"[Docs]\n{doc[:400]}"
|
| self._docs_queried = True
|
| base_reward = 0.001
|
|
|
|
|
|
|
|
|
| elif isinstance(action, WriteComment):
|
| self._comments.append(f"Agent: {action.comment_text}")
|
|
|
| response = self._author.respond(
|
| agent_comment=action.comment_text,
|
| test_results=self._test_results,
|
| lint_results=self._lint_results,
|
| doc_results=self._doc_results,
|
| proposed_fix=None,
|
| original_code=self._current_code
|
| )
|
|
|
| self._comments.append(f"Author: {response}")
|
| self._last_author_response = response
|
| self._test_results = f"[Comment] Author: {response[:200]}"
|
| base_reward = 0.001
|
|
|
| elif isinstance(action, AskQuestion):
|
| self._comments.append(f"Agent: {action.question}")
|
|
|
| response = self._author.respond(
|
| agent_question=action.question,
|
| test_results=self._test_results,
|
| lint_results=self._lint_results,
|
| doc_results=self._doc_results,
|
| proposed_fix=None,
|
| original_code=self._current_code
|
| )
|
|
|
| self._comments.append(f"Author: {response}")
|
| self._last_author_response = response
|
| self._test_results = f"[Question] Author: {response[:200]}"
|
| base_reward = 0.002
|
|
|
|
|
|
|
|
|
| elif isinstance(action, ProposeFix):
|
| if not action.fix_code:
|
| base_reward = -0.05
|
| self._done = True
|
| else:
|
|
|
| original_buggy = self._current_code
|
| self._current_code = action.fix_code
|
|
|
| runner = TestRunner(self._get_test_runner_bug_id())
|
| test_score, test_output = runner.run_tests(self._current_code)
|
| lint_score = self._run_linter_score(self._current_code)
|
| negotiation_score = self._author.get_negotiation_score()
|
|
|
| self._current_test_score = test_score
|
| self._current_lint_score = lint_score
|
|
|
|
|
| threshold = self._author.thresholds.get(self._author.personality, 0.5)
|
| if self._author._confidence < threshold:
|
| if self._step_count < self.max_steps:
|
| self._done = False
|
| else:
|
| self._done = True
|
| else:
|
| self._done = True
|
|
|
|
|
| author_feedback = self._author.respond(
|
| agent_comment=f"Proposed fix:\n{action.fix_code}",
|
| test_results=f"Score: {test_score:.2f}",
|
| lint_results=f"Score: {lint_score:.2f}",
|
| doc_results=self._doc_results,
|
| proposed_fix=action.fix_code,
|
| original_code=original_buggy
|
| )
|
| self._test_results = f"[Fix] Author: {author_feedback[:200]}"
|
| self._comments.append(f"Author: {author_feedback}")
|
| self._last_author_response = author_feedback
|
|
|
| base_reward = 0.001
|
|
|
|
|
|
|
|
|
| elif isinstance(action, Skip):
|
| base_reward = -0.03
|
| self._done = True
|
|
|
| elif isinstance(action, Done):
|
| if self._tests_run:
|
| base_reward = self._current_test_score * 0.5 - 0.2
|
| else:
|
| base_reward = -0.04
|
| self._done = True
|
|
|
| else:
|
| base_reward = -0.02
|
| self._done = True
|
|
|
|
|
|
|
|
|
| self._step_count += 1
|
| if self._step_count >= self.max_steps:
|
| self._done = True
|
|
|
|
|
| obs = self._get_observation()
|
|
|
|
|
| info = {
|
| "action_type": action_type,
|
| "test_score": self._current_test_score,
|
| "lint_score": self._current_lint_score,
|
| "test_delta": self._current_test_score - self._previous_test_score,
|
| "lint_delta": self._current_lint_score - self._previous_lint_score,
|
| "prev_tests_run": prev_tests_run,
|
| "prev_linter_run": prev_linter_run,
|
| "prev_docs_queried": prev_docs_queried,
|
| "docs_query_len": len((action.query_topic or "").strip()) if isinstance(action, QueryDocs) else 0,
|
| "base_reward": base_reward,
|
| }
|
|
|
|
|
|
|
|
|
| rubric_score = sum(r(self, action, obs, None, self._done, info) for r in self.rubrics)
|
| final_reward = 0.4 * base_reward + rubric_score
|
| final_reward = max(-1.0, min(1.0, final_reward))
|
|
|
|
|
| self._episode_total_reward += final_reward
|
|
|
|
|
| if self._done:
|
| self._episode_rewards.append(self._episode_total_reward)
|
|
|
|
|
| info["final_reward"] = final_reward
|
| info["episode_total"] = self._episode_total_reward
|
|
|
| return obs, Reward(value=final_reward), self._done, info
|
|
|
|
|
| def _run_linter_score(self, code: str) -> float:
|
| """Run pylint and return normalized score [0, 1]."""
|
| try:
|
| with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
|
| f.write(code)
|
| tmp_path = f.name
|
|
|
| result = subprocess.run( |
| [sys.executable, '-m', 'pylint', tmp_path, '--score=y', '--exit-zero'], |
| capture_output=True, |
| text=True, |
| timeout=5 |
| )
|
|
|
| match = re.search(r"rated at (\d+\.\d+)/10", result.stdout)
|
| if match:
|
| return float(match.group(1)) / 10.0
|
| return 0.0
|
| except:
|
| return 0.0
|
| finally:
|
| try:
|
| os.unlink(tmp_path)
|
| except:
|
| pass
|
|
|
|
|
| def state(self) -> State:
|
| """Legacy compatibility."""
|
| return State(
|
| pr_title="Code Review",
|
| pr_description=self._bug_description,
|
| code_snippet=self._current_code,
|
| comments=self._comments.copy(),
|
| test_results=self._test_results,
|
| step=self._step_count,
|
| done=self._done
|
| )
|
|
|