Spaces:
Running
Running
| """ | |
| Centralized Snowflake Authentication Utilities | |
| Supports multiple private key formats for cloud deployment flexibility: | |
| - Direct PEM format (multi-line) | |
| - Base64-encoded PEM (single line, recommended for HF Spaces) | |
| - Newline-escaped PEM (\\n replaced with actual newlines) | |
| """ | |
| import base64 | |
| from cryptography.hazmat.primitives import serialization | |
| from cryptography.hazmat.primitives.serialization import load_pem_private_key | |
| from supabase_client import get_admin_setting | |
| def _decode_private_key(raw_key: str) -> str: | |
| """ | |
| Decode private key from various formats to standard PEM. | |
| Supports: | |
| - Direct PEM (-----BEGIN ... -----END) | |
| - Base64-encoded PEM | |
| - Escaped newlines (\\n) | |
| Args: | |
| raw_key: Raw private key string from environment | |
| Returns: | |
| Decoded PEM string | |
| """ | |
| if not raw_key: | |
| return raw_key | |
| # Strip whitespace | |
| raw_key = raw_key.strip() | |
| # Already in PEM format | |
| if raw_key.startswith('-----BEGIN'): | |
| # Handle escaped newlines (common in some env var systems) | |
| if '\\n' in raw_key: | |
| raw_key = raw_key.replace('\\n', '\n') | |
| return raw_key | |
| # Try base64 decode | |
| try: | |
| decoded = base64.b64decode(raw_key).decode('utf-8') | |
| if decoded.startswith('-----BEGIN'): | |
| print("β Successfully decoded base64-encoded private key") | |
| return decoded | |
| except Exception: | |
| pass | |
| # Try base64 decode with padding fix | |
| try: | |
| # Add padding if needed | |
| padding = 4 - (len(raw_key) % 4) | |
| if padding != 4: | |
| raw_key_padded = raw_key + ('=' * padding) | |
| else: | |
| raw_key_padded = raw_key | |
| decoded = base64.b64decode(raw_key_padded).decode('utf-8') | |
| if decoded.startswith('-----BEGIN'): | |
| print("β Successfully decoded base64-encoded private key (with padding fix)") | |
| return decoded | |
| except Exception: | |
| pass | |
| # Return as-is if no decoding worked | |
| return raw_key | |
| def get_snowflake_connection_params(): | |
| """ | |
| Get standardized Snowflake connection parameters using key pair authentication. | |
| Returns: | |
| dict: Connection parameters for snowflake.connector.connect() | |
| """ | |
| # Source of truth: admin settings in Supabase | |
| private_key_raw = get_admin_setting('SNOWFLAKE_KP_PK') | |
| # Decode the private key from various formats | |
| private_key_pem = _decode_private_key(private_key_raw) | |
| # Load and process the private key (handle both encrypted and unencrypted) | |
| password = None | |
| if 'ENCRYPTED' in private_key_pem: | |
| # Try to get password from environment if key is encrypted | |
| password = get_admin_setting('SNOWFLAKE_KP_PASSPHRASE', required=False) | |
| if password: | |
| password = password.encode() | |
| print("β Using passphrase for encrypted private key") | |
| try: | |
| private_key = load_pem_private_key( | |
| private_key_pem.encode(), | |
| password=password, | |
| ) | |
| print("β Successfully loaded Snowflake private key") | |
| except Exception as e: | |
| if 'ENCRYPTED' in private_key_pem and not password: | |
| raise ValueError( | |
| "Private key is encrypted but SNOWFLAKE_KP_PASSPHRASE is missing. " | |
| "Set it in Admin Settings." | |
| ) | |
| else: | |
| raise ValueError( | |
| f"Could not load private key: {str(e)}\n" | |
| f"Key format detected: {'PEM' if private_key_pem.startswith('-----BEGIN') else 'Unknown'}\n" | |
| f"Key length: {len(private_key_pem)} characters" | |
| ) | |
| # Convert to DER format for Snowflake connector | |
| private_key_bytes = private_key.private_bytes( | |
| encoding=serialization.Encoding.DER, | |
| format=serialization.PrivateFormat.PKCS8, | |
| encryption_algorithm=serialization.NoEncryption() | |
| ) | |
| # Return connection parameters | |
| schema_name = get_admin_setting('SNOWFLAKE_SCHEMA', required=False) or 'PUBLIC' | |
| return { | |
| 'user': get_admin_setting('SNOWFLAKE_KP_USER'), | |
| 'private_key': private_key_bytes, | |
| 'account': get_admin_setting('SNOWFLAKE_ACCOUNT'), | |
| 'role': get_admin_setting('SNOWFLAKE_ROLE', required=False), | |
| 'warehouse': get_admin_setting('SNOWFLAKE_WAREHOUSE'), | |
| 'database': get_admin_setting('SNOWFLAKE_DATABASE'), | |
| 'schema': schema_name, | |
| } | |
| def get_snowflake_connection(): | |
| """ | |
| Create a Snowflake connection using key pair authentication. | |
| Returns: | |
| snowflake.connector.SnowflakeConnection: Connected Snowflake session | |
| """ | |
| import snowflake.connector | |
| connection_params = get_snowflake_connection_params() | |
| # Validate required parameters | |
| required_params = ['user', 'private_key', 'account', 'warehouse', 'database'] | |
| missing_params = [param for param in required_params if not connection_params.get(param)] | |
| if missing_params: | |
| raise ValueError(f"Missing required Snowflake parameters: {missing_params}") | |
| return snowflake.connector.connect(**connection_params) | |