Spaces:
Running
Running
| import streamlit as st | |
| import streamlit_antd_components as sac | |
| import json | |
| from sentence_transformers import SentenceTransformer | |
| import os | |
| import boto3 | |
| import psycopg2 | |
| from psycopg2.extensions import connection | |
| from pgvector.psycopg2 import register_vector | |
| import re | |
| import torch | |
| from collections import defaultdict | |
| from dotenv import load_dotenv | |
| from latex_clean import clean_latex_for_display | |
| # Config | |
| torch.classes.__path__ = [] | |
| load_dotenv() | |
| def get_rds_connection() -> connection: | |
| region = os.getenv("AWS_REGION") | |
| secret_arn = os.getenv("RDS_SECRET_ARN") | |
| host = os.getenv("RDS_HOST") | |
| dbname = os.getenv("RDS_DB_NAME") | |
| sm = boto3.client("secretsmanager", region_name=region) | |
| secret_value = sm.get_secret_value(SecretId=secret_arn) | |
| secret_dict = json.loads(secret_value["SecretString"]) | |
| conn = psycopg2.connect( | |
| host=host or secret_dict.get("host"), | |
| port=int(secret_dict.get("port", 5432)), | |
| dbname=dbname or secret_dict.get("dbname"), | |
| user=secret_dict["username"], | |
| password=secret_dict["password"], | |
| sslmode="require", | |
| ) | |
| register_vector(conn) | |
| return conn | |
| ALLOWED_TYPES = [ | |
| "theorem", "lemma", "proposition", "corollary" | |
| ] | |
| ARXIV_ID_RE = re.compile( | |
| r'(?:arxiv\.org/(?:abs|pdf)/)?((?:\d{4}\.\d{4,5}|[a-z\-]+/\d{7}))', | |
| re.IGNORECASE | |
| ) | |
| EMBED_TABLE = "theorem_embedding_qwen" | |
| # Load the Embedding Model | |
| def load_model(): | |
| try: | |
| model = SentenceTransformer('Qwen/Qwen3-Embedding-0.6B') | |
| return model | |
| except Exception as e: | |
| st.error(f"Error loading the embedding model: {e}") | |
| return None | |
| def infer_type(name: str) -> str: | |
| if not name: | |
| return "theorem" | |
| lower = name.lower() | |
| for t in ALLOWED_TYPES: | |
| if t in lower: | |
| return t | |
| return "theorem" | |
| # cache for 24 hours | |
| def load_authors(): | |
| conn = get_rds_connection() | |
| cur = conn.cursor() | |
| cur.execute(""" | |
| SELECT DISTINCT unnest(p.authors) AS author | |
| FROM paper p | |
| WHERE p.authors IS NOT NULL | |
| """) | |
| rows = cur.fetchall() | |
| cur.close() | |
| conn.close() | |
| authors = sorted(r[0] for r in rows if r[0]) | |
| return authors | |
| # cache for 24 hours | |
| def load_tags_per_source(): | |
| conn = get_rds_connection() | |
| cur = conn.cursor() | |
| cur.execute(""" | |
| SELECT | |
| CASE WHEN p.link ILIKE '%%arxiv.org%%' | |
| THEN 'arXiv' | |
| ELSE 'Stacks Project' | |
| END AS source, | |
| p.primary_category | |
| FROM paper p | |
| WHERE p.primary_category IS NOT NULL | |
| """) | |
| rows = cur.fetchall() | |
| cur.close() | |
| conn.close() | |
| tags_per_source = defaultdict(set) | |
| for source, cat in rows: | |
| tags_per_source[source].add(cat) | |
| return {src: sorted(cats) for src, cats in tags_per_source.items()} | |
| # cache for 24 hours | |
| def load_theorem_count(): | |
| conn = get_rds_connection() | |
| cur = conn.cursor() | |
| cur.execute("SELECT COUNT(*) FROM theorem;") | |
| (n,) = cur.fetchone() | |
| cur.close() | |
| conn.close() | |
| return int(n) | |
| def extract_arxiv_id(s: str) -> str | None: | |
| """Return normalized arXiv ID if present in s (URL or raw), else None.""" | |
| if not s: | |
| return None | |
| m = ARXIV_ID_RE.search(s.strip()) | |
| return m.group(1) if m else None | |
| def normalize_title(s: str) -> str: | |
| return (s or "").casefold().strip() | |
| def parse_paper_filter(raw: str) -> dict: | |
| """ | |
| Parse user input into two sets: arxiv_ids and title substrings. | |
| Multiple entries may be comma-separated. | |
| e.g. "2401.12345, Optimal Transport" -> {"ids":{"2401.12345"}, "titles":{"optimal transport"}} | |
| """ | |
| ids, titles = set(), set() | |
| if not raw: | |
| return {"ids": ids, "titles": titles} | |
| for token in [t.strip() for t in raw.split(",") if t.strip()]: | |
| arx = extract_arxiv_id(token) | |
| if arx: | |
| ids.add(arx.lower()) | |
| else: | |
| titles.add(normalize_title(token)) | |
| return {"ids": ids, "titles": titles} | |
| def save_feedback(feedback, query, url, theorem_name, filters): | |
| conn = get_rds_connection() | |
| cur = conn.cursor() | |
| def make_json_safe(obj): | |
| if isinstance(obj, dict): | |
| return {k: make_json_safe(v) for k, v in obj.items()} | |
| elif isinstance(obj, set): | |
| return list(obj) | |
| elif isinstance(obj, tuple): | |
| return list(obj) | |
| elif isinstance(obj, list): | |
| return [make_json_safe(v) for v in obj] | |
| elif hasattr(obj, "item"): | |
| return obj.item() | |
| else: | |
| return obj | |
| # --- Search and Display --- | |
| def search_and_display(query: str, model, filters: dict): | |
| if not filters['sources']: | |
| st.warning("Please select at least one source.") | |
| return | |
| citation_weight = float(filters['citation_weight']) | |
| # Encode query to numpy array | |
| query_vec = model.encode(query or "", normalize_embeddings=True, convert_to_numpy=True) | |
| where = [] | |
| params = [] | |
| # Source | |
| if filters['sources']: | |
| src_cases = [] | |
| if 'arXiv' in filters['sources']: | |
| src_cases.append(" (p.link ILIKE '%%arxiv.org%%') ") | |
| if 'Stacks Project' in filters['sources']: | |
| src_cases.append(" (p.link NOT ILIKE '%%arxiv.org%%') ") | |
| if src_cases: | |
| where.append("(" + " OR ".join(src_cases) + ")") | |
| # Authors | |
| if filters['authors']: | |
| where.append(" p.authors && %s ") | |
| params.append(filters['authors']) | |
| # Tag/category | |
| if filters['tags']: | |
| where.append(" p.primary_category = ANY(%s) ") | |
| params.append(filters['tags']) | |
| # Year (arXiv only) | |
| if filters['year_range']: | |
| yr0, yr1 = filters['year_range'] | |
| where.append(""" | |
| ( (p.link ILIKE '%%arxiv.org%%' AND EXTRACT(YEAR FROM p.last_updated) BETWEEN %s AND %s) | |
| OR (p.link NOT ILIKE '%%arxiv.org%%') ) | |
| """) | |
| params.extend([yr0, yr1]) | |
| # Journal status (arXiv only) | |
| if filters['journal_status'] != "All": | |
| if filters['journal_status'] == "Journal Article": | |
| where.append(" (p.link ILIKE '%%arxiv.org%%' AND p.journal_ref IS NOT NULL) ") | |
| elif filters['journal_status'] == "Preprint Only": | |
| where.append(" (p.link ILIKE '%%arxiv.org%%' AND p.journal_ref IS NULL) ") | |
| # Paper filter: arXiv id in link or title substring(s) | |
| pf = filters.get("paper_filter", {"ids": set(), "titles": set()}) | |
| id_patterns = [f"%{i}%" for i in pf.get("ids", set())] | |
| title_patterns = [f"%{t}%" for t in pf.get("titles", set())] | |
| pf_clauses = [] | |
| if id_patterns: | |
| pf_clauses.append(" p.link ILIKE ANY(%s) ") | |
| params.append(id_patterns) | |
| if title_patterns: | |
| pf_clauses.append(" p.title ILIKE ANY(%s) ") | |
| params.append(title_patterns) | |
| if pf_clauses: | |
| where.append("(" + " OR ".join(pf_clauses) + ")") | |
| # Result type | |
| if filters['types']: | |
| like_any = [f"%{t}%" for t in filters['types']] | |
| where.append(" lower(t.name) ILIKE ANY(%s) ") | |
| params.append(like_any) | |
| # Citations | |
| low, high = filters["citation_range"] | |
| include_unknown = filters["include_unknown_citations"] | |
| if include_unknown: | |
| where.append("( (p.citations BETWEEN %s AND %s) OR p.citations IS NULL )") | |
| else: | |
| where.append("( p.citations IS NOT NULL AND (p.citations BETWEEN %s AND %s) )") | |
| params.extend([low, high]) | |
| conn = get_rds_connection() | |
| cur = conn.cursor() | |
| results = [] | |
| # Fetch results from RDS | |
| if citation_weight == 0.0: | |
| sql = f""" | |
| WITH latest_slogan AS ( | |
| SELECT DISTINCT ON (ts.theorem_id) | |
| ts.theorem_id, ts.slogan_id, ts.slogan | |
| FROM theorem_slogan ts | |
| ORDER BY ts.theorem_id, ts.slogan_id DESC | |
| ) | |
| SELECT | |
| p.paper_id, | |
| p.title, | |
| p.authors, | |
| p.link, | |
| p.last_updated, | |
| p.summary, | |
| p.journal_ref, | |
| p.primary_category, | |
| p.categories, | |
| p.citations, | |
| t.theorem_id, | |
| t.name AS theorem_name, | |
| t.body AS theorem_body, | |
| ls.slogan AS theorem_slogan, | |
| (1.0 - (e.embedding <#> %s::vector)) AS similarity | |
| FROM paper p | |
| JOIN theorem t ON t.paper_id = p.paper_id | |
| JOIN latest_slogan ls ON ls.theorem_id = t.theorem_id | |
| JOIN {EMBED_TABLE} e ON e.slogan_id = ls.slogan_id | |
| {'WHERE ' + ' AND '.join(where) if where else ''} | |
| ORDER BY e.embedding <#> %s::vector ASC | |
| LIMIT %s; | |
| """ | |
| exec_params = [query_vec, *params, query_vec, int(filters['top_k'])] | |
| cur.execute(sql, exec_params) | |
| rows = cur.fetchall() | |
| for (paper_id, title, authors, link, last_updated, summary, journal_ref, | |
| primary_category, categories, citations, theorem_id, theorem_name, | |
| theorem_body, theorem_slogan, similarity) in rows: | |
| link_str = link or "" | |
| source = "arXiv" if "arxiv.org" in link_str else "Stacks Project" | |
| inferred_type = infer_type(theorem_name or "") | |
| year = last_updated.year if last_updated else None | |
| results.append({ | |
| "paper_id": paper_id, | |
| "authors": authors, | |
| "paper_title": title, | |
| "paper_url": link, | |
| "year": year, | |
| "primary_category": primary_category, | |
| "source": source, | |
| "type": inferred_type, | |
| "journal_published": bool(journal_ref), | |
| "citations": citations, | |
| "theorem_id": theorem_id, | |
| "theorem_name": theorem_name, | |
| "theorem_slogan": theorem_slogan, | |
| "theorem_body": theorem_body, | |
| "similarity": float(similarity), | |
| "score": float(similarity), | |
| }) | |
| else: | |
| pool_size = max(50, int(filters['top_k']) * 10) | |
| sql = f""" | |
| WITH latest_slogan AS ( | |
| SELECT DISTINCT ON (ts.theorem_id) | |
| ts.theorem_id, ts.slogan_id, ts.slogan | |
| FROM theorem_slogan ts | |
| ORDER BY ts.theorem_id, ts.slogan_id DESC | |
| ), | |
| candidates AS ( | |
| SELECT | |
| p.paper_id, | |
| p.title, | |
| p.authors, | |
| p.link, | |
| p.last_updated, | |
| p.summary, | |
| p.journal_ref, | |
| p.primary_category, | |
| p.categories, | |
| p.citations, | |
| t.theorem_id, | |
| t.name AS theorem_name, | |
| t.body AS theorem_body, | |
| ls.slogan AS theorem_slogan, | |
| (1.0 - (e.embedding <#> %s::vector)) AS similarity | |
| FROM paper p | |
| JOIN theorem t ON t.paper_id = p.paper_id | |
| JOIN latest_slogan ls ON ls.theorem_id = t.theorem_id | |
| JOIN {EMBED_TABLE} e ON e.slogan_id = ls.slogan_id | |
| {'WHERE ' + ' AND '.join(where) if where else ''} | |
| ORDER BY e.embedding <#> %s::vector ASC | |
| LIMIT {pool_size} | |
| ) | |
| SELECT | |
| *, | |
| ( | |
| similarity + | |
| %s * CASE | |
| WHEN citations IS NOT NULL AND citations > 0 | |
| THEN ln(citations::float) | |
| ELSE 0 | |
| END | |
| ) AS weighted_score | |
| FROM candidates | |
| ORDER BY weighted_score DESC, similarity DESC | |
| LIMIT %s; | |
| """ | |
| exec_params = [query_vec, *params, query_vec, citation_weight, int(filters['top_k'])] | |
| cur.execute(sql, exec_params) | |
| rows = cur.fetchall() | |
| for (paper_id, title, authors, link, last_updated, summary, journal_ref, | |
| primary_category, categories, citations, theorem_id, theorem_name, | |
| theorem_body, theorem_slogan, similarity, weighted_score) in rows: | |
| link_str = link or "" | |
| source = "arXiv" if "arxiv.org" in link_str else "Stacks Project" | |
| inferred_type = infer_type(theorem_name or "") | |
| year = last_updated.year if last_updated else None | |
| results.append({ | |
| "paper_id": paper_id, | |
| "authors": authors, | |
| "paper_title": title, | |
| "paper_url": link, | |
| "year": year, | |
| "primary_category": primary_category, | |
| "source": source, | |
| "type": inferred_type, | |
| "journal_published": bool(journal_ref), | |
| "citations": citations, | |
| "theorem_id": theorem_id, | |
| "theorem_name": theorem_name, | |
| "theorem_slogan": theorem_slogan, | |
| "theorem_body": theorem_body, | |
| "similarity": float(similarity), | |
| "score": float(weighted_score), | |
| }) | |
| cur.close() | |
| conn.close() | |
| # Display results | |
| st.subheader(f"Found {len(results)} Matching Results") | |
| if not results: | |
| st.warning("No results found for the current filters.") | |
| return | |
| for i, info in enumerate(results): | |
| expander_title = f"**Result {i + 1} | Similarity: {info['score']:.4f} | {info.get('type', '').title()}**" | |
| with st.expander(expander_title, expanded=True): | |
| st.markdown(f"**Paper:** *{info.get('paper_title', 'Unknown')}*") | |
| st.markdown(f"**Authors:** {', '.join(info.get('authors') or []) or 'N/A'}") | |
| st.markdown(f"**Source:** {info.get('source')}") | |
| sac.buttons( | |
| items= | |
| [sac.ButtonsItem(label=info.get("paper_url"), icon="link-45deg", href=info.get("paper_url"))], | |
| variant="outline", | |
| color="violet", | |
| index=-1, | |
| key=f"link_{i}" | |
| ) | |
| citations = info.get("citations") | |
| cit_str = "Unknown" if citations is None else str(citations) | |
| st.markdown( | |
| f"**Tag:** `{info.get('primary_category')}` | " | |
| f"**Citations:** {cit_str} | " | |
| f"**Year:** {info.get('year', 'N/A')}" | |
| ) | |
| st.markdown("---") | |
| if info.get("theorem_slogan"): | |
| st.markdown(f"**Slogan:** {info['theorem_slogan']}\n") | |
| cleaned_content = clean_latex_for_display(info['theorem_body']) | |
| st.markdown(f"**{info['theorem_name'] or 'Theorem Body.'}**") | |
| st.markdown(cleaned_content) | |
| sac.buttons( | |
| items= | |
| [ | |
| sac.ButtonsItem(icon="hand-thumbs-up"), | |
| sac.ButtonsItem(icon="hand-thumbs-down") | |
| ], | |
| variant="outline", | |
| color="violet", | |
| index=-1, | |
| key=f"feedback_{i}") | |
| # --- Main App Interface --- | |
| st.set_page_config(page_title="Theorem Search Demo", layout="wide") | |
| st.title("Math Theorem Search") | |
| st.write("This demo finds mathematical theorems that are semantically similar to your query.") | |
| model = load_model() | |
| theorem_count = load_theorem_count() | |
| authors = load_authors() | |
| tags_per_source = load_tags_per_source() | |
| if model: | |
| st.success(f"Successfully loaded {theorem_count} theorems from arXiv and the Stacks Project. Ready to search!") | |
| # --- Sidebar filters --- | |
| st.logo(image="images/math-ai-logo.jpg", size="large", link="https://sites.math.washington.edu/ai/") | |
| with st.sidebar: | |
| st.header("Search Filters") | |
| all_sources = ['arXiv', 'Stacks Project'] | |
| selected_sources = st.multiselect( | |
| "Filter by Source(s):", | |
| all_sources, | |
| default=all_sources[:1] if all_sources else [], | |
| help="Select one or more sources to reveal more filters." | |
| ) | |
| selected_authors, selected_types, selected_tags = [], [], [] | |
| paper_filter = "" | |
| year_range, journal_status = None, "All" | |
| citation_range = (0, 1000) | |
| citation_weight = 0.0 | |
| include_unknown_citations = True | |
| top_k_results = 5 | |
| if selected_sources: | |
| st.write("---") | |
| selected_types = st.multiselect("Filter by Type:", ALLOWED_TYPES) | |
| selected_authors = st.multiselect("Filter by Author(s):", authors) | |
| # Tags per selected source(s) | |
| union_tags = sorted({ | |
| t | |
| for s in selected_sources | |
| for t in tags_per_source.get(s, []) | |
| if t | |
| }) | |
| selected_tags = st.multiselect("Filter by Tag/Category:", union_tags) | |
| paper_filter = st.text_input("Filter by Paper", | |
| value="", | |
| placeholder="e.g., 2401.12345, Finite Hilbert stability", | |
| help="Filter by title substring or arXiv ID/URL. Use commas for multiple.") | |
| if 'arXiv' in selected_sources: | |
| year_range = st.slider("Filter by Year:", 1991, 2025, (1991, 2025)) | |
| journal_status = st.radio("Publication Status:", | |
| ["All", "Journal Article", "Preprint Only"], | |
| horizontal=True) | |
| citation_range = st.slider("Filter by Citations:", 0, 1000, (0,1000), step=10) | |
| citation_weight = st.slider("Citation Weight:", 0.0, 1.0, 0.0, step=0.01, | |
| help="If nonzero, results are ranked by base_score $+$ weight $\\times$ " | |
| "$\\log($citations$)$. This will increase search time." | |
| ) | |
| include_unknown_citations = st.checkbox( | |
| "Include entries with unknown citation counts", | |
| value=True, | |
| help="If unchecked, results with unknown citation counts are excluded." | |
| ) | |
| top_k_results = st.slider("Number of Results to Display:", 1, 20, 5) | |
| filters = { | |
| "authors": selected_authors, | |
| "types": [t.lower() for t in selected_types], | |
| "tags": selected_tags, | |
| "sources": selected_sources, | |
| "paper_filter": parse_paper_filter(paper_filter), | |
| "year_range": year_range, | |
| "journal_status": journal_status, | |
| "citation_range": citation_range, | |
| "citation_weight": citation_weight, | |
| "include_unknown_citations": include_unknown_citations, | |
| "top_k": top_k_results, | |
| } | |
| user_query = st.text_input("Enter your query:", "") | |
| if st.button("Search") or user_query: | |
| with st.spinner("Fetching theorems..."): | |
| search_and_display(user_query, model, filters) | |
| else: | |
| st.error("Could not load the model or data from RDS. Please check your RDS database connection and credentials.") |