import os from pathlib import Path import streamlit as st import pandas as pd from sentence_transformers import SentenceTransformer, util import torch from spellchecker import SpellChecker from io import StringIO # --- Configuration --- EMBEDDING_MODEL = 'sentence-transformers/all-MiniLM-L6-v2' # Define paths relative to the script's directory DATA_FILE = 'papers_data.pkl' EMBEDDINGS_FILE = 'paper_embeddings.pt' CSV_FILE = 'papers_with_abstracts_parallel.csv' # --- Caching Functions --- @st.cache_resource def load_embedding_model(): """Loads the Sentence Transformer model and caches it.""" return SentenceTransformer(EMBEDDING_MODEL) @st.cache_resource def load_spell_checker(): """Loads the SpellChecker object and caches it.""" return SpellChecker() # --- Core Functions --- def create_and_save_embeddings(model, data_df): """ Generates and saves document embeddings and the dataframe. This function is called only once if the files don't exist. """ print("First time setup: Generating and saving embeddings. This may take a moment...") # Combine title and abstract for richer embeddings data_df['text_to_embed'] = data_df['title'] + ". " + data_df['abstract'].fillna('') # Generate embeddings corpus_embeddings = model.encode(data_df['text_to_embed'].tolist(), convert_to_tensor=True, show_progress_bar=True) # Save embeddings and dataframe to /tmp directory try: torch.save(corpus_embeddings.cpu(), EMBEDDINGS_FILE) data_df.to_pickle(DATA_FILE) print("Embeddings and data saved successfully!") except Exception as e: print(f"Could not save embeddings to disk: {e}. Will regenerate on each session.") return corpus_embeddings, data_df @st.cache_data def load_data_and_embeddings(): """ Loads the saved embeddings and dataframe from disk. If files don't exist, it calls the creation function. """ model = load_embedding_model() # Check if files exist and are readable if os.path.exists(EMBEDDINGS_FILE) and os.path.exists(DATA_FILE): try: corpus_embeddings = torch.load(EMBEDDINGS_FILE) data_df = pd.read_pickle(DATA_FILE) return model, corpus_embeddings, data_df except Exception as e: print(f"Could not load saved embeddings: {e}. Regenerating...") print("embeding model path exists: " + str(Path(EMBEDDING_MODEL).exists())) # Load the raw data from CSV try: data_df = pd.read_csv(CSV_FILE) corpus_embeddings, data_df = create_and_save_embeddings(model, data_df) except FileNotFoundError: print(f"CSV file '{CSV_FILE}' not found. Please ensure it's in your repository.") st.stop() except Exception as e: print(f"Error loading data: {e}") st.stop() return model, corpus_embeddings, data_df def correct_query_spelling(query, spell_checker): """ Corrects potential spelling mistakes in the user's query. """ if not query: return "" # Split the query into words words = query.split() # Find words that are likely misspelled misspelled = spell_checker.unknown(words) if not misspelled: return query # Return original if no typos found # Generate the corrected query corrected_words = [] for word in words: if word in misspelled: corrected_word = spell_checker.correction(word) # Use the correction, but fall back to the original word if no correction is found corrected_words.append(corrected_word if corrected_word else word) else: corrected_words.append(word) return " ".join(corrected_words) def semantic_search(query, model, corpus_embeddings, data_df, top_k=10): """ Performs semantic search on the loaded data. """ if not query: return [] # Encode the query query_embedding = model.encode(query, convert_to_tensor=True) # Calculate cosine similarity cos_scores = util.cos_sim(query_embedding, corpus_embeddings)[0] # Get the top k results, ensuring we don't ask for more results than exist top_k = min(top_k, len(corpus_embeddings)) top_results = torch.topk(cos_scores, k=top_k) # Format results results = [] for score, idx in zip(top_results[0], top_results[1]): item = data_df.iloc[idx.item()] results.append({ "title": item["title"], "authors": item["authors"], "link": item["link"], "abstract": item["abstract"], "score": score.item() # Score is kept for potential future use but not displayed }) return results # --- Streamlit App UI --- st.set_page_config(page_title="Semantic Paper Search", layout="wide") st.title("📄 Semantic Research Paper Search") st.markdown(""" Enter a query below to search through a small collection of ICML 2025 papers. The search is performed by comparing the semantic meaning of your query with the papers' titles and abstracts. Spelling mistakes in your query will be automatically corrected. """) # Load all necessary data try: model, corpus_embeddings, data_df = load_data_and_embeddings() spell_checker = load_spell_checker() # --- User Inputs: Search Bar and Slider --- col1, col2 = st.columns([4, 1]) with col1: search_query = st.text_input( "Enter your search query:", placeholder="e.g., machine learning models for time series" ) with col2: top_k_results = st.number_input( "Number of results", min_value=1, max_value=100, # Set a reasonable max value=10, help="Select the number of top results to display." ) if search_query: # --- Perform Typo Correction --- corrected_query = correct_query_spelling(search_query, spell_checker) # If a correction was made, notify the user if corrected_query.lower() != search_query.lower(): st.info(f"Did you mean: **{corrected_query}**? \n\n*Showing results for the corrected query.*") final_query = corrected_query # --- Perform Search --- search_results = semantic_search(final_query, model, corpus_embeddings, data_df, top_k=top_k_results) st.subheader(f"Found {len(search_results)} results for '{final_query}'") # --- Display Results --- if search_results: for result in search_results: with st.container(border=True): # Title as a clickable link st.markdown(f"### [{result['title']}]({result['link']})") # Authors st.caption(f"**Authors:** {result['authors']}") # Expander for the abstract if pd.notna(result['abstract']): with st.expander("View Abstract"): st.write(result['abstract']) else: st.warning("No results found. Try a different query.") except Exception as e: st.error(f"An error occurred: {e}") st.info("Please ensure all required libraries are installed and the CSV file is present in your repository.")