stochastic / arxiv_tool.py
Sonu Prasad
initial commit
822c114
import arxiv
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import urllib.request
from config import config
@dataclass
class ArxivPaper:
arxiv_id: str
title: str
authors: list[str]
abstract: str
pdf_url: str
local_path: Optional[Path] = None
class ArxivTool:
def search(self, query: str, max_results: int = 5) -> list[ArxivPaper]:
try:
if query.replace('.', '').replace('v', '').isdigit() or '.' in query and len(query) < 20:
search = arxiv.Search(id_list=[query])
else:
search = arxiv.Search(
query=query,
max_results=max_results,
sort_by=arxiv.SortCriterion.Relevance
)
papers = []
for result in search.results():
papers.append(ArxivPaper(
arxiv_id=result.entry_id.split('/')[-1],
title=result.title,
authors=[a.name for a in result.authors],
abstract=result.summary,
pdf_url=result.pdf_url
))
return papers
except Exception:
return []
def download(self, paper: ArxivPaper) -> Optional[Path]:
try:
config.PAPERS_DIR.mkdir(parents=True, exist_ok=True)
safe_title = "".join(c if c.isalnum() or c in ' -_' else '' for c in paper.title)[:50]
filename = f"{paper.arxiv_id}_{safe_title}.pdf"
filepath = config.PAPERS_DIR / filename
if not filepath.exists():
urllib.request.urlretrieve(paper.pdf_url, filepath)
paper.local_path = filepath
return filepath
except Exception:
return None
arxiv_tool = ArxivTool()