demoprep / snowflake_auth.py
mikeboone's picture
feat: March 2026 sprint β€” new vision merge, pipeline improvements, settings refactor
5ac32c1
"""
Centralized Snowflake Authentication Utilities
Supports multiple private key formats for cloud deployment flexibility:
- Direct PEM format (multi-line)
- Base64-encoded PEM (single line, recommended for HF Spaces)
- Newline-escaped PEM (\\n replaced with actual newlines)
"""
import base64
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.serialization import load_pem_private_key
from supabase_client import get_admin_setting
def _decode_private_key(raw_key: str) -> str:
"""
Decode private key from various formats to standard PEM.
Supports:
- Direct PEM (-----BEGIN ... -----END)
- Base64-encoded PEM
- Escaped newlines (\\n)
Args:
raw_key: Raw private key string from environment
Returns:
Decoded PEM string
"""
if not raw_key:
return raw_key
# Strip whitespace
raw_key = raw_key.strip()
# Already in PEM format
if raw_key.startswith('-----BEGIN'):
# Handle escaped newlines (common in some env var systems)
if '\\n' in raw_key:
raw_key = raw_key.replace('\\n', '\n')
return raw_key
# Try base64 decode
try:
decoded = base64.b64decode(raw_key).decode('utf-8')
if decoded.startswith('-----BEGIN'):
print("βœ… Successfully decoded base64-encoded private key")
return decoded
except Exception:
pass
# Try base64 decode with padding fix
try:
# Add padding if needed
padding = 4 - (len(raw_key) % 4)
if padding != 4:
raw_key_padded = raw_key + ('=' * padding)
else:
raw_key_padded = raw_key
decoded = base64.b64decode(raw_key_padded).decode('utf-8')
if decoded.startswith('-----BEGIN'):
print("βœ… Successfully decoded base64-encoded private key (with padding fix)")
return decoded
except Exception:
pass
# Return as-is if no decoding worked
return raw_key
def get_snowflake_connection_params():
"""
Get standardized Snowflake connection parameters using key pair authentication.
Returns:
dict: Connection parameters for snowflake.connector.connect()
"""
# Source of truth: admin settings in Supabase
private_key_raw = get_admin_setting('SNOWFLAKE_KP_PK')
# Decode the private key from various formats
private_key_pem = _decode_private_key(private_key_raw)
# Load and process the private key (handle both encrypted and unencrypted)
password = None
if 'ENCRYPTED' in private_key_pem:
# Try to get password from environment if key is encrypted
password = get_admin_setting('SNOWFLAKE_KP_PASSPHRASE', required=False)
if password:
password = password.encode()
print("βœ… Using passphrase for encrypted private key")
try:
private_key = load_pem_private_key(
private_key_pem.encode(),
password=password,
)
print("βœ… Successfully loaded Snowflake private key")
except Exception as e:
if 'ENCRYPTED' in private_key_pem and not password:
raise ValueError(
"Private key is encrypted but SNOWFLAKE_KP_PASSPHRASE is missing. "
"Set it in Admin Settings."
)
else:
raise ValueError(
f"Could not load private key: {str(e)}\n"
f"Key format detected: {'PEM' if private_key_pem.startswith('-----BEGIN') else 'Unknown'}\n"
f"Key length: {len(private_key_pem)} characters"
)
# Convert to DER format for Snowflake connector
private_key_bytes = private_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption()
)
# Return connection parameters
schema_name = get_admin_setting('SNOWFLAKE_SCHEMA', required=False) or 'PUBLIC'
return {
'user': get_admin_setting('SNOWFLAKE_KP_USER'),
'private_key': private_key_bytes,
'account': get_admin_setting('SNOWFLAKE_ACCOUNT'),
'role': get_admin_setting('SNOWFLAKE_ROLE', required=False),
'warehouse': get_admin_setting('SNOWFLAKE_WAREHOUSE'),
'database': get_admin_setting('SNOWFLAKE_DATABASE'),
'schema': schema_name,
}
def get_snowflake_connection():
"""
Create a Snowflake connection using key pair authentication.
Returns:
snowflake.connector.SnowflakeConnection: Connected Snowflake session
"""
import snowflake.connector
connection_params = get_snowflake_connection_params()
# Validate required parameters
required_params = ['user', 'private_key', 'account', 'warehouse', 'database']
missing_params = [param for param in required_params if not connection_params.get(param)]
if missing_params:
raise ValueError(f"Missing required Snowflake parameters: {missing_params}")
return snowflake.connector.connect(**connection_params)