theorem-search / src /streamlit_app.py
Sophie
added filtering by paper; ranking now uses cosine similarity and citation counts
934ea0e
raw
history blame
19.8 kB
import streamlit as st
import json
import numpy as np
from sentence_transformers import SentenceTransformer, util
import os
import boto3
import psycopg2
from psycopg2.extensions import connection
import torch
import re
import requests
from concurrent.futures import ThreadPoolExecutor, as_completed
from dotenv import load_dotenv
from latex_clean import clean_latex_for_display
# Config
load_dotenv()
def get_rds_connection() -> connection:
region = os.getenv("AWS_REGION")
secret_arn = os.getenv("RDS_SECRET_ARN")
host = os.getenv("RDS_HOST")
dbname = os.getenv("RDS_DB_NAME")
sm = boto3.client("secretsmanager", region_name=region)
secret_value = sm.get_secret_value(SecretId=secret_arn)
secret_dict = json.loads(secret_value["SecretString"])
conn = psycopg2.connect(
host=host or secret_dict.get("host"),
port=int(secret_dict.get("port", 5432)),
dbname=dbname or secret_dict.get("dbname"),
user=secret_dict["username"],
password=secret_dict["password"],
sslmode="require",
)
return conn
AVAILABLE_TAGS = {
"arXiv": [
"math.AC", "math.AG", "math.AP", "math.AT", "math.CA", "math.CO",
"math.CT", "math.CV", "math.DG", "math.DS", "math.FA", "math.GM",
"math.GN", "math.GR", "math.GT", "math.HO", "math.IT", "math.KT",
"math.LO", "math.MG", "math.MP", "math.NA", "math.NT", "math.OA",
"math.OC", "math.PR", "math.QA", "math.RA", "math.RT", "math.SG",
"math.SP", "math.ST", "Statistics Theory"
],
"Stacks Project": [
"Sets", "Schemes", "Algebraic Stacks", "Étale Cohomology"
]
}
ALLOWED_TYPES = [
"theorem", "lemma", "proposition", "corollary", "definition", "remark", "assumption"
]
ARXIV_ID_RE = re.compile(
r'(?:arxiv\.org/(?:abs|pdf)/)?((?:\d{4}\.\d{4,5}|[a-z\-]+/\d{7}))',
re.IGNORECASE
)
# Load the Embedding Model
@st.cache_resource
def load_model():
"""
Loads the specialized math embedding model from Hugging Face.
"""
try:
model = SentenceTransformer('math-similarity/Bert-MLM_arXiv-MP-class_zbMath')
return model
except Exception as e:
st.error(f"Error loading the embedding model: {e}")
return None
# Load Data from RDS
@st.cache_data
def load_papers_from_rds():
"""
Loads theorem data from the RDS database and prepares it for embedding.
Returns a list of theorem dictionaries with all necessary fields.
"""
try:
conn = get_rds_connection()
cur = conn.cursor()
# Fetch all papers with their theorems and embeddings
cur.execute("""
SELECT
tm.paper_id,
tm.title,
tm.authors,
tm.link,
tm.last_updated,
tm.summary,
tm.journal_ref,
tm.primary_category,
tm.categories,
tm.global_notations,
tm.global_definitions,
tm.global_assumptions,
te.theorem_name,
te.theorem_slogan,
te.theorem_body,
te.embedding
FROM theorem_metadata tm
JOIN theorem_embedding te ON tm.paper_id = te.paper_id
ORDER BY tm.paper_id, te.theorem_name;
""")
rows = cur.fetchall()
cur.close()
conn.close()
all_theorems_data = []
for row in rows:
(paper_id, title, authors, link, last_updated, summary,
journal_ref, primary_category, categories,
global_notations, global_definitions, global_assumptions,
theorem_name, theorem_slogan, theorem_body, embedding) = row
# Build global context
global_context_parts = []
if global_notations:
global_context_parts.append(f"**Global Notations:**\n{global_notations}")
if global_definitions:
global_context_parts.append(f"**Global Definitions:**\n{global_definitions}")
if global_assumptions:
global_context_parts.append(f"**Global Assumptions:**\n{global_assumptions}")
global_context = "\n\n".join(global_context_parts)
# Convert embedding to a numpy float array
if isinstance(embedding, str):
embedding = json.loads(embedding)
if isinstance(embedding, list):
embedding = np.array(embedding, dtype=np.float32)
elif isinstance(embedding, np.ndarray):
embedding = embedding.astype(np.float32)
# Determine source from url
link_str = link or ""
if link_str.startswith("http://arxiv.org") or link_str.startswith("https://arxiv.org"):
source = "arXiv"
else:
source = "Stacks Project"
# Determine type from name
def infer_type(name: str) -> str:
if not name:
return "theorem"
lower = name.lower()
for t in ["theorem", "lemma", "proposition", "corollary", "definition", "remark", "assumption"]:
if t in lower:
return t
return "theorem"
inferred_type = infer_type(theorem_name or "")
all_theorems_data.append({
"paper_id": paper_id,
"authors": authors,
"paper_title": title,
"paper_url": link,
"year": last_updated.year,
"primary_category": primary_category,
"source": source,
"type": inferred_type,
"journal_published": bool(journal_ref),
"citations": None,
"theorem_name": theorem_name,
"theorem_slogan": theorem_slogan,
"theorem_body": theorem_body,
"global_context": global_context,
"stored_embedding": embedding,
})
return all_theorems_data
except Exception as e:
st.error(f"Error loading data from RDS: {e}")
return []
@st.cache_data(ttl=60*60*24) # cache for 24 hours
def fetch_citations(paper_url: str, title: str) -> int | None:
"""
Returns citation count if found, else None.
Tries the following sources in order:
1) OpenAlex by arXiv id
2) Semantic Scholar by arXiv id
3) Semantic Scholar by title
"""
arx_id = None
if paper_url:
m = ARXIV_ID_RE.search(paper_url)
if m:
arx_id = m.group(1)
# OpenAlex by arXiv id
if arx_id:
try:
r = requests.get(f"https://api.openalex.org/works/arXiv:{arx_id}", timeout=10)
if r.ok:
data = r.json()
c = data.get("cited_by_count")
if isinstance(c, int):
return c
except Exception:
pass
# Semantic Scholar by arXiv id
if arx_id:
try:
r = requests.get(
f"https://api.semanticscholar.org/graph/v1/paper/arXiv:{arx_id}",
params={"fields": "citationCount"},
timeout=10
)
if r.ok:
j = r.json()
c = j.get("citationCount")
if isinstance(c, int):
return c
except Exception:
pass
# Fallback: Semantic Scholar by title
if title:
try:
r = requests.get(
"https://api.semanticscholar.org/graph/v1/paper/search",
params={"query": title, "limit": 1, "fields": "title,citationCount"},
timeout=10
)
if r.ok:
j = r.json()
if j.get("data"):
c = j["data"][0].get("citationCount")
if isinstance(c, int):
return c
except Exception:
pass
return None
def add_citations(candidates: list[dict], max_workers: int = 6) -> None:
# Select targets with missing citations
targets = [
it for it in candidates
if it.get("source") == "arXiv" and (it.get("citations") in (None, 0))
]
if not targets:
return
with ThreadPoolExecutor(max_workers=max_workers) as exe:
fut2item = {
exe.submit(fetch_citations, it.get("paper_url"), it.get("paper_title")): it
for it in targets
}
for fut in as_completed(fut2item):
it = fut2item[fut]
try:
c = fut.result()
if c is not None:
it["citations"] = c
except Exception:
pass
def extract_arxiv_id(s: str) -> str | None:
"""Return normalized arXiv ID if present in s (URL or raw), else None."""
if not s:
return None
m = ARXIV_ID_RE.search(s.strip())
return m.group(1) if m else None
def normalize_title(s: str) -> str:
return (s or "").casefold().strip()
def parse_paper_filter_input(raw: str) -> dict:
"""
Parse user input into two sets: arxiv_ids and title substrings.
Multiple entries may be comma-separated.
e.g. "2401.12345, Optimal Transport" -> {"ids":{"2401.12345"}, "titles":{"optimal transport"}}
"""
ids, titles = set(), set()
if not raw:
return {"ids": ids, "titles": titles}
for token in [t.strip() for t in raw.split(",") if t.strip()]:
arx = extract_arxiv_id(token)
if arx:
ids.add(arx.lower())
else:
titles.add(normalize_title(token))
return {"ids": ids, "titles": titles}
def item_matches_paper_filter(item: dict, paper_filter: dict) -> bool:
"""
True if the item matches at least one requested arXiv ID or one title substring.
If paper_filter is empty (both sets empty), always True.
"""
ids = paper_filter.get("ids", set())
titles = paper_filter.get("titles", set())
if not ids and not titles:
return True
# Compare IDs (extract once from url)
url = item.get("paper_url") or ""
item_id = extract_arxiv_id(url)
if item_id and item_id.lower() in ids:
return True
# Compare titles (substring, case-insensitive)
t = normalize_title(item.get("paper_title"))
if t and any(sub in t for sub in titles):
return True
return False
# --- Search and Display ---
def search_and_display_with_filters(query, model, theorems_data, embeddings_db, filters):
if not filters['sources']:
st.warning("Please select at least one source.")
return
if query:
query_embedding = model.encode(query, convert_to_tensor=True)
cosine_scores = util.cos_sim(query_embedding, embeddings_db)[0]
else:
cosine_scores = torch.zeros(len(theorems_data))
low, high = filters['citation_range']
# Get a larger pool to filter from
top_k_pool = min(200, len(theorems_data))
top_indices = torch.topk(cosine_scores, k=top_k_pool, sorted=True).indices
top_indices = top_indices.tolist()
paper_filter = filters.get("paper_filter", {"ids": set(), "titles": set()})
matched_indices = []
if paper_filter and (paper_filter.get("ids") or paper_filter.get("titles")):
for i, it in enumerate(theorems_data):
if item_matches_paper_filter(it, paper_filter):
matched_indices.append(i)
pool_indices = list(dict.fromkeys(top_indices + matched_indices))
pool = [(i, theorems_data[i]) for i in pool_indices]
# Fetch citations in parallel
if ('arXiv' in filters['sources']):
add_citations([it for _, it in pool])
results = []
# Filter results
for idx, item in pool:
type_match = (not filters['types']) or (item.get('type','').lower() in filters['types'])
tag_match = (not filters['tags']) or (item.get('primary_category') in filters['tags'])
author_match = (not filters['authors']) or any(a in (item.get('authors') or []) for a in filters['authors'])
source_match = item.get('source') in filters['sources']
paper_match = item_matches_paper_filter(item, filters['paper_filter'])
# Citations & year & journal only for arXiv
citations = item.get('citations')
log_cit = np.log1p(int(citations)) if citations is not None else 0.0
if citations is None:
if not filters['include_unknown_citations']:
continue
citation_match = True
else:
citation_match = (low <= int(citations) <= high)
year_match = True
if filters['year_range'] and item.get('source') == 'arXiv':
y = item.get('year') or 0
yr0, yr1 = filters['year_range']
year_match = (yr0 <= y <= yr1)
journal_match = True
if item.get('source') == 'arXiv':
status = filters['journal_status']
jp = bool(item.get('journal_published'))
if status == "Journal Article":
journal_match = jp
elif status == "Preprint Only":
journal_match = not jp
if all([type_match, tag_match, author_match, source_match, paper_match, citation_match, year_match, journal_match]):
# Similarity = cosine_similary + citation_weight * log(citation_count)
similarity = float(cosine_scores[idx].item()) + filters['citation_weight'] * log_cit
results.append({"idx": idx, "info": item, "similarity": similarity})
if len(results) >= filters['top_k']:
break
results.sort(key=lambda r: r["similarity"], reverse=True)
results = results[:filters['top_k']]
st.subheader(f"Found {len(results)} Matching Results")
if not results:
st.warning("No results found for the current filters.")
return
for i, r in enumerate(results):
info = r["info"]
expander_title = f"**Result {i+1} | Similarity: {r['similarity']:.4f} | Type: {info.get('type','').title()}**"
with st.expander(expander_title, expanded=True):
st.markdown(f"**Paper:** *{info.get('paper_title','Unknown')}*")
st.markdown(f"**Authors:** {', '.join(info.get('authors') or []) or 'N/A'}")
st.markdown(f"**Source:** {info.get('source')} ({info.get('paper_url')})")
citations = info.get("citations")
cit_str = "Unknown" if citations is None else str(citations)
st.markdown(
f"**Math Tag:** `{info.get('primary_category')}` | "
f"**Citations:** {cit_str} | "
f"**Year:** {info.get('year', 'N/A')}"
)
# Testing only
if filters['citation_weight'] > 0:
base = float(cosine_scores[r["idx"]].item())
log_cit = np.log1p(int(citations)) if citations is not None else 0.0
st.caption(
f"base_cosine={base:.4f} | log(citations)={log_cit:.4f} | weight={filters['citation_weight']:.2f}")
st.markdown("---")
if info.get("theorem_slogan"):
st.markdown(f"**Slogan:** {info['theorem_slogan']}\n")
if info.get("global_context"):
cleaned_ctx = clean_latex_for_display(info["global_context"])
st.markdown("> " + cleaned_ctx.replace("\n", "\n> ") )
cleaned_content = clean_latex_for_display(info['theorem_body'])
st.markdown(f"**{info['theorem_name'] or 'Theorem Body.'}**")
st.markdown(cleaned_content)
# Testing only
st.markdown('**Paper ID (testing only)**')
st.markdown(info['paper_id'])
# --- Main App Interface ---
st.set_page_config(page_title="Theorem Search Demo", layout="wide")
st.title("📚 Semantic Theorem Search")
st.write("This demo uses a specialized mathematical language model to find theorems semantically similar to your query.")
model = load_model()
theorems_data = load_papers_from_rds()
if model and theorems_data:
with st.spinner("Preparing embeddings from database..."):
corpus_embeddings = np.array([item['stored_embedding'] for item in theorems_data])
st.success(f"Successfully loaded {len(theorems_data)} theorems from arXiv and the Stacks Project. Ready to search!")
# --- Sidebar filters ---
with st.sidebar:
st.header("Search Filters")
all_sources = ['arXiv', 'Stacks Project']
selected_sources = st.multiselect(
"Filter by Source(s):",
all_sources,
default=all_sources[:1] if all_sources else [],
help="Select one or more sources to reveal more filters."
)
selected_authors, selected_types, selected_tags = [], [], []
year_range, journal_status = None, "All"
citation_range = (0, 1000)
citation_weight = 0.0
include_unknown_citations = True
top_k_results = 5
if selected_sources:
st.write("---")
selected_types = st.multiselect("Filter by Type:", ALLOWED_TYPES)
all_authors = sorted(list(set(a for it in theorems_data for a in (it.get('authors') or []))))
selected_authors = st.multiselect("Filter by Author(s):", all_authors)
# Tags come from the union of categories per selected source
from collections import defaultdict
tags_per_source = defaultdict(set)
for it in theorems_data:
tags_per_source[it['source']].add(it.get('primary_category'))
union_tags = sorted({t for s in selected_sources for t in tags_per_source.get(s, set()) if t})
selected_tags = st.multiselect("Filter by Math Tag/Category:", union_tags)
paper_filter_raw = st.text_input("Filter by Paper",
value="",
placeholder="e.g., 2401.12345, Finite Hilbert stability",
help="Filter by title substring or arXiv ID/URL. Use commas for multiple.")
if 'arXiv' in selected_sources:
year_range = st.slider("Filter by Year:", 1991, 2025, (1991, 2025))
journal_status = st.radio("Publication Status:", ["All", "Journal Article", "Preprint Only"], horizontal=True)
citation_range = st.slider("Filter by Citations:", 0, 1000, (0, 1000))
citation_weight = st.slider("Citation Weight:", 0.0, 1.0, 0.0, step=0.01)
include_unknown_citations = st.checkbox(
"Include entries with unknown citation counts",
value=True,
help="If unchecked, results with unknown citation counts are excluded."
)
top_k_results = st.slider("Number of Results to Display:", 1, 20, 5)
filters = {
"authors": selected_authors,
"types": [t.lower() for t in selected_types],
"tags": selected_tags,
"sources": selected_sources,
"paper_filter": parse_paper_filter_input(paper_filter_raw),
"year_range": year_range,
"journal_status": journal_status,
"citation_range": citation_range,
"citation_weight": citation_weight,
"include_unknown_citations": include_unknown_citations,
"top_k": top_k_results,
}
user_query = st.text_input("Enter your query:", "")
if st.button("Search") or user_query:
search_and_display_with_filters(user_query, model, theorems_data, corpus_embeddings, filters)
else:
st.error("Could not load the model or data from RDS. Please check your RDS database connection and credentials.")