Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -347,12 +347,16 @@ def search_embeddings(chunks, embedding_model, vector_store_type, search_type, q
|
|
| 347 |
results = retriever.invoke(preprocessed_query)
|
| 348 |
|
| 349 |
def score_result(doc):
|
| 350 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 351 |
if apply_phonetic:
|
| 352 |
phonetic_score = phonetic_match(doc.page_content, query)
|
| 353 |
-
return (1 - phonetic_weight) *
|
| 354 |
else:
|
| 355 |
-
return
|
| 356 |
|
| 357 |
results = sorted(results, key=score_result, reverse=True)
|
| 358 |
end_time = time.time()
|
|
@@ -378,6 +382,7 @@ def search_embeddings(chunks, embedding_model, vector_store_type, search_type, q
|
|
| 378 |
# Evaluation Metrics
|
| 379 |
# ... (previous code remains the same)
|
| 380 |
|
|
|
|
| 381 |
def calculate_statistics(results, search_time, vector_store, num_tokens, embedding_model, query, top_k, expected_result=None):
|
| 382 |
stats = {
|
| 383 |
"num_results": len(results),
|
|
@@ -385,14 +390,34 @@ def calculate_statistics(results, search_time, vector_store, num_tokens, embeddi
|
|
| 385 |
"min_content_length": min([len(doc.page_content) for doc in results]) if results else 0,
|
| 386 |
"max_content_length": max([len(doc.page_content) for doc in results]) if results else 0,
|
| 387 |
"search_time": search_time,
|
| 388 |
-
"vector_store_size": vector_store._index.ntotal if hasattr(vector_store, '_index') else "N/A",
|
| 389 |
-
"num_documents": len(vector_store.docstore._dict),
|
| 390 |
"num_tokens": num_tokens,
|
| 391 |
-
"embedding_vocab_size": embedding_model.client.get_vocab_size() if hasattr(embedding_model, 'client') and hasattr(embedding_model.client, 'get_vocab_size') else "N/A",
|
| 392 |
"embedding_dimension": len(embedding_model.embed_query(query)),
|
| 393 |
"top_k": top_k,
|
| 394 |
}
|
| 395 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 396 |
if expected_result:
|
| 397 |
stats["contains_expected"] = any(expected_result in doc.page_content for doc in results)
|
| 398 |
stats["expected_result_rank"] = next((i for i, doc in enumerate(results) if expected_result in doc.page_content), -1) + 1
|
|
@@ -419,35 +444,55 @@ def calculate_statistics(results, search_time, vector_store, num_tokens, embeddi
|
|
| 419 |
return stats
|
| 420 |
# Visualization
|
| 421 |
def visualize_results(results_df, stats_df):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 422 |
fig, axs = plt.subplots(2, 2, figsize=(20, 20))
|
| 423 |
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
axs[0, 0].set_xticklabels(axs[0, 0].get_xticklabels(), rotation=45, ha='right')
|
| 429 |
|
| 430 |
-
|
| 431 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 432 |
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
|
|
|
|
|
|
| 437 |
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 447 |
|
| 448 |
plt.tight_layout()
|
| 449 |
return fig
|
| 450 |
-
|
| 451 |
def optimize_vocabulary(texts, vocab_size=10000, min_frequency=2):
|
| 452 |
tokenizer = Tokenizer(models.BPE(unk_token="[UNK]"))
|
| 453 |
|
|
@@ -465,8 +510,15 @@ def optimize_vocabulary(texts, vocab_size=10000, min_frequency=2):
|
|
| 465 |
|
| 466 |
# New postprocessing function
|
| 467 |
def rerank_results(results, query, reranker):
|
| 468 |
-
|
| 469 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 470 |
|
| 471 |
# Main Comparison Function
|
| 472 |
def compare_embeddings(file, query, embedding_models, custom_embedding_model, split_strategy, chunk_size, overlap_size, custom_separators, vector_store_type, search_type, top_k, expected_result=None, lang='german', apply_preprocessing=True, optimize_vocab=False, apply_phonetic=True, phonetic_weight=0.3, custom_tokenizer_file=None, custom_tokenizer_model=None, custom_tokenizer_vocab_size=10000, custom_tokenizer_special_tokens=None, use_query_optimization=False, query_optimization_model="google/flan-t5-base", use_reranking=False):
|
|
|
|
| 347 |
results = retriever.invoke(preprocessed_query)
|
| 348 |
|
| 349 |
def score_result(doc):
|
| 350 |
+
base_score = vector_store.similarity_search_with_score(doc.page_content, k=1)[0][1]
|
| 351 |
+
|
| 352 |
+
# Add bonus for containing expected result
|
| 353 |
+
expected_bonus = 0.3 if expected_result and expected_result in doc.page_content else 0
|
| 354 |
+
|
| 355 |
if apply_phonetic:
|
| 356 |
phonetic_score = phonetic_match(doc.page_content, query)
|
| 357 |
+
return (1 - phonetic_weight) * base_score + phonetic_weight * phonetic_score + expected_bonus
|
| 358 |
else:
|
| 359 |
+
return base_score + expected_bonus
|
| 360 |
|
| 361 |
results = sorted(results, key=score_result, reverse=True)
|
| 362 |
end_time = time.time()
|
|
|
|
| 382 |
# Evaluation Metrics
|
| 383 |
# ... (previous code remains the same)
|
| 384 |
|
| 385 |
+
def calculate_statistics(results, search_time, vector_store, num_tokens, embedding_model, query, top_k, expected_result=None):
|
| 386 |
def calculate_statistics(results, search_time, vector_store, num_tokens, embedding_model, query, top_k, expected_result=None):
|
| 387 |
stats = {
|
| 388 |
"num_results": len(results),
|
|
|
|
| 390 |
"min_content_length": min([len(doc.page_content) for doc in results]) if results else 0,
|
| 391 |
"max_content_length": max([len(doc.page_content) for doc in results]) if results else 0,
|
| 392 |
"search_time": search_time,
|
|
|
|
|
|
|
| 393 |
"num_tokens": num_tokens,
|
|
|
|
| 394 |
"embedding_dimension": len(embedding_model.embed_query(query)),
|
| 395 |
"top_k": top_k,
|
| 396 |
}
|
| 397 |
|
| 398 |
+
# Safely get vector store size
|
| 399 |
+
try:
|
| 400 |
+
if hasattr(vector_store, '_index'):
|
| 401 |
+
stats["vector_store_size"] = vector_store._index.ntotal
|
| 402 |
+
elif hasattr(vector_store, '_collection'):
|
| 403 |
+
stats["vector_store_size"] = len(vector_store._collection.get())
|
| 404 |
+
else:
|
| 405 |
+
stats["vector_store_size"] = "N/A"
|
| 406 |
+
except:
|
| 407 |
+
stats["vector_store_size"] = "N/A"
|
| 408 |
+
|
| 409 |
+
# Safely get document count
|
| 410 |
+
try:
|
| 411 |
+
if hasattr(vector_store, 'docstore'):
|
| 412 |
+
stats["num_documents"] = len(vector_store.docstore._dict)
|
| 413 |
+
elif hasattr(vector_store, '_collection'):
|
| 414 |
+
stats["num_documents"] = len(vector_store._collection.get())
|
| 415 |
+
else:
|
| 416 |
+
stats["num_documents"] = len(results)
|
| 417 |
+
except:
|
| 418 |
+
stats["num_documents"] = len(results)
|
| 419 |
+
|
| 420 |
+
|
| 421 |
if expected_result:
|
| 422 |
stats["contains_expected"] = any(expected_result in doc.page_content for doc in results)
|
| 423 |
stats["expected_result_rank"] = next((i for i, doc in enumerate(results) if expected_result in doc.page_content), -1) + 1
|
|
|
|
| 444 |
return stats
|
| 445 |
# Visualization
|
| 446 |
def visualize_results(results_df, stats_df):
|
| 447 |
+
# Add model column if not present
|
| 448 |
+
if 'model' not in stats_df.columns:
|
| 449 |
+
stats_df['model'] = stats_df['model_type'] + ' - ' + stats_df['model_name']
|
| 450 |
+
|
| 451 |
fig, axs = plt.subplots(2, 2, figsize=(20, 20))
|
| 452 |
|
| 453 |
+
# Handle empty dataframe case
|
| 454 |
+
if len(stats_df) == 0:
|
| 455 |
+
return fig
|
|
|
|
|
|
|
| 456 |
|
| 457 |
+
# Create plots with error handling
|
| 458 |
+
try:
|
| 459 |
+
sns.barplot(data=stats_df, x='model', y='search_time', ax=axs[0, 0])
|
| 460 |
+
axs[0, 0].set_title('Search Time by Model')
|
| 461 |
+
axs[0, 0].tick_params(axis='x', rotation=45)
|
| 462 |
+
except Exception as e:
|
| 463 |
+
print(f"Error in search time plot: {e}")
|
| 464 |
|
| 465 |
+
try:
|
| 466 |
+
sns.scatterplot(data=stats_df, x='result_diversity', y='rank_correlation',
|
| 467 |
+
hue='model', ax=axs[0, 1])
|
| 468 |
+
axs[0, 1].set_title('Result Diversity vs. Rank Correlation')
|
| 469 |
+
except Exception as e:
|
| 470 |
+
print(f"Error in diversity plot: {e}")
|
| 471 |
|
| 472 |
+
try:
|
| 473 |
+
sns.boxplot(data=stats_df, x='model', y='avg_content_length', ax=axs[1, 0])
|
| 474 |
+
axs[1, 0].set_title('Distribution of Result Content Lengths')
|
| 475 |
+
axs[1, 0].tick_params(axis='x', rotation=45)
|
| 476 |
+
except Exception as e:
|
| 477 |
+
print(f"Error in content length plot: {e}")
|
| 478 |
+
|
| 479 |
+
try:
|
| 480 |
+
valid_embeddings = results_df['embedding'].dropna().values
|
| 481 |
+
if len(valid_embeddings) > 1:
|
| 482 |
+
tsne = TSNE(n_components=2, random_state=42)
|
| 483 |
+
embeddings_2d = tsne.fit_transform(np.vstack(valid_embeddings))
|
| 484 |
+
sns.scatterplot(x=embeddings_2d[:, 0], y=embeddings_2d[:, 1],
|
| 485 |
+
hue=results_df['Model'][:len(valid_embeddings)],
|
| 486 |
+
ax=axs[1, 1])
|
| 487 |
+
axs[1, 1].set_title('t-SNE Visualization of Result Embeddings')
|
| 488 |
+
else:
|
| 489 |
+
axs[1, 1].text(0.5, 0.5, "Not enough embeddings for visualization",
|
| 490 |
+
ha='center', va='center')
|
| 491 |
+
except Exception as e:
|
| 492 |
+
print(f"Error in embedding visualization: {e}")
|
| 493 |
|
| 494 |
plt.tight_layout()
|
| 495 |
return fig
|
|
|
|
| 496 |
def optimize_vocabulary(texts, vocab_size=10000, min_frequency=2):
|
| 497 |
tokenizer = Tokenizer(models.BPE(unk_token="[UNK]"))
|
| 498 |
|
|
|
|
| 510 |
|
| 511 |
# New postprocessing function
|
| 512 |
def rerank_results(results, query, reranker):
|
| 513 |
+
if not hasattr(reranker, 'rerank'):
|
| 514 |
+
# For TextClassificationPipeline
|
| 515 |
+
pairs = [[query, doc.page_content] for doc in results]
|
| 516 |
+
scores = [pred['score'] for pred in reranker(pairs, function_to_apply='cross_entropy')]
|
| 517 |
+
reranked_idx = np.argsort(scores)[::-1]
|
| 518 |
+
return [results[i] for i in reranked_idx]
|
| 519 |
+
else:
|
| 520 |
+
# For models with rerank method
|
| 521 |
+
return reranker.rerank(query, [doc.page_content for doc in results])
|
| 522 |
|
| 523 |
# Main Comparison Function
|
| 524 |
def compare_embeddings(file, query, embedding_models, custom_embedding_model, split_strategy, chunk_size, overlap_size, custom_separators, vector_store_type, search_type, top_k, expected_result=None, lang='german', apply_preprocessing=True, optimize_vocab=False, apply_phonetic=True, phonetic_weight=0.3, custom_tokenizer_file=None, custom_tokenizer_model=None, custom_tokenizer_vocab_size=10000, custom_tokenizer_special_tokens=None, use_query_optimization=False, query_optimization_model="google/flan-t5-base", use_reranking=False):
|