TOPICMODELLING / database.py
Milind Kamat
Add PDF Downloader workbench
614dd95
"""
database.py
Responsibility: Supabase persistence for both the embedding workbench and
the TCCM screening pipeline.
Connection: requires SUPABASE_URL and SUPABASE_SERVICE_ROLE_KEY environment
variables. On Hugging Face Spaces these are set in Settings -> Secrets.
The app fails fast with a clear error if either is missing.
Schema management: Supabase intentionally restricts CREATE TABLE from the
client SDK. The researcher runs the schema SQL (via `bootstrap_schema_sql()`)
once in the Supabase SQL editor; subsequent app runs find tables in place.
The UI surfaces missing tables explicitly via `assert_schema_present()`.
Why service-role key (not anon key):
Unrestricted reads/writes across all tables in this single-researcher
workbench need the service-role key. The anon key is bound by Row Level
Security and would require RLS policies for every table — overkill here.
"""
from __future__ import annotations
import os
import json
import datetime
from typing import Any
from supabase import create_client, Client
# ---------------------------------------------------------- env / client
def _require_env() -> tuple[str, str]:
url = os.environ.get("SUPABASE_URL")
key = os.environ.get("SUPABASE_SERVICE_ROLE_KEY")
if not url or not key:
raise RuntimeError(
"Missing Supabase credentials. Set SUPABASE_URL and "
"SUPABASE_SERVICE_ROLE_KEY in the environment "
"(HF Space Settings -> Variables and secrets -> Secrets)."
)
return url, key
_client: Client | None = None
def client() -> Client:
global _client
if _client is None:
url, key = _require_env()
_client = create_client(url, key)
return _client
# ---------------------------------------------------------- schema
SCHEMA_SQL = """
-- Run this entire block once in the Supabase SQL editor
-- (Project -> SQL Editor -> New query -> paste -> Run).
create table if not exists papers (
paper_id text primary key,
doi text,
title text not null,
abstract text,
year integer,
author_keywords text,
source text default 'manual_upload',
created_at timestamptz default now()
);
create table if not exists embeddings_2d (
paper_id text primary key references papers(paper_id) on delete cascade,
coords_json text not null
);
create table if not exists clustering_runs (
run_id bigserial primary key,
params_json text not null,
n_clusters integer,
n_noise integer,
notes text,
created_at timestamptz default now()
);
create table if not exists cluster_assignments (
run_id bigint not null references clustering_runs(run_id) on delete cascade,
paper_id text not null references papers(paper_id) on delete cascade,
cluster_id integer not null,
membership_prob double precision not null,
primary key (run_id, paper_id)
);
create table if not exists cluster_labels (
run_id bigint not null references clustering_runs(run_id) on delete cascade,
cluster_id integer not null,
label text,
subject text,
object_ text,
phenomenon text,
rationale text,
top_paper_ids_json text,
validated_by_researcher boolean default false,
locked boolean default false,
researcher_notes text,
primary key (run_id, cluster_id)
);
create table if not exists tccm_runs (
run_id bigserial primary key,
pattern_set_version text not null,
threshold_rule text not null default 'anti_dominates',
n_papers integer,
n_include integer,
n_exclude integer,
n_marginal integer,
notes text,
created_at timestamptz default now()
);
create table if not exists tccm_classifications (
run_id bigint not null references tccm_runs(run_id) on delete cascade,
paper_id text not null references papers(paper_id) on delete cascade,
verdict text not null check (verdict in ('INCLUDE','EXCLUDE','MARGINAL')),
n_method integer not null default 0,
n_sample integer not null default 0,
n_analytic integer not null default 0,
n_anti integer not null default 0,
fired_terms_json text,
primary key (run_id, paper_id)
);
create table if not exists tccm_marginal_reviews (
run_id bigint not null references tccm_runs(run_id) on delete cascade,
paper_id text not null references papers(paper_id) on delete cascade,
agent_verdict text,
agent_rationale text,
researcher_verdict text,
researcher_notes text,
reviewed_at timestamptz default now(),
primary key (run_id, paper_id)
);
create table if not exists pdf_downloads (
paper_id text primary key references papers(paper_id) on delete cascade,
doi text,
oa_status text,
oa_type text,
source text,
pdf_url text,
scihub_used boolean default false,
downloaded boolean default false,
file_path text,
file_size_bytes bigint,
uploaded_to_supabase boolean default false,
supabase_storage_path text,
error_message text,
checked_at timestamptz,
downloaded_at timestamptz
);
create table if not exists crossref_metadata (
paper_id text primary key references papers(paper_id) on delete cascade,
doi text,
crossref_title text,
crossref_abstract text,
citation_count integer,
references_count integer,
reference_dois text,
publication_type text,
publisher text,
container_title text,
error_message text,
fetched_at timestamptz default now()
);
"""
def bootstrap_schema_sql() -> str:
"""Return the full schema SQL for one-time manual execution."""
return SCHEMA_SQL
def assert_schema_present() -> dict[str, bool]:
"""
Verify all expected tables exist by attempting bounded reads.
Returns {table_name: present}. Does not raise; the UI handles missing
tables by showing the bootstrap SQL.
"""
expected = [
"papers", "embeddings_2d", "clustering_runs",
"cluster_assignments", "cluster_labels",
"tccm_runs", "tccm_classifications", "tccm_marginal_reviews",
"pdf_downloads", "crossref_metadata",
]
report: dict[str, bool] = {}
c = client()
for tbl in expected:
try:
c.table(tbl).select("*").limit(1).execute()
report[tbl] = True
except Exception:
report[tbl] = False
return report
# ---------------------------------------------------------- papers
def _scrub_nan(v):
"""Convert any NaN-like value to None so it's JSON-encodable.
Handles pandas NaN (float), numpy NaN, and plain string 'nan'."""
try:
import math
if v is None:
return None
if isinstance(v, float) and math.isnan(v):
return None
# pandas NaT, numpy.nan in object columns
if v != v: # NaN is the only value that is not equal to itself
return None
except Exception:
pass
return v
def upsert_papers(rows: list[dict]) -> int:
if not rows:
return 0
payload = []
for r in rows:
doi = _scrub_nan(r.get("doi"))
absr = _scrub_nan(r.get("abstract"))
keyw = _scrub_nan(r.get("author_keywords"))
year = _scrub_nan(r.get("year"))
title = _scrub_nan(r.get("title"))
payload.append({
"paper_id": r["paper_id"],
"doi": doi if doi is not None else "",
"title": title if title is not None else "",
"abstract": absr if absr is not None else "",
"year": int(year) if year is not None else None,
"author_keywords": keyw if keyw is not None else "",
})
chunk = 500
for i in range(0, len(payload), chunk):
(client().table("papers")
.upsert(payload[i:i + chunk], on_conflict="paper_id").execute())
return len(payload)
def get_all_papers() -> list[dict]:
out: list[dict] = []
page_size = 1000
offset = 0
while True:
res = (client().table("papers").select("*")
.order("paper_id")
.range(offset, offset + page_size - 1).execute())
batch = res.data or []
out.extend(batch)
if len(batch) < page_size:
break
offset += page_size
return out
def get_paper(paper_id: str) -> dict | None:
res = (client().table("papers")
.select("*").eq("paper_id", paper_id).limit(1).execute())
return (res.data or [None])[0]
def count_papers() -> int:
res = client().table("papers").select("paper_id", count="exact").limit(1).execute()
return res.count or 0
# ---------------------------------------------------------- 2D coords
def save_2d_coords_bulk(rows: list[dict]) -> None:
if not rows:
return
payload = [
{"paper_id": r["paper_id"], "coords_json": json.dumps(r["coords"])}
for r in rows
]
chunk = 500
for i in range(0, len(payload), chunk):
(client().table("embeddings_2d")
.upsert(payload[i:i + chunk], on_conflict="paper_id").execute())
def get_2d_coords() -> dict[str, list[float]]:
res = client().table("embeddings_2d").select("*").execute()
return {r["paper_id"]: json.loads(r["coords_json"]) for r in (res.data or [])}
# ---------------------------------------------------------- clustering runs
def create_clustering_run(params: dict, n_clusters: int,
n_noise: int, notes: str = "") -> int:
res = client().table("clustering_runs").insert({
"params_json": json.dumps(params),
"n_clusters": int(n_clusters),
"n_noise": int(n_noise),
"notes": notes,
}).execute()
return res.data[0]["run_id"]
def list_clustering_runs() -> list[dict]:
res = (client().table("clustering_runs").select("*")
.order("run_id", desc=True).execute())
out = []
for r in res.data or []:
out.append({
"run_id": r["run_id"],
"params": json.loads(r["params_json"]),
"n_clusters": r["n_clusters"],
"n_noise": r["n_noise"],
"created_at": r["created_at"],
"notes": r.get("notes", "") or "",
})
return out
def save_cluster_assignments(run_id: int, paper_ids: list[str],
cluster_ids: list[int], probs: list[float]) -> None:
payload = [
{"run_id": int(run_id), "paper_id": pid,
"cluster_id": int(cid), "membership_prob": float(p)}
for pid, cid, p in zip(paper_ids, cluster_ids, probs)
]
chunk = 500
for i in range(0, len(payload), chunk):
(client().table("cluster_assignments")
.upsert(payload[i:i + chunk], on_conflict="run_id,paper_id").execute())
def get_cluster_assignments(run_id: int) -> list[dict]:
res = (client().table("cluster_assignments")
.select("*, papers(title, abstract, year)")
.eq("run_id", int(run_id)).execute())
out = []
for r in res.data or []:
p = r.get("papers") or {}
out.append({
"paper_id": r["paper_id"],
"cluster_id": r["cluster_id"],
"membership_prob": r["membership_prob"],
"title": p.get("title", ""),
"abstract": p.get("abstract", ""),
"year": p.get("year"),
})
out.sort(key=lambda d: (d["cluster_id"], -d["membership_prob"]))
return out
# ---------------------------------------------------------- cluster labels
def save_cluster_label(run_id: int, cluster_id: int, label: str,
subject: str = "", object_: str = "",
phenomenon: str = "", rationale: str = "",
top_paper_ids: list[str] | None = None,
validated_by_researcher: bool = False,
locked: bool = False,
researcher_notes: str = "") -> None:
client().table("cluster_labels").upsert({
"run_id": int(run_id),
"cluster_id": int(cluster_id),
"label": label,
"subject": subject,
"object_": object_,
"phenomenon": phenomenon,
"rationale": rationale,
"top_paper_ids_json": json.dumps(top_paper_ids or []),
"validated_by_researcher": bool(validated_by_researcher),
"locked": bool(locked),
"researcher_notes": researcher_notes,
}, on_conflict="run_id,cluster_id").execute()
def get_cluster_labels(run_id: int) -> list[dict]:
res = (client().table("cluster_labels")
.select("*").eq("run_id", int(run_id))
.order("cluster_id").execute())
out = []
for r in res.data or []:
out.append({
"cluster_id": r["cluster_id"],
"label": r.get("label") or "",
"subject": r.get("subject") or "",
"object": r.get("object_") or "",
"phenomenon": r.get("phenomenon") or "",
"rationale": r.get("rationale") or "",
"top_paper_ids": json.loads(r.get("top_paper_ids_json") or "[]"),
"validated_by_researcher": bool(r.get("validated_by_researcher")),
"locked": bool(r.get("locked")),
"researcher_notes": r.get("researcher_notes") or "",
})
return out
# ---------------------------------------------------------- TCCM runs
def create_tccm_run(pattern_set_version: str, threshold_rule: str,
n_papers: int, n_include: int, n_exclude: int,
n_marginal: int, notes: str = "") -> int:
res = client().table("tccm_runs").insert({
"pattern_set_version": pattern_set_version,
"threshold_rule": threshold_rule,
"n_papers": int(n_papers),
"n_include": int(n_include),
"n_exclude": int(n_exclude),
"n_marginal": int(n_marginal),
"notes": notes,
}).execute()
return res.data[0]["run_id"]
def list_tccm_runs() -> list[dict]:
res = (client().table("tccm_runs").select("*")
.order("run_id", desc=True).execute())
return res.data or []
def save_tccm_classifications(run_id: int, rows: list[dict]) -> None:
payload = [
{
"run_id": int(run_id),
"paper_id": r["paper_id"],
"verdict": r["verdict"],
"n_method": int(r.get("n_method", 0)),
"n_sample": int(r.get("n_sample", 0)),
"n_analytic": int(r.get("n_analytic", 0)),
"n_anti": int(r.get("n_anti", 0)),
"fired_terms_json": r.get("fired_terms_json", "{}"),
}
for r in rows
]
chunk = 500
for i in range(0, len(payload), chunk):
(client().table("tccm_classifications")
.upsert(payload[i:i + chunk], on_conflict="run_id,paper_id").execute())
def get_tccm_classifications(run_id: int,
verdict_filter: str | None = None) -> list[dict]:
q = (client().table("tccm_classifications")
.select("*, papers(title, abstract, year, author_keywords)")
.eq("run_id", int(run_id)))
if verdict_filter:
q = q.eq("verdict", verdict_filter)
res = q.execute()
out = []
for r in res.data or []:
p = r.get("papers") or {}
out.append({
"paper_id": r["paper_id"],
"verdict": r["verdict"],
"n_method": r["n_method"],
"n_sample": r["n_sample"],
"n_analytic": r["n_analytic"],
"n_anti": r["n_anti"],
"fired_terms": json.loads(r.get("fired_terms_json") or "{}"),
"title": p.get("title", ""),
"abstract": p.get("abstract", ""),
"year": p.get("year"),
"author_keywords": p.get("author_keywords") or "",
})
return out
# ---------------------------------------------------------- marginal reviews
def save_marginal_review(run_id: int, paper_id: str,
agent_verdict: str, agent_rationale: str,
researcher_verdict: str | None = None,
researcher_notes: str = "") -> None:
client().table("tccm_marginal_reviews").upsert({
"run_id": int(run_id),
"paper_id": paper_id,
"agent_verdict": agent_verdict,
"agent_rationale": agent_rationale,
"researcher_verdict": researcher_verdict,
"researcher_notes": researcher_notes,
"reviewed_at": datetime.datetime.utcnow().isoformat(),
}, on_conflict="run_id,paper_id").execute()
def get_marginal_reviews(run_id: int) -> list[dict]:
res = (client().table("tccm_marginal_reviews")
.select("*").eq("run_id", int(run_id)).execute())
return res.data or []
# ---------------------------------------------------------- PDF downloads
def save_pdf_discoveries(rows: list[dict]) -> None:
"""rows: list from pdf_downloader.bulk_discover."""
payload = []
now = datetime.datetime.utcnow().isoformat()
for r in rows:
payload.append({
"paper_id": r["paper_id"],
"doi": r.get("doi") or "",
"oa_status": r.get("oa_status"),
"oa_type": r.get("oa_type"),
"source": r.get("source"),
"pdf_url": r.get("pdf_url"),
"scihub_used": bool(r.get("scihub_used", False)),
"error_message": r.get("error"),
"checked_at": now,
})
chunk = 500
for i in range(0, len(payload), chunk):
(client().table("pdf_downloads")
.upsert(payload[i:i + chunk], on_conflict="paper_id").execute())
def save_pdf_downloads(rows: list[dict]) -> None:
"""rows: list from pdf_downloader.bulk_download. Updates existing
pdf_downloads records with download status."""
now = datetime.datetime.utcnow().isoformat()
for r in rows:
update_payload = {
"downloaded": bool(r.get("downloaded", False)),
"file_path": r.get("file_path"),
"file_size_bytes": int(r.get("file_size") or 0),
"uploaded_to_supabase": bool(r.get("uploaded_to_supabase", False)),
"supabase_storage_path": r.get("supabase_path"),
"downloaded_at": now if r.get("downloaded") else None,
}
# Only override error_message if there's a new error
if r.get("error"):
update_payload["error_message"] = r["error"]
(client().table("pdf_downloads")
.update(update_payload)
.eq("paper_id", r["paper_id"]).execute())
def get_pdf_status(paper_id: str | None = None) -> list[dict]:
"""Return all pdf_downloads rows, optionally filtered by paper_id."""
q = client().table("pdf_downloads").select("*")
if paper_id:
q = q.eq("paper_id", paper_id)
res = q.execute()
return res.data or []
def list_pdfs_with_metadata() -> list[dict]:
"""Joined view: pdf_downloads with paper title."""
res = (client().table("pdf_downloads")
.select("*, papers(title, year)")
.order("paper_id").execute())
out = []
for r in res.data or []:
p = r.get("papers") or {}
out.append({
"paper_id": r["paper_id"],
"doi": r.get("doi") or "",
"title": p.get("title", ""),
"year": p.get("year"),
"oa_status": r.get("oa_status") or "",
"oa_type": r.get("oa_type") or "",
"source": r.get("source") or "",
"scihub_used": bool(r.get("scihub_used")),
"downloaded": bool(r.get("downloaded")),
"file_size_mb": round((r.get("file_size_bytes") or 0) / (1024 ** 2), 2),
"uploaded_to_supabase": bool(r.get("uploaded_to_supabase")),
"error_message": r.get("error_message") or "",
})
return out
# ---------------------------------------------------------- Crossref metadata
def save_crossref_metadata(rows: list[dict]) -> None:
"""rows: list from pdf_downloader.bulk_crossref."""
payload = []
for r in rows:
if not r.get("paper_id"):
continue
ref_dois = r.get("reference_dois") or []
payload.append({
"paper_id": r["paper_id"],
"doi": r.get("doi") or "",
"crossref_title": r.get("title"),
"crossref_abstract": r.get("abstract"),
"citation_count": r.get("citation_count"),
"references_count": r.get("references_count"),
"reference_dois": json.dumps(ref_dois) if ref_dois else None,
"publication_type": r.get("type"),
"publisher": r.get("publisher"),
"container_title": r.get("container_title"),
"error_message": r.get("error"),
})
chunk = 500
for i in range(0, len(payload), chunk):
(client().table("crossref_metadata")
.upsert(payload[i:i + chunk], on_conflict="paper_id").execute())
def get_crossref_metadata(paper_id: str | None = None) -> list[dict]:
q = client().table("crossref_metadata").select("*")
if paper_id:
q = q.eq("paper_id", paper_id)
res = q.execute()
return res.data or []