| from cryptography.fernet import Fernet |
| from cryptography.hazmat.primitives import hashes |
| from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC as PBKDF2 |
| import base64 |
| import os |
| from app.config import settings |
|
|
| class EncryptionManager: |
| """Encrypt sensitive data at rest""" |
| |
| def __init__(self): |
| self._cipher = None |
| self._key = None |
| |
| def _get_key(self) -> bytes: |
| """Derive encryption key from JWT secret""" |
| if self._key is not None: |
| return self._key |
| |
| |
| password = settings.jwt_secret.encode() |
| |
| salt = os.urandom(16) |
| |
| |
| kdf = PBKDF2( |
| algorithm=hashes.SHA256(), |
| length=32, |
| salt=salt, |
| iterations=600000, |
| ) |
| |
| key = base64.urlsafe_b64encode(kdf.derive(password)) |
| self._key = key |
| return key |
| |
| def _get_cipher(self) -> Fernet: |
| """Get Fernet cipher instance""" |
| if not self._cipher: |
| self._cipher = Fernet(self._get_key()) |
| return self._cipher |
| |
| def encrypt(self, data: str) -> str: |
| """Encrypt string data""" |
| if not data: |
| return data |
| |
| cipher = self._get_cipher() |
| encrypted = cipher.encrypt(data.encode()) |
| return base64.urlsafe_b64encode(encrypted).decode() |
| |
| def decrypt(self, encrypted_data: str) -> str: |
| """Decrypt string data""" |
| if not encrypted_data: |
| return encrypted_data |
| |
| try: |
| cipher = self._get_cipher() |
| decoded = base64.urlsafe_b64decode(encrypted_data.encode()) |
| decrypted = cipher.decrypt(decoded) |
| return decrypted.decode() |
| except Exception as e: |
| raise ValueError(f"Decryption failed: {e}") |
| |
| def encrypt_dict(self, data: dict, fields: list) -> dict: |
| """Encrypt specific fields in a dictionary""" |
| encrypted = data.copy() |
| for field in fields: |
| if field in encrypted and encrypted[field]: |
| encrypted[field] = self.encrypt(str(encrypted[field])) |
| return encrypted |
| |
| def decrypt_dict(self, data: dict, fields: list) -> dict: |
| """Decrypt specific fields in a dictionary""" |
| decrypted = data.copy() |
| for field in fields: |
| if field in decrypted and decrypted[field]: |
| try: |
| decrypted[field] = self.decrypt(decrypted[field]) |
| except: |
| pass |
| return decrypted |
| |
| def hash_sensitive_data(self, data: str) -> str: |
| """One-way hash for sensitive data (like API keys for comparison)""" |
| import hashlib |
| return hashlib.sha256(data.encode()).hexdigest() |
|
|
| encryption_manager = EncryptionManager() |
|
|
|
|
| class SecureStorage: |
| """Secure storage for sensitive configuration""" |
| |
| def __init__(self): |
| self.encryption = encryption_manager |
| |
| def store_hf_credentials(self, user_id: str, token: str, repo: str) -> dict: |
| """Store HuggingFace credentials securely""" |
| return { |
| 'user_id': user_id, |
| 'hf_token': self.encryption.encrypt(token) if token else None, |
| 'hf_repo': repo |
| } |
| |
| def retrieve_hf_credentials(self, stored_data: dict) -> dict: |
| """Retrieve and decrypt HuggingFace credentials""" |
| return { |
| 'user_id': stored_data.get('user_id'), |
| 'hf_token': self.encryption.decrypt(stored_data['hf_token']) if stored_data.get('hf_token') else None, |
| 'hf_repo': stored_data.get('hf_repo') |
| } |
| |
| def mask_sensitive_data(self, data: str, visible_chars: int = 4) -> str: |
| """Mask sensitive data for display (e.g., API keys)""" |
| if not data or len(data) <= visible_chars: |
| return '*' * len(data) if data else '' |
| |
| return data[:visible_chars] + '*' * (len(data) - visible_chars) |
|
|
| secure_storage = SecureStorage() |
|
|