stefanoviel commited on
Commit
1a67af9
·
1 Parent(s): beda8ad
Files changed (1) hide show
  1. src/streamlit_app.py +30 -37
src/streamlit_app.py CHANGED
@@ -20,12 +20,13 @@ CSV_FILE = 'papers_with_abstracts_parallel.csv'
20
 
21
 
22
  # --- Caching Functions ---
23
-
 
24
  def load_embedding_model():
25
  """Loads the Sentence Transformer model and caches it."""
26
  return SentenceTransformer(EMBEDDING_MODEL)
27
 
28
-
29
  def load_spell_checker():
30
  """Loads the SpellChecker object and caches it."""
31
  return SpellChecker()
@@ -34,54 +35,56 @@ def load_spell_checker():
34
  def create_and_save_embeddings(model, data_df):
35
  """
36
  Generates and saves document embeddings and the dataframe.
37
- This function is called only once if the files don't exist.
38
  """
39
  st.info("First time setup: Generating and saving embeddings. This may take a moment...")
40
- # Combine title and abstract for richer embeddings
41
- data_df['text_to_embed'] = data_df['title'] + ". " + data_df['abstract'].fillna('')
42
 
43
- # Generate embeddings
44
- corpus_embeddings = model.encode(data_df['text_to_embed'].tolist(), convert_to_tensor=True, show_progress_bar=True)
 
 
 
45
 
46
- # Save embeddings and dataframe to /tmp directory
47
  try:
48
  torch.save(corpus_embeddings, EMBEDDINGS_FILE)
49
  data_df.to_pickle(DATA_FILE)
50
- st.success("Embeddings and data saved successfully!")
51
  except Exception as e:
52
- st.warning(f"Could not save embeddings to disk: {e}. Will regenerate on each session.")
53
 
54
  return corpus_embeddings, data_df
55
 
 
56
  def load_data_and_embeddings():
57
  """
58
- Loads the saved embeddings and dataframe from disk.
59
- If files don't exist, it calls the creation function.
60
  """
61
  model = load_embedding_model()
62
 
63
- # Check if files exist and are readable
64
- if os.path.exists(EMBEDDINGS_FILE) and os.path.exists(DATA_FILE):
65
  try:
66
- corpus_embeddings = torch.load(EMBEDDINGS_FILE)
67
  data_df = pd.read_pickle(DATA_FILE)
 
68
  return model, corpus_embeddings, data_df
69
  except Exception as e:
70
- st.warning(f"Could not load saved embeddings: {e}. Regenerating...")
71
-
72
- # Load the raw data from CSV
73
  try:
74
  data_df = pd.read_csv(CSV_FILE)
75
  corpus_embeddings, data_df = create_and_save_embeddings(model, data_df)
76
  except FileNotFoundError:
77
- st.error(f"CSV file '{CSV_FILE}' not found. Please ensure it's in your repository.")
78
  st.stop()
79
  except Exception as e:
80
- st.error(f"Error loading data: {e}")
81
  st.stop()
82
 
83
  return model, corpus_embeddings, data_df
84
 
 
85
  def correct_query_spelling(query, spell_checker):
86
  """
87
  Corrects potential spelling mistakes in the user's query.
@@ -150,12 +153,13 @@ The search is performed by comparing the semantic meaning of your query with the
150
  Spelling mistakes in your query will be automatically corrected.
151
  """)
152
 
153
- # Load all necessary data
154
  try:
 
155
  model, corpus_embeddings, data_df = load_data_and_embeddings()
156
  spell_checker = load_spell_checker()
157
 
158
- # --- User Inputs: Search Bar and Slider ---
159
  col1, col2 = st.columns([4, 1])
160
  with col1:
161
  search_query = st.text_input(
@@ -166,37 +170,26 @@ try:
166
  top_k_results = st.number_input(
167
  "Number of results",
168
  min_value=1,
169
- max_value=100, # Set a reasonable max
170
  value=10,
171
  help="Select the number of top results to display."
172
  )
173
 
174
  if search_query:
175
- # --- Perform Typo Correction ---
176
  corrected_query = correct_query_spelling(search_query, spell_checker)
177
 
178
- # If a correction was made, notify the user
179
  if corrected_query.lower() != search_query.lower():
180
  st.info(f"Did you mean: **{corrected_query}**? \n\n*Showing results for the corrected query.*")
181
 
182
- final_query = corrected_query
183
-
184
- # --- Perform Search ---
185
- search_results = semantic_search(final_query, model, corpus_embeddings, data_df, top_k=top_k_results)
186
 
187
- st.subheader(f"Found {len(search_results)} results for '{final_query}'")
188
 
189
- # --- Display Results ---
190
  if search_results:
191
  for result in search_results:
192
  with st.container(border=True):
193
- # Title as a clickable link
194
  st.markdown(f"### [{result['title']}]({result['link']})")
195
-
196
- # Authors
197
  st.caption(f"**Authors:** {result['authors']}")
198
-
199
- # Expander for the abstract
200
  if pd.notna(result['abstract']):
201
  with st.expander("View Abstract"):
202
  st.write(result['abstract'])
@@ -204,5 +197,5 @@ try:
204
  st.warning("No results found. Try a different query.")
205
 
206
  except Exception as e:
207
- st.error(f"An error occurred: {e}")
208
  st.info("Please ensure all required libraries are installed and the CSV file is present in your repository.")
 
20
 
21
 
22
  # --- Caching Functions ---
23
+ # --- Caching Functions (Unchanged but crucial) ---
24
+ @st.cache_resource
25
  def load_embedding_model():
26
  """Loads the Sentence Transformer model and caches it."""
27
  return SentenceTransformer(EMBEDDING_MODEL)
28
 
29
+ @st.cache_resource
30
  def load_spell_checker():
31
  """Loads the SpellChecker object and caches it."""
32
  return SpellChecker()
 
35
  def create_and_save_embeddings(model, data_df):
36
  """
37
  Generates and saves document embeddings and the dataframe.
38
+ This function is called only once if the files don't exist in the persistent directory.
39
  """
40
  st.info("First time setup: Generating and saving embeddings. This may take a moment...")
41
+ data_df['text_to_embed'] = data_df['title'].fillna('') + ". " + data_df['abstract'].fillna('')
 
42
 
43
+ corpus_embeddings = model.encode(
44
+ data_df['text_to_embed'].tolist(),
45
+ convert_to_tensor=True,
46
+ show_progress_bar=True
47
+ )
48
 
 
49
  try:
50
  torch.save(corpus_embeddings, EMBEDDINGS_FILE)
51
  data_df.to_pickle(DATA_FILE)
52
+ st.success("Embeddings and data saved successfully for future sessions!")
53
  except Exception as e:
54
+ st.warning(f"Could not save embeddings to persistent storage: {e}. Will regenerate on next session.")
55
 
56
  return corpus_embeddings, data_df
57
 
58
+ @st.cache_data
59
  def load_data_and_embeddings():
60
  """
61
+ Loads data and embeddings. It first tries to load from the persistent directory.
62
+ If files don't exist, it creates them. The results are cached for the current session.
63
  """
64
  model = load_embedding_model()
65
 
66
+ if DATA_FILE.exists() and EMBEDDINGS_FILE.exists():
 
67
  try:
 
68
  data_df = pd.read_pickle(DATA_FILE)
69
+ corpus_embeddings = torch.load(EMBEDDINGS_FILE)
70
  return model, corpus_embeddings, data_df
71
  except Exception as e:
72
+ st.warning(f"Could not load saved files: {e}. Regenerating...")
73
+
74
+ # Fallback to creating embeddings if they don't exist
75
  try:
76
  data_df = pd.read_csv(CSV_FILE)
77
  corpus_embeddings, data_df = create_and_save_embeddings(model, data_df)
78
  except FileNotFoundError:
79
+ st.error(f"The required data file '{CSV_FILE}' was not found. Please make sure it's in your repository root.")
80
  st.stop()
81
  except Exception as e:
82
+ st.error(f"An unexpected error occurred while loading data: {e}")
83
  st.stop()
84
 
85
  return model, corpus_embeddings, data_df
86
 
87
+ # ... (The rest of your functions `correct_query_spelling` and `semantic_search` remain the same) ...
88
  def correct_query_spelling(query, spell_checker):
89
  """
90
  Corrects potential spelling mistakes in the user's query.
 
153
  Spelling mistakes in your query will be automatically corrected.
154
  """)
155
 
156
+ # --- App Logic ---
157
  try:
158
+ # Load all necessary data using the corrected function
159
  model, corpus_embeddings, data_df = load_data_and_embeddings()
160
  spell_checker = load_spell_checker()
161
 
162
+ # --- User Inputs ---
163
  col1, col2 = st.columns([4, 1])
164
  with col1:
165
  search_query = st.text_input(
 
170
  top_k_results = st.number_input(
171
  "Number of results",
172
  min_value=1,
173
+ max_value=100,
174
  value=10,
175
  help="Select the number of top results to display."
176
  )
177
 
178
  if search_query:
 
179
  corrected_query = correct_query_spelling(search_query, spell_checker)
180
 
 
181
  if corrected_query.lower() != search_query.lower():
182
  st.info(f"Did you mean: **{corrected_query}**? \n\n*Showing results for the corrected query.*")
183
 
184
+ search_results = semantic_search(corrected_query, model, corpus_embeddings, data_df, top_k=top_k_results)
 
 
 
185
 
186
+ st.subheader(f"Found {len(search_results)} results for '{corrected_query}'")
187
 
 
188
  if search_results:
189
  for result in search_results:
190
  with st.container(border=True):
 
191
  st.markdown(f"### [{result['title']}]({result['link']})")
 
 
192
  st.caption(f"**Authors:** {result['authors']}")
 
 
193
  if pd.notna(result['abstract']):
194
  with st.expander("View Abstract"):
195
  st.write(result['abstract'])
 
197
  st.warning("No results found. Try a different query.")
198
 
199
  except Exception as e:
200
+ st.error(f"An error occurred during app execution: {e}")
201
  st.info("Please ensure all required libraries are installed and the CSV file is present in your repository.")