davanstrien's picture
davanstrien HF Staff
Fix: filter on is_new_dataset when showing dataset papers
fa3f012
"""
FastAPI + HTMX app for browsing arxiv papers with new ML datasets.
Downloads Lance dataset from HuggingFace Hub and loads locally.
"""
import math
import re
from datetime import date, timedelta
from functools import lru_cache
from typing import Optional
from urllib.parse import urlencode
import lance
import polars as pl
from cachetools import TTLCache
from dotenv import load_dotenv
from fastapi import FastAPI, Query, Request
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from huggingface_hub import snapshot_download
from markupsafe import Markup
# Load .env file for local development (HF_TOKEN)
load_dotenv()
app = FastAPI(title="ArXiv New ML Datasets")
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
def highlight_search(text: str, search: str) -> Markup:
"""Highlight search terms in text with yellow background."""
if not search or not text:
return Markup(text) if text else Markup("")
# Escape HTML in text first
import html
text = html.escape(str(text))
# Case-insensitive replacement with highlight span
pattern = re.compile(re.escape(search), re.IGNORECASE)
highlighted = pattern.sub(
lambda m: f'<mark class="bg-yellow-200 px-0.5 rounded">{m.group()}</mark>',
text
)
return Markup(highlighted)
# Register custom filters
templates.env.filters["highlight"] = highlight_search
def confidence_fmt(score):
"""Format confidence as percentage, truncating to 1 decimal to avoid rounding 99.95->100."""
pct = math.floor(score * 1000) / 10
return f"{pct:.1f}"
templates.env.filters["confidence"] = confidence_fmt
# Dataset config
DATASET_REPO = "librarian-bots/arxiv-cs-papers-lance"
# Cache for dataset (reload every 6 hours)
_dataset_cache: TTLCache = TTLCache(maxsize=1, ttl=60 * 60 * 6)
# Cache for Lance dataset connection (for vector search)
_lance_cache: dict = {}
# Cache for embedding model (lazy loaded on first semantic search)
_model_cache: dict = {}
def get_lance_dataset():
"""Download dataset from HF Hub (cached) and return Lance connection."""
if "ds" not in _lance_cache:
import os
# Use HF_HOME or /tmp for Spaces compatibility (./data not writable on Spaces)
cache_base = os.environ.get("HF_HOME", "/tmp/hf_cache")
local_dir = f"{cache_base}/arxiv-lance"
print(f"Downloading dataset from {DATASET_REPO} to {local_dir}...")
snapshot_download(
DATASET_REPO,
repo_type="dataset",
local_dir=local_dir,
)
lance_path = f"{local_dir}/data/train.lance"
print(f"Loading Lance dataset from {lance_path}")
_lance_cache["ds"] = lance.dataset(lance_path)
return _lance_cache["ds"]
def get_embedding_model():
"""Load embedding model (cached, lazy-loaded on first semantic search)."""
if "model" not in _model_cache:
from sentence_transformers import SentenceTransformer
print("Loading embedding model...")
_model_cache["model"] = SentenceTransformer("BAAI/bge-base-en-v1.5")
print("Embedding model loaded!")
return _model_cache["model"]
def get_dataframe() -> pl.DataFrame:
"""Load Lance dataset and convert to Polars DataFrame."""
cache_key = "df"
if cache_key in _dataset_cache:
return _dataset_cache[cache_key]
ds = get_lance_dataset() # Downloads from HF Hub if not cached
# Select columns needed for filtering/display (exclude embeddings for memory)
columns = [
"id", "title", "abstract", "categories", "update_date",
"authors", "is_new_dataset", "confidence_score"
]
arrow_table = ds.to_table(columns=columns)
df = pl.from_arrow(arrow_table)
_dataset_cache[cache_key] = df
print(f"Loaded {len(df):,} papers")
return df
@lru_cache(maxsize=1)
def get_categories() -> list[str]:
"""Get unique category prefixes for filtering."""
df = get_dataframe()
# Extract primary category (before first space or as-is)
categories = (
df.select(pl.col("categories").str.split(" ").list.first().alias("cat"))
.unique()
.sort("cat")
.to_series()
.to_list()
)
# Get common ML-related categories
ml_cats = ["cs.AI", "cs.CL", "cs.CV", "cs.LG", "cs.NE", "cs.IR", "cs.RO", "stat.ML"]
return [c for c in ml_cats if c in categories]
@lru_cache(maxsize=1)
def get_confidence_options() -> list[dict]:
"""Compute confidence filter options from actual data distribution.
Uses percentiles so the UI adapts to any model's score range.
"""
df = get_dataframe()
scores = df.filter(pl.col("is_new_dataset"))["confidence_score"]
options = [{"value": "0.5", "label": "All new datasets", "count": len(scores)}]
for pct_label, quantile in [("Top 75%", 0.25), ("Top 50%", 0.50), ("Top 25%", 0.75)]:
threshold = float(scores.quantile(quantile))
count = scores.filter(scores >= threshold).len()
options.append({
"value": f"{threshold:.2f}",
"label": pct_label,
"count": int(count),
})
options.append({"value": "0", "label": "All papers", "count": len(df)})
return options
@lru_cache(maxsize=1)
def get_histogram_data() -> dict:
"""Get confidence distribution data for histogram display.
Dynamically determines the range from actual data distribution.
Returns dict with bins and metadata. The 50% line marks the prediction boundary.
"""
df = get_dataframe()
# Get all papers with confidence scores
all_papers = df.select("confidence_score", "is_new_dataset")
# Dynamically determine the range from actual data
# Round to nearest 5% for clean boundaries
actual_min = float(all_papers["confidence_score"].min())
actual_max = float(all_papers["confidence_score"].max())
# Round down to nearest 5% for min, round up for max
min_pct = max(0, (int(actual_min * 20) / 20)) # Floor to 5%
max_pct = min(1, ((int(actual_max * 20) + 1) / 20)) # Ceil to 5%
# Ensure minimum range of 25% for usability
if max_pct - min_pct < 0.25:
center = (min_pct + max_pct) / 2
min_pct = max(0, center - 0.125)
max_pct = min(1, center + 0.125)
# Use 25 bins for good granularity
num_bins = 25
bin_width = (max_pct - min_pct) / num_bins
bins = []
for i in range(num_bins):
bin_start = min_pct + i * bin_width
bin_end = min_pct + (i + 1) * bin_width
# Count papers in this bin
count = all_papers.filter(
(pl.col("confidence_score") >= bin_start) &
(pl.col("confidence_score") < bin_end)
).height
# Count new_dataset papers in this bin
new_dataset_count = all_papers.filter(
(pl.col("confidence_score") >= bin_start) &
(pl.col("confidence_score") < bin_end) &
(pl.col("is_new_dataset"))
).height
bins.append({
"bin_start": round(bin_start, 3),
"bin_end": round(bin_end, 3),
"bin_pct": int(bin_start * 100),
"count": count,
"new_dataset_count": new_dataset_count,
})
# Normalize counts for display (max height = 100%)
max_count = max(b["count"] for b in bins) if bins else 1
for b in bins:
b["height_pct"] = int((b["count"] / max_count) * 100) if max_count > 0 else 0
b["new_height_pct"] = int((b["new_dataset_count"] / max_count) * 100) if max_count > 0 else 0
# Calculate cumulative counts from each threshold
# (how many papers are at or above this threshold)
total_so_far = all_papers.height
for b in bins:
b["papers_above"] = total_so_far
total_so_far -= b["count"]
return {
"bins": bins,
"min_pct": round(min_pct, 2),
"max_pct": round(max_pct, 2),
"total_papers": all_papers.height,
"new_dataset_count": all_papers.filter(pl.col("is_new_dataset")).height,
}
def parse_since(since: str) -> Optional[date]:
"""Parse 'since' parameter to a date. Returns None for 'all time'."""
if not since:
return None
today = date.today()
if since == "1m":
return today - timedelta(days=30)
elif since == "6m":
return today - timedelta(days=180)
elif since == "1y":
return today - timedelta(days=365)
return None
def filter_papers(
df: pl.DataFrame,
category: Optional[str] = None,
search: Optional[str] = None,
min_confidence: float = 0.5,
since: Optional[str] = None,
) -> pl.DataFrame:
"""Apply filters to the papers dataframe.
The confidence threshold controls which papers are shown:
- Papers with is_new_dataset=True have confidence >= 0.5
- Setting threshold to 0 shows all papers
- Setting threshold >= 0.5 effectively shows only new_dataset papers
"""
if min_confidence >= 0.5:
# Show only papers classified as new datasets, filtered by confidence
df = df.filter(
pl.col("is_new_dataset") & (pl.col("confidence_score") >= min_confidence)
)
elif min_confidence > 0:
df = df.filter(pl.col("confidence_score") >= min_confidence)
if category:
df = df.filter(pl.col("categories").str.contains(category))
if search:
search_lower = search.lower()
df = df.filter(
pl.col("title").str.to_lowercase().str.contains(search_lower)
| pl.col("abstract").str.to_lowercase().str.contains(search_lower)
)
# Date filter
min_date = parse_since(since)
if min_date:
df = df.filter(pl.col("update_date") >= min_date)
return df
def paginate_papers(
df: pl.DataFrame,
page: int = 1,
per_page: int = 20,
sort: str = "date",
) -> tuple[pl.DataFrame, bool]:
"""Sort and paginate papers, return (page_df, has_more).
Sort options:
- "date": By update_date desc, then confidence_score desc
- "relevance": Keep existing order (for semantic search similarity)
"""
if sort == "date":
df_sorted = df.sort(
["update_date", "confidence_score"], descending=[True, True]
)
else:
# "relevance" - keep existing order (already sorted by similarity for semantic)
df_sorted = df
start = (page - 1) * per_page
page_df = df_sorted.slice(start, per_page + 1)
has_more = len(page_df) > per_page
return page_df.head(per_page), has_more
def semantic_search(
query: str,
k: int = 100,
category: Optional[str] = None,
min_confidence: float = 0.5,
since: Optional[str] = None,
) -> pl.DataFrame:
"""Search using vector similarity via Lance nearest neighbor.
Returns DataFrame with similarity_score column (0-1, higher is more similar).
"""
model = get_embedding_model()
query_embedding = model.encode(query).tolist()
ds = get_lance_dataset()
# Build SQL filter (Lance supports SQL-like syntax)
filters = []
if min_confidence >= 0.5:
filters.append("is_new_dataset = true")
filters.append(f"confidence_score >= {min_confidence}")
elif min_confidence > 0:
filters.append(f"confidence_score >= {min_confidence}")
if category:
# Escape single quotes in category name for SQL safety
safe_category = category.replace("'", "''")
filters.append(f"categories LIKE '%{safe_category}%'")
# Date filter - use TIMESTAMP literal for Lance/DataFusion
min_date = parse_since(since)
if min_date:
filters.append(f"update_date >= TIMESTAMP '{min_date.isoformat()} 00:00:00'")
filter_str = " AND ".join(filters) if filters else None
# Vector search - include _distance for similarity calculation
results = ds.scanner(
nearest={"column": "embedding", "q": query_embedding, "k": k},
filter=filter_str,
columns=["id", "title", "abstract", "categories", "update_date",
"authors", "confidence_score", "_distance"]
).to_table()
df = pl.from_arrow(results)
# Convert L2 distance to similarity score (0-1 range)
# For normalized embeddings: similarity = 1 - distance/2
# BGE embeddings are normalized, so L2 distance ranges from 0 to 2
df = df.with_columns(
(1 - pl.col("_distance") / 2).clip(0, 1).alias("similarity_score")
).drop("_distance")
return df
@app.get("/", response_class=HTMLResponse)
async def home(
request: Request,
search: Optional[str] = Query(None),
search_type: str = Query("keyword"),
category: Optional[str] = Query(None),
min_confidence: str = Query("0.5"), # String to preserve exact value for template
since: Optional[str] = Query(None),
sort: str = Query("date"),
):
"""Render the home page with optional initial filter state from URL."""
df = get_dataframe()
categories = get_categories()
histogram_data = get_histogram_data()
confidence_options = get_confidence_options()
# Get stats
total_papers = len(df)
new_dataset_count = df.filter(pl.col("is_new_dataset")).height
return templates.TemplateResponse(
"index.html",
{
"request": request,
"categories": categories,
"total_papers": total_papers,
"new_dataset_count": new_dataset_count,
"histogram_data": histogram_data,
"confidence_options": confidence_options,
# Pass filter state for URL persistence
"search": search or "",
"search_type": search_type,
"category": category or "",
"min_confidence": min_confidence,
"since": since or "",
"sort": sort,
},
)
@app.get("/papers", response_class=HTMLResponse)
async def get_papers(
request: Request,
page: int = Query(1, ge=1),
per_page: int = Query(20, ge=1, le=100),
category: Optional[str] = Query(None),
search: Optional[str] = Query(None),
min_confidence: float = Query(0.5, ge=0, le=1),
search_type: str = Query("keyword"), # "keyword" or "semantic"
sort: str = Query("date"), # "date" or "relevance"
since: Optional[str] = Query(None), # "1m", "6m", "1y", or None for all
):
"""Get paginated and filtered papers (returns HTML partial for HTMX).
If accessed directly (not via HTMX), redirects to home page with same params.
"""
# Redirect direct browser visits to home page (this endpoint returns partials)
if "HX-Request" not in request.headers:
# Build redirect URL with current query params
query_string = str(request.url.query)
redirect_url = f"/?{query_string}" if query_string else "/"
return RedirectResponse(url=redirect_url, status_code=302)
if search and search_type == "semantic":
# Vector search - returns pre-sorted by similarity
filtered_df = semantic_search(
query=search,
k=per_page * 5, # Get more for pagination buffer
category=category,
min_confidence=min_confidence,
since=since,
)
# Default to relevance sort for semantic, but allow date sort
effective_sort = sort if sort == "date" else "relevance"
page_df, has_more = paginate_papers(
filtered_df, page=page, per_page=per_page, sort=effective_sort
)
else:
# Existing keyword search path
df = get_dataframe()
filtered_df = filter_papers(
df,
category=category,
search=search,
min_confidence=min_confidence,
since=since,
)
# Keyword search always sorts by date
page_df, has_more = paginate_papers(
filtered_df, page=page, per_page=per_page, sort="date"
)
# Convert to list of dicts for template
papers = page_df.to_dicts()
# Build clean URL for browser history (/ instead of /papers)
# Only include non-default values to keep URLs short
params = {}
if search:
params["search"] = search
if search_type != "keyword":
params["search_type"] = search_type
if category:
params["category"] = category
if min_confidence != 0.5:
params["min_confidence"] = min_confidence
if since:
params["since"] = since
if sort != "date":
params["sort"] = sort
push_url = "/?" + urlencode(params) if params else "/"
response = templates.TemplateResponse(
"partials/paper_list.html",
{
"request": request,
"papers": papers,
"page": page,
"has_more": has_more,
"category": category or "",
"search": search or "",
"min_confidence": min_confidence,
"search_type": search_type,
"sort": sort,
"since": since or "",
"total_filtered": len(filtered_df),
},
)
# Tell HTMX to push clean URL (/ not /papers)
response.headers["HX-Push-Url"] = push_url
return response
@app.get("/api/stats")
async def get_stats():
"""Get dataset statistics as JSON."""
df = get_dataframe()
new_datasets = df.filter(pl.col("is_new_dataset"))
return {
"total_papers": len(df),
"new_dataset_count": len(new_datasets),
"avg_confidence": float(df["confidence_score"].mean()),
"date_range": {
"min": str(df["update_date"].min()),
"max": str(df["update_date"].max()),
},
}
# Preload dataset and model on startup
@app.on_event("startup")
async def startup_event():
"""Preload dataset and embedding model on startup."""
print("Preloading dataset...")
get_dataframe()
print("Dataset loaded!")
print("Preloading embedding model...")
get_embedding_model()
print("Embedding model loaded!")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)