bookmyservice-ums / app /models /refresh_token_model.py
MukeshKapoor25's picture
remember me
9b51d59
from datetime import datetime, timedelta
from typing import Optional
import uuid
import logging
from app.core.nosql_client import db
from app.core.cache_client import get_redis
logger = logging.getLogger("refresh_token_model")
class RefreshTokenModel:
"""Model for managing refresh tokens with rotation support"""
collection = db["refresh_tokens"]
# Token family tracking for rotation
TOKEN_FAMILY_TTL = 30 * 24 * 3600 # 30 days in seconds
@staticmethod
async def create_token_family(customer_id: str, device_info: Optional[str] = None) -> str:
"""Create a new token family for refresh token rotation"""
family_id = str(uuid.uuid4())
try:
redis = await get_redis()
family_key = f"token_family:{family_id}"
family_data = {
"customer_id": customer_id,
"device_info": device_info,
"created_at": datetime.utcnow().isoformat(),
"rotation_count": 0
}
import json
await redis.setex(family_key, RefreshTokenModel.TOKEN_FAMILY_TTL, json.dumps(family_data))
logger.info(f"Created token family {family_id} for user {customer_id}")
return family_id
except Exception as e:
logger.error(f"Error creating token family: {str(e)}", exc_info=True)
raise
@staticmethod
async def get_token_family(family_id: str) -> Optional[dict]:
"""Get token family data"""
try:
redis = await get_redis()
family_key = f"token_family:{family_id}"
data = await redis.get(family_key)
if data:
import json
return json.loads(data)
return None
except Exception as e:
logger.error(f"Error getting token family: {str(e)}", exc_info=True)
return None
@staticmethod
async def increment_rotation_count(family_id: str) -> int:
"""Increment rotation count for a token family"""
try:
redis = await get_redis()
family_key = f"token_family:{family_id}"
family_data = await RefreshTokenModel.get_token_family(family_id)
if not family_data:
return 0
family_data["rotation_count"] = family_data.get("rotation_count", 0) + 1
family_data["last_rotated"] = datetime.utcnow().isoformat()
import json
ttl = await redis.ttl(family_key)
if ttl > 0:
await redis.setex(family_key, ttl, json.dumps(family_data))
logger.info(f"Incremented rotation count for family {family_id} to {family_data['rotation_count']}")
return family_data["rotation_count"]
except Exception as e:
logger.error(f"Error incrementing rotation count: {str(e)}", exc_info=True)
return 0
@staticmethod
async def store_refresh_token(
token_id: str,
customer_id: str,
family_id: str,
expires_at: datetime,
remember_me: bool = False,
device_info: Optional[str] = None,
ip_address: Optional[str] = None
):
"""Store refresh token metadata"""
try:
token_doc = {
"token_id": token_id,
"customer_id": customer_id,
"family_id": family_id,
"expires_at": expires_at,
"remember_me": remember_me,
"device_info": device_info,
"ip_address": ip_address,
"created_at": datetime.utcnow(),
"revoked": False,
"used": False
}
await RefreshTokenModel.collection.insert_one(token_doc)
logger.info(f"Stored refresh token {token_id} for user {customer_id}")
except Exception as e:
logger.error(f"Error storing refresh token: {str(e)}", exc_info=True)
raise
@staticmethod
async def mark_token_as_used(token_id: str) -> bool:
"""Mark a refresh token as used (for rotation)"""
try:
result = await RefreshTokenModel.collection.update_one(
{"token_id": token_id},
{
"$set": {
"used": True,
"used_at": datetime.utcnow()
}
}
)
if result.modified_count > 0:
logger.info(f"Marked token {token_id} as used")
return True
return False
except Exception as e:
logger.error(f"Error marking token as used: {str(e)}", exc_info=True)
return False
@staticmethod
async def is_token_valid(token_id: str) -> bool:
"""Check if a refresh token is valid (not revoked or used)"""
try:
token = await RefreshTokenModel.collection.find_one({"token_id": token_id})
if not token:
logger.warning(f"Token {token_id} not found")
return False
if token.get("revoked"):
logger.warning(f"Token {token_id} is revoked")
return False
if token.get("used"):
logger.warning(f"Token {token_id} already used - possible replay attack")
# Revoke entire token family on reuse attempt
await RefreshTokenModel.revoke_token_family(token.get("family_id"))
return False
if token.get("expires_at") < datetime.utcnow():
logger.warning(f"Token {token_id} is expired")
return False
return True
except Exception as e:
logger.error(f"Error checking token validity: {str(e)}", exc_info=True)
return False
@staticmethod
async def get_token_metadata(token_id: str) -> Optional[dict]:
"""Get refresh token metadata"""
try:
token = await RefreshTokenModel.collection.find_one({"token_id": token_id})
return token
except Exception as e:
logger.error(f"Error getting token metadata: {str(e)}", exc_info=True)
return None
@staticmethod
async def revoke_token(token_id: str) -> bool:
"""Revoke a specific refresh token"""
try:
result = await RefreshTokenModel.collection.update_one(
{"token_id": token_id},
{
"$set": {
"revoked": True,
"revoked_at": datetime.utcnow()
}
}
)
if result.modified_count > 0:
logger.info(f"Revoked token {token_id}")
return True
return False
except Exception as e:
logger.error(f"Error revoking token: {str(e)}", exc_info=True)
return False
@staticmethod
async def revoke_token_family(family_id: str) -> int:
"""Revoke all tokens in a family (security breach detection)"""
try:
result = await RefreshTokenModel.collection.update_many(
{"family_id": family_id, "revoked": False},
{
"$set": {
"revoked": True,
"revoked_at": datetime.utcnow(),
"revoke_reason": "token_reuse_detected"
}
}
)
# Also delete the family from Redis
redis = await get_redis()
await redis.delete(f"token_family:{family_id}")
logger.warning(f"Revoked {result.modified_count} tokens in family {family_id}")
return result.modified_count
except Exception as e:
logger.error(f"Error revoking token family: {str(e)}", exc_info=True)
return 0
@staticmethod
async def revoke_all_user_tokens(customer_id: str) -> int:
"""Revoke all refresh tokens for a user (logout from all devices)"""
try:
result = await RefreshTokenModel.collection.update_many(
{"customer_id": customer_id, "revoked": False},
{
"$set": {
"revoked": True,
"revoked_at": datetime.utcnow(),
"revoke_reason": "user_logout_all"
}
}
)
logger.info(f"Revoked {result.modified_count} tokens for user {customer_id}")
return result.modified_count
except Exception as e:
logger.error(f"Error revoking all user tokens: {str(e)}", exc_info=True)
return 0
@staticmethod
async def get_active_sessions(customer_id: str) -> list:
"""Get all active sessions (valid refresh tokens) for a user"""
try:
tokens = await RefreshTokenModel.collection.find({
"customer_id": customer_id,
"revoked": False,
"expires_at": {"$gt": datetime.utcnow()}
}).to_list(length=100)
return tokens
except Exception as e:
logger.error(f"Error getting active sessions: {str(e)}", exc_info=True)
return []
@staticmethod
async def cleanup_expired_tokens():
"""Cleanup expired tokens (run periodically)"""
try:
result = await RefreshTokenModel.collection.delete_many({
"expires_at": {"$lt": datetime.utcnow()}
})
logger.info(f"Cleaned up {result.deleted_count} expired tokens")
return result.deleted_count
except Exception as e:
logger.error(f"Error cleaning up expired tokens: {str(e)}", exc_info=True)
return 0