import streamlit as st import chromadb from sentence_transformers import SentenceTransformer import fitz # PyMuPDF import os import requests import re import hashlib # ─── Page Config ────────────────────────────────────────────────────────────── st.set_page_config( page_title="PDF RAG · Upload & Ask", page_icon="📂", layout="wide", initial_sidebar_state="expanded" ) # ─── CSS ────────────────────────────────────────────────────────────────────── st.markdown(""" """, unsafe_allow_html=True) # ─── Session State ──────────────────────────────────────────────────────────── if "indexed_files" not in st.session_state: st.session_state.indexed_files = {} # filename → {chunks, pages, size} if "chroma_collection" not in st.session_state: st.session_state.chroma_collection = None if "chroma_client" not in st.session_state: st.session_state.chroma_client = None if "total_chunks" not in st.session_state: st.session_state.total_chunks = 0 # ─── Load embedding model (cached globally) ─────────────────────────────────── @st.cache_resource(show_spinner=False) def load_embed_model(): return SentenceTransformer('all-MiniLM-L6-v2') # ─── PDF Extraction ─────────────────────────────────────────────────────────── def extract_text_from_pdf(pdf_bytes: bytes) -> list[dict]: """Returns list of {page, text} dicts.""" doc = fitz.open(stream=pdf_bytes, filetype="pdf") pages = [] for page_num, page in enumerate(doc, start=1): text = page.get_text("text").strip() if text: pages.append({"page": page_num, "text": text}) doc.close() return pages # ─── Chunking ───────────────────────────────────────────────────────────────── def chunk_text(pages: list[dict], chunk_size: int = 400, overlap: int = 60) -> list[dict]: """Splits page text into overlapping word-based chunks.""" chunks = [] for p in pages: words = p["text"].split() start = 0 while start < len(words): end = start + chunk_size chunk_words = words[start:end] chunk_text_str = " ".join(chunk_words).strip() if len(chunk_text_str) > 60: chunks.append({"text": chunk_text_str, "page": p["page"]}) start += chunk_size - overlap return chunks # ─── Index PDF into ChromaDB ────────────────────────────────────────────────── def index_pdf(filename: str, pdf_bytes: bytes, embed_model): # Init or reuse ChromaDB if st.session_state.chroma_client is None: st.session_state.chroma_client = chromadb.Client() st.session_state.chroma_collection = st.session_state.chroma_client.get_or_create_collection( name="pdf_rag", metadata={"hnsw:space": "cosine"} ) collection = st.session_state.chroma_collection # Extract & chunk pages = extract_text_from_pdf(pdf_bytes) chunks = chunk_text(pages) if not chunks: return 0, 0 # Embed & add texts = [c["text"] for c in chunks] embeddings = embed_model.encode(texts, batch_size=32, show_progress_bar=False).tolist() ids, docs, metas, embeds = [], [], [], [] for i, (chunk, emb) in enumerate(zip(chunks, embeddings)): chunk_id = f"{hashlib.md5(filename.encode()).hexdigest()[:8]}_chunk_{i}" ids.append(chunk_id) docs.append(chunk["text"]) metas.append({"filename": filename, "page": chunk["page"]}) embeds.append(emb) collection.add(ids=ids, embeddings=embeds, documents=docs, metadatas=metas) st.session_state.indexed_files[filename] = { "chunks": len(chunks), "pages": len(pages), "size_kb": round(len(pdf_bytes) / 1024, 1) } st.session_state.total_chunks += len(chunks) return len(chunks), len(pages) # ─── RAG Query ──────────────────────────────────────────────────────────────── def rag_query(question: str, embed_model, top_k: int, api_key: str): collection = st.session_state.chroma_collection q_emb = embed_model.encode(question).tolist() results = collection.query(query_embeddings=[q_emb], n_results=top_k) chunks = [] for i in range(len(results["documents"][0])): distance = results["distances"][0][i] chunks.append({ "text": results["documents"][0][i], "filename": results["metadatas"][0][i]["filename"], "page": results["metadatas"][0][i]["page"], "relevance": round((1 - distance) * 100, 1), }) context = "\n\n".join([ f"[Source: {c['filename']}, Page {c['page']}]\n{c['text']}" for c in chunks ]) prompt = f"""You are a helpful assistant. Answer the user's question using ONLY the document context provided below. Be concise and clear. Always mention the source filename and page number when referencing specific information. If the answer cannot be found in the provided context, say "I couldn't find that information in the uploaded documents." Document Context: {context} Question: {question} Answer:""" headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} payload = { "model": "llama-3.3-70b-versatile", "messages": [{"role": "user", "content": prompt}], "max_tokens": 600, "temperature": 0.2, } r = requests.post("https://api.groq.com/openai/v1/chat/completions", headers=headers, json=payload, timeout=30) r.raise_for_status() answer = r.json()["choices"][0]["message"]["content"] return answer, chunks # ─── Determine current phase ────────────────────────────────────────────────── has_docs = len(st.session_state.indexed_files) > 0 phase = 1 if not has_docs else 2 # ─── Sidebar ────────────────────────────────────────────────────────────────── with st.sidebar: st.markdown("## 📂 PDF RAG Demo") st.markdown("
Upload → Extract → Index → Ask
", unsafe_allow_html=True) st.markdown("---") env_key = os.environ.get("GROQ_API_KEY", "") if env_key: api_key = env_key st.success("✅ Groq key loaded from secrets") else: api_key = st.text_input("🔑 Groq API Key", type="password", placeholder="gsk_...", help="Free at console.groq.com") if not api_key: st.caption("Get free key → [console.groq.com](https://console.groq.com)") st.markdown("---") st.markdown("
Indexed Documents
", unsafe_allow_html=True) if st.session_state.indexed_files: for fname, info in st.session_state.indexed_files.items(): st.markdown(f"""
📄 {fname}
{info["pages"]}p · {info["chunks"]} chunks · {info["size_kb"]}KB
""", unsafe_allow_html=True) st.markdown("---") if st.button("🗑️ Clear all & reset", use_container_width=True): for key in ["indexed_files", "chroma_collection", "chroma_client", "total_chunks"]: del st.session_state[key] st.rerun() else: st.markdown("
No documents indexed yet.
", unsafe_allow_html=True) st.markdown("---") st.markdown("""
Stack
📄 PDF parsing: PyMuPDF
✂️ Chunking: word-overlap (400w)
🔢 Embeddings: all-MiniLM-L6-v2
🗄️ Vector DB: ChromaDB in-memory
🧠 LLM: Groq · Llama 3.3 70B
🌐 Hosting: HuggingFace Spaces
""", unsafe_allow_html=True) # ─── Hero ───────────────────────────────────────────────────────────────────── st.markdown("""

📂 PDF RAG — Upload & Ask

Upload any PDF documents · They get extracted, chunked, embedded, and indexed · Then ask questions across all of them

""", unsafe_allow_html=True) # Phase bar st.markdown(f"""
📤Upload PDFs
📑Extract Text
✂️Chunk
🔢Embed
🗄️Index
💬Ask Questions
""", unsafe_allow_html=True) # ─── Load model ─────────────────────────────────────────────────────────────── with st.spinner("⚙️ Loading embedding model..."): embed_model = load_embed_model() # ════════════════════════════════════════════════════════════ # PHASE 1 — Upload & Index # ════════════════════════════════════════════════════════════ st.markdown("
Step 1 — Upload PDF Documents
", unsafe_allow_html=True) uploaded_files = st.file_uploader( "Drop your PDF files here", type=["pdf"], accept_multiple_files=True, label_visibility="collapsed" ) if uploaded_files: new_files = [f for f in uploaded_files if f.name not in st.session_state.indexed_files] if new_files: st.markdown(f"**{len(new_files)} new file(s) ready to index:**") for f in new_files: st.markdown(f"
📄 {f.name}
{round(f.size/1024,1)} KB
ready
", unsafe_allow_html=True) if st.button(f"⚡ Extract & Index {len(new_files)} PDF(s)", type="primary", use_container_width=True): progress = st.progress(0, text="Starting...") for idx, f in enumerate(new_files): progress.progress((idx) / len(new_files), text=f"Processing: {f.name}") pdf_bytes = f.read() with st.spinner(f"Extracting & indexing **{f.name}**..."): n_chunks, n_pages = index_pdf(f.name, pdf_bytes, embed_model) st.success(f"✅ **{f.name}** → {n_pages} pages · {n_chunks} chunks indexed") progress.progress(1.0, text="Done!") st.balloons() st.rerun() else: st.info("All uploaded files are already indexed. Upload new files or ask questions below.") elif not has_docs: st.markdown("""
📂

No documents uploaded yet
Upload one or more PDF files above to get started.
Any topic works — reports, manuals, research papers, policies.

""", unsafe_allow_html=True) # ════════════════════════════════════════════════════════════ # PHASE 2 — Stats & Query # ════════════════════════════════════════════════════════════ if has_docs: total_pages = sum(v["pages"] for v in st.session_state.indexed_files.values()) st.markdown("
Index Summary
", unsafe_allow_html=True) st.markdown(f"""
{len(st.session_state.indexed_files)}
Documents
{total_pages}
Pages Parsed
{st.session_state.total_chunks}
Chunks Indexed
384
Embedding Dims
""", unsafe_allow_html=True) if not api_key: st.warning("👈 Enter your Groq API key in the sidebar to start asking questions.") st.stop() st.markdown("---") st.markdown("
Step 2 — Ask a Question
", unsafe_allow_html=True) col1, col2 = st.columns([5, 1]) with col1: question = st.text_input("", placeholder="What does the document say about...?", label_visibility="collapsed") with col2: top_k = st.selectbox("Top K", [2, 3, 4, 5], index=1, help="Number of chunks to retrieve") ask_btn = st.button("🔍 Search & Answer", type="primary", use_container_width=True) if ask_btn and question: with st.spinner("🔍 Searching index and generating answer..."): try: answer, chunks = rag_query(question, embed_model, top_k, api_key) st.markdown(f"
Answer
", unsafe_allow_html=True) st.markdown(f"
{answer}
", unsafe_allow_html=True) st.markdown("
Retrieved Chunks (context sent to LLM)
", unsafe_allow_html=True) for i, chunk in enumerate(chunks): bar_width = int(chunk['relevance']) st.markdown(f"""
📄 {chunk['filename']}
Page {chunk['page']}
{chunk['relevance']}%
{chunk['text']}
""", unsafe_allow_html=True) except requests.HTTPError as e: if e.response.status_code == 401: st.error("❌ Invalid Groq API key.") else: st.error(f"❌ API error: {str(e)}") except Exception as e: st.error(f"❌ Error: {str(e)}") elif ask_btn and not question: st.warning("Please enter a question.")