Spaces:
Sleeping
Sleeping
| from datetime import datetime, timedelta | |
| from typing import Optional | |
| import uuid | |
| import logging | |
| from app.core.nosql_client import db | |
| from app.core.cache_client import get_redis | |
| logger = logging.getLogger("refresh_token_model") | |
| class RefreshTokenModel: | |
| """Model for managing refresh tokens with rotation support""" | |
| collection = db["refresh_tokens"] | |
| # Token family tracking for rotation | |
| TOKEN_FAMILY_TTL = 30 * 24 * 3600 # 30 days in seconds | |
| async def create_token_family(customer_id: str, device_info: Optional[str] = None) -> str: | |
| """Create a new token family for refresh token rotation""" | |
| family_id = str(uuid.uuid4()) | |
| try: | |
| redis = await get_redis() | |
| family_key = f"token_family:{family_id}" | |
| family_data = { | |
| "customer_id": customer_id, | |
| "device_info": device_info, | |
| "created_at": datetime.utcnow().isoformat(), | |
| "rotation_count": 0 | |
| } | |
| import json | |
| await redis.setex(family_key, RefreshTokenModel.TOKEN_FAMILY_TTL, json.dumps(family_data)) | |
| logger.info(f"Created token family {family_id} for user {customer_id}") | |
| return family_id | |
| except Exception as e: | |
| logger.error(f"Error creating token family: {str(e)}", exc_info=True) | |
| raise | |
| async def get_token_family(family_id: str) -> Optional[dict]: | |
| """Get token family data""" | |
| try: | |
| redis = await get_redis() | |
| family_key = f"token_family:{family_id}" | |
| data = await redis.get(family_key) | |
| if data: | |
| import json | |
| return json.loads(data) | |
| return None | |
| except Exception as e: | |
| logger.error(f"Error getting token family: {str(e)}", exc_info=True) | |
| return None | |
| async def increment_rotation_count(family_id: str) -> int: | |
| """Increment rotation count for a token family""" | |
| try: | |
| redis = await get_redis() | |
| family_key = f"token_family:{family_id}" | |
| family_data = await RefreshTokenModel.get_token_family(family_id) | |
| if not family_data: | |
| return 0 | |
| family_data["rotation_count"] = family_data.get("rotation_count", 0) + 1 | |
| family_data["last_rotated"] = datetime.utcnow().isoformat() | |
| import json | |
| ttl = await redis.ttl(family_key) | |
| if ttl > 0: | |
| await redis.setex(family_key, ttl, json.dumps(family_data)) | |
| logger.info(f"Incremented rotation count for family {family_id} to {family_data['rotation_count']}") | |
| return family_data["rotation_count"] | |
| except Exception as e: | |
| logger.error(f"Error incrementing rotation count: {str(e)}", exc_info=True) | |
| return 0 | |
| async def store_refresh_token( | |
| token_id: str, | |
| customer_id: str, | |
| family_id: str, | |
| expires_at: datetime, | |
| remember_me: bool = False, | |
| device_info: Optional[str] = None, | |
| ip_address: Optional[str] = None | |
| ): | |
| """Store refresh token metadata""" | |
| try: | |
| token_doc = { | |
| "token_id": token_id, | |
| "customer_id": customer_id, | |
| "family_id": family_id, | |
| "expires_at": expires_at, | |
| "remember_me": remember_me, | |
| "device_info": device_info, | |
| "ip_address": ip_address, | |
| "created_at": datetime.utcnow(), | |
| "revoked": False, | |
| "used": False | |
| } | |
| await RefreshTokenModel.collection.insert_one(token_doc) | |
| logger.info(f"Stored refresh token {token_id} for user {customer_id}") | |
| except Exception as e: | |
| logger.error(f"Error storing refresh token: {str(e)}", exc_info=True) | |
| raise | |
| async def mark_token_as_used(token_id: str) -> bool: | |
| """Mark a refresh token as used (for rotation)""" | |
| try: | |
| result = await RefreshTokenModel.collection.update_one( | |
| {"token_id": token_id}, | |
| { | |
| "$set": { | |
| "used": True, | |
| "used_at": datetime.utcnow() | |
| } | |
| } | |
| ) | |
| if result.modified_count > 0: | |
| logger.info(f"Marked token {token_id} as used") | |
| return True | |
| return False | |
| except Exception as e: | |
| logger.error(f"Error marking token as used: {str(e)}", exc_info=True) | |
| return False | |
| async def is_token_valid(token_id: str) -> bool: | |
| """Check if a refresh token is valid (not revoked or used)""" | |
| try: | |
| token = await RefreshTokenModel.collection.find_one({"token_id": token_id}) | |
| if not token: | |
| logger.warning(f"Token {token_id} not found") | |
| return False | |
| if token.get("revoked"): | |
| logger.warning(f"Token {token_id} is revoked") | |
| return False | |
| if token.get("used"): | |
| logger.warning(f"Token {token_id} already used - possible replay attack") | |
| # Revoke entire token family on reuse attempt | |
| await RefreshTokenModel.revoke_token_family(token.get("family_id")) | |
| return False | |
| if token.get("expires_at") < datetime.utcnow(): | |
| logger.warning(f"Token {token_id} is expired") | |
| return False | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error checking token validity: {str(e)}", exc_info=True) | |
| return False | |
| async def get_token_metadata(token_id: str) -> Optional[dict]: | |
| """Get refresh token metadata""" | |
| try: | |
| token = await RefreshTokenModel.collection.find_one({"token_id": token_id}) | |
| return token | |
| except Exception as e: | |
| logger.error(f"Error getting token metadata: {str(e)}", exc_info=True) | |
| return None | |
| async def revoke_token(token_id: str) -> bool: | |
| """Revoke a specific refresh token""" | |
| try: | |
| result = await RefreshTokenModel.collection.update_one( | |
| {"token_id": token_id}, | |
| { | |
| "$set": { | |
| "revoked": True, | |
| "revoked_at": datetime.utcnow() | |
| } | |
| } | |
| ) | |
| if result.modified_count > 0: | |
| logger.info(f"Revoked token {token_id}") | |
| return True | |
| return False | |
| except Exception as e: | |
| logger.error(f"Error revoking token: {str(e)}", exc_info=True) | |
| return False | |
| async def revoke_token_family(family_id: str) -> int: | |
| """Revoke all tokens in a family (security breach detection)""" | |
| try: | |
| result = await RefreshTokenModel.collection.update_many( | |
| {"family_id": family_id, "revoked": False}, | |
| { | |
| "$set": { | |
| "revoked": True, | |
| "revoked_at": datetime.utcnow(), | |
| "revoke_reason": "token_reuse_detected" | |
| } | |
| } | |
| ) | |
| # Also delete the family from Redis | |
| redis = await get_redis() | |
| await redis.delete(f"token_family:{family_id}") | |
| logger.warning(f"Revoked {result.modified_count} tokens in family {family_id}") | |
| return result.modified_count | |
| except Exception as e: | |
| logger.error(f"Error revoking token family: {str(e)}", exc_info=True) | |
| return 0 | |
| async def revoke_all_user_tokens(customer_id: str) -> int: | |
| """Revoke all refresh tokens for a user (logout from all devices)""" | |
| try: | |
| result = await RefreshTokenModel.collection.update_many( | |
| {"customer_id": customer_id, "revoked": False}, | |
| { | |
| "$set": { | |
| "revoked": True, | |
| "revoked_at": datetime.utcnow(), | |
| "revoke_reason": "user_logout_all" | |
| } | |
| } | |
| ) | |
| logger.info(f"Revoked {result.modified_count} tokens for user {customer_id}") | |
| return result.modified_count | |
| except Exception as e: | |
| logger.error(f"Error revoking all user tokens: {str(e)}", exc_info=True) | |
| return 0 | |
| async def get_active_sessions(customer_id: str) -> list: | |
| """Get all active sessions (valid refresh tokens) for a user""" | |
| try: | |
| tokens = await RefreshTokenModel.collection.find({ | |
| "customer_id": customer_id, | |
| "revoked": False, | |
| "expires_at": {"$gt": datetime.utcnow()} | |
| }).to_list(length=100) | |
| return tokens | |
| except Exception as e: | |
| logger.error(f"Error getting active sessions: {str(e)}", exc_info=True) | |
| return [] | |
| async def cleanup_expired_tokens(): | |
| """Cleanup expired tokens (run periodically)""" | |
| try: | |
| result = await RefreshTokenModel.collection.delete_many({ | |
| "expires_at": {"$lt": datetime.utcnow()} | |
| }) | |
| logger.info(f"Cleaned up {result.deleted_count} expired tokens") | |
| return result.deleted_count | |
| except Exception as e: | |
| logger.error(f"Error cleaning up expired tokens: {str(e)}", exc_info=True) | |
| return 0 | |