| """test_data_engine.py — Data engine and DuckDB integration tests. |
| |
| Covers: |
| - Schema introspection correctness |
| - Connection isolation (per-request independence) |
| - Seed data integrity (schools, years, grades, chronic rate) |
| - Query timeout mechanism |
| - Empty result handling |
| - extract_sql edge cases |
| """ |
|
|
| import time |
| import pytest |
| from data_engine import ( |
| create_session, |
| get_connection, |
| seed_database, |
| get_schema_info, |
| extract_sql, |
| validate_sql, |
| execute_safe, |
| QueryTimeoutError, |
| MAX_RESULT_ROWS, |
| QUERY_TIMEOUT_SEC, |
| ) |
|
|
|
|
| |
|
|
| class TestSchemaIntrospection: |
| """get_schema_info() should return correct table/column metadata.""" |
|
|
| def test_returns_correct_tables(self, db): |
| info = get_schema_info(db) |
| assert "enrollment" in info |
| assert "attendance" in info |
|
|
| def test_enrollment_columns(self, db): |
| info = get_schema_info(db) |
| cols = {name for name, _, _ in info["enrollment"]} |
| assert cols == {"school_year", "school_name", "grade_level", "student_count"} |
|
|
| def test_attendance_columns(self, db): |
| info = get_schema_info(db) |
| cols = {name for name, _, _ in info["attendance"]} |
| assert cols == { |
| "student_id", "school_name", "school_year", |
| "absence_count", "is_chronically_absent", |
| } |
|
|
| def test_no_extra_tables(self, db): |
| info = get_schema_info(db) |
| |
| expected_tables = {"enrollment", "attendance", "students", "discipline", "grades"} |
| assert set(info.keys()) == expected_tables |
|
|
|
|
| |
|
|
| class TestConnectionIsolation: |
| """Each create_session() should return an independent database.""" |
|
|
| def test_independent_connections(self): |
| conn_a = create_session() |
| conn_b = create_session() |
|
|
| |
| a_tables = conn_a.execute( |
| "SELECT table_name FROM information_schema.tables WHERE table_schema='main'" |
| ).fetchall() |
| b_tables = conn_b.execute( |
| "SELECT table_name FROM information_schema.tables WHERE table_schema='main'" |
| ).fetchall() |
| assert a_tables == b_tables |
|
|
| |
| conn_a.close() |
| result = conn_b.execute("SELECT 1").fetchone() |
| assert result == (1,) |
|
|
| conn_b.close() |
|
|
| def test_data_is_independent(self): |
| """Modifications in one connection don't affect another.""" |
| conn_a = create_session() |
| conn_b = create_session() |
|
|
| |
| count_a = conn_a.execute("SELECT COUNT(*) FROM enrollment").fetchone()[0] |
| count_b = conn_b.execute("SELECT COUNT(*) FROM enrollment").fetchone()[0] |
| assert count_a == count_b |
|
|
| conn_a.close() |
| conn_b.close() |
|
|
|
|
| |
|
|
| class TestSeedDataIntegrity: |
| """Seed data should match the expected shape and distributions.""" |
|
|
| def test_five_schools(self, db): |
| schools = db.execute( |
| "SELECT DISTINCT school_name FROM enrollment ORDER BY school_name" |
| ).fetchall() |
| assert len(schools) == 5 |
|
|
| def test_four_school_years(self, db): |
| years = db.execute( |
| "SELECT DISTINCT school_year FROM enrollment ORDER BY school_year" |
| ).fetchall() |
| assert len(years) == 4 |
| assert years[0][0] == "2021-2022" |
| assert years[-1][0] == "2024-2025" |
|
|
| def test_thirteen_grade_levels(self, db): |
| grades = db.execute( |
| "SELECT DISTINCT grade_level FROM enrollment ORDER BY grade_level" |
| ).fetchall() |
| assert len(grades) == 13 |
|
|
| def test_chronic_absenteeism_rate(self, db): |
| rate = db.execute(""" |
| SELECT 100.0 * SUM(CASE WHEN is_chronically_absent THEN 1 ELSE 0 END) / COUNT(*) |
| FROM attendance |
| """).fetchone()[0] |
| assert 12 <= rate <= 18, f"Chronic rate {rate:.1f}% outside expected range" |
|
|
| def test_yoy_enrollment_growth(self, db): |
| rows = db.execute(""" |
| SELECT school_year, SUM(student_count) AS total |
| FROM enrollment GROUP BY school_year ORDER BY school_year |
| """).fetchall() |
| for i in range(1, len(rows)): |
| assert rows[i][1] > rows[i-1][1], \ |
| f"Enrollment did not grow: {rows[i-1]} → {rows[i]}" |
|
|
| def test_school_grade_ranges(self, db): |
| """Each school should only have grades in its range.""" |
| grade_ranges = { |
| "Lincoln Elementary": (0, 5), |
| "Washington Middle": (6, 8), |
| "Jefferson High": (9, 12), |
| "Roosevelt Academy": (0, 8), |
| "Kennedy Prep": (6, 12), |
| } |
| for school, (lo, hi) in grade_ranges.items(): |
| grades = db.execute( |
| "SELECT DISTINCT grade_level FROM enrollment WHERE school_name = ?", |
| [school], |
| ).fetchall() |
| for (g,) in grades: |
| assert lo <= g <= hi, \ |
| f"{school} has grade {g}, expected {lo}-{hi}" |
|
|
|
|
| |
|
|
| class TestTimeout: |
| """Query timeout should trigger on slow queries.""" |
|
|
| def test_fast_query_succeeds(self, db): |
| sql, df = execute_safe(db, "SELECT 1", timeout_sec=5) |
| assert df.shape == (1, 1) |
|
|
| def test_timeout_configuration(self): |
| """Verify timeout_sec is configurable.""" |
| assert QUERY_TIMEOUT_SEC == 10 |
|
|
| def test_max_result_rows_caps_output(self, db): |
| """MAX_RESULT_ROWS should limit results.""" |
| sql, df = execute_safe(db, "SELECT * FROM attendance") |
| assert len(df) <= MAX_RESULT_ROWS |
|
|
|
|
| |
|
|
| class TestEmptyResults: |
| """Queries that return no rows should not error.""" |
|
|
| def test_empty_where_clause(self, db): |
| sql, df = execute_safe( |
| db, |
| "```sql\nSELECT * FROM attendance WHERE school_name = 'Nonexistent'\n```", |
| ) |
| assert sql == "SELECT * FROM attendance WHERE school_name = 'Nonexistent'" |
| assert len(df) == 0 |
|
|
| def test_empty_aggregate(self, db): |
| sql, df = execute_safe( |
| db, |
| "SELECT COUNT(*) FROM attendance WHERE 1 = 0", |
| ) |
| assert df.iloc[0, 0] == 0 |
|
|
|
|
| |
|
|
| class TestExtractSqlEdgeCases: |
| """extract_sql() should handle unusual but valid inputs.""" |
|
|
| def test_whitespace_only_after_block(self): |
| result = extract_sql("```sql\nSELECT 1\n``` \n ") |
| assert result == "SELECT 1" |
|
|
| def test_no_newline_after_fence(self): |
| result = extract_sql("```sql\nSELECT 1```") |
| assert result == "SELECT 1" |
|
|
| def test_multiple_code_blocks(self): |
| """Only the first ```sql``` block should be extracted.""" |
| result = extract_sql("```sql\nSELECT 1\n```\n```sql\nSELECT 2\n```") |
| assert result == "SELECT 1" |
|
|
| def test_case_insensitive_sql_fence(self): |
| result = extract_sql("```SQL\nSELECT 1\n```") |
| assert result == "SELECT 1" |
|
|
| def test_trailing_semicolons_stripped(self): |
| result = extract_sql("```sql\nSELECT 1;;;;\n```") |
| assert result == "SELECT 1" |
|
|