Spaces:
Build error
Build error
| import os | |
| import re | |
| import math | |
| from dataclasses import dataclass | |
| from typing import List, Tuple, Dict, Any | |
| import gradio as gr | |
| import numpy as np | |
| from pypdf import PdfReader | |
| from sentence_transformers import SentenceTransformer | |
| from groq import Groq | |
| # ----------------------------- | |
| # Utils | |
| # ----------------------------- | |
| def clean_text(t: str) -> str: | |
| t = t.replace("\x00", " ") | |
| t = re.sub(r"[ \t]+", " ", t) | |
| t = re.sub(r"\n{3,}", "\n\n", t) | |
| return t.strip() | |
| def split_into_sentences(text: str) -> List[str]: | |
| # Simple sentence split (works ok for English; for Urdu you can improve later) | |
| text = re.sub(r"\s+", " ", text).strip() | |
| if not text: | |
| return [] | |
| # Split on ., ?, ! with a small heuristic | |
| parts = re.split(r"(?<=[.!?])\s+", text) | |
| return [p.strip() for p in parts if p.strip()] | |
| def chunk_text_semantic( | |
| text: str, | |
| target_words: int = 180, | |
| overlap_words: int = 40 | |
| ) -> List[str]: | |
| """ | |
| Semantic-ish chunking: sentence-based, then pack sentences until target_words. | |
| Overlap via last overlap_words words from previous chunk. | |
| """ | |
| sents = split_into_sentences(text) | |
| chunks = [] | |
| cur = [] | |
| cur_words = 0 | |
| for s in sents: | |
| w = len(s.split()) | |
| if cur_words + w <= target_words or not cur: | |
| cur.append(s) | |
| cur_words += w | |
| else: | |
| chunk = " ".join(cur).strip() | |
| if chunk: | |
| chunks.append(chunk) | |
| # overlap: take last overlap_words from previous chunk | |
| prev_words = chunk.split() | |
| overlap = " ".join(prev_words[-overlap_words:]) if overlap_words > 0 else "" | |
| cur = ([overlap] if overlap else []) + [s] | |
| cur_words = len(" ".join(cur).split()) | |
| last = " ".join(cur).strip() | |
| if last: | |
| chunks.append(last) | |
| return chunks | |
| def cosine_sim_matrix(query_vec: np.ndarray, mat: np.ndarray) -> np.ndarray: | |
| # query_vec shape: (d,), mat: (n,d) | |
| q = query_vec / (np.linalg.norm(query_vec) + 1e-12) | |
| m = mat / (np.linalg.norm(mat, axis=1, keepdims=True) + 1e-12) | |
| return m @ q | |
| # ----------------------------- | |
| # Data structures | |
| # ----------------------------- | |
| class Chunk: | |
| doc_name: str | |
| page: int | |
| text: str | |
| # ----------------------------- | |
| # RAG Core | |
| # ----------------------------- | |
| class RAGChatbot: | |
| def __init__(self, embed_model_name: str = "sentence-transformers/all-MiniLM-L6-v2"): | |
| self.embedder = SentenceTransformer(embed_model_name) | |
| self.chunks: List[Chunk] = [] | |
| self.embeddings: np.ndarray = np.zeros((0, 384), dtype=np.float32) | |
| groq_key = os.getenv("GROQ_API_KEY", "").strip() | |
| if not groq_key: | |
| raise RuntimeError("GROQ_API_KEY env variable missing. Set it before running.") | |
| self.groq = Groq(api_key=groq_key) | |
| def ingest_pdfs(self, files: List[Any]) -> Dict[str, Any]: | |
| """ | |
| files: gradio uploaded file objects (have .name) | |
| """ | |
| all_chunks: List[Chunk] = [] | |
| for f in files: | |
| path = f.name | |
| doc_name = os.path.basename(path) | |
| reader = PdfReader(path) | |
| for i, page in enumerate(reader.pages): | |
| page_text = page.extract_text() or "" | |
| page_text = clean_text(page_text) | |
| if not page_text: | |
| continue | |
| # chunk per page, but chunk further semantically | |
| ctexts = chunk_text_semantic(page_text, target_words=180, overlap_words=40) | |
| for ct in ctexts: | |
| all_chunks.append(Chunk(doc_name=doc_name, page=i + 1, text=ct)) | |
| if not all_chunks: | |
| return {"ok": False, "msg": "No text extracted from PDFs (maybe scanned images). Try text-based PDFs."} | |
| texts = [c.text for c in all_chunks] | |
| embs = self.embedder.encode(texts, convert_to_numpy=True, normalize_embeddings=True) | |
| self.chunks = all_chunks | |
| self.embeddings = embs.astype(np.float32) | |
| return {"ok": True, "msg": f"Ingested {len(files)} PDF(s), built {len(all_chunks)} chunks."} | |
| def retrieve(self, query: str, top_k: int = 5) -> List[Tuple[Chunk, float]]: | |
| if self.embeddings.shape[0] == 0: | |
| return [] | |
| qv = self.embedder.encode([query], convert_to_numpy=True, normalize_embeddings=True)[0].astype(np.float32) | |
| sims = cosine_sim_matrix(qv, self.embeddings) # (n,) | |
| idx = np.argsort(-sims)[:top_k] | |
| return [(self.chunks[i], float(sims[i])) for i in idx] | |
| def build_prompt(self, question: str, retrieved: List[Tuple[Chunk, float]], chat_history: List[Tuple[str, str]]) -> str: | |
| # Short history window to avoid token explosion | |
| hist = chat_history[-6:] if chat_history else [] | |
| history_block = "" | |
| if hist: | |
| history_lines = [] | |
| for u, a in hist: | |
| history_lines.append(f"User: {u}") | |
| history_lines.append(f"Assistant: {a}") | |
| history_block = "\n".join(history_lines) | |
| context_lines = [] | |
| for ch, score in retrieved: | |
| context_lines.append(f"[{ch.doc_name} | page {ch.page} | score {score:.3f}]\n{ch.text}") | |
| context_block = "\n\n".join(context_lines) | |
| prompt = f"""You are a helpful RAG chatbot. | |
| Rules: | |
| - Answer ONLY using the provided context. If context is insufficient, say: "I don't have enough information in the uploaded PDFs." | |
| - Keep the answer clear and structured. | |
| - After the answer, include a "Sources" section listing document name + page numbers used. | |
| Chat history (may help follow-ups): | |
| {history_block if history_block else "(no prior history)"} | |
| Context: | |
| {context_block} | |
| Question: | |
| {question} | |
| Now write the answer. | |
| """ | |
| return prompt | |
| def ask_groq(self, prompt: str, model: str = "llama-3.1-8b-instant") -> str: | |
| resp = self.groq.chat.completions.create( | |
| model=model, | |
| messages=[ | |
| {"role": "system", "content": "You are a retrieval-augmented assistant."}, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| temperature=0.2, | |
| max_tokens=700, | |
| ) | |
| return resp.choices[0].message.content | |
| # ----------------------------- | |
| # Gradio App | |
| # ----------------------------- | |
| rag = None # will init lazily to show friendly errors | |
| def init_rag(): | |
| global rag | |
| if rag is None: | |
| rag = RAGChatbot() | |
| return rag | |
| def on_upload(files, state): | |
| bot = init_rag() | |
| result = bot.ingest_pdfs(files) | |
| # reset chat on new docs | |
| state = {"history": [], "ready": result["ok"]} | |
| status = result["msg"] | |
| return status, state | |
| def chat_fn(message, chat_history, state, top_k): | |
| bot = init_rag() | |
| if not state or not state.get("ready"): | |
| return chat_history, "Please upload PDF files first." | |
| retrieved = bot.retrieve(message, top_k=int(top_k)) | |
| if not retrieved: | |
| answer = "I don't have enough information in the uploaded PDFs." | |
| chat_history = chat_history + [(message, answer)] | |
| state["history"] = chat_history | |
| return chat_history, "" | |
| prompt = bot.build_prompt(message, retrieved, state.get("history", [])) | |
| answer = bot.ask_groq(prompt) | |
| chat_history = chat_history + [(message, answer)] | |
| state["history"] = chat_history | |
| return chat_history, "" | |
| def clear_chat(state): | |
| if state is None: | |
| state = {} | |
| state["history"] = [] | |
| return [], state | |
| with gr.Blocks(title="Enhanced RAG PDF Chatbot (Groq)") as demo: | |
| gr.Markdown("# 📄 Enhanced RAG-Based Chatbot (Groq + Multi-PDF)") | |
| gr.Markdown( | |
| "Upload multiple PDFs, then ask questions. The bot retrieves relevant chunks and answers with sources (page numbers)." | |
| ) | |
| state = gr.State({"history": [], "ready": False}) | |
| with gr.Row(): | |
| files = gr.File( | |
| file_types=[".pdf"], | |
| file_count="multiple", | |
| label="Upload PDF files" | |
| ) | |
| status = gr.Textbox(label="Status", interactive=False) | |
| with gr.Row(): | |
| top_k = gr.Slider(2, 10, value=5, step=1, label="Top-K chunks to retrieve") | |
| upload_btn = gr.Button("Build Knowledge Base") | |
| upload_btn.click(on_upload, inputs=[files, state], outputs=[status, state]) | |
| chatbot = gr.Chatbot(label="Chat", height=420) | |
| msg = gr.Textbox(label="Your question", placeholder="Ask something from the PDFs...") | |
| send = gr.Button("Send") | |
| clear = gr.Button("Clear Chat") | |
| send.click(chat_fn, inputs=[msg, chatbot, state, top_k], outputs=[chatbot, msg]) | |
| msg.submit(chat_fn, inputs=[msg, chatbot, state, top_k], outputs=[chatbot, msg]) | |
| clear.click(clear_chat, inputs=[state], outputs=[chatbot, state]) | |
| gr.Markdown( | |
| "### Notes\n" | |
| "- Set `GROQ_API_KEY` in HuggingFace Space secrets.\n" | |
| "- If your PDFs are scanned images, text extraction may fail (need OCR enhancement)." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |