Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import json | |
| # Import your modules here | |
| from Agents.togetherAIAgent import generate_article_from_query | |
| from Agents.wikiAgent import get_wiki_data | |
| from Agents.rankerAgent import rankerAgent | |
| from Query_Modification.QueryModification import query_Modifier, getKeywords | |
| from Ranking.RRF.RRF_implementation import reciprocal_rank_fusion_three, reciprocal_rank_fusion_six | |
| from Retrieval.tf_idf import tf_idf_pipeline | |
| from Retrieval.bm25 import bm25_pipeline | |
| from Retrieval.vision import vision_pipeline | |
| from Retrieval.openSource import open_source_pipeline | |
| from Baseline.boolean import boolean_pipeline | |
| from AnswerGeneration.getAnswer import generate_answer_withContext, generate_answer_zeroShot | |
| # Load miniWikiCollection | |
| miniWikiCollection = json.load(open('Datasets/mini_wiki_collection.json', 'r')) | |
| miniWikiCollectionDict = {wiki['wikipedia_id']: " ".join(wiki['text']) for wiki in miniWikiCollection} | |
| def process_query(query): | |
| # Query modification | |
| modified_query = query_Modifier(query) | |
| # Context Generation | |
| article = generate_article_from_query(query) | |
| # Keyword Extraction and getting context from Wiki | |
| keywords = getKeywords(query) | |
| wiki_data = get_wiki_data(keywords) | |
| # Retrieve rankings | |
| boolean_ranking = boolean_pipeline(query) | |
| tf_idf_ranking = tf_idf_pipeline(query) | |
| bm25_ranking = bm25_pipeline(query) | |
| vision_ranking = vision_pipeline(query) | |
| open_source_ranking = open_source_pipeline(query) | |
| # Modified queries | |
| boolean_ranking_modified = boolean_pipeline(modified_query) | |
| tf_idf_ranking_modified = tf_idf_pipeline(modified_query) | |
| bm25_ranking_modified = bm25_pipeline(modified_query) | |
| vision_ranking_modified = vision_pipeline(modified_query) | |
| open_source_ranking_modified = open_source_pipeline(modified_query) | |
| # RRF rankings | |
| tf_idf_bm25_open_RRF_Ranking = reciprocal_rank_fusion_three(tf_idf_ranking, bm25_ranking, open_source_ranking) | |
| tf_idf_bm25_open_RRF_Ranking_modified = reciprocal_rank_fusion_three(tf_idf_ranking_modified, bm25_ranking_modified, open_source_ranking_modified) | |
| tf_idf_bm25_open_RRF_Ranking_combined = reciprocal_rank_fusion_six( | |
| tf_idf_ranking, bm25_ranking, open_source_ranking, | |
| tf_idf_ranking_modified, bm25_ranking_modified, open_source_ranking_modified | |
| ) | |
| try: | |
| agent1_context = wiki_data[0] | |
| except: | |
| agent1_context = "Can't find a Wiki article for this query." | |
| agent2_context = article | |
| try: | |
| boolean_context = miniWikiCollectionDict[boolean_ranking[0]] | |
| except: | |
| boolean_context = "Can't find a matching document for this query." | |
| tf_idf_context = miniWikiCollectionDict[tf_idf_ranking[0]] | |
| bm25_context = miniWikiCollectionDict[str(bm25_ranking[0])] | |
| vision_context = miniWikiCollectionDict[vision_ranking[0]] | |
| open_source_context = miniWikiCollectionDict[open_source_ranking[0]] | |
| boolean_context_modified = miniWikiCollectionDict[boolean_ranking_modified[0]] | |
| tf_idf_context_modified = miniWikiCollectionDict[tf_idf_ranking_modified[0]] | |
| bm25_context_modified = miniWikiCollectionDict[str(bm25_ranking_modified[0])] | |
| vision_context_modified = miniWikiCollectionDict[vision_ranking_modified[0]] | |
| open_source_context_modified = miniWikiCollectionDict[open_source_ranking_modified[0]] | |
| tf_idf_bm25_open_RRF_Ranking_context = miniWikiCollectionDict[tf_idf_bm25_open_RRF_Ranking[0]] | |
| tf_idf_bm25_open_RRF_Ranking_modified_context = miniWikiCollectionDict[tf_idf_bm25_open_RRF_Ranking_modified[0]] | |
| tf_idf_bm25_open_RRF_Ranking_combined_context = miniWikiCollectionDict[tf_idf_bm25_open_RRF_Ranking_combined[0]] | |
| # Generating answers | |
| agent1_answer = generate_answer_withContext(query, agent1_context) | |
| agent2_answer = generate_answer_withContext(query, agent2_context) | |
| boolean_answer = generate_answer_withContext(query, boolean_context) | |
| tf_idf_answer = generate_answer_withContext(query, tf_idf_context) | |
| bm25_answer = generate_answer_withContext(query, bm25_context) | |
| vision_answer = generate_answer_withContext(query, vision_context) | |
| open_source_answer = generate_answer_withContext(query, open_source_context) | |
| boolean_answer_modified = generate_answer_withContext(modified_query, boolean_context_modified) | |
| tf_idf_answer_modified = generate_answer_withContext(modified_query, tf_idf_context_modified) | |
| bm25_answer_modified = generate_answer_withContext(modified_query, bm25_context_modified) | |
| vision_answer_modified = generate_answer_withContext(modified_query, vision_context_modified) | |
| open_source_answer_modified = generate_answer_withContext(modified_query, open_source_context_modified) | |
| tf_idf_bm25_open_RRF_Ranking_answer = generate_answer_withContext(query, tf_idf_bm25_open_RRF_Ranking_context) | |
| tf_idf_bm25_open_RRF_Ranking_modified_answer = generate_answer_withContext(modified_query, tf_idf_bm25_open_RRF_Ranking_modified_context) | |
| tf_idf_bm25_open_RRF_Ranking_combined_answer = generate_answer_withContext(query, tf_idf_bm25_open_RRF_Ranking_combined_context) | |
| zeroShot = generate_answer_zeroShot(query) | |
| # Ranking the best answer | |
| rankerAgentInput = { | |
| "query": query, | |
| "agent1": agent1_answer, | |
| "agent2": agent2_answer, | |
| "boolean": boolean_answer, | |
| "tf_idf": tf_idf_answer, | |
| "bm25": bm25_answer, | |
| "vision": vision_answer, | |
| "open_source": open_source_answer, | |
| "boolean_modified": boolean_answer_modified, | |
| "tf_idf_modified": tf_idf_answer_modified, | |
| "bm25_modified": bm25_answer_modified, | |
| "vision_modified": vision_answer_modified, | |
| "open_source_modified": open_source_answer_modified, | |
| "tf_idf_bm25_open_RRF_Ranking": tf_idf_bm25_open_RRF_Ranking_answer, | |
| "tf_idf_bm25_open_RRF_Ranking_modified": tf_idf_bm25_open_RRF_Ranking_modified_answer, | |
| "tf_idf_bm25_open_RRF_Ranking_combined": tf_idf_bm25_open_RRF_Ranking_combined_answer, | |
| "zeroShot": zeroShot | |
| } | |
| best_model, best_answer = rankerAgent(rankerAgentInput) | |
| return ( | |
| best_model, | |
| best_answer, | |
| agent1_answer, agent1_context, | |
| agent2_answer, agent2_context, | |
| boolean_answer, boolean_context, | |
| tf_idf_answer, tf_idf_context, | |
| bm25_answer, bm25_context, | |
| vision_answer, vision_context, | |
| open_source_answer, open_source_context, | |
| boolean_answer_modified, boolean_context_modified, | |
| tf_idf_answer_modified, tf_idf_context_modified, | |
| bm25_answer_modified, bm25_context_modified, | |
| vision_answer_modified, vision_context_modified, | |
| open_source_answer_modified, open_source_context_modified, | |
| tf_idf_bm25_open_RRF_Ranking_answer, tf_idf_bm25_open_RRF_Ranking_context, | |
| tf_idf_bm25_open_RRF_Ranking_modified_answer, tf_idf_bm25_open_RRF_Ranking_modified_context, | |
| tf_idf_bm25_open_RRF_Ranking_combined_answer, tf_idf_bm25_open_RRF_Ranking_combined_context, | |
| zeroShot, "Zero-shot doesn't have a context." | |
| ) | |
| # CSS Styling for the fancy effects | |
| css = """ | |
| #fancy-column { | |
| background: linear-gradient(135deg, #1a242f, #2b3a44); /* Dark blue-gray gradient background */ | |
| padding: 20px; | |
| border-radius: 15px; | |
| } | |
| #query-input, #submit-button, #best-model-output, #best-answer-output { | |
| border-radius: 10px; /* Rounded corners */ | |
| box-shadow: 0 4px 6px rgba(0, 0, 0, 0.3); /* Darker shadow for better contrast */ | |
| background-color: #34495e; /* Dark background for inputs */ | |
| color: #ecf0f1; /* Light text for good readability */ | |
| } | |
| #query-input:focus, #submit-button:focus, #best-model-output:focus, #best-answer-output:focus { | |
| outline: none; | |
| border: 2px solid #7f8c8d; /* Subtle accent border on focus */ | |
| } | |
| #submit-button { | |
| background-color: #16a085; /* Muted teal color for button */ | |
| color: #ecf0f1; /* Light text for button */ | |
| font-weight: bold; | |
| padding: 10px; | |
| } | |
| #submit-button:hover { | |
| background-color: #1abc9c; /* Slightly lighter teal on hover */ | |
| } | |
| #best-model-output, #best-answer-output { | |
| background-color: #2c3e50; /* Darker background for output boxes */ | |
| } | |
| #best-model-output label, #best-answer-output label, #query-input label { | |
| color: #ecf0f1; /* Light text for labels */ | |
| } | |
| """ | |
| # Interface creation | |
| def create_interface(): | |
| with gr.Blocks() as interface: | |
| with gr.Column(elem_id="fancy-column", scale=3): # Fancy column with extra styling | |
| with gr.Row(): | |
| query_input = gr.Textbox(label="Enter your query", scale=3, elem_id="query-input") | |
| submit_button = gr.Button("Submit", scale=1, elem_id="submit-button") | |
| # Adjusting the spacing between the output fields | |
| with gr.Row(): | |
| best_model_output = gr.Textbox(label="Best Model", interactive=False, scale=1.5, elem_id="best-model-output") | |
| best_answer_output = gr.Textbox(label="Best Answer", interactive=False, scale=1.5, elem_id="best-answer-output") | |
| with gr.Column(): | |
| # Function to create a row for answers and contexts | |
| def create_answer_row(label): | |
| if label == "Agent 1": | |
| label = "Wiki Search" | |
| elif label == "Agent 2": | |
| label = "Llama Context Generation" | |
| elif label == "Open Source Answer": | |
| label = 'MiniLM Text Embedding model' | |
| elif label == "Open Source (Modified)": | |
| label = 'MiniLM Text Embedding model (Modified)' | |
| elif label == "TF-IDF + BM25 + Open RRF": | |
| label = "RRF (TF-IDF + BM25 + MiniLM)" | |
| elif label == "TF-IDF + BM25 + Open RRF (Modified)": | |
| label = "RRF (TF-IDF + BM25 + MiniLM) (Modified)" | |
| elif label == "TF-IDF + BM25 + Open RRF (Combined)": | |
| label = "RRF (TF-IDF + BM25 + MiniLM) (Combined)" | |
| with gr.Row(): | |
| answer_textbox = gr.Textbox(label=f"{label} Answer", interactive=False, scale=1.2, elem_id="best-model-output") | |
| context_textbox = gr.Textbox(label=f"{label} Context", scale=1.8, elem_id="best-answer-output") | |
| return answer_textbox, context_textbox | |
| agent1_output, agent1_context_output = create_answer_row("Agent 1") | |
| agent2_output, agent2_context_output = create_answer_row("Agent 2") | |
| boolean_output, boolean_context_output = create_answer_row("Boolean") | |
| tf_idf_output, tf_idf_context_output = create_answer_row("TF-IDF") | |
| bm25_output, bm25_context_output = create_answer_row("BM25") | |
| vision_output, vision_context_output = create_answer_row("Vision") | |
| open_source_output, open_source_context_output = create_answer_row("Open Source") | |
| boolean_mod_output, boolean_mod_context_output = create_answer_row("Boolean (Modified)") | |
| tf_idf_mod_output, tf_idf_mod_context_output = create_answer_row("TF-IDF (Modified)") | |
| bm25_mod_output, bm25_mod_context_output = create_answer_row("BM25 (Modified)") | |
| vision_mod_output, vision_mod_context_output = create_answer_row("Vision (Modified)") | |
| open_source_mod_output, open_source_mod_context_output = create_answer_row("Open Source (Modified)") | |
| tf_idf_rrf_output, tf_idf_rrf_context_output = create_answer_row("TF-IDF + BM25 + Open RRF") | |
| tf_idf_rrf_mod_output, tf_idf_rrf_mod_context_output = create_answer_row("TF-IDF + BM25 + Open RRF (Modified)") | |
| tf_idf_rrf_combined_output, tf_idf_rrf_combined_context_output = create_answer_row("TF-IDF + BM25 + Open RRF (Combined)") | |
| zero_shot_output, zero_shot_context_output = create_answer_row("Zero Shot") | |
| submit_button.click( | |
| fn=process_query, | |
| inputs=query_input, | |
| outputs=[ | |
| best_model_output, | |
| best_answer_output, | |
| agent1_output, agent1_context_output, | |
| agent2_output, agent2_context_output, | |
| boolean_output, boolean_context_output, | |
| tf_idf_output, tf_idf_context_output, | |
| bm25_output, bm25_context_output, | |
| vision_output, vision_context_output, | |
| open_source_output, open_source_context_output, | |
| boolean_mod_output, boolean_mod_context_output, | |
| tf_idf_mod_output, tf_idf_mod_context_output, | |
| bm25_mod_output, bm25_mod_context_output, | |
| vision_mod_output, vision_mod_context_output, | |
| open_source_mod_output, open_source_mod_context_output, | |
| tf_idf_rrf_output, tf_idf_rrf_context_output, | |
| tf_idf_rrf_mod_output, tf_idf_rrf_mod_context_output, | |
| tf_idf_rrf_combined_output, tf_idf_rrf_combined_context_output, | |
| zero_shot_output, zero_shot_context_output | |
| ] | |
| ) | |
| return interface | |
| # Launch the interface | |
| if __name__ == "__main__": | |
| interface = create_interface() | |
| interface.css = css | |
| interface.launch() | |