Spaces:
Sleeping
Sleeping
| """ | |
| validate.py - Pre-submission validation script for SQL Repair Clinic. | |
| Runs the full checklist: | |
| - All task graders produce scores in [0.0, 1.0] | |
| - Graders are deterministic (same query -> same score) | |
| - Broken query scores < 0.5 (not trivially correct) | |
| - Reference (correct) query scores >= 0.95 | |
| - Environment reset() and step() and state() work | |
| - Reward varies (graders don't always return the same score) | |
| Usage: | |
| python validate.py | |
| """ | |
| from __future__ import annotations | |
| import sqlite3 | |
| import sys | |
| import textwrap | |
| from typing import Tuple | |
| # import local modules | |
| sys.path.insert(0, ".") | |
| from environment import SQLRepairEnv | |
| from models import SQLAction | |
| from tasks import TASK_REGISTRY | |
| PASS = "OK" | |
| FAIL = "FAIL" | |
| results = [] | |
| def check(name: str, condition: bool, detail: str = "") -> None: | |
| status = PASS if condition else FAIL | |
| msg = f" {status} {name}" | |
| if detail: | |
| msg += f" ({detail})" | |
| print(msg) | |
| results.append(condition) | |
| # Reference (correct) queries for each task | |
| CORRECT_QUERIES = { | |
| "fix_syntax": ( | |
| "SELECT name, salary FROM employees " | |
| "WHERE department = 'Engineering' ORDER BY salary DESC" | |
| ), | |
| "fix_logic": textwrap.dedent("""\ | |
| SELECT c.name, | |
| COALESCE(SUM(o.amount), 0) AS total_spent | |
| FROM customers c | |
| LEFT JOIN orders o ON c.id = o.customer_id | |
| AND o.status = 'completed' | |
| GROUP BY c.id, c.name | |
| ORDER BY total_spent DESC | |
| """) if False else ( | |
| "SELECT c.name, COALESCE(SUM(o.amount), 0) AS total_spent " | |
| "FROM customers c " | |
| "LEFT JOIN orders o ON c.id = o.customer_id AND o.status = 'completed' " | |
| "GROUP BY c.id, c.name ORDER BY total_spent DESC" | |
| ), | |
| "write_analytical": textwrap.dedent("""\ | |
| WITH revenue AS ( | |
| SELECT p.name AS product_name, | |
| p.category, | |
| s.region, | |
| SUM(s.quantity * p.unit_price) AS total_revenue, | |
| COUNT(*) AS num_transactions | |
| FROM sales s | |
| JOIN products p ON s.product_id = p.id | |
| WHERE s.sale_date >= '2024-01-01' | |
| AND s.sale_date <= '2024-03-31' | |
| GROUP BY p.id, p.category, s.region | |
| ), | |
| ranked AS ( | |
| SELECT *, | |
| ROW_NUMBER() OVER ( | |
| PARTITION BY region, category | |
| ORDER BY total_revenue DESC | |
| ) AS rn | |
| FROM revenue | |
| ) | |
| SELECT region, category, product_name, total_revenue, num_transactions | |
| FROM ranked | |
| WHERE rn = 1 | |
| ORDER BY region ASC, category ASC | |
| """), | |
| } | |
| def _make_conn(task_name: str) -> sqlite3.Connection: | |
| task = TASK_REGISTRY[task_name] | |
| conn = sqlite3.connect(":memory:") | |
| for stmt in task["setup_sql"]: | |
| conn.execute(stmt) | |
| conn.commit() | |
| return conn | |
| SEPARATOR = "=" * 68 | |
| print(f"\n{SEPARATOR}") | |
| print(" SQL Repair Clinic - Pre-Submission Check") | |
| print(f"{SEPARATOR}\n") | |
| # 1. Grader correctness: broken query must score < 0.5 | |
| print("1. Broken query scores < 0.5 (grader rejects clearly wrong queries)") | |
| for tname, tdef in TASK_REGISTRY.items(): | |
| conn = _make_conn(tname) | |
| r, reason = tdef["grader"](conn, tdef["broken_query"]) | |
| check(f" {tname}: broken_query score={r:.2f}", r < 0.5, reason) | |
| conn.close() | |
| print() | |
| # 2. Grader correctness: correct query must score >= 0.95 | |
| print("2. Correct query scores >= 0.95") | |
| for tname, correct_sql in CORRECT_QUERIES.items(): | |
| conn = _make_conn(tname) | |
| r, reason = TASK_REGISTRY[tname]["grader"](conn, correct_sql) | |
| check(f" {tname}: correct_query score={r:.2f}", r >= 0.95, reason) | |
| conn.close() | |
| print() | |
| # 3. Determinism: same query returns same score | |
| print("3. Graders are deterministic") | |
| for tname, correct_sql in CORRECT_QUERIES.items(): | |
| scores = [] | |
| for _ in range(3): | |
| conn = _make_conn(tname) | |
| r, _ = TASK_REGISTRY[tname]["grader"](conn, correct_sql) | |
| scores.append(r) | |
| conn.close() | |
| all_equal = len(set(scores)) == 1 | |
| check(f" {tname}: scores={scores}", all_equal) | |
| print() | |
| # 4. Reward range [0.0, 1.0] | |
| print("4. All reward values in [0.0, 1.0]") | |
| test_queries = { | |
| "fix_syntax": ["SELCT * FORM bad_tbl", "SELECT name FROM employees"], | |
| "fix_logic": ["SELECT * FROM customers", CORRECT_QUERIES["fix_logic"]], | |
| "write_analytical": ["SELECT 1", CORRECT_QUERIES["write_analytical"]], | |
| } | |
| all_in_range = True | |
| for tname, queries in test_queries.items(): | |
| for q in queries: | |
| conn = _make_conn(tname) | |
| r, _ = TASK_REGISTRY[tname]["grader"](conn, q) | |
| in_range = 0.0 <= r <= 1.0 | |
| if not in_range: | |
| check(f" {tname}: score={r}", False, "OUT OF RANGE") | |
| all_in_range = False | |
| conn.close() | |
| if all_in_range: | |
| check(" All scores in [0.0, 1.0]", True) | |
| print() | |
| # 5. Graders produce varying scores | |
| print("5. Graders produce varying scores (not constant)") | |
| for tname, queries in test_queries.items(): | |
| conn1 = _make_conn(tname) | |
| conn2 = _make_conn(tname) | |
| r1, _ = TASK_REGISTRY[tname]["grader"](conn1, queries[0]) | |
| r2, _ = TASK_REGISTRY[tname]["grader"](conn2, queries[1]) | |
| conn1.close() | |
| conn2.close() | |
| check(f" {tname}: scores vary ({r1:.2f} != {r2:.2f})", abs(r1 - r2) > 0.01) | |
| print() | |
| # 6. Environment API works | |
| print("6. Environment API: reset / step / state") | |
| env = SQLRepairEnv() | |
| for tname in TASK_REGISTRY.keys(): | |
| obs = env.reset(task=tname) | |
| check(f" {tname}: reset() returns SQLObservation", obs is not None) | |
| state = env.state() | |
| check(f" {tname}: state() works", state.task_name == tname) | |
| obs2, reward, done, info = env.step(SQLAction(query=CORRECT_QUERIES[tname])) | |
| check( | |
| f" {tname}: step(correct_query) reward={reward:.2f}", | |
| reward >= 0.95, | |
| info.get("grader_reason", ""), | |
| ) | |
| print() | |
| # Summary | |
| passed = sum(results) | |
| total = len(results) | |
| print(SEPARATOR) | |
| print(f" RESULT: {passed}/{total} checks passed") | |
| if passed == total: | |
| print(" All checks passed - ready to submit!") | |
| else: | |
| print(" Some checks failed - fix before submitting.") | |
| print(f"{SEPARATOR}\n") | |
| sys.exit(0 if passed == total else 1) | |