Spaces:
Paused
Paused
| import os | |
| import pytest | |
| import asyncio | |
| from typing import AsyncGenerator, Generator | |
| from fastapi.testclient import TestClient | |
| from httpx import AsyncClient | |
| from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession | |
| from sqlalchemy.orm import sessionmaker | |
| from sqlalchemy.pool import StaticPool | |
| from dotenv import load_dotenv | |
| from pathlib import Path | |
| # Load test environment variables | |
| test_env_path = Path(__file__).parent / "test.env" | |
| load_dotenv(test_env_path) | |
| from app.db.database import Base, get_db | |
| from app.main import app | |
| from app.core.config import settings | |
| # Use SQLite for testing | |
| SQLALCHEMY_DATABASE_URL = "sqlite+aiosqlite:///:memory:" | |
| engine = create_async_engine( | |
| SQLALCHEMY_DATABASE_URL, | |
| connect_args={"check_same_thread": False}, | |
| poolclass=StaticPool, | |
| ) | |
| TestingSessionLocal = sessionmaker( | |
| engine, | |
| class_=AsyncSession, | |
| expire_on_commit=False, | |
| ) | |
| def event_loop() -> Generator: | |
| """Create an instance of the default event loop for each test case.""" | |
| try: | |
| loop = asyncio.get_event_loop() | |
| except RuntimeError: | |
| loop = asyncio.new_event_loop() | |
| yield loop | |
| loop.close() | |
| async def test_app(): | |
| """Create tables for testing and clean up afterward.""" | |
| # Create test upload directories | |
| os.makedirs("./test_uploads/documents", exist_ok=True) | |
| os.makedirs("./test_uploads/images", exist_ok=True) | |
| # Create tables | |
| async with engine.begin() as conn: | |
| await conn.run_sync(Base.metadata.drop_all) # Ensure clean state | |
| await conn.run_sync(Base.metadata.create_all) | |
| yield app | |
| # Cleanup | |
| async with engine.begin() as conn: | |
| await conn.run_sync(Base.metadata.drop_all) | |
| async def db_session() -> AsyncGenerator[AsyncSession, None]: | |
| """Get a test database session.""" | |
| async with TestingSessionLocal() as session: | |
| yield session | |
| # Rollback any changes made in the test | |
| await session.rollback() | |
| async def client(test_app) -> AsyncGenerator[AsyncClient, None]: | |
| """Get a test client with database session override.""" | |
| async def override_get_db(): | |
| async with TestingSessionLocal() as session: | |
| yield session | |
| app.dependency_overrides[get_db] = override_get_db | |
| async with AsyncClient(app=app, base_url="http://test") as client: | |
| yield client | |
| app.dependency_overrides.clear() |