"""PR review simulation environment (gym-style reset/step API).""" from __future__ import annotations import glob import json import os import random from typing import Optional from .grader import check_comment, grade from .models import PRReviewAction, PRReviewObservation, PRReviewReward _BUG_POOL = 0.68 _FALSE_POS = 0.02 _DECISION_CORRECT = 0.31 _DECISION_WRONG = 0.02 _SCENARIOS_DIR = os.path.join(os.path.dirname(__file__), "..", "data", "scenarios") TASK_PREFIXES = {"easy": "easy_", "medium": "medium_", "hard": "hard_"} TASK_MAX_STEPS = {"easy": 5, "medium": 10, "hard": 15} TASK_THRESHOLDS = {"easy": 0.7, "medium": 0.6, "hard": 0.5} def clamp_value(v: float) -> float: """Ensure values are strictly within (0, 1).""" return round(max(0.02, min(0.98, float(v))), 4) def _load_all() -> dict[str, dict]: paths = glob.glob(os.path.join(_SCENARIOS_DIR, "*.json")) if not paths: raise RuntimeError(f"No scenario JSON files found in {_SCENARIOS_DIR}") store: dict[str, dict] = {} for path in sorted(paths): sid = os.path.splitext(os.path.basename(path))[0] with open(path, encoding="utf-8") as f: data = json.load(f) for field in ("pr_title", "pr_description", "diff", "ground_truth"): if field not in data: raise ValueError(f"Scenario '{sid}' missing field '{field}'") store[sid] = data return store _STORE: dict[str, dict] = _load_all() class PRReviewEnv: def __init__(self, task: str = "easy") -> None: if task not in TASK_PREFIXES: raise ValueError(f"Unknown task '{task}'. Valid: {sorted(TASK_PREFIXES)}") self.task = task self.max_steps: int = TASK_MAX_STEPS[task] self.threshold: float = TASK_THRESHOLDS[task] self._scenario_id: Optional[str] = None self._scenario: Optional[dict] = None self._comments: list[str] = [] self._step_count: int = 0 self._done: bool = False self._score: Optional[float] = None self._rewarded_bugs: set[int] = set() def reset(self) -> PRReviewObservation: prefix = TASK_PREFIXES[self.task] candidates = [sid for sid in _STORE if sid.startswith(prefix)] if not candidates: raise RuntimeError(f"No scenarios with prefix '{prefix}'") self._scenario_id = random.choice(candidates) self._scenario = _STORE[self._scenario_id] self._comments = [] self._step_count = 0 self._done = False self._score = None self._rewarded_bugs = set() return self._obs() def step(self, action: PRReviewAction) -> tuple[PRReviewObservation, PRReviewReward, bool, dict]: if self._scenario is None: raise RuntimeError("Call reset() before step().") if self._done: raise RuntimeError("Episode done. Call reset() to start a new one.") if self._step_count >= self.max_steps: return self._terminal_step("reject") self._step_count += 1 if action.action_type == "comment": reward_val = self._comment_reward(action.body) if action.body: self._comments.append(action.body) clipped = clamp_value(reward_val) return self._obs(), PRReviewReward(value=clipped), False, {} if action.action_type in ("approve", "request_changes"): decision = "approve" if action.action_type == "approve" else "reject" return self._terminal_step(decision) raise ValueError(f"Unknown action_type '{action.action_type}'.") def state(self) -> dict: return { "task": self.task, "scenario_id": self._scenario_id, "step_count": self._step_count, "max_steps": self.max_steps, "done": self._done, "score": self._score, "comments": list(self._comments), } def _obs(self) -> PRReviewObservation: assert self._scenario is not None return PRReviewObservation( diff=self._scenario["diff"], pr_description=self._scenario["pr_description"], pr_title=self._scenario["pr_title"], comments_so_far=[{"body": c} for c in self._comments], step_count=self._step_count, done=self._done, scenario_id=self._scenario_id or "", ) def _comment_reward(self, body: str) -> float: if not body: return _FALSE_POS assert self._scenario is not None bugs: list = self._scenario["ground_truth"].get("bugs", []) if not bugs: return _FALSE_POS newly_found = [i for i in check_comment(body, bugs) if i not in self._rewarded_bugs] if newly_found: per_bug = _BUG_POOL / len(bugs) self._rewarded_bugs.update(newly_found) return len(newly_found) * per_bug return _FALSE_POS def _terminal_step(self, decision: str) -> tuple[PRReviewObservation, PRReviewReward, bool, dict]: assert self._scenario is not None result = grade( ground_truth=self._scenario["ground_truth"], comments=self._comments, decision=decision, ) self._done = True self._score = clamp_value(result["score"]) result["score"] = self._score decision_reward = _DECISION_CORRECT if result["decision_correct"] else _DECISION_WRONG clipped_reward = clamp_value(decision_reward) return self._obs(), PRReviewReward(value=clipped_reward, breakdown=result), True, result