Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -531,6 +531,111 @@ def format_results(results, stats):
|
|
| 531 |
formatted_results.append(result)
|
| 532 |
return formatted_results
|
| 533 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 534 |
# Gradio Interface
|
| 535 |
def launch_interface(share=True):
|
| 536 |
with gr.Blocks() as iface:
|
|
@@ -592,6 +697,51 @@ def launch_interface(share=True):
|
|
| 592 |
outputs=[results_output, stats_output, plot_output]
|
| 593 |
)
|
| 594 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 595 |
|
| 596 |
tutorial_md = """
|
| 597 |
# Advanced Embedding Comparison Tool Tutorial
|
|
@@ -618,5 +768,33 @@ def launch_interface(share=True):
|
|
| 618 |
|
| 619 |
iface.launch(share=share)
|
| 620 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 621 |
if __name__ == "__main__":
|
| 622 |
launch_interface()
|
|
|
|
| 531 |
formatted_results.append(result)
|
| 532 |
return formatted_results
|
| 533 |
|
| 534 |
+
|
| 535 |
+
#####
|
| 536 |
+
from sklearn.model_selection import ParameterGrid
|
| 537 |
+
from tqdm import tqdm
|
| 538 |
+
|
| 539 |
+
# ... (previous code remains the same)
|
| 540 |
+
|
| 541 |
+
# New function for automated testing
|
| 542 |
+
def automated_testing(file, query, test_params):
|
| 543 |
+
all_results = []
|
| 544 |
+
all_stats = []
|
| 545 |
+
|
| 546 |
+
param_grid = ParameterGrid(test_params)
|
| 547 |
+
|
| 548 |
+
for params in tqdm(param_grid, desc="Running tests"):
|
| 549 |
+
chunks, embedding_model, num_tokens = process_files(
|
| 550 |
+
file.name if file else None,
|
| 551 |
+
params['model_type'],
|
| 552 |
+
params['model_name'],
|
| 553 |
+
params['split_strategy'],
|
| 554 |
+
params['chunk_size'],
|
| 555 |
+
params['overlap_size'],
|
| 556 |
+
params.get('custom_separators', None),
|
| 557 |
+
params['lang'],
|
| 558 |
+
params['apply_preprocessing'],
|
| 559 |
+
params.get('custom_tokenizer_file', None),
|
| 560 |
+
params.get('custom_tokenizer_model', None),
|
| 561 |
+
params.get('custom_tokenizer_vocab_size', 10000),
|
| 562 |
+
params.get('custom_tokenizer_special_tokens', None)
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
if params['optimize_vocab']:
|
| 566 |
+
tokenizer, optimized_chunks = optimize_vocabulary(chunks)
|
| 567 |
+
chunks = optimized_chunks
|
| 568 |
+
|
| 569 |
+
if params['use_query_optimization']:
|
| 570 |
+
optimized_queries = optimize_query(query, params['query_optimization_model'])
|
| 571 |
+
query = " ".join(optimized_queries)
|
| 572 |
+
|
| 573 |
+
results, search_time, vector_store, results_raw = search_embeddings(
|
| 574 |
+
chunks,
|
| 575 |
+
embedding_model,
|
| 576 |
+
params['vector_store_type'],
|
| 577 |
+
params['search_type'],
|
| 578 |
+
query,
|
| 579 |
+
params['top_k'],
|
| 580 |
+
params['lang'],
|
| 581 |
+
params['apply_phonetic'],
|
| 582 |
+
params['phonetic_weight']
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
if params['use_reranking']:
|
| 586 |
+
reranker = pipeline("text-classification", model="cross-encoder/ms-marco-MiniLM-L-12-v2")
|
| 587 |
+
results_raw = rerank_results(results_raw, query, reranker)
|
| 588 |
+
|
| 589 |
+
stats = calculate_statistics(results_raw, search_time, vector_store, num_tokens, embedding_model, query, params['top_k'])
|
| 590 |
+
stats["model"] = f"{params['model_type']} - {params['model_name']}"
|
| 591 |
+
stats.update(params)
|
| 592 |
+
|
| 593 |
+
all_results.extend(format_results(results_raw, stats))
|
| 594 |
+
all_stats.append(stats)
|
| 595 |
+
|
| 596 |
+
return pd.DataFrame(all_results), pd.DataFrame(all_stats)
|
| 597 |
+
|
| 598 |
+
# Function to analyze results and propose best model and settings
|
| 599 |
+
def analyze_results(stats_df):
|
| 600 |
+
# Define weights for different metrics (adjust as needed)
|
| 601 |
+
metric_weights = {
|
| 602 |
+
'search_time': -0.3, # Lower is better
|
| 603 |
+
'result_diversity': 0.2,
|
| 604 |
+
'rank_correlation': 0.3,
|
| 605 |
+
'silhouette_score': 0.2
|
| 606 |
+
}
|
| 607 |
+
|
| 608 |
+
# Calculate weighted score for each configuration
|
| 609 |
+
stats_df['weighted_score'] = sum(stats_df[metric] * weight for metric, weight in metric_weights.items())
|
| 610 |
+
|
| 611 |
+
# Get the best configuration
|
| 612 |
+
best_config = stats_df.loc[stats_df['weighted_score'].idxmax()]
|
| 613 |
+
|
| 614 |
+
# Generate recommendations
|
| 615 |
+
recommendations = {
|
| 616 |
+
'best_model': f"{best_config['model_type']} - {best_config['model_name']}",
|
| 617 |
+
'best_settings': {
|
| 618 |
+
'split_strategy': best_config['split_strategy'],
|
| 619 |
+
'chunk_size': best_config['chunk_size'],
|
| 620 |
+
'overlap_size': best_config['overlap_size'],
|
| 621 |
+
'vector_store_type': best_config['vector_store_type'],
|
| 622 |
+
'search_type': best_config['search_type'],
|
| 623 |
+
'top_k': best_config['top_k'],
|
| 624 |
+
'optimize_vocab': best_config['optimize_vocab'],
|
| 625 |
+
'use_query_optimization': best_config['use_query_optimization'],
|
| 626 |
+
'use_reranking': best_config['use_reranking']
|
| 627 |
+
},
|
| 628 |
+
'performance_summary': {
|
| 629 |
+
'search_time': best_config['search_time'],
|
| 630 |
+
'result_diversity': best_config['result_diversity'],
|
| 631 |
+
'rank_correlation': best_config['rank_correlation'],
|
| 632 |
+
'silhouette_score': best_config['silhouette_score']
|
| 633 |
+
}
|
| 634 |
+
}
|
| 635 |
+
|
| 636 |
+
return recommendations
|
| 637 |
+
####
|
| 638 |
+
|
| 639 |
# Gradio Interface
|
| 640 |
def launch_interface(share=True):
|
| 641 |
with gr.Blocks() as iface:
|
|
|
|
| 697 |
outputs=[results_output, stats_output, plot_output]
|
| 698 |
)
|
| 699 |
|
| 700 |
+
####
|
| 701 |
+
with gr.Tab("Automated"):
|
| 702 |
+
auto_file_input = gr.File(label="Upload File (Optional)")
|
| 703 |
+
auto_query_input = gr.Textbox(label="Search Query")
|
| 704 |
+
auto_model_types = gr.CheckboxGroup(
|
| 705 |
+
choices=["HuggingFace", "OpenAI", "Cohere"],
|
| 706 |
+
label="Model Types to Test"
|
| 707 |
+
)
|
| 708 |
+
auto_model_names = gr.TextArea(label="Model Names to Test (comma-separated)")
|
| 709 |
+
auto_split_strategies = gr.CheckboxGroup(
|
| 710 |
+
choices=["token", "recursive"],
|
| 711 |
+
label="Split Strategies to Test"
|
| 712 |
+
)
|
| 713 |
+
auto_chunk_sizes = gr.TextArea(label="Chunk Sizes to Test (comma-separated)")
|
| 714 |
+
auto_overlap_sizes = gr.TextArea(label="Overlap Sizes to Test (comma-separated)")
|
| 715 |
+
auto_vector_store_types = gr.CheckboxGroup(
|
| 716 |
+
choices=["FAISS", "Chroma"],
|
| 717 |
+
label="Vector Store Types to Test"
|
| 718 |
+
)
|
| 719 |
+
auto_search_types = gr.CheckboxGroup(
|
| 720 |
+
choices=["similarity", "mmr", "custom"],
|
| 721 |
+
label="Search Types to Test"
|
| 722 |
+
)
|
| 723 |
+
auto_top_k = gr.TextArea(label="Top K Values to Test (comma-separated)")
|
| 724 |
+
auto_optimize_vocab = gr.Checkbox(label="Test Vocabulary Optimization", value=True)
|
| 725 |
+
auto_use_query_optimization = gr.Checkbox(label="Test Query Optimization", value=True)
|
| 726 |
+
auto_use_reranking = gr.Checkbox(label="Test Reranking", value=True)
|
| 727 |
+
|
| 728 |
+
auto_results_output = gr.Dataframe(label="Automated Test Results", interactive=False)
|
| 729 |
+
auto_stats_output = gr.Dataframe(label="Automated Test Statistics", interactive=False)
|
| 730 |
+
recommendations_output = gr.JSON(label="Recommendations")
|
| 731 |
+
|
| 732 |
+
auto_submit_button = gr.Button("Run Automated Tests")
|
| 733 |
+
auto_submit_button.click(
|
| 734 |
+
fn=lambda *args: run_automated_tests_and_analyze(*args),
|
| 735 |
+
inputs=[
|
| 736 |
+
auto_file_input, auto_query_input, auto_model_types, auto_model_names,
|
| 737 |
+
auto_split_strategies, auto_chunk_sizes, auto_overlap_sizes,
|
| 738 |
+
auto_vector_store_types, auto_search_types, auto_top_k,
|
| 739 |
+
auto_optimize_vocab, auto_use_query_optimization, auto_use_reranking
|
| 740 |
+
],
|
| 741 |
+
outputs=[auto_results_output, auto_stats_output, recommendations_output]
|
| 742 |
+
)
|
| 743 |
+
###
|
| 744 |
+
|
| 745 |
|
| 746 |
tutorial_md = """
|
| 747 |
# Advanced Embedding Comparison Tool Tutorial
|
|
|
|
| 768 |
|
| 769 |
iface.launch(share=share)
|
| 770 |
|
| 771 |
+
def run_automated_tests_and_analyze(*args):
|
| 772 |
+
file, query, model_types, model_names, split_strategies, chunk_sizes, overlap_sizes, \
|
| 773 |
+
vector_store_types, search_types, top_k_values, optimize_vocab, use_query_optimization, use_reranking = args
|
| 774 |
+
|
| 775 |
+
test_params = {
|
| 776 |
+
'model_type': model_types,
|
| 777 |
+
'model_name': [name.strip() for name in model_names.split(',')],
|
| 778 |
+
'split_strategy': split_strategies,
|
| 779 |
+
'chunk_size': [int(size.strip()) for size in chunk_sizes.split(',')],
|
| 780 |
+
'overlap_size': [int(size.strip()) for size in overlap_sizes.split(',')],
|
| 781 |
+
'vector_store_type': vector_store_types,
|
| 782 |
+
'search_type': search_types,
|
| 783 |
+
'top_k': [int(k.strip()) for k in top_k_values.split(',')],
|
| 784 |
+
'lang': ['german'], # You can expand this if needed
|
| 785 |
+
'apply_preprocessing': [True],
|
| 786 |
+
'optimize_vocab': [optimize_vocab],
|
| 787 |
+
'apply_phonetic': [True],
|
| 788 |
+
'phonetic_weight': [0.3],
|
| 789 |
+
'use_query_optimization': [use_query_optimization],
|
| 790 |
+
'query_optimization_model': ['google/flan-t5-base'],
|
| 791 |
+
'use_reranking': [use_reranking]
|
| 792 |
+
}
|
| 793 |
+
|
| 794 |
+
results_df, stats_df = automated_testing(file, query, test_params)
|
| 795 |
+
recommendations = analyze_results(stats_df)
|
| 796 |
+
|
| 797 |
+
return results_df, stats_df, recommendations
|
| 798 |
+
|
| 799 |
if __name__ == "__main__":
|
| 800 |
launch_interface()
|