Spaces:
Sleeping
Sleeping
| """ | |
| Simple test script to verify refresh token rotation implementation | |
| Run this after setting up the service to ensure everything works | |
| """ | |
| import asyncio | |
| import sys | |
| import os | |
| # Add parent directory to path | |
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '.'))) | |
| from app.models.refresh_token_model import RefreshTokenModel | |
| from app.utils.jwt import create_refresh_token, decode_token | |
| from datetime import datetime | |
| import logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| async def test_token_family_creation(): | |
| """Test creating a token family""" | |
| logger.info("\nπ§ͺ Test 1: Token Family Creation") | |
| try: | |
| family_id = await RefreshTokenModel.create_token_family( | |
| customer_id="test-user-123", | |
| device_info="Test Device" | |
| ) | |
| logger.info(f"β Created token family: {family_id}") | |
| # Verify family exists | |
| family_data = await RefreshTokenModel.get_token_family(family_id) | |
| assert family_data is not None, "Family data should exist" | |
| assert family_data["customer_id"] == "test-user-123" | |
| logger.info("β Token family verified") | |
| return family_id | |
| except Exception as e: | |
| logger.error(f"β Test failed: {str(e)}") | |
| raise | |
| async def test_refresh_token_creation(family_id): | |
| """Test creating a refresh token with rotation support""" | |
| logger.info("\nπ§ͺ Test 2: Refresh Token Creation") | |
| try: | |
| # Create token without remember me | |
| token1, token_id1, expires_at1 = create_refresh_token( | |
| {"sub": "test-user-123"}, | |
| remember_me=False, | |
| family_id=family_id | |
| ) | |
| logger.info(f"β Created refresh token (7 days): {token_id1}") | |
| # Create token with remember me | |
| token2, token_id2, expires_at2 = create_refresh_token( | |
| {"sub": "test-user-123"}, | |
| remember_me=True, | |
| family_id=family_id | |
| ) | |
| logger.info(f"β Created refresh token (30 days): {token_id2}") | |
| # Verify expiry difference | |
| days_diff = (expires_at2 - expires_at1).days | |
| assert days_diff >= 20, f"Remember me token should be longer (diff: {days_diff} days)" | |
| logger.info(f"β Expiry difference verified: {days_diff} days") | |
| return token1, token_id1, token2, token_id2 | |
| except Exception as e: | |
| logger.error(f"β Test failed: {str(e)}") | |
| raise | |
| async def test_token_storage(token_id, family_id): | |
| """Test storing token metadata""" | |
| logger.info("\nπ§ͺ Test 3: Token Storage") | |
| try: | |
| await RefreshTokenModel.store_refresh_token( | |
| token_id=token_id, | |
| customer_id="test-user-123", | |
| family_id=family_id, | |
| expires_at=datetime.utcnow(), | |
| remember_me=True, | |
| device_info="Test Device", | |
| ip_address="127.0.0.1" | |
| ) | |
| logger.info(f"β Stored token metadata: {token_id}") | |
| # Verify storage | |
| metadata = await RefreshTokenModel.get_token_metadata(token_id) | |
| assert metadata is not None, "Token metadata should exist" | |
| assert metadata["customer_id"] == "test-user-123" | |
| assert metadata["remember_me"] == True | |
| logger.info("β Token metadata verified") | |
| return True | |
| except Exception as e: | |
| logger.error(f"β Test failed: {str(e)}") | |
| raise | |
| async def test_token_validation(token_id): | |
| """Test token validation""" | |
| logger.info("\nπ§ͺ Test 4: Token Validation") | |
| try: | |
| # Should be valid initially | |
| is_valid = await RefreshTokenModel.is_token_valid(token_id) | |
| assert is_valid, "Token should be valid initially" | |
| logger.info("β Token is valid") | |
| # Mark as used | |
| await RefreshTokenModel.mark_token_as_used(token_id) | |
| logger.info("β Marked token as used") | |
| # Should be invalid after use | |
| is_valid = await RefreshTokenModel.is_token_valid(token_id) | |
| assert not is_valid, "Token should be invalid after use" | |
| logger.info("β Token is invalid after use (rotation working)") | |
| return True | |
| except Exception as e: | |
| logger.error(f"β Test failed: {str(e)}") | |
| raise | |
| async def test_token_revocation(family_id): | |
| """Test token revocation""" | |
| logger.info("\nπ§ͺ Test 5: Token Revocation") | |
| try: | |
| # Create and store a test token | |
| token, token_id, expires_at = create_refresh_token( | |
| {"sub": "test-user-123"}, | |
| family_id=family_id | |
| ) | |
| await RefreshTokenModel.store_refresh_token( | |
| token_id=token_id, | |
| customer_id="test-user-123", | |
| family_id=family_id, | |
| expires_at=expires_at, | |
| remember_me=False, | |
| device_info="Test Device", | |
| ip_address="127.0.0.1" | |
| ) | |
| # Revoke the token | |
| success = await RefreshTokenModel.revoke_token(token_id) | |
| assert success, "Token revocation should succeed" | |
| logger.info(f"β Revoked token: {token_id}") | |
| # Verify it's invalid | |
| is_valid = await RefreshTokenModel.is_token_valid(token_id) | |
| assert not is_valid, "Revoked token should be invalid" | |
| logger.info("β Revoked token is invalid") | |
| return True | |
| except Exception as e: | |
| logger.error(f"β Test failed: {str(e)}") | |
| raise | |
| async def test_family_revocation(family_id): | |
| """Test revoking entire token family""" | |
| logger.info("\nπ§ͺ Test 6: Family Revocation") | |
| try: | |
| # Create multiple tokens in the family | |
| tokens = [] | |
| for i in range(3): | |
| token, token_id, expires_at = create_refresh_token( | |
| {"sub": "test-user-123"}, | |
| family_id=family_id | |
| ) | |
| await RefreshTokenModel.store_refresh_token( | |
| token_id=token_id, | |
| customer_id="test-user-123", | |
| family_id=family_id, | |
| expires_at=expires_at, | |
| remember_me=False, | |
| device_info=f"Test Device {i}", | |
| ip_address="127.0.0.1" | |
| ) | |
| tokens.append(token_id) | |
| logger.info(f"β Created {len(tokens)} tokens in family") | |
| # Revoke entire family | |
| revoked_count = await RefreshTokenModel.revoke_token_family(family_id) | |
| assert revoked_count >= len(tokens), f"Should revoke at least {len(tokens)} tokens" | |
| logger.info(f"β Revoked {revoked_count} tokens in family") | |
| # Verify all are invalid | |
| for token_id in tokens: | |
| is_valid = await RefreshTokenModel.is_token_valid(token_id) | |
| assert not is_valid, f"Token {token_id} should be invalid" | |
| logger.info("β All family tokens are invalid") | |
| return True | |
| except Exception as e: | |
| logger.error(f"β Test failed: {str(e)}") | |
| raise | |
| async def test_session_management(): | |
| """Test session management""" | |
| logger.info("\nπ§ͺ Test 7: Session Management") | |
| try: | |
| customer_id = "test-user-456" | |
| family_id = await RefreshTokenModel.create_token_family(customer_id, "Test Device") | |
| # Create multiple sessions | |
| for i in range(3): | |
| token, token_id, expires_at = create_refresh_token( | |
| {"sub": customer_id}, | |
| family_id=family_id | |
| ) | |
| await RefreshTokenModel.store_refresh_token( | |
| token_id=token_id, | |
| customer_id=customer_id, | |
| family_id=family_id, | |
| expires_at=expires_at, | |
| remember_me=i % 2 == 0, | |
| device_info=f"Device {i}", | |
| ip_address=f"192.168.1.{i}" | |
| ) | |
| logger.info("β Created 3 test sessions") | |
| # Get active sessions | |
| sessions = await RefreshTokenModel.get_active_sessions(customer_id) | |
| assert len(sessions) >= 3, "Should have at least 3 active sessions" | |
| logger.info(f"β Found {len(sessions)} active sessions") | |
| # Revoke all user tokens | |
| revoked_count = await RefreshTokenModel.revoke_all_user_tokens(customer_id) | |
| assert revoked_count >= 3, "Should revoke at least 3 tokens" | |
| logger.info(f"β Revoked all {revoked_count} user tokens") | |
| # Verify no active sessions | |
| sessions = await RefreshTokenModel.get_active_sessions(customer_id) | |
| assert len(sessions) == 0, "Should have no active sessions" | |
| logger.info("β No active sessions remaining") | |
| return True | |
| except Exception as e: | |
| logger.error(f"β Test failed: {str(e)}") | |
| raise | |
| async def cleanup_test_data(): | |
| """Cleanup test data""" | |
| logger.info("\nπ§Ή Cleaning up test data...") | |
| try: | |
| from app.core.nosql_client import db | |
| # Delete test tokens | |
| result = await db["refresh_tokens"].delete_many({ | |
| "customer_id": {"$in": ["test-user-123", "test-user-456"]} | |
| }) | |
| logger.info(f"β Deleted {result.deleted_count} test tokens") | |
| # Clean up Redis token families | |
| from app.core.cache_client import get_redis | |
| redis = await get_redis() | |
| cursor = 0 | |
| deleted = 0 | |
| while True: | |
| cursor, keys = await redis.scan(cursor, match="token_family:*", count=100) | |
| for key in keys: | |
| await redis.delete(key) | |
| deleted += 1 | |
| if cursor == 0: | |
| break | |
| logger.info(f"β Deleted {deleted} token families from Redis") | |
| except Exception as e: | |
| logger.error(f"β οΈ Cleanup warning: {str(e)}") | |
| async def main(): | |
| """Run all tests""" | |
| logger.info("=" * 60) | |
| logger.info("Refresh Token Rotation Test Suite") | |
| logger.info("=" * 60) | |
| try: | |
| # Test 1: Token family creation | |
| family_id = await test_token_family_creation() | |
| # Test 2: Refresh token creation | |
| token1, token_id1, token2, token_id2 = await test_refresh_token_creation(family_id) | |
| # Test 3: Token storage | |
| await test_token_storage(token_id1, family_id) | |
| # Test 4: Token validation | |
| await test_token_validation(token_id1) | |
| # Test 5: Token revocation | |
| await test_token_revocation(family_id) | |
| # Test 6: Family revocation | |
| new_family_id = await RefreshTokenModel.create_token_family("test-user-123", "Test") | |
| await test_family_revocation(new_family_id) | |
| # Test 7: Session management | |
| await test_session_management() | |
| # Cleanup | |
| await cleanup_test_data() | |
| logger.info("\n" + "=" * 60) | |
| logger.info("β All tests passed!") | |
| logger.info("=" * 60) | |
| except Exception as e: | |
| logger.error("\n" + "=" * 60) | |
| logger.error("β Tests failed!") | |
| logger.error(f"Error: {str(e)}") | |
| logger.error("=" * 60) | |
| raise | |
| if __name__ == "__main__": | |
| asyncio.run(main()) | |