| from __future__ import annotations |
|
|
| import os |
| import re |
| import subprocess |
| from pathlib import Path |
|
|
| import psycopg |
| from psycopg import errors as pg_errors |
|
|
|
|
| def load_dotenv(path: Path) -> dict[str, str]: |
| values: dict[str, str] = {} |
| if not path.exists(): |
| return values |
|
|
| for raw_line in path.read_text(encoding="utf-8").splitlines(): |
| line = raw_line.strip() |
| if not line or line.startswith("#") or "=" not in line: |
| continue |
| key, value = line.split("=", 1) |
| values[key.strip()] = value |
| return values |
|
|
|
|
| def load_runtime_env(repo_root: Path) -> dict[str, str]: |
| candidates = [repo_root / ".env.local", repo_root / ".env"] |
| loaded: dict[str, str] = {} |
| for candidate in candidates: |
| loaded.update(load_dotenv(candidate)) |
| return loaded |
|
|
|
|
| def get_database_url(loaded: dict[str, str]) -> str: |
| value = os.getenv("DATABASE_URL", loaded.get("DATABASE_URL", "")).strip() |
| if ( |
| (value.startswith('"') and value.endswith('"')) |
| or (value.startswith("'") and value.endswith("'")) |
| ): |
| value = value[1:-1] |
| if value.startswith(":postgresql://"): |
| value = value[1:] |
| value = re.sub( |
| r"(postgres(?:ql)?://[^:/?#]+:)\[([^\]]+)\](@)", |
| r"\1\2\3", |
| value, |
| flags=re.IGNORECASE, |
| ) |
| return value |
|
|
|
|
| def get_database_hostaddr(loaded: dict[str, str]) -> str: |
| return os.getenv("DATABASE_HOSTADDR", loaded.get("DATABASE_HOSTADDR", "")).strip() |
|
|
|
|
| def extract_host(database_url: str) -> str: |
| if "://" in database_url and "@" in database_url: |
| host_part = database_url.rsplit("@", 1)[1].split("/", 1)[0] |
| if host_part.startswith("["): |
| end = host_part.find("]") |
| return host_part[1:end] if end > 0 else "" |
| return host_part.split(":", 1)[0] |
|
|
| host_match = re.search(r"\bhost=([^\s]+)", database_url) |
| if host_match: |
| return host_match.group(1) |
|
|
| return "" |
|
|
|
|
| def resolve_host_with_nslookup(hostname: str) -> str: |
| if not hostname: |
| return "" |
|
|
| result = subprocess.run( |
| ["nslookup", hostname], |
| capture_output=True, |
| text=True, |
| check=False, |
| ) |
| if result.returncode != 0: |
| return "" |
|
|
| ipv4_matches = re.findall(r"\b(?:\d{1,3}\.){3}\d{1,3}\b", result.stdout) |
| if ipv4_matches: |
| return ipv4_matches[-1] |
|
|
| ipv6_matches = re.findall(r"\b(?:[0-9a-fA-F]{0,4}:){2,}[0-9a-fA-F]{0,4}\b", result.stdout) |
| if ipv6_matches: |
| return ipv6_matches[-1] |
|
|
| return "" |
|
|
|
|
| def connect_with_fallback(database_url: str, hostaddr_override: str): |
| try: |
| return psycopg.connect(database_url, autocommit=False) |
| except psycopg.OperationalError as exc: |
| message = str(exc) |
| if "getaddrinfo failed" not in message: |
| raise |
|
|
| host = extract_host(database_url) |
| fallback_hostaddr = hostaddr_override or resolve_host_with_nslookup(host) |
| if fallback_hostaddr: |
| return psycopg.connect(database_url, autocommit=False, hostaddr=fallback_hostaddr) |
|
|
| raise RuntimeError( |
| "Database host resolution failed. Add DATABASE_HOSTADDR in .env.local or " |
| "switch DATABASE_URL to a resolvable Supabase pooler host." |
| ) from exc |
|
|
|
|
| def main() -> None: |
| repo_root = Path(__file__).resolve().parents[1] |
| loaded = load_runtime_env(repo_root) |
| database_url = get_database_url(loaded) |
| database_hostaddr = get_database_hostaddr(loaded) |
| if not database_url: |
| raise RuntimeError("DATABASE_URL is required. Set it in .env.local or the shell environment.") |
|
|
| migration_dir = repo_root / "supabase" / "migrations" |
| migration_files = sorted(migration_dir.glob("*.sql")) |
| if not migration_files: |
| print("No migration files found.") |
| return |
|
|
| with connect_with_fallback(database_url, database_hostaddr) as conn: |
| with conn.cursor() as cur: |
| cur.execute( |
| """ |
| create table if not exists public.__th_migrations ( |
| filename text primary key, |
| applied_at timestamptz not null default now() |
| ) |
| """ |
| ) |
| conn.commit() |
|
|
| for file_path in migration_files: |
| filename = file_path.name |
| with conn.cursor() as cur: |
| cur.execute( |
| "select 1 from public.__th_migrations where filename = %s", |
| (filename,), |
| ) |
| already_applied = cur.fetchone() is not None |
|
|
| if already_applied: |
| print(f"skip {filename}") |
| continue |
|
|
| sql = file_path.read_text(encoding="utf-8") |
| try: |
| with conn.transaction(): |
| with conn.cursor() as cur: |
| cur.execute(sql) |
| cur.execute( |
| "insert into public.__th_migrations (filename) values (%s)", |
| (filename,), |
| ) |
| print(f"applied {filename}") |
| except ( |
| pg_errors.DuplicateObject, |
| pg_errors.DuplicateTable, |
| pg_errors.DuplicateFunction, |
| pg_errors.DuplicateColumn, |
| ) as exc: |
| with conn.transaction(): |
| with conn.cursor() as cur: |
| cur.execute( |
| "insert into public.__th_migrations (filename) values (%s) on conflict (filename) do nothing", |
| (filename,), |
| ) |
| print(f"assumed-applied {filename} ({exc.__class__.__name__}: {exc})") |
|
|
| print("Migrations complete.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|