File size: 5,275 Bytes
761f203
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc990fa
761f203
 
 
dc990fa
761f203
 
 
 
 
 
 
 
 
 
dc990fa
761f203
 
 
 
 
 
 
dc990fa
761f203
 
 
 
dc990fa
761f203
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc990fa
761f203
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc990fa
 
761f203
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
from __future__ import annotations

import json
import os
import subprocess
import tempfile
from pathlib import Path

from openai import OpenAI

from env.models import FlakySleuthAction

CATEGORY_DESCRIPTIONS = {
    "TD": "Time-Dependent: fails due to wall-clock time assumptions",
    "TZD": "Timezone-Dependent: fails across timezone settings",
    "NOD": "Non-Deterministic: fails due to randomness/non-determinism",
    "NIO": "Non-Idempotent-Outcome: passes first run, fails on repeated run",
    "ID": "Implementation-Dependent: fails due to runtime implementation details",
}

EXPECTED_FIX_PATTERNS = {
    "TD": ["freeze_time", "mock", "patch", "utcnow", "datetime", "monkeypatch"],
    "TZD": ["timezone", "utc", "pytz", "zoneinfo", "tzinfo", "UTC"],
    "NOD": ["seed", "mock", "patch", "deterministic", "sorted"],
    "NIO": ["setup", "teardown", "fixture", "yield", "cleanup", "autouse"],
    "ID": ["sorted(", "list(", "frozenset", "OrderedDict"],
}


def grade(action: FlakySleuthAction, task: dict) -> float:
    """Hybrid fixer grader: pattern + dry-run apply + LLM judge."""
    if action.action_type != "propose_fix":
        return 0.001

    proposed_fix = action.argument.strip()
    if not proposed_fix:
        return 0.001

    category = str(task.get("category", "")).split(";")[0].strip().upper()
    known_fix = task.get("known_fix_diff", "") or ""
    test_code = task.get("test_code", "") or ""

    patterns = EXPECTED_FIX_PATTERNS.get(category, [])
    if patterns:
        matches = sum(
            1 for pattern in patterns if pattern.lower() in proposed_fix.lower()
        )
        pattern_score = min(0.999, matches / max(1, len(patterns) * 0.4))
    else:
        pattern_score = 0.5

    apply_score = _check_diff_applies(proposed_fix, task)
    judge_score = _llm_judge(proposed_fix, known_fix, category, test_code)

    total = (0.35 * pattern_score) + (0.25 * apply_score) + (0.40 * judge_score)
    return round(min(0.999, max(0.001, total)), 4)


def _check_diff_applies(diff_text: str, task: dict) -> float:
    if "+++" not in diff_text or "---" not in diff_text:
        return 0.001

    repo_root = str(task.get("sandbox_root", "")).strip()
    if not repo_root or not Path(repo_root).exists():
        return 0.3

    patch_path = None
    try:
        with tempfile.NamedTemporaryFile(
            mode="w", suffix=".patch", delete=False
        ) as handle:
            handle.write(diff_text)
            patch_path = handle.name

        result = subprocess.run(
            ["patch", "--dry-run", "-p1", "-i", patch_path],
            cwd=repo_root,
            capture_output=True,
            text=True,
            timeout=10,
        )
        return 0.999 if result.returncode == 0 else 0.001
    except Exception:
        return 0.3
    finally:
        if patch_path and os.path.exists(patch_path):
            os.unlink(patch_path)


def _llm_judge(proposed: str, known: str, category: str, test_code: str) -> float:
    openrouter_key = os.environ.get("OPENROUTER_API_KEY")
    openai_key = os.environ.get("OPENAI_API_KEY")
    raw_api_key = os.environ.get("API_KEY")
    api_key = (raw_api_key or openrouter_key or openai_key or "").strip()
    if not api_key:
        return 0.5

    using_openrouter = (openrouter_key and not raw_api_key and not openai_key) or (
        raw_api_key and raw_api_key.startswith("sk-or-") and not openai_key
    )

    default_base_url = (
        "https://openrouter.ai/api/v1"
        if using_openrouter
        else "https://api.openai.com/v1"
    )
    api_base_url = os.environ.get("API_BASE_URL", default_base_url)
    client = OpenAI(api_key=api_key, base_url=api_base_url)
    model = os.environ.get(
        "MODEL_NAME",
        "qwen/qwen3.6-plus:free"
        if api_base_url.startswith("https://openrouter.ai")
        else "gpt-4o-mini",
    )

    cat_desc = CATEGORY_DESCRIPTIONS.get(category, f"Flakiness category: {category}")
    if known:
        known_section = f"Known accepted fix (from merged PR):\n```\n{known[:800]}\n```"
    else:
        known_section = "Known fix: Not available"

    prompt = f"""You are evaluating a proposed fix for a flaky Python test.

Flakiness category: {category}
What this means: {cat_desc}

Original flaky test code:
```python
{test_code[:1000]}
```

Proposed fix (unified diff):
```
{proposed[:1000]}
```

{known_section}

Score the proposed fix from 0 to 10:
- 0-2: Fix is wrong, irrelevant, or harmful
- 3-5: Fix partially addresses the issue but misses root cause
- 6-8: Fix addresses root cause with minor issues
- 9-10: Fix is correct, minimal, and complete

Respond ONLY with JSON:
{{"score": <integer 0-10>, "reason": "<one sentence>"}}"""

    try:
        response = client.chat.completions.create(
            model=model,
            messages=[{"role": "user", "content": prompt}],
            max_tokens=120,
            temperature=0.0,
        )
        raw = (response.choices[0].message.content or "").strip()
        raw = raw.replace("```json", "").replace("```", "").strip()
        payload = json.loads(raw)
        score = int(payload.get("score", 5))
        raw_score = max(0.0, min(10.0, score)) / 10.0
        return max(0.001, min(0.999, raw_score))
    except Exception:
        return 0.5