File size: 6,781 Bytes
13507d6 8d8959a 13507d6 0c8a432 13507d6 0c8a432 13507d6 8d8959a 13507d6 8d8959a 13507d6 8d8959a 13507d6 7d08a88 13507d6 dcee3d3 13507d6 7d08a88 13507d6 7d08a88 dcee3d3 13507d6 7d08a88 13507d6 8d8959a 13507d6 7d08a88 13507d6 8d8959a 13507d6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 | import uuid
import difflib
import re
from models import PRAction, PRObservation, PRState, ReviewDecision
from server.tasks import single_pass, iterative, escalation, custom
from server import graders
TASKS = {
"single-pass-review": single_pass.TASK,
"iterative-negotiation": iterative.TASK,
"escalation-judgment": escalation.TASK,
"custom-review": custom.TASK,
}
class PRReviewEnvironment:
def __init__(self):
self._state = None
self._task = None
self._rewards = []
self._current_diff = ""
self._initial_diff = ""
def _extract_code(self, text: str) -> str:
"""Extracts python code from markdown triple backticks if present."""
match = re.search(r"```python\n(.*?)\n```", text, re.DOTALL)
if match:
return match.group(1).strip()
# Fallback to any backticks
match = re.search(r"```\n(.*?)\n```", text, re.DOTALL)
if match:
return match.group(1).strip()
return None
def _generate_unified_diff(self, old_code: str, new_code: str, filename: str = "file.py") -> str:
"""Generates a standard unified diff string between two versions of code."""
old_lines = old_code.splitlines(keepends=True)
new_lines = new_code.splitlines(keepends=True)
diff = difflib.unified_diff(
old_lines, new_lines,
fromfile=f"a/{filename}", tofile=f"b/{filename}"
)
return "".join(diff)
def _get_base_code(self, diff_text: str) -> str:
"""Heuristic to extract the 'result' of a diff or just the text if it's a snippet."""
if not any(x in diff_text for x in ["--- ", "+++ ", "@@ "]):
return diff_text.strip()
# If it's a real diff, we try to reconstruct the NEW state (all context + all additions)
lines = diff_text.splitlines()
result_lines = []
for l in lines:
if l.startswith("--- ") or l.startswith("+++ ") or l.startswith("@@ ") or l.startswith("index "):
continue
if l.startswith("-"):
continue
if l.startswith("+"):
result_lines.append(l[1:])
elif l.startswith(" "):
result_lines.append(l[1:])
else:
result_lines.append(l)
return "\n".join(result_lines).strip()
def reset(self, task_name: str = "single-pass-review") -> PRObservation:
self._task = TASKS[task_name]
self._rewards = []
self._initial_diff = self._task["diff"]
self._current_diff = self._task["diff"]
self._state = PRState(
episode_id=str(uuid.uuid4()),
task_name=task_name,
turn=0,
max_turns=self._task["max_turns"],
review_history=[],
done=False,
success=False,
cumulative_reward=0.0,
)
return PRObservation(
turn=0,
diff=self._current_diff,
pr_title=self._task["pr_title"],
pr_description=self._task["pr_description"],
review_history=[],
author_response=None,
done=False,
message="New PR ready for review. Read the diff carefully. Identify the root cause of any issues, not just the symptom. Submit your decision.",
)
def step(self, action: PRAction) -> tuple[PRObservation, float, bool, dict]:
assert self._state is not None, "Call reset() first"
assert not self._state.done, "Episode is already done"
t = self._state
task = self._task
gt = task["ground_truth"]
turn = t.turn + 1
correct_key = f"correct_decision_turn_{turn}" if f"correct_decision_turn_{turn}" in gt else "correct_decision"
correct_decision = gt.get(correct_key, gt.get("correct_decision", "request_changes"))
author_responses = task.get("author_responses", [])
bug_still_present = correct_decision != ReviewDecision.APPROVE.value
reward = graders.compute_step_reward(
action=action,
correct_decision=correct_decision,
root_cause_keywords=gt.get("root_cause_keywords", []),
correct_issue_category=gt.get("correct_issue_category", "logic"),
bug_still_present=bug_still_present and action.decision == ReviewDecision.APPROVE,
turn=turn,
max_turns=task["max_turns"],
symptom_only_keywords=gt.get("symptom_only_keywords"),
false_fix_keywords=gt.get("false_fix_keywords"),
escalation_required=gt.get("escalation_required", False) and correct_decision == ReviewDecision.ESCALATE.value,
)
self._rewards.append(reward)
t.cumulative_reward = round(sum(self._rewards), 2)
t.turn = turn
t.review_history.append({"role": "reviewer", "content": f"[{action.decision.value}] {action.comment}"})
done = (
turn >= task["max_turns"]
or action.decision == ReviewDecision.APPROVE
or action.decision == ReviewDecision.ESCALATE
)
t.done = done
author_resp = None
if not done and turn <= len(author_responses):
author_resp = author_responses[turn - 1]
t.review_history.append({"role": "author", "content": author_resp})
# --- DYNAMIC DIFF INJECTION ---
proposed_fix = self._extract_code(author_resp)
if proposed_fix:
# Compare the fix against the INITIAL buggy state to generate a fresh Red/Green diff
base_code = self._get_base_code(self._initial_diff)
self._current_diff = self._generate_unified_diff(base_code, proposed_fix)
if done:
final_score = graders.compute_final_score(self._rewards, task["max_turns"])
t.success = final_score >= 0.5
message = f"Episode complete. Final score: {final_score:.3f}"
else:
message = "Author has responded. Re-read the diff. Has the actual root cause been addressed, or just the symptom?"
return PRObservation(
turn=turn,
diff=self._current_diff,
pr_title=task["pr_title"],
pr_description=task["pr_description"],
review_history=list(t.review_history),
author_response=author_resp,
done=done,
message=message,
), reward, done, {"episode_id": t.episode_id, "task": t.task_name}
def state(self) -> PRState:
return self._state
def get_rewards(self):
return self._rewards
def get_final_score(self):
return graders.compute_final_score(self._rewards, self._task["max_turns"])
|