Spaces:
Running
Running
File size: 6,834 Bytes
3972bf0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 | import os
import ssl
import re
import sys
import logging
import urllib.parse
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy.engine.url import make_url
def debug_print(msg):
sys.stderr.write(f"--- DB_DEBUG: {msg} ---\n")
sys.stderr.flush()
debug_print("Loading database.py")
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# Define default PostgreSQL fallback DBs for local development if variables missing
DEFAULT_AUTH_URL = "postgresql+pg8000://postgres:root@localhost:5432/auth_db"
DEFAULT_MANDI_URL = "postgresql+pg8000://postgres:root@localhost:5432/mandi_db"
# Retrieve DB URLs - strip whitespace as Render/Prisma/Neon can sometimes have it
AUTH_DATABASE_URL = os.getenv("AUTH_DATABASE_URL", DEFAULT_AUTH_URL).strip()
MANDI_DATABASE_URL = os.getenv("MANDI_DATABASE_URL", DEFAULT_MANDI_URL).strip()
# Helper to clean and format URLs
def format_db_url(name, url: str) -> str:
if not url:
debug_print(f"{name} is EMPTY")
return ""
url = url.strip()
# If it looks like a key-value string (Supabase style), parse it
if "user=" in url and "host=" in url:
debug_print(f"Detected key-value format for {name}. Attempting to parse...")
try:
# Match assignments like key=value or key = value
# We handle potential newlines or multiple spaces between pairs
kv = {}
# Use regex to find all key=value pairs, even if values have special chars
matches = re.findall(r'(\w+)\s*=\s*([^\s]+)', url)
for k, v in matches:
kv[k.lower()] = v
if all(k in kv for k in ['user', 'password', 'host', 'dbname']):
port = kv.get('port', '5432')
# Escape password to handle special chars like @ or :
safe_password = urllib.parse.quote_plus(kv['password'])
# For Supabase, we default to psycopg2
url = f"postgresql+psycopg2://{kv['user']}:{safe_password}@{kv['host']}:{port}/{kv['dbname']}"
debug_print(f"Parsed {name} into SQLAlchemy format (with encoded password).")
else:
debug_print(f"Incomplete key-value pairs for {name}: {list(kv.keys())}")
except Exception as e:
debug_print(f"Failed to parse key-value string for {name}: {e}")
# Standardize dialect
is_supabase = "supabase" in url.lower()
dialect = "+psycopg2" if is_supabase else "+pg8000"
# Standardize scheme using regex to be robust against variations
if re.match(r"^postgres(ql)?(\+\w+)?://", url):
url = re.sub(r"^postgres(ql)?(\+\w+)?://", f"postgresql{dialect}://", url, count=1)
elif not url.startswith("postgresql"):
# If it doesn't have a protocol at all after parsing attempts, we assume it's just raw
# but create_engine will still fail later if it's not a URL.
pass
return url
# Retrieve and clean DB URLs
AUTH_RAW = os.getenv("AUTH_DATABASE_URL", DEFAULT_AUTH_URL)
MANDI_RAW = os.getenv("MANDI_DATABASE_URL", DEFAULT_MANDI_URL)
AUTH_DATABASE_URL = format_db_url("AUTH", AUTH_RAW)
MANDI_DATABASE_URL = format_db_url("MANDI", MANDI_RAW)
if not AUTH_DATABASE_URL:
raise ValueError("AUTH_DATABASE_URL is not set or empty.")
if not MANDI_DATABASE_URL:
raise ValueError("MANDI_DATABASE_URL is not set or empty.")
# Args for Postgres
# We add pool_recycle=1800 to recycle connections older than 30 minutes,
# preventing them from being dropped quietly by the database server.
auth_engine_args = {"pool_size": 10, "max_overflow": 20, "pool_pre_ping": True, "pool_recycle": 1800, "connect_args": {}}
mandi_engine_args = {"pool_size": 20, "max_overflow": 30, "pool_pre_ping": True, "pool_recycle": 1800, "connect_args": {}}
# For Remote DBs, we handle SSL context manually ONLY for pg8000
# Psycopg2 (Supabase) handles SSL via the connection string (?sslmode=require)
def apply_ssl_if_needed(url: str, engine_args: dict):
# Only apply to external hosts
is_external = any(host in url for host in ["neon.tech", "supabase", "aws.com", "elephantsql.com"])
if is_external:
# If using pg8000, we must strip params and use ssl_context
if "pg8000" in url:
cleaned_url = url.split("?")[0]
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
engine_args["connect_args"] = {"ssl_context": ssl_context}
return cleaned_url
# If using psycopg2 (Supabase)
if "psycopg2" in url:
# Render networking can be tricky with Supabase IPv6 on port 5432
# Connection pooler on 6543 is generally more stable.
# We automatically switch to 6543 ONLY if we're on Render (detected by RENDER env var)
# Ensure sslmode=require is present for security and stability
if "sslmode" not in url:
separator = "&" if "?" in url else "?"
url = f"{url}{separator}sslmode=require"
if ":6543" in url:
debug_print("Using Supabase Pooler (6543). Ensuring compatibility parameters.")
pass
return url
AUTH_DATABASE_URL = apply_ssl_if_needed(AUTH_DATABASE_URL, auth_engine_args)
MANDI_DATABASE_URL = apply_ssl_if_needed(MANDI_DATABASE_URL, mandi_engine_args)
def safe_create_engine(name, url, args):
try:
# Pre-validate with make_url
u = make_url(url)
debug_print(f"Creating {name} engine (Driver: {u.drivername}, Host: {u.host}, Port: {u.port})")
# We DON'T test connection here because it might block app startup
# or fail if network is temporarily down. SQLAlchemy handles reconnection.
engine = create_engine(url, **args)
debug_print(f"Engine {name} created successfully.")
return engine
except Exception as e:
debug_print(f"CRITICAL ERROR in {name} engine creation: {str(e)}")
# We still return the engine if possible or raise if it's a structural error
raise e
auth_engine = safe_create_engine("AUTH", AUTH_DATABASE_URL, auth_engine_args)
mandi_engine = safe_create_engine("MANDI", MANDI_DATABASE_URL, mandi_engine_args)
AuthSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=auth_engine)
MandiSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=mandi_engine)
AuthBase = declarative_base()
MandiBase = declarative_base()
def get_auth_db():
db = AuthSessionLocal()
try:
yield db
finally:
db.close()
def get_mandi_db():
db = MandiSessionLocal()
try:
yield db
finally:
db.close()
|