siddhm11 commited on
Commit
c82215c
Β·
verified Β·
1 Parent(s): dc5b2e5

Add 01_fetch_citation_edges.py

Browse files
Files changed (1) hide show
  1. scripts/01_fetch_citation_edges.py +388 -0
scripts/01_fetch_citation_edges.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Step 1: Fetch citation edges from Semantic Scholar API.
3
+
4
+ Produces: citations.parquet β†’ (citing_arxiv_id, cited_arxiv_id)
5
+ where BOTH IDs exist in the ResearchIT Qdrant corpus.
6
+
7
+ Usage:
8
+ # Option A: Batch API (no API key needed, slower, ~1-2 hours for 1.6M papers)
9
+ python 01_fetch_citation_edges.py --corpus-file arxiv_ids.txt --output citations.parquet
10
+
11
+ # Option B: Batch API with API key (faster, ~30-60 min)
12
+ python 01_fetch_citation_edges.py --corpus-file arxiv_ids.txt --output citations.parquet --api-key YOUR_KEY
13
+
14
+ # Option C: If you already have the S2 bulk datasets downloaded:
15
+ python 01_fetch_citation_edges.py --bulk-papers paper-ids.jsonl.gz --bulk-citations citations.jsonl.gz \
16
+ --corpus-file arxiv_ids.txt --output citations.parquet
17
+
18
+ Prerequisites:
19
+ - arxiv_ids.txt: one arXiv ID per line (e.g. "2303.14957"), exported from Qdrant/Turso
20
+ - pip install httpx pyarrow tqdm
21
+
22
+ Output schema:
23
+ citing_arxiv_id (string) β€” the paper that contains the citation
24
+ cited_arxiv_id (string) β€” the paper being cited
25
+ is_influential (bool) β€” S2's influential citation flag (if available)
26
+
27
+ Author: ResearchIT ML Pipeline β€” Phase 6, Step 1
28
+ """
29
+ from __future__ import annotations
30
+
31
+ import argparse
32
+ import asyncio
33
+ import gzip
34
+ import json
35
+ import os
36
+ import sys
37
+ import time
38
+ from pathlib import Path
39
+
40
+ import httpx
41
+ import pyarrow as pa
42
+ import pyarrow.parquet as pq
43
+ from tqdm import tqdm
44
+
45
+
46
+ # ── Constants ────────────────────────────────────────────────────────────────
47
+
48
+ S2_BATCH_URL = "https://api.semanticscholar.org/graph/v1/paper/batch"
49
+ S2_BATCH_FIELDS = "externalIds,references.externalIds"
50
+ BATCH_SIZE = 500 # S2 hard limit
51
+ MAX_RETRIES = 5 # per batch
52
+ RETRY_BACKOFF_BASE = 2 # exponential backoff base (seconds)
53
+ CHECKPOINT_EVERY = 50 # save checkpoint every N batches
54
+
55
+
56
+ # ── Batch API Path ───────────────────────────────────────────────────────────
57
+
58
+ async def fetch_one_batch(
59
+ client: httpx.AsyncClient,
60
+ arxiv_ids: list[str],
61
+ api_key: str | None,
62
+ batch_idx: int,
63
+ ) -> list[tuple[str, str, bool]]:
64
+ """
65
+ Fetch references for a batch of arXiv IDs via S2 batch endpoint.
66
+
67
+ Returns list of (citing_arxiv_id, cited_arxiv_id, is_influential) tuples.
68
+ Only returns edges where the cited paper has an arXiv ID.
69
+ (In-corpus filtering happens later.)
70
+ """
71
+ # Format IDs for S2: "arXiv:2303.14957"
72
+ s2_ids = [f"arXiv:{aid}" for aid in arxiv_ids]
73
+
74
+ headers = {"Content-Type": "application/json"}
75
+ if api_key:
76
+ headers["x-api-key"] = api_key
77
+
78
+ url = f"{S2_BATCH_URL}?fields={S2_BATCH_FIELDS}"
79
+
80
+ for attempt in range(MAX_RETRIES):
81
+ try:
82
+ resp = await client.post(
83
+ url,
84
+ json={"ids": s2_ids},
85
+ headers=headers,
86
+ timeout=30.0,
87
+ )
88
+
89
+ if resp.status_code == 200:
90
+ results = resp.json()
91
+ edges = []
92
+ for i, paper in enumerate(results):
93
+ if paper is None:
94
+ continue
95
+ citing_id = arxiv_ids[i]
96
+ refs = paper.get("references") or []
97
+ for ref in refs:
98
+ ext_ids = ref.get("externalIds") or {}
99
+ cited_arxiv = ext_ids.get("ArXiv")
100
+ if cited_arxiv:
101
+ edges.append((citing_id, cited_arxiv, False))
102
+ return edges
103
+
104
+ elif resp.status_code == 429:
105
+ retry_after = int(resp.headers.get("Retry-After", RETRY_BACKOFF_BASE ** attempt))
106
+ print(f" [batch {batch_idx}] Rate limited. Waiting {retry_after}s (attempt {attempt+1}/{MAX_RETRIES})")
107
+ await asyncio.sleep(retry_after)
108
+
109
+ elif resp.status_code == 400:
110
+ print(f" [batch {batch_idx}] Bad request (400). Skipping batch.")
111
+ return []
112
+
113
+ else:
114
+ print(f" [batch {batch_idx}] HTTP {resp.status_code}. Retrying (attempt {attempt+1}/{MAX_RETRIES})")
115
+ await asyncio.sleep(RETRY_BACKOFF_BASE ** attempt)
116
+
117
+ except (httpx.TimeoutException, httpx.ConnectError, httpx.ReadError) as e:
118
+ print(f" [batch {batch_idx}] Network error: {type(e).__name__}. Retrying (attempt {attempt+1}/{MAX_RETRIES})")
119
+ await asyncio.sleep(RETRY_BACKOFF_BASE ** attempt)
120
+
121
+ print(f" [batch {batch_idx}] FAILED after {MAX_RETRIES} attempts. Skipping.")
122
+ return []
123
+
124
+
125
+ async def fetch_all_batches(
126
+ corpus_ids: list[str],
127
+ api_key: str | None,
128
+ checkpoint_dir: Path,
129
+ ) -> list[tuple[str, str, bool]]:
130
+ """
131
+ Fetch citation edges for all corpus IDs using the S2 batch API.
132
+ Supports checkpoint/resume.
133
+ """
134
+ # Check for existing checkpoint
135
+ checkpoint_file = checkpoint_dir / "checkpoint.json"
136
+ all_edges: list[tuple[str, str, bool]] = []
137
+ start_batch = 0
138
+
139
+ if checkpoint_file.exists():
140
+ with open(checkpoint_file) as f:
141
+ ckpt = json.load(f)
142
+ start_batch = ckpt["next_batch"]
143
+ # Load previously saved edges
144
+ edges_file = checkpoint_dir / "edges_partial.jsonl"
145
+ if edges_file.exists():
146
+ with open(edges_file) as f:
147
+ for line in f:
148
+ row = json.loads(line)
149
+ all_edges.append((row["citing"], row["cited"], row["influential"]))
150
+ print(f"Resuming from batch {start_batch} ({len(all_edges)} edges already collected)")
151
+
152
+ # Split into batches
153
+ batches = []
154
+ for i in range(0, len(corpus_ids), BATCH_SIZE):
155
+ batches.append(corpus_ids[i : i + BATCH_SIZE])
156
+
157
+ total_batches = len(batches)
158
+ print(f"Total: {len(corpus_ids)} papers β†’ {total_batches} batches of {BATCH_SIZE}")
159
+ print(f"Starting from batch {start_batch}")
160
+
161
+ # Rate limiting: 1 req/s without key, slightly faster with key
162
+ delay = 0.5 if api_key else 1.1
163
+
164
+ edges_file = checkpoint_dir / "edges_partial.jsonl"
165
+
166
+ async with httpx.AsyncClient() as client:
167
+ pbar = tqdm(range(start_batch, total_batches), initial=start_batch, total=total_batches)
168
+ for batch_idx in pbar:
169
+ batch = batches[batch_idx]
170
+
171
+ edges = await fetch_one_batch(client, batch, api_key, batch_idx)
172
+ all_edges.extend(edges)
173
+
174
+ # Append edges to partial file
175
+ with open(edges_file, "a") as f:
176
+ for citing, cited, influential in edges:
177
+ f.write(json.dumps({"citing": citing, "cited": cited, "influential": influential}) + "\n")
178
+
179
+ pbar.set_postfix({"edges": len(all_edges), "batch_edges": len(edges)})
180
+
181
+ # Checkpoint periodically
182
+ if (batch_idx + 1) % CHECKPOINT_EVERY == 0:
183
+ with open(checkpoint_file, "w") as f:
184
+ json.dump({"next_batch": batch_idx + 1}, f)
185
+
186
+ await asyncio.sleep(delay)
187
+
188
+ # Final checkpoint
189
+ with open(checkpoint_file, "w") as f:
190
+ json.dump({"next_batch": total_batches, "status": "complete"}, f)
191
+
192
+ return all_edges
193
+
194
+
195
+ # ── Bulk Download Path ───────────────────────────────────────────────────────
196
+
197
+ def process_bulk_downloads(
198
+ papers_file: str,
199
+ citations_file: str,
200
+ corpus_set: set[str],
201
+ ) -> list[tuple[str, str, bool]]:
202
+ """
203
+ Process S2 bulk dataset downloads to extract in-corpus citation edges.
204
+
205
+ papers_file: paper-ids.jsonl.gz (corpusid β†’ externalIds mapping)
206
+ citations_file: citations.jsonl.gz (citingcorpusid β†’ citedcorpusid edges)
207
+ """
208
+ print("Step 1/2: Building corpusid β†’ arxiv_id mapping from paper-ids...")
209
+ corpusid_to_arxiv: dict[int, str] = {}
210
+ with gzip.open(papers_file, "rt") as f:
211
+ for line in tqdm(f, desc="Reading paper-ids"):
212
+ try:
213
+ rec = json.loads(line)
214
+ ext_ids = rec.get("externalids") or rec.get("externalIds") or {}
215
+ arxiv_id = ext_ids.get("ArXiv")
216
+ corpus_id = rec.get("corpusid") or rec.get("corpusId")
217
+ if arxiv_id and corpus_id and arxiv_id in corpus_set:
218
+ corpusid_to_arxiv[int(corpus_id)] = arxiv_id
219
+ except (json.JSONDecodeError, ValueError):
220
+ continue
221
+
222
+ print(f" Mapped {len(corpusid_to_arxiv)} corpus IDs to arXiv IDs in your corpus")
223
+
224
+ print("Step 2/2: Filtering citation edges to in-corpus pairs...")
225
+ edges: list[tuple[str, str, bool]] = []
226
+ with gzip.open(citations_file, "rt") as f:
227
+ for line in tqdm(f, desc="Reading citations"):
228
+ try:
229
+ rec = json.loads(line)
230
+ citing_cid = rec.get("citingcorpusid") or rec.get("citingCorpusId")
231
+ cited_cid = rec.get("citedcorpusid") or rec.get("citedCorpusId")
232
+ is_influential = rec.get("isinfluential", False) or rec.get("isInfluential", False)
233
+
234
+ citing_arxiv = corpusid_to_arxiv.get(int(citing_cid)) if citing_cid else None
235
+ cited_arxiv = corpusid_to_arxiv.get(int(cited_cid)) if cited_cid else None
236
+
237
+ if citing_arxiv and cited_arxiv:
238
+ edges.append((citing_arxiv, cited_arxiv, bool(is_influential)))
239
+ except (json.JSONDecodeError, ValueError, TypeError):
240
+ continue
241
+
242
+ print(f" Found {len(edges)} in-corpus citation edges")
243
+ return edges
244
+
245
+
246
+ # ── Filter & Save ────────────────────────────────────────────────────────────
247
+
248
+ def filter_and_save(
249
+ edges: list[tuple[str, str, bool]],
250
+ corpus_set: set[str],
251
+ output_path: str,
252
+ ):
253
+ """
254
+ Filter edges to in-corpus pairs, deduplicate, and save as parquet.
255
+ """
256
+ print(f"Raw edges before filtering: {len(edges)}")
257
+
258
+ # Filter: both citing and cited must be in corpus
259
+ filtered = [
260
+ (citing, cited, influential)
261
+ for citing, cited, influential in edges
262
+ if citing in corpus_set and cited in corpus_set and citing != cited
263
+ ]
264
+ print(f"In-corpus edges (both sides in corpus): {len(filtered)}")
265
+
266
+ # Deduplicate
267
+ seen = set()
268
+ deduped = []
269
+ for citing, cited, influential in filtered:
270
+ key = (citing, cited)
271
+ if key not in seen:
272
+ seen.add(key)
273
+ deduped.append((citing, cited, influential))
274
+
275
+ print(f"After deduplication: {len(deduped)}")
276
+
277
+ # Save as parquet
278
+ table = pa.table({
279
+ "citing_arxiv_id": pa.array([e[0] for e in deduped], type=pa.string()),
280
+ "cited_arxiv_id": pa.array([e[1] for e in deduped], type=pa.string()),
281
+ "is_influential": pa.array([e[2] for e in deduped], type=pa.bool_()),
282
+ })
283
+
284
+ pq.write_table(table, output_path, compression="snappy")
285
+ print(f"Saved {len(deduped)} citation edges to {output_path}")
286
+
287
+ # Print stats
288
+ citing_papers = set(e[0] for e in deduped)
289
+ cited_papers = set(e[1] for e in deduped)
290
+ print(f"\nStats:")
291
+ print(f" Unique citing papers: {len(citing_papers)}")
292
+ print(f" Unique cited papers: {len(cited_papers)}")
293
+ print(f" Unique papers total: {len(citing_papers | cited_papers)}")
294
+ print(f" Avg references per citing paper: {len(deduped) / max(len(citing_papers), 1):.1f}")
295
+ influential_count = sum(1 for e in deduped if e[2])
296
+ print(f" Influential citations: {influential_count} ({100*influential_count/max(len(deduped),1):.1f}%)")
297
+
298
+
299
+ # ── Main ─────────────────────────────────────────────────────────────────────
300
+
301
+ def load_corpus_ids(path: str) -> list[str]:
302
+ """Load arXiv IDs from a text file (one per line)."""
303
+ ids = []
304
+ with open(path) as f:
305
+ for line in f:
306
+ line = line.strip()
307
+ if line and not line.startswith("#"):
308
+ # Handle various formats: "2303.14957", "arXiv:2303.14957", etc.
309
+ if line.startswith("arXiv:"):
310
+ line = line[6:]
311
+ elif line.startswith("ARXIV:"):
312
+ line = line[6:]
313
+ ids.append(line)
314
+ print(f"Loaded {len(ids)} arXiv IDs from {path}")
315
+ return ids
316
+
317
+
318
+ def main():
319
+ parser = argparse.ArgumentParser(
320
+ description="Fetch citation edges from Semantic Scholar for ResearchIT corpus"
321
+ )
322
+ parser.add_argument(
323
+ "--corpus-file", required=True,
324
+ help="Text file with one arXiv ID per line (e.g. arxiv_ids.txt)"
325
+ )
326
+ parser.add_argument(
327
+ "--output", default="citations.parquet",
328
+ help="Output parquet file path (default: citations.parquet)"
329
+ )
330
+ parser.add_argument(
331
+ "--api-key", default=None,
332
+ help="Semantic Scholar API key (optional, speeds up rate limit)"
333
+ )
334
+ parser.add_argument(
335
+ "--bulk-papers", default=None,
336
+ help="Path to S2 bulk paper-ids.jsonl.gz (use bulk download path)"
337
+ )
338
+ parser.add_argument(
339
+ "--bulk-citations", default=None,
340
+ help="Path to S2 bulk citations.jsonl.gz (use bulk download path)"
341
+ )
342
+ parser.add_argument(
343
+ "--checkpoint-dir", default="./citation_checkpoint",
344
+ help="Directory for checkpoint files (batch API mode)"
345
+ )
346
+ parser.add_argument(
347
+ "--max-papers", type=int, default=None,
348
+ help="Limit to first N papers (for testing)"
349
+ )
350
+
351
+ args = parser.parse_args()
352
+
353
+ # Load corpus
354
+ corpus_ids = load_corpus_ids(args.corpus_file)
355
+ if args.max_papers:
356
+ corpus_ids = corpus_ids[:args.max_papers]
357
+ print(f" Limited to {len(corpus_ids)} papers (--max-papers)")
358
+
359
+ corpus_set = set(corpus_ids)
360
+
361
+ # Choose path
362
+ if args.bulk_papers and args.bulk_citations:
363
+ print("\n=== BULK DOWNLOAD PATH ===")
364
+ edges = process_bulk_downloads(args.bulk_papers, args.bulk_citations, corpus_set)
365
+ else:
366
+ print("\n=== BATCH API PATH ===")
367
+ if not args.api_key:
368
+ # Check environment variable
369
+ args.api_key = os.environ.get("S2_API_KEY")
370
+ if args.api_key:
371
+ print(f"Using API key: {args.api_key[:8]}...")
372
+ else:
373
+ print("No API key β€” using unauthenticated rate (1 req/s)")
374
+ print("Get a free key at: https://www.semanticscholar.org/product/api#Partner-Form")
375
+
376
+ checkpoint_dir = Path(args.checkpoint_dir)
377
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
378
+
379
+ edges = asyncio.run(fetch_all_batches(corpus_ids, args.api_key, checkpoint_dir))
380
+
381
+ # Filter to in-corpus and save
382
+ filter_and_save(edges, corpus_set, args.output)
383
+
384
+ print(f"\nβœ… Done! Citation edges saved to: {args.output}")
385
+
386
+
387
+ if __name__ == "__main__":
388
+ main()