Spaces:
Paused
Paused
| """ | |
| graders.py β Execution-Grounded Reward Function | |
| ================================================= | |
| What makes this environment unique: reward is computed from REAL | |
| DuckDB execution results, not just keyword heuristics. | |
| Scoring breakdown (sums to 1.0): | |
| Real Execution Speedup 35% β actual timing ratio from DuckDB | |
| Result Correctness 20% β both queries return identical data? | |
| Issue Detection 25% β keyword match vs ground truth | |
| Approval Correctness 8% β correctly flags query as bad? | |
| Summary Quality 7% β is the written analysis thorough? | |
| Severity Labels 5% β are severity values present? | |
| """ | |
| from typing import Any, Dict, List, Optional | |
| from executor import get_executor | |
| from models import Action, Reward | |
| # ββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _kw_match(text: str, keywords: List[str]) -> bool: | |
| t = text.lower() | |
| return any(kw.lower() in t for kw in keywords) | |
| def _suggestions_text(action: Action) -> str: | |
| parts = [action.summary, action.optimized_query, action.estimated_improvement] | |
| for s in action.suggestions: | |
| parts += [ | |
| str(s.get("issue_type", "")), | |
| str(s.get("description", "")), | |
| str(s.get("fix", "")), | |
| str(s.get("severity", "")), | |
| ] | |
| return " ".join(parts) | |
| # ββ Speedup β score mapping βββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _speedup_score(speedup: float, has_error: bool) -> float: | |
| """Map real speedup ratio to a score in [0.0, 0.35].""" | |
| if has_error: | |
| return 0.0 | |
| if speedup >= 15.0: | |
| return 0.35 | |
| if speedup >= 8.0: | |
| return 0.30 | |
| if speedup >= 4.0: | |
| return 0.25 | |
| if speedup >= 2.0: | |
| return 0.18 | |
| if speedup >= 1.2: | |
| return 0.10 | |
| if speedup >= 0.9: # slightly slower β acceptable | |
| return 0.04 | |
| return 0.0 # regression | |
| # ββ Main grader βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def grade(task_data: Dict[str, Any], action: Action) -> Reward: | |
| original_query: str = task_data["sql_query"] | |
| optimized_query: str = (action.optimized_query or "").strip() | |
| ground_truth: List[Dict[str, Any]] = task_data["ground_truth_issues"] | |
| full_text = _suggestions_text(action) | |
| # ββ 1. Real Execution (0.0β0.35) βββββββββββββββββββββββββββββββββ | |
| exec_info: Dict[str, Any] = {} | |
| speedup_sc = 0.0 | |
| correctness_sc = 0.0 | |
| exec_feedback: List[str] = [] | |
| if optimized_query: | |
| try: | |
| ex = get_executor() | |
| exec_info = ex.compare(original_query, optimized_query) | |
| speedup = exec_info.get("speedup", 1.0) | |
| r_match = exec_info.get("results_match", False) | |
| opt_err = exec_info.get("optimized_error") | |
| # 1a. Speedup score | |
| speedup_sc = _speedup_score(speedup, bool(opt_err)) | |
| # 1b. Correctness score (0.0-0.20) | |
| if opt_err: | |
| correctness_sc = 0.0 | |
| elif r_match: | |
| correctness_sc = 0.20 | |
| elif exec_info.get("optimized_rows", 0) > 0: | |
| # Query ran but different results -- partial credit | |
| correctness_sc = 0.05 | |
| # Feedback lines | |
| exec_feedback = [ | |
| "\n[DuckDB Execution Results]", | |
| f" Original : {exec_info['original_ms']:.1f} ms " | |
| f"({exec_info['original_rows']} rows)", | |
| f" Optimized : {exec_info['optimized_ms']:.1f} ms " | |
| f"({exec_info['optimized_rows']} rows)", | |
| f" Speedup : {speedup:.2f}x", | |
| f" Correct? : {'YES' if r_match else 'NO -- results differ'}", | |
| f" Verdict : {exec_info.get('verdict', '')}", | |
| ] | |
| if opt_err: | |
| exec_feedback.append(f" SQL Error : {opt_err[:200]}") | |
| except Exception as exc: | |
| exec_feedback = [f"\n[WARN] Execution engine error: {exc}"] | |
| # ββ 2. Issue Detection (0.0β0.25) ββββββββββββββββββββββββββββββββ | |
| detected = 0 | |
| detection_fb: List[str] = ["\n[Issue Detection]"] | |
| for gt in ground_truth: | |
| found = _kw_match(full_text, gt["keywords"]) | |
| if found: | |
| detected += 1 | |
| detection_fb.append(f" [FOUND] {gt['type']} (line ~{gt['line']})") | |
| else: | |
| detection_fb.append(f" [MISS ] {gt['type']} (line ~{gt['line']})") | |
| detection_sc = (detected / len(ground_truth)) * 0.25 if ground_truth else 0.0 | |
| # ββ 3. Approval Correctness (0.0β0.08) βββββββββββββββββββββββββββ | |
| expected_approved = task_data.get("approved_expected", False) | |
| approval_sc = 0.08 if action.approved == expected_approved else 0.0 | |
| # ββ 4. Summary Quality (0.0β0.07) ββββββββββββββββββββββββββββββββ | |
| summary_sc = 0.0 | |
| slen = len(action.summary) | |
| if slen > 50: | |
| summary_sc = 0.03 | |
| if slen > 120: | |
| summary_sc = 0.07 | |
| # ββ 5. Severity Labels (0.0β0.05) ββββββββββββββββββββββββββββββββ | |
| sev_kw = ["critical", "high", "medium", "low"] | |
| has_sev = any( | |
| _kw_match(str(s.get("severity", "")), sev_kw) for s in action.suggestions | |
| ) | |
| severity_sc = 0.05 if has_sev else 0.0 | |
| # ββ Total βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| total = min( | |
| max(speedup_sc + correctness_sc + detection_sc + | |
| approval_sc + summary_sc + severity_sc, 0.0), | |
| 1.0, | |
| ) | |
| total = round(total, 4) | |
| if total == 0.0 and action.suggestions: | |
| total = 0.02 # minimum signal for any submission | |
| breakdown = { | |
| "execution_speedup": round(speedup_sc, 4), | |
| "result_correctness": round(correctness_sc, 4), | |
| "issue_detection": round(detection_sc, 4), | |
| "approval_correctness": round(approval_sc, 4), | |
| "summary_quality": round(summary_sc, 4), | |
| "severity_labels": round(severity_sc, 4), | |
| } | |
| feedback = "\n".join( | |
| exec_feedback | |
| + detection_fb | |
| + [ | |
| f"\n Suggestions submitted: {len(action.suggestions)} " | |
| f"(expected ~{len(ground_truth)})", | |
| f" Approval: {'β ' if action.approved == expected_approved else 'β'} " | |
| f"(got {'approved' if action.approved else 'rejected'}, " | |
| f"expected {'approved' if expected_approved else 'rejected'})", | |
| f"\nπ Total score: {total:.4f}", | |
| ] | |
| ) | |
| return Reward(score=total, breakdown=breakdown, feedback=feedback) | |