Spaces:
Build error
Build error
| # search_content.py | |
| import faiss | |
| import pandas as pd | |
| from sentence_transformers import SentenceTransformer | |
| # Define paths for model, Faiss index, and data file | |
| MODEL_SAVE_PATH = "all-distilroberta-v1-model.pkl" | |
| FAISS_INDEX_FILE_PATH = "index.faiss" | |
| DATA_FILE_PATH = "omdena_qna_dataset/omdena_faq_training_data.csv" | |
| def load_transformer_model(model_file): | |
| """Load a sentence transformer model from a file.""" | |
| return SentenceTransformer.load(model_file) | |
| def load_faiss_index(filename): | |
| """Load a Faiss index from a file.""" | |
| return faiss.read_index(filename) | |
| def load_data(file_path): | |
| """Load data from a CSV file and preprocess it.""" | |
| data_frame = pd.read_csv(file_path) | |
| data_frame["id"] = data_frame.index | |
| # Create a 'QNA' column that combines 'Questions' and 'Answers' | |
| data_frame['QNA'] = data_frame.apply(lambda row: f"Question: {row['Questions']}, Answer: {row['Answers']}", axis=1) | |
| return data_frame.set_index(["id"], drop=False) | |
| def search_content(query, data_frame_indexed, transformer_model, faiss_index, k=5): | |
| """Search the content using a query and return the top k results.""" | |
| # Encode the query using the model | |
| query_vector = transformer_model.encode([query]) | |
| # Normalize the query vector | |
| faiss.normalize_L2(query_vector) | |
| # Search the Faiss index using the query vector | |
| top_k = faiss_index.search(query_vector, k) | |
| # Extract the IDs and similarities of the top k results | |
| ids = top_k[1][0].tolist() | |
| similarities = top_k[0][0].tolist() | |
| # Get the corresponding results from the data frame | |
| results = data_frame_indexed.loc[ids] | |
| # Add a column for the similarities | |
| results["similarities"] = similarities | |
| return results | |
| def main_search(query): | |
| """Main function to execute the search.""" | |
| transformer_model = load_transformer_model(MODEL_SAVE_PATH) | |
| faiss_index = load_faiss_index(FAISS_INDEX_FILE_PATH) | |
| data_frame_indexed = load_data(DATA_FILE_PATH) | |
| results = search_content(query, data_frame_indexed, transformer_model, faiss_index) | |
| return results['Answers'] # return the 'Answers' column | |
| if __name__ == "__main__": | |
| query = "school courses" | |
| print(main_search(query)) # print the results if this script is run directly |