File size: 8,174 Bytes
17674c2 240383f 17674c2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 | """test_data_engine.py — Data engine and DuckDB integration tests.
Covers:
- Schema introspection correctness
- Connection isolation (per-request independence)
- Seed data integrity (schools, years, grades, chronic rate)
- Query timeout mechanism
- Empty result handling
- extract_sql edge cases
"""
import time
import pytest
from data_engine import (
create_session,
get_connection,
seed_database,
get_schema_info,
extract_sql,
validate_sql,
execute_safe,
QueryTimeoutError,
MAX_RESULT_ROWS,
QUERY_TIMEOUT_SEC,
)
# ── Schema introspection ───────────────────────────────────────────────
class TestSchemaIntrospection:
"""get_schema_info() should return correct table/column metadata."""
def test_returns_correct_tables(self, db):
info = get_schema_info(db)
assert "enrollment" in info
assert "attendance" in info
def test_enrollment_columns(self, db):
info = get_schema_info(db)
cols = {name for name, _, _ in info["enrollment"]}
assert cols == {"school_year", "school_name", "grade_level", "student_count"}
def test_attendance_columns(self, db):
info = get_schema_info(db)
cols = {name for name, _, _ in info["attendance"]}
assert cols == {
"student_id", "school_name", "school_year",
"absence_count", "is_chronically_absent",
}
def test_no_extra_tables(self, db):
info = get_schema_info(db)
# Day 1 expanded schema: 5 tables (enrollment, attendance, students, discipline, grades)
expected_tables = {"enrollment", "attendance", "students", "discipline", "grades"}
assert set(info.keys()) == expected_tables
# ── Connection isolation ───────────────────────────────────────────────
class TestConnectionIsolation:
"""Each create_session() should return an independent database."""
def test_independent_connections(self):
conn_a = create_session()
conn_b = create_session()
# Both have the same tables
a_tables = conn_a.execute(
"SELECT table_name FROM information_schema.tables WHERE table_schema='main'"
).fetchall()
b_tables = conn_b.execute(
"SELECT table_name FROM information_schema.tables WHERE table_schema='main'"
).fetchall()
assert a_tables == b_tables
# Closing one doesn't affect the other
conn_a.close()
result = conn_b.execute("SELECT 1").fetchone()
assert result == (1,)
conn_b.close()
def test_data_is_independent(self):
"""Modifications in one connection don't affect another."""
conn_a = create_session()
conn_b = create_session()
# Count in both should match initially
count_a = conn_a.execute("SELECT COUNT(*) FROM enrollment").fetchone()[0]
count_b = conn_b.execute("SELECT COUNT(*) FROM enrollment").fetchone()[0]
assert count_a == count_b
conn_a.close()
conn_b.close()
# ── Seed data integrity ────────────────────────────────────────────────
class TestSeedDataIntegrity:
"""Seed data should match the expected shape and distributions."""
def test_five_schools(self, db):
schools = db.execute(
"SELECT DISTINCT school_name FROM enrollment ORDER BY school_name"
).fetchall()
assert len(schools) == 5
def test_four_school_years(self, db):
years = db.execute(
"SELECT DISTINCT school_year FROM enrollment ORDER BY school_year"
).fetchall()
assert len(years) == 4
assert years[0][0] == "2021-2022"
assert years[-1][0] == "2024-2025"
def test_thirteen_grade_levels(self, db):
grades = db.execute(
"SELECT DISTINCT grade_level FROM enrollment ORDER BY grade_level"
).fetchall()
assert len(grades) == 13 # K (0) through 12
def test_chronic_absenteeism_rate(self, db):
rate = db.execute("""
SELECT 100.0 * SUM(CASE WHEN is_chronically_absent THEN 1 ELSE 0 END) / COUNT(*)
FROM attendance
""").fetchone()[0]
assert 12 <= rate <= 18, f"Chronic rate {rate:.1f}% outside expected range"
def test_yoy_enrollment_growth(self, db):
rows = db.execute("""
SELECT school_year, SUM(student_count) AS total
FROM enrollment GROUP BY school_year ORDER BY school_year
""").fetchall()
for i in range(1, len(rows)):
assert rows[i][1] > rows[i-1][1], \
f"Enrollment did not grow: {rows[i-1]} → {rows[i]}"
def test_school_grade_ranges(self, db):
"""Each school should only have grades in its range."""
grade_ranges = {
"Lincoln Elementary": (0, 5),
"Washington Middle": (6, 8),
"Jefferson High": (9, 12),
"Roosevelt Academy": (0, 8),
"Kennedy Prep": (6, 12),
}
for school, (lo, hi) in grade_ranges.items():
grades = db.execute(
"SELECT DISTINCT grade_level FROM enrollment WHERE school_name = ?",
[school],
).fetchall()
for (g,) in grades:
assert lo <= g <= hi, \
f"{school} has grade {g}, expected {lo}-{hi}"
# ── Timeout mechanism ──────────────────────────────────────────────────
class TestTimeout:
"""Query timeout should trigger on slow queries."""
def test_fast_query_succeeds(self, db):
sql, df = execute_safe(db, "SELECT 1", timeout_sec=5)
assert df.shape == (1, 1)
def test_timeout_configuration(self):
"""Verify timeout_sec is configurable."""
assert QUERY_TIMEOUT_SEC == 10
def test_max_result_rows_caps_output(self, db):
"""MAX_RESULT_ROWS should limit results."""
sql, df = execute_safe(db, "SELECT * FROM attendance")
assert len(df) <= MAX_RESULT_ROWS
# ── Empty results ──────────────────────────────────────────────────────
class TestEmptyResults:
"""Queries that return no rows should not error."""
def test_empty_where_clause(self, db):
sql, df = execute_safe(
db,
"```sql\nSELECT * FROM attendance WHERE school_name = 'Nonexistent'\n```",
)
assert sql == "SELECT * FROM attendance WHERE school_name = 'Nonexistent'"
assert len(df) == 0
def test_empty_aggregate(self, db):
sql, df = execute_safe(
db,
"SELECT COUNT(*) FROM attendance WHERE 1 = 0",
)
assert df.iloc[0, 0] == 0
# ── extract_sql edge cases ─────────────────────────────────────────────
class TestExtractSqlEdgeCases:
"""extract_sql() should handle unusual but valid inputs."""
def test_whitespace_only_after_block(self):
result = extract_sql("```sql\nSELECT 1\n``` \n ")
assert result == "SELECT 1"
def test_no_newline_after_fence(self):
result = extract_sql("```sql\nSELECT 1```")
assert result == "SELECT 1"
def test_multiple_code_blocks(self):
"""Only the first ```sql``` block should be extracted."""
result = extract_sql("```sql\nSELECT 1\n```\n```sql\nSELECT 2\n```")
assert result == "SELECT 1"
def test_case_insensitive_sql_fence(self):
result = extract_sql("```SQL\nSELECT 1\n```")
assert result == "SELECT 1"
def test_trailing_semicolons_stripped(self):
result = extract_sql("```sql\nSELECT 1;;;;\n```")
assert result == "SELECT 1"
|