import pytest from fastapi.testclient import TestClient from unittest.mock import AsyncMock, MagicMock, patch from uuid import uuid4 from app.main import app from app.db.database import get_db from app.services.orchestrator import GenerationOrchestrator @pytest.fixture def client(): return TestClient(app) @pytest.fixture def mock_db_session(): mock_session = AsyncMock() # Mock the add, commit, refresh methods mock_session.add = MagicMock() mock_session.commit = AsyncMock() async def mock_refresh(instance): instance.id = uuid4() instance.created_at = "2023-01-01T00:00:00Z" mock_session.refresh = AsyncMock(side_effect=mock_refresh) return mock_session @pytest.fixture def override_get_db(mock_db_session): async def _get_db(): yield mock_db_session return _get_db @pytest.mark.asyncio async def test_create_generation_endpoint(client, override_get_db): app.dependency_overrides[get_db] = override_get_db payload = { "prompt": "A happy upbeat song", "duration": 30 } with patch("app.api.v1.endpoints.generations.process_generation_task") as mock_bg_task: response = client.post("/api/v1/generations/", json=payload) assert response.status_code == 202 data = response.json() assert data["status"] == "pending" assert "id" in data # Verify background task was added # Note: TestClient with BackgroundTasks runs them by default or we can check if it was added. # But here we mocked the function itself so we can check if it was called (if TestClient runs it) # or we just rely on the response being successful. # Since FastAPI TestClient runs background tasks, the mock should be called. # Wait, TestClient runs background tasks AFTER the response. # So we might need to check if it was called. app.dependency_overrides = {} @pytest.mark.asyncio async def test_get_generation_endpoint(client, override_get_db, mock_db_session): app.dependency_overrides[get_db] = override_get_db generation_id = uuid4() mock_generation = MagicMock() mock_generation.id = generation_id mock_generation.status = "completed" mock_generation.audio_path = "/path/to/audio.wav" mock_generation.generation_metadata = {} mock_generation.processing_time_seconds = 10.5 mock_generation.error_message = None mock_generation.created_at = "2023-01-01T00:00:00Z" mock_generation.completed_at = "2023-01-01T00:00:10Z" # Mock the scalar_one_or_none result mock_result = MagicMock() mock_result.scalar_one_or_none.return_value = mock_generation mock_db_session.execute.return_value = mock_result response = client.get(f"/api/v1/generations/{generation_id}") assert response.status_code == 200 data = response.json() assert data["id"] == str(generation_id) assert data["status"] == "completed" app.dependency_overrides = {} @pytest.mark.asyncio async def test_get_generation_not_found(client, override_get_db, mock_db_session): app.dependency_overrides[get_db] = override_get_db generation_id = uuid4() # Mock the scalar_one_or_none result to return None mock_result = MagicMock() mock_result.scalar_one_or_none.return_value = None mock_db_session.execute.return_value = mock_result response = client.get(f"/api/v1/generations/{generation_id}") assert response.status_code == 404 app.dependency_overrides = {} =