Researcher / src /pipelines /aiml.py
amarck's picture
Initial commit: Research Intelligence System
a0f27fa
"""AI/ML paper pipeline.
Fetches papers from HuggingFace Daily Papers + arXiv, enriches with
HF ecosystem metadata, and writes to the database.
"""
import logging
import re
import time
from datetime import datetime, timedelta, timezone
import arxiv
import requests
from src.config import (
ARXIV_LARGE_CATS,
ARXIV_SMALL_CATS,
EXCLUDE_RE,
GITHUB_URL_RE,
HF_API,
HF_MAX_AGE_DAYS,
INCLUDE_RE,
MAX_ABSTRACT_CHARS_AIML,
)
from src.db import create_run, finish_run, insert_papers
log = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# HuggingFace API
# ---------------------------------------------------------------------------
def fetch_hf_daily(date_str: str) -> list[dict]:
"""Fetch HF Daily Papers for a given date."""
url = f"{HF_API}/daily_papers?date={date_str}"
try:
resp = requests.get(url, timeout=30)
resp.raise_for_status()
return resp.json()
except (requests.RequestException, ValueError):
return []
def fetch_hf_trending(limit: int = 50) -> list[dict]:
"""Fetch HF trending papers."""
url = f"{HF_API}/daily_papers?sort=trending&limit={limit}"
try:
resp = requests.get(url, timeout=30)
resp.raise_for_status()
return resp.json()
except (requests.RequestException, ValueError):
return []
def arxiv_id_to_date(arxiv_id: str) -> datetime | None:
"""Extract approximate publication date from arXiv ID (YYMM.NNNNN)."""
match = re.match(r"(\d{2})(\d{2})\.\d+", arxiv_id)
if not match:
return None
year = 2000 + int(match.group(1))
month = int(match.group(2))
if not (1 <= month <= 12):
return None
return datetime(year, month, 1, tzinfo=timezone.utc)
def normalize_hf_paper(hf_entry: dict) -> dict | None:
"""Convert an HF daily_papers entry to our normalized format.
Returns None if the paper is too old.
"""
paper = hf_entry.get("paper", hf_entry)
arxiv_id = paper.get("id", "")
authors_raw = paper.get("authors", [])
authors = []
for a in authors_raw:
if isinstance(a, dict):
name = a.get("name", a.get("user", {}).get("fullname", ""))
if name:
authors.append(name)
elif isinstance(a, str):
authors.append(a)
github_repo = hf_entry.get("githubRepo") or paper.get("githubRepo") or ""
pub_date = arxiv_id_to_date(arxiv_id)
if pub_date and (datetime.now(timezone.utc) - pub_date).days > HF_MAX_AGE_DAYS:
return None
return {
"arxiv_id": arxiv_id,
"title": paper.get("title", "").replace("\n", " ").strip(),
"authors": authors[:10],
"abstract": paper.get("summary", paper.get("abstract", "")).replace("\n", " ").strip(),
"published": paper.get("publishedAt", paper.get("published", "")),
"categories": paper.get("categories", []),
"pdf_url": f"https://arxiv.org/pdf/{arxiv_id}" if arxiv_id else "",
"arxiv_url": f"https://arxiv.org/abs/{arxiv_id}" if arxiv_id else "",
"comment": "",
"source": "hf",
"hf_upvotes": hf_entry.get("paper", {}).get("upvotes", hf_entry.get("upvotes", 0)),
"github_repo": github_repo,
"github_stars": None,
"hf_models": [],
"hf_datasets": [],
"hf_spaces": [],
}
# ---------------------------------------------------------------------------
# arXiv fetching
# ---------------------------------------------------------------------------
def fetch_arxiv_category(
cat: str,
start: datetime,
end: datetime,
max_results: int,
filter_keywords: bool,
) -> list[dict]:
"""Fetch papers from a single arXiv category."""
client = arxiv.Client(page_size=200, delay_seconds=3.0, num_retries=3)
query = arxiv.Search(
query=f"cat:{cat}",
max_results=max_results,
sort_by=arxiv.SortCriterion.SubmittedDate,
sort_order=arxiv.SortOrder.Descending,
)
papers = []
for result in client.results(query):
pub = result.published.replace(tzinfo=timezone.utc)
if pub < start:
break
if pub > end:
continue
if filter_keywords:
text = f"{result.title} {result.summary}"
if not INCLUDE_RE.search(text):
continue
if EXCLUDE_RE.search(text):
continue
papers.append(_arxiv_result_to_dict(result))
return papers
def _arxiv_result_to_dict(result: arxiv.Result) -> dict:
"""Convert an arxiv.Result to our normalized format."""
arxiv_id = result.entry_id.split("/abs/")[-1]
base_id = re.sub(r"v\d+$", "", arxiv_id)
github_urls = GITHUB_URL_RE.findall(f"{result.summary} {result.comment or ''}")
github_repo = github_urls[0].rstrip(".") if github_urls else ""
return {
"arxiv_id": base_id,
"title": result.title.replace("\n", " ").strip(),
"authors": [a.name for a in result.authors[:10]],
"abstract": result.summary.replace("\n", " ").strip(),
"published": result.published.isoformat(),
"categories": list(result.categories),
"pdf_url": result.pdf_url,
"arxiv_url": result.entry_id,
"comment": (result.comment or "").replace("\n", " ").strip(),
"source": "arxiv",
"hf_upvotes": 0,
"github_repo": github_repo,
"github_stars": None,
"hf_models": [],
"hf_datasets": [],
"hf_spaces": [],
}
# ---------------------------------------------------------------------------
# Enrichment
# ---------------------------------------------------------------------------
def enrich_paper(paper: dict) -> dict:
"""Query HF API for linked models, datasets, and spaces."""
arxiv_id = paper["arxiv_id"]
if not arxiv_id:
return paper
base_id = re.sub(r"v\d+$", "", arxiv_id)
for resource, key, limit in [
("models", "hf_models", 5),
("datasets", "hf_datasets", 3),
("spaces", "hf_spaces", 3),
]:
url = f"{HF_API}/{resource}?filter=arxiv:{base_id}&limit={limit}&sort=likes"
try:
resp = requests.get(url, timeout=15)
if resp.ok:
items = resp.json()
paper[key] = [
{"id": item.get("id", item.get("_id", "")), "likes": item.get("likes", 0)}
for item in items
]
except (requests.RequestException, ValueError):
pass
time.sleep(0.2)
return paper
# ---------------------------------------------------------------------------
# Merge
# ---------------------------------------------------------------------------
def merge_papers(hf_papers: list[dict], arxiv_papers: list[dict]) -> list[dict]:
"""Deduplicate by arXiv ID. When both sources have a paper, merge."""
by_id: dict[str, dict] = {}
for p in arxiv_papers:
aid = re.sub(r"v\d+$", "", p["arxiv_id"])
if aid:
by_id[aid] = p
for p in hf_papers:
aid = re.sub(r"v\d+$", "", p["arxiv_id"])
if not aid:
continue
if aid in by_id:
existing = by_id[aid]
existing["source"] = "both"
existing["hf_upvotes"] = max(existing.get("hf_upvotes", 0), p.get("hf_upvotes", 0))
if p.get("github_repo") and not existing.get("github_repo"):
existing["github_repo"] = p["github_repo"]
if not existing.get("categories") and p.get("categories"):
existing["categories"] = p["categories"]
else:
by_id[aid] = p
return list(by_id.values())
# ---------------------------------------------------------------------------
# Pipeline entry point
# ---------------------------------------------------------------------------
def run_aiml_pipeline(
start: datetime | None = None,
end: datetime | None = None,
max_papers: int = 300,
skip_enrich: bool = False,
) -> int:
"""Run the full AI/ML pipeline. Returns the run ID."""
if end is None:
end = datetime.now(timezone.utc)
if start is None:
start = end - timedelta(days=7)
# Ensure timezone-aware
if start.tzinfo is None:
start = start.replace(tzinfo=timezone.utc)
if end.tzinfo is None:
end = end.replace(tzinfo=timezone.utc, hour=23, minute=59, second=59)
run_id = create_run("aiml", start.date().isoformat(), end.date().isoformat())
log.info("Run %d: %s to %s", run_id, start.date(), end.date())
try:
# Step 1: Fetch HF papers
log.info("Fetching HuggingFace Daily Papers ...")
hf_papers_raw = []
current = start
while current <= end:
date_str = current.strftime("%Y-%m-%d")
daily = fetch_hf_daily(date_str)
hf_papers_raw.extend(daily)
current += timedelta(days=1)
trending = fetch_hf_trending(limit=50)
hf_papers_raw.extend(trending)
hf_papers = [p for p in (normalize_hf_paper(e) for e in hf_papers_raw) if p is not None]
log.info("HF papers: %d", len(hf_papers))
# Step 2: Fetch arXiv papers
log.info("Fetching arXiv papers ...")
arxiv_papers = []
for cat in ARXIV_LARGE_CATS:
papers = fetch_arxiv_category(cat, start, end, max_papers, filter_keywords=True)
arxiv_papers.extend(papers)
log.info(" %s: %d papers (keyword-filtered)", cat, len(papers))
for cat in ARXIV_SMALL_CATS:
papers = fetch_arxiv_category(cat, start, end, max_papers, filter_keywords=False)
arxiv_papers.extend(papers)
log.info(" %s: %d papers", cat, len(papers))
# Step 3: Merge
all_papers = merge_papers(hf_papers, arxiv_papers)
log.info("Merged: %d unique papers", len(all_papers))
# Step 4: Enrich
if not skip_enrich:
log.info("Enriching with HF ecosystem links ...")
for i, paper in enumerate(all_papers):
all_papers[i] = enrich_paper(paper)
if (i + 1) % 25 == 0:
log.info(" Enriched %d/%d ...", i + 1, len(all_papers))
log.info("Enrichment complete")
# Step 5: Insert into DB
insert_papers(all_papers, run_id, "aiml")
finish_run(run_id, len(all_papers))
log.info("Done — %d papers inserted", len(all_papers))
return run_id
except Exception as e:
finish_run(run_id, 0, status="failed")
log.exception("Pipeline failed")
raise