""" validate.py - Pre-submission validation script for SQL Repair Clinic. Runs the full checklist: - All task graders produce scores in [0.0, 1.0] - Graders are deterministic (same query -> same score) - Broken query scores < 0.5 (not trivially correct) - Reference (correct) query scores >= 0.95 - Environment reset() and step() and state() work - Reward varies (graders don't always return the same score) Usage: python validate.py """ from __future__ import annotations import sqlite3 import sys import textwrap from typing import Tuple # import local modules sys.path.insert(0, ".") from environment import SQLRepairEnv from models import SQLAction from tasks import TASK_REGISTRY PASS = "OK" FAIL = "FAIL" results = [] def check(name: str, condition: bool, detail: str = "") -> None: status = PASS if condition else FAIL msg = f" {status} {name}" if detail: msg += f" ({detail})" print(msg) results.append(condition) # Reference (correct) queries for each task CORRECT_QUERIES = { "fix_syntax": ( "SELECT name, salary FROM employees " "WHERE department = 'Engineering' ORDER BY salary DESC" ), "fix_logic": textwrap.dedent("""\ SELECT c.name, COALESCE(SUM(o.amount), 0) AS total_spent FROM customers c LEFT JOIN orders o ON c.id = o.customer_id AND o.status = 'completed' GROUP BY c.id, c.name ORDER BY total_spent DESC """) if False else ( "SELECT c.name, COALESCE(SUM(o.amount), 0) AS total_spent " "FROM customers c " "LEFT JOIN orders o ON c.id = o.customer_id AND o.status = 'completed' " "GROUP BY c.id, c.name ORDER BY total_spent DESC" ), "write_analytical": textwrap.dedent("""\ WITH revenue AS ( SELECT p.name AS product_name, p.category, s.region, SUM(s.quantity * p.unit_price) AS total_revenue, COUNT(*) AS num_transactions FROM sales s JOIN products p ON s.product_id = p.id WHERE s.sale_date >= '2024-01-01' AND s.sale_date <= '2024-03-31' GROUP BY p.id, p.category, s.region ), ranked AS ( SELECT *, ROW_NUMBER() OVER ( PARTITION BY region, category ORDER BY total_revenue DESC ) AS rn FROM revenue ) SELECT region, category, product_name, total_revenue, num_transactions FROM ranked WHERE rn = 1 ORDER BY region ASC, category ASC """), } def _make_conn(task_name: str) -> sqlite3.Connection: task = TASK_REGISTRY[task_name] conn = sqlite3.connect(":memory:") for stmt in task["setup_sql"]: conn.execute(stmt) conn.commit() return conn SEPARATOR = "=" * 68 print(f"\n{SEPARATOR}") print(" SQL Repair Clinic - Pre-Submission Check") print(f"{SEPARATOR}\n") # 1. Grader correctness: broken query must score < 0.5 print("1. Broken query scores < 0.5 (grader rejects clearly wrong queries)") for tname, tdef in TASK_REGISTRY.items(): conn = _make_conn(tname) r, reason = tdef["grader"](conn, tdef["broken_query"]) check(f" {tname}: broken_query score={r:.2f}", r < 0.5, reason) conn.close() print() # 2. Grader correctness: correct query must score >= 0.95 print("2. Correct query scores >= 0.95") for tname, correct_sql in CORRECT_QUERIES.items(): conn = _make_conn(tname) r, reason = TASK_REGISTRY[tname]["grader"](conn, correct_sql) check(f" {tname}: correct_query score={r:.2f}", r >= 0.95, reason) conn.close() print() # 3. Determinism: same query returns same score print("3. Graders are deterministic") for tname, correct_sql in CORRECT_QUERIES.items(): scores = [] for _ in range(3): conn = _make_conn(tname) r, _ = TASK_REGISTRY[tname]["grader"](conn, correct_sql) scores.append(r) conn.close() all_equal = len(set(scores)) == 1 check(f" {tname}: scores={scores}", all_equal) print() # 4. Reward range [0.0, 1.0] print("4. All reward values in [0.0, 1.0]") test_queries = { "fix_syntax": ["SELCT * FORM bad_tbl", "SELECT name FROM employees"], "fix_logic": ["SELECT * FROM customers", CORRECT_QUERIES["fix_logic"]], "write_analytical": ["SELECT 1", CORRECT_QUERIES["write_analytical"]], } all_in_range = True for tname, queries in test_queries.items(): for q in queries: conn = _make_conn(tname) r, _ = TASK_REGISTRY[tname]["grader"](conn, q) in_range = 0.0 <= r <= 1.0 if not in_range: check(f" {tname}: score={r}", False, "OUT OF RANGE") all_in_range = False conn.close() if all_in_range: check(" All scores in [0.0, 1.0]", True) print() # 5. Graders produce varying scores print("5. Graders produce varying scores (not constant)") for tname, queries in test_queries.items(): conn1 = _make_conn(tname) conn2 = _make_conn(tname) r1, _ = TASK_REGISTRY[tname]["grader"](conn1, queries[0]) r2, _ = TASK_REGISTRY[tname]["grader"](conn2, queries[1]) conn1.close() conn2.close() check(f" {tname}: scores vary ({r1:.2f} != {r2:.2f})", abs(r1 - r2) > 0.01) print() # 6. Environment API works print("6. Environment API: reset / step / state") env = SQLRepairEnv() for tname in TASK_REGISTRY.keys(): obs = env.reset(task=tname) check(f" {tname}: reset() returns SQLObservation", obs is not None) state = env.state() check(f" {tname}: state() works", state.task_name == tname) obs2, reward, done, info = env.step(SQLAction(query=CORRECT_QUERIES[tname])) check( f" {tname}: step(correct_query) reward={reward:.2f}", reward >= 0.95, info.get("grader_reason", ""), ) print() # Summary passed = sum(results) total = len(results) print(SEPARATOR) print(f" RESULT: {passed}/{total} checks passed") if passed == total: print(" All checks passed - ready to submit!") else: print(" Some checks failed - fix before submitting.") print(f"{SEPARATOR}\n") sys.exit(0 if passed == total else 1)