code-review-professional / environment.py
100XZX001's picture
Upload 13 files
29c6586 verified
# environment.py – Final integrated environment (multi-turn, gated, continuous scoring)
import sys
import subprocess
import tempfile
import os
import re
from dataclasses import dataclass, field
from typing import Tuple, Dict, Any, Optional
# Uncomment these imports – ensure the files are in the same directory
from models import (
AnyAction, WriteComment, ProposeFix, Execute, Inspect,
RunLinter, RunTests, QueryDocs, Skip, Done, AskQuestion,
Observation, Reward, State
)
from grader import RigorousGrader
from redteam import RedTeam
from test_runner import TestRunner
from author import PersonaAuthor
from rltool import ToolBox
# ----------------------------------------------------------------------
# Helper: execute arbitrary Python code
# ----------------------------------------------------------------------
def execute_code(code: str, timeout_sec: int = 5) -> Tuple[bool, str, str]:
if not code.strip():
return False, "", "Error: Empty code"
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False, encoding='utf-8') as f:
f.write(code)
tmp_path = f.name
try:
result = subprocess.run(
[sys.executable, tmp_path],
capture_output=True,
text=True,
timeout=timeout_sec
)
success = (result.returncode == 0)
return success, result.stdout, result.stderr
except subprocess.TimeoutExpired:
return False, "", f"Timeout after {timeout_sec}s"
except Exception as e:
return False, "", f"Execution error: {str(e)}"
finally:
try:
os.unlink(tmp_path)
except:
pass
# ----------------------------------------------------------------------
# Main Environment
# ----------------------------------------------------------------------
@dataclass
class CodeReviewEnv:
task: str = "easy"
max_steps: int = 10
step_penalty: float = 0.02
_red_team: Optional[RedTeam] = field(init=False, default=None)
_author: Optional[PersonaAuthor] = field(init=False, default=None)
_current_code: str = field(init=False, default="")
_current_bug_id: str = field(init=False, default="")
_bug_description: str = field(init=False, default="")
_oracle_fix: str = field(init=False, default="")
_comments: list = field(init=False, default_factory=list)
_test_results: Optional[str] = field(init=False, default=None)
_lint_results: Optional[str] = field(init=False, default=None)
_doc_results: Optional[str] = field(init=False, default=None)
_step_count: int = field(init=False, default=0)
_done: bool = field(init=False, default=False)
# ------------------------------------------------------------------
def __post_init__(self):
self.set_task(self.task)
# ------------------------------------------------------------------
def set_task(self, task: str):
if task not in ["easy", "medium", "hard", "harder", "hardest"]:
raise ValueError(f"Unknown task: {task}")
self.task = task
self._red_team = RedTeam(task)
self._author = PersonaAuthor() # uses default personality "defensive"
self._reset_internal()
# ------------------------------------------------------------------
def _reset_internal(self):
self._step_count = 0
self._comments = []
self._test_results = None
self._lint_results = None
self._doc_results = None
self._done = False
self._author.reset()
# --- Base tasks ---
if self.task == "easy":
original = "def get_user(id):\n if id in users:\n return users[id]"
elif self.task == "medium":
original = "def process_items(items):\n for item in items:\n print(item)"
elif self.task == "hard":
original = "def average(data):\n if not data:\n return 0\n return sum(data) / len(data)"
elif self.task == "harder":
original = "counter = 0\ndef increment():\n global counter\n with lock:\n counter += 1"
else:
original = "def safe_work():\n with lock1:\n with lock2:\n do_work()"
# --- Inject bug ---
buggy_code, bug_id, desc, oracle = self._red_team.inject_bug(original)
self._current_code = buggy_code
self._current_bug_id = bug_id
self._bug_description = desc
self._oracle_fix = oracle
self._comments.append(f"[RedTeam] {desc}")
# ------------------------------------------------------------------
def reset(self) -> Observation:
self._reset_internal()
return self._get_observation()
# ------------------------------------------------------------------
def _get_observation(self) -> Observation:
# Observation as defined in models.py (no conversation_history)
return Observation(
code_snippet=self._current_code,
last_tool_output=self._test_results or "",
step=self._step_count,
done=self._done
)
# ------------------------------------------------------------------
def step(self, action: AnyAction) -> Tuple[Observation, Reward, bool, Dict[str, Any]]:
if self._done:
raise RuntimeError("Episode already finished")
reward_val = 0.0
info = {}
# ================================================================
# TOOL ACTIONS
# ================================================================
if isinstance(action, Execute):
success, stdout, stderr = execute_code(self._current_code)
self._test_results = (stdout + stderr).strip() or "No output"
reward_val = -self.step_penalty
elif isinstance(action, Inspect):
self._test_results = self._current_code
reward_val = -self.step_penalty
elif isinstance(action, RunLinter):
lint_output = ToolBox.run_linter(self._current_code)
self._lint_results = lint_output[:500]
self._test_results = self._lint_results
reward_val = -self.step_penalty
elif isinstance(action, RunTests):
runner = TestRunner(self._current_bug_id)
score, output = runner.run_tests(self._current_code)
self._test_results = f"Test score: {score:.2f}\n{output[:500]}"
reward_val = -self.step_penalty
elif isinstance(action, QueryDocs):
doc = ToolBox.query_docs(action.query_topic)
self._doc_results = doc
self._test_results = doc
reward_val = -self.step_penalty
# ================================================================
# COMMUNICATION (MULTI-TURN)
# ================================================================
elif isinstance(action, WriteComment):
self._comments.append(f"Agent: {action.comment_text}")
response = self._author.respond(
agent_comment=action.comment_text,
test_results=self._test_results,
lint_results=self._lint_results,
doc_results=self._doc_results,
proposed_fix=None,
original_code=self._current_code
)
self._comments.append(f"Author: {response}")
reward_val = -self.step_penalty
elif isinstance(action, AskQuestion):
self._comments.append(f"Agent: {action.question}")
response = self._author.respond(
agent_question=action.question,
test_results=self._test_results,
lint_results=self._lint_results,
doc_results=self._doc_results,
proposed_fix=None,
original_code=self._current_code
)
self._comments.append(f"Author: {response}")
reward_val = -self.step_penalty
# ================================================================
# FINAL FIX
# ================================================================
elif isinstance(action, ProposeFix):
if not action.fix_code:
reward_val = -0.5
self._done = True
else:
self._current_code = action.fix_code
runner = TestRunner(self._current_bug_id)
test_score, test_output = runner.run_tests(self._current_code)
lint_score = self._run_linter_score(self._current_code)
negotiation_score = self._author.get_negotiation_score()
step_cost = self.step_penalty * self._step_count
reward_val = (
0.6 * test_score +
0.2 * lint_score +
0.2 * negotiation_score -
step_cost
)
# -------------------------
# Cross-signal penalties
# -------------------------
if test_score > 0.8 and lint_score < 0.3:
reward_val *= 0.8
if test_score < 0.3 and lint_score > 0.8:
reward_val *= 0.7
if test_score > 0.8 and negotiation_score < 0.3:
reward_val *= 0.75
# -------------------------
# Author gating (only if not already convinced)
# -------------------------
threshold = self._author.thresholds.get(self._author.personality, 0.5)
if self._author._confidence < threshold:
reward_val = max(0.0, reward_val - 0.3)
# Allow continuation if steps left
if self._step_count < self.max_steps:
self._done = False
else:
self._done = True
else:
self._done = True
reward_val = max(0.0, min(1.0, reward_val))
self._test_results = f"Test score: {test_score:.2f}\n{test_output[:300]}"
# ================================================================
# TERMINATION
# ================================================================
elif isinstance(action, Skip):
reward_val = -0.2
self._done = True
elif isinstance(action, Done):
reward_val = -0.5
self._done = True
else:
reward_val = -0.2
self._done = True
# ================================================================
# STEP UPDATE
# ================================================================
self._step_count += 1
if self._step_count >= self.max_steps:
self._done = True
obs = self._get_observation()
return obs, Reward(value=reward_val), self._done, info
# ------------------------------------------------------------------
def _run_linter_score(self, code: str) -> float:
try:
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
f.write(code)
tmp_path = f.name
result = subprocess.run(
['pylint', tmp_path, '--score=y', '--exit-zero'],
capture_output=True,
text=True,
timeout=5
)
match = re.search(r"rated at (\d+\.\d+)/10", result.stdout)
if match:
return float(match.group(1)) / 10.0
return 0.0
except:
return 0.0
finally:
try:
os.unlink(tmp_path)
except:
pass
# ------------------------------------------------------------------
def state(self) -> State:
return State(
pr_title="Code Review",
pr_description=self._bug_description,
code_snippet=self._current_code,
comments=self._comments.copy(),
test_results=self._test_results,
step=self._step_count,
done=self._done
)