Spaces:
Sleeping
Sleeping
| """AI/ML paper pipeline. | |
| Fetches papers from HuggingFace Daily Papers + arXiv, enriches with | |
| HF ecosystem metadata, and writes to the database. | |
| """ | |
| import logging | |
| import re | |
| import time | |
| from datetime import datetime, timedelta, timezone | |
| import arxiv | |
| import requests | |
| from src.config import ( | |
| ARXIV_LARGE_CATS, | |
| ARXIV_SMALL_CATS, | |
| EXCLUDE_RE, | |
| GITHUB_URL_RE, | |
| HF_API, | |
| HF_MAX_AGE_DAYS, | |
| INCLUDE_RE, | |
| MAX_ABSTRACT_CHARS_AIML, | |
| ) | |
| from src.db import create_run, finish_run, insert_papers | |
| log = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # HuggingFace API | |
| # --------------------------------------------------------------------------- | |
| def fetch_hf_daily(date_str: str) -> list[dict]: | |
| """Fetch HF Daily Papers for a given date.""" | |
| url = f"{HF_API}/daily_papers?date={date_str}" | |
| try: | |
| resp = requests.get(url, timeout=30) | |
| resp.raise_for_status() | |
| return resp.json() | |
| except (requests.RequestException, ValueError): | |
| return [] | |
| def fetch_hf_trending(limit: int = 50) -> list[dict]: | |
| """Fetch HF trending papers.""" | |
| url = f"{HF_API}/daily_papers?sort=trending&limit={limit}" | |
| try: | |
| resp = requests.get(url, timeout=30) | |
| resp.raise_for_status() | |
| return resp.json() | |
| except (requests.RequestException, ValueError): | |
| return [] | |
| def arxiv_id_to_date(arxiv_id: str) -> datetime | None: | |
| """Extract approximate publication date from arXiv ID (YYMM.NNNNN).""" | |
| match = re.match(r"(\d{2})(\d{2})\.\d+", arxiv_id) | |
| if not match: | |
| return None | |
| year = 2000 + int(match.group(1)) | |
| month = int(match.group(2)) | |
| if not (1 <= month <= 12): | |
| return None | |
| return datetime(year, month, 1, tzinfo=timezone.utc) | |
| def normalize_hf_paper(hf_entry: dict) -> dict | None: | |
| """Convert an HF daily_papers entry to our normalized format. | |
| Returns None if the paper is too old. | |
| """ | |
| paper = hf_entry.get("paper", hf_entry) | |
| arxiv_id = paper.get("id", "") | |
| authors_raw = paper.get("authors", []) | |
| authors = [] | |
| for a in authors_raw: | |
| if isinstance(a, dict): | |
| name = a.get("name", a.get("user", {}).get("fullname", "")) | |
| if name: | |
| authors.append(name) | |
| elif isinstance(a, str): | |
| authors.append(a) | |
| github_repo = hf_entry.get("githubRepo") or paper.get("githubRepo") or "" | |
| pub_date = arxiv_id_to_date(arxiv_id) | |
| if pub_date and (datetime.now(timezone.utc) - pub_date).days > HF_MAX_AGE_DAYS: | |
| return None | |
| return { | |
| "arxiv_id": arxiv_id, | |
| "title": paper.get("title", "").replace("\n", " ").strip(), | |
| "authors": authors[:10], | |
| "abstract": paper.get("summary", paper.get("abstract", "")).replace("\n", " ").strip(), | |
| "published": paper.get("publishedAt", paper.get("published", "")), | |
| "categories": paper.get("categories", []), | |
| "pdf_url": f"https://arxiv.org/pdf/{arxiv_id}" if arxiv_id else "", | |
| "arxiv_url": f"https://arxiv.org/abs/{arxiv_id}" if arxiv_id else "", | |
| "comment": "", | |
| "source": "hf", | |
| "hf_upvotes": hf_entry.get("paper", {}).get("upvotes", hf_entry.get("upvotes", 0)), | |
| "github_repo": github_repo, | |
| "github_stars": None, | |
| "hf_models": [], | |
| "hf_datasets": [], | |
| "hf_spaces": [], | |
| } | |
| # --------------------------------------------------------------------------- | |
| # arXiv fetching | |
| # --------------------------------------------------------------------------- | |
| def fetch_arxiv_category( | |
| cat: str, | |
| start: datetime, | |
| end: datetime, | |
| max_results: int, | |
| filter_keywords: bool, | |
| ) -> list[dict]: | |
| """Fetch papers from a single arXiv category.""" | |
| client = arxiv.Client(page_size=200, delay_seconds=3.0, num_retries=3) | |
| query = arxiv.Search( | |
| query=f"cat:{cat}", | |
| max_results=max_results, | |
| sort_by=arxiv.SortCriterion.SubmittedDate, | |
| sort_order=arxiv.SortOrder.Descending, | |
| ) | |
| papers = [] | |
| for result in client.results(query): | |
| pub = result.published.replace(tzinfo=timezone.utc) | |
| if pub < start: | |
| break | |
| if pub > end: | |
| continue | |
| if filter_keywords: | |
| text = f"{result.title} {result.summary}" | |
| if not INCLUDE_RE.search(text): | |
| continue | |
| if EXCLUDE_RE.search(text): | |
| continue | |
| papers.append(_arxiv_result_to_dict(result)) | |
| return papers | |
| def _arxiv_result_to_dict(result: arxiv.Result) -> dict: | |
| """Convert an arxiv.Result to our normalized format.""" | |
| arxiv_id = result.entry_id.split("/abs/")[-1] | |
| base_id = re.sub(r"v\d+$", "", arxiv_id) | |
| github_urls = GITHUB_URL_RE.findall(f"{result.summary} {result.comment or ''}") | |
| github_repo = github_urls[0].rstrip(".") if github_urls else "" | |
| return { | |
| "arxiv_id": base_id, | |
| "title": result.title.replace("\n", " ").strip(), | |
| "authors": [a.name for a in result.authors[:10]], | |
| "abstract": result.summary.replace("\n", " ").strip(), | |
| "published": result.published.isoformat(), | |
| "categories": list(result.categories), | |
| "pdf_url": result.pdf_url, | |
| "arxiv_url": result.entry_id, | |
| "comment": (result.comment or "").replace("\n", " ").strip(), | |
| "source": "arxiv", | |
| "hf_upvotes": 0, | |
| "github_repo": github_repo, | |
| "github_stars": None, | |
| "hf_models": [], | |
| "hf_datasets": [], | |
| "hf_spaces": [], | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Enrichment | |
| # --------------------------------------------------------------------------- | |
| def enrich_paper(paper: dict) -> dict: | |
| """Query HF API for linked models, datasets, and spaces.""" | |
| arxiv_id = paper["arxiv_id"] | |
| if not arxiv_id: | |
| return paper | |
| base_id = re.sub(r"v\d+$", "", arxiv_id) | |
| for resource, key, limit in [ | |
| ("models", "hf_models", 5), | |
| ("datasets", "hf_datasets", 3), | |
| ("spaces", "hf_spaces", 3), | |
| ]: | |
| url = f"{HF_API}/{resource}?filter=arxiv:{base_id}&limit={limit}&sort=likes" | |
| try: | |
| resp = requests.get(url, timeout=15) | |
| if resp.ok: | |
| items = resp.json() | |
| paper[key] = [ | |
| {"id": item.get("id", item.get("_id", "")), "likes": item.get("likes", 0)} | |
| for item in items | |
| ] | |
| except (requests.RequestException, ValueError): | |
| pass | |
| time.sleep(0.2) | |
| return paper | |
| # --------------------------------------------------------------------------- | |
| # Merge | |
| # --------------------------------------------------------------------------- | |
| def merge_papers(hf_papers: list[dict], arxiv_papers: list[dict]) -> list[dict]: | |
| """Deduplicate by arXiv ID. When both sources have a paper, merge.""" | |
| by_id: dict[str, dict] = {} | |
| for p in arxiv_papers: | |
| aid = re.sub(r"v\d+$", "", p["arxiv_id"]) | |
| if aid: | |
| by_id[aid] = p | |
| for p in hf_papers: | |
| aid = re.sub(r"v\d+$", "", p["arxiv_id"]) | |
| if not aid: | |
| continue | |
| if aid in by_id: | |
| existing = by_id[aid] | |
| existing["source"] = "both" | |
| existing["hf_upvotes"] = max(existing.get("hf_upvotes", 0), p.get("hf_upvotes", 0)) | |
| if p.get("github_repo") and not existing.get("github_repo"): | |
| existing["github_repo"] = p["github_repo"] | |
| if not existing.get("categories") and p.get("categories"): | |
| existing["categories"] = p["categories"] | |
| else: | |
| by_id[aid] = p | |
| return list(by_id.values()) | |
| # --------------------------------------------------------------------------- | |
| # Pipeline entry point | |
| # --------------------------------------------------------------------------- | |
| def run_aiml_pipeline( | |
| start: datetime | None = None, | |
| end: datetime | None = None, | |
| max_papers: int = 300, | |
| skip_enrich: bool = False, | |
| ) -> int: | |
| """Run the full AI/ML pipeline. Returns the run ID.""" | |
| if end is None: | |
| end = datetime.now(timezone.utc) | |
| if start is None: | |
| start = end - timedelta(days=7) | |
| # Ensure timezone-aware | |
| if start.tzinfo is None: | |
| start = start.replace(tzinfo=timezone.utc) | |
| if end.tzinfo is None: | |
| end = end.replace(tzinfo=timezone.utc, hour=23, minute=59, second=59) | |
| run_id = create_run("aiml", start.date().isoformat(), end.date().isoformat()) | |
| log.info("Run %d: %s to %s", run_id, start.date(), end.date()) | |
| try: | |
| # Step 1: Fetch HF papers | |
| log.info("Fetching HuggingFace Daily Papers ...") | |
| hf_papers_raw = [] | |
| current = start | |
| while current <= end: | |
| date_str = current.strftime("%Y-%m-%d") | |
| daily = fetch_hf_daily(date_str) | |
| hf_papers_raw.extend(daily) | |
| current += timedelta(days=1) | |
| trending = fetch_hf_trending(limit=50) | |
| hf_papers_raw.extend(trending) | |
| hf_papers = [p for p in (normalize_hf_paper(e) for e in hf_papers_raw) if p is not None] | |
| log.info("HF papers: %d", len(hf_papers)) | |
| # Step 2: Fetch arXiv papers | |
| log.info("Fetching arXiv papers ...") | |
| arxiv_papers = [] | |
| for cat in ARXIV_LARGE_CATS: | |
| papers = fetch_arxiv_category(cat, start, end, max_papers, filter_keywords=True) | |
| arxiv_papers.extend(papers) | |
| log.info(" %s: %d papers (keyword-filtered)", cat, len(papers)) | |
| for cat in ARXIV_SMALL_CATS: | |
| papers = fetch_arxiv_category(cat, start, end, max_papers, filter_keywords=False) | |
| arxiv_papers.extend(papers) | |
| log.info(" %s: %d papers", cat, len(papers)) | |
| # Step 3: Merge | |
| all_papers = merge_papers(hf_papers, arxiv_papers) | |
| log.info("Merged: %d unique papers", len(all_papers)) | |
| # Step 4: Enrich | |
| if not skip_enrich: | |
| log.info("Enriching with HF ecosystem links ...") | |
| for i, paper in enumerate(all_papers): | |
| all_papers[i] = enrich_paper(paper) | |
| if (i + 1) % 25 == 0: | |
| log.info(" Enriched %d/%d ...", i + 1, len(all_papers)) | |
| log.info("Enrichment complete") | |
| # Step 5: Insert into DB | |
| insert_papers(all_papers, run_id, "aiml") | |
| finish_run(run_id, len(all_papers)) | |
| log.info("Done — %d papers inserted", len(all_papers)) | |
| return run_id | |
| except Exception as e: | |
| finish_run(run_id, 0, status="failed") | |
| log.exception("Pipeline failed") | |
| raise | |