|
|
import pytest |
|
|
from sqlalchemy import create_engine |
|
|
from sqlalchemy.orm import sessionmaker |
|
|
from fastapi.testclient import TestClient |
|
|
|
|
|
from main import app |
|
|
from database.database import get_db |
|
|
from models.base import Base |
|
|
|
|
|
|
|
|
|
|
|
TEST_DATABASE_URL = "sqlite:///./test_assessment_platform.db" |
|
|
engine = create_engine(TEST_DATABASE_URL, connect_args={"check_same_thread": False}) |
|
|
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) |
|
|
|
|
|
|
|
|
def override_get_db(): |
|
|
"""Override the get_db dependency for testing.""" |
|
|
try: |
|
|
db = TestingSessionLocal() |
|
|
yield db |
|
|
finally: |
|
|
db.close() |
|
|
|
|
|
|
|
|
|
|
|
app.dependency_overrides[get_db] = override_get_db |
|
|
|
|
|
|
|
|
@pytest.fixture(scope="session") |
|
|
def db_engine(): |
|
|
"""Create a test database engine.""" |
|
|
Base.metadata.create_all(bind=engine) |
|
|
yield engine |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="function") |
|
|
def db_session(db_engine): |
|
|
"""Create a test database session.""" |
|
|
connection = db_engine.connect() |
|
|
transaction = connection.begin() |
|
|
session = TestingSessionLocal(bind=connection) |
|
|
|
|
|
yield session |
|
|
|
|
|
session.close() |
|
|
transaction.rollback() |
|
|
connection.close() |
|
|
|
|
|
|
|
|
@pytest.fixture(scope="module") |
|
|
def client(): |
|
|
"""Create a test client.""" |
|
|
with TestClient(app) as test_client: |
|
|
yield test_client |