Sophie commited on
Commit
947c57d
·
1 Parent(s): 878f0ee

integrated pgvector; updated SQL calls to reference new papers table; minor refactoring

Browse files
Files changed (2) hide show
  1. requirements.txt +2 -0
  2. src/streamlit_app.py +215 -195
requirements.txt CHANGED
@@ -1,3 +1,5 @@
 
 
1
  streamlit==1.39.0
2
  sentence-transformers>=3.0.0
3
  numpy
 
1
+ requests
2
+ pgvector
3
  streamlit==1.39.0
4
  sentence-transformers>=3.0.0
5
  numpy
src/streamlit_app.py CHANGED
@@ -1,12 +1,12 @@
1
  import streamlit as st
2
  import json
3
  import numpy as np
4
- from sentence_transformers import SentenceTransformer, util
5
  import os
6
  import boto3
7
  import psycopg2
8
  from psycopg2.extensions import connection
9
- import torch
10
  import re
11
  import requests
12
  from concurrent.futures import ThreadPoolExecutor, as_completed
@@ -16,6 +16,7 @@ from latex_clean import clean_latex_for_display
16
  # Config
17
  load_dotenv()
18
 
 
19
  def get_rds_connection() -> connection:
20
  region = os.getenv("AWS_REGION")
21
  secret_arn = os.getenv("RDS_SECRET_ARN")
@@ -34,8 +35,10 @@ def get_rds_connection() -> connection:
34
  password=secret_dict["password"],
35
  sslmode="require",
36
  )
 
37
  return conn
38
 
 
39
  AVAILABLE_TAGS = {
40
  "arXiv": [
41
  "math.AC", "math.AG", "math.AP", "math.AT", "math.CA", "math.CO",
@@ -51,7 +54,7 @@ AVAILABLE_TAGS = {
51
  }
52
 
53
  ALLOWED_TYPES = [
54
- "theorem", "lemma", "proposition", "corollary", "definition", "remark", "assumption"
55
  ]
56
 
57
  ARXIV_ID_RE = re.compile(
@@ -59,52 +62,63 @@ ARXIV_ID_RE = re.compile(
59
  re.IGNORECASE
60
  )
61
 
 
 
 
62
  # Load the Embedding Model
63
  @st.cache_resource
64
  def load_model():
65
- """
66
- Loads the specialized math embedding model from Hugging Face.
67
- """
68
  try:
69
- model = SentenceTransformer('math-similarity/Bert-MLM_arXiv-MP-class_zbMath')
70
  return model
71
  except Exception as e:
72
  st.error(f"Error loading the embedding model: {e}")
73
  return None
74
 
 
 
 
 
 
 
 
 
 
75
  # Load Data from RDS
76
  @st.cache_data
77
  def load_papers_from_rds():
78
  """
79
- Loads theorem data from the RDS database and prepares it for embedding.
80
  Returns a list of theorem dictionaries with all necessary fields.
81
  """
82
  try:
83
  conn = get_rds_connection()
84
  cur = conn.cursor()
85
 
86
- # Fetch all papers with their theorems and embeddings
87
  cur.execute("""
88
- SELECT
89
- tm.paper_id,
90
- tm.title,
91
- tm.authors,
92
- tm.link,
93
- tm.last_updated,
94
- tm.summary,
95
- tm.journal_ref,
96
- tm.primary_category,
97
- tm.categories,
98
- tm.global_notations,
99
- tm.global_definitions,
100
- tm.global_assumptions,
101
- te.theorem_name,
102
- te.theorem_slogan,
103
- te.theorem_body,
104
- te.embedding
105
- FROM theorem_metadata tm
106
- JOIN theorem_embedding te ON tm.paper_id = te.paper_id
107
- ORDER BY tm.paper_id, te.theorem_name;
 
 
108
  """)
109
 
110
  rows = cur.fetchall()
@@ -115,27 +129,7 @@ def load_papers_from_rds():
115
  for row in rows:
116
  (paper_id, title, authors, link, last_updated, summary,
117
  journal_ref, primary_category, categories,
118
- global_notations, global_definitions, global_assumptions,
119
- theorem_name, theorem_slogan, theorem_body, embedding) = row
120
-
121
- # Build global context
122
- global_context_parts = []
123
- if global_notations:
124
- global_context_parts.append(f"**Global Notations:**\n{global_notations}")
125
- if global_definitions:
126
- global_context_parts.append(f"**Global Definitions:**\n{global_definitions}")
127
- if global_assumptions:
128
- global_context_parts.append(f"**Global Assumptions:**\n{global_assumptions}")
129
-
130
- global_context = "\n\n".join(global_context_parts)
131
-
132
- # Convert embedding to a numpy float array
133
- if isinstance(embedding, str):
134
- embedding = json.loads(embedding)
135
- if isinstance(embedding, list):
136
- embedding = np.array(embedding, dtype=np.float32)
137
- elif isinstance(embedding, np.ndarray):
138
- embedding = embedding.astype(np.float32)
139
 
140
  # Determine source from url
141
  link_str = link or ""
@@ -145,15 +139,6 @@ def load_papers_from_rds():
145
  source = "Stacks Project"
146
 
147
  # Determine type from name
148
- def infer_type(name: str) -> str:
149
- if not name:
150
- return "theorem"
151
- lower = name.lower()
152
- for t in ["theorem", "lemma", "proposition", "corollary", "definition", "remark", "assumption"]:
153
- if t in lower:
154
- return t
155
- return "theorem"
156
-
157
  inferred_type = infer_type(theorem_name or "")
158
 
159
  all_theorems_data.append({
@@ -170,8 +155,6 @@ def load_papers_from_rds():
170
  "theorem_name": theorem_name,
171
  "theorem_slogan": theorem_slogan,
172
  "theorem_body": theorem_body,
173
- "global_context": global_context,
174
- "stored_embedding": embedding,
175
  })
176
 
177
  return all_theorems_data
@@ -272,7 +255,7 @@ def extract_arxiv_id(s: str) -> str | None:
272
  def normalize_title(s: str) -> str:
273
  return (s or "").casefold().strip()
274
 
275
- def parse_paper_filter_input(raw: str) -> dict:
276
  """
277
  Parse user input into two sets: arxiv_ids and title substrings.
278
  Multiple entries may be comma-separated.
@@ -289,165 +272,197 @@ def parse_paper_filter_input(raw: str) -> dict:
289
  titles.add(normalize_title(token))
290
  return {"ids": ids, "titles": titles}
291
 
292
- def item_matches_paper_filter(item: dict, paper_filter: dict) -> bool:
293
- """
294
- True if the item matches at least one requested arXiv ID or one title substring.
295
- If paper_filter is empty (both sets empty), always True.
296
- """
297
- ids = paper_filter.get("ids", set())
298
- titles = paper_filter.get("titles", set())
299
- if not ids and not titles:
300
- return True
301
-
302
- # Compare IDs (extract once from url)
303
- url = item.get("paper_url") or ""
304
- item_id = extract_arxiv_id(url)
305
- if item_id and item_id.lower() in ids:
306
- return True
307
-
308
- # Compare titles (substring, case-insensitive)
309
- t = normalize_title(item.get("paper_title"))
310
- if t and any(sub in t for sub in titles):
311
- return True
312
-
313
- return False
314
 
315
  # --- Search and Display ---
316
- def search_and_display_with_filters(query, model, theorems_data, embeddings_db, filters):
317
  if not filters['sources']:
318
  st.warning("Please select at least one source.")
319
  return
320
 
321
- if query:
322
- query_embedding = model.encode(query, convert_to_tensor=True)
323
- cosine_scores = util.cos_sim(query_embedding, embeddings_db)[0]
324
- else:
325
- cosine_scores = torch.zeros(len(theorems_data))
326
-
327
- low, high = filters['citation_range']
328
-
329
- # Get a larger pool to filter from
330
- top_k_pool = min(200, len(theorems_data))
331
- top_indices = torch.topk(cosine_scores, k=top_k_pool, sorted=True).indices
332
- top_indices = top_indices.tolist()
333
-
334
- paper_filter = filters.get("paper_filter", {"ids": set(), "titles": set()})
335
- matched_indices = []
336
- if paper_filter and (paper_filter.get("ids") or paper_filter.get("titles")):
337
- for i, it in enumerate(theorems_data):
338
- if item_matches_paper_filter(it, paper_filter):
339
- matched_indices.append(i)
340
-
341
- pool_indices = list(dict.fromkeys(top_indices + matched_indices))
342
- pool = [(i, theorems_data[i]) for i in pool_indices]
343
-
344
- # Fetch citations in parallel
345
- if ('arXiv' in filters['sources']):
346
- add_citations([it for _, it in pool])
347
-
348
- results = []
349
-
350
- # Filter results
351
- for idx, item in pool:
352
- type_match = (not filters['types']) or (item.get('type','').lower() in filters['types'])
353
- tag_match = (not filters['tags']) or (item.get('primary_category') in filters['tags'])
354
- author_match = (not filters['authors']) or any(a in (item.get('authors') or []) for a in filters['authors'])
355
- source_match = item.get('source') in filters['sources']
356
- paper_match = item_matches_paper_filter(item, filters['paper_filter'])
357
-
358
- # Citations & year & journal only for arXiv
359
- citations = item.get('citations')
360
- log_cit = np.log1p(int(citations)) if citations is not None else 0.0
361
- if citations is None:
362
- if not filters['include_unknown_citations']:
363
- continue
364
- citation_match = True
365
- else:
366
- citation_match = (low <= int(citations) <= high)
367
-
368
- year_match = True
369
- if filters['year_range'] and item.get('source') == 'arXiv':
370
- y = item.get('year') or 0
371
- yr0, yr1 = filters['year_range']
372
- year_match = (yr0 <= y <= yr1)
373
-
374
- journal_match = True
375
- if item.get('source') == 'arXiv':
376
- status = filters['journal_status']
377
- jp = bool(item.get('journal_published'))
378
- if status == "Journal Article":
379
- journal_match = jp
380
- elif status == "Preprint Only":
381
- journal_match = not jp
382
-
383
- if all([type_match, tag_match, author_match, source_match, paper_match, citation_match, year_match, journal_match]):
384
- # Similarity = cosine_similary + citation_weight * log(citation_count)
385
- similarity = float(cosine_scores[idx].item()) + filters['citation_weight'] * log_cit
386
- results.append({"idx": idx, "info": item, "similarity": similarity})
387
- if len(results) >= filters['top_k']:
388
- break
389
-
390
- results.sort(key=lambda r: r["similarity"], reverse=True)
391
- results = results[:filters['top_k']]
392
-
393
- st.subheader(f"Found {len(results)} Matching Results")
394
- if not results:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
  st.warning("No results found for the current filters.")
396
  return
397
 
398
- for i, r in enumerate(results):
399
- info = r["info"]
400
- expander_title = f"**Result {i+1} | Similarity: {r['similarity']:.4f} | Type: {info.get('type','').title()}**"
401
  with st.expander(expander_title, expanded=True):
402
- st.markdown(f"**Paper:** *{info.get('paper_title','Unknown')}*")
403
  st.markdown(f"**Authors:** {', '.join(info.get('authors') or []) or 'N/A'}")
404
  st.markdown(f"**Source:** {info.get('source')} ({info.get('paper_url')})")
405
  citations = info.get("citations")
406
  cit_str = "Unknown" if citations is None else str(citations)
407
  st.markdown(
408
- f"**Math Tag:** `{info.get('primary_category')}` | "
409
  f"**Citations:** {cit_str} | "
410
  f"**Year:** {info.get('year', 'N/A')}"
411
  )
412
- # Testing only
413
- if filters['citation_weight'] > 0:
414
- base = float(cosine_scores[r["idx"]].item())
415
- log_cit = np.log1p(int(citations)) if citations is not None else 0.0
416
- st.caption(
417
- f"base_cosine={base:.4f} | log(citations)={log_cit:.4f} | weight={filters['citation_weight']:.2f}")
418
  st.markdown("---")
419
-
420
  if info.get("theorem_slogan"):
421
  st.markdown(f"**Slogan:** {info['theorem_slogan']}\n")
422
 
423
- if info.get("global_context"):
424
- cleaned_ctx = clean_latex_for_display(info["global_context"])
425
- st.markdown("> " + cleaned_ctx.replace("\n", "\n> ") )
426
-
427
  cleaned_content = clean_latex_for_display(info['theorem_body'])
428
  st.markdown(f"**{info['theorem_name'] or 'Theorem Body.'}**")
429
  st.markdown(cleaned_content)
430
- st.markdown("**Paper ID:**")
431
- st.markdown(info['paper_id'])
432
-
433
- # Testing only
434
- st.markdown('**Paper ID (testing only)**')
435
- st.markdown(info['paper_id'])
 
 
 
436
 
437
  # --- Main App Interface ---
438
  st.set_page_config(page_title="Theorem Search Demo", layout="wide")
439
- st.title("📚 Semantic Theorem Search")
440
- st.write("This demo uses a specialized mathematical language model to find theorems semantically similar to your query.")
441
 
442
  model = load_model()
443
  theorems_data = load_papers_from_rds()
444
 
445
  if model and theorems_data:
446
- with st.spinner("Preparing embeddings from database..."):
447
- corpus_embeddings = np.array([item['stored_embedding'] for item in theorems_data])
448
-
449
  st.success(f"Successfully loaded {len(theorems_data)} theorems from arXiv and the Stacks Project. Ready to search!")
450
-
451
  # --- Sidebar filters ---
452
  with st.sidebar:
453
  st.header("Search Filters")
@@ -461,6 +476,7 @@ if model and theorems_data:
461
  )
462
 
463
  selected_authors, selected_types, selected_tags = [], [], []
 
464
  year_range, journal_status = None, "All"
465
  citation_range = (0, 1000)
466
  citation_weight = 0.0
@@ -479,16 +495,20 @@ if model and theorems_data:
479
  for it in theorems_data:
480
  tags_per_source[it['source']].add(it.get('primary_category'))
481
  union_tags = sorted({t for s in selected_sources for t in tags_per_source.get(s, set()) if t})
482
- selected_tags = st.multiselect("Filter by Math Tag/Category:", union_tags)
483
- paper_filter_raw = st.text_input("Filter by Paper",
484
  value="",
485
  placeholder="e.g., 2401.12345, Finite Hilbert stability",
486
  help="Filter by title substring or arXiv ID/URL. Use commas for multiple.")
487
  if 'arXiv' in selected_sources:
488
  year_range = st.slider("Filter by Year:", 1991, 2025, (1991, 2025))
489
- journal_status = st.radio("Publication Status:", ["All", "Journal Article", "Preprint Only"], horizontal=True)
490
- citation_range = st.slider("Filter by Citations:", 0, 1000, (0, 1000))
491
- citation_weight = st.slider("Citation Weight:", 0.0, 1.0, 0.0, step=0.01)
 
 
 
 
492
  include_unknown_citations = st.checkbox(
493
  "Include entries with unknown citation counts",
494
  value=True,
@@ -501,7 +521,7 @@ if model and theorems_data:
501
  "types": [t.lower() for t in selected_types],
502
  "tags": selected_tags,
503
  "sources": selected_sources,
504
- "paper_filter": parse_paper_filter_input(paper_filter_raw),
505
  "year_range": year_range,
506
  "journal_status": journal_status,
507
  "citation_range": citation_range,
@@ -512,6 +532,6 @@ if model and theorems_data:
512
 
513
  user_query = st.text_input("Enter your query:", "")
514
  if st.button("Search") or user_query:
515
- search_and_display_with_filters(user_query, model, theorems_data, corpus_embeddings, filters)
516
  else:
517
- st.error("Could not load the model or data from RDS. Please check your RDS database connection and credentials.")
 
1
  import streamlit as st
2
  import json
3
  import numpy as np
4
+ from sentence_transformers import SentenceTransformer
5
  import os
6
  import boto3
7
  import psycopg2
8
  from psycopg2.extensions import connection
9
+ from pgvector.psycopg2 import register_vector
10
  import re
11
  import requests
12
  from concurrent.futures import ThreadPoolExecutor, as_completed
 
16
  # Config
17
  load_dotenv()
18
 
19
+
20
  def get_rds_connection() -> connection:
21
  region = os.getenv("AWS_REGION")
22
  secret_arn = os.getenv("RDS_SECRET_ARN")
 
35
  password=secret_dict["password"],
36
  sslmode="require",
37
  )
38
+ register_vector(conn)
39
  return conn
40
 
41
+
42
  AVAILABLE_TAGS = {
43
  "arXiv": [
44
  "math.AC", "math.AG", "math.AP", "math.AT", "math.CA", "math.CO",
 
54
  }
55
 
56
  ALLOWED_TYPES = [
57
+ "theorem", "lemma", "proposition"
58
  ]
59
 
60
  ARXIV_ID_RE = re.compile(
 
62
  re.IGNORECASE
63
  )
64
 
65
+ EMBED_TABLE = "theorem_embedding_qwen"
66
+
67
+
68
  # Load the Embedding Model
69
  @st.cache_resource
70
  def load_model():
 
 
 
71
  try:
72
+ model = SentenceTransformer('Qwen/Qwen3-Embedding-0.6B')
73
  return model
74
  except Exception as e:
75
  st.error(f"Error loading the embedding model: {e}")
76
  return None
77
 
78
+ def infer_type(name: str) -> str:
79
+ if not name:
80
+ return "theorem"
81
+ lower = name.lower()
82
+ for t in ["theorem", "lemma", "proposition"]:
83
+ if t in lower:
84
+ return t
85
+ return "theorem"
86
+
87
  # Load Data from RDS
88
  @st.cache_data
89
  def load_papers_from_rds():
90
  """
91
+ Loads the theorem data from the RDS database.
92
  Returns a list of theorem dictionaries with all necessary fields.
93
  """
94
  try:
95
  conn = get_rds_connection()
96
  cur = conn.cursor()
97
 
98
+ # Fetch all papers with their theorems
99
  cur.execute("""
100
+ WITH latest_slogan AS (SELECT DISTINCT
101
+ ON (ts.theorem_id)
102
+ ts.theorem_id, ts.slogan_id, ts.slogan
103
+ FROM theorem_slogan ts
104
+ ORDER BY ts.theorem_id, ts.slogan_id DESC
105
+ )
106
+ SELECT p.paper_id,
107
+ p.title,
108
+ p.authors,
109
+ p.link,
110
+ p.last_updated,
111
+ p.summary,
112
+ p.journal_ref,
113
+ p.primary_category,
114
+ p.categories,
115
+ t.name AS theorem_name,
116
+ ls.slogan AS theorem_slogan,
117
+ t.body AS theorem_body
118
+ FROM paper p
119
+ JOIN theorem t ON t.paper_id = p.paper_id
120
+ LEFT JOIN latest_slogan ls ON ls.theorem_id = t.theorem_id
121
+ ORDER BY p.paper_id, t.name;
122
  """)
123
 
124
  rows = cur.fetchall()
 
129
  for row in rows:
130
  (paper_id, title, authors, link, last_updated, summary,
131
  journal_ref, primary_category, categories,
132
+ theorem_name, theorem_slogan, theorem_body) = row
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  # Determine source from url
135
  link_str = link or ""
 
139
  source = "Stacks Project"
140
 
141
  # Determine type from name
 
 
 
 
 
 
 
 
 
142
  inferred_type = infer_type(theorem_name or "")
143
 
144
  all_theorems_data.append({
 
155
  "theorem_name": theorem_name,
156
  "theorem_slogan": theorem_slogan,
157
  "theorem_body": theorem_body,
 
 
158
  })
159
 
160
  return all_theorems_data
 
255
  def normalize_title(s: str) -> str:
256
  return (s or "").casefold().strip()
257
 
258
+ def parse_paper_filter(raw: str) -> dict:
259
  """
260
  Parse user input into two sets: arxiv_ids and title substrings.
261
  Multiple entries may be comma-separated.
 
272
  titles.add(normalize_title(token))
273
  return {"ids": ids, "titles": titles}
274
 
275
+ def compute_score(similarity: float, citations: int, weight: float) -> float:
276
+ c = int(citations) if citations is not None else 0
277
+ if c == 0:
278
+ return float(similarity)
279
+ return float(similarity) + float(weight) * np.log(c)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
281
  # --- Search and Display ---
282
+ def search_and_display(query: str, model, filters: dict):
283
  if not filters['sources']:
284
  st.warning("Please select at least one source.")
285
  return
286
 
287
+ # Encode query to numpy array
288
+ query_vec = model.encode(query or "", normalize_embeddings=True, convert_to_numpy=True)
289
+
290
+ where = []
291
+ params = []
292
+
293
+ # Source
294
+ if filters['sources']:
295
+ src_cases = []
296
+ if 'arXiv' in filters['sources']:
297
+ src_cases.append(" (p.link ILIKE '%%arxiv.org%%') ")
298
+ if 'Stacks Project' in filters['sources']:
299
+ src_cases.append(" (p.link NOT ILIKE '%%arxiv.org%%') ")
300
+ if src_cases:
301
+ where.append("(" + " OR ".join(src_cases) + ")")
302
+
303
+ # Authors
304
+ if filters['authors']:
305
+ where.append(" p.authors && %s ")
306
+ params.append(filters['authors'])
307
+
308
+ # Tag/category
309
+ if filters['tags']:
310
+ where.append(" p.primary_category = ANY(%s) ")
311
+ params.append(filters['tags'])
312
+
313
+ # Year (arXiv only)
314
+ if filters['year_range']:
315
+ yr0, yr1 = filters['year_range']
316
+ where.append("""
317
+ ( (p.link ILIKE '%%arxiv.org%%' AND EXTRACT(YEAR FROM p.last_updated) BETWEEN %s AND %s)
318
+ OR (p.link NOT ILIKE '%%arxiv.org%%') )
319
+ """)
320
+ params.extend([yr0, yr1])
321
+
322
+ # Journal status (arXiv only)
323
+ if filters['journal_status'] != "All":
324
+ if filters['journal_status'] == "Journal Article":
325
+ where.append(" (p.link ILIKE '%%arxiv.org%%' AND p.journal_ref IS NOT NULL) ")
326
+ elif filters['journal_status'] == "Preprint Only":
327
+ where.append(" (p.link ILIKE '%%arxiv.org%%' AND p.journal_ref IS NULL) ")
328
+
329
+ # Paper filter: arXiv id in link or title substring(s)
330
+ pf = filters.get("paper_filter", {"ids": set(), "titles": set()})
331
+ id_patterns = [f"%{i}%" for i in pf.get("ids", set())]
332
+ title_patterns = [f"%{t}%" for t in pf.get("titles", set())]
333
+ pf_clauses = []
334
+ if id_patterns:
335
+ pf_clauses.append(" p.link ILIKE ANY(%s) ")
336
+ params.append(id_patterns)
337
+ if title_patterns:
338
+ pf_clauses.append(" p.title ILIKE ANY(%s) ")
339
+ params.append(title_patterns)
340
+ if pf_clauses:
341
+ where.append("(" + " OR ".join(pf_clauses) + ")")
342
+
343
+ # Filter in SQL
344
+ if filters['types']:
345
+ like_any = [f"%{t}%" for t in filters['types']]
346
+ where.append(" lower(t.name) ILIKE ANY(%s) ")
347
+ params.append(like_any)
348
+
349
+ sql = f"""
350
+ WITH latest_slogan AS (
351
+ SELECT DISTINCT ON (ts.theorem_id)
352
+ ts.theorem_id, ts.slogan_id, ts.slogan, ts.model
353
+ FROM theorem_slogan ts
354
+ ORDER BY ts.theorem_id, ts.slogan_id DESC
355
+ )
356
+ SELECT
357
+ p.paper_id, p.title, p.authors, p.link, p.last_updated, p.summary,
358
+ p.journal_ref, p.primary_category, p.categories,
359
+ t.theorem_id, t.name AS theorem_name, t.body AS theorem_body,
360
+ ls.slogan AS theorem_slogan,
361
+ (1.0 - (e.embedding <#> %s::vector)) AS similarity
362
+ FROM paper p
363
+ JOIN theorem t ON t.paper_id = p.paper_id
364
+ JOIN latest_slogan ls ON ls.theorem_id = t.theorem_id
365
+ JOIN {EMBED_TABLE} e ON e.slogan_id = ls.slogan_id
366
+ {'WHERE ' + ' AND '.join(where) if where else ''}
367
+ ORDER BY e.embedding <#> %s::vector ASC
368
+ LIMIT %s
369
+ """
370
+ exec_params = [query_vec, *params, query_vec, int(filters['top_k'])]
371
+
372
+ conn = get_rds_connection()
373
+ cur = conn.cursor()
374
+ cur.execute(sql, exec_params)
375
+ rows = cur.fetchall()
376
+ cur.close()
377
+ conn.close()
378
+
379
+ # Populate result fields
380
+ items = []
381
+ for (paper_id, title, authors, link, last_updated, summary, journal_ref,
382
+ primary_category, categories, theorem_id, theorem_name, theorem_body,
383
+ theorem_slogan, similarity) in rows:
384
+
385
+ # Determine source from url
386
+ link_str = link or ""
387
+ source = "arXiv" if link_str.startswith(
388
+ ("http://arxiv.org", "https://arxiv.org")) or "arxiv.org" in link_str else "Stacks Project"
389
+
390
+ inferred_type = infer_type(theorem_name or "")
391
+
392
+ items.append({
393
+ "paper_id": paper_id,
394
+ "authors": authors,
395
+ "paper_title": title,
396
+ "paper_url": link,
397
+ "year": last_updated.year,
398
+ "primary_category": primary_category,
399
+ "source": source,
400
+ "type": inferred_type,
401
+ "journal_published": bool(journal_ref),
402
+ "citations": None,
403
+ "theorem_name": theorem_name,
404
+ "theorem_slogan": theorem_slogan,
405
+ "theorem_body": theorem_body,
406
+ "similarity": float(similarity),
407
+ })
408
+
409
+ # Citations
410
+ if 'arXiv' in filters['sources']:
411
+ with st.spinner("Fetching citations..."):
412
+ add_citations(items)
413
+ for it in items:
414
+ # Compute weighted score if applicable
415
+ it["score"] = compute_score(it["similarity"], it.get("citations"), citation_weight)
416
+
417
+ # Sort results by weighted score, then cosine similarity, then paper id
418
+ items.sort(key=lambda x: (x["score"], x["similarity"], str(x.get("paper_id"))), reverse=True)
419
+
420
+ # Display results
421
+ st.subheader(f"Found {len(items)} Matching Results")
422
+ if not items:
423
  st.warning("No results found for the current filters.")
424
  return
425
 
426
+ for i, info in enumerate(items):
427
+ expander_title = f"**Result {i + 1} | Similarity: {info['score']:.4f} | {info.get('type', '').title()}**"
 
428
  with st.expander(expander_title, expanded=True):
429
+ st.markdown(f"**Paper:** *{info.get('paper_title', 'Unknown')}*")
430
  st.markdown(f"**Authors:** {', '.join(info.get('authors') or []) or 'N/A'}")
431
  st.markdown(f"**Source:** {info.get('source')} ({info.get('paper_url')})")
432
  citations = info.get("citations")
433
  cit_str = "Unknown" if citations is None else str(citations)
434
  st.markdown(
435
+ f"**Tag:** `{info.get('primary_category')}` | "
436
  f"**Citations:** {cit_str} | "
437
  f"**Year:** {info.get('year', 'N/A')}"
438
  )
 
 
 
 
 
 
439
  st.markdown("---")
 
440
  if info.get("theorem_slogan"):
441
  st.markdown(f"**Slogan:** {info['theorem_slogan']}\n")
442
 
 
 
 
 
443
  cleaned_content = clean_latex_for_display(info['theorem_body'])
444
  st.markdown(f"**{info['theorem_name'] or 'Theorem Body.'}**")
445
  st.markdown(cleaned_content)
446
+ st.markdown("---")
447
+ # FOR TESTING ONLY
448
+ st.caption(f"Paper ID: {info['paper_id']}")
449
+ if info['citations'] is None or info['citations'] == 0:
450
+ log = 0
451
+ else:
452
+ log = np.log(info['citations'])
453
+ st.caption(
454
+ f"base_cosine={info['similarity']:.4f} | log(cit)={log:.4f} | weight={filters['citation_weight']:.2f}")
455
 
456
  # --- Main App Interface ---
457
  st.set_page_config(page_title="Theorem Search Demo", layout="wide")
458
+ st.title("Math Theorem Search")
459
+ st.write("This demo finds mathematical theorems that are semantically similar to your query.")
460
 
461
  model = load_model()
462
  theorems_data = load_papers_from_rds()
463
 
464
  if model and theorems_data:
 
 
 
465
  st.success(f"Successfully loaded {len(theorems_data)} theorems from arXiv and the Stacks Project. Ready to search!")
 
466
  # --- Sidebar filters ---
467
  with st.sidebar:
468
  st.header("Search Filters")
 
476
  )
477
 
478
  selected_authors, selected_types, selected_tags = [], [], []
479
+ paper_filter = ""
480
  year_range, journal_status = None, "All"
481
  citation_range = (0, 1000)
482
  citation_weight = 0.0
 
495
  for it in theorems_data:
496
  tags_per_source[it['source']].add(it.get('primary_category'))
497
  union_tags = sorted({t for s in selected_sources for t in tags_per_source.get(s, set()) if t})
498
+ selected_tags = st.multiselect("Filter by Tag/Category:", union_tags)
499
+ paper_filter = st.text_input("Filter by Paper",
500
  value="",
501
  placeholder="e.g., 2401.12345, Finite Hilbert stability",
502
  help="Filter by title substring or arXiv ID/URL. Use commas for multiple.")
503
  if 'arXiv' in selected_sources:
504
  year_range = st.slider("Filter by Year:", 1991, 2025, (1991, 2025))
505
+ journal_status = st.radio("Publication Status:",
506
+ ["All", "Journal Article", "Preprint Only"],
507
+ horizontal=True)
508
+ citation_range = st.slider("Filter by Citations:", 0, 1000, 1000, step=10)
509
+ citation_weight = st.slider("Citation Weight:", 0.0, 1.0, 0.0, step=0.01,
510
+ help="If nonzero, results are ranked by base_score $+$ weight $\\times$ "
511
+ "$\\log($citations$)$.")
512
  include_unknown_citations = st.checkbox(
513
  "Include entries with unknown citation counts",
514
  value=True,
 
521
  "types": [t.lower() for t in selected_types],
522
  "tags": selected_tags,
523
  "sources": selected_sources,
524
+ "paper_filter": parse_paper_filter(paper_filter),
525
  "year_range": year_range,
526
  "journal_status": journal_status,
527
  "citation_range": citation_range,
 
532
 
533
  user_query = st.text_input("Enter your query:", "")
534
  if st.button("Search") or user_query:
535
+ search_and_display(user_query, model, filters)
536
  else:
537
+ st.error("Could not load the model or data from RDS. Please check your RDS database connection and credentials.")