shivam2k3's picture
new updated proj
8f4672f
"""
CodeReviewEnv — OpenEnv environment for AI-driven code review and bug triage.
An agent is given a pull request diff and must:
1. Identify all bugs / issues present
2. Classify each issue by severity (critical / high / medium / low)
3. Suggest a concrete fix for each issue
The environment tracks which issues have been found, rewards partial progress,
and penalizes hallucinated (non-existent) issues.
"""
from __future__ import annotations
import copy
import json
import re
import time
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
from tasks.registry import TASK_REGISTRY
# ---------------------------------------------------------------------------
# Typed OpenEnv models
# ---------------------------------------------------------------------------
class ReviewComment(BaseModel):
"""A single review comment emitted by the agent."""
issue_id: str = Field(..., description="Agent-assigned identifier, e.g. 'bug-1'")
severity: str = Field(..., description="One of: critical, high, medium, low")
description: str = Field(..., description="Human-readable description of the issue")
line_hint: Optional[int] = Field(None, description="Approximate line number in the diff")
fix_suggestion: str = Field(..., description="Concrete fix or recommendation")
class Action(BaseModel):
"""
Agent action. Exactly one of the fields should be non-null per step.
- submit_comment : add a review comment for one issue
- approve : mark the PR as safe to merge (ends episode)
- request_changes: mark the PR as needing changes (ends episode)
- pass_step : do nothing this step (costs a small time penalty)
"""
submit_comment: Optional[ReviewComment] = None
approve: Optional[bool] = None # True → approve PR
request_changes: Optional[bool] = None # True → request changes
pass_step: Optional[bool] = None
class Observation(BaseModel):
"""What the agent sees each step."""
diff: str = Field(..., description="The full PR diff text")
pr_title: str
pr_description: str
step: int
max_steps: int
comments_so_far: List[ReviewComment] = Field(default_factory=list)
last_action_feedback: str = Field("", description="Feedback on last action")
done: bool = False
class Reward(BaseModel):
value: float
breakdown: Dict[str, float] = Field(default_factory=dict)
class EpisodeState(BaseModel):
"""Full internal state (returned by state())."""
task_id: str
step: int
max_steps: int
done: bool
pr_title: str
pr_description: str
diff: str
comments: List[ReviewComment] = Field(default_factory=list)
found_issue_ids: List[str] = Field(default_factory=list) # ground-truth IDs confirmed found
false_positives: int = 0
final_decision: Optional[str] = None # "approve" | "request_changes" | None
cumulative_reward: float = 0.0
started_at: float = Field(default_factory=time.time)
# ---------------------------------------------------------------------------
# Environment
# ---------------------------------------------------------------------------
class CodeReviewEnv:
"""
OpenEnv-compliant code review environment.
Lifecycle
---------
env = CodeReviewEnv(task_id="easy")
obs = env.reset()
while not obs.done:
action = agent_policy(obs)
obs, reward, done, info = env.step(action)
"""
MAX_STEPS = 12
def __init__(self, task_id: str = "easy"):
if task_id not in TASK_REGISTRY:
raise ValueError(f"Unknown task '{task_id}'. Choose from: {list(TASK_REGISTRY)}")
self.task_id = task_id
self._task = TASK_REGISTRY[task_id]
self._state: Optional[EpisodeState] = None
# ------------------------------------------------------------------
# OpenEnv API
# ------------------------------------------------------------------
def reset(self) -> Observation:
"""Start a fresh episode. Returns initial observation."""
self._state = EpisodeState(
task_id=self.task_id,
step=0,
max_steps=self.MAX_STEPS,
done=False,
pr_title=self._task["pr_title"],
pr_description=self._task["pr_description"],
diff=self._task["diff"],
)
return self._make_obs("Welcome. Review the diff and submit comments for each issue you find.")
def step(self, action: Action) -> tuple[Observation, float, bool, dict]:
"""
Execute one agent action.
Returns
-------
observation, reward_value, done, info_dict
"""
if self._state is None:
raise RuntimeError("Call reset() before step()")
if self._state.done:
raise RuntimeError("Episode is done. Call reset() to start a new one.")
s = self._state
s.step += 1
reward_val, breakdown, feedback = self._process_action(action)
s.cumulative_reward += reward_val
# Terminal conditions
if action.approve or action.request_changes:
s.done = True
s.final_decision = "approve" if action.approve else "request_changes"
# Bonus/penalty for correct final decision
bonus, bonus_info = self._final_decision_bonus()
reward_val += bonus
s.cumulative_reward += bonus
breakdown.update(bonus_info)
feedback += f" | Final decision reward: {bonus:+.2f}"
if s.step >= s.max_steps and not s.done:
s.done = True
feedback += " | Max steps reached — episode ended."
obs = self._make_obs(feedback)
info = {
"breakdown": breakdown,
"found_issues": s.found_issue_ids,
"false_positives": s.false_positives,
"cumulative_reward": s.cumulative_reward,
}
return obs, reward_val, s.done, info
def state(self) -> dict:
"""Return full internal state as a dict."""
if self._state is None:
return {}
return self._state.model_dump()
# ------------------------------------------------------------------
# Grader — call after episode ends to get normalised 0-1 score
# ------------------------------------------------------------------
def grade(self) -> float:
"""
Compute a normalised score in [0, 1] for the completed episode.
Score components:
- Issue recall : fraction of ground-truth issues found (50 %)
- Severity accuracy: fraction of found issues with correct severity (20 %)
- Fix quality : keyword-match proxy for fix suggestions (20 %)
- Decision bonus : correct approve/request_changes (10 %)
"""
if self._state is None:
return 0.0
s = self._state
gt_issues = self._task["ground_truth_issues"] # list of dicts
n_gt = len(gt_issues)
if n_gt == 0:
return 1.0
found_ids = set(s.found_issue_ids)
# --- recall ---
recall = len(found_ids) / n_gt
# --- severity accuracy & fix quality ---
severity_hits = 0
fix_hits = 0
for gt in gt_issues:
if gt["id"] not in found_ids:
continue
# find the agent comment that matched
for c in s.comments:
if _comment_matches_issue(c, gt):
if c.severity.lower() == gt["severity"].lower():
severity_hits += 1
# keyword check for fix quality
fix_kws = gt.get("fix_keywords", [])
if fix_kws:
agent_fix = (c.fix_suggestion + " " + c.description).lower()
if any(kw.lower() in agent_fix for kw in fix_kws):
fix_hits += 1
else:
fix_hits += 1 # no keywords required → full credit
break
sev_score = severity_hits / n_gt
fix_score = fix_hits / n_gt
# --- false positive penalty ---
fp_penalty = min(0.3, s.false_positives * 0.05)
# --- decision bonus ---
expected_decision = self._task.get("expected_decision", "request_changes")
dec_score = 0.0
if s.final_decision == expected_decision:
dec_score = 1.0
elif s.final_decision is not None:
dec_score = 0.0
else:
dec_score = 0.0 # no decision made
raw = (
0.50 * recall
+ 0.20 * sev_score
+ 0.20 * fix_score
+ 0.10 * dec_score
- fp_penalty
)
return float(max(0.0, min(1.0, raw)))
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _make_obs(self, feedback: str) -> Observation:
s = self._state
return Observation(
diff=s.diff,
pr_title=s.pr_title,
pr_description=s.pr_description,
step=s.step,
max_steps=s.max_steps,
comments_so_far=list(s.comments),
last_action_feedback=feedback,
done=s.done,
)
def _process_action(self, action: Action) -> tuple[float, dict, str]:
s = self._state
reward = 0.0
breakdown: dict[str, float] = {}
feedback = ""
if action.pass_step:
reward = -0.02 # small cost for wasting a step
breakdown["pass_penalty"] = reward
feedback = "Pass noted. No progress made this step."
elif action.submit_comment:
c = action.submit_comment
s.comments.append(c)
r, bd, fb = self._score_comment(c)
reward += r
breakdown.update(bd)
feedback = fb
elif action.approve is True or action.request_changes is True:
feedback = "Decision submitted."
else:
reward = -0.05
breakdown["invalid_action"] = reward
feedback = "Invalid action structure. Use submit_comment, approve, request_changes, or pass_step."
breakdown["step_reward"] = reward
return reward, breakdown, feedback
def _score_comment(self, comment: ReviewComment) -> tuple[float, dict, str]:
"""Reward a single comment against ground truth."""
s = self._state
gt_issues = self._task["ground_truth_issues"]
# Check if this comment corresponds to a real issue
for gt in gt_issues:
if _comment_matches_issue(comment, gt):
if gt["id"] in s.found_issue_ids:
return -0.05, {"duplicate_penalty": -0.05}, f"Issue '{gt['id']}' already found — duplicate comment penalised."
s.found_issue_ids.append(gt["id"])
r = 0.3 # base for finding the issue
bd: dict[str, float] = {"issue_found": 0.3}
# severity bonus
if comment.severity.lower() == gt["severity"].lower():
r += 0.1
bd["severity_correct"] = 0.1
else:
r -= 0.05
bd["severity_wrong"] = -0.05
# fix quality bonus
fix_kws = gt.get("fix_keywords", [])
agent_text = (comment.fix_suggestion + " " + comment.description).lower()
if fix_kws and any(kw.lower() in agent_text for kw in fix_kws):
r += 0.1
bd["fix_quality"] = 0.1
return r, bd, f"✓ Found real issue '{gt['id']}' (severity: {gt['severity']}). Reward: {r:+.2f}"
# No match → false positive
s.false_positives += 1
return -0.08, {"false_positive": -0.08}, f"✗ No matching ground-truth issue for comment '{comment.issue_id}'. False positive penalised."
def _final_decision_bonus(self) -> tuple[float, dict]:
expected = self._task.get("expected_decision", "request_changes")
s = self._state
n_gt = len(self._task["ground_truth_issues"])
recall = len(s.found_issue_ids) / max(n_gt, 1)
if s.final_decision == expected:
bonus = 0.2 * recall # scales with how many issues were found
return bonus, {"decision_correct_bonus": bonus}
else:
return -0.1, {"decision_wrong_penalty": -0.1}
# ---------------------------------------------------------------------------
# Utility
# ---------------------------------------------------------------------------
def _comment_matches_issue(comment: ReviewComment, gt: dict) -> bool:
"""
Heuristic match: checks if the comment's description or issue_id contains
any of the ground-truth keywords.
"""
keywords: list[str] = gt.get("match_keywords", [])
if not keywords:
return False
text = (comment.description + " " + comment.issue_id + " " + comment.fix_suggestion).lower()
# Require at least 2 keyword hits for robustness (or 1 if only 1 keyword defined)
hits = sum(1 for kw in keywords if kw.lower() in text)
threshold = min(2, len(keywords))
return hits >= threshold