SQL-Query-Env / graders.py
Abhinav Singh
feat(v2): execution-grounded rewards via DuckDB -- the key differentiator
c15d346
"""
graders.py β€” Execution-Grounded Reward Function
=================================================
What makes this environment unique: reward is computed from REAL
DuckDB execution results, not just keyword heuristics.
Scoring breakdown (sums to 1.0):
Real Execution Speedup 35% β€” actual timing ratio from DuckDB
Result Correctness 20% β€” both queries return identical data?
Issue Detection 25% β€” keyword match vs ground truth
Approval Correctness 8% β€” correctly flags query as bad?
Summary Quality 7% β€” is the written analysis thorough?
Severity Labels 5% β€” are severity values present?
"""
from typing import Any, Dict, List, Optional
from executor import get_executor
from models import Action, Reward
# ── Helpers ──────────────────────────────────────────────────────────────
def _kw_match(text: str, keywords: List[str]) -> bool:
t = text.lower()
return any(kw.lower() in t for kw in keywords)
def _suggestions_text(action: Action) -> str:
parts = [action.summary, action.optimized_query, action.estimated_improvement]
for s in action.suggestions:
parts += [
str(s.get("issue_type", "")),
str(s.get("description", "")),
str(s.get("fix", "")),
str(s.get("severity", "")),
]
return " ".join(parts)
# ── Speedup β†’ score mapping ───────────────────────────────────────────────
def _speedup_score(speedup: float, has_error: bool) -> float:
"""Map real speedup ratio to a score in [0.0, 0.35]."""
if has_error:
return 0.0
if speedup >= 15.0:
return 0.35
if speedup >= 8.0:
return 0.30
if speedup >= 4.0:
return 0.25
if speedup >= 2.0:
return 0.18
if speedup >= 1.2:
return 0.10
if speedup >= 0.9: # slightly slower β€” acceptable
return 0.04
return 0.0 # regression
# ── Main grader ───────────────────────────────────────────────────────────
def grade(task_data: Dict[str, Any], action: Action) -> Reward:
original_query: str = task_data["sql_query"]
optimized_query: str = (action.optimized_query or "").strip()
ground_truth: List[Dict[str, Any]] = task_data["ground_truth_issues"]
full_text = _suggestions_text(action)
# ── 1. Real Execution (0.0–0.35) ─────────────────────────────────
exec_info: Dict[str, Any] = {}
speedup_sc = 0.0
correctness_sc = 0.0
exec_feedback: List[str] = []
if optimized_query:
try:
ex = get_executor()
exec_info = ex.compare(original_query, optimized_query)
speedup = exec_info.get("speedup", 1.0)
r_match = exec_info.get("results_match", False)
opt_err = exec_info.get("optimized_error")
# 1a. Speedup score
speedup_sc = _speedup_score(speedup, bool(opt_err))
# 1b. Correctness score (0.0-0.20)
if opt_err:
correctness_sc = 0.0
elif r_match:
correctness_sc = 0.20
elif exec_info.get("optimized_rows", 0) > 0:
# Query ran but different results -- partial credit
correctness_sc = 0.05
# Feedback lines
exec_feedback = [
"\n[DuckDB Execution Results]",
f" Original : {exec_info['original_ms']:.1f} ms "
f"({exec_info['original_rows']} rows)",
f" Optimized : {exec_info['optimized_ms']:.1f} ms "
f"({exec_info['optimized_rows']} rows)",
f" Speedup : {speedup:.2f}x",
f" Correct? : {'YES' if r_match else 'NO -- results differ'}",
f" Verdict : {exec_info.get('verdict', '')}",
]
if opt_err:
exec_feedback.append(f" SQL Error : {opt_err[:200]}")
except Exception as exc:
exec_feedback = [f"\n[WARN] Execution engine error: {exc}"]
# ── 2. Issue Detection (0.0–0.25) ────────────────────────────────
detected = 0
detection_fb: List[str] = ["\n[Issue Detection]"]
for gt in ground_truth:
found = _kw_match(full_text, gt["keywords"])
if found:
detected += 1
detection_fb.append(f" [FOUND] {gt['type']} (line ~{gt['line']})")
else:
detection_fb.append(f" [MISS ] {gt['type']} (line ~{gt['line']})")
detection_sc = (detected / len(ground_truth)) * 0.25 if ground_truth else 0.0
# ── 3. Approval Correctness (0.0–0.08) ───────────────────────────
expected_approved = task_data.get("approved_expected", False)
approval_sc = 0.08 if action.approved == expected_approved else 0.0
# ── 4. Summary Quality (0.0–0.07) ────────────────────────────────
summary_sc = 0.0
slen = len(action.summary)
if slen > 50:
summary_sc = 0.03
if slen > 120:
summary_sc = 0.07
# ── 5. Severity Labels (0.0–0.05) ────────────────────────────────
sev_kw = ["critical", "high", "medium", "low"]
has_sev = any(
_kw_match(str(s.get("severity", "")), sev_kw) for s in action.suggestions
)
severity_sc = 0.05 if has_sev else 0.0
# ── Total ─────────────────────────────────────────────────────────
total = min(
max(speedup_sc + correctness_sc + detection_sc +
approval_sc + summary_sc + severity_sc, 0.0),
1.0,
)
total = round(total, 4)
if total == 0.0 and action.suggestions:
total = 0.02 # minimum signal for any submission
breakdown = {
"execution_speedup": round(speedup_sc, 4),
"result_correctness": round(correctness_sc, 4),
"issue_detection": round(detection_sc, 4),
"approval_correctness": round(approval_sc, 4),
"summary_quality": round(summary_sc, 4),
"severity_labels": round(severity_sc, 4),
}
feedback = "\n".join(
exec_feedback
+ detection_fb
+ [
f"\n Suggestions submitted: {len(action.suggestions)} "
f"(expected ~{len(ground_truth)})",
f" Approval: {'βœ…' if action.approved == expected_approved else '❌'} "
f"(got {'approved' if action.approved else 'rejected'}, "
f"expected {'approved' if expected_approved else 'rejected'})",
f"\nπŸ† Total score: {total:.4f}",
]
)
return Reward(score=total, breakdown=breakdown, feedback=feedback)