Spaces:
Running
Running
| import streamlit as st | |
| import json | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer, util | |
| import os | |
| import boto3 | |
| import psycopg2 | |
| from psycopg2.extensions import connection | |
| import torch | |
| import re | |
| import requests | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from dotenv import load_dotenv | |
| from latex_clean import clean_latex_for_display | |
| # Config | |
| load_dotenv() | |
| def get_rds_connection() -> connection: | |
| region = os.getenv("AWS_REGION") | |
| secret_arn = os.getenv("RDS_SECRET_ARN") | |
| host = os.getenv("RDS_HOST") | |
| dbname = os.getenv("RDS_DB_NAME") | |
| sm = boto3.client("secretsmanager", region_name=region) | |
| secret_value = sm.get_secret_value(SecretId=secret_arn) | |
| secret_dict = json.loads(secret_value["SecretString"]) | |
| conn = psycopg2.connect( | |
| host=host or secret_dict.get("host"), | |
| port=int(secret_dict.get("port", 5432)), | |
| dbname=dbname or secret_dict.get("dbname"), | |
| user=secret_dict["username"], | |
| password=secret_dict["password"], | |
| sslmode="require", | |
| ) | |
| return conn | |
| AVAILABLE_TAGS = { | |
| "arXiv": [ | |
| "math.AC", "math.AG", "math.AP", "math.AT", "math.CA", "math.CO", | |
| "math.CT", "math.CV", "math.DG", "math.DS", "math.FA", "math.GM", | |
| "math.GN", "math.GR", "math.GT", "math.HO", "math.IT", "math.KT", | |
| "math.LO", "math.MG", "math.MP", "math.NA", "math.NT", "math.OA", | |
| "math.OC", "math.PR", "math.QA", "math.RA", "math.RT", "math.SG", | |
| "math.SP", "math.ST", "Statistics Theory" | |
| ], | |
| "Stacks Project": [ | |
| "Sets", "Schemes", "Algebraic Stacks", "Étale Cohomology" | |
| ] | |
| } | |
| ALLOWED_TYPES = [ | |
| "theorem", "lemma", "proposition", "corollary", "definition", "remark", "assumption" | |
| ] | |
| ARXIV_ID_RE = re.compile( | |
| r'(?:arxiv\.org/(?:abs|pdf)/)?((?:\d{4}\.\d{4,5}|[a-z\-]+/\d{7}))', | |
| re.IGNORECASE | |
| ) | |
| # Load the Embedding Model | |
| def load_model(): | |
| """ | |
| Loads the specialized math embedding model from Hugging Face. | |
| """ | |
| try: | |
| model = SentenceTransformer('math-similarity/Bert-MLM_arXiv-MP-class_zbMath') | |
| return model | |
| except Exception as e: | |
| st.error(f"Error loading the embedding model: {e}") | |
| return None | |
| # Load Data from RDS | |
| def load_papers_from_rds(): | |
| """ | |
| Loads theorem data from the RDS database and prepares it for embedding. | |
| Returns a list of theorem dictionaries with all necessary fields. | |
| """ | |
| try: | |
| conn = get_rds_connection() | |
| cur = conn.cursor() | |
| # Fetch all papers with their theorems and embeddings | |
| cur.execute(""" | |
| SELECT | |
| tm.paper_id, | |
| tm.title, | |
| tm.authors, | |
| tm.link, | |
| tm.last_updated, | |
| tm.summary, | |
| tm.journal_ref, | |
| tm.primary_category, | |
| tm.categories, | |
| tm.global_notations, | |
| tm.global_definitions, | |
| tm.global_assumptions, | |
| te.theorem_name, | |
| te.theorem_slogan, | |
| te.theorem_body, | |
| te.embedding | |
| FROM theorem_metadata tm | |
| JOIN theorem_embedding te ON tm.paper_id = te.paper_id | |
| ORDER BY tm.paper_id, te.theorem_name; | |
| """) | |
| rows = cur.fetchall() | |
| cur.close() | |
| conn.close() | |
| all_theorems_data = [] | |
| for row in rows: | |
| (paper_id, title, authors, link, last_updated, summary, | |
| journal_ref, primary_category, categories, | |
| global_notations, global_definitions, global_assumptions, | |
| theorem_name, theorem_slogan, theorem_body, embedding) = row | |
| # Build global context | |
| global_context_parts = [] | |
| if global_notations: | |
| global_context_parts.append(f"**Global Notations:**\n{global_notations}") | |
| if global_definitions: | |
| global_context_parts.append(f"**Global Definitions:**\n{global_definitions}") | |
| if global_assumptions: | |
| global_context_parts.append(f"**Global Assumptions:**\n{global_assumptions}") | |
| global_context = "\n\n".join(global_context_parts) | |
| # Convert embedding to a numpy float array | |
| if isinstance(embedding, str): | |
| embedding = json.loads(embedding) | |
| if isinstance(embedding, list): | |
| embedding = np.array(embedding, dtype=np.float32) | |
| elif isinstance(embedding, np.ndarray): | |
| embedding = embedding.astype(np.float32) | |
| # Determine source from url | |
| link_str = link or "" | |
| if link_str.startswith("http://arxiv.org") or link_str.startswith("https://arxiv.org"): | |
| source = "arXiv" | |
| else: | |
| source = "Stacks Project" | |
| # Determine type from name | |
| def infer_type(name: str) -> str: | |
| if not name: | |
| return "theorem" | |
| lower = name.lower() | |
| for t in ["theorem", "lemma", "proposition", "corollary", "definition", "remark", "assumption"]: | |
| if t in lower: | |
| return t | |
| return "theorem" | |
| inferred_type = infer_type(theorem_name or "") | |
| all_theorems_data.append({ | |
| "paper_id": paper_id, | |
| "authors": authors, | |
| "paper_title": title, | |
| "paper_url": link, | |
| "year": last_updated.year, | |
| "primary_category": primary_category, | |
| "source": source, | |
| "type": inferred_type, | |
| "journal_published": bool(journal_ref), | |
| "citations": None, | |
| "theorem_name": theorem_name, | |
| "theorem_slogan": theorem_slogan, | |
| "theorem_body": theorem_body, | |
| "global_context": global_context, | |
| "stored_embedding": embedding, | |
| }) | |
| return all_theorems_data | |
| except Exception as e: | |
| st.error(f"Error loading data from RDS: {e}") | |
| return [] | |
| # cache for 24 hours | |
| def fetch_citations(paper_url: str, title: str) -> int | None: | |
| """ | |
| Returns citation count if found, else None. | |
| Tries the following sources in order: | |
| 1) OpenAlex by arXiv id | |
| 2) Semantic Scholar by arXiv id | |
| 3) Semantic Scholar by title | |
| """ | |
| arx_id = None | |
| if paper_url: | |
| m = ARXIV_ID_RE.search(paper_url) | |
| if m: | |
| arx_id = m.group(1) | |
| # OpenAlex by arXiv id | |
| if arx_id: | |
| try: | |
| r = requests.get(f"https://api.openalex.org/works/arXiv:{arx_id}", timeout=10) | |
| if r.ok: | |
| data = r.json() | |
| c = data.get("cited_by_count") | |
| if isinstance(c, int): | |
| return c | |
| except Exception: | |
| pass | |
| # Semantic Scholar by arXiv id | |
| if arx_id: | |
| try: | |
| r = requests.get( | |
| f"https://api.semanticscholar.org/graph/v1/paper/arXiv:{arx_id}", | |
| params={"fields": "citationCount"}, | |
| timeout=10 | |
| ) | |
| if r.ok: | |
| j = r.json() | |
| c = j.get("citationCount") | |
| if isinstance(c, int): | |
| return c | |
| except Exception: | |
| pass | |
| # Fallback: Semantic Scholar by title | |
| if title: | |
| try: | |
| r = requests.get( | |
| "https://api.semanticscholar.org/graph/v1/paper/search", | |
| params={"query": title, "limit": 1, "fields": "title,citationCount"}, | |
| timeout=10 | |
| ) | |
| if r.ok: | |
| j = r.json() | |
| if j.get("data"): | |
| c = j["data"][0].get("citationCount") | |
| if isinstance(c, int): | |
| return c | |
| except Exception: | |
| pass | |
| return None | |
| def add_citations(candidates: list[dict], max_workers: int = 6) -> None: | |
| # Select targets with missing citations | |
| targets = [ | |
| it for it in candidates | |
| if it.get("source") == "arXiv" and (it.get("citations") in (None, 0)) | |
| ] | |
| if not targets: | |
| return | |
| with ThreadPoolExecutor(max_workers=max_workers) as exe: | |
| fut2item = { | |
| exe.submit(fetch_citations, it.get("paper_url"), it.get("paper_title")): it | |
| for it in targets | |
| } | |
| for fut in as_completed(fut2item): | |
| it = fut2item[fut] | |
| try: | |
| c = fut.result() | |
| if c is not None: | |
| it["citations"] = c | |
| except Exception: | |
| pass | |
| def extract_arxiv_id(s: str) -> str | None: | |
| """Return normalized arXiv ID if present in s (URL or raw), else None.""" | |
| if not s: | |
| return None | |
| m = ARXIV_ID_RE.search(s.strip()) | |
| return m.group(1) if m else None | |
| def normalize_title(s: str) -> str: | |
| return (s or "").casefold().strip() | |
| def parse_paper_filter_input(raw: str) -> dict: | |
| """ | |
| Parse user input into two sets: arxiv_ids and title substrings. | |
| Multiple entries may be comma-separated. | |
| e.g. "2401.12345, Optimal Transport" -> {"ids":{"2401.12345"}, "titles":{"optimal transport"}} | |
| """ | |
| ids, titles = set(), set() | |
| if not raw: | |
| return {"ids": ids, "titles": titles} | |
| for token in [t.strip() for t in raw.split(",") if t.strip()]: | |
| arx = extract_arxiv_id(token) | |
| if arx: | |
| ids.add(arx.lower()) | |
| else: | |
| titles.add(normalize_title(token)) | |
| return {"ids": ids, "titles": titles} | |
| def item_matches_paper_filter(item: dict, paper_filter: dict) -> bool: | |
| """ | |
| True if the item matches at least one requested arXiv ID or one title substring. | |
| If paper_filter is empty (both sets empty), always True. | |
| """ | |
| ids = paper_filter.get("ids", set()) | |
| titles = paper_filter.get("titles", set()) | |
| if not ids and not titles: | |
| return True | |
| # Compare IDs (extract once from url) | |
| url = item.get("paper_url") or "" | |
| item_id = extract_arxiv_id(url) | |
| if item_id and item_id.lower() in ids: | |
| return True | |
| # Compare titles (substring, case-insensitive) | |
| t = normalize_title(item.get("paper_title")) | |
| if t and any(sub in t for sub in titles): | |
| return True | |
| return False | |
| # --- Search and Display --- | |
| def search_and_display_with_filters(query, model, theorems_data, embeddings_db, filters): | |
| if not filters['sources']: | |
| st.warning("Please select at least one source.") | |
| return | |
| if query: | |
| query_embedding = model.encode(query, convert_to_tensor=True) | |
| cosine_scores = util.cos_sim(query_embedding, embeddings_db)[0] | |
| else: | |
| cosine_scores = torch.zeros(len(theorems_data)) | |
| low, high = filters['citation_range'] | |
| # Get a larger pool to filter from | |
| top_k_pool = min(200, len(theorems_data)) | |
| top_indices = torch.topk(cosine_scores, k=top_k_pool, sorted=True).indices | |
| top_indices = top_indices.tolist() | |
| paper_filter = filters.get("paper_filter", {"ids": set(), "titles": set()}) | |
| matched_indices = [] | |
| if paper_filter and (paper_filter.get("ids") or paper_filter.get("titles")): | |
| for i, it in enumerate(theorems_data): | |
| if item_matches_paper_filter(it, paper_filter): | |
| matched_indices.append(i) | |
| pool_indices = list(dict.fromkeys(top_indices + matched_indices)) | |
| pool = [(i, theorems_data[i]) for i in pool_indices] | |
| # Fetch citations in parallel | |
| if ('arXiv' in filters['sources']): | |
| add_citations([it for _, it in pool]) | |
| results = [] | |
| # Filter results | |
| for idx, item in pool: | |
| type_match = (not filters['types']) or (item.get('type','').lower() in filters['types']) | |
| tag_match = (not filters['tags']) or (item.get('primary_category') in filters['tags']) | |
| author_match = (not filters['authors']) or any(a in (item.get('authors') or []) for a in filters['authors']) | |
| source_match = item.get('source') in filters['sources'] | |
| paper_match = item_matches_paper_filter(item, filters['paper_filter']) | |
| # Citations & year & journal only for arXiv | |
| citations = item.get('citations') | |
| log_cit = np.log1p(int(citations)) if citations is not None else 0.0 | |
| if citations is None: | |
| if not filters['include_unknown_citations']: | |
| continue | |
| citation_match = True | |
| else: | |
| citation_match = (low <= int(citations) <= high) | |
| year_match = True | |
| if filters['year_range'] and item.get('source') == 'arXiv': | |
| y = item.get('year') or 0 | |
| yr0, yr1 = filters['year_range'] | |
| year_match = (yr0 <= y <= yr1) | |
| journal_match = True | |
| if item.get('source') == 'arXiv': | |
| status = filters['journal_status'] | |
| jp = bool(item.get('journal_published')) | |
| if status == "Journal Article": | |
| journal_match = jp | |
| elif status == "Preprint Only": | |
| journal_match = not jp | |
| if all([type_match, tag_match, author_match, source_match, paper_match, citation_match, year_match, journal_match]): | |
| # Similarity = cosine_similary + citation_weight * log(citation_count) | |
| similarity = float(cosine_scores[idx].item()) + filters['citation_weight'] * log_cit | |
| results.append({"idx": idx, "info": item, "similarity": similarity}) | |
| if len(results) >= filters['top_k']: | |
| break | |
| results.sort(key=lambda r: r["similarity"], reverse=True) | |
| results = results[:filters['top_k']] | |
| st.subheader(f"Found {len(results)} Matching Results") | |
| if not results: | |
| st.warning("No results found for the current filters.") | |
| return | |
| for i, r in enumerate(results): | |
| info = r["info"] | |
| expander_title = f"**Result {i+1} | Similarity: {r['similarity']:.4f} | Type: {info.get('type','').title()}**" | |
| with st.expander(expander_title, expanded=True): | |
| st.markdown(f"**Paper:** *{info.get('paper_title','Unknown')}*") | |
| st.markdown(f"**Authors:** {', '.join(info.get('authors') or []) or 'N/A'}") | |
| st.markdown(f"**Source:** {info.get('source')} ({info.get('paper_url')})") | |
| citations = info.get("citations") | |
| cit_str = "Unknown" if citations is None else str(citations) | |
| st.markdown( | |
| f"**Math Tag:** `{info.get('primary_category')}` | " | |
| f"**Citations:** {cit_str} | " | |
| f"**Year:** {info.get('year', 'N/A')}" | |
| ) | |
| # Testing only | |
| if filters['citation_weight'] > 0: | |
| base = float(cosine_scores[r["idx"]].item()) | |
| log_cit = np.log1p(int(citations)) if citations is not None else 0.0 | |
| st.caption( | |
| f"base_cosine={base:.4f} | log(citations)={log_cit:.4f} | weight={filters['citation_weight']:.2f}") | |
| st.markdown("---") | |
| if info.get("theorem_slogan"): | |
| st.markdown(f"**Slogan:** {info['theorem_slogan']}\n") | |
| if info.get("global_context"): | |
| cleaned_ctx = clean_latex_for_display(info["global_context"]) | |
| st.markdown("> " + cleaned_ctx.replace("\n", "\n> ") ) | |
| cleaned_content = clean_latex_for_display(info['theorem_body']) | |
| st.markdown(f"**{info['theorem_name'] or 'Theorem Body.'}**") | |
| st.markdown(cleaned_content) | |
| # Testing only | |
| st.markdown('**Paper ID (testing only)**') | |
| st.markdown(info['paper_id']) | |
| # --- Main App Interface --- | |
| st.set_page_config(page_title="Theorem Search Demo", layout="wide") | |
| st.title("📚 Semantic Theorem Search") | |
| st.write("This demo uses a specialized mathematical language model to find theorems semantically similar to your query.") | |
| model = load_model() | |
| theorems_data = load_papers_from_rds() | |
| if model and theorems_data: | |
| with st.spinner("Preparing embeddings from database..."): | |
| corpus_embeddings = np.array([item['stored_embedding'] for item in theorems_data]) | |
| st.success(f"Successfully loaded {len(theorems_data)} theorems from arXiv and the Stacks Project. Ready to search!") | |
| # --- Sidebar filters --- | |
| with st.sidebar: | |
| st.header("Search Filters") | |
| all_sources = ['arXiv', 'Stacks Project'] | |
| selected_sources = st.multiselect( | |
| "Filter by Source(s):", | |
| all_sources, | |
| default=all_sources[:1] if all_sources else [], | |
| help="Select one or more sources to reveal more filters." | |
| ) | |
| selected_authors, selected_types, selected_tags = [], [], [] | |
| year_range, journal_status = None, "All" | |
| citation_range = (0, 1000) | |
| citation_weight = 0.0 | |
| include_unknown_citations = True | |
| top_k_results = 5 | |
| if selected_sources: | |
| st.write("---") | |
| selected_types = st.multiselect("Filter by Type:", ALLOWED_TYPES) | |
| all_authors = sorted(list(set(a for it in theorems_data for a in (it.get('authors') or [])))) | |
| selected_authors = st.multiselect("Filter by Author(s):", all_authors) | |
| # Tags come from the union of categories per selected source | |
| from collections import defaultdict | |
| tags_per_source = defaultdict(set) | |
| for it in theorems_data: | |
| tags_per_source[it['source']].add(it.get('primary_category')) | |
| union_tags = sorted({t for s in selected_sources for t in tags_per_source.get(s, set()) if t}) | |
| selected_tags = st.multiselect("Filter by Math Tag/Category:", union_tags) | |
| paper_filter_raw = st.text_input("Filter by Paper", | |
| value="", | |
| placeholder="e.g., 2401.12345, Finite Hilbert stability", | |
| help="Filter by title substring or arXiv ID/URL. Use commas for multiple.") | |
| if 'arXiv' in selected_sources: | |
| year_range = st.slider("Filter by Year:", 1991, 2025, (1991, 2025)) | |
| journal_status = st.radio("Publication Status:", ["All", "Journal Article", "Preprint Only"], horizontal=True) | |
| citation_range = st.slider("Filter by Citations:", 0, 1000, (0, 1000)) | |
| citation_weight = st.slider("Citation Weight:", 0.0, 1.0, 0.0, step=0.01) | |
| include_unknown_citations = st.checkbox( | |
| "Include entries with unknown citation counts", | |
| value=True, | |
| help="If unchecked, results with unknown citation counts are excluded." | |
| ) | |
| top_k_results = st.slider("Number of Results to Display:", 1, 20, 5) | |
| filters = { | |
| "authors": selected_authors, | |
| "types": [t.lower() for t in selected_types], | |
| "tags": selected_tags, | |
| "sources": selected_sources, | |
| "paper_filter": parse_paper_filter_input(paper_filter_raw), | |
| "year_range": year_range, | |
| "journal_status": journal_status, | |
| "citation_range": citation_range, | |
| "citation_weight": citation_weight, | |
| "include_unknown_citations": include_unknown_citations, | |
| "top_k": top_k_results, | |
| } | |
| user_query = st.text_input("Enter your query:", "") | |
| if st.button("Search") or user_query: | |
| search_and_display_with_filters(user_query, model, theorems_data, corpus_embeddings, filters) | |
| else: | |
| st.error("Could not load the model or data from RDS. Please check your RDS database connection and credentials.") |