""" Centralized Snowflake Authentication Utilities Supports multiple private key formats for cloud deployment flexibility: - Direct PEM format (multi-line) - Base64-encoded PEM (single line, recommended for HF Spaces) - Newline-escaped PEM (\\n replaced with actual newlines) """ import base64 from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.serialization import load_pem_private_key from supabase_client import get_admin_setting def _decode_private_key(raw_key: str) -> str: """ Decode private key from various formats to standard PEM. Supports: - Direct PEM (-----BEGIN ... -----END) - Base64-encoded PEM - Escaped newlines (\\n) Args: raw_key: Raw private key string from environment Returns: Decoded PEM string """ if not raw_key: return raw_key # Strip whitespace raw_key = raw_key.strip() # Already in PEM format if raw_key.startswith('-----BEGIN'): # Handle escaped newlines (common in some env var systems) if '\\n' in raw_key: raw_key = raw_key.replace('\\n', '\n') return raw_key # Try base64 decode try: decoded = base64.b64decode(raw_key).decode('utf-8') if decoded.startswith('-----BEGIN'): print("✅ Successfully decoded base64-encoded private key") return decoded except Exception: pass # Try base64 decode with padding fix try: # Add padding if needed padding = 4 - (len(raw_key) % 4) if padding != 4: raw_key_padded = raw_key + ('=' * padding) else: raw_key_padded = raw_key decoded = base64.b64decode(raw_key_padded).decode('utf-8') if decoded.startswith('-----BEGIN'): print("✅ Successfully decoded base64-encoded private key (with padding fix)") return decoded except Exception: pass # Return as-is if no decoding worked return raw_key def get_snowflake_connection_params(): """ Get standardized Snowflake connection parameters using key pair authentication. Returns: dict: Connection parameters for snowflake.connector.connect() """ # Source of truth: admin settings in Supabase private_key_raw = get_admin_setting('SNOWFLAKE_KP_PK') # Decode the private key from various formats private_key_pem = _decode_private_key(private_key_raw) # Load and process the private key (handle both encrypted and unencrypted) password = None if 'ENCRYPTED' in private_key_pem: # Try to get password from environment if key is encrypted password = get_admin_setting('SNOWFLAKE_KP_PASSPHRASE', required=False) if password: password = password.encode() print("✅ Using passphrase for encrypted private key") try: private_key = load_pem_private_key( private_key_pem.encode(), password=password, ) print("✅ Successfully loaded Snowflake private key") except Exception as e: if 'ENCRYPTED' in private_key_pem and not password: raise ValueError( "Private key is encrypted but SNOWFLAKE_KP_PASSPHRASE is missing. " "Set it in Admin Settings." ) else: raise ValueError( f"Could not load private key: {str(e)}\n" f"Key format detected: {'PEM' if private_key_pem.startswith('-----BEGIN') else 'Unknown'}\n" f"Key length: {len(private_key_pem)} characters" ) # Convert to DER format for Snowflake connector private_key_bytes = private_key.private_bytes( encoding=serialization.Encoding.DER, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption() ) # Return connection parameters schema_name = get_admin_setting('SNOWFLAKE_SCHEMA', required=False) or 'PUBLIC' return { 'user': get_admin_setting('SNOWFLAKE_KP_USER'), 'private_key': private_key_bytes, 'account': get_admin_setting('SNOWFLAKE_ACCOUNT'), 'role': get_admin_setting('SNOWFLAKE_ROLE', required=False), 'warehouse': get_admin_setting('SNOWFLAKE_WAREHOUSE'), 'database': get_admin_setting('SNOWFLAKE_DATABASE'), 'schema': schema_name, } def get_snowflake_connection(): """ Create a Snowflake connection using key pair authentication. Returns: snowflake.connector.SnowflakeConnection: Connected Snowflake session """ import snowflake.connector connection_params = get_snowflake_connection_params() # Validate required parameters required_params = ['user', 'private_key', 'account', 'warehouse', 'database'] missing_params = [param for param in required_params if not connection_params.get(param)] if missing_params: raise ValueError(f"Missing required Snowflake parameters: {missing_params}") return snowflake.connector.connect(**connection_params)