File size: 3,617 Bytes
6423ff2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109

import pytest
from fastapi.testclient import TestClient
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import uuid4

from app.main import app
from app.db.database import get_db
from app.services.orchestrator import GenerationOrchestrator

@pytest.fixture
def client():
    return TestClient(app)

@pytest.fixture
def mock_db_session():
    mock_session = AsyncMock()
    # Mock the add, commit, refresh methods
    mock_session.add = MagicMock()
    mock_session.commit = AsyncMock()
    
    async def mock_refresh(instance):
        instance.id = uuid4()
        instance.created_at = "2023-01-01T00:00:00Z"
        
    mock_session.refresh = AsyncMock(side_effect=mock_refresh)
    return mock_session

@pytest.fixture
def override_get_db(mock_db_session):
    async def _get_db():
        yield mock_db_session
    return _get_db

@pytest.mark.asyncio
async def test_create_generation_endpoint(client, override_get_db):
    app.dependency_overrides[get_db] = override_get_db
    
    payload = {
        "prompt": "A happy upbeat song",
        "duration": 30
    }
    
    with patch("app.api.v1.endpoints.generations.process_generation_task") as mock_bg_task:
        response = client.post("/api/v1/generations/", json=payload)
        
        assert response.status_code == 202
        data = response.json()
        assert data["status"] == "pending"
        assert "id" in data
        
        # Verify background task was added
        # Note: TestClient with BackgroundTasks runs them by default or we can check if it was added.
        # But here we mocked the function itself so we can check if it was called (if TestClient runs it)
        # or we just rely on the response being successful.
        # Since FastAPI TestClient runs background tasks, the mock should be called.
        # Wait, TestClient runs background tasks AFTER the response.
        # So we might need to check if it was called.
        
    app.dependency_overrides = {}

@pytest.mark.asyncio
async def test_get_generation_endpoint(client, override_get_db, mock_db_session):
    app.dependency_overrides[get_db] = override_get_db
    
    generation_id = uuid4()
    mock_generation = MagicMock()
    mock_generation.id = generation_id
    mock_generation.status = "completed"
    mock_generation.audio_path = "/path/to/audio.wav"
    mock_generation.generation_metadata = {}
    mock_generation.processing_time_seconds = 10.5
    mock_generation.error_message = None
    mock_generation.created_at = "2023-01-01T00:00:00Z"
    mock_generation.completed_at = "2023-01-01T00:00:10Z"
    
    # Mock the scalar_one_or_none result
    mock_result = MagicMock()
    mock_result.scalar_one_or_none.return_value = mock_generation
    mock_db_session.execute.return_value = mock_result
    
    response = client.get(f"/api/v1/generations/{generation_id}")
    
    assert response.status_code == 200
    data = response.json()
    assert data["id"] == str(generation_id)
    assert data["status"] == "completed"
    
    app.dependency_overrides = {}

@pytest.mark.asyncio
async def test_get_generation_not_found(client, override_get_db, mock_db_session):
    app.dependency_overrides[get_db] = override_get_db
    
    generation_id = uuid4()
    
    # Mock the scalar_one_or_none result to return None
    mock_result = MagicMock()
    mock_result.scalar_one_or_none.return_value = None
    mock_db_session.execute.return_value = mock_result
    
    response = client.get(f"/api/v1/generations/{generation_id}")
    
    assert response.status_code == 404
    
    app.dependency_overrides = {}
    
    =