ALM-2 / backend /core /security.py
ACA050's picture
Upload 520 files
2ed8996 verified
"""
Security utilities for AegisLM SaaS Backend.
Production-ready JWT handling, password hashing,
API key generation, and security utilities.
"""
import secrets
from datetime import datetime, timedelta
from typing import Optional, Union, Any
from jose import JWTError, jwt
from passlib.context import CryptContext
from passlib.hash import bcrypt
from .config import settings
# Password hashing context
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def create_access_token(
subject: Union[str, Any],
expires_delta: Optional[timedelta] = None
) -> str:
"""
Create JWT access token.
Args:
subject: Token subject (usually user ID)
expires_delta: Optional expiration delta
Returns:
str: JWT token
"""
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
)
to_encode = {"exp": expire, "sub": str(subject)}
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
return encoded_jwt
def verify_token(token: str) -> Optional[str]:
"""
Verify and decode JWT token.
Args:
token: JWT token to verify
Returns:
Optional[str]: User ID if valid, None otherwise
"""
try:
payload = jwt.decode(
token,
settings.SECRET_KEY,
algorithms=[settings.ALGORITHM]
)
user_id: str = payload.get("sub")
if user_id is None:
return None
return user_id
except JWTError:
return None
def get_password_hash(password: str) -> str:
"""
Hash password using Argon2 (more secure than bcrypt).
Args:
password: Plain text password (8-72 characters)
Returns:
str: Hashed password
"""
# Validate password length for security
if len(password) < 8:
raise ValueError("Password must be at least 8 characters long for security")
if len(password) > 72:
raise ValueError("Password must be less than 72 characters long")
try:
# Use Argon2 - more secure and better compatibility
from passlib.context import CryptContext
pwd_context = CryptContext(
schemes=["argon2"],
argon2__time_cost=3, # Number of iterations
argon2__memory_cost=65536, # Memory usage in KB
argon2__parallelism=4, # Number of parallel threads
argon2__hash_len=32, # Hash length
deprecated="auto"
)
return pwd_context.hash(password)
except ImportError:
# Fallback to bcrypt if Argon2 not available
try:
from passlib.context import CryptContext
pwd_context = CryptContext(
schemes=["bcrypt"],
bcrypt__rounds=12,
deprecated="auto"
)
return pwd_context.hash(password)
except Exception as e:
raise Exception(f"Password hashing failed: {str(e)}")
except Exception as e:
raise Exception(f"Password hashing failed: {str(e)}")
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""
Verify password against hash (supports Argon2 and bcrypt).
Args:
plain_password: Plain text password
hashed_password: Hashed password
Returns:
bool: True if password matches
"""
try:
# Use Argon2 context for verification
from passlib.context import CryptContext
pwd_context = CryptContext(
schemes=["argon2", "bcrypt"], # Support both
argon2__time_cost=3,
argon2__memory_cost=65536,
argon2__parallelism=4,
argon2__hash_len=32,
deprecated="auto"
)
return pwd_context.verify(plain_password, hashed_password)
except Exception as e:
return False
def generate_api_key() -> str:
"""
Generate secure API key.
Returns:
str: API key
"""
return secrets.token_urlsafe(settings.API_KEY_LENGTH)
def generate_secure_token(length: int = 32) -> str:
"""
Generate secure random token.
Args:
length: Token length
Returns:
str: Secure token
"""
return secrets.token_urlsafe(length)
def mask_sensitive_data(data: str, mask_char: str = "*", visible_chars: int = 4) -> str:
"""
Mask sensitive data for logging.
Args:
data: Sensitive data to mask
mask_char: Character to use for masking
visible_chars: Number of characters to keep visible
Returns:
str: Masked data
"""
if len(data) <= visible_chars:
return mask_char * len(data)
return data[:visible_chars] + mask_char * (len(data) - visible_chars)
def validate_password_strength(password: str) -> tuple[bool, list[str]]:
"""
Validate password strength following professional security standards.
Args:
password: Password to validate
Returns:
tuple[bool, list[str]]: (is_valid, error_messages)
"""
errors = []
# Professional security standards: 8-72 characters (bcrypt limit)
if len(password) < 8:
errors.append("Password must be at least 8 characters long")
if len(password) > 72:
errors.append("Password must be less than 72 characters long (bcrypt limit)")
if not any(c.isupper() for c in password):
errors.append("Password must contain at least one uppercase letter")
if not any(c.islower() for c in password):
errors.append("Password must contain at least one lowercase letter")
if not any(c.isdigit() for c in password):
errors.append("Password must contain at least one digit")
if not any(c in "!@#$%^&*()_+-=[]{}|;:,.<>?" for c in password):
errors.append("Password must contain at least one special character")
return len(errors) == 0, errors
def sanitize_input(input_string: str) -> str:
"""
Sanitize user input to prevent injection attacks.
Args:
input_string: Input string to sanitize
Returns:
str: Sanitized string
"""
# Remove potentially dangerous characters
dangerous_chars = ["<", ">", "&", "\"", "'", "/", "\\"]
sanitized = input_string
for char in dangerous_chars:
sanitized = sanitized.replace(char, "")
return sanitized.strip()
def rate_limit_key(identifier: str, window: int = 60) -> str:
"""
Generate rate limit key for Redis.
Args:
identifier: Unique identifier (IP, user ID, etc.)
window: Time window in seconds
Returns:
str: Rate limit key
"""
timestamp = int(datetime.utcnow().timestamp() / window)
return f"rate_limit:{identifier}:{timestamp}"
def verify_websocket_token(token: str) -> Optional[int]:
"""
Verify WebSocket JWT token and return user ID.
Args:
token: JWT token from WebSocket connection
Returns:
Optional[int]: User ID if valid, None otherwise
"""
try:
# Verify the token using the same verification function
user_id = verify_token(token)
if user_id:
return int(user_id)
return None
except (JWTError, ValueError, TypeError):
return None
def create_refresh_token(
data: dict,
expires_delta: Optional[timedelta] = None
) -> str:
"""
Create JWT refresh token.
Args:
data: Data to encode in token
expires_delta: Optional expiration delta (default: 7 days)
Returns:
str: JWT refresh token
"""
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(days=7)
to_encode = {"exp": expire, **data}
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
return encoded_jwt
def verify_refresh_token(token: str) -> Optional[dict]:
"""
Verify and decode JWT refresh token.
Args:
token: JWT refresh token
Returns:
Optional[dict]: Token payload if valid, None otherwise
"""
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
return payload
except JWTError:
return None