apigateway / tests /test_credit_middleware_integration.py
jebin2's picture
Add comprehensive test suite for credit service
a295e63
"""
Integration Test Suite for Credit Middleware
Tests the complete middleware flow including:
- Request interception
- Credit reservation
- Response inspection
- Automatic confirmation/refund
"""
import pytest
import json
from unittest.mock import AsyncMock, MagicMock, patch
from fastapi import Request, Response, status
from fastapi.responses import JSONResponse
from services.credit_service.middleware import CreditMiddleware
from services.credit_service.config import CreditServiceConfig
from core.models import User
# =============================================================================
# Fixtures
# =============================================================================
@pytest.fixture
def mock_user():
"""Create a mock user with credits."""
user = MagicMock(spec=User)
user.id = 1
user.user_id = "test_user_123"
user.credits = 100
return user
@pytest.fixture
def mock_request(mock_user):
"""Create a mock FastAPI request."""
request = MagicMock(spec=Request)
request.method = "POST"
request.url.path = "/gemini/analyze-image"
request.state.user = mock_user
request.state.credit_transaction_id = None
request.client.host = "127.0.0.1"
request.headers = {"user-agent": "test"}
return request
@pytest.fixture
def credit_middleware():
"""Create credit middleware instance."""
# Register test configuration
CreditServiceConfig.register(
route_configs={
"/gemini/analyze-image": {"cost": 1, "type": "sync"},
"/gemini/generate-video": {"cost": 10, "type": "async"},
"/gemini/job/{job_id}": {"cost": 0, "type": "async"},
"/free-endpoint": {"cost": 0, "type": "free"}
}
)
return CreditMiddleware(MagicMock())
# =============================================================================
# Free Endpoint Tests
# =============================================================================
@pytest.mark.asyncio
async def test_free_endpoint_no_credit_check(credit_middleware, mock_request):
"""Test that free endpoints bypass credit middleware."""
mock_request.url.path = "/free-endpoint"
async def mock_call_next(request):
return Response(content="OK", status_code=200)
response = await credit_middleware.dispatch(mock_request, mock_call_next)
assert response.status_code == 200
assert not hasattr(mock_request.state, 'credit_transaction_id')
@pytest.mark.asyncio
async def test_options_request_bypass(credit_middleware, mock_request):
"""Test that OPTIONS requests bypass middleware."""
mock_request.method = "OPTIONS"
async def mock_call_next(request):
return Response(status_code=204)
response = await credit_middleware.dispatch(mock_request, mock_call_next)
assert response.status_code == 204
# =============================================================================
# Unauthenticated Request Tests
# =============================================================================
@pytest.mark.asyncio
async def test_unauthenticated_request(credit_middleware, mock_request):
"""Test that unauthenticated requests are rejected."""
mock_request.state.user = None
async def mock_call_next(request):
return Response(status_code=200)
with patch('services.credit_service.middleware.async_session_maker'):
response = await credit_middleware.dispatch(mock_request, mock_call_next)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
# =============================================================================
# Credit Reservation Tests
# =============================================================================
@pytest.mark.asyncio
async def test_successful_credit_reservation(credit_middleware, mock_request):
"""Test successful credit reservation on request."""
# Mock database session and transaction manager
with patch('services.credit_service.middleware.async_session_maker') as mock_session:
mock_db = AsyncMock()
mock_session.return_value.__aenter__.return_value = mock_db
# Mock transaction manager
with patch('services.credit_service.middleware.CreditTransactionManager') as mock_tm:
mock_transaction = MagicMock()
mock_transaction.transaction_id = "ctx_test123"
mock_tm.reserve_credits = AsyncMock(return_value=mock_transaction)
# Mock call_next to return success response
async def mock_call_next(request):
# Simulate response iterator
async def body_iterator():
yield b'{"result": "success"}'
response = Response(content=b'{"result": "success"}', status_code=200)
response.body_iterator = body_iterator()
return response
response = await credit_middleware.dispatch(mock_request, mock_call_next)
# Verify reserve_credits was called
mock_tm.reserve_credits.assert_called_once()
call_args = mock_tm.reserve_credits.call_args
assert call_args.kwargs['amount'] == 1 # 1 credit for analyze-image
# =============================================================================
# Insufficient Credits Tests
# =============================================================================
@pytest.mark.asyncio
async def test_insufficient_credits(credit_middleware, mock_request):
"""Test request rejection when user has insufficient credits."""
from services.credit_service.transaction_manager import InsufficientCreditsError
with patch('services.credit_service.middleware.async_session_maker') as mock_session:
mock_db = AsyncMock()
mock_session.return_value.__aenter__.return_value = mock_db
with patch('services.credit_service.middleware.CreditTransactionManager') as mock_tm:
# Simulate insufficient credits
mock_tm.reserve_credits = AsyncMock(side_effect=InsufficientCreditsError("Not enough credits"))
async def mock_call_next(request):
return Response(status_code=200)
response = await credit_middleware.dispatch(mock_request, mock_call_next)
assert response.status_code == status.HTTP_402_PAYMENT_REQUIRED
content = json.loads(response.body.decode())
assert "Insufficient credits" in content["detail"]
# =============================================================================
# Response Inspection Tests - Sync Endpoints
# =============================================================================
@pytest.mark.asyncio
async def test_sync_success_confirms_credits(credit_middleware, mock_request):
"""Test that successful sync response confirms credits."""
with patch('services.credit_service.middleware.async_session_maker') as mock_session:
mock_db = AsyncMock()
mock_session.return_value.__aenter__.return_value = mock_db
with patch('services.credit_service.middleware.CreditTransactionManager') as mock_tm:
mock_transaction = MagicMock()
mock_transaction.transaction_id = "ctx_test123"
mock_tm.reserve_credits = AsyncMock(return_value=mock_transaction)
mock_tm.confirm_credits = AsyncMock()
# Mock successful response
async def mock_call_next(request):
async def body_iterator():
yield b'{"result": "image analyzed"}'
response = Response(content=b'{"result": "image analyzed"}', status_code=200)
response.body_iterator = body_iterator()
return response
await credit_middleware.dispatch(mock_request, mock_call_next)
# Verify confirm was called
mock_tm.confirm_credits.assert_called_once()
@pytest.mark.asyncio
async def test_sync_failure_refunds_credits(credit_middleware, mock_request):
"""Test that failed sync response refunds credits."""
with patch('services.credit_service.middleware.async_session_maker') as mock_session:
mock_db = AsyncMock()
mock_session.return_value.__aenter__.return_value = mock_db
with patch('services.credit_service.middleware.CreditTransactionManager') as mock_tm:
mock_transaction = MagicMock()
mock_transaction.transaction_id = "ctx_test123"
mock_tm.reserve_credits = AsyncMock(return_value=mock_transaction)
mock_tm.refund_credits = AsyncMock()
# Mock failed response
async def mock_call_next(request):
async def body_iterator():
yield b'{"detail": "Invalid image"}'
response = Response(content=b'{"detail": "Invalid image"}', status_code=400)
response.body_iterator = body_iterator()
return response
await credit_middleware.dispatch(mock_request, mock_call_next)
# Verify refund was called
mock_tm.refund_credits.assert_called_once()
# =============================================================================
# Response Inspection Tests - Async Endpoints
# =============================================================================
@pytest.mark.asyncio
async def test_async_job_creation_keeps_reserved(credit_middleware, mock_request):
"""Test that async job creation keeps credits reserved."""
mock_request.url.path = "/gemini/generate-video"
with patch('services.credit_service.middleware.async_session_maker') as mock_session:
mock_db = AsyncMock()
mock_session.return_value.__aenter__.return_value = mock_db
with patch('services.credit_service.middleware.CreditTransactionManager') as mock_tm:
mock_transaction = MagicMock()
mock_transaction.transaction_id = "ctx_test123"
mock_tm.reserve_credits = AsyncMock(return_value=mock_transaction)
mock_tm.confirm_credits = AsyncMock()
mock_tm.refund_credits = AsyncMock()
# Mock job creation response
async def mock_call_next(request):
async def body_iterator():
yield b'{"job_id": "job_abc", "status": "queued"}'
response = Response(
content=b'{"job_id": "job_abc", "status": "queued"}',
status_code=200
)
response.body_iterator = body_iterator()
return response
await credit_middleware.dispatch(mock_request, mock_call_next)
# Verify neither confirm nor refund was called
mock_tm.confirm_credits.assert_not_called()
mock_tm.refund_credits.assert_not_called()
@pytest.mark.asyncio
async def test_async_job_completed_confirms_credits(credit_middleware, mock_request):
"""Test that completed async job confirms credits."""
mock_request.url.path = "/gemini/job/job_abc"
with patch('services.credit_service.middleware.async_session_maker') as mock_session:
mock_db = AsyncMock()
mock_session.return_value.__aenter__.return_value = mock_db
with patch('services.credit_service.middleware.CreditTransactionManager') as mock_tm:
# No reservation for status check (cost=0)
mock_transaction = MagicMock()
mock_transaction.transaction_id = "ctx_test123"
mock_tm.confirm_credits = AsyncMock()
# Mock completed job response
async def mock_call_next(request):
async def body_iterator():
yield b'{"job_id": "job_abc", "status": "completed", "video_url": "..."}'
response = Response(
content=b'{"job_id": "job_abc", "status": "completed", "video_url": "..."}',
status_code=200
)
response.body_iterator = body_iterator()
return response
# Since cost=0, no reservation happens
# But this test shows the logic for when a reservation exists
response = await credit_middleware.dispatch(mock_request, mock_call_next)
assert response.status_code == 200
# =============================================================================
# Error Handling Tests
# =============================================================================
@pytest.mark.asyncio
async def test_database_error_during_reservation(credit_middleware, mock_request):
"""Test handling of database errors during reservation."""
with patch('services.credit_service.middleware.async_session_maker') as mock_session:
mock_db = AsyncMock()
mock_session.return_value.__aenter__.return_value = mock_db
with patch('services.credit_service.middleware.CreditTransactionManager') as mock_tm:
# Simulate database error
mock_tm.reserve_credits = AsyncMock(side_effect=Exception("DB connection failed"))
async def mock_call_next(request):
return Response(status_code=200)
response = await credit_middleware.dispatch(mock_request, mock_call_next)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
@pytest.mark.asyncio
async def test_response_phase_error_doesnt_fail_request(credit_middleware, mock_request):
"""Test that errors in response phase don't break the actual response."""
with patch('services.credit_service.middleware.async_session_maker') as mock_session:
mock_db = AsyncMock()
mock_session.return_value.__aenter__.return_value = mock_db
with patch('services.credit_service.middleware.CreditTransactionManager') as mock_tm:
mock_transaction = MagicMock()
mock_transaction.transaction_id = "ctx_test123"
mock_tm.reserve_credits = AsyncMock(return_value=mock_transaction)
# Confirm will fail, but response should still be returned
mock_tm.confirm_credits = AsyncMock(side_effect=Exception("Confirm failed"))
async def mock_call_next(request):
async def body_iterator():
yield b'{"result": "success"}'
response = Response(content=b'{"result": "success"}', status_code=200)
response.body_iterator = body_iterator()
return response
response = await credit_middleware.dispatch(mock_request, mock_call_next)
# Response should still be 200 even though confirm failed
assert response.status_code == 200