""" RSA Decryption utilities for the URL Blink application. """ import base64 import json import os import logging from typing import Any, Optional from cryptography.hazmat.primitives import serialization, hashes from cryptography.hazmat.primitives.asymmetric import padding from cryptography.hazmat.backends import default_backend logger = logging.getLogger(__name__) # Path to the private key file PRIVATE_KEY_PATH = os.getenv("PRIVATE_KEY_PATH", "./PRIVATE_KEY.pem") # Cache the private key after first load _private_key = None def load_private_key(): """ Load the RSA private key from environment variable or PEM file. Caches the key for subsequent calls. Priority: 1. PRIVATE_KEY environment variable (PEM content) 2. PRIVATE_KEY_PATH file Returns: RSA private key object, or None if not available """ global _private_key if _private_key is not None: return _private_key # Try loading from environment variable first private_key_pem = os.getenv("PRIVATE_KEY") if private_key_pem: try: _private_key = serialization.load_pem_private_key( private_key_pem.encode(), password=None, backend=default_backend() ) logger.info("Successfully loaded private key from PRIVATE_KEY env variable") return _private_key except Exception as e: logger.warning(f"Failed to load private key from env: {e}") # Try loading from file try: with open(PRIVATE_KEY_PATH, "rb") as key_file: _private_key = serialization.load_pem_private_key( key_file.read(), password=None, backend=default_backend() ) logger.info(f"Successfully loaded private key from {PRIVATE_KEY_PATH}") return _private_key except FileNotFoundError: logger.warning(f"Private key file not found: {PRIVATE_KEY_PATH}") except Exception as e: logger.warning(f"Failed to load private key from file: {e}") logger.warning("No private key available - encrypted data will not be decrypted") return None def decrypt_data(encrypted_base64: str) -> Optional[Any]: """ Decrypt base64-encoded RSA encrypted data. Args: encrypted_base64: Base64 URL-safe encoded encrypted string Returns: Decrypted data parsed as JSON, encrypted data if no key available, or None on error """ try: # Load the private key private_key = load_private_key() # If no private key, return the encrypted data as-is if private_key is None: logger.warning("No private key - returning encrypted data") return {"encrypted_data": encrypted_base64, "decryption_status": "no_key_available"} # Decode base64 URL-safe encoded data # Add padding if necessary padded = encrypted_base64 + '=' * (4 - len(encrypted_base64) % 4) encrypted_bytes = base64.urlsafe_b64decode(padded) # Decrypt using RSA OAEP with SHA256 decrypted_bytes = private_key.decrypt( encrypted_bytes, padding.OAEP( mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None ) ) # Decode to string decrypted_str = decrypted_bytes.decode('utf-8') # Try to parse as JSON try: return json.loads(decrypted_str) except json.JSONDecodeError: # Return as raw string if not valid JSON logger.warning("Decrypted data is not valid JSON, returning raw string") return {"raw_data": decrypted_str} except Exception as e: logger.error(f"Decryption failed: {e}") # Return encrypted data on failure return {"encrypted_data": encrypted_base64, "decryption_error": str(e)} def decrypt_multiple_blocks(encrypted_data: str) -> list[Any]: """ Decrypt multiple concatenated encrypted blocks. RSA 2048-bit encrypted data is 256 bytes, which can be: - 344 chars in base64 with padding - 342 chars in base64 without padding (URL-safe) This function tries multiple block sizes to find the right one. Args: encrypted_data: Concatenated base64-encoded encrypted blocks Returns: List of decrypted data objects """ results = [] # Common block sizes for RSA-2048 in base64 # 344 = with padding, 342 = without padding POSSIBLE_BLOCK_SIZES = [344, 342, 343, 256] # First, try to decrypt as a single block if len(encrypted_data) <= 350: result = decrypt_data(encrypted_data) if result: results.append(result) return results # Try each possible block size for block_size in POSSIBLE_BLOCK_SIZES: # Check if data length is divisible by block size if len(encrypted_data) % block_size == 0: blocks_results = [] success = True for i in range(0, len(encrypted_data), block_size): block = encrypted_data[i:i + block_size] result = decrypt_data(block) if result and "decryption_error" not in result: blocks_results.append(result) else: success = False break if success and blocks_results: logger.info(f"Successfully decrypted {len(blocks_results)} blocks with block_size={block_size}") return blocks_results # Fallback: try to decrypt with default block size 344, collecting all results block_size = 344 logger.warning(f"Falling back to block_size={block_size}, data length={len(encrypted_data)}") for i in range(0, len(encrypted_data), block_size): block = encrypted_data[i:i + block_size] if block: result = decrypt_data(block) if result: results.append(result) return results