tenderhub-webai-verification / scripts /apply-migrations.py
engresearch's picture
Upload folder using huggingface_hub
7f88bdf verified
from __future__ import annotations
import os
import re
import subprocess
from pathlib import Path
import psycopg
from psycopg import errors as pg_errors
def load_dotenv(path: Path) -> dict[str, str]:
values: dict[str, str] = {}
if not path.exists():
return values
for raw_line in path.read_text(encoding="utf-8").splitlines():
line = raw_line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
key, value = line.split("=", 1)
values[key.strip()] = value
return values
def load_runtime_env(repo_root: Path) -> dict[str, str]:
candidates = [repo_root / ".env.local", repo_root / ".env"]
loaded: dict[str, str] = {}
for candidate in candidates:
loaded.update(load_dotenv(candidate))
return loaded
def get_database_url(loaded: dict[str, str]) -> str:
value = os.getenv("DATABASE_URL", loaded.get("DATABASE_URL", "")).strip()
if (
(value.startswith('"') and value.endswith('"'))
or (value.startswith("'") and value.endswith("'"))
):
value = value[1:-1]
if value.startswith(":postgresql://"):
value = value[1:]
value = re.sub(
r"(postgres(?:ql)?://[^:/?#]+:)\[([^\]]+)\](@)",
r"\1\2\3",
value,
flags=re.IGNORECASE,
)
return value
def get_database_hostaddr(loaded: dict[str, str]) -> str:
return os.getenv("DATABASE_HOSTADDR", loaded.get("DATABASE_HOSTADDR", "")).strip()
def extract_host(database_url: str) -> str:
if "://" in database_url and "@" in database_url:
host_part = database_url.rsplit("@", 1)[1].split("/", 1)[0]
if host_part.startswith("["):
end = host_part.find("]")
return host_part[1:end] if end > 0 else ""
return host_part.split(":", 1)[0]
host_match = re.search(r"\bhost=([^\s]+)", database_url)
if host_match:
return host_match.group(1)
return ""
def resolve_host_with_nslookup(hostname: str) -> str:
if not hostname:
return ""
result = subprocess.run(
["nslookup", hostname],
capture_output=True,
text=True,
check=False,
)
if result.returncode != 0:
return ""
ipv4_matches = re.findall(r"\b(?:\d{1,3}\.){3}\d{1,3}\b", result.stdout)
if ipv4_matches:
return ipv4_matches[-1]
ipv6_matches = re.findall(r"\b(?:[0-9a-fA-F]{0,4}:){2,}[0-9a-fA-F]{0,4}\b", result.stdout)
if ipv6_matches:
return ipv6_matches[-1]
return ""
def connect_with_fallback(database_url: str, hostaddr_override: str):
try:
return psycopg.connect(database_url, autocommit=False)
except psycopg.OperationalError as exc:
message = str(exc)
if "getaddrinfo failed" not in message:
raise
host = extract_host(database_url)
fallback_hostaddr = hostaddr_override or resolve_host_with_nslookup(host)
if fallback_hostaddr:
return psycopg.connect(database_url, autocommit=False, hostaddr=fallback_hostaddr)
raise RuntimeError(
"Database host resolution failed. Add DATABASE_HOSTADDR in .env.local or "
"switch DATABASE_URL to a resolvable Supabase pooler host."
) from exc
def main() -> None:
repo_root = Path(__file__).resolve().parents[1]
loaded = load_runtime_env(repo_root)
database_url = get_database_url(loaded)
database_hostaddr = get_database_hostaddr(loaded)
if not database_url:
raise RuntimeError("DATABASE_URL is required. Set it in .env.local or the shell environment.")
migration_dir = repo_root / "supabase" / "migrations"
migration_files = sorted(migration_dir.glob("*.sql"))
if not migration_files:
print("No migration files found.")
return
with connect_with_fallback(database_url, database_hostaddr) as conn:
with conn.cursor() as cur:
cur.execute(
"""
create table if not exists public.__th_migrations (
filename text primary key,
applied_at timestamptz not null default now()
)
"""
)
conn.commit()
for file_path in migration_files:
filename = file_path.name
with conn.cursor() as cur:
cur.execute(
"select 1 from public.__th_migrations where filename = %s",
(filename,),
)
already_applied = cur.fetchone() is not None
if already_applied:
print(f"skip {filename}")
continue
sql = file_path.read_text(encoding="utf-8")
try:
with conn.transaction():
with conn.cursor() as cur:
cur.execute(sql)
cur.execute(
"insert into public.__th_migrations (filename) values (%s)",
(filename,),
)
print(f"applied {filename}")
except (
pg_errors.DuplicateObject,
pg_errors.DuplicateTable,
pg_errors.DuplicateFunction,
pg_errors.DuplicateColumn,
) as exc:
with conn.transaction():
with conn.cursor() as cur:
cur.execute(
"insert into public.__th_migrations (filename) values (%s) on conflict (filename) do nothing",
(filename,),
)
print(f"assumed-applied {filename} ({exc.__class__.__name__}: {exc})")
print("Migrations complete.")
if __name__ == "__main__":
main()