import os import pytest import asyncio from typing import AsyncGenerator, Generator from fastapi.testclient import TestClient from httpx import AsyncClient from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import StaticPool from dotenv import load_dotenv from pathlib import Path # Load test environment variables test_env_path = Path(__file__).parent / "test.env" load_dotenv(test_env_path) from app.db.database import Base, get_db from app.main import app from app.core.config import settings # Use SQLite for testing SQLALCHEMY_DATABASE_URL = "sqlite+aiosqlite:///:memory:" engine = create_async_engine( SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}, poolclass=StaticPool, ) TestingSessionLocal = sessionmaker( engine, class_=AsyncSession, expire_on_commit=False, ) @pytest.fixture(scope="session") def event_loop() -> Generator: """Create an instance of the default event loop for each test case.""" try: loop = asyncio.get_event_loop() except RuntimeError: loop = asyncio.new_event_loop() yield loop loop.close() @pytest.fixture(scope="session") async def test_app(): """Create tables for testing and clean up afterward.""" # Create test upload directories os.makedirs("./test_uploads/documents", exist_ok=True) os.makedirs("./test_uploads/images", exist_ok=True) # Create tables async with engine.begin() as conn: await conn.run_sync(Base.metadata.drop_all) # Ensure clean state await conn.run_sync(Base.metadata.create_all) yield app # Cleanup async with engine.begin() as conn: await conn.run_sync(Base.metadata.drop_all) @pytest.fixture async def db_session() -> AsyncGenerator[AsyncSession, None]: """Get a test database session.""" async with TestingSessionLocal() as session: yield session # Rollback any changes made in the test await session.rollback() @pytest.fixture async def client(test_app) -> AsyncGenerator[AsyncClient, None]: """Get a test client with database session override.""" async def override_get_db(): async with TestingSessionLocal() as session: yield session app.dependency_overrides[get_db] = override_get_db async with AsyncClient(app=app, base_url="http://test") as client: yield client app.dependency_overrides.clear()