researchit-reranker-phase6 / scripts /01_fetch_citation_edges.py
siddhm11's picture
Add 01_fetch_citation_edges.py
c82215c verified
"""
Step 1: Fetch citation edges from Semantic Scholar API.
Produces: citations.parquet β†’ (citing_arxiv_id, cited_arxiv_id)
where BOTH IDs exist in the ResearchIT Qdrant corpus.
Usage:
# Option A: Batch API (no API key needed, slower, ~1-2 hours for 1.6M papers)
python 01_fetch_citation_edges.py --corpus-file arxiv_ids.txt --output citations.parquet
# Option B: Batch API with API key (faster, ~30-60 min)
python 01_fetch_citation_edges.py --corpus-file arxiv_ids.txt --output citations.parquet --api-key YOUR_KEY
# Option C: If you already have the S2 bulk datasets downloaded:
python 01_fetch_citation_edges.py --bulk-papers paper-ids.jsonl.gz --bulk-citations citations.jsonl.gz \
--corpus-file arxiv_ids.txt --output citations.parquet
Prerequisites:
- arxiv_ids.txt: one arXiv ID per line (e.g. "2303.14957"), exported from Qdrant/Turso
- pip install httpx pyarrow tqdm
Output schema:
citing_arxiv_id (string) β€” the paper that contains the citation
cited_arxiv_id (string) β€” the paper being cited
is_influential (bool) β€” S2's influential citation flag (if available)
Author: ResearchIT ML Pipeline β€” Phase 6, Step 1
"""
from __future__ import annotations
import argparse
import asyncio
import gzip
import json
import os
import sys
import time
from pathlib import Path
import httpx
import pyarrow as pa
import pyarrow.parquet as pq
from tqdm import tqdm
# ── Constants ────────────────────────────────────────────────────────────────
S2_BATCH_URL = "https://api.semanticscholar.org/graph/v1/paper/batch"
S2_BATCH_FIELDS = "externalIds,references.externalIds"
BATCH_SIZE = 500 # S2 hard limit
MAX_RETRIES = 5 # per batch
RETRY_BACKOFF_BASE = 2 # exponential backoff base (seconds)
CHECKPOINT_EVERY = 50 # save checkpoint every N batches
# ── Batch API Path ───────────────────────────────────────────────────────────
async def fetch_one_batch(
client: httpx.AsyncClient,
arxiv_ids: list[str],
api_key: str | None,
batch_idx: int,
) -> list[tuple[str, str, bool]]:
"""
Fetch references for a batch of arXiv IDs via S2 batch endpoint.
Returns list of (citing_arxiv_id, cited_arxiv_id, is_influential) tuples.
Only returns edges where the cited paper has an arXiv ID.
(In-corpus filtering happens later.)
"""
# Format IDs for S2: "arXiv:2303.14957"
s2_ids = [f"arXiv:{aid}" for aid in arxiv_ids]
headers = {"Content-Type": "application/json"}
if api_key:
headers["x-api-key"] = api_key
url = f"{S2_BATCH_URL}?fields={S2_BATCH_FIELDS}"
for attempt in range(MAX_RETRIES):
try:
resp = await client.post(
url,
json={"ids": s2_ids},
headers=headers,
timeout=30.0,
)
if resp.status_code == 200:
results = resp.json()
edges = []
for i, paper in enumerate(results):
if paper is None:
continue
citing_id = arxiv_ids[i]
refs = paper.get("references") or []
for ref in refs:
ext_ids = ref.get("externalIds") or {}
cited_arxiv = ext_ids.get("ArXiv")
if cited_arxiv:
edges.append((citing_id, cited_arxiv, False))
return edges
elif resp.status_code == 429:
retry_after = int(resp.headers.get("Retry-After", RETRY_BACKOFF_BASE ** attempt))
print(f" [batch {batch_idx}] Rate limited. Waiting {retry_after}s (attempt {attempt+1}/{MAX_RETRIES})")
await asyncio.sleep(retry_after)
elif resp.status_code == 400:
print(f" [batch {batch_idx}] Bad request (400). Skipping batch.")
return []
else:
print(f" [batch {batch_idx}] HTTP {resp.status_code}. Retrying (attempt {attempt+1}/{MAX_RETRIES})")
await asyncio.sleep(RETRY_BACKOFF_BASE ** attempt)
except (httpx.TimeoutException, httpx.ConnectError, httpx.ReadError) as e:
print(f" [batch {batch_idx}] Network error: {type(e).__name__}. Retrying (attempt {attempt+1}/{MAX_RETRIES})")
await asyncio.sleep(RETRY_BACKOFF_BASE ** attempt)
print(f" [batch {batch_idx}] FAILED after {MAX_RETRIES} attempts. Skipping.")
return []
async def fetch_all_batches(
corpus_ids: list[str],
api_key: str | None,
checkpoint_dir: Path,
) -> list[tuple[str, str, bool]]:
"""
Fetch citation edges for all corpus IDs using the S2 batch API.
Supports checkpoint/resume.
"""
# Check for existing checkpoint
checkpoint_file = checkpoint_dir / "checkpoint.json"
all_edges: list[tuple[str, str, bool]] = []
start_batch = 0
if checkpoint_file.exists():
with open(checkpoint_file) as f:
ckpt = json.load(f)
start_batch = ckpt["next_batch"]
# Load previously saved edges
edges_file = checkpoint_dir / "edges_partial.jsonl"
if edges_file.exists():
with open(edges_file) as f:
for line in f:
row = json.loads(line)
all_edges.append((row["citing"], row["cited"], row["influential"]))
print(f"Resuming from batch {start_batch} ({len(all_edges)} edges already collected)")
# Split into batches
batches = []
for i in range(0, len(corpus_ids), BATCH_SIZE):
batches.append(corpus_ids[i : i + BATCH_SIZE])
total_batches = len(batches)
print(f"Total: {len(corpus_ids)} papers β†’ {total_batches} batches of {BATCH_SIZE}")
print(f"Starting from batch {start_batch}")
# Rate limiting: 1 req/s without key, slightly faster with key
delay = 0.5 if api_key else 1.1
edges_file = checkpoint_dir / "edges_partial.jsonl"
async with httpx.AsyncClient() as client:
pbar = tqdm(range(start_batch, total_batches), initial=start_batch, total=total_batches)
for batch_idx in pbar:
batch = batches[batch_idx]
edges = await fetch_one_batch(client, batch, api_key, batch_idx)
all_edges.extend(edges)
# Append edges to partial file
with open(edges_file, "a") as f:
for citing, cited, influential in edges:
f.write(json.dumps({"citing": citing, "cited": cited, "influential": influential}) + "\n")
pbar.set_postfix({"edges": len(all_edges), "batch_edges": len(edges)})
# Checkpoint periodically
if (batch_idx + 1) % CHECKPOINT_EVERY == 0:
with open(checkpoint_file, "w") as f:
json.dump({"next_batch": batch_idx + 1}, f)
await asyncio.sleep(delay)
# Final checkpoint
with open(checkpoint_file, "w") as f:
json.dump({"next_batch": total_batches, "status": "complete"}, f)
return all_edges
# ── Bulk Download Path ───────────────────────────────────────────────────────
def process_bulk_downloads(
papers_file: str,
citations_file: str,
corpus_set: set[str],
) -> list[tuple[str, str, bool]]:
"""
Process S2 bulk dataset downloads to extract in-corpus citation edges.
papers_file: paper-ids.jsonl.gz (corpusid β†’ externalIds mapping)
citations_file: citations.jsonl.gz (citingcorpusid β†’ citedcorpusid edges)
"""
print("Step 1/2: Building corpusid β†’ arxiv_id mapping from paper-ids...")
corpusid_to_arxiv: dict[int, str] = {}
with gzip.open(papers_file, "rt") as f:
for line in tqdm(f, desc="Reading paper-ids"):
try:
rec = json.loads(line)
ext_ids = rec.get("externalids") or rec.get("externalIds") or {}
arxiv_id = ext_ids.get("ArXiv")
corpus_id = rec.get("corpusid") or rec.get("corpusId")
if arxiv_id and corpus_id and arxiv_id in corpus_set:
corpusid_to_arxiv[int(corpus_id)] = arxiv_id
except (json.JSONDecodeError, ValueError):
continue
print(f" Mapped {len(corpusid_to_arxiv)} corpus IDs to arXiv IDs in your corpus")
print("Step 2/2: Filtering citation edges to in-corpus pairs...")
edges: list[tuple[str, str, bool]] = []
with gzip.open(citations_file, "rt") as f:
for line in tqdm(f, desc="Reading citations"):
try:
rec = json.loads(line)
citing_cid = rec.get("citingcorpusid") or rec.get("citingCorpusId")
cited_cid = rec.get("citedcorpusid") or rec.get("citedCorpusId")
is_influential = rec.get("isinfluential", False) or rec.get("isInfluential", False)
citing_arxiv = corpusid_to_arxiv.get(int(citing_cid)) if citing_cid else None
cited_arxiv = corpusid_to_arxiv.get(int(cited_cid)) if cited_cid else None
if citing_arxiv and cited_arxiv:
edges.append((citing_arxiv, cited_arxiv, bool(is_influential)))
except (json.JSONDecodeError, ValueError, TypeError):
continue
print(f" Found {len(edges)} in-corpus citation edges")
return edges
# ── Filter & Save ────────────────────────────────────────────────────────────
def filter_and_save(
edges: list[tuple[str, str, bool]],
corpus_set: set[str],
output_path: str,
):
"""
Filter edges to in-corpus pairs, deduplicate, and save as parquet.
"""
print(f"Raw edges before filtering: {len(edges)}")
# Filter: both citing and cited must be in corpus
filtered = [
(citing, cited, influential)
for citing, cited, influential in edges
if citing in corpus_set and cited in corpus_set and citing != cited
]
print(f"In-corpus edges (both sides in corpus): {len(filtered)}")
# Deduplicate
seen = set()
deduped = []
for citing, cited, influential in filtered:
key = (citing, cited)
if key not in seen:
seen.add(key)
deduped.append((citing, cited, influential))
print(f"After deduplication: {len(deduped)}")
# Save as parquet
table = pa.table({
"citing_arxiv_id": pa.array([e[0] for e in deduped], type=pa.string()),
"cited_arxiv_id": pa.array([e[1] for e in deduped], type=pa.string()),
"is_influential": pa.array([e[2] for e in deduped], type=pa.bool_()),
})
pq.write_table(table, output_path, compression="snappy")
print(f"Saved {len(deduped)} citation edges to {output_path}")
# Print stats
citing_papers = set(e[0] for e in deduped)
cited_papers = set(e[1] for e in deduped)
print(f"\nStats:")
print(f" Unique citing papers: {len(citing_papers)}")
print(f" Unique cited papers: {len(cited_papers)}")
print(f" Unique papers total: {len(citing_papers | cited_papers)}")
print(f" Avg references per citing paper: {len(deduped) / max(len(citing_papers), 1):.1f}")
influential_count = sum(1 for e in deduped if e[2])
print(f" Influential citations: {influential_count} ({100*influential_count/max(len(deduped),1):.1f}%)")
# ── Main ─────────────────────────────────────────────────────────────────────
def load_corpus_ids(path: str) -> list[str]:
"""Load arXiv IDs from a text file (one per line)."""
ids = []
with open(path) as f:
for line in f:
line = line.strip()
if line and not line.startswith("#"):
# Handle various formats: "2303.14957", "arXiv:2303.14957", etc.
if line.startswith("arXiv:"):
line = line[6:]
elif line.startswith("ARXIV:"):
line = line[6:]
ids.append(line)
print(f"Loaded {len(ids)} arXiv IDs from {path}")
return ids
def main():
parser = argparse.ArgumentParser(
description="Fetch citation edges from Semantic Scholar for ResearchIT corpus"
)
parser.add_argument(
"--corpus-file", required=True,
help="Text file with one arXiv ID per line (e.g. arxiv_ids.txt)"
)
parser.add_argument(
"--output", default="citations.parquet",
help="Output parquet file path (default: citations.parquet)"
)
parser.add_argument(
"--api-key", default=None,
help="Semantic Scholar API key (optional, speeds up rate limit)"
)
parser.add_argument(
"--bulk-papers", default=None,
help="Path to S2 bulk paper-ids.jsonl.gz (use bulk download path)"
)
parser.add_argument(
"--bulk-citations", default=None,
help="Path to S2 bulk citations.jsonl.gz (use bulk download path)"
)
parser.add_argument(
"--checkpoint-dir", default="./citation_checkpoint",
help="Directory for checkpoint files (batch API mode)"
)
parser.add_argument(
"--max-papers", type=int, default=None,
help="Limit to first N papers (for testing)"
)
args = parser.parse_args()
# Load corpus
corpus_ids = load_corpus_ids(args.corpus_file)
if args.max_papers:
corpus_ids = corpus_ids[:args.max_papers]
print(f" Limited to {len(corpus_ids)} papers (--max-papers)")
corpus_set = set(corpus_ids)
# Choose path
if args.bulk_papers and args.bulk_citations:
print("\n=== BULK DOWNLOAD PATH ===")
edges = process_bulk_downloads(args.bulk_papers, args.bulk_citations, corpus_set)
else:
print("\n=== BATCH API PATH ===")
if not args.api_key:
# Check environment variable
args.api_key = os.environ.get("S2_API_KEY")
if args.api_key:
print(f"Using API key: {args.api_key[:8]}...")
else:
print("No API key β€” using unauthenticated rate (1 req/s)")
print("Get a free key at: https://www.semanticscholar.org/product/api#Partner-Form")
checkpoint_dir = Path(args.checkpoint_dir)
checkpoint_dir.mkdir(parents=True, exist_ok=True)
edges = asyncio.run(fetch_all_batches(corpus_ids, args.api_key, checkpoint_dir))
# Filter to in-corpus and save
filter_and_save(edges, corpus_set, args.output)
print(f"\nβœ… Done! Citation edges saved to: {args.output}")
if __name__ == "__main__":
main()