| | """ |
| | FastAPI + HTMX app for browsing arxiv papers with new ML datasets. |
| | Downloads Lance dataset from HuggingFace Hub and loads locally. |
| | """ |
| |
|
| | import math |
| | import re |
| | from datetime import date, timedelta |
| | from functools import lru_cache |
| | from typing import Optional |
| | from urllib.parse import urlencode |
| |
|
| | import lance |
| | import polars as pl |
| | from cachetools import TTLCache |
| | from dotenv import load_dotenv |
| | from fastapi import FastAPI, Query, Request |
| | from fastapi.responses import HTMLResponse, RedirectResponse |
| | from fastapi.staticfiles import StaticFiles |
| | from fastapi.templating import Jinja2Templates |
| | from huggingface_hub import snapshot_download |
| | from markupsafe import Markup |
| |
|
| | |
| | load_dotenv() |
| |
|
| | app = FastAPI(title="ArXiv New ML Datasets") |
| | app.mount("/static", StaticFiles(directory="static"), name="static") |
| | templates = Jinja2Templates(directory="templates") |
| |
|
| |
|
| | def highlight_search(text: str, search: str) -> Markup: |
| | """Highlight search terms in text with yellow background.""" |
| | if not search or not text: |
| | return Markup(text) if text else Markup("") |
| |
|
| | |
| | import html |
| | text = html.escape(str(text)) |
| |
|
| | |
| | pattern = re.compile(re.escape(search), re.IGNORECASE) |
| | highlighted = pattern.sub( |
| | lambda m: f'<mark class="bg-yellow-200 px-0.5 rounded">{m.group()}</mark>', |
| | text |
| | ) |
| | return Markup(highlighted) |
| |
|
| |
|
| | |
| | templates.env.filters["highlight"] = highlight_search |
| |
|
| |
|
| | def confidence_fmt(score): |
| | """Format confidence as percentage, truncating to 1 decimal to avoid rounding 99.95->100.""" |
| | pct = math.floor(score * 1000) / 10 |
| | return f"{pct:.1f}" |
| |
|
| |
|
| | templates.env.filters["confidence"] = confidence_fmt |
| |
|
| | |
| | DATASET_REPO = "librarian-bots/arxiv-cs-papers-lance" |
| |
|
| | |
| | _dataset_cache: TTLCache = TTLCache(maxsize=1, ttl=60 * 60 * 6) |
| |
|
| | |
| | _lance_cache: dict = {} |
| |
|
| | |
| | _model_cache: dict = {} |
| |
|
| |
|
| | def get_lance_dataset(): |
| | """Download dataset from HF Hub (cached) and return Lance connection.""" |
| | if "ds" not in _lance_cache: |
| | import os |
| | |
| | cache_base = os.environ.get("HF_HOME", "/tmp/hf_cache") |
| | local_dir = f"{cache_base}/arxiv-lance" |
| | print(f"Downloading dataset from {DATASET_REPO} to {local_dir}...") |
| | snapshot_download( |
| | DATASET_REPO, |
| | repo_type="dataset", |
| | local_dir=local_dir, |
| | ) |
| | lance_path = f"{local_dir}/data/train.lance" |
| | print(f"Loading Lance dataset from {lance_path}") |
| | _lance_cache["ds"] = lance.dataset(lance_path) |
| | return _lance_cache["ds"] |
| |
|
| |
|
| | def get_embedding_model(): |
| | """Load embedding model (cached, lazy-loaded on first semantic search).""" |
| | if "model" not in _model_cache: |
| | from sentence_transformers import SentenceTransformer |
| | print("Loading embedding model...") |
| | _model_cache["model"] = SentenceTransformer("BAAI/bge-base-en-v1.5") |
| | print("Embedding model loaded!") |
| | return _model_cache["model"] |
| |
|
| |
|
| | def get_dataframe() -> pl.DataFrame: |
| | """Load Lance dataset and convert to Polars DataFrame.""" |
| | cache_key = "df" |
| | if cache_key in _dataset_cache: |
| | return _dataset_cache[cache_key] |
| |
|
| | ds = get_lance_dataset() |
| | |
| | columns = [ |
| | "id", "title", "abstract", "categories", "update_date", |
| | "authors", "is_new_dataset", "confidence_score" |
| | ] |
| | arrow_table = ds.to_table(columns=columns) |
| | df = pl.from_arrow(arrow_table) |
| | _dataset_cache[cache_key] = df |
| | print(f"Loaded {len(df):,} papers") |
| | return df |
| |
|
| |
|
| | @lru_cache(maxsize=1) |
| | def get_categories() -> list[str]: |
| | """Get unique category prefixes for filtering.""" |
| | df = get_dataframe() |
| | |
| | categories = ( |
| | df.select(pl.col("categories").str.split(" ").list.first().alias("cat")) |
| | .unique() |
| | .sort("cat") |
| | .to_series() |
| | .to_list() |
| | ) |
| | |
| | ml_cats = ["cs.AI", "cs.CL", "cs.CV", "cs.LG", "cs.NE", "cs.IR", "cs.RO", "stat.ML"] |
| | return [c for c in ml_cats if c in categories] |
| |
|
| |
|
| | @lru_cache(maxsize=1) |
| | def get_confidence_options() -> list[dict]: |
| | """Compute confidence filter options from actual data distribution. |
| | |
| | Uses percentiles so the UI adapts to any model's score range. |
| | """ |
| | df = get_dataframe() |
| | scores = df.filter(pl.col("is_new_dataset"))["confidence_score"] |
| |
|
| | options = [{"value": "0.5", "label": "All new datasets", "count": len(scores)}] |
| |
|
| | for pct_label, quantile in [("Top 75%", 0.25), ("Top 50%", 0.50), ("Top 25%", 0.75)]: |
| | threshold = float(scores.quantile(quantile)) |
| | count = scores.filter(scores >= threshold).len() |
| | options.append({ |
| | "value": f"{threshold:.2f}", |
| | "label": pct_label, |
| | "count": int(count), |
| | }) |
| |
|
| | options.append({"value": "0", "label": "All papers", "count": len(df)}) |
| | return options |
| |
|
| |
|
| | @lru_cache(maxsize=1) |
| | def get_histogram_data() -> dict: |
| | """Get confidence distribution data for histogram display. |
| | |
| | Dynamically determines the range from actual data distribution. |
| | Returns dict with bins and metadata. The 50% line marks the prediction boundary. |
| | """ |
| | df = get_dataframe() |
| |
|
| | |
| | all_papers = df.select("confidence_score", "is_new_dataset") |
| |
|
| | |
| | |
| | actual_min = float(all_papers["confidence_score"].min()) |
| | actual_max = float(all_papers["confidence_score"].max()) |
| |
|
| | |
| | min_pct = max(0, (int(actual_min * 20) / 20)) |
| | max_pct = min(1, ((int(actual_max * 20) + 1) / 20)) |
| |
|
| | |
| | if max_pct - min_pct < 0.25: |
| | center = (min_pct + max_pct) / 2 |
| | min_pct = max(0, center - 0.125) |
| | max_pct = min(1, center + 0.125) |
| |
|
| | |
| | num_bins = 25 |
| | bin_width = (max_pct - min_pct) / num_bins |
| |
|
| | bins = [] |
| | for i in range(num_bins): |
| | bin_start = min_pct + i * bin_width |
| | bin_end = min_pct + (i + 1) * bin_width |
| |
|
| | |
| | count = all_papers.filter( |
| | (pl.col("confidence_score") >= bin_start) & |
| | (pl.col("confidence_score") < bin_end) |
| | ).height |
| |
|
| | |
| | new_dataset_count = all_papers.filter( |
| | (pl.col("confidence_score") >= bin_start) & |
| | (pl.col("confidence_score") < bin_end) & |
| | (pl.col("is_new_dataset")) |
| | ).height |
| |
|
| | bins.append({ |
| | "bin_start": round(bin_start, 3), |
| | "bin_end": round(bin_end, 3), |
| | "bin_pct": int(bin_start * 100), |
| | "count": count, |
| | "new_dataset_count": new_dataset_count, |
| | }) |
| |
|
| | |
| | max_count = max(b["count"] for b in bins) if bins else 1 |
| | for b in bins: |
| | b["height_pct"] = int((b["count"] / max_count) * 100) if max_count > 0 else 0 |
| | b["new_height_pct"] = int((b["new_dataset_count"] / max_count) * 100) if max_count > 0 else 0 |
| |
|
| | |
| | |
| | total_so_far = all_papers.height |
| | for b in bins: |
| | b["papers_above"] = total_so_far |
| | total_so_far -= b["count"] |
| |
|
| | return { |
| | "bins": bins, |
| | "min_pct": round(min_pct, 2), |
| | "max_pct": round(max_pct, 2), |
| | "total_papers": all_papers.height, |
| | "new_dataset_count": all_papers.filter(pl.col("is_new_dataset")).height, |
| | } |
| |
|
| |
|
| | def parse_since(since: str) -> Optional[date]: |
| | """Parse 'since' parameter to a date. Returns None for 'all time'.""" |
| | if not since: |
| | return None |
| | today = date.today() |
| | if since == "1m": |
| | return today - timedelta(days=30) |
| | elif since == "6m": |
| | return today - timedelta(days=180) |
| | elif since == "1y": |
| | return today - timedelta(days=365) |
| | return None |
| |
|
| |
|
| | def filter_papers( |
| | df: pl.DataFrame, |
| | category: Optional[str] = None, |
| | search: Optional[str] = None, |
| | min_confidence: float = 0.5, |
| | since: Optional[str] = None, |
| | ) -> pl.DataFrame: |
| | """Apply filters to the papers dataframe. |
| | |
| | The confidence threshold controls which papers are shown: |
| | - Papers with is_new_dataset=True have confidence >= 0.5 |
| | - Setting threshold to 0 shows all papers |
| | - Setting threshold >= 0.5 effectively shows only new_dataset papers |
| | """ |
| | if min_confidence >= 0.5: |
| | |
| | df = df.filter( |
| | pl.col("is_new_dataset") & (pl.col("confidence_score") >= min_confidence) |
| | ) |
| | elif min_confidence > 0: |
| | df = df.filter(pl.col("confidence_score") >= min_confidence) |
| |
|
| | if category: |
| | df = df.filter(pl.col("categories").str.contains(category)) |
| |
|
| | if search: |
| | search_lower = search.lower() |
| | df = df.filter( |
| | pl.col("title").str.to_lowercase().str.contains(search_lower) |
| | | pl.col("abstract").str.to_lowercase().str.contains(search_lower) |
| | ) |
| |
|
| | |
| | min_date = parse_since(since) |
| | if min_date: |
| | df = df.filter(pl.col("update_date") >= min_date) |
| |
|
| | return df |
| |
|
| |
|
| | def paginate_papers( |
| | df: pl.DataFrame, |
| | page: int = 1, |
| | per_page: int = 20, |
| | sort: str = "date", |
| | ) -> tuple[pl.DataFrame, bool]: |
| | """Sort and paginate papers, return (page_df, has_more). |
| | |
| | Sort options: |
| | - "date": By update_date desc, then confidence_score desc |
| | - "relevance": Keep existing order (for semantic search similarity) |
| | """ |
| | if sort == "date": |
| | df_sorted = df.sort( |
| | ["update_date", "confidence_score"], descending=[True, True] |
| | ) |
| | else: |
| | |
| | df_sorted = df |
| |
|
| | start = (page - 1) * per_page |
| | page_df = df_sorted.slice(start, per_page + 1) |
| | has_more = len(page_df) > per_page |
| |
|
| | return page_df.head(per_page), has_more |
| |
|
| |
|
| | def semantic_search( |
| | query: str, |
| | k: int = 100, |
| | category: Optional[str] = None, |
| | min_confidence: float = 0.5, |
| | since: Optional[str] = None, |
| | ) -> pl.DataFrame: |
| | """Search using vector similarity via Lance nearest neighbor. |
| | |
| | Returns DataFrame with similarity_score column (0-1, higher is more similar). |
| | """ |
| | model = get_embedding_model() |
| | query_embedding = model.encode(query).tolist() |
| |
|
| | ds = get_lance_dataset() |
| |
|
| | |
| | filters = [] |
| | if min_confidence >= 0.5: |
| | filters.append("is_new_dataset = true") |
| | filters.append(f"confidence_score >= {min_confidence}") |
| | elif min_confidence > 0: |
| | filters.append(f"confidence_score >= {min_confidence}") |
| | if category: |
| | |
| | safe_category = category.replace("'", "''") |
| | filters.append(f"categories LIKE '%{safe_category}%'") |
| | |
| | min_date = parse_since(since) |
| | if min_date: |
| | filters.append(f"update_date >= TIMESTAMP '{min_date.isoformat()} 00:00:00'") |
| | filter_str = " AND ".join(filters) if filters else None |
| |
|
| | |
| | results = ds.scanner( |
| | nearest={"column": "embedding", "q": query_embedding, "k": k}, |
| | filter=filter_str, |
| | columns=["id", "title", "abstract", "categories", "update_date", |
| | "authors", "confidence_score", "_distance"] |
| | ).to_table() |
| |
|
| | df = pl.from_arrow(results) |
| |
|
| | |
| | |
| | |
| | df = df.with_columns( |
| | (1 - pl.col("_distance") / 2).clip(0, 1).alias("similarity_score") |
| | ).drop("_distance") |
| |
|
| | return df |
| |
|
| |
|
| | @app.get("/", response_class=HTMLResponse) |
| | async def home( |
| | request: Request, |
| | search: Optional[str] = Query(None), |
| | search_type: str = Query("keyword"), |
| | category: Optional[str] = Query(None), |
| | min_confidence: str = Query("0.5"), |
| | since: Optional[str] = Query(None), |
| | sort: str = Query("date"), |
| | ): |
| | """Render the home page with optional initial filter state from URL.""" |
| | df = get_dataframe() |
| | categories = get_categories() |
| | histogram_data = get_histogram_data() |
| | confidence_options = get_confidence_options() |
| |
|
| | |
| | total_papers = len(df) |
| | new_dataset_count = df.filter(pl.col("is_new_dataset")).height |
| |
|
| | return templates.TemplateResponse( |
| | "index.html", |
| | { |
| | "request": request, |
| | "categories": categories, |
| | "total_papers": total_papers, |
| | "new_dataset_count": new_dataset_count, |
| | "histogram_data": histogram_data, |
| | "confidence_options": confidence_options, |
| | |
| | "search": search or "", |
| | "search_type": search_type, |
| | "category": category or "", |
| | "min_confidence": min_confidence, |
| | "since": since or "", |
| | "sort": sort, |
| | }, |
| | ) |
| |
|
| |
|
| | @app.get("/papers", response_class=HTMLResponse) |
| | async def get_papers( |
| | request: Request, |
| | page: int = Query(1, ge=1), |
| | per_page: int = Query(20, ge=1, le=100), |
| | category: Optional[str] = Query(None), |
| | search: Optional[str] = Query(None), |
| | min_confidence: float = Query(0.5, ge=0, le=1), |
| | search_type: str = Query("keyword"), |
| | sort: str = Query("date"), |
| | since: Optional[str] = Query(None), |
| | ): |
| | """Get paginated and filtered papers (returns HTML partial for HTMX). |
| | |
| | If accessed directly (not via HTMX), redirects to home page with same params. |
| | """ |
| | |
| | if "HX-Request" not in request.headers: |
| | |
| | query_string = str(request.url.query) |
| | redirect_url = f"/?{query_string}" if query_string else "/" |
| | return RedirectResponse(url=redirect_url, status_code=302) |
| |
|
| | if search and search_type == "semantic": |
| | |
| | filtered_df = semantic_search( |
| | query=search, |
| | k=per_page * 5, |
| | category=category, |
| | min_confidence=min_confidence, |
| | since=since, |
| | ) |
| | |
| | effective_sort = sort if sort == "date" else "relevance" |
| | page_df, has_more = paginate_papers( |
| | filtered_df, page=page, per_page=per_page, sort=effective_sort |
| | ) |
| | else: |
| | |
| | df = get_dataframe() |
| | filtered_df = filter_papers( |
| | df, |
| | category=category, |
| | search=search, |
| | min_confidence=min_confidence, |
| | since=since, |
| | ) |
| | |
| | page_df, has_more = paginate_papers( |
| | filtered_df, page=page, per_page=per_page, sort="date" |
| | ) |
| |
|
| | |
| | papers = page_df.to_dicts() |
| |
|
| | |
| | |
| | params = {} |
| | if search: |
| | params["search"] = search |
| | if search_type != "keyword": |
| | params["search_type"] = search_type |
| | if category: |
| | params["category"] = category |
| | if min_confidence != 0.5: |
| | params["min_confidence"] = min_confidence |
| | if since: |
| | params["since"] = since |
| | if sort != "date": |
| | params["sort"] = sort |
| | push_url = "/?" + urlencode(params) if params else "/" |
| |
|
| | response = templates.TemplateResponse( |
| | "partials/paper_list.html", |
| | { |
| | "request": request, |
| | "papers": papers, |
| | "page": page, |
| | "has_more": has_more, |
| | "category": category or "", |
| | "search": search or "", |
| | "min_confidence": min_confidence, |
| | "search_type": search_type, |
| | "sort": sort, |
| | "since": since or "", |
| | "total_filtered": len(filtered_df), |
| | }, |
| | ) |
| | |
| | response.headers["HX-Push-Url"] = push_url |
| | return response |
| |
|
| |
|
| | @app.get("/api/stats") |
| | async def get_stats(): |
| | """Get dataset statistics as JSON.""" |
| | df = get_dataframe() |
| |
|
| | new_datasets = df.filter(pl.col("is_new_dataset")) |
| |
|
| | return { |
| | "total_papers": len(df), |
| | "new_dataset_count": len(new_datasets), |
| | "avg_confidence": float(df["confidence_score"].mean()), |
| | "date_range": { |
| | "min": str(df["update_date"].min()), |
| | "max": str(df["update_date"].max()), |
| | }, |
| | } |
| |
|
| |
|
| | |
| | @app.on_event("startup") |
| | async def startup_event(): |
| | """Preload dataset and embedding model on startup.""" |
| | print("Preloading dataset...") |
| | get_dataframe() |
| | print("Dataset loaded!") |
| | print("Preloading embedding model...") |
| | get_embedding_model() |
| | print("Embedding model loaded!") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import uvicorn |
| |
|
| | uvicorn.run(app, host="0.0.0.0", port=7860) |
| |
|