Spaces:
Sleeping
Sleeping
| """ | |
| QueryForge Judge β deterministic DuckDB grading + Anthropic AI quality scoring. | |
| Grading pipeline for each submitted SQL query: | |
| Stage 1 β Syntax (0.0 β 0.15) | |
| DuckDB EXPLAIN parses the query. Fail β score = 0.0. | |
| Stage 2 β Execution (β 0.30) | |
| Run the full query against in-memory DuckDB seeded with task data. | |
| Fail β score = 0.15 (syntax was fine, runtime error). | |
| Stage 3 β Correctness (β 0.80) | |
| Compare returned rows against expected rows. | |
| Perfect match β deterministic score reaches 0.80. | |
| Partial credit for correct row count or partial row matches. | |
| Stage 4 β AI Quality (β 1.0) | |
| Anthropic claude-haiku-4-5 evaluates optimization, code style, and | |
| semantic correctness vs. the reference solution. | |
| The AI score can move the final score up to 1.0 when rows are correct, | |
| or provide nuanced feedback even when rows are partially wrong. | |
| Environment variable required: | |
| ANTHROPIC_API_KEY β standard Anthropic SDK key. | |
| """ | |
| import json | |
| import re | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import anthropic | |
| import duckdb | |
| try: | |
| from .tasks import SQLTask, TestCase | |
| except ImportError: | |
| from tasks import SQLTask, TestCase | |
| JUDGE_MODEL = "claude-haiku-4-5-20251001" | |
| # --------------------------------------------------------------------------- | |
| # Stage 1 β Syntax check | |
| # --------------------------------------------------------------------------- | |
| def _reject_multi_statement(query: str) -> Optional[str]: | |
| """Return an error message if the query contains multiple statements.""" | |
| # Strip string literals and comments before checking for semicolons | |
| stripped = re.sub(r"'[^']*'", "", query) # remove string literals | |
| stripped = re.sub(r"--[^\n]*", "", stripped) # remove line comments | |
| stripped = re.sub(r"/\*.*?\*/", "", stripped, flags=re.DOTALL) # block comments | |
| stripped = stripped.strip().rstrip(";") # allow a single trailing semicolon | |
| if ";" in stripped: | |
| return "Multi-statement queries are not allowed." | |
| return None | |
| def check_syntax(query: str) -> Tuple[bool, Optional[str]]: | |
| """ | |
| Return (is_valid, error_message). | |
| Strategy: run EXPLAIN against an empty in-memory DuckDB. | |
| - "Parser Error" in the exception β genuine syntax error β invalid. | |
| - "Catalog Error" / "Binder Error" β tables unknown but syntax is fine β valid. | |
| - Any other exception β treat as syntax error to be safe. | |
| """ | |
| multi_err = _reject_multi_statement(query) | |
| if multi_err: | |
| return False, multi_err | |
| conn = duckdb.connect(":memory:") | |
| try: | |
| conn.execute(f"EXPLAIN {query}") | |
| return True, None | |
| except Exception as exc: | |
| msg = str(exc) | |
| # Catalog/Binder errors mean the SQL parsed fine; tables just aren't seeded. | |
| if any( | |
| tag in msg | |
| for tag in ("Catalog Error", "Binder Error", "Table with name", | |
| "Referenced column", "does not exist", "column") | |
| ): | |
| return True, None | |
| return False, msg | |
| finally: | |
| conn.close() | |
| # --------------------------------------------------------------------------- | |
| # Stage 2 β Execution | |
| # --------------------------------------------------------------------------- | |
| def execute_query( | |
| schema_ddl: str, query: str | |
| ) -> Tuple[bool, Optional[List[Dict[str, Any]]], Optional[str]]: | |
| """ | |
| Seed a fresh DuckDB in-memory DB with *schema_ddl*, then run *query*. | |
| Returns (success, rows_as_list_of_dicts, error_message). | |
| """ | |
| conn = duckdb.connect(":memory:") | |
| try: | |
| conn.execute(schema_ddl) | |
| result = conn.execute(query).fetchdf() | |
| rows = result.to_dict(orient="records") | |
| # Convert numpy types to native Python | |
| clean: List[Dict[str, Any]] = [] | |
| for row in rows: | |
| clean.append({k: _native(v) for k, v in row.items()}) | |
| return True, clean, None | |
| except Exception as exc: | |
| return False, None, str(exc) | |
| finally: | |
| conn.close() | |
| def _native(value: Any) -> Any: | |
| """Convert numpy scalars β native Python types for JSON-safe comparison.""" | |
| try: | |
| import numpy as np # duckdb fetchdf() returns numpy types | |
| if isinstance(value, (np.integer,)): | |
| return int(value) | |
| if isinstance(value, (np.floating,)): | |
| return float(value) | |
| if isinstance(value, np.bool_): | |
| return bool(value) | |
| except ImportError: | |
| pass | |
| return value | |
| # --------------------------------------------------------------------------- | |
| # Stage 3 β Row correctness | |
| # --------------------------------------------------------------------------- | |
| def _normalize(row: Dict[str, Any]) -> Dict[str, Any]: | |
| """Round floats to 2 dp so 999.99000000001 == 999.99.""" | |
| return { | |
| k: (round(float(v), 2) if isinstance(v, float) else v) | |
| for k, v in row.items() | |
| } | |
| def _sort_key(row: Dict[str, Any], order_by: Optional[str]) -> tuple: | |
| if order_by: | |
| cols = [c.strip() for c in order_by.split(",")] | |
| return tuple(str(row.get(c, "")) for c in cols) | |
| return tuple(str(v) for v in row.values()) | |
| def rows_match( | |
| actual: List[Dict[str, Any]], | |
| expected: List[Dict[str, Any]], | |
| order_by: Optional[str] = None, | |
| ) -> Tuple[float, str]: | |
| """ | |
| Compare *actual* vs *expected* rows. | |
| Scoring: | |
| 1.0 β exact match | |
| 0.5β0.9 β row count matches, some rows differ | |
| 0.3 β row count wrong but partial overlap | |
| 0.0 β empty when non-empty expected | |
| """ | |
| if not expected: | |
| return (1.0, "No expected rows β query accepted.") if not actual else ( | |
| 0.8, f"Expected empty result but got {len(actual)} row(s)." | |
| ) | |
| if not actual: | |
| return 0.0, f"Query returned 0 rows; expected {len(expected)}." | |
| # Project actual rows to only the expected columns (agent may SELECT extra). | |
| # Use case-insensitive matching: build a map from lower(actual_col) β actual_col. | |
| expected_cols = list(expected[0].keys()) | |
| lower_map = {k.lower(): k for k in actual[0].keys()} if actual else {} | |
| def _project(row: Dict[str, Any]) -> Dict[str, Any]: | |
| out: Dict[str, Any] = {} | |
| for ec in expected_cols: | |
| actual_key = lower_map.get(ec.lower()) | |
| if actual_key is not None: | |
| out[ec] = row[actual_key] | |
| return out | |
| projected = [_project(row) for row in actual] | |
| actual_norm = [_normalize(r) for r in projected] | |
| expected_norm = [_normalize(r) for r in expected] | |
| if len(projected) != len(expected): | |
| # Count how many returned rows are actually in the expected set | |
| expected_set = [tuple(sorted(r.items())) for r in expected_norm] | |
| correct_rows = sum(1 for r in actual_norm if tuple(sorted(r.items())) in expected_set) | |
| # Score based on fraction of expected rows correctly returned | |
| coverage = correct_rows / len(expected) | |
| # Base 0.10 for count mismatch, up to 0.45 for high coverage of correct rows | |
| score = 0.10 + 0.35 * coverage | |
| return score, ( | |
| f"Row count mismatch: got {len(projected)}, expected {len(expected)}. " | |
| f"{correct_rows}/{len(expected)} expected rows present." | |
| ) | |
| actual_sorted = sorted(actual_norm, key=lambda r: _sort_key(r, order_by)) | |
| expected_sorted = sorted(expected_norm, key=lambda r: _sort_key(r, order_by)) | |
| matches = sum(1 for a, e in zip(actual_sorted, expected_sorted) if a == e) | |
| row_accuracy = matches / len(expected) | |
| if row_accuracy == 1.0: | |
| return 1.0, "All rows match perfectly." | |
| score = 0.5 + 0.4 * row_accuracy | |
| return score, f"{matches}/{len(expected)} rows match correctly." | |
| # --------------------------------------------------------------------------- | |
| # Stage 4 β Anthropic AI judge | |
| # --------------------------------------------------------------------------- | |
| def call_anthropic_judge( | |
| task: SQLTask, | |
| agent_query: str, | |
| execution_success: bool, | |
| execution_error: Optional[str], | |
| actual_rows: Optional[List[Dict[str, Any]]], | |
| deterministic_score: float, | |
| ) -> Tuple[float, str, str]: | |
| """ | |
| Call claude-sonnet-4-6 to evaluate query quality across three axes: | |
| - Correctness (0β0.50) | |
| - Optimization (0β0.30) β avoids inefficiencies, uses best SQL patterns | |
| - Code quality (0β0.20) β readable, well-aliased, idiomatic SQL | |
| Returns (final_score, feedback, improvement_hint). | |
| Falls back to deterministic_score if the API call fails. | |
| """ | |
| client = anthropic.Anthropic() | |
| sample_actual = json.dumps(actual_rows[:5] if actual_rows else [], indent=2) | |
| sample_expected = json.dumps( | |
| task.test_cases[0].expected_rows if task.test_cases else [], indent=2 | |
| ) | |
| prompt = f"""\ | |
| You are a strict SQL expert judge scoring an agent's query for the task below. | |
| ## Task ({task.level}) | |
| {task.description} | |
| ## Agent Query | |
| ```sql | |
| {agent_query} | |
| ``` | |
| ## Execution | |
| - Success: {execution_success} | |
| - Error: {execution_error or "None"} | |
| - Rows returned (first 5): {sample_actual} | |
| - Expected rows: {sample_expected} | |
| ## Reference Solution | |
| ```sql | |
| {task.solution_query} | |
| ``` | |
| ## Deterministic row-match score (0.0β1.0): {deterministic_score:.3f} | |
| Score the agent query on THREE axes and sum them for the final score: | |
| | Axis | Max | Criteria | | |
| |--------------|------|----------| | |
| | Correctness | 0.50 | Produces the right rows for the stated goal | | |
| | Optimization | 0.30 | Avoids cartesian products / correlated subqueries; uses efficient patterns (CTEs, explicit JOINs, proper GROUP BY) | | |
| | Code quality | 0.20 | Readable aliases, clean formatting, no redundant clauses | | |
| IMPORTANT rules: | |
| - If execution failed with a runtime error, Correctness β€ 0.10. | |
| - If rows are fully correct per deterministic score β₯ 0.95, Correctness β₯ 0.40. | |
| - For the medium task: a query that still uses comma-join syntax scores Optimization β€ 0.05. | |
| - For the hard task: a query without a CTE scores Optimization β€ 0.10. | |
| Respond with ONLY valid JSON (no markdown fences): | |
| {{ | |
| "correctness": <float 0.0β0.50>, | |
| "optimization": <float 0.0β0.30>, | |
| "code_quality": <float 0.0β0.20>, | |
| "score": <sum of above, float 0.0β1.0>, | |
| "feedback": "<2β3 sentences summarising what the agent did right/wrong>", | |
| "hint": "<one concrete actionable improvement, or 'Excellent!' if score >= 0.95>" | |
| }}""" | |
| try: | |
| message = client.messages.create( | |
| model=JUDGE_MODEL, | |
| max_tokens=512, | |
| messages=[ | |
| {"role": "user", "content": prompt}, | |
| {"role": "assistant", "content": "{"}, # prefill forces JSON-only reply | |
| ], | |
| ) | |
| # Prepend the prefilled "{" back before parsing | |
| raw = "{" + message.content[0].text.strip() | |
| # Belt-and-suspenders: extract the first {...} block in case of any preamble | |
| brace_start = raw.find("{") | |
| brace_end = raw.rfind("}") + 1 | |
| if brace_start != -1 and brace_end > brace_start: | |
| raw = raw[brace_start:brace_end] | |
| data = json.loads(raw) | |
| score = float(data["score"]) | |
| score = max(0.0, min(1.0, score)) | |
| feedback = str(data.get("feedback", "")) | |
| hint = str(data.get("hint", "")) | |
| return score, feedback, hint | |
| except Exception as exc: | |
| # Graceful fallback β no API key, network error, or parse failure | |
| msg = str(exc).lower() | |
| if "api_key" in msg or "auth" in msg or "authentication" in msg: | |
| reason = "ANTHROPIC_API_KEY not set β deterministic scoring only (max 0.80)" | |
| else: | |
| reason = f"AI judge call failed ({type(exc).__name__}) β fell back to deterministic score" | |
| return ( | |
| deterministic_score, | |
| f"[AI Judge unavailable] {reason}.", | |
| task.hint, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Public entry point | |
| # --------------------------------------------------------------------------- | |
| def grade( | |
| task: SQLTask, agent_query: str | |
| ) -> Tuple[float, str, Dict[str, Any]]: | |
| """ | |
| Full grading pipeline. Returns (score 0.0β1.0, feedback, details_dict). | |
| Partial progress scoring: | |
| 0.00 β syntax error (unparseable) | |
| 0.15 β syntax valid, runtime error | |
| 0.30 β executes, but 0 rows returned | |
| 0.30β0.80 β partial row matches (deterministic) | |
| 0.80β1.00 β correct rows + AI quality assessment | |
| """ | |
| details: Dict[str, Any] = {} | |
| # ββ Stage 1: syntax ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| syntax_ok, syntax_error = check_syntax(agent_query) | |
| details["syntax_valid"] = syntax_ok | |
| details["syntax_error"] = syntax_error | |
| if not syntax_ok: | |
| return 0.001, f"Syntax error: {syntax_error}", details | |
| # ββ Stage 2: execution βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| exec_ok, rows, exec_error = execute_query(task.schema_ddl, agent_query) | |
| details["execution_success"] = exec_ok | |
| details["execution_error"] = exec_error | |
| details["rows_returned"] = len(rows) if rows else 0 | |
| if not exec_ok: | |
| # Syntax valid but runtime error β call AI for nuanced feedback | |
| ai_score, ai_feedback, ai_hint = call_anthropic_judge( | |
| task, agent_query, False, exec_error, None, 0.15 | |
| ) | |
| details["ai_score"] = ai_score | |
| details["ai_feedback"] = ai_feedback | |
| final = max(0.15, ai_score * 0.3) # cap at 0.3 when execution fails | |
| return final, f"Runtime error: {exec_error} | AI: {ai_feedback}", details | |
| # ββ Stage 3: row correctness βββββββββββββββββββββββββββββββββββββββββββββ | |
| test_case = task.test_cases[0] | |
| row_score, row_feedback = rows_match(rows, test_case.expected_rows, test_case.order_by) | |
| details["row_match_score"] = row_score | |
| details["row_match_feedback"] = row_feedback | |
| # ββ Stage 3b: structural checks (task-specific) βββββββββββββββββββββββββ | |
| # These prevent high scores when the agent submits the broken query verbatim | |
| # or ignores the task's structural requirement. | |
| structural_penalty = 0.0 | |
| query_upper = agent_query.upper() | |
| if task.level == "hard" and "WITH " not in query_upper: | |
| structural_penalty = 0.30 # hard task demands a CTE | |
| row_feedback += " (Penalty: no CTE detected β task requires WITH clause.)" | |
| elif task.level == "medium" and "JOIN " not in query_upper: | |
| structural_penalty = 0.20 # medium task demands explicit JOINs | |
| row_feedback += " (Penalty: no explicit JOIN β task requires JOIN β¦ ON syntax.)" | |
| elif task.id == "task_expert_recursive": | |
| # Two bugs: anchor uses WHERE id=3 (includes VP Eng) + non-recursive CTE (misses deep levels) | |
| if "RECURSIVE" not in query_upper: | |
| structural_penalty += 0.30 | |
| row_feedback += " (Penalty: WITH RECURSIVE required β hardcoded levels won't scale.)" | |
| if "MANAGER_ID = 3" not in query_upper and "MANAGER_ID=3" not in query_upper: | |
| structural_penalty += 0.15 | |
| row_feedback += " (Penalty: anchor should select subordinates via manager_id, not the VP themselves.)" | |
| structural_penalty = min(structural_penalty, 0.40) | |
| elif task.id == "task_expert_rank": | |
| # Two bugs: ROW_NUMBER (drops ties) + ASC ordering (picks lowest instead of highest) | |
| if "ROW_NUMBER" in query_upper: | |
| structural_penalty += 0.20 | |
| row_feedback += " (Penalty: ROW_NUMBER() drops tied rows β use RANK() or DENSE_RANK().)" | |
| if "ASC" in query_upper and "DESC" not in query_upper: | |
| structural_penalty += 0.15 | |
| row_feedback += " (Penalty: ordering by revenue ASC picks lowest earners, not highest.)" | |
| structural_penalty = min(structural_penalty, 0.35) | |
| elif task.id == "task_expert_window": | |
| # Three bugs: missing PARTITION BY on both windows + tied revenues need correct ranking | |
| if "PARTITION BY" not in query_upper: | |
| structural_penalty += 0.20 | |
| row_feedback += " (Penalty: missing PARTITION BY β both SUM and RANK must be partitioned per region.)" | |
| # Count PARTITION BY occurrences β need at least 2 (one per window function) | |
| partition_count = query_upper.count("PARTITION BY") | |
| if 0 < partition_count < 2: | |
| structural_penalty += 0.10 | |
| row_feedback += " (Penalty: only one window function has PARTITION BY β both need it.)" | |
| structural_penalty = min(structural_penalty, 0.30) | |
| details["structural_penalty"] = structural_penalty | |
| # Deterministic score: 0.30 base for executing + up to 0.50 for rows β penalty | |
| deterministic_score = max(0.30, 0.30 + 0.50 * row_score - structural_penalty) | |
| # ββ Stage 4: AI quality ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| ai_score, ai_feedback, ai_hint = call_anthropic_judge( | |
| task, agent_query, True, None, rows, deterministic_score | |
| ) | |
| details["ai_score"] = ai_score | |
| details["ai_feedback"] = ai_feedback | |
| details["ai_hint"] = ai_hint | |
| # Final blending: | |
| # AI judge offline (fallback) β use deterministic score directly | |
| # rows fully correct β trust AI score (can reach 1.0) | |
| # rows partially wrong β clamp AI score to not exceed deterministic | |
| ai_is_fallback = abs(ai_score - deterministic_score) < 0.001 | |
| if ai_is_fallback: | |
| # AI judge was unavailable β use deterministic score as-is | |
| final_score = deterministic_score | |
| elif row_score >= 0.95: | |
| final_score = ai_score | |
| elif row_score >= 0.5: | |
| # Blend: AI provides nuance but can't exceed deterministic ceiling | |
| final_score = min(deterministic_score, ai_score + 0.05) | |
| else: | |
| # Low row accuracy β stay near deterministic | |
| final_score = min(deterministic_score, ai_score * 0.6) | |
| final_score = max(0.001, min(0.999, final_score)) | |
| feedback = ( | |
| f"[Rows] {row_feedback} " | |
| f"[AI Judge] {ai_feedback} " | |
| f"[Hint] {ai_hint}" | |
| ) | |
| return final_score, feedback, details | |