# ============================================================ # app/core/security.py - JWT & Password Management (Secure) # ============================================================ import jwt import uuid import logging from datetime import datetime, timedelta, timezone from typing import Optional, Dict, Any from passlib.context import CryptContext from app.config import settings logger = logging.getLogger(__name__) # ============================================================ # Startup Validation # ============================================================ def validate_jwt_config(): """Validate JWT configuration at startup - FAIL if secrets not set""" if not settings.JWT_SECRET: raise RuntimeError( "CRITICAL: JWT_SECRET environment variable must be set! " "Cannot start application without a secure JWT secret." ) if not settings.JWT_REFRESH_SECRET: raise RuntimeError( "CRITICAL: JWT_REFRESH_SECRET environment variable must be set! " "Cannot start application without a secure refresh token secret." ) if settings.JWT_SECRET == settings.JWT_REFRESH_SECRET: raise RuntimeError( "CRITICAL: JWT_SECRET and JWT_REFRESH_SECRET must be different!" ) logger.info("✅ JWT configuration validated") # ============================================================ # Password Hashing # ============================================================ pwd_context = CryptContext( schemes=["bcrypt"], deprecated="auto", bcrypt__rounds=settings.BCRYPT_ROUNDS, ) def hash_password(password: str) -> str: """Hash password using bcrypt""" try: return pwd_context.hash(password) except Exception as e: logger.error(f"Error hashing password: {str(e)}") raise def verify_password(plain_password: str, hashed_password: str) -> bool: """Verify password against hash""" try: return pwd_context.verify(plain_password, hashed_password) except Exception as e: logger.error(f"Error verifying password: {str(e)}") return False # ============================================================ # Password Strength Validation # ============================================================ import re def validate_password_strength(password: str) -> None: """ Validate password meets security requirements. Raises ValueError if password is too weak. """ errors = [] if len(password) < 8: errors.append("at least 8 characters") if not re.search(r"[A-Z]", password): errors.append("an uppercase letter") if not re.search(r"[a-z]", password): errors.append("a lowercase letter") if not re.search(r"\d", password): errors.append("a number") if errors: raise ValueError(f"Password must contain {', '.join(errors)}") # ============================================================ # JWT Token Management - Dual Token Strategy # ============================================================ def _create_token( data: Dict[str, Any], secret: str, expires_delta: timedelta, token_type: str, ) -> str: """ Internal: Create JWT token with full security claims """ now = datetime.now(timezone.utc) to_encode = data.copy() to_encode.update({ "exp": now + expires_delta, "iat": now, # Issued At "jti": str(uuid.uuid4()), # JWT ID (for blacklisting) "iss": settings.JWT_ISSUER, # Issuer "aud": settings.JWT_AUDIENCE, # Audience "type": token_type, }) encoded_jwt = jwt.encode( to_encode, secret, algorithm=settings.JWT_ALGORITHM, ) return encoded_jwt def _verify_token( token: str, secret: str, expected_type: str, ) -> Optional[Dict[str, Any]]: """ Internal: Verify and decode JWT token with full validation """ try: payload = jwt.decode( token, secret, algorithms=[settings.JWT_ALGORITHM], issuer=settings.JWT_ISSUER, audience=settings.JWT_AUDIENCE, ) # Verify token type if payload.get("type") != expected_type: logger.warning(f"Token type mismatch: expected {expected_type}, got {payload.get('type')}") return None return payload except jwt.ExpiredSignatureError: logger.warning("Token has expired") return None except jwt.InvalidIssuerError: logger.warning("Invalid token issuer") return None except jwt.InvalidAudienceError: logger.warning("Invalid token audience") return None except jwt.InvalidTokenError as e: logger.warning(f"Invalid token: {str(e)}") return None except Exception as e: logger.error(f"Error verifying token: {str(e)}") return None # ============================================================ # Access Token (Short-lived: 15 minutes) # ============================================================ def create_access_token( user_id: str, email: Optional[str], phone: Optional[str], role: str, ) -> str: """ Create short-lived access token (15 minutes). Use this for API authentication. """ return _create_token( data={ "user_id": user_id, "sub": user_id, # Standard JWT subject claim "email": email, "phone": phone, "role": role, }, secret=settings.JWT_SECRET, expires_delta=timedelta(minutes=settings.JWT_ACCESS_EXPIRY_MINUTES), token_type="access", ) def verify_access_token(token: str) -> Optional[Dict[str, Any]]: """Verify and decode access token""" return _verify_token(token, settings.JWT_SECRET, "access") # ============================================================ # Refresh Token (Long-lived: 30 days) # ============================================================ def create_refresh_token( user_id: str, email: Optional[str], phone: Optional[str], role: str, ) -> str: """ Create long-lived refresh token (30 days). Use this to obtain new access tokens without re-login. """ return _create_token( data={ "user_id": user_id, "sub": user_id, "email": email, "phone": phone, "role": role, }, secret=settings.JWT_REFRESH_SECRET, # Different secret! expires_delta=timedelta(days=settings.JWT_REFRESH_EXPIRY_DAYS), token_type="refresh", ) def verify_refresh_token(token: str) -> Optional[Dict[str, Any]]: """Verify and decode refresh token""" return _verify_token(token, settings.JWT_REFRESH_SECRET, "refresh") # ============================================================ # Token Pair Creation (Login Response) # ============================================================ def create_token_pair( user_id: str, email: Optional[str], phone: Optional[str], role: str, ) -> Dict[str, str]: """ Create long-lived access token for login. Returns dict with access_token. """ access_token = create_access_token(user_id, email, phone, role) # refresh_token = create_refresh_token(user_id, email, phone, role) # DEPRECATED logger.info(f"Long-lived token created for user: {user_id}") return { "access_token": access_token, # "refresh_token": refresh_token, # DEPRECATED "token_type": "bearer", "expires_in": settings.JWT_ACCESS_EXPIRY_MINUTES * 60, # In seconds (60 days) } # ============================================================ # Password Reset Token (Short-lived: 10 minutes) # ============================================================ def create_reset_token(identifier: str) -> str: """Create short-lived password reset token (10 minutes)""" return _create_token( data={ "identifier": identifier, "purpose": "password_reset", }, secret=settings.JWT_SECRET, expires_delta=timedelta(minutes=settings.JWT_RESET_EXPIRY_MINUTES), token_type="reset", ) def verify_reset_token(token: str) -> Optional[Dict[str, Any]]: """Verify and decode reset token""" payload = _verify_token(token, settings.JWT_SECRET, "reset") if payload and payload.get("purpose") != "password_reset": logger.warning("Reset token has wrong purpose") return None return payload # ============================================================ # Legacy Aliases (for backward compatibility) # ============================================================ def create_login_token(user_id: str, email: Optional[str], phone: Optional[str], role: str) -> str: """DEPRECATED: Use create_access_token() instead""" logger.warning("create_login_token is deprecated, use create_access_token") return create_access_token(user_id, email, phone, role) def decode_reset_token(token: str) -> Optional[Dict[str, Any]]: """DEPRECATED: Use verify_reset_token() instead""" return verify_reset_token(token) def decode_access_token(token: str) -> Optional[Dict[str, Any]]: """DEPRECATED: Use verify_access_token() instead""" return verify_access_token(token) # Alias for backward compatibility verify_token = verify_access_token