Spaces:
Paused
Paused
File size: 2,571 Bytes
4c5298a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import os
import pytest
import asyncio
from typing import AsyncGenerator, Generator
from fastapi.testclient import TestClient
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from dotenv import load_dotenv
from pathlib import Path
# Load test environment variables
test_env_path = Path(__file__).parent / "test.env"
load_dotenv(test_env_path)
from app.db.database import Base, get_db
from app.main import app
from app.core.config import settings
# Use SQLite for testing
SQLALCHEMY_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
engine = create_async_engine(
SQLALCHEMY_DATABASE_URL,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
)
@pytest.fixture(scope="session")
def event_loop() -> Generator:
"""Create an instance of the default event loop for each test case."""
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
yield loop
loop.close()
@pytest.fixture(scope="session")
async def test_app():
"""Create tables for testing and clean up afterward."""
# Create test upload directories
os.makedirs("./test_uploads/documents", exist_ok=True)
os.makedirs("./test_uploads/images", exist_ok=True)
# Create tables
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all) # Ensure clean state
await conn.run_sync(Base.metadata.create_all)
yield app
# Cleanup
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
@pytest.fixture
async def db_session() -> AsyncGenerator[AsyncSession, None]:
"""Get a test database session."""
async with TestingSessionLocal() as session:
yield session
# Rollback any changes made in the test
await session.rollback()
@pytest.fixture
async def client(test_app) -> AsyncGenerator[AsyncClient, None]:
"""Get a test client with database session override."""
async def override_get_db():
async with TestingSessionLocal() as session:
yield session
app.dependency_overrides[get_db] = override_get_db
async with AsyncClient(app=app, base_url="http://test") as client:
yield client
app.dependency_overrides.clear() |