Spaces:
Sleeping
Sleeping
| """ | |
| Test Suite for Auth Service | |
| Comprehensive tests for the authentication service including: | |
| - JWT token creation and verification | |
| - Token expiry validation | |
| - Token version checking (logout/invalidation) | |
| - Google OAuth token verification (mocked) | |
| - Error handling | |
| """ | |
| import pytest | |
| import os | |
| from datetime import datetime, timedelta | |
| from unittest.mock import patch, MagicMock | |
| from services.auth_service.jwt_provider import ( | |
| JWTService, | |
| TokenPayload, | |
| create_access_token, | |
| create_refresh_token, | |
| verify_access_token, | |
| TokenExpiredError, | |
| InvalidTokenError, | |
| ConfigurationError, | |
| get_jwt_service | |
| ) | |
| from services.auth_service.google_provider import ( | |
| GoogleAuthService, | |
| GoogleUserInfo, | |
| InvalidTokenError as GoogleInvalidTokenError, | |
| ConfigurationError as GoogleConfigError, | |
| get_google_auth_service | |
| ) | |
| # ============================================================================ | |
| # Fixtures | |
| # ============================================================================ | |
| def jwt_secret(): | |
| """Provide a test JWT secret.""" | |
| return "test-secret-key-for-testing-only-do-not-use-in-production" | |
| def jwt_service(jwt_secret): | |
| """Create a JWTService instance for testing.""" | |
| return JWTService( | |
| secret_key=jwt_secret, | |
| algorithm="HS256", | |
| access_expiry_minutes=15, | |
| refresh_expiry_days=7 | |
| ) | |
| def google_client_id(): | |
| """Provide a test Google client ID.""" | |
| return "test-google-client-id.apps.googleusercontent.com" | |
| def mock_google_user_info(): | |
| """Provide mock Google user info.""" | |
| return GoogleUserInfo( | |
| google_id="12345678901234567890", | |
| email="test@example.com", | |
| name="Test User", | |
| picture="https://example.com/photo.jpg" | |
| ) | |
| # ============================================================================ | |
| # JWT Service Tests | |
| # ============================================================================ | |
| class TestJWTService: | |
| """Test JWT token creation and verification.""" | |
| def test_service_initialization(self, jwt_secret): | |
| """Test that JWT service initializes correctly.""" | |
| service = JWTService( | |
| secret_key=jwt_secret, | |
| algorithm="HS256", | |
| access_expiry_minutes=15, | |
| refresh_expiry_days=7 | |
| ) | |
| assert service.secret_key == jwt_secret | |
| assert service.algorithm == "HS256" | |
| assert service.access_expiry_minutes == 15 | |
| assert service.refresh_expiry_days == 7 | |
| def test_service_requires_secret(self, monkeypatch): | |
| """Test that service requires a secret key.""" | |
| # Clear environment variable so it can't fall back to env | |
| monkeypatch.delenv("JWT_SECRET", raising=False) | |
| with pytest.raises(ConfigurationError) as exc_info: | |
| JWTService(secret_key=None) # None and no env var | |
| assert "secret" in str(exc_info.value).lower() | |
| def test_service_warns_short_secret(self, caplog): | |
| """Test that service warns about short secret keys.""" | |
| short_secret = "short" | |
| service = JWTService(secret_key=short_secret) | |
| assert "short" in caplog.text.lower() or "32 chars" in caplog.text.lower() | |
| def test_service_from_env(self, monkeypatch, jwt_secret): | |
| """Test that service reads config from environment.""" | |
| monkeypatch.setenv("JWT_SECRET", jwt_secret) | |
| monkeypatch.setenv("JWT_ALGORITHM", "HS512") | |
| monkeypatch.setenv("JWT_ACCESS_EXPIRY_MINUTES", "30") | |
| monkeypatch.setenv("JWT_REFRESH_EXPIRY_DAYS", "14") | |
| service = JWTService() | |
| assert service.secret_key == jwt_secret | |
| assert service.algorithm == "HS512" | |
| assert service.access_expiry_minutes == 30 | |
| assert service.refresh_expiry_days == 14 | |
| class TestAccessTokenCreation: | |
| """Test access token creation.""" | |
| def test_create_access_token(self, jwt_service): | |
| """Test creating an access token.""" | |
| token = jwt_service.create_access_token( | |
| user_id="usr_123", | |
| email="test@example.com", | |
| token_version=1 | |
| ) | |
| assert isinstance(token, str) | |
| assert len(token) > 0 | |
| assert token.count('.') == 2 # JWT format: header.payload.signature | |
| def test_access_token_payload(self, jwt_service): | |
| """Test that access token has correct payload.""" | |
| token = jwt_service.create_access_token( | |
| user_id="usr_123", | |
| email="test@example.com", | |
| token_version=1 | |
| ) | |
| payload = jwt_service.verify_token(token) | |
| assert payload.user_id == "usr_123" | |
| assert payload.email == "test@example.com" | |
| assert payload.token_version == 1 | |
| assert payload.token_type == "access" | |
| def test_access_token_expiry(self, jwt_service): | |
| """Test that access token has correct expiry time.""" | |
| before = datetime.utcnow() | |
| token = jwt_service.create_access_token( | |
| user_id="usr_123", | |
| email="test@example.com" | |
| ) | |
| after = datetime.utcnow() | |
| payload = jwt_service.verify_token(token) | |
| # Should expire 15 minutes from creation (with some tolerance for execution time) | |
| expected_min = before + timedelta(minutes=15) - timedelta(seconds=1) | |
| expected_max = after + timedelta(minutes=15) + timedelta(seconds=1) | |
| assert expected_min <= payload.expires_at <= expected_max | |
| def test_access_token_custom_expiry(self, jwt_service): | |
| """Test creating token with custom expiry.""" | |
| custom_delta = timedelta(hours=1) | |
| token = jwt_service.create_token( | |
| user_id="usr_123", | |
| email="test@example.com", | |
| token_type="access", | |
| expiry_delta=custom_delta | |
| ) | |
| payload = jwt_service.verify_token(token) | |
| time_diff = payload.expires_at - payload.issued_at | |
| # Should be approximately 1 hour | |
| assert 3590 <= time_diff.total_seconds() <= 3610 | |
| def test_access_token_extra_claims(self, jwt_service): | |
| """Test creating token with extra claims.""" | |
| token = jwt_service.create_token( | |
| user_id="usr_123", | |
| email="test@example.com", | |
| token_type="access", | |
| extra_claims={"role": "admin", "org": "test_org"} | |
| ) | |
| payload = jwt_service.verify_token(token) | |
| assert payload.extra.get("role") == "admin" | |
| assert payload.extra.get("org") == "test_org" | |
| class TestRefreshTokenCreation: | |
| """Test refresh token creation.""" | |
| def test_create_refresh_token(self, jwt_service): | |
| """Test creating a refresh token.""" | |
| token = jwt_service.create_refresh_token( | |
| user_id="usr_123", | |
| email="test@example.com", | |
| token_version=1 | |
| ) | |
| assert isinstance(token, str) | |
| assert len(token) > 0 | |
| def test_refresh_token_type(self, jwt_service): | |
| """Test that refresh token has correct type.""" | |
| token = jwt_service.create_refresh_token( | |
| user_id="usr_123", | |
| email="test@example.com" | |
| ) | |
| payload = jwt_service.verify_token(token) | |
| assert payload.token_type == "refresh" | |
| def test_refresh_token_longer_expiry(self, jwt_service): | |
| """Test that refresh token expires in 7 days.""" | |
| before = datetime.utcnow() | |
| token = jwt_service.create_refresh_token( | |
| user_id="usr_123", | |
| email="test@example.com" | |
| ) | |
| payload = jwt_service.verify_token(token) | |
| time_diff = payload.expires_at - before | |
| # Should be approximately 7 days | |
| expected_seconds = 7 * 24 * 60 * 60 | |
| assert abs(time_diff.total_seconds() - expected_seconds) < 10 | |
| class TestTokenVerification: | |
| """Test token verification.""" | |
| def test_verify_valid_token(self, jwt_service): | |
| """Test verifying a valid token.""" | |
| token = jwt_service.create_access_token( | |
| user_id="usr_123", | |
| email="test@example.com" | |
| ) | |
| payload = jwt_service.verify_token(token) | |
| assert payload.user_id == "usr_123" | |
| assert payload.email == "test@example.com" | |
| def test_verify_empty_token(self, jwt_service): | |
| """Test that empty token raises error.""" | |
| with pytest.raises(InvalidTokenError) as exc_info: | |
| jwt_service.verify_token("") | |
| assert "empty" in str(exc_info.value).lower() | |
| def test_verify_malformed_token(self, jwt_service): | |
| """Test that malformed token raises error.""" | |
| with pytest.raises(InvalidTokenError): | |
| jwt_service.verify_token("not.a.valid.jwt.token") | |
| def test_verify_tampered_token(self, jwt_service): | |
| """Test that tampered token raises error.""" | |
| token = jwt_service.create_access_token( | |
| user_id="usr_123", | |
| email="test@example.com" | |
| ) | |
| # Tamper with the token | |
| parts = token.split('.') | |
| parts[1] = parts[1][:-5] + "AAAAA" # Change payload | |
| tampered = '.'.join(parts) | |
| with pytest.raises(InvalidTokenError): | |
| jwt_service.verify_token(tampered) | |
| def test_verify_token_wrong_secret(self, jwt_service): | |
| """Test that token with wrong secret fails.""" | |
| # Create token with one secret | |
| token = jwt_service.create_access_token( | |
| user_id="usr_123", | |
| email="test@example.com" | |
| ) | |
| # Try to verify with different secret | |
| wrong_service = JWTService(secret_key="different-secret") | |
| with pytest.raises(InvalidTokenError): | |
| wrong_service.verify_token(token) | |
| class TestTokenExpiry: | |
| """Test token expiry behavior.""" | |
| def test_expired_token_raises_error(self, jwt_service): | |
| """Test that expired token raises TokenExpiredError.""" | |
| # Create token that expires immediately | |
| token = jwt_service.create_token( | |
| user_id="usr_123", | |
| email="test@example.com", | |
| token_type="access", | |
| expiry_delta=timedelta(seconds=-1) # Already expired | |
| ) | |
| with pytest.raises(TokenExpiredError) as exc_info: | |
| jwt_service.verify_token(token) | |
| assert "expired" in str(exc_info.value).lower() | |
| def test_token_not_expired_yet(self, jwt_service): | |
| """Test that non-expired token verifies successfully.""" | |
| token = jwt_service.create_token( | |
| user_id="usr_123", | |
| email="test@example.com", | |
| token_type="access", | |
| expiry_delta=timedelta(hours=1) | |
| ) | |
| # Should not raise | |
| payload = jwt_service.verify_token(token) | |
| assert payload.user_id == "usr_123" | |
| assert not payload.is_expired | |
| def test_token_expiry_property(self, jwt_service): | |
| """Test TokenPayload.is_expired property.""" | |
| token = jwt_service.create_token( | |
| user_id="usr_123", | |
| email="test@example.com", | |
| expiry_delta=timedelta(seconds=-1) | |
| ) | |
| # Decode without verifying expiry | |
| import jwt as pyjwt | |
| payload_dict = pyjwt.decode( | |
| token, | |
| jwt_service.secret_key, | |
| algorithms=[jwt_service.algorithm], | |
| options={"verify_exp": False} | |
| ) | |
| payload = TokenPayload( | |
| user_id=payload_dict["sub"], | |
| email=payload_dict["email"], | |
| issued_at=datetime.utcfromtimestamp(payload_dict["iat"]), | |
| expires_at=datetime.utcfromtimestamp(payload_dict["exp"]), | |
| token_version=payload_dict.get("tv", 1), | |
| token_type=payload_dict.get("type", "access") | |
| ) | |
| assert payload.is_expired is True | |
| class TestTokenVersion: | |
| """Test token version functionality.""" | |
| def test_token_version_in_payload(self, jwt_service): | |
| """Test that token version is included in payload.""" | |
| token = jwt_service.create_access_token( | |
| user_id="usr_123", | |
| email="test@example.com", | |
| token_version=5 | |
| ) | |
| payload = jwt_service.verify_token(token) | |
| assert payload.token_version == 5 | |
| def test_default_token_version(self, jwt_service): | |
| """Test that default token version is 1.""" | |
| token = jwt_service.create_access_token( | |
| user_id="usr_123", | |
| email="test@example.com" | |
| ) | |
| payload = jwt_service.verify_token(token) | |
| assert payload.token_version == 1 | |
| class TestConvenienceFunctions: | |
| """Test module-level convenience functions.""" | |
| def test_create_access_token_function(self, monkeypatch, jwt_secret): | |
| """Test create_access_token convenience function.""" | |
| monkeypatch.setenv("JWT_SECRET", jwt_secret) | |
| # Reset singleton | |
| import services.auth_service.jwt_provider as jwt_module | |
| jwt_module._default_service = None | |
| token = create_access_token( | |
| user_id="usr_123", | |
| email="test@example.com" | |
| ) | |
| assert isinstance(token, str) | |
| assert len(token) > 0 | |
| def test_create_refresh_token_function(self, monkeypatch, jwt_secret): | |
| """Test create_refresh_token convenience function.""" | |
| monkeypatch.setenv("JWT_SECRET", jwt_secret) | |
| # Reset singleton | |
| import services.auth_service.jwt_provider as jwt_module | |
| jwt_module._default_service = None | |
| token = create_refresh_token( | |
| user_id="usr_123", | |
| email="test@example.com" | |
| ) | |
| assert isinstance(token, str) | |
| payload_dict = jwt_module.get_jwt_service().verify_token(token) | |
| assert payload_dict.token_type == "refresh" | |
| def test_verify_access_token_function(self, monkeypatch, jwt_secret): | |
| """Test verify_access_token convenience function.""" | |
| monkeypatch.setenv("JWT_SECRET", jwt_secret) | |
| # Reset singleton | |
| import services.auth_service.jwt_provider as jwt_module | |
| jwt_module._default_service = None | |
| token = create_access_token( | |
| user_id="usr_123", | |
| email="test@example.com" | |
| ) | |
| payload = verify_access_token(token) | |
| assert payload.user_id == "usr_123" | |
| def test_get_jwt_service_singleton(self, monkeypatch, jwt_secret): | |
| """Test that get_jwt_service returns singleton.""" | |
| monkeypatch.setenv("JWT_SECRET", jwt_secret) | |
| # Reset singleton | |
| import services.auth_service.jwt_provider as jwt_module | |
| jwt_module._default_service = None | |
| service1 = get_jwt_service() | |
| service2 = get_jwt_service() | |
| assert service1 is service2 # Same instance | |
| # ============================================================================ | |
| # Google OAuth Tests | |
| # ============================================================================ | |
| class TestGoogleAuthService: | |
| """Test Google OAuth integration.""" | |
| def test_service_initialization(self, google_client_id): | |
| """Test Google auth service initialization.""" | |
| service = GoogleAuthService(client_id=google_client_id) | |
| assert service.client_id == google_client_id | |
| def test_service_requires_client_id(self, monkeypatch): | |
| """Test that service requires client ID.""" | |
| # Clear environment variable so it can't fall back to env | |
| monkeypatch.delenv("AUTH_SIGN_IN_GOOGLE_CLIENT_ID", raising=False) | |
| monkeypatch.delenv("GOOGLE_CLIENT_ID", raising=False) | |
| with pytest.raises(GoogleConfigError) as exc_info: | |
| GoogleAuthService(client_id=None) # None and no env var | |
| assert "client id" in str(exc_info.value).lower() | |
| def test_verify_valid_token(self, mock_verify, google_client_id, mock_google_user_info): | |
| """Test verifying valid Google ID token.""" | |
| # Mock the Google verification | |
| mock_verify.return_value = { | |
| 'sub': mock_google_user_info.google_id, | |
| 'email': mock_google_user_info.email, | |
| 'name': mock_google_user_info.name, | |
| 'picture': mock_google_user_info.picture, | |
| 'iss': 'accounts.google.com', | |
| 'aud': google_client_id | |
| } | |
| service = GoogleAuthService(client_id=google_client_id) | |
| user_info = service.verify_token("fake-google-id-token") | |
| assert user_info.google_id == mock_google_user_info.google_id | |
| assert user_info.email == mock_google_user_info.email | |
| assert user_info.name == mock_google_user_info.name | |
| assert user_info.picture == mock_google_user_info.picture | |
| def test_verify_invalid_token(self, mock_verify, google_client_id): | |
| """Test that invalid token raises error.""" | |
| # Mock verification failure | |
| mock_verify.side_effect = ValueError("Invalid token") | |
| service = GoogleAuthService(client_id=google_client_id) | |
| with pytest.raises(GoogleInvalidTokenError) as exc_info: | |
| service.verify_token("invalid-token") | |
| assert "invalid" in str(exc_info.value).lower() | |
| def test_verify_wrong_audience(self, mock_verify, google_client_id): | |
| """Test that token with wrong audience fails.""" | |
| # Mock token with wrong audience | |
| mock_verify.return_value = { | |
| 'sub': '12345', | |
| 'email': 'test@example.com', | |
| 'iss': 'accounts.google.com', | |
| 'aud': 'wrong-client-id' | |
| } | |
| service = GoogleAuthService(client_id=google_client_id) | |
| with pytest.raises(GoogleInvalidTokenError): | |
| service.verify_token("token-for-wrong-app") | |
| # ============================================================================ | |
| # Run Tests | |
| # ============================================================================ | |
| if __name__ == "__main__": | |
| pytest.main([__file__, "-v", "--tb=short"]) | |