File size: 2,312 Bytes
0e07955
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""Fernet encryption utilities for user-registered database credentials.

Encryption key is sourced from `dataeyond__db__credential__key` env variable,
intentionally separate from the user-auth bcrypt salt (`emarcal__bcrypt__salt`).

Usage:
    from src.utils.db_credential_encryption import encrypt_credentials_dict, decrypt_credentials_dict

    # Before INSERT:
    safe_creds = encrypt_credentials_dict(raw_credentials)

    # After SELECT:
    plain_creds = decrypt_credentials_dict(row.credentials)
"""

from cryptography.fernet import Fernet
from src.config.settings import settings

# Sensitive credential field names that must be encrypted at rest.
# Covers all supported DB types:
#   - password      : postgres, mysql, sqlserver, supabase, snowflake
#   - service_account_json : bigquery
SENSITIVE_FIELDS: frozenset[str] = frozenset({"password", "service_account_json"})


def _get_cipher() -> Fernet:
    key = settings.dataeyond_db_credential_key
    if not key:
        raise ValueError(
            "dataeyond__db__credential__key is not set. "
            "Generate one with: Fernet.generate_key().decode()"
        )
    return Fernet(key.encode())


def encrypt_credential(value: str) -> str:
    """Encrypt a single credential string value."""
    return _get_cipher().encrypt(value.encode()).decode()


def decrypt_credential(value: str) -> str:
    """Decrypt a single Fernet-encrypted credential string."""
    return _get_cipher().decrypt(value.encode()).decode()


def encrypt_credentials_dict(creds: dict) -> dict:
    """Return a copy of the credentials dict with sensitive fields encrypted.

    Call this before inserting a new DatabaseClient record.
    """
    cipher = _get_cipher()
    result = dict(creds)
    for field in SENSITIVE_FIELDS:
        if result.get(field):
            result[field] = cipher.encrypt(result[field].encode()).decode()
    return result


def decrypt_credentials_dict(creds: dict) -> dict:
    """Return a copy of the credentials dict with sensitive fields decrypted.

    Call this after fetching a DatabaseClient record from DB.
    """
    cipher = _get_cipher()
    result = dict(creds)
    for field in SENSITIVE_FIELDS:
        if result.get(field):
            result[field] = cipher.decrypt(result[field].encode()).decode()
    return result