from __future__ import annotations import os import re import subprocess from pathlib import Path import psycopg from psycopg import errors as pg_errors def load_dotenv(path: Path) -> dict[str, str]: values: dict[str, str] = {} if not path.exists(): return values for raw_line in path.read_text(encoding="utf-8").splitlines(): line = raw_line.strip() if not line or line.startswith("#") or "=" not in line: continue key, value = line.split("=", 1) values[key.strip()] = value return values def load_runtime_env(repo_root: Path) -> dict[str, str]: candidates = [repo_root / ".env.local", repo_root / ".env"] loaded: dict[str, str] = {} for candidate in candidates: loaded.update(load_dotenv(candidate)) return loaded def get_database_url(loaded: dict[str, str]) -> str: value = os.getenv("DATABASE_URL", loaded.get("DATABASE_URL", "")).strip() if ( (value.startswith('"') and value.endswith('"')) or (value.startswith("'") and value.endswith("'")) ): value = value[1:-1] if value.startswith(":postgresql://"): value = value[1:] value = re.sub( r"(postgres(?:ql)?://[^:/?#]+:)\[([^\]]+)\](@)", r"\1\2\3", value, flags=re.IGNORECASE, ) return value def get_database_hostaddr(loaded: dict[str, str]) -> str: return os.getenv("DATABASE_HOSTADDR", loaded.get("DATABASE_HOSTADDR", "")).strip() def extract_host(database_url: str) -> str: if "://" in database_url and "@" in database_url: host_part = database_url.rsplit("@", 1)[1].split("/", 1)[0] if host_part.startswith("["): end = host_part.find("]") return host_part[1:end] if end > 0 else "" return host_part.split(":", 1)[0] host_match = re.search(r"\bhost=([^\s]+)", database_url) if host_match: return host_match.group(1) return "" def resolve_host_with_nslookup(hostname: str) -> str: if not hostname: return "" result = subprocess.run( ["nslookup", hostname], capture_output=True, text=True, check=False, ) if result.returncode != 0: return "" ipv4_matches = re.findall(r"\b(?:\d{1,3}\.){3}\d{1,3}\b", result.stdout) if ipv4_matches: return ipv4_matches[-1] ipv6_matches = re.findall(r"\b(?:[0-9a-fA-F]{0,4}:){2,}[0-9a-fA-F]{0,4}\b", result.stdout) if ipv6_matches: return ipv6_matches[-1] return "" def connect_with_fallback(database_url: str, hostaddr_override: str): try: return psycopg.connect(database_url, autocommit=False) except psycopg.OperationalError as exc: message = str(exc) if "getaddrinfo failed" not in message: raise host = extract_host(database_url) fallback_hostaddr = hostaddr_override or resolve_host_with_nslookup(host) if fallback_hostaddr: return psycopg.connect(database_url, autocommit=False, hostaddr=fallback_hostaddr) raise RuntimeError( "Database host resolution failed. Add DATABASE_HOSTADDR in .env.local or " "switch DATABASE_URL to a resolvable Supabase pooler host." ) from exc def main() -> None: repo_root = Path(__file__).resolve().parents[1] loaded = load_runtime_env(repo_root) database_url = get_database_url(loaded) database_hostaddr = get_database_hostaddr(loaded) if not database_url: raise RuntimeError("DATABASE_URL is required. Set it in .env.local or the shell environment.") migration_dir = repo_root / "supabase" / "migrations" migration_files = sorted(migration_dir.glob("*.sql")) if not migration_files: print("No migration files found.") return with connect_with_fallback(database_url, database_hostaddr) as conn: with conn.cursor() as cur: cur.execute( """ create table if not exists public.__th_migrations ( filename text primary key, applied_at timestamptz not null default now() ) """ ) conn.commit() for file_path in migration_files: filename = file_path.name with conn.cursor() as cur: cur.execute( "select 1 from public.__th_migrations where filename = %s", (filename,), ) already_applied = cur.fetchone() is not None if already_applied: print(f"skip {filename}") continue sql = file_path.read_text(encoding="utf-8") try: with conn.transaction(): with conn.cursor() as cur: cur.execute(sql) cur.execute( "insert into public.__th_migrations (filename) values (%s)", (filename,), ) print(f"applied {filename}") except ( pg_errors.DuplicateObject, pg_errors.DuplicateTable, pg_errors.DuplicateFunction, pg_errors.DuplicateColumn, ) as exc: with conn.transaction(): with conn.cursor() as cur: cur.execute( "insert into public.__th_migrations (filename) values (%s) on conflict (filename) do nothing", (filename,), ) print(f"assumed-applied {filename} ({exc.__class__.__name__}: {exc})") print("Migrations complete.") if __name__ == "__main__": main()