""" database.py Responsibility: Supabase persistence for both the embedding workbench and the TCCM screening pipeline. Connection: requires SUPABASE_URL and SUPABASE_SERVICE_ROLE_KEY environment variables. On Hugging Face Spaces these are set in Settings -> Secrets. The app fails fast with a clear error if either is missing. Schema management: Supabase intentionally restricts CREATE TABLE from the client SDK. The researcher runs the schema SQL (via `bootstrap_schema_sql()`) once in the Supabase SQL editor; subsequent app runs find tables in place. The UI surfaces missing tables explicitly via `assert_schema_present()`. Why service-role key (not anon key): Unrestricted reads/writes across all tables in this single-researcher workbench need the service-role key. The anon key is bound by Row Level Security and would require RLS policies for every table — overkill here. """ from __future__ import annotations import os import json import datetime from typing import Any from supabase import create_client, Client # ---------------------------------------------------------- env / client def _require_env() -> tuple[str, str]: url = os.environ.get("SUPABASE_URL") key = os.environ.get("SUPABASE_SERVICE_ROLE_KEY") if not url or not key: raise RuntimeError( "Missing Supabase credentials. Set SUPABASE_URL and " "SUPABASE_SERVICE_ROLE_KEY in the environment " "(HF Space Settings -> Variables and secrets -> Secrets)." ) return url, key _client: Client | None = None def client() -> Client: global _client if _client is None: url, key = _require_env() _client = create_client(url, key) return _client # ---------------------------------------------------------- schema SCHEMA_SQL = """ -- Run this entire block once in the Supabase SQL editor -- (Project -> SQL Editor -> New query -> paste -> Run). create table if not exists papers ( paper_id text primary key, doi text, title text not null, abstract text, year integer, author_keywords text, source text default 'manual_upload', created_at timestamptz default now() ); create table if not exists embeddings_2d ( paper_id text primary key references papers(paper_id) on delete cascade, coords_json text not null ); create table if not exists clustering_runs ( run_id bigserial primary key, params_json text not null, n_clusters integer, n_noise integer, notes text, created_at timestamptz default now() ); create table if not exists cluster_assignments ( run_id bigint not null references clustering_runs(run_id) on delete cascade, paper_id text not null references papers(paper_id) on delete cascade, cluster_id integer not null, membership_prob double precision not null, primary key (run_id, paper_id) ); create table if not exists cluster_labels ( run_id bigint not null references clustering_runs(run_id) on delete cascade, cluster_id integer not null, label text, subject text, object_ text, phenomenon text, rationale text, top_paper_ids_json text, validated_by_researcher boolean default false, locked boolean default false, researcher_notes text, primary key (run_id, cluster_id) ); create table if not exists tccm_runs ( run_id bigserial primary key, pattern_set_version text not null, threshold_rule text not null default 'anti_dominates', n_papers integer, n_include integer, n_exclude integer, n_marginal integer, notes text, created_at timestamptz default now() ); create table if not exists tccm_classifications ( run_id bigint not null references tccm_runs(run_id) on delete cascade, paper_id text not null references papers(paper_id) on delete cascade, verdict text not null check (verdict in ('INCLUDE','EXCLUDE','MARGINAL')), n_method integer not null default 0, n_sample integer not null default 0, n_analytic integer not null default 0, n_anti integer not null default 0, fired_terms_json text, primary key (run_id, paper_id) ); create table if not exists tccm_marginal_reviews ( run_id bigint not null references tccm_runs(run_id) on delete cascade, paper_id text not null references papers(paper_id) on delete cascade, agent_verdict text, agent_rationale text, researcher_verdict text, researcher_notes text, reviewed_at timestamptz default now(), primary key (run_id, paper_id) ); create table if not exists pdf_downloads ( paper_id text primary key references papers(paper_id) on delete cascade, doi text, oa_status text, oa_type text, source text, pdf_url text, scihub_used boolean default false, downloaded boolean default false, file_path text, file_size_bytes bigint, uploaded_to_supabase boolean default false, supabase_storage_path text, error_message text, checked_at timestamptz, downloaded_at timestamptz ); create table if not exists crossref_metadata ( paper_id text primary key references papers(paper_id) on delete cascade, doi text, crossref_title text, crossref_abstract text, citation_count integer, references_count integer, reference_dois text, publication_type text, publisher text, container_title text, error_message text, fetched_at timestamptz default now() ); """ def bootstrap_schema_sql() -> str: """Return the full schema SQL for one-time manual execution.""" return SCHEMA_SQL def assert_schema_present() -> dict[str, bool]: """ Verify all expected tables exist by attempting bounded reads. Returns {table_name: present}. Does not raise; the UI handles missing tables by showing the bootstrap SQL. """ expected = [ "papers", "embeddings_2d", "clustering_runs", "cluster_assignments", "cluster_labels", "tccm_runs", "tccm_classifications", "tccm_marginal_reviews", "pdf_downloads", "crossref_metadata", ] report: dict[str, bool] = {} c = client() for tbl in expected: try: c.table(tbl).select("*").limit(1).execute() report[tbl] = True except Exception: report[tbl] = False return report # ---------------------------------------------------------- papers def _scrub_nan(v): """Convert any NaN-like value to None so it's JSON-encodable. Handles pandas NaN (float), numpy NaN, and plain string 'nan'.""" try: import math if v is None: return None if isinstance(v, float) and math.isnan(v): return None # pandas NaT, numpy.nan in object columns if v != v: # NaN is the only value that is not equal to itself return None except Exception: pass return v def upsert_papers(rows: list[dict]) -> int: if not rows: return 0 payload = [] for r in rows: doi = _scrub_nan(r.get("doi")) absr = _scrub_nan(r.get("abstract")) keyw = _scrub_nan(r.get("author_keywords")) year = _scrub_nan(r.get("year")) title = _scrub_nan(r.get("title")) payload.append({ "paper_id": r["paper_id"], "doi": doi if doi is not None else "", "title": title if title is not None else "", "abstract": absr if absr is not None else "", "year": int(year) if year is not None else None, "author_keywords": keyw if keyw is not None else "", }) chunk = 500 for i in range(0, len(payload), chunk): (client().table("papers") .upsert(payload[i:i + chunk], on_conflict="paper_id").execute()) return len(payload) def get_all_papers() -> list[dict]: out: list[dict] = [] page_size = 1000 offset = 0 while True: res = (client().table("papers").select("*") .order("paper_id") .range(offset, offset + page_size - 1).execute()) batch = res.data or [] out.extend(batch) if len(batch) < page_size: break offset += page_size return out def get_paper(paper_id: str) -> dict | None: res = (client().table("papers") .select("*").eq("paper_id", paper_id).limit(1).execute()) return (res.data or [None])[0] def count_papers() -> int: res = client().table("papers").select("paper_id", count="exact").limit(1).execute() return res.count or 0 # ---------------------------------------------------------- 2D coords def save_2d_coords_bulk(rows: list[dict]) -> None: if not rows: return payload = [ {"paper_id": r["paper_id"], "coords_json": json.dumps(r["coords"])} for r in rows ] chunk = 500 for i in range(0, len(payload), chunk): (client().table("embeddings_2d") .upsert(payload[i:i + chunk], on_conflict="paper_id").execute()) def get_2d_coords() -> dict[str, list[float]]: res = client().table("embeddings_2d").select("*").execute() return {r["paper_id"]: json.loads(r["coords_json"]) for r in (res.data or [])} # ---------------------------------------------------------- clustering runs def create_clustering_run(params: dict, n_clusters: int, n_noise: int, notes: str = "") -> int: res = client().table("clustering_runs").insert({ "params_json": json.dumps(params), "n_clusters": int(n_clusters), "n_noise": int(n_noise), "notes": notes, }).execute() return res.data[0]["run_id"] def list_clustering_runs() -> list[dict]: res = (client().table("clustering_runs").select("*") .order("run_id", desc=True).execute()) out = [] for r in res.data or []: out.append({ "run_id": r["run_id"], "params": json.loads(r["params_json"]), "n_clusters": r["n_clusters"], "n_noise": r["n_noise"], "created_at": r["created_at"], "notes": r.get("notes", "") or "", }) return out def save_cluster_assignments(run_id: int, paper_ids: list[str], cluster_ids: list[int], probs: list[float]) -> None: payload = [ {"run_id": int(run_id), "paper_id": pid, "cluster_id": int(cid), "membership_prob": float(p)} for pid, cid, p in zip(paper_ids, cluster_ids, probs) ] chunk = 500 for i in range(0, len(payload), chunk): (client().table("cluster_assignments") .upsert(payload[i:i + chunk], on_conflict="run_id,paper_id").execute()) def get_cluster_assignments(run_id: int) -> list[dict]: res = (client().table("cluster_assignments") .select("*, papers(title, abstract, year)") .eq("run_id", int(run_id)).execute()) out = [] for r in res.data or []: p = r.get("papers") or {} out.append({ "paper_id": r["paper_id"], "cluster_id": r["cluster_id"], "membership_prob": r["membership_prob"], "title": p.get("title", ""), "abstract": p.get("abstract", ""), "year": p.get("year"), }) out.sort(key=lambda d: (d["cluster_id"], -d["membership_prob"])) return out # ---------------------------------------------------------- cluster labels def save_cluster_label(run_id: int, cluster_id: int, label: str, subject: str = "", object_: str = "", phenomenon: str = "", rationale: str = "", top_paper_ids: list[str] | None = None, validated_by_researcher: bool = False, locked: bool = False, researcher_notes: str = "") -> None: client().table("cluster_labels").upsert({ "run_id": int(run_id), "cluster_id": int(cluster_id), "label": label, "subject": subject, "object_": object_, "phenomenon": phenomenon, "rationale": rationale, "top_paper_ids_json": json.dumps(top_paper_ids or []), "validated_by_researcher": bool(validated_by_researcher), "locked": bool(locked), "researcher_notes": researcher_notes, }, on_conflict="run_id,cluster_id").execute() def get_cluster_labels(run_id: int) -> list[dict]: res = (client().table("cluster_labels") .select("*").eq("run_id", int(run_id)) .order("cluster_id").execute()) out = [] for r in res.data or []: out.append({ "cluster_id": r["cluster_id"], "label": r.get("label") or "", "subject": r.get("subject") or "", "object": r.get("object_") or "", "phenomenon": r.get("phenomenon") or "", "rationale": r.get("rationale") or "", "top_paper_ids": json.loads(r.get("top_paper_ids_json") or "[]"), "validated_by_researcher": bool(r.get("validated_by_researcher")), "locked": bool(r.get("locked")), "researcher_notes": r.get("researcher_notes") or "", }) return out # ---------------------------------------------------------- TCCM runs def create_tccm_run(pattern_set_version: str, threshold_rule: str, n_papers: int, n_include: int, n_exclude: int, n_marginal: int, notes: str = "") -> int: res = client().table("tccm_runs").insert({ "pattern_set_version": pattern_set_version, "threshold_rule": threshold_rule, "n_papers": int(n_papers), "n_include": int(n_include), "n_exclude": int(n_exclude), "n_marginal": int(n_marginal), "notes": notes, }).execute() return res.data[0]["run_id"] def list_tccm_runs() -> list[dict]: res = (client().table("tccm_runs").select("*") .order("run_id", desc=True).execute()) return res.data or [] def save_tccm_classifications(run_id: int, rows: list[dict]) -> None: payload = [ { "run_id": int(run_id), "paper_id": r["paper_id"], "verdict": r["verdict"], "n_method": int(r.get("n_method", 0)), "n_sample": int(r.get("n_sample", 0)), "n_analytic": int(r.get("n_analytic", 0)), "n_anti": int(r.get("n_anti", 0)), "fired_terms_json": r.get("fired_terms_json", "{}"), } for r in rows ] chunk = 500 for i in range(0, len(payload), chunk): (client().table("tccm_classifications") .upsert(payload[i:i + chunk], on_conflict="run_id,paper_id").execute()) def get_tccm_classifications(run_id: int, verdict_filter: str | None = None) -> list[dict]: q = (client().table("tccm_classifications") .select("*, papers(title, abstract, year, author_keywords)") .eq("run_id", int(run_id))) if verdict_filter: q = q.eq("verdict", verdict_filter) res = q.execute() out = [] for r in res.data or []: p = r.get("papers") or {} out.append({ "paper_id": r["paper_id"], "verdict": r["verdict"], "n_method": r["n_method"], "n_sample": r["n_sample"], "n_analytic": r["n_analytic"], "n_anti": r["n_anti"], "fired_terms": json.loads(r.get("fired_terms_json") or "{}"), "title": p.get("title", ""), "abstract": p.get("abstract", ""), "year": p.get("year"), "author_keywords": p.get("author_keywords") or "", }) return out # ---------------------------------------------------------- marginal reviews def save_marginal_review(run_id: int, paper_id: str, agent_verdict: str, agent_rationale: str, researcher_verdict: str | None = None, researcher_notes: str = "") -> None: client().table("tccm_marginal_reviews").upsert({ "run_id": int(run_id), "paper_id": paper_id, "agent_verdict": agent_verdict, "agent_rationale": agent_rationale, "researcher_verdict": researcher_verdict, "researcher_notes": researcher_notes, "reviewed_at": datetime.datetime.utcnow().isoformat(), }, on_conflict="run_id,paper_id").execute() def get_marginal_reviews(run_id: int) -> list[dict]: res = (client().table("tccm_marginal_reviews") .select("*").eq("run_id", int(run_id)).execute()) return res.data or [] # ---------------------------------------------------------- PDF downloads def save_pdf_discoveries(rows: list[dict]) -> None: """rows: list from pdf_downloader.bulk_discover.""" payload = [] now = datetime.datetime.utcnow().isoformat() for r in rows: payload.append({ "paper_id": r["paper_id"], "doi": r.get("doi") or "", "oa_status": r.get("oa_status"), "oa_type": r.get("oa_type"), "source": r.get("source"), "pdf_url": r.get("pdf_url"), "scihub_used": bool(r.get("scihub_used", False)), "error_message": r.get("error"), "checked_at": now, }) chunk = 500 for i in range(0, len(payload), chunk): (client().table("pdf_downloads") .upsert(payload[i:i + chunk], on_conflict="paper_id").execute()) def save_pdf_downloads(rows: list[dict]) -> None: """rows: list from pdf_downloader.bulk_download. Updates existing pdf_downloads records with download status.""" now = datetime.datetime.utcnow().isoformat() for r in rows: update_payload = { "downloaded": bool(r.get("downloaded", False)), "file_path": r.get("file_path"), "file_size_bytes": int(r.get("file_size") or 0), "uploaded_to_supabase": bool(r.get("uploaded_to_supabase", False)), "supabase_storage_path": r.get("supabase_path"), "downloaded_at": now if r.get("downloaded") else None, } # Only override error_message if there's a new error if r.get("error"): update_payload["error_message"] = r["error"] (client().table("pdf_downloads") .update(update_payload) .eq("paper_id", r["paper_id"]).execute()) def get_pdf_status(paper_id: str | None = None) -> list[dict]: """Return all pdf_downloads rows, optionally filtered by paper_id.""" q = client().table("pdf_downloads").select("*") if paper_id: q = q.eq("paper_id", paper_id) res = q.execute() return res.data or [] def list_pdfs_with_metadata() -> list[dict]: """Joined view: pdf_downloads with paper title.""" res = (client().table("pdf_downloads") .select("*, papers(title, year)") .order("paper_id").execute()) out = [] for r in res.data or []: p = r.get("papers") or {} out.append({ "paper_id": r["paper_id"], "doi": r.get("doi") or "", "title": p.get("title", ""), "year": p.get("year"), "oa_status": r.get("oa_status") or "", "oa_type": r.get("oa_type") or "", "source": r.get("source") or "", "scihub_used": bool(r.get("scihub_used")), "downloaded": bool(r.get("downloaded")), "file_size_mb": round((r.get("file_size_bytes") or 0) / (1024 ** 2), 2), "uploaded_to_supabase": bool(r.get("uploaded_to_supabase")), "error_message": r.get("error_message") or "", }) return out # ---------------------------------------------------------- Crossref metadata def save_crossref_metadata(rows: list[dict]) -> None: """rows: list from pdf_downloader.bulk_crossref.""" payload = [] for r in rows: if not r.get("paper_id"): continue ref_dois = r.get("reference_dois") or [] payload.append({ "paper_id": r["paper_id"], "doi": r.get("doi") or "", "crossref_title": r.get("title"), "crossref_abstract": r.get("abstract"), "citation_count": r.get("citation_count"), "references_count": r.get("references_count"), "reference_dois": json.dumps(ref_dois) if ref_dois else None, "publication_type": r.get("type"), "publisher": r.get("publisher"), "container_title": r.get("container_title"), "error_message": r.get("error"), }) chunk = 500 for i in range(0, len(payload), chunk): (client().table("crossref_metadata") .upsert(payload[i:i + chunk], on_conflict="paper_id").execute()) def get_crossref_metadata(paper_id: str | None = None) -> list[dict]: q = client().table("crossref_metadata").select("*") if paper_id: q = q.eq("paper_id", paper_id) res = q.execute() return res.data or []