Sophie commited on
Commit
ab99a3d
·
1 Parent(s): cb6c277

better latex parsing when displaying theorems (not perfect)

Browse files
Files changed (2) hide show
  1. src/latex_clean.py +186 -0
  2. src/streamlit_app.py +39 -36
src/latex_clean.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ _MATH_ENVS = [
4
+ # display / alignment
5
+ "align", "equation", "gather", "multline", "flalign", "dmath",
6
+ "aligned", "alignedat", "split",
7
+ # arrays & matrices
8
+ "array", "matrix", "pmatrix", "bmatrix", "Bmatrix", "vmatrix", "Vmatrix", "smallmatrix", "cases",
9
+ ]
10
+
11
+ def _fix_truncated_end_braces(s: str) -> str:
12
+ return re.sub(r'(\\end\{[A-Za-z]+(?:\*)?)(?=\s|$)', r'\1}', s)
13
+
14
+ def _close_unclosed_envs(s: str) -> str:
15
+ token = re.compile(
16
+ r'\\begin\{(?P<b_env>[A-Za-z]+)(?P<b_star>\*)?\}'
17
+ r'|\\end\{(?P<e_env>[A-Za-z]+)(?P<e_star>\*)?}?',
18
+ re.DOTALL
19
+ )
20
+
21
+ stack = []
22
+ for m in token.finditer(s):
23
+ if m.group('b_env'):
24
+ env = m.group('b_env')
25
+ star = m.group('b_star') or ''
26
+ if env in _MATH_ENVS:
27
+ stack.append((env, star))
28
+ else:
29
+ env = m.group('e_env')
30
+ star = m.group('e_star') or ''
31
+ if stack and stack[-1] == (env, star):
32
+ stack.pop()
33
+
34
+ if not stack:
35
+ return s
36
+
37
+ # Append missing delimiters in reverse order
38
+ closers = ''.join(f'\n\\end{{{env}{star}}}' for env, star in reversed(stack))
39
+ return s + closers
40
+
41
+ def _balance_math_fences(s: str) -> str:
42
+ # $$ blocks
43
+ if s.count('$$') % 2 == 1:
44
+ s = s.rstrip() + '\n$$'
45
+ # \[ \]
46
+ if len(re.findall(r'\\\[', s)) > len(re.findall(r'\\\]', s)):
47
+ s = s.rstrip() + '\n\\]'
48
+ # \( \)
49
+ if len(re.findall(r'\\\(', s)) > len(re.findall(r'\\\)', s)):
50
+ s = s.rstrip() + '\\)'
51
+
52
+ return s
53
+
54
+ def _repair_unbalanced_math(text: str) -> str:
55
+ # normalize newlines
56
+ text = text.replace('\r\n', '\n').replace('\r', '\n')
57
+ # fix truncated \end{env
58
+ text = _fix_truncated_end_braces(text)
59
+ text = text + "]"
60
+ # append closing \end{...} for any unclosed math envs we care about
61
+ text = _close_unclosed_envs(text)
62
+ # make sure $$ / \[ / \( are closed
63
+ text = _balance_math_fences(text)
64
+ return text
65
+
66
+ def clean_latex_for_display(text: str) -> str:
67
+ """Cleans raw LaTeX for display in Streamlit."""
68
+ if not text:
69
+ return text
70
+
71
+ # Fix potential truncation errors
72
+ text = _repair_unbalanced_math(text)
73
+
74
+ # Remove common macros and non-important display commands
75
+ text = re.sub(
76
+ r"""
77
+ \\(?:DeclareMathOperator|newcommand|renewcommand)\*? # command
78
+ \s*\{[^{}]+\} # {name}
79
+ (?:\s*\[\d+\])? # [n] optional
80
+ (?:\s*\[[^\]]*\])? # [default] optional
81
+ \s*\{[^{}]*\} # {body} (no nesting)
82
+ """,
83
+ "",
84
+ text,
85
+ flags=re.VERBOSE | re.DOTALL,
86
+ )
87
+
88
+ text = re.sub(r'\\(label|ref|eqref|cite|footnote|footnotetext|alert)\{[^}]*\}', '', text)
89
+
90
+ # Align/align* normalization
91
+ def _normalize_align_blocks(s: str) -> str:
92
+ out, i, n = [], 0, len(s)
93
+ begin_pat = re.compile(r'\\begin\{align(\*)?\}', re.DOTALL)
94
+
95
+ while i < n:
96
+ m = begin_pat.search(s, i)
97
+ if not m:
98
+ out.append(s[i:])
99
+ break
100
+
101
+ # Copy everything before this block
102
+ out.append(s[i:m.start()])
103
+
104
+ star = m.group(1) or "" # "" or "*"
105
+ body_start = m.end()
106
+ rest = s[body_start:]
107
+
108
+ # Try exact end: \end{align*} or \end{align}
109
+ exact_end = re.search(rf'\\end\{{align{re.escape(star)}\}}', rest)
110
+ if exact_end:
111
+ end_start_in_rest = exact_end.start()
112
+ end_consumed = exact_end.end()
113
+ else:
114
+ # Fallback: accept truncated end like "\end{align*"
115
+ trunc = re.search(rf'\\end\{{align{re.escape(star)}', rest)
116
+ if not trunc:
117
+ out.append(s[m.start():])
118
+ break
119
+ end_start_in_rest = trunc.start()
120
+ end_consumed = trunc.end() + (1 if rest[trunc.end():].startswith('}') else 0)
121
+
122
+ body = rest[:end_start_in_rest]
123
+
124
+ # Clean the body
125
+ body = re.sub(r'\\tag\{[^}]*\}', '', body)
126
+ body = re.sub(r'\\(?:nonumber|notag)\b', '', body)
127
+ body = re.sub(r'\\label\{[^}]*\}', '', body)
128
+
129
+ # Trim trailing "\\" on the final line
130
+ lines = [ln.rstrip() for ln in body.strip().split('\n')]
131
+ if lines and lines[-1].endswith(r'\\'):
132
+ lines[-1] = lines[-1][:-2].rstrip()
133
+ cleaned = '\n'.join(lines).strip()
134
+
135
+ # Emit a single aligned block
136
+ out.append(f"$$\n\\begin{{aligned}}\n{cleaned}\n\\end{{aligned}}\n$$")
137
+
138
+ # Advance past the end tag (exact or truncated)
139
+ i = body_start + end_consumed
140
+
141
+ return ''.join(out)
142
+
143
+ text = _normalize_align_blocks(text)
144
+
145
+ text = re.sub(r'\\\[\s*(.*?)\s*\\\]', r'$$\n\1\n$$', text, flags=re.DOTALL)
146
+ text = re.sub(r'\\\(\s*(.*?)\s*\\\)', r'$\1$', text, flags=re.DOTALL)
147
+
148
+ # Turn \item into Markdown bullets
149
+ text = re.sub(r'\\begin\{(?:enumerate|itemize)\}', '', text)
150
+ text = re.sub(r'\\end\{(?:enumerate|itemize)\}', '', text)
151
+ text = re.sub(r'^[ \t]*\\item[ \t]*', r'- ', text, flags=re.MULTILINE)
152
+
153
+ # Wrap "&"-aligned single lines outside existing $$...$$ blocks
154
+ parts = re.split(r'(\$\$[\s\S]*?\$\$)', text) # keep math blocks intact
155
+ for i in range(0, len(parts), 2):
156
+ segment = parts[i]
157
+ lines = segment.split('\n')
158
+ for j, ln in enumerate(lines):
159
+ if '&' in ln and not ln.strip().startswith(('-', '$')):
160
+ lines[j] = f"$$\n\\begin{{aligned}}\n{ln}\n\\end{{aligned}}\n$$"
161
+ parts[i] = '\n'.join(lines)
162
+ text = ''.join(parts)
163
+
164
+ def _isolate_display_math(s: str) -> str:
165
+ """Ensure each $$...$$ block is on its own lines with padding blank lines."""
166
+ parts = re.split(r'(\$\$[\s\S]*?\$\$)', s) # keep the $$...$$ blocks
167
+ for i in range(1, len(parts), 2): # only the $$ blocks (odd indices)
168
+ block = parts[i] # starts with $$, ends with $$
169
+ # normalize interior newlines: $$\n... \n$$
170
+ if not block.startswith('$$\n'):
171
+ block = '$$\n' + block[2:].lstrip()
172
+ if not block.endswith('\n$$'):
173
+ block = block[:-2].rstrip() + '\n$$'
174
+ parts[i] = block
175
+
176
+ # ensure a blank line before and after the block
177
+ if i - 1 >= 0:
178
+ parts[i - 1] = parts[i - 1].rstrip() + '\n\n'
179
+ if i + 1 < len(parts):
180
+ parts[i + 1] = '\n\n' + parts[i + 1].lstrip()
181
+ return ''.join(parts)
182
+ text = _isolate_display_math(text)
183
+
184
+ # Remove whitespace
185
+ text = re.sub(r'\n{3,}', '\n\n', text).strip()
186
+ return text
src/streamlit_app.py CHANGED
@@ -3,13 +3,13 @@ import json
3
  import numpy as np
4
  from sentence_transformers import SentenceTransformer, util
5
  import os
6
- import re
7
  import boto3
8
  import psycopg2
9
  from psycopg2.extensions import connection
10
  from dotenv import load_dotenv
 
11
 
12
- # --- 0. Config ---
13
  load_dotenv()
14
 
15
  def get_rds_connection() -> connection:
@@ -32,7 +32,25 @@ def get_rds_connection() -> connection:
32
  )
33
  return conn
34
 
35
- # --- 1. Load the Embedding Model ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  @st.cache_resource
37
  def load_model():
38
  """
@@ -46,7 +64,7 @@ def load_model():
46
  return None
47
 
48
 
49
- # --- 2. Load Data from RDS ---
50
  @st.cache_data
51
  def load_papers_from_rds():
52
  """
@@ -113,13 +131,15 @@ def load_papers_from_rds():
113
 
114
  all_theorems_data.append({
115
  "paper_id": paper_id,
 
116
  "paper_title": title,
117
  "paper_url": link,
 
 
118
  "theorem_name": theorem_name,
119
  "theorem_slogan": theorem_slogan,
120
  "theorem_body": theorem_body,
121
  "global_context": global_context,
122
- "text_to_embed": f"{global_context}\n\n**Theorem ({theorem_name}):**\n{theorem_body}",
123
  "stored_embedding": embedding
124
  })
125
 
@@ -133,7 +153,7 @@ def load_papers_from_rds():
133
  # --- 3. The Search Function ---
134
  def search_theorems(query, model, theorems_data, embeddings_db):
135
  """
136
- Takes a user query and finds the top 5 most similar theorems.
137
  """
138
  if not query:
139
  st.info("Please enter a search query.")
@@ -141,7 +161,7 @@ def search_theorems(query, model, theorems_data, embeddings_db):
141
 
142
  query_embedding = model.encode(query, convert_to_tensor=True)
143
  cosine_scores = util.cos_sim(query_embedding, embeddings_db)[0]
144
- top_results_indices = np.argsort(-cosine_scores.cpu())[:5]
145
 
146
  st.subheader("Top 5 Most Similar Theorems")
147
 
@@ -154,67 +174,50 @@ def search_theorems(query, model, theorems_data, embeddings_db):
154
  similarity = cosine_scores[idx].item()
155
  theorem_info = theorems_data[idx]
156
 
157
- # Use an expander for each result to keep the main view clean
158
  expander_title = f"**Result {i+1} | Similarity: {similarity:.4f}**"
159
  if theorem_info.get("theorem_name"):
160
  expander_title += f" | {theorem_info['theorem_name']}"
161
 
162
  with st.expander(expander_title):
163
  st.markdown(f"**Paper:** {theorem_info.get('paper_title', 'Unknown')}")
 
164
  st.markdown(f"**Source:** [{theorem_info['paper_url']}]({theorem_info['paper_url']})")
 
 
 
165
 
166
- # Display theorem slogan if available
167
  if theorem_info.get("theorem_slogan"):
168
  st.markdown(f"**Slogan:** {theorem_info['theorem_slogan']}")
169
  st.write("")
170
 
171
- # Display global context in a more readable blockquote
172
  if theorem_info["global_context"]:
173
- blockquote_context = "> " + theorem_info["global_context"].replace("\n", "\n> ")
174
- st.markdown(blockquote_context)
175
  st.write("")
176
 
177
- # Clean and display theorem body
178
- content = theorem_info['theorem_body']
179
-
180
- # Remove labels, citations, and other disruptive commands
181
- cleaned_content = re.sub(r'\\(label|cite|eqref)\{.*?\}', '', content)
182
-
183
- # Convert math delimiters to $$
184
- cleaned_content = re.sub(r'\\\[(.*?)\\\]', r'$$\1$$', cleaned_content)
185
- cleaned_content = re.sub(r'\\\((.*?)\\\)', r'$\1$', cleaned_content)
186
-
187
- # Remove common environment wrappers like \begin\{...\} and \end\{...\}
188
- cleaned_content = re.sub(r'\\label\{.*?\}', r'', cleaned_content)
189
- cleaned_content = re.sub(r'\\begin\{.*?\}', r'', cleaned_content)
190
- cleaned_content = re.sub(r'\\end\{.*?\}', r'', cleaned_content)
191
-
192
- # Remove extra formatting like newlines and tabs
193
- cleaned_content = cleaned_content.replace('\n', ' ').replace('\t', ' ').strip()
194
-
195
- # Use st.markdown() to render the cleaned, mixed text and LaTeX
196
  st.markdown(f"**Theorem Body:**")
197
  st.markdown(cleaned_content)
198
 
199
-
200
  # --- Main App Interface ---
201
  st.set_page_config(page_title="Theorem Search Demo", layout="wide")
202
  st.title("📚 Semantic Theorem Search")
203
  st.write("This demo uses a specialized mathematical language model to find theorems semantically similar to your query.")
 
 
204
 
205
  model = load_model()
206
  theorems_data = load_papers_from_rds()
207
 
208
  if model and theorems_data:
209
  with st.spinner("Preparing embeddings from database..."):
210
- # Use stored embeddings from database - already numpy arrays
211
  corpus_embeddings = np.array([item['stored_embedding'] for item in theorems_data])
212
 
213
- st.success(f"Successfully loaded {len(theorems_data)} theorems from RDS. Ready to search!")
214
 
215
- user_query = st.text_input("Enter your query:", "The Jones polynomial is a link invariant")
216
 
217
  if st.button("Search") or user_query:
218
  search_theorems(user_query, model, theorems_data, corpus_embeddings)
219
  else:
220
- st.error("Could not load the model or data from RDS. Please check your database connection and credentials.")
 
3
  import numpy as np
4
  from sentence_transformers import SentenceTransformer, util
5
  import os
 
6
  import boto3
7
  import psycopg2
8
  from psycopg2.extensions import connection
9
  from dotenv import load_dotenv
10
+ from latex_clean import clean_latex_for_display
11
 
12
+ # Config
13
  load_dotenv()
14
 
15
  def get_rds_connection() -> connection:
 
32
  )
33
  return conn
34
 
35
+ AVAILABLE_TAGS = {
36
+ "arXiv": [
37
+ "math.AC", "math.AG", "math.AP", "math.AT", "math.CA", "math.CO",
38
+ "math.CT", "math.CV", "math.DG", "math.DS", "math.FA", "math.GM",
39
+ "math.GN", "math.GR", "math.GT", "math.HO", "math.IT", "math.KT",
40
+ "math.LO", "math.MG", "math.MP", "math.NA", "math.NT", "math.OA",
41
+ "math.OC", "math.PR", "math.QA", "math.RA", "math.RT", "math.SG",
42
+ "math.SP", "math.ST", "Statistics Theory"
43
+ ],
44
+ "Stacks Project": [
45
+ "Sets", "Schemes", "Algebraic Stacks", "Étale Cohomology"
46
+ ]
47
+ }
48
+
49
+ ALLOWED_TYPES = [
50
+ "theorem", "lemma", "proposition", "corollary", "definition", "remark", "assumption"
51
+ ]
52
+
53
+ # Load the Embedding Model
54
  @st.cache_resource
55
  def load_model():
56
  """
 
64
  return None
65
 
66
 
67
+ # Load Data from RDS
68
  @st.cache_data
69
  def load_papers_from_rds():
70
  """
 
131
 
132
  all_theorems_data.append({
133
  "paper_id": paper_id,
134
+ "authors": authors,
135
  "paper_title": title,
136
  "paper_url": link,
137
+ "year": last_updated.year,
138
+ "primary_category": primary_category,
139
  "theorem_name": theorem_name,
140
  "theorem_slogan": theorem_slogan,
141
  "theorem_body": theorem_body,
142
  "global_context": global_context,
 
143
  "stored_embedding": embedding
144
  })
145
 
 
153
  # --- 3. The Search Function ---
154
  def search_theorems(query, model, theorems_data, embeddings_db):
155
  """
156
+ Takes a user query and finds the top 10 most similar theorems.
157
  """
158
  if not query:
159
  st.info("Please enter a search query.")
 
161
 
162
  query_embedding = model.encode(query, convert_to_tensor=True)
163
  cosine_scores = util.cos_sim(query_embedding, embeddings_db)[0]
164
+ top_results_indices = np.argsort(-cosine_scores.cpu())[:10]
165
 
166
  st.subheader("Top 5 Most Similar Theorems")
167
 
 
174
  similarity = cosine_scores[idx].item()
175
  theorem_info = theorems_data[idx]
176
 
 
177
  expander_title = f"**Result {i+1} | Similarity: {similarity:.4f}**"
178
  if theorem_info.get("theorem_name"):
179
  expander_title += f" | {theorem_info['theorem_name']}"
180
 
181
  with st.expander(expander_title):
182
  st.markdown(f"**Paper:** {theorem_info.get('paper_title', 'Unknown')}")
183
+ st.markdown(f"**Authors:** {', '.join(theorem_info['authors']) if theorem_info['authors'] else 'N/A'}")
184
  st.markdown(f"**Source:** [{theorem_info['paper_url']}]({theorem_info['paper_url']})")
185
+ st.markdown(
186
+ f"**Math Tag:** `{theorem_info['primary_category']}` | **Year:** {theorem_info.get('year', 'N/A')}")
187
+ st.markdown("---")
188
 
 
189
  if theorem_info.get("theorem_slogan"):
190
  st.markdown(f"**Slogan:** {theorem_info['theorem_slogan']}")
191
  st.write("")
192
 
 
193
  if theorem_info["global_context"]:
194
+ cleaned_ctx = clean_latex_for_display(theorem_info["global_context"])
195
+ st.markdown(f"> {cleaned_ctx.replace('\n', '\n> ')}")
196
  st.write("")
197
 
198
+ cleaned_content = clean_latex_for_display(theorem_info['theorem_body'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  st.markdown(f"**Theorem Body:**")
200
  st.markdown(cleaned_content)
201
 
 
202
  # --- Main App Interface ---
203
  st.set_page_config(page_title="Theorem Search Demo", layout="wide")
204
  st.title("📚 Semantic Theorem Search")
205
  st.write("This demo uses a specialized mathematical language model to find theorems semantically similar to your query.")
206
+ st.markdown("*Note: Linking to a specific page within an arXiv PDF is not directly possible.*",
207
+ help="arXiv links redirect to the paper's abstract, not a specific page in the PDF.")
208
 
209
  model = load_model()
210
  theorems_data = load_papers_from_rds()
211
 
212
  if model and theorems_data:
213
  with st.spinner("Preparing embeddings from database..."):
 
214
  corpus_embeddings = np.array([item['stored_embedding'] for item in theorems_data])
215
 
216
+ st.success(f"Successfully loaded {len(theorems_data)} theorems from arXiv. Ready to search!")
217
 
218
+ user_query = st.text_input("Enter your query:", "")
219
 
220
  if st.button("Search") or user_query:
221
  search_theorems(user_query, model, theorems_data, corpus_embeddings)
222
  else:
223
+ st.error("Could not load the model or data from RDS. Please check your RDS database connection and credentials.")