import gradio as gr import numpy as np import pandas as pd import plotly.graph_objects as go import faiss import PyPDF2 from sentence_transformers import SentenceTransformer from sklearn.decomposition import PCA model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") def read_pdf(file): reader = PyPDF2.PdfReader(file) text = "" for page in reader.pages: text += page.extract_text() + "\n" return text.strip() def chunk_text(text, chunk_size): words = text.split() return [ " ".join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size) ] def embed_chunks(chunks): if len(chunks) == 0: return None return model.encode(chunks) def build_faiss_index(vectors): dim = vectors.shape[1] index = faiss.IndexFlatL2(dim) index.add(vectors) return index def search_faiss(index, query_vec, k): distances, indices = index.search(query_vec, k) return indices[0], distances[0] def visualize_3d(vectors, query_vec, highlight_idx=None): if vectors is None or query_vec is None: return go.Figure() all_vecs = np.vstack([vectors, query_vec]) reduced = PCA(n_components=3).fit_transform(all_vecs) fig = go.Figure() # Prepare indices n_chunks = reduced.shape[0] - 1 all_indices = np.arange(n_chunks) highlight_idx = [] if highlight_idx is None else list(map(int, highlight_idx)) # normal chunk points (exclude highlighted) normal_mask = [i for i in all_indices if i not in highlight_idx] if len(normal_mask) > 0: fig.add_trace(go.Scatter3d( x=reduced[normal_mask, 0], y=reduced[normal_mask, 1], z=reduced[normal_mask, 2], mode="markers", marker=dict(size=4, color="skyblue"), name="Chunks" )) # highlighted chunk points if len(highlight_idx) > 0: # ensure we only take valid indices valid_highlights = [i for i in highlight_idx if 0 <= i < n_chunks] if len(valid_highlights) > 0: fig.add_trace(go.Scatter3d( x=reduced[valid_highlights, 0], y=reduced[valid_highlights, 1], z=reduced[valid_highlights, 2], mode="markers+text", text=[f"Top-{j+1}" for j in range(len(valid_highlights))], marker=dict(size=8, color="limegreen", symbol="diamond"), name="Nearest Chunks" )) # query point fig.add_trace(go.Scatter3d( x=[reduced[-1, 0]], y=[reduced[-1, 1]], z=[reduced[-1, 2]], mode="markers+text", text=["Query"], marker=dict(size=10, color="red"), name="Query" )) fig.update_layout( title="📐 RAG Vector Space (3D)", height=500 ) return fig TAB_LABELS = { 0: "Next ➡ Chunking", 1: "Next ➡ Embeddings", 2: "Next ➡ FAISS Search", 3: "Next ➡ Answer", 4: "Done ✅" } def go_next(tab): return min(tab + 1, 4) def go_prev(tab): return max(tab - 1, 0) def update_next_label(tab): return gr.update(value=TAB_LABELS[tab]) def update_prev_visibility(tab): return gr.update(visible=tab > 0) with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# 📄➡️🧠 RAG Visualizer (PDF + FAISS)") gr.Markdown("### See how PDFs turn into vectors and answers") current_tab = gr.State(0) chunks_state = gr.State([]) embeddings_state = gr.State(None) faiss_index_state = gr.State(None) query_vec_state = gr.State(None) tabs = gr.Tabs(selected=0) with tabs: # ---------------- TAB 0 ---------------- with gr.Tab("📄 Upload PDF", id=0): pdf_file = gr.File(file_types=[".pdf"]) pdf_text = gr.Textbox(lines=8, label="Extracted Text") def load_pdf(file): return read_pdf(file) gr.Button("📖 Read PDF").click( load_pdf, inputs=pdf_file, outputs=pdf_text ) # ---------------- TAB 1 ---------------- with gr.Tab("✂️ Chunking", id=1): chunk_size = gr.Slider(50, 200, 100, step=10) chunk_table = gr.Dataframe() def run_chunking(text, size): chunks = chunk_text(text, size) df = pd.DataFrame({ "Chunk ID": range(len(chunks)), "Text": chunks }) return df, chunks gr.Button("✂️ Create Chunks").click( run_chunking, inputs=[pdf_text, chunk_size], outputs=[chunk_table, chunks_state] ) # ---------------- TAB 2 ---------------- with gr.Tab("🧠 Embeddings + FAISS", id=2): embed_info = gr.Markdown() def build_embeddings(chunks): vecs = embed_chunks(chunks) index = build_faiss_index(vecs) return f""" ### ✅ FAISS Index Ready - Chunks: **{len(chunks)}** - Vector Dim: **{vecs.shape[1]}** """, vecs, index gr.Button("🧠 Build FAISS Index").click( build_embeddings, inputs=chunks_state, outputs=[embed_info, embeddings_state, faiss_index_state] ) # ---------------- TAB 3 ---------------- with gr.Tab("🔍 Retrieval + 3D View", id=3): query = gr.Textbox(label="Ask a question") k = gr.Slider(1, 10, 3, step=1) results = gr.Dataframe() plot = gr.Plot() def retrieve(query, chunks, vectors, index, k): q_vec = model.encode([query]) idx, dist = search_faiss(index, q_vec, k) nearest_flags = [i == int(idx[0]) for i in idx] df = pd.DataFrame({ "Chunk ID": idx, "Distance": dist, "Nearest": nearest_flags, "Text": [chunks[i][:200] + "..." for i in idx] }) fig = visualize_3d(vectors, q_vec, highlight_idx=idx) return df, fig gr.Button("🔍 Search").click( retrieve, inputs=[query, chunks_state, embeddings_state, faiss_index_state, k], outputs=[results, plot] ) with gr.Tab("📝 Answer Generation", id=4): answer_query = gr.Textbox(label="Ask a question for answer") answer_k = gr.Slider(1, 5, 2, step=1, label="Number of chunks to use") answer_output = gr.Textbox(label="Generated Answer", lines=6) def generate_answer(query, chunks, vectors, index, k): if not query.strip(): return "Please enter a question." if vectors is None or index is None: return "Please build the FAISS index first." q_vec = model.encode([query]) idx, dist = search_faiss(index, q_vec, k) # Get the most relevant chunks relevant_chunks = [chunks[i] for i in idx] # Create a simple answer by combining the most relevant chunks answer = f"Based on the most relevant information found:\n\n" for i, chunk in enumerate(relevant_chunks): answer += f"**Relevant passage {i+1}:**\n{chunk}\n\n" answer += f"**Summary:** The above {len(relevant_chunks)} passage(s) contain the most relevant information to answer your question about: '{query}'" return answer gr.Button("📝 Generate Answer").click( generate_answer, inputs=[answer_query, chunks_state, embeddings_state, faiss_index_state, answer_k], outputs=answer_output ) with gr.Row(): nav_prev = gr.Button("⬅ Back", visible=False) nav_next = gr.Button("Next ➡ Chunking") nav_next.click( fn=go_next, inputs=current_tab, outputs=current_tab ).then( fn=lambda tab: gr.update(selected=tab), inputs=current_tab, outputs=tabs ).then( fn=update_next_label, inputs=current_tab, outputs=nav_next ).then( fn=update_prev_visibility, inputs=current_tab, outputs=nav_prev ) nav_prev.click( fn=go_prev, inputs=current_tab, outputs=current_tab ).then( fn=lambda tab: gr.update(selected=tab), inputs=current_tab, outputs=tabs ).then( fn=update_next_label, inputs=current_tab, outputs=nav_next ).then( fn=update_prev_visibility, inputs=current_tab, outputs=nav_prev ) if __name__ == "__main__": demo.launch(debug=True)