| """ |
| Security utilities for AegisLM SaaS Backend. |
| |
| Production-ready JWT handling, password hashing, |
| API key generation, and security utilities. |
| """ |
|
|
| import secrets |
| from datetime import datetime, timedelta |
| from typing import Optional, Union, Any |
| from jose import JWTError, jwt |
| from passlib.context import CryptContext |
| from passlib.hash import bcrypt |
|
|
| from .config import settings |
|
|
|
|
| |
| pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") |
|
|
|
|
| def create_access_token( |
| subject: Union[str, Any], |
| expires_delta: Optional[timedelta] = None |
| ) -> str: |
| """ |
| Create JWT access token. |
| |
| Args: |
| subject: Token subject (usually user ID) |
| expires_delta: Optional expiration delta |
| |
| Returns: |
| str: JWT token |
| """ |
| if expires_delta: |
| expire = datetime.utcnow() + expires_delta |
| else: |
| expire = datetime.utcnow() + timedelta( |
| minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES |
| ) |
| |
| to_encode = {"exp": expire, "sub": str(subject)} |
| encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM) |
| return encoded_jwt |
|
|
|
|
| def verify_token(token: str) -> Optional[str]: |
| """ |
| Verify and decode JWT token. |
| |
| Args: |
| token: JWT token to verify |
| |
| Returns: |
| Optional[str]: User ID if valid, None otherwise |
| """ |
| try: |
| payload = jwt.decode( |
| token, |
| settings.SECRET_KEY, |
| algorithms=[settings.ALGORITHM] |
| ) |
| user_id: str = payload.get("sub") |
| if user_id is None: |
| return None |
| return user_id |
| except JWTError: |
| return None |
|
|
|
|
| def get_password_hash(password: str) -> str: |
| """ |
| Hash password using Argon2 (more secure than bcrypt). |
| |
| Args: |
| password: Plain text password (8-72 characters) |
| |
| Returns: |
| str: Hashed password |
| """ |
| |
| if len(password) < 8: |
| raise ValueError("Password must be at least 8 characters long for security") |
| if len(password) > 72: |
| raise ValueError("Password must be less than 72 characters long") |
| |
| try: |
| |
| from passlib.context import CryptContext |
| pwd_context = CryptContext( |
| schemes=["argon2"], |
| argon2__time_cost=3, |
| argon2__memory_cost=65536, |
| argon2__parallelism=4, |
| argon2__hash_len=32, |
| deprecated="auto" |
| ) |
| return pwd_context.hash(password) |
| except ImportError: |
| |
| try: |
| from passlib.context import CryptContext |
| pwd_context = CryptContext( |
| schemes=["bcrypt"], |
| bcrypt__rounds=12, |
| deprecated="auto" |
| ) |
| return pwd_context.hash(password) |
| except Exception as e: |
| raise Exception(f"Password hashing failed: {str(e)}") |
| except Exception as e: |
| raise Exception(f"Password hashing failed: {str(e)}") |
|
|
|
|
| def verify_password(plain_password: str, hashed_password: str) -> bool: |
| """ |
| Verify password against hash (supports Argon2 and bcrypt). |
| |
| Args: |
| plain_password: Plain text password |
| hashed_password: Hashed password |
| |
| Returns: |
| bool: True if password matches |
| """ |
| try: |
| |
| from passlib.context import CryptContext |
| pwd_context = CryptContext( |
| schemes=["argon2", "bcrypt"], |
| argon2__time_cost=3, |
| argon2__memory_cost=65536, |
| argon2__parallelism=4, |
| argon2__hash_len=32, |
| deprecated="auto" |
| ) |
| return pwd_context.verify(plain_password, hashed_password) |
| except Exception as e: |
| return False |
|
|
|
|
| def generate_api_key() -> str: |
| """ |
| Generate secure API key. |
| |
| Returns: |
| str: API key |
| """ |
| return secrets.token_urlsafe(settings.API_KEY_LENGTH) |
|
|
|
|
| def generate_secure_token(length: int = 32) -> str: |
| """ |
| Generate secure random token. |
| |
| Args: |
| length: Token length |
| |
| Returns: |
| str: Secure token |
| """ |
| return secrets.token_urlsafe(length) |
|
|
|
|
| def mask_sensitive_data(data: str, mask_char: str = "*", visible_chars: int = 4) -> str: |
| """ |
| Mask sensitive data for logging. |
| |
| Args: |
| data: Sensitive data to mask |
| mask_char: Character to use for masking |
| visible_chars: Number of characters to keep visible |
| |
| Returns: |
| str: Masked data |
| """ |
| if len(data) <= visible_chars: |
| return mask_char * len(data) |
| |
| return data[:visible_chars] + mask_char * (len(data) - visible_chars) |
|
|
|
|
| def validate_password_strength(password: str) -> tuple[bool, list[str]]: |
| """ |
| Validate password strength following professional security standards. |
| |
| Args: |
| password: Password to validate |
| |
| Returns: |
| tuple[bool, list[str]]: (is_valid, error_messages) |
| """ |
| errors = [] |
| |
| |
| if len(password) < 8: |
| errors.append("Password must be at least 8 characters long") |
| |
| if len(password) > 72: |
| errors.append("Password must be less than 72 characters long (bcrypt limit)") |
| |
| if not any(c.isupper() for c in password): |
| errors.append("Password must contain at least one uppercase letter") |
| |
| if not any(c.islower() for c in password): |
| errors.append("Password must contain at least one lowercase letter") |
| |
| if not any(c.isdigit() for c in password): |
| errors.append("Password must contain at least one digit") |
| |
| if not any(c in "!@#$%^&*()_+-=[]{}|;:,.<>?" for c in password): |
| errors.append("Password must contain at least one special character") |
| |
| return len(errors) == 0, errors |
|
|
|
|
| def sanitize_input(input_string: str) -> str: |
| """ |
| Sanitize user input to prevent injection attacks. |
| |
| Args: |
| input_string: Input string to sanitize |
| |
| Returns: |
| str: Sanitized string |
| """ |
| |
| dangerous_chars = ["<", ">", "&", "\"", "'", "/", "\\"] |
| sanitized = input_string |
| |
| for char in dangerous_chars: |
| sanitized = sanitized.replace(char, "") |
| |
| return sanitized.strip() |
|
|
|
|
| def rate_limit_key(identifier: str, window: int = 60) -> str: |
| """ |
| Generate rate limit key for Redis. |
| |
| Args: |
| identifier: Unique identifier (IP, user ID, etc.) |
| window: Time window in seconds |
| |
| Returns: |
| str: Rate limit key |
| """ |
| timestamp = int(datetime.utcnow().timestamp() / window) |
| return f"rate_limit:{identifier}:{timestamp}" |
|
|
|
|
| def verify_websocket_token(token: str) -> Optional[int]: |
| """ |
| Verify WebSocket JWT token and return user ID. |
| |
| Args: |
| token: JWT token from WebSocket connection |
| |
| Returns: |
| Optional[int]: User ID if valid, None otherwise |
| """ |
| try: |
| |
| user_id = verify_token(token) |
| if user_id: |
| return int(user_id) |
| return None |
| except (JWTError, ValueError, TypeError): |
| return None |
|
|
|
|
| def create_refresh_token( |
| data: dict, |
| expires_delta: Optional[timedelta] = None |
| ) -> str: |
| """ |
| Create JWT refresh token. |
| |
| Args: |
| data: Data to encode in token |
| expires_delta: Optional expiration delta (default: 7 days) |
| |
| Returns: |
| str: JWT refresh token |
| """ |
| if expires_delta: |
| expire = datetime.utcnow() + expires_delta |
| else: |
| expire = datetime.utcnow() + timedelta(days=7) |
| |
| to_encode = {"exp": expire, **data} |
| encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM) |
| return encoded_jwt |
|
|
|
|
| def verify_refresh_token(token: str) -> Optional[dict]: |
| """ |
| Verify and decode JWT refresh token. |
| |
| Args: |
| token: JWT refresh token |
| |
| Returns: |
| Optional[dict]: Token payload if valid, None otherwise |
| """ |
| try: |
| payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) |
| return payload |
| except JWTError: |
| return None |
|
|