""" Simple test script to verify refresh token rotation implementation Run this after setting up the service to ensure everything works """ import asyncio import sys import os # Add parent directory to path sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '.'))) from app.models.refresh_token_model import RefreshTokenModel from app.utils.jwt import create_refresh_token, decode_token from datetime import datetime import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) async def test_token_family_creation(): """Test creating a token family""" logger.info("\n๐Ÿงช Test 1: Token Family Creation") try: family_id = await RefreshTokenModel.create_token_family( customer_id="test-user-123", device_info="Test Device" ) logger.info(f"โœ… Created token family: {family_id}") # Verify family exists family_data = await RefreshTokenModel.get_token_family(family_id) assert family_data is not None, "Family data should exist" assert family_data["customer_id"] == "test-user-123" logger.info("โœ… Token family verified") return family_id except Exception as e: logger.error(f"โŒ Test failed: {str(e)}") raise async def test_refresh_token_creation(family_id): """Test creating a refresh token with rotation support""" logger.info("\n๐Ÿงช Test 2: Refresh Token Creation") try: # Create token without remember me token1, token_id1, expires_at1 = create_refresh_token( {"sub": "test-user-123"}, remember_me=False, family_id=family_id ) logger.info(f"โœ… Created refresh token (7 days): {token_id1}") # Create token with remember me token2, token_id2, expires_at2 = create_refresh_token( {"sub": "test-user-123"}, remember_me=True, family_id=family_id ) logger.info(f"โœ… Created refresh token (30 days): {token_id2}") # Verify expiry difference days_diff = (expires_at2 - expires_at1).days assert days_diff >= 20, f"Remember me token should be longer (diff: {days_diff} days)" logger.info(f"โœ… Expiry difference verified: {days_diff} days") return token1, token_id1, token2, token_id2 except Exception as e: logger.error(f"โŒ Test failed: {str(e)}") raise async def test_token_storage(token_id, family_id): """Test storing token metadata""" logger.info("\n๐Ÿงช Test 3: Token Storage") try: await RefreshTokenModel.store_refresh_token( token_id=token_id, customer_id="test-user-123", family_id=family_id, expires_at=datetime.utcnow(), remember_me=True, device_info="Test Device", ip_address="127.0.0.1" ) logger.info(f"โœ… Stored token metadata: {token_id}") # Verify storage metadata = await RefreshTokenModel.get_token_metadata(token_id) assert metadata is not None, "Token metadata should exist" assert metadata["customer_id"] == "test-user-123" assert metadata["remember_me"] == True logger.info("โœ… Token metadata verified") return True except Exception as e: logger.error(f"โŒ Test failed: {str(e)}") raise async def test_token_validation(token_id): """Test token validation""" logger.info("\n๐Ÿงช Test 4: Token Validation") try: # Should be valid initially is_valid = await RefreshTokenModel.is_token_valid(token_id) assert is_valid, "Token should be valid initially" logger.info("โœ… Token is valid") # Mark as used await RefreshTokenModel.mark_token_as_used(token_id) logger.info("โœ… Marked token as used") # Should be invalid after use is_valid = await RefreshTokenModel.is_token_valid(token_id) assert not is_valid, "Token should be invalid after use" logger.info("โœ… Token is invalid after use (rotation working)") return True except Exception as e: logger.error(f"โŒ Test failed: {str(e)}") raise async def test_token_revocation(family_id): """Test token revocation""" logger.info("\n๐Ÿงช Test 5: Token Revocation") try: # Create and store a test token token, token_id, expires_at = create_refresh_token( {"sub": "test-user-123"}, family_id=family_id ) await RefreshTokenModel.store_refresh_token( token_id=token_id, customer_id="test-user-123", family_id=family_id, expires_at=expires_at, remember_me=False, device_info="Test Device", ip_address="127.0.0.1" ) # Revoke the token success = await RefreshTokenModel.revoke_token(token_id) assert success, "Token revocation should succeed" logger.info(f"โœ… Revoked token: {token_id}") # Verify it's invalid is_valid = await RefreshTokenModel.is_token_valid(token_id) assert not is_valid, "Revoked token should be invalid" logger.info("โœ… Revoked token is invalid") return True except Exception as e: logger.error(f"โŒ Test failed: {str(e)}") raise async def test_family_revocation(family_id): """Test revoking entire token family""" logger.info("\n๐Ÿงช Test 6: Family Revocation") try: # Create multiple tokens in the family tokens = [] for i in range(3): token, token_id, expires_at = create_refresh_token( {"sub": "test-user-123"}, family_id=family_id ) await RefreshTokenModel.store_refresh_token( token_id=token_id, customer_id="test-user-123", family_id=family_id, expires_at=expires_at, remember_me=False, device_info=f"Test Device {i}", ip_address="127.0.0.1" ) tokens.append(token_id) logger.info(f"โœ… Created {len(tokens)} tokens in family") # Revoke entire family revoked_count = await RefreshTokenModel.revoke_token_family(family_id) assert revoked_count >= len(tokens), f"Should revoke at least {len(tokens)} tokens" logger.info(f"โœ… Revoked {revoked_count} tokens in family") # Verify all are invalid for token_id in tokens: is_valid = await RefreshTokenModel.is_token_valid(token_id) assert not is_valid, f"Token {token_id} should be invalid" logger.info("โœ… All family tokens are invalid") return True except Exception as e: logger.error(f"โŒ Test failed: {str(e)}") raise async def test_session_management(): """Test session management""" logger.info("\n๐Ÿงช Test 7: Session Management") try: customer_id = "test-user-456" family_id = await RefreshTokenModel.create_token_family(customer_id, "Test Device") # Create multiple sessions for i in range(3): token, token_id, expires_at = create_refresh_token( {"sub": customer_id}, family_id=family_id ) await RefreshTokenModel.store_refresh_token( token_id=token_id, customer_id=customer_id, family_id=family_id, expires_at=expires_at, remember_me=i % 2 == 0, device_info=f"Device {i}", ip_address=f"192.168.1.{i}" ) logger.info("โœ… Created 3 test sessions") # Get active sessions sessions = await RefreshTokenModel.get_active_sessions(customer_id) assert len(sessions) >= 3, "Should have at least 3 active sessions" logger.info(f"โœ… Found {len(sessions)} active sessions") # Revoke all user tokens revoked_count = await RefreshTokenModel.revoke_all_user_tokens(customer_id) assert revoked_count >= 3, "Should revoke at least 3 tokens" logger.info(f"โœ… Revoked all {revoked_count} user tokens") # Verify no active sessions sessions = await RefreshTokenModel.get_active_sessions(customer_id) assert len(sessions) == 0, "Should have no active sessions" logger.info("โœ… No active sessions remaining") return True except Exception as e: logger.error(f"โŒ Test failed: {str(e)}") raise async def cleanup_test_data(): """Cleanup test data""" logger.info("\n๐Ÿงน Cleaning up test data...") try: from app.core.nosql_client import db # Delete test tokens result = await db["refresh_tokens"].delete_many({ "customer_id": {"$in": ["test-user-123", "test-user-456"]} }) logger.info(f"โœ… Deleted {result.deleted_count} test tokens") # Clean up Redis token families from app.core.cache_client import get_redis redis = await get_redis() cursor = 0 deleted = 0 while True: cursor, keys = await redis.scan(cursor, match="token_family:*", count=100) for key in keys: await redis.delete(key) deleted += 1 if cursor == 0: break logger.info(f"โœ… Deleted {deleted} token families from Redis") except Exception as e: logger.error(f"โš ๏ธ Cleanup warning: {str(e)}") async def main(): """Run all tests""" logger.info("=" * 60) logger.info("Refresh Token Rotation Test Suite") logger.info("=" * 60) try: # Test 1: Token family creation family_id = await test_token_family_creation() # Test 2: Refresh token creation token1, token_id1, token2, token_id2 = await test_refresh_token_creation(family_id) # Test 3: Token storage await test_token_storage(token_id1, family_id) # Test 4: Token validation await test_token_validation(token_id1) # Test 5: Token revocation await test_token_revocation(family_id) # Test 6: Family revocation new_family_id = await RefreshTokenModel.create_token_family("test-user-123", "Test") await test_family_revocation(new_family_id) # Test 7: Session management await test_session_management() # Cleanup await cleanup_test_data() logger.info("\n" + "=" * 60) logger.info("โœ… All tests passed!") logger.info("=" * 60) except Exception as e: logger.error("\n" + "=" * 60) logger.error("โŒ Tests failed!") logger.error(f"Error: {str(e)}") logger.error("=" * 60) raise if __name__ == "__main__": asyncio.run(main())