Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -11,6 +11,13 @@ from sentence_transformers import SentenceTransformer
|
|
| 11 |
from scipy.spatial.distance import cosine
|
| 12 |
import PyPDF2
|
| 13 |
import spacy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
from difflib import SequenceMatcher
|
| 15 |
|
| 16 |
# Load spaCy model
|
|
@@ -32,16 +39,18 @@ model, tfidf_vectorizer, word2vec_model = load_models()
|
|
| 32 |
# Initialize session state for results table if not already present
|
| 33 |
if 'results_df' not in st.session_state:
|
| 34 |
st.session_state.results_df = pd.DataFrame(columns=[
|
| 35 |
-
"LLM1", "LLM2",
|
| 36 |
-
"
|
|
|
|
|
|
|
| 37 |
"Combined Similarity (%)"
|
| 38 |
])
|
| 39 |
|
| 40 |
-
# Initialize session state for radar chart data
|
| 41 |
if 'radar_chart_data' not in st.session_state:
|
| 42 |
st.session_state.radar_chart_data = []
|
| 43 |
|
| 44 |
-
# Functions
|
| 45 |
@st.cache_data
|
| 46 |
def chunk_text(text, chunk_size=500):
|
| 47 |
return [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
|
|
@@ -71,6 +80,7 @@ def calculate_word_similarity_ratio(text1, text2):
|
|
| 71 |
try:
|
| 72 |
doc1 = nlp(text1)
|
| 73 |
doc2 = nlp(text2)
|
|
|
|
| 74 |
words1 = [token.text for token in doc1 if not token.is_stop and not token.is_punct]
|
| 75 |
words2 = [token.text for token in doc2 if not token.is_stop and not token.is_punct]
|
| 76 |
|
|
@@ -81,7 +91,7 @@ def calculate_word_similarity_ratio(text1, text2):
|
|
| 81 |
word_embeddings2 = model.encode(words2)
|
| 82 |
|
| 83 |
similarities = np.array([
|
| 84 |
-
max([1 - cosine(emb1, emb2) for emb2 in word_embeddings2], default=0)
|
| 85 |
for emb1 in word_embeddings1
|
| 86 |
])
|
| 87 |
|
|
@@ -153,7 +163,7 @@ def calculate_paraphrasing_similarity(text1, text2):
|
|
| 153 |
chunks_2 = chunk_text(text2)
|
| 154 |
embeddings_1 = create_embeddings(chunks_1)
|
| 155 |
embeddings_2 = create_embeddings(chunks_2)
|
| 156 |
-
|
| 157 |
if embeddings_1.size > 0 and embeddings_2.size > 0:
|
| 158 |
similarities, average_similarity = calculate_similarity_ratio_and_find_matches(embeddings_1, embeddings_2)
|
| 159 |
return average_similarity * 100
|
|
@@ -171,10 +181,12 @@ def calculate_direct_text_comparison_similarity(text1, text2):
|
|
| 171 |
bleu_score = calculate_bleu_score(text1, text2) * 100
|
| 172 |
rouge_l_score = calculate_rouge_l_score(text1, text2)
|
| 173 |
bertscore = calculate_bertscore(text1, text2)
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
tfidf_cosine_similarity * 0.2 +
|
| 177 |
-
|
|
|
|
|
|
|
| 178 |
except Exception as e:
|
| 179 |
st.error(f"Error calculating direct text comparison similarity: {e}")
|
| 180 |
return 0
|
|
@@ -190,6 +202,7 @@ def calculate_summarization_similarity(text1, text2):
|
|
| 190 |
|
| 191 |
# Streamlit UI
|
| 192 |
st.title("Text-Based Similarity Comparison")
|
|
|
|
| 193 |
|
| 194 |
# Create a two-column layout for input
|
| 195 |
col1, col2 = st.columns([2, 1])
|
|
@@ -198,12 +211,11 @@ with col1:
|
|
| 198 |
st.sidebar.title("LLM Details")
|
| 199 |
llm1_name = st.sidebar.text_input("What is LLM1?", "LLM1")
|
| 200 |
llm2_name = st.sidebar.text_input("What is LLM2?", "LLM2")
|
| 201 |
-
|
| 202 |
st.write("## Input")
|
| 203 |
-
|
| 204 |
# Create two columns for text input
|
| 205 |
input_col1, input_col2 = st.columns(2)
|
| 206 |
-
|
| 207 |
with input_col1:
|
| 208 |
st.write(f"{llm1_name} response")
|
| 209 |
upload_pdf_1 = st.file_uploader(f"Upload PDF for {llm1_name} response", type="pdf", key="pdf1")
|
|
@@ -211,7 +223,7 @@ with col1:
|
|
| 211 |
text_input_1 = extract_pdf_text(upload_pdf_1)
|
| 212 |
else:
|
| 213 |
text_input_1 = st.text_area(f"Text for {llm1_name}", height=150, key="text1")
|
| 214 |
-
|
| 215 |
with input_col2:
|
| 216 |
st.write(f"{llm2_name} response")
|
| 217 |
upload_pdf_2 = st.file_uploader(f"Upload PDF for {llm2_name} response", type="pdf", key="pdf2")
|
|
@@ -228,9 +240,9 @@ with col1:
|
|
| 228 |
summarization_similarity = calculate_summarization_similarity(text_input_1, text_input_2)
|
| 229 |
|
| 230 |
# Combine all metrics into a single similarity score
|
| 231 |
-
total_similarity = (paraphrasing_similarity * 0.
|
| 232 |
-
direct_text_comparison_similarity * 0.
|
| 233 |
-
summarization_similarity * 0.
|
| 234 |
|
| 235 |
# Update results table in session state
|
| 236 |
new_row = pd.Series({
|
|
@@ -241,6 +253,7 @@ with col1:
|
|
| 241 |
"Summarization Similarity (%)": summarization_similarity,
|
| 242 |
"Combined Similarity (%)": total_similarity
|
| 243 |
})
|
|
|
|
| 244 |
st.session_state.results_df = pd.concat([st.session_state.results_df, new_row.to_frame().T], ignore_index=True)
|
| 245 |
|
| 246 |
# Add new data for radar chart
|
|
@@ -252,37 +265,57 @@ with col1:
|
|
| 252 |
})
|
| 253 |
|
| 254 |
# Display metrics with large and bold text
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
|
| 261 |
with col2:
|
| 262 |
-
|
| 263 |
-
|
| 264 |
# Display radar chart
|
| 265 |
if st.session_state.radar_chart_data:
|
| 266 |
st.subheader("Metrics Comparison")
|
| 267 |
-
|
|
|
|
| 268 |
num_vars = len(labels)
|
| 269 |
angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
|
| 270 |
angles += angles[:1]
|
| 271 |
-
|
| 272 |
fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True))
|
| 273 |
-
|
| 274 |
# Plot each response with a different color
|
| 275 |
color_palette = sns.color_palette("husl", len(st.session_state.radar_chart_data))
|
| 276 |
for idx, data in enumerate(st.session_state.radar_chart_data):
|
| 277 |
values = [
|
| 278 |
-
data["paraphrasing_similarity"],
|
| 279 |
-
data["direct_text_comparison_similarity"],
|
| 280 |
data["summarization_similarity"]
|
| 281 |
]
|
| 282 |
values += values[:1]
|
| 283 |
ax.fill(angles, values, color=color_palette[idx], alpha=0.25, label=data["name"])
|
| 284 |
ax.plot(angles, values, color=color_palette[idx], linewidth=2, linestyle='solid')
|
| 285 |
-
|
| 286 |
ax.set_yticklabels([])
|
| 287 |
ax.set_xticks(angles[:-1])
|
| 288 |
ax.set_xticklabels(labels)
|
|
@@ -292,14 +325,20 @@ with col2:
|
|
| 292 |
|
| 293 |
# Display metrics sliders beside the radar chart
|
| 294 |
if st.session_state.radar_chart_data:
|
| 295 |
-
st.subheader("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
metrics = st.session_state.radar_chart_data[-1]
|
| 297 |
for metric_name in ["paraphrasing_similarity", "direct_text_comparison_similarity", "summarization_similarity"]:
|
| 298 |
st.slider(
|
| 299 |
-
metric_name
|
| 300 |
-
0, 100,
|
| 301 |
-
int(metrics[metric_name]),
|
| 302 |
-
key=metric_name,
|
| 303 |
disabled=True, # Make the slider non-editable
|
| 304 |
format="%.0f" # Format the slider value to be an integer
|
| 305 |
)
|
|
@@ -311,7 +350,7 @@ with results_col:
|
|
| 311 |
st.write("## Detailed Results Table")
|
| 312 |
if not st.session_state.results_df.empty:
|
| 313 |
st.write(st.session_state.results_df)
|
| 314 |
-
|
| 315 |
# Download the results as a CSV file
|
| 316 |
csv_data = st.session_state.results_df.to_csv(index=False).encode('utf-8')
|
| 317 |
st.download_button(label="Download Results as CSV", data=csv_data, file_name='similarity_results.csv', mime='text/csv')
|
|
@@ -319,9 +358,46 @@ with results_col:
|
|
| 319 |
with actions_col:
|
| 320 |
if st.button("Reset Table"):
|
| 321 |
st.session_state.results_df = pd.DataFrame(columns=[
|
| 322 |
-
"LLM1", "LLM2",
|
| 323 |
-
"
|
|
|
|
|
|
|
| 324 |
"Combined Similarity (%)"
|
| 325 |
])
|
| 326 |
st.session_state.radar_chart_data = []
|
| 327 |
-
st.write("Results table has been reset.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
from scipy.spatial.distance import cosine
|
| 12 |
import PyPDF2
|
| 13 |
import spacy
|
| 14 |
+
try:
|
| 15 |
+
nlp = spacy.load("en_core_web_sm")
|
| 16 |
+
except OSError:
|
| 17 |
+
from spacy.cli import download
|
| 18 |
+
download("en_core_web_sm")
|
| 19 |
+
nlp = spacy.load("en_core_web_sm")
|
| 20 |
+
|
| 21 |
from difflib import SequenceMatcher
|
| 22 |
|
| 23 |
# Load spaCy model
|
|
|
|
| 39 |
# Initialize session state for results table if not already present
|
| 40 |
if 'results_df' not in st.session_state:
|
| 41 |
st.session_state.results_df = pd.DataFrame(columns=[
|
| 42 |
+
"LLM1", "LLM2",
|
| 43 |
+
"Paraphrasing Similarity (%)",
|
| 44 |
+
"Direct Text Comparison (%)",
|
| 45 |
+
"Summarization Similarity (%)",
|
| 46 |
"Combined Similarity (%)"
|
| 47 |
])
|
| 48 |
|
| 49 |
+
# Initialize session state for radar chart data
|
| 50 |
if 'radar_chart_data' not in st.session_state:
|
| 51 |
st.session_state.radar_chart_data = []
|
| 52 |
|
| 53 |
+
# Functions (same as before)
|
| 54 |
@st.cache_data
|
| 55 |
def chunk_text(text, chunk_size=500):
|
| 56 |
return [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
|
|
|
|
| 80 |
try:
|
| 81 |
doc1 = nlp(text1)
|
| 82 |
doc2 = nlp(text2)
|
| 83 |
+
|
| 84 |
words1 = [token.text for token in doc1 if not token.is_stop and not token.is_punct]
|
| 85 |
words2 = [token.text for token in doc2 if not token.is_stop and not token.is_punct]
|
| 86 |
|
|
|
|
| 91 |
word_embeddings2 = model.encode(words2)
|
| 92 |
|
| 93 |
similarities = np.array([
|
| 94 |
+
max([1 - cosine(emb1, emb2) for emb2 in word_embeddings2], default=0)
|
| 95 |
for emb1 in word_embeddings1
|
| 96 |
])
|
| 97 |
|
|
|
|
| 163 |
chunks_2 = chunk_text(text2)
|
| 164 |
embeddings_1 = create_embeddings(chunks_1)
|
| 165 |
embeddings_2 = create_embeddings(chunks_2)
|
| 166 |
+
|
| 167 |
if embeddings_1.size > 0 and embeddings_2.size > 0:
|
| 168 |
similarities, average_similarity = calculate_similarity_ratio_and_find_matches(embeddings_1, embeddings_2)
|
| 169 |
return average_similarity * 100
|
|
|
|
| 181 |
bleu_score = calculate_bleu_score(text1, text2) * 100
|
| 182 |
rouge_l_score = calculate_rouge_l_score(text1, text2)
|
| 183 |
bertscore = calculate_bertscore(text1, text2)
|
| 184 |
+
return (levenshtein_ratio * 0.1 +
|
| 185 |
+
jaccard_similarity * 0.2 +
|
| 186 |
+
tfidf_cosine_similarity * 0.2 +
|
| 187 |
+
bleu_score * 0.2 +
|
| 188 |
+
rouge_l_score * 0.2 +
|
| 189 |
+
bertscore * 0.2) / 1.1
|
| 190 |
except Exception as e:
|
| 191 |
st.error(f"Error calculating direct text comparison similarity: {e}")
|
| 192 |
return 0
|
|
|
|
| 202 |
|
| 203 |
# Streamlit UI
|
| 204 |
st.title("Text-Based Similarity Comparison")
|
| 205 |
+
st.markdown("*Use in wide mode*")
|
| 206 |
|
| 207 |
# Create a two-column layout for input
|
| 208 |
col1, col2 = st.columns([2, 1])
|
|
|
|
| 211 |
st.sidebar.title("LLM Details")
|
| 212 |
llm1_name = st.sidebar.text_input("What is LLM1?", "LLM1")
|
| 213 |
llm2_name = st.sidebar.text_input("What is LLM2?", "LLM2")
|
| 214 |
+
|
| 215 |
st.write("## Input")
|
| 216 |
+
|
| 217 |
# Create two columns for text input
|
| 218 |
input_col1, input_col2 = st.columns(2)
|
|
|
|
| 219 |
with input_col1:
|
| 220 |
st.write(f"{llm1_name} response")
|
| 221 |
upload_pdf_1 = st.file_uploader(f"Upload PDF for {llm1_name} response", type="pdf", key="pdf1")
|
|
|
|
| 223 |
text_input_1 = extract_pdf_text(upload_pdf_1)
|
| 224 |
else:
|
| 225 |
text_input_1 = st.text_area(f"Text for {llm1_name}", height=150, key="text1")
|
| 226 |
+
|
| 227 |
with input_col2:
|
| 228 |
st.write(f"{llm2_name} response")
|
| 229 |
upload_pdf_2 = st.file_uploader(f"Upload PDF for {llm2_name} response", type="pdf", key="pdf2")
|
|
|
|
| 240 |
summarization_similarity = calculate_summarization_similarity(text_input_1, text_input_2)
|
| 241 |
|
| 242 |
# Combine all metrics into a single similarity score
|
| 243 |
+
total_similarity = (paraphrasing_similarity * 0.6 + # High weight
|
| 244 |
+
direct_text_comparison_similarity * 0.3 + # Moderate weight
|
| 245 |
+
summarization_similarity * 0.1) # Low weight
|
| 246 |
|
| 247 |
# Update results table in session state
|
| 248 |
new_row = pd.Series({
|
|
|
|
| 253 |
"Summarization Similarity (%)": summarization_similarity,
|
| 254 |
"Combined Similarity (%)": total_similarity
|
| 255 |
})
|
| 256 |
+
|
| 257 |
st.session_state.results_df = pd.concat([st.session_state.results_df, new_row.to_frame().T], ignore_index=True)
|
| 258 |
|
| 259 |
# Add new data for radar chart
|
|
|
|
| 265 |
})
|
| 266 |
|
| 267 |
# Display metrics with large and bold text
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
# Define a style for the combined score
|
| 271 |
+
combined_score_style = """
|
| 272 |
+
<style>
|
| 273 |
+
.combined-score {
|
| 274 |
+
font-size: 48px;
|
| 275 |
+
font-weight: bold;
|
| 276 |
+
color: #4CAF50; /* Green color for positive emphasis */
|
| 277 |
+
background-color: #f0f0f5;
|
| 278 |
+
padding: 20px;
|
| 279 |
+
border-radius: 15px;
|
| 280 |
+
text-align: center;
|
| 281 |
+
margin-top: 30px;
|
| 282 |
+
box-shadow: 2px 2px 12px rgba(0, 0, 0, 0.1);
|
| 283 |
+
}
|
| 284 |
+
</style>
|
| 285 |
+
"""
|
| 286 |
+
|
| 287 |
+
# Apply the style
|
| 288 |
+
st.markdown(combined_score_style, unsafe_allow_html=True)
|
| 289 |
+
|
| 290 |
+
# Display the combined similarity score
|
| 291 |
+
st.markdown(f'<div class="combined-score">Combined Similarity Score: {total_similarity:.2f}%</div>', unsafe_allow_html=True)
|
| 292 |
+
|
| 293 |
|
| 294 |
with col2:
|
| 295 |
+
|
|
|
|
| 296 |
# Display radar chart
|
| 297 |
if st.session_state.radar_chart_data:
|
| 298 |
st.subheader("Metrics Comparison")
|
| 299 |
+
st.markdown("*Larger area = More similarity of responses.*")
|
| 300 |
+
labels = ["Context similarity", "Words Similarity", "Summarization Similarity"]
|
| 301 |
num_vars = len(labels)
|
| 302 |
angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
|
| 303 |
angles += angles[:1]
|
| 304 |
+
|
| 305 |
fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True))
|
| 306 |
+
|
| 307 |
# Plot each response with a different color
|
| 308 |
color_palette = sns.color_palette("husl", len(st.session_state.radar_chart_data))
|
| 309 |
for idx, data in enumerate(st.session_state.radar_chart_data):
|
| 310 |
values = [
|
| 311 |
+
data["paraphrasing_similarity"],
|
| 312 |
+
data["direct_text_comparison_similarity"],
|
| 313 |
data["summarization_similarity"]
|
| 314 |
]
|
| 315 |
values += values[:1]
|
| 316 |
ax.fill(angles, values, color=color_palette[idx], alpha=0.25, label=data["name"])
|
| 317 |
ax.plot(angles, values, color=color_palette[idx], linewidth=2, linestyle='solid')
|
| 318 |
+
|
| 319 |
ax.set_yticklabels([])
|
| 320 |
ax.set_xticks(angles[:-1])
|
| 321 |
ax.set_xticklabels(labels)
|
|
|
|
| 325 |
|
| 326 |
# Display metrics sliders beside the radar chart
|
| 327 |
if st.session_state.radar_chart_data:
|
| 328 |
+
st.subheader("Similarity Factors")
|
| 329 |
+
st.markdown("*100 being the best case*")
|
| 330 |
+
slider_labels = {
|
| 331 |
+
"paraphrasing_similarity": "Context",
|
| 332 |
+
"direct_text_comparison_similarity": "Words",
|
| 333 |
+
"summarization_similarity": "Summary"
|
| 334 |
+
}
|
| 335 |
metrics = st.session_state.radar_chart_data[-1]
|
| 336 |
for metric_name in ["paraphrasing_similarity", "direct_text_comparison_similarity", "summarization_similarity"]:
|
| 337 |
st.slider(
|
| 338 |
+
slider_labels[metric_name],
|
| 339 |
+
0, 100,
|
| 340 |
+
int(metrics[metric_name]),
|
| 341 |
+
key=metric_name,
|
| 342 |
disabled=True, # Make the slider non-editable
|
| 343 |
format="%.0f" # Format the slider value to be an integer
|
| 344 |
)
|
|
|
|
| 350 |
st.write("## Detailed Results Table")
|
| 351 |
if not st.session_state.results_df.empty:
|
| 352 |
st.write(st.session_state.results_df)
|
| 353 |
+
|
| 354 |
# Download the results as a CSV file
|
| 355 |
csv_data = st.session_state.results_df.to_csv(index=False).encode('utf-8')
|
| 356 |
st.download_button(label="Download Results as CSV", data=csv_data, file_name='similarity_results.csv', mime='text/csv')
|
|
|
|
| 358 |
with actions_col:
|
| 359 |
if st.button("Reset Table"):
|
| 360 |
st.session_state.results_df = pd.DataFrame(columns=[
|
| 361 |
+
"LLM1", "LLM2",
|
| 362 |
+
"Paraphrasing Similarity (%)",
|
| 363 |
+
"Direct Text Comparison (%)",
|
| 364 |
+
"Summarization Similarity (%)",
|
| 365 |
"Combined Similarity (%)"
|
| 366 |
])
|
| 367 |
st.session_state.radar_chart_data = []
|
| 368 |
+
st.write("Results table has been reset.")
|
| 369 |
+
# Add an "About" button in the sidebar
|
| 370 |
+
if st.sidebar.button("About"):
|
| 371 |
+
st.sidebar.markdown("""
|
| 372 |
+
### About This App
|
| 373 |
+
This app compares text similarity between different responses from Language Models (LLMs).
|
| 374 |
+
It calculates various similarity metrics and provides a comprehensive comparison using a radar chart.
|
| 375 |
+
**Features:**
|
| 376 |
+
- Upload or input text for comparison.
|
| 377 |
+
- Calculate and display multiple similarity metrics.
|
| 378 |
+
- Visualize the results using a radar chart.
|
| 379 |
+
- Download the results as a CSV file.
|
| 380 |
+
**Similarity Metrics:**
|
| 381 |
+
1. **Paraphrasing Similarity**:
|
| 382 |
+
- Compares chunks of text from both LLM responses using embeddings generated by a pre-trained model.
|
| 383 |
+
- Calculates the average cosine similarity between the chunks.
|
| 384 |
+
2. **Direct Text Comparison**:
|
| 385 |
+
- Uses a combination of metrics:
|
| 386 |
+
- **Levenshtein Ratio**: Measures the similarity based on the minimum edit distance.
|
| 387 |
+
- **Jaccard Similarity**: Compares the overlap of unique words.
|
| 388 |
+
- **TF-IDF Cosine Similarity**: Compares the text using TF-IDF vectorization.
|
| 389 |
+
- **BLEU Score**: Evaluates the overlap of n-grams.
|
| 390 |
+
- **ROUGE-L Score**: Measures the longest matching sequence of words.
|
| 391 |
+
- **BERTScore**: Uses BERT embeddings to compare sentence similarity.
|
| 392 |
+
3. **Summarization Similarity**:
|
| 393 |
+
- Uses the Word Mover's Distance (WMD) to compare the semantic distance between the summaries of the texts.
|
| 394 |
+
4. **Combined Similarity**:
|
| 395 |
+
- A weighted average of the above metrics to provide an overall similarity score.
|
| 396 |
+
**Developed with:**
|
| 397 |
+
- Streamlit
|
| 398 |
+
- Sentence Transformers
|
| 399 |
+
- SpaCy
|
| 400 |
+
- Scikit-learn
|
| 401 |
+
- NLTK
|
| 402 |
+
- Gensim
|
| 403 |
+
""")
|