from datetime import datetime, timedelta from typing import Optional import uuid import logging from app.core.nosql_client import db from app.core.cache_client import get_redis logger = logging.getLogger("refresh_token_model") class RefreshTokenModel: """Model for managing refresh tokens with rotation support""" collection = db["refresh_tokens"] # Token family tracking for rotation TOKEN_FAMILY_TTL = 30 * 24 * 3600 # 30 days in seconds @staticmethod async def create_token_family(customer_id: str, device_info: Optional[str] = None) -> str: """Create a new token family for refresh token rotation""" family_id = str(uuid.uuid4()) try: redis = await get_redis() family_key = f"token_family:{family_id}" family_data = { "customer_id": customer_id, "device_info": device_info, "created_at": datetime.utcnow().isoformat(), "rotation_count": 0 } import json await redis.setex(family_key, RefreshTokenModel.TOKEN_FAMILY_TTL, json.dumps(family_data)) logger.info(f"Created token family {family_id} for user {customer_id}") return family_id except Exception as e: logger.error(f"Error creating token family: {str(e)}", exc_info=True) raise @staticmethod async def get_token_family(family_id: str) -> Optional[dict]: """Get token family data""" try: redis = await get_redis() family_key = f"token_family:{family_id}" data = await redis.get(family_key) if data: import json return json.loads(data) return None except Exception as e: logger.error(f"Error getting token family: {str(e)}", exc_info=True) return None @staticmethod async def increment_rotation_count(family_id: str) -> int: """Increment rotation count for a token family""" try: redis = await get_redis() family_key = f"token_family:{family_id}" family_data = await RefreshTokenModel.get_token_family(family_id) if not family_data: return 0 family_data["rotation_count"] = family_data.get("rotation_count", 0) + 1 family_data["last_rotated"] = datetime.utcnow().isoformat() import json ttl = await redis.ttl(family_key) if ttl > 0: await redis.setex(family_key, ttl, json.dumps(family_data)) logger.info(f"Incremented rotation count for family {family_id} to {family_data['rotation_count']}") return family_data["rotation_count"] except Exception as e: logger.error(f"Error incrementing rotation count: {str(e)}", exc_info=True) return 0 @staticmethod async def store_refresh_token( token_id: str, customer_id: str, family_id: str, expires_at: datetime, remember_me: bool = False, device_info: Optional[str] = None, ip_address: Optional[str] = None ): """Store refresh token metadata""" try: token_doc = { "token_id": token_id, "customer_id": customer_id, "family_id": family_id, "expires_at": expires_at, "remember_me": remember_me, "device_info": device_info, "ip_address": ip_address, "created_at": datetime.utcnow(), "revoked": False, "used": False } await RefreshTokenModel.collection.insert_one(token_doc) logger.info(f"Stored refresh token {token_id} for user {customer_id}") except Exception as e: logger.error(f"Error storing refresh token: {str(e)}", exc_info=True) raise @staticmethod async def mark_token_as_used(token_id: str) -> bool: """Mark a refresh token as used (for rotation)""" try: result = await RefreshTokenModel.collection.update_one( {"token_id": token_id}, { "$set": { "used": True, "used_at": datetime.utcnow() } } ) if result.modified_count > 0: logger.info(f"Marked token {token_id} as used") return True return False except Exception as e: logger.error(f"Error marking token as used: {str(e)}", exc_info=True) return False @staticmethod async def is_token_valid(token_id: str) -> bool: """Check if a refresh token is valid (not revoked or used)""" try: token = await RefreshTokenModel.collection.find_one({"token_id": token_id}) if not token: logger.warning(f"Token {token_id} not found") return False if token.get("revoked"): logger.warning(f"Token {token_id} is revoked") return False if token.get("used"): logger.warning(f"Token {token_id} already used - possible replay attack") # Revoke entire token family on reuse attempt await RefreshTokenModel.revoke_token_family(token.get("family_id")) return False if token.get("expires_at") < datetime.utcnow(): logger.warning(f"Token {token_id} is expired") return False return True except Exception as e: logger.error(f"Error checking token validity: {str(e)}", exc_info=True) return False @staticmethod async def get_token_metadata(token_id: str) -> Optional[dict]: """Get refresh token metadata""" try: token = await RefreshTokenModel.collection.find_one({"token_id": token_id}) return token except Exception as e: logger.error(f"Error getting token metadata: {str(e)}", exc_info=True) return None @staticmethod async def revoke_token(token_id: str) -> bool: """Revoke a specific refresh token""" try: result = await RefreshTokenModel.collection.update_one( {"token_id": token_id}, { "$set": { "revoked": True, "revoked_at": datetime.utcnow() } } ) if result.modified_count > 0: logger.info(f"Revoked token {token_id}") return True return False except Exception as e: logger.error(f"Error revoking token: {str(e)}", exc_info=True) return False @staticmethod async def revoke_token_family(family_id: str) -> int: """Revoke all tokens in a family (security breach detection)""" try: result = await RefreshTokenModel.collection.update_many( {"family_id": family_id, "revoked": False}, { "$set": { "revoked": True, "revoked_at": datetime.utcnow(), "revoke_reason": "token_reuse_detected" } } ) # Also delete the family from Redis redis = await get_redis() await redis.delete(f"token_family:{family_id}") logger.warning(f"Revoked {result.modified_count} tokens in family {family_id}") return result.modified_count except Exception as e: logger.error(f"Error revoking token family: {str(e)}", exc_info=True) return 0 @staticmethod async def revoke_all_user_tokens(customer_id: str) -> int: """Revoke all refresh tokens for a user (logout from all devices)""" try: result = await RefreshTokenModel.collection.update_many( {"customer_id": customer_id, "revoked": False}, { "$set": { "revoked": True, "revoked_at": datetime.utcnow(), "revoke_reason": "user_logout_all" } } ) logger.info(f"Revoked {result.modified_count} tokens for user {customer_id}") return result.modified_count except Exception as e: logger.error(f"Error revoking all user tokens: {str(e)}", exc_info=True) return 0 @staticmethod async def get_active_sessions(customer_id: str) -> list: """Get all active sessions (valid refresh tokens) for a user""" try: tokens = await RefreshTokenModel.collection.find({ "customer_id": customer_id, "revoked": False, "expires_at": {"$gt": datetime.utcnow()} }).to_list(length=100) return tokens except Exception as e: logger.error(f"Error getting active sessions: {str(e)}", exc_info=True) return [] @staticmethod async def cleanup_expired_tokens(): """Cleanup expired tokens (run periodically)""" try: result = await RefreshTokenModel.collection.delete_many({ "expires_at": {"$lt": datetime.utcnow()} }) logger.info(f"Cleaned up {result.deleted_count} expired tokens") return result.deleted_count except Exception as e: logger.error(f"Error cleaning up expired tokens: {str(e)}", exc_info=True) return 0