phi3-mini-sql-generator-demo / tests /e2e_flow_test.py
Shizu0n's picture
feat: normalize SQL questions and add deterministic SQL routing
d88f966
"""
End-to-end flow tests for phi3-mini-sql-generator demo.
Run with: python tests/e2e_flow_test.py
Model must be loaded first. Call app.load_model(app.FINE_TUNED_MODEL_ID)
before running these tests.
"""
import app
import types
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def sql_out(result):
return result[4]
def status(result):
return result[6]
def reset_model_state():
app._model = None
app._tokenizer = None
app._current_model_id = None
def check_sql(result, expected_fragments, description):
"""Print and assert SQL output checks."""
sql = sql_out(result)
status_msg = status(result)
ok = True
for frag in expected_fragments:
if frag not in sql:
print(f" FAIL: missing '{frag}' in output")
ok = False
if ok:
print(f" OK: {description}")
print(f" SQL: {sql[:200]}")
return ok
# ---------------------------------------------------------------------------
# Scenario 1: Parser still works (no model call)
# ---------------------------------------------------------------------------
def test_scenario1_parser_keeps_working():
print("\n=== Scenario 1: Parser — accented columns ===")
result = app.generate_response(
"criar tabela animal com nome nome cientifico e especie",
[], "", None, None
)
fragments = ["CREATE TABLE animal", "nome TEXT", "cientifico TEXT", "especie TEXT"]
return check_sql(result, fragments, "3 columns from Portuguese message")
# ---------------------------------------------------------------------------
# Scenario 2: SELECT all
# ---------------------------------------------------------------------------
def test_scenario2_select_all():
print("\n=== Scenario 2: SELECT all rows ===")
schema = app.PRESETS["employees"]
result = app.generate_response(
"liste todos os funcionarios",
[], schema, app.FINE_TUNED_MODEL_KEY, None
)
sql = sql_out(result)
status_msg = status(result)
ok = True
if "SELECT" not in sql.upper():
print(f" FAIL: no SELECT in output")
ok = False
if "FROM" not in sql.upper():
print(f" FAIL: no FROM in output")
ok = False
if ok:
print(f" OK: generated SELECT")
print(f" SQL: {sql}")
return ok
# ---------------------------------------------------------------------------
# Scenario 3: SELECT with WHERE filter
# ---------------------------------------------------------------------------
def test_scenario3_select_with_filter():
print("\n=== Scenario 3: SELECT with WHERE ===")
schema = app.PRESETS["employees"]
result = app.generate_response(
"mostre os funcionarios do departamento de vendas",
[], schema, app.FINE_TUNED_MODEL_KEY, None
)
sql = sql_out(result)
ok = True
if "SELECT" not in sql.upper():
print(f" FAIL: no SELECT")
ok = False
if "WHERE" not in sql.upper():
print(f" FAIL: no WHERE")
ok = False
if "department" in sql.lower() or "vendas" in sql.lower():
print(f" OK: WHERE clause present")
print(f" SQL: {sql}")
else:
print(f" FAIL: filter condition missing")
ok = False
return ok
# ---------------------------------------------------------------------------
# Scenario 4: Aggregate (COUNT, AVG, GROUP BY)
# ---------------------------------------------------------------------------
def test_scenario4_aggregates():
print("\n=== Scenario 4: Aggregate query ===")
schema = app.PRESETS["employees"]
result = app.generate_response(
"qual a media de salarios por departamento",
[], schema, app.FINE_TUNED_MODEL_KEY, None
)
sql = sql_out(result)
ok = True
checks = ["SELECT", "AVG", "GROUP BY"]
for c in checks:
if c not in sql.upper():
print(f" FAIL: missing '{c}'")
ok = False
if ok:
print(f" OK: aggregate query generated")
print(f" SQL: {sql}")
return ok
# ---------------------------------------------------------------------------
# Scenario 5: Natural language SQL (Issue 3)
# ---------------------------------------------------------------------------
def test_scenario5_natural_language():
print("\n=== Scenario 5: Natural language SQL (Issue 3) ===")
schema = app.PRESETS["products"]
result = app.generate_response(
"what is the most expensive product",
[], schema, app.FINE_TUNED_MODEL_KEY, None
)
sql = sql_out(result)
status_msg = status(result)
ok = True
if not sql.strip():
print(f" FAIL: no SQL generated — model returned: {status_msg[:100]}")
ok = False
elif "SELECT" not in sql.upper():
print(f" FAIL: output is not SQL: {sql[:100]}")
ok = False
else:
print(f" OK: natural language produced SQL")
print(f" SQL: {sql}")
return ok
# ---------------------------------------------------------------------------
# Scenario 6: Multi-turn flow (create → add → remove → query)
# ---------------------------------------------------------------------------
def test_scenario6_multiturn_flow():
print("\n=== Scenario 6: Multi-turn schema build + query ===")
ok = True
# Step 1: Create table
r1 = app.generate_response(
"crie tabela vendas com id produto quantidade total",
[], "", None, None
)
if not check_sql(r1, ["CREATE TABLE vendas", "id INTEGER", "produto TEXT", "quantidade INTEGER", "total NUMERIC"], "Step 1: CREATE TABLE"):
ok = False
# Step 2: Add column
r2 = app.generate_response("adicione desconto", r1[0], "", None, None)
if not check_sql(r2, ["desconto NUMERIC", "CREATE TABLE vendas"], "Step 2: ADD COLUMN"):
ok = False
# Step 3: Remove column
r3 = app.generate_response("remova quantidade", r2[0], "", None, None)
sql3 = sql_out(r3)
# CORRECT: quantidade should NOT be in SQL (it was removed)
if "quantidade" in sql3:
print(f" FAIL: 'quantidade' still in table after remove (regression)")
ok = False
else:
print(f" OK: Step 3: REMOVE COLUMN - 'quantidade' removed")
# Verify remaining columns still exist
for col in ["id", "produto", "desconto", "total"]:
if col not in sql3:
print(f" FAIL: column '{col}' missing after remove")
ok = False
# Step 4: Query (model call)
final_schema = sql_out(r3)
r4 = app.generate_response(
"quanto vendemos no total",
r3[0], final_schema, app.FINE_TUNED_MODEL_KEY, None
)
sql4 = sql_out(r4)
if "SELECT" not in sql4.upper():
print(f" FAIL: Step 4 no SELECT generated. Status: {status(r4)[:100]}")
ok = False
else:
print(f" OK: Step 4: model generated SQL from multi-turn context")
print(f" SQL: {sql4}")
return ok
# ---------------------------------------------------------------------------
# Run all
# ---------------------------------------------------------------------------
def run_all():
if app._model is None:
print("ERROR: model not loaded. Run app.load_model(app.FINE_TUNED_MODEL_ID) first.")
return
results = {}
results["s1_parser"] = test_scenario1_parser_keeps_working()
results["s2_select_all"] = test_scenario2_select_all()
results["s3_where"] = test_scenario3_select_with_filter()
results["s4_aggregates"] = test_scenario4_aggregates()
results["s5_natlang"] = test_scenario5_natural_language()
results["s6_multiturn"] = test_scenario6_multiturn_flow()
print("\n" + "=" * 50)
print("SUMMARY")
print("=" * 50)
passed = sum(1 for v in results.values() if v)
total = len(results)
for name, result in results.items():
mark = "PASS" if result else "FAIL"
print(f" {mark} {name}")
print(f"\n Total: {passed}/{total} passed")
return passed == total
if __name__ == "__main__":
# Check model loaded
if app._model is None:
print("Model not loaded. Call app.load_model(app.FINE_TUNED_MODEL_ID) then re-run.")
print("From python: python -c \"import app; app.load_model(app.FINE_TUNED_MODEL_ID); exec(open('tests/e2e_flow_test.py').read())\"")
else:
run_all()