Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import chromadb | |
| from sentence_transformers import SentenceTransformer | |
| import fitz # PyMuPDF | |
| import os | |
| import requests | |
| import re | |
| import hashlib | |
| # βββ Page Config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.set_page_config( | |
| page_title="PDF RAG Β· Upload & Ask", | |
| page_icon="π", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # βββ CSS ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.markdown(""" | |
| <style> | |
| @import url('https://fonts.googleapis.com/css2?family=IBM+Plex+Sans:wght@300;400;500;600&family=IBM+Plex+Mono:wght@400;500&display=swap'); | |
| html, body, [class*="css"] { font-family: 'IBM Plex Sans', sans-serif; } | |
| .main { background-color: #0b0f1a; } | |
| .hero { | |
| background: linear-gradient(160deg, #0d1424 0%, #0b0f1a 100%); | |
| border: 1px solid #1e2a3e; | |
| border-top: 3px solid #22d3ee; | |
| border-radius: 12px; | |
| padding: 28px 32px; | |
| margin-bottom: 24px; | |
| } | |
| .hero h1 { font-size: 1.8rem; font-weight: 600; color: #e2e8f0; margin: 0 0 6px 0; } | |
| .hero p { color: #64748b; font-size: 0.95rem; margin: 0; } | |
| .phase-bar { | |
| display: flex; gap: 0; margin-bottom: 28px; | |
| border: 1px solid #1e2a3e; border-radius: 10px; overflow: hidden; | |
| } | |
| .phase { | |
| flex: 1; padding: 10px 6px; text-align: center; | |
| font-size: 0.75rem; color: #4b5563; background: #0d1117; | |
| border-right: 1px solid #1e2a3e; line-height: 1.5; | |
| } | |
| .phase:last-child { border-right: none; } | |
| .phase.done { color: #22d3ee; background: rgba(34,211,238,0.05); } | |
| .phase.active { color: #f8fafc; background: rgba(34,211,238,0.1); font-weight: 600; } | |
| .phase-icon { font-size: 1.1rem; display: block; margin-bottom: 2px; } | |
| .pdf-card { | |
| background: #0d1424; | |
| border: 1px solid #1e2a3e; | |
| border-radius: 10px; | |
| padding: 14px 16px; | |
| margin: 8px 0; | |
| display: flex; | |
| align-items: center; | |
| justify-content: space-between; | |
| } | |
| .pdf-name { font-size: 0.85rem; color: #e2e8f0; font-weight: 500; } | |
| .pdf-meta { font-family: 'IBM Plex Mono', monospace; font-size: 0.72rem; color: #475569; margin-top: 3px; } | |
| .pdf-badge { | |
| font-size: 0.72rem; font-family: 'IBM Plex Mono', monospace; | |
| background: rgba(34,211,238,0.1); color: #22d3ee; | |
| border: 1px solid rgba(34,211,238,0.25); padding: 3px 10px; border-radius: 20px; | |
| } | |
| .answer-box { | |
| background: #0d1424; | |
| border: 1px solid #1e3a4a; | |
| border-left: 3px solid #22d3ee; | |
| border-radius: 10px; | |
| padding: 22px 24px; | |
| color: #e2e8f0; | |
| line-height: 1.75; | |
| font-size: 0.96rem; | |
| margin: 12px 0 20px 0; | |
| } | |
| .chunk-card { | |
| background: #0d1117; | |
| border: 1px solid #1e2a3e; | |
| border-radius: 9px; | |
| padding: 14px 18px; | |
| margin: 8px 0; | |
| } | |
| .chunk-top { | |
| display: flex; justify-content: space-between; | |
| align-items: center; margin-bottom: 8px; | |
| } | |
| .chunk-source { font-size: 0.77rem; font-weight: 600; color: #22d3ee; text-transform: uppercase; letter-spacing: 0.05em; } | |
| .chunk-page { font-family: 'IBM Plex Mono', monospace; font-size: 0.72rem; color: #475569; } | |
| .score-bar-wrap { display: flex; align-items: center; gap: 8px; } | |
| .score-bar { | |
| height: 4px; border-radius: 2px; background: #1e2a3e; width: 80px; overflow: hidden; | |
| } | |
| .score-fill { height: 100%; border-radius: 2px; background: #22d3ee; } | |
| .score-num { font-family: 'IBM Plex Mono', monospace; font-size: 0.72rem; color: #22d3ee; } | |
| .chunk-text { font-size: 0.86rem; color: #94a3b8; line-height: 1.65; } | |
| .stat-row { display: flex; gap: 10px; margin: 16px 0; } | |
| .stat-box { | |
| flex: 1; background: #0d1424; border: 1px solid #1e2a3e; | |
| border-radius: 8px; padding: 12px; text-align: center; | |
| } | |
| .stat-val { font-size: 1.35rem; font-weight: 600; color: #22d3ee; } | |
| .stat-lbl { font-size: 0.7rem; color: #475569; margin-top: 2px; } | |
| .section-label { | |
| font-size: 0.7rem; text-transform: uppercase; letter-spacing: 0.1em; | |
| color: #374151; font-weight: 600; margin: 18px 0 8px 0; | |
| } | |
| section[data-testid="stSidebar"] { | |
| background-color: #080c14; border-right: 1px solid #131c2e; | |
| } | |
| .empty-state { | |
| text-align: center; padding: 48px 24px; | |
| border: 2px dashed #1e2a3e; border-radius: 12px; color: #374151; | |
| } | |
| .empty-state .icon { font-size: 2.5rem; margin-bottom: 12px; } | |
| .empty-state p { font-size: 0.9rem; line-height: 1.6; } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # βββ Session State ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if "indexed_files" not in st.session_state: | |
| st.session_state.indexed_files = {} # filename β {chunks, pages, size} | |
| if "chroma_collection" not in st.session_state: | |
| st.session_state.chroma_collection = None | |
| if "chroma_client" not in st.session_state: | |
| st.session_state.chroma_client = None | |
| if "total_chunks" not in st.session_state: | |
| st.session_state.total_chunks = 0 | |
| # βββ Load embedding model (cached globally) βββββββββββββββββββββββββββββββββββ | |
| def load_embed_model(): | |
| return SentenceTransformer('all-MiniLM-L6-v2') | |
| # βββ PDF Extraction βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def extract_text_from_pdf(pdf_bytes: bytes) -> list[dict]: | |
| """Returns list of {page, text} dicts.""" | |
| doc = fitz.open(stream=pdf_bytes, filetype="pdf") | |
| pages = [] | |
| for page_num, page in enumerate(doc, start=1): | |
| text = page.get_text("text").strip() | |
| if text: | |
| pages.append({"page": page_num, "text": text}) | |
| doc.close() | |
| return pages | |
| # βββ Chunking βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def chunk_text(pages: list[dict], chunk_size: int = 400, overlap: int = 60) -> list[dict]: | |
| """Splits page text into overlapping word-based chunks.""" | |
| chunks = [] | |
| for p in pages: | |
| words = p["text"].split() | |
| start = 0 | |
| while start < len(words): | |
| end = start + chunk_size | |
| chunk_words = words[start:end] | |
| chunk_text_str = " ".join(chunk_words).strip() | |
| if len(chunk_text_str) > 60: | |
| chunks.append({"text": chunk_text_str, "page": p["page"]}) | |
| start += chunk_size - overlap | |
| return chunks | |
| # βββ Index PDF into ChromaDB ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def index_pdf(filename: str, pdf_bytes: bytes, embed_model): | |
| # Init or reuse ChromaDB | |
| if st.session_state.chroma_client is None: | |
| st.session_state.chroma_client = chromadb.Client() | |
| st.session_state.chroma_collection = st.session_state.chroma_client.get_or_create_collection( | |
| name="pdf_rag", metadata={"hnsw:space": "cosine"} | |
| ) | |
| collection = st.session_state.chroma_collection | |
| # Extract & chunk | |
| pages = extract_text_from_pdf(pdf_bytes) | |
| chunks = chunk_text(pages) | |
| if not chunks: | |
| return 0, 0 | |
| # Embed & add | |
| texts = [c["text"] for c in chunks] | |
| embeddings = embed_model.encode(texts, batch_size=32, show_progress_bar=False).tolist() | |
| ids, docs, metas, embeds = [], [], [], [] | |
| for i, (chunk, emb) in enumerate(zip(chunks, embeddings)): | |
| chunk_id = f"{hashlib.md5(filename.encode()).hexdigest()[:8]}_chunk_{i}" | |
| ids.append(chunk_id) | |
| docs.append(chunk["text"]) | |
| metas.append({"filename": filename, "page": chunk["page"]}) | |
| embeds.append(emb) | |
| collection.add(ids=ids, embeddings=embeds, documents=docs, metadatas=metas) | |
| st.session_state.indexed_files[filename] = { | |
| "chunks": len(chunks), | |
| "pages": len(pages), | |
| "size_kb": round(len(pdf_bytes) / 1024, 1) | |
| } | |
| st.session_state.total_chunks += len(chunks) | |
| return len(chunks), len(pages) | |
| # βββ RAG Query ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def rag_query(question: str, embed_model, top_k: int, api_key: str): | |
| collection = st.session_state.chroma_collection | |
| q_emb = embed_model.encode(question).tolist() | |
| results = collection.query(query_embeddings=[q_emb], n_results=top_k) | |
| chunks = [] | |
| for i in range(len(results["documents"][0])): | |
| distance = results["distances"][0][i] | |
| chunks.append({ | |
| "text": results["documents"][0][i], | |
| "filename": results["metadatas"][0][i]["filename"], | |
| "page": results["metadatas"][0][i]["page"], | |
| "relevance": round((1 - distance) * 100, 1), | |
| }) | |
| context = "\n\n".join([ | |
| f"[Source: {c['filename']}, Page {c['page']}]\n{c['text']}" for c in chunks | |
| ]) | |
| prompt = f"""You are a helpful assistant. Answer the user's question using ONLY the document context provided below. Be concise and clear. Always mention the source filename and page number when referencing specific information. If the answer cannot be found in the provided context, say "I couldn't find that information in the uploaded documents." | |
| Document Context: | |
| {context} | |
| Question: {question} | |
| Answer:""" | |
| headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} | |
| payload = { | |
| "model": "llama-3.3-70b-versatile", | |
| "messages": [{"role": "user", "content": prompt}], | |
| "max_tokens": 600, | |
| "temperature": 0.2, | |
| } | |
| r = requests.post("https://api.groq.com/openai/v1/chat/completions", headers=headers, json=payload, timeout=30) | |
| r.raise_for_status() | |
| answer = r.json()["choices"][0]["message"]["content"] | |
| return answer, chunks | |
| # βββ Determine current phase ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| has_docs = len(st.session_state.indexed_files) > 0 | |
| phase = 1 if not has_docs else 2 | |
| # βββ Sidebar ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with st.sidebar: | |
| st.markdown("## π PDF RAG Demo") | |
| st.markdown("<div style='color:#374151;font-size:0.8rem'>Upload β Extract β Index β Ask</div>", unsafe_allow_html=True) | |
| st.markdown("---") | |
| env_key = os.environ.get("GROQ_API_KEY", "") | |
| if env_key: | |
| api_key = env_key | |
| st.success("β Groq key loaded from secrets") | |
| else: | |
| api_key = st.text_input("π Groq API Key", type="password", placeholder="gsk_...", help="Free at console.groq.com") | |
| if not api_key: | |
| st.caption("Get free key β [console.groq.com](https://console.groq.com)") | |
| st.markdown("---") | |
| st.markdown("<div class='section-label'>Indexed Documents</div>", unsafe_allow_html=True) | |
| if st.session_state.indexed_files: | |
| for fname, info in st.session_state.indexed_files.items(): | |
| st.markdown(f""" | |
| <div style='padding:6px 0;border-bottom:1px solid #131c2e'> | |
| <div style='font-size:0.8rem;color:#e2e8f0'>π {fname}</div> | |
| <div style='font-size:0.72rem;color:#475569;font-family:IBM Plex Mono,monospace'> | |
| {info["pages"]}p Β· {info["chunks"]} chunks Β· {info["size_kb"]}KB | |
| </div> | |
| </div>""", unsafe_allow_html=True) | |
| st.markdown("---") | |
| if st.button("ποΈ Clear all & reset", use_container_width=True): | |
| for key in ["indexed_files", "chroma_collection", "chroma_client", "total_chunks"]: | |
| del st.session_state[key] | |
| st.rerun() | |
| else: | |
| st.markdown("<div style='color:#374151;font-size:0.82rem'>No documents indexed yet.</div>", unsafe_allow_html=True) | |
| st.markdown("---") | |
| st.markdown(""" | |
| <div style='font-size:0.77rem;color:#374151;line-height:1.9'> | |
| <b style='color:#4b5563'>Stack</b><br> | |
| π PDF parsing: PyMuPDF<br> | |
| βοΈ Chunking: word-overlap (400w)<br> | |
| π’ Embeddings: all-MiniLM-L6-v2<br> | |
| ποΈ Vector DB: ChromaDB in-memory<br> | |
| π§ LLM: Groq Β· Llama 3.3 70B<br> | |
| π Hosting: HuggingFace Spaces | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # βββ Hero βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.markdown(""" | |
| <div class='hero'> | |
| <h1>π PDF RAG β Upload & Ask</h1> | |
| <p>Upload any PDF documents Β· They get extracted, chunked, embedded, and indexed Β· Then ask questions across all of them</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Phase bar | |
| st.markdown(f""" | |
| <div class='phase-bar'> | |
| <div class='phase {"done" if phase > 1 else "active"}'> | |
| <span class='phase-icon'>π€</span>Upload PDFs | |
| </div> | |
| <div class='phase {"active" if phase == 1 else "done"}'> | |
| <span class='phase-icon'>π</span>Extract Text | |
| </div> | |
| <div class='phase {"active" if phase == 1 else "done"}'> | |
| <span class='phase-icon'>βοΈ</span>Chunk | |
| </div> | |
| <div class='phase {"active" if phase == 1 else "done"}'> | |
| <span class='phase-icon'>π’</span>Embed | |
| </div> | |
| <div class='phase {"active" if phase == 1 else "done"}'> | |
| <span class='phase-icon'>ποΈ</span>Index | |
| </div> | |
| <div class='phase {"active" if phase == 2 else ""}'> | |
| <span class='phase-icon'>π¬</span>Ask Questions | |
| </div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # βββ Load model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with st.spinner("βοΈ Loading embedding model..."): | |
| embed_model = load_embed_model() | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # PHASE 1 β Upload & Index | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.markdown("<div class='section-label'>Step 1 β Upload PDF Documents</div>", unsafe_allow_html=True) | |
| uploaded_files = st.file_uploader( | |
| "Drop your PDF files here", | |
| type=["pdf"], | |
| accept_multiple_files=True, | |
| label_visibility="collapsed" | |
| ) | |
| if uploaded_files: | |
| new_files = [f for f in uploaded_files if f.name not in st.session_state.indexed_files] | |
| if new_files: | |
| st.markdown(f"**{len(new_files)} new file(s) ready to index:**") | |
| for f in new_files: | |
| st.markdown(f"<div class='pdf-card'><div><div class='pdf-name'>π {f.name}</div><div class='pdf-meta'>{round(f.size/1024,1)} KB</div></div><div class='pdf-badge'>ready</div></div>", unsafe_allow_html=True) | |
| if st.button(f"β‘ Extract & Index {len(new_files)} PDF(s)", type="primary", use_container_width=True): | |
| progress = st.progress(0, text="Starting...") | |
| for idx, f in enumerate(new_files): | |
| progress.progress((idx) / len(new_files), text=f"Processing: {f.name}") | |
| pdf_bytes = f.read() | |
| with st.spinner(f"Extracting & indexing **{f.name}**..."): | |
| n_chunks, n_pages = index_pdf(f.name, pdf_bytes, embed_model) | |
| st.success(f"β **{f.name}** β {n_pages} pages Β· {n_chunks} chunks indexed") | |
| progress.progress(1.0, text="Done!") | |
| st.balloons() | |
| st.rerun() | |
| else: | |
| st.info("All uploaded files are already indexed. Upload new files or ask questions below.") | |
| elif not has_docs: | |
| st.markdown(""" | |
| <div class='empty-state'> | |
| <div class='icon'>π</div> | |
| <p><b style='color:#94a3b8'>No documents uploaded yet</b><br> | |
| Upload one or more PDF files above to get started.<br> | |
| Any topic works β reports, manuals, research papers, policies.</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # PHASE 2 β Stats & Query | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if has_docs: | |
| total_pages = sum(v["pages"] for v in st.session_state.indexed_files.values()) | |
| st.markdown("<div class='section-label' style='margin-top:24px'>Index Summary</div>", unsafe_allow_html=True) | |
| st.markdown(f""" | |
| <div class='stat-row'> | |
| <div class='stat-box'><div class='stat-val'>{len(st.session_state.indexed_files)}</div><div class='stat-lbl'>Documents</div></div> | |
| <div class='stat-box'><div class='stat-val'>{total_pages}</div><div class='stat-lbl'>Pages Parsed</div></div> | |
| <div class='stat-box'><div class='stat-val'>{st.session_state.total_chunks}</div><div class='stat-lbl'>Chunks Indexed</div></div> | |
| <div class='stat-box'><div class='stat-val'>384</div><div class='stat-lbl'>Embedding Dims</div></div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| if not api_key: | |
| st.warning("π Enter your Groq API key in the sidebar to start asking questions.") | |
| st.stop() | |
| st.markdown("---") | |
| st.markdown("<div class='section-label'>Step 2 β Ask a Question</div>", unsafe_allow_html=True) | |
| col1, col2 = st.columns([5, 1]) | |
| with col1: | |
| question = st.text_input("", placeholder="What does the document say about...?", label_visibility="collapsed") | |
| with col2: | |
| top_k = st.selectbox("Top K", [2, 3, 4, 5], index=1, help="Number of chunks to retrieve") | |
| ask_btn = st.button("π Search & Answer", type="primary", use_container_width=True) | |
| if ask_btn and question: | |
| with st.spinner("π Searching index and generating answer..."): | |
| try: | |
| answer, chunks = rag_query(question, embed_model, top_k, api_key) | |
| st.markdown(f"<div class='section-label'>Answer</div>", unsafe_allow_html=True) | |
| st.markdown(f"<div class='answer-box'>{answer}</div>", unsafe_allow_html=True) | |
| st.markdown("<div class='section-label'>Retrieved Chunks (context sent to LLM)</div>", unsafe_allow_html=True) | |
| for i, chunk in enumerate(chunks): | |
| bar_width = int(chunk['relevance']) | |
| st.markdown(f""" | |
| <div class='chunk-card'> | |
| <div class='chunk-top'> | |
| <div> | |
| <div class='chunk-source'>π {chunk['filename']}</div> | |
| <div class='chunk-page'>Page {chunk['page']}</div> | |
| </div> | |
| <div class='score-bar-wrap'> | |
| <div class='score-bar'><div class='score-fill' style='width:{bar_width}%'></div></div> | |
| <div class='score-num'>{chunk['relevance']}%</div> | |
| </div> | |
| </div> | |
| <div class='chunk-text'>{chunk['text']}</div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| except requests.HTTPError as e: | |
| if e.response.status_code == 401: | |
| st.error("β Invalid Groq API key.") | |
| else: | |
| st.error(f"β API error: {str(e)}") | |
| except Exception as e: | |
| st.error(f"β Error: {str(e)}") | |
| elif ask_btn and not question: | |
| st.warning("Please enter a question.") |