NEWHACKTHONSPACE / validate.py
AKGW580's picture
Update inference and validation
5009d52
"""
validate.py - Pre-submission validation script for SQL Repair Clinic.
Runs the full checklist:
- All task graders produce scores in [0.0, 1.0]
- Graders are deterministic (same query -> same score)
- Broken query scores < 0.5 (not trivially correct)
- Reference (correct) query scores >= 0.95
- Environment reset() and step() and state() work
- Reward varies (graders don't always return the same score)
Usage:
python validate.py
"""
from __future__ import annotations
import sqlite3
import sys
import textwrap
from typing import Tuple
# import local modules
sys.path.insert(0, ".")
from environment import SQLRepairEnv
from models import SQLAction
from tasks import TASK_REGISTRY
PASS = "OK"
FAIL = "FAIL"
results = []
def check(name: str, condition: bool, detail: str = "") -> None:
status = PASS if condition else FAIL
msg = f" {status} {name}"
if detail:
msg += f" ({detail})"
print(msg)
results.append(condition)
# Reference (correct) queries for each task
CORRECT_QUERIES = {
"fix_syntax": (
"SELECT name, salary FROM employees "
"WHERE department = 'Engineering' ORDER BY salary DESC"
),
"fix_logic": textwrap.dedent("""\
SELECT c.name,
COALESCE(SUM(o.amount), 0) AS total_spent
FROM customers c
LEFT JOIN orders o ON c.id = o.customer_id
AND o.status = 'completed'
GROUP BY c.id, c.name
ORDER BY total_spent DESC
""") if False else (
"SELECT c.name, COALESCE(SUM(o.amount), 0) AS total_spent "
"FROM customers c "
"LEFT JOIN orders o ON c.id = o.customer_id AND o.status = 'completed' "
"GROUP BY c.id, c.name ORDER BY total_spent DESC"
),
"write_analytical": textwrap.dedent("""\
WITH revenue AS (
SELECT p.name AS product_name,
p.category,
s.region,
SUM(s.quantity * p.unit_price) AS total_revenue,
COUNT(*) AS num_transactions
FROM sales s
JOIN products p ON s.product_id = p.id
WHERE s.sale_date >= '2024-01-01'
AND s.sale_date <= '2024-03-31'
GROUP BY p.id, p.category, s.region
),
ranked AS (
SELECT *,
ROW_NUMBER() OVER (
PARTITION BY region, category
ORDER BY total_revenue DESC
) AS rn
FROM revenue
)
SELECT region, category, product_name, total_revenue, num_transactions
FROM ranked
WHERE rn = 1
ORDER BY region ASC, category ASC
"""),
}
def _make_conn(task_name: str) -> sqlite3.Connection:
task = TASK_REGISTRY[task_name]
conn = sqlite3.connect(":memory:")
for stmt in task["setup_sql"]:
conn.execute(stmt)
conn.commit()
return conn
SEPARATOR = "=" * 68
print(f"\n{SEPARATOR}")
print(" SQL Repair Clinic - Pre-Submission Check")
print(f"{SEPARATOR}\n")
# 1. Grader correctness: broken query must score < 0.5
print("1. Broken query scores < 0.5 (grader rejects clearly wrong queries)")
for tname, tdef in TASK_REGISTRY.items():
conn = _make_conn(tname)
r, reason = tdef["grader"](conn, tdef["broken_query"])
check(f" {tname}: broken_query score={r:.2f}", r < 0.5, reason)
conn.close()
print()
# 2. Grader correctness: correct query must score >= 0.95
print("2. Correct query scores >= 0.95")
for tname, correct_sql in CORRECT_QUERIES.items():
conn = _make_conn(tname)
r, reason = TASK_REGISTRY[tname]["grader"](conn, correct_sql)
check(f" {tname}: correct_query score={r:.2f}", r >= 0.95, reason)
conn.close()
print()
# 3. Determinism: same query returns same score
print("3. Graders are deterministic")
for tname, correct_sql in CORRECT_QUERIES.items():
scores = []
for _ in range(3):
conn = _make_conn(tname)
r, _ = TASK_REGISTRY[tname]["grader"](conn, correct_sql)
scores.append(r)
conn.close()
all_equal = len(set(scores)) == 1
check(f" {tname}: scores={scores}", all_equal)
print()
# 4. Reward range [0.0, 1.0]
print("4. All reward values in [0.0, 1.0]")
test_queries = {
"fix_syntax": ["SELCT * FORM bad_tbl", "SELECT name FROM employees"],
"fix_logic": ["SELECT * FROM customers", CORRECT_QUERIES["fix_logic"]],
"write_analytical": ["SELECT 1", CORRECT_QUERIES["write_analytical"]],
}
all_in_range = True
for tname, queries in test_queries.items():
for q in queries:
conn = _make_conn(tname)
r, _ = TASK_REGISTRY[tname]["grader"](conn, q)
in_range = 0.0 <= r <= 1.0
if not in_range:
check(f" {tname}: score={r}", False, "OUT OF RANGE")
all_in_range = False
conn.close()
if all_in_range:
check(" All scores in [0.0, 1.0]", True)
print()
# 5. Graders produce varying scores
print("5. Graders produce varying scores (not constant)")
for tname, queries in test_queries.items():
conn1 = _make_conn(tname)
conn2 = _make_conn(tname)
r1, _ = TASK_REGISTRY[tname]["grader"](conn1, queries[0])
r2, _ = TASK_REGISTRY[tname]["grader"](conn2, queries[1])
conn1.close()
conn2.close()
check(f" {tname}: scores vary ({r1:.2f} != {r2:.2f})", abs(r1 - r2) > 0.01)
print()
# 6. Environment API works
print("6. Environment API: reset / step / state")
env = SQLRepairEnv()
for tname in TASK_REGISTRY.keys():
obs = env.reset(task=tname)
check(f" {tname}: reset() returns SQLObservation", obs is not None)
state = env.state()
check(f" {tname}: state() works", state.task_name == tname)
obs2, reward, done, info = env.step(SQLAction(query=CORRECT_QUERIES[tname]))
check(
f" {tname}: step(correct_query) reward={reward:.2f}",
reward >= 0.95,
info.get("grader_reason", ""),
)
print()
# Summary
passed = sum(results)
total = len(results)
print(SEPARATOR)
print(f" RESULT: {passed}/{total} checks passed")
if passed == total:
print(" All checks passed - ready to submit!")
else:
print(" Some checks failed - fix before submitting.")
print(f"{SEPARATOR}\n")
sys.exit(0 if passed == total else 1)