Spaces:
Sleeping
Sleeping
| """Semantic Scholar enrichment — connected papers, TL;DR, and topic extraction. | |
| Uses the free S2 Academic Graph API. No API key required but rate-limited | |
| to a shared pool. With a key (x-api-key header), 1 req/sec guaranteed. | |
| Enrichment strategy: | |
| 1. Batch lookup all papers → TL;DR + S2 paper ID (1 API call per 500 papers) | |
| 2. Top N papers by score → references + recommendations (2 calls each) | |
| 3. Topic extraction from title/abstract (local, no API) | |
| """ | |
| import json | |
| import logging | |
| import re | |
| import time | |
| import requests | |
| log = logging.getLogger(__name__) | |
| from src.db import ( | |
| clear_connections, | |
| get_arxiv_id_map, | |
| get_conn, | |
| get_top_papers, | |
| insert_connections, | |
| update_paper_s2, | |
| update_paper_topics, | |
| ) | |
| S2_GRAPH = "https://api.semanticscholar.org/graph/v1" | |
| S2_RECO = "https://api.semanticscholar.org/recommendations/v1" | |
| S2_HEADERS: dict[str, str] = {} # Add {"x-api-key": "..."} if you have one | |
| # How many top papers get full connection enrichment | |
| TOP_N_CONNECTIONS = 30 | |
| # Rate limit pause between requests (seconds) | |
| RATE_LIMIT = 1.1 | |
| # --------------------------------------------------------------------------- | |
| # Main entry point | |
| # --------------------------------------------------------------------------- | |
| def enrich_run(run_id: int, domain: str): | |
| """Enrich all scored papers in a run with S2 data + topics.""" | |
| with get_conn() as conn: | |
| rows = conn.execute( | |
| "SELECT id, arxiv_id, title, abstract, composite FROM papers " | |
| "WHERE run_id=? AND composite IS NOT NULL " | |
| "ORDER BY composite DESC", | |
| (run_id,), | |
| ).fetchall() | |
| papers = [dict(r) for r in rows] | |
| if not papers: | |
| log.info("No scored papers in run %d, skipping", run_id) | |
| return | |
| arxiv_map = get_arxiv_id_map(run_id) | |
| log.info("Enriching %d papers from run %d (%s)...", len(papers), run_id, domain) | |
| # Step 1: Batch TL;DR + S2 ID | |
| _batch_tldr(papers) | |
| # Step 2: Connected papers for top N | |
| top_papers = papers[:TOP_N_CONNECTIONS] | |
| for i, p in enumerate(top_papers): | |
| try: | |
| _fetch_connections(p, arxiv_map) | |
| except Exception as e: | |
| log.warning("Error fetching connections for %s: %s", p['arxiv_id'], e) | |
| if (i + 1) % 10 == 0: | |
| log.info("Connections: %d/%d", i + 1, len(top_papers)) | |
| # Step 3: Topic extraction (local, instant) | |
| for p in papers: | |
| topics = extract_topics(p["title"], p.get("abstract", ""), domain) | |
| if topics: | |
| update_paper_topics(p["id"], topics) | |
| log.info("Done enriching run %d", run_id) | |
| # --------------------------------------------------------------------------- | |
| # Step 1: Batch TL;DR | |
| # --------------------------------------------------------------------------- | |
| def _batch_tldr(papers: list[dict]): | |
| """Batch fetch TL;DR and S2 paper IDs.""" | |
| chunk_size = 500 | |
| for start in range(0, len(papers), chunk_size): | |
| chunk = papers[start : start + chunk_size] | |
| ids = [f"arXiv:{p['arxiv_id']}" for p in chunk] | |
| try: | |
| resp = requests.post( | |
| f"{S2_GRAPH}/paper/batch", | |
| params={"fields": "externalIds,tldr"}, | |
| json={"ids": ids}, | |
| headers=S2_HEADERS, | |
| timeout=30, | |
| ) | |
| resp.raise_for_status() | |
| results = resp.json() | |
| except Exception as e: | |
| log.warning("Batch TL;DR failed: %s", e) | |
| time.sleep(RATE_LIMIT) | |
| continue | |
| for paper, s2_data in zip(chunk, results): | |
| if s2_data is None: | |
| continue | |
| s2_id = s2_data.get("paperId", "") | |
| tldr_obj = s2_data.get("tldr") | |
| tldr_text = tldr_obj.get("text", "") if tldr_obj else "" | |
| update_paper_s2(paper["id"], s2_id, tldr_text) | |
| paper["s2_paper_id"] = s2_id | |
| found = sum(1 for r in results if r is not None) | |
| log.info("Batch TL;DR: %d/%d papers found in S2", found, len(chunk)) | |
| time.sleep(RATE_LIMIT) | |
| # --------------------------------------------------------------------------- | |
| # Step 2: Connected papers (references + recommendations) | |
| # --------------------------------------------------------------------------- | |
| def _fetch_connections(paper: dict, arxiv_map: dict[str, int]): | |
| """Fetch references and recommendations for a single paper.""" | |
| arxiv_id = paper["arxiv_id"] | |
| paper_id = paper["id"] | |
| # Clear old connections before re-fetching | |
| clear_connections(paper_id) | |
| connections: list[dict] = [] | |
| # References | |
| time.sleep(RATE_LIMIT) | |
| try: | |
| resp = requests.get( | |
| f"{S2_GRAPH}/paper/arXiv:{arxiv_id}/references", | |
| params={"fields": "title,year,externalIds", "limit": 30}, | |
| headers=S2_HEADERS, | |
| timeout=15, | |
| ) | |
| if resp.ok: | |
| for item in resp.json().get("data", []): | |
| cited = item.get("citedPaper") | |
| if not cited or not cited.get("title"): | |
| continue | |
| ext = cited.get("externalIds") or {} | |
| c_arxiv = ext.get("ArXiv", "") | |
| connections.append({ | |
| "paper_id": paper_id, | |
| "connected_arxiv_id": c_arxiv, | |
| "connected_s2_id": cited.get("paperId", ""), | |
| "connected_title": cited.get("title", ""), | |
| "connected_year": cited.get("year"), | |
| "connection_type": "reference", | |
| "in_db_paper_id": arxiv_map.get(c_arxiv), | |
| }) | |
| except requests.RequestException as e: | |
| log.warning("References failed for %s: %s", arxiv_id, e) | |
| # Recommendations | |
| time.sleep(RATE_LIMIT) | |
| try: | |
| resp = requests.get( | |
| f"{S2_RECO}/papers/forpaper/arXiv:{arxiv_id}", | |
| params={"fields": "title,year,externalIds", "limit": 15}, | |
| headers=S2_HEADERS, | |
| timeout=15, | |
| ) | |
| if resp.ok: | |
| for rec in resp.json().get("recommendedPapers", []): | |
| if not rec or not rec.get("title"): | |
| continue | |
| ext = rec.get("externalIds") or {} | |
| c_arxiv = ext.get("ArXiv", "") | |
| connections.append({ | |
| "paper_id": paper_id, | |
| "connected_arxiv_id": c_arxiv, | |
| "connected_s2_id": rec.get("paperId", ""), | |
| "connected_title": rec.get("title", ""), | |
| "connected_year": rec.get("year"), | |
| "connection_type": "recommendation", | |
| "in_db_paper_id": arxiv_map.get(c_arxiv), | |
| }) | |
| except requests.RequestException as e: | |
| log.warning("Recommendations failed for %s: %s", arxiv_id, e) | |
| if connections: | |
| insert_connections(connections) | |
| # --------------------------------------------------------------------------- | |
| # Step 3: Topic extraction (local, no API) | |
| # --------------------------------------------------------------------------- | |
| AIML_TOPICS = { | |
| "Video Generation": re.compile( | |
| r"video.generat|text.to.video|video.diffusion|video.synth|video.edit", re.I), | |
| "Image Generation": re.compile( | |
| r"image.generat|text.to.image|(?:stable|latent).diffusion|image.synth|image.edit", re.I), | |
| "Language Models": re.compile( | |
| r"language.model|(?:large|foundation).model|\bllm\b|\bgpt\b|instruction.tun|fine.tun", re.I), | |
| "Code": re.compile( | |
| r"code.generat|code.complet|program.synth|vibe.cod|software.engineer", re.I), | |
| "Multimodal": re.compile( | |
| r"multimodal|vision.language|\bvlm\b|visual.question|image.text", re.I), | |
| "Efficiency": re.compile( | |
| r"quantiz|distillat|pruning|efficient|scaling.law|compress|accelerat", re.I), | |
| "Agents": re.compile( | |
| r"\bagent\b|tool.use|function.call|planning|agentic", re.I), | |
| "Speech / Audio": re.compile( | |
| r"text.to.speech|\btts\b|speech|audio.generat|voice|music.generat", re.I), | |
| "3D / Vision": re.compile( | |
| r"\b3d\b|nerf|gaussian.splat|point.cloud|depth.estim|object.detect|segmentat", re.I), | |
| "Retrieval / RAG": re.compile( | |
| r"retriev|\brag\b|knowledge.(?:base|graph)|in.context.learn|embedding", re.I), | |
| "Robotics": re.compile( | |
| r"robot|embodied|manipulat|locomotion|navigation", re.I), | |
| "Reasoning": re.compile( | |
| r"reasoning|chain.of.thought|mathemat|logic|theorem", re.I), | |
| "Training": re.compile( | |
| r"reinforcement.learn|\brlhf\b|\bdpo\b|preference|reward.model|alignment", re.I), | |
| "Architecture": re.compile( | |
| r"attention.mechanism|state.space|\bmamba\b|mixture.of.expert|\bmoe\b|transformer", re.I), | |
| "Benchmark": re.compile( | |
| r"benchmark|evaluat|leaderboard|dataset|scaling.law", re.I), | |
| "World Models": re.compile( | |
| r"world.model|environment.model|predictive.model|dynamics.model", re.I), | |
| "Optimization": re.compile( | |
| r"optimi[zs]|gradient|convergence|learning.rate|loss.function|multi.objective|adversarial.train", re.I), | |
| "RL": re.compile( | |
| r"reinforcement.learn|\brl\b|reward|policy.gradient|q.learning|bandit", re.I), | |
| } | |
| SECURITY_TOPICS = { | |
| "Web Security": re.compile( | |
| r"web.(?:secur|app|vuln)|xss|injection|csrf|waf|\bbrowser.secur", re.I), | |
| "Network": re.compile( | |
| r"network.secur|intrusion|\bids\b|firewall|traffic|\bdns\b|\bbgp\b|\bddos\b|fingerprint|scanning|packet", re.I), | |
| "Malware": re.compile( | |
| r"malware|ransomware|trojan|botnet|rootkit|worm|backdoor", re.I), | |
| "Vulnerabilities": re.compile( | |
| r"vulnerab|\bcve\b|exploit|fuzzing|fuzz|buffer.overflow|zero.day|attack.surface|security.bench", re.I), | |
| "Cryptography": re.compile( | |
| r"cryptograph|encryption|decrypt|protocol|\btls\b|\bssl\b|cipher", re.I), | |
| "Hardware": re.compile( | |
| r"side.channel|timing.attack|spectre|meltdown|hardware|firmware|microarch|fault.inject|emfi|embedded.secur", re.I), | |
| "Reverse Engineering": re.compile( | |
| r"reverse.engineer|binary|decompil|obfuscat|disassembl", re.I), | |
| "Mobile": re.compile( | |
| r"\bandroid\b|\bios.secur|mobile.secur", re.I), | |
| "Cloud": re.compile( | |
| r"cloud.secur|container.secur|docker|kubernetes|serverless|devsecops", re.I), | |
| "Authentication": re.compile( | |
| r"authentica|identity|credential|phishing|password|oauth|passkey|webauthn", re.I), | |
| "Privacy": re.compile( | |
| r"privacy|anonymi|differential.privacy|data.leak|tracking|membership.inference", re.I), | |
| "LLM Security": re.compile( | |
| r"(?:llm|language.model).*(secur|attack|jailbreak|safety|risk|unsafe|inject|adversar)|prompt.inject|red.team|rubric.attack|preference.drift", re.I), | |
| "Forensics": re.compile( | |
| r"forensic|incident.response|audit|log.analy|carver|tamper|evidence", re.I), | |
| "Blockchain": re.compile( | |
| r"blockchain|smart.contract|solana|ethereum|memecoin|mev|defi|token|cryptocurrency", re.I), | |
| "Supply Chain": re.compile( | |
| r"supply.chain|dependency|package.secur|software.comp|sbom", re.I), | |
| } | |
| def extract_topics(title: str, abstract: str, domain: str) -> list[str]: | |
| """Extract up to 3 topic tags from title and abstract.""" | |
| patterns = AIML_TOPICS if domain == "aiml" else SECURITY_TOPICS | |
| abstract_head = (abstract or "")[:500] | |
| scored: dict[str, int] = {} | |
| for topic, pattern in patterns.items(): | |
| score = 0 | |
| if pattern.search(title): | |
| score += 3 # Title match is strong signal | |
| if pattern.search(abstract_head): | |
| score += 1 | |
| if score > 0: | |
| scored[topic] = score | |
| ranked = sorted(scored.items(), key=lambda x: -x[1]) | |
| return [t for t, _ in ranked[:3]] | |