theorem-search / src /streamlit_app.py
Sophie
minor fixes
6de8c39
raw
history blame
8.26 kB
import streamlit as st
import json
import numpy as np
from sentence_transformers import SentenceTransformer, util
import os
import boto3
import psycopg2
from psycopg2.extensions import connection
from dotenv import load_dotenv
from latex_clean import clean_latex_for_display
# Config
load_dotenv()
def get_rds_connection() -> connection:
region = os.getenv("AWS_REGION")
secret_arn = os.getenv("RDS_SECRET_ARN")
host = os.getenv("RDS_HOST")
dbname = os.getenv("RDS_DB_NAME")
sm = boto3.client("secretsmanager", region_name=region)
secret_value = sm.get_secret_value(SecretId=secret_arn)
secret_dict = json.loads(secret_value["SecretString"])
conn = psycopg2.connect(
host=host or secret_dict.get("host"),
port=int(secret_dict.get("port", 5432)),
dbname=dbname or secret_dict.get("dbname"),
user=secret_dict["username"],
password=secret_dict["password"],
sslmode="require",
)
return conn
AVAILABLE_TAGS = {
"arXiv": [
"math.AC", "math.AG", "math.AP", "math.AT", "math.CA", "math.CO",
"math.CT", "math.CV", "math.DG", "math.DS", "math.FA", "math.GM",
"math.GN", "math.GR", "math.GT", "math.HO", "math.IT", "math.KT",
"math.LO", "math.MG", "math.MP", "math.NA", "math.NT", "math.OA",
"math.OC", "math.PR", "math.QA", "math.RA", "math.RT", "math.SG",
"math.SP", "math.ST", "Statistics Theory"
],
"Stacks Project": [
"Sets", "Schemes", "Algebraic Stacks", "Étale Cohomology"
]
}
ALLOWED_TYPES = [
"theorem", "lemma", "proposition", "corollary", "definition", "remark", "assumption"
]
# Load the Embedding Model
@st.cache_resource
def load_model():
"""
Loads the specialized math embedding model from Hugging Face.
"""
try:
model = SentenceTransformer('math-similarity/Bert-MLM_arXiv-MP-class_zbMath')
return model
except Exception as e:
st.error(f"Error loading the embedding model: {e}")
return None
# Load Data from RDS
@st.cache_data
def load_papers_from_rds():
"""
Loads theorem data from the RDS database and prepares it for embedding.
Returns a list of theorem dictionaries with all necessary fields.
"""
try:
conn = get_rds_connection()
cur = conn.cursor()
# Fetch all papers with their theorems and embeddings
cur.execute("""
SELECT
tm.paper_id,
tm.title,
tm.authors,
tm.link,
tm.last_updated,
tm.summary,
tm.journal_ref,
tm.primary_category,
tm.categories,
tm.global_notations,
tm.global_definitions,
tm.global_assumptions,
te.theorem_name,
te.theorem_slogan,
te.theorem_body,
te.embedding
FROM theorem_metadata tm
JOIN theorem_embedding te ON tm.paper_id = te.paper_id
ORDER BY tm.paper_id, te.theorem_name;
""")
rows = cur.fetchall()
cur.close()
conn.close()
all_theorems_data = []
for row in rows:
(paper_id, title, authors, link, last_updated, summary,
journal_ref, primary_category, categories,
global_notations, global_definitions, global_assumptions,
theorem_name, theorem_slogan, theorem_body, embedding) = row
# Build global context
global_context_parts = []
if global_notations:
global_context_parts.append(f"**Global Notations:**\n{global_notations}")
if global_definitions:
global_context_parts.append(f"**Global Definitions:**\n{global_definitions}")
if global_assumptions:
global_context_parts.append(f"**Global Assumptions:**\n{global_assumptions}")
global_context = "\n\n".join(global_context_parts)
# Convert embedding to a numpy float array
if isinstance(embedding, str):
embedding = json.loads(embedding)
if isinstance(embedding, list):
embedding = np.array(embedding, dtype=np.float32)
elif isinstance(embedding, np.ndarray):
embedding = embedding.astype(np.float32)
all_theorems_data.append({
"paper_id": paper_id,
"authors": authors,
"paper_title": title,
"paper_url": link,
"year": last_updated.year,
"primary_category": primary_category,
"theorem_name": theorem_name,
"theorem_slogan": theorem_slogan,
"theorem_body": theorem_body,
"global_context": global_context,
"stored_embedding": embedding
})
return all_theorems_data
except Exception as e:
st.error(f"Error loading data from RDS: {e}")
return []
# --- 3. The Search Function ---
def search_theorems(query, model, theorems_data, embeddings_db):
"""
Takes a user query and finds the top 10 most similar theorems.
"""
if not query:
st.info("Please enter a search query.")
return
query_embedding = model.encode(query, convert_to_tensor=True)
cosine_scores = util.cos_sim(query_embedding, embeddings_db)[0]
top_results_indices = np.argsort(-cosine_scores.cpu())[:10]
st.subheader("Top 5 Most Similar Theorems")
if len(top_results_indices) == 0:
st.write("No results found.")
return
for i, idx in enumerate(top_results_indices):
idx = idx.item()
similarity = cosine_scores[idx].item()
theorem_info = theorems_data[idx]
expander_title = f"**Result {i+1} | Similarity: {similarity:.4f}**"
if theorem_info.get("theorem_name"):
expander_title += f" | {theorem_info['theorem_name']}"
with st.expander(expander_title):
st.markdown(f"**Paper:** {theorem_info.get('paper_title', 'Unknown')}")
st.markdown(f"**Authors:** {', '.join(theorem_info['authors']) if theorem_info['authors'] else 'N/A'}")
st.markdown(f"**Source:** [{theorem_info['paper_url']}]({theorem_info['paper_url']})")
st.markdown(
f"**Math Tag:** `{theorem_info['primary_category']}` | **Year:** {theorem_info.get('year', 'N/A')}")
st.markdown("---")
if theorem_info.get("theorem_slogan"):
st.markdown(f"**Slogan:** {theorem_info['theorem_slogan']}")
st.write("")
if theorem_info["global_context"]:
cleaned_ctx = clean_latex_for_display(theorem_info["global_context"])
blockquote_context = "> " + cleaned_ctx.replace("\n", "\n> ")
st.markdown(blockquote_context)
st.write("")
cleaned_content = clean_latex_for_display(theorem_info['theorem_body'])
st.markdown(f"**Theorem Body:**")
st.markdown(cleaned_content)
# --- Main App Interface ---
st.set_page_config(page_title="Theorem Search Demo", layout="wide")
st.title("📚 Semantic Theorem Search")
st.write("This demo uses a specialized mathematical language model to find theorems semantically similar to your query.")
st.markdown("*Note: Linking to a specific page within an arXiv PDF is not directly possible.*",
help="arXiv links redirect to the paper's abstract, not a specific page in the PDF.")
model = load_model()
theorems_data = load_papers_from_rds()
if model and theorems_data:
with st.spinner("Preparing embeddings from database..."):
corpus_embeddings = np.array([item['stored_embedding'] for item in theorems_data])
st.success(f"Successfully loaded {len(theorems_data)} theorems from arXiv. Ready to search!")
user_query = st.text_input("Enter your query:", "")
if st.button("Search") or user_query:
search_theorems(user_query, model, theorems_data, corpus_embeddings)
else:
st.error("Could not load the model or data from RDS. Please check your RDS database connection and credentials.")