""" End-to-end flow tests for phi3-mini-sql-generator demo. Run with: python tests/e2e_flow_test.py Model must be loaded first. Call app.load_model(app.FINE_TUNED_MODEL_ID) before running these tests. """ import app import types # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def sql_out(result): return result[4] def status(result): return result[6] def reset_model_state(): app._model = None app._tokenizer = None app._current_model_id = None def check_sql(result, expected_fragments, description): """Print and assert SQL output checks.""" sql = sql_out(result) status_msg = status(result) ok = True for frag in expected_fragments: if frag not in sql: print(f" FAIL: missing '{frag}' in output") ok = False if ok: print(f" OK: {description}") print(f" SQL: {sql[:200]}") return ok # --------------------------------------------------------------------------- # Scenario 1: Parser still works (no model call) # --------------------------------------------------------------------------- def test_scenario1_parser_keeps_working(): print("\n=== Scenario 1: Parser — accented columns ===") result = app.generate_response( "criar tabela animal com nome nome cientifico e especie", [], "", None, None ) fragments = ["CREATE TABLE animal", "nome TEXT", "cientifico TEXT", "especie TEXT"] return check_sql(result, fragments, "3 columns from Portuguese message") # --------------------------------------------------------------------------- # Scenario 2: SELECT all # --------------------------------------------------------------------------- def test_scenario2_select_all(): print("\n=== Scenario 2: SELECT all rows ===") schema = app.PRESETS["employees"] result = app.generate_response( "liste todos os funcionarios", [], schema, app.FINE_TUNED_MODEL_KEY, None ) sql = sql_out(result) status_msg = status(result) ok = True if "SELECT" not in sql.upper(): print(f" FAIL: no SELECT in output") ok = False if "FROM" not in sql.upper(): print(f" FAIL: no FROM in output") ok = False if ok: print(f" OK: generated SELECT") print(f" SQL: {sql}") return ok # --------------------------------------------------------------------------- # Scenario 3: SELECT with WHERE filter # --------------------------------------------------------------------------- def test_scenario3_select_with_filter(): print("\n=== Scenario 3: SELECT with WHERE ===") schema = app.PRESETS["employees"] result = app.generate_response( "mostre os funcionarios do departamento de vendas", [], schema, app.FINE_TUNED_MODEL_KEY, None ) sql = sql_out(result) ok = True if "SELECT" not in sql.upper(): print(f" FAIL: no SELECT") ok = False if "WHERE" not in sql.upper(): print(f" FAIL: no WHERE") ok = False if "department" in sql.lower() or "vendas" in sql.lower(): print(f" OK: WHERE clause present") print(f" SQL: {sql}") else: print(f" FAIL: filter condition missing") ok = False return ok # --------------------------------------------------------------------------- # Scenario 4: Aggregate (COUNT, AVG, GROUP BY) # --------------------------------------------------------------------------- def test_scenario4_aggregates(): print("\n=== Scenario 4: Aggregate query ===") schema = app.PRESETS["employees"] result = app.generate_response( "qual a media de salarios por departamento", [], schema, app.FINE_TUNED_MODEL_KEY, None ) sql = sql_out(result) ok = True checks = ["SELECT", "AVG", "GROUP BY"] for c in checks: if c not in sql.upper(): print(f" FAIL: missing '{c}'") ok = False if ok: print(f" OK: aggregate query generated") print(f" SQL: {sql}") return ok # --------------------------------------------------------------------------- # Scenario 5: Natural language SQL (Issue 3) # --------------------------------------------------------------------------- def test_scenario5_natural_language(): print("\n=== Scenario 5: Natural language SQL (Issue 3) ===") schema = app.PRESETS["products"] result = app.generate_response( "what is the most expensive product", [], schema, app.FINE_TUNED_MODEL_KEY, None ) sql = sql_out(result) status_msg = status(result) ok = True if not sql.strip(): print(f" FAIL: no SQL generated — model returned: {status_msg[:100]}") ok = False elif "SELECT" not in sql.upper(): print(f" FAIL: output is not SQL: {sql[:100]}") ok = False else: print(f" OK: natural language produced SQL") print(f" SQL: {sql}") return ok # --------------------------------------------------------------------------- # Scenario 6: Multi-turn flow (create → add → remove → query) # --------------------------------------------------------------------------- def test_scenario6_multiturn_flow(): print("\n=== Scenario 6: Multi-turn schema build + query ===") ok = True # Step 1: Create table r1 = app.generate_response( "crie tabela vendas com id produto quantidade total", [], "", None, None ) if not check_sql(r1, ["CREATE TABLE vendas", "id INTEGER", "produto TEXT", "quantidade INTEGER", "total NUMERIC"], "Step 1: CREATE TABLE"): ok = False # Step 2: Add column r2 = app.generate_response("adicione desconto", r1[0], "", None, None) if not check_sql(r2, ["desconto NUMERIC", "CREATE TABLE vendas"], "Step 2: ADD COLUMN"): ok = False # Step 3: Remove column r3 = app.generate_response("remova quantidade", r2[0], "", None, None) sql3 = sql_out(r3) # CORRECT: quantidade should NOT be in SQL (it was removed) if "quantidade" in sql3: print(f" FAIL: 'quantidade' still in table after remove (regression)") ok = False else: print(f" OK: Step 3: REMOVE COLUMN - 'quantidade' removed") # Verify remaining columns still exist for col in ["id", "produto", "desconto", "total"]: if col not in sql3: print(f" FAIL: column '{col}' missing after remove") ok = False # Step 4: Query (model call) final_schema = sql_out(r3) r4 = app.generate_response( "quanto vendemos no total", r3[0], final_schema, app.FINE_TUNED_MODEL_KEY, None ) sql4 = sql_out(r4) if "SELECT" not in sql4.upper(): print(f" FAIL: Step 4 no SELECT generated. Status: {status(r4)[:100]}") ok = False else: print(f" OK: Step 4: model generated SQL from multi-turn context") print(f" SQL: {sql4}") return ok # --------------------------------------------------------------------------- # Run all # --------------------------------------------------------------------------- def run_all(): if app._model is None: print("ERROR: model not loaded. Run app.load_model(app.FINE_TUNED_MODEL_ID) first.") return results = {} results["s1_parser"] = test_scenario1_parser_keeps_working() results["s2_select_all"] = test_scenario2_select_all() results["s3_where"] = test_scenario3_select_with_filter() results["s4_aggregates"] = test_scenario4_aggregates() results["s5_natlang"] = test_scenario5_natural_language() results["s6_multiturn"] = test_scenario6_multiturn_flow() print("\n" + "=" * 50) print("SUMMARY") print("=" * 50) passed = sum(1 for v in results.values() if v) total = len(results) for name, result in results.items(): mark = "PASS" if result else "FAIL" print(f" {mark} {name}") print(f"\n Total: {passed}/{total} passed") return passed == total if __name__ == "__main__": # Check model loaded if app._model is None: print("Model not loaded. Call app.load_model(app.FINE_TUNED_MODEL_ID) then re-run.") print("From python: python -c \"import app; app.load_model(app.FINE_TUNED_MODEL_ID); exec(open('tests/e2e_flow_test.py').read())\"") else: run_all()