VD10 commited on
Commit
59ef264
·
verified ·
1 Parent(s): 60cec59

Upload patchjudge/judge.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. patchjudge/judge.py +441 -0
patchjudge/judge.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PatchJudge LLM Judge — Core evaluation engine.
2
+
3
+ Evaluates AI-generated code patches on 5 dimensions using an LLM,
4
+ producing a MergeScore (0-100) that indicates merge-readiness.
5
+
6
+ Uses HuggingFace Inference API with structured JSON output.
7
+ """
8
+
9
+ import json
10
+ import os
11
+ import re
12
+ import time
13
+ import logging
14
+ from typing import Optional
15
+
16
+ from huggingface_hub import InferenceClient
17
+
18
+ from patchjudge.models import (
19
+ PatchExample, PatchFeatures, DimensionScore, JudgeResult
20
+ )
21
+ from patchjudge.feature_extractor import FeatureExtractor
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ # ============================================================================
27
+ # Prompt Templates
28
+ # ============================================================================
29
+
30
+ JUDGE_SYSTEM_PROMPT = """You are PatchJudge, an expert senior software engineer evaluating whether an AI-generated code patch is truly merge-worthy — not just whether it passes tests.
31
+
32
+ You must be HARSH and PRECISE. A patch that "works but is bad code" should score low. A patch that is clean, complete, and genuinely solves the root cause should score high.
33
+
34
+ Most AI-generated patches that pass tests are NOT merge-worthy. Average scores should be 3-5, not 7-8. A score of 7+ means genuinely good, publishable code."""
35
+
36
+
37
+ JUDGE_USER_PROMPT = """Evaluate this AI-generated code patch.
38
+
39
+ ## THE ISSUE:
40
+ {problem_statement}
41
+
42
+ ## THE PATCH (diff):
43
+ ```diff
44
+ {agent_patch}
45
+ ```
46
+
47
+ ## REFERENCE GOLD PATCH (human-written):
48
+ ```diff
49
+ {gold_patch}
50
+ ```
51
+
52
+ ## EXTRACTED FEATURES:
53
+ {features_summary}
54
+
55
+ ## TEST RESULT: {test_result}
56
+
57
+ ---
58
+
59
+ Score the patch on each of these 5 dimensions (0-10 integer each):
60
+
61
+ 1. **CORRECTNESS** (weight: 30%): Does the patch address the ROOT CAUSE described in the issue? Would the issue be genuinely resolved for all described scenarios, not just the test cases?
62
+
63
+ 2. **COMPLETENESS** (weight: 20%): Does it handle edge cases? Is error handling added where appropriate? Are there TODO comments or placeholder logic left behind?
64
+
65
+ 3. **CODE QUALITY** (weight: 20%): Does the code follow the project's existing style? Is it readable, well-structured? No unnecessary complexity?
66
+
67
+ 4. **NON-REGRESSION RISK** (weight: 15%): Is the change scope appropriate? Could it break unrelated functionality? Does it modify shared interfaces unnecessarily?
68
+
69
+ 5. **MERGE-READINESS** (weight: 15%): Would a senior engineer approve this PR as-is? Score 8+ = approve, 5-7 = request changes, below 5 = reject.
70
+
71
+ ---
72
+
73
+ Respond with ONLY this JSON (no other text):
74
+ ```json
75
+ {{
76
+ "correctness": {{"score": <0-10>, "reasoning": "<2-4 sentences>", "flags": ["<issue1>", ...]}},
77
+ "completeness": {{"score": <0-10>, "reasoning": "<2-4 sentences>", "flags": ["<issue1>", ...]}},
78
+ "code_quality": {{"score": <0-10>, "reasoning": "<2-4 sentences>", "flags": ["<issue1>", ...]}},
79
+ "non_regression_risk": {{"score": <0-10>, "reasoning": "<2-4 sentences>", "flags": ["<issue1>", ...]}},
80
+ "merge_readiness": {{"score": <0-10>, "reasoning": "<2-4 sentences>", "flags": ["<issue1>", ...]}}
81
+ }}
82
+ ```"""
83
+
84
+
85
+ FEATURES_TEMPLATE = """- Files changed: {num_files_changed}
86
+ - Lines added: {num_lines_added}, removed: {num_lines_removed}
87
+ - Hunks: {num_hunks}
88
+ - Change scope: {change_scope}
89
+ - Added functions: {added_functions}
90
+ - Modified functions: {modified_functions}
91
+ - Error handling present: {has_error_handling}
92
+ - Edge case handling: {has_edge_case_handling}
93
+ - Has TODOs/FIXMEs: {has_todos}
94
+ - Has hardcoded values: {has_hardcoded_values}
95
+ - Has debug statements: {has_debug_statements}
96
+ - Modifies core files: {modifies_core_files}
97
+ - New imports: {new_imports}
98
+ - Issue keyword coverage: {keyword_coverage_ratio:.0%}
99
+ - Touches test files: {touches_tests}
100
+ - Style violations: {style_violations}"""
101
+
102
+
103
+ # ============================================================================
104
+ # PatchJudge Class
105
+ # ============================================================================
106
+
107
+ class PatchJudge:
108
+ """LLM-based judge for evaluating AI-generated code patches."""
109
+
110
+ WEIGHTS = {
111
+ "correctness": 0.30,
112
+ "completeness": 0.20,
113
+ "code_quality": 0.20,
114
+ "non_regression_risk": 0.15,
115
+ "merge_readiness": 0.15,
116
+ }
117
+
118
+ DIMENSIONS = list(WEIGHTS.keys())
119
+
120
+ def __init__(
121
+ self,
122
+ model_id: str = "Qwen/Qwen2.5-Coder-32B-Instruct",
123
+ provider: str = "auto",
124
+ temperature: float = 0.1,
125
+ max_tokens: int = 2000,
126
+ max_retries: int = 3,
127
+ retry_delay: float = 2.0,
128
+ max_context_chars: int = 12000,
129
+ ):
130
+ """Initialize PatchJudge.
131
+
132
+ Args:
133
+ model_id: HF model ID to use for judging.
134
+ provider: Inference provider ('auto', 'cerebras', 'novita', etc.)
135
+ temperature: Low for consistency (0.1 recommended).
136
+ max_tokens: Max tokens for LLM response.
137
+ max_retries: Retries on API/parse failures.
138
+ retry_delay: Seconds between retries.
139
+ max_context_chars: Max chars for patch/context in prompt.
140
+ """
141
+ token = os.environ.get("HF_TOKEN")
142
+ self.client = InferenceClient(
143
+ provider=provider,
144
+ api_key=token,
145
+ )
146
+ self.model_id = model_id
147
+ self.temperature = temperature
148
+ self.max_tokens = max_tokens
149
+ self.max_retries = max_retries
150
+ self.retry_delay = retry_delay
151
+ self.max_context_chars = max_context_chars
152
+ self.feature_extractor = FeatureExtractor()
153
+
154
+ def judge(
155
+ self,
156
+ example: PatchExample,
157
+ features: Optional[PatchFeatures] = None,
158
+ ) -> JudgeResult:
159
+ """Evaluate a single patch example.
160
+
161
+ Args:
162
+ example: The patch to evaluate.
163
+ features: Pre-extracted features (extracted automatically if None).
164
+
165
+ Returns:
166
+ JudgeResult with MergeScore, dimension scores, and reasoning.
167
+ """
168
+ # Extract features if not provided
169
+ if features is None:
170
+ features = self.feature_extractor.extract(example)
171
+
172
+ # Format the prompt
173
+ features_summary = self._format_features(features)
174
+
175
+ # Truncate patches if needed
176
+ agent_patch = self._truncate(example.agent_patch, self.max_context_chars // 2)
177
+ gold_patch = self._truncate(example.gold_patch, self.max_context_chars // 4)
178
+ problem_stmt = self._truncate(example.problem_statement, self.max_context_chars // 4)
179
+
180
+ user_prompt = JUDGE_USER_PROMPT.format(
181
+ problem_statement=problem_stmt,
182
+ agent_patch=agent_patch,
183
+ gold_patch=gold_patch,
184
+ features_summary=features_summary,
185
+ test_result="PASSED ✓" if example.test_passed else "FAILED ✗",
186
+ )
187
+
188
+ # Call LLM with retries
189
+ raw_output = None
190
+ scores = None
191
+
192
+ for attempt in range(self.max_retries):
193
+ try:
194
+ raw_output = self._call_llm(user_prompt)
195
+ scores = self._parse_json_output(raw_output)
196
+ self._validate_scores(scores)
197
+ break
198
+ except Exception as e:
199
+ logger.warning(
200
+ f"Attempt {attempt+1}/{self.max_retries} failed: {e}"
201
+ )
202
+ if attempt < self.max_retries - 1:
203
+ time.sleep(self.retry_delay * (attempt + 1))
204
+
205
+ if scores is None:
206
+ # Return a failure result
207
+ logger.error(
208
+ f"Failed to judge {example.instance_id} after {self.max_retries} attempts"
209
+ )
210
+ scores = {
211
+ dim: {"score": 0, "reasoning": "Judge failed to produce valid output", "flags": ["JUDGE_ERROR"]}
212
+ for dim in self.DIMENSIONS
213
+ }
214
+ raw_output = raw_output or "ERROR: No output from LLM"
215
+
216
+ # Compute MergeScore
217
+ merge_score = self._compute_merge_score(scores)
218
+
219
+ return JudgeResult(
220
+ merge_score=merge_score,
221
+ dimension_scores=scores,
222
+ raw_output=raw_output,
223
+ features=features,
224
+ model_used=self.model_id,
225
+ )
226
+
227
+ def judge_batch(
228
+ self,
229
+ examples: list[PatchExample],
230
+ features_list: Optional[list[PatchFeatures]] = None,
231
+ show_progress: bool = True,
232
+ ) -> list[JudgeResult]:
233
+ """Evaluate a batch of patches.
234
+
235
+ Args:
236
+ examples: List of PatchExamples to evaluate.
237
+ features_list: Pre-extracted features (one per example). Optional.
238
+ show_progress: Print progress.
239
+
240
+ Returns:
241
+ List of JudgeResults in same order as input.
242
+ """
243
+ results = []
244
+
245
+ for i, example in enumerate(examples):
246
+ if show_progress:
247
+ print(f" Judging [{i+1}/{len(examples)}] {example.instance_id} "
248
+ f"({example.agent_name})...")
249
+
250
+ features = features_list[i] if features_list else None
251
+
252
+ try:
253
+ result = self.judge(example, features)
254
+ results.append(result)
255
+
256
+ if show_progress:
257
+ print(f" MergeScore: {result.merge_score:.1f}/100")
258
+
259
+ except Exception as e:
260
+ logger.error(f"Failed to judge {example.instance_id}: {e}")
261
+ # Append error result
262
+ results.append(JudgeResult(
263
+ merge_score=0.0,
264
+ dimension_scores={
265
+ dim: {"score": 0, "reasoning": f"Error: {str(e)}", "flags": ["ERROR"]}
266
+ for dim in self.DIMENSIONS
267
+ },
268
+ raw_output=f"ERROR: {str(e)}",
269
+ model_used=self.model_id,
270
+ ))
271
+
272
+ # Rate limiting
273
+ time.sleep(0.5)
274
+
275
+ return results
276
+
277
+ # =========================================================================
278
+ # Internal methods
279
+ # =========================================================================
280
+
281
+ def _call_llm(self, user_prompt: str) -> str:
282
+ """Call the LLM and return raw text response."""
283
+ response = self.client.chat_completion(
284
+ model=self.model_id,
285
+ messages=[
286
+ {"role": "system", "content": JUDGE_SYSTEM_PROMPT},
287
+ {"role": "user", "content": user_prompt},
288
+ ],
289
+ max_tokens=self.max_tokens,
290
+ temperature=self.temperature,
291
+ )
292
+ return response.choices[0].message.content
293
+
294
+ def _compute_merge_score(self, scores: dict) -> float:
295
+ """Compute weighted MergeScore (0-100) from dimension scores."""
296
+ weighted_sum = 0.0
297
+ for dim, weight in self.WEIGHTS.items():
298
+ dim_score = scores.get(dim, {}).get("score", 0)
299
+ weighted_sum += dim_score * weight
300
+ return round(weighted_sum * 10, 1) # Scale 0-10 → 0-100
301
+
302
+ def _parse_json_output(self, raw: str) -> dict:
303
+ """Extract JSON from LLM output, handling markdown code blocks."""
304
+ # Try to find JSON in code blocks
305
+ json_match = re.search(r'```(?:json)?\s*([\{][\s\S]*?[\}])\s*```', raw)
306
+ if json_match:
307
+ return json.loads(json_match.group(1))
308
+
309
+ # Try to find raw JSON object
310
+ json_match = re.search(r'(\{[\s\S]*\})', raw)
311
+ if json_match:
312
+ # Try parsing progressively larger substrings
313
+ text = json_match.group(1)
314
+ try:
315
+ return json.loads(text)
316
+ except json.JSONDecodeError:
317
+ pass
318
+
319
+ # Try to find balanced braces
320
+ depth = 0
321
+ for i, ch in enumerate(text):
322
+ if ch == '{':
323
+ depth += 1
324
+ elif ch == '}':
325
+ depth -= 1
326
+ if depth == 0:
327
+ try:
328
+ return json.loads(text[:i+1])
329
+ except json.JSONDecodeError:
330
+ continue
331
+
332
+ raise ValueError(f"Could not parse JSON from LLM output: {raw[:200]}...")
333
+
334
+ def _validate_scores(self, scores: dict) -> None:
335
+ """Validate that all required dimensions are present with valid scores."""
336
+ for dim in self.DIMENSIONS:
337
+ if dim not in scores:
338
+ raise ValueError(f"Missing dimension: {dim}")
339
+ if "score" not in scores[dim]:
340
+ raise ValueError(f"Missing score for {dim}")
341
+ score = scores[dim]["score"]
342
+ if not isinstance(score, (int, float)) or score < 0 or score > 10:
343
+ raise ValueError(f"Invalid score for {dim}: {score}")
344
+ # Ensure score is int
345
+ scores[dim]["score"] = int(round(score))
346
+ # Ensure flags is a list
347
+ if "flags" not in scores[dim]:
348
+ scores[dim]["flags"] = []
349
+ if isinstance(scores[dim]["flags"], str):
350
+ scores[dim]["flags"] = [scores[dim]["flags"]]
351
+ # Ensure reasoning exists
352
+ if "reasoning" not in scores[dim]:
353
+ scores[dim]["reasoning"] = ""
354
+
355
+ def _format_features(self, features: PatchFeatures) -> str:
356
+ """Format features into a readable summary for the prompt."""
357
+ d = features.to_dict()
358
+ # Format lists as comma-separated
359
+ for key in ['added_functions', 'modified_functions', 'new_imports',
360
+ 'style_violations', 'issue_keywords_addressed',
361
+ 'issue_components_mentioned']:
362
+ if isinstance(d.get(key), list):
363
+ d[key] = ', '.join(str(x) for x in d[key][:10]) or 'none'
364
+
365
+ return FEATURES_TEMPLATE.format(**d)
366
+
367
+ def _truncate(self, text: str, max_chars: int) -> str:
368
+ """Truncate text, keeping beginning and end."""
369
+ if len(text) <= max_chars:
370
+ return text
371
+ half = max_chars // 2
372
+ return text[:half] + "\n\n... [truncated] ...\n\n" + text[-half:]
373
+
374
+
375
+ # ============================================================================
376
+ # Convenience functions
377
+ # ============================================================================
378
+
379
+ def quick_judge(
380
+ problem_statement: str,
381
+ agent_patch: str,
382
+ gold_patch: str = "",
383
+ test_passed: bool = True,
384
+ model_id: str = "Qwen/Qwen2.5-Coder-32B-Instruct",
385
+ ) -> JudgeResult:
386
+ """Quick one-shot evaluation of a patch.
387
+
388
+ Args:
389
+ problem_statement: The GitHub issue text.
390
+ agent_patch: The AI-generated diff.
391
+ gold_patch: Optional reference patch.
392
+ test_passed: Whether tests passed.
393
+ model_id: LLM to use.
394
+
395
+ Returns:
396
+ JudgeResult with MergeScore and breakdown.
397
+ """
398
+ example = PatchExample(
399
+ instance_id="quick-judge",
400
+ repo="unknown",
401
+ problem_statement=problem_statement,
402
+ gold_patch=gold_patch,
403
+ agent_patch=agent_patch,
404
+ agent_name="unknown",
405
+ test_passed=test_passed,
406
+ base_commit="",
407
+ )
408
+
409
+ judge = PatchJudge(model_id=model_id)
410
+ return judge.judge(example)
411
+
412
+
413
+ if __name__ == "__main__":
414
+ logging.basicConfig(level=logging.INFO)
415
+
416
+ # Quick test with a sample
417
+ result = quick_judge(
418
+ problem_statement="Fix the divide by zero error in calculate_average when the list is empty",
419
+ agent_patch="""diff --git a/utils.py b/utils.py
420
+ --- a/utils.py
421
+ +++ b/utils.py
422
+ @@ -10,4 +10,6 @@
423
+ def calculate_average(numbers):
424
+ - return sum(numbers) / len(numbers)
425
+ + if not numbers:
426
+ + return 0.0
427
+ + return sum(numbers) / len(numbers)
428
+ """,
429
+ gold_patch="""diff --git a/utils.py b/utils.py
430
+ --- a/utils.py
431
+ +++ b/utils.py
432
+ @@ -10,4 +10,7 @@
433
+ def calculate_average(numbers):
434
+ + if not numbers:
435
+ + raise ValueError("Cannot calculate average of empty list")
436
+ return sum(numbers) / len(numbers)
437
+ """,
438
+ test_passed=True,
439
+ )
440
+
441
+ print(result.summary())