| """ |
| End-to-end flow tests for phi3-mini-sql-generator demo. |
| Run with: python tests/e2e_flow_test.py |
| |
| Model must be loaded first. Call app.load_model(app.FINE_TUNED_MODEL_ID) |
| before running these tests. |
| """ |
|
|
| import app |
| import types |
|
|
| |
| |
| |
|
|
| def sql_out(result): |
| return result[4] |
|
|
| def status(result): |
| return result[6] |
|
|
| def reset_model_state(): |
| app._model = None |
| app._tokenizer = None |
| app._current_model_id = None |
|
|
|
|
| def check_sql(result, expected_fragments, description): |
| """Print and assert SQL output checks.""" |
| sql = sql_out(result) |
| status_msg = status(result) |
| ok = True |
| for frag in expected_fragments: |
| if frag not in sql: |
| print(f" FAIL: missing '{frag}' in output") |
| ok = False |
| if ok: |
| print(f" OK: {description}") |
| print(f" SQL: {sql[:200]}") |
| return ok |
|
|
|
|
| |
| |
| |
|
|
| def test_scenario1_parser_keeps_working(): |
| print("\n=== Scenario 1: Parser — accented columns ===") |
| result = app.generate_response( |
| "criar tabela animal com nome nome cientifico e especie", |
| [], "", None, None |
| ) |
| fragments = ["CREATE TABLE animal", "nome TEXT", "cientifico TEXT", "especie TEXT"] |
| return check_sql(result, fragments, "3 columns from Portuguese message") |
|
|
|
|
| |
| |
| |
|
|
| def test_scenario2_select_all(): |
| print("\n=== Scenario 2: SELECT all rows ===") |
| schema = app.PRESETS["employees"] |
| result = app.generate_response( |
| "liste todos os funcionarios", |
| [], schema, app.FINE_TUNED_MODEL_KEY, None |
| ) |
| sql = sql_out(result) |
| status_msg = status(result) |
| ok = True |
| if "SELECT" not in sql.upper(): |
| print(f" FAIL: no SELECT in output") |
| ok = False |
| if "FROM" not in sql.upper(): |
| print(f" FAIL: no FROM in output") |
| ok = False |
| if ok: |
| print(f" OK: generated SELECT") |
| print(f" SQL: {sql}") |
| return ok |
|
|
|
|
| |
| |
| |
|
|
| def test_scenario3_select_with_filter(): |
| print("\n=== Scenario 3: SELECT with WHERE ===") |
| schema = app.PRESETS["employees"] |
| result = app.generate_response( |
| "mostre os funcionarios do departamento de vendas", |
| [], schema, app.FINE_TUNED_MODEL_KEY, None |
| ) |
| sql = sql_out(result) |
| ok = True |
| if "SELECT" not in sql.upper(): |
| print(f" FAIL: no SELECT") |
| ok = False |
| if "WHERE" not in sql.upper(): |
| print(f" FAIL: no WHERE") |
| ok = False |
| if "department" in sql.lower() or "vendas" in sql.lower(): |
| print(f" OK: WHERE clause present") |
| print(f" SQL: {sql}") |
| else: |
| print(f" FAIL: filter condition missing") |
| ok = False |
| return ok |
|
|
|
|
| |
| |
| |
|
|
| def test_scenario4_aggregates(): |
| print("\n=== Scenario 4: Aggregate query ===") |
| schema = app.PRESETS["employees"] |
| result = app.generate_response( |
| "qual a media de salarios por departamento", |
| [], schema, app.FINE_TUNED_MODEL_KEY, None |
| ) |
| sql = sql_out(result) |
| ok = True |
| checks = ["SELECT", "AVG", "GROUP BY"] |
| for c in checks: |
| if c not in sql.upper(): |
| print(f" FAIL: missing '{c}'") |
| ok = False |
| if ok: |
| print(f" OK: aggregate query generated") |
| print(f" SQL: {sql}") |
| return ok |
|
|
|
|
| |
| |
| |
|
|
| def test_scenario5_natural_language(): |
| print("\n=== Scenario 5: Natural language SQL (Issue 3) ===") |
| schema = app.PRESETS["products"] |
| result = app.generate_response( |
| "what is the most expensive product", |
| [], schema, app.FINE_TUNED_MODEL_KEY, None |
| ) |
| sql = sql_out(result) |
| status_msg = status(result) |
| ok = True |
| if not sql.strip(): |
| print(f" FAIL: no SQL generated — model returned: {status_msg[:100]}") |
| ok = False |
| elif "SELECT" not in sql.upper(): |
| print(f" FAIL: output is not SQL: {sql[:100]}") |
| ok = False |
| else: |
| print(f" OK: natural language produced SQL") |
| print(f" SQL: {sql}") |
| return ok |
|
|
|
|
| |
| |
| |
|
|
| def test_scenario6_multiturn_flow(): |
| print("\n=== Scenario 6: Multi-turn schema build + query ===") |
| ok = True |
|
|
| |
| r1 = app.generate_response( |
| "crie tabela vendas com id produto quantidade total", |
| [], "", None, None |
| ) |
| if not check_sql(r1, ["CREATE TABLE vendas", "id INTEGER", "produto TEXT", "quantidade INTEGER", "total NUMERIC"], "Step 1: CREATE TABLE"): |
| ok = False |
|
|
| |
| r2 = app.generate_response("adicione desconto", r1[0], "", None, None) |
| if not check_sql(r2, ["desconto NUMERIC", "CREATE TABLE vendas"], "Step 2: ADD COLUMN"): |
| ok = False |
|
|
| |
| r3 = app.generate_response("remova quantidade", r2[0], "", None, None) |
| sql3 = sql_out(r3) |
| |
| if "quantidade" in sql3: |
| print(f" FAIL: 'quantidade' still in table after remove (regression)") |
| ok = False |
| else: |
| print(f" OK: Step 3: REMOVE COLUMN - 'quantidade' removed") |
| |
| for col in ["id", "produto", "desconto", "total"]: |
| if col not in sql3: |
| print(f" FAIL: column '{col}' missing after remove") |
| ok = False |
|
|
| |
| final_schema = sql_out(r3) |
| r4 = app.generate_response( |
| "quanto vendemos no total", |
| r3[0], final_schema, app.FINE_TUNED_MODEL_KEY, None |
| ) |
| sql4 = sql_out(r4) |
| if "SELECT" not in sql4.upper(): |
| print(f" FAIL: Step 4 no SELECT generated. Status: {status(r4)[:100]}") |
| ok = False |
| else: |
| print(f" OK: Step 4: model generated SQL from multi-turn context") |
| print(f" SQL: {sql4}") |
|
|
| return ok |
|
|
|
|
| |
| |
| |
|
|
| def run_all(): |
| if app._model is None: |
| print("ERROR: model not loaded. Run app.load_model(app.FINE_TUNED_MODEL_ID) first.") |
| return |
|
|
| results = {} |
| results["s1_parser"] = test_scenario1_parser_keeps_working() |
| results["s2_select_all"] = test_scenario2_select_all() |
| results["s3_where"] = test_scenario3_select_with_filter() |
| results["s4_aggregates"] = test_scenario4_aggregates() |
| results["s5_natlang"] = test_scenario5_natural_language() |
| results["s6_multiturn"] = test_scenario6_multiturn_flow() |
|
|
| print("\n" + "=" * 50) |
| print("SUMMARY") |
| print("=" * 50) |
| passed = sum(1 for v in results.values() if v) |
| total = len(results) |
| for name, result in results.items(): |
| mark = "PASS" if result else "FAIL" |
| print(f" {mark} {name}") |
| print(f"\n Total: {passed}/{total} passed") |
|
|
| return passed == total |
|
|
|
|
| if __name__ == "__main__": |
| |
| if app._model is None: |
| print("Model not loaded. Call app.load_model(app.FINE_TUNED_MODEL_ID) then re-run.") |
| print("From python: python -c \"import app; app.load_model(app.FINE_TUNED_MODEL_ID); exec(open('tests/e2e_flow_test.py').read())\"") |
| else: |
| run_all() |
|
|