"""test_data_engine.py — Data engine and DuckDB integration tests. Covers: - Schema introspection correctness - Connection isolation (per-request independence) - Seed data integrity (schools, years, grades, chronic rate) - Query timeout mechanism - Empty result handling - extract_sql edge cases """ import time import pytest from data_engine import ( create_session, get_connection, seed_database, get_schema_info, extract_sql, validate_sql, execute_safe, QueryTimeoutError, MAX_RESULT_ROWS, QUERY_TIMEOUT_SEC, ) # ── Schema introspection ─────────────────────────────────────────────── class TestSchemaIntrospection: """get_schema_info() should return correct table/column metadata.""" def test_returns_correct_tables(self, db): info = get_schema_info(db) assert "enrollment" in info assert "attendance" in info def test_enrollment_columns(self, db): info = get_schema_info(db) cols = {name for name, _, _ in info["enrollment"]} assert cols == {"school_year", "school_name", "grade_level", "student_count"} def test_attendance_columns(self, db): info = get_schema_info(db) cols = {name for name, _, _ in info["attendance"]} assert cols == { "student_id", "school_name", "school_year", "absence_count", "is_chronically_absent", } def test_no_extra_tables(self, db): info = get_schema_info(db) # Day 1 expanded schema: 5 tables (enrollment, attendance, students, discipline, grades) expected_tables = {"enrollment", "attendance", "students", "discipline", "grades"} assert set(info.keys()) == expected_tables # ── Connection isolation ─────────────────────────────────────────────── class TestConnectionIsolation: """Each create_session() should return an independent database.""" def test_independent_connections(self): conn_a = create_session() conn_b = create_session() # Both have the same tables a_tables = conn_a.execute( "SELECT table_name FROM information_schema.tables WHERE table_schema='main'" ).fetchall() b_tables = conn_b.execute( "SELECT table_name FROM information_schema.tables WHERE table_schema='main'" ).fetchall() assert a_tables == b_tables # Closing one doesn't affect the other conn_a.close() result = conn_b.execute("SELECT 1").fetchone() assert result == (1,) conn_b.close() def test_data_is_independent(self): """Modifications in one connection don't affect another.""" conn_a = create_session() conn_b = create_session() # Count in both should match initially count_a = conn_a.execute("SELECT COUNT(*) FROM enrollment").fetchone()[0] count_b = conn_b.execute("SELECT COUNT(*) FROM enrollment").fetchone()[0] assert count_a == count_b conn_a.close() conn_b.close() # ── Seed data integrity ──────────────────────────────────────────────── class TestSeedDataIntegrity: """Seed data should match the expected shape and distributions.""" def test_five_schools(self, db): schools = db.execute( "SELECT DISTINCT school_name FROM enrollment ORDER BY school_name" ).fetchall() assert len(schools) == 5 def test_four_school_years(self, db): years = db.execute( "SELECT DISTINCT school_year FROM enrollment ORDER BY school_year" ).fetchall() assert len(years) == 4 assert years[0][0] == "2021-2022" assert years[-1][0] == "2024-2025" def test_thirteen_grade_levels(self, db): grades = db.execute( "SELECT DISTINCT grade_level FROM enrollment ORDER BY grade_level" ).fetchall() assert len(grades) == 13 # K (0) through 12 def test_chronic_absenteeism_rate(self, db): rate = db.execute(""" SELECT 100.0 * SUM(CASE WHEN is_chronically_absent THEN 1 ELSE 0 END) / COUNT(*) FROM attendance """).fetchone()[0] assert 12 <= rate <= 18, f"Chronic rate {rate:.1f}% outside expected range" def test_yoy_enrollment_growth(self, db): rows = db.execute(""" SELECT school_year, SUM(student_count) AS total FROM enrollment GROUP BY school_year ORDER BY school_year """).fetchall() for i in range(1, len(rows)): assert rows[i][1] > rows[i-1][1], \ f"Enrollment did not grow: {rows[i-1]} → {rows[i]}" def test_school_grade_ranges(self, db): """Each school should only have grades in its range.""" grade_ranges = { "Lincoln Elementary": (0, 5), "Washington Middle": (6, 8), "Jefferson High": (9, 12), "Roosevelt Academy": (0, 8), "Kennedy Prep": (6, 12), } for school, (lo, hi) in grade_ranges.items(): grades = db.execute( "SELECT DISTINCT grade_level FROM enrollment WHERE school_name = ?", [school], ).fetchall() for (g,) in grades: assert lo <= g <= hi, \ f"{school} has grade {g}, expected {lo}-{hi}" # ── Timeout mechanism ────────────────────────────────────────────────── class TestTimeout: """Query timeout should trigger on slow queries.""" def test_fast_query_succeeds(self, db): sql, df = execute_safe(db, "SELECT 1", timeout_sec=5) assert df.shape == (1, 1) def test_timeout_configuration(self): """Verify timeout_sec is configurable.""" assert QUERY_TIMEOUT_SEC == 10 def test_max_result_rows_caps_output(self, db): """MAX_RESULT_ROWS should limit results.""" sql, df = execute_safe(db, "SELECT * FROM attendance") assert len(df) <= MAX_RESULT_ROWS # ── Empty results ────────────────────────────────────────────────────── class TestEmptyResults: """Queries that return no rows should not error.""" def test_empty_where_clause(self, db): sql, df = execute_safe( db, "```sql\nSELECT * FROM attendance WHERE school_name = 'Nonexistent'\n```", ) assert sql == "SELECT * FROM attendance WHERE school_name = 'Nonexistent'" assert len(df) == 0 def test_empty_aggregate(self, db): sql, df = execute_safe( db, "SELECT COUNT(*) FROM attendance WHERE 1 = 0", ) assert df.iloc[0, 0] == 0 # ── extract_sql edge cases ───────────────────────────────────────────── class TestExtractSqlEdgeCases: """extract_sql() should handle unusual but valid inputs.""" def test_whitespace_only_after_block(self): result = extract_sql("```sql\nSELECT 1\n``` \n ") assert result == "SELECT 1" def test_no_newline_after_fence(self): result = extract_sql("```sql\nSELECT 1```") assert result == "SELECT 1" def test_multiple_code_blocks(self): """Only the first ```sql``` block should be extracted.""" result = extract_sql("```sql\nSELECT 1\n```\n```sql\nSELECT 2\n```") assert result == "SELECT 1" def test_case_insensitive_sql_fence(self): result = extract_sql("```SQL\nSELECT 1\n```") assert result == "SELECT 1" def test_trailing_semicolons_stripped(self): result = extract_sql("```sql\nSELECT 1;;;;\n```") assert result == "SELECT 1"