File size: 5,867 Bytes
5db060f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# tasks/grader_medium.py
"""

Grader for logic_fix_001 β€” fix wrong JOIN type / WHERE logic.



Harder than easy: agent must get BOTH precision and recall right.

Extra penalty for wrong row count (catches SELECT * with no WHERE).

"""


def grade(

    task: dict,

    agent_query: str,

    run_result: dict,

    prev_absolute_score: float = 0.0,

    step_count: int = 1,

    max_steps: int = 8,

) -> dict:

    syntax_ok = run_result["error"] is None

    # ── Syntax ────────────────────────────────────────────────────────────────
    if not syntax_ok:
        absolute_score = 0.05
        delta = round(
            max(-0.3, min(0.5, absolute_score - prev_absolute_score)), 4
        )
        return {
            "value": delta,
            "absolute_score": absolute_score,
            "syntax_ok": False,
            "result_score": 0.0,
            "plan_score": 0.0,
            "delta": delta,
            "status": "syntax_error",
            "feedback": f"syntax_error: {run_result['error'][:100]}",
            "message": f"syntax_error | abs=0.050 | delta={delta:+.3f}",
        }

    expected = task["expected_rows"]
    got = run_result["rows"]

    # ── F1 row score ──────────────────────────────────────────────────────────
    if not got:
        result_score = 0.0
    else:
        correct_returned = sum(1 for row in got if row in expected)
        correct_expected = sum(1 for row in expected if row in got)

        precision = correct_returned / max(len(got), 1)
        recall    = correct_expected / max(len(expected), 1)

        if precision + recall > 0:
            result_score = 2 * precision * recall / (precision + recall)
        else:
            result_score = 0.0

    # ── Extra penalty for wrong row count ─────────────────────────────────────
    # Logic bugs typically show up as too many rows (LEFT JOIN returns NULLs)
    # Penalize harder than easy task to encourage precise reasoning
    row_count_penalty = 0.0
    if len(got) > len(expected):
        extra = len(got) - len(expected)
        row_count_penalty = min(0.25, extra * 0.08)

    # ── JOIN type hint score ──────────────────────────────────────────────────
    # Gives partial credit for using the right JOIN type even if rows are off
    # Avoids zero-reward cliff for agents that fix JOIN but have minor issues
    query_upper = agent_query.upper()
    join_score = 0.0
    if "INNER JOIN" in query_upper:
        join_score = 0.15   # using INNER JOIN is the right direction
    elif "LEFT JOIN" in query_upper:
        join_score = 0.0    # LEFT JOIN is the bug β€” no credit
    elif "JOIN" in query_upper:
        join_score = 0.05   # some join exists β€” small credit

    # ── Efficiency bonus ──────────────────────────────────────────────────────
    steps_remaining = max_steps - step_count
    efficiency_bonus = 0.0
    if result_score >= 0.99:
        efficiency_bonus = round(0.05 * (steps_remaining / max_steps), 4)

    # ── Absolute score β€” medium: syntax 10% + correctness 70% + join 15% + bonus 5% ──
    absolute_score = round(
        min(
            0.99,
            0.10 * 1.0
            + 0.70 * result_score
            + 0.15 * join_score
            + efficiency_bonus
            - row_count_penalty,
        ),
        4,
    )
    absolute_score = max(0.05, absolute_score)  # floor at 0.05

    # ── Delta ─────────────────────────────────────────────────────────────────
    delta = absolute_score - prev_absolute_score
    if abs(delta) < 0.001 and step_count > 1:
        delta -= 0.02
    delta = round(max(-0.3, min(0.5, delta)), 4)

    # ── Feedback ─────────────────────────────────────────────────────────────
    issues = []
    if "LEFT JOIN" in query_upper:
        issues.append("join_type: using LEFT JOIN includes rows with no matching department")
    if len(got) > len(expected):
        issues.append(f"extra_rows: got {len(got)} rows, expected {len(expected)} β€” filter too loose")
    if len(got) < len(expected) and len(got) > 0:
        issues.append(f"missing_rows: got {len(got)} rows, expected {len(expected)} β€” filter too strict")
    if result_score < 0.5:
        issues.append("result_rows: output does not match expected β€” check JOIN and WHERE")
    feedback = "; ".join(issues) if issues else "rows and join look correct"

    status = (
        "solved"     if absolute_score >= 0.99
        else "improving" if delta > 0.01
        else "regression" if delta < -0.01
        else "stalled"
    )

    return {
        "value": delta,
        "absolute_score": absolute_score,
        "syntax_ok": True,
        "result_score": result_score,
        "plan_score": join_score,
        "delta": delta,
        "status": status,
        "feedback": feedback,
        "message": (
            f"{status} | abs={absolute_score:.3f} | delta={delta:+.3f} | "
            f"result={result_score:.0%} | join={join_score:.2f}"
        ),
    }