import gradio as gr import faiss import numpy as np from datasets import load_dataset from sentence_transformers import SentenceTransformer from sklearn.metrics import ndcg_score # ---------------------------- # Load dataset (MS MARCO v1.1) # ---------------------------- dataset = load_dataset("ms_marco", "v1.1", split="train[:10000]") passages = [item["passage"] for item in dataset] print(f"Loaded {len(passages)} passages") # ---------------------------- # Load SBERT model # ---------------------------- model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2") # ---------------------------- # Build FAISS index # ---------------------------- embeddings = model.encode(passages, convert_to_numpy=True, show_progress_bar=True) dimension = embeddings.shape[1] index = faiss.IndexFlatL2(dimension) index.add(embeddings) print("FAISS index built with", index.ntotal, "passages") # ---------------------------- # Search function # ---------------------------- def search(query, k=10): query_vec = model.encode([query], convert_to_numpy=True) distances, indices = index.search(query_vec, k) results = [(passages[i], float(dist)) for i, dist in zip(indices[0], distances[0])] return results # ---------------------------- # Evaluation metrics # ---------------------------- def evaluate(query, relevant_passages, k=10): """Compute IR metrics for a query given a list of relevant passages (ground truth).""" results = search(query, k) retrieved = [res[0] for res in results] # Binary relevance vector y_true = [1 if p in relevant_passages else 0 for p in retrieved] y_true_full = np.array([[1 if passages[i] in relevant_passages else 0 for i in range(len(passages))]]) y_scores_full = np.zeros((1, len(passages))) for idx, (res, dist) in enumerate(results): pos = passages.index(res) y_scores_full[0, pos] = 1.0 - dist # higher score = more relevant # Metrics precision = sum(y_true) / k recall = sum(y_true) / len(relevant_passages) if relevant_passages else 0 f1 = (2 * precision * recall) / (precision + recall) if (precision+recall) > 0 else 0 mrr = 1.0 / (y_true.index(1)+1) if 1 in y_true else 0 ndcg = ndcg_score(y_true_full, y_scores_full, k=k) return { "Precision@10": round(precision, 3), "Recall@10": round(recall, 3), "F1": round(f1, 3), "MRR": round(mrr, 3), "nDCG@10": round(ndcg, 3) } # ---------------------------- # Gradio interface # ---------------------------- def gradio_interface(query, relevant_texts): results = search(query, k=10) metrics = {} if relevant_texts.strip(): relevant_passages = [t.strip() for t in relevant_texts.split("\n") if t.strip()] metrics = evaluate(query, relevant_passages, k=10) return results, metrics demo = gr.Interface( fn=gradio_interface, inputs=[ gr.Textbox(label="Enter your query"), gr.Textbox(label="Enter relevant passages (ground truth, one per line)", placeholder="Optional") ], outputs=[ gr.Dataframe(headers=["Passage", "Distance"], label="Top-10 Results"), gr.Label(label="Evaluation Metrics") ], title="SBERT + FAISS Semantic Search", description="Enter a query to search MS MARCO passages. Optionally provide ground truth passages to compute IR metrics." ) if __name__ == "__main__": demo.launch()