slszeto commited on
Commit
c6bcc0c
·
verified ·
1 Parent(s): f1767a2

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +218 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,220 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
1
  import streamlit as st
2
+ import json
3
+ import numpy as np
4
+ from sentence_transformers import SentenceTransformer, util
5
+ import os
6
+ import re
7
+ import boto3
8
+ import psycopg2
9
+ from psycopg2.extensions import connection
10
+ from dotenv import load_dotenv
11
+
12
+ # --- 0. Config ---
13
+ load_dotenv()
14
+
15
+ def get_rds_connection() -> connection:
16
+ region = os.getenv("AWS_REGION")
17
+ secret_arn = os.getenv("RDS_SECRET_ARN")
18
+ host = os.getenv("RDS_HOST")
19
+ dbname = os.getenv("RDS_DB_NAME")
20
+
21
+ sm = boto3.client("secretsmanager", region_name=region)
22
+ secret_value = sm.get_secret_value(SecretId=secret_arn)
23
+ secret_dict = json.loads(secret_value["SecretString"])
24
+
25
+ conn = psycopg2.connect(
26
+ host=host or secret_dict.get("host"),
27
+ port=int(secret_dict.get("port", 5432)),
28
+ dbname=dbname or secret_dict.get("dbname"),
29
+ user=secret_dict["username"],
30
+ password=secret_dict["password"],
31
+ sslmode="require",
32
+ )
33
+ return conn
34
+
35
+ # --- 1. Load the Embedding Model ---
36
+ @st.cache_resource
37
+ def load_model():
38
+ """
39
+ Loads the specialized math embedding model from Hugging Face.
40
+ """
41
+ try:
42
+ model = SentenceTransformer('math-similarity/Bert-MLM_arXiv-MP-class_zbMath')
43
+ return model
44
+ except Exception as e:
45
+ st.error(f"Error loading the embedding model: {e}")
46
+ return None
47
+
48
+
49
+ # --- 2. Load Data from RDS ---
50
+ @st.cache_data
51
+ def load_papers_from_rds():
52
+ """
53
+ Loads theorem data from the RDS database and prepares it for embedding.
54
+ Returns a list of theorem dictionaries with all necessary fields.
55
+ """
56
+ try:
57
+ conn = get_rds_connection()
58
+ cur = conn.cursor()
59
+
60
+ # Fetch all papers with their theorems and embeddings
61
+ cur.execute("""
62
+ SELECT
63
+ tm.paper_id,
64
+ tm.title,
65
+ tm.authors,
66
+ tm.link,
67
+ tm.last_updated,
68
+ tm.summary,
69
+ tm.journal_ref,
70
+ tm.primary_category,
71
+ tm.categories,
72
+ tm.global_notations,
73
+ tm.global_definitions,
74
+ tm.global_assumptions,
75
+ te.theorem_name,
76
+ te.theorem_slogan,
77
+ te.theorem_body,
78
+ te.embedding
79
+ FROM theorem_metadata tm
80
+ JOIN theorem_embedding te ON tm.paper_id = te.paper_id
81
+ ORDER BY tm.paper_id, te.theorem_name;
82
+ """)
83
+
84
+ rows = cur.fetchall()
85
+ cur.close()
86
+ conn.close()
87
+
88
+ all_theorems_data = []
89
+ for row in rows:
90
+ (paper_id, title, authors, link, last_updated, summary,
91
+ journal_ref, primary_category, categories,
92
+ global_notations, global_definitions, global_assumptions,
93
+ theorem_name, theorem_slogan, theorem_body, embedding) = row
94
+
95
+ # Build global context
96
+ global_context_parts = []
97
+ if global_notations:
98
+ global_context_parts.append(f"**Global Notations:**\n{global_notations}")
99
+ if global_definitions:
100
+ global_context_parts.append(f"**Global Definitions:**\n{global_definitions}")
101
+ if global_assumptions:
102
+ global_context_parts.append(f"**Global Assumptions:**\n{global_assumptions}")
103
+
104
+ global_context = "\n\n".join(global_context_parts)
105
+
106
+ # Convert embedding to a numpy float array
107
+ if isinstance(embedding, str):
108
+ embedding = json.loads(embedding)
109
+ if isinstance(embedding, list):
110
+ embedding = np.array(embedding, dtype=np.float32)
111
+ elif isinstance(embedding, np.ndarray):
112
+ embedding = embedding.astype(np.float32)
113
+
114
+ all_theorems_data.append({
115
+ "paper_id": paper_id,
116
+ "paper_title": title,
117
+ "paper_url": link,
118
+ "theorem_name": theorem_name,
119
+ "theorem_slogan": theorem_slogan,
120
+ "theorem_body": theorem_body,
121
+ "global_context": global_context,
122
+ "text_to_embed": f"{global_context}\n\n**Theorem ({theorem_name}):**\n{theorem_body}",
123
+ "stored_embedding": embedding
124
+ })
125
+
126
+ return all_theorems_data
127
+
128
+ except Exception as e:
129
+ st.error(f"Error loading data from RDS: {e}")
130
+ return []
131
+
132
+
133
+ # --- 3. The Search Function ---
134
+ def search_theorems(query, model, theorems_data, embeddings_db):
135
+ """
136
+ Takes a user query and finds the top 5 most similar theorems.
137
+ """
138
+ if not query:
139
+ st.info("Please enter a search query.")
140
+ return
141
+
142
+ query_embedding = model.encode(query, convert_to_tensor=True)
143
+ cosine_scores = util.cos_sim(query_embedding, embeddings_db)[0]
144
+ top_results_indices = np.argsort(-cosine_scores.cpu())[:5]
145
+
146
+ st.subheader("Top 5 Most Similar Theorems")
147
+
148
+ if len(top_results_indices) == 0:
149
+ st.write("No results found.")
150
+ return
151
+
152
+ for i, idx in enumerate(top_results_indices):
153
+ idx = idx.item()
154
+ similarity = cosine_scores[idx].item()
155
+ theorem_info = theorems_data[idx]
156
+
157
+ # Use an expander for each result to keep the main view clean
158
+ expander_title = f"**Result {i+1} | Similarity: {similarity:.4f}**"
159
+ if theorem_info.get("theorem_name"):
160
+ expander_title += f" | {theorem_info['theorem_name']}"
161
+
162
+ with st.expander(expander_title):
163
+ st.markdown(f"**Paper:** {theorem_info.get('paper_title', 'Unknown')}")
164
+ st.markdown(f"**Source:** [{theorem_info['paper_url']}]({theorem_info['paper_url']})")
165
+
166
+ # Display theorem slogan if available
167
+ if theorem_info.get("theorem_slogan"):
168
+ st.markdown(f"**Slogan:** {theorem_info['theorem_slogan']}")
169
+ st.write("")
170
+
171
+ # Display global context in a more readable blockquote
172
+ if theorem_info["global_context"]:
173
+ blockquote_context = "> " + theorem_info["global_context"].replace("\n", "\n> ")
174
+ st.markdown(blockquote_context)
175
+ st.write("")
176
+
177
+ # Clean and display theorem body
178
+ content = theorem_info['theorem_body']
179
+
180
+ # Remove labels, citations, and other disruptive commands
181
+ cleaned_content = re.sub(r'\\(label|cite|eqref)\{.*?\}', '', content)
182
+
183
+ # Convert math delimiters to $$
184
+ cleaned_content = re.sub(r'\\\[(.*?)\\\]', r'$$\1$$', cleaned_content)
185
+ cleaned_content = re.sub(r'\\\((.*?)\\\)', r'$\1$', cleaned_content)
186
+
187
+ # Remove common environment wrappers like \begin\{...\} and \end\{...\}
188
+ cleaned_content = re.sub(r'\\label\{.*?\}', r'', cleaned_content)
189
+ cleaned_content = re.sub(r'\\begin\{.*?\}', r'', cleaned_content)
190
+ cleaned_content = re.sub(r'\\end\{.*?\}', r'', cleaned_content)
191
+
192
+ # Remove extra formatting like newlines and tabs
193
+ cleaned_content = cleaned_content.replace('\n', ' ').replace('\t', ' ').strip()
194
+
195
+ # Use st.markdown() to render the cleaned, mixed text and LaTeX
196
+ st.markdown(f"**Theorem Body:**")
197
+ st.markdown(cleaned_content)
198
+
199
+
200
+ # --- Main App Interface ---
201
+ st.set_page_config(page_title="Theorem Search Demo", layout="wide")
202
+ st.title("📚 Semantic Theorem Search")
203
+ st.write("This demo uses a specialized mathematical language model to find theorems semantically similar to your query.")
204
+
205
+ model = load_model()
206
+ theorems_data = load_papers_from_rds()
207
+
208
+ if model and theorems_data:
209
+ with st.spinner("Preparing embeddings from database..."):
210
+ # Use stored embeddings from database - already numpy arrays
211
+ corpus_embeddings = np.array([item['stored_embedding'] for item in theorems_data])
212
+
213
+ st.success(f"Successfully loaded {len(theorems_data)} theorems from RDS. Ready to search!")
214
+
215
+ user_query = st.text_input("Enter your query:", "The Jones polynomial is a link invariant")
216
 
217
+ if st.button("Search") or user_query:
218
+ search_theorems(user_query, model, theorems_data, corpus_embeddings)
219
+ else:
220
+ st.error("Could not load the model or data from RDS. Please check your database connection and credentials.")