FairRelay / brain /tests /conftest.py
MouleeswaranM's picture
Upload folder using huggingface_hub
fcf8749 verified
import pytest
import asyncio
import os
from typing import AsyncGenerator, Generator
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlalchemy.pool import StaticPool
from httpx import AsyncClient, ASGITransport
from app.database import Base, get_db
from app.main import app
from app.models import Driver, Route, DriverStatsDaily, FairnessConfig
from app.models.driver import VehicleType, PreferredLanguage
from tests.fixtures.test_data import generate_drivers, generate_routes, generate_allocation_request
from datetime import date, timedelta
import numpy as np
# Use in-memory SQLite for tests by default, unless TEST_DATABASE_URL is set
TEST_DB_URL = os.getenv("TEST_DATABASE_URL", "sqlite+aiosqlite:///:memory:")
@pytest.fixture(scope="session")
async def test_engine():
"""Session-scoped test database engine."""
engine = create_async_engine(
TEST_DB_URL,
connect_args={"check_same_thread": False} if "sqlite" in TEST_DB_URL else {},
poolclass=StaticPool if "sqlite" in TEST_DB_URL else None,
)
# Create tables
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
# Drop tables
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await engine.dispose()
@pytest.fixture
async def db_session(test_engine) -> AsyncGenerator[AsyncSession, None]:
"""Function-scoped fresh DB session."""
connection = await test_engine.connect()
transaction = await connection.begin()
session_maker = async_sessionmaker(
bind=connection,
class_=AsyncSession,
expire_on_commit=False,
)
session = session_maker()
yield session
await session.close()
await transaction.rollback()
await connection.close()
@pytest.fixture
async def client(db_session) -> AsyncGenerator[AsyncClient, None]:
"""Test client with override for get_db."""
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac:
yield ac
app.dependency_overrides.clear()
@pytest.fixture
async def sample_drivers(db_session):
"""50 drivers: 20% EV, mixed experience/stress."""
driver_data = generate_drivers(count=50, ev_ratio=0.2)
drivers = []
for d_data in driver_data:
driver = Driver(
external_id=d_data["id"],
name=d_data["name"],
vehicle_capacity_kg=d_data["vehicle_capacity_kg"],
preferred_language=PreferredLanguage(d_data["preferred_language"]),
vehicle_type=VehicleType.EV if d_data["is_ev"] else VehicleType.ICE,
battery_range_km=d_data["battery_range_km"],
charging_time_minutes=d_data["charging_time_minutes"],
)
db_session.add(driver)
drivers.append(driver)
await db_session.commit()
return drivers
@pytest.fixture
async def sample_routes(db_session):
"""50 routes: varied difficulty."""
# We create routes in DB but allocation request will likely overwrite/use them
# Actually allocation creates routes from packages.
# This fixture is useful if we want to test other things that need existing routes.
# For allocation endpoint testing, we mainly need the request object.
route_data = generate_routes(count=50)
routes = []
for r_data in route_data:
route = Route(
date=date.today(),
cluster_id=r_data["cluster_id"],
total_weight_kg=r_data["total_weight_kg"],
num_packages=r_data["stops"] * 2,
num_stops=r_data["stops"],
route_difficulty_score=r_data["parking_difficulty"],
estimated_time_minutes=r_data["estimated_time_minutes"],
total_distance_km=r_data["total_distance_km"]
)
db_session.add(route)
routes.append(route)
await db_session.commit()
return routes
@pytest.fixture
async def allocation_request(sample_drivers):
"""Complete allocation request payload."""
# We use sample_drivers to ensure IDs match
drivers_list = []
for d in sample_drivers:
drivers_list.append({
"id": d.external_id,
"name": d.name,
"vehicle_capacity_kg": d.vehicle_capacity_kg,
"preferred_language": d.preferred_language.value,
"is_ev": d.vehicle_type == VehicleType.EV
})
# Generate fresh routes for the request
route_data = generate_routes(count=len(drivers_list))
return generate_allocation_request(drivers_list, route_data)
@pytest.fixture
async def active_config(db_session):
"""Active fairness configuration."""
config = FairnessConfig(
is_active=True,
gini_threshold=0.35,
stddev_threshold=25.0,
max_gap_threshold=25.0,
recovery_mode_enabled=True,
ev_safety_margin_pct=10.0,
ev_charging_penalty_weight=0.3,
recovery_penalty_weight=3.0
)
db_session.add(config)
await db_session.commit()
return config