| |
| """Paper research helper for ML Intern Codex. |
| |
| Emulates the useful parts of upstream ml-intern's hf_papers tool with public |
| HTTP APIs: Hugging Face Papers, arXiv/ar5iv HTML, and Semantic Scholar. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import html.parser |
| import json |
| import os |
| import re |
| import sys |
| import urllib.error |
| import urllib.parse |
| import urllib.request |
| from concurrent.futures import ThreadPoolExecutor, as_completed |
| from typing import Any |
|
|
|
|
| HF_API = "https://huggingface.co/api" |
| ARXIV_HTML = "https://arxiv.org/html" |
| AR5IV_HTML = "https://ar5iv.labs.arxiv.org/html" |
| S2_API = "https://api.semanticscholar.org" |
| MAX_SECTION_TEXT_LEN = 8000 |
|
|
|
|
| def request_json(url: str, params: dict[str, Any] | None = None, method: str = "GET", body: dict[str, Any] | None = None) -> Any: |
| if params: |
| url = f"{url}?{urllib.parse.urlencode({k: v for k, v in params.items() if v is not None})}" |
| data = None |
| headers = {"User-Agent": "ml-intern-codex/0.1"} |
| if body is not None: |
| data = json.dumps(body).encode("utf-8") |
| headers["Content-Type"] = "application/json" |
| s2_key = os.environ.get("S2_API_KEY") |
| if s2_key and url.startswith(S2_API): |
| headers["x-api-key"] = s2_key |
| request = urllib.request.Request(url, data=data, headers=headers, method=method) |
| try: |
| with urllib.request.urlopen(request, timeout=30) as response: |
| return json.loads(response.read().decode("utf-8")) |
| except urllib.error.HTTPError as exc: |
| text = exc.read().decode("utf-8", errors="replace") |
| raise RuntimeError(f"{url} returned HTTP {exc.code}: {text[:500]}") from exc |
|
|
|
|
| def request_text(url: str) -> str: |
| request = urllib.request.Request(url, headers={"User-Agent": "ml-intern-codex/0.1"}) |
| with urllib.request.urlopen(request, timeout=30) as response: |
| return response.read().decode("utf-8", errors="replace") |
|
|
|
|
| def arxiv_s2_id(arxiv_id: str) -> str: |
| return f"ARXIV:{arxiv_id}" |
|
|
|
|
| def truncate(text: str, limit: int) -> str: |
| text = re.sub(r"\s+", " ", text).strip() |
| return text if len(text) <= limit else text[:limit].rstrip() + "..." |
|
|
|
|
| def paper_arxiv_id(paper: dict[str, Any]) -> str: |
| external = paper.get("externalIds") or paper.get("external_ids") or {} |
| return external.get("ArXiv") or paper.get("arxiv_id") or paper.get("id", "") |
|
|
|
|
| def format_hf_paper(paper: dict[str, Any], idx: int) -> str: |
| nested = paper.get("paper") if isinstance(paper.get("paper"), dict) else paper |
| title = nested.get("title") or paper.get("title") or "(untitled)" |
| arxiv_id = nested.get("id") or nested.get("arxivId") or paper.get("id") or "" |
| summary = nested.get("summary") or nested.get("abstract") or "" |
| lines = [f"### {idx}. {title}"] |
| if arxiv_id: |
| lines.append(f"arxiv_id: {arxiv_id}") |
| lines.append(f"https://arxiv.org/abs/{arxiv_id}") |
| if nested.get("publishedAt"): |
| lines.append(f"Published: {nested['publishedAt']}") |
| if nested.get("githubUrl"): |
| lines.append(f"GitHub: {nested['githubUrl']}") |
| if summary: |
| lines.append(truncate(summary, 500)) |
| return "\n".join(lines) |
|
|
|
|
| def format_s2_paper(paper: dict[str, Any], idx: int) -> str: |
| title = paper.get("title") or "(untitled)" |
| year = paper.get("year") or "?" |
| cites = paper.get("citationCount", 0) |
| venue = paper.get("venue") or "" |
| arxiv_id = paper_arxiv_id(paper) |
| tldr = (paper.get("tldr") or {}).get("text", "") |
| parts = [f"Year: {year}", f"Citations: {cites}"] |
| if venue: |
| parts.append(f"Venue: {venue}") |
| if arxiv_id: |
| parts.append(f"arxiv_id: {arxiv_id}") |
| lines = [f"### {idx}. {title}", " | ".join(parts)] |
| if arxiv_id: |
| lines.append(f"https://arxiv.org/abs/{arxiv_id}") |
| if tldr: |
| lines.append(f"TL;DR: {tldr}") |
| return "\n".join(lines) |
|
|
|
|
| class ArxivHTMLParser(html.parser.HTMLParser): |
| def __init__(self) -> None: |
| super().__init__() |
| self.capture_title = False |
| self.capture_abstract = False |
| self.capture_heading = False |
| self.capture_paragraph = False |
| self.title_parts: list[str] = [] |
| self.abstract_parts: list[str] = [] |
| self.sections: list[dict[str, Any]] = [] |
| self.current_heading: list[str] = [] |
| self.current_paragraph: list[str] = [] |
| self.current_section: dict[str, Any] | None = None |
|
|
| def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None: |
| classes = dict(attrs).get("class", "") or "" |
| if tag == "h1" and "ltx_title" in classes: |
| self.capture_title = True |
| elif tag == "div" and "ltx_abstract" in classes: |
| self.capture_abstract = True |
| elif tag in {"h2", "h3"} and "ltx_title" in classes: |
| self.capture_heading = True |
| self.current_heading = [] |
| elif tag == "p": |
| self.capture_paragraph = True |
| self.current_paragraph = [] |
|
|
| def handle_endtag(self, tag: str) -> None: |
| if tag == "h1" and self.capture_title: |
| self.capture_title = False |
| elif tag == "div" and self.capture_abstract: |
| self.capture_abstract = False |
| elif tag in {"h2", "h3"} and self.capture_heading: |
| heading = truncate(" ".join(self.current_heading), 500) |
| section_id = "" |
| match = re.match(r"^([A-Z]?\d+(?:\.\d+)*)\s", heading) |
| if match: |
| section_id = match.group(1) |
| self.current_section = {"id": section_id, "title": heading, "text": ""} |
| self.sections.append(self.current_section) |
| self.capture_heading = False |
| elif tag == "p" and self.capture_paragraph: |
| paragraph = truncate(" ".join(self.current_paragraph), 4000) |
| if paragraph: |
| if self.capture_abstract: |
| self.abstract_parts.append(paragraph) |
| elif self.current_section is not None: |
| existing = self.current_section["text"] |
| self.current_section["text"] = (existing + "\n\n" + paragraph).strip() |
| self.capture_paragraph = False |
|
|
| def handle_data(self, data: str) -> None: |
| text = data.strip() |
| if not text: |
| return |
| if self.capture_title: |
| self.title_parts.append(text.removeprefix("Title:")) |
| if self.capture_heading: |
| self.current_heading.append(text) |
| if self.capture_paragraph: |
| self.current_paragraph.append(text) |
|
|
|
|
| def parse_arxiv_html(html_text: str) -> dict[str, Any]: |
| parser = ArxivHTMLParser() |
| parser.feed(html_text) |
| return { |
| "title": truncate(" ".join(parser.title_parts), 500), |
| "abstract": truncate(" ".join(parser.abstract_parts), 2000), |
| "sections": parser.sections, |
| } |
|
|
|
|
| def op_trending(args: argparse.Namespace) -> str: |
| params: dict[str, Any] = {"limit": args.limit * 3 if args.query else args.limit} |
| if args.date: |
| params["date"] = args.date |
| papers = request_json(f"{HF_API}/daily_papers", params) |
| if args.query: |
| needle = args.query.lower() |
| papers = [ |
| paper |
| for paper in papers |
| if needle in json.dumps(paper, ensure_ascii=False).lower() |
| ] |
| lines = ["# Trending Papers"] |
| for idx, paper in enumerate(papers[: args.limit], 1): |
| lines.append(format_hf_paper(paper, idx)) |
| lines.append("") |
| return "\n".join(lines) |
|
|
|
|
| def op_search(args: argparse.Namespace) -> str: |
| if not args.query: |
| raise SystemExit("search requires --query") |
| use_s2 = any([args.date_from, args.date_to, args.categories, args.min_citations, args.sort_by != "relevance"]) |
| if use_s2: |
| params: dict[str, Any] = { |
| "query": args.query, |
| "limit": args.limit, |
| "fields": "title,externalIds,year,citationCount,tldr,venue,publicationDate", |
| } |
| if args.date_from or args.date_to: |
| params["publicationDateOrYear"] = f"{args.date_from or ''}:{args.date_to or ''}" |
| if args.categories: |
| params["fieldsOfStudy"] = args.categories |
| if args.min_citations: |
| params["minCitationCount"] = str(args.min_citations) |
| if args.sort_by != "relevance": |
| params["sort"] = f"{args.sort_by}:desc" |
| data = request_json(f"{S2_API}/graph/v1/paper/search/bulk", params) |
| papers = data.get("data", []) |
| lines = [f"# Papers matching '{args.query}' (Semantic Scholar)"] |
| for idx, paper in enumerate(papers[: args.limit], 1): |
| lines.append(format_s2_paper(paper, idx)) |
| lines.append("") |
| return "\n".join(lines) |
| papers = request_json(f"{HF_API}/papers/search", {"q": args.query, "limit": args.limit}) |
| lines = [f"# Papers matching '{args.query}' (Hugging Face Papers)"] |
| for idx, paper in enumerate(papers[: args.limit], 1): |
| lines.append(format_hf_paper(paper, idx)) |
| lines.append("") |
| return "\n".join(lines) |
|
|
|
|
| def op_paper_details(args: argparse.Namespace) -> str: |
| if not args.arxiv_id: |
| raise SystemExit("paper_details requires --arxiv-id") |
| paper = request_json(f"{HF_API}/papers/{args.arxiv_id}") |
| lines = [f"# {paper.get('title', args.arxiv_id)}", f"https://huggingface.co/papers/{args.arxiv_id}", f"https://arxiv.org/abs/{args.arxiv_id}"] |
| for key in ("publishedAt", "submittedOnDailyAt", "githubUrl"): |
| if paper.get(key): |
| lines.append(f"{key}: {paper[key]}") |
| if paper.get("summary"): |
| lines.append("") |
| lines.append("## Abstract") |
| lines.append(paper["summary"]) |
| if paper.get("ai_summary"): |
| lines.append("") |
| lines.append("## AI Summary") |
| lines.append(str(paper["ai_summary"])) |
| return "\n".join(lines) |
|
|
|
|
| def op_read_paper(args: argparse.Namespace) -> str: |
| if not args.arxiv_id: |
| raise SystemExit("read_paper requires --arxiv-id") |
| parsed = None |
| for base in (ARXIV_HTML, AR5IV_HTML): |
| try: |
| parsed = parse_arxiv_html(request_text(f"{base}/{args.arxiv_id}")) |
| if parsed["sections"]: |
| break |
| except Exception: |
| parsed = None |
| if not parsed or not parsed["sections"]: |
| return op_paper_details(args) + f"\n\nHTML sections unavailable. PDF: https://arxiv.org/pdf/{args.arxiv_id}" |
| if not args.section: |
| lines = [f"# {parsed['title'] or args.arxiv_id}", f"https://arxiv.org/abs/{args.arxiv_id}", "", "## Abstract", parsed["abstract"], "", "## Sections"] |
| for section in parsed["sections"]: |
| preview = truncate(section.get("text", ""), 280) |
| lines.append(f"- {section['title']}: {preview}") |
| return "\n".join(lines) |
| wanted = args.section.lower() |
| selected = None |
| for section in parsed["sections"]: |
| if section["id"].lower() == wanted or wanted in section["title"].lower(): |
| selected = section |
| break |
| if not selected: |
| available = "\n".join(f"- {section['title']}" for section in parsed["sections"]) |
| raise SystemExit(f"section not found. Available sections:\n{available}") |
| return "\n".join([ |
| f"# {selected['title']}", |
| f"https://arxiv.org/abs/{args.arxiv_id}", |
| "", |
| truncate(selected.get("text", ""), MAX_SECTION_TEXT_LEN), |
| ]) |
|
|
|
|
| def format_citation(entry: dict[str, Any]) -> str: |
| paper = entry.get("citingPaper") or entry.get("citedPaper") or {} |
| title = paper.get("title") or "(untitled)" |
| year = paper.get("year") or "?" |
| cites = paper.get("citationCount", 0) |
| arxiv_id = paper_arxiv_id(paper) |
| line = f"- {title} ({year}, {cites} cites)" |
| if arxiv_id: |
| line += f" arxiv:{arxiv_id}" |
| if entry.get("isInfluential"): |
| line += " [influential]" |
| contexts = entry.get("contexts") or [] |
| if contexts: |
| line += f"\n > {truncate(contexts[0], 220)}" |
| return line |
|
|
|
|
| def op_citation_graph(args: argparse.Namespace) -> str: |
| if not args.arxiv_id: |
| raise SystemExit("citation_graph requires --arxiv-id") |
| fields = "title,externalIds,year,citationCount,influentialCitationCount,contexts,intents,isInfluential" |
| params = {"fields": fields, "limit": args.limit} |
| paths: dict[str, str] = {} |
| if args.direction in {"references", "both"}: |
| paths["References"] = f"/graph/v1/paper/{arxiv_s2_id(args.arxiv_id)}/references" |
| if args.direction in {"citations", "both"}: |
| paths["Citations"] = f"/graph/v1/paper/{arxiv_s2_id(args.arxiv_id)}/citations" |
| lines = [f"# Citation Graph for {args.arxiv_id}", f"https://arxiv.org/abs/{args.arxiv_id}"] |
| with ThreadPoolExecutor(max_workers=2) as pool: |
| futures = {pool.submit(request_json, f"{S2_API}{path}", params): name for name, path in paths.items()} |
| for future in as_completed(futures): |
| name = futures[future] |
| lines.append("") |
| lines.append(f"## {name}") |
| try: |
| data = future.result() |
| for entry in data.get("data", []): |
| lines.append(format_citation(entry)) |
| except Exception as exc: |
| lines.append(f"Error: {exc}") |
| return "\n".join(lines) |
|
|
|
|
| def op_resources(args: argparse.Namespace) -> str: |
| if not args.arxiv_id: |
| raise SystemExit(f"{args.operation} requires --arxiv-id") |
| sort = {"downloads": "downloads", "likes": "likes", "trending": "trendingScore"}[args.sort] |
| calls: dict[str, tuple[str, dict[str, Any]]] = {} |
| if args.operation in {"find_datasets", "find_all_resources"}: |
| calls["Datasets"] = (f"{HF_API}/datasets", {"filter": f"arxiv:{args.arxiv_id}", "limit": args.limit, "sort": sort, "direction": -1}) |
| if args.operation in {"find_models", "find_all_resources"}: |
| calls["Models"] = (f"{HF_API}/models", {"filter": f"arxiv:{args.arxiv_id}", "limit": args.limit, "sort": sort, "direction": -1}) |
| if args.operation in {"find_collections", "find_all_resources"}: |
| calls["Collections"] = (f"{HF_API}/collections", {"paper": args.arxiv_id}) |
| lines = [f"# Resources linked to paper {args.arxiv_id}", f"https://huggingface.co/papers/{args.arxiv_id}"] |
| with ThreadPoolExecutor(max_workers=3) as pool: |
| futures = {pool.submit(request_json, url, params): name for name, (url, params) in calls.items()} |
| for future in as_completed(futures): |
| name = futures[future] |
| lines.append("") |
| lines.append(f"## {name}") |
| try: |
| items = future.result() |
| for item in items[: args.limit]: |
| repo_id = item.get("id") or item.get("modelId") or item.get("slug") or item.get("title") |
| likes = item.get("likes") |
| downloads = item.get("downloads") |
| meta = [] |
| if downloads is not None: |
| meta.append(f"downloads={downloads}") |
| if likes is not None: |
| meta.append(f"likes={likes}") |
| lines.append(f"- {repo_id}" + (f" ({', '.join(meta)})" if meta else "")) |
| except Exception as exc: |
| lines.append(f"Error: {exc}") |
| return "\n".join(lines) |
|
|
|
|
| def op_snippet_search(args: argparse.Namespace) -> str: |
| if not args.query: |
| raise SystemExit("snippet_search requires --query") |
| params: dict[str, Any] = {"query": args.query, "limit": args.limit, "fields": "title,externalIds,year,citationCount"} |
| if args.date_from or args.date_to: |
| params["publicationDateOrYear"] = f"{args.date_from or ''}:{args.date_to or ''}" |
| if args.categories: |
| params["fieldsOfStudy"] = args.categories |
| if args.min_citations: |
| params["minCitationCount"] = str(args.min_citations) |
| data = request_json(f"{S2_API}/graph/v1/snippet/search", params) |
| lines = [f"# Snippet Search: {args.query}"] |
| for idx, item in enumerate(data.get("data", [])[: args.limit], 1): |
| paper = item.get("paper") or {} |
| snippet = item.get("snippet") or {} |
| lines.append(f"### {idx}. {paper.get('title', '(untitled)')}") |
| arxiv_id = paper_arxiv_id(paper) |
| if arxiv_id: |
| lines.append(f"arxiv:{arxiv_id}") |
| if snippet.get("section"): |
| lines.append(f"Section: {snippet['section']}") |
| if snippet.get("text"): |
| lines.append(f"> {truncate(snippet['text'], 400)}") |
| lines.append("") |
| return "\n".join(lines) |
|
|
|
|
| def op_recommend(args: argparse.Namespace) -> str: |
| if not args.arxiv_id and not args.positive_ids: |
| raise SystemExit("recommend requires --arxiv-id or --positive-ids") |
| fields = "title,externalIds,year,citationCount,tldr,venue" |
| if args.positive_ids and not args.arxiv_id: |
| positive = [arxiv_s2_id(item.strip()) for item in args.positive_ids.split(",") if item.strip()] |
| negative = [arxiv_s2_id(item.strip()) for item in args.negative_ids.split(",") if item.strip()] |
| data = request_json( |
| f"{S2_API}/recommendations/v1/papers/", |
| {"fields": fields, "limit": args.limit}, |
| method="POST", |
| body={"positivePaperIds": positive, "negativePaperIds": negative}, |
| ) |
| else: |
| data = request_json( |
| f"{S2_API}/recommendations/v1/papers/forpaper/{arxiv_s2_id(args.arxiv_id)}", |
| {"fields": fields, "limit": args.limit, "from": "recent"}, |
| ) |
| papers = data.get("recommendedPapers", []) |
| lines = ["# Recommended Papers"] |
| for idx, paper in enumerate(papers[: args.limit], 1): |
| lines.append(format_s2_paper(paper, idx)) |
| lines.append("") |
| return "\n".join(lines) |
|
|
|
|
| def build_parser() -> argparse.ArgumentParser: |
| parser = argparse.ArgumentParser(description=__doc__) |
| parser.add_argument("operation", choices=[ |
| "trending", |
| "search", |
| "paper_details", |
| "read_paper", |
| "citation_graph", |
| "snippet_search", |
| "recommend", |
| "find_datasets", |
| "find_models", |
| "find_collections", |
| "find_all_resources", |
| ]) |
| parser.add_argument("--query") |
| parser.add_argument("--arxiv-id") |
| parser.add_argument("--section") |
| parser.add_argument("--direction", choices=["citations", "references", "both"], default="both") |
| parser.add_argument("--date") |
| parser.add_argument("--date-from", default="") |
| parser.add_argument("--date-to", default="") |
| parser.add_argument("--categories") |
| parser.add_argument("--min-citations", type=int) |
| parser.add_argument("--sort-by", choices=["relevance", "citationCount", "publicationDate"], default="relevance") |
| parser.add_argument("--positive-ids", default="") |
| parser.add_argument("--negative-ids", default="") |
| parser.add_argument("--sort", choices=["downloads", "likes", "trending"], default="downloads") |
| parser.add_argument("--limit", type=int, default=10) |
| return parser |
|
|
|
|
| def main() -> int: |
| args = build_parser().parse_args() |
| args.limit = min(max(args.limit, 1), 50) |
| handlers = { |
| "trending": op_trending, |
| "search": op_search, |
| "paper_details": op_paper_details, |
| "read_paper": op_read_paper, |
| "citation_graph": op_citation_graph, |
| "snippet_search": op_snippet_search, |
| "recommend": op_recommend, |
| "find_datasets": op_resources, |
| "find_models": op_resources, |
| "find_collections": op_resources, |
| "find_all_resources": op_resources, |
| } |
| try: |
| print(handlers[args.operation](args)) |
| except Exception as exc: |
| print(f"Error running papers {args.operation}: {exc}", file=sys.stderr) |
| return 1 |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|