nl2sql-copilot / tests /test_generator.py
Melika Kheirieh
style: format code with ruff
c1bc4eb
import pytest
from nl2sql.generator import Generator
from nl2sql.types import StageResult
# --- Dummy LLMs (respect the 5-tuple contract) --------------------------------
class LLM_OK:
def generate_sql(self, **kwargs):
# contract: (sql, rationale, t_in, t_out, cost)
return "SELECT * FROM singer;", "list all", 10, 5, 0.00001
class LLM_EMPTY_SQL:
def generate_sql(self, **kwargs):
# empty SQL β†’ should be error
return "", "reason", 10, 5, 0.0
class LLM_NON_SELECT:
def generate_sql(self, **kwargs):
# non-SELECT SQL β†’ should be error
return "UPDATE users SET name='x' WHERE id=1;", "bad", 8, 3, 0.0
class LLM_CONTRACT_NONE:
def generate_sql(self, **kwargs):
# contract violation: None instead of 5-tuple
return None
class LLM_CONTRACT_SHORT:
def generate_sql(self, **kwargs):
# contract violation: too few items
return ("SELECT * FROM singer;", "list all") # only 2
# --- Parametrized negative cases ----------------------------------------------
@pytest.mark.parametrize(
"llm, err_keyword",
[
(LLM_EMPTY_SQL(), "empty"), # empty or non-string sql
(LLM_NON_SELECT(), "non-select"), # generated non-SELECT
(LLM_CONTRACT_NONE(), "contract violation"),
(LLM_CONTRACT_SHORT(), "contract violation"),
],
)
def test_generator_errors_do_not_create_trace(llm, err_keyword):
gen = Generator(llm=llm)
r = gen.run(
user_query="show all singers",
schema_preview="CREATE TABLE singer(id int, name text);",
plan_text="-- plan --",
clarify_answers={},
)
assert isinstance(r, StageResult)
assert r.ok is False
# Error message is flexible; just check a keyword
joined = " ".join(r.error or []).lower()
assert err_keyword in joined
# On errors, Generator should not attach a trace (we measure only successful stage)
assert r.trace is None
# --- Positive case (success) ---------------------------------------------------
def test_generator_success_has_valid_trace_and_data():
gen = Generator(llm=LLM_OK())
r = gen.run(
user_query="show all singers",
schema_preview="CREATE TABLE singer(id int, name text);",
plan_text="-- plan --",
clarify_answers={},
)
# Basic success checks
assert isinstance(r, StageResult)
assert r.ok is True
assert r.data and r.data["sql"].lower().startswith("select")
assert "rationale" in r.data
# Trace should exist and be coherent
assert r.trace is not None
assert r.trace.stage == "generator"
assert isinstance(r.trace.duration_ms, float)
assert r.trace.token_in == 10
assert r.trace.token_out == 5
# cost can be float or None depending on provider; if present must be numeric
if r.trace.cost_usd is not None:
assert isinstance(r.trace.cost_usd, float)
# Optional notes check – rationale_len should match length of rationale
notes = r.trace.notes or {}
if "rationale_len" in notes:
assert notes["rationale_len"] == len(r.data.get("rationale", ""))