""" Security utilities for AegisLM SaaS Backend. Production-ready JWT handling, password hashing, API key generation, and security utilities. """ import secrets from datetime import datetime, timedelta from typing import Optional, Union, Any from jose import JWTError, jwt from passlib.context import CryptContext from passlib.hash import bcrypt from .config import settings # Password hashing context pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") def create_access_token( subject: Union[str, Any], expires_delta: Optional[timedelta] = None ) -> str: """ Create JWT access token. Args: subject: Token subject (usually user ID) expires_delta: Optional expiration delta Returns: str: JWT token """ if expires_delta: expire = datetime.utcnow() + expires_delta else: expire = datetime.utcnow() + timedelta( minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES ) to_encode = {"exp": expire, "sub": str(subject)} encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM) return encoded_jwt def verify_token(token: str) -> Optional[str]: """ Verify and decode JWT token. Args: token: JWT token to verify Returns: Optional[str]: User ID if valid, None otherwise """ try: payload = jwt.decode( token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM] ) user_id: str = payload.get("sub") if user_id is None: return None return user_id except JWTError: return None def get_password_hash(password: str) -> str: """ Hash password using Argon2 (more secure than bcrypt). Args: password: Plain text password (8-72 characters) Returns: str: Hashed password """ # Validate password length for security if len(password) < 8: raise ValueError("Password must be at least 8 characters long for security") if len(password) > 72: raise ValueError("Password must be less than 72 characters long") try: # Use Argon2 - more secure and better compatibility from passlib.context import CryptContext pwd_context = CryptContext( schemes=["argon2"], argon2__time_cost=3, # Number of iterations argon2__memory_cost=65536, # Memory usage in KB argon2__parallelism=4, # Number of parallel threads argon2__hash_len=32, # Hash length deprecated="auto" ) return pwd_context.hash(password) except ImportError: # Fallback to bcrypt if Argon2 not available try: from passlib.context import CryptContext pwd_context = CryptContext( schemes=["bcrypt"], bcrypt__rounds=12, deprecated="auto" ) return pwd_context.hash(password) except Exception as e: raise Exception(f"Password hashing failed: {str(e)}") except Exception as e: raise Exception(f"Password hashing failed: {str(e)}") def verify_password(plain_password: str, hashed_password: str) -> bool: """ Verify password against hash (supports Argon2 and bcrypt). Args: plain_password: Plain text password hashed_password: Hashed password Returns: bool: True if password matches """ try: # Use Argon2 context for verification from passlib.context import CryptContext pwd_context = CryptContext( schemes=["argon2", "bcrypt"], # Support both argon2__time_cost=3, argon2__memory_cost=65536, argon2__parallelism=4, argon2__hash_len=32, deprecated="auto" ) return pwd_context.verify(plain_password, hashed_password) except Exception as e: return False def generate_api_key() -> str: """ Generate secure API key. Returns: str: API key """ return secrets.token_urlsafe(settings.API_KEY_LENGTH) def generate_secure_token(length: int = 32) -> str: """ Generate secure random token. Args: length: Token length Returns: str: Secure token """ return secrets.token_urlsafe(length) def mask_sensitive_data(data: str, mask_char: str = "*", visible_chars: int = 4) -> str: """ Mask sensitive data for logging. Args: data: Sensitive data to mask mask_char: Character to use for masking visible_chars: Number of characters to keep visible Returns: str: Masked data """ if len(data) <= visible_chars: return mask_char * len(data) return data[:visible_chars] + mask_char * (len(data) - visible_chars) def validate_password_strength(password: str) -> tuple[bool, list[str]]: """ Validate password strength following professional security standards. Args: password: Password to validate Returns: tuple[bool, list[str]]: (is_valid, error_messages) """ errors = [] # Professional security standards: 8-72 characters (bcrypt limit) if len(password) < 8: errors.append("Password must be at least 8 characters long") if len(password) > 72: errors.append("Password must be less than 72 characters long (bcrypt limit)") if not any(c.isupper() for c in password): errors.append("Password must contain at least one uppercase letter") if not any(c.islower() for c in password): errors.append("Password must contain at least one lowercase letter") if not any(c.isdigit() for c in password): errors.append("Password must contain at least one digit") if not any(c in "!@#$%^&*()_+-=[]{}|;:,.<>?" for c in password): errors.append("Password must contain at least one special character") return len(errors) == 0, errors def sanitize_input(input_string: str) -> str: """ Sanitize user input to prevent injection attacks. Args: input_string: Input string to sanitize Returns: str: Sanitized string """ # Remove potentially dangerous characters dangerous_chars = ["<", ">", "&", "\"", "'", "/", "\\"] sanitized = input_string for char in dangerous_chars: sanitized = sanitized.replace(char, "") return sanitized.strip() def rate_limit_key(identifier: str, window: int = 60) -> str: """ Generate rate limit key for Redis. Args: identifier: Unique identifier (IP, user ID, etc.) window: Time window in seconds Returns: str: Rate limit key """ timestamp = int(datetime.utcnow().timestamp() / window) return f"rate_limit:{identifier}:{timestamp}" def verify_websocket_token(token: str) -> Optional[int]: """ Verify WebSocket JWT token and return user ID. Args: token: JWT token from WebSocket connection Returns: Optional[int]: User ID if valid, None otherwise """ try: # Verify the token using the same verification function user_id = verify_token(token) if user_id: return int(user_id) return None except (JWTError, ValueError, TypeError): return None def create_refresh_token( data: dict, expires_delta: Optional[timedelta] = None ) -> str: """ Create JWT refresh token. Args: data: Data to encode in token expires_delta: Optional expiration delta (default: 7 days) Returns: str: JWT refresh token """ if expires_delta: expire = datetime.utcnow() + expires_delta else: expire = datetime.utcnow() + timedelta(days=7) to_encode = {"exp": expire, **data} encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM) return encoded_jwt def verify_refresh_token(token: str) -> Optional[dict]: """ Verify and decode JWT refresh token. Args: token: JWT refresh token Returns: Optional[dict]: Token payload if valid, None otherwise """ try: payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) return payload except JWTError: return None