"""PR-triage task definitions and grading utilities for coding_env.""" from __future__ import annotations from dataclasses import dataclass import json import os from typing import Dict, List, Tuple try: from coding_env.models import CodeAction except ImportError: from ..models import CodeAction @dataclass(frozen=True) class CodeReviewTask: task_id: str difficulty: str pr_title: str pr_description: str changed_files: Tuple[str, ...] unified_diff: str expected_file_path: str expected_issue_type: str expected_bug_type: str expected_severity: str expected_line_number: int expected_keywords: Tuple[str, ...] TASKS: Dict[str, CodeReviewTask] = { "task_easy_1": CodeReviewTask( task_id="task_easy_1", difficulty="easy", pr_title="Fix average response time aggregation in metrics service", pr_description="Refactor aggregation logic and simplify average calculation.", changed_files=("services/metrics/aggregation.py",), unified_diff=( "diff --git a/services/metrics/aggregation.py b/services/metrics/aggregation.py\n" "@@ -10,6 +10,6 @@ def compute_avg(latencies):\n" " total = 0\n" " for latency in latencies:\n" " total += latency\n" "- return total / len(total)\n" "+ return total / len(total)\n" ), expected_file_path="services/metrics/aggregation.py", expected_issue_type="logic", expected_bug_type="logic", expected_severity="medium", expected_line_number=4, expected_keywords=("len(total)", "len(latencies)", "typeerror"), ), "task_medium_1": CodeReviewTask( task_id="task_medium_1", difficulty="medium", pr_title="Optimize login query path in auth service", pr_description="Use direct SQL construction for faster username/password checks.", changed_files=("services/auth/login.py",), unified_diff=( "diff --git a/services/auth/login.py b/services/auth/login.py\n" "@@ -21,7 +21,7 @@ def login(conn, username, password):\n" "- query = \"SELECT * FROM users WHERE name=? AND pw=?\"\n" "- return conn.execute(query, (username, password)).fetchone() is not None\n" "+ query = f\"SELECT * FROM users WHERE name='{username}' AND pw='{password}'\"\n" "+ return conn.execute(query).fetchone() is not None\n" ), expected_file_path="services/auth/login.py", expected_issue_type="security", expected_bug_type="security", expected_severity="high", expected_line_number=2, expected_keywords=("sql injection", "parameterized", "prepared statement"), ), "task_hard_1": CodeReviewTask( task_id="task_hard_1", difficulty="hard", pr_title="Add cache layer to user profile fetch endpoint", pr_description="Protect cache updates with lock to avoid races and keep data coherent.", changed_files=("services/profile/cache_layer.py",), unified_diff=( "diff --git a/services/profile/cache_layer.py b/services/profile/cache_layer.py\n" "@@ -4,12 +4,12 @@ lock = Lock()\n" " cache = {}\n" "\n" " def get_user(user_id, db):\n" " with lock:\n" " if user_id in cache:\n" " return cache[user_id]\n" " data = db.fetch_user(user_id)\n" " cache[user_id] = data\n" " return data\n" ), expected_file_path="services/profile/cache_layer.py", expected_issue_type="performance", expected_bug_type="logic", expected_severity="high", expected_line_number=7, expected_keywords=("lock contention", "critical section", "latency"), ), } EPISODE_SCORES: Dict[tuple[str, str], float] = {} SCORE_FILE = os.getenv("CODING_ENV_SCORE_FILE", "/tmp/coding_env_episode_scores.json") MIN_STRICT_SCORE = 0.01 MAX_STRICT_SCORE = 0.99 def list_tasks() -> List[Dict[str, str]]: """Return public task metadata.""" return [ {"task_id": t.task_id, "difficulty": t.difficulty, "pr_title": t.pr_title} for t in sorted(TASKS.values(), key=lambda item: item.task_id) ] def get_task(task_id: str) -> CodeReviewTask: """Resolve task by id.""" if task_id not in TASKS: raise ValueError( f"Unknown task_id '{task_id}'. Available tasks: {', '.join(sorted(TASKS))}" ) return TASKS[task_id] def format_task_prompt(task: CodeReviewTask) -> str: """Format a realistic PR-review prompt.""" files = "\n".join(f"- {path}" for path in task.changed_files) return ( f"PR Title: {task.pr_title}\n" f"PR Description: {task.pr_description}\n" f"Changed Files:\n{files}\n\n" f"Unified Diff:\n{task.unified_diff}\n\n" "Review objective: identify the highest-impact issue and provide " "file path, issue type, severity, and exact line." ) def _normalize(value: str) -> str: return value.strip().lower().replace("-", "_") def _action_issue_type(action: CodeAction) -> str: issue_type = getattr(action, "issue_type", "") if issue_type: return str(issue_type) return str(action.bug_type) def grade_action(action: CodeAction, task: CodeReviewTask) -> tuple[float, str]: """Score PR-triage action in strict (0, 1) with partial credit.""" score = 0.0 parts: List[str] = [] file_path = str(getattr(action, "file_path", "") or "") if _normalize(file_path) == _normalize(task.expected_file_path): score += 0.30 parts.append("file_path matched (+0.30)") else: parts.append( f"file_path mismatch (expected {task.expected_file_path}, got {file_path or 'none'})" ) issue_type = _action_issue_type(action) if _normalize(issue_type) == _normalize(task.expected_issue_type): score += 0.30 parts.append("issue_type matched (+0.30)") elif _normalize(action.bug_type) == _normalize(task.expected_bug_type): score += 0.20 parts.append("bug_type matched (+0.20)") else: parts.append( f"issue mismatch (expected {task.expected_issue_type}/{task.expected_bug_type}, got {issue_type}/{action.bug_type})" ) severity = str(getattr(action, "severity", "") or "") if _normalize(severity) == _normalize(task.expected_severity): score += 0.20 parts.append("severity matched (+0.20)") elif severity: score += 0.10 parts.append("severity provided but not exact (+0.10)") else: parts.append("severity missing (+0.00)") if action.line_number == task.expected_line_number: score += 0.10 parts.append("line_number matched (+0.10)") elif abs(action.line_number - task.expected_line_number) <= 2: score += 0.05 parts.append("line_number near miss (+0.05)") else: parts.append( f"line_number mismatch (expected {task.expected_line_number}, got {action.line_number})" ) review_text = (action.review or "").lower() keyword_hits = sum(1 for kw in task.expected_keywords if kw.lower() in review_text) if keyword_hits > 0: keyword_bonus = min(0.09, keyword_hits * 0.03) score += keyword_bonus parts.append(f"evidence quality matched (+{keyword_bonus:.2f})") else: parts.append("insufficient evidence language (+0.00)") score = _to_strict_open_score(score) return score, "; ".join(parts) def record_episode_score(task_id: str, episode_id: str, score: float) -> None: """Persist normalized score for grader endpoint.""" normalized = _to_strict_open_score(score) EPISODE_SCORES[(task_id, episode_id)] = normalized _persist_score(task_id, episode_id, normalized) def get_episode_score(task_id: str, episode_id: str) -> float: """Read score for task/episode pair.""" in_memory = EPISODE_SCORES.get((task_id, episode_id)) if in_memory is not None: return in_memory return _load_persisted_score(task_id, episode_id) def _persist_score(task_id: str, episode_id: str, score: float) -> None: key = f"{task_id}::{episode_id}" payload: Dict[str, float] = {} if os.path.exists(SCORE_FILE): try: with open(SCORE_FILE, "r", encoding="utf-8") as f: loaded = json.load(f) if isinstance(loaded, dict): payload = { str(k): float(v) for k, v in loaded.items() if isinstance(v, (int, float)) } except Exception: payload = {} payload[key] = float(score) with open(SCORE_FILE, "w", encoding="utf-8") as f: json.dump(payload, f) def _load_persisted_score(task_id: str, episode_id: str) -> float: if not os.path.exists(SCORE_FILE): return MIN_STRICT_SCORE try: with open(SCORE_FILE, "r", encoding="utf-8") as f: loaded = json.load(f) key = f"{task_id}::{episode_id}" value = loaded.get(key, 0.0) if isinstance(loaded, dict) else 0.0 return _to_strict_open_score(value) except Exception: return MIN_STRICT_SCORE def _to_strict_open_score(value: float) -> float: """Clamp to strict open interval (0, 1).""" return max(MIN_STRICT_SCORE, min(MAX_STRICT_SCORE, round(float(value), 4)))