Spaces:
Paused
Paused
File size: 7,379 Bytes
c15d346 7841be7 c15d346 7841be7 c15d346 7841be7 c15d346 7841be7 c15d346 7841be7 c15d346 7841be7 c15d346 7841be7 c15d346 7841be7 c15d346 7841be7 c15d346 7841be7 c15d346 7841be7 c15d346 7841be7 c15d346 7841be7 c15d346 7841be7 c15d346 7841be7 c15d346 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 | """
graders.py β Execution-Grounded Reward Function
=================================================
What makes this environment unique: reward is computed from REAL
DuckDB execution results, not just keyword heuristics.
Scoring breakdown (sums to 1.0):
Real Execution Speedup 35% β actual timing ratio from DuckDB
Result Correctness 20% β both queries return identical data?
Issue Detection 25% β keyword match vs ground truth
Approval Correctness 8% β correctly flags query as bad?
Summary Quality 7% β is the written analysis thorough?
Severity Labels 5% β are severity values present?
"""
from typing import Any, Dict, List, Optional
from executor import get_executor
from models import Action, Reward
# ββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _kw_match(text: str, keywords: List[str]) -> bool:
t = text.lower()
return any(kw.lower() in t for kw in keywords)
def _suggestions_text(action: Action) -> str:
parts = [action.summary, action.optimized_query, action.estimated_improvement]
for s in action.suggestions:
parts += [
str(s.get("issue_type", "")),
str(s.get("description", "")),
str(s.get("fix", "")),
str(s.get("severity", "")),
]
return " ".join(parts)
# ββ Speedup β score mapping βββββββββββββββββββββββββββββββββββββββββββββββ
def _speedup_score(speedup: float, has_error: bool) -> float:
"""Map real speedup ratio to a score in [0.0, 0.35]."""
if has_error:
return 0.0
if speedup >= 15.0:
return 0.35
if speedup >= 8.0:
return 0.30
if speedup >= 4.0:
return 0.25
if speedup >= 2.0:
return 0.18
if speedup >= 1.2:
return 0.10
if speedup >= 0.9: # slightly slower β acceptable
return 0.04
return 0.0 # regression
# ββ Main grader βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def grade(task_data: Dict[str, Any], action: Action) -> Reward:
original_query: str = task_data["sql_query"]
optimized_query: str = (action.optimized_query or "").strip()
ground_truth: List[Dict[str, Any]] = task_data["ground_truth_issues"]
full_text = _suggestions_text(action)
# ββ 1. Real Execution (0.0β0.35) βββββββββββββββββββββββββββββββββ
exec_info: Dict[str, Any] = {}
speedup_sc = 0.0
correctness_sc = 0.0
exec_feedback: List[str] = []
if optimized_query:
try:
ex = get_executor()
exec_info = ex.compare(original_query, optimized_query)
speedup = exec_info.get("speedup", 1.0)
r_match = exec_info.get("results_match", False)
opt_err = exec_info.get("optimized_error")
# 1a. Speedup score
speedup_sc = _speedup_score(speedup, bool(opt_err))
# 1b. Correctness score (0.0-0.20)
if opt_err:
correctness_sc = 0.0
elif r_match:
correctness_sc = 0.20
elif exec_info.get("optimized_rows", 0) > 0:
# Query ran but different results -- partial credit
correctness_sc = 0.05
# Feedback lines
exec_feedback = [
"\n[DuckDB Execution Results]",
f" Original : {exec_info['original_ms']:.1f} ms "
f"({exec_info['original_rows']} rows)",
f" Optimized : {exec_info['optimized_ms']:.1f} ms "
f"({exec_info['optimized_rows']} rows)",
f" Speedup : {speedup:.2f}x",
f" Correct? : {'YES' if r_match else 'NO -- results differ'}",
f" Verdict : {exec_info.get('verdict', '')}",
]
if opt_err:
exec_feedback.append(f" SQL Error : {opt_err[:200]}")
except Exception as exc:
exec_feedback = [f"\n[WARN] Execution engine error: {exc}"]
# ββ 2. Issue Detection (0.0β0.25) ββββββββββββββββββββββββββββββββ
detected = 0
detection_fb: List[str] = ["\n[Issue Detection]"]
for gt in ground_truth:
found = _kw_match(full_text, gt["keywords"])
if found:
detected += 1
detection_fb.append(f" [FOUND] {gt['type']} (line ~{gt['line']})")
else:
detection_fb.append(f" [MISS ] {gt['type']} (line ~{gt['line']})")
detection_sc = (detected / len(ground_truth)) * 0.25 if ground_truth else 0.0
# ββ 3. Approval Correctness (0.0β0.08) βββββββββββββββββββββββββββ
expected_approved = task_data.get("approved_expected", False)
approval_sc = 0.08 if action.approved == expected_approved else 0.0
# ββ 4. Summary Quality (0.0β0.07) ββββββββββββββββββββββββββββββββ
summary_sc = 0.0
slen = len(action.summary)
if slen > 50:
summary_sc = 0.03
if slen > 120:
summary_sc = 0.07
# ββ 5. Severity Labels (0.0β0.05) ββββββββββββββββββββββββββββββββ
sev_kw = ["critical", "high", "medium", "low"]
has_sev = any(
_kw_match(str(s.get("severity", "")), sev_kw) for s in action.suggestions
)
severity_sc = 0.05 if has_sev else 0.0
# ββ Total βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
total = min(
max(speedup_sc + correctness_sc + detection_sc +
approval_sc + summary_sc + severity_sc, 0.0),
1.0,
)
total = round(total, 4)
if total == 0.0 and action.suggestions:
total = 0.02 # minimum signal for any submission
breakdown = {
"execution_speedup": round(speedup_sc, 4),
"result_correctness": round(correctness_sc, 4),
"issue_detection": round(detection_sc, 4),
"approval_correctness": round(approval_sc, 4),
"summary_quality": round(summary_sc, 4),
"severity_labels": round(severity_sc, 4),
}
feedback = "\n".join(
exec_feedback
+ detection_fb
+ [
f"\n Suggestions submitted: {len(action.suggestions)} "
f"(expected ~{len(ground_truth)})",
f" Approval: {'β
' if action.approved == expected_approved else 'β'} "
f"(got {'approved' if action.approved else 'rejected'}, "
f"expected {'approved' if expected_approved else 'rejected'})",
f"\nπ Total score: {total:.4f}",
]
)
return Reward(score=total, breakdown=breakdown, feedback=feedback)
|