File size: 18,130 Bytes
4cf63e7 0bb28cf 460689f 4cf63e7 284db10 4cf63e7 ad38c8f 4cf63e7 0bb28cf 4cf63e7 0bb28cf ad38c8f 4cf63e7 ad38c8f 4cf63e7 a38b615 4cf63e7 a38b615 4cf63e7 460689f 4cf63e7 460689f 4cf63e7 a546a29 4cf63e7 a38b615 4cf63e7 ad38c8f 4cf63e7 0aea3fd 4cf63e7 0aea3fd 4cf63e7 0aea3fd 4cf63e7 ad38c8f fa3f012 4cf63e7 ad38c8f 4cf63e7 ad38c8f 4cf63e7 ad38c8f 4cf63e7 ad38c8f 4cf63e7 fa3f012 4cf63e7 976f652 4cf63e7 976f652 4cf63e7 0aea3fd 4cf63e7 0aea3fd 4cf63e7 ad38c8f 4cf63e7 976f652 4cf63e7 976f652 4cf63e7 976f652 4cf63e7 976f652 4cf63e7 ad38c8f 4cf63e7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 | """
FastAPI + HTMX app for browsing arxiv papers with new ML datasets.
Downloads Lance dataset from HuggingFace Hub and loads locally.
"""
import math
import re
from datetime import date, timedelta
from functools import lru_cache
from typing import Optional
from urllib.parse import urlencode
import lance
import polars as pl
from cachetools import TTLCache
from dotenv import load_dotenv
from fastapi import FastAPI, Query, Request
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from huggingface_hub import snapshot_download
from markupsafe import Markup
# Load .env file for local development (HF_TOKEN)
load_dotenv()
app = FastAPI(title="ArXiv New ML Datasets")
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
def highlight_search(text: str, search: str) -> Markup:
"""Highlight search terms in text with yellow background."""
if not search or not text:
return Markup(text) if text else Markup("")
# Escape HTML in text first
import html
text = html.escape(str(text))
# Case-insensitive replacement with highlight span
pattern = re.compile(re.escape(search), re.IGNORECASE)
highlighted = pattern.sub(
lambda m: f'<mark class="bg-yellow-200 px-0.5 rounded">{m.group()}</mark>',
text
)
return Markup(highlighted)
# Register custom filters
templates.env.filters["highlight"] = highlight_search
def confidence_fmt(score):
"""Format confidence as percentage, truncating to 1 decimal to avoid rounding 99.95->100."""
pct = math.floor(score * 1000) / 10
return f"{pct:.1f}"
templates.env.filters["confidence"] = confidence_fmt
# Dataset config
DATASET_REPO = "librarian-bots/arxiv-cs-papers-lance"
# Cache for dataset (reload every 6 hours)
_dataset_cache: TTLCache = TTLCache(maxsize=1, ttl=60 * 60 * 6)
# Cache for Lance dataset connection (for vector search)
_lance_cache: dict = {}
# Cache for embedding model (lazy loaded on first semantic search)
_model_cache: dict = {}
def get_lance_dataset():
"""Download dataset from HF Hub (cached) and return Lance connection."""
if "ds" not in _lance_cache:
import os
# Use HF_HOME or /tmp for Spaces compatibility (./data not writable on Spaces)
cache_base = os.environ.get("HF_HOME", "/tmp/hf_cache")
local_dir = f"{cache_base}/arxiv-lance"
print(f"Downloading dataset from {DATASET_REPO} to {local_dir}...")
snapshot_download(
DATASET_REPO,
repo_type="dataset",
local_dir=local_dir,
)
lance_path = f"{local_dir}/data/train.lance"
print(f"Loading Lance dataset from {lance_path}")
_lance_cache["ds"] = lance.dataset(lance_path)
return _lance_cache["ds"]
def get_embedding_model():
"""Load embedding model (cached, lazy-loaded on first semantic search)."""
if "model" not in _model_cache:
from sentence_transformers import SentenceTransformer
print("Loading embedding model...")
_model_cache["model"] = SentenceTransformer("BAAI/bge-base-en-v1.5")
print("Embedding model loaded!")
return _model_cache["model"]
def get_dataframe() -> pl.DataFrame:
"""Load Lance dataset and convert to Polars DataFrame."""
cache_key = "df"
if cache_key in _dataset_cache:
return _dataset_cache[cache_key]
ds = get_lance_dataset() # Downloads from HF Hub if not cached
# Select columns needed for filtering/display (exclude embeddings for memory)
columns = [
"id", "title", "abstract", "categories", "update_date",
"authors", "is_new_dataset", "confidence_score"
]
arrow_table = ds.to_table(columns=columns)
df = pl.from_arrow(arrow_table)
_dataset_cache[cache_key] = df
print(f"Loaded {len(df):,} papers")
return df
@lru_cache(maxsize=1)
def get_categories() -> list[str]:
"""Get unique category prefixes for filtering."""
df = get_dataframe()
# Extract primary category (before first space or as-is)
categories = (
df.select(pl.col("categories").str.split(" ").list.first().alias("cat"))
.unique()
.sort("cat")
.to_series()
.to_list()
)
# Get common ML-related categories
ml_cats = ["cs.AI", "cs.CL", "cs.CV", "cs.LG", "cs.NE", "cs.IR", "cs.RO", "stat.ML"]
return [c for c in ml_cats if c in categories]
@lru_cache(maxsize=1)
def get_confidence_options() -> list[dict]:
"""Compute confidence filter options from actual data distribution.
Uses percentiles so the UI adapts to any model's score range.
"""
df = get_dataframe()
scores = df.filter(pl.col("is_new_dataset"))["confidence_score"]
options = [{"value": "0.5", "label": "All new datasets", "count": len(scores)}]
for pct_label, quantile in [("Top 75%", 0.25), ("Top 50%", 0.50), ("Top 25%", 0.75)]:
threshold = float(scores.quantile(quantile))
count = scores.filter(scores >= threshold).len()
options.append({
"value": f"{threshold:.2f}",
"label": pct_label,
"count": int(count),
})
options.append({"value": "0", "label": "All papers", "count": len(df)})
return options
@lru_cache(maxsize=1)
def get_histogram_data() -> dict:
"""Get confidence distribution data for histogram display.
Dynamically determines the range from actual data distribution.
Returns dict with bins and metadata. The 50% line marks the prediction boundary.
"""
df = get_dataframe()
# Get all papers with confidence scores
all_papers = df.select("confidence_score", "is_new_dataset")
# Dynamically determine the range from actual data
# Round to nearest 5% for clean boundaries
actual_min = float(all_papers["confidence_score"].min())
actual_max = float(all_papers["confidence_score"].max())
# Round down to nearest 5% for min, round up for max
min_pct = max(0, (int(actual_min * 20) / 20)) # Floor to 5%
max_pct = min(1, ((int(actual_max * 20) + 1) / 20)) # Ceil to 5%
# Ensure minimum range of 25% for usability
if max_pct - min_pct < 0.25:
center = (min_pct + max_pct) / 2
min_pct = max(0, center - 0.125)
max_pct = min(1, center + 0.125)
# Use 25 bins for good granularity
num_bins = 25
bin_width = (max_pct - min_pct) / num_bins
bins = []
for i in range(num_bins):
bin_start = min_pct + i * bin_width
bin_end = min_pct + (i + 1) * bin_width
# Count papers in this bin
count = all_papers.filter(
(pl.col("confidence_score") >= bin_start) &
(pl.col("confidence_score") < bin_end)
).height
# Count new_dataset papers in this bin
new_dataset_count = all_papers.filter(
(pl.col("confidence_score") >= bin_start) &
(pl.col("confidence_score") < bin_end) &
(pl.col("is_new_dataset"))
).height
bins.append({
"bin_start": round(bin_start, 3),
"bin_end": round(bin_end, 3),
"bin_pct": int(bin_start * 100),
"count": count,
"new_dataset_count": new_dataset_count,
})
# Normalize counts for display (max height = 100%)
max_count = max(b["count"] for b in bins) if bins else 1
for b in bins:
b["height_pct"] = int((b["count"] / max_count) * 100) if max_count > 0 else 0
b["new_height_pct"] = int((b["new_dataset_count"] / max_count) * 100) if max_count > 0 else 0
# Calculate cumulative counts from each threshold
# (how many papers are at or above this threshold)
total_so_far = all_papers.height
for b in bins:
b["papers_above"] = total_so_far
total_so_far -= b["count"]
return {
"bins": bins,
"min_pct": round(min_pct, 2),
"max_pct": round(max_pct, 2),
"total_papers": all_papers.height,
"new_dataset_count": all_papers.filter(pl.col("is_new_dataset")).height,
}
def parse_since(since: str) -> Optional[date]:
"""Parse 'since' parameter to a date. Returns None for 'all time'."""
if not since:
return None
today = date.today()
if since == "1m":
return today - timedelta(days=30)
elif since == "6m":
return today - timedelta(days=180)
elif since == "1y":
return today - timedelta(days=365)
return None
def filter_papers(
df: pl.DataFrame,
category: Optional[str] = None,
search: Optional[str] = None,
min_confidence: float = 0.5,
since: Optional[str] = None,
) -> pl.DataFrame:
"""Apply filters to the papers dataframe.
The confidence threshold controls which papers are shown:
- Papers with is_new_dataset=True have confidence >= 0.5
- Setting threshold to 0 shows all papers
- Setting threshold >= 0.5 effectively shows only new_dataset papers
"""
if min_confidence >= 0.5:
# Show only papers classified as new datasets, filtered by confidence
df = df.filter(
pl.col("is_new_dataset") & (pl.col("confidence_score") >= min_confidence)
)
elif min_confidence > 0:
df = df.filter(pl.col("confidence_score") >= min_confidence)
if category:
df = df.filter(pl.col("categories").str.contains(category))
if search:
search_lower = search.lower()
df = df.filter(
pl.col("title").str.to_lowercase().str.contains(search_lower)
| pl.col("abstract").str.to_lowercase().str.contains(search_lower)
)
# Date filter
min_date = parse_since(since)
if min_date:
df = df.filter(pl.col("update_date") >= min_date)
return df
def paginate_papers(
df: pl.DataFrame,
page: int = 1,
per_page: int = 20,
sort: str = "date",
) -> tuple[pl.DataFrame, bool]:
"""Sort and paginate papers, return (page_df, has_more).
Sort options:
- "date": By update_date desc, then confidence_score desc
- "relevance": Keep existing order (for semantic search similarity)
"""
if sort == "date":
df_sorted = df.sort(
["update_date", "confidence_score"], descending=[True, True]
)
else:
# "relevance" - keep existing order (already sorted by similarity for semantic)
df_sorted = df
start = (page - 1) * per_page
page_df = df_sorted.slice(start, per_page + 1)
has_more = len(page_df) > per_page
return page_df.head(per_page), has_more
def semantic_search(
query: str,
k: int = 100,
category: Optional[str] = None,
min_confidence: float = 0.5,
since: Optional[str] = None,
) -> pl.DataFrame:
"""Search using vector similarity via Lance nearest neighbor.
Returns DataFrame with similarity_score column (0-1, higher is more similar).
"""
model = get_embedding_model()
query_embedding = model.encode(query).tolist()
ds = get_lance_dataset()
# Build SQL filter (Lance supports SQL-like syntax)
filters = []
if min_confidence >= 0.5:
filters.append("is_new_dataset = true")
filters.append(f"confidence_score >= {min_confidence}")
elif min_confidence > 0:
filters.append(f"confidence_score >= {min_confidence}")
if category:
# Escape single quotes in category name for SQL safety
safe_category = category.replace("'", "''")
filters.append(f"categories LIKE '%{safe_category}%'")
# Date filter - use TIMESTAMP literal for Lance/DataFusion
min_date = parse_since(since)
if min_date:
filters.append(f"update_date >= TIMESTAMP '{min_date.isoformat()} 00:00:00'")
filter_str = " AND ".join(filters) if filters else None
# Vector search - include _distance for similarity calculation
results = ds.scanner(
nearest={"column": "embedding", "q": query_embedding, "k": k},
filter=filter_str,
columns=["id", "title", "abstract", "categories", "update_date",
"authors", "confidence_score", "_distance"]
).to_table()
df = pl.from_arrow(results)
# Convert L2 distance to similarity score (0-1 range)
# For normalized embeddings: similarity = 1 - distance/2
# BGE embeddings are normalized, so L2 distance ranges from 0 to 2
df = df.with_columns(
(1 - pl.col("_distance") / 2).clip(0, 1).alias("similarity_score")
).drop("_distance")
return df
@app.get("/", response_class=HTMLResponse)
async def home(
request: Request,
search: Optional[str] = Query(None),
search_type: str = Query("keyword"),
category: Optional[str] = Query(None),
min_confidence: str = Query("0.5"), # String to preserve exact value for template
since: Optional[str] = Query(None),
sort: str = Query("date"),
):
"""Render the home page with optional initial filter state from URL."""
df = get_dataframe()
categories = get_categories()
histogram_data = get_histogram_data()
confidence_options = get_confidence_options()
# Get stats
total_papers = len(df)
new_dataset_count = df.filter(pl.col("is_new_dataset")).height
return templates.TemplateResponse(
"index.html",
{
"request": request,
"categories": categories,
"total_papers": total_papers,
"new_dataset_count": new_dataset_count,
"histogram_data": histogram_data,
"confidence_options": confidence_options,
# Pass filter state for URL persistence
"search": search or "",
"search_type": search_type,
"category": category or "",
"min_confidence": min_confidence,
"since": since or "",
"sort": sort,
},
)
@app.get("/papers", response_class=HTMLResponse)
async def get_papers(
request: Request,
page: int = Query(1, ge=1),
per_page: int = Query(20, ge=1, le=100),
category: Optional[str] = Query(None),
search: Optional[str] = Query(None),
min_confidence: float = Query(0.5, ge=0, le=1),
search_type: str = Query("keyword"), # "keyword" or "semantic"
sort: str = Query("date"), # "date" or "relevance"
since: Optional[str] = Query(None), # "1m", "6m", "1y", or None for all
):
"""Get paginated and filtered papers (returns HTML partial for HTMX).
If accessed directly (not via HTMX), redirects to home page with same params.
"""
# Redirect direct browser visits to home page (this endpoint returns partials)
if "HX-Request" not in request.headers:
# Build redirect URL with current query params
query_string = str(request.url.query)
redirect_url = f"/?{query_string}" if query_string else "/"
return RedirectResponse(url=redirect_url, status_code=302)
if search and search_type == "semantic":
# Vector search - returns pre-sorted by similarity
filtered_df = semantic_search(
query=search,
k=per_page * 5, # Get more for pagination buffer
category=category,
min_confidence=min_confidence,
since=since,
)
# Default to relevance sort for semantic, but allow date sort
effective_sort = sort if sort == "date" else "relevance"
page_df, has_more = paginate_papers(
filtered_df, page=page, per_page=per_page, sort=effective_sort
)
else:
# Existing keyword search path
df = get_dataframe()
filtered_df = filter_papers(
df,
category=category,
search=search,
min_confidence=min_confidence,
since=since,
)
# Keyword search always sorts by date
page_df, has_more = paginate_papers(
filtered_df, page=page, per_page=per_page, sort="date"
)
# Convert to list of dicts for template
papers = page_df.to_dicts()
# Build clean URL for browser history (/ instead of /papers)
# Only include non-default values to keep URLs short
params = {}
if search:
params["search"] = search
if search_type != "keyword":
params["search_type"] = search_type
if category:
params["category"] = category
if min_confidence != 0.5:
params["min_confidence"] = min_confidence
if since:
params["since"] = since
if sort != "date":
params["sort"] = sort
push_url = "/?" + urlencode(params) if params else "/"
response = templates.TemplateResponse(
"partials/paper_list.html",
{
"request": request,
"papers": papers,
"page": page,
"has_more": has_more,
"category": category or "",
"search": search or "",
"min_confidence": min_confidence,
"search_type": search_type,
"sort": sort,
"since": since or "",
"total_filtered": len(filtered_df),
},
)
# Tell HTMX to push clean URL (/ not /papers)
response.headers["HX-Push-Url"] = push_url
return response
@app.get("/api/stats")
async def get_stats():
"""Get dataset statistics as JSON."""
df = get_dataframe()
new_datasets = df.filter(pl.col("is_new_dataset"))
return {
"total_papers": len(df),
"new_dataset_count": len(new_datasets),
"avg_confidence": float(df["confidence_score"].mean()),
"date_range": {
"min": str(df["update_date"].min()),
"max": str(df["update_date"].max()),
},
}
# Preload dataset and model on startup
@app.on_event("startup")
async def startup_event():
"""Preload dataset and embedding model on startup."""
print("Preloading dataset...")
get_dataframe()
print("Dataset loaded!")
print("Preloading embedding model...")
get_embedding_model()
print("Embedding model loaded!")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
|