ihute / tests /test_database.py
LeonceNsh's picture
Upload 4 files
cc54551 verified
"""Tests for database connection and query execution."""
import pytest
import pandas as pd
from ..database import Database, get_database, query
class TestDatabase:
"""Test database connection functionality."""
def test_database_initialization(self):
"""Test database can be initialized."""
db = Database()
assert db is not None
def test_in_memory_connection(self):
"""Test in-memory database connection works."""
db = Database(db_path=None)
conn = db.connection
assert conn is not None
def test_query_returns_dataframe(self):
"""Test that queries return pandas DataFrames."""
db = Database()
result = db.query("SELECT 1 as value")
assert isinstance(result, pd.DataFrame)
assert len(result) == 1
assert result['value'].iloc[0] == 1
def test_sample_data_creation(self):
"""Test sample data is created for demo mode."""
db = Database(db_path=None)
# Force sample data creation
_ = db.connection
# Check tables exist
result = db.query("SELECT * FROM marts.fct_corridor_flows LIMIT 1")
assert not result.empty
def test_global_database_singleton(self):
"""Test global database returns same instance."""
db1 = get_database()
db2 = get_database()
# Should be same connection pool
assert db1 is db2
def test_query_helper_function(self):
"""Test the query helper function works."""
result = query("SELECT 42 as answer")
assert isinstance(result, pd.DataFrame)
assert result['answer'].iloc[0] == 42
class TestDatabaseQueries:
"""Test specific query patterns used by components."""
@pytest.fixture
def db(self):
"""Create fresh database for each test."""
return Database(db_path=None)
def test_corridor_flows_query(self, db):
"""Test corridor flows query structure."""
_ = db.connection # Initialize sample data
result = db.query("""
SELECT corridor_id, hour_bucket, avg_speed_mph
FROM marts.fct_corridor_flows
LIMIT 10
""")
assert 'corridor_id' in result.columns
assert 'hour_bucket' in result.columns
assert 'avg_speed_mph' in result.columns
def test_incentive_events_query(self, db):
"""Test incentive events query structure."""
_ = db.connection
result = db.query("""
SELECT incentive_type, was_accepted, was_completed
FROM marts.fct_incentive_events
LIMIT 10
""")
assert 'incentive_type' in result.columns
assert 'was_accepted' in result.columns
def test_aggregation_query(self, db):
"""Test aggregation queries work correctly."""
_ = db.connection
result = db.query("""
SELECT
incentive_type,
count(*) as count,
avg(offered_amount) as avg_amount
FROM marts.fct_incentive_events
GROUP BY incentive_type
""")
assert 'count' in result.columns
assert 'avg_amount' in result.columns
assert all(result['count'] > 0)