AudioForge / backend /tests /test_api_generations.py
OnyxlMunkey's picture
c618549
import pytest
from fastapi.testclient import TestClient
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import uuid4
from app.main import app
from app.db.database import get_db
from app.services.orchestrator import GenerationOrchestrator
@pytest.fixture
def client():
return TestClient(app)
@pytest.fixture
def mock_db_session():
mock_session = AsyncMock()
# Mock the add, commit, refresh methods
mock_session.add = MagicMock()
mock_session.commit = AsyncMock()
async def mock_refresh(instance):
instance.id = uuid4()
instance.created_at = "2023-01-01T00:00:00Z"
mock_session.refresh = AsyncMock(side_effect=mock_refresh)
return mock_session
@pytest.fixture
def override_get_db(mock_db_session):
async def _get_db():
yield mock_db_session
return _get_db
@pytest.mark.asyncio
async def test_create_generation_endpoint(client, override_get_db):
app.dependency_overrides[get_db] = override_get_db
payload = {
"prompt": "A happy upbeat song",
"duration": 30
}
with patch("app.api.v1.endpoints.generations.process_generation_task") as mock_bg_task:
response = client.post("/api/v1/generations/", json=payload)
assert response.status_code == 202
data = response.json()
assert data["status"] == "pending"
assert "id" in data
# Verify background task was added
# Note: TestClient with BackgroundTasks runs them by default or we can check if it was added.
# But here we mocked the function itself so we can check if it was called (if TestClient runs it)
# or we just rely on the response being successful.
# Since FastAPI TestClient runs background tasks, the mock should be called.
# Wait, TestClient runs background tasks AFTER the response.
# So we might need to check if it was called.
app.dependency_overrides = {}
@pytest.mark.asyncio
async def test_get_generation_endpoint(client, override_get_db, mock_db_session):
app.dependency_overrides[get_db] = override_get_db
generation_id = uuid4()
mock_generation = MagicMock()
mock_generation.id = generation_id
mock_generation.status = "completed"
mock_generation.audio_path = "/path/to/audio.wav"
mock_generation.generation_metadata = {}
mock_generation.processing_time_seconds = 10.5
mock_generation.error_message = None
mock_generation.created_at = "2023-01-01T00:00:00Z"
mock_generation.completed_at = "2023-01-01T00:00:10Z"
# Mock the scalar_one_or_none result
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = mock_generation
mock_db_session.execute.return_value = mock_result
response = client.get(f"/api/v1/generations/{generation_id}")
assert response.status_code == 200
data = response.json()
assert data["id"] == str(generation_id)
assert data["status"] == "completed"
app.dependency_overrides = {}
@pytest.mark.asyncio
async def test_get_generation_not_found(client, override_get_db, mock_db_session):
app.dependency_overrides[get_db] = override_get_db
generation_id = uuid4()
# Mock the scalar_one_or_none result to return None
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_db_session.execute.return_value = mock_result
response = client.get(f"/api/v1/generations/{generation_id}")
assert response.status_code == 404
app.dependency_overrides = {}
=