Spaces:
Sleeping
Sleeping
davanstrien HF Staff
Dynamic confidence thresholds: percentile-based dropdown adapts to any model
70197b9 verified | """ | |
| FastAPI + HTMX app for browsing arxiv papers with new ML datasets. | |
| Downloads Lance dataset from HuggingFace Hub and loads locally. | |
| """ | |
| import math | |
| import re | |
| from datetime import date, timedelta | |
| from functools import lru_cache | |
| from typing import Optional | |
| from urllib.parse import urlencode | |
| import lance | |
| import polars as pl | |
| from cachetools import TTLCache | |
| from dotenv import load_dotenv | |
| from fastapi import FastAPI, Query, Request | |
| from fastapi.responses import HTMLResponse, RedirectResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.templating import Jinja2Templates | |
| from huggingface_hub import snapshot_download | |
| from markupsafe import Markup | |
| # Load .env file for local development (HF_TOKEN) | |
| load_dotenv() | |
| app = FastAPI(title="ArXiv New ML Datasets") | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| templates = Jinja2Templates(directory="templates") | |
| def highlight_search(text: str, search: str) -> Markup: | |
| """Highlight search terms in text with yellow background.""" | |
| if not search or not text: | |
| return Markup(text) if text else Markup("") | |
| # Escape HTML in text first | |
| import html | |
| text = html.escape(str(text)) | |
| # Case-insensitive replacement with highlight span | |
| pattern = re.compile(re.escape(search), re.IGNORECASE) | |
| highlighted = pattern.sub( | |
| lambda m: f'<mark class="bg-yellow-200 px-0.5 rounded">{m.group()}</mark>', | |
| text | |
| ) | |
| return Markup(highlighted) | |
| # Register custom filters | |
| templates.env.filters["highlight"] = highlight_search | |
| def confidence_fmt(score): | |
| """Format confidence as percentage, truncating to 1 decimal to avoid rounding 99.95->100.""" | |
| pct = math.floor(score * 1000) / 10 | |
| return f"{pct:.1f}" | |
| templates.env.filters["confidence"] = confidence_fmt | |
| # Dataset config | |
| DATASET_REPO = "librarian-bots/arxiv-cs-papers-lance" | |
| # Cache for dataset (reload every 6 hours) | |
| _dataset_cache: TTLCache = TTLCache(maxsize=1, ttl=60 * 60 * 6) | |
| # Cache for Lance dataset connection (for vector search) | |
| _lance_cache: dict = {} | |
| # Cache for embedding model (lazy loaded on first semantic search) | |
| _model_cache: dict = {} | |
| def get_lance_dataset(): | |
| """Download dataset from HF Hub (cached) and return Lance connection.""" | |
| if "ds" not in _lance_cache: | |
| import os | |
| # Use HF_HOME or /tmp for Spaces compatibility (./data not writable on Spaces) | |
| cache_base = os.environ.get("HF_HOME", "/tmp/hf_cache") | |
| local_dir = f"{cache_base}/arxiv-lance" | |
| print(f"Downloading dataset from {DATASET_REPO} to {local_dir}...") | |
| snapshot_download( | |
| DATASET_REPO, | |
| repo_type="dataset", | |
| local_dir=local_dir, | |
| ) | |
| lance_path = f"{local_dir}/data/train.lance" | |
| print(f"Loading Lance dataset from {lance_path}") | |
| _lance_cache["ds"] = lance.dataset(lance_path) | |
| return _lance_cache["ds"] | |
| def get_embedding_model(): | |
| """Load embedding model (cached, lazy-loaded on first semantic search).""" | |
| if "model" not in _model_cache: | |
| from sentence_transformers import SentenceTransformer | |
| print("Loading embedding model...") | |
| _model_cache["model"] = SentenceTransformer("BAAI/bge-base-en-v1.5") | |
| print("Embedding model loaded!") | |
| return _model_cache["model"] | |
| def get_dataframe() -> pl.DataFrame: | |
| """Load Lance dataset and convert to Polars DataFrame.""" | |
| cache_key = "df" | |
| if cache_key in _dataset_cache: | |
| return _dataset_cache[cache_key] | |
| ds = get_lance_dataset() # Downloads from HF Hub if not cached | |
| # Select columns needed for filtering/display (exclude embeddings for memory) | |
| columns = [ | |
| "id", "title", "abstract", "categories", "update_date", | |
| "authors", "is_new_dataset", "confidence_score" | |
| ] | |
| arrow_table = ds.to_table(columns=columns) | |
| df = pl.from_arrow(arrow_table) | |
| _dataset_cache[cache_key] = df | |
| print(f"Loaded {len(df):,} papers") | |
| return df | |
| def get_categories() -> list[str]: | |
| """Get unique category prefixes for filtering.""" | |
| df = get_dataframe() | |
| # Extract primary category (before first space or as-is) | |
| categories = ( | |
| df.select(pl.col("categories").str.split(" ").list.first().alias("cat")) | |
| .unique() | |
| .sort("cat") | |
| .to_series() | |
| .to_list() | |
| ) | |
| # Get common ML-related categories | |
| ml_cats = ["cs.AI", "cs.CL", "cs.CV", "cs.LG", "cs.NE", "cs.IR", "cs.RO", "stat.ML"] | |
| return [c for c in ml_cats if c in categories] | |
| def get_confidence_options() -> list[dict]: | |
| """Compute confidence filter options from actual data distribution. | |
| Uses percentiles so the UI adapts to any model's score range. | |
| """ | |
| df = get_dataframe() | |
| scores = df.filter(pl.col("is_new_dataset"))["confidence_score"] | |
| options = [{"value": "0.5", "label": "All new datasets", "count": len(scores)}] | |
| for pct_label, quantile in [("Top 75%", 0.25), ("Top 50%", 0.50), ("Top 25%", 0.75)]: | |
| threshold = float(scores.quantile(quantile)) | |
| count = scores.filter(scores >= threshold).len() | |
| options.append({ | |
| "value": f"{threshold:.2f}", | |
| "label": pct_label, | |
| "count": int(count), | |
| }) | |
| options.append({"value": "0", "label": "All papers", "count": len(df)}) | |
| return options | |
| def get_histogram_data() -> dict: | |
| """Get confidence distribution data for histogram display. | |
| Dynamically determines the range from actual data distribution. | |
| Returns dict with bins and metadata. The 50% line marks the prediction boundary. | |
| """ | |
| df = get_dataframe() | |
| # Get all papers with confidence scores | |
| all_papers = df.select("confidence_score", "is_new_dataset") | |
| # Dynamically determine the range from actual data | |
| # Round to nearest 5% for clean boundaries | |
| actual_min = float(all_papers["confidence_score"].min()) | |
| actual_max = float(all_papers["confidence_score"].max()) | |
| # Round down to nearest 5% for min, round up for max | |
| min_pct = max(0, (int(actual_min * 20) / 20)) # Floor to 5% | |
| max_pct = min(1, ((int(actual_max * 20) + 1) / 20)) # Ceil to 5% | |
| # Ensure minimum range of 25% for usability | |
| if max_pct - min_pct < 0.25: | |
| center = (min_pct + max_pct) / 2 | |
| min_pct = max(0, center - 0.125) | |
| max_pct = min(1, center + 0.125) | |
| # Use 25 bins for good granularity | |
| num_bins = 25 | |
| bin_width = (max_pct - min_pct) / num_bins | |
| bins = [] | |
| for i in range(num_bins): | |
| bin_start = min_pct + i * bin_width | |
| bin_end = min_pct + (i + 1) * bin_width | |
| # Count papers in this bin | |
| count = all_papers.filter( | |
| (pl.col("confidence_score") >= bin_start) & | |
| (pl.col("confidence_score") < bin_end) | |
| ).height | |
| # Count new_dataset papers in this bin | |
| new_dataset_count = all_papers.filter( | |
| (pl.col("confidence_score") >= bin_start) & | |
| (pl.col("confidence_score") < bin_end) & | |
| (pl.col("is_new_dataset")) | |
| ).height | |
| bins.append({ | |
| "bin_start": round(bin_start, 3), | |
| "bin_end": round(bin_end, 3), | |
| "bin_pct": int(bin_start * 100), | |
| "count": count, | |
| "new_dataset_count": new_dataset_count, | |
| }) | |
| # Normalize counts for display (max height = 100%) | |
| max_count = max(b["count"] for b in bins) if bins else 1 | |
| for b in bins: | |
| b["height_pct"] = int((b["count"] / max_count) * 100) if max_count > 0 else 0 | |
| b["new_height_pct"] = int((b["new_dataset_count"] / max_count) * 100) if max_count > 0 else 0 | |
| # Calculate cumulative counts from each threshold | |
| # (how many papers are at or above this threshold) | |
| total_so_far = all_papers.height | |
| for b in bins: | |
| b["papers_above"] = total_so_far | |
| total_so_far -= b["count"] | |
| return { | |
| "bins": bins, | |
| "min_pct": round(min_pct, 2), | |
| "max_pct": round(max_pct, 2), | |
| "total_papers": all_papers.height, | |
| "new_dataset_count": all_papers.filter(pl.col("is_new_dataset")).height, | |
| } | |
| def parse_since(since: str) -> Optional[date]: | |
| """Parse 'since' parameter to a date. Returns None for 'all time'.""" | |
| if not since: | |
| return None | |
| today = date.today() | |
| if since == "1m": | |
| return today - timedelta(days=30) | |
| elif since == "6m": | |
| return today - timedelta(days=180) | |
| elif since == "1y": | |
| return today - timedelta(days=365) | |
| return None | |
| def filter_papers( | |
| df: pl.DataFrame, | |
| category: Optional[str] = None, | |
| search: Optional[str] = None, | |
| min_confidence: float = 0.5, | |
| since: Optional[str] = None, | |
| ) -> pl.DataFrame: | |
| """Apply filters to the papers dataframe. | |
| The confidence threshold controls which papers are shown: | |
| - Papers with is_new_dataset=True have confidence >= 0.5 | |
| - Setting threshold to 0 shows all papers | |
| - Setting threshold >= 0.5 effectively shows only new_dataset papers | |
| """ | |
| if min_confidence > 0: | |
| df = df.filter(pl.col("confidence_score") >= min_confidence) | |
| if category: | |
| df = df.filter(pl.col("categories").str.contains(category)) | |
| if search: | |
| search_lower = search.lower() | |
| df = df.filter( | |
| pl.col("title").str.to_lowercase().str.contains(search_lower) | |
| | pl.col("abstract").str.to_lowercase().str.contains(search_lower) | |
| ) | |
| # Date filter | |
| min_date = parse_since(since) | |
| if min_date: | |
| df = df.filter(pl.col("update_date") >= min_date) | |
| return df | |
| def paginate_papers( | |
| df: pl.DataFrame, | |
| page: int = 1, | |
| per_page: int = 20, | |
| sort: str = "date", | |
| ) -> tuple[pl.DataFrame, bool]: | |
| """Sort and paginate papers, return (page_df, has_more). | |
| Sort options: | |
| - "date": By update_date desc, then confidence_score desc | |
| - "relevance": Keep existing order (for semantic search similarity) | |
| """ | |
| if sort == "date": | |
| df_sorted = df.sort( | |
| ["update_date", "confidence_score"], descending=[True, True] | |
| ) | |
| else: | |
| # "relevance" - keep existing order (already sorted by similarity for semantic) | |
| df_sorted = df | |
| start = (page - 1) * per_page | |
| page_df = df_sorted.slice(start, per_page + 1) | |
| has_more = len(page_df) > per_page | |
| return page_df.head(per_page), has_more | |
| def semantic_search( | |
| query: str, | |
| k: int = 100, | |
| category: Optional[str] = None, | |
| min_confidence: float = 0.5, | |
| since: Optional[str] = None, | |
| ) -> pl.DataFrame: | |
| """Search using vector similarity via Lance nearest neighbor. | |
| Returns DataFrame with similarity_score column (0-1, higher is more similar). | |
| """ | |
| model = get_embedding_model() | |
| query_embedding = model.encode(query).tolist() | |
| ds = get_lance_dataset() | |
| # Build SQL filter (Lance supports SQL-like syntax) | |
| filters = [] | |
| if min_confidence > 0: | |
| filters.append(f"confidence_score >= {min_confidence}") | |
| if category: | |
| # Escape single quotes in category name for SQL safety | |
| safe_category = category.replace("'", "''") | |
| filters.append(f"categories LIKE '%{safe_category}%'") | |
| # Date filter - use TIMESTAMP literal for Lance/DataFusion | |
| min_date = parse_since(since) | |
| if min_date: | |
| filters.append(f"update_date >= TIMESTAMP '{min_date.isoformat()} 00:00:00'") | |
| filter_str = " AND ".join(filters) if filters else None | |
| # Vector search - include _distance for similarity calculation | |
| results = ds.scanner( | |
| nearest={"column": "embedding", "q": query_embedding, "k": k}, | |
| filter=filter_str, | |
| columns=["id", "title", "abstract", "categories", "update_date", | |
| "authors", "confidence_score", "_distance"] | |
| ).to_table() | |
| df = pl.from_arrow(results) | |
| # Convert L2 distance to similarity score (0-1 range) | |
| # For normalized embeddings: similarity = 1 - distance/2 | |
| # BGE embeddings are normalized, so L2 distance ranges from 0 to 2 | |
| df = df.with_columns( | |
| (1 - pl.col("_distance") / 2).clip(0, 1).alias("similarity_score") | |
| ).drop("_distance") | |
| return df | |
| async def home( | |
| request: Request, | |
| search: Optional[str] = Query(None), | |
| search_type: str = Query("keyword"), | |
| category: Optional[str] = Query(None), | |
| min_confidence: str = Query("0.5"), # String to preserve exact value for template | |
| since: Optional[str] = Query(None), | |
| sort: str = Query("date"), | |
| ): | |
| """Render the home page with optional initial filter state from URL.""" | |
| df = get_dataframe() | |
| categories = get_categories() | |
| histogram_data = get_histogram_data() | |
| confidence_options = get_confidence_options() | |
| # Get stats | |
| total_papers = len(df) | |
| new_dataset_count = df.filter(pl.col("is_new_dataset")).height | |
| return templates.TemplateResponse( | |
| "index.html", | |
| { | |
| "request": request, | |
| "categories": categories, | |
| "total_papers": total_papers, | |
| "new_dataset_count": new_dataset_count, | |
| "histogram_data": histogram_data, | |
| "confidence_options": confidence_options, | |
| # Pass filter state for URL persistence | |
| "search": search or "", | |
| "search_type": search_type, | |
| "category": category or "", | |
| "min_confidence": min_confidence, | |
| "since": since or "", | |
| "sort": sort, | |
| }, | |
| ) | |
| async def get_papers( | |
| request: Request, | |
| page: int = Query(1, ge=1), | |
| per_page: int = Query(20, ge=1, le=100), | |
| category: Optional[str] = Query(None), | |
| search: Optional[str] = Query(None), | |
| min_confidence: float = Query(0.5, ge=0, le=1), | |
| search_type: str = Query("keyword"), # "keyword" or "semantic" | |
| sort: str = Query("date"), # "date" or "relevance" | |
| since: Optional[str] = Query(None), # "1m", "6m", "1y", or None for all | |
| ): | |
| """Get paginated and filtered papers (returns HTML partial for HTMX). | |
| If accessed directly (not via HTMX), redirects to home page with same params. | |
| """ | |
| # Redirect direct browser visits to home page (this endpoint returns partials) | |
| if "HX-Request" not in request.headers: | |
| # Build redirect URL with current query params | |
| query_string = str(request.url.query) | |
| redirect_url = f"/?{query_string}" if query_string else "/" | |
| return RedirectResponse(url=redirect_url, status_code=302) | |
| if search and search_type == "semantic": | |
| # Vector search - returns pre-sorted by similarity | |
| filtered_df = semantic_search( | |
| query=search, | |
| k=per_page * 5, # Get more for pagination buffer | |
| category=category, | |
| min_confidence=min_confidence, | |
| since=since, | |
| ) | |
| # Default to relevance sort for semantic, but allow date sort | |
| effective_sort = sort if sort == "date" else "relevance" | |
| page_df, has_more = paginate_papers( | |
| filtered_df, page=page, per_page=per_page, sort=effective_sort | |
| ) | |
| else: | |
| # Existing keyword search path | |
| df = get_dataframe() | |
| filtered_df = filter_papers( | |
| df, | |
| category=category, | |
| search=search, | |
| min_confidence=min_confidence, | |
| since=since, | |
| ) | |
| # Keyword search always sorts by date | |
| page_df, has_more = paginate_papers( | |
| filtered_df, page=page, per_page=per_page, sort="date" | |
| ) | |
| # Convert to list of dicts for template | |
| papers = page_df.to_dicts() | |
| # Build clean URL for browser history (/ instead of /papers) | |
| # Only include non-default values to keep URLs short | |
| params = {} | |
| if search: | |
| params["search"] = search | |
| if search_type != "keyword": | |
| params["search_type"] = search_type | |
| if category: | |
| params["category"] = category | |
| if min_confidence != 0.5: | |
| params["min_confidence"] = min_confidence | |
| if since: | |
| params["since"] = since | |
| if sort != "date": | |
| params["sort"] = sort | |
| push_url = "/?" + urlencode(params) if params else "/" | |
| response = templates.TemplateResponse( | |
| "partials/paper_list.html", | |
| { | |
| "request": request, | |
| "papers": papers, | |
| "page": page, | |
| "has_more": has_more, | |
| "category": category or "", | |
| "search": search or "", | |
| "min_confidence": min_confidence, | |
| "search_type": search_type, | |
| "sort": sort, | |
| "since": since or "", | |
| "total_filtered": len(filtered_df), | |
| }, | |
| ) | |
| # Tell HTMX to push clean URL (/ not /papers) | |
| response.headers["HX-Push-Url"] = push_url | |
| return response | |
| async def get_stats(): | |
| """Get dataset statistics as JSON.""" | |
| df = get_dataframe() | |
| new_datasets = df.filter(pl.col("is_new_dataset")) | |
| return { | |
| "total_papers": len(df), | |
| "new_dataset_count": len(new_datasets), | |
| "avg_confidence": float(df["confidence_score"].mean()), | |
| "date_range": { | |
| "min": str(df["update_date"].min()), | |
| "max": str(df["update_date"].max()), | |
| }, | |
| } | |
| # Preload dataset and model on startup | |
| async def startup_event(): | |
| """Preload dataset and embedding model on startup.""" | |
| print("Preloading dataset...") | |
| get_dataframe() | |
| print("Dataset loaded!") | |
| print("Preloading embedding model...") | |
| get_embedding_model() | |
| print("Embedding model loaded!") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |