File size: 3,244 Bytes
cc54551 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
"""Tests for database connection and query execution."""
import pytest
import pandas as pd
from ..database import Database, get_database, query
class TestDatabase:
"""Test database connection functionality."""
def test_database_initialization(self):
"""Test database can be initialized."""
db = Database()
assert db is not None
def test_in_memory_connection(self):
"""Test in-memory database connection works."""
db = Database(db_path=None)
conn = db.connection
assert conn is not None
def test_query_returns_dataframe(self):
"""Test that queries return pandas DataFrames."""
db = Database()
result = db.query("SELECT 1 as value")
assert isinstance(result, pd.DataFrame)
assert len(result) == 1
assert result['value'].iloc[0] == 1
def test_sample_data_creation(self):
"""Test sample data is created for demo mode."""
db = Database(db_path=None)
# Force sample data creation
_ = db.connection
# Check tables exist
result = db.query("SELECT * FROM marts.fct_corridor_flows LIMIT 1")
assert not result.empty
def test_global_database_singleton(self):
"""Test global database returns same instance."""
db1 = get_database()
db2 = get_database()
# Should be same connection pool
assert db1 is db2
def test_query_helper_function(self):
"""Test the query helper function works."""
result = query("SELECT 42 as answer")
assert isinstance(result, pd.DataFrame)
assert result['answer'].iloc[0] == 42
class TestDatabaseQueries:
"""Test specific query patterns used by components."""
@pytest.fixture
def db(self):
"""Create fresh database for each test."""
return Database(db_path=None)
def test_corridor_flows_query(self, db):
"""Test corridor flows query structure."""
_ = db.connection # Initialize sample data
result = db.query("""
SELECT corridor_id, hour_bucket, avg_speed_mph
FROM marts.fct_corridor_flows
LIMIT 10
""")
assert 'corridor_id' in result.columns
assert 'hour_bucket' in result.columns
assert 'avg_speed_mph' in result.columns
def test_incentive_events_query(self, db):
"""Test incentive events query structure."""
_ = db.connection
result = db.query("""
SELECT incentive_type, was_accepted, was_completed
FROM marts.fct_incentive_events
LIMIT 10
""")
assert 'incentive_type' in result.columns
assert 'was_accepted' in result.columns
def test_aggregation_query(self, db):
"""Test aggregation queries work correctly."""
_ = db.connection
result = db.query("""
SELECT
incentive_type,
count(*) as count,
avg(offered_amount) as avg_amount
FROM marts.fct_incentive_events
GROUP BY incentive_type
""")
assert 'count' in result.columns
assert 'avg_amount' in result.columns
assert all(result['count'] > 0)
|