Spaces:
Sleeping
Sleeping
| """ | |
| Comprehensive Tests for Credit Service | |
| Tests cover: | |
| 1. Credit Manager - reserve, confirm, refund operations | |
| 2. Error pattern matching - refundable vs non-refundable | |
| 3. Job completion handling | |
| 4. Credits Router endpoints | |
| 5. Credit middleware (if needed) | |
| Uses mocked database and user models. | |
| """ | |
| import pytest | |
| from datetime import datetime | |
| from unittest.mock import patch, MagicMock, AsyncMock | |
| from fastapi.testclient import TestClient | |
| # ============================================================================ | |
| # 1. Credit Manager Tests | |
| # ============================================================================ | |
| class TestCreditReservation: | |
| """Test credit reservation functionality.""" | |
| async def test_reserve_credit_success(self): | |
| """Successfully reserve credits from user balance.""" | |
| from services.credit_service.credit_manager import reserve_credit | |
| # Mock user with sufficient credits | |
| mock_user = MagicMock() | |
| mock_user.user_id = "usr_123" | |
| mock_user.credits = 10 | |
| mock_session = AsyncMock() | |
| result = await reserve_credit(mock_session, mock_user, amount=5) | |
| assert result == True | |
| assert mock_user.credits == 5 # 10 - 5 | |
| async def test_reserve_credit_insufficient(self): | |
| """Cannot reserve more credits than user has.""" | |
| from services.credit_service.credit_manager import reserve_credit | |
| mock_user = MagicMock() | |
| mock_user.user_id = "usr_123" | |
| mock_user.credits = 3 | |
| mock_session = AsyncMock() | |
| result = await reserve_credit(mock_session, mock_user, amount=5) | |
| assert result == False | |
| assert mock_user.credits == 3 # Unchanged | |
| async def test_reserve_credit_exact_amount(self): | |
| """Can reserve exact balance.""" | |
| from services.credit_service.credit_manager import reserve_credit | |
| mock_user = MagicMock() | |
| mock_user.credits = 10 | |
| mock_session = AsyncMock() | |
| result = await reserve_credit(mock_session, mock_user, amount=10) | |
| assert result == True | |
| assert mock_user.credits == 0 | |
| class TestCreditConfirmation: | |
| """Test credit confirmation on job completion.""" | |
| async def test_confirm_credit_clears_reservation(self): | |
| """Confirming credit clears the reservation tracking.""" | |
| from services.credit_service.credit_manager import confirm_credit | |
| mock_job = MagicMock() | |
| mock_job.job_id = "job_123" | |
| mock_job.credits_reserved = 5 | |
| mock_session = AsyncMock() | |
| await confirm_credit(mock_session, mock_job) | |
| assert mock_job.credits_reserved == 0 | |
| async def test_confirm_credit_no_reservation(self): | |
| """Confirming when no credits reserved does nothing.""" | |
| from services.credit_service.credit_manager import confirm_credit | |
| mock_job = MagicMock() | |
| mock_job.credits_reserved = 0 | |
| mock_session = AsyncMock() | |
| await confirm_credit(mock_session, mock_job) | |
| assert mock_job.credits_reserved == 0 | |
| class TestCreditRefund: | |
| """Test credit refund functionality.""" | |
| async def test_refund_credit_success(self): | |
| """Successfully refund credits to user.""" | |
| from services.credit_service.credit_manager import refund_credit | |
| from core.models import User | |
| # Mock job with reserved credits | |
| mock_job = MagicMock() | |
| mock_job.job_id = "job_123" | |
| mock_job.user_id = 1 | |
| mock_job.credits_reserved = 5 | |
| mock_job.credits_refunded = False | |
| # Mock user | |
| mock_user = MagicMock(spec=User) | |
| mock_user.id = 1 | |
| mock_user.user_id = "usr_123" | |
| mock_user.credits = 10 | |
| # Mock database session | |
| mock_session = AsyncMock() | |
| mock_result = MagicMock() | |
| mock_result.scalar_one_or_none.return_value = mock_user | |
| mock_session.execute.return_value = mock_result | |
| result = await refund_credit(mock_session, mock_job, "Test refund") | |
| assert result == True | |
| assert mock_user.credits == 15 # 10 + 5 | |
| assert mock_job.credits_reserved == 0 | |
| assert mock_job.credits_refunded == True | |
| async def test_refund_credit_no_reservation(self): | |
| """Cannot refund if no credits were reserved.""" | |
| from services.credit_service.credit_manager import refund_credit | |
| mock_job = MagicMock() | |
| mock_job.credits_reserved = 0 | |
| mock_session = AsyncMock() | |
| result = await refund_credit(mock_session, mock_job, "Test") | |
| assert result == False | |
| async def test_refund_credit_already_refunded(self): | |
| """Cannot refund credits twice.""" | |
| from services.credit_service.credit_manager import refund_credit | |
| mock_job = MagicMock() | |
| mock_job.credits_reserved = 5 | |
| mock_job.credits_refunded = True | |
| mock_session = AsyncMock() | |
| result = await refund_credit(mock_session, mock_job, "Test") | |
| assert result == False | |
| # ============================================================================ | |
| # 2. Error Pattern Matching Tests | |
| # ============================================================================ | |
| class TestErrorPatternMatching: | |
| """Test refundable vs non-refundable error detection.""" | |
| def test_refundable_api_key_error(self): | |
| """API key errors are refundable.""" | |
| from services.credit_service.credit_manager import is_refundable_error | |
| assert is_refundable_error("API_KEY_INVALID: The API key is invalid") == True | |
| def test_refundable_quota_exceeded(self): | |
| """Quota exceeded is refundable.""" | |
| from services.credit_service.credit_manager import is_refundable_error | |
| assert is_refundable_error("QUOTA_EXCEEDED: Daily quota exceeded") == True | |
| def test_refundable_internal_error(self): | |
| """Internal server errors are refundable.""" | |
| from services.credit_service.credit_manager import is_refundable_error | |
| assert is_refundable_error("INTERNAL_ERROR: Something went wrong") == True | |
| def test_refundable_timeout(self): | |
| """Timeouts are refundable.""" | |
| from services.credit_service.credit_manager import is_refundable_error | |
| assert is_refundable_error("Request TIMEOUT after 30 seconds") == True | |
| def test_refundable_500_error(self): | |
| """HTTP 500 errors are refundable.""" | |
| from services.credit_service.credit_manager import is_refundable_error | |
| assert is_refundable_error("Server returned 500 Internal Server Error") == True | |
| def test_non_refundable_safety_filter(self): | |
| """Safety filter blocks are not refundable.""" | |
| from services.credit_service.credit_manager import is_refundable_error | |
| assert is_refundable_error("Content blocked by safety filter") == False | |
| def test_non_refundable_invalid_input(self): | |
| """Invalid input errors are not refundable.""" | |
| from services.credit_service.credit_manager import is_refundable_error | |
| assert is_refundable_error("INVALID_INPUT: Bad image format") == False | |
| def test_non_refundable_400_error(self): | |
| """HTTP 400 errors are not refundable.""" | |
| from services.credit_service.credit_manager import is_refundable_error | |
| assert is_refundable_error("Bad request: 400 status code") == False | |
| def test_non_refundable_cancelled(self): | |
| """User cancellations are not refundable.""" | |
| from services.credit_service.credit_manager import is_refundable_error | |
| assert is_refundable_error("User cancelled the operation") == False | |
| def test_refundable_max_retries(self): | |
| """Max retries exceeded is refundable.""" | |
| from services.credit_service.credit_manager import is_refundable_error | |
| assert is_refundable_error("Failed after max retries") == True | |
| def test_unknown_error_not_refundable(self): | |
| """Unknown errors default to non-refundable.""" | |
| from services.credit_service.credit_manager import is_refundable_error | |
| assert is_refundable_error("Some random unknown error") == False | |
| def test_empty_error_not_refundable(self): | |
| """Empty error message is not refundable.""" | |
| from services.credit_service.credit_manager import is_refundable_error | |
| assert is_refundable_error("") == False | |
| assert is_refundable_error(None) == False | |
| # ============================================================================ | |
| # 3. Job Completion Handling Tests | |
| # ============================================================================ | |
| class TestJobCompletionHandling: | |
| """Test credit handling when jobs complete.""" | |
| async def test_completed_job_confirms_credits(self): | |
| """Completed jobs confirm credit usage.""" | |
| from services.credit_service.credit_manager import handle_job_completion | |
| mock_job = MagicMock() | |
| mock_job.job_id = "job_123" | |
| mock_job.status = "completed" | |
| mock_job.credits_reserved = 5 | |
| mock_session = AsyncMock() | |
| with patch('services.credit_service.credit_manager.confirm_credit') as mock_confirm: | |
| await handle_job_completion(mock_session, mock_job) | |
| mock_confirm.assert_called_once() | |
| async def test_failed_refundable_job_refunds(self): | |
| """Failed jobs with refundable errors get refunds.""" | |
| from services.credit_service.credit_manager import handle_job_completion | |
| mock_job = MagicMock() | |
| mock_job.status = "failed" | |
| mock_job.error_message = "API_KEY_INVALID: Bad key" | |
| mock_job.credits_reserved = 5 | |
| mock_session = AsyncMock() | |
| with patch('services.credit_service.credit_manager.refund_credit') as mock_refund: | |
| await handle_job_completion(mock_session, mock_job) | |
| mock_refund.assert_called_once() | |
| async def test_failed_non_refundable_job_keeps_credits(self): | |
| """Failed jobs with non-refundable errors keep credits.""" | |
| from services.credit_service.credit_manager import handle_job_completion | |
| mock_job = MagicMock() | |
| mock_job.status = "failed" | |
| mock_job.error_message = "Safety filter blocked content" | |
| mock_job.credits_reserved = 5 | |
| mock_session = AsyncMock() | |
| with patch('services.credit_service.credit_manager.confirm_credit') as mock_confirm: | |
| await handle_job_completion(mock_session, mock_job) | |
| mock_confirm.assert_called_once() | |
| async def test_cancelled_before_start_refunds(self): | |
| """Cancelled jobs that never started get refunds.""" | |
| from services.credit_service.credit_manager import handle_job_completion | |
| mock_job = MagicMock() | |
| mock_job.status = "cancelled" | |
| mock_job.started_at = None | |
| mock_job.credits_reserved = 5 | |
| mock_session = AsyncMock() | |
| with patch('services.credit_service.credit_manager.refund_credit') as mock_refund: | |
| await handle_job_completion(mock_session, mock_job) | |
| mock_refund.assert_called_once() | |
| async def test_cancelled_during_processing_keeps_credits(self): | |
| """Cancelled jobs that started keep credits.""" | |
| from services.credit_service.credit_manager import handle_job_completion | |
| mock_job = MagicMock() | |
| mock_job.status = "cancelled" | |
| mock_job.started_at = datetime.utcnow() | |
| mock_job.credits_reserved = 5 | |
| mock_session = AsyncMock() | |
| with patch('services.credit_service.credit_manager.confirm_credit') as mock_confirm: | |
| await handle_job_completion(mock_session, mock_job) | |
| mock_confirm.assert_called_once() | |
| # ============================================================================ | |
| # 4. Credits Router Tests | |
| # ============================================================================ | |
| class TestCreditsRouter: | |
| """Test credits API endpoints.""" | |
| def test_get_balance_requires_auth(self): | |
| """GET /credits/balance requires authentication.""" | |
| from routers.credits import router | |
| from fastapi import FastAPI | |
| app = FastAPI() | |
| app.include_router(router) | |
| client = TestClient(app) | |
| response = client.get("/credits/balance") | |
| # Should fail without auth | |
| assert response.status_code == 500 # Attribute Error - no middleware | |
| def test_get_balance_returns_user_credits(self): | |
| """GET /credits/balance returns user's credit balance.""" | |
| from routers.credits import router | |
| from fastapi import FastAPI | |
| app = FastAPI() | |
| # Mock authenticated user in request state | |
| mock_user = MagicMock() | |
| mock_user.user_id = "usr_123" | |
| mock_user.credits = 50 | |
| mock_user.last_used_at = None | |
| # Create test client with middleware that sets request.state.user | |
| async def add_user_to_state(request, call_next): | |
| request.state.user = mock_user | |
| return await call_next(request) | |
| app.include_router(router) | |
| client = TestClient(app) | |
| response = client.get("/credits/balance") | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert data["user_id"] == "usr_123" | |
| assert data["credits"] == 50 | |
| def test_get_history_requires_auth(self): | |
| """GET /credits/history requires authentication.""" | |
| from routers.credits import router | |
| from fastapi import FastAPI | |
| app = FastAPI() | |
| app.include_router(router) | |
| client = TestClient(app) | |
| response = client.get("/credits/history") | |
| # Should fail without auth | |
| assert response.status_code == 500 # Attribute Error - no middleware | |
| def test_get_history_returns_paginated_jobs(self): | |
| """GET /credits/history returns paginated job list.""" | |
| from routers.credits import router | |
| from fastapi import FastAPI | |
| from core.database import get_db | |
| app = FastAPI() | |
| mock_user = MagicMock() | |
| mock_user.user_id = "usr_123" | |
| mock_user.credits = 50 | |
| # Mock database with jobs | |
| mock_job = MagicMock() | |
| mock_job.job_id = "job_123" | |
| mock_job.job_type = "generate-video" | |
| mock_job.status = "completed" | |
| mock_job.credits_reserved = 10 | |
| mock_job.credits_refunded = False | |
| mock_job.error_message = None | |
| mock_job.created_at = datetime.utcnow() | |
| mock_job.completed_at = datetime.utcnow() | |
| async def mock_get_db(): | |
| mock_db = AsyncMock() | |
| mock_result = MagicMock() | |
| mock_result.scalars.return_value.all.return_value = [mock_job] | |
| mock_db.execute.return_value = mock_result | |
| yield mock_db | |
| async def add_user_to_state(request, call_next): | |
| request.state.user = mock_user | |
| return await call_next(request) | |
| app.dependency_overrides[get_db] = mock_get_db | |
| app.include_router(router) | |
| client = TestClient(app) | |
| response = client.get("/credits/history") | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert data["user_id"] == "usr_123" | |
| assert data["current_balance"] == 50 | |
| assert len(data["history"]) == 1 | |
| assert data["history"][0]["job_id"] == "job_123" | |
| def test_get_history_pagination(self): | |
| """GET /credits/history supports pagination.""" | |
| from routers.credits import router | |
| from fastapi import FastAPI | |
| from core.database import get_db | |
| app = FastAPI() | |
| mock_user = MagicMock() | |
| mock_user.user_id = "usr_123" | |
| mock_user.credits = 50 | |
| async def mock_get_db(): | |
| mock_db = AsyncMock() | |
| mock_result = MagicMock() | |
| mock_result.scalars.return_value.all.return_value = [] | |
| mock_db.execute.return_value = mock_result | |
| yield mock_db | |
| async def add_user_to_state(request, call_next): | |
| request.state.user = mock_user | |
| return await call_next(request) | |
| app.dependency_overrides[get_db] = mock_get_db | |
| app.include_router(router) | |
| client = TestClient(app) | |
| response = client.get("/credits/history?page=2&limit=10") | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert data["page"] == 2 | |
| assert data["limit"] == 10 | |
| if __name__ == "__main__": | |
| pytest.main([__file__, "-v"]) | |