RL-Improvement / src /env.py
Rushhaabhhh's picture
RL-Improvement-v1
5719d11 verified
"""PR review simulation environment (gym-style reset/step API)."""
from __future__ import annotations
import glob
import json
import os
import random
from typing import Optional
from .grader import check_comment, grade
from .models import PRReviewAction, PRReviewObservation, PRReviewReward
_BUG_POOL = 0.68
_FALSE_POS = 0.01
_DECISION_CORRECT = 0.31
_DECISION_WRONG = 0.01
_NEUTRAL_EXPLORATION = 0.01
_SCENARIOS_DIR = os.path.join(os.path.dirname(__file__), "..", "data", "scenarios")
TASK_PREFIXES = {"easy": "easy_", "medium": "medium_", "hard": "hard_"}
# Increased max steps slightly to allow room for navigating files
TASK_MAX_STEPS = {"easy": 8, "medium": 15, "hard": 20}
TASK_THRESHOLDS = {"easy": 0.7, "medium": 0.6, "hard": 0.5}
def clamp_value(v: float) -> float:
return round(max(0.01, min(0.99, float(v))), 4)
def _load_all() -> dict[str, dict]:
paths = glob.glob(os.path.join(_SCENARIOS_DIR, "*.json"))
if not paths:
raise RuntimeError(f"No scenario JSON files found in {_SCENARIOS_DIR}")
store: dict[str, dict] = {}
for path in sorted(paths):
sid = os.path.splitext(os.path.basename(path))[0]
with open(path, encoding="utf-8") as f:
data = json.load(f)
# Ensure new schema fields exist
for field in ("pr_title", "pr_description", "diff", "ground_truth", "repo_files"):
if field not in data:
raise ValueError(f"Scenario '{sid}' missing field '{field}'")
store[sid] = data
return store
_STORE: dict[str, dict] = _load_all()
class PRReviewEnv:
def __init__(self, task: str = "easy") -> None:
if task not in TASK_PREFIXES:
raise ValueError(f"Unknown task '{task}'. Valid: {sorted(TASK_PREFIXES)}")
self.task = task
self.max_steps: int = TASK_MAX_STEPS[task]
self.threshold: float = TASK_THRESHOLDS[task]
self._scenario_id: Optional[str] = None
self._scenario: Optional[dict] = None
self._comments: list[dict] = []
self._step_count: int = 0
self._done: bool = False
self._score: Optional[float] = None
self._rewarded_bugs: set[int] = set()
# Navigation state
self._current_file_path: str = ""
self._current_file_content: str = ""
def reset(self) -> PRReviewObservation:
prefix = TASK_PREFIXES[self.task]
candidates = [sid for sid in _STORE if sid.startswith(prefix)]
if not candidates:
raise RuntimeError(f"No scenarios with prefix '{prefix}'")
self._scenario_id = random.choice(candidates)
self._scenario = _STORE[self._scenario_id]
self._comments = []
self._step_count = 0
self._done = False
self._score = None
self._rewarded_bugs = set()
self._current_file_path = ""
self._current_file_content = ""
return self._obs()
def step(self, action: PRReviewAction) -> tuple[PRReviewObservation, PRReviewReward, bool, dict]:
if self._scenario is None:
raise RuntimeError("Call reset() before step().")
if self._done:
raise RuntimeError("Episode done. Call reset() to start a new one.")
if self._step_count >= self.max_steps:
return self._terminal_step("reject")
self._step_count += 1
# Handle Directory Traversal
if action.action_type == "read_file":
repo_files = self._scenario.get("repo_files", {})
target_file = action.file or ""
if target_file in repo_files:
self._current_file_path = target_file
self._current_file_content = repo_files[target_file]
else:
self._current_file_content = f"Error: File '{target_file}' not found."
# Neutral reward for exploring
return self._obs(), PRReviewReward(value=clamp_value(_NEUTRAL_EXPLORATION)), False, {}
# Handle Spatial Comments
if action.action_type == "comment":
reward_val = self._comment_reward(action.body, action.file, action.line)
if action.body:
self._comments.append({
"file": action.file,
"line": action.line,
"body": action.body
})
return self._obs(), PRReviewReward(value=clamp_value(reward_val)), False, {}
# Handle Decisions
if action.action_type in ("approve", "request_changes"):
decision = "approve" if action.action_type == "approve" else "reject"
return self._terminal_step(decision)
raise ValueError(f"Unknown action_type '{action.action_type}'.")
def state(self) -> dict:
return {
"task": self.task,
"scenario_id": self._scenario_id,
"step_count": self._step_count,
"max_steps": self.max_steps,
"done": self._done,
"score": self._score,
"comments": list(self._comments),
}
def _obs(self) -> PRReviewObservation:
assert self._scenario is not None
repo_files = self._scenario.get("repo_files", {})
return PRReviewObservation(
diff=self._scenario["diff"],
pr_description=self._scenario["pr_description"],
pr_title=self._scenario["pr_title"],
file_tree=list(repo_files.keys()),
current_file_path=self._current_file_path,
current_file_content=self._current_file_content,
comments_so_far=self._comments,
step_count=self._step_count,
done=self._done,
scenario_id=self._scenario_id or "",
)
def _comment_reward(self, body: str, file: Optional[str], line: Optional[int]) -> float:
if not body or not file or line is None:
return _FALSE_POS
assert self._scenario is not None
bugs: list = self._scenario["ground_truth"].get("bugs", [])
if not bugs:
return _FALSE_POS
newly_found = [i for i in check_comment(body, file, line, bugs) if i not in self._rewarded_bugs]
if newly_found:
per_bug = _BUG_POOL / len(bugs)
self._rewarded_bugs.update(newly_found)
return len(newly_found) * per_bug
return _FALSE_POS
def _terminal_step(self, decision: str) -> tuple[PRReviewObservation, PRReviewReward, bool, dict]:
assert self._scenario is not None
result = grade(
ground_truth=self._scenario["ground_truth"],
comments=self._comments,
decision=decision,
)
self._done = True
self._score = clamp_value(result["score"])
result["score"] = self._score
decision_reward = _DECISION_CORRECT if result["decision_correct"] else _DECISION_WRONG
return self._obs(), PRReviewReward(value=clamp_value(decision_reward), breakdown=result), True, result