""" FastAPI + HTMX app for browsing arxiv papers with new ML datasets. Downloads Lance dataset from HuggingFace Hub and loads locally. """ import math import re from datetime import date, timedelta from functools import lru_cache from typing import Optional from urllib.parse import urlencode import lance import polars as pl from cachetools import TTLCache from dotenv import load_dotenv from fastapi import FastAPI, Query, Request from fastapi.responses import HTMLResponse, RedirectResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from huggingface_hub import snapshot_download from markupsafe import Markup # Load .env file for local development (HF_TOKEN) load_dotenv() app = FastAPI(title="ArXiv New ML Datasets") app.mount("/static", StaticFiles(directory="static"), name="static") templates = Jinja2Templates(directory="templates") def highlight_search(text: str, search: str) -> Markup: """Highlight search terms in text with yellow background.""" if not search or not text: return Markup(text) if text else Markup("") # Escape HTML in text first import html text = html.escape(str(text)) # Case-insensitive replacement with highlight span pattern = re.compile(re.escape(search), re.IGNORECASE) highlighted = pattern.sub( lambda m: f'{m.group()}', text ) return Markup(highlighted) # Register custom filters templates.env.filters["highlight"] = highlight_search def confidence_fmt(score): """Format confidence as percentage, truncating to 1 decimal to avoid rounding 99.95->100.""" pct = math.floor(score * 1000) / 10 return f"{pct:.1f}" templates.env.filters["confidence"] = confidence_fmt # Dataset config DATASET_REPO = "librarian-bots/arxiv-cs-papers-lance" # Cache for dataset (reload every 6 hours) _dataset_cache: TTLCache = TTLCache(maxsize=1, ttl=60 * 60 * 6) # Cache for Lance dataset connection (for vector search) _lance_cache: dict = {} # Cache for embedding model (lazy loaded on first semantic search) _model_cache: dict = {} def get_lance_dataset(): """Download dataset from HF Hub (cached) and return Lance connection.""" if "ds" not in _lance_cache: import os # Use HF_HOME or /tmp for Spaces compatibility (./data not writable on Spaces) cache_base = os.environ.get("HF_HOME", "/tmp/hf_cache") local_dir = f"{cache_base}/arxiv-lance" print(f"Downloading dataset from {DATASET_REPO} to {local_dir}...") snapshot_download( DATASET_REPO, repo_type="dataset", local_dir=local_dir, ) lance_path = f"{local_dir}/data/train.lance" print(f"Loading Lance dataset from {lance_path}") _lance_cache["ds"] = lance.dataset(lance_path) return _lance_cache["ds"] def get_embedding_model(): """Load embedding model (cached, lazy-loaded on first semantic search).""" if "model" not in _model_cache: from sentence_transformers import SentenceTransformer print("Loading embedding model...") _model_cache["model"] = SentenceTransformer("BAAI/bge-base-en-v1.5") print("Embedding model loaded!") return _model_cache["model"] def get_dataframe() -> pl.DataFrame: """Load Lance dataset and convert to Polars DataFrame.""" cache_key = "df" if cache_key in _dataset_cache: return _dataset_cache[cache_key] ds = get_lance_dataset() # Downloads from HF Hub if not cached # Select columns needed for filtering/display (exclude embeddings for memory) columns = [ "id", "title", "abstract", "categories", "update_date", "authors", "is_new_dataset", "confidence_score" ] arrow_table = ds.to_table(columns=columns) df = pl.from_arrow(arrow_table) _dataset_cache[cache_key] = df print(f"Loaded {len(df):,} papers") return df @lru_cache(maxsize=1) def get_categories() -> list[str]: """Get unique category prefixes for filtering.""" df = get_dataframe() # Extract primary category (before first space or as-is) categories = ( df.select(pl.col("categories").str.split(" ").list.first().alias("cat")) .unique() .sort("cat") .to_series() .to_list() ) # Get common ML-related categories ml_cats = ["cs.AI", "cs.CL", "cs.CV", "cs.LG", "cs.NE", "cs.IR", "cs.RO", "stat.ML"] return [c for c in ml_cats if c in categories] @lru_cache(maxsize=1) def get_confidence_options() -> list[dict]: """Compute confidence filter options from actual data distribution. Uses percentiles so the UI adapts to any model's score range. """ df = get_dataframe() scores = df.filter(pl.col("is_new_dataset"))["confidence_score"] options = [{"value": "0.5", "label": "All new datasets", "count": len(scores)}] for pct_label, quantile in [("Top 75%", 0.25), ("Top 50%", 0.50), ("Top 25%", 0.75)]: threshold = float(scores.quantile(quantile)) count = scores.filter(scores >= threshold).len() options.append({ "value": f"{threshold:.2f}", "label": pct_label, "count": int(count), }) options.append({"value": "0", "label": "All papers", "count": len(df)}) return options @lru_cache(maxsize=1) def get_histogram_data() -> dict: """Get confidence distribution data for histogram display. Dynamically determines the range from actual data distribution. Returns dict with bins and metadata. The 50% line marks the prediction boundary. """ df = get_dataframe() # Get all papers with confidence scores all_papers = df.select("confidence_score", "is_new_dataset") # Dynamically determine the range from actual data # Round to nearest 5% for clean boundaries actual_min = float(all_papers["confidence_score"].min()) actual_max = float(all_papers["confidence_score"].max()) # Round down to nearest 5% for min, round up for max min_pct = max(0, (int(actual_min * 20) / 20)) # Floor to 5% max_pct = min(1, ((int(actual_max * 20) + 1) / 20)) # Ceil to 5% # Ensure minimum range of 25% for usability if max_pct - min_pct < 0.25: center = (min_pct + max_pct) / 2 min_pct = max(0, center - 0.125) max_pct = min(1, center + 0.125) # Use 25 bins for good granularity num_bins = 25 bin_width = (max_pct - min_pct) / num_bins bins = [] for i in range(num_bins): bin_start = min_pct + i * bin_width bin_end = min_pct + (i + 1) * bin_width # Count papers in this bin count = all_papers.filter( (pl.col("confidence_score") >= bin_start) & (pl.col("confidence_score") < bin_end) ).height # Count new_dataset papers in this bin new_dataset_count = all_papers.filter( (pl.col("confidence_score") >= bin_start) & (pl.col("confidence_score") < bin_end) & (pl.col("is_new_dataset")) ).height bins.append({ "bin_start": round(bin_start, 3), "bin_end": round(bin_end, 3), "bin_pct": int(bin_start * 100), "count": count, "new_dataset_count": new_dataset_count, }) # Normalize counts for display (max height = 100%) max_count = max(b["count"] for b in bins) if bins else 1 for b in bins: b["height_pct"] = int((b["count"] / max_count) * 100) if max_count > 0 else 0 b["new_height_pct"] = int((b["new_dataset_count"] / max_count) * 100) if max_count > 0 else 0 # Calculate cumulative counts from each threshold # (how many papers are at or above this threshold) total_so_far = all_papers.height for b in bins: b["papers_above"] = total_so_far total_so_far -= b["count"] return { "bins": bins, "min_pct": round(min_pct, 2), "max_pct": round(max_pct, 2), "total_papers": all_papers.height, "new_dataset_count": all_papers.filter(pl.col("is_new_dataset")).height, } def parse_since(since: str) -> Optional[date]: """Parse 'since' parameter to a date. Returns None for 'all time'.""" if not since: return None today = date.today() if since == "1m": return today - timedelta(days=30) elif since == "6m": return today - timedelta(days=180) elif since == "1y": return today - timedelta(days=365) return None def filter_papers( df: pl.DataFrame, category: Optional[str] = None, search: Optional[str] = None, min_confidence: float = 0.5, since: Optional[str] = None, ) -> pl.DataFrame: """Apply filters to the papers dataframe. The confidence threshold controls which papers are shown: - Papers with is_new_dataset=True have confidence >= 0.5 - Setting threshold to 0 shows all papers - Setting threshold >= 0.5 effectively shows only new_dataset papers """ if min_confidence >= 0.5: # Show only papers classified as new datasets, filtered by confidence df = df.filter( pl.col("is_new_dataset") & (pl.col("confidence_score") >= min_confidence) ) elif min_confidence > 0: df = df.filter(pl.col("confidence_score") >= min_confidence) if category: df = df.filter(pl.col("categories").str.contains(category)) if search: search_lower = search.lower() df = df.filter( pl.col("title").str.to_lowercase().str.contains(search_lower) | pl.col("abstract").str.to_lowercase().str.contains(search_lower) ) # Date filter min_date = parse_since(since) if min_date: df = df.filter(pl.col("update_date") >= min_date) return df def paginate_papers( df: pl.DataFrame, page: int = 1, per_page: int = 20, sort: str = "date", ) -> tuple[pl.DataFrame, bool]: """Sort and paginate papers, return (page_df, has_more). Sort options: - "date": By update_date desc, then confidence_score desc - "relevance": Keep existing order (for semantic search similarity) """ if sort == "date": df_sorted = df.sort( ["update_date", "confidence_score"], descending=[True, True] ) else: # "relevance" - keep existing order (already sorted by similarity for semantic) df_sorted = df start = (page - 1) * per_page page_df = df_sorted.slice(start, per_page + 1) has_more = len(page_df) > per_page return page_df.head(per_page), has_more def semantic_search( query: str, k: int = 100, category: Optional[str] = None, min_confidence: float = 0.5, since: Optional[str] = None, ) -> pl.DataFrame: """Search using vector similarity via Lance nearest neighbor. Returns DataFrame with similarity_score column (0-1, higher is more similar). """ model = get_embedding_model() query_embedding = model.encode(query).tolist() ds = get_lance_dataset() # Build SQL filter (Lance supports SQL-like syntax) filters = [] if min_confidence >= 0.5: filters.append("is_new_dataset = true") filters.append(f"confidence_score >= {min_confidence}") elif min_confidence > 0: filters.append(f"confidence_score >= {min_confidence}") if category: # Escape single quotes in category name for SQL safety safe_category = category.replace("'", "''") filters.append(f"categories LIKE '%{safe_category}%'") # Date filter - use TIMESTAMP literal for Lance/DataFusion min_date = parse_since(since) if min_date: filters.append(f"update_date >= TIMESTAMP '{min_date.isoformat()} 00:00:00'") filter_str = " AND ".join(filters) if filters else None # Vector search - include _distance for similarity calculation results = ds.scanner( nearest={"column": "embedding", "q": query_embedding, "k": k}, filter=filter_str, columns=["id", "title", "abstract", "categories", "update_date", "authors", "confidence_score", "_distance"] ).to_table() df = pl.from_arrow(results) # Convert L2 distance to similarity score (0-1 range) # For normalized embeddings: similarity = 1 - distance/2 # BGE embeddings are normalized, so L2 distance ranges from 0 to 2 df = df.with_columns( (1 - pl.col("_distance") / 2).clip(0, 1).alias("similarity_score") ).drop("_distance") return df @app.get("/", response_class=HTMLResponse) async def home( request: Request, search: Optional[str] = Query(None), search_type: str = Query("keyword"), category: Optional[str] = Query(None), min_confidence: str = Query("0.5"), # String to preserve exact value for template since: Optional[str] = Query(None), sort: str = Query("date"), ): """Render the home page with optional initial filter state from URL.""" df = get_dataframe() categories = get_categories() histogram_data = get_histogram_data() confidence_options = get_confidence_options() # Get stats total_papers = len(df) new_dataset_count = df.filter(pl.col("is_new_dataset")).height return templates.TemplateResponse( "index.html", { "request": request, "categories": categories, "total_papers": total_papers, "new_dataset_count": new_dataset_count, "histogram_data": histogram_data, "confidence_options": confidence_options, # Pass filter state for URL persistence "search": search or "", "search_type": search_type, "category": category or "", "min_confidence": min_confidence, "since": since or "", "sort": sort, }, ) @app.get("/papers", response_class=HTMLResponse) async def get_papers( request: Request, page: int = Query(1, ge=1), per_page: int = Query(20, ge=1, le=100), category: Optional[str] = Query(None), search: Optional[str] = Query(None), min_confidence: float = Query(0.5, ge=0, le=1), search_type: str = Query("keyword"), # "keyword" or "semantic" sort: str = Query("date"), # "date" or "relevance" since: Optional[str] = Query(None), # "1m", "6m", "1y", or None for all ): """Get paginated and filtered papers (returns HTML partial for HTMX). If accessed directly (not via HTMX), redirects to home page with same params. """ # Redirect direct browser visits to home page (this endpoint returns partials) if "HX-Request" not in request.headers: # Build redirect URL with current query params query_string = str(request.url.query) redirect_url = f"/?{query_string}" if query_string else "/" return RedirectResponse(url=redirect_url, status_code=302) if search and search_type == "semantic": # Vector search - returns pre-sorted by similarity filtered_df = semantic_search( query=search, k=per_page * 5, # Get more for pagination buffer category=category, min_confidence=min_confidence, since=since, ) # Default to relevance sort for semantic, but allow date sort effective_sort = sort if sort == "date" else "relevance" page_df, has_more = paginate_papers( filtered_df, page=page, per_page=per_page, sort=effective_sort ) else: # Existing keyword search path df = get_dataframe() filtered_df = filter_papers( df, category=category, search=search, min_confidence=min_confidence, since=since, ) # Keyword search always sorts by date page_df, has_more = paginate_papers( filtered_df, page=page, per_page=per_page, sort="date" ) # Convert to list of dicts for template papers = page_df.to_dicts() # Build clean URL for browser history (/ instead of /papers) # Only include non-default values to keep URLs short params = {} if search: params["search"] = search if search_type != "keyword": params["search_type"] = search_type if category: params["category"] = category if min_confidence != 0.5: params["min_confidence"] = min_confidence if since: params["since"] = since if sort != "date": params["sort"] = sort push_url = "/?" + urlencode(params) if params else "/" response = templates.TemplateResponse( "partials/paper_list.html", { "request": request, "papers": papers, "page": page, "has_more": has_more, "category": category or "", "search": search or "", "min_confidence": min_confidence, "search_type": search_type, "sort": sort, "since": since or "", "total_filtered": len(filtered_df), }, ) # Tell HTMX to push clean URL (/ not /papers) response.headers["HX-Push-Url"] = push_url return response @app.get("/api/stats") async def get_stats(): """Get dataset statistics as JSON.""" df = get_dataframe() new_datasets = df.filter(pl.col("is_new_dataset")) return { "total_papers": len(df), "new_dataset_count": len(new_datasets), "avg_confidence": float(df["confidence_score"].mean()), "date_range": { "min": str(df["update_date"].min()), "max": str(df["update_date"].max()), }, } # Preload dataset and model on startup @app.on_event("startup") async def startup_event(): """Preload dataset and embedding model on startup.""" print("Preloading dataset...") get_dataframe() print("Dataset loaded!") print("Preloading embedding model...") get_embedding_model() print("Embedding model loaded!") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)