Spaces:
Sleeping
Sleeping
File size: 8,702 Bytes
da9494d a42ab7e da9494d a42ab7e da9494d a42ab7e da9494d a42ab7e da9494d a42ab7e da9494d a42ab7e da9494d a42ab7e da9494d a42ab7e da9494d a42ab7e da9494d a42ab7e da9494d a42ab7e da9494d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 |
"""
Comprehensive Tests for Core Dependencies
Tests cover:
1. get_current_user - JWT extraction & verification
2. get_optional_user - Optional authentication
3. check_rate_limit - Rate limiting function
4. get_geolocation - IP geolocation
Uses mocked database and JWT services.
"""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
from fastapi import HTTPException, Request
# ============================================================================
# 1. get_current_user Tests
# ============================================================================
class TestGetCurrentUser:
"""Test get_current_user dependency."""
@pytest.mark.asyncio
async def test_valid_token_returns_user(self, db_session):
"""Valid JWT token returns authenticated user."""
from core.dependencies import get_current_user
from core.models import User
# Create user
user = User(user_id="usr_dep", email="dep@example.com", token_version=1)
db_session.add(user)
await db_session.commit()
# Mock request with valid token
mock_request = MagicMock(spec=Request)
mock_request.headers.get.return_value = "Bearer valid_token_here"
with patch('dependencies.verify_access_token') as mock_verify:
mock_verify.return_value = MagicMock(
user_id="usr_dep",
email="dep@example.com",
token_version=1
)
result = await get_current_user(mock_request, db_session)
assert result.user_id == "usr_dep"
assert result.email == "dep@example.com"
@pytest.mark.asyncio
async def test_missing_auth_header_raises_401(self, db_session):
"""Missing Authorization header raises 401."""
from core.dependencies import get_current_user
mock_request = MagicMock(spec=Request)
mock_request.headers.get.return_value = None
with pytest.raises(HTTPException) as exc_info:
await get_current_user(mock_request, db_session)
assert exc_info.value.status_code == 401
@pytest.mark.asyncio
async def test_invalid_header_format_raises_401(self, db_session):
"""Invalid Authorization header format raises 401."""
from core.dependencies import get_current_user
mock_request = MagicMock(spec=Request)
mock_request.headers.get.return_value = "InvalidFormat token123"
with pytest.raises(HTTPException) as exc_info:
await get_current_user(mock_request, db_session)
assert exc_info.value.status_code == 401
@pytest.mark.asyncio
async def test_expired_token_raises_401(self, db_session):
"""Expired JWT token raises 401."""
from core.dependencies import get_current_user
from services.auth_service.jwt_provider import TokenExpiredError
mock_request = MagicMock(spec=Request)
mock_request.headers.get.return_value = "Bearer expired_token"
with patch('dependencies.verify_access_token') as mock_verify:
mock_verify.side_effect = TokenExpiredError("Token expired")
with pytest.raises(HTTPException) as exc_info:
await get_current_user(mock_request, db_session)
assert exc_info.value.status_code == 401
@pytest.mark.asyncio
async def test_invalid_token_raises_401(self, db_session):
"""Invalid JWT token raises 401."""
from core.dependencies import get_current_user
from services.auth_service.jwt_provider import InvalidTokenError
mock_request = MagicMock(spec=Request)
mock_request.headers.get.return_value = "Bearer invalid_token"
with patch('dependencies.verify_access_token') as mock_verify:
mock_verify.side_effect = InvalidTokenError("Invalid token")
with pytest.raises(HTTPException) as exc_info:
await get_current_user(mock_request, db_session)
assert exc_info.value.status_code == 401
@pytest.mark.asyncio
async def test_token_version_mismatch_raises_401(self, db_session):
"""Mismatched token version (after logout) raises 401."""
from core.dependencies import get_current_user
from core.models import User
# User has token_version=2 (logged out)
user = User(user_id="usr_logout", email="logout@example.com", token_version=2)
db_session.add(user)
await db_session.commit()
mock_request = MagicMock(spec=Request)
mock_request.headers.get.return_value = "Bearer old_token"
with patch('dependencies.verify_access_token') as mock_verify:
# Token has old version
mock_verify.return_value = MagicMock(
user_id="usr_logout",
email="logout@example.com",
token_version=1 # Old version
)
with pytest.raises(HTTPException) as exc_info:
await get_current_user(mock_request, db_session)
assert exc_info.value.status_code == 401
assert "invalidated" in exc_info.value.detail.lower()
# ============================================================================
# 2. Rate Limiting Tests (already covered in test_rate_limiting.py)
# ============================================================================
class TestRateLimitDependency:
"""Test rate limit dependency function."""
@pytest.mark.asyncio
async def test_rate_limit_function_exists(self, db_session):
"""check_rate_limit function is accessible."""
from core.dependencies import check_rate_limit
result = await check_rate_limit(
db=db_session,
identifier="test_ip",
endpoint="/test",
limit=10,
window_minutes=15
)
assert isinstance(result, bool)
assert result == True # First request allowed
# ============================================================================
# 3. Geolocation Tests
# ============================================================================
class TestGeolocation:
"""Test IP geolocation functionality."""
@pytest.mark.asyncio
async def test_geolocation_with_valid_ip(self):
"""Get geolocation for valid IP address."""
from core.utils import get_geolocation
with patch('dependencies.httpx.AsyncClient') as mock_client:
# Mock API response
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"status": "success",
"country": "United States",
"regionName": "California"
}
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
country, region = await get_geolocation("8.8.8.8")
assert country == "United States"
assert region == "California"
@pytest.mark.asyncio
async def test_geolocation_with_invalid_ip(self):
"""Handle invalid IP gracefully."""
from core.utils import get_geolocation
country, region = await get_geolocation("invalid_ip")
# Should return None, None for invalid IP
assert country is None or country == "Unknown"
assert region is None or region == "Unknown"
@pytest.mark.asyncio
async def test_geolocation_with_none_ip(self):
"""Handle None IP gracefully."""
from core.utils import get_geolocation
country, region = await get_geolocation(None)
assert country is None or country == "Unknown"
assert region is None or region == "Unknown"
@pytest.mark.asyncio
async def test_geolocation_api_failure(self):
"""Handle API failure gracefully."""
from core.utils import get_geolocation
with patch('dependencies.httpx.AsyncClient') as mock_client:
# Mock API failure
mock_client.return_value.__aenter__.return_value.get.side_effect = Exception("API Error")
country, region = await get_geolocation("1.1.1.1")
# Should handle error gracefully
assert country is None or country == "Unknown"
if __name__ == "__main__":
pytest.main([__file__, "-v"])
|