"""Fernet encryption utilities for user-registered database credentials. Encryption key is sourced from `dataeyond__db__credential__key` env variable, intentionally separate from the user-auth bcrypt salt (`emarcal__bcrypt__salt`). Usage: from src.utils.db_credential_encryption import encrypt_credentials_dict, decrypt_credentials_dict # Before INSERT: safe_creds = encrypt_credentials_dict(raw_credentials) # After SELECT: plain_creds = decrypt_credentials_dict(row.credentials) """ from cryptography.fernet import Fernet from src.config.settings import settings # Sensitive credential field names that must be encrypted at rest. # Covers all supported DB types: # - password : postgres, mysql, sqlserver, supabase, snowflake # - service_account_json : bigquery SENSITIVE_FIELDS: frozenset[str] = frozenset({"password", "service_account_json"}) def _get_cipher() -> Fernet: key = settings.dataeyond_db_credential_key if not key: raise ValueError( "dataeyond__db__credential__key is not set. " "Generate one with: Fernet.generate_key().decode()" ) return Fernet(key.encode()) def encrypt_credential(value: str) -> str: """Encrypt a single credential string value.""" return _get_cipher().encrypt(value.encode()).decode() def decrypt_credential(value: str) -> str: """Decrypt a single Fernet-encrypted credential string.""" return _get_cipher().decrypt(value.encode()).decode() def encrypt_credentials_dict(creds: dict) -> dict: """Return a copy of the credentials dict with sensitive fields encrypted. Call this before inserting a new DatabaseClient record. """ cipher = _get_cipher() result = dict(creds) for field in SENSITIVE_FIELDS: if result.get(field): result[field] = cipher.encrypt(result[field].encode()).decode() return result def decrypt_credentials_dict(creds: dict) -> dict: """Return a copy of the credentials dict with sensitive fields decrypted. Call this after fetching a DatabaseClient record from DB. """ cipher = _get_cipher() result = dict(creds) for field in SENSITIVE_FIELDS: if result.get(field): result[field] = cipher.decrypt(result[field].encode()).decode() return result