File size: 16,052 Bytes
59ef264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
"""PatchJudge LLM Judge — Core evaluation engine.

Evaluates AI-generated code patches on 5 dimensions using an LLM,
producing a MergeScore (0-100) that indicates merge-readiness.

Uses HuggingFace Inference API with structured JSON output.
"""

import json
import os
import re
import time
import logging
from typing import Optional

from huggingface_hub import InferenceClient

from patchjudge.models import (
    PatchExample, PatchFeatures, DimensionScore, JudgeResult
)
from patchjudge.feature_extractor import FeatureExtractor

logger = logging.getLogger(__name__)


# ============================================================================
# Prompt Templates
# ============================================================================

JUDGE_SYSTEM_PROMPT = """You are PatchJudge, an expert senior software engineer evaluating whether an AI-generated code patch is truly merge-worthy — not just whether it passes tests.

You must be HARSH and PRECISE. A patch that "works but is bad code" should score low. A patch that is clean, complete, and genuinely solves the root cause should score high.

Most AI-generated patches that pass tests are NOT merge-worthy. Average scores should be 3-5, not 7-8. A score of 7+ means genuinely good, publishable code."""


JUDGE_USER_PROMPT = """Evaluate this AI-generated code patch.

## THE ISSUE:
{problem_statement}

## THE PATCH (diff):
```diff
{agent_patch}
```

## REFERENCE GOLD PATCH (human-written):
```diff
{gold_patch}
```

## EXTRACTED FEATURES:
{features_summary}

## TEST RESULT: {test_result}

---

Score the patch on each of these 5 dimensions (0-10 integer each):

1. **CORRECTNESS** (weight: 30%): Does the patch address the ROOT CAUSE described in the issue? Would the issue be genuinely resolved for all described scenarios, not just the test cases?

2. **COMPLETENESS** (weight: 20%): Does it handle edge cases? Is error handling added where appropriate? Are there TODO comments or placeholder logic left behind?

3. **CODE QUALITY** (weight: 20%): Does the code follow the project's existing style? Is it readable, well-structured? No unnecessary complexity?

4. **NON-REGRESSION RISK** (weight: 15%): Is the change scope appropriate? Could it break unrelated functionality? Does it modify shared interfaces unnecessarily?

5. **MERGE-READINESS** (weight: 15%): Would a senior engineer approve this PR as-is? Score 8+ = approve, 5-7 = request changes, below 5 = reject.

---

Respond with ONLY this JSON (no other text):
```json
{{
  "correctness": {{"score": <0-10>, "reasoning": "<2-4 sentences>", "flags": ["<issue1>", ...]}},
  "completeness": {{"score": <0-10>, "reasoning": "<2-4 sentences>", "flags": ["<issue1>", ...]}},
  "code_quality": {{"score": <0-10>, "reasoning": "<2-4 sentences>", "flags": ["<issue1>", ...]}},
  "non_regression_risk": {{"score": <0-10>, "reasoning": "<2-4 sentences>", "flags": ["<issue1>", ...]}},
  "merge_readiness": {{"score": <0-10>, "reasoning": "<2-4 sentences>", "flags": ["<issue1>", ...]}}
}}
```"""


FEATURES_TEMPLATE = """- Files changed: {num_files_changed}
- Lines added: {num_lines_added}, removed: {num_lines_removed}
- Hunks: {num_hunks}
- Change scope: {change_scope}
- Added functions: {added_functions}
- Modified functions: {modified_functions}
- Error handling present: {has_error_handling}
- Edge case handling: {has_edge_case_handling}
- Has TODOs/FIXMEs: {has_todos}
- Has hardcoded values: {has_hardcoded_values}
- Has debug statements: {has_debug_statements}
- Modifies core files: {modifies_core_files}
- New imports: {new_imports}
- Issue keyword coverage: {keyword_coverage_ratio:.0%}
- Touches test files: {touches_tests}
- Style violations: {style_violations}"""


# ============================================================================
# PatchJudge Class
# ============================================================================

class PatchJudge:
    """LLM-based judge for evaluating AI-generated code patches."""
    
    WEIGHTS = {
        "correctness": 0.30,
        "completeness": 0.20,
        "code_quality": 0.20,
        "non_regression_risk": 0.15,
        "merge_readiness": 0.15,
    }
    
    DIMENSIONS = list(WEIGHTS.keys())
    
    def __init__(
        self,
        model_id: str = "Qwen/Qwen2.5-Coder-32B-Instruct",
        provider: str = "auto",
        temperature: float = 0.1,
        max_tokens: int = 2000,
        max_retries: int = 3,
        retry_delay: float = 2.0,
        max_context_chars: int = 12000,
    ):
        """Initialize PatchJudge.
        
        Args:
            model_id: HF model ID to use for judging.
            provider: Inference provider ('auto', 'cerebras', 'novita', etc.)
            temperature: Low for consistency (0.1 recommended).
            max_tokens: Max tokens for LLM response.
            max_retries: Retries on API/parse failures.
            retry_delay: Seconds between retries.
            max_context_chars: Max chars for patch/context in prompt.
        """
        token = os.environ.get("HF_TOKEN")
        self.client = InferenceClient(
            provider=provider,
            api_key=token,
        )
        self.model_id = model_id
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.max_retries = max_retries
        self.retry_delay = retry_delay
        self.max_context_chars = max_context_chars
        self.feature_extractor = FeatureExtractor()
    
    def judge(
        self,
        example: PatchExample,
        features: Optional[PatchFeatures] = None,
    ) -> JudgeResult:
        """Evaluate a single patch example.
        
        Args:
            example: The patch to evaluate.
            features: Pre-extracted features (extracted automatically if None).
            
        Returns:
            JudgeResult with MergeScore, dimension scores, and reasoning.
        """
        # Extract features if not provided
        if features is None:
            features = self.feature_extractor.extract(example)
        
        # Format the prompt
        features_summary = self._format_features(features)
        
        # Truncate patches if needed
        agent_patch = self._truncate(example.agent_patch, self.max_context_chars // 2)
        gold_patch = self._truncate(example.gold_patch, self.max_context_chars // 4)
        problem_stmt = self._truncate(example.problem_statement, self.max_context_chars // 4)
        
        user_prompt = JUDGE_USER_PROMPT.format(
            problem_statement=problem_stmt,
            agent_patch=agent_patch,
            gold_patch=gold_patch,
            features_summary=features_summary,
            test_result="PASSED ✓" if example.test_passed else "FAILED ✗",
        )
        
        # Call LLM with retries
        raw_output = None
        scores = None
        
        for attempt in range(self.max_retries):
            try:
                raw_output = self._call_llm(user_prompt)
                scores = self._parse_json_output(raw_output)
                self._validate_scores(scores)
                break
            except Exception as e:
                logger.warning(
                    f"Attempt {attempt+1}/{self.max_retries} failed: {e}"
                )
                if attempt < self.max_retries - 1:
                    time.sleep(self.retry_delay * (attempt + 1))
        
        if scores is None:
            # Return a failure result
            logger.error(
                f"Failed to judge {example.instance_id} after {self.max_retries} attempts"
            )
            scores = {
                dim: {"score": 0, "reasoning": "Judge failed to produce valid output", "flags": ["JUDGE_ERROR"]}
                for dim in self.DIMENSIONS
            }
            raw_output = raw_output or "ERROR: No output from LLM"
        
        # Compute MergeScore
        merge_score = self._compute_merge_score(scores)
        
        return JudgeResult(
            merge_score=merge_score,
            dimension_scores=scores,
            raw_output=raw_output,
            features=features,
            model_used=self.model_id,
        )
    
    def judge_batch(
        self,
        examples: list[PatchExample],
        features_list: Optional[list[PatchFeatures]] = None,
        show_progress: bool = True,
    ) -> list[JudgeResult]:
        """Evaluate a batch of patches.
        
        Args:
            examples: List of PatchExamples to evaluate.
            features_list: Pre-extracted features (one per example). Optional.
            show_progress: Print progress.
            
        Returns:
            List of JudgeResults in same order as input.
        """
        results = []
        
        for i, example in enumerate(examples):
            if show_progress:
                print(f"  Judging [{i+1}/{len(examples)}] {example.instance_id} "
                      f"({example.agent_name})...")
            
            features = features_list[i] if features_list else None
            
            try:
                result = self.judge(example, features)
                results.append(result)
                
                if show_progress:
                    print(f"    MergeScore: {result.merge_score:.1f}/100")
                    
            except Exception as e:
                logger.error(f"Failed to judge {example.instance_id}: {e}")
                # Append error result
                results.append(JudgeResult(
                    merge_score=0.0,
                    dimension_scores={
                        dim: {"score": 0, "reasoning": f"Error: {str(e)}", "flags": ["ERROR"]}
                        for dim in self.DIMENSIONS
                    },
                    raw_output=f"ERROR: {str(e)}",
                    model_used=self.model_id,
                ))
            
            # Rate limiting
            time.sleep(0.5)
        
        return results
    
    # =========================================================================
    # Internal methods
    # =========================================================================
    
    def _call_llm(self, user_prompt: str) -> str:
        """Call the LLM and return raw text response."""
        response = self.client.chat_completion(
            model=self.model_id,
            messages=[
                {"role": "system", "content": JUDGE_SYSTEM_PROMPT},
                {"role": "user", "content": user_prompt},
            ],
            max_tokens=self.max_tokens,
            temperature=self.temperature,
        )
        return response.choices[0].message.content
    
    def _compute_merge_score(self, scores: dict) -> float:
        """Compute weighted MergeScore (0-100) from dimension scores."""
        weighted_sum = 0.0
        for dim, weight in self.WEIGHTS.items():
            dim_score = scores.get(dim, {}).get("score", 0)
            weighted_sum += dim_score * weight
        return round(weighted_sum * 10, 1)  # Scale 0-10 → 0-100
    
    def _parse_json_output(self, raw: str) -> dict:
        """Extract JSON from LLM output, handling markdown code blocks."""
        # Try to find JSON in code blocks
        json_match = re.search(r'```(?:json)?\s*([\{][\s\S]*?[\}])\s*```', raw)
        if json_match:
            return json.loads(json_match.group(1))
        
        # Try to find raw JSON object
        json_match = re.search(r'(\{[\s\S]*\})', raw)
        if json_match:
            # Try parsing progressively larger substrings
            text = json_match.group(1)
            try:
                return json.loads(text)
            except json.JSONDecodeError:
                pass
            
            # Try to find balanced braces
            depth = 0
            for i, ch in enumerate(text):
                if ch == '{':
                    depth += 1
                elif ch == '}':
                    depth -= 1
                    if depth == 0:
                        try:
                            return json.loads(text[:i+1])
                        except json.JSONDecodeError:
                            continue
        
        raise ValueError(f"Could not parse JSON from LLM output: {raw[:200]}...")
    
    def _validate_scores(self, scores: dict) -> None:
        """Validate that all required dimensions are present with valid scores."""
        for dim in self.DIMENSIONS:
            if dim not in scores:
                raise ValueError(f"Missing dimension: {dim}")
            if "score" not in scores[dim]:
                raise ValueError(f"Missing score for {dim}")
            score = scores[dim]["score"]
            if not isinstance(score, (int, float)) or score < 0 or score > 10:
                raise ValueError(f"Invalid score for {dim}: {score}")
            # Ensure score is int
            scores[dim]["score"] = int(round(score))
            # Ensure flags is a list
            if "flags" not in scores[dim]:
                scores[dim]["flags"] = []
            if isinstance(scores[dim]["flags"], str):
                scores[dim]["flags"] = [scores[dim]["flags"]]
            # Ensure reasoning exists
            if "reasoning" not in scores[dim]:
                scores[dim]["reasoning"] = ""
    
    def _format_features(self, features: PatchFeatures) -> str:
        """Format features into a readable summary for the prompt."""
        d = features.to_dict()
        # Format lists as comma-separated
        for key in ['added_functions', 'modified_functions', 'new_imports',
                     'style_violations', 'issue_keywords_addressed',
                     'issue_components_mentioned']:
            if isinstance(d.get(key), list):
                d[key] = ', '.join(str(x) for x in d[key][:10]) or 'none'
        
        return FEATURES_TEMPLATE.format(**d)
    
    def _truncate(self, text: str, max_chars: int) -> str:
        """Truncate text, keeping beginning and end."""
        if len(text) <= max_chars:
            return text
        half = max_chars // 2
        return text[:half] + "\n\n... [truncated] ...\n\n" + text[-half:]


# ============================================================================
# Convenience functions
# ============================================================================

def quick_judge(
    problem_statement: str,
    agent_patch: str,
    gold_patch: str = "",
    test_passed: bool = True,
    model_id: str = "Qwen/Qwen2.5-Coder-32B-Instruct",
) -> JudgeResult:
    """Quick one-shot evaluation of a patch.
    
    Args:
        problem_statement: The GitHub issue text.
        agent_patch: The AI-generated diff.
        gold_patch: Optional reference patch.
        test_passed: Whether tests passed.
        model_id: LLM to use.
        
    Returns:
        JudgeResult with MergeScore and breakdown.
    """
    example = PatchExample(
        instance_id="quick-judge",
        repo="unknown",
        problem_statement=problem_statement,
        gold_patch=gold_patch,
        agent_patch=agent_patch,
        agent_name="unknown",
        test_passed=test_passed,
        base_commit="",
    )
    
    judge = PatchJudge(model_id=model_id)
    return judge.judge(example)


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    
    # Quick test with a sample
    result = quick_judge(
        problem_statement="Fix the divide by zero error in calculate_average when the list is empty",
        agent_patch="""diff --git a/utils.py b/utils.py
--- a/utils.py
+++ b/utils.py
@@ -10,4 +10,6 @@
 def calculate_average(numbers):
-    return sum(numbers) / len(numbers)
+    if not numbers:
+        return 0.0
+    return sum(numbers) / len(numbers)
""",
        gold_patch="""diff --git a/utils.py b/utils.py
--- a/utils.py
+++ b/utils.py
@@ -10,4 +10,7 @@
 def calculate_average(numbers):
+    if not numbers:
+        raise ValueError("Cannot calculate average of empty list")
     return sum(numbers) / len(numbers)
""",
        test_passed=True,
    )
    
    print(result.summary())