Researcher / src /pipelines /semantic_scholar.py
amarck's picture
Initial commit: Research Intelligence System
a0f27fa
"""Semantic Scholar enrichment — connected papers, TL;DR, and topic extraction.
Uses the free S2 Academic Graph API. No API key required but rate-limited
to a shared pool. With a key (x-api-key header), 1 req/sec guaranteed.
Enrichment strategy:
1. Batch lookup all papers → TL;DR + S2 paper ID (1 API call per 500 papers)
2. Top N papers by score → references + recommendations (2 calls each)
3. Topic extraction from title/abstract (local, no API)
"""
import json
import logging
import re
import time
import requests
log = logging.getLogger(__name__)
from src.db import (
clear_connections,
get_arxiv_id_map,
get_conn,
get_top_papers,
insert_connections,
update_paper_s2,
update_paper_topics,
)
S2_GRAPH = "https://api.semanticscholar.org/graph/v1"
S2_RECO = "https://api.semanticscholar.org/recommendations/v1"
S2_HEADERS: dict[str, str] = {} # Add {"x-api-key": "..."} if you have one
# How many top papers get full connection enrichment
TOP_N_CONNECTIONS = 30
# Rate limit pause between requests (seconds)
RATE_LIMIT = 1.1
# ---------------------------------------------------------------------------
# Main entry point
# ---------------------------------------------------------------------------
def enrich_run(run_id: int, domain: str):
"""Enrich all scored papers in a run with S2 data + topics."""
with get_conn() as conn:
rows = conn.execute(
"SELECT id, arxiv_id, title, abstract, composite FROM papers "
"WHERE run_id=? AND composite IS NOT NULL "
"ORDER BY composite DESC",
(run_id,),
).fetchall()
papers = [dict(r) for r in rows]
if not papers:
log.info("No scored papers in run %d, skipping", run_id)
return
arxiv_map = get_arxiv_id_map(run_id)
log.info("Enriching %d papers from run %d (%s)...", len(papers), run_id, domain)
# Step 1: Batch TL;DR + S2 ID
_batch_tldr(papers)
# Step 2: Connected papers for top N
top_papers = papers[:TOP_N_CONNECTIONS]
for i, p in enumerate(top_papers):
try:
_fetch_connections(p, arxiv_map)
except Exception as e:
log.warning("Error fetching connections for %s: %s", p['arxiv_id'], e)
if (i + 1) % 10 == 0:
log.info("Connections: %d/%d", i + 1, len(top_papers))
# Step 3: Topic extraction (local, instant)
for p in papers:
topics = extract_topics(p["title"], p.get("abstract", ""), domain)
if topics:
update_paper_topics(p["id"], topics)
log.info("Done enriching run %d", run_id)
# ---------------------------------------------------------------------------
# Step 1: Batch TL;DR
# ---------------------------------------------------------------------------
def _batch_tldr(papers: list[dict]):
"""Batch fetch TL;DR and S2 paper IDs."""
chunk_size = 500
for start in range(0, len(papers), chunk_size):
chunk = papers[start : start + chunk_size]
ids = [f"arXiv:{p['arxiv_id']}" for p in chunk]
try:
resp = requests.post(
f"{S2_GRAPH}/paper/batch",
params={"fields": "externalIds,tldr"},
json={"ids": ids},
headers=S2_HEADERS,
timeout=30,
)
resp.raise_for_status()
results = resp.json()
except Exception as e:
log.warning("Batch TL;DR failed: %s", e)
time.sleep(RATE_LIMIT)
continue
for paper, s2_data in zip(chunk, results):
if s2_data is None:
continue
s2_id = s2_data.get("paperId", "")
tldr_obj = s2_data.get("tldr")
tldr_text = tldr_obj.get("text", "") if tldr_obj else ""
update_paper_s2(paper["id"], s2_id, tldr_text)
paper["s2_paper_id"] = s2_id
found = sum(1 for r in results if r is not None)
log.info("Batch TL;DR: %d/%d papers found in S2", found, len(chunk))
time.sleep(RATE_LIMIT)
# ---------------------------------------------------------------------------
# Step 2: Connected papers (references + recommendations)
# ---------------------------------------------------------------------------
def _fetch_connections(paper: dict, arxiv_map: dict[str, int]):
"""Fetch references and recommendations for a single paper."""
arxiv_id = paper["arxiv_id"]
paper_id = paper["id"]
# Clear old connections before re-fetching
clear_connections(paper_id)
connections: list[dict] = []
# References
time.sleep(RATE_LIMIT)
try:
resp = requests.get(
f"{S2_GRAPH}/paper/arXiv:{arxiv_id}/references",
params={"fields": "title,year,externalIds", "limit": 30},
headers=S2_HEADERS,
timeout=15,
)
if resp.ok:
for item in resp.json().get("data", []):
cited = item.get("citedPaper")
if not cited or not cited.get("title"):
continue
ext = cited.get("externalIds") or {}
c_arxiv = ext.get("ArXiv", "")
connections.append({
"paper_id": paper_id,
"connected_arxiv_id": c_arxiv,
"connected_s2_id": cited.get("paperId", ""),
"connected_title": cited.get("title", ""),
"connected_year": cited.get("year"),
"connection_type": "reference",
"in_db_paper_id": arxiv_map.get(c_arxiv),
})
except requests.RequestException as e:
log.warning("References failed for %s: %s", arxiv_id, e)
# Recommendations
time.sleep(RATE_LIMIT)
try:
resp = requests.get(
f"{S2_RECO}/papers/forpaper/arXiv:{arxiv_id}",
params={"fields": "title,year,externalIds", "limit": 15},
headers=S2_HEADERS,
timeout=15,
)
if resp.ok:
for rec in resp.json().get("recommendedPapers", []):
if not rec or not rec.get("title"):
continue
ext = rec.get("externalIds") or {}
c_arxiv = ext.get("ArXiv", "")
connections.append({
"paper_id": paper_id,
"connected_arxiv_id": c_arxiv,
"connected_s2_id": rec.get("paperId", ""),
"connected_title": rec.get("title", ""),
"connected_year": rec.get("year"),
"connection_type": "recommendation",
"in_db_paper_id": arxiv_map.get(c_arxiv),
})
except requests.RequestException as e:
log.warning("Recommendations failed for %s: %s", arxiv_id, e)
if connections:
insert_connections(connections)
# ---------------------------------------------------------------------------
# Step 3: Topic extraction (local, no API)
# ---------------------------------------------------------------------------
AIML_TOPICS = {
"Video Generation": re.compile(
r"video.generat|text.to.video|video.diffusion|video.synth|video.edit", re.I),
"Image Generation": re.compile(
r"image.generat|text.to.image|(?:stable|latent).diffusion|image.synth|image.edit", re.I),
"Language Models": re.compile(
r"language.model|(?:large|foundation).model|\bllm\b|\bgpt\b|instruction.tun|fine.tun", re.I),
"Code": re.compile(
r"code.generat|code.complet|program.synth|vibe.cod|software.engineer", re.I),
"Multimodal": re.compile(
r"multimodal|vision.language|\bvlm\b|visual.question|image.text", re.I),
"Efficiency": re.compile(
r"quantiz|distillat|pruning|efficient|scaling.law|compress|accelerat", re.I),
"Agents": re.compile(
r"\bagent\b|tool.use|function.call|planning|agentic", re.I),
"Speech / Audio": re.compile(
r"text.to.speech|\btts\b|speech|audio.generat|voice|music.generat", re.I),
"3D / Vision": re.compile(
r"\b3d\b|nerf|gaussian.splat|point.cloud|depth.estim|object.detect|segmentat", re.I),
"Retrieval / RAG": re.compile(
r"retriev|\brag\b|knowledge.(?:base|graph)|in.context.learn|embedding", re.I),
"Robotics": re.compile(
r"robot|embodied|manipulat|locomotion|navigation", re.I),
"Reasoning": re.compile(
r"reasoning|chain.of.thought|mathemat|logic|theorem", re.I),
"Training": re.compile(
r"reinforcement.learn|\brlhf\b|\bdpo\b|preference|reward.model|alignment", re.I),
"Architecture": re.compile(
r"attention.mechanism|state.space|\bmamba\b|mixture.of.expert|\bmoe\b|transformer", re.I),
"Benchmark": re.compile(
r"benchmark|evaluat|leaderboard|dataset|scaling.law", re.I),
"World Models": re.compile(
r"world.model|environment.model|predictive.model|dynamics.model", re.I),
"Optimization": re.compile(
r"optimi[zs]|gradient|convergence|learning.rate|loss.function|multi.objective|adversarial.train", re.I),
"RL": re.compile(
r"reinforcement.learn|\brl\b|reward|policy.gradient|q.learning|bandit", re.I),
}
SECURITY_TOPICS = {
"Web Security": re.compile(
r"web.(?:secur|app|vuln)|xss|injection|csrf|waf|\bbrowser.secur", re.I),
"Network": re.compile(
r"network.secur|intrusion|\bids\b|firewall|traffic|\bdns\b|\bbgp\b|\bddos\b|fingerprint|scanning|packet", re.I),
"Malware": re.compile(
r"malware|ransomware|trojan|botnet|rootkit|worm|backdoor", re.I),
"Vulnerabilities": re.compile(
r"vulnerab|\bcve\b|exploit|fuzzing|fuzz|buffer.overflow|zero.day|attack.surface|security.bench", re.I),
"Cryptography": re.compile(
r"cryptograph|encryption|decrypt|protocol|\btls\b|\bssl\b|cipher", re.I),
"Hardware": re.compile(
r"side.channel|timing.attack|spectre|meltdown|hardware|firmware|microarch|fault.inject|emfi|embedded.secur", re.I),
"Reverse Engineering": re.compile(
r"reverse.engineer|binary|decompil|obfuscat|disassembl", re.I),
"Mobile": re.compile(
r"\bandroid\b|\bios.secur|mobile.secur", re.I),
"Cloud": re.compile(
r"cloud.secur|container.secur|docker|kubernetes|serverless|devsecops", re.I),
"Authentication": re.compile(
r"authentica|identity|credential|phishing|password|oauth|passkey|webauthn", re.I),
"Privacy": re.compile(
r"privacy|anonymi|differential.privacy|data.leak|tracking|membership.inference", re.I),
"LLM Security": re.compile(
r"(?:llm|language.model).*(secur|attack|jailbreak|safety|risk|unsafe|inject|adversar)|prompt.inject|red.team|rubric.attack|preference.drift", re.I),
"Forensics": re.compile(
r"forensic|incident.response|audit|log.analy|carver|tamper|evidence", re.I),
"Blockchain": re.compile(
r"blockchain|smart.contract|solana|ethereum|memecoin|mev|defi|token|cryptocurrency", re.I),
"Supply Chain": re.compile(
r"supply.chain|dependency|package.secur|software.comp|sbom", re.I),
}
def extract_topics(title: str, abstract: str, domain: str) -> list[str]:
"""Extract up to 3 topic tags from title and abstract."""
patterns = AIML_TOPICS if domain == "aiml" else SECURITY_TOPICS
abstract_head = (abstract or "")[:500]
scored: dict[str, int] = {}
for topic, pattern in patterns.items():
score = 0
if pattern.search(title):
score += 3 # Title match is strong signal
if pattern.search(abstract_head):
score += 1
if score > 0:
scored[topic] = score
ranked = sorted(scored.items(), key=lambda x: -x[1])
return [t for t, _ in ranked[:3]]