stefanoviel commited on
Commit ·
ce35c00
1
Parent(s): 70f287c
using tmp folder
Browse files- src/streamlit_app.py +31 -17
src/streamlit_app.py
CHANGED
|
@@ -1,20 +1,18 @@
|
|
| 1 |
import os
|
| 2 |
-
|
| 3 |
import streamlit as st
|
| 4 |
import pandas as pd
|
| 5 |
from sentence_transformers import SentenceTransformer, util
|
| 6 |
import torch
|
| 7 |
-
from spellchecker import SpellChecker
|
| 8 |
from io import StringIO
|
| 9 |
|
| 10 |
# --- Configuration ---
|
| 11 |
EMBEDDING_MODEL = 'sentence-transformers/all-MiniLM-L6-v2'
|
| 12 |
-
|
| 13 |
-
|
|
|
|
| 14 |
|
| 15 |
# --- Data Loading and Preparation ---
|
| 16 |
-
# This is the raw data provided by the user.
|
| 17 |
-
# In a real application, you might load this from a CSV file.
|
| 18 |
CSV_FILE = 'papers_with_abstracts_parallel.csv'
|
| 19 |
|
| 20 |
# --- Caching Functions ---
|
|
@@ -41,10 +39,14 @@ def create_and_save_embeddings(model, data_df):
|
|
| 41 |
# Generate embeddings
|
| 42 |
corpus_embeddings = model.encode(data_df['text_to_embed'].tolist(), convert_to_tensor=True, show_progress_bar=True)
|
| 43 |
|
| 44 |
-
# Save embeddings and dataframe
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
return corpus_embeddings, data_df
|
| 49 |
|
| 50 |
def load_data_and_embeddings():
|
|
@@ -53,13 +55,26 @@ def load_data_and_embeddings():
|
|
| 53 |
If files don't exist, it calls the creation function.
|
| 54 |
"""
|
| 55 |
model = load_embedding_model()
|
|
|
|
|
|
|
| 56 |
if os.path.exists(EMBEDDINGS_FILE) and os.path.exists(DATA_FILE):
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
data_df = pd.read_csv(CSV_FILE)
|
| 62 |
corpus_embeddings, data_df = create_and_save_embeddings(model, data_df)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
return model, corpus_embeddings, data_df
|
| 65 |
|
|
@@ -91,7 +106,6 @@ def correct_query_spelling(query, spell_checker):
|
|
| 91 |
|
| 92 |
return " ".join(corrected_words)
|
| 93 |
|
| 94 |
-
|
| 95 |
def semantic_search(query, model, corpus_embeddings, data_df, top_k=10):
|
| 96 |
"""
|
| 97 |
Performs semantic search on the loaded data.
|
|
@@ -142,7 +156,7 @@ try:
|
|
| 142 |
with col1:
|
| 143 |
search_query = st.text_input(
|
| 144 |
"Enter your search query:",
|
| 145 |
-
placeholder="e.g.,
|
| 146 |
)
|
| 147 |
with col2:
|
| 148 |
top_k_results = st.number_input(
|
|
@@ -187,4 +201,4 @@ try:
|
|
| 187 |
|
| 188 |
except Exception as e:
|
| 189 |
st.error(f"An error occurred: {e}")
|
| 190 |
-
st.info("Please ensure all required libraries are installed
|
|
|
|
| 1 |
import os
|
|
|
|
| 2 |
import streamlit as st
|
| 3 |
import pandas as pd
|
| 4 |
from sentence_transformers import SentenceTransformer, util
|
| 5 |
import torch
|
| 6 |
+
from spellchecker import SpellChecker
|
| 7 |
from io import StringIO
|
| 8 |
|
| 9 |
# --- Configuration ---
|
| 10 |
EMBEDDING_MODEL = 'sentence-transformers/all-MiniLM-L6-v2'
|
| 11 |
+
# Use /tmp directory for temporary files in Hugging Face Spaces
|
| 12 |
+
EMBEDDINGS_FILE = '/tmp/paper_embeddings.pt'
|
| 13 |
+
DATA_FILE = '/tmp/papers_data.pkl'
|
| 14 |
|
| 15 |
# --- Data Loading and Preparation ---
|
|
|
|
|
|
|
| 16 |
CSV_FILE = 'papers_with_abstracts_parallel.csv'
|
| 17 |
|
| 18 |
# --- Caching Functions ---
|
|
|
|
| 39 |
# Generate embeddings
|
| 40 |
corpus_embeddings = model.encode(data_df['text_to_embed'].tolist(), convert_to_tensor=True, show_progress_bar=True)
|
| 41 |
|
| 42 |
+
# Save embeddings and dataframe to /tmp directory
|
| 43 |
+
try:
|
| 44 |
+
torch.save(corpus_embeddings, EMBEDDINGS_FILE)
|
| 45 |
+
data_df.to_pickle(DATA_FILE)
|
| 46 |
+
st.success("Embeddings and data saved successfully!")
|
| 47 |
+
except Exception as e:
|
| 48 |
+
st.warning(f"Could not save embeddings to disk: {e}. Will regenerate on each session.")
|
| 49 |
+
|
| 50 |
return corpus_embeddings, data_df
|
| 51 |
|
| 52 |
def load_data_and_embeddings():
|
|
|
|
| 55 |
If files don't exist, it calls the creation function.
|
| 56 |
"""
|
| 57 |
model = load_embedding_model()
|
| 58 |
+
|
| 59 |
+
# Check if files exist and are readable
|
| 60 |
if os.path.exists(EMBEDDINGS_FILE) and os.path.exists(DATA_FILE):
|
| 61 |
+
try:
|
| 62 |
+
corpus_embeddings = torch.load(EMBEDDINGS_FILE)
|
| 63 |
+
data_df = pd.read_pickle(DATA_FILE)
|
| 64 |
+
return model, corpus_embeddings, data_df
|
| 65 |
+
except Exception as e:
|
| 66 |
+
st.warning(f"Could not load saved embeddings: {e}. Regenerating...")
|
| 67 |
+
|
| 68 |
+
# Load the raw data from CSV
|
| 69 |
+
try:
|
| 70 |
data_df = pd.read_csv(CSV_FILE)
|
| 71 |
corpus_embeddings, data_df = create_and_save_embeddings(model, data_df)
|
| 72 |
+
except FileNotFoundError:
|
| 73 |
+
st.error(f"CSV file '{CSV_FILE}' not found. Please ensure it's in your repository.")
|
| 74 |
+
st.stop()
|
| 75 |
+
except Exception as e:
|
| 76 |
+
st.error(f"Error loading data: {e}")
|
| 77 |
+
st.stop()
|
| 78 |
|
| 79 |
return model, corpus_embeddings, data_df
|
| 80 |
|
|
|
|
| 106 |
|
| 107 |
return " ".join(corrected_words)
|
| 108 |
|
|
|
|
| 109 |
def semantic_search(query, model, corpus_embeddings, data_df, top_k=10):
|
| 110 |
"""
|
| 111 |
Performs semantic search on the loaded data.
|
|
|
|
| 156 |
with col1:
|
| 157 |
search_query = st.text_input(
|
| 158 |
"Enter your search query:",
|
| 159 |
+
placeholder="e.g., machine learning models for time series"
|
| 160 |
)
|
| 161 |
with col2:
|
| 162 |
top_k_results = st.number_input(
|
|
|
|
| 201 |
|
| 202 |
except Exception as e:
|
| 203 |
st.error(f"An error occurred: {e}")
|
| 204 |
+
st.info("Please ensure all required libraries are installed and the CSV file is present in your repository.")
|