""" Encryption Utility for AegisLM Provides AES-256 encryption for sensitive fields at rest. """ import base64 import hashlib import os from typing import Optional, Tuple from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import padding from security.secret_manager import get_encryption_key class EncryptionService: """ AES-256 encryption service for sensitive data at rest. Uses AES-256-CBC with PKCS7 padding. """ _instance: Optional["EncryptionService"] = None def __new__(cls) -> "EncryptionService": """Singleton pattern for encryption service.""" if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance def __init__(self): """Initialize the encryption service.""" if not hasattr(self, "_initialized"): self._key = self._derive_key(get_encryption_key()) self._initialized = True def _derive_key(self, key_input: str) -> bytes: """ Derive a 256-bit key from the input key. Args: key_input: Input key string Returns: 32-byte derived key """ # Use SHA-256 to derive a 32-byte key from the input return hashlib.sha256(key_input.encode()).digest() def _pad(self, data: bytes) -> bytes: """ Pad data using PKCS7 padding. Args: data: Data to pad Returns: Padded data """ padder = padding.PKCS7(algorithms.AES.block_size).padder() return padder.update(data) + padder.finalize() def _unpad(self, data: bytes) -> bytes: """ Remove PKCS7 padding. Args: data: Padded data Returns: Unpadded data """ unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder() return unpadder.update(data) + unpadder.finalize() def encrypt(self, plaintext: str) -> str: """ Encrypt plaintext using AES-256-CBC. Args: plaintext: Plain text to encrypt Returns: Base64-encoded ciphertext (IV + ciphertext) """ # Generate random IV iv = os.urandom(16) # Create cipher cipher = Cipher( algorithms.AES(self._key), modes.CBC(iv), backend=default_backend() ) encryptor = cipher.encryptor() # Pad and encrypt padded_data = self._pad(plaintext.encode("utf-8")) ciphertext = encryptor.update(padded_data) + encryptor.finalize() # Combine IV and ciphertext combined = iv + ciphertext # Return base64-encoded result return base64.b64encode(combined).decode("utf-8") def decrypt(self, ciphertext: str) -> str: """ Decrypt ciphertext using AES-256-CBC. Args: ciphertext: Base64-encoded ciphertext (IV + ciphertext) Returns: Decrypted plaintext """ # Decode from base64 combined = base64.b64decode(ciphertext) # Extract IV and ciphertext iv = combined[:16] ciphertext = combined[16:] # Create cipher cipher = Cipher( algorithms.AES(self._key), modes.CBC(iv), backend=default_backend() ) decryptor = cipher.decryptor() # Decrypt and unpad padded_data = decryptor.update(ciphertext) + decryptor.finalize() plaintext = self._unpad(padded_data) return plaintext.decode("utf-8") def encrypt_bytes(self, data: bytes) -> str: """ Encrypt bytes data. Args: data: Bytes to encrypt Returns: Base64-encoded ciphertext """ return self.encrypt(base64.b64encode(data).decode("utf-8")) def decrypt_bytes(self, ciphertext: str) -> bytes: """ Decrypt to bytes. Args: ciphertext: Base64-encoded ciphertext Returns: Decrypted bytes """ plaintext = self.decrypt(ciphertext) return base64.b64decode(plaintext) def get_encryption_service() -> EncryptionService: """ Get the singleton encryption service instance. Returns: EncryptionService instance """ return EncryptionService() # Convenience functions def encrypt(plaintext: str) -> str: """Encrypt plaintext.""" return get_encryption_service().encrypt(plaintext) def decrypt(ciphertext: str) -> str: """Decrypt ciphertext.""" return get_encryption_service().decrypt(ciphertext) # ============================================================================= # Password hashing (using bcrypt) # ============================================================================= def hash_password(password: str) -> str: """ Hash a password using bcrypt. Args: password: Plain text password Returns: bcrypt hash """ import bcrypt return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8") def verify_password(password: str, hashed: str) -> bool: """ Verify a password against a bcrypt hash. Args: password: Plain text password hashed: bcrypt hash Returns: True if password matches """ import bcrypt try: return bcrypt.checkpw(password.encode("utf-8"), hashed.encode("utf-8")) except Exception: return False