RagViz / app.py
Pranesh64's picture
Update app.py
0b16e2b verified
import gradio as gr
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import faiss
import PyPDF2
from sentence_transformers import SentenceTransformer
from sklearn.decomposition import PCA
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
def read_pdf(file):
reader = PyPDF2.PdfReader(file)
text = ""
for page in reader.pages:
text += page.extract_text() + "\n"
return text.strip()
def chunk_text(text, chunk_size):
words = text.split()
return [
" ".join(words[i:i + chunk_size])
for i in range(0, len(words), chunk_size)
]
def embed_chunks(chunks):
if len(chunks) == 0:
return None
return model.encode(chunks)
def build_faiss_index(vectors):
dim = vectors.shape[1]
index = faiss.IndexFlatL2(dim)
index.add(vectors)
return index
def search_faiss(index, query_vec, k):
distances, indices = index.search(query_vec, k)
return indices[0], distances[0]
def visualize_3d(vectors, query_vec, highlight_idx=None):
if vectors is None or query_vec is None:
return go.Figure()
all_vecs = np.vstack([vectors, query_vec])
reduced = PCA(n_components=3).fit_transform(all_vecs)
fig = go.Figure()
# Prepare indices
n_chunks = reduced.shape[0] - 1
all_indices = np.arange(n_chunks)
highlight_idx = [] if highlight_idx is None else list(map(int, highlight_idx))
# normal chunk points (exclude highlighted)
normal_mask = [i for i in all_indices if i not in highlight_idx]
if len(normal_mask) > 0:
fig.add_trace(go.Scatter3d(
x=reduced[normal_mask, 0],
y=reduced[normal_mask, 1],
z=reduced[normal_mask, 2],
mode="markers",
marker=dict(size=4, color="skyblue"),
name="Chunks"
))
# highlighted chunk points
if len(highlight_idx) > 0:
# ensure we only take valid indices
valid_highlights = [i for i in highlight_idx if 0 <= i < n_chunks]
if len(valid_highlights) > 0:
fig.add_trace(go.Scatter3d(
x=reduced[valid_highlights, 0],
y=reduced[valid_highlights, 1],
z=reduced[valid_highlights, 2],
mode="markers+text",
text=[f"Top-{j+1}" for j in range(len(valid_highlights))],
marker=dict(size=8, color="limegreen", symbol="diamond"),
name="Nearest Chunks"
))
# query point
fig.add_trace(go.Scatter3d(
x=[reduced[-1, 0]],
y=[reduced[-1, 1]],
z=[reduced[-1, 2]],
mode="markers+text",
text=["Query"],
marker=dict(size=10, color="red"),
name="Query"
))
fig.update_layout(
title="📐 RAG Vector Space (3D)",
height=500
)
return fig
TAB_LABELS = {
0: "Next ➡ Chunking",
1: "Next ➡ Embeddings",
2: "Next ➡ FAISS Search",
3: "Next ➡ Answer",
4: "Done ✅"
}
def go_next(tab): return min(tab + 1, 4)
def go_prev(tab): return max(tab - 1, 0)
def update_next_label(tab):
return gr.update(value=TAB_LABELS[tab])
def update_prev_visibility(tab):
return gr.update(visible=tab > 0)
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 📄➡️🧠 RAG Visualizer (PDF + FAISS)")
gr.Markdown("### See how PDFs turn into vectors and answers")
current_tab = gr.State(0)
chunks_state = gr.State([])
embeddings_state = gr.State(None)
faiss_index_state = gr.State(None)
query_vec_state = gr.State(None)
tabs = gr.Tabs(selected=0)
with tabs:
# ---------------- TAB 0 ----------------
with gr.Tab("📄 Upload PDF", id=0):
pdf_file = gr.File(file_types=[".pdf"])
pdf_text = gr.Textbox(lines=8, label="Extracted Text")
def load_pdf(file):
return read_pdf(file)
gr.Button("📖 Read PDF").click(
load_pdf,
inputs=pdf_file,
outputs=pdf_text
)
# ---------------- TAB 1 ----------------
with gr.Tab("✂️ Chunking", id=1):
chunk_size = gr.Slider(50, 200, 100, step=10)
chunk_table = gr.Dataframe()
def run_chunking(text, size):
chunks = chunk_text(text, size)
df = pd.DataFrame({
"Chunk ID": range(len(chunks)),
"Text": chunks
})
return df, chunks
gr.Button("✂️ Create Chunks").click(
run_chunking,
inputs=[pdf_text, chunk_size],
outputs=[chunk_table, chunks_state]
)
# ---------------- TAB 2 ----------------
with gr.Tab("🧠 Embeddings + FAISS", id=2):
embed_info = gr.Markdown()
def build_embeddings(chunks):
vecs = embed_chunks(chunks)
index = build_faiss_index(vecs)
return f"""
### ✅ FAISS Index Ready
- Chunks: **{len(chunks)}**
- Vector Dim: **{vecs.shape[1]}**
""", vecs, index
gr.Button("🧠 Build FAISS Index").click(
build_embeddings,
inputs=chunks_state,
outputs=[embed_info, embeddings_state, faiss_index_state]
)
# ---------------- TAB 3 ----------------
with gr.Tab("🔍 Retrieval + 3D View", id=3):
query = gr.Textbox(label="Ask a question")
k = gr.Slider(1, 10, 3, step=1)
results = gr.Dataframe()
plot = gr.Plot()
def retrieve(query, chunks, vectors, index, k):
q_vec = model.encode([query])
idx, dist = search_faiss(index, q_vec, k)
nearest_flags = [i == int(idx[0]) for i in idx]
df = pd.DataFrame({
"Chunk ID": idx,
"Distance": dist,
"Nearest": nearest_flags,
"Text": [chunks[i][:200] + "..." for i in idx]
})
fig = visualize_3d(vectors, q_vec, highlight_idx=idx)
return df, fig
gr.Button("🔍 Search").click(
retrieve,
inputs=[query, chunks_state, embeddings_state, faiss_index_state, k],
outputs=[results, plot]
)
with gr.Tab("📝 Answer Generation", id=4):
answer_query = gr.Textbox(label="Ask a question for answer")
answer_k = gr.Slider(1, 5, 2, step=1, label="Number of chunks to use")
answer_output = gr.Textbox(label="Generated Answer", lines=6)
def generate_answer(query, chunks, vectors, index, k):
if not query.strip():
return "Please enter a question."
if vectors is None or index is None:
return "Please build the FAISS index first."
q_vec = model.encode([query])
idx, dist = search_faiss(index, q_vec, k)
# Get the most relevant chunks
relevant_chunks = [chunks[i] for i in idx]
# Create a simple answer by combining the most relevant chunks
answer = f"Based on the most relevant information found:\n\n"
for i, chunk in enumerate(relevant_chunks):
answer += f"**Relevant passage {i+1}:**\n{chunk}\n\n"
answer += f"**Summary:** The above {len(relevant_chunks)} passage(s) contain the most relevant information to answer your question about: '{query}'"
return answer
gr.Button("📝 Generate Answer").click(
generate_answer,
inputs=[answer_query, chunks_state, embeddings_state, faiss_index_state, answer_k],
outputs=answer_output
)
with gr.Row():
nav_prev = gr.Button("⬅ Back", visible=False)
nav_next = gr.Button("Next ➡ Chunking")
nav_next.click(
fn=go_next,
inputs=current_tab,
outputs=current_tab
).then(
fn=lambda tab: gr.update(selected=tab),
inputs=current_tab,
outputs=tabs
).then(
fn=update_next_label,
inputs=current_tab,
outputs=nav_next
).then(
fn=update_prev_visibility,
inputs=current_tab,
outputs=nav_prev
)
nav_prev.click(
fn=go_prev,
inputs=current_tab,
outputs=current_tab
).then(
fn=lambda tab: gr.update(selected=tab),
inputs=current_tab,
outputs=tabs
).then(
fn=update_next_label,
inputs=current_tab,
outputs=nav_next
).then(
fn=update_prev_visibility,
inputs=current_tab,
outputs=nav_prev
)
if __name__ == "__main__":
demo.launch(debug=True)