MukeshKapoor25's picture
remember me
9b51d59
## bookmyservice-ums/app/utils/jwt.py
from jose import jwt, JWTError
from datetime import datetime, timedelta
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from typing import Optional
from app.core.config import settings
import logging
import uuid
SECRET_KEY = settings.JWT_SECRET_KEY
ALGORITHM = settings.JWT_ALGORITHM
ACCESS_EXPIRE_MINUTES_DEFAULT = settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES
REFRESH_EXPIRE_DAYS_DEFAULT = settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
TEMP_EXPIRE_MINUTES_DEFAULT = settings.JWT_TEMP_TOKEN_EXPIRE_MINUTES
# Remember me settings
REMEMBER_ME_REFRESH_EXPIRE_DAYS = settings.JWT_REMEMBER_ME_EXPIRE_DAYS
# Security scheme
security = HTTPBearer()
# Module logger (app-level logging config applies)
logger = logging.getLogger(__name__)
def create_access_token(data: dict, expires_minutes: int = ACCESS_EXPIRE_MINUTES_DEFAULT):
to_encode = data.copy()
expire = datetime.utcnow() + timedelta(minutes=expires_minutes)
to_encode.update({"exp": expire})
# Avoid logging sensitive payload; log minimal context
logger.info(
"Creating access token",
)
logger.info(
"Access token claims keys=%s expires_at=%s",
list(to_encode.keys()),
expire.isoformat(),
)
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
def create_refresh_token(
data: dict,
expires_days: int = REFRESH_EXPIRE_DAYS_DEFAULT,
remember_me: bool = False,
family_id: Optional[str] = None
):
"""Create refresh token with rotation support"""
to_encode = data.copy()
# Use longer expiry for remember me
if remember_me:
expires_days = REMEMBER_ME_REFRESH_EXPIRE_DAYS
expire = datetime.utcnow() + timedelta(days=expires_days)
# Generate unique token ID for tracking
token_id = str(uuid.uuid4())
to_encode.update({
"exp": expire,
"type": "refresh",
"jti": token_id, # JWT ID for token tracking
"remember_me": remember_me
})
# Add family ID for rotation tracking
if family_id:
to_encode["family_id"] = family_id
logger.info("Creating refresh token")
logger.info(
"Refresh token claims keys=%s expires_at=%s remember_me=%s",
list(to_encode.keys()),
expire.isoformat(),
remember_me
)
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM), token_id, expire
def create_temp_token(data: dict, expires_minutes: int = TEMP_EXPIRE_MINUTES_DEFAULT):
logger.info("Creating temporary access token with short expiry")
return create_access_token(data, expires_minutes=expires_minutes)
def decode_token(token: str) -> dict:
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
logger.info("Token decoded successfully")
logger.info("Decoded claims keys=%s", list(payload.keys()))
return payload
except JWTError as e:
logger.warning("Token decode failed: %s", str(e))
return {}
def verify_token(token: str) -> dict:
"""
Verify and decode JWT token, raise HTTPException if invalid.
"""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
customer_id: str = payload.get("sub")
if customer_id is None:
logger.warning("Verified token missing 'sub' claim")
raise credentials_exception
logger.info("Token verified for subject")
logger.info("Verified claims keys=%s", list(payload.keys()))
return payload
except JWTError as e:
logger.error("Token verification failed: %s", str(e))
raise credentials_exception
async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)) -> dict:
"""
Dependency to get current authenticated user from JWT token.
"""
token = credentials.credentials
logger.info("Authenticating request with Bearer token")
# Don't log raw tokens; log minimal metadata
logger.info("Bearer token length=%d", len(token) if token else 0)
return verify_token(token)
async def get_current_customer_id(current_user: dict = Depends(get_current_user)) -> str:
"""
Dependency to get current user ID.
"""
customer_id = current_user.get("sub")
logger.info("Resolved current customer id")
logger.info("Current customer id=%s", customer_id)
return customer_id