Spaces:
Running
Running
| """ | |
| Core environment logic for the Code Review Environment. | |
| """ | |
| from __future__ import annotations | |
| import random | |
| import uuid | |
| import sys | |
| import os | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from typing import Optional, List, Dict, Any, Set | |
| from models import Issue, ReviewAction, ReviewObservation, ReviewState | |
| from tasks.data import ALL_TASKS, TASK_IDS | |
| from server.graders import ( | |
| grade_episode, compute_live_score, match_issue, match_quality, | |
| compute_code_metadata, grade_episode_detailed, | |
| graduated_near_reward, compute_potential, compute_code_state_features, | |
| ) | |
| try: | |
| from openenv.core.env_server import Environment as _BaseEnv | |
| _HAS_OPENENV = True | |
| except ImportError: | |
| _HAS_OPENENV = False | |
| class _BaseEnv: # type: ignore[no-redef] | |
| pass | |
| # Reward constants | |
| _BASE_TP_REWARD = 0.10 | |
| _NEAR_MISS_REWARD = 0.03 | |
| _BASE_FP_PENALTY = -0.05 | |
| _SEVERITY_EXACT_BONUS = 0.02 # when severity exactly matches GT | |
| _TEMPORAL_BONUS = 0.02 # early correct flag (first 40% of steps) | |
| _CONFIDENCE_TP_BONUS = 0.01 # high-confidence TP | |
| _CONFIDENCE_FP_EXTRA = -0.03 # high-confidence FP (penalty multiplier) | |
| _HINT_COST = -0.01 | |
| _REMOVE_TP_PENALTY = -0.03 | |
| _REMOVE_FP_REWARD = 0.03 | |
| _VALIDATION_PENALTY = -0.02 | |
| # Flood protection: escalating FP penalty | |
| _FP_FLOOD_THRESHOLD = 3 # FPs before escalation kicks in | |
| _FP_FLOOD_MULTIPLIER = 1.5 # each extra FP beyond threshold costs 1.5x more | |
| # Diversity bonus: reward for covering a new issue category | |
| _DIVERSITY_BONUS = 0.02 # first TP in a new issue_type category | |
| # Exploration bonus: first flag in a previously unflagged file | |
| _FILE_EXPLORATION_BONUS = 0.01 | |
| _SEV_RANK = {"low": 0, "medium": 1, "high": 2, "critical": 3} | |
| class CodeReviewEnvironment(_BaseEnv): | |
| """ | |
| A code review and security audit RL environment. | |
| The agent receives code files and must identify bugs, security | |
| vulnerabilities, and performance issues by flagging them with | |
| exact line numbers, types, and severity ratings. | |
| Reward design: | |
| - True positive flag: +0.10 base, +0.02 severity exact match, | |
| +0.02 early (first 40% steps), +0.01 high-confidence TP | |
| - Near-miss (±3-5 lines): +0.03 partial credit | |
| - False positive: -0.05 base, escalating penalty after 3rd FP, | |
| extra -0.03 for high-confidence FP | |
| - Clear false positive: +0.03 | |
| - Clear true positive: -0.03 | |
| - Hint: -0.01 | |
| - Submit: final F1+severity score (0.0–1.0) | |
| - Auto-end (max_steps): full grade score (no penalty) | |
| """ | |
| SUPPORTS_CONCURRENT_SESSIONS = False | |
| def __init__(self) -> None: | |
| self._state = ReviewState() | |
| self._task: Optional[dict] = None | |
| self._ground_truth: List[Issue] = [] | |
| self._hint_index: int = 0 | |
| self._code_metadata: Dict[str, Any] = {} | |
| self._fp_count: int = 0 # total false positives this episode | |
| self._matched_gt_indices: Set[int] = set() # GT indices already matched | |
| self._episode_rewards: List[float] = [] # for VL return normalization | |
| self._found_categories: Set[str] = set() # issue types already found (for diversity bonus) | |
| self._flagged_files: Set[str] = set() # files already flagged (for exploration bonus) | |
| def reset( | |
| self, | |
| task_id: Optional[str] = None, | |
| seed: Optional[int] = None, | |
| episode_id: Optional[str] = None, | |
| **kwargs, | |
| ) -> ReviewObservation: | |
| """Start a new review episode.""" | |
| if seed is not None: | |
| random.seed(seed) | |
| if task_id is None or task_id not in ALL_TASKS: | |
| task_id = random.choice(TASK_IDS) | |
| self._task = ALL_TASKS[task_id] | |
| self._ground_truth = [ | |
| Issue.from_dict(gt) | |
| for gt in self._task["ground_truth_issues"] | |
| ] | |
| self._hint_index = 0 | |
| self._fp_count = 0 | |
| self._matched_gt_indices = set() | |
| self._episode_rewards = [] | |
| self._found_categories = set() | |
| self._flagged_files = set() | |
| self._state = ReviewState( | |
| task_id=task_id, | |
| difficulty=self._task["difficulty"], | |
| episode_id=episode_id or str(uuid.uuid4()), | |
| step_count=0, | |
| flagged_issues=[], | |
| current_score=0.0, | |
| submitted=False, | |
| ) | |
| issue_categories = list({gt.issue_type for gt in self._ground_truth}) | |
| self._code_metadata = compute_code_metadata( | |
| self._task["code_files"], | |
| issue_categories=issue_categories, | |
| ) | |
| # Pre-compute initial state features (progress=empty at reset) | |
| self._code_metadata["state_features"] = compute_code_state_features( | |
| self._code_metadata, progress={} | |
| ) | |
| return ReviewObservation( | |
| task_id=task_id, | |
| task_description=self._task["description"], | |
| code_files=self._task["code_files"], | |
| language=self._task.get("language", "python"), | |
| flagged_issues=[], | |
| step_count=0, | |
| max_steps=self._task["max_steps"], | |
| hints_remaining=len(self._task.get("hints", [])), | |
| feedback=( | |
| f"New episode started. Task: {self._task['difficulty'].upper()}. " | |
| f"Review the code carefully and flag all issues you find. " | |
| f"Use 'submit_review' when done. " | |
| f"Issue categories present: {sorted(set(issue_categories))}." | |
| ), | |
| current_score=0.0, | |
| done=False, | |
| reward=None, | |
| reward_breakdown={}, | |
| progress={}, | |
| flagged_summary={}, | |
| code_metadata=self._code_metadata, | |
| ) | |
| def step( | |
| self, | |
| action: ReviewAction, | |
| timeout_s: Optional[float] = None, | |
| **kwargs, | |
| ) -> ReviewObservation: | |
| """Process one agent action and return the new observation.""" | |
| if self._task is None: | |
| return ReviewObservation( | |
| done=True, | |
| reward=0.0, | |
| feedback="Episode not initialized. Call reset() first.", | |
| ) | |
| if self._state.submitted: | |
| return ReviewObservation( | |
| task_id=self._state.task_id, | |
| task_description="", | |
| code_files={}, | |
| flagged_issues=list(self._state.flagged_issues), | |
| step_count=self._state.step_count, | |
| max_steps=self._task["max_steps"], | |
| hints_remaining=0, | |
| feedback="Episode already submitted. Call reset() to start a new episode.", | |
| current_score=self._state.current_score, | |
| done=True, | |
| reward=0.0, | |
| ) | |
| if isinstance(action, dict): | |
| action = ReviewAction.from_dict(action) | |
| self._state.step_count += 1 | |
| reward, feedback, reward_breakdown = self._process_action(action) | |
| # Track episode rewards for VL return normalization | |
| if reward is not None: | |
| self._episode_rewards.append(float(reward)) | |
| max_steps = self._task["max_steps"] | |
| auto_end = self._state.step_count >= max_steps and not self._state.submitted | |
| done = self._state.submitted or auto_end | |
| if auto_end and not self._state.submitted: | |
| # Auto-end: grade in full (no penalty for hitting step limit) | |
| final = grade_episode(self._state.flagged_issues, self._ground_truth) | |
| self._state.current_score = final | |
| reward = final # full score, no 0.5x penalty | |
| reward_breakdown = {"auto_end_grade": final, "total": final} | |
| feedback += ( | |
| f" Step budget exhausted — auto-graded: {final:.3f}. " | |
| f"Submit earlier next time for slightly cleaner feedback." | |
| ) | |
| self._state.submitted = True | |
| live = compute_live_score(self._state.flagged_issues, self._ground_truth) | |
| self._state.current_score = live | |
| progress = self._compute_progress(max_steps) | |
| flagged_summary = self._compute_flagged_summary() | |
| # PRM-style dense signal: expected reward-to-go | |
| # Based on Process Reward Models research: give agent an estimate of | |
| # how much reward is still available, so it can plan remaining steps. | |
| tp_found = len(self._matched_gt_indices) | |
| total_gt = len(self._ground_truth) | |
| issues_remaining = total_gt - tp_found | |
| # Expected: each remaining TP gives ~0.12 (base + avg severity bonus) | |
| expected_reward_to_go = round(issues_remaining * 0.12, 3) | |
| return ReviewObservation( | |
| task_id=self._state.task_id, | |
| task_description="", | |
| code_files={}, | |
| language=self._task.get("language", "python"), | |
| flagged_issues=list(self._state.flagged_issues), | |
| step_count=self._state.step_count, | |
| max_steps=max_steps, | |
| hints_remaining=max(0, len(self._task.get("hints", [])) - self._hint_index), | |
| feedback=feedback, | |
| current_score=live, | |
| done=done, | |
| reward=reward, | |
| reward_breakdown=reward_breakdown, | |
| progress=progress, | |
| flagged_summary=flagged_summary, | |
| code_metadata={}, # Only populated on reset | |
| metadata={ | |
| "issues_remaining": issues_remaining, | |
| "expected_reward_to_go": expected_reward_to_go, | |
| }, | |
| ) | |
| def state(self) -> ReviewState: | |
| return self._state | |
| # ------------------------------------------------------------------ | |
| # Progress and summary helpers | |
| # ------------------------------------------------------------------ | |
| def _compute_progress(self, max_steps: int) -> Dict[str, Any]: | |
| """Compute live precision/recall/f1, step stats, and unfound issue types.""" | |
| flagged = self._state.flagged_issues | |
| gt = self._ground_truth | |
| tp = 0 | |
| fp = 0 | |
| matched: Set[int] = set() | |
| found_types: Set[str] = set() | |
| for flag in flagged: | |
| hit = False | |
| for i, g in enumerate(gt): | |
| if i not in matched and match_issue(flag, g): | |
| tp += 1 | |
| matched.add(i) | |
| found_types.add(g.issue_type) | |
| hit = True | |
| break | |
| if not hit: | |
| fp += 1 | |
| fn = len(gt) - len(matched) | |
| precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 | |
| recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 | |
| f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0 | |
| all_types = {g.issue_type for g in gt} | |
| unfound_types = sorted(all_types - found_types) | |
| steps_used = self._state.step_count | |
| steps_remaining = max(0, max_steps - steps_used) | |
| # Variable-Length Return Normalization (VL Norm 2025): | |
| # normalized_return = cumulative_reward / max(steps_used, 1) | |
| # This makes return comparable across episodes of different length, | |
| # which is key for multi-task RL where tasks have different max_steps. | |
| cumulative_reward = sum(self._episode_rewards) | |
| normalized_return = round(cumulative_reward / max(steps_used, 1), 4) | |
| progress = { | |
| "precision": round(precision, 4), | |
| "recall": round(recall, 4), | |
| "f1": round(f1, 4), | |
| "true_positives": float(tp), | |
| "false_positives": float(fp), | |
| "total_ground_truth": float(len(gt)), | |
| "steps_used": float(steps_used), | |
| "steps_remaining": float(steps_remaining), | |
| "unfound_issue_types": unfound_types, | |
| "normalized_return": normalized_return, | |
| "cumulative_reward": round(cumulative_reward, 4), | |
| } | |
| # 12-dim state feature vector for RL policy/value networks (code2vec/PBRS literature) | |
| progress["state_features"] = compute_code_state_features( | |
| self._code_metadata, progress=progress | |
| ) | |
| return progress | |
| def _compute_flagged_summary(self) -> Dict[str, Any]: | |
| """Compute correct/incorrect/near_miss counts.""" | |
| flagged = self._state.flagged_issues | |
| gt = self._ground_truth | |
| correct = 0 | |
| near_misses = 0 | |
| incorrect = 0 | |
| matched_gt: Set[int] = set() | |
| for flag in flagged: | |
| matched = False | |
| for i, g in enumerate(gt): | |
| if i in matched_gt: | |
| continue | |
| if match_issue(flag, g): | |
| correct += 1 | |
| matched_gt.add(i) | |
| matched = True | |
| break | |
| if not matched: | |
| is_near = False | |
| for i, g in enumerate(gt): | |
| if i in matched_gt: | |
| continue | |
| if match_quality(flag, g) == "near": | |
| is_near = True | |
| break | |
| if is_near: | |
| near_misses += 1 | |
| else: | |
| incorrect += 1 | |
| return { | |
| "total_flagged": len(flagged), | |
| "correct": correct, | |
| "incorrect": incorrect, | |
| "near_misses": near_misses, | |
| } | |
| # ------------------------------------------------------------------ | |
| # Action handlers | |
| # ------------------------------------------------------------------ | |
| def _process_action(self, action: ReviewAction): | |
| atype = (action.action_type or "").strip().lower() | |
| if atype == "flag_issue": | |
| return self._handle_flag(action) | |
| elif atype == "clear_flag": | |
| return self._handle_clear(action) | |
| elif atype == "request_hint": | |
| return self._handle_hint() | |
| elif atype == "submit_review": | |
| return self._handle_submit() | |
| else: | |
| return 0.0, ( | |
| f"Unknown action_type '{action.action_type}'. " | |
| "Use: flag_issue | clear_flag | request_hint | submit_review" | |
| ), {} | |
| def _handle_flag(self, action: ReviewAction): | |
| if action.line_number is None: | |
| return _VALIDATION_PENALTY, "flag_issue requires 'line_number'.", {"validation_penalty": _VALIDATION_PENALTY} | |
| if not action.filename: | |
| return _VALIDATION_PENALTY, "flag_issue requires 'filename'.", {"validation_penalty": _VALIDATION_PENALTY} | |
| if action.issue_type not in ("bug", "security", "performance", "logic", None): | |
| action.issue_type = "bug" | |
| if action.severity not in ("low", "medium", "high", "critical", None): | |
| action.severity = "medium" | |
| # Duplicate check | |
| for existing in self._state.flagged_issues: | |
| if (existing.line_number == action.line_number | |
| and existing.filename == action.filename): | |
| return 0.0, ( | |
| f"Line {action.line_number} in {action.filename} already flagged. " | |
| "Use clear_flag first to change it." | |
| ), {"duplicate": 0.0} | |
| new_issue = Issue( | |
| line_number=action.line_number, | |
| filename=action.filename or "", | |
| issue_type=action.issue_type or "bug", | |
| severity=action.severity or "medium", | |
| description=action.description or "", | |
| fix_suggestion=action.fix_suggestion, | |
| ) | |
| # Track file exploration | |
| is_new_file = action.filename not in self._flagged_files | |
| if action.filename: | |
| self._flagged_files.add(action.filename) | |
| # Classify: TP, near-miss (with line distance), or FP | |
| is_tp = False | |
| is_near = False | |
| near_line_diff = 0 | |
| matched_gt_issue: Optional[Issue] = None | |
| matched_gt_idx: Optional[int] = None | |
| for i, gt in enumerate(self._ground_truth): | |
| q = match_quality(new_issue, gt) | |
| if q == "exact" and i not in self._matched_gt_indices: | |
| is_tp = True | |
| matched_gt_issue = gt | |
| matched_gt_idx = i | |
| break | |
| elif q == "near" and not is_near: | |
| is_near = True | |
| near_line_diff = abs(new_issue.line_number - gt.line_number) | |
| self._state.flagged_issues.append(new_issue) | |
| # PBRS: compute potential before and after this flag | |
| tp_before = len(self._matched_gt_indices) | |
| total_gt = len(self._ground_truth) | |
| reward_breakdown: Dict[str, float] = {} | |
| if is_tp and matched_gt_issue is not None and matched_gt_idx is not None: | |
| self._matched_gt_indices.add(matched_gt_idx) | |
| tp_after = len(self._matched_gt_indices) | |
| base_reward = _BASE_TP_REWARD | |
| reward_breakdown["base_tp"] = base_reward | |
| # Severity exact match bonus | |
| severity_bonus = 0.0 | |
| if new_issue.severity == matched_gt_issue.severity: | |
| severity_bonus = _SEVERITY_EXACT_BONUS | |
| reward_breakdown["severity_exact"] = severity_bonus | |
| # Temporal bonus: TP caught in first 40% of max_steps | |
| max_steps = self._task["max_steps"] | |
| early_threshold = max(1, int(max_steps * 0.4)) | |
| temporal_bonus = 0.0 | |
| if self._state.step_count <= early_threshold: | |
| temporal_bonus = _TEMPORAL_BONUS | |
| reward_breakdown["temporal_bonus"] = temporal_bonus | |
| # Confidence calibration: high confidence TP → small bonus | |
| confidence_bonus = 0.0 | |
| if action.confidence is not None and action.confidence >= 0.7: | |
| confidence_bonus = _CONFIDENCE_TP_BONUS | |
| reward_breakdown["confidence_bonus"] = confidence_bonus | |
| # PBRS: Φ(s') - Φ(s) (potential-based shaping, policy-invariant) | |
| phi_before = compute_potential(tp_before, total_gt) | |
| phi_after = compute_potential(tp_after, total_gt) | |
| pbrs_bonus = round(phi_after - phi_before, 4) | |
| reward_breakdown["pbrs_shaping"] = pbrs_bonus | |
| # Diversity bonus: first TP in a new issue category | |
| diversity_bonus = 0.0 | |
| gt_type = matched_gt_issue.issue_type | |
| if gt_type not in self._found_categories: | |
| self._found_categories.add(gt_type) | |
| diversity_bonus = _DIVERSITY_BONUS | |
| reward_breakdown["diversity_bonus"] = diversity_bonus | |
| # Exploration bonus: first flag in a new file (multi-file tasks) | |
| exploration_bonus = 0.0 | |
| if is_new_file and len(self._task.get("code_files", {})) > 1: | |
| exploration_bonus = _FILE_EXPLORATION_BONUS | |
| reward_breakdown["exploration_bonus"] = exploration_bonus | |
| reward = (base_reward + severity_bonus + temporal_bonus + | |
| confidence_bonus + pbrs_bonus + diversity_bonus + exploration_bonus) | |
| reward_breakdown["total"] = round(reward, 4) | |
| sev_note = f", severity +{severity_bonus:.2f}" if severity_bonus else "" | |
| temp_note = f", early +{temporal_bonus:.2f}" if temporal_bonus else "" | |
| conf_note = f", conf +{confidence_bonus:.2f}" if confidence_bonus else "" | |
| pbrs_note = f", progress +{pbrs_bonus:.2f}" if pbrs_bonus > 0 else "" | |
| div_note = f", new-type +{diversity_bonus:.2f}" if diversity_bonus else "" | |
| feedback = ( | |
| f"Correct! Issue at {action.filename}:{action.line_number} confirmed. " | |
| f"[+{reward:.2f}{sev_note}{temp_note}{conf_note}{pbrs_note}{div_note}]" | |
| ) | |
| elif is_near: | |
| # Graduated near-miss: smooth exponential decay by line distance | |
| near_reward = graduated_near_reward(near_line_diff) | |
| reward_breakdown["near_miss"] = near_reward | |
| reward_breakdown["line_diff"] = float(near_line_diff) | |
| reward_breakdown["total"] = near_reward | |
| feedback = ( | |
| f"Close! Near a real issue at {action.filename}:{action.line_number}. " | |
| f"[+{near_reward:.3f} — {near_line_diff} lines off, adjust line number]" | |
| ) | |
| reward = near_reward | |
| else: | |
| # False positive — with flood protection | |
| self._fp_count += 1 | |
| base_penalty = _BASE_FP_PENALTY | |
| reward_breakdown["base_fp"] = base_penalty | |
| # Escalating penalty after FP_FLOOD_THRESHOLD FPs | |
| flood_penalty = 0.0 | |
| if self._fp_count > _FP_FLOOD_THRESHOLD: | |
| extra = self._fp_count - _FP_FLOOD_THRESHOLD | |
| flood_penalty = round(-0.02 * extra * _FP_FLOOD_MULTIPLIER, 3) | |
| reward_breakdown["flood_penalty"] = flood_penalty | |
| # High-confidence FP: extra penalty | |
| confidence_penalty = 0.0 | |
| if action.confidence is not None and action.confidence >= 0.7: | |
| confidence_penalty = _CONFIDENCE_FP_EXTRA | |
| reward_breakdown["confidence_penalty"] = confidence_penalty | |
| reward = base_penalty + flood_penalty + confidence_penalty | |
| reward_breakdown["total"] = round(reward, 4) | |
| flood_note = f", over-flagging -{abs(flood_penalty):.2f}" if flood_penalty else "" | |
| conf_note = f", high-confidence penalty {confidence_penalty:.2f}" if confidence_penalty else "" | |
| feedback = ( | |
| f"No match at {action.filename}:{action.line_number}. " | |
| f"[{reward:.2f} — false positive{flood_note}{conf_note}]" | |
| ) | |
| return reward, feedback, reward_breakdown | |
| def _handle_clear(self, action: ReviewAction): | |
| if action.line_number is None or not action.filename: | |
| return _VALIDATION_PENALTY, "clear_flag requires 'line_number' and 'filename'.", {"validation_penalty": _VALIDATION_PENALTY} | |
| removed_issue = None | |
| new_list = [] | |
| for f in self._state.flagged_issues: | |
| if f.line_number == action.line_number and f.filename == action.filename: | |
| removed_issue = f | |
| else: | |
| new_list.append(f) | |
| if removed_issue is None: | |
| return 0.0, ( | |
| f"No flagged issue found at {action.filename}:{action.line_number}." | |
| ), {"no_op": 0.0} | |
| self._state.flagged_issues = new_list | |
| # Check if removed issue was TP | |
| was_tp = any(match_issue(removed_issue, gt) for gt in self._ground_truth) | |
| if was_tp: | |
| # Un-track it from matched set | |
| for i, gt in enumerate(self._ground_truth): | |
| if match_issue(removed_issue, gt): | |
| self._matched_gt_indices.discard(i) | |
| break | |
| reward = _REMOVE_TP_PENALTY | |
| reward_breakdown = {"removed_tp": reward, "total": reward} | |
| feedback = ( | |
| f"Removed a correct finding at {action.filename}:{action.line_number}. " | |
| f"[{reward:.2f}]" | |
| ) | |
| else: | |
| # Removing a FP — decrement counter | |
| self._fp_count = max(0, self._fp_count - 1) | |
| reward = _REMOVE_FP_REWARD | |
| reward_breakdown = {"removed_fp": reward, "total": reward} | |
| feedback = ( | |
| f"Removed a false positive at {action.filename}:{action.line_number}. " | |
| f"[+{reward:.2f} — good correction]" | |
| ) | |
| return reward, feedback, reward_breakdown | |
| def _handle_hint(self): | |
| hints = self._task.get("hints", []) | |
| adaptive_hint = self._get_adaptive_hint() | |
| if adaptive_hint: | |
| return _HINT_COST, f"Hint: {adaptive_hint} ({_HINT_COST} reward)", {"hint_cost": _HINT_COST} | |
| if self._hint_index >= len(hints): | |
| return _HINT_COST, "No more hints available for this task.", {"hint_cost": _HINT_COST} | |
| hint = hints[self._hint_index] | |
| self._hint_index += 1 | |
| remaining = len(hints) - self._hint_index | |
| return _HINT_COST, f"Hint {self._hint_index}/{len(hints)}: {hint} ({remaining} hints left)", {"hint_cost": _HINT_COST} | |
| def _get_adaptive_hint(self) -> Optional[str]: | |
| """Generate a context-aware hint based on current episode state.""" | |
| flagged = self._state.flagged_issues | |
| gt = self._ground_truth | |
| if not gt: | |
| return None | |
| tp_count = len(self._matched_gt_indices) | |
| fp_count = len(flagged) - tp_count - sum( | |
| 1 for f in flagged | |
| if any(match_quality(f, g) == "near" for g in gt) | |
| ) | |
| issue_categories = self._code_metadata.get("issue_categories", []) | |
| # Many false positives: over-flagging | |
| if fp_count > tp_count and fp_count >= 2: | |
| return ( | |
| "You are over-flagging. Focus only on confident, concrete findings. " | |
| "Consider using clear_flag to remove uncertain flags." | |
| ) | |
| # No correct flags at all yet | |
| if len(flagged) > 0 and tp_count == 0: | |
| if issue_categories: | |
| cats = ", ".join(sorted(set(issue_categories))) | |
| return ( | |
| f"Focus on [{cats}] issues. " | |
| "None of your current flags match real issues. Re-examine carefully." | |
| ) | |
| # Found some but missed whole categories | |
| if tp_count > 0 and issue_categories: | |
| found_types: Set[str] = set() | |
| for i in self._matched_gt_indices: | |
| found_types.add(gt[i].issue_type) | |
| missed = sorted(set(issue_categories) - found_types) | |
| if missed: | |
| missed_str = ", ".join(missed) | |
| return ( | |
| f"Good progress! You've found some issues but haven't flagged any " | |
| f"[{missed_str}] issues yet — look again for those specifically." | |
| ) | |
| return None # Fall through to static hints | |
| def _handle_submit(self): | |
| self._state.submitted = True | |
| final_score = grade_episode(self._state.flagged_issues, self._ground_truth) | |
| self._state.current_score = final_score | |
| tp_count = len(self._matched_gt_indices) | |
| total_gt = len(self._ground_truth) | |
| total_flagged = len(self._state.flagged_issues) | |
| fp_count = total_flagged - tp_count | |
| # Breakdown for detailed feedback | |
| detailed = grade_episode_detailed(self._state.flagged_issues, self._ground_truth) | |
| feedback = ( | |
| f"Review submitted! Final score: {final_score:.3f}. " | |
| f"Found {tp_count}/{total_gt} issues. " | |
| f"Precision: {detailed['precision']:.2f}, Recall: {detailed['recall']:.2f}, " | |
| f"F1: {detailed['f1']:.2f}. " | |
| ) | |
| if fp_count > 0: | |
| feedback += f"{fp_count} false positive(s). " | |
| if detailed["false_negatives"] > 0: | |
| fn = detailed["false_negatives"] | |
| feedback += f"{fn} issue(s) missed." | |
| reward_breakdown = { | |
| "final_f1": detailed["f1"], | |
| "severity_accuracy": detailed["severity_accuracy"], | |
| "final_score": final_score, | |
| "total": final_score, | |
| } | |
| return final_score, feedback, reward_breakdown | |