Spaces:
Running
Running
Sophie
commited on
Commit
·
a132e72
1
Parent(s):
71df2d7
optimized app
Browse files- src/streamlit_app.py +128 -113
src/streamlit_app.py
CHANGED
|
@@ -11,6 +11,7 @@ from pgvector.psycopg2 import register_vector
|
|
| 11 |
import re
|
| 12 |
import requests
|
| 13 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
|
|
| 14 |
from dotenv import load_dotenv
|
| 15 |
from latex_clean import clean_latex_for_display
|
| 16 |
|
|
@@ -179,86 +180,55 @@ def load_papers_from_rds():
|
|
| 179 |
return []
|
| 180 |
|
| 181 |
@st.cache_data(ttl=60*60*24) # cache for 24 hours
|
| 182 |
-
def
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
"""
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
return c
|
| 219 |
-
except Exception:
|
| 220 |
-
pass
|
| 221 |
-
# Fallback: Semantic Scholar by title
|
| 222 |
-
if title:
|
| 223 |
-
try:
|
| 224 |
-
r = requests.get(
|
| 225 |
-
"https://api.semanticscholar.org/graph/v1/paper/search",
|
| 226 |
-
params={"query": title, "limit": 1, "fields": "title,citationCount"},
|
| 227 |
-
timeout=10
|
| 228 |
-
)
|
| 229 |
-
if r.ok:
|
| 230 |
-
j = r.json()
|
| 231 |
-
if j.get("data"):
|
| 232 |
-
c = j["data"][0].get("citationCount")
|
| 233 |
-
if isinstance(c, int):
|
| 234 |
-
return c
|
| 235 |
-
except Exception:
|
| 236 |
-
pass
|
| 237 |
-
|
| 238 |
-
return None
|
| 239 |
-
|
| 240 |
-
def add_citations(candidates: list[dict], max_workers: int = 6) -> None:
|
| 241 |
-
# Select targets with missing citations
|
| 242 |
-
targets = [
|
| 243 |
-
it for it in candidates
|
| 244 |
-
if it.get("source") == "arXiv" and (it.get("citations") in (None, 0))
|
| 245 |
-
]
|
| 246 |
-
if not targets:
|
| 247 |
-
return
|
| 248 |
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
pass
|
| 262 |
|
| 263 |
def extract_arxiv_id(s: str) -> str | None:
|
| 264 |
"""Return normalized arXiv ID if present in s (URL or raw), else None."""
|
|
@@ -335,6 +305,8 @@ def search_and_display(query: str, model, filters: dict):
|
|
| 335 |
st.warning("Please select at least one source.")
|
| 336 |
return
|
| 337 |
|
|
|
|
|
|
|
| 338 |
# Encode query to numpy array
|
| 339 |
query_vec = model.encode(query or "", normalize_embeddings=True, convert_to_numpy=True)
|
| 340 |
|
|
@@ -391,34 +363,73 @@ def search_and_display(query: str, model, filters: dict):
|
|
| 391 |
if pf_clauses:
|
| 392 |
where.append("(" + " OR ".join(pf_clauses) + ")")
|
| 393 |
|
| 394 |
-
#
|
| 395 |
if filters['types']:
|
| 396 |
like_any = [f"%{t}%" for t in filters['types']]
|
| 397 |
where.append(" lower(t.name) ILIKE ANY(%s) ")
|
| 398 |
params.append(like_any)
|
| 399 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
sql = f"""
|
| 401 |
WITH latest_slogan AS (
|
| 402 |
SELECT DISTINCT ON (ts.theorem_id)
|
| 403 |
-
ts.theorem_id, ts.slogan_id, ts.slogan
|
| 404 |
FROM theorem_slogan ts
|
| 405 |
ORDER BY ts.theorem_id, ts.slogan_id DESC
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
)
|
| 407 |
SELECT
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
|
|
|
| 420 |
"""
|
| 421 |
-
exec_params = [query_vec, *params, query_vec, int(filters['top_k'])]
|
| 422 |
|
| 423 |
conn = get_rds_connection()
|
| 424 |
cur = conn.cursor()
|
|
@@ -430,8 +441,8 @@ def search_and_display(query: str, model, filters: dict):
|
|
| 430 |
# Populate result fields
|
| 431 |
items = []
|
| 432 |
for (paper_id, title, authors, link, last_updated, summary, journal_ref,
|
| 433 |
-
primary_category, categories, theorem_id, theorem_name,
|
| 434 |
-
theorem_slogan, similarity) in rows:
|
| 435 |
|
| 436 |
# Determine source from url
|
| 437 |
link_str = link or ""
|
|
@@ -450,19 +461,16 @@ def search_and_display(query: str, model, filters: dict):
|
|
| 450 |
"source": source,
|
| 451 |
"type": inferred_type,
|
| 452 |
"journal_published": bool(journal_ref),
|
| 453 |
-
"citations":
|
| 454 |
"theorem_name": theorem_name,
|
| 455 |
"theorem_slogan": theorem_slogan,
|
| 456 |
"theorem_body": theorem_body,
|
| 457 |
"similarity": float(similarity),
|
|
|
|
| 458 |
})
|
| 459 |
|
| 460 |
-
#
|
| 461 |
-
if 'arXiv' in filters['sources']:
|
| 462 |
-
with st.spinner("Fetching citations..."):
|
| 463 |
-
add_citations(items)
|
| 464 |
for it in items:
|
| 465 |
-
# Compute weighted score if applicable
|
| 466 |
it["score"] = compute_score(it["similarity"], it.get("citations"), citation_weight)
|
| 467 |
|
| 468 |
# Sort results by weighted score, then cosine similarity, then paper id
|
|
@@ -514,10 +522,12 @@ st.title("Math Theorem Search")
|
|
| 514 |
st.write("This demo finds mathematical theorems that are semantically similar to your query.")
|
| 515 |
|
| 516 |
model = load_model()
|
| 517 |
-
|
|
|
|
|
|
|
| 518 |
|
| 519 |
-
if model
|
| 520 |
-
st.success(f"Successfully loaded {
|
| 521 |
# --- Sidebar filters ---
|
| 522 |
with st.sidebar:
|
| 523 |
st.header("Search Filters")
|
|
@@ -541,20 +551,23 @@ if model and theorems_data:
|
|
| 541 |
if selected_sources:
|
| 542 |
st.write("---")
|
| 543 |
selected_types = st.multiselect("Filter by Type:", ALLOWED_TYPES)
|
| 544 |
-
|
| 545 |
-
selected_authors = st.multiselect("Filter by Author(s):", all_authors)
|
| 546 |
|
| 547 |
-
# Tags
|
| 548 |
-
from collections import defaultdict
|
| 549 |
tags_per_source = defaultdict(set)
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
|
|
|
|
|
|
|
|
|
| 553 |
selected_tags = st.multiselect("Filter by Tag/Category:", union_tags)
|
|
|
|
| 554 |
paper_filter = st.text_input("Filter by Paper",
|
| 555 |
value="",
|
| 556 |
placeholder="e.g., 2401.12345, Finite Hilbert stability",
|
| 557 |
help="Filter by title substring or arXiv ID/URL. Use commas for multiple.")
|
|
|
|
| 558 |
if 'arXiv' in selected_sources:
|
| 559 |
year_range = st.slider("Filter by Year:", 1991, 2025, (1991, 2025))
|
| 560 |
journal_status = st.radio("Publication Status:",
|
|
@@ -569,6 +582,7 @@ if model and theorems_data:
|
|
| 569 |
value=True,
|
| 570 |
help="If unchecked, results with unknown citation counts are excluded."
|
| 571 |
)
|
|
|
|
| 572 |
top_k_results = st.slider("Number of Results to Display:", 1, 20, 5)
|
| 573 |
|
| 574 |
filters = {
|
|
@@ -587,6 +601,7 @@ if model and theorems_data:
|
|
| 587 |
|
| 588 |
user_query = st.text_input("Enter your query:", "")
|
| 589 |
if st.button("Search") or user_query:
|
| 590 |
-
|
|
|
|
| 591 |
else:
|
| 592 |
st.error("Could not load the model or data from RDS. Please check your RDS database connection and credentials.")
|
|
|
|
| 11 |
import re
|
| 12 |
import requests
|
| 13 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 14 |
+
from collections import defaultdict
|
| 15 |
from dotenv import load_dotenv
|
| 16 |
from latex_clean import clean_latex_for_display
|
| 17 |
|
|
|
|
| 180 |
return []
|
| 181 |
|
| 182 |
@st.cache_data(ttl=60*60*24) # cache for 24 hours
|
| 183 |
+
def load_authors():
|
| 184 |
+
conn = get_rds_connection()
|
| 185 |
+
cur = conn.cursor()
|
| 186 |
+
cur.execute("""
|
| 187 |
+
SELECT DISTINCT unnest(p.authors) AS author
|
| 188 |
+
FROM paper p
|
| 189 |
+
WHERE p.authors IS NOT NULL
|
| 190 |
+
""")
|
| 191 |
+
rows = cur.fetchall()
|
| 192 |
+
cur.close()
|
| 193 |
+
conn.close()
|
| 194 |
+
authors = sorted(r[0] for r in rows if r[0])
|
| 195 |
+
return authors
|
| 196 |
+
|
| 197 |
+
@st.cache_data(ttl=60*60*24) # cache for 24 hours
|
| 198 |
+
def load_tags_per_source():
|
| 199 |
+
conn = get_rds_connection()
|
| 200 |
+
cur = conn.cursor()
|
| 201 |
+
cur.execute("""
|
| 202 |
+
SELECT
|
| 203 |
+
CASE WHEN p.link ILIKE '%%arxiv.org%%'
|
| 204 |
+
THEN 'arXiv'
|
| 205 |
+
ELSE 'Stacks Project'
|
| 206 |
+
END AS source,
|
| 207 |
+
p.primary_category
|
| 208 |
+
FROM paper p
|
| 209 |
+
WHERE p.primary_category IS NOT NULL
|
| 210 |
+
""")
|
| 211 |
+
rows = cur.fetchall()
|
| 212 |
+
cur.close()
|
| 213 |
+
conn.close()
|
| 214 |
+
|
| 215 |
+
from collections import defaultdict
|
| 216 |
+
tags_per_source = defaultdict(set)
|
| 217 |
+
for source, cat in rows:
|
| 218 |
+
tags_per_source[source].add(cat)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
|
| 220 |
+
# Make them plain lists so Streamlit cache can serialize easily
|
| 221 |
+
return {src: sorted(cats) for src, cats in tags_per_source.items()}
|
| 222 |
+
|
| 223 |
+
@st.cache_data(ttl=60*60*24) # cache for 24 hours
|
| 224 |
+
def load_theorem_count():
|
| 225 |
+
conn = get_rds_connection()
|
| 226 |
+
cur = conn.cursor()
|
| 227 |
+
cur.execute("SELECT COUNT(*) FROM theorem;")
|
| 228 |
+
(n,) = cur.fetchone()
|
| 229 |
+
cur.close()
|
| 230 |
+
conn.close()
|
| 231 |
+
return int(n)
|
|
|
|
| 232 |
|
| 233 |
def extract_arxiv_id(s: str) -> str | None:
|
| 234 |
"""Return normalized arXiv ID if present in s (URL or raw), else None."""
|
|
|
|
| 305 |
st.warning("Please select at least one source.")
|
| 306 |
return
|
| 307 |
|
| 308 |
+
citation_weight = float(filters['citation_weight'])
|
| 309 |
+
|
| 310 |
# Encode query to numpy array
|
| 311 |
query_vec = model.encode(query or "", normalize_embeddings=True, convert_to_numpy=True)
|
| 312 |
|
|
|
|
| 363 |
if pf_clauses:
|
| 364 |
where.append("(" + " OR ".join(pf_clauses) + ")")
|
| 365 |
|
| 366 |
+
# Result type
|
| 367 |
if filters['types']:
|
| 368 |
like_any = [f"%{t}%" for t in filters['types']]
|
| 369 |
where.append(" lower(t.name) ILIKE ANY(%s) ")
|
| 370 |
params.append(like_any)
|
| 371 |
|
| 372 |
+
# Citations
|
| 373 |
+
low, high = filters["citation_range"]
|
| 374 |
+
include_unknown = filters["include_unknown_citations"]
|
| 375 |
+
|
| 376 |
+
if include_unknown:
|
| 377 |
+
where.append("( (p.citations BETWEEN %s AND %s) OR p.citations IS NULL )")
|
| 378 |
+
else:
|
| 379 |
+
where.append("( p.citations IS NOT NULL AND (p.citations BETWEEN %s AND %s) )")
|
| 380 |
+
|
| 381 |
+
params.extend([low, high])
|
| 382 |
+
|
| 383 |
+
pool_size = max(50, int(filters['top_k']) * 10)
|
| 384 |
+
|
| 385 |
sql = f"""
|
| 386 |
WITH latest_slogan AS (
|
| 387 |
SELECT DISTINCT ON (ts.theorem_id)
|
| 388 |
+
ts.theorem_id, ts.slogan_id, ts.slogan
|
| 389 |
FROM theorem_slogan ts
|
| 390 |
ORDER BY ts.theorem_id, ts.slogan_id DESC
|
| 391 |
+
),
|
| 392 |
+
candidates AS (
|
| 393 |
+
SELECT
|
| 394 |
+
p.paper_id,
|
| 395 |
+
p.title,
|
| 396 |
+
p.authors,
|
| 397 |
+
p.link,
|
| 398 |
+
p.last_updated,
|
| 399 |
+
p.summary,
|
| 400 |
+
p.journal_ref,
|
| 401 |
+
p.primary_category,
|
| 402 |
+
p.categories,
|
| 403 |
+
p.citations,
|
| 404 |
+
t.theorem_id,
|
| 405 |
+
t.name AS theorem_name,
|
| 406 |
+
t.body AS theorem_body,
|
| 407 |
+
ls.slogan AS theorem_slogan,
|
| 408 |
+
(1.0 - (e.embedding <#> %s::vector)) AS similarity
|
| 409 |
+
FROM paper p
|
| 410 |
+
JOIN theorem t ON t.paper_id = p.paper_id
|
| 411 |
+
JOIN latest_slogan ls ON ls.theorem_id = t.theorem_id
|
| 412 |
+
JOIN {EMBED_TABLE} e ON e.slogan_id = ls.slogan_id
|
| 413 |
+
{'WHERE ' + ' AND '.join(where) if where else ''}
|
| 414 |
+
ORDER BY e.embedding <#> %s::vector ASC
|
| 415 |
+
LIMIT {pool_size}
|
| 416 |
)
|
| 417 |
SELECT
|
| 418 |
+
*,
|
| 419 |
+
(
|
| 420 |
+
similarity +
|
| 421 |
+
%s *
|
| 422 |
+
CASE
|
| 423 |
+
WHEN citations IS NOT NULL AND citations > 0
|
| 424 |
+
THEN ln(citations::float)
|
| 425 |
+
ELSE 0
|
| 426 |
+
END
|
| 427 |
+
) AS weighted_score
|
| 428 |
+
FROM candidates
|
| 429 |
+
ORDER BY weighted_score DESC, similarity DESC
|
| 430 |
+
LIMIT %s;
|
| 431 |
"""
|
| 432 |
+
exec_params = [query_vec, *params, query_vec, citation_weight, int(filters['top_k'])]
|
| 433 |
|
| 434 |
conn = get_rds_connection()
|
| 435 |
cur = conn.cursor()
|
|
|
|
| 441 |
# Populate result fields
|
| 442 |
items = []
|
| 443 |
for (paper_id, title, authors, link, last_updated, summary, journal_ref,
|
| 444 |
+
primary_category, categories, citations, theorem_id, theorem_name,
|
| 445 |
+
theorem_body, theorem_slogan, similarity, weighted_score) in rows:
|
| 446 |
|
| 447 |
# Determine source from url
|
| 448 |
link_str = link or ""
|
|
|
|
| 461 |
"source": source,
|
| 462 |
"type": inferred_type,
|
| 463 |
"journal_published": bool(journal_ref),
|
| 464 |
+
"citations": citations,
|
| 465 |
"theorem_name": theorem_name,
|
| 466 |
"theorem_slogan": theorem_slogan,
|
| 467 |
"theorem_body": theorem_body,
|
| 468 |
"similarity": float(similarity),
|
| 469 |
+
"score": weighted_score
|
| 470 |
})
|
| 471 |
|
| 472 |
+
# Compute weighted citation score if applicable
|
|
|
|
|
|
|
|
|
|
| 473 |
for it in items:
|
|
|
|
| 474 |
it["score"] = compute_score(it["similarity"], it.get("citations"), citation_weight)
|
| 475 |
|
| 476 |
# Sort results by weighted score, then cosine similarity, then paper id
|
|
|
|
| 522 |
st.write("This demo finds mathematical theorems that are semantically similar to your query.")
|
| 523 |
|
| 524 |
model = load_model()
|
| 525 |
+
theorem_count = load_theorem_count()
|
| 526 |
+
authors = load_authors()
|
| 527 |
+
tags_per_source = load_tags_per_source()
|
| 528 |
|
| 529 |
+
if model:
|
| 530 |
+
st.success(f"Successfully loaded {theorem_count} theorems from arXiv and the Stacks Project. Ready to search!")
|
| 531 |
# --- Sidebar filters ---
|
| 532 |
with st.sidebar:
|
| 533 |
st.header("Search Filters")
|
|
|
|
| 551 |
if selected_sources:
|
| 552 |
st.write("---")
|
| 553 |
selected_types = st.multiselect("Filter by Type:", ALLOWED_TYPES)
|
| 554 |
+
selected_authors = st.multiselect("Filter by Author(s):", authors)
|
|
|
|
| 555 |
|
| 556 |
+
# Tags per selected source(s)
|
|
|
|
| 557 |
tags_per_source = defaultdict(set)
|
| 558 |
+
union_tags = sorted({
|
| 559 |
+
t
|
| 560 |
+
for s in selected_sources
|
| 561 |
+
for t in tags_per_source.get(s, [])
|
| 562 |
+
if t
|
| 563 |
+
})
|
| 564 |
selected_tags = st.multiselect("Filter by Tag/Category:", union_tags)
|
| 565 |
+
|
| 566 |
paper_filter = st.text_input("Filter by Paper",
|
| 567 |
value="",
|
| 568 |
placeholder="e.g., 2401.12345, Finite Hilbert stability",
|
| 569 |
help="Filter by title substring or arXiv ID/URL. Use commas for multiple.")
|
| 570 |
+
|
| 571 |
if 'arXiv' in selected_sources:
|
| 572 |
year_range = st.slider("Filter by Year:", 1991, 2025, (1991, 2025))
|
| 573 |
journal_status = st.radio("Publication Status:",
|
|
|
|
| 582 |
value=True,
|
| 583 |
help="If unchecked, results with unknown citation counts are excluded."
|
| 584 |
)
|
| 585 |
+
|
| 586 |
top_k_results = st.slider("Number of Results to Display:", 1, 20, 5)
|
| 587 |
|
| 588 |
filters = {
|
|
|
|
| 601 |
|
| 602 |
user_query = st.text_input("Enter your query:", "")
|
| 603 |
if st.button("Search") or user_query:
|
| 604 |
+
with st.spinner("Fetching theorems..."):
|
| 605 |
+
search_and_display(user_query, model, filters)
|
| 606 |
else:
|
| 607 |
st.error("Could not load the model or data from RDS. Please check your RDS database connection and credentials.")
|