| |
| |
| |
| |
| """ |
| Paper Database Manager for semantic paper discovery. |
| |
| Usage: |
| uv run references/paper_db.py index # Index all papers in references/ |
| uv run references/paper_db.py search "query" # Search papers |
| uv run references/paper_db.py cite <paper> # Find citations/references |
| uv run references/paper_db.py refs <paper> # Find references from paper |
| uv run references/paper_db.py related <paper> # Find related papers |
| uv run references/paper_db.py discover # Discover new papers via citations |
| uv run references/paper_db.py fetch <arxiv_id> # Fetch paper from arXiv |
| uv run references/paper_db.py graph # Generate citation graph |
| uv run references/paper_db.py stats # Show database statistics |
| """ |
|
|
| import json |
| import os |
| import re |
| import sys |
| import time |
| from pathlib import Path |
| from dataclasses import dataclass, field, asdict |
| from typing import Optional |
| import yaml |
| import requests |
|
|
| |
| S2_API = "https://api.semanticscholar.org/graph/v1" |
| S2_FIELDS = "title,authors,year,venue,citationCount,abstract,externalIds,references,citations,tldr" |
|
|
| @dataclass |
| class Paper: |
| """Paper metadata.""" |
| id: str |
| title: str |
| authors: list[str] |
| year: int |
| venue: str |
| url: str |
| arxiv_id: Optional[str] = None |
| s2_id: Optional[str] = None |
| doi: Optional[str] = None |
| citation_count: int = 0 |
| abstract: str = "" |
| tldr: str = "" |
| keywords: list[str] = field(default_factory=list) |
| references: list[str] = field(default_factory=list) |
| cited_by: list[str] = field(default_factory=list) |
| local_path: str = "" |
| fetched: bool = False |
|
|
| def to_dict(self): |
| return asdict(self) |
|
|
| @classmethod |
| def from_dict(cls, d): |
| return cls(**d) |
|
|
|
|
| class PaperDB: |
| """Paper database with semantic search and citation discovery.""" |
|
|
| def __init__(self, base_dir: str = "references"): |
| self.base_dir = Path(base_dir) |
| self.db_path = self.base_dir / "paper_db.json" |
| self.papers: dict[str, Paper] = {} |
| self.s2_cache: dict[str, dict] = {} |
| self.load() |
|
|
| def load(self): |
| """Load database from disk.""" |
| if self.db_path.exists(): |
| data = json.loads(self.db_path.read_text()) |
| self.papers = {k: Paper.from_dict(v) for k, v in data.get("papers", {}).items()} |
| self.s2_cache = data.get("s2_cache", {}) |
| print(f"Loaded {len(self.papers)} papers from database") |
|
|
| def save(self): |
| """Save database to disk.""" |
| data = { |
| "papers": {k: v.to_dict() for k, v in self.papers.items()}, |
| "s2_cache": self.s2_cache |
| } |
| self.db_path.write_text(json.dumps(data, indent=2, ensure_ascii=False)) |
| print(f"Saved {len(self.papers)} papers to database") |
|
|
| def index_local_papers(self): |
| """Index all papers from local folders.""" |
| for folder in self.base_dir.iterdir(): |
| if not folder.is_dir(): |
| continue |
| if folder.name.startswith("research_") or folder.name.startswith("."): |
| continue |
|
|
| md_path = folder / "paper.md" |
| if not md_path.exists(): |
| continue |
|
|
| paper_id = folder.name |
| if paper_id in self.papers and self.papers[paper_id].fetched: |
| continue |
|
|
| |
| content = md_path.read_text(encoding='utf-8', errors='ignore') |
| metadata = self._parse_front_matter(content) |
|
|
| if not metadata: |
| print(f" Skipping {paper_id}: no metadata") |
| continue |
|
|
| paper = Paper( |
| id=paper_id, |
| title=metadata.get("title", ""), |
| authors=metadata.get("authors", []), |
| year=metadata.get("year", 0), |
| venue=metadata.get("venue", ""), |
| url=metadata.get("url", ""), |
| arxiv_id=metadata.get("arxiv"), |
| local_path=str(folder), |
| fetched=True |
| ) |
|
|
| |
| paper.keywords = self._extract_keywords(content) |
|
|
| self.papers[paper_id] = paper |
| print(f" Indexed: {paper_id} - {paper.title[:50]}...") |
|
|
| def _parse_front_matter(self, content: str) -> dict: |
| """Parse YAML front matter from markdown.""" |
| match = re.match(r'^---\n(.*?)\n---', content, re.DOTALL) |
| if match: |
| try: |
| return yaml.safe_load(match.group(1)) |
| except: |
| pass |
| return {} |
|
|
| def _extract_keywords(self, content: str) -> list[str]: |
| """Extract keywords from paper content.""" |
| keywords = set() |
| |
| terms = [ |
| "dependency parsing", "biaffine", "transformer", "BERT", "PhoBERT", |
| "Vietnamese", "treebank", "POS tagging", "NER", "word segmentation", |
| "BiLSTM", "attention", "neural", "deep learning", "pre-trained", |
| "Universal Dependencies", "CoNLL", "UAS", "LAS", "multi-task" |
| ] |
| content_lower = content.lower() |
| for term in terms: |
| if term.lower() in content_lower: |
| keywords.add(term) |
| return list(keywords) |
|
|
| def enrich_with_s2(self, paper_id: str, force: bool = False): |
| """Enrich paper with Semantic Scholar data.""" |
| if paper_id not in self.papers: |
| print(f"Paper not found: {paper_id}") |
| return |
|
|
| paper = self.papers[paper_id] |
|
|
| |
| if paper.s2_id and not force: |
| return |
|
|
| |
| s2_data = None |
| if paper.arxiv_id: |
| s2_data = self._fetch_s2_paper(f"arXiv:{paper.arxiv_id}") |
| if not s2_data and paper.title: |
| s2_data = self._search_s2_paper(paper.title) |
|
|
| if not s2_data: |
| print(f" Could not find S2 data for: {paper.title[:50]}") |
| return |
|
|
| |
| paper.s2_id = s2_data.get("paperId") |
| paper.citation_count = s2_data.get("citationCount", 0) |
| paper.abstract = s2_data.get("abstract", "") |
| if s2_data.get("tldr"): |
| paper.tldr = s2_data["tldr"].get("text", "") |
|
|
| |
| ext_ids = s2_data.get("externalIds", {}) |
| if not paper.arxiv_id and ext_ids.get("ArXiv"): |
| paper.arxiv_id = ext_ids["ArXiv"] |
| if not paper.doi and ext_ids.get("DOI"): |
| paper.doi = ext_ids["DOI"] |
|
|
| print(f" Enriched: {paper_id} (citations: {paper.citation_count})") |
|
|
| def _fetch_s2_paper(self, paper_id: str) -> Optional[dict]: |
| """Fetch paper from Semantic Scholar by ID.""" |
| if paper_id in self.s2_cache: |
| return self.s2_cache[paper_id] |
|
|
| try: |
| url = f"{S2_API}/paper/{paper_id}" |
| params = {"fields": S2_FIELDS} |
| response = requests.get(url, params=params, timeout=10) |
| if response.status_code == 200: |
| data = response.json() |
| self.s2_cache[paper_id] = data |
| return data |
| elif response.status_code == 429: |
| print(" Rate limited, waiting...") |
| time.sleep(2) |
| return self._fetch_s2_paper(paper_id) |
| except Exception as e: |
| print(f" S2 fetch error: {e}") |
| return None |
|
|
| def _search_s2_paper(self, title: str) -> Optional[dict]: |
| """Search Semantic Scholar by title.""" |
| cache_key = f"search:{title[:100]}" |
| if cache_key in self.s2_cache: |
| return self.s2_cache[cache_key] |
|
|
| try: |
| url = f"{S2_API}/paper/search" |
| params = {"query": title, "limit": 1, "fields": S2_FIELDS} |
| response = requests.get(url, params=params, timeout=10) |
| if response.status_code == 200: |
| data = response.json() |
| if data.get("data"): |
| result = data["data"][0] |
| self.s2_cache[cache_key] = result |
| return result |
| elif response.status_code == 429: |
| print(" Rate limited, waiting...") |
| time.sleep(2) |
| return self._search_s2_paper(title) |
| except Exception as e: |
| print(f" S2 search error: {e}") |
| return None |
|
|
| def get_citations(self, paper_id: str, limit: int = 20) -> list[dict]: |
| """Get papers that cite this paper.""" |
| paper = self.papers.get(paper_id) |
| if not paper or not paper.s2_id: |
| self.enrich_with_s2(paper_id) |
| paper = self.papers.get(paper_id) |
| if not paper or not paper.s2_id: |
| return [] |
|
|
| try: |
| url = f"{S2_API}/paper/{paper.s2_id}/citations" |
| params = {"fields": "title,authors,year,venue,citationCount", "limit": limit} |
| response = requests.get(url, params=params, timeout=10) |
| if response.status_code == 200: |
| data = response.json() |
| return [c["citingPaper"] for c in data.get("data", []) if c.get("citingPaper")] |
| except Exception as e: |
| print(f" Error getting citations: {e}") |
| return [] |
|
|
| def get_references(self, paper_id: str, limit: int = 20) -> list[dict]: |
| """Get papers that this paper cites.""" |
| paper = self.papers.get(paper_id) |
| if not paper or not paper.s2_id: |
| self.enrich_with_s2(paper_id) |
| paper = self.papers.get(paper_id) |
| if not paper or not paper.s2_id: |
| return [] |
|
|
| try: |
| url = f"{S2_API}/paper/{paper.s2_id}/references" |
| params = {"fields": "title,authors,year,venue,citationCount", "limit": limit} |
| response = requests.get(url, params=params, timeout=10) |
| if response.status_code == 200: |
| data = response.json() |
| return [r["citedPaper"] for r in data.get("data", []) if r.get("citedPaper")] |
| except Exception as e: |
| print(f" Error getting references: {e}") |
| return [] |
|
|
| def search(self, query: str) -> list[Paper]: |
| """Search papers by keyword.""" |
| query_lower = query.lower() |
| results = [] |
| for paper in self.papers.values(): |
| score = 0 |
| |
| if query_lower in paper.title.lower(): |
| score += 10 |
| |
| for author in paper.authors: |
| if query_lower in author.lower(): |
| score += 5 |
| |
| for kw in paper.keywords: |
| if query_lower in kw.lower(): |
| score += 3 |
| |
| if paper.abstract and query_lower in paper.abstract.lower(): |
| score += 2 |
| |
| if paper.tldr and query_lower in paper.tldr.lower(): |
| score += 2 |
|
|
| if score > 0: |
| results.append((score, paper)) |
|
|
| results.sort(key=lambda x: (-x[0], -x[1].citation_count)) |
| return [p for _, p in results] |
|
|
| def discover_related(self, topic: str = "Vietnamese dependency parsing", limit: int = 20) -> list[dict]: |
| """Discover new papers via Semantic Scholar search.""" |
| try: |
| url = f"{S2_API}/paper/search" |
| params = { |
| "query": topic, |
| "limit": limit, |
| "fields": "title,authors,year,venue,citationCount,abstract,externalIds" |
| } |
| response = requests.get(url, params=params, timeout=10) |
| if response.status_code == 200: |
| data = response.json() |
| papers = data.get("data", []) |
|
|
| |
| existing_titles = {p.title.lower() for p in self.papers.values()} |
| new_papers = [ |
| p for p in papers |
| if p.get("title", "").lower() not in existing_titles |
| ] |
| return new_papers |
| except Exception as e: |
| print(f" Error discovering papers: {e}") |
| return [] |
|
|
| def discover_via_citations(self, min_citations: int = 5) -> list[dict]: |
| """Discover new papers by following citation networks.""" |
| discovered = [] |
| seen_ids = set() |
|
|
| |
| local_s2_ids = {p.s2_id for p in self.papers.values() if p.s2_id} |
|
|
| for paper_id, paper in self.papers.items(): |
| if not paper.s2_id: |
| continue |
|
|
| |
| citations = self.get_citations(paper_id, limit=10) |
| time.sleep(0.5) |
|
|
| for cite in citations: |
| s2_id = cite.get("paperId") |
| if not s2_id or s2_id in seen_ids or s2_id in local_s2_ids: |
| continue |
| seen_ids.add(s2_id) |
|
|
| citation_count = cite.get("citationCount", 0) |
| if citation_count >= min_citations: |
| cite["_discovered_via"] = f"cites {paper_id}" |
| discovered.append(cite) |
|
|
| |
| refs = self.get_references(paper_id, limit=10) |
| time.sleep(0.5) |
|
|
| for ref in refs: |
| s2_id = ref.get("paperId") |
| if not s2_id or s2_id in seen_ids or s2_id in local_s2_ids: |
| continue |
| seen_ids.add(s2_id) |
|
|
| citation_count = ref.get("citationCount", 0) |
| if citation_count >= min_citations: |
| ref["_discovered_via"] = f"cited by {paper_id}" |
| discovered.append(ref) |
|
|
| |
| discovered.sort(key=lambda x: x.get("citationCount", 0), reverse=True) |
| return discovered |
|
|
| def generate_graph(self) -> str: |
| """Generate citation graph in Mermaid format.""" |
| lines = ["graph TD"] |
|
|
| |
| for paper_id, paper in self.papers.items(): |
| label = f"{paper.year}: {paper.title[:30]}..." |
| lines.append(f' {paper_id.replace(".", "_").replace("-", "_")}["{label}"]') |
|
|
| |
| |
|
|
| return "\n".join(lines) |
|
|
| def print_stats(self): |
| """Print database statistics.""" |
| print(f"\n=== Paper Database Statistics ===") |
| print(f"Total papers: {len(self.papers)}") |
| print(f"With S2 data: {sum(1 for p in self.papers.values() if p.s2_id)}") |
| print(f"With abstracts: {sum(1 for p in self.papers.values() if p.abstract)}") |
|
|
| |
| by_year = {} |
| for p in self.papers.values(): |
| by_year[p.year] = by_year.get(p.year, 0) + 1 |
| print(f"\nBy year:") |
| for year in sorted(by_year.keys()): |
| print(f" {year}: {by_year[year]} papers") |
|
|
| |
| by_venue = {} |
| for p in self.papers.values(): |
| venue = p.venue.split()[0] if p.venue else "Unknown" |
| by_venue[venue] = by_venue.get(venue, 0) + 1 |
| print(f"\nBy venue:") |
| for venue, count in sorted(by_venue.items(), key=lambda x: -x[1])[:10]: |
| print(f" {venue}: {count} papers") |
|
|
| |
| print(f"\nTop cited papers:") |
| for p in sorted(self.papers.values(), key=lambda x: -x.citation_count)[:5]: |
| if p.citation_count > 0: |
| print(f" [{p.citation_count}] {p.title[:60]}...") |
|
|
|
|
| def fetch_arxiv_paper(arxiv_id: str, db: PaperDB): |
| """Fetch a paper from arXiv and add to database.""" |
| import arxiv |
| import pymupdf4llm |
| import unicodedata |
| import tarfile |
| import gzip |
| from io import BytesIO |
|
|
| arxiv_id = re.sub(r'^(arxiv:|https?://arxiv\.org/(abs|pdf)/)', '', arxiv_id) |
| arxiv_id = arxiv_id.rstrip('.pdf').rstrip('/') |
|
|
| |
| print(f" Fetching metadata for arXiv:{arxiv_id}...") |
| client = arxiv.Client() |
| try: |
| paper = next(client.results(arxiv.Search(id_list=[arxiv_id]))) |
| except StopIteration: |
| print(f" Paper not found: {arxiv_id}") |
| return None |
|
|
| |
| year = paper.published.year |
| first_author = paper.authors[0].name if paper.authors else "unknown" |
| |
| lastname = first_author.split()[-1] if first_author.split() else first_author |
| normalized = unicodedata.normalize('NFD', lastname) |
| author = ''.join(c for c in normalized if unicodedata.category(c) != 'Mn').lower() |
|
|
| folder_name = f"{year}.arxiv.{author}" |
| folder = db.base_dir / folder_name |
| folder.mkdir(exist_ok=True) |
| print(f" Title: {paper.title[:60]}...") |
| print(f" Folder: {folder}") |
|
|
| |
| authors_yaml = '\n'.join(f' - "{a.name}"' for a in paper.authors) |
| front_matter = f'''--- |
| title: "{paper.title}" |
| authors: |
| {authors_yaml} |
| year: {year} |
| venue: "arXiv" |
| url: "{paper.entry_id}" |
| arxiv: "{arxiv_id}" |
| --- |
| |
| ''' |
|
|
| |
| tex_content = None |
| source_url = f"https://arxiv.org/e-print/{arxiv_id}" |
| try: |
| response = requests.get(source_url, allow_redirects=True, timeout=30) |
| response.raise_for_status() |
| content = response.content |
|
|
| try: |
| with tarfile.open(fileobj=BytesIO(content), mode='r:gz') as tar: |
| tex_files = [m.name for m in tar.getmembers() if m.name.endswith('.tex')] |
| source_dir = folder / "source" |
| source_dir.mkdir(exist_ok=True) |
| tar.extractall(path=source_dir) |
| print(f" Extracted {len(tar.getmembers())} source files") |
|
|
| main_tex = None |
| for name in tex_files: |
| if 'main' in name.lower(): |
| main_tex = name |
| break |
| if not main_tex and tex_files: |
| main_tex = tex_files[0] |
|
|
| if main_tex: |
| with open(source_dir / main_tex, 'r', encoding='utf-8', errors='ignore') as f: |
| tex_content = f.read() |
| except tarfile.TarError: |
| try: |
| tex_content = gzip.decompress(content).decode('utf-8', errors='ignore') |
| if '\\documentclass' not in tex_content: |
| tex_content = None |
| except: |
| pass |
| except Exception as e: |
| print(f" Could not fetch source: {e}") |
|
|
| if tex_content: |
| (folder / "paper.tex").write_text(tex_content, encoding='utf-8') |
| print(f" Saved: paper.tex") |
|
|
| |
| md = tex_content |
| doc_match = re.search(r'\\begin\{document\}', md) |
| if doc_match: |
| md = md[doc_match.end():] |
| md = re.sub(r'\\end\{document\}.*', '', md, flags=re.DOTALL) |
| md = re.sub(r'%.*$', '', md, flags=re.MULTILINE) |
| md = re.sub(r'\\section\*?\{([^}]+)\}', r'# \1', md) |
| md = re.sub(r'\\subsection\*?\{([^}]+)\}', r'## \1', md) |
| md = re.sub(r'\\textbf\{([^}]+)\}', r'**\1**', md) |
| md = re.sub(r'\\textit\{([^}]+)\}', r'*\1*', md) |
| md = re.sub(r'\\cite\w*\{([^}]+)\}', r'[\1]', md) |
|
|
| (folder / "paper.md").write_text(front_matter + md.strip(), encoding='utf-8') |
| print(f" Generated: paper.md") |
| has_source = True |
| else: |
| has_source = False |
|
|
| |
| pdf_path = folder / "paper.pdf" |
| paper.download_pdf(filename=str(pdf_path)) |
| print(f" Downloaded: paper.pdf") |
|
|
| |
| if not has_source: |
| md_text = pymupdf4llm.to_markdown(str(pdf_path)) |
| (folder / "paper.md").write_text(front_matter + md_text, encoding='utf-8') |
| print(f" Extracted: paper.md (from PDF)") |
|
|
| |
| new_paper = Paper( |
| id=folder_name, |
| title=paper.title, |
| authors=[a.name for a in paper.authors], |
| year=year, |
| venue="arXiv", |
| url=paper.entry_id, |
| arxiv_id=arxiv_id, |
| local_path=str(folder), |
| fetched=True |
| ) |
| db.papers[folder_name] = new_paper |
| db.enrich_with_s2(folder_name) |
|
|
| return folder_name |
|
|
|
|
| def main(): |
| if len(sys.argv) < 2: |
| print(__doc__) |
| return |
|
|
| |
| script_dir = Path(__file__).parent |
| os.chdir(script_dir.parent) |
|
|
| db = PaperDB("references") |
| cmd = sys.argv[1] |
|
|
| if cmd == "index": |
| print("Indexing local papers...") |
| db.index_local_papers() |
| print("\nEnriching with Semantic Scholar data...") |
| for paper_id in list(db.papers.keys()): |
| db.enrich_with_s2(paper_id) |
| time.sleep(0.5) |
| db.save() |
| db.print_stats() |
|
|
| elif cmd == "search": |
| query = " ".join(sys.argv[2:]) if len(sys.argv) > 2 else "Vietnamese" |
| print(f"Searching for: {query}\n") |
| results = db.search(query) |
| for paper in results[:10]: |
| print(f" [{paper.year}] {paper.title[:60]}...") |
| print(f" Authors: {', '.join(paper.authors[:3])}") |
| print(f" Citations: {paper.citation_count}, Venue: {paper.venue}") |
| if paper.tldr: |
| print(f" TLDR: {paper.tldr[:100]}...") |
| print() |
|
|
| elif cmd == "cite": |
| paper_id = sys.argv[2] if len(sys.argv) > 2 else "" |
| if not paper_id: |
| print("Usage: paper_db.py cite <paper_id>") |
| return |
| print(f"Citations for: {paper_id}\n") |
| citations = db.get_citations(paper_id) |
| for cite in citations[:15]: |
| authors = [a["name"] for a in cite.get("authors", [])[:2]] |
| print(f" [{cite.get('year', '?')}] {cite.get('title', '?')[:60]}...") |
| print(f" Authors: {', '.join(authors)}") |
| print(f" Citations: {cite.get('citationCount', 0)}") |
| print() |
|
|
| elif cmd == "refs": |
| paper_id = sys.argv[2] if len(sys.argv) > 2 else "" |
| if not paper_id: |
| print("Usage: paper_db.py refs <paper_id>") |
| return |
| print(f"References from: {paper_id}\n") |
| refs = db.get_references(paper_id) |
| for ref in refs[:15]: |
| authors = [a["name"] for a in ref.get("authors", [])[:2]] |
| print(f" [{ref.get('year', '?')}] {ref.get('title', '?')[:60]}...") |
| print(f" Authors: {', '.join(authors)}") |
| print(f" Citations: {ref.get('citationCount', 0)}") |
| print() |
|
|
| elif cmd == "related": |
| paper_id = sys.argv[2] if len(sys.argv) > 2 else "" |
| if paper_id and paper_id in db.papers: |
| paper = db.papers[paper_id] |
| query = paper.title |
| else: |
| query = " ".join(sys.argv[2:]) if len(sys.argv) > 2 else "Vietnamese dependency parsing" |
| print(f"Finding related papers for: {query[:50]}...\n") |
| related = db.discover_related(query, limit=15) |
| for p in related: |
| authors = [a["name"] for a in p.get("authors", [])[:2]] |
| print(f" [{p.get('year', '?')}] {p.get('title', '?')[:60]}...") |
| print(f" Authors: {', '.join(authors)}") |
| print(f" Citations: {p.get('citationCount', 0)}, Venue: {p.get('venue', '?')}") |
| ext_ids = p.get("externalIds", {}) |
| if ext_ids.get("ArXiv"): |
| print(f" arXiv: {ext_ids['ArXiv']}") |
| print() |
|
|
| elif cmd == "discover": |
| print("Discovering new papers via citation network...\n") |
| discovered = db.discover_via_citations(min_citations=10) |
| print(f"Found {len(discovered)} new papers:\n") |
| for p in discovered[:20]: |
| authors = [a["name"] for a in p.get("authors", [])[:2]] |
| print(f" [{p.get('year', '?')}] {p.get('title', '?')[:60]}...") |
| print(f" Authors: {', '.join(authors)}") |
| print(f" Citations: {p.get('citationCount', 0)}") |
| print(f" Discovered via: {p.get('_discovered_via', '?')}") |
| ext_ids = p.get("externalIds", {}) |
| if ext_ids.get("ArXiv"): |
| print(f" arXiv: {ext_ids['ArXiv']}") |
| print() |
|
|
| elif cmd == "graph": |
| print("Generating citation graph...\n") |
| graph = db.generate_graph() |
| print(graph) |
|
|
| elif cmd == "fetch": |
| arxiv_id = sys.argv[2] if len(sys.argv) > 2 else "" |
| if not arxiv_id: |
| print("Usage: paper_db.py fetch <arxiv_id>") |
| return |
| print(f"Fetching paper: {arxiv_id}") |
| fetch_arxiv_paper(arxiv_id, db) |
| db.save() |
|
|
| elif cmd == "stats": |
| db.print_stats() |
|
|
| else: |
| print(f"Unknown command: {cmd}") |
| print(__doc__) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|