Spaces:
Running
Running
Sophie
commited on
Commit
·
934ea0e
1
Parent(s):
15b91ce
added filtering by paper; ranking now uses cosine similarity and citation counts
Browse files- src/streamlit_app.py +127 -39
src/streamlit_app.py
CHANGED
|
@@ -55,7 +55,7 @@ ALLOWED_TYPES = [
|
|
| 55 |
]
|
| 56 |
|
| 57 |
ARXIV_ID_RE = re.compile(
|
| 58 |
-
r'arxiv\.org/(?:abs|pdf)/((?:\d{4}\.\d{4,5}|[a-z\-]+/\d{7}))
|
| 59 |
re.IGNORECASE
|
| 60 |
)
|
| 61 |
|
|
@@ -262,42 +262,108 @@ def add_citations(candidates: list[dict], max_workers: int = 6) -> None:
|
|
| 262 |
except Exception:
|
| 263 |
pass
|
| 264 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
# --- Search and Display ---
|
| 266 |
def search_and_display_with_filters(query, model, theorems_data, embeddings_db, filters):
|
| 267 |
-
if not query:
|
| 268 |
-
st.info("Please enter a search query.")
|
| 269 |
-
return
|
| 270 |
if not filters['sources']:
|
| 271 |
st.warning("Please select at least one source.")
|
| 272 |
return
|
| 273 |
|
| 274 |
-
|
| 275 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
|
| 277 |
# Get a larger pool to filter from
|
| 278 |
top_k_pool = min(200, len(theorems_data))
|
| 279 |
top_indices = torch.topk(cosine_scores, k=top_k_pool, sorted=True).indices
|
| 280 |
-
|
| 281 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
|
| 283 |
results = []
|
| 284 |
-
low, high = filters['citation_range']
|
| 285 |
|
| 286 |
# Filter results
|
| 287 |
-
for item in
|
| 288 |
-
type_match
|
| 289 |
-
tag_match
|
| 290 |
-
author_match
|
| 291 |
-
source_match
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
|
|
|
|
|
|
| 296 |
if not filters['include_unknown_citations']:
|
| 297 |
continue
|
| 298 |
citation_match = True
|
| 299 |
else:
|
| 300 |
-
citation_match = (low <= int(
|
| 301 |
|
| 302 |
year_match = True
|
| 303 |
if filters['year_range'] and item.get('source') == 'arXiv':
|
|
@@ -314,11 +380,16 @@ def search_and_display_with_filters(query, model, theorems_data, embeddings_db,
|
|
| 314 |
elif status == "Preprint Only":
|
| 315 |
journal_match = not jp
|
| 316 |
|
| 317 |
-
if all([type_match, tag_match, author_match, source_match, citation_match, year_match, journal_match]):
|
| 318 |
-
|
|
|
|
|
|
|
| 319 |
if len(results) >= filters['top_k']:
|
| 320 |
break
|
| 321 |
|
|
|
|
|
|
|
|
|
|
| 322 |
st.subheader(f"Found {len(results)} Matching Results")
|
| 323 |
if not results:
|
| 324 |
st.warning("No results found for the current filters.")
|
|
@@ -327,17 +398,23 @@ def search_and_display_with_filters(query, model, theorems_data, embeddings_db,
|
|
| 327 |
for i, r in enumerate(results):
|
| 328 |
info = r["info"]
|
| 329 |
expander_title = f"**Result {i+1} | Similarity: {r['similarity']:.4f} | Type: {info.get('type','').title()}**"
|
| 330 |
-
with st.expander(expander_title):
|
| 331 |
st.markdown(f"**Paper:** *{info.get('paper_title','Unknown')}*")
|
| 332 |
st.markdown(f"**Authors:** {', '.join(info.get('authors') or []) or 'N/A'}")
|
| 333 |
-
st.markdown(f"**Source:** {info.get('source')} (
|
| 334 |
-
|
| 335 |
-
cit_str = "Unknown" if
|
| 336 |
st.markdown(
|
| 337 |
f"**Math Tag:** `{info.get('primary_category')}` | "
|
| 338 |
f"**Citations:** {cit_str} | "
|
| 339 |
f"**Year:** {info.get('year', 'N/A')}"
|
| 340 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
st.markdown("---")
|
| 342 |
|
| 343 |
if info.get("theorem_slogan"):
|
|
@@ -348,9 +425,13 @@ def search_and_display_with_filters(query, model, theorems_data, embeddings_db,
|
|
| 348 |
st.markdown("> " + cleaned_ctx.replace("\n", "\n> ") )
|
| 349 |
|
| 350 |
cleaned_content = clean_latex_for_display(info['theorem_body'])
|
| 351 |
-
st.markdown("**Theorem Body
|
| 352 |
st.markdown(cleaned_content)
|
| 353 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
# --- Main App Interface ---
|
| 355 |
st.set_page_config(page_title="Theorem Search Demo", layout="wide")
|
| 356 |
st.title("📚 Semantic Theorem Search")
|
|
@@ -380,6 +461,8 @@ if model and theorems_data:
|
|
| 380 |
selected_authors, selected_types, selected_tags = [], [], []
|
| 381 |
year_range, journal_status = None, "All"
|
| 382 |
citation_range = (0, 1000)
|
|
|
|
|
|
|
| 383 |
top_k_results = 5
|
| 384 |
|
| 385 |
if selected_sources:
|
|
@@ -388,36 +471,41 @@ if model and theorems_data:
|
|
| 388 |
all_authors = sorted(list(set(a for it in theorems_data for a in (it.get('authors') or []))))
|
| 389 |
selected_authors = st.multiselect("Filter by Author(s):", all_authors)
|
| 390 |
|
| 391 |
-
# Tags come from union of categories per selected source
|
| 392 |
from collections import defaultdict
|
| 393 |
tags_per_source = defaultdict(set)
|
| 394 |
for it in theorems_data:
|
| 395 |
tags_per_source[it['source']].add(it.get('primary_category'))
|
| 396 |
union_tags = sorted({t for s in selected_sources for t in tags_per_source.get(s, set()) if t})
|
| 397 |
selected_tags = st.multiselect("Filter by Math Tag/Category:", union_tags)
|
| 398 |
-
|
|
|
|
|
|
|
|
|
|
| 399 |
if 'arXiv' in selected_sources:
|
| 400 |
-
year_range = st.slider("Filter by Year
|
| 401 |
-
journal_status = st.radio("Publication Status
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
top_k_results = st.slider("Number of
|
| 410 |
|
| 411 |
filters = {
|
| 412 |
"authors": selected_authors,
|
| 413 |
"types": [t.lower() for t in selected_types],
|
| 414 |
"tags": selected_tags,
|
| 415 |
"sources": selected_sources,
|
|
|
|
| 416 |
"year_range": year_range,
|
| 417 |
"journal_status": journal_status,
|
| 418 |
"citation_range": citation_range,
|
|
|
|
| 419 |
"include_unknown_citations": include_unknown_citations,
|
| 420 |
-
"top_k": top_k_results
|
| 421 |
}
|
| 422 |
|
| 423 |
user_query = st.text_input("Enter your query:", "")
|
|
|
|
| 55 |
]
|
| 56 |
|
| 57 |
ARXIV_ID_RE = re.compile(
|
| 58 |
+
r'(?:arxiv\.org/(?:abs|pdf)/)?((?:\d{4}\.\d{4,5}|[a-z\-]+/\d{7}))',
|
| 59 |
re.IGNORECASE
|
| 60 |
)
|
| 61 |
|
|
|
|
| 262 |
except Exception:
|
| 263 |
pass
|
| 264 |
|
| 265 |
+
def extract_arxiv_id(s: str) -> str | None:
|
| 266 |
+
"""Return normalized arXiv ID if present in s (URL or raw), else None."""
|
| 267 |
+
if not s:
|
| 268 |
+
return None
|
| 269 |
+
m = ARXIV_ID_RE.search(s.strip())
|
| 270 |
+
return m.group(1) if m else None
|
| 271 |
+
|
| 272 |
+
def normalize_title(s: str) -> str:
|
| 273 |
+
return (s or "").casefold().strip()
|
| 274 |
+
|
| 275 |
+
def parse_paper_filter_input(raw: str) -> dict:
|
| 276 |
+
"""
|
| 277 |
+
Parse user input into two sets: arxiv_ids and title substrings.
|
| 278 |
+
Multiple entries may be comma-separated.
|
| 279 |
+
e.g. "2401.12345, Optimal Transport" -> {"ids":{"2401.12345"}, "titles":{"optimal transport"}}
|
| 280 |
+
"""
|
| 281 |
+
ids, titles = set(), set()
|
| 282 |
+
if not raw:
|
| 283 |
+
return {"ids": ids, "titles": titles}
|
| 284 |
+
for token in [t.strip() for t in raw.split(",") if t.strip()]:
|
| 285 |
+
arx = extract_arxiv_id(token)
|
| 286 |
+
if arx:
|
| 287 |
+
ids.add(arx.lower())
|
| 288 |
+
else:
|
| 289 |
+
titles.add(normalize_title(token))
|
| 290 |
+
return {"ids": ids, "titles": titles}
|
| 291 |
+
|
| 292 |
+
def item_matches_paper_filter(item: dict, paper_filter: dict) -> bool:
|
| 293 |
+
"""
|
| 294 |
+
True if the item matches at least one requested arXiv ID or one title substring.
|
| 295 |
+
If paper_filter is empty (both sets empty), always True.
|
| 296 |
+
"""
|
| 297 |
+
ids = paper_filter.get("ids", set())
|
| 298 |
+
titles = paper_filter.get("titles", set())
|
| 299 |
+
if not ids and not titles:
|
| 300 |
+
return True
|
| 301 |
+
|
| 302 |
+
# Compare IDs (extract once from url)
|
| 303 |
+
url = item.get("paper_url") or ""
|
| 304 |
+
item_id = extract_arxiv_id(url)
|
| 305 |
+
if item_id and item_id.lower() in ids:
|
| 306 |
+
return True
|
| 307 |
+
|
| 308 |
+
# Compare titles (substring, case-insensitive)
|
| 309 |
+
t = normalize_title(item.get("paper_title"))
|
| 310 |
+
if t and any(sub in t for sub in titles):
|
| 311 |
+
return True
|
| 312 |
+
|
| 313 |
+
return False
|
| 314 |
+
|
| 315 |
# --- Search and Display ---
|
| 316 |
def search_and_display_with_filters(query, model, theorems_data, embeddings_db, filters):
|
|
|
|
|
|
|
|
|
|
| 317 |
if not filters['sources']:
|
| 318 |
st.warning("Please select at least one source.")
|
| 319 |
return
|
| 320 |
|
| 321 |
+
if query:
|
| 322 |
+
query_embedding = model.encode(query, convert_to_tensor=True)
|
| 323 |
+
cosine_scores = util.cos_sim(query_embedding, embeddings_db)[0]
|
| 324 |
+
else:
|
| 325 |
+
cosine_scores = torch.zeros(len(theorems_data))
|
| 326 |
+
|
| 327 |
+
low, high = filters['citation_range']
|
| 328 |
|
| 329 |
# Get a larger pool to filter from
|
| 330 |
top_k_pool = min(200, len(theorems_data))
|
| 331 |
top_indices = torch.topk(cosine_scores, k=top_k_pool, sorted=True).indices
|
| 332 |
+
top_indices = top_indices.tolist()
|
| 333 |
+
|
| 334 |
+
paper_filter = filters.get("paper_filter", {"ids": set(), "titles": set()})
|
| 335 |
+
matched_indices = []
|
| 336 |
+
if paper_filter and (paper_filter.get("ids") or paper_filter.get("titles")):
|
| 337 |
+
for i, it in enumerate(theorems_data):
|
| 338 |
+
if item_matches_paper_filter(it, paper_filter):
|
| 339 |
+
matched_indices.append(i)
|
| 340 |
+
|
| 341 |
+
pool_indices = list(dict.fromkeys(top_indices + matched_indices))
|
| 342 |
+
pool = [(i, theorems_data[i]) for i in pool_indices]
|
| 343 |
+
|
| 344 |
+
# Fetch citations in parallel
|
| 345 |
+
if ('arXiv' in filters['sources']):
|
| 346 |
+
add_citations([it for _, it in pool])
|
| 347 |
|
| 348 |
results = []
|
|
|
|
| 349 |
|
| 350 |
# Filter results
|
| 351 |
+
for idx, item in pool:
|
| 352 |
+
type_match = (not filters['types']) or (item.get('type','').lower() in filters['types'])
|
| 353 |
+
tag_match = (not filters['tags']) or (item.get('primary_category') in filters['tags'])
|
| 354 |
+
author_match = (not filters['authors']) or any(a in (item.get('authors') or []) for a in filters['authors'])
|
| 355 |
+
source_match = item.get('source') in filters['sources']
|
| 356 |
+
paper_match = item_matches_paper_filter(item, filters['paper_filter'])
|
| 357 |
+
|
| 358 |
+
# Citations & year & journal only for arXiv
|
| 359 |
+
citations = item.get('citations')
|
| 360 |
+
log_cit = np.log1p(int(citations)) if citations is not None else 0.0
|
| 361 |
+
if citations is None:
|
| 362 |
if not filters['include_unknown_citations']:
|
| 363 |
continue
|
| 364 |
citation_match = True
|
| 365 |
else:
|
| 366 |
+
citation_match = (low <= int(citations) <= high)
|
| 367 |
|
| 368 |
year_match = True
|
| 369 |
if filters['year_range'] and item.get('source') == 'arXiv':
|
|
|
|
| 380 |
elif status == "Preprint Only":
|
| 381 |
journal_match = not jp
|
| 382 |
|
| 383 |
+
if all([type_match, tag_match, author_match, source_match, paper_match, citation_match, year_match, journal_match]):
|
| 384 |
+
# Similarity = cosine_similary + citation_weight * log(citation_count)
|
| 385 |
+
similarity = float(cosine_scores[idx].item()) + filters['citation_weight'] * log_cit
|
| 386 |
+
results.append({"idx": idx, "info": item, "similarity": similarity})
|
| 387 |
if len(results) >= filters['top_k']:
|
| 388 |
break
|
| 389 |
|
| 390 |
+
results.sort(key=lambda r: r["similarity"], reverse=True)
|
| 391 |
+
results = results[:filters['top_k']]
|
| 392 |
+
|
| 393 |
st.subheader(f"Found {len(results)} Matching Results")
|
| 394 |
if not results:
|
| 395 |
st.warning("No results found for the current filters.")
|
|
|
|
| 398 |
for i, r in enumerate(results):
|
| 399 |
info = r["info"]
|
| 400 |
expander_title = f"**Result {i+1} | Similarity: {r['similarity']:.4f} | Type: {info.get('type','').title()}**"
|
| 401 |
+
with st.expander(expander_title, expanded=True):
|
| 402 |
st.markdown(f"**Paper:** *{info.get('paper_title','Unknown')}*")
|
| 403 |
st.markdown(f"**Authors:** {', '.join(info.get('authors') or []) or 'N/A'}")
|
| 404 |
+
st.markdown(f"**Source:** {info.get('source')} ({info.get('paper_url')})")
|
| 405 |
+
citations = info.get("citations")
|
| 406 |
+
cit_str = "Unknown" if citations is None else str(citations)
|
| 407 |
st.markdown(
|
| 408 |
f"**Math Tag:** `{info.get('primary_category')}` | "
|
| 409 |
f"**Citations:** {cit_str} | "
|
| 410 |
f"**Year:** {info.get('year', 'N/A')}"
|
| 411 |
)
|
| 412 |
+
# Testing only
|
| 413 |
+
if filters['citation_weight'] > 0:
|
| 414 |
+
base = float(cosine_scores[r["idx"]].item())
|
| 415 |
+
log_cit = np.log1p(int(citations)) if citations is not None else 0.0
|
| 416 |
+
st.caption(
|
| 417 |
+
f"base_cosine={base:.4f} | log(citations)={log_cit:.4f} | weight={filters['citation_weight']:.2f}")
|
| 418 |
st.markdown("---")
|
| 419 |
|
| 420 |
if info.get("theorem_slogan"):
|
|
|
|
| 425 |
st.markdown("> " + cleaned_ctx.replace("\n", "\n> ") )
|
| 426 |
|
| 427 |
cleaned_content = clean_latex_for_display(info['theorem_body'])
|
| 428 |
+
st.markdown(f"**{info['theorem_name'] or 'Theorem Body.'}**")
|
| 429 |
st.markdown(cleaned_content)
|
| 430 |
|
| 431 |
+
# Testing only
|
| 432 |
+
st.markdown('**Paper ID (testing only)**')
|
| 433 |
+
st.markdown(info['paper_id'])
|
| 434 |
+
|
| 435 |
# --- Main App Interface ---
|
| 436 |
st.set_page_config(page_title="Theorem Search Demo", layout="wide")
|
| 437 |
st.title("📚 Semantic Theorem Search")
|
|
|
|
| 461 |
selected_authors, selected_types, selected_tags = [], [], []
|
| 462 |
year_range, journal_status = None, "All"
|
| 463 |
citation_range = (0, 1000)
|
| 464 |
+
citation_weight = 0.0
|
| 465 |
+
include_unknown_citations = True
|
| 466 |
top_k_results = 5
|
| 467 |
|
| 468 |
if selected_sources:
|
|
|
|
| 471 |
all_authors = sorted(list(set(a for it in theorems_data for a in (it.get('authors') or []))))
|
| 472 |
selected_authors = st.multiselect("Filter by Author(s):", all_authors)
|
| 473 |
|
| 474 |
+
# Tags come from the union of categories per selected source
|
| 475 |
from collections import defaultdict
|
| 476 |
tags_per_source = defaultdict(set)
|
| 477 |
for it in theorems_data:
|
| 478 |
tags_per_source[it['source']].add(it.get('primary_category'))
|
| 479 |
union_tags = sorted({t for s in selected_sources for t in tags_per_source.get(s, set()) if t})
|
| 480 |
selected_tags = st.multiselect("Filter by Math Tag/Category:", union_tags)
|
| 481 |
+
paper_filter_raw = st.text_input("Filter by Paper",
|
| 482 |
+
value="",
|
| 483 |
+
placeholder="e.g., 2401.12345, Finite Hilbert stability",
|
| 484 |
+
help="Filter by title substring or arXiv ID/URL. Use commas for multiple.")
|
| 485 |
if 'arXiv' in selected_sources:
|
| 486 |
+
year_range = st.slider("Filter by Year:", 1991, 2025, (1991, 2025))
|
| 487 |
+
journal_status = st.radio("Publication Status:", ["All", "Journal Article", "Preprint Only"], horizontal=True)
|
| 488 |
+
citation_range = st.slider("Filter by Citations:", 0, 1000, (0, 1000))
|
| 489 |
+
citation_weight = st.slider("Citation Weight:", 0.0, 1.0, 0.0, step=0.01)
|
| 490 |
+
include_unknown_citations = st.checkbox(
|
| 491 |
+
"Include entries with unknown citation counts",
|
| 492 |
+
value=True,
|
| 493 |
+
help="If unchecked, results with unknown citation counts are excluded."
|
| 494 |
+
)
|
| 495 |
+
top_k_results = st.slider("Number of Results to Display:", 1, 20, 5)
|
| 496 |
|
| 497 |
filters = {
|
| 498 |
"authors": selected_authors,
|
| 499 |
"types": [t.lower() for t in selected_types],
|
| 500 |
"tags": selected_tags,
|
| 501 |
"sources": selected_sources,
|
| 502 |
+
"paper_filter": parse_paper_filter_input(paper_filter_raw),
|
| 503 |
"year_range": year_range,
|
| 504 |
"journal_status": journal_status,
|
| 505 |
"citation_range": citation_range,
|
| 506 |
+
"citation_weight": citation_weight,
|
| 507 |
"include_unknown_citations": include_unknown_citations,
|
| 508 |
+
"top_k": top_k_results,
|
| 509 |
}
|
| 510 |
|
| 511 |
user_query = st.text_input("Enter your query:", "")
|