PatchJudge / patchjudge /judge.py
VD10's picture
Upload patchjudge/judge.py with huggingface_hub
59ef264 verified
"""PatchJudge LLM Judge — Core evaluation engine.
Evaluates AI-generated code patches on 5 dimensions using an LLM,
producing a MergeScore (0-100) that indicates merge-readiness.
Uses HuggingFace Inference API with structured JSON output.
"""
import json
import os
import re
import time
import logging
from typing import Optional
from huggingface_hub import InferenceClient
from patchjudge.models import (
PatchExample, PatchFeatures, DimensionScore, JudgeResult
)
from patchjudge.feature_extractor import FeatureExtractor
logger = logging.getLogger(__name__)
# ============================================================================
# Prompt Templates
# ============================================================================
JUDGE_SYSTEM_PROMPT = """You are PatchJudge, an expert senior software engineer evaluating whether an AI-generated code patch is truly merge-worthy — not just whether it passes tests.
You must be HARSH and PRECISE. A patch that "works but is bad code" should score low. A patch that is clean, complete, and genuinely solves the root cause should score high.
Most AI-generated patches that pass tests are NOT merge-worthy. Average scores should be 3-5, not 7-8. A score of 7+ means genuinely good, publishable code."""
JUDGE_USER_PROMPT = """Evaluate this AI-generated code patch.
## THE ISSUE:
{problem_statement}
## THE PATCH (diff):
```diff
{agent_patch}
```
## REFERENCE GOLD PATCH (human-written):
```diff
{gold_patch}
```
## EXTRACTED FEATURES:
{features_summary}
## TEST RESULT: {test_result}
---
Score the patch on each of these 5 dimensions (0-10 integer each):
1. **CORRECTNESS** (weight: 30%): Does the patch address the ROOT CAUSE described in the issue? Would the issue be genuinely resolved for all described scenarios, not just the test cases?
2. **COMPLETENESS** (weight: 20%): Does it handle edge cases? Is error handling added where appropriate? Are there TODO comments or placeholder logic left behind?
3. **CODE QUALITY** (weight: 20%): Does the code follow the project's existing style? Is it readable, well-structured? No unnecessary complexity?
4. **NON-REGRESSION RISK** (weight: 15%): Is the change scope appropriate? Could it break unrelated functionality? Does it modify shared interfaces unnecessarily?
5. **MERGE-READINESS** (weight: 15%): Would a senior engineer approve this PR as-is? Score 8+ = approve, 5-7 = request changes, below 5 = reject.
---
Respond with ONLY this JSON (no other text):
```json
{{
"correctness": {{"score": <0-10>, "reasoning": "<2-4 sentences>", "flags": ["<issue1>", ...]}},
"completeness": {{"score": <0-10>, "reasoning": "<2-4 sentences>", "flags": ["<issue1>", ...]}},
"code_quality": {{"score": <0-10>, "reasoning": "<2-4 sentences>", "flags": ["<issue1>", ...]}},
"non_regression_risk": {{"score": <0-10>, "reasoning": "<2-4 sentences>", "flags": ["<issue1>", ...]}},
"merge_readiness": {{"score": <0-10>, "reasoning": "<2-4 sentences>", "flags": ["<issue1>", ...]}}
}}
```"""
FEATURES_TEMPLATE = """- Files changed: {num_files_changed}
- Lines added: {num_lines_added}, removed: {num_lines_removed}
- Hunks: {num_hunks}
- Change scope: {change_scope}
- Added functions: {added_functions}
- Modified functions: {modified_functions}
- Error handling present: {has_error_handling}
- Edge case handling: {has_edge_case_handling}
- Has TODOs/FIXMEs: {has_todos}
- Has hardcoded values: {has_hardcoded_values}
- Has debug statements: {has_debug_statements}
- Modifies core files: {modifies_core_files}
- New imports: {new_imports}
- Issue keyword coverage: {keyword_coverage_ratio:.0%}
- Touches test files: {touches_tests}
- Style violations: {style_violations}"""
# ============================================================================
# PatchJudge Class
# ============================================================================
class PatchJudge:
"""LLM-based judge for evaluating AI-generated code patches."""
WEIGHTS = {
"correctness": 0.30,
"completeness": 0.20,
"code_quality": 0.20,
"non_regression_risk": 0.15,
"merge_readiness": 0.15,
}
DIMENSIONS = list(WEIGHTS.keys())
def __init__(
self,
model_id: str = "Qwen/Qwen2.5-Coder-32B-Instruct",
provider: str = "auto",
temperature: float = 0.1,
max_tokens: int = 2000,
max_retries: int = 3,
retry_delay: float = 2.0,
max_context_chars: int = 12000,
):
"""Initialize PatchJudge.
Args:
model_id: HF model ID to use for judging.
provider: Inference provider ('auto', 'cerebras', 'novita', etc.)
temperature: Low for consistency (0.1 recommended).
max_tokens: Max tokens for LLM response.
max_retries: Retries on API/parse failures.
retry_delay: Seconds between retries.
max_context_chars: Max chars for patch/context in prompt.
"""
token = os.environ.get("HF_TOKEN")
self.client = InferenceClient(
provider=provider,
api_key=token,
)
self.model_id = model_id
self.temperature = temperature
self.max_tokens = max_tokens
self.max_retries = max_retries
self.retry_delay = retry_delay
self.max_context_chars = max_context_chars
self.feature_extractor = FeatureExtractor()
def judge(
self,
example: PatchExample,
features: Optional[PatchFeatures] = None,
) -> JudgeResult:
"""Evaluate a single patch example.
Args:
example: The patch to evaluate.
features: Pre-extracted features (extracted automatically if None).
Returns:
JudgeResult with MergeScore, dimension scores, and reasoning.
"""
# Extract features if not provided
if features is None:
features = self.feature_extractor.extract(example)
# Format the prompt
features_summary = self._format_features(features)
# Truncate patches if needed
agent_patch = self._truncate(example.agent_patch, self.max_context_chars // 2)
gold_patch = self._truncate(example.gold_patch, self.max_context_chars // 4)
problem_stmt = self._truncate(example.problem_statement, self.max_context_chars // 4)
user_prompt = JUDGE_USER_PROMPT.format(
problem_statement=problem_stmt,
agent_patch=agent_patch,
gold_patch=gold_patch,
features_summary=features_summary,
test_result="PASSED ✓" if example.test_passed else "FAILED ✗",
)
# Call LLM with retries
raw_output = None
scores = None
for attempt in range(self.max_retries):
try:
raw_output = self._call_llm(user_prompt)
scores = self._parse_json_output(raw_output)
self._validate_scores(scores)
break
except Exception as e:
logger.warning(
f"Attempt {attempt+1}/{self.max_retries} failed: {e}"
)
if attempt < self.max_retries - 1:
time.sleep(self.retry_delay * (attempt + 1))
if scores is None:
# Return a failure result
logger.error(
f"Failed to judge {example.instance_id} after {self.max_retries} attempts"
)
scores = {
dim: {"score": 0, "reasoning": "Judge failed to produce valid output", "flags": ["JUDGE_ERROR"]}
for dim in self.DIMENSIONS
}
raw_output = raw_output or "ERROR: No output from LLM"
# Compute MergeScore
merge_score = self._compute_merge_score(scores)
return JudgeResult(
merge_score=merge_score,
dimension_scores=scores,
raw_output=raw_output,
features=features,
model_used=self.model_id,
)
def judge_batch(
self,
examples: list[PatchExample],
features_list: Optional[list[PatchFeatures]] = None,
show_progress: bool = True,
) -> list[JudgeResult]:
"""Evaluate a batch of patches.
Args:
examples: List of PatchExamples to evaluate.
features_list: Pre-extracted features (one per example). Optional.
show_progress: Print progress.
Returns:
List of JudgeResults in same order as input.
"""
results = []
for i, example in enumerate(examples):
if show_progress:
print(f" Judging [{i+1}/{len(examples)}] {example.instance_id} "
f"({example.agent_name})...")
features = features_list[i] if features_list else None
try:
result = self.judge(example, features)
results.append(result)
if show_progress:
print(f" MergeScore: {result.merge_score:.1f}/100")
except Exception as e:
logger.error(f"Failed to judge {example.instance_id}: {e}")
# Append error result
results.append(JudgeResult(
merge_score=0.0,
dimension_scores={
dim: {"score": 0, "reasoning": f"Error: {str(e)}", "flags": ["ERROR"]}
for dim in self.DIMENSIONS
},
raw_output=f"ERROR: {str(e)}",
model_used=self.model_id,
))
# Rate limiting
time.sleep(0.5)
return results
# =========================================================================
# Internal methods
# =========================================================================
def _call_llm(self, user_prompt: str) -> str:
"""Call the LLM and return raw text response."""
response = self.client.chat_completion(
model=self.model_id,
messages=[
{"role": "system", "content": JUDGE_SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
],
max_tokens=self.max_tokens,
temperature=self.temperature,
)
return response.choices[0].message.content
def _compute_merge_score(self, scores: dict) -> float:
"""Compute weighted MergeScore (0-100) from dimension scores."""
weighted_sum = 0.0
for dim, weight in self.WEIGHTS.items():
dim_score = scores.get(dim, {}).get("score", 0)
weighted_sum += dim_score * weight
return round(weighted_sum * 10, 1) # Scale 0-10 → 0-100
def _parse_json_output(self, raw: str) -> dict:
"""Extract JSON from LLM output, handling markdown code blocks."""
# Try to find JSON in code blocks
json_match = re.search(r'```(?:json)?\s*([\{][\s\S]*?[\}])\s*```', raw)
if json_match:
return json.loads(json_match.group(1))
# Try to find raw JSON object
json_match = re.search(r'(\{[\s\S]*\})', raw)
if json_match:
# Try parsing progressively larger substrings
text = json_match.group(1)
try:
return json.loads(text)
except json.JSONDecodeError:
pass
# Try to find balanced braces
depth = 0
for i, ch in enumerate(text):
if ch == '{':
depth += 1
elif ch == '}':
depth -= 1
if depth == 0:
try:
return json.loads(text[:i+1])
except json.JSONDecodeError:
continue
raise ValueError(f"Could not parse JSON from LLM output: {raw[:200]}...")
def _validate_scores(self, scores: dict) -> None:
"""Validate that all required dimensions are present with valid scores."""
for dim in self.DIMENSIONS:
if dim not in scores:
raise ValueError(f"Missing dimension: {dim}")
if "score" not in scores[dim]:
raise ValueError(f"Missing score for {dim}")
score = scores[dim]["score"]
if not isinstance(score, (int, float)) or score < 0 or score > 10:
raise ValueError(f"Invalid score for {dim}: {score}")
# Ensure score is int
scores[dim]["score"] = int(round(score))
# Ensure flags is a list
if "flags" not in scores[dim]:
scores[dim]["flags"] = []
if isinstance(scores[dim]["flags"], str):
scores[dim]["flags"] = [scores[dim]["flags"]]
# Ensure reasoning exists
if "reasoning" not in scores[dim]:
scores[dim]["reasoning"] = ""
def _format_features(self, features: PatchFeatures) -> str:
"""Format features into a readable summary for the prompt."""
d = features.to_dict()
# Format lists as comma-separated
for key in ['added_functions', 'modified_functions', 'new_imports',
'style_violations', 'issue_keywords_addressed',
'issue_components_mentioned']:
if isinstance(d.get(key), list):
d[key] = ', '.join(str(x) for x in d[key][:10]) or 'none'
return FEATURES_TEMPLATE.format(**d)
def _truncate(self, text: str, max_chars: int) -> str:
"""Truncate text, keeping beginning and end."""
if len(text) <= max_chars:
return text
half = max_chars // 2
return text[:half] + "\n\n... [truncated] ...\n\n" + text[-half:]
# ============================================================================
# Convenience functions
# ============================================================================
def quick_judge(
problem_statement: str,
agent_patch: str,
gold_patch: str = "",
test_passed: bool = True,
model_id: str = "Qwen/Qwen2.5-Coder-32B-Instruct",
) -> JudgeResult:
"""Quick one-shot evaluation of a patch.
Args:
problem_statement: The GitHub issue text.
agent_patch: The AI-generated diff.
gold_patch: Optional reference patch.
test_passed: Whether tests passed.
model_id: LLM to use.
Returns:
JudgeResult with MergeScore and breakdown.
"""
example = PatchExample(
instance_id="quick-judge",
repo="unknown",
problem_statement=problem_statement,
gold_patch=gold_patch,
agent_patch=agent_patch,
agent_name="unknown",
test_passed=test_passed,
base_commit="",
)
judge = PatchJudge(model_id=model_id)
return judge.judge(example)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
# Quick test with a sample
result = quick_judge(
problem_statement="Fix the divide by zero error in calculate_average when the list is empty",
agent_patch="""diff --git a/utils.py b/utils.py
--- a/utils.py
+++ b/utils.py
@@ -10,4 +10,6 @@
def calculate_average(numbers):
- return sum(numbers) / len(numbers)
+ if not numbers:
+ return 0.0
+ return sum(numbers) / len(numbers)
""",
gold_patch="""diff --git a/utils.py b/utils.py
--- a/utils.py
+++ b/utils.py
@@ -10,4 +10,7 @@
def calculate_average(numbers):
+ if not numbers:
+ raise ValueError("Cannot calculate average of empty list")
return sum(numbers) / len(numbers)
""",
test_passed=True,
)
print(result.summary())