Spaces:
Sleeping
Sleeping
| """ | |
| RSA/Hybrid Decryption utilities for the URL Blink application. | |
| Supports two encryption modes from the client: | |
| 1. Direct RSA-OAEP (for data ≤ 190 bytes) | |
| 2. Hybrid RSA-OAEP + AES-GCM (for data > 190 bytes) | |
| """ | |
| import base64 | |
| import json | |
| import os | |
| import logging | |
| from typing import Any, Optional | |
| from cryptography.hazmat.primitives import serialization, hashes | |
| from cryptography.hazmat.primitives.asymmetric import padding | |
| from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes | |
| from cryptography.hazmat.backends import default_backend | |
| logger = logging.getLogger(__name__) | |
| # Path to the private key file | |
| PRIVATE_KEY_PATH = os.getenv("PRIVATE_KEY_PATH", "./PRIVATE_KEY.pem") | |
| # Cache the private key after first load | |
| _private_key = None | |
| def load_private_key(): | |
| """ | |
| Load the RSA private key from environment variable or PEM file. | |
| Caches the key for subsequent calls. | |
| Priority: | |
| 1. PRIVATE_KEY environment variable (PEM content) | |
| 2. PRIVATE_KEY_PATH file | |
| Returns: | |
| RSA private key object, or None if not available | |
| """ | |
| global _private_key | |
| if _private_key is not None: | |
| return _private_key | |
| # Try loading from environment variable first | |
| private_key_pem = os.getenv("PRIVATE_KEY") | |
| if private_key_pem: | |
| try: | |
| _private_key = serialization.load_pem_private_key( | |
| private_key_pem.encode(), | |
| password=None, | |
| backend=default_backend() | |
| ) | |
| logger.info("Successfully loaded private key from PRIVATE_KEY env variable") | |
| return _private_key | |
| except Exception as e: | |
| logger.warning(f"Failed to load private key from env: {e}") | |
| # Try loading from file | |
| try: | |
| with open(PRIVATE_KEY_PATH, "rb") as key_file: | |
| _private_key = serialization.load_pem_private_key( | |
| key_file.read(), | |
| password=None, | |
| backend=default_backend() | |
| ) | |
| logger.info(f"Successfully loaded private key from {PRIVATE_KEY_PATH}") | |
| return _private_key | |
| except FileNotFoundError: | |
| logger.warning(f"Private key file not found: {PRIVATE_KEY_PATH}") | |
| except Exception as e: | |
| logger.warning(f"Failed to load private key from file: {e}") | |
| logger.warning("No private key available - encrypted data will not be decrypted") | |
| return None | |
| def decrypt_direct(payload: dict, private_key) -> str: | |
| """ | |
| Decrypt directly RSA-OAEP encrypted data. | |
| Args: | |
| payload: Dict with 'data' field containing base64 RSA-encrypted data | |
| private_key: RSA private key object | |
| Returns: | |
| Decrypted plaintext string | |
| """ | |
| encrypted_bytes = base64.b64decode(payload['data']) | |
| decrypted = private_key.decrypt( | |
| encrypted_bytes, | |
| padding.OAEP( | |
| mgf=padding.MGF1(algorithm=hashes.SHA256()), | |
| algorithm=hashes.SHA256(), | |
| label=None | |
| ) | |
| ) | |
| return decrypted.decode('utf-8') | |
| def decrypt_hybrid(payload: dict, private_key) -> str: | |
| """ | |
| Decrypt hybrid RSA+AES-GCM encrypted data. | |
| Args: | |
| payload: Dict with 'key' (RSA-encrypted AES key), 'iv', and 'data' (AES-encrypted) | |
| private_key: RSA private key object | |
| Returns: | |
| Decrypted plaintext string | |
| """ | |
| # 1. Decrypt the AES key with RSA-OAEP | |
| encrypted_aes_key = base64.b64decode(payload['key']) | |
| aes_key = private_key.decrypt( | |
| encrypted_aes_key, | |
| padding.OAEP( | |
| mgf=padding.MGF1(algorithm=hashes.SHA256()), | |
| algorithm=hashes.SHA256(), | |
| label=None | |
| ) | |
| ) | |
| # 2. Decrypt the data with AES-GCM | |
| iv = base64.b64decode(payload['iv']) | |
| encrypted_data = base64.b64decode(payload['data']) | |
| # AES-GCM: the tag is appended to the ciphertext (last 16 bytes) | |
| # Split ciphertext and tag | |
| tag = encrypted_data[-16:] | |
| ciphertext = encrypted_data[:-16] | |
| cipher = Cipher( | |
| algorithms.AES(aes_key), | |
| modes.GCM(iv, tag), | |
| backend=default_backend() | |
| ) | |
| decryptor = cipher.decryptor() | |
| decrypted = decryptor.update(ciphertext) + decryptor.finalize() | |
| return decrypted.decode('utf-8') | |
| def decrypt_data(encrypted_base64: str) -> Optional[Any]: | |
| """ | |
| Decrypt data encrypted by the client usageService. | |
| The encrypted data format is: btoa(JSON.stringify({ type: 'direct'|'hybrid', ... })) | |
| Args: | |
| encrypted_base64: The outer base64 string from the client | |
| Returns: | |
| Decrypted data parsed as JSON, or error info if decryption fails | |
| """ | |
| private_key = load_private_key() | |
| # If no private key, return the encrypted data as-is | |
| if private_key is None: | |
| logger.warning("No private key - returning encrypted data") | |
| return {"encrypted_data": encrypted_base64, "decryption_status": "no_key_available"} | |
| try: | |
| # Decode outer base64 and parse JSON | |
| outer_json = base64.b64decode(encrypted_base64).decode('utf-8') | |
| payload = json.loads(outer_json) | |
| encryption_type = payload.get('type') | |
| if encryption_type == 'direct': | |
| decrypted_str = decrypt_direct(payload, private_key) | |
| elif encryption_type == 'hybrid': | |
| decrypted_str = decrypt_hybrid(payload, private_key) | |
| else: | |
| logger.error(f"Unknown encryption type: {encryption_type}") | |
| return {"encrypted_data": encrypted_base64, "decryption_error": f"Unknown type: {encryption_type}"} | |
| # Try to parse decrypted string as JSON | |
| try: | |
| return json.loads(decrypted_str) | |
| except json.JSONDecodeError: | |
| return {"raw_data": decrypted_str} | |
| except Exception as e: | |
| logger.error(f"Decryption failed: {e}") | |
| return {"encrypted_data": encrypted_base64[:100] + "...", "decryption_error": str(e)} | |
| def decrypt_multiple_blocks(encrypted_data: str) -> list[Any]: | |
| """ | |
| Decrypt multiple concatenated encrypted blocks. | |
| The input is expected to be a comma-separated list of base64 strings. | |
| Args: | |
| encrypted_data: Comma-separated base64-encoded encrypted blocks | |
| Returns: | |
| List of decrypted data objects | |
| """ | |
| results = [] | |
| if not encrypted_data: | |
| return results | |
| # Split by comma to separate blocks | |
| # Base64 does not use commas, so this is safe | |
| blocks = encrypted_data.split(',') | |
| for block in blocks: | |
| block = block.strip() | |
| if not block: | |
| continue | |
| result = decrypt_data(block) | |
| results.append(result) | |
| return results | |