apigateway / tests /test_auth_service.py
jebin2's picture
feat: add comprehensive auth service test suite
d9bae79
"""
Test Suite for Auth Service
Comprehensive tests for the authentication service including:
- JWT token creation and verification
- Token expiry validation
- Token version checking (logout/invalidation)
- Google OAuth token verification (mocked)
- Error handling
"""
import pytest
import os
from datetime import datetime, timedelta
from unittest.mock import patch, MagicMock
from services.auth_service.jwt_provider import (
JWTService,
TokenPayload,
create_access_token,
create_refresh_token,
verify_access_token,
TokenExpiredError,
InvalidTokenError,
ConfigurationError,
get_jwt_service
)
from services.auth_service.google_provider import (
GoogleAuthService,
GoogleUserInfo,
InvalidTokenError as GoogleInvalidTokenError,
ConfigurationError as GoogleConfigError,
get_google_auth_service
)
# ============================================================================
# Fixtures
# ============================================================================
@pytest.fixture
def jwt_secret():
"""Provide a test JWT secret."""
return "test-secret-key-for-testing-only-do-not-use-in-production"
@pytest.fixture
def jwt_service(jwt_secret):
"""Create a JWTService instance for testing."""
return JWTService(
secret_key=jwt_secret,
algorithm="HS256",
access_expiry_minutes=15,
refresh_expiry_days=7
)
@pytest.fixture
def google_client_id():
"""Provide a test Google client ID."""
return "test-google-client-id.apps.googleusercontent.com"
@pytest.fixture
def mock_google_user_info():
"""Provide mock Google user info."""
return GoogleUserInfo(
google_id="12345678901234567890",
email="test@example.com",
name="Test User",
picture="https://example.com/photo.jpg"
)
# ============================================================================
# JWT Service Tests
# ============================================================================
class TestJWTService:
"""Test JWT token creation and verification."""
def test_service_initialization(self, jwt_secret):
"""Test that JWT service initializes correctly."""
service = JWTService(
secret_key=jwt_secret,
algorithm="HS256",
access_expiry_minutes=15,
refresh_expiry_days=7
)
assert service.secret_key == jwt_secret
assert service.algorithm == "HS256"
assert service.access_expiry_minutes == 15
assert service.refresh_expiry_days == 7
def test_service_requires_secret(self, monkeypatch):
"""Test that service requires a secret key."""
# Clear environment variable so it can't fall back to env
monkeypatch.delenv("JWT_SECRET", raising=False)
with pytest.raises(ConfigurationError) as exc_info:
JWTService(secret_key=None) # None and no env var
assert "secret" in str(exc_info.value).lower()
def test_service_warns_short_secret(self, caplog):
"""Test that service warns about short secret keys."""
short_secret = "short"
service = JWTService(secret_key=short_secret)
assert "short" in caplog.text.lower() or "32 chars" in caplog.text.lower()
def test_service_from_env(self, monkeypatch, jwt_secret):
"""Test that service reads config from environment."""
monkeypatch.setenv("JWT_SECRET", jwt_secret)
monkeypatch.setenv("JWT_ALGORITHM", "HS512")
monkeypatch.setenv("JWT_ACCESS_EXPIRY_MINUTES", "30")
monkeypatch.setenv("JWT_REFRESH_EXPIRY_DAYS", "14")
service = JWTService()
assert service.secret_key == jwt_secret
assert service.algorithm == "HS512"
assert service.access_expiry_minutes == 30
assert service.refresh_expiry_days == 14
class TestAccessTokenCreation:
"""Test access token creation."""
def test_create_access_token(self, jwt_service):
"""Test creating an access token."""
token = jwt_service.create_access_token(
user_id="usr_123",
email="test@example.com",
token_version=1
)
assert isinstance(token, str)
assert len(token) > 0
assert token.count('.') == 2 # JWT format: header.payload.signature
def test_access_token_payload(self, jwt_service):
"""Test that access token has correct payload."""
token = jwt_service.create_access_token(
user_id="usr_123",
email="test@example.com",
token_version=1
)
payload = jwt_service.verify_token(token)
assert payload.user_id == "usr_123"
assert payload.email == "test@example.com"
assert payload.token_version == 1
assert payload.token_type == "access"
def test_access_token_expiry(self, jwt_service):
"""Test that access token has correct expiry time."""
before = datetime.utcnow()
token = jwt_service.create_access_token(
user_id="usr_123",
email="test@example.com"
)
after = datetime.utcnow()
payload = jwt_service.verify_token(token)
# Should expire 15 minutes from creation (with some tolerance for execution time)
expected_min = before + timedelta(minutes=15) - timedelta(seconds=1)
expected_max = after + timedelta(minutes=15) + timedelta(seconds=1)
assert expected_min <= payload.expires_at <= expected_max
def test_access_token_custom_expiry(self, jwt_service):
"""Test creating token with custom expiry."""
custom_delta = timedelta(hours=1)
token = jwt_service.create_token(
user_id="usr_123",
email="test@example.com",
token_type="access",
expiry_delta=custom_delta
)
payload = jwt_service.verify_token(token)
time_diff = payload.expires_at - payload.issued_at
# Should be approximately 1 hour
assert 3590 <= time_diff.total_seconds() <= 3610
def test_access_token_extra_claims(self, jwt_service):
"""Test creating token with extra claims."""
token = jwt_service.create_token(
user_id="usr_123",
email="test@example.com",
token_type="access",
extra_claims={"role": "admin", "org": "test_org"}
)
payload = jwt_service.verify_token(token)
assert payload.extra.get("role") == "admin"
assert payload.extra.get("org") == "test_org"
class TestRefreshTokenCreation:
"""Test refresh token creation."""
def test_create_refresh_token(self, jwt_service):
"""Test creating a refresh token."""
token = jwt_service.create_refresh_token(
user_id="usr_123",
email="test@example.com",
token_version=1
)
assert isinstance(token, str)
assert len(token) > 0
def test_refresh_token_type(self, jwt_service):
"""Test that refresh token has correct type."""
token = jwt_service.create_refresh_token(
user_id="usr_123",
email="test@example.com"
)
payload = jwt_service.verify_token(token)
assert payload.token_type == "refresh"
def test_refresh_token_longer_expiry(self, jwt_service):
"""Test that refresh token expires in 7 days."""
before = datetime.utcnow()
token = jwt_service.create_refresh_token(
user_id="usr_123",
email="test@example.com"
)
payload = jwt_service.verify_token(token)
time_diff = payload.expires_at - before
# Should be approximately 7 days
expected_seconds = 7 * 24 * 60 * 60
assert abs(time_diff.total_seconds() - expected_seconds) < 10
class TestTokenVerification:
"""Test token verification."""
def test_verify_valid_token(self, jwt_service):
"""Test verifying a valid token."""
token = jwt_service.create_access_token(
user_id="usr_123",
email="test@example.com"
)
payload = jwt_service.verify_token(token)
assert payload.user_id == "usr_123"
assert payload.email == "test@example.com"
def test_verify_empty_token(self, jwt_service):
"""Test that empty token raises error."""
with pytest.raises(InvalidTokenError) as exc_info:
jwt_service.verify_token("")
assert "empty" in str(exc_info.value).lower()
def test_verify_malformed_token(self, jwt_service):
"""Test that malformed token raises error."""
with pytest.raises(InvalidTokenError):
jwt_service.verify_token("not.a.valid.jwt.token")
def test_verify_tampered_token(self, jwt_service):
"""Test that tampered token raises error."""
token = jwt_service.create_access_token(
user_id="usr_123",
email="test@example.com"
)
# Tamper with the token
parts = token.split('.')
parts[1] = parts[1][:-5] + "AAAAA" # Change payload
tampered = '.'.join(parts)
with pytest.raises(InvalidTokenError):
jwt_service.verify_token(tampered)
def test_verify_token_wrong_secret(self, jwt_service):
"""Test that token with wrong secret fails."""
# Create token with one secret
token = jwt_service.create_access_token(
user_id="usr_123",
email="test@example.com"
)
# Try to verify with different secret
wrong_service = JWTService(secret_key="different-secret")
with pytest.raises(InvalidTokenError):
wrong_service.verify_token(token)
class TestTokenExpiry:
"""Test token expiry behavior."""
def test_expired_token_raises_error(self, jwt_service):
"""Test that expired token raises TokenExpiredError."""
# Create token that expires immediately
token = jwt_service.create_token(
user_id="usr_123",
email="test@example.com",
token_type="access",
expiry_delta=timedelta(seconds=-1) # Already expired
)
with pytest.raises(TokenExpiredError) as exc_info:
jwt_service.verify_token(token)
assert "expired" in str(exc_info.value).lower()
def test_token_not_expired_yet(self, jwt_service):
"""Test that non-expired token verifies successfully."""
token = jwt_service.create_token(
user_id="usr_123",
email="test@example.com",
token_type="access",
expiry_delta=timedelta(hours=1)
)
# Should not raise
payload = jwt_service.verify_token(token)
assert payload.user_id == "usr_123"
assert not payload.is_expired
def test_token_expiry_property(self, jwt_service):
"""Test TokenPayload.is_expired property."""
token = jwt_service.create_token(
user_id="usr_123",
email="test@example.com",
expiry_delta=timedelta(seconds=-1)
)
# Decode without verifying expiry
import jwt as pyjwt
payload_dict = pyjwt.decode(
token,
jwt_service.secret_key,
algorithms=[jwt_service.algorithm],
options={"verify_exp": False}
)
payload = TokenPayload(
user_id=payload_dict["sub"],
email=payload_dict["email"],
issued_at=datetime.utcfromtimestamp(payload_dict["iat"]),
expires_at=datetime.utcfromtimestamp(payload_dict["exp"]),
token_version=payload_dict.get("tv", 1),
token_type=payload_dict.get("type", "access")
)
assert payload.is_expired is True
class TestTokenVersion:
"""Test token version functionality."""
def test_token_version_in_payload(self, jwt_service):
"""Test that token version is included in payload."""
token = jwt_service.create_access_token(
user_id="usr_123",
email="test@example.com",
token_version=5
)
payload = jwt_service.verify_token(token)
assert payload.token_version == 5
def test_default_token_version(self, jwt_service):
"""Test that default token version is 1."""
token = jwt_service.create_access_token(
user_id="usr_123",
email="test@example.com"
)
payload = jwt_service.verify_token(token)
assert payload.token_version == 1
class TestConvenienceFunctions:
"""Test module-level convenience functions."""
def test_create_access_token_function(self, monkeypatch, jwt_secret):
"""Test create_access_token convenience function."""
monkeypatch.setenv("JWT_SECRET", jwt_secret)
# Reset singleton
import services.auth_service.jwt_provider as jwt_module
jwt_module._default_service = None
token = create_access_token(
user_id="usr_123",
email="test@example.com"
)
assert isinstance(token, str)
assert len(token) > 0
def test_create_refresh_token_function(self, monkeypatch, jwt_secret):
"""Test create_refresh_token convenience function."""
monkeypatch.setenv("JWT_SECRET", jwt_secret)
# Reset singleton
import services.auth_service.jwt_provider as jwt_module
jwt_module._default_service = None
token = create_refresh_token(
user_id="usr_123",
email="test@example.com"
)
assert isinstance(token, str)
payload_dict = jwt_module.get_jwt_service().verify_token(token)
assert payload_dict.token_type == "refresh"
def test_verify_access_token_function(self, monkeypatch, jwt_secret):
"""Test verify_access_token convenience function."""
monkeypatch.setenv("JWT_SECRET", jwt_secret)
# Reset singleton
import services.auth_service.jwt_provider as jwt_module
jwt_module._default_service = None
token = create_access_token(
user_id="usr_123",
email="test@example.com"
)
payload = verify_access_token(token)
assert payload.user_id == "usr_123"
def test_get_jwt_service_singleton(self, monkeypatch, jwt_secret):
"""Test that get_jwt_service returns singleton."""
monkeypatch.setenv("JWT_SECRET", jwt_secret)
# Reset singleton
import services.auth_service.jwt_provider as jwt_module
jwt_module._default_service = None
service1 = get_jwt_service()
service2 = get_jwt_service()
assert service1 is service2 # Same instance
# ============================================================================
# Google OAuth Tests
# ============================================================================
class TestGoogleAuthService:
"""Test Google OAuth integration."""
def test_service_initialization(self, google_client_id):
"""Test Google auth service initialization."""
service = GoogleAuthService(client_id=google_client_id)
assert service.client_id == google_client_id
def test_service_requires_client_id(self, monkeypatch):
"""Test that service requires client ID."""
# Clear environment variable so it can't fall back to env
monkeypatch.delenv("AUTH_SIGN_IN_GOOGLE_CLIENT_ID", raising=False)
monkeypatch.delenv("GOOGLE_CLIENT_ID", raising=False)
with pytest.raises(GoogleConfigError) as exc_info:
GoogleAuthService(client_id=None) # None and no env var
assert "client id" in str(exc_info.value).lower()
@patch('google.oauth2.id_token.verify_oauth2_token')
def test_verify_valid_token(self, mock_verify, google_client_id, mock_google_user_info):
"""Test verifying valid Google ID token."""
# Mock the Google verification
mock_verify.return_value = {
'sub': mock_google_user_info.google_id,
'email': mock_google_user_info.email,
'name': mock_google_user_info.name,
'picture': mock_google_user_info.picture,
'iss': 'accounts.google.com',
'aud': google_client_id
}
service = GoogleAuthService(client_id=google_client_id)
user_info = service.verify_token("fake-google-id-token")
assert user_info.google_id == mock_google_user_info.google_id
assert user_info.email == mock_google_user_info.email
assert user_info.name == mock_google_user_info.name
assert user_info.picture == mock_google_user_info.picture
@patch('google.oauth2.id_token.verify_oauth2_token')
def test_verify_invalid_token(self, mock_verify, google_client_id):
"""Test that invalid token raises error."""
# Mock verification failure
mock_verify.side_effect = ValueError("Invalid token")
service = GoogleAuthService(client_id=google_client_id)
with pytest.raises(GoogleInvalidTokenError) as exc_info:
service.verify_token("invalid-token")
assert "invalid" in str(exc_info.value).lower()
@patch('google.oauth2.id_token.verify_oauth2_token')
def test_verify_wrong_audience(self, mock_verify, google_client_id):
"""Test that token with wrong audience fails."""
# Mock token with wrong audience
mock_verify.return_value = {
'sub': '12345',
'email': 'test@example.com',
'iss': 'accounts.google.com',
'aud': 'wrong-client-id'
}
service = GoogleAuthService(client_id=google_client_id)
with pytest.raises(GoogleInvalidTokenError):
service.verify_token("token-for-wrong-app")
# ============================================================================
# Run Tests
# ============================================================================
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])