hf-papers / hf_papers_tool.py
evalstate's picture
evalstate HF Staff
Update hf_papers_tool.py
5daf397 verified
from __future__ import annotations
import json
import os
import re
from pathlib import Path
from typing import Any
from urllib.error import HTTPError, URLError
from urllib.parse import urlencode
from urllib.request import Request, urlopen
DEFAULT_LIMIT = 20
DEFAULT_TIMEOUT_SEC = 30
MAX_API_LIMIT = 100
def _load_token() -> str | None:
# Check for request-scoped token first (when running as MCP server)
try:
from fast_agent.mcp.auth.context import request_bearer_token
ctx_token = request_bearer_token.get()
if ctx_token:
return ctx_token
except ImportError:
pass
return None
def _normalize_date_param(value: str | None) -> str | None:
if not value:
return None
return value.strip()
def _build_url(params: dict[str, Any]) -> str:
base = os.getenv("HF_ENDPOINT", "https://huggingface.co").rstrip("/")
query = urlencode({k: v for k, v in params.items() if v is not None}, doseq=True)
return f"{base}/api/daily_papers?{query}" if query else f"{base}/api/daily_papers"
def _request_json(url: str) -> list[dict[str, Any]]:
headers = {"Accept": "application/json"}
token = _load_token()
if token:
headers["Authorization"] = f"Bearer {token}"
request = Request(url, headers=headers, method="GET")
try:
with urlopen(request, timeout=DEFAULT_TIMEOUT_SEC) as response:
raw = response.read()
except HTTPError as exc:
error_body = exc.read().decode("utf-8", errors="replace")
raise RuntimeError(f"HF API error {exc.code} for {url}: {error_body}") from exc
except URLError as exc:
raise RuntimeError(f"HF API request failed for {url}: {exc}") from exc
payload = json.loads(raw)
if not isinstance(payload, list):
raise RuntimeError("Unexpected response shape from /api/daily_papers")
return payload
def _extract_search_blob(item: dict[str, Any]) -> str:
paper = item.get("paper") or {}
authors = paper.get("authors") or []
author_names = [a.get("name", "") for a in authors if isinstance(a, dict)]
ai_keywords = paper.get("ai_keywords") or []
if isinstance(ai_keywords, list):
ai_keywords_text = " ".join(str(k) for k in ai_keywords)
else:
ai_keywords_text = str(ai_keywords)
parts = [
item.get("title"),
item.get("summary"),
paper.get("title"),
paper.get("summary"),
paper.get("ai_summary"),
ai_keywords_text,
" ".join(author_names),
paper.get("id"),
paper.get("projectPage"),
paper.get("githubRepo"),
]
text = " ".join(str(part) for part in parts if part)
return text.lower()
def _matches_query(item: dict[str, Any], query: str) -> bool:
tokens = [t for t in re.split(r"\s+", query.strip().lower()) if t]
if not tokens:
return True
haystack = _extract_search_blob(item)
return all(token in haystack for token in tokens)
def hf_papers_search(
query: str | None = None,
*,
date: str | None = None,
week: str | None = None,
month: str | None = None,
submitter: str | None = None,
sort: str | None = None,
limit: int | None = None,
page: int | None = None,
max_pages: int | None = None,
api_limit: int | None = None,
) -> dict[str, Any]:
"""
Search Hugging Face Daily Papers with optional local filtering.
Args:
query: Case-insensitive keyword search across title, summary, authors,
AI summary/keywords, project page, repo link, and paper id.
date: ISO date (YYYY-MM-DD).
week: ISO week (YYYY-Www).
month: ISO month (YYYY-MM).
submitter: HF username of the submitter.
sort: "publishedAt" or "trending".
limit: Max results to return after filtering (default 20).
page: Page index for the API (default 0).
max_pages: Number of pages to fetch for local filtering (default 1).
api_limit: Page size for the API (default 50, max 100).
Returns:
dict with query metadata and list of daily paper entries.
"""
resolved_limit = DEFAULT_LIMIT if limit is None else max(int(limit), 1)
start_page = max(int(page or 0), 0)
pages_to_fetch = max(int(max_pages or 1), 1)
per_page = 50 if api_limit is None else max(int(api_limit), 1)
per_page = min(per_page, MAX_API_LIMIT)
params_base: dict[str, Any] = {
"date": _normalize_date_param(date),
"week": _normalize_date_param(week),
"month": _normalize_date_param(month),
"submitter": submitter.strip() if submitter else None,
"sort": sort.strip() if sort else None,
"limit": per_page,
}
results: list[dict[str, Any]] = []
pages_fetched = 0
for page_index in range(start_page, start_page + pages_to_fetch):
params = {**params_base, "p": page_index}
url = _build_url(params)
payload = _request_json(url)
pages_fetched += 1
if query:
filtered = [item for item in payload if _matches_query(item, query)]
else:
filtered = payload
results.extend(filtered)
if len(results) >= resolved_limit:
break
return {
"query": query,
"params": {
**{k: v for k, v in params_base.items() if v is not None},
"page": start_page,
"max_pages": pages_fetched,
"api_limit": per_page,
},
"returned": min(len(results), resolved_limit),
"data": results[:resolved_limit],
}