File size: 5,643 Bytes
8d96200
 
 
 
 
 
 
 
 
 
 
 
 
7e9c2fa
 
80454a1
 
8d96200
 
 
 
 
 
 
7e9c2fa
 
 
8d96200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e9c2fa
d60a64c
8d96200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80454a1
8d96200
 
 
7e9c2fa
8d96200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e9c2fa
d60a64c
8d96200
7e9c2fa
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""PR review simulation environment (gym-style reset/step API)."""

from __future__ import annotations

import glob
import json
import os
import random
from typing import Optional

from .grader import check_comment, grade
from .models import PRReviewAction, PRReviewObservation, PRReviewReward

_BUG_POOL = 0.68      
_FALSE_POS = 0.02     
_DECISION_CORRECT = 0.31
_DECISION_WRONG = 0.02

_SCENARIOS_DIR = os.path.join(os.path.dirname(__file__), "..", "data", "scenarios")

TASK_PREFIXES = {"easy": "easy_", "medium": "medium_", "hard": "hard_"}
TASK_MAX_STEPS = {"easy": 5, "medium": 10, "hard": 15}
TASK_THRESHOLDS = {"easy": 0.7, "medium": 0.6, "hard": 0.5}

def clamp_value(v: float) -> float:
    """Ensure values are strictly within (0, 1)."""
    return round(max(0.02, min(0.98, float(v))), 4)

def _load_all() -> dict[str, dict]:
    paths = glob.glob(os.path.join(_SCENARIOS_DIR, "*.json"))
    if not paths:
        raise RuntimeError(f"No scenario JSON files found in {_SCENARIOS_DIR}")
    store: dict[str, dict] = {}
    for path in sorted(paths):
        sid = os.path.splitext(os.path.basename(path))[0]
        with open(path, encoding="utf-8") as f:
            data = json.load(f)
        for field in ("pr_title", "pr_description", "diff", "ground_truth"):
            if field not in data:
                raise ValueError(f"Scenario '{sid}' missing field '{field}'")
        store[sid] = data
    return store

_STORE: dict[str, dict] = _load_all()

class PRReviewEnv:
    def __init__(self, task: str = "easy") -> None:
        if task not in TASK_PREFIXES:
            raise ValueError(f"Unknown task '{task}'. Valid: {sorted(TASK_PREFIXES)}")
        self.task = task
        self.max_steps: int = TASK_MAX_STEPS[task]
        self.threshold: float = TASK_THRESHOLDS[task]
        self._scenario_id: Optional[str] = None
        self._scenario: Optional[dict] = None
        self._comments: list[str] = []
        self._step_count: int = 0
        self._done: bool = False
        self._score: Optional[float] = None
        self._rewarded_bugs: set[int] = set()

    def reset(self) -> PRReviewObservation:
        prefix = TASK_PREFIXES[self.task]
        candidates = [sid for sid in _STORE if sid.startswith(prefix)]
        if not candidates:
            raise RuntimeError(f"No scenarios with prefix '{prefix}'")
        self._scenario_id = random.choice(candidates)
        self._scenario = _STORE[self._scenario_id]
        self._comments = []
        self._step_count = 0
        self._done = False
        self._score = None
        self._rewarded_bugs = set()
        return self._obs()

    def step(self, action: PRReviewAction) -> tuple[PRReviewObservation, PRReviewReward, bool, dict]:
        if self._scenario is None:
            raise RuntimeError("Call reset() before step().")
        if self._done:
            raise RuntimeError("Episode done. Call reset() to start a new one.")
        if self._step_count >= self.max_steps:
            return self._terminal_step("reject")

        self._step_count += 1

        if action.action_type == "comment":
            reward_val = self._comment_reward(action.body)
            if action.body:
                self._comments.append(action.body)
            clipped = clamp_value(reward_val)
            return self._obs(), PRReviewReward(value=clipped), False, {}

        if action.action_type in ("approve", "request_changes"):
            decision = "approve" if action.action_type == "approve" else "reject"
            return self._terminal_step(decision)

        raise ValueError(f"Unknown action_type '{action.action_type}'.")

    def state(self) -> dict:
        return {
            "task": self.task,
            "scenario_id": self._scenario_id,
            "step_count": self._step_count,
            "max_steps": self.max_steps,
            "done": self._done,
            "score": self._score,
            "comments": list(self._comments),
        }

    def _obs(self) -> PRReviewObservation:
        assert self._scenario is not None
        return PRReviewObservation(
            diff=self._scenario["diff"],
            pr_description=self._scenario["pr_description"],
            pr_title=self._scenario["pr_title"],
            comments_so_far=[{"body": c} for c in self._comments],
            step_count=self._step_count,
            done=self._done,
            scenario_id=self._scenario_id or "",
        )

    def _comment_reward(self, body: str) -> float:
        if not body:
            return _FALSE_POS
        assert self._scenario is not None
        bugs: list = self._scenario["ground_truth"].get("bugs", [])
        if not bugs:
            return _FALSE_POS 
        newly_found = [i for i in check_comment(body, bugs) if i not in self._rewarded_bugs]
        if newly_found:
            per_bug = _BUG_POOL / len(bugs)
            self._rewarded_bugs.update(newly_found)
            return len(newly_found) * per_bug
        return _FALSE_POS

    def _terminal_step(self, decision: str) -> tuple[PRReviewObservation, PRReviewReward, bool, dict]:
        assert self._scenario is not None
        result = grade(
            ground_truth=self._scenario["ground_truth"],
            comments=self._comments,
            decision=decision,
        )
        self._done = True
        self._score = clamp_value(result["score"])
        result["score"] = self._score
        decision_reward = _DECISION_CORRECT if result["decision_correct"] else _DECISION_WRONG
        clipped_reward = clamp_value(decision_reward)
        return self._obs(), PRReviewReward(value=clipped_reward, breakdown=result), True, result