Spaces:
Sleeping
Sleeping
| import pytest | |
| from nl2sql.generator import Generator | |
| from nl2sql.types import StageResult | |
| # --- Dummy LLMs (respect the 5-tuple contract) -------------------------------- | |
| class LLM_OK: | |
| def generate_sql(self, **kwargs): | |
| # contract: (sql, rationale, t_in, t_out, cost) | |
| return "SELECT * FROM singer;", "list all", 10, 5, 0.00001 | |
| class LLM_EMPTY_SQL: | |
| def generate_sql(self, **kwargs): | |
| # empty SQL β should be error | |
| return "", "reason", 10, 5, 0.0 | |
| class LLM_NON_SELECT: | |
| def generate_sql(self, **kwargs): | |
| # non-SELECT SQL β should be error | |
| return "UPDATE users SET name='x' WHERE id=1;", "bad", 8, 3, 0.0 | |
| class LLM_CONTRACT_NONE: | |
| def generate_sql(self, **kwargs): | |
| # contract violation: None instead of 5-tuple | |
| return None | |
| class LLM_CONTRACT_SHORT: | |
| def generate_sql(self, **kwargs): | |
| # contract violation: too few items | |
| return ("SELECT * FROM singer;", "list all") # only 2 | |
| # --- Parametrized negative cases ---------------------------------------------- | |
| def test_generator_errors_do_not_create_trace(llm, err_keyword): | |
| gen = Generator(llm=llm) | |
| r = gen.run( | |
| user_query="show all singers", | |
| schema_preview="CREATE TABLE singer(id int, name text);", | |
| plan_text="-- plan --", | |
| clarify_answers={}, | |
| ) | |
| assert isinstance(r, StageResult) | |
| assert r.ok is False | |
| # Error message is flexible; just check a keyword | |
| joined = " ".join(r.error or []).lower() | |
| assert err_keyword in joined | |
| # On errors, Generator should not attach a trace (we measure only successful stage) | |
| assert r.trace is None | |
| # --- Positive case (success) --------------------------------------------------- | |
| def test_generator_success_has_valid_trace_and_data(): | |
| gen = Generator(llm=LLM_OK()) | |
| r = gen.run( | |
| user_query="show all singers", | |
| schema_preview="CREATE TABLE singer(id int, name text);", | |
| plan_text="-- plan --", | |
| clarify_answers={}, | |
| ) | |
| # Basic success checks | |
| assert isinstance(r, StageResult) | |
| assert r.ok is True | |
| assert r.data and r.data["sql"].lower().startswith("select") | |
| assert "rationale" in r.data | |
| # Trace should exist and be coherent | |
| assert r.trace is not None | |
| assert r.trace.stage == "generator" | |
| assert isinstance(r.trace.duration_ms, float) | |
| assert r.trace.token_in == 10 | |
| assert r.trace.token_out == 5 | |
| # cost can be float or None depending on provider; if present must be numeric | |
| if r.trace.cost_usd is not None: | |
| assert isinstance(r.trace.cost_usd, float) | |
| # Optional notes check β rationale_len should match length of rationale | |
| notes = r.trace.notes or {} | |
| if "rationale_len" in notes: | |
| assert notes["rationale_len"] == len(r.data.get("rationale", "")) | |