File size: 6,781 Bytes
13507d6
8d8959a
 
13507d6
0c8a432
13507d6
 
 
 
 
 
0c8a432
13507d6
 
 
 
 
 
 
8d8959a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13507d6
 
 
 
8d8959a
 
 
13507d6
 
 
 
 
 
 
 
 
 
 
 
8d8959a
13507d6
 
 
 
 
7d08a88
13507d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dcee3d3
13507d6
 
 
 
7d08a88
 
13507d6
 
 
7d08a88
 
dcee3d3
13507d6
 
 
 
 
 
7d08a88
 
 
 
 
13507d6
 
 
 
 
 
8d8959a
 
 
 
 
 
 
13507d6
 
 
 
 
 
7d08a88
13507d6
 
 
8d8959a
13507d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import uuid
import difflib
import re
from models import PRAction, PRObservation, PRState, ReviewDecision
from server.tasks import single_pass, iterative, escalation, custom
from server import graders

TASKS = {
    "single-pass-review": single_pass.TASK,
    "iterative-negotiation": iterative.TASK,
    "escalation-judgment": escalation.TASK,
    "custom-review": custom.TASK,
}

class PRReviewEnvironment:
    def __init__(self):
        self._state = None
        self._task = None
        self._rewards = []
        self._current_diff = ""
        self._initial_diff = ""

    def _extract_code(self, text: str) -> str:
        """Extracts python code from markdown triple backticks if present."""
        match = re.search(r"```python\n(.*?)\n```", text, re.DOTALL)
        if match:
            return match.group(1).strip()
        # Fallback to any backticks
        match = re.search(r"```\n(.*?)\n```", text, re.DOTALL)
        if match:
            return match.group(1).strip()
        return None

    def _generate_unified_diff(self, old_code: str, new_code: str, filename: str = "file.py") -> str:
        """Generates a standard unified diff string between two versions of code."""
        old_lines = old_code.splitlines(keepends=True)
        new_lines = new_code.splitlines(keepends=True)
        diff = difflib.unified_diff(
            old_lines, new_lines, 
            fromfile=f"a/{filename}", tofile=f"b/{filename}"
        )
        return "".join(diff)

    def _get_base_code(self, diff_text: str) -> str:
        """Heuristic to extract the 'result' of a diff or just the text if it's a snippet."""
        if not any(x in diff_text for x in ["--- ", "+++ ", "@@ "]):
            return diff_text.strip()
        
        # If it's a real diff, we try to reconstruct the NEW state (all context + all additions)
        lines = diff_text.splitlines()
        result_lines = []
        for l in lines:
            if l.startswith("--- ") or l.startswith("+++ ") or l.startswith("@@ ") or l.startswith("index "):
                continue
            if l.startswith("-"):
                continue
            if l.startswith("+"):
                result_lines.append(l[1:])
            elif l.startswith(" "):
                result_lines.append(l[1:])
            else:
                result_lines.append(l)
        return "\n".join(result_lines).strip()

    def reset(self, task_name: str = "single-pass-review") -> PRObservation:
        self._task = TASKS[task_name]
        self._rewards = []
        self._initial_diff = self._task["diff"]
        self._current_diff = self._task["diff"]
        
        self._state = PRState(
            episode_id=str(uuid.uuid4()),
            task_name=task_name,
            turn=0,
            max_turns=self._task["max_turns"],
            review_history=[],
            done=False,
            success=False,
            cumulative_reward=0.0,
        )
        return PRObservation(
            turn=0,
            diff=self._current_diff,
            pr_title=self._task["pr_title"],
            pr_description=self._task["pr_description"],
            review_history=[],
            author_response=None,
            done=False,
            message="New PR ready for review. Read the diff carefully. Identify the root cause of any issues, not just the symptom. Submit your decision.",
        )

    def step(self, action: PRAction) -> tuple[PRObservation, float, bool, dict]:
        assert self._state is not None, "Call reset() first"
        assert not self._state.done, "Episode is already done"

        t = self._state
        task = self._task
        gt = task["ground_truth"]
        turn = t.turn + 1

        correct_key = f"correct_decision_turn_{turn}" if f"correct_decision_turn_{turn}" in gt else "correct_decision"
        correct_decision = gt.get(correct_key, gt.get("correct_decision", "request_changes"))

        author_responses = task.get("author_responses", [])
        bug_still_present = correct_decision != ReviewDecision.APPROVE.value

        reward = graders.compute_step_reward(
            action=action,
            correct_decision=correct_decision,
            root_cause_keywords=gt.get("root_cause_keywords", []),
            correct_issue_category=gt.get("correct_issue_category", "logic"),
            bug_still_present=bug_still_present and action.decision == ReviewDecision.APPROVE,
            turn=turn,
            max_turns=task["max_turns"],
            symptom_only_keywords=gt.get("symptom_only_keywords"),
            false_fix_keywords=gt.get("false_fix_keywords"),
            escalation_required=gt.get("escalation_required", False) and correct_decision == ReviewDecision.ESCALATE.value,
        )
        self._rewards.append(reward)
        t.cumulative_reward = round(sum(self._rewards), 2)
        t.turn = turn
        t.review_history.append({"role": "reviewer", "content": f"[{action.decision.value}] {action.comment}"})

        done = (
            turn >= task["max_turns"]
            or action.decision == ReviewDecision.APPROVE
            or action.decision == ReviewDecision.ESCALATE
        )
        t.done = done

        author_resp = None
        if not done and turn <= len(author_responses):
            author_resp = author_responses[turn - 1]
            t.review_history.append({"role": "author", "content": author_resp})
            
            # --- DYNAMIC DIFF INJECTION ---
            proposed_fix = self._extract_code(author_resp)
            if proposed_fix:
                # Compare the fix against the INITIAL buggy state to generate a fresh Red/Green diff
                base_code = self._get_base_code(self._initial_diff)
                self._current_diff = self._generate_unified_diff(base_code, proposed_fix)

        if done:
            final_score = graders.compute_final_score(self._rewards, task["max_turns"])
            t.success = final_score >= 0.5
            message = f"Episode complete. Final score: {final_score:.3f}"
        else:
            message = "Author has responded. Re-read the diff. Has the actual root cause been addressed, or just the symptom?"

        return PRObservation(
            turn=turn,
            diff=self._current_diff,
            pr_title=task["pr_title"],
            pr_description=task["pr_description"],
            review_history=list(t.review_history),
            author_response=author_resp,
            done=done,
            message=message,
        ), reward, done, {"episode_id": t.episode_id, "task": t.task_name}

    def state(self) -> PRState:
        return self._state

    def get_rewards(self):
        return self._rewards

    def get_final_score(self):
        return graders.compute_final_score(self._rewards, self._task["max_turns"])