"""Cryptographic utilities for secret encryption/decryption. Provides Fernet symmetric authenticated encryption over wallet seeds. Secrets are stored with a sentinel prefix: ENC:: If a value does not start with ENC:: it is treated as legacy (plain/base64) and should be migrated by the caller. """ from __future__ import annotations import base64 import os from typing import Optional from cryptography.fernet import Fernet, InvalidToken _FERNET: Optional[Fernet] = None _PREFIX = "ENC::" def _derive_or_load_key(raw: Optional[str]) -> Optional[bytes]: """Return a 32-byte urlsafe base64-encoded key for Fernet or None. Accepts either: * Already a 44 char Fernet key (base64 urlsafe) -> used directly * 32 raw bytes in base64 (decode then re-encode to urlsafe) (len 43/44) * Hex string of length 64 -> interpreted as 32 bytes then base64 encoded """ if not raw: return None raw = raw.strip() # Direct Fernet key if len(raw) in (43, 44) and all(c.isalnum() or c in ("-", "_") for c in raw.rstrip("=")): try: base64.urlsafe_b64decode(raw + ("=" * (-len(raw) % 4))) return raw.encode() except Exception: pass # Hex 64 -> bytes if len(raw) == 64: try: b = bytes.fromhex(raw) return base64.urlsafe_b64encode(b) except ValueError: pass # Fallback: if it decodes from base64 to 32 bytes use it try: b = base64.b64decode(raw + ("=" * (-len(raw) % 4))) if len(b) == 32: return base64.urlsafe_b64encode(b) except Exception: pass raise ValueError("ENCRYPTION_KEY provided is not a valid 32-byte or Fernet key material") def get_fernet(encryption_key: Optional[str]) -> Optional[Fernet]: global _FERNET if _FERNET is not None: return _FERNET key_bytes = _derive_or_load_key(encryption_key) if not key_bytes: return None _FERNET = Fernet(key_bytes) return _FERNET def encrypt_secret(plaintext: str, encryption_key: Optional[str]) -> str: """Encrypt a secret; returns ENC::. If key missing returns original (no prefix).""" if not plaintext: return plaintext f = get_fernet(encryption_key) if not f: return plaintext # no encryption key configured token = f.encrypt(plaintext.encode()).decode() return f"{_PREFIX}{token}" def is_encrypted(value: str) -> bool: return isinstance(value, str) and value.startswith(_PREFIX) def decrypt_secret(value: Optional[str], encryption_key: Optional[str]) -> Optional[str]: if not value: return value if not is_encrypted(value): # Legacy base64 encoded value - decode it try: return base64.b64decode(value).decode() except Exception: # If decode fails, return as-is (might be plain text) return value token = value[len(_PREFIX):] f = get_fernet(encryption_key) if not f: return value # cannot decrypt without key try: return f.decrypt(token.encode()).decode() except InvalidToken: # Return sentinel text to signal decryption issue without raising return None