## bookmyservice-ums/app/utils/jwt.py from jose import jwt, JWTError from datetime import datetime, timedelta from fastapi import Depends, HTTPException, status from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from typing import Optional from app.core.config import settings import logging import uuid SECRET_KEY = settings.JWT_SECRET_KEY ALGORITHM = settings.JWT_ALGORITHM ACCESS_EXPIRE_MINUTES_DEFAULT = settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES REFRESH_EXPIRE_DAYS_DEFAULT = settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS TEMP_EXPIRE_MINUTES_DEFAULT = settings.JWT_TEMP_TOKEN_EXPIRE_MINUTES # Remember me settings REMEMBER_ME_REFRESH_EXPIRE_DAYS = settings.JWT_REMEMBER_ME_EXPIRE_DAYS # Security scheme security = HTTPBearer() # Module logger (app-level logging config applies) logger = logging.getLogger(__name__) def create_access_token(data: dict, expires_minutes: int = ACCESS_EXPIRE_MINUTES_DEFAULT): to_encode = data.copy() expire = datetime.utcnow() + timedelta(minutes=expires_minutes) to_encode.update({"exp": expire}) # Avoid logging sensitive payload; log minimal context logger.info( "Creating access token", ) logger.info( "Access token claims keys=%s expires_at=%s", list(to_encode.keys()), expire.isoformat(), ) return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) def create_refresh_token( data: dict, expires_days: int = REFRESH_EXPIRE_DAYS_DEFAULT, remember_me: bool = False, family_id: Optional[str] = None ): """Create refresh token with rotation support""" to_encode = data.copy() # Use longer expiry for remember me if remember_me: expires_days = REMEMBER_ME_REFRESH_EXPIRE_DAYS expire = datetime.utcnow() + timedelta(days=expires_days) # Generate unique token ID for tracking token_id = str(uuid.uuid4()) to_encode.update({ "exp": expire, "type": "refresh", "jti": token_id, # JWT ID for token tracking "remember_me": remember_me }) # Add family ID for rotation tracking if family_id: to_encode["family_id"] = family_id logger.info("Creating refresh token") logger.info( "Refresh token claims keys=%s expires_at=%s remember_me=%s", list(to_encode.keys()), expire.isoformat(), remember_me ) return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM), token_id, expire def create_temp_token(data: dict, expires_minutes: int = TEMP_EXPIRE_MINUTES_DEFAULT): logger.info("Creating temporary access token with short expiry") return create_access_token(data, expires_minutes=expires_minutes) def decode_token(token: str) -> dict: try: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) logger.info("Token decoded successfully") logger.info("Decoded claims keys=%s", list(payload.keys())) return payload except JWTError as e: logger.warning("Token decode failed: %s", str(e)) return {} def verify_token(token: str) -> dict: """ Verify and decode JWT token, raise HTTPException if invalid. """ credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) try: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) customer_id: str = payload.get("sub") if customer_id is None: logger.warning("Verified token missing 'sub' claim") raise credentials_exception logger.info("Token verified for subject") logger.info("Verified claims keys=%s", list(payload.keys())) return payload except JWTError as e: logger.error("Token verification failed: %s", str(e)) raise credentials_exception async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)) -> dict: """ Dependency to get current authenticated user from JWT token. """ token = credentials.credentials logger.info("Authenticating request with Bearer token") # Don't log raw tokens; log minimal metadata logger.info("Bearer token length=%d", len(token) if token else 0) return verify_token(token) async def get_current_customer_id(current_user: dict = Depends(get_current_user)) -> str: """ Dependency to get current user ID. """ customer_id = current_user.get("sub") logger.info("Resolved current customer id") logger.info("Current customer id=%s", customer_id) return customer_id