meta-hackathon / src /env.py
Rushhaabhhh's picture
Fixed range values and formatting
7e9c2fa verified
"""PR review simulation environment (gym-style reset/step API)."""
from __future__ import annotations
import glob
import json
import os
import random
from typing import Optional
from .grader import check_comment, grade
from .models import PRReviewAction, PRReviewObservation, PRReviewReward
_BUG_POOL = 0.68
_FALSE_POS = 0.02
_DECISION_CORRECT = 0.31
_DECISION_WRONG = 0.02
_SCENARIOS_DIR = os.path.join(os.path.dirname(__file__), "..", "data", "scenarios")
TASK_PREFIXES = {"easy": "easy_", "medium": "medium_", "hard": "hard_"}
TASK_MAX_STEPS = {"easy": 5, "medium": 10, "hard": 15}
TASK_THRESHOLDS = {"easy": 0.7, "medium": 0.6, "hard": 0.5}
def clamp_value(v: float) -> float:
"""Ensure values are strictly within (0, 1)."""
return round(max(0.02, min(0.98, float(v))), 4)
def _load_all() -> dict[str, dict]:
paths = glob.glob(os.path.join(_SCENARIOS_DIR, "*.json"))
if not paths:
raise RuntimeError(f"No scenario JSON files found in {_SCENARIOS_DIR}")
store: dict[str, dict] = {}
for path in sorted(paths):
sid = os.path.splitext(os.path.basename(path))[0]
with open(path, encoding="utf-8") as f:
data = json.load(f)
for field in ("pr_title", "pr_description", "diff", "ground_truth"):
if field not in data:
raise ValueError(f"Scenario '{sid}' missing field '{field}'")
store[sid] = data
return store
_STORE: dict[str, dict] = _load_all()
class PRReviewEnv:
def __init__(self, task: str = "easy") -> None:
if task not in TASK_PREFIXES:
raise ValueError(f"Unknown task '{task}'. Valid: {sorted(TASK_PREFIXES)}")
self.task = task
self.max_steps: int = TASK_MAX_STEPS[task]
self.threshold: float = TASK_THRESHOLDS[task]
self._scenario_id: Optional[str] = None
self._scenario: Optional[dict] = None
self._comments: list[str] = []
self._step_count: int = 0
self._done: bool = False
self._score: Optional[float] = None
self._rewarded_bugs: set[int] = set()
def reset(self) -> PRReviewObservation:
prefix = TASK_PREFIXES[self.task]
candidates = [sid for sid in _STORE if sid.startswith(prefix)]
if not candidates:
raise RuntimeError(f"No scenarios with prefix '{prefix}'")
self._scenario_id = random.choice(candidates)
self._scenario = _STORE[self._scenario_id]
self._comments = []
self._step_count = 0
self._done = False
self._score = None
self._rewarded_bugs = set()
return self._obs()
def step(self, action: PRReviewAction) -> tuple[PRReviewObservation, PRReviewReward, bool, dict]:
if self._scenario is None:
raise RuntimeError("Call reset() before step().")
if self._done:
raise RuntimeError("Episode done. Call reset() to start a new one.")
if self._step_count >= self.max_steps:
return self._terminal_step("reject")
self._step_count += 1
if action.action_type == "comment":
reward_val = self._comment_reward(action.body)
if action.body:
self._comments.append(action.body)
clipped = clamp_value(reward_val)
return self._obs(), PRReviewReward(value=clipped), False, {}
if action.action_type in ("approve", "request_changes"):
decision = "approve" if action.action_type == "approve" else "reject"
return self._terminal_step(decision)
raise ValueError(f"Unknown action_type '{action.action_type}'.")
def state(self) -> dict:
return {
"task": self.task,
"scenario_id": self._scenario_id,
"step_count": self._step_count,
"max_steps": self.max_steps,
"done": self._done,
"score": self._score,
"comments": list(self._comments),
}
def _obs(self) -> PRReviewObservation:
assert self._scenario is not None
return PRReviewObservation(
diff=self._scenario["diff"],
pr_description=self._scenario["pr_description"],
pr_title=self._scenario["pr_title"],
comments_so_far=[{"body": c} for c in self._comments],
step_count=self._step_count,
done=self._done,
scenario_id=self._scenario_id or "",
)
def _comment_reward(self, body: str) -> float:
if not body:
return _FALSE_POS
assert self._scenario is not None
bugs: list = self._scenario["ground_truth"].get("bugs", [])
if not bugs:
return _FALSE_POS
newly_found = [i for i in check_comment(body, bugs) if i not in self._rewarded_bugs]
if newly_found:
per_bug = _BUG_POOL / len(bugs)
self._rewarded_bugs.update(newly_found)
return len(newly_found) * per_bug
return _FALSE_POS
def _terminal_step(self, decision: str) -> tuple[PRReviewObservation, PRReviewReward, bool, dict]:
assert self._scenario is not None
result = grade(
ground_truth=self._scenario["ground_truth"],
comments=self._comments,
decision=decision,
)
self._done = True
self._score = clamp_value(result["score"])
result["score"] = self._score
decision_reward = _DECISION_CORRECT if result["decision_correct"] else _DECISION_WRONG
clipped_reward = clamp_value(decision_reward)
return self._obs(), PRReviewReward(value=clipped_reward, breakdown=result), True, result