File size: 2,312 Bytes
2ba0613 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 | """Fernet encryption utilities for user-registered database credentials.
Encryption key is sourced from `dataeyond__db__credential__key` env variable,
intentionally separate from the user-auth bcrypt salt (`emarcal__bcrypt__salt`).
Usage:
from src.utils.db_credential_encryption import encrypt_credentials_dict, decrypt_credentials_dict
# Before INSERT:
safe_creds = encrypt_credentials_dict(raw_credentials)
# After SELECT:
plain_creds = decrypt_credentials_dict(row.credentials)
"""
from cryptography.fernet import Fernet
from src.config.settings import settings
# Sensitive credential field names that must be encrypted at rest.
# Covers all supported DB types:
# - password : postgres, mysql, sqlserver, supabase, snowflake
# - service_account_json : bigquery
SENSITIVE_FIELDS: frozenset[str] = frozenset({"password", "service_account_json"})
def _get_cipher() -> Fernet:
key = settings.dataeyond_db_credential_key
if not key:
raise ValueError(
"dataeyond__db__credential__key is not set. "
"Generate one with: Fernet.generate_key().decode()"
)
return Fernet(key.encode())
def encrypt_credential(value: str) -> str:
"""Encrypt a single credential string value."""
return _get_cipher().encrypt(value.encode()).decode()
def decrypt_credential(value: str) -> str:
"""Decrypt a single Fernet-encrypted credential string."""
return _get_cipher().decrypt(value.encode()).decode()
def encrypt_credentials_dict(creds: dict) -> dict:
"""Return a copy of the credentials dict with sensitive fields encrypted.
Call this before inserting a new DatabaseClient record.
"""
cipher = _get_cipher()
result = dict(creds)
for field in SENSITIVE_FIELDS:
if result.get(field):
result[field] = cipher.encrypt(result[field].encode()).decode()
return result
def decrypt_credentials_dict(creds: dict) -> dict:
"""Return a copy of the credentials dict with sensitive fields decrypted.
Call this after fetching a DatabaseClient record from DB.
"""
cipher = _get_cipher()
result = dict(creds)
for field in SENSITIVE_FIELDS:
if result.get(field):
result[field] = cipher.decrypt(result[field].encode()).decode()
return result
|