Rag_chatbot / app.py
muhammad yasir
Update app.py
86bc089 verified
import os
import re
import math
from dataclasses import dataclass
from typing import List, Tuple, Dict, Any
import gradio as gr
import numpy as np
from pypdf import PdfReader
from sentence_transformers import SentenceTransformer
from groq import Groq
# -----------------------------
# Utils
# -----------------------------
def clean_text(t: str) -> str:
t = t.replace("\x00", " ")
t = re.sub(r"[ \t]+", " ", t)
t = re.sub(r"\n{3,}", "\n\n", t)
return t.strip()
def split_into_sentences(text: str) -> List[str]:
# Simple sentence split (works ok for English; for Urdu you can improve later)
text = re.sub(r"\s+", " ", text).strip()
if not text:
return []
# Split on ., ?, ! with a small heuristic
parts = re.split(r"(?<=[.!?])\s+", text)
return [p.strip() for p in parts if p.strip()]
def chunk_text_semantic(
text: str,
target_words: int = 180,
overlap_words: int = 40
) -> List[str]:
"""
Semantic-ish chunking: sentence-based, then pack sentences until target_words.
Overlap via last overlap_words words from previous chunk.
"""
sents = split_into_sentences(text)
chunks = []
cur = []
cur_words = 0
for s in sents:
w = len(s.split())
if cur_words + w <= target_words or not cur:
cur.append(s)
cur_words += w
else:
chunk = " ".join(cur).strip()
if chunk:
chunks.append(chunk)
# overlap: take last overlap_words from previous chunk
prev_words = chunk.split()
overlap = " ".join(prev_words[-overlap_words:]) if overlap_words > 0 else ""
cur = ([overlap] if overlap else []) + [s]
cur_words = len(" ".join(cur).split())
last = " ".join(cur).strip()
if last:
chunks.append(last)
return chunks
def cosine_sim_matrix(query_vec: np.ndarray, mat: np.ndarray) -> np.ndarray:
# query_vec shape: (d,), mat: (n,d)
q = query_vec / (np.linalg.norm(query_vec) + 1e-12)
m = mat / (np.linalg.norm(mat, axis=1, keepdims=True) + 1e-12)
return m @ q
# -----------------------------
# Data structures
# -----------------------------
@dataclass
class Chunk:
doc_name: str
page: int
text: str
# -----------------------------
# RAG Core
# -----------------------------
class RAGChatbot:
def __init__(self, embed_model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
self.embedder = SentenceTransformer(embed_model_name)
self.chunks: List[Chunk] = []
self.embeddings: np.ndarray = np.zeros((0, 384), dtype=np.float32)
groq_key = os.getenv("GROQ_API_KEY", "").strip()
if not groq_key:
raise RuntimeError("GROQ_API_KEY env variable missing. Set it before running.")
self.groq = Groq(api_key=groq_key)
def ingest_pdfs(self, files: List[Any]) -> Dict[str, Any]:
"""
files: gradio uploaded file objects (have .name)
"""
all_chunks: List[Chunk] = []
for f in files:
path = f.name
doc_name = os.path.basename(path)
reader = PdfReader(path)
for i, page in enumerate(reader.pages):
page_text = page.extract_text() or ""
page_text = clean_text(page_text)
if not page_text:
continue
# chunk per page, but chunk further semantically
ctexts = chunk_text_semantic(page_text, target_words=180, overlap_words=40)
for ct in ctexts:
all_chunks.append(Chunk(doc_name=doc_name, page=i + 1, text=ct))
if not all_chunks:
return {"ok": False, "msg": "No text extracted from PDFs (maybe scanned images). Try text-based PDFs."}
texts = [c.text for c in all_chunks]
embs = self.embedder.encode(texts, convert_to_numpy=True, normalize_embeddings=True)
self.chunks = all_chunks
self.embeddings = embs.astype(np.float32)
return {"ok": True, "msg": f"Ingested {len(files)} PDF(s), built {len(all_chunks)} chunks."}
def retrieve(self, query: str, top_k: int = 5) -> List[Tuple[Chunk, float]]:
if self.embeddings.shape[0] == 0:
return []
qv = self.embedder.encode([query], convert_to_numpy=True, normalize_embeddings=True)[0].astype(np.float32)
sims = cosine_sim_matrix(qv, self.embeddings) # (n,)
idx = np.argsort(-sims)[:top_k]
return [(self.chunks[i], float(sims[i])) for i in idx]
def build_prompt(self, question: str, retrieved: List[Tuple[Chunk, float]], chat_history: List[Tuple[str, str]]) -> str:
# Short history window to avoid token explosion
hist = chat_history[-6:] if chat_history else []
history_block = ""
if hist:
history_lines = []
for u, a in hist:
history_lines.append(f"User: {u}")
history_lines.append(f"Assistant: {a}")
history_block = "\n".join(history_lines)
context_lines = []
for ch, score in retrieved:
context_lines.append(f"[{ch.doc_name} | page {ch.page} | score {score:.3f}]\n{ch.text}")
context_block = "\n\n".join(context_lines)
prompt = f"""You are a helpful RAG chatbot.
Rules:
- Answer ONLY using the provided context. If context is insufficient, say: "I don't have enough information in the uploaded PDFs."
- Keep the answer clear and structured.
- After the answer, include a "Sources" section listing document name + page numbers used.
Chat history (may help follow-ups):
{history_block if history_block else "(no prior history)"}
Context:
{context_block}
Question:
{question}
Now write the answer.
"""
return prompt
def ask_groq(self, prompt: str, model: str = "llama-3.1-8b-instant") -> str:
resp = self.groq.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": "You are a retrieval-augmented assistant."},
{"role": "user", "content": prompt},
],
temperature=0.2,
max_tokens=700,
)
return resp.choices[0].message.content
# -----------------------------
# Gradio App
# -----------------------------
rag = None # will init lazily to show friendly errors
def init_rag():
global rag
if rag is None:
rag = RAGChatbot()
return rag
def on_upload(files, state):
bot = init_rag()
result = bot.ingest_pdfs(files)
# reset chat on new docs
state = {"history": [], "ready": result["ok"]}
status = result["msg"]
return status, state
def chat_fn(message, chat_history, state, top_k):
bot = init_rag()
if not state or not state.get("ready"):
return chat_history, "Please upload PDF files first."
retrieved = bot.retrieve(message, top_k=int(top_k))
if not retrieved:
answer = "I don't have enough information in the uploaded PDFs."
chat_history = chat_history + [(message, answer)]
state["history"] = chat_history
return chat_history, ""
prompt = bot.build_prompt(message, retrieved, state.get("history", []))
answer = bot.ask_groq(prompt)
chat_history = chat_history + [(message, answer)]
state["history"] = chat_history
return chat_history, ""
def clear_chat(state):
if state is None:
state = {}
state["history"] = []
return [], state
with gr.Blocks(title="Enhanced RAG PDF Chatbot (Groq)") as demo:
gr.Markdown("# 📄 Enhanced RAG-Based Chatbot (Groq + Multi-PDF)")
gr.Markdown(
"Upload multiple PDFs, then ask questions. The bot retrieves relevant chunks and answers with sources (page numbers)."
)
state = gr.State({"history": [], "ready": False})
with gr.Row():
files = gr.File(
file_types=[".pdf"],
file_count="multiple",
label="Upload PDF files"
)
status = gr.Textbox(label="Status", interactive=False)
with gr.Row():
top_k = gr.Slider(2, 10, value=5, step=1, label="Top-K chunks to retrieve")
upload_btn = gr.Button("Build Knowledge Base")
upload_btn.click(on_upload, inputs=[files, state], outputs=[status, state])
chatbot = gr.Chatbot(label="Chat", height=420)
msg = gr.Textbox(label="Your question", placeholder="Ask something from the PDFs...")
send = gr.Button("Send")
clear = gr.Button("Clear Chat")
send.click(chat_fn, inputs=[msg, chatbot, state, top_k], outputs=[chatbot, msg])
msg.submit(chat_fn, inputs=[msg, chatbot, state, top_k], outputs=[chatbot, msg])
clear.click(clear_chat, inputs=[state], outputs=[chatbot, state])
gr.Markdown(
"### Notes\n"
"- Set `GROQ_API_KEY` in HuggingFace Space secrets.\n"
"- If your PDFs are scanned images, text extraction may fail (need OCR enhancement)."
)
if __name__ == "__main__":
demo.launch()