Spaces:
Sleeping
Sleeping
File size: 5,643 Bytes
8d96200 7e9c2fa 80454a1 8d96200 7e9c2fa 8d96200 7e9c2fa d60a64c 8d96200 80454a1 8d96200 7e9c2fa 8d96200 7e9c2fa d60a64c 8d96200 7e9c2fa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | """PR review simulation environment (gym-style reset/step API)."""
from __future__ import annotations
import glob
import json
import os
import random
from typing import Optional
from .grader import check_comment, grade
from .models import PRReviewAction, PRReviewObservation, PRReviewReward
_BUG_POOL = 0.68
_FALSE_POS = 0.02
_DECISION_CORRECT = 0.31
_DECISION_WRONG = 0.02
_SCENARIOS_DIR = os.path.join(os.path.dirname(__file__), "..", "data", "scenarios")
TASK_PREFIXES = {"easy": "easy_", "medium": "medium_", "hard": "hard_"}
TASK_MAX_STEPS = {"easy": 5, "medium": 10, "hard": 15}
TASK_THRESHOLDS = {"easy": 0.7, "medium": 0.6, "hard": 0.5}
def clamp_value(v: float) -> float:
"""Ensure values are strictly within (0, 1)."""
return round(max(0.02, min(0.98, float(v))), 4)
def _load_all() -> dict[str, dict]:
paths = glob.glob(os.path.join(_SCENARIOS_DIR, "*.json"))
if not paths:
raise RuntimeError(f"No scenario JSON files found in {_SCENARIOS_DIR}")
store: dict[str, dict] = {}
for path in sorted(paths):
sid = os.path.splitext(os.path.basename(path))[0]
with open(path, encoding="utf-8") as f:
data = json.load(f)
for field in ("pr_title", "pr_description", "diff", "ground_truth"):
if field not in data:
raise ValueError(f"Scenario '{sid}' missing field '{field}'")
store[sid] = data
return store
_STORE: dict[str, dict] = _load_all()
class PRReviewEnv:
def __init__(self, task: str = "easy") -> None:
if task not in TASK_PREFIXES:
raise ValueError(f"Unknown task '{task}'. Valid: {sorted(TASK_PREFIXES)}")
self.task = task
self.max_steps: int = TASK_MAX_STEPS[task]
self.threshold: float = TASK_THRESHOLDS[task]
self._scenario_id: Optional[str] = None
self._scenario: Optional[dict] = None
self._comments: list[str] = []
self._step_count: int = 0
self._done: bool = False
self._score: Optional[float] = None
self._rewarded_bugs: set[int] = set()
def reset(self) -> PRReviewObservation:
prefix = TASK_PREFIXES[self.task]
candidates = [sid for sid in _STORE if sid.startswith(prefix)]
if not candidates:
raise RuntimeError(f"No scenarios with prefix '{prefix}'")
self._scenario_id = random.choice(candidates)
self._scenario = _STORE[self._scenario_id]
self._comments = []
self._step_count = 0
self._done = False
self._score = None
self._rewarded_bugs = set()
return self._obs()
def step(self, action: PRReviewAction) -> tuple[PRReviewObservation, PRReviewReward, bool, dict]:
if self._scenario is None:
raise RuntimeError("Call reset() before step().")
if self._done:
raise RuntimeError("Episode done. Call reset() to start a new one.")
if self._step_count >= self.max_steps:
return self._terminal_step("reject")
self._step_count += 1
if action.action_type == "comment":
reward_val = self._comment_reward(action.body)
if action.body:
self._comments.append(action.body)
clipped = clamp_value(reward_val)
return self._obs(), PRReviewReward(value=clipped), False, {}
if action.action_type in ("approve", "request_changes"):
decision = "approve" if action.action_type == "approve" else "reject"
return self._terminal_step(decision)
raise ValueError(f"Unknown action_type '{action.action_type}'.")
def state(self) -> dict:
return {
"task": self.task,
"scenario_id": self._scenario_id,
"step_count": self._step_count,
"max_steps": self.max_steps,
"done": self._done,
"score": self._score,
"comments": list(self._comments),
}
def _obs(self) -> PRReviewObservation:
assert self._scenario is not None
return PRReviewObservation(
diff=self._scenario["diff"],
pr_description=self._scenario["pr_description"],
pr_title=self._scenario["pr_title"],
comments_so_far=[{"body": c} for c in self._comments],
step_count=self._step_count,
done=self._done,
scenario_id=self._scenario_id or "",
)
def _comment_reward(self, body: str) -> float:
if not body:
return _FALSE_POS
assert self._scenario is not None
bugs: list = self._scenario["ground_truth"].get("bugs", [])
if not bugs:
return _FALSE_POS
newly_found = [i for i in check_comment(body, bugs) if i not in self._rewarded_bugs]
if newly_found:
per_bug = _BUG_POOL / len(bugs)
self._rewarded_bugs.update(newly_found)
return len(newly_found) * per_bug
return _FALSE_POS
def _terminal_step(self, decision: str) -> tuple[PRReviewObservation, PRReviewReward, bool, dict]:
assert self._scenario is not None
result = grade(
ground_truth=self._scenario["ground_truth"],
comments=self._comments,
decision=decision,
)
self._done = True
self._score = clamp_value(result["score"])
result["score"] = self._score
decision_reward = _DECISION_CORRECT if result["decision_correct"] else _DECISION_WRONG
clipped_reward = clamp_value(decision_reward)
return self._obs(), PRReviewReward(value=clipped_reward, breakdown=result), True, result |