Spaces:
Sleeping
Sleeping
| """ | |
| Integration Test Suite for Credit Middleware | |
| Tests the complete middleware flow including: | |
| - Request interception | |
| - Credit reservation | |
| - Response inspection | |
| - Automatic confirmation/refund | |
| """ | |
| import pytest | |
| import json | |
| from unittest.mock import AsyncMock, MagicMock, patch | |
| from fastapi import Request, Response, status | |
| from fastapi.responses import JSONResponse | |
| from services.credit_service.middleware import CreditMiddleware | |
| from services.credit_service.config import CreditServiceConfig | |
| from core.models import User | |
| # ============================================================================= | |
| # Fixtures | |
| # ============================================================================= | |
| def mock_user(): | |
| """Create a mock user with credits.""" | |
| user = MagicMock(spec=User) | |
| user.id = 1 | |
| user.user_id = "test_user_123" | |
| user.credits = 100 | |
| return user | |
| def mock_request(mock_user): | |
| """Create a mock FastAPI request.""" | |
| request = MagicMock(spec=Request) | |
| request.method = "POST" | |
| request.url.path = "/gemini/analyze-image" | |
| request.state.user = mock_user | |
| request.state.credit_transaction_id = None | |
| request.client.host = "127.0.0.1" | |
| request.headers = {"user-agent": "test"} | |
| return request | |
| def credit_middleware(): | |
| """Create credit middleware instance.""" | |
| # Register test configuration | |
| CreditServiceConfig.register( | |
| route_configs={ | |
| "/gemini/analyze-image": {"cost": 1, "type": "sync"}, | |
| "/gemini/generate-video": {"cost": 10, "type": "async"}, | |
| "/gemini/job/{job_id}": {"cost": 0, "type": "async"}, | |
| "/free-endpoint": {"cost": 0, "type": "free"} | |
| } | |
| ) | |
| return CreditMiddleware(MagicMock()) | |
| # ============================================================================= | |
| # Free Endpoint Tests | |
| # ============================================================================= | |
| async def test_free_endpoint_no_credit_check(credit_middleware, mock_request): | |
| """Test that free endpoints bypass credit middleware.""" | |
| mock_request.url.path = "/free-endpoint" | |
| async def mock_call_next(request): | |
| return Response(content="OK", status_code=200) | |
| response = await credit_middleware.dispatch(mock_request, mock_call_next) | |
| assert response.status_code == 200 | |
| assert not hasattr(mock_request.state, 'credit_transaction_id') | |
| async def test_options_request_bypass(credit_middleware, mock_request): | |
| """Test that OPTIONS requests bypass middleware.""" | |
| mock_request.method = "OPTIONS" | |
| async def mock_call_next(request): | |
| return Response(status_code=204) | |
| response = await credit_middleware.dispatch(mock_request, mock_call_next) | |
| assert response.status_code == 204 | |
| # ============================================================================= | |
| # Unauthenticated Request Tests | |
| # ============================================================================= | |
| async def test_unauthenticated_request(credit_middleware, mock_request): | |
| """Test that unauthenticated requests are rejected.""" | |
| mock_request.state.user = None | |
| async def mock_call_next(request): | |
| return Response(status_code=200) | |
| with patch('services.credit_service.middleware.async_session_maker'): | |
| response = await credit_middleware.dispatch(mock_request, mock_call_next) | |
| assert response.status_code == status.HTTP_401_UNAUTHORIZED | |
| # ============================================================================= | |
| # Credit Reservation Tests | |
| # ============================================================================= | |
| async def test_successful_credit_reservation(credit_middleware, mock_request): | |
| """Test successful credit reservation on request.""" | |
| # Mock database session and transaction manager | |
| with patch('services.credit_service.middleware.async_session_maker') as mock_session: | |
| mock_db = AsyncMock() | |
| mock_session.return_value.__aenter__.return_value = mock_db | |
| # Mock transaction manager | |
| with patch('services.credit_service.middleware.CreditTransactionManager') as mock_tm: | |
| mock_transaction = MagicMock() | |
| mock_transaction.transaction_id = "ctx_test123" | |
| mock_tm.reserve_credits = AsyncMock(return_value=mock_transaction) | |
| # Mock call_next to return success response | |
| async def mock_call_next(request): | |
| # Simulate response iterator | |
| async def body_iterator(): | |
| yield b'{"result": "success"}' | |
| response = Response(content=b'{"result": "success"}', status_code=200) | |
| response.body_iterator = body_iterator() | |
| return response | |
| response = await credit_middleware.dispatch(mock_request, mock_call_next) | |
| # Verify reserve_credits was called | |
| mock_tm.reserve_credits.assert_called_once() | |
| call_args = mock_tm.reserve_credits.call_args | |
| assert call_args.kwargs['amount'] == 1 # 1 credit for analyze-image | |
| # ============================================================================= | |
| # Insufficient Credits Tests | |
| # ============================================================================= | |
| async def test_insufficient_credits(credit_middleware, mock_request): | |
| """Test request rejection when user has insufficient credits.""" | |
| from services.credit_service.transaction_manager import InsufficientCreditsError | |
| with patch('services.credit_service.middleware.async_session_maker') as mock_session: | |
| mock_db = AsyncMock() | |
| mock_session.return_value.__aenter__.return_value = mock_db | |
| with patch('services.credit_service.middleware.CreditTransactionManager') as mock_tm: | |
| # Simulate insufficient credits | |
| mock_tm.reserve_credits = AsyncMock(side_effect=InsufficientCreditsError("Not enough credits")) | |
| async def mock_call_next(request): | |
| return Response(status_code=200) | |
| response = await credit_middleware.dispatch(mock_request, mock_call_next) | |
| assert response.status_code == status.HTTP_402_PAYMENT_REQUIRED | |
| content = json.loads(response.body.decode()) | |
| assert "Insufficient credits" in content["detail"] | |
| # ============================================================================= | |
| # Response Inspection Tests - Sync Endpoints | |
| # ============================================================================= | |
| async def test_sync_success_confirms_credits(credit_middleware, mock_request): | |
| """Test that successful sync response confirms credits.""" | |
| with patch('services.credit_service.middleware.async_session_maker') as mock_session: | |
| mock_db = AsyncMock() | |
| mock_session.return_value.__aenter__.return_value = mock_db | |
| with patch('services.credit_service.middleware.CreditTransactionManager') as mock_tm: | |
| mock_transaction = MagicMock() | |
| mock_transaction.transaction_id = "ctx_test123" | |
| mock_tm.reserve_credits = AsyncMock(return_value=mock_transaction) | |
| mock_tm.confirm_credits = AsyncMock() | |
| # Mock successful response | |
| async def mock_call_next(request): | |
| async def body_iterator(): | |
| yield b'{"result": "image analyzed"}' | |
| response = Response(content=b'{"result": "image analyzed"}', status_code=200) | |
| response.body_iterator = body_iterator() | |
| return response | |
| await credit_middleware.dispatch(mock_request, mock_call_next) | |
| # Verify confirm was called | |
| mock_tm.confirm_credits.assert_called_once() | |
| async def test_sync_failure_refunds_credits(credit_middleware, mock_request): | |
| """Test that failed sync response refunds credits.""" | |
| with patch('services.credit_service.middleware.async_session_maker') as mock_session: | |
| mock_db = AsyncMock() | |
| mock_session.return_value.__aenter__.return_value = mock_db | |
| with patch('services.credit_service.middleware.CreditTransactionManager') as mock_tm: | |
| mock_transaction = MagicMock() | |
| mock_transaction.transaction_id = "ctx_test123" | |
| mock_tm.reserve_credits = AsyncMock(return_value=mock_transaction) | |
| mock_tm.refund_credits = AsyncMock() | |
| # Mock failed response | |
| async def mock_call_next(request): | |
| async def body_iterator(): | |
| yield b'{"detail": "Invalid image"}' | |
| response = Response(content=b'{"detail": "Invalid image"}', status_code=400) | |
| response.body_iterator = body_iterator() | |
| return response | |
| await credit_middleware.dispatch(mock_request, mock_call_next) | |
| # Verify refund was called | |
| mock_tm.refund_credits.assert_called_once() | |
| # ============================================================================= | |
| # Response Inspection Tests - Async Endpoints | |
| # ============================================================================= | |
| async def test_async_job_creation_keeps_reserved(credit_middleware, mock_request): | |
| """Test that async job creation keeps credits reserved.""" | |
| mock_request.url.path = "/gemini/generate-video" | |
| with patch('services.credit_service.middleware.async_session_maker') as mock_session: | |
| mock_db = AsyncMock() | |
| mock_session.return_value.__aenter__.return_value = mock_db | |
| with patch('services.credit_service.middleware.CreditTransactionManager') as mock_tm: | |
| mock_transaction = MagicMock() | |
| mock_transaction.transaction_id = "ctx_test123" | |
| mock_tm.reserve_credits = AsyncMock(return_value=mock_transaction) | |
| mock_tm.confirm_credits = AsyncMock() | |
| mock_tm.refund_credits = AsyncMock() | |
| # Mock job creation response | |
| async def mock_call_next(request): | |
| async def body_iterator(): | |
| yield b'{"job_id": "job_abc", "status": "queued"}' | |
| response = Response( | |
| content=b'{"job_id": "job_abc", "status": "queued"}', | |
| status_code=200 | |
| ) | |
| response.body_iterator = body_iterator() | |
| return response | |
| await credit_middleware.dispatch(mock_request, mock_call_next) | |
| # Verify neither confirm nor refund was called | |
| mock_tm.confirm_credits.assert_not_called() | |
| mock_tm.refund_credits.assert_not_called() | |
| async def test_async_job_completed_confirms_credits(credit_middleware, mock_request): | |
| """Test that completed async job confirms credits.""" | |
| mock_request.url.path = "/gemini/job/job_abc" | |
| with patch('services.credit_service.middleware.async_session_maker') as mock_session: | |
| mock_db = AsyncMock() | |
| mock_session.return_value.__aenter__.return_value = mock_db | |
| with patch('services.credit_service.middleware.CreditTransactionManager') as mock_tm: | |
| # No reservation for status check (cost=0) | |
| mock_transaction = MagicMock() | |
| mock_transaction.transaction_id = "ctx_test123" | |
| mock_tm.confirm_credits = AsyncMock() | |
| # Mock completed job response | |
| async def mock_call_next(request): | |
| async def body_iterator(): | |
| yield b'{"job_id": "job_abc", "status": "completed", "video_url": "..."}' | |
| response = Response( | |
| content=b'{"job_id": "job_abc", "status": "completed", "video_url": "..."}', | |
| status_code=200 | |
| ) | |
| response.body_iterator = body_iterator() | |
| return response | |
| # Since cost=0, no reservation happens | |
| # But this test shows the logic for when a reservation exists | |
| response = await credit_middleware.dispatch(mock_request, mock_call_next) | |
| assert response.status_code == 200 | |
| # ============================================================================= | |
| # Error Handling Tests | |
| # ============================================================================= | |
| async def test_database_error_during_reservation(credit_middleware, mock_request): | |
| """Test handling of database errors during reservation.""" | |
| with patch('services.credit_service.middleware.async_session_maker') as mock_session: | |
| mock_db = AsyncMock() | |
| mock_session.return_value.__aenter__.return_value = mock_db | |
| with patch('services.credit_service.middleware.CreditTransactionManager') as mock_tm: | |
| # Simulate database error | |
| mock_tm.reserve_credits = AsyncMock(side_effect=Exception("DB connection failed")) | |
| async def mock_call_next(request): | |
| return Response(status_code=200) | |
| response = await credit_middleware.dispatch(mock_request, mock_call_next) | |
| assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR | |
| async def test_response_phase_error_doesnt_fail_request(credit_middleware, mock_request): | |
| """Test that errors in response phase don't break the actual response.""" | |
| with patch('services.credit_service.middleware.async_session_maker') as mock_session: | |
| mock_db = AsyncMock() | |
| mock_session.return_value.__aenter__.return_value = mock_db | |
| with patch('services.credit_service.middleware.CreditTransactionManager') as mock_tm: | |
| mock_transaction = MagicMock() | |
| mock_transaction.transaction_id = "ctx_test123" | |
| mock_tm.reserve_credits = AsyncMock(return_value=mock_transaction) | |
| # Confirm will fail, but response should still be returned | |
| mock_tm.confirm_credits = AsyncMock(side_effect=Exception("Confirm failed")) | |
| async def mock_call_next(request): | |
| async def body_iterator(): | |
| yield b'{"result": "success"}' | |
| response = Response(content=b'{"result": "success"}', status_code=200) | |
| response.body_iterator = body_iterator() | |
| return response | |
| response = await credit_middleware.dispatch(mock_request, mock_call_next) | |
| # Response should still be 200 even though confirm failed | |
| assert response.status_code == 200 | |