Likhitha2805's picture
Update app.py
578151d verified
import os
import re
import gradio as gr
import chromadb
import PyPDF2
from groq import Groq
from sentence_transformers import SentenceTransformer, CrossEncoder
from rank_bm25 import BM25Okapi
# ── Models ──────────────────────────────────────────────────────────────────
groq_client = Groq(api_key=os.getenv("GROQ_API_KEY"))
embedder = SentenceTransformer("all-MiniLM-L6-v2")
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
# ── ChromaDB (in-memory) ─────────────────────────────────────────────────────
chroma_client = chromadb.Client()
collection = chroma_client.get_or_create_collection("rag_docs")
# ── Global store for BM25 ────────────────────────────────────────────────────
doc_store = [] # list of {"id": str, "text": str, "source": str}
# ── Helpers ──────────────────────────────────────────────────────────────────
def chunk_text(text, chunk_size=400, overlap=80):
"""Split text into overlapping chunks."""
words = text.split()
chunks = []
start = 0
while start < len(words):
chunk = " ".join(words[start : start + chunk_size])
chunks.append(chunk)
start += chunk_size - overlap
return chunks
def extract_text_from_pdf(file_path):
text = ""
with open(file_path, "rb") as f:
reader = PyPDF2.PdfReader(f)
for page in reader.pages:
text += page.extract_text() or ""
return text
def extract_text_from_txt(file_path):
with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
return f.read()
# ── Ingest ───────────────────────────────────────────────────────────────────
def ingest_documents(files):
global doc_store, collection
if not files:
return "⚠️ No files uploaded."
# Reset
doc_store = []
chroma_client.delete_collection("rag_docs")
collection = chroma_client.get_or_create_collection("rag_docs")
total_chunks = 0
file_names = []
for file in files:
path = file.name
name = os.path.basename(path)
file_names.append(name)
if path.endswith(".pdf"):
raw_text = extract_text_from_pdf(path)
else:
raw_text = extract_text_from_txt(path)
if not raw_text.strip():
continue
chunks = chunk_text(raw_text)
for i, chunk in enumerate(chunks):
chunk_id = f"{name}_chunk_{i}"
embedding = embedder.encode(chunk).tolist()
# Add to ChromaDB
collection.add(
ids=[chunk_id],
embeddings=[embedding],
documents=[chunk],
metadatas=[{"source": name, "chunk_index": i}]
)
# Add to doc_store for BM25
doc_store.append({"id": chunk_id, "text": chunk, "source": name})
total_chunks += 1
return f"βœ… Ingested {total_chunks} chunks from: {', '.join(file_names)}"
# ── Retrieval ────────────────────────────────────────────────────────────────
def vector_search(query, top_k=10):
query_embedding = embedder.encode(query).tolist()
results = collection.query(
query_embeddings=[query_embedding],
n_results=min(top_k, len(doc_store))
)
chunks = []
for doc, meta in zip(results["documents"][0], results["metadatas"][0]):
chunks.append({"text": doc, "source": meta["source"]})
return chunks
def bm25_search(query, top_k=10):
if not doc_store:
return []
tokenized_corpus = [d["text"].lower().split() for d in doc_store]
bm25 = BM25Okapi(tokenized_corpus)
tokenized_query = query.lower().split()
scores = bm25.get_scores(tokenized_query)
top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_k]
return [{"text": doc_store[i]["text"], "source": doc_store[i]["source"]} for i in top_indices]
def reciprocal_rank_fusion(vector_results, bm25_results, k=60):
"""Merge two ranked lists using RRF."""
scores = {}
all_chunks = {}
for rank, chunk in enumerate(vector_results):
key = chunk["text"][:100]
scores[key] = scores.get(key, 0) + 1 / (k + rank + 1)
all_chunks[key] = chunk
for rank, chunk in enumerate(bm25_results):
key = chunk["text"][:100]
scores[key] = scores.get(key, 0) + 1 / (k + rank + 1)
all_chunks[key] = chunk
sorted_keys = sorted(scores, key=lambda x: scores[x], reverse=True)
return [all_chunks[k] for k in sorted_keys]
def rerank(query, chunks, top_k=5):
"""Cross-encoder reranking."""
if not chunks:
return []
pairs = [(query, c["text"]) for c in chunks]
scores = reranker.predict(pairs)
ranked = sorted(zip(scores, chunks), key=lambda x: x[0], reverse=True)
return [c for _, c in ranked[:top_k]]
# ── Generate answer ──────────────────────────────────────────────────────────
def generate_answer(query, chunks):
context_parts = []
for i, chunk in enumerate(chunks):
context_parts.append(f"[Source {i+1}: {chunk['source']}]\n{chunk['text']}")
context = "\n\n".join(context_parts)
prompt = f"""You are a helpful research assistant. Answer the question based ONLY on the provided context.
For every claim you make, cite the source using [Source N] notation.
If the answer is not in the context, say "I couldn't find relevant information in the uploaded documents."
Question: {query}
Context:
{context}
Answer (with citations):"""
response = groq_client.chat.completions.create(
model="llama-3.3-70b-versatile",
messages=[{"role": "user", "content": prompt}],
max_tokens=1024
)
return response.choices[0].message.content
# ── Main pipeline ─────────────────────────────────────────────────────────────
def ask_question(query):
if not query.strip():
return "⚠️ Please enter a question.", ""
if not doc_store:
return "⚠️ Please upload and ingest documents first.", ""
# Step 1: Hybrid retrieval
vector_results = vector_search(query, top_k=10)
bm25_results = bm25_search(query, top_k=10)
# Step 2: RRF fusion
fused = reciprocal_rank_fusion(vector_results, bm25_results)
# Step 3: Rerank
top_chunks = rerank(query, fused, top_k=5)
# Step 4: Generate answer
answer = generate_answer(query, top_chunks)
# Step 5: Format sources
seen = []
for chunk in top_chunks:
if chunk["source"] not in seen:
seen.append(chunk["source"])
sources = "\n".join([f"β€’ {s}" for s in seen])
return answer, f"πŸ“„ Sources referenced:\n{sources}"
# ── Gradio UI ─────────────────────────────────────────────────────────────────
with gr.Blocks(title="Domain RAG System β€” Ask My Doc") as demo:
gr.Markdown("# πŸ“š Ask My Doc β€” Domain-Specific RAG System")
gr.Markdown(
"Upload your documents (PDF or TXT), then ask questions. "
"Uses **hybrid retrieval** (vector search + BM25) and **cross-encoder reranking** "
"to give you cited, accurate answers."
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Step 1: Upload Documents")
file_upload = gr.File(
label="Upload PDF or TXT files",
file_types=[".pdf", ".txt"],
file_count="multiple"
)
ingest_btn = gr.Button("πŸ“₯ Ingest Documents", variant="primary")
ingest_status = gr.Textbox(label="Status", interactive=False)
with gr.Column(scale=2):
gr.Markdown("### Step 2: Ask a Question")
question_input = gr.Textbox(
label="Your Question",
placeholder="e.g. What are the key findings in this document?",
lines=2
)
ask_btn = gr.Button("πŸ” Get Answer", variant="primary")
answer_output = gr.Textbox(label="Answer", lines=10, interactive=False)
sources_output = gr.Textbox(label="Sources", lines=3, interactive=False)
ingest_btn.click(fn=ingest_documents, inputs=[file_upload], outputs=[ingest_status])
ask_btn.click(fn=ask_question, inputs=[question_input], outputs=[answer_output, sources_output])
demo.launch()