Sophie commited on
Commit
934ea0e
·
1 Parent(s): 15b91ce

added filtering by paper; ranking now uses cosine similarity and citation counts

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +127 -39
src/streamlit_app.py CHANGED
@@ -55,7 +55,7 @@ ALLOWED_TYPES = [
55
  ]
56
 
57
  ARXIV_ID_RE = re.compile(
58
- r'arxiv\.org/(?:abs|pdf)/((?:\d{4}\.\d{4,5}|[a-z\-]+/\d{7}))(?:v\d+)?',
59
  re.IGNORECASE
60
  )
61
 
@@ -262,42 +262,108 @@ def add_citations(candidates: list[dict], max_workers: int = 6) -> None:
262
  except Exception:
263
  pass
264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  # --- Search and Display ---
266
  def search_and_display_with_filters(query, model, theorems_data, embeddings_db, filters):
267
- if not query:
268
- st.info("Please enter a search query.")
269
- return
270
  if not filters['sources']:
271
  st.warning("Please select at least one source.")
272
  return
273
 
274
- query_embedding = model.encode(query, convert_to_tensor=True)
275
- cosine_scores = util.cos_sim(query_embedding, embeddings_db)[0]
 
 
 
 
 
276
 
277
  # Get a larger pool to filter from
278
  top_k_pool = min(200, len(theorems_data))
279
  top_indices = torch.topk(cosine_scores, k=top_k_pool, sorted=True).indices
280
- pool_items = [theorems_data[int(i.item())] for i in top_indices]
281
- add_citations(pool_items)
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
  results = []
284
- low, high = filters['citation_range']
285
 
286
  # Filter results
287
- for item in pool_items:
288
- type_match = (not filters['types']) or (item.get('type','').lower() in filters['types'])
289
- tag_match = (not filters['tags']) or (item.get('primary_category') in filters['tags'])
290
- author_match = (not filters['authors']) or any(a in (item.get('authors') or []) for a in filters['authors'])
291
- source_match = item.get('source') in filters['sources']
292
-
293
- # Citations & year & journal only meaningful for arXiv
294
- cit = item.get('citations')
295
- if cit is None:
 
 
296
  if not filters['include_unknown_citations']:
297
  continue
298
  citation_match = True
299
  else:
300
- citation_match = (low <= int(cit) <= high)
301
 
302
  year_match = True
303
  if filters['year_range'] and item.get('source') == 'arXiv':
@@ -314,11 +380,16 @@ def search_and_display_with_filters(query, model, theorems_data, embeddings_db,
314
  elif status == "Preprint Only":
315
  journal_match = not jp
316
 
317
- if all([type_match, tag_match, author_match, source_match, citation_match, year_match, journal_match]):
318
- results.append({"info": item, "similarity": float(cosine_scores[theorems_data.index(item)].item())})
 
 
319
  if len(results) >= filters['top_k']:
320
  break
321
 
 
 
 
322
  st.subheader(f"Found {len(results)} Matching Results")
323
  if not results:
324
  st.warning("No results found for the current filters.")
@@ -327,17 +398,23 @@ def search_and_display_with_filters(query, model, theorems_data, embeddings_db,
327
  for i, r in enumerate(results):
328
  info = r["info"]
329
  expander_title = f"**Result {i+1} | Similarity: {r['similarity']:.4f} | Type: {info.get('type','').title()}**"
330
- with st.expander(expander_title):
331
  st.markdown(f"**Paper:** *{info.get('paper_title','Unknown')}*")
332
  st.markdown(f"**Authors:** {', '.join(info.get('authors') or []) or 'N/A'}")
333
- st.markdown(f"**Source:** {info.get('source')} ([Link]({info.get('paper_url')}))")
334
- cit = info.get("citations")
335
- cit_str = "Unknown" if cit is None else str(cit)
336
  st.markdown(
337
  f"**Math Tag:** `{info.get('primary_category')}` | "
338
  f"**Citations:** {cit_str} | "
339
  f"**Year:** {info.get('year', 'N/A')}"
340
  )
 
 
 
 
 
 
341
  st.markdown("---")
342
 
343
  if info.get("theorem_slogan"):
@@ -348,9 +425,13 @@ def search_and_display_with_filters(query, model, theorems_data, embeddings_db,
348
  st.markdown("> " + cleaned_ctx.replace("\n", "\n> ") )
349
 
350
  cleaned_content = clean_latex_for_display(info['theorem_body'])
351
- st.markdown("**Theorem Body:**")
352
  st.markdown(cleaned_content)
353
 
 
 
 
 
354
  # --- Main App Interface ---
355
  st.set_page_config(page_title="Theorem Search Demo", layout="wide")
356
  st.title("📚 Semantic Theorem Search")
@@ -380,6 +461,8 @@ if model and theorems_data:
380
  selected_authors, selected_types, selected_tags = [], [], []
381
  year_range, journal_status = None, "All"
382
  citation_range = (0, 1000)
 
 
383
  top_k_results = 5
384
 
385
  if selected_sources:
@@ -388,36 +471,41 @@ if model and theorems_data:
388
  all_authors = sorted(list(set(a for it in theorems_data for a in (it.get('authors') or []))))
389
  selected_authors = st.multiselect("Filter by Author(s):", all_authors)
390
 
391
- # Tags come from union of categories per selected source
392
  from collections import defaultdict
393
  tags_per_source = defaultdict(set)
394
  for it in theorems_data:
395
  tags_per_source[it['source']].add(it.get('primary_category'))
396
  union_tags = sorted({t for s in selected_sources for t in tags_per_source.get(s, set()) if t})
397
  selected_tags = st.multiselect("Filter by Math Tag/Category:", union_tags)
398
-
 
 
 
399
  if 'arXiv' in selected_sources:
400
- year_range = st.slider("Filter by Year (for arXiv):", 1991, 2025, (1991, 2025))
401
- journal_status = st.radio("Publication Status (for arXiv):", ["All", "Journal Article", "Preprint Only"], horizontal=True)
402
-
403
- citation_range = st.slider("Filter by Citations:", 0, 1000, (0, 1000))
404
- include_unknown_citations = st.checkbox(
405
- "Include entries with unknown citation counts",
406
- value=True,
407
- help="If unchecked, results with unknown citation counts are excluded."
408
- )
409
- top_k_results = st.slider("Number of results to display:", 1, 20, 5)
410
 
411
  filters = {
412
  "authors": selected_authors,
413
  "types": [t.lower() for t in selected_types],
414
  "tags": selected_tags,
415
  "sources": selected_sources,
 
416
  "year_range": year_range,
417
  "journal_status": journal_status,
418
  "citation_range": citation_range,
 
419
  "include_unknown_citations": include_unknown_citations,
420
- "top_k": top_k_results
421
  }
422
 
423
  user_query = st.text_input("Enter your query:", "")
 
55
  ]
56
 
57
  ARXIV_ID_RE = re.compile(
58
+ r'(?:arxiv\.org/(?:abs|pdf)/)?((?:\d{4}\.\d{4,5}|[a-z\-]+/\d{7}))',
59
  re.IGNORECASE
60
  )
61
 
 
262
  except Exception:
263
  pass
264
 
265
+ def extract_arxiv_id(s: str) -> str | None:
266
+ """Return normalized arXiv ID if present in s (URL or raw), else None."""
267
+ if not s:
268
+ return None
269
+ m = ARXIV_ID_RE.search(s.strip())
270
+ return m.group(1) if m else None
271
+
272
+ def normalize_title(s: str) -> str:
273
+ return (s or "").casefold().strip()
274
+
275
+ def parse_paper_filter_input(raw: str) -> dict:
276
+ """
277
+ Parse user input into two sets: arxiv_ids and title substrings.
278
+ Multiple entries may be comma-separated.
279
+ e.g. "2401.12345, Optimal Transport" -> {"ids":{"2401.12345"}, "titles":{"optimal transport"}}
280
+ """
281
+ ids, titles = set(), set()
282
+ if not raw:
283
+ return {"ids": ids, "titles": titles}
284
+ for token in [t.strip() for t in raw.split(",") if t.strip()]:
285
+ arx = extract_arxiv_id(token)
286
+ if arx:
287
+ ids.add(arx.lower())
288
+ else:
289
+ titles.add(normalize_title(token))
290
+ return {"ids": ids, "titles": titles}
291
+
292
+ def item_matches_paper_filter(item: dict, paper_filter: dict) -> bool:
293
+ """
294
+ True if the item matches at least one requested arXiv ID or one title substring.
295
+ If paper_filter is empty (both sets empty), always True.
296
+ """
297
+ ids = paper_filter.get("ids", set())
298
+ titles = paper_filter.get("titles", set())
299
+ if not ids and not titles:
300
+ return True
301
+
302
+ # Compare IDs (extract once from url)
303
+ url = item.get("paper_url") or ""
304
+ item_id = extract_arxiv_id(url)
305
+ if item_id and item_id.lower() in ids:
306
+ return True
307
+
308
+ # Compare titles (substring, case-insensitive)
309
+ t = normalize_title(item.get("paper_title"))
310
+ if t and any(sub in t for sub in titles):
311
+ return True
312
+
313
+ return False
314
+
315
  # --- Search and Display ---
316
  def search_and_display_with_filters(query, model, theorems_data, embeddings_db, filters):
 
 
 
317
  if not filters['sources']:
318
  st.warning("Please select at least one source.")
319
  return
320
 
321
+ if query:
322
+ query_embedding = model.encode(query, convert_to_tensor=True)
323
+ cosine_scores = util.cos_sim(query_embedding, embeddings_db)[0]
324
+ else:
325
+ cosine_scores = torch.zeros(len(theorems_data))
326
+
327
+ low, high = filters['citation_range']
328
 
329
  # Get a larger pool to filter from
330
  top_k_pool = min(200, len(theorems_data))
331
  top_indices = torch.topk(cosine_scores, k=top_k_pool, sorted=True).indices
332
+ top_indices = top_indices.tolist()
333
+
334
+ paper_filter = filters.get("paper_filter", {"ids": set(), "titles": set()})
335
+ matched_indices = []
336
+ if paper_filter and (paper_filter.get("ids") or paper_filter.get("titles")):
337
+ for i, it in enumerate(theorems_data):
338
+ if item_matches_paper_filter(it, paper_filter):
339
+ matched_indices.append(i)
340
+
341
+ pool_indices = list(dict.fromkeys(top_indices + matched_indices))
342
+ pool = [(i, theorems_data[i]) for i in pool_indices]
343
+
344
+ # Fetch citations in parallel
345
+ if ('arXiv' in filters['sources']):
346
+ add_citations([it for _, it in pool])
347
 
348
  results = []
 
349
 
350
  # Filter results
351
+ for idx, item in pool:
352
+ type_match = (not filters['types']) or (item.get('type','').lower() in filters['types'])
353
+ tag_match = (not filters['tags']) or (item.get('primary_category') in filters['tags'])
354
+ author_match = (not filters['authors']) or any(a in (item.get('authors') or []) for a in filters['authors'])
355
+ source_match = item.get('source') in filters['sources']
356
+ paper_match = item_matches_paper_filter(item, filters['paper_filter'])
357
+
358
+ # Citations & year & journal only for arXiv
359
+ citations = item.get('citations')
360
+ log_cit = np.log1p(int(citations)) if citations is not None else 0.0
361
+ if citations is None:
362
  if not filters['include_unknown_citations']:
363
  continue
364
  citation_match = True
365
  else:
366
+ citation_match = (low <= int(citations) <= high)
367
 
368
  year_match = True
369
  if filters['year_range'] and item.get('source') == 'arXiv':
 
380
  elif status == "Preprint Only":
381
  journal_match = not jp
382
 
383
+ if all([type_match, tag_match, author_match, source_match, paper_match, citation_match, year_match, journal_match]):
384
+ # Similarity = cosine_similary + citation_weight * log(citation_count)
385
+ similarity = float(cosine_scores[idx].item()) + filters['citation_weight'] * log_cit
386
+ results.append({"idx": idx, "info": item, "similarity": similarity})
387
  if len(results) >= filters['top_k']:
388
  break
389
 
390
+ results.sort(key=lambda r: r["similarity"], reverse=True)
391
+ results = results[:filters['top_k']]
392
+
393
  st.subheader(f"Found {len(results)} Matching Results")
394
  if not results:
395
  st.warning("No results found for the current filters.")
 
398
  for i, r in enumerate(results):
399
  info = r["info"]
400
  expander_title = f"**Result {i+1} | Similarity: {r['similarity']:.4f} | Type: {info.get('type','').title()}**"
401
+ with st.expander(expander_title, expanded=True):
402
  st.markdown(f"**Paper:** *{info.get('paper_title','Unknown')}*")
403
  st.markdown(f"**Authors:** {', '.join(info.get('authors') or []) or 'N/A'}")
404
+ st.markdown(f"**Source:** {info.get('source')} ({info.get('paper_url')})")
405
+ citations = info.get("citations")
406
+ cit_str = "Unknown" if citations is None else str(citations)
407
  st.markdown(
408
  f"**Math Tag:** `{info.get('primary_category')}` | "
409
  f"**Citations:** {cit_str} | "
410
  f"**Year:** {info.get('year', 'N/A')}"
411
  )
412
+ # Testing only
413
+ if filters['citation_weight'] > 0:
414
+ base = float(cosine_scores[r["idx"]].item())
415
+ log_cit = np.log1p(int(citations)) if citations is not None else 0.0
416
+ st.caption(
417
+ f"base_cosine={base:.4f} | log(citations)={log_cit:.4f} | weight={filters['citation_weight']:.2f}")
418
  st.markdown("---")
419
 
420
  if info.get("theorem_slogan"):
 
425
  st.markdown("> " + cleaned_ctx.replace("\n", "\n> ") )
426
 
427
  cleaned_content = clean_latex_for_display(info['theorem_body'])
428
+ st.markdown(f"**{info['theorem_name'] or 'Theorem Body.'}**")
429
  st.markdown(cleaned_content)
430
 
431
+ # Testing only
432
+ st.markdown('**Paper ID (testing only)**')
433
+ st.markdown(info['paper_id'])
434
+
435
  # --- Main App Interface ---
436
  st.set_page_config(page_title="Theorem Search Demo", layout="wide")
437
  st.title("📚 Semantic Theorem Search")
 
461
  selected_authors, selected_types, selected_tags = [], [], []
462
  year_range, journal_status = None, "All"
463
  citation_range = (0, 1000)
464
+ citation_weight = 0.0
465
+ include_unknown_citations = True
466
  top_k_results = 5
467
 
468
  if selected_sources:
 
471
  all_authors = sorted(list(set(a for it in theorems_data for a in (it.get('authors') or []))))
472
  selected_authors = st.multiselect("Filter by Author(s):", all_authors)
473
 
474
+ # Tags come from the union of categories per selected source
475
  from collections import defaultdict
476
  tags_per_source = defaultdict(set)
477
  for it in theorems_data:
478
  tags_per_source[it['source']].add(it.get('primary_category'))
479
  union_tags = sorted({t for s in selected_sources for t in tags_per_source.get(s, set()) if t})
480
  selected_tags = st.multiselect("Filter by Math Tag/Category:", union_tags)
481
+ paper_filter_raw = st.text_input("Filter by Paper",
482
+ value="",
483
+ placeholder="e.g., 2401.12345, Finite Hilbert stability",
484
+ help="Filter by title substring or arXiv ID/URL. Use commas for multiple.")
485
  if 'arXiv' in selected_sources:
486
+ year_range = st.slider("Filter by Year:", 1991, 2025, (1991, 2025))
487
+ journal_status = st.radio("Publication Status:", ["All", "Journal Article", "Preprint Only"], horizontal=True)
488
+ citation_range = st.slider("Filter by Citations:", 0, 1000, (0, 1000))
489
+ citation_weight = st.slider("Citation Weight:", 0.0, 1.0, 0.0, step=0.01)
490
+ include_unknown_citations = st.checkbox(
491
+ "Include entries with unknown citation counts",
492
+ value=True,
493
+ help="If unchecked, results with unknown citation counts are excluded."
494
+ )
495
+ top_k_results = st.slider("Number of Results to Display:", 1, 20, 5)
496
 
497
  filters = {
498
  "authors": selected_authors,
499
  "types": [t.lower() for t in selected_types],
500
  "tags": selected_tags,
501
  "sources": selected_sources,
502
+ "paper_filter": parse_paper_filter_input(paper_filter_raw),
503
  "year_range": year_range,
504
  "journal_status": journal_status,
505
  "citation_range": citation_range,
506
+ "citation_weight": citation_weight,
507
  "include_unknown_citations": include_unknown_citations,
508
+ "top_k": top_k_results,
509
  }
510
 
511
  user_query = st.text_input("Enter your query:", "")