code-revieww-env / server /environment.py
codemaverick2
Add diversity/exploration bonuses, near-miss type check, context truncation
78f3eb2
"""
Core environment logic for the Code Review Environment.
"""
from __future__ import annotations
import random
import uuid
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from typing import Optional, List, Dict, Any, Set
from models import Issue, ReviewAction, ReviewObservation, ReviewState
from tasks.data import ALL_TASKS, TASK_IDS
from server.graders import (
grade_episode, compute_live_score, match_issue, match_quality,
compute_code_metadata, grade_episode_detailed,
graduated_near_reward, compute_potential, compute_code_state_features,
)
try:
from openenv.core.env_server import Environment as _BaseEnv
_HAS_OPENENV = True
except ImportError:
_HAS_OPENENV = False
class _BaseEnv: # type: ignore[no-redef]
pass
# Reward constants
_BASE_TP_REWARD = 0.10
_NEAR_MISS_REWARD = 0.03
_BASE_FP_PENALTY = -0.05
_SEVERITY_EXACT_BONUS = 0.02 # when severity exactly matches GT
_TEMPORAL_BONUS = 0.02 # early correct flag (first 40% of steps)
_CONFIDENCE_TP_BONUS = 0.01 # high-confidence TP
_CONFIDENCE_FP_EXTRA = -0.03 # high-confidence FP (penalty multiplier)
_HINT_COST = -0.01
_REMOVE_TP_PENALTY = -0.03
_REMOVE_FP_REWARD = 0.03
_VALIDATION_PENALTY = -0.02
# Flood protection: escalating FP penalty
_FP_FLOOD_THRESHOLD = 3 # FPs before escalation kicks in
_FP_FLOOD_MULTIPLIER = 1.5 # each extra FP beyond threshold costs 1.5x more
# Diversity bonus: reward for covering a new issue category
_DIVERSITY_BONUS = 0.02 # first TP in a new issue_type category
# Exploration bonus: first flag in a previously unflagged file
_FILE_EXPLORATION_BONUS = 0.01
_SEV_RANK = {"low": 0, "medium": 1, "high": 2, "critical": 3}
class CodeReviewEnvironment(_BaseEnv):
"""
A code review and security audit RL environment.
The agent receives code files and must identify bugs, security
vulnerabilities, and performance issues by flagging them with
exact line numbers, types, and severity ratings.
Reward design:
- True positive flag: +0.10 base, +0.02 severity exact match,
+0.02 early (first 40% steps), +0.01 high-confidence TP
- Near-miss (±3-5 lines): +0.03 partial credit
- False positive: -0.05 base, escalating penalty after 3rd FP,
extra -0.03 for high-confidence FP
- Clear false positive: +0.03
- Clear true positive: -0.03
- Hint: -0.01
- Submit: final F1+severity score (0.0–1.0)
- Auto-end (max_steps): full grade score (no penalty)
"""
SUPPORTS_CONCURRENT_SESSIONS = False
def __init__(self) -> None:
self._state = ReviewState()
self._task: Optional[dict] = None
self._ground_truth: List[Issue] = []
self._hint_index: int = 0
self._code_metadata: Dict[str, Any] = {}
self._fp_count: int = 0 # total false positives this episode
self._matched_gt_indices: Set[int] = set() # GT indices already matched
self._episode_rewards: List[float] = [] # for VL return normalization
self._found_categories: Set[str] = set() # issue types already found (for diversity bonus)
self._flagged_files: Set[str] = set() # files already flagged (for exploration bonus)
def reset(
self,
task_id: Optional[str] = None,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
**kwargs,
) -> ReviewObservation:
"""Start a new review episode."""
if seed is not None:
random.seed(seed)
if task_id is None or task_id not in ALL_TASKS:
task_id = random.choice(TASK_IDS)
self._task = ALL_TASKS[task_id]
self._ground_truth = [
Issue.from_dict(gt)
for gt in self._task["ground_truth_issues"]
]
self._hint_index = 0
self._fp_count = 0
self._matched_gt_indices = set()
self._episode_rewards = []
self._found_categories = set()
self._flagged_files = set()
self._state = ReviewState(
task_id=task_id,
difficulty=self._task["difficulty"],
episode_id=episode_id or str(uuid.uuid4()),
step_count=0,
flagged_issues=[],
current_score=0.0,
submitted=False,
)
issue_categories = list({gt.issue_type for gt in self._ground_truth})
self._code_metadata = compute_code_metadata(
self._task["code_files"],
issue_categories=issue_categories,
)
# Pre-compute initial state features (progress=empty at reset)
self._code_metadata["state_features"] = compute_code_state_features(
self._code_metadata, progress={}
)
return ReviewObservation(
task_id=task_id,
task_description=self._task["description"],
code_files=self._task["code_files"],
language=self._task.get("language", "python"),
flagged_issues=[],
step_count=0,
max_steps=self._task["max_steps"],
hints_remaining=len(self._task.get("hints", [])),
feedback=(
f"New episode started. Task: {self._task['difficulty'].upper()}. "
f"Review the code carefully and flag all issues you find. "
f"Use 'submit_review' when done. "
f"Issue categories present: {sorted(set(issue_categories))}."
),
current_score=0.0,
done=False,
reward=None,
reward_breakdown={},
progress={},
flagged_summary={},
code_metadata=self._code_metadata,
)
def step(
self,
action: ReviewAction,
timeout_s: Optional[float] = None,
**kwargs,
) -> ReviewObservation:
"""Process one agent action and return the new observation."""
if self._task is None:
return ReviewObservation(
done=True,
reward=0.0,
feedback="Episode not initialized. Call reset() first.",
)
if self._state.submitted:
return ReviewObservation(
task_id=self._state.task_id,
task_description="",
code_files={},
flagged_issues=list(self._state.flagged_issues),
step_count=self._state.step_count,
max_steps=self._task["max_steps"],
hints_remaining=0,
feedback="Episode already submitted. Call reset() to start a new episode.",
current_score=self._state.current_score,
done=True,
reward=0.0,
)
if isinstance(action, dict):
action = ReviewAction.from_dict(action)
self._state.step_count += 1
reward, feedback, reward_breakdown = self._process_action(action)
# Track episode rewards for VL return normalization
if reward is not None:
self._episode_rewards.append(float(reward))
max_steps = self._task["max_steps"]
auto_end = self._state.step_count >= max_steps and not self._state.submitted
done = self._state.submitted or auto_end
if auto_end and not self._state.submitted:
# Auto-end: grade in full (no penalty for hitting step limit)
final = grade_episode(self._state.flagged_issues, self._ground_truth)
self._state.current_score = final
reward = final # full score, no 0.5x penalty
reward_breakdown = {"auto_end_grade": final, "total": final}
feedback += (
f" Step budget exhausted — auto-graded: {final:.3f}. "
f"Submit earlier next time for slightly cleaner feedback."
)
self._state.submitted = True
live = compute_live_score(self._state.flagged_issues, self._ground_truth)
self._state.current_score = live
progress = self._compute_progress(max_steps)
flagged_summary = self._compute_flagged_summary()
# PRM-style dense signal: expected reward-to-go
# Based on Process Reward Models research: give agent an estimate of
# how much reward is still available, so it can plan remaining steps.
tp_found = len(self._matched_gt_indices)
total_gt = len(self._ground_truth)
issues_remaining = total_gt - tp_found
# Expected: each remaining TP gives ~0.12 (base + avg severity bonus)
expected_reward_to_go = round(issues_remaining * 0.12, 3)
return ReviewObservation(
task_id=self._state.task_id,
task_description="",
code_files={},
language=self._task.get("language", "python"),
flagged_issues=list(self._state.flagged_issues),
step_count=self._state.step_count,
max_steps=max_steps,
hints_remaining=max(0, len(self._task.get("hints", [])) - self._hint_index),
feedback=feedback,
current_score=live,
done=done,
reward=reward,
reward_breakdown=reward_breakdown,
progress=progress,
flagged_summary=flagged_summary,
code_metadata={}, # Only populated on reset
metadata={
"issues_remaining": issues_remaining,
"expected_reward_to_go": expected_reward_to_go,
},
)
@property
def state(self) -> ReviewState:
return self._state
# ------------------------------------------------------------------
# Progress and summary helpers
# ------------------------------------------------------------------
def _compute_progress(self, max_steps: int) -> Dict[str, Any]:
"""Compute live precision/recall/f1, step stats, and unfound issue types."""
flagged = self._state.flagged_issues
gt = self._ground_truth
tp = 0
fp = 0
matched: Set[int] = set()
found_types: Set[str] = set()
for flag in flagged:
hit = False
for i, g in enumerate(gt):
if i not in matched and match_issue(flag, g):
tp += 1
matched.add(i)
found_types.add(g.issue_type)
hit = True
break
if not hit:
fp += 1
fn = len(gt) - len(matched)
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0
all_types = {g.issue_type for g in gt}
unfound_types = sorted(all_types - found_types)
steps_used = self._state.step_count
steps_remaining = max(0, max_steps - steps_used)
# Variable-Length Return Normalization (VL Norm 2025):
# normalized_return = cumulative_reward / max(steps_used, 1)
# This makes return comparable across episodes of different length,
# which is key for multi-task RL where tasks have different max_steps.
cumulative_reward = sum(self._episode_rewards)
normalized_return = round(cumulative_reward / max(steps_used, 1), 4)
progress = {
"precision": round(precision, 4),
"recall": round(recall, 4),
"f1": round(f1, 4),
"true_positives": float(tp),
"false_positives": float(fp),
"total_ground_truth": float(len(gt)),
"steps_used": float(steps_used),
"steps_remaining": float(steps_remaining),
"unfound_issue_types": unfound_types,
"normalized_return": normalized_return,
"cumulative_reward": round(cumulative_reward, 4),
}
# 12-dim state feature vector for RL policy/value networks (code2vec/PBRS literature)
progress["state_features"] = compute_code_state_features(
self._code_metadata, progress=progress
)
return progress
def _compute_flagged_summary(self) -> Dict[str, Any]:
"""Compute correct/incorrect/near_miss counts."""
flagged = self._state.flagged_issues
gt = self._ground_truth
correct = 0
near_misses = 0
incorrect = 0
matched_gt: Set[int] = set()
for flag in flagged:
matched = False
for i, g in enumerate(gt):
if i in matched_gt:
continue
if match_issue(flag, g):
correct += 1
matched_gt.add(i)
matched = True
break
if not matched:
is_near = False
for i, g in enumerate(gt):
if i in matched_gt:
continue
if match_quality(flag, g) == "near":
is_near = True
break
if is_near:
near_misses += 1
else:
incorrect += 1
return {
"total_flagged": len(flagged),
"correct": correct,
"incorrect": incorrect,
"near_misses": near_misses,
}
# ------------------------------------------------------------------
# Action handlers
# ------------------------------------------------------------------
def _process_action(self, action: ReviewAction):
atype = (action.action_type or "").strip().lower()
if atype == "flag_issue":
return self._handle_flag(action)
elif atype == "clear_flag":
return self._handle_clear(action)
elif atype == "request_hint":
return self._handle_hint()
elif atype == "submit_review":
return self._handle_submit()
else:
return 0.0, (
f"Unknown action_type '{action.action_type}'. "
"Use: flag_issue | clear_flag | request_hint | submit_review"
), {}
def _handle_flag(self, action: ReviewAction):
if action.line_number is None:
return _VALIDATION_PENALTY, "flag_issue requires 'line_number'.", {"validation_penalty": _VALIDATION_PENALTY}
if not action.filename:
return _VALIDATION_PENALTY, "flag_issue requires 'filename'.", {"validation_penalty": _VALIDATION_PENALTY}
if action.issue_type not in ("bug", "security", "performance", "logic", None):
action.issue_type = "bug"
if action.severity not in ("low", "medium", "high", "critical", None):
action.severity = "medium"
# Duplicate check
for existing in self._state.flagged_issues:
if (existing.line_number == action.line_number
and existing.filename == action.filename):
return 0.0, (
f"Line {action.line_number} in {action.filename} already flagged. "
"Use clear_flag first to change it."
), {"duplicate": 0.0}
new_issue = Issue(
line_number=action.line_number,
filename=action.filename or "",
issue_type=action.issue_type or "bug",
severity=action.severity or "medium",
description=action.description or "",
fix_suggestion=action.fix_suggestion,
)
# Track file exploration
is_new_file = action.filename not in self._flagged_files
if action.filename:
self._flagged_files.add(action.filename)
# Classify: TP, near-miss (with line distance), or FP
is_tp = False
is_near = False
near_line_diff = 0
matched_gt_issue: Optional[Issue] = None
matched_gt_idx: Optional[int] = None
for i, gt in enumerate(self._ground_truth):
q = match_quality(new_issue, gt)
if q == "exact" and i not in self._matched_gt_indices:
is_tp = True
matched_gt_issue = gt
matched_gt_idx = i
break
elif q == "near" and not is_near:
is_near = True
near_line_diff = abs(new_issue.line_number - gt.line_number)
self._state.flagged_issues.append(new_issue)
# PBRS: compute potential before and after this flag
tp_before = len(self._matched_gt_indices)
total_gt = len(self._ground_truth)
reward_breakdown: Dict[str, float] = {}
if is_tp and matched_gt_issue is not None and matched_gt_idx is not None:
self._matched_gt_indices.add(matched_gt_idx)
tp_after = len(self._matched_gt_indices)
base_reward = _BASE_TP_REWARD
reward_breakdown["base_tp"] = base_reward
# Severity exact match bonus
severity_bonus = 0.0
if new_issue.severity == matched_gt_issue.severity:
severity_bonus = _SEVERITY_EXACT_BONUS
reward_breakdown["severity_exact"] = severity_bonus
# Temporal bonus: TP caught in first 40% of max_steps
max_steps = self._task["max_steps"]
early_threshold = max(1, int(max_steps * 0.4))
temporal_bonus = 0.0
if self._state.step_count <= early_threshold:
temporal_bonus = _TEMPORAL_BONUS
reward_breakdown["temporal_bonus"] = temporal_bonus
# Confidence calibration: high confidence TP → small bonus
confidence_bonus = 0.0
if action.confidence is not None and action.confidence >= 0.7:
confidence_bonus = _CONFIDENCE_TP_BONUS
reward_breakdown["confidence_bonus"] = confidence_bonus
# PBRS: Φ(s') - Φ(s) (potential-based shaping, policy-invariant)
phi_before = compute_potential(tp_before, total_gt)
phi_after = compute_potential(tp_after, total_gt)
pbrs_bonus = round(phi_after - phi_before, 4)
reward_breakdown["pbrs_shaping"] = pbrs_bonus
# Diversity bonus: first TP in a new issue category
diversity_bonus = 0.0
gt_type = matched_gt_issue.issue_type
if gt_type not in self._found_categories:
self._found_categories.add(gt_type)
diversity_bonus = _DIVERSITY_BONUS
reward_breakdown["diversity_bonus"] = diversity_bonus
# Exploration bonus: first flag in a new file (multi-file tasks)
exploration_bonus = 0.0
if is_new_file and len(self._task.get("code_files", {})) > 1:
exploration_bonus = _FILE_EXPLORATION_BONUS
reward_breakdown["exploration_bonus"] = exploration_bonus
reward = (base_reward + severity_bonus + temporal_bonus +
confidence_bonus + pbrs_bonus + diversity_bonus + exploration_bonus)
reward_breakdown["total"] = round(reward, 4)
sev_note = f", severity +{severity_bonus:.2f}" if severity_bonus else ""
temp_note = f", early +{temporal_bonus:.2f}" if temporal_bonus else ""
conf_note = f", conf +{confidence_bonus:.2f}" if confidence_bonus else ""
pbrs_note = f", progress +{pbrs_bonus:.2f}" if pbrs_bonus > 0 else ""
div_note = f", new-type +{diversity_bonus:.2f}" if diversity_bonus else ""
feedback = (
f"Correct! Issue at {action.filename}:{action.line_number} confirmed. "
f"[+{reward:.2f}{sev_note}{temp_note}{conf_note}{pbrs_note}{div_note}]"
)
elif is_near:
# Graduated near-miss: smooth exponential decay by line distance
near_reward = graduated_near_reward(near_line_diff)
reward_breakdown["near_miss"] = near_reward
reward_breakdown["line_diff"] = float(near_line_diff)
reward_breakdown["total"] = near_reward
feedback = (
f"Close! Near a real issue at {action.filename}:{action.line_number}. "
f"[+{near_reward:.3f}{near_line_diff} lines off, adjust line number]"
)
reward = near_reward
else:
# False positive — with flood protection
self._fp_count += 1
base_penalty = _BASE_FP_PENALTY
reward_breakdown["base_fp"] = base_penalty
# Escalating penalty after FP_FLOOD_THRESHOLD FPs
flood_penalty = 0.0
if self._fp_count > _FP_FLOOD_THRESHOLD:
extra = self._fp_count - _FP_FLOOD_THRESHOLD
flood_penalty = round(-0.02 * extra * _FP_FLOOD_MULTIPLIER, 3)
reward_breakdown["flood_penalty"] = flood_penalty
# High-confidence FP: extra penalty
confidence_penalty = 0.0
if action.confidence is not None and action.confidence >= 0.7:
confidence_penalty = _CONFIDENCE_FP_EXTRA
reward_breakdown["confidence_penalty"] = confidence_penalty
reward = base_penalty + flood_penalty + confidence_penalty
reward_breakdown["total"] = round(reward, 4)
flood_note = f", over-flagging -{abs(flood_penalty):.2f}" if flood_penalty else ""
conf_note = f", high-confidence penalty {confidence_penalty:.2f}" if confidence_penalty else ""
feedback = (
f"No match at {action.filename}:{action.line_number}. "
f"[{reward:.2f} — false positive{flood_note}{conf_note}]"
)
return reward, feedback, reward_breakdown
def _handle_clear(self, action: ReviewAction):
if action.line_number is None or not action.filename:
return _VALIDATION_PENALTY, "clear_flag requires 'line_number' and 'filename'.", {"validation_penalty": _VALIDATION_PENALTY}
removed_issue = None
new_list = []
for f in self._state.flagged_issues:
if f.line_number == action.line_number and f.filename == action.filename:
removed_issue = f
else:
new_list.append(f)
if removed_issue is None:
return 0.0, (
f"No flagged issue found at {action.filename}:{action.line_number}."
), {"no_op": 0.0}
self._state.flagged_issues = new_list
# Check if removed issue was TP
was_tp = any(match_issue(removed_issue, gt) for gt in self._ground_truth)
if was_tp:
# Un-track it from matched set
for i, gt in enumerate(self._ground_truth):
if match_issue(removed_issue, gt):
self._matched_gt_indices.discard(i)
break
reward = _REMOVE_TP_PENALTY
reward_breakdown = {"removed_tp": reward, "total": reward}
feedback = (
f"Removed a correct finding at {action.filename}:{action.line_number}. "
f"[{reward:.2f}]"
)
else:
# Removing a FP — decrement counter
self._fp_count = max(0, self._fp_count - 1)
reward = _REMOVE_FP_REWARD
reward_breakdown = {"removed_fp": reward, "total": reward}
feedback = (
f"Removed a false positive at {action.filename}:{action.line_number}. "
f"[+{reward:.2f} — good correction]"
)
return reward, feedback, reward_breakdown
def _handle_hint(self):
hints = self._task.get("hints", [])
adaptive_hint = self._get_adaptive_hint()
if adaptive_hint:
return _HINT_COST, f"Hint: {adaptive_hint} ({_HINT_COST} reward)", {"hint_cost": _HINT_COST}
if self._hint_index >= len(hints):
return _HINT_COST, "No more hints available for this task.", {"hint_cost": _HINT_COST}
hint = hints[self._hint_index]
self._hint_index += 1
remaining = len(hints) - self._hint_index
return _HINT_COST, f"Hint {self._hint_index}/{len(hints)}: {hint} ({remaining} hints left)", {"hint_cost": _HINT_COST}
def _get_adaptive_hint(self) -> Optional[str]:
"""Generate a context-aware hint based on current episode state."""
flagged = self._state.flagged_issues
gt = self._ground_truth
if not gt:
return None
tp_count = len(self._matched_gt_indices)
fp_count = len(flagged) - tp_count - sum(
1 for f in flagged
if any(match_quality(f, g) == "near" for g in gt)
)
issue_categories = self._code_metadata.get("issue_categories", [])
# Many false positives: over-flagging
if fp_count > tp_count and fp_count >= 2:
return (
"You are over-flagging. Focus only on confident, concrete findings. "
"Consider using clear_flag to remove uncertain flags."
)
# No correct flags at all yet
if len(flagged) > 0 and tp_count == 0:
if issue_categories:
cats = ", ".join(sorted(set(issue_categories)))
return (
f"Focus on [{cats}] issues. "
"None of your current flags match real issues. Re-examine carefully."
)
# Found some but missed whole categories
if tp_count > 0 and issue_categories:
found_types: Set[str] = set()
for i in self._matched_gt_indices:
found_types.add(gt[i].issue_type)
missed = sorted(set(issue_categories) - found_types)
if missed:
missed_str = ", ".join(missed)
return (
f"Good progress! You've found some issues but haven't flagged any "
f"[{missed_str}] issues yet — look again for those specifically."
)
return None # Fall through to static hints
def _handle_submit(self):
self._state.submitted = True
final_score = grade_episode(self._state.flagged_issues, self._ground_truth)
self._state.current_score = final_score
tp_count = len(self._matched_gt_indices)
total_gt = len(self._ground_truth)
total_flagged = len(self._state.flagged_issues)
fp_count = total_flagged - tp_count
# Breakdown for detailed feedback
detailed = grade_episode_detailed(self._state.flagged_issues, self._ground_truth)
feedback = (
f"Review submitted! Final score: {final_score:.3f}. "
f"Found {tp_count}/{total_gt} issues. "
f"Precision: {detailed['precision']:.2f}, Recall: {detailed['recall']:.2f}, "
f"F1: {detailed['f1']:.2f}. "
)
if fp_count > 0:
feedback += f"{fp_count} false positive(s). "
if detailed["false_negatives"] > 0:
fn = detailed["false_negatives"]
feedback += f"{fn} issue(s) missed."
reward_breakdown = {
"final_f1": detailed["f1"],
"severity_accuracy": detailed["severity_accuracy"],
"final_score": final_score,
"total": final_score,
}
return final_score, feedback, reward_breakdown