Spaces:
Sleeping
Sleeping
| """ | |
| Comprehensive Tests for Core Dependencies | |
| Tests cover: | |
| 1. get_current_user - JWT extraction & verification | |
| 2. get_optional_user - Optional authentication | |
| 3. check_rate_limit - Rate limiting function | |
| 4. get_geolocation - IP geolocation | |
| Uses mocked database and JWT services. | |
| """ | |
| import pytest | |
| from unittest.mock import MagicMock, AsyncMock, patch | |
| from fastapi import HTTPException, Request | |
| # ============================================================================ | |
| # 1. get_current_user Tests | |
| # ============================================================================ | |
| class TestGetCurrentUser: | |
| """Test get_current_user dependency.""" | |
| async def test_valid_token_returns_user(self, db_session): | |
| """Valid JWT token returns authenticated user.""" | |
| from core.dependencies import get_current_user | |
| from core.models import User | |
| # Create user | |
| user = User(user_id="usr_dep", email="dep@example.com", token_version=1) | |
| db_session.add(user) | |
| await db_session.commit() | |
| # Mock request with valid token | |
| mock_request = MagicMock(spec=Request) | |
| mock_request.headers.get.return_value = "Bearer valid_token_here" | |
| with patch('dependencies.verify_access_token') as mock_verify: | |
| mock_verify.return_value = MagicMock( | |
| user_id="usr_dep", | |
| email="dep@example.com", | |
| token_version=1 | |
| ) | |
| result = await get_current_user(mock_request, db_session) | |
| assert result.user_id == "usr_dep" | |
| assert result.email == "dep@example.com" | |
| async def test_missing_auth_header_raises_401(self, db_session): | |
| """Missing Authorization header raises 401.""" | |
| from core.dependencies import get_current_user | |
| mock_request = MagicMock(spec=Request) | |
| mock_request.headers.get.return_value = None | |
| with pytest.raises(HTTPException) as exc_info: | |
| await get_current_user(mock_request, db_session) | |
| assert exc_info.value.status_code == 401 | |
| async def test_invalid_header_format_raises_401(self, db_session): | |
| """Invalid Authorization header format raises 401.""" | |
| from core.dependencies import get_current_user | |
| mock_request = MagicMock(spec=Request) | |
| mock_request.headers.get.return_value = "InvalidFormat token123" | |
| with pytest.raises(HTTPException) as exc_info: | |
| await get_current_user(mock_request, db_session) | |
| assert exc_info.value.status_code == 401 | |
| async def test_expired_token_raises_401(self, db_session): | |
| """Expired JWT token raises 401.""" | |
| from core.dependencies import get_current_user | |
| from services.auth_service.jwt_provider import TokenExpiredError | |
| mock_request = MagicMock(spec=Request) | |
| mock_request.headers.get.return_value = "Bearer expired_token" | |
| with patch('dependencies.verify_access_token') as mock_verify: | |
| mock_verify.side_effect = TokenExpiredError("Token expired") | |
| with pytest.raises(HTTPException) as exc_info: | |
| await get_current_user(mock_request, db_session) | |
| assert exc_info.value.status_code == 401 | |
| async def test_invalid_token_raises_401(self, db_session): | |
| """Invalid JWT token raises 401.""" | |
| from core.dependencies import get_current_user | |
| from services.auth_service.jwt_provider import InvalidTokenError | |
| mock_request = MagicMock(spec=Request) | |
| mock_request.headers.get.return_value = "Bearer invalid_token" | |
| with patch('dependencies.verify_access_token') as mock_verify: | |
| mock_verify.side_effect = InvalidTokenError("Invalid token") | |
| with pytest.raises(HTTPException) as exc_info: | |
| await get_current_user(mock_request, db_session) | |
| assert exc_info.value.status_code == 401 | |
| async def test_token_version_mismatch_raises_401(self, db_session): | |
| """Mismatched token version (after logout) raises 401.""" | |
| from core.dependencies import get_current_user | |
| from core.models import User | |
| # User has token_version=2 (logged out) | |
| user = User(user_id="usr_logout", email="logout@example.com", token_version=2) | |
| db_session.add(user) | |
| await db_session.commit() | |
| mock_request = MagicMock(spec=Request) | |
| mock_request.headers.get.return_value = "Bearer old_token" | |
| with patch('dependencies.verify_access_token') as mock_verify: | |
| # Token has old version | |
| mock_verify.return_value = MagicMock( | |
| user_id="usr_logout", | |
| email="logout@example.com", | |
| token_version=1 # Old version | |
| ) | |
| with pytest.raises(HTTPException) as exc_info: | |
| await get_current_user(mock_request, db_session) | |
| assert exc_info.value.status_code == 401 | |
| assert "invalidated" in exc_info.value.detail.lower() | |
| # ============================================================================ | |
| # 2. Rate Limiting Tests (already covered in test_rate_limiting.py) | |
| # ============================================================================ | |
| class TestRateLimitDependency: | |
| """Test rate limit dependency function.""" | |
| async def test_rate_limit_function_exists(self, db_session): | |
| """check_rate_limit function is accessible.""" | |
| from core.dependencies import check_rate_limit | |
| result = await check_rate_limit( | |
| db=db_session, | |
| identifier="test_ip", | |
| endpoint="/test", | |
| limit=10, | |
| window_minutes=15 | |
| ) | |
| assert isinstance(result, bool) | |
| assert result == True # First request allowed | |
| # ============================================================================ | |
| # 3. Geolocation Tests | |
| # ============================================================================ | |
| class TestGeolocation: | |
| """Test IP geolocation functionality.""" | |
| async def test_geolocation_with_valid_ip(self): | |
| """Get geolocation for valid IP address.""" | |
| from core.utils import get_geolocation | |
| with patch('dependencies.httpx.AsyncClient') as mock_client: | |
| # Mock API response | |
| mock_response = MagicMock() | |
| mock_response.status_code = 200 | |
| mock_response.json.return_value = { | |
| "status": "success", | |
| "country": "United States", | |
| "regionName": "California" | |
| } | |
| mock_client.return_value.__aenter__.return_value.get.return_value = mock_response | |
| country, region = await get_geolocation("8.8.8.8") | |
| assert country == "United States" | |
| assert region == "California" | |
| async def test_geolocation_with_invalid_ip(self): | |
| """Handle invalid IP gracefully.""" | |
| from core.utils import get_geolocation | |
| country, region = await get_geolocation("invalid_ip") | |
| # Should return None, None for invalid IP | |
| assert country is None or country == "Unknown" | |
| assert region is None or region == "Unknown" | |
| async def test_geolocation_with_none_ip(self): | |
| """Handle None IP gracefully.""" | |
| from core.utils import get_geolocation | |
| country, region = await get_geolocation(None) | |
| assert country is None or country == "Unknown" | |
| assert region is None or region == "Unknown" | |
| async def test_geolocation_api_failure(self): | |
| """Handle API failure gracefully.""" | |
| from core.utils import get_geolocation | |
| with patch('dependencies.httpx.AsyncClient') as mock_client: | |
| # Mock API failure | |
| mock_client.return_value.__aenter__.return_value.get.side_effect = Exception("API Error") | |
| country, region = await get_geolocation("1.1.1.1") | |
| # Should handle error gracefully | |
| assert country is None or country == "Unknown" | |
| if __name__ == "__main__": | |
| pytest.main([__file__, "-v"]) | |