Spaces:
Sleeping
Sleeping
| """ | |
| database.py | |
| Responsibility: Supabase persistence for both the embedding workbench and | |
| the TCCM screening pipeline. | |
| Connection: requires SUPABASE_URL and SUPABASE_SERVICE_ROLE_KEY environment | |
| variables. On Hugging Face Spaces these are set in Settings -> Secrets. | |
| The app fails fast with a clear error if either is missing. | |
| Schema management: Supabase intentionally restricts CREATE TABLE from the | |
| client SDK. The researcher runs the schema SQL (via `bootstrap_schema_sql()`) | |
| once in the Supabase SQL editor; subsequent app runs find tables in place. | |
| The UI surfaces missing tables explicitly via `assert_schema_present()`. | |
| Why service-role key (not anon key): | |
| Unrestricted reads/writes across all tables in this single-researcher | |
| workbench need the service-role key. The anon key is bound by Row Level | |
| Security and would require RLS policies for every table — overkill here. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import json | |
| import datetime | |
| from typing import Any | |
| from supabase import create_client, Client | |
| # ---------------------------------------------------------- env / client | |
| def _require_env() -> tuple[str, str]: | |
| url = os.environ.get("SUPABASE_URL") | |
| key = os.environ.get("SUPABASE_SERVICE_ROLE_KEY") | |
| if not url or not key: | |
| raise RuntimeError( | |
| "Missing Supabase credentials. Set SUPABASE_URL and " | |
| "SUPABASE_SERVICE_ROLE_KEY in the environment " | |
| "(HF Space Settings -> Variables and secrets -> Secrets)." | |
| ) | |
| return url, key | |
| _client: Client | None = None | |
| def client() -> Client: | |
| global _client | |
| if _client is None: | |
| url, key = _require_env() | |
| _client = create_client(url, key) | |
| return _client | |
| # ---------------------------------------------------------- schema | |
| SCHEMA_SQL = """ | |
| -- Run this entire block once in the Supabase SQL editor | |
| -- (Project -> SQL Editor -> New query -> paste -> Run). | |
| create table if not exists papers ( | |
| paper_id text primary key, | |
| doi text, | |
| title text not null, | |
| abstract text, | |
| year integer, | |
| author_keywords text, | |
| source text default 'manual_upload', | |
| created_at timestamptz default now() | |
| ); | |
| create table if not exists embeddings_2d ( | |
| paper_id text primary key references papers(paper_id) on delete cascade, | |
| coords_json text not null | |
| ); | |
| create table if not exists clustering_runs ( | |
| run_id bigserial primary key, | |
| params_json text not null, | |
| n_clusters integer, | |
| n_noise integer, | |
| notes text, | |
| created_at timestamptz default now() | |
| ); | |
| create table if not exists cluster_assignments ( | |
| run_id bigint not null references clustering_runs(run_id) on delete cascade, | |
| paper_id text not null references papers(paper_id) on delete cascade, | |
| cluster_id integer not null, | |
| membership_prob double precision not null, | |
| primary key (run_id, paper_id) | |
| ); | |
| create table if not exists cluster_labels ( | |
| run_id bigint not null references clustering_runs(run_id) on delete cascade, | |
| cluster_id integer not null, | |
| label text, | |
| subject text, | |
| object_ text, | |
| phenomenon text, | |
| rationale text, | |
| top_paper_ids_json text, | |
| validated_by_researcher boolean default false, | |
| locked boolean default false, | |
| researcher_notes text, | |
| primary key (run_id, cluster_id) | |
| ); | |
| create table if not exists tccm_runs ( | |
| run_id bigserial primary key, | |
| pattern_set_version text not null, | |
| threshold_rule text not null default 'anti_dominates', | |
| n_papers integer, | |
| n_include integer, | |
| n_exclude integer, | |
| n_marginal integer, | |
| notes text, | |
| created_at timestamptz default now() | |
| ); | |
| create table if not exists tccm_classifications ( | |
| run_id bigint not null references tccm_runs(run_id) on delete cascade, | |
| paper_id text not null references papers(paper_id) on delete cascade, | |
| verdict text not null check (verdict in ('INCLUDE','EXCLUDE','MARGINAL')), | |
| n_method integer not null default 0, | |
| n_sample integer not null default 0, | |
| n_analytic integer not null default 0, | |
| n_anti integer not null default 0, | |
| fired_terms_json text, | |
| primary key (run_id, paper_id) | |
| ); | |
| create table if not exists tccm_marginal_reviews ( | |
| run_id bigint not null references tccm_runs(run_id) on delete cascade, | |
| paper_id text not null references papers(paper_id) on delete cascade, | |
| agent_verdict text, | |
| agent_rationale text, | |
| researcher_verdict text, | |
| researcher_notes text, | |
| reviewed_at timestamptz default now(), | |
| primary key (run_id, paper_id) | |
| ); | |
| create table if not exists pdf_downloads ( | |
| paper_id text primary key references papers(paper_id) on delete cascade, | |
| doi text, | |
| oa_status text, | |
| oa_type text, | |
| source text, | |
| pdf_url text, | |
| scihub_used boolean default false, | |
| downloaded boolean default false, | |
| file_path text, | |
| file_size_bytes bigint, | |
| uploaded_to_supabase boolean default false, | |
| supabase_storage_path text, | |
| error_message text, | |
| checked_at timestamptz, | |
| downloaded_at timestamptz | |
| ); | |
| create table if not exists crossref_metadata ( | |
| paper_id text primary key references papers(paper_id) on delete cascade, | |
| doi text, | |
| crossref_title text, | |
| crossref_abstract text, | |
| citation_count integer, | |
| references_count integer, | |
| reference_dois text, | |
| publication_type text, | |
| publisher text, | |
| container_title text, | |
| error_message text, | |
| fetched_at timestamptz default now() | |
| ); | |
| """ | |
| def bootstrap_schema_sql() -> str: | |
| """Return the full schema SQL for one-time manual execution.""" | |
| return SCHEMA_SQL | |
| def assert_schema_present() -> dict[str, bool]: | |
| """ | |
| Verify all expected tables exist by attempting bounded reads. | |
| Returns {table_name: present}. Does not raise; the UI handles missing | |
| tables by showing the bootstrap SQL. | |
| """ | |
| expected = [ | |
| "papers", "embeddings_2d", "clustering_runs", | |
| "cluster_assignments", "cluster_labels", | |
| "tccm_runs", "tccm_classifications", "tccm_marginal_reviews", | |
| "pdf_downloads", "crossref_metadata", | |
| ] | |
| report: dict[str, bool] = {} | |
| c = client() | |
| for tbl in expected: | |
| try: | |
| c.table(tbl).select("*").limit(1).execute() | |
| report[tbl] = True | |
| except Exception: | |
| report[tbl] = False | |
| return report | |
| # ---------------------------------------------------------- papers | |
| def _scrub_nan(v): | |
| """Convert any NaN-like value to None so it's JSON-encodable. | |
| Handles pandas NaN (float), numpy NaN, and plain string 'nan'.""" | |
| try: | |
| import math | |
| if v is None: | |
| return None | |
| if isinstance(v, float) and math.isnan(v): | |
| return None | |
| # pandas NaT, numpy.nan in object columns | |
| if v != v: # NaN is the only value that is not equal to itself | |
| return None | |
| except Exception: | |
| pass | |
| return v | |
| def upsert_papers(rows: list[dict]) -> int: | |
| if not rows: | |
| return 0 | |
| payload = [] | |
| for r in rows: | |
| doi = _scrub_nan(r.get("doi")) | |
| absr = _scrub_nan(r.get("abstract")) | |
| keyw = _scrub_nan(r.get("author_keywords")) | |
| year = _scrub_nan(r.get("year")) | |
| title = _scrub_nan(r.get("title")) | |
| payload.append({ | |
| "paper_id": r["paper_id"], | |
| "doi": doi if doi is not None else "", | |
| "title": title if title is not None else "", | |
| "abstract": absr if absr is not None else "", | |
| "year": int(year) if year is not None else None, | |
| "author_keywords": keyw if keyw is not None else "", | |
| }) | |
| chunk = 500 | |
| for i in range(0, len(payload), chunk): | |
| (client().table("papers") | |
| .upsert(payload[i:i + chunk], on_conflict="paper_id").execute()) | |
| return len(payload) | |
| def get_all_papers() -> list[dict]: | |
| out: list[dict] = [] | |
| page_size = 1000 | |
| offset = 0 | |
| while True: | |
| res = (client().table("papers").select("*") | |
| .order("paper_id") | |
| .range(offset, offset + page_size - 1).execute()) | |
| batch = res.data or [] | |
| out.extend(batch) | |
| if len(batch) < page_size: | |
| break | |
| offset += page_size | |
| return out | |
| def get_paper(paper_id: str) -> dict | None: | |
| res = (client().table("papers") | |
| .select("*").eq("paper_id", paper_id).limit(1).execute()) | |
| return (res.data or [None])[0] | |
| def count_papers() -> int: | |
| res = client().table("papers").select("paper_id", count="exact").limit(1).execute() | |
| return res.count or 0 | |
| # ---------------------------------------------------------- 2D coords | |
| def save_2d_coords_bulk(rows: list[dict]) -> None: | |
| if not rows: | |
| return | |
| payload = [ | |
| {"paper_id": r["paper_id"], "coords_json": json.dumps(r["coords"])} | |
| for r in rows | |
| ] | |
| chunk = 500 | |
| for i in range(0, len(payload), chunk): | |
| (client().table("embeddings_2d") | |
| .upsert(payload[i:i + chunk], on_conflict="paper_id").execute()) | |
| def get_2d_coords() -> dict[str, list[float]]: | |
| res = client().table("embeddings_2d").select("*").execute() | |
| return {r["paper_id"]: json.loads(r["coords_json"]) for r in (res.data or [])} | |
| # ---------------------------------------------------------- clustering runs | |
| def create_clustering_run(params: dict, n_clusters: int, | |
| n_noise: int, notes: str = "") -> int: | |
| res = client().table("clustering_runs").insert({ | |
| "params_json": json.dumps(params), | |
| "n_clusters": int(n_clusters), | |
| "n_noise": int(n_noise), | |
| "notes": notes, | |
| }).execute() | |
| return res.data[0]["run_id"] | |
| def list_clustering_runs() -> list[dict]: | |
| res = (client().table("clustering_runs").select("*") | |
| .order("run_id", desc=True).execute()) | |
| out = [] | |
| for r in res.data or []: | |
| out.append({ | |
| "run_id": r["run_id"], | |
| "params": json.loads(r["params_json"]), | |
| "n_clusters": r["n_clusters"], | |
| "n_noise": r["n_noise"], | |
| "created_at": r["created_at"], | |
| "notes": r.get("notes", "") or "", | |
| }) | |
| return out | |
| def save_cluster_assignments(run_id: int, paper_ids: list[str], | |
| cluster_ids: list[int], probs: list[float]) -> None: | |
| payload = [ | |
| {"run_id": int(run_id), "paper_id": pid, | |
| "cluster_id": int(cid), "membership_prob": float(p)} | |
| for pid, cid, p in zip(paper_ids, cluster_ids, probs) | |
| ] | |
| chunk = 500 | |
| for i in range(0, len(payload), chunk): | |
| (client().table("cluster_assignments") | |
| .upsert(payload[i:i + chunk], on_conflict="run_id,paper_id").execute()) | |
| def get_cluster_assignments(run_id: int) -> list[dict]: | |
| res = (client().table("cluster_assignments") | |
| .select("*, papers(title, abstract, year)") | |
| .eq("run_id", int(run_id)).execute()) | |
| out = [] | |
| for r in res.data or []: | |
| p = r.get("papers") or {} | |
| out.append({ | |
| "paper_id": r["paper_id"], | |
| "cluster_id": r["cluster_id"], | |
| "membership_prob": r["membership_prob"], | |
| "title": p.get("title", ""), | |
| "abstract": p.get("abstract", ""), | |
| "year": p.get("year"), | |
| }) | |
| out.sort(key=lambda d: (d["cluster_id"], -d["membership_prob"])) | |
| return out | |
| # ---------------------------------------------------------- cluster labels | |
| def save_cluster_label(run_id: int, cluster_id: int, label: str, | |
| subject: str = "", object_: str = "", | |
| phenomenon: str = "", rationale: str = "", | |
| top_paper_ids: list[str] | None = None, | |
| validated_by_researcher: bool = False, | |
| locked: bool = False, | |
| researcher_notes: str = "") -> None: | |
| client().table("cluster_labels").upsert({ | |
| "run_id": int(run_id), | |
| "cluster_id": int(cluster_id), | |
| "label": label, | |
| "subject": subject, | |
| "object_": object_, | |
| "phenomenon": phenomenon, | |
| "rationale": rationale, | |
| "top_paper_ids_json": json.dumps(top_paper_ids or []), | |
| "validated_by_researcher": bool(validated_by_researcher), | |
| "locked": bool(locked), | |
| "researcher_notes": researcher_notes, | |
| }, on_conflict="run_id,cluster_id").execute() | |
| def get_cluster_labels(run_id: int) -> list[dict]: | |
| res = (client().table("cluster_labels") | |
| .select("*").eq("run_id", int(run_id)) | |
| .order("cluster_id").execute()) | |
| out = [] | |
| for r in res.data or []: | |
| out.append({ | |
| "cluster_id": r["cluster_id"], | |
| "label": r.get("label") or "", | |
| "subject": r.get("subject") or "", | |
| "object": r.get("object_") or "", | |
| "phenomenon": r.get("phenomenon") or "", | |
| "rationale": r.get("rationale") or "", | |
| "top_paper_ids": json.loads(r.get("top_paper_ids_json") or "[]"), | |
| "validated_by_researcher": bool(r.get("validated_by_researcher")), | |
| "locked": bool(r.get("locked")), | |
| "researcher_notes": r.get("researcher_notes") or "", | |
| }) | |
| return out | |
| # ---------------------------------------------------------- TCCM runs | |
| def create_tccm_run(pattern_set_version: str, threshold_rule: str, | |
| n_papers: int, n_include: int, n_exclude: int, | |
| n_marginal: int, notes: str = "") -> int: | |
| res = client().table("tccm_runs").insert({ | |
| "pattern_set_version": pattern_set_version, | |
| "threshold_rule": threshold_rule, | |
| "n_papers": int(n_papers), | |
| "n_include": int(n_include), | |
| "n_exclude": int(n_exclude), | |
| "n_marginal": int(n_marginal), | |
| "notes": notes, | |
| }).execute() | |
| return res.data[0]["run_id"] | |
| def list_tccm_runs() -> list[dict]: | |
| res = (client().table("tccm_runs").select("*") | |
| .order("run_id", desc=True).execute()) | |
| return res.data or [] | |
| def save_tccm_classifications(run_id: int, rows: list[dict]) -> None: | |
| payload = [ | |
| { | |
| "run_id": int(run_id), | |
| "paper_id": r["paper_id"], | |
| "verdict": r["verdict"], | |
| "n_method": int(r.get("n_method", 0)), | |
| "n_sample": int(r.get("n_sample", 0)), | |
| "n_analytic": int(r.get("n_analytic", 0)), | |
| "n_anti": int(r.get("n_anti", 0)), | |
| "fired_terms_json": r.get("fired_terms_json", "{}"), | |
| } | |
| for r in rows | |
| ] | |
| chunk = 500 | |
| for i in range(0, len(payload), chunk): | |
| (client().table("tccm_classifications") | |
| .upsert(payload[i:i + chunk], on_conflict="run_id,paper_id").execute()) | |
| def get_tccm_classifications(run_id: int, | |
| verdict_filter: str | None = None) -> list[dict]: | |
| q = (client().table("tccm_classifications") | |
| .select("*, papers(title, abstract, year, author_keywords)") | |
| .eq("run_id", int(run_id))) | |
| if verdict_filter: | |
| q = q.eq("verdict", verdict_filter) | |
| res = q.execute() | |
| out = [] | |
| for r in res.data or []: | |
| p = r.get("papers") or {} | |
| out.append({ | |
| "paper_id": r["paper_id"], | |
| "verdict": r["verdict"], | |
| "n_method": r["n_method"], | |
| "n_sample": r["n_sample"], | |
| "n_analytic": r["n_analytic"], | |
| "n_anti": r["n_anti"], | |
| "fired_terms": json.loads(r.get("fired_terms_json") or "{}"), | |
| "title": p.get("title", ""), | |
| "abstract": p.get("abstract", ""), | |
| "year": p.get("year"), | |
| "author_keywords": p.get("author_keywords") or "", | |
| }) | |
| return out | |
| # ---------------------------------------------------------- marginal reviews | |
| def save_marginal_review(run_id: int, paper_id: str, | |
| agent_verdict: str, agent_rationale: str, | |
| researcher_verdict: str | None = None, | |
| researcher_notes: str = "") -> None: | |
| client().table("tccm_marginal_reviews").upsert({ | |
| "run_id": int(run_id), | |
| "paper_id": paper_id, | |
| "agent_verdict": agent_verdict, | |
| "agent_rationale": agent_rationale, | |
| "researcher_verdict": researcher_verdict, | |
| "researcher_notes": researcher_notes, | |
| "reviewed_at": datetime.datetime.utcnow().isoformat(), | |
| }, on_conflict="run_id,paper_id").execute() | |
| def get_marginal_reviews(run_id: int) -> list[dict]: | |
| res = (client().table("tccm_marginal_reviews") | |
| .select("*").eq("run_id", int(run_id)).execute()) | |
| return res.data or [] | |
| # ---------------------------------------------------------- PDF downloads | |
| def save_pdf_discoveries(rows: list[dict]) -> None: | |
| """rows: list from pdf_downloader.bulk_discover.""" | |
| payload = [] | |
| now = datetime.datetime.utcnow().isoformat() | |
| for r in rows: | |
| payload.append({ | |
| "paper_id": r["paper_id"], | |
| "doi": r.get("doi") or "", | |
| "oa_status": r.get("oa_status"), | |
| "oa_type": r.get("oa_type"), | |
| "source": r.get("source"), | |
| "pdf_url": r.get("pdf_url"), | |
| "scihub_used": bool(r.get("scihub_used", False)), | |
| "error_message": r.get("error"), | |
| "checked_at": now, | |
| }) | |
| chunk = 500 | |
| for i in range(0, len(payload), chunk): | |
| (client().table("pdf_downloads") | |
| .upsert(payload[i:i + chunk], on_conflict="paper_id").execute()) | |
| def save_pdf_downloads(rows: list[dict]) -> None: | |
| """rows: list from pdf_downloader.bulk_download. Updates existing | |
| pdf_downloads records with download status.""" | |
| now = datetime.datetime.utcnow().isoformat() | |
| for r in rows: | |
| update_payload = { | |
| "downloaded": bool(r.get("downloaded", False)), | |
| "file_path": r.get("file_path"), | |
| "file_size_bytes": int(r.get("file_size") or 0), | |
| "uploaded_to_supabase": bool(r.get("uploaded_to_supabase", False)), | |
| "supabase_storage_path": r.get("supabase_path"), | |
| "downloaded_at": now if r.get("downloaded") else None, | |
| } | |
| # Only override error_message if there's a new error | |
| if r.get("error"): | |
| update_payload["error_message"] = r["error"] | |
| (client().table("pdf_downloads") | |
| .update(update_payload) | |
| .eq("paper_id", r["paper_id"]).execute()) | |
| def get_pdf_status(paper_id: str | None = None) -> list[dict]: | |
| """Return all pdf_downloads rows, optionally filtered by paper_id.""" | |
| q = client().table("pdf_downloads").select("*") | |
| if paper_id: | |
| q = q.eq("paper_id", paper_id) | |
| res = q.execute() | |
| return res.data or [] | |
| def list_pdfs_with_metadata() -> list[dict]: | |
| """Joined view: pdf_downloads with paper title.""" | |
| res = (client().table("pdf_downloads") | |
| .select("*, papers(title, year)") | |
| .order("paper_id").execute()) | |
| out = [] | |
| for r in res.data or []: | |
| p = r.get("papers") or {} | |
| out.append({ | |
| "paper_id": r["paper_id"], | |
| "doi": r.get("doi") or "", | |
| "title": p.get("title", ""), | |
| "year": p.get("year"), | |
| "oa_status": r.get("oa_status") or "", | |
| "oa_type": r.get("oa_type") or "", | |
| "source": r.get("source") or "", | |
| "scihub_used": bool(r.get("scihub_used")), | |
| "downloaded": bool(r.get("downloaded")), | |
| "file_size_mb": round((r.get("file_size_bytes") or 0) / (1024 ** 2), 2), | |
| "uploaded_to_supabase": bool(r.get("uploaded_to_supabase")), | |
| "error_message": r.get("error_message") or "", | |
| }) | |
| return out | |
| # ---------------------------------------------------------- Crossref metadata | |
| def save_crossref_metadata(rows: list[dict]) -> None: | |
| """rows: list from pdf_downloader.bulk_crossref.""" | |
| payload = [] | |
| for r in rows: | |
| if not r.get("paper_id"): | |
| continue | |
| ref_dois = r.get("reference_dois") or [] | |
| payload.append({ | |
| "paper_id": r["paper_id"], | |
| "doi": r.get("doi") or "", | |
| "crossref_title": r.get("title"), | |
| "crossref_abstract": r.get("abstract"), | |
| "citation_count": r.get("citation_count"), | |
| "references_count": r.get("references_count"), | |
| "reference_dois": json.dumps(ref_dois) if ref_dois else None, | |
| "publication_type": r.get("type"), | |
| "publisher": r.get("publisher"), | |
| "container_title": r.get("container_title"), | |
| "error_message": r.get("error"), | |
| }) | |
| chunk = 500 | |
| for i in range(0, len(payload), chunk): | |
| (client().table("crossref_metadata") | |
| .upsert(payload[i:i + chunk], on_conflict="paper_id").execute()) | |
| def get_crossref_metadata(paper_id: str | None = None) -> list[dict]: | |
| q = client().table("crossref_metadata").select("*") | |
| if paper_id: | |
| q = q.eq("paper_id", paper_id) | |
| res = q.execute() | |
| return res.data or [] | |