| import gradio as gr | |
| import pandas as pd | |
| import faiss | |
| import numpy as np | |
| import os | |
| from FlagEmbedding import BGEM3FlagModel | |
| # Load the pre-trained embedding model | |
| model = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True) | |
| # Load the JSON data into a DataFrame | |
| df = pd.read_json('White-Stride-Red-68.json') | |
| df['embeding_context'] = df['embeding_context'].astype(str).fillna('') | |
| # Filter out any rows where 'embeding_context' might be empty or invalid | |
| df = df[df['embeding_context'] != ''] | |
| # # Encode the 'embeding_context' column | |
| # embedding_contexts = df['embeding_context'].tolist() | |
| # embeddings_csv = model.encode(embedding_contexts, batch_size=12, max_length=1024)['dense_vecs'] | |
| # # Convert embeddings to numpy array | |
| # embeddings_np = np.array(embeddings_csv).astype('float32') | |
| # # FAISS index file path | |
| # index_file_path = 'vector_store_bge_m3.index' | |
| # # Check if FAISS index file already exists | |
| # if os.path.exists(index_file_path): | |
| # # Load the existing FAISS index from file | |
| # index = faiss.read_index(index_file_path) | |
| # print("FAISS index loaded from file.") | |
| # else: | |
| # # Initialize FAISS index (for L2 similarity) | |
| # dim = embeddings_np.shape[1] | |
| # index = faiss.IndexFlatL2(dim) | |
| # # Add embeddings to the FAISS index | |
| # index.add(embeddings_np) | |
| # # Save the FAISS index to a file for future use | |
| # faiss.write_index(index, index_file_path) | |
| # print("FAISS index created and saved to file.") | |
| index = faiss.read_index('vector_store_bge_m3.index') | |
| # Function to perform search and return all columns | |
| def search_query(query_text): | |
| num_records = 50 | |
| # Encode the input query text | |
| embeddings_query = model.encode([query_text], batch_size=12, max_length=1024)['dense_vecs'] | |
| embeddings_query_np = np.array(embeddings_query).astype('float32') | |
| # Search in FAISS index for nearest neighbors | |
| distances, indices = index.search(embeddings_query_np, num_records) | |
| # Get the top results based on FAISS indices | |
| result_df = df.iloc[indices[0]].drop(columns=['embeding_context']).drop_duplicates().reset_index(drop=True) | |
| return result_df | |
| # Gradio interface function | |
| def gradio_interface(query_text): | |
| search_results = search_query(query_text) | |
| return search_results | |
| with gr.Blocks() as app: | |
| gr.Markdown("<h1>White Stride Red Search (BEG-M3)</h1>") | |
| # Input text box for the search query | |
| search_input = gr.Textbox(label="Search Query", placeholder="Enter search text", interactive=True) | |
| # Search button below the text box | |
| search_button = gr.Button("Search") | |
| # Output table for displaying results | |
| search_output = gr.DataFrame(label="Search Results") | |
| # Link button click to action | |
| search_button.click(fn=gradio_interface, inputs=search_input, outputs=search_output) | |
| # Launch the Gradio app | |
| app.launch() | |