File size: 7,412 Bytes
6611ead
ca806f9
64e9557
7a417b0
 
 
ce35c00
7a417b0
 
ca806f9
 
c25929c
 
3c2ac96
 
de8365c
7a417b0
1c0cf1d
7a417b0
1a67af9
7a417b0
 
2c1d8e4
53b63ed
1a67af9
7a417b0
 
 
 
 
 
 
 
0fd8f7a
7a417b0
b1a742b
0fd8f7a
 
7a417b0
0fd8f7a
 
7a417b0
0fd8f7a
ce35c00
3c2ac96
ce35c00
b1a742b
ce35c00
b1a742b
ce35c00
7a417b0
 
1a67af9
7a417b0
 
0fd8f7a
 
7a417b0
 
ce35c00
0fd8f7a
 
ce35c00
1a67af9
0fd8f7a
ce35c00
 
b1a742b
1a67af9
b1a742b
0fd8f7a
 
ce35c00
7a417b0
 
ce35c00
b1a742b
ce35c00
 
b1a742b
ce35c00
7a417b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0fd8f7a
7a417b0
 
 
 
0fd8f7a
7a417b0
 
 
 
ce35c00
7a417b0
 
 
 
 
0fd8f7a
7a417b0
 
 
 
 
0fd8f7a
7a417b0
 
0fd8f7a
7a417b0
 
 
0fd8f7a
 
 
 
7a417b0
0fd8f7a
7a417b0
0fd8f7a
7a417b0
 
 
0fd8f7a
7a417b0
0fd8f7a
 
7a417b0
0fd8f7a
 
7a417b0
 
 
 
 
64e9557
7a417b0
0fd8f7a
ce35c00
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import os
from pathlib import Path
import streamlit as st
import pandas as pd
from sentence_transformers import SentenceTransformer, util
import torch
from spellchecker import SpellChecker
from io import StringIO

# --- Configuration ---
EMBEDDING_MODEL = 'sentence-transformers/all-MiniLM-L6-v2'

# Define paths relative to the script's directory
DATA_FILE =  'papers_data.pkl'
EMBEDDINGS_FILE =  'paper_embeddings.pt'
CSV_FILE = 'papers_with_abstracts_parallel.csv'


# --- Caching Functions ---
@st.cache_resource
def load_embedding_model():
    """Loads the Sentence Transformer model and caches it."""
    return SentenceTransformer(EMBEDDING_MODEL)

@st.cache_resource
def load_spell_checker():
    """Loads the SpellChecker object and caches it."""
    return SpellChecker()

# --- Core Functions ---
def create_and_save_embeddings(model, data_df):
    """
    Generates and saves document embeddings and the dataframe.
    This function is called only once if the files don't exist.
    """
    print("First time setup: Generating and saving embeddings. This may take a moment...")
    # Combine title and abstract for richer embeddings
    data_df['text_to_embed'] = data_df['title'] + ". " + data_df['abstract'].fillna('')
    
    # Generate embeddings
    corpus_embeddings = model.encode(data_df['text_to_embed'].tolist(), convert_to_tensor=True, show_progress_bar=True)
    
    # Save embeddings and dataframe to /tmp directory
    try:
        torch.save(corpus_embeddings.cpu(), EMBEDDINGS_FILE)
        data_df.to_pickle(DATA_FILE)
        print("Embeddings and data saved successfully!")
    except Exception as e:
        print(f"Could not save embeddings to disk: {e}. Will regenerate on each session.")
    
    return corpus_embeddings, data_df

@st.cache_data
def load_data_and_embeddings():
    """
    Loads the saved embeddings and dataframe from disk.
    If files don't exist, it calls the creation function.
    """
    model = load_embedding_model()
    
    # Check if files exist and are readable
    if os.path.exists(EMBEDDINGS_FILE) and os.path.exists(DATA_FILE):
        try:
            corpus_embeddings = torch.load(EMBEDDINGS_FILE)
            data_df = pd.read_pickle(DATA_FILE)
            return model, corpus_embeddings, data_df
        except Exception as e:
            print(f"Could not load saved embeddings: {e}. Regenerating...")

    print("embeding model path exists: " + str(Path(EMBEDDING_MODEL).exists()))
    
    # Load the raw data from CSV
    try:
        data_df = pd.read_csv(CSV_FILE)
        corpus_embeddings, data_df = create_and_save_embeddings(model, data_df)
    except FileNotFoundError:
        print(f"CSV file '{CSV_FILE}' not found. Please ensure it's in your repository.")
        st.stop()
    except Exception as e:
        print(f"Error loading data: {e}")
        st.stop()
        
    return model, corpus_embeddings, data_df

def correct_query_spelling(query, spell_checker):
    """
    Corrects potential spelling mistakes in the user's query.
    """
    if not query:
        return ""
    
    # Split the query into words
    words = query.split()
    
    # Find words that are likely misspelled
    misspelled = spell_checker.unknown(words)
    
    if not misspelled:
        return query # Return original if no typos found

    # Generate the corrected query
    corrected_words = []
    for word in words:
        if word in misspelled:
            corrected_word = spell_checker.correction(word)
            # Use the correction, but fall back to the original word if no correction is found
            corrected_words.append(corrected_word if corrected_word else word)
        else:
            corrected_words.append(word)
            
    return " ".join(corrected_words)

def semantic_search(query, model, corpus_embeddings, data_df, top_k=10):
    """
    Performs semantic search on the loaded data.
    """
    if not query:
        return []
        
    # Encode the query
    query_embedding = model.encode(query, convert_to_tensor=True)

    # Calculate cosine similarity
    cos_scores = util.cos_sim(query_embedding, corpus_embeddings)[0]

    # Get the top k results, ensuring we don't ask for more results than exist
    top_k = min(top_k, len(corpus_embeddings))
    top_results = torch.topk(cos_scores, k=top_k)

    # Format results
    results = []
    for score, idx in zip(top_results[0], top_results[1]):
        item = data_df.iloc[idx.item()]
        results.append({
            "title": item["title"],
            "authors": item["authors"],
            "link": item["link"],
            "abstract": item["abstract"],
            "score": score.item() # Score is kept for potential future use but not displayed
        })
    return results

# --- Streamlit App UI ---
st.set_page_config(page_title="Semantic Paper Search", layout="wide")

st.title("📄 Semantic Research Paper Search")
st.markdown("""
Enter a query below to search through a small collection of ICML 2025 papers. 
The search is performed by comparing the semantic meaning of your query with the papers' titles and abstracts.
Spelling mistakes in your query will be automatically corrected.
""")

# Load all necessary data
try:
    model, corpus_embeddings, data_df = load_data_and_embeddings()
    spell_checker = load_spell_checker()

    # --- User Inputs: Search Bar and Slider ---
    col1, col2 = st.columns([4, 1])
    with col1:
        search_query = st.text_input(
            "Enter your search query:", 
            placeholder="e.g., machine learning models for time series"
        )
    with col2:
        top_k_results = st.number_input(
            "Number of results", 
            min_value=1, 
            max_value=100, # Set a reasonable max
            value=10, 
            help="Select the number of top results to display."
        )

    if search_query:
        # --- Perform Typo Correction ---
        corrected_query = correct_query_spelling(search_query, spell_checker)
        
        # If a correction was made, notify the user
        if corrected_query.lower() != search_query.lower():
            st.info(f"Did you mean: **{corrected_query}**? \n\n*Showing results for the corrected query.*")
        
        final_query = corrected_query
        
        # --- Perform Search ---
        search_results = semantic_search(final_query, model, corpus_embeddings, data_df, top_k=top_k_results)
        
        st.subheader(f"Found {len(search_results)} results for '{final_query}'")
        
        # --- Display Results ---
        if search_results:
            for result in search_results:
                with st.container(border=True):
                    # Title as a clickable link
                    st.markdown(f"### [{result['title']}]({result['link']})")
                    
                    # Authors
                    st.caption(f"**Authors:** {result['authors']}")
                    
                    # Expander for the abstract
                    if pd.notna(result['abstract']):
                        with st.expander("View Abstract"):
                            st.write(result['abstract'])
        else:
            st.warning("No results found. Try a different query.")

except Exception as e:
    st.error(f"An error occurred: {e}")
    st.info("Please ensure all required libraries are installed and the CSV file is present in your repository.")