| """ |
| Step 1: Fetch citation edges from Semantic Scholar API. |
| |
| Produces: citations.parquet β (citing_arxiv_id, cited_arxiv_id) |
| where BOTH IDs exist in the ResearchIT Qdrant corpus. |
| |
| Usage: |
| # Option A: Batch API (no API key needed, slower, ~1-2 hours for 1.6M papers) |
| python 01_fetch_citation_edges.py --corpus-file arxiv_ids.txt --output citations.parquet |
| |
| # Option B: Batch API with API key (faster, ~30-60 min) |
| python 01_fetch_citation_edges.py --corpus-file arxiv_ids.txt --output citations.parquet --api-key YOUR_KEY |
| |
| # Option C: If you already have the S2 bulk datasets downloaded: |
| python 01_fetch_citation_edges.py --bulk-papers paper-ids.jsonl.gz --bulk-citations citations.jsonl.gz \ |
| --corpus-file arxiv_ids.txt --output citations.parquet |
| |
| Prerequisites: |
| - arxiv_ids.txt: one arXiv ID per line (e.g. "2303.14957"), exported from Qdrant/Turso |
| - pip install httpx pyarrow tqdm |
| |
| Output schema: |
| citing_arxiv_id (string) β the paper that contains the citation |
| cited_arxiv_id (string) β the paper being cited |
| is_influential (bool) β S2's influential citation flag (if available) |
| |
| Author: ResearchIT ML Pipeline β Phase 6, Step 1 |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import asyncio |
| import gzip |
| import json |
| import os |
| import sys |
| import time |
| from pathlib import Path |
|
|
| import httpx |
| import pyarrow as pa |
| import pyarrow.parquet as pq |
| from tqdm import tqdm |
|
|
|
|
| |
|
|
| S2_BATCH_URL = "https://api.semanticscholar.org/graph/v1/paper/batch" |
| S2_BATCH_FIELDS = "externalIds,references.externalIds" |
| BATCH_SIZE = 500 |
| MAX_RETRIES = 5 |
| RETRY_BACKOFF_BASE = 2 |
| CHECKPOINT_EVERY = 50 |
|
|
|
|
| |
|
|
| async def fetch_one_batch( |
| client: httpx.AsyncClient, |
| arxiv_ids: list[str], |
| api_key: str | None, |
| batch_idx: int, |
| ) -> list[tuple[str, str, bool]]: |
| """ |
| Fetch references for a batch of arXiv IDs via S2 batch endpoint. |
| |
| Returns list of (citing_arxiv_id, cited_arxiv_id, is_influential) tuples. |
| Only returns edges where the cited paper has an arXiv ID. |
| (In-corpus filtering happens later.) |
| """ |
| |
| s2_ids = [f"arXiv:{aid}" for aid in arxiv_ids] |
|
|
| headers = {"Content-Type": "application/json"} |
| if api_key: |
| headers["x-api-key"] = api_key |
|
|
| url = f"{S2_BATCH_URL}?fields={S2_BATCH_FIELDS}" |
|
|
| for attempt in range(MAX_RETRIES): |
| try: |
| resp = await client.post( |
| url, |
| json={"ids": s2_ids}, |
| headers=headers, |
| timeout=30.0, |
| ) |
|
|
| if resp.status_code == 200: |
| results = resp.json() |
| edges = [] |
| for i, paper in enumerate(results): |
| if paper is None: |
| continue |
| citing_id = arxiv_ids[i] |
| refs = paper.get("references") or [] |
| for ref in refs: |
| ext_ids = ref.get("externalIds") or {} |
| cited_arxiv = ext_ids.get("ArXiv") |
| if cited_arxiv: |
| edges.append((citing_id, cited_arxiv, False)) |
| return edges |
|
|
| elif resp.status_code == 429: |
| retry_after = int(resp.headers.get("Retry-After", RETRY_BACKOFF_BASE ** attempt)) |
| print(f" [batch {batch_idx}] Rate limited. Waiting {retry_after}s (attempt {attempt+1}/{MAX_RETRIES})") |
| await asyncio.sleep(retry_after) |
|
|
| elif resp.status_code == 400: |
| print(f" [batch {batch_idx}] Bad request (400). Skipping batch.") |
| return [] |
|
|
| else: |
| print(f" [batch {batch_idx}] HTTP {resp.status_code}. Retrying (attempt {attempt+1}/{MAX_RETRIES})") |
| await asyncio.sleep(RETRY_BACKOFF_BASE ** attempt) |
|
|
| except (httpx.TimeoutException, httpx.ConnectError, httpx.ReadError) as e: |
| print(f" [batch {batch_idx}] Network error: {type(e).__name__}. Retrying (attempt {attempt+1}/{MAX_RETRIES})") |
| await asyncio.sleep(RETRY_BACKOFF_BASE ** attempt) |
|
|
| print(f" [batch {batch_idx}] FAILED after {MAX_RETRIES} attempts. Skipping.") |
| return [] |
|
|
|
|
| async def fetch_all_batches( |
| corpus_ids: list[str], |
| api_key: str | None, |
| checkpoint_dir: Path, |
| ) -> list[tuple[str, str, bool]]: |
| """ |
| Fetch citation edges for all corpus IDs using the S2 batch API. |
| Supports checkpoint/resume. |
| """ |
| |
| checkpoint_file = checkpoint_dir / "checkpoint.json" |
| all_edges: list[tuple[str, str, bool]] = [] |
| start_batch = 0 |
|
|
| if checkpoint_file.exists(): |
| with open(checkpoint_file) as f: |
| ckpt = json.load(f) |
| start_batch = ckpt["next_batch"] |
| |
| edges_file = checkpoint_dir / "edges_partial.jsonl" |
| if edges_file.exists(): |
| with open(edges_file) as f: |
| for line in f: |
| row = json.loads(line) |
| all_edges.append((row["citing"], row["cited"], row["influential"])) |
| print(f"Resuming from batch {start_batch} ({len(all_edges)} edges already collected)") |
|
|
| |
| batches = [] |
| for i in range(0, len(corpus_ids), BATCH_SIZE): |
| batches.append(corpus_ids[i : i + BATCH_SIZE]) |
|
|
| total_batches = len(batches) |
| print(f"Total: {len(corpus_ids)} papers β {total_batches} batches of {BATCH_SIZE}") |
| print(f"Starting from batch {start_batch}") |
|
|
| |
| delay = 0.5 if api_key else 1.1 |
|
|
| edges_file = checkpoint_dir / "edges_partial.jsonl" |
|
|
| async with httpx.AsyncClient() as client: |
| pbar = tqdm(range(start_batch, total_batches), initial=start_batch, total=total_batches) |
| for batch_idx in pbar: |
| batch = batches[batch_idx] |
|
|
| edges = await fetch_one_batch(client, batch, api_key, batch_idx) |
| all_edges.extend(edges) |
|
|
| |
| with open(edges_file, "a") as f: |
| for citing, cited, influential in edges: |
| f.write(json.dumps({"citing": citing, "cited": cited, "influential": influential}) + "\n") |
|
|
| pbar.set_postfix({"edges": len(all_edges), "batch_edges": len(edges)}) |
|
|
| |
| if (batch_idx + 1) % CHECKPOINT_EVERY == 0: |
| with open(checkpoint_file, "w") as f: |
| json.dump({"next_batch": batch_idx + 1}, f) |
|
|
| await asyncio.sleep(delay) |
|
|
| |
| with open(checkpoint_file, "w") as f: |
| json.dump({"next_batch": total_batches, "status": "complete"}, f) |
|
|
| return all_edges |
|
|
|
|
| |
|
|
| def process_bulk_downloads( |
| papers_file: str, |
| citations_file: str, |
| corpus_set: set[str], |
| ) -> list[tuple[str, str, bool]]: |
| """ |
| Process S2 bulk dataset downloads to extract in-corpus citation edges. |
| |
| papers_file: paper-ids.jsonl.gz (corpusid β externalIds mapping) |
| citations_file: citations.jsonl.gz (citingcorpusid β citedcorpusid edges) |
| """ |
| print("Step 1/2: Building corpusid β arxiv_id mapping from paper-ids...") |
| corpusid_to_arxiv: dict[int, str] = {} |
| with gzip.open(papers_file, "rt") as f: |
| for line in tqdm(f, desc="Reading paper-ids"): |
| try: |
| rec = json.loads(line) |
| ext_ids = rec.get("externalids") or rec.get("externalIds") or {} |
| arxiv_id = ext_ids.get("ArXiv") |
| corpus_id = rec.get("corpusid") or rec.get("corpusId") |
| if arxiv_id and corpus_id and arxiv_id in corpus_set: |
| corpusid_to_arxiv[int(corpus_id)] = arxiv_id |
| except (json.JSONDecodeError, ValueError): |
| continue |
|
|
| print(f" Mapped {len(corpusid_to_arxiv)} corpus IDs to arXiv IDs in your corpus") |
|
|
| print("Step 2/2: Filtering citation edges to in-corpus pairs...") |
| edges: list[tuple[str, str, bool]] = [] |
| with gzip.open(citations_file, "rt") as f: |
| for line in tqdm(f, desc="Reading citations"): |
| try: |
| rec = json.loads(line) |
| citing_cid = rec.get("citingcorpusid") or rec.get("citingCorpusId") |
| cited_cid = rec.get("citedcorpusid") or rec.get("citedCorpusId") |
| is_influential = rec.get("isinfluential", False) or rec.get("isInfluential", False) |
|
|
| citing_arxiv = corpusid_to_arxiv.get(int(citing_cid)) if citing_cid else None |
| cited_arxiv = corpusid_to_arxiv.get(int(cited_cid)) if cited_cid else None |
|
|
| if citing_arxiv and cited_arxiv: |
| edges.append((citing_arxiv, cited_arxiv, bool(is_influential))) |
| except (json.JSONDecodeError, ValueError, TypeError): |
| continue |
|
|
| print(f" Found {len(edges)} in-corpus citation edges") |
| return edges |
|
|
|
|
| |
|
|
| def filter_and_save( |
| edges: list[tuple[str, str, bool]], |
| corpus_set: set[str], |
| output_path: str, |
| ): |
| """ |
| Filter edges to in-corpus pairs, deduplicate, and save as parquet. |
| """ |
| print(f"Raw edges before filtering: {len(edges)}") |
|
|
| |
| filtered = [ |
| (citing, cited, influential) |
| for citing, cited, influential in edges |
| if citing in corpus_set and cited in corpus_set and citing != cited |
| ] |
| print(f"In-corpus edges (both sides in corpus): {len(filtered)}") |
|
|
| |
| seen = set() |
| deduped = [] |
| for citing, cited, influential in filtered: |
| key = (citing, cited) |
| if key not in seen: |
| seen.add(key) |
| deduped.append((citing, cited, influential)) |
|
|
| print(f"After deduplication: {len(deduped)}") |
|
|
| |
| table = pa.table({ |
| "citing_arxiv_id": pa.array([e[0] for e in deduped], type=pa.string()), |
| "cited_arxiv_id": pa.array([e[1] for e in deduped], type=pa.string()), |
| "is_influential": pa.array([e[2] for e in deduped], type=pa.bool_()), |
| }) |
|
|
| pq.write_table(table, output_path, compression="snappy") |
| print(f"Saved {len(deduped)} citation edges to {output_path}") |
|
|
| |
| citing_papers = set(e[0] for e in deduped) |
| cited_papers = set(e[1] for e in deduped) |
| print(f"\nStats:") |
| print(f" Unique citing papers: {len(citing_papers)}") |
| print(f" Unique cited papers: {len(cited_papers)}") |
| print(f" Unique papers total: {len(citing_papers | cited_papers)}") |
| print(f" Avg references per citing paper: {len(deduped) / max(len(citing_papers), 1):.1f}") |
| influential_count = sum(1 for e in deduped if e[2]) |
| print(f" Influential citations: {influential_count} ({100*influential_count/max(len(deduped),1):.1f}%)") |
|
|
|
|
| |
|
|
| def load_corpus_ids(path: str) -> list[str]: |
| """Load arXiv IDs from a text file (one per line).""" |
| ids = [] |
| with open(path) as f: |
| for line in f: |
| line = line.strip() |
| if line and not line.startswith("#"): |
| |
| if line.startswith("arXiv:"): |
| line = line[6:] |
| elif line.startswith("ARXIV:"): |
| line = line[6:] |
| ids.append(line) |
| print(f"Loaded {len(ids)} arXiv IDs from {path}") |
| return ids |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Fetch citation edges from Semantic Scholar for ResearchIT corpus" |
| ) |
| parser.add_argument( |
| "--corpus-file", required=True, |
| help="Text file with one arXiv ID per line (e.g. arxiv_ids.txt)" |
| ) |
| parser.add_argument( |
| "--output", default="citations.parquet", |
| help="Output parquet file path (default: citations.parquet)" |
| ) |
| parser.add_argument( |
| "--api-key", default=None, |
| help="Semantic Scholar API key (optional, speeds up rate limit)" |
| ) |
| parser.add_argument( |
| "--bulk-papers", default=None, |
| help="Path to S2 bulk paper-ids.jsonl.gz (use bulk download path)" |
| ) |
| parser.add_argument( |
| "--bulk-citations", default=None, |
| help="Path to S2 bulk citations.jsonl.gz (use bulk download path)" |
| ) |
| parser.add_argument( |
| "--checkpoint-dir", default="./citation_checkpoint", |
| help="Directory for checkpoint files (batch API mode)" |
| ) |
| parser.add_argument( |
| "--max-papers", type=int, default=None, |
| help="Limit to first N papers (for testing)" |
| ) |
|
|
| args = parser.parse_args() |
|
|
| |
| corpus_ids = load_corpus_ids(args.corpus_file) |
| if args.max_papers: |
| corpus_ids = corpus_ids[:args.max_papers] |
| print(f" Limited to {len(corpus_ids)} papers (--max-papers)") |
|
|
| corpus_set = set(corpus_ids) |
|
|
| |
| if args.bulk_papers and args.bulk_citations: |
| print("\n=== BULK DOWNLOAD PATH ===") |
| edges = process_bulk_downloads(args.bulk_papers, args.bulk_citations, corpus_set) |
| else: |
| print("\n=== BATCH API PATH ===") |
| if not args.api_key: |
| |
| args.api_key = os.environ.get("S2_API_KEY") |
| if args.api_key: |
| print(f"Using API key: {args.api_key[:8]}...") |
| else: |
| print("No API key β using unauthenticated rate (1 req/s)") |
| print("Get a free key at: https://www.semanticscholar.org/product/api#Partner-Form") |
|
|
| checkpoint_dir = Path(args.checkpoint_dir) |
| checkpoint_dir.mkdir(parents=True, exist_ok=True) |
|
|
| edges = asyncio.run(fetch_all_batches(corpus_ids, args.api_key, checkpoint_dir)) |
|
|
| |
| filter_and_save(edges, corpus_set, args.output) |
|
|
| print(f"\nβ
Done! Citation edges saved to: {args.output}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|