File size: 22,490 Bytes
5b28350
 
 
8d75855
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b28350
 
 
8d75855
5b28350
 
 
 
8d75855
5b28350
 
8d75855
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b28350
8d75855
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b28350
 
8d75855
 
 
5b28350
 
8d75855
 
 
 
 
 
 
5b28350
 
8d75855
 
 
 
 
 
5b28350
 
8d75855
 
 
 
 
 
 
 
 
 
 
 
 
5b28350
 
8d75855
 
 
 
 
 
 
 
 
 
 
5b28350
8d75855
 
 
 
 
 
 
 
 
 
 
5b28350
8d75855
 
 
 
 
 
 
 
 
 
614dd95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d75855
5b28350
 
8d75855
 
 
5b28350
 
8d75855
 
 
5b28350
8d75855
 
 
 
 
 
 
614dd95
8d75855
 
 
 
 
 
 
 
 
 
 
 
 
6ffd841
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d75855
 
 
6ffd841
 
 
 
 
 
 
 
8d75855
6ffd841
 
 
 
 
 
8d75855
 
 
 
 
5b28350
 
 
8d75855
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b28350
 
 
8d75855
 
 
 
 
 
 
 
 
 
 
 
 
 
5b28350
 
 
8d75855
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b28350
 
8d75855
 
 
 
5b28350
 
 
8d75855
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b28350
 
 
8d75855
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
614dd95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
"""
database.py

Responsibility: Supabase persistence for both the embedding workbench and
the TCCM screening pipeline.

Connection: requires SUPABASE_URL and SUPABASE_SERVICE_ROLE_KEY environment
variables. On Hugging Face Spaces these are set in Settings -> Secrets.
The app fails fast with a clear error if either is missing.

Schema management: Supabase intentionally restricts CREATE TABLE from the
client SDK. The researcher runs the schema SQL (via `bootstrap_schema_sql()`)
once in the Supabase SQL editor; subsequent app runs find tables in place.
The UI surfaces missing tables explicitly via `assert_schema_present()`.

Why service-role key (not anon key):
    Unrestricted reads/writes across all tables in this single-researcher
    workbench need the service-role key. The anon key is bound by Row Level
    Security and would require RLS policies for every table — overkill here.
"""

from __future__ import annotations
import os
import json
import datetime
from typing import Any

from supabase import create_client, Client


# ---------------------------------------------------------- env / client
def _require_env() -> tuple[str, str]:
    url = os.environ.get("SUPABASE_URL")
    key = os.environ.get("SUPABASE_SERVICE_ROLE_KEY")
    if not url or not key:
        raise RuntimeError(
            "Missing Supabase credentials. Set SUPABASE_URL and "
            "SUPABASE_SERVICE_ROLE_KEY in the environment "
            "(HF Space Settings -> Variables and secrets -> Secrets)."
        )
    return url, key


_client: Client | None = None


def client() -> Client:
    global _client
    if _client is None:
        url, key = _require_env()
        _client = create_client(url, key)
    return _client


# ---------------------------------------------------------- schema
SCHEMA_SQL = """
-- Run this entire block once in the Supabase SQL editor
-- (Project -> SQL Editor -> New query -> paste -> Run).

create table if not exists papers (
    paper_id        text primary key,
    doi             text,
    title           text not null,
    abstract        text,
    year            integer,
    author_keywords text,
    source          text default 'manual_upload',
    created_at      timestamptz default now()
);

create table if not exists embeddings_2d (
    paper_id    text primary key references papers(paper_id) on delete cascade,
    coords_json text not null
);

create table if not exists clustering_runs (
    run_id      bigserial primary key,
    params_json text not null,
    n_clusters  integer,
    n_noise     integer,
    notes       text,
    created_at  timestamptz default now()
);

create table if not exists cluster_assignments (
    run_id           bigint not null references clustering_runs(run_id) on delete cascade,
    paper_id         text not null references papers(paper_id) on delete cascade,
    cluster_id       integer not null,
    membership_prob  double precision not null,
    primary key (run_id, paper_id)
);

create table if not exists cluster_labels (
    run_id                  bigint not null references clustering_runs(run_id) on delete cascade,
    cluster_id              integer not null,
    label                   text,
    subject                 text,
    object_                 text,
    phenomenon              text,
    rationale               text,
    top_paper_ids_json      text,
    validated_by_researcher boolean default false,
    locked                  boolean default false,
    researcher_notes        text,
    primary key (run_id, cluster_id)
);

create table if not exists tccm_runs (
    run_id               bigserial primary key,
    pattern_set_version  text not null,
    threshold_rule       text not null default 'anti_dominates',
    n_papers             integer,
    n_include            integer,
    n_exclude            integer,
    n_marginal           integer,
    notes                text,
    created_at           timestamptz default now()
);

create table if not exists tccm_classifications (
    run_id            bigint not null references tccm_runs(run_id) on delete cascade,
    paper_id          text not null references papers(paper_id) on delete cascade,
    verdict           text not null check (verdict in ('INCLUDE','EXCLUDE','MARGINAL')),
    n_method          integer not null default 0,
    n_sample          integer not null default 0,
    n_analytic        integer not null default 0,
    n_anti            integer not null default 0,
    fired_terms_json  text,
    primary key (run_id, paper_id)
);

create table if not exists tccm_marginal_reviews (
    run_id              bigint not null references tccm_runs(run_id) on delete cascade,
    paper_id            text not null references papers(paper_id) on delete cascade,
    agent_verdict       text,
    agent_rationale     text,
    researcher_verdict  text,
    researcher_notes    text,
    reviewed_at         timestamptz default now(),
    primary key (run_id, paper_id)
);

create table if not exists pdf_downloads (
    paper_id              text primary key references papers(paper_id) on delete cascade,
    doi                   text,
    oa_status             text,
    oa_type               text,
    source                text,
    pdf_url               text,
    scihub_used           boolean default false,
    downloaded            boolean default false,
    file_path             text,
    file_size_bytes       bigint,
    uploaded_to_supabase  boolean default false,
    supabase_storage_path text,
    error_message         text,
    checked_at            timestamptz,
    downloaded_at         timestamptz
);

create table if not exists crossref_metadata (
    paper_id          text primary key references papers(paper_id) on delete cascade,
    doi               text,
    crossref_title    text,
    crossref_abstract text,
    citation_count    integer,
    references_count  integer,
    reference_dois    text,
    publication_type  text,
    publisher         text,
    container_title   text,
    error_message     text,
    fetched_at        timestamptz default now()
);
"""


def bootstrap_schema_sql() -> str:
    """Return the full schema SQL for one-time manual execution."""
    return SCHEMA_SQL


def assert_schema_present() -> dict[str, bool]:
    """
    Verify all expected tables exist by attempting bounded reads.

    Returns {table_name: present}. Does not raise; the UI handles missing
    tables by showing the bootstrap SQL.
    """
    expected = [
        "papers", "embeddings_2d", "clustering_runs",
        "cluster_assignments", "cluster_labels",
        "tccm_runs", "tccm_classifications", "tccm_marginal_reviews",
        "pdf_downloads", "crossref_metadata",
    ]
    report: dict[str, bool] = {}
    c = client()
    for tbl in expected:
        try:
            c.table(tbl).select("*").limit(1).execute()
            report[tbl] = True
        except Exception:
            report[tbl] = False
    return report


# ---------------------------------------------------------- papers
def _scrub_nan(v):
    """Convert any NaN-like value to None so it's JSON-encodable.
    Handles pandas NaN (float), numpy NaN, and plain string 'nan'."""
    try:
        import math
        if v is None:
            return None
        if isinstance(v, float) and math.isnan(v):
            return None
        # pandas NaT, numpy.nan in object columns
        if v != v:  # NaN is the only value that is not equal to itself
            return None
    except Exception:
        pass
    return v


def upsert_papers(rows: list[dict]) -> int:
    if not rows:
        return 0
    payload = []
    for r in rows:
        doi   = _scrub_nan(r.get("doi"))
        absr  = _scrub_nan(r.get("abstract"))
        keyw  = _scrub_nan(r.get("author_keywords"))
        year  = _scrub_nan(r.get("year"))
        title = _scrub_nan(r.get("title"))
        payload.append({
            "paper_id":        r["paper_id"],
            "doi":             doi if doi is not None else "",
            "title":           title if title is not None else "",
            "abstract":        absr if absr is not None else "",
            "year":            int(year) if year is not None else None,
            "author_keywords": keyw if keyw is not None else "",
        })
    chunk = 500
    for i in range(0, len(payload), chunk):
        (client().table("papers")
         .upsert(payload[i:i + chunk], on_conflict="paper_id").execute())
    return len(payload)


def get_all_papers() -> list[dict]:
    out: list[dict] = []
    page_size = 1000
    offset = 0
    while True:
        res = (client().table("papers").select("*")
               .order("paper_id")
               .range(offset, offset + page_size - 1).execute())
        batch = res.data or []
        out.extend(batch)
        if len(batch) < page_size:
            break
        offset += page_size
    return out


def get_paper(paper_id: str) -> dict | None:
    res = (client().table("papers")
           .select("*").eq("paper_id", paper_id).limit(1).execute())
    return (res.data or [None])[0]


def count_papers() -> int:
    res = client().table("papers").select("paper_id", count="exact").limit(1).execute()
    return res.count or 0


# ---------------------------------------------------------- 2D coords
def save_2d_coords_bulk(rows: list[dict]) -> None:
    if not rows:
        return
    payload = [
        {"paper_id": r["paper_id"], "coords_json": json.dumps(r["coords"])}
        for r in rows
    ]
    chunk = 500
    for i in range(0, len(payload), chunk):
        (client().table("embeddings_2d")
         .upsert(payload[i:i + chunk], on_conflict="paper_id").execute())


def get_2d_coords() -> dict[str, list[float]]:
    res = client().table("embeddings_2d").select("*").execute()
    return {r["paper_id"]: json.loads(r["coords_json"]) for r in (res.data or [])}


# ---------------------------------------------------------- clustering runs
def create_clustering_run(params: dict, n_clusters: int,
                           n_noise: int, notes: str = "") -> int:
    res = client().table("clustering_runs").insert({
        "params_json": json.dumps(params),
        "n_clusters":  int(n_clusters),
        "n_noise":     int(n_noise),
        "notes":       notes,
    }).execute()
    return res.data[0]["run_id"]


def list_clustering_runs() -> list[dict]:
    res = (client().table("clustering_runs").select("*")
           .order("run_id", desc=True).execute())
    out = []
    for r in res.data or []:
        out.append({
            "run_id":     r["run_id"],
            "params":     json.loads(r["params_json"]),
            "n_clusters": r["n_clusters"],
            "n_noise":    r["n_noise"],
            "created_at": r["created_at"],
            "notes":      r.get("notes", "") or "",
        })
    return out


def save_cluster_assignments(run_id: int, paper_ids: list[str],
                              cluster_ids: list[int], probs: list[float]) -> None:
    payload = [
        {"run_id": int(run_id), "paper_id": pid,
         "cluster_id": int(cid), "membership_prob": float(p)}
        for pid, cid, p in zip(paper_ids, cluster_ids, probs)
    ]
    chunk = 500
    for i in range(0, len(payload), chunk):
        (client().table("cluster_assignments")
         .upsert(payload[i:i + chunk], on_conflict="run_id,paper_id").execute())


def get_cluster_assignments(run_id: int) -> list[dict]:
    res = (client().table("cluster_assignments")
           .select("*, papers(title, abstract, year)")
           .eq("run_id", int(run_id)).execute())
    out = []
    for r in res.data or []:
        p = r.get("papers") or {}
        out.append({
            "paper_id":        r["paper_id"],
            "cluster_id":      r["cluster_id"],
            "membership_prob": r["membership_prob"],
            "title":           p.get("title", ""),
            "abstract":        p.get("abstract", ""),
            "year":            p.get("year"),
        })
    out.sort(key=lambda d: (d["cluster_id"], -d["membership_prob"]))
    return out


# ---------------------------------------------------------- cluster labels
def save_cluster_label(run_id: int, cluster_id: int, label: str,
                        subject: str = "", object_: str = "",
                        phenomenon: str = "", rationale: str = "",
                        top_paper_ids: list[str] | None = None,
                        validated_by_researcher: bool = False,
                        locked: bool = False,
                        researcher_notes: str = "") -> None:
    client().table("cluster_labels").upsert({
        "run_id":                  int(run_id),
        "cluster_id":              int(cluster_id),
        "label":                   label,
        "subject":                 subject,
        "object_":                 object_,
        "phenomenon":              phenomenon,
        "rationale":               rationale,
        "top_paper_ids_json":      json.dumps(top_paper_ids or []),
        "validated_by_researcher": bool(validated_by_researcher),
        "locked":                  bool(locked),
        "researcher_notes":        researcher_notes,
    }, on_conflict="run_id,cluster_id").execute()


def get_cluster_labels(run_id: int) -> list[dict]:
    res = (client().table("cluster_labels")
           .select("*").eq("run_id", int(run_id))
           .order("cluster_id").execute())
    out = []
    for r in res.data or []:
        out.append({
            "cluster_id":              r["cluster_id"],
            "label":                   r.get("label") or "",
            "subject":                 r.get("subject") or "",
            "object":                  r.get("object_") or "",
            "phenomenon":              r.get("phenomenon") or "",
            "rationale":               r.get("rationale") or "",
            "top_paper_ids":           json.loads(r.get("top_paper_ids_json") or "[]"),
            "validated_by_researcher": bool(r.get("validated_by_researcher")),
            "locked":                  bool(r.get("locked")),
            "researcher_notes":        r.get("researcher_notes") or "",
        })
    return out


# ---------------------------------------------------------- TCCM runs
def create_tccm_run(pattern_set_version: str, threshold_rule: str,
                     n_papers: int, n_include: int, n_exclude: int,
                     n_marginal: int, notes: str = "") -> int:
    res = client().table("tccm_runs").insert({
        "pattern_set_version": pattern_set_version,
        "threshold_rule":      threshold_rule,
        "n_papers":            int(n_papers),
        "n_include":           int(n_include),
        "n_exclude":           int(n_exclude),
        "n_marginal":          int(n_marginal),
        "notes":               notes,
    }).execute()
    return res.data[0]["run_id"]


def list_tccm_runs() -> list[dict]:
    res = (client().table("tccm_runs").select("*")
           .order("run_id", desc=True).execute())
    return res.data or []


def save_tccm_classifications(run_id: int, rows: list[dict]) -> None:
    payload = [
        {
            "run_id":           int(run_id),
            "paper_id":         r["paper_id"],
            "verdict":          r["verdict"],
            "n_method":         int(r.get("n_method", 0)),
            "n_sample":         int(r.get("n_sample", 0)),
            "n_analytic":       int(r.get("n_analytic", 0)),
            "n_anti":           int(r.get("n_anti", 0)),
            "fired_terms_json": r.get("fired_terms_json", "{}"),
        }
        for r in rows
    ]
    chunk = 500
    for i in range(0, len(payload), chunk):
        (client().table("tccm_classifications")
         .upsert(payload[i:i + chunk], on_conflict="run_id,paper_id").execute())


def get_tccm_classifications(run_id: int,
                              verdict_filter: str | None = None) -> list[dict]:
    q = (client().table("tccm_classifications")
         .select("*, papers(title, abstract, year, author_keywords)")
         .eq("run_id", int(run_id)))
    if verdict_filter:
        q = q.eq("verdict", verdict_filter)
    res = q.execute()
    out = []
    for r in res.data or []:
        p = r.get("papers") or {}
        out.append({
            "paper_id":        r["paper_id"],
            "verdict":         r["verdict"],
            "n_method":        r["n_method"],
            "n_sample":        r["n_sample"],
            "n_analytic":      r["n_analytic"],
            "n_anti":          r["n_anti"],
            "fired_terms":     json.loads(r.get("fired_terms_json") or "{}"),
            "title":           p.get("title", ""),
            "abstract":        p.get("abstract", ""),
            "year":            p.get("year"),
            "author_keywords": p.get("author_keywords") or "",
        })
    return out


# ---------------------------------------------------------- marginal reviews
def save_marginal_review(run_id: int, paper_id: str,
                          agent_verdict: str, agent_rationale: str,
                          researcher_verdict: str | None = None,
                          researcher_notes: str = "") -> None:
    client().table("tccm_marginal_reviews").upsert({
        "run_id":             int(run_id),
        "paper_id":           paper_id,
        "agent_verdict":      agent_verdict,
        "agent_rationale":    agent_rationale,
        "researcher_verdict": researcher_verdict,
        "researcher_notes":   researcher_notes,
        "reviewed_at":        datetime.datetime.utcnow().isoformat(),
    }, on_conflict="run_id,paper_id").execute()


def get_marginal_reviews(run_id: int) -> list[dict]:
    res = (client().table("tccm_marginal_reviews")
           .select("*").eq("run_id", int(run_id)).execute())
    return res.data or []


# ---------------------------------------------------------- PDF downloads
def save_pdf_discoveries(rows: list[dict]) -> None:
    """rows: list from pdf_downloader.bulk_discover."""
    payload = []
    now = datetime.datetime.utcnow().isoformat()
    for r in rows:
        payload.append({
            "paper_id":     r["paper_id"],
            "doi":          r.get("doi") or "",
            "oa_status":    r.get("oa_status"),
            "oa_type":      r.get("oa_type"),
            "source":       r.get("source"),
            "pdf_url":      r.get("pdf_url"),
            "scihub_used":  bool(r.get("scihub_used", False)),
            "error_message": r.get("error"),
            "checked_at":   now,
        })
    chunk = 500
    for i in range(0, len(payload), chunk):
        (client().table("pdf_downloads")
         .upsert(payload[i:i + chunk], on_conflict="paper_id").execute())


def save_pdf_downloads(rows: list[dict]) -> None:
    """rows: list from pdf_downloader.bulk_download. Updates existing
    pdf_downloads records with download status."""
    now = datetime.datetime.utcnow().isoformat()
    for r in rows:
        update_payload = {
            "downloaded":            bool(r.get("downloaded", False)),
            "file_path":             r.get("file_path"),
            "file_size_bytes":       int(r.get("file_size") or 0),
            "uploaded_to_supabase":  bool(r.get("uploaded_to_supabase", False)),
            "supabase_storage_path": r.get("supabase_path"),
            "downloaded_at":         now if r.get("downloaded") else None,
        }
        # Only override error_message if there's a new error
        if r.get("error"):
            update_payload["error_message"] = r["error"]
        (client().table("pdf_downloads")
         .update(update_payload)
         .eq("paper_id", r["paper_id"]).execute())


def get_pdf_status(paper_id: str | None = None) -> list[dict]:
    """Return all pdf_downloads rows, optionally filtered by paper_id."""
    q = client().table("pdf_downloads").select("*")
    if paper_id:
        q = q.eq("paper_id", paper_id)
    res = q.execute()
    return res.data or []


def list_pdfs_with_metadata() -> list[dict]:
    """Joined view: pdf_downloads with paper title."""
    res = (client().table("pdf_downloads")
           .select("*, papers(title, year)")
           .order("paper_id").execute())
    out = []
    for r in res.data or []:
        p = r.get("papers") or {}
        out.append({
            "paper_id":             r["paper_id"],
            "doi":                  r.get("doi") or "",
            "title":                p.get("title", ""),
            "year":                 p.get("year"),
            "oa_status":            r.get("oa_status") or "",
            "oa_type":              r.get("oa_type") or "",
            "source":               r.get("source") or "",
            "scihub_used":          bool(r.get("scihub_used")),
            "downloaded":           bool(r.get("downloaded")),
            "file_size_mb":         round((r.get("file_size_bytes") or 0) / (1024 ** 2), 2),
            "uploaded_to_supabase": bool(r.get("uploaded_to_supabase")),
            "error_message":        r.get("error_message") or "",
        })
    return out


# ---------------------------------------------------------- Crossref metadata
def save_crossref_metadata(rows: list[dict]) -> None:
    """rows: list from pdf_downloader.bulk_crossref."""
    payload = []
    for r in rows:
        if not r.get("paper_id"):
            continue
        ref_dois = r.get("reference_dois") or []
        payload.append({
            "paper_id":          r["paper_id"],
            "doi":               r.get("doi") or "",
            "crossref_title":    r.get("title"),
            "crossref_abstract": r.get("abstract"),
            "citation_count":    r.get("citation_count"),
            "references_count":  r.get("references_count"),
            "reference_dois":    json.dumps(ref_dois) if ref_dois else None,
            "publication_type":  r.get("type"),
            "publisher":         r.get("publisher"),
            "container_title":   r.get("container_title"),
            "error_message":     r.get("error"),
        })
    chunk = 500
    for i in range(0, len(payload), chunk):
        (client().table("crossref_metadata")
         .upsert(payload[i:i + chunk], on_conflict="paper_id").execute())


def get_crossref_metadata(paper_id: str | None = None) -> list[dict]:
    q = client().table("crossref_metadata").select("*")
    if paper_id:
        q = q.eq("paper_id", paper_id)
    res = q.execute()
    return res.data or []