File size: 4,668 Bytes
b407a42
 
 
8ee3f02
b407a42
2e26396
 
 
975cfc8
ed83da2
9b51d59
b407a42
975cfc8
 
 
 
 
b407a42
9b51d59
 
 
2e26396
 
8ee3f02
ed83da2
 
 
975cfc8
b407a42
 
 
8a17472
ed83da2
 
 
 
 
 
 
 
 
8ee3f02
 
9b51d59
 
 
 
 
 
 
4ddb8b3
9b51d59
 
 
 
 
4ddb8b3
9b51d59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed83da2
 
9b51d59
ed83da2
 
9b51d59
ed83da2
9b51d59
 
4ddb8b3
975cfc8
ed83da2
8ee3f02
 
 
 
ed83da2
 
 
 
 
 
2e26396
 
 
 
 
 
 
 
 
 
 
 
 
 
a9ccd3b
 
ed83da2
2e26396
ed83da2
 
2e26396
ed83da2
 
2e26396
 
 
 
 
 
 
ed83da2
 
 
2e26396
 
a9ccd3b
2e26396
 
 
ed83da2
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137

## bookmyservice-ums/app/utils/jwt.py

from jose import jwt, JWTError
from datetime import datetime, timedelta
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from typing import Optional
from app.core.config import settings
import logging
import uuid

SECRET_KEY = settings.JWT_SECRET_KEY
ALGORITHM = settings.JWT_ALGORITHM
ACCESS_EXPIRE_MINUTES_DEFAULT = settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES
REFRESH_EXPIRE_DAYS_DEFAULT = settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
TEMP_EXPIRE_MINUTES_DEFAULT = settings.JWT_TEMP_TOKEN_EXPIRE_MINUTES

# Remember me settings
REMEMBER_ME_REFRESH_EXPIRE_DAYS = settings.JWT_REMEMBER_ME_EXPIRE_DAYS

# Security scheme
security = HTTPBearer()

# Module logger (app-level logging config applies)
logger = logging.getLogger(__name__)

def create_access_token(data: dict, expires_minutes: int = ACCESS_EXPIRE_MINUTES_DEFAULT):
    to_encode = data.copy()
    expire = datetime.utcnow() + timedelta(minutes=expires_minutes)
    to_encode.update({"exp": expire})

    # Avoid logging sensitive payload; log minimal context
    logger.info(
        "Creating access token",
    )
    logger.info(
        "Access token claims keys=%s expires_at=%s",
        list(to_encode.keys()),
        expire.isoformat(),
    )
    return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)

def create_refresh_token(
    data: dict, 
    expires_days: int = REFRESH_EXPIRE_DAYS_DEFAULT,
    remember_me: bool = False,
    family_id: Optional[str] = None
):
    """Create refresh token with rotation support"""
    to_encode = data.copy()
    
    # Use longer expiry for remember me
    if remember_me:
        expires_days = REMEMBER_ME_REFRESH_EXPIRE_DAYS
    
    expire = datetime.utcnow() + timedelta(days=expires_days)
    
    # Generate unique token ID for tracking
    token_id = str(uuid.uuid4())
    
    to_encode.update({
        "exp": expire,
        "type": "refresh",
        "jti": token_id,  # JWT ID for token tracking
        "remember_me": remember_me
    })
    
    # Add family ID for rotation tracking
    if family_id:
        to_encode["family_id"] = family_id
    
    logger.info("Creating refresh token")
    logger.info(
        "Refresh token claims keys=%s expires_at=%s remember_me=%s",
        list(to_encode.keys()),
        expire.isoformat(),
        remember_me
    )
    
    return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM), token_id, expire

def create_temp_token(data: dict, expires_minutes: int = TEMP_EXPIRE_MINUTES_DEFAULT):
    logger.info("Creating temporary access token with short expiry")
    return create_access_token(data, expires_minutes=expires_minutes)

def decode_token(token: str) -> dict:
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        logger.info("Token decoded successfully")
        logger.info("Decoded claims keys=%s", list(payload.keys()))
        return payload
    except JWTError as e:
        logger.warning("Token decode failed: %s", str(e))
        return {}

def verify_token(token: str) -> dict:
    """
    Verify and decode JWT token, raise HTTPException if invalid.
    """
    credentials_exception = HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="Could not validate credentials",
        headers={"WWW-Authenticate": "Bearer"},
    )
    
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        customer_id: str = payload.get("sub")
        if customer_id is None:
            logger.warning("Verified token missing 'sub' claim")
            raise credentials_exception
        logger.info("Token verified for subject")
        logger.info("Verified claims keys=%s", list(payload.keys()))
        return payload
    except JWTError as e:
        logger.error("Token verification failed: %s", str(e))
        raise credentials_exception

async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)) -> dict:
    """
    Dependency to get current authenticated user from JWT token.
    """
    token = credentials.credentials
    logger.info("Authenticating request with Bearer token")
    # Don't log raw tokens; log minimal metadata
    logger.info("Bearer token length=%d", len(token) if token else 0)
    return verify_token(token)

async def get_current_customer_id(current_user: dict = Depends(get_current_user)) -> str:
    """
    Dependency to get current user ID.
    """
    customer_id = current_user.get("sub")
    logger.info("Resolved current customer id")
    logger.info("Current customer id=%s", customer_id)
    return customer_id