File size: 18,130 Bytes
4cf63e7
 
 
 
0bb28cf
460689f
4cf63e7
 
 
 
 
284db10
4cf63e7
 
 
 
 
 
 
 
 
 
ad38c8f
4cf63e7
 
0bb28cf
4cf63e7
 
 
0bb28cf
ad38c8f
4cf63e7
 
 
 
ad38c8f
4cf63e7
 
 
a38b615
4cf63e7
 
 
 
 
 
 
a38b615
4cf63e7
460689f
4cf63e7
 
460689f
 
 
 
 
 
 
 
 
4cf63e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a546a29
 
 
 
4cf63e7
 
 
 
 
a38b615
4cf63e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad38c8f
4cf63e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0aea3fd
 
4cf63e7
0aea3fd
4cf63e7
 
0aea3fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cf63e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad38c8f
fa3f012
 
 
 
 
 
4cf63e7
 
 
 
 
 
 
 
 
 
 
ad38c8f
4cf63e7
 
 
 
ad38c8f
 
 
 
4cf63e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad38c8f
 
4cf63e7
 
 
 
 
 
 
 
ad38c8f
4cf63e7
 
 
 
 
 
 
 
 
fa3f012
 
 
 
4cf63e7
 
 
 
 
 
 
 
 
 
976f652
4cf63e7
 
 
 
 
 
 
976f652
4cf63e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0aea3fd
4cf63e7
 
 
 
 
 
 
 
 
 
 
 
 
0aea3fd
4cf63e7
 
 
 
 
 
 
 
ad38c8f
4cf63e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
976f652
4cf63e7
 
 
 
976f652
4cf63e7
 
 
 
 
 
 
 
 
976f652
4cf63e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
976f652
4cf63e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad38c8f
4cf63e7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
"""
FastAPI + HTMX app for browsing arxiv papers with new ML datasets.
Downloads Lance dataset from HuggingFace Hub and loads locally.
"""

import math
import re
from datetime import date, timedelta
from functools import lru_cache
from typing import Optional
from urllib.parse import urlencode

import lance
import polars as pl
from cachetools import TTLCache
from dotenv import load_dotenv
from fastapi import FastAPI, Query, Request
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from huggingface_hub import snapshot_download
from markupsafe import Markup

# Load .env file for local development (HF_TOKEN)
load_dotenv()

app = FastAPI(title="ArXiv New ML Datasets")
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")


def highlight_search(text: str, search: str) -> Markup:
    """Highlight search terms in text with yellow background."""
    if not search or not text:
        return Markup(text) if text else Markup("")

    # Escape HTML in text first
    import html
    text = html.escape(str(text))

    # Case-insensitive replacement with highlight span
    pattern = re.compile(re.escape(search), re.IGNORECASE)
    highlighted = pattern.sub(
        lambda m: f'<mark class="bg-yellow-200 px-0.5 rounded">{m.group()}</mark>',
        text
    )
    return Markup(highlighted)


# Register custom filters
templates.env.filters["highlight"] = highlight_search


def confidence_fmt(score):
    """Format confidence as percentage, truncating to 1 decimal to avoid rounding 99.95->100."""
    pct = math.floor(score * 1000) / 10
    return f"{pct:.1f}"


templates.env.filters["confidence"] = confidence_fmt

# Dataset config
DATASET_REPO = "librarian-bots/arxiv-cs-papers-lance"

# Cache for dataset (reload every 6 hours)
_dataset_cache: TTLCache = TTLCache(maxsize=1, ttl=60 * 60 * 6)

# Cache for Lance dataset connection (for vector search)
_lance_cache: dict = {}

# Cache for embedding model (lazy loaded on first semantic search)
_model_cache: dict = {}


def get_lance_dataset():
    """Download dataset from HF Hub (cached) and return Lance connection."""
    if "ds" not in _lance_cache:
        import os
        # Use HF_HOME or /tmp for Spaces compatibility (./data not writable on Spaces)
        cache_base = os.environ.get("HF_HOME", "/tmp/hf_cache")
        local_dir = f"{cache_base}/arxiv-lance"
        print(f"Downloading dataset from {DATASET_REPO} to {local_dir}...")
        snapshot_download(
            DATASET_REPO,
            repo_type="dataset",
            local_dir=local_dir,
        )
        lance_path = f"{local_dir}/data/train.lance"
        print(f"Loading Lance dataset from {lance_path}")
        _lance_cache["ds"] = lance.dataset(lance_path)
    return _lance_cache["ds"]


def get_embedding_model():
    """Load embedding model (cached, lazy-loaded on first semantic search)."""
    if "model" not in _model_cache:
        from sentence_transformers import SentenceTransformer
        print("Loading embedding model...")
        _model_cache["model"] = SentenceTransformer("BAAI/bge-base-en-v1.5")
        print("Embedding model loaded!")
    return _model_cache["model"]


def get_dataframe() -> pl.DataFrame:
    """Load Lance dataset and convert to Polars DataFrame."""
    cache_key = "df"
    if cache_key in _dataset_cache:
        return _dataset_cache[cache_key]

    ds = get_lance_dataset()  # Downloads from HF Hub if not cached
    # Select columns needed for filtering/display (exclude embeddings for memory)
    columns = [
        "id", "title", "abstract", "categories", "update_date",
        "authors", "is_new_dataset", "confidence_score"
    ]
    arrow_table = ds.to_table(columns=columns)
    df = pl.from_arrow(arrow_table)
    _dataset_cache[cache_key] = df
    print(f"Loaded {len(df):,} papers")
    return df


@lru_cache(maxsize=1)
def get_categories() -> list[str]:
    """Get unique category prefixes for filtering."""
    df = get_dataframe()
    # Extract primary category (before first space or as-is)
    categories = (
        df.select(pl.col("categories").str.split(" ").list.first().alias("cat"))
        .unique()
        .sort("cat")
        .to_series()
        .to_list()
    )
    # Get common ML-related categories
    ml_cats = ["cs.AI", "cs.CL", "cs.CV", "cs.LG", "cs.NE", "cs.IR", "cs.RO", "stat.ML"]
    return [c for c in ml_cats if c in categories]


@lru_cache(maxsize=1)
def get_confidence_options() -> list[dict]:
    """Compute confidence filter options from actual data distribution.

    Uses percentiles so the UI adapts to any model's score range.
    """
    df = get_dataframe()
    scores = df.filter(pl.col("is_new_dataset"))["confidence_score"]

    options = [{"value": "0.5", "label": "All new datasets", "count": len(scores)}]

    for pct_label, quantile in [("Top 75%", 0.25), ("Top 50%", 0.50), ("Top 25%", 0.75)]:
        threshold = float(scores.quantile(quantile))
        count = scores.filter(scores >= threshold).len()
        options.append({
            "value": f"{threshold:.2f}",
            "label": pct_label,
            "count": int(count),
        })

    options.append({"value": "0", "label": "All papers", "count": len(df)})
    return options


@lru_cache(maxsize=1)
def get_histogram_data() -> dict:
    """Get confidence distribution data for histogram display.

    Dynamically determines the range from actual data distribution.
    Returns dict with bins and metadata. The 50% line marks the prediction boundary.
    """
    df = get_dataframe()

    # Get all papers with confidence scores
    all_papers = df.select("confidence_score", "is_new_dataset")

    # Dynamically determine the range from actual data
    # Round to nearest 5% for clean boundaries
    actual_min = float(all_papers["confidence_score"].min())
    actual_max = float(all_papers["confidence_score"].max())

    # Round down to nearest 5% for min, round up for max
    min_pct = max(0, (int(actual_min * 20) / 20))  # Floor to 5%
    max_pct = min(1, ((int(actual_max * 20) + 1) / 20))  # Ceil to 5%

    # Ensure minimum range of 25% for usability
    if max_pct - min_pct < 0.25:
        center = (min_pct + max_pct) / 2
        min_pct = max(0, center - 0.125)
        max_pct = min(1, center + 0.125)

    # Use 25 bins for good granularity
    num_bins = 25
    bin_width = (max_pct - min_pct) / num_bins

    bins = []
    for i in range(num_bins):
        bin_start = min_pct + i * bin_width
        bin_end = min_pct + (i + 1) * bin_width

        # Count papers in this bin
        count = all_papers.filter(
            (pl.col("confidence_score") >= bin_start) &
            (pl.col("confidence_score") < bin_end)
        ).height

        # Count new_dataset papers in this bin
        new_dataset_count = all_papers.filter(
            (pl.col("confidence_score") >= bin_start) &
            (pl.col("confidence_score") < bin_end) &
            (pl.col("is_new_dataset"))
        ).height

        bins.append({
            "bin_start": round(bin_start, 3),
            "bin_end": round(bin_end, 3),
            "bin_pct": int(bin_start * 100),
            "count": count,
            "new_dataset_count": new_dataset_count,
        })

    # Normalize counts for display (max height = 100%)
    max_count = max(b["count"] for b in bins) if bins else 1
    for b in bins:
        b["height_pct"] = int((b["count"] / max_count) * 100) if max_count > 0 else 0
        b["new_height_pct"] = int((b["new_dataset_count"] / max_count) * 100) if max_count > 0 else 0

    # Calculate cumulative counts from each threshold
    # (how many papers are at or above this threshold)
    total_so_far = all_papers.height
    for b in bins:
        b["papers_above"] = total_so_far
        total_so_far -= b["count"]

    return {
        "bins": bins,
        "min_pct": round(min_pct, 2),
        "max_pct": round(max_pct, 2),
        "total_papers": all_papers.height,
        "new_dataset_count": all_papers.filter(pl.col("is_new_dataset")).height,
    }


def parse_since(since: str) -> Optional[date]:
    """Parse 'since' parameter to a date. Returns None for 'all time'."""
    if not since:
        return None
    today = date.today()
    if since == "1m":
        return today - timedelta(days=30)
    elif since == "6m":
        return today - timedelta(days=180)
    elif since == "1y":
        return today - timedelta(days=365)
    return None


def filter_papers(
    df: pl.DataFrame,
    category: Optional[str] = None,
    search: Optional[str] = None,
    min_confidence: float = 0.5,
    since: Optional[str] = None,
) -> pl.DataFrame:
    """Apply filters to the papers dataframe.

    The confidence threshold controls which papers are shown:
    - Papers with is_new_dataset=True have confidence >= 0.5
    - Setting threshold to 0 shows all papers
    - Setting threshold >= 0.5 effectively shows only new_dataset papers
    """
    if min_confidence >= 0.5:
        # Show only papers classified as new datasets, filtered by confidence
        df = df.filter(
            pl.col("is_new_dataset") & (pl.col("confidence_score") >= min_confidence)
        )
    elif min_confidence > 0:
        df = df.filter(pl.col("confidence_score") >= min_confidence)

    if category:
        df = df.filter(pl.col("categories").str.contains(category))

    if search:
        search_lower = search.lower()
        df = df.filter(
            pl.col("title").str.to_lowercase().str.contains(search_lower)
            | pl.col("abstract").str.to_lowercase().str.contains(search_lower)
        )

    # Date filter
    min_date = parse_since(since)
    if min_date:
        df = df.filter(pl.col("update_date") >= min_date)

    return df


def paginate_papers(
    df: pl.DataFrame,
    page: int = 1,
    per_page: int = 20,
    sort: str = "date",
) -> tuple[pl.DataFrame, bool]:
    """Sort and paginate papers, return (page_df, has_more).

    Sort options:
    - "date": By update_date desc, then confidence_score desc
    - "relevance": Keep existing order (for semantic search similarity)
    """
    if sort == "date":
        df_sorted = df.sort(
            ["update_date", "confidence_score"], descending=[True, True]
        )
    else:
        # "relevance" - keep existing order (already sorted by similarity for semantic)
        df_sorted = df

    start = (page - 1) * per_page
    page_df = df_sorted.slice(start, per_page + 1)
    has_more = len(page_df) > per_page

    return page_df.head(per_page), has_more


def semantic_search(
    query: str,
    k: int = 100,
    category: Optional[str] = None,
    min_confidence: float = 0.5,
    since: Optional[str] = None,
) -> pl.DataFrame:
    """Search using vector similarity via Lance nearest neighbor.

    Returns DataFrame with similarity_score column (0-1, higher is more similar).
    """
    model = get_embedding_model()
    query_embedding = model.encode(query).tolist()

    ds = get_lance_dataset()

    # Build SQL filter (Lance supports SQL-like syntax)
    filters = []
    if min_confidence >= 0.5:
        filters.append("is_new_dataset = true")
        filters.append(f"confidence_score >= {min_confidence}")
    elif min_confidence > 0:
        filters.append(f"confidence_score >= {min_confidence}")
    if category:
        # Escape single quotes in category name for SQL safety
        safe_category = category.replace("'", "''")
        filters.append(f"categories LIKE '%{safe_category}%'")
    # Date filter - use TIMESTAMP literal for Lance/DataFusion
    min_date = parse_since(since)
    if min_date:
        filters.append(f"update_date >= TIMESTAMP '{min_date.isoformat()} 00:00:00'")
    filter_str = " AND ".join(filters) if filters else None

    # Vector search - include _distance for similarity calculation
    results = ds.scanner(
        nearest={"column": "embedding", "q": query_embedding, "k": k},
        filter=filter_str,
        columns=["id", "title", "abstract", "categories", "update_date",
                 "authors", "confidence_score", "_distance"]
    ).to_table()

    df = pl.from_arrow(results)

    # Convert L2 distance to similarity score (0-1 range)
    # For normalized embeddings: similarity = 1 - distance/2
    # BGE embeddings are normalized, so L2 distance ranges from 0 to 2
    df = df.with_columns(
        (1 - pl.col("_distance") / 2).clip(0, 1).alias("similarity_score")
    ).drop("_distance")

    return df


@app.get("/", response_class=HTMLResponse)
async def home(
    request: Request,
    search: Optional[str] = Query(None),
    search_type: str = Query("keyword"),
    category: Optional[str] = Query(None),
    min_confidence: str = Query("0.5"),  # String to preserve exact value for template
    since: Optional[str] = Query(None),
    sort: str = Query("date"),
):
    """Render the home page with optional initial filter state from URL."""
    df = get_dataframe()
    categories = get_categories()
    histogram_data = get_histogram_data()
    confidence_options = get_confidence_options()

    # Get stats
    total_papers = len(df)
    new_dataset_count = df.filter(pl.col("is_new_dataset")).height

    return templates.TemplateResponse(
        "index.html",
        {
            "request": request,
            "categories": categories,
            "total_papers": total_papers,
            "new_dataset_count": new_dataset_count,
            "histogram_data": histogram_data,
            "confidence_options": confidence_options,
            # Pass filter state for URL persistence
            "search": search or "",
            "search_type": search_type,
            "category": category or "",
            "min_confidence": min_confidence,
            "since": since or "",
            "sort": sort,
        },
    )


@app.get("/papers", response_class=HTMLResponse)
async def get_papers(
    request: Request,
    page: int = Query(1, ge=1),
    per_page: int = Query(20, ge=1, le=100),
    category: Optional[str] = Query(None),
    search: Optional[str] = Query(None),
    min_confidence: float = Query(0.5, ge=0, le=1),
    search_type: str = Query("keyword"),  # "keyword" or "semantic"
    sort: str = Query("date"),  # "date" or "relevance"
    since: Optional[str] = Query(None),  # "1m", "6m", "1y", or None for all
):
    """Get paginated and filtered papers (returns HTML partial for HTMX).

    If accessed directly (not via HTMX), redirects to home page with same params.
    """
    # Redirect direct browser visits to home page (this endpoint returns partials)
    if "HX-Request" not in request.headers:
        # Build redirect URL with current query params
        query_string = str(request.url.query)
        redirect_url = f"/?{query_string}" if query_string else "/"
        return RedirectResponse(url=redirect_url, status_code=302)

    if search and search_type == "semantic":
        # Vector search - returns pre-sorted by similarity
        filtered_df = semantic_search(
            query=search,
            k=per_page * 5,  # Get more for pagination buffer
            category=category,
            min_confidence=min_confidence,
            since=since,
        )
        # Default to relevance sort for semantic, but allow date sort
        effective_sort = sort if sort == "date" else "relevance"
        page_df, has_more = paginate_papers(
            filtered_df, page=page, per_page=per_page, sort=effective_sort
        )
    else:
        # Existing keyword search path
        df = get_dataframe()
        filtered_df = filter_papers(
            df,
            category=category,
            search=search,
            min_confidence=min_confidence,
            since=since,
        )
        # Keyword search always sorts by date
        page_df, has_more = paginate_papers(
            filtered_df, page=page, per_page=per_page, sort="date"
        )

    # Convert to list of dicts for template
    papers = page_df.to_dicts()

    # Build clean URL for browser history (/ instead of /papers)
    # Only include non-default values to keep URLs short
    params = {}
    if search:
        params["search"] = search
    if search_type != "keyword":
        params["search_type"] = search_type
    if category:
        params["category"] = category
    if min_confidence != 0.5:
        params["min_confidence"] = min_confidence
    if since:
        params["since"] = since
    if sort != "date":
        params["sort"] = sort
    push_url = "/?" + urlencode(params) if params else "/"

    response = templates.TemplateResponse(
        "partials/paper_list.html",
        {
            "request": request,
            "papers": papers,
            "page": page,
            "has_more": has_more,
            "category": category or "",
            "search": search or "",
            "min_confidence": min_confidence,
            "search_type": search_type,
            "sort": sort,
            "since": since or "",
            "total_filtered": len(filtered_df),
        },
    )
    # Tell HTMX to push clean URL (/ not /papers)
    response.headers["HX-Push-Url"] = push_url
    return response


@app.get("/api/stats")
async def get_stats():
    """Get dataset statistics as JSON."""
    df = get_dataframe()

    new_datasets = df.filter(pl.col("is_new_dataset"))

    return {
        "total_papers": len(df),
        "new_dataset_count": len(new_datasets),
        "avg_confidence": float(df["confidence_score"].mean()),
        "date_range": {
            "min": str(df["update_date"].min()),
            "max": str(df["update_date"].max()),
        },
    }


# Preload dataset and model on startup
@app.on_event("startup")
async def startup_event():
    """Preload dataset and embedding model on startup."""
    print("Preloading dataset...")
    get_dataframe()
    print("Dataset loaded!")
    print("Preloading embedding model...")
    get_embedding_model()
    print("Embedding model loaded!")


if __name__ == "__main__":
    import uvicorn

    uvicorn.run(app, host="0.0.0.0", port=7860)