| import uuid |
| import difflib |
| import re |
| from models import PRAction, PRObservation, PRState, ReviewDecision |
| from server.tasks import single_pass, iterative, escalation, custom |
| from server import graders |
|
|
| TASKS = { |
| "single-pass-review": single_pass.TASK, |
| "iterative-negotiation": iterative.TASK, |
| "escalation-judgment": escalation.TASK, |
| "custom-review": custom.TASK, |
| } |
|
|
| class PRReviewEnvironment: |
| def __init__(self): |
| self._state = None |
| self._task = None |
| self._rewards = [] |
| self._current_diff = "" |
| self._initial_diff = "" |
|
|
| def _extract_code(self, text: str) -> str: |
| """Extracts python code from markdown triple backticks if present.""" |
| match = re.search(r"```python\n(.*?)\n```", text, re.DOTALL) |
| if match: |
| return match.group(1).strip() |
| |
| match = re.search(r"```\n(.*?)\n```", text, re.DOTALL) |
| if match: |
| return match.group(1).strip() |
| return None |
|
|
| def _generate_unified_diff(self, old_code: str, new_code: str, filename: str = "file.py") -> str: |
| """Generates a standard unified diff string between two versions of code.""" |
| old_lines = old_code.splitlines(keepends=True) |
| new_lines = new_code.splitlines(keepends=True) |
| diff = difflib.unified_diff( |
| old_lines, new_lines, |
| fromfile=f"a/{filename}", tofile=f"b/{filename}" |
| ) |
| return "".join(diff) |
|
|
| def _get_base_code(self, diff_text: str) -> str: |
| """Heuristic to extract the 'result' of a diff or just the text if it's a snippet.""" |
| if not any(x in diff_text for x in ["--- ", "+++ ", "@@ "]): |
| return diff_text.strip() |
| |
| |
| lines = diff_text.splitlines() |
| result_lines = [] |
| for l in lines: |
| if l.startswith("--- ") or l.startswith("+++ ") or l.startswith("@@ ") or l.startswith("index "): |
| continue |
| if l.startswith("-"): |
| continue |
| if l.startswith("+"): |
| result_lines.append(l[1:]) |
| elif l.startswith(" "): |
| result_lines.append(l[1:]) |
| else: |
| result_lines.append(l) |
| return "\n".join(result_lines).strip() |
|
|
| def reset(self, task_name: str = "single-pass-review") -> PRObservation: |
| self._task = TASKS[task_name] |
| self._rewards = [] |
| self._initial_diff = self._task["diff"] |
| self._current_diff = self._task["diff"] |
| |
| self._state = PRState( |
| episode_id=str(uuid.uuid4()), |
| task_name=task_name, |
| turn=0, |
| max_turns=self._task["max_turns"], |
| review_history=[], |
| done=False, |
| success=False, |
| cumulative_reward=0.0, |
| ) |
| return PRObservation( |
| turn=0, |
| diff=self._current_diff, |
| pr_title=self._task["pr_title"], |
| pr_description=self._task["pr_description"], |
| review_history=[], |
| author_response=None, |
| done=False, |
| message="New PR ready for review. Read the diff carefully. Identify the root cause of any issues, not just the symptom. Submit your decision.", |
| ) |
|
|
| def step(self, action: PRAction) -> tuple[PRObservation, float, bool, dict]: |
| assert self._state is not None, "Call reset() first" |
| assert not self._state.done, "Episode is already done" |
|
|
| t = self._state |
| task = self._task |
| gt = task["ground_truth"] |
| turn = t.turn + 1 |
|
|
| correct_key = f"correct_decision_turn_{turn}" if f"correct_decision_turn_{turn}" in gt else "correct_decision" |
| correct_decision = gt.get(correct_key, gt.get("correct_decision", "request_changes")) |
|
|
| author_responses = task.get("author_responses", []) |
| bug_still_present = correct_decision != ReviewDecision.APPROVE.value |
|
|
| reward = graders.compute_step_reward( |
| action=action, |
| correct_decision=correct_decision, |
| root_cause_keywords=gt.get("root_cause_keywords", []), |
| correct_issue_category=gt.get("correct_issue_category", "logic"), |
| bug_still_present=bug_still_present and action.decision == ReviewDecision.APPROVE, |
| turn=turn, |
| max_turns=task["max_turns"], |
| symptom_only_keywords=gt.get("symptom_only_keywords"), |
| false_fix_keywords=gt.get("false_fix_keywords"), |
| escalation_required=gt.get("escalation_required", False) and correct_decision == ReviewDecision.ESCALATE.value, |
| ) |
| self._rewards.append(reward) |
| t.cumulative_reward = round(sum(self._rewards), 2) |
| t.turn = turn |
| t.review_history.append({"role": "reviewer", "content": f"[{action.decision.value}] {action.comment}"}) |
|
|
| done = ( |
| turn >= task["max_turns"] |
| or action.decision == ReviewDecision.APPROVE |
| or action.decision == ReviewDecision.ESCALATE |
| ) |
| t.done = done |
|
|
| author_resp = None |
| if not done and turn <= len(author_responses): |
| author_resp = author_responses[turn - 1] |
| t.review_history.append({"role": "author", "content": author_resp}) |
| |
| |
| proposed_fix = self._extract_code(author_resp) |
| if proposed_fix: |
| |
| base_code = self._get_base_code(self._initial_diff) |
| self._current_diff = self._generate_unified_diff(base_code, proposed_fix) |
|
|
| if done: |
| final_score = graders.compute_final_score(self._rewards, task["max_turns"]) |
| t.success = final_score >= 0.5 |
| message = f"Episode complete. Final score: {final_score:.3f}" |
| else: |
| message = "Author has responded. Re-read the diff. Has the actual root cause been addressed, or just the symptom?" |
|
|
| return PRObservation( |
| turn=turn, |
| diff=self._current_diff, |
| pr_title=task["pr_title"], |
| pr_description=task["pr_description"], |
| review_history=list(t.review_history), |
| author_response=author_resp, |
| done=done, |
| message=message, |
| ), reward, done, {"episode_id": t.episode_id, "task": t.task_name} |
|
|
| def state(self) -> PRState: |
| return self._state |
|
|
| def get_rewards(self): |
| return self._rewards |
|
|
| def get_final_score(self): |
| return graders.compute_final_score(self._rewards, self._task["max_turns"]) |
|
|