Sophie commited on
Commit
a132e72
·
1 Parent(s): 71df2d7

optimized app

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +128 -113
src/streamlit_app.py CHANGED
@@ -11,6 +11,7 @@ from pgvector.psycopg2 import register_vector
11
  import re
12
  import requests
13
  from concurrent.futures import ThreadPoolExecutor, as_completed
 
14
  from dotenv import load_dotenv
15
  from latex_clean import clean_latex_for_display
16
 
@@ -179,86 +180,55 @@ def load_papers_from_rds():
179
  return []
180
 
181
  @st.cache_data(ttl=60*60*24) # cache for 24 hours
182
- def fetch_citations(paper_url: str, title: str) -> int | None:
183
- """
184
- Returns citation count if found, else None.
185
- Tries the following sources in order:
186
- 1) OpenAlex by arXiv id
187
- 2) Semantic Scholar by arXiv id
188
- 3) Semantic Scholar by title
189
- """
190
- arx_id = None
191
- if paper_url:
192
- m = ARXIV_ID_RE.search(paper_url)
193
- if m:
194
- arx_id = m.group(1)
195
- # OpenAlex by arXiv id
196
- if arx_id:
197
- try:
198
- r = requests.get(f"https://api.openalex.org/works/arXiv:{arx_id}", timeout=10)
199
- if r.ok:
200
- data = r.json()
201
- c = data.get("cited_by_count")
202
- if isinstance(c, int):
203
- return c
204
- except Exception:
205
- pass
206
- # Semantic Scholar by arXiv id
207
- if arx_id:
208
- try:
209
- r = requests.get(
210
- f"https://api.semanticscholar.org/graph/v1/paper/arXiv:{arx_id}",
211
- params={"fields": "citationCount"},
212
- timeout=10
213
- )
214
- if r.ok:
215
- j = r.json()
216
- c = j.get("citationCount")
217
- if isinstance(c, int):
218
- return c
219
- except Exception:
220
- pass
221
- # Fallback: Semantic Scholar by title
222
- if title:
223
- try:
224
- r = requests.get(
225
- "https://api.semanticscholar.org/graph/v1/paper/search",
226
- params={"query": title, "limit": 1, "fields": "title,citationCount"},
227
- timeout=10
228
- )
229
- if r.ok:
230
- j = r.json()
231
- if j.get("data"):
232
- c = j["data"][0].get("citationCount")
233
- if isinstance(c, int):
234
- return c
235
- except Exception:
236
- pass
237
-
238
- return None
239
-
240
- def add_citations(candidates: list[dict], max_workers: int = 6) -> None:
241
- # Select targets with missing citations
242
- targets = [
243
- it for it in candidates
244
- if it.get("source") == "arXiv" and (it.get("citations") in (None, 0))
245
- ]
246
- if not targets:
247
- return
248
 
249
- with ThreadPoolExecutor(max_workers=max_workers) as exe:
250
- fut2item = {
251
- exe.submit(fetch_citations, it.get("paper_url"), it.get("paper_title")): it
252
- for it in targets
253
- }
254
- for fut in as_completed(fut2item):
255
- it = fut2item[fut]
256
- try:
257
- c = fut.result()
258
- if c is not None:
259
- it["citations"] = c
260
- except Exception:
261
- pass
262
 
263
  def extract_arxiv_id(s: str) -> str | None:
264
  """Return normalized arXiv ID if present in s (URL or raw), else None."""
@@ -335,6 +305,8 @@ def search_and_display(query: str, model, filters: dict):
335
  st.warning("Please select at least one source.")
336
  return
337
 
 
 
338
  # Encode query to numpy array
339
  query_vec = model.encode(query or "", normalize_embeddings=True, convert_to_numpy=True)
340
 
@@ -391,34 +363,73 @@ def search_and_display(query: str, model, filters: dict):
391
  if pf_clauses:
392
  where.append("(" + " OR ".join(pf_clauses) + ")")
393
 
394
- # Filter in SQL
395
  if filters['types']:
396
  like_any = [f"%{t}%" for t in filters['types']]
397
  where.append(" lower(t.name) ILIKE ANY(%s) ")
398
  params.append(like_any)
399
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
  sql = f"""
401
  WITH latest_slogan AS (
402
  SELECT DISTINCT ON (ts.theorem_id)
403
- ts.theorem_id, ts.slogan_id, ts.slogan, ts.model
404
  FROM theorem_slogan ts
405
  ORDER BY ts.theorem_id, ts.slogan_id DESC
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406
  )
407
  SELECT
408
- p.paper_id, p.title, p.authors, p.link, p.last_updated, p.summary,
409
- p.journal_ref, p.primary_category, p.categories,
410
- t.theorem_id, t.name AS theorem_name, t.body AS theorem_body,
411
- ls.slogan AS theorem_slogan,
412
- (1.0 - (e.embedding <#> %s::vector)) AS similarity
413
- FROM paper p
414
- JOIN theorem t ON t.paper_id = p.paper_id
415
- JOIN latest_slogan ls ON ls.theorem_id = t.theorem_id
416
- JOIN {EMBED_TABLE} e ON e.slogan_id = ls.slogan_id
417
- {'WHERE ' + ' AND '.join(where) if where else ''}
418
- ORDER BY e.embedding <#> %s::vector ASC
419
- LIMIT %s
 
420
  """
421
- exec_params = [query_vec, *params, query_vec, int(filters['top_k'])]
422
 
423
  conn = get_rds_connection()
424
  cur = conn.cursor()
@@ -430,8 +441,8 @@ def search_and_display(query: str, model, filters: dict):
430
  # Populate result fields
431
  items = []
432
  for (paper_id, title, authors, link, last_updated, summary, journal_ref,
433
- primary_category, categories, theorem_id, theorem_name, theorem_body,
434
- theorem_slogan, similarity) in rows:
435
 
436
  # Determine source from url
437
  link_str = link or ""
@@ -450,19 +461,16 @@ def search_and_display(query: str, model, filters: dict):
450
  "source": source,
451
  "type": inferred_type,
452
  "journal_published": bool(journal_ref),
453
- "citations": None,
454
  "theorem_name": theorem_name,
455
  "theorem_slogan": theorem_slogan,
456
  "theorem_body": theorem_body,
457
  "similarity": float(similarity),
 
458
  })
459
 
460
- # Citations
461
- if 'arXiv' in filters['sources']:
462
- with st.spinner("Fetching citations..."):
463
- add_citations(items)
464
  for it in items:
465
- # Compute weighted score if applicable
466
  it["score"] = compute_score(it["similarity"], it.get("citations"), citation_weight)
467
 
468
  # Sort results by weighted score, then cosine similarity, then paper id
@@ -514,10 +522,12 @@ st.title("Math Theorem Search")
514
  st.write("This demo finds mathematical theorems that are semantically similar to your query.")
515
 
516
  model = load_model()
517
- theorems_data = load_papers_from_rds()
 
 
518
 
519
- if model and theorems_data:
520
- st.success(f"Successfully loaded {len(theorems_data)} theorems from arXiv and the Stacks Project. Ready to search!")
521
  # --- Sidebar filters ---
522
  with st.sidebar:
523
  st.header("Search Filters")
@@ -541,20 +551,23 @@ if model and theorems_data:
541
  if selected_sources:
542
  st.write("---")
543
  selected_types = st.multiselect("Filter by Type:", ALLOWED_TYPES)
544
- all_authors = sorted(list(set(a for it in theorems_data for a in (it.get('authors') or []))))
545
- selected_authors = st.multiselect("Filter by Author(s):", all_authors)
546
 
547
- # Tags come from the union of categories per selected source
548
- from collections import defaultdict
549
  tags_per_source = defaultdict(set)
550
- for it in theorems_data:
551
- tags_per_source[it['source']].add(it.get('primary_category'))
552
- union_tags = sorted({t for s in selected_sources for t in tags_per_source.get(s, set()) if t})
 
 
 
553
  selected_tags = st.multiselect("Filter by Tag/Category:", union_tags)
 
554
  paper_filter = st.text_input("Filter by Paper",
555
  value="",
556
  placeholder="e.g., 2401.12345, Finite Hilbert stability",
557
  help="Filter by title substring or arXiv ID/URL. Use commas for multiple.")
 
558
  if 'arXiv' in selected_sources:
559
  year_range = st.slider("Filter by Year:", 1991, 2025, (1991, 2025))
560
  journal_status = st.radio("Publication Status:",
@@ -569,6 +582,7 @@ if model and theorems_data:
569
  value=True,
570
  help="If unchecked, results with unknown citation counts are excluded."
571
  )
 
572
  top_k_results = st.slider("Number of Results to Display:", 1, 20, 5)
573
 
574
  filters = {
@@ -587,6 +601,7 @@ if model and theorems_data:
587
 
588
  user_query = st.text_input("Enter your query:", "")
589
  if st.button("Search") or user_query:
590
- search_and_display(user_query, model, filters)
 
591
  else:
592
  st.error("Could not load the model or data from RDS. Please check your RDS database connection and credentials.")
 
11
  import re
12
  import requests
13
  from concurrent.futures import ThreadPoolExecutor, as_completed
14
+ from collections import defaultdict
15
  from dotenv import load_dotenv
16
  from latex_clean import clean_latex_for_display
17
 
 
180
  return []
181
 
182
  @st.cache_data(ttl=60*60*24) # cache for 24 hours
183
+ def load_authors():
184
+ conn = get_rds_connection()
185
+ cur = conn.cursor()
186
+ cur.execute("""
187
+ SELECT DISTINCT unnest(p.authors) AS author
188
+ FROM paper p
189
+ WHERE p.authors IS NOT NULL
190
+ """)
191
+ rows = cur.fetchall()
192
+ cur.close()
193
+ conn.close()
194
+ authors = sorted(r[0] for r in rows if r[0])
195
+ return authors
196
+
197
+ @st.cache_data(ttl=60*60*24) # cache for 24 hours
198
+ def load_tags_per_source():
199
+ conn = get_rds_connection()
200
+ cur = conn.cursor()
201
+ cur.execute("""
202
+ SELECT
203
+ CASE WHEN p.link ILIKE '%%arxiv.org%%'
204
+ THEN 'arXiv'
205
+ ELSE 'Stacks Project'
206
+ END AS source,
207
+ p.primary_category
208
+ FROM paper p
209
+ WHERE p.primary_category IS NOT NULL
210
+ """)
211
+ rows = cur.fetchall()
212
+ cur.close()
213
+ conn.close()
214
+
215
+ from collections import defaultdict
216
+ tags_per_source = defaultdict(set)
217
+ for source, cat in rows:
218
+ tags_per_source[source].add(cat)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
+ # Make them plain lists so Streamlit cache can serialize easily
221
+ return {src: sorted(cats) for src, cats in tags_per_source.items()}
222
+
223
+ @st.cache_data(ttl=60*60*24) # cache for 24 hours
224
+ def load_theorem_count():
225
+ conn = get_rds_connection()
226
+ cur = conn.cursor()
227
+ cur.execute("SELECT COUNT(*) FROM theorem;")
228
+ (n,) = cur.fetchone()
229
+ cur.close()
230
+ conn.close()
231
+ return int(n)
 
232
 
233
  def extract_arxiv_id(s: str) -> str | None:
234
  """Return normalized arXiv ID if present in s (URL or raw), else None."""
 
305
  st.warning("Please select at least one source.")
306
  return
307
 
308
+ citation_weight = float(filters['citation_weight'])
309
+
310
  # Encode query to numpy array
311
  query_vec = model.encode(query or "", normalize_embeddings=True, convert_to_numpy=True)
312
 
 
363
  if pf_clauses:
364
  where.append("(" + " OR ".join(pf_clauses) + ")")
365
 
366
+ # Result type
367
  if filters['types']:
368
  like_any = [f"%{t}%" for t in filters['types']]
369
  where.append(" lower(t.name) ILIKE ANY(%s) ")
370
  params.append(like_any)
371
 
372
+ # Citations
373
+ low, high = filters["citation_range"]
374
+ include_unknown = filters["include_unknown_citations"]
375
+
376
+ if include_unknown:
377
+ where.append("( (p.citations BETWEEN %s AND %s) OR p.citations IS NULL )")
378
+ else:
379
+ where.append("( p.citations IS NOT NULL AND (p.citations BETWEEN %s AND %s) )")
380
+
381
+ params.extend([low, high])
382
+
383
+ pool_size = max(50, int(filters['top_k']) * 10)
384
+
385
  sql = f"""
386
  WITH latest_slogan AS (
387
  SELECT DISTINCT ON (ts.theorem_id)
388
+ ts.theorem_id, ts.slogan_id, ts.slogan
389
  FROM theorem_slogan ts
390
  ORDER BY ts.theorem_id, ts.slogan_id DESC
391
+ ),
392
+ candidates AS (
393
+ SELECT
394
+ p.paper_id,
395
+ p.title,
396
+ p.authors,
397
+ p.link,
398
+ p.last_updated,
399
+ p.summary,
400
+ p.journal_ref,
401
+ p.primary_category,
402
+ p.categories,
403
+ p.citations,
404
+ t.theorem_id,
405
+ t.name AS theorem_name,
406
+ t.body AS theorem_body,
407
+ ls.slogan AS theorem_slogan,
408
+ (1.0 - (e.embedding <#> %s::vector)) AS similarity
409
+ FROM paper p
410
+ JOIN theorem t ON t.paper_id = p.paper_id
411
+ JOIN latest_slogan ls ON ls.theorem_id = t.theorem_id
412
+ JOIN {EMBED_TABLE} e ON e.slogan_id = ls.slogan_id
413
+ {'WHERE ' + ' AND '.join(where) if where else ''}
414
+ ORDER BY e.embedding <#> %s::vector ASC
415
+ LIMIT {pool_size}
416
  )
417
  SELECT
418
+ *,
419
+ (
420
+ similarity +
421
+ %s *
422
+ CASE
423
+ WHEN citations IS NOT NULL AND citations > 0
424
+ THEN ln(citations::float)
425
+ ELSE 0
426
+ END
427
+ ) AS weighted_score
428
+ FROM candidates
429
+ ORDER BY weighted_score DESC, similarity DESC
430
+ LIMIT %s;
431
  """
432
+ exec_params = [query_vec, *params, query_vec, citation_weight, int(filters['top_k'])]
433
 
434
  conn = get_rds_connection()
435
  cur = conn.cursor()
 
441
  # Populate result fields
442
  items = []
443
  for (paper_id, title, authors, link, last_updated, summary, journal_ref,
444
+ primary_category, categories, citations, theorem_id, theorem_name,
445
+ theorem_body, theorem_slogan, similarity, weighted_score) in rows:
446
 
447
  # Determine source from url
448
  link_str = link or ""
 
461
  "source": source,
462
  "type": inferred_type,
463
  "journal_published": bool(journal_ref),
464
+ "citations": citations,
465
  "theorem_name": theorem_name,
466
  "theorem_slogan": theorem_slogan,
467
  "theorem_body": theorem_body,
468
  "similarity": float(similarity),
469
+ "score": weighted_score
470
  })
471
 
472
+ # Compute weighted citation score if applicable
 
 
 
473
  for it in items:
 
474
  it["score"] = compute_score(it["similarity"], it.get("citations"), citation_weight)
475
 
476
  # Sort results by weighted score, then cosine similarity, then paper id
 
522
  st.write("This demo finds mathematical theorems that are semantically similar to your query.")
523
 
524
  model = load_model()
525
+ theorem_count = load_theorem_count()
526
+ authors = load_authors()
527
+ tags_per_source = load_tags_per_source()
528
 
529
+ if model:
530
+ st.success(f"Successfully loaded {theorem_count} theorems from arXiv and the Stacks Project. Ready to search!")
531
  # --- Sidebar filters ---
532
  with st.sidebar:
533
  st.header("Search Filters")
 
551
  if selected_sources:
552
  st.write("---")
553
  selected_types = st.multiselect("Filter by Type:", ALLOWED_TYPES)
554
+ selected_authors = st.multiselect("Filter by Author(s):", authors)
 
555
 
556
+ # Tags per selected source(s)
 
557
  tags_per_source = defaultdict(set)
558
+ union_tags = sorted({
559
+ t
560
+ for s in selected_sources
561
+ for t in tags_per_source.get(s, [])
562
+ if t
563
+ })
564
  selected_tags = st.multiselect("Filter by Tag/Category:", union_tags)
565
+
566
  paper_filter = st.text_input("Filter by Paper",
567
  value="",
568
  placeholder="e.g., 2401.12345, Finite Hilbert stability",
569
  help="Filter by title substring or arXiv ID/URL. Use commas for multiple.")
570
+
571
  if 'arXiv' in selected_sources:
572
  year_range = st.slider("Filter by Year:", 1991, 2025, (1991, 2025))
573
  journal_status = st.radio("Publication Status:",
 
582
  value=True,
583
  help="If unchecked, results with unknown citation counts are excluded."
584
  )
585
+
586
  top_k_results = st.slider("Number of Results to Display:", 1, 20, 5)
587
 
588
  filters = {
 
601
 
602
  user_query = st.text_input("Enter your query:", "")
603
  if st.button("Search") or user_query:
604
+ with st.spinner("Fetching theorems..."):
605
+ search_and_display(user_query, model, filters)
606
  else:
607
  st.error("Could not load the model or data from RDS. Please check your RDS database connection and credentials.")