Spaces:
Running
Running
File size: 20,282 Bytes
a39d8ef | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 | """
nl2sql-bench/tests/test_all.py
================================
Comprehensive test suite covering:
- Database seeder (determinism + row counts)
- Grader (all reward components, step penalty, edge cases)
- Task registry (all 3 tasks load and produce valid examples)
- Environment (reset, step, episode boundary, done logic)
- Inference log format (regex checks on START / STEP / END)
Run with:
pytest tests/ -v
or from project root:
PYTHONPATH=.:server pytest tests/ -v
"""
from __future__ import annotations
import re
import sqlite3
import sys
import os
from pathlib import Path
import pytest
# ββ Path setup so tests can import from both project root and server/ ββββββ
ROOT = Path(__file__).parent.parent
SERVER = ROOT / "server"
sys.path.insert(0, str(ROOT))
sys.path.insert(0, str(SERVER))
# ββ Fixtures βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@pytest.fixture(scope="session")
def db_conn():
"""Shared in-memory SQLite connection with full schema + seed data."""
from db.seed import seed_database
schema = (SERVER / "db" / "schema.sql").read_text()
conn = sqlite3.connect(":memory:", check_same_thread=False)
conn.row_factory = sqlite3.Row
conn.executescript(schema)
seed_database(conn)
yield conn
conn.close()
@pytest.fixture
def fresh_env():
"""A fresh NL2SQLEnvironment instance per test."""
from environment import NL2SQLEnvironment
return NL2SQLEnvironment()
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 1. DATABASE SEEDER
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class TestSeeder:
def test_categories_count(self, db_conn):
row = db_conn.execute("SELECT COUNT(*) FROM categories").fetchone()
assert row[0] == 8, "Should have exactly 8 categories"
def test_products_count(self, db_conn):
row = db_conn.execute("SELECT COUNT(*) FROM products").fetchone()
assert row[0] == 64, "Should have 8 products Γ 8 categories = 64"
def test_customers_count(self, db_conn):
row = db_conn.execute("SELECT COUNT(*) FROM customers").fetchone()
assert row[0] == 150
def test_orders_exist(self, db_conn):
row = db_conn.execute("SELECT COUNT(*) FROM orders").fetchone()
assert row[0] > 100, "Should have a meaningful number of orders"
def test_order_items_exist(self, db_conn):
row = db_conn.execute("SELECT COUNT(*) FROM order_items").fetchone()
assert row[0] > 200
def test_reviews_exist(self, db_conn):
row = db_conn.execute("SELECT COUNT(*) FROM reviews").fetchone()
assert row[0] > 50
def test_determinism(self, db_conn):
"""Seeding a second connection with the same seed gives identical counts."""
from db.seed import seed_database
schema = (SERVER / "db" / "schema.sql").read_text()
conn2 = sqlite3.connect(":memory:")
conn2.executescript(schema)
seed_database(conn2)
for tbl in ["categories", "products", "customers", "orders",
"order_items", "reviews"]:
c1 = db_conn.execute(f"SELECT COUNT(*) FROM {tbl}").fetchone()[0]
c2 = conn2.execute(f"SELECT COUNT(*) FROM {tbl}").fetchone()[0]
assert c1 == c2, f"Table {tbl} count mismatch: {c1} vs {c2}"
conn2.close()
def test_tiers_valid(self, db_conn):
bad = db_conn.execute(
"SELECT COUNT(*) FROM customers WHERE tier NOT IN ('bronze','silver','gold')"
).fetchone()[0]
assert bad == 0
def test_statuses_valid(self, db_conn):
bad = db_conn.execute(
"SELECT COUNT(*) FROM orders "
"WHERE status NOT IN ('pending','processing','shipped','delivered','cancelled')"
).fetchone()[0]
assert bad == 0
def test_ratings_valid(self, db_conn):
bad = db_conn.execute(
"SELECT COUNT(*) FROM reviews WHERE rating < 1 OR rating > 5"
).fetchone()[0]
assert bad == 0
def test_referential_integrity(self, db_conn):
"""Order items should reference valid orders and products."""
orphan_orders = db_conn.execute(
"SELECT COUNT(*) FROM order_items oi "
"LEFT JOIN orders o ON o.id = oi.order_id WHERE o.id IS NULL"
).fetchone()[0]
assert orphan_orders == 0
orphan_products = db_conn.execute(
"SELECT COUNT(*) FROM order_items oi "
"LEFT JOIN products p ON p.id = oi.product_id WHERE p.id IS NULL"
).fetchone()[0]
assert orphan_products == 0
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 2. GRADER
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class TestGrader:
def test_exact_match_first_step(self):
from grader import grade
gt = [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]
result = grade(
actual_rows=gt.copy(),
ground_truth_rows=gt,
error=None,
step=1,
order_sensitive=False,
)
assert result.reward == pytest.approx(1.0)
assert result.exact_match is True
assert result.syntax_ok is True
assert result.columns_match is True
assert result.row_count_match is True
assert result.step_penalty == 0.0
def test_syntax_error_gives_zero(self):
from grader import grade
result = grade(
actual_rows=None,
ground_truth_rows=[{"x": 1}],
error="near 'SELCT': syntax error",
step=1,
)
assert result.reward == 0.0
assert result.syntax_ok is False
def test_step_penalty_applied(self):
from grader import grade
gt = [{"n": 1}]
result = grade(
actual_rows=gt.copy(),
ground_truth_rows=gt,
error=None,
step=3, # penalty = (3-1)*0.05 = 0.10
)
assert result.reward == pytest.approx(1.0 - 0.10)
assert result.step_penalty == pytest.approx(0.10)
def test_columns_wrong_zero_higher_components(self):
from grader import grade
gt = [{"name": "Alice", "score": 10}]
actual = [{"user": "Alice", "points": 10}] # wrong column names
result = grade(actual_rows=actual, ground_truth_rows=gt, error=None, step=1)
assert result.columns_match is False
assert result.exact_match is False
# Only syntax score: 0.10
assert result.reward == pytest.approx(0.10)
def test_correct_columns_wrong_rows(self):
from grader import grade
gt = [{"name": "Alice"}, {"name": "Bob"}]
actual = [{"name": "Charlie"}, {"name": "Dave"}]
result = grade(actual_rows=actual, ground_truth_rows=gt, error=None, step=1)
assert result.columns_match is True
assert result.row_count_match is True
assert result.exact_match is False
# syntax(0.10) + columns(0.20) + row_count(0.20) = 0.50
assert result.reward == pytest.approx(0.50)
def test_order_sensitive_wrong_order_is_not_exact(self):
from grader import grade
gt = [{"id": 1}, {"id": 2}]
actual = [{"id": 2}, {"id": 1}] # reversed
result = grade(
actual_rows=actual,
ground_truth_rows=gt,
error=None,
step=1,
order_sensitive=True,
)
assert result.exact_match is False
def test_order_insensitive_accepts_different_row_order(self):
from grader import grade
gt = [{"id": 1}, {"id": 2}]
actual = [{"id": 2}, {"id": 1}] # different order but same content
result = grade(
actual_rows=actual,
ground_truth_rows=gt,
error=None,
step=1,
order_sensitive=False,
)
assert result.exact_match is True
def test_penalty_never_makes_reward_negative(self):
from grader import grade
# Step 99 with syntax error β reward must be >= 0
result = grade(
actual_rows=None,
ground_truth_rows=[{"x": 1}],
error="some error",
step=99,
)
assert result.reward >= 0.0
def test_execute_query_blocks_writes(self, db_conn):
from grader import execute_query
rows, err = execute_query(db_conn, "INSERT INTO categories(name) VALUES ('x')")
assert rows is None
assert "not allowed" in err.lower() or "INSERT" in err
def test_execute_query_returns_rows(self, db_conn):
from grader import execute_query
rows, err = execute_query(db_conn, "SELECT id, name FROM categories ORDER BY id")
assert err is None
assert len(rows) == 8
assert "id" in rows[0]
assert "name" in rows[0]
def test_compute_ground_truth(self, db_conn):
from grader import compute_ground_truth
rows = compute_ground_truth(db_conn, "SELECT COUNT(*) AS n FROM customers")
assert len(rows) == 1
assert rows[0]["n"] == 150
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 3. TASK REGISTRY
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class TestTasks:
def test_all_tasks_registered(self):
from tasks import all_task_names
names = all_task_names()
assert "simple-filter" in names
assert "join-aggregation" in names
assert "analytics-window" in names
@pytest.mark.parametrize("task_name", [
"simple-filter", "join-aggregation", "analytics-window"
])
def test_task_has_examples(self, task_name):
from tasks import get_task
task = get_task(task_name)
assert len(task.examples) >= 3, f"{task_name} needs at least 3 examples"
@pytest.mark.parametrize("task_name", [
"simple-filter", "join-aggregation", "analytics-window"
])
def test_task_sql_runs_on_real_db(self, task_name, db_conn):
"""Every ground-truth SQL must execute cleanly against the seeded DB."""
from tasks import get_task
from grader import execute_query
task = get_task(task_name)
for ex in task.examples:
rows, error = execute_query(db_conn, ex.sql)
assert error is None, (
f"Task {task_name!r} SQL failed:\n{ex.sql}\nError: {error}"
)
assert rows is not None
@pytest.mark.parametrize("task_name", [
"simple-filter", "join-aggregation", "analytics-window"
])
def test_task_roundrobin(self, task_name):
from tasks import get_task
task = get_task(task_name)
n = len(task.examples)
seen = [task.next_example() for _ in range(n * 2)]
# After n calls, second half should repeat first half
assert seen[:n] == seen[n:]
def test_schema_context_non_empty(self):
from tasks import get_task
task = get_task("simple-filter")
ctx = task.schema_context()
assert "customers" in ctx
assert "orders" in ctx
assert "products" in ctx
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 4. ENVIRONMENT
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class TestEnvironment:
def test_reset_returns_observation(self, fresh_env):
obs = fresh_env.reset(task_name="simple-filter")
assert obs.question != ""
assert obs.schema_context != ""
assert obs.task_name == "simple-filter"
assert obs.done is False
assert obs.step == 0
assert obs.reward is None
def test_reset_state(self, fresh_env):
fresh_env.reset(task_name="join-aggregation")
state = fresh_env.state
assert state.task_name == "join-aggregation"
assert state.task_difficulty == "medium"
assert state.step_count == 0
assert state.solved is False
def test_step_increments_step_count(self, fresh_env):
from models import NL2SQLAction
fresh_env.reset(task_name="simple-filter")
fresh_env.step(NL2SQLAction(query="SELECT 1"))
assert fresh_env.state.step_count == 1
def test_step_syntax_error_gives_nonzero_error(self, fresh_env):
from models import NL2SQLAction
fresh_env.reset(task_name="simple-filter")
obs = fresh_env.step(NL2SQLAction(query="SELCT * FORM broken"))
assert obs.last_error is not None
assert obs.reward == 0.0
def test_step_valid_query_returns_result(self, fresh_env):
from models import NL2SQLAction
fresh_env.reset(task_name="simple-filter")
obs = fresh_env.step(NL2SQLAction(
query="SELECT id, name FROM customers ORDER BY name LIMIT 5"
))
assert obs.last_error is None
assert len(obs.last_result) <= 5
assert obs.reward >= 0.0
def test_exact_match_ends_episode(self, fresh_env):
"""Submitting the exact ground-truth SQL should solve the episode."""
from models import NL2SQLAction
fresh_env.reset(task_name="simple-filter")
# Get the ground truth SQL from the internal example
gt_sql = fresh_env._example.sql
obs = fresh_env.step(NL2SQLAction(query=gt_sql))
assert obs.done is True
assert fresh_env.state.solved is True
assert obs.reward == pytest.approx(1.0) # step 1, full score
def test_max_steps_ends_episode(self, fresh_env):
"""Exhausting all steps should end the episode even without solving."""
from models import NL2SQLAction
from environment import MAX_STEPS
fresh_env.reset(task_name="analytics-window")
obs = None
for _ in range(MAX_STEPS):
obs = fresh_env.step(NL2SQLAction(query="SELECT 1"))
assert obs is not None
assert obs.done is True
def test_reset_clears_previous_episode(self, fresh_env):
from models import NL2SQLAction
fresh_env.reset(task_name="simple-filter")
fresh_env.step(NL2SQLAction(query="SELECT 1"))
# Second reset should clear state
obs = fresh_env.reset(task_name="join-aggregation")
assert fresh_env.state.step_count == 0
assert obs.step == 0
assert obs.task_name == "join-aggregation"
@pytest.mark.parametrize("task_name", [
"simple-filter", "join-aggregation", "analytics-window"
])
def test_all_tasks_solvable(self, task_name):
"""Ground-truth SQL should always produce reward == 1.0 on step 1."""
from environment import NL2SQLEnvironment
from models import NL2SQLAction
env = NL2SQLEnvironment()
env.reset(task_name=task_name)
gt_sql = env._example.sql
obs = env.step(NL2SQLAction(query=gt_sql))
assert obs.done is True
assert obs.reward == pytest.approx(1.0), (
f"Task {task_name!r}: ground-truth SQL did not score 1.0.\n"
f"SQL: {gt_sql}\nError: {obs.last_error}\nReward: {obs.reward}"
)
def test_score_normalised_to_0_1(self, fresh_env):
from models import NL2SQLAction
fresh_env.reset(task_name="simple-filter")
for _ in range(3):
obs = fresh_env.step(NL2SQLAction(query="SELECT 1 AS x"))
assert 0.0 <= obs.score <= 1.0
def test_write_query_blocked(self, fresh_env):
from models import NL2SQLAction
fresh_env.reset(task_name="simple-filter")
obs = fresh_env.step(NL2SQLAction(
query="INSERT INTO categories(name) VALUES ('hack')"
))
assert obs.last_error is not None
assert "not allowed" in obs.last_error.lower() or "INSERT" in obs.last_error
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 5. LOG FORMAT COMPLIANCE
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class TestLogFormat:
"""Validate that the inference.py log helpers emit correct format."""
START_RE = re.compile(
r"^\[START\] task=\S+ env=\S+ model=\S+$"
)
STEP_RE = re.compile(
r"^\[STEP\] step=\d+ action=.+ reward=\d+\.\d{2} "
r"done=(true|false) error=.+$"
)
END_RE = re.compile(
r"^\[END\] success=(true|false) steps=\d+ score=\d+\.\d{3} "
r"rewards=[\d.,]+$"
)
def _capture(self, func, *args, **kwargs) -> str:
import io
from contextlib import redirect_stdout
buf = io.StringIO()
with redirect_stdout(buf):
func(*args, **kwargs)
return buf.getvalue().strip()
def test_log_start_format(self):
sys.path.insert(0, str(ROOT))
from inference import log_start
out = self._capture(log_start, "simple-filter", "Qwen/Qwen2.5-72B")
assert self.START_RE.match(out), f"Bad [START] format: {out!r}"
def test_log_step_format_null_error(self):
from inference import log_step
out = self._capture(log_step, 1, "SELECT 1", 0.10, False, None)
assert self.STEP_RE.match(out), f"Bad [STEP] format: {out!r}"
def test_log_step_format_with_error(self):
from inference import log_step
out = self._capture(log_step, 2, "SELCT 1", 0.0, False, "syntax error")
assert self.STEP_RE.match(out), f"Bad [STEP] format: {out!r}"
def test_log_end_format_success(self):
from inference import log_end
out = self._capture(log_end, True, 3, 0.850, [0.50, 1.0, 1.0])
assert self.END_RE.match(out), f"Bad [END] format: {out!r}"
def test_log_end_format_failure(self):
from inference import log_end
out = self._capture(log_end, False, 5, 0.100, [0.1, 0.0, 0.0, 0.0, 0.0])
assert self.END_RE.match(out), f"Bad [END] format: {out!r}"
def test_reward_two_decimal_places(self):
from inference import log_step
out = self._capture(log_step, 1, "SELECT 1", 0.5, False, None)
# reward= field must have exactly 2 decimal places
match = re.search(r"reward=(\d+\.\d+)", out)
assert match, "No reward= field found"
assert len(match.group(1).split(".")[1]) == 2
def test_score_three_decimal_places(self):
from inference import log_end
out = self._capture(log_end, True, 1, 1.0, [1.0])
match = re.search(r"score=(\d+\.\d+)", out)
assert match
assert len(match.group(1).split(".")[1]) == 3
|