LFED / tests /test_data_engine.py
Kasualdad's picture
Day 1 expanded schema: 5 tables, 14B fine-tuned model
240383f
Raw
History Blame Contribute Delete
8.17 kB
"""test_data_engine.py — Data engine and DuckDB integration tests.
Covers:
- Schema introspection correctness
- Connection isolation (per-request independence)
- Seed data integrity (schools, years, grades, chronic rate)
- Query timeout mechanism
- Empty result handling
- extract_sql edge cases
"""
import time
import pytest
from data_engine import (
create_session,
get_connection,
seed_database,
get_schema_info,
extract_sql,
validate_sql,
execute_safe,
QueryTimeoutError,
MAX_RESULT_ROWS,
QUERY_TIMEOUT_SEC,
)
# ── Schema introspection ───────────────────────────────────────────────
class TestSchemaIntrospection:
"""get_schema_info() should return correct table/column metadata."""
def test_returns_correct_tables(self, db):
info = get_schema_info(db)
assert "enrollment" in info
assert "attendance" in info
def test_enrollment_columns(self, db):
info = get_schema_info(db)
cols = {name for name, _, _ in info["enrollment"]}
assert cols == {"school_year", "school_name", "grade_level", "student_count"}
def test_attendance_columns(self, db):
info = get_schema_info(db)
cols = {name for name, _, _ in info["attendance"]}
assert cols == {
"student_id", "school_name", "school_year",
"absence_count", "is_chronically_absent",
}
def test_no_extra_tables(self, db):
info = get_schema_info(db)
# Day 1 expanded schema: 5 tables (enrollment, attendance, students, discipline, grades)
expected_tables = {"enrollment", "attendance", "students", "discipline", "grades"}
assert set(info.keys()) == expected_tables
# ── Connection isolation ───────────────────────────────────────────────
class TestConnectionIsolation:
"""Each create_session() should return an independent database."""
def test_independent_connections(self):
conn_a = create_session()
conn_b = create_session()
# Both have the same tables
a_tables = conn_a.execute(
"SELECT table_name FROM information_schema.tables WHERE table_schema='main'"
).fetchall()
b_tables = conn_b.execute(
"SELECT table_name FROM information_schema.tables WHERE table_schema='main'"
).fetchall()
assert a_tables == b_tables
# Closing one doesn't affect the other
conn_a.close()
result = conn_b.execute("SELECT 1").fetchone()
assert result == (1,)
conn_b.close()
def test_data_is_independent(self):
"""Modifications in one connection don't affect another."""
conn_a = create_session()
conn_b = create_session()
# Count in both should match initially
count_a = conn_a.execute("SELECT COUNT(*) FROM enrollment").fetchone()[0]
count_b = conn_b.execute("SELECT COUNT(*) FROM enrollment").fetchone()[0]
assert count_a == count_b
conn_a.close()
conn_b.close()
# ── Seed data integrity ────────────────────────────────────────────────
class TestSeedDataIntegrity:
"""Seed data should match the expected shape and distributions."""
def test_five_schools(self, db):
schools = db.execute(
"SELECT DISTINCT school_name FROM enrollment ORDER BY school_name"
).fetchall()
assert len(schools) == 5
def test_four_school_years(self, db):
years = db.execute(
"SELECT DISTINCT school_year FROM enrollment ORDER BY school_year"
).fetchall()
assert len(years) == 4
assert years[0][0] == "2021-2022"
assert years[-1][0] == "2024-2025"
def test_thirteen_grade_levels(self, db):
grades = db.execute(
"SELECT DISTINCT grade_level FROM enrollment ORDER BY grade_level"
).fetchall()
assert len(grades) == 13 # K (0) through 12
def test_chronic_absenteeism_rate(self, db):
rate = db.execute("""
SELECT 100.0 * SUM(CASE WHEN is_chronically_absent THEN 1 ELSE 0 END) / COUNT(*)
FROM attendance
""").fetchone()[0]
assert 12 <= rate <= 18, f"Chronic rate {rate:.1f}% outside expected range"
def test_yoy_enrollment_growth(self, db):
rows = db.execute("""
SELECT school_year, SUM(student_count) AS total
FROM enrollment GROUP BY school_year ORDER BY school_year
""").fetchall()
for i in range(1, len(rows)):
assert rows[i][1] > rows[i-1][1], \
f"Enrollment did not grow: {rows[i-1]}{rows[i]}"
def test_school_grade_ranges(self, db):
"""Each school should only have grades in its range."""
grade_ranges = {
"Lincoln Elementary": (0, 5),
"Washington Middle": (6, 8),
"Jefferson High": (9, 12),
"Roosevelt Academy": (0, 8),
"Kennedy Prep": (6, 12),
}
for school, (lo, hi) in grade_ranges.items():
grades = db.execute(
"SELECT DISTINCT grade_level FROM enrollment WHERE school_name = ?",
[school],
).fetchall()
for (g,) in grades:
assert lo <= g <= hi, \
f"{school} has grade {g}, expected {lo}-{hi}"
# ── Timeout mechanism ──────────────────────────────────────────────────
class TestTimeout:
"""Query timeout should trigger on slow queries."""
def test_fast_query_succeeds(self, db):
sql, df = execute_safe(db, "SELECT 1", timeout_sec=5)
assert df.shape == (1, 1)
def test_timeout_configuration(self):
"""Verify timeout_sec is configurable."""
assert QUERY_TIMEOUT_SEC == 10
def test_max_result_rows_caps_output(self, db):
"""MAX_RESULT_ROWS should limit results."""
sql, df = execute_safe(db, "SELECT * FROM attendance")
assert len(df) <= MAX_RESULT_ROWS
# ── Empty results ──────────────────────────────────────────────────────
class TestEmptyResults:
"""Queries that return no rows should not error."""
def test_empty_where_clause(self, db):
sql, df = execute_safe(
db,
"```sql\nSELECT * FROM attendance WHERE school_name = 'Nonexistent'\n```",
)
assert sql == "SELECT * FROM attendance WHERE school_name = 'Nonexistent'"
assert len(df) == 0
def test_empty_aggregate(self, db):
sql, df = execute_safe(
db,
"SELECT COUNT(*) FROM attendance WHERE 1 = 0",
)
assert df.iloc[0, 0] == 0
# ── extract_sql edge cases ─────────────────────────────────────────────
class TestExtractSqlEdgeCases:
"""extract_sql() should handle unusual but valid inputs."""
def test_whitespace_only_after_block(self):
result = extract_sql("```sql\nSELECT 1\n``` \n ")
assert result == "SELECT 1"
def test_no_newline_after_fence(self):
result = extract_sql("```sql\nSELECT 1```")
assert result == "SELECT 1"
def test_multiple_code_blocks(self):
"""Only the first ```sql``` block should be extracted."""
result = extract_sql("```sql\nSELECT 1\n```\n```sql\nSELECT 2\n```")
assert result == "SELECT 1"
def test_case_insensitive_sql_fence(self):
result = extract_sql("```SQL\nSELECT 1\n```")
assert result == "SELECT 1"
def test_trailing_semicolons_stripped(self):
result = extract_sql("```sql\nSELECT 1;;;;\n```")
assert result == "SELECT 1"