File size: 8,174 Bytes
17674c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240383f
 
 
17674c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
"""test_data_engine.py — Data engine and DuckDB integration tests.

Covers:
  - Schema introspection correctness
  - Connection isolation (per-request independence)
  - Seed data integrity (schools, years, grades, chronic rate)
  - Query timeout mechanism
  - Empty result handling
  - extract_sql edge cases
"""

import time
import pytest
from data_engine import (
    create_session,
    get_connection,
    seed_database,
    get_schema_info,
    extract_sql,
    validate_sql,
    execute_safe,
    QueryTimeoutError,
    MAX_RESULT_ROWS,
    QUERY_TIMEOUT_SEC,
)


# ── Schema introspection ───────────────────────────────────────────────

class TestSchemaIntrospection:
    """get_schema_info() should return correct table/column metadata."""

    def test_returns_correct_tables(self, db):
        info = get_schema_info(db)
        assert "enrollment" in info
        assert "attendance" in info

    def test_enrollment_columns(self, db):
        info = get_schema_info(db)
        cols = {name for name, _, _ in info["enrollment"]}
        assert cols == {"school_year", "school_name", "grade_level", "student_count"}

    def test_attendance_columns(self, db):
        info = get_schema_info(db)
        cols = {name for name, _, _ in info["attendance"]}
        assert cols == {
            "student_id", "school_name", "school_year",
            "absence_count", "is_chronically_absent",
        }

    def test_no_extra_tables(self, db):
        info = get_schema_info(db)
        # Day 1 expanded schema: 5 tables (enrollment, attendance, students, discipline, grades)
        expected_tables = {"enrollment", "attendance", "students", "discipline", "grades"}
        assert set(info.keys()) == expected_tables


# ── Connection isolation ───────────────────────────────────────────────

class TestConnectionIsolation:
    """Each create_session() should return an independent database."""

    def test_independent_connections(self):
        conn_a = create_session()
        conn_b = create_session()

        # Both have the same tables
        a_tables = conn_a.execute(
            "SELECT table_name FROM information_schema.tables WHERE table_schema='main'"
        ).fetchall()
        b_tables = conn_b.execute(
            "SELECT table_name FROM information_schema.tables WHERE table_schema='main'"
        ).fetchall()
        assert a_tables == b_tables

        # Closing one doesn't affect the other
        conn_a.close()
        result = conn_b.execute("SELECT 1").fetchone()
        assert result == (1,)

        conn_b.close()

    def test_data_is_independent(self):
        """Modifications in one connection don't affect another."""
        conn_a = create_session()
        conn_b = create_session()

        # Count in both should match initially
        count_a = conn_a.execute("SELECT COUNT(*) FROM enrollment").fetchone()[0]
        count_b = conn_b.execute("SELECT COUNT(*) FROM enrollment").fetchone()[0]
        assert count_a == count_b

        conn_a.close()
        conn_b.close()


# ── Seed data integrity ────────────────────────────────────────────────

class TestSeedDataIntegrity:
    """Seed data should match the expected shape and distributions."""

    def test_five_schools(self, db):
        schools = db.execute(
            "SELECT DISTINCT school_name FROM enrollment ORDER BY school_name"
        ).fetchall()
        assert len(schools) == 5

    def test_four_school_years(self, db):
        years = db.execute(
            "SELECT DISTINCT school_year FROM enrollment ORDER BY school_year"
        ).fetchall()
        assert len(years) == 4
        assert years[0][0] == "2021-2022"
        assert years[-1][0] == "2024-2025"

    def test_thirteen_grade_levels(self, db):
        grades = db.execute(
            "SELECT DISTINCT grade_level FROM enrollment ORDER BY grade_level"
        ).fetchall()
        assert len(grades) == 13  # K (0) through 12

    def test_chronic_absenteeism_rate(self, db):
        rate = db.execute("""
            SELECT 100.0 * SUM(CASE WHEN is_chronically_absent THEN 1 ELSE 0 END) / COUNT(*)
            FROM attendance
        """).fetchone()[0]
        assert 12 <= rate <= 18, f"Chronic rate {rate:.1f}% outside expected range"

    def test_yoy_enrollment_growth(self, db):
        rows = db.execute("""
            SELECT school_year, SUM(student_count) AS total
            FROM enrollment GROUP BY school_year ORDER BY school_year
        """).fetchall()
        for i in range(1, len(rows)):
            assert rows[i][1] > rows[i-1][1], \
                f"Enrollment did not grow: {rows[i-1]}{rows[i]}"

    def test_school_grade_ranges(self, db):
        """Each school should only have grades in its range."""
        grade_ranges = {
            "Lincoln Elementary": (0, 5),
            "Washington Middle": (6, 8),
            "Jefferson High": (9, 12),
            "Roosevelt Academy": (0, 8),
            "Kennedy Prep": (6, 12),
        }
        for school, (lo, hi) in grade_ranges.items():
            grades = db.execute(
                "SELECT DISTINCT grade_level FROM enrollment WHERE school_name = ?",
                [school],
            ).fetchall()
            for (g,) in grades:
                assert lo <= g <= hi, \
                    f"{school} has grade {g}, expected {lo}-{hi}"


# ── Timeout mechanism ──────────────────────────────────────────────────

class TestTimeout:
    """Query timeout should trigger on slow queries."""

    def test_fast_query_succeeds(self, db):
        sql, df = execute_safe(db, "SELECT 1", timeout_sec=5)
        assert df.shape == (1, 1)

    def test_timeout_configuration(self):
        """Verify timeout_sec is configurable."""
        assert QUERY_TIMEOUT_SEC == 10

    def test_max_result_rows_caps_output(self, db):
        """MAX_RESULT_ROWS should limit results."""
        sql, df = execute_safe(db, "SELECT * FROM attendance")
        assert len(df) <= MAX_RESULT_ROWS


# ── Empty results ──────────────────────────────────────────────────────

class TestEmptyResults:
    """Queries that return no rows should not error."""

    def test_empty_where_clause(self, db):
        sql, df = execute_safe(
            db,
            "```sql\nSELECT * FROM attendance WHERE school_name = 'Nonexistent'\n```",
        )
        assert sql == "SELECT * FROM attendance WHERE school_name = 'Nonexistent'"
        assert len(df) == 0

    def test_empty_aggregate(self, db):
        sql, df = execute_safe(
            db,
            "SELECT COUNT(*) FROM attendance WHERE 1 = 0",
        )
        assert df.iloc[0, 0] == 0


# ── extract_sql edge cases ─────────────────────────────────────────────

class TestExtractSqlEdgeCases:
    """extract_sql() should handle unusual but valid inputs."""

    def test_whitespace_only_after_block(self):
        result = extract_sql("```sql\nSELECT 1\n```   \n  ")
        assert result == "SELECT 1"

    def test_no_newline_after_fence(self):
        result = extract_sql("```sql\nSELECT 1```")
        assert result == "SELECT 1"

    def test_multiple_code_blocks(self):
        """Only the first ```sql``` block should be extracted."""
        result = extract_sql("```sql\nSELECT 1\n```\n```sql\nSELECT 2\n```")
        assert result == "SELECT 1"

    def test_case_insensitive_sql_fence(self):
        result = extract_sql("```SQL\nSELECT 1\n```")
        assert result == "SELECT 1"

    def test_trailing_semicolons_stripped(self):
        result = extract_sql("```sql\nSELECT 1;;;;\n```")
        assert result == "SELECT 1"