Real_time_RAG / src /streamlit_app.py
anir-1995's picture
Update src/streamlit_app.py
622a72f verified
import streamlit as st
import chromadb
from sentence_transformers import SentenceTransformer
import fitz # PyMuPDF
import os
import requests
import re
import hashlib
# ─── Page Config ──────────────────────────────────────────────────────────────
st.set_page_config(
page_title="PDF RAG Β· Upload & Ask",
page_icon="πŸ“‚",
layout="wide",
initial_sidebar_state="expanded"
)
# ─── CSS ──────────────────────────────────────────────────────────────────────
st.markdown("""
<style>
@import url('https://fonts.googleapis.com/css2?family=IBM+Plex+Sans:wght@300;400;500;600&family=IBM+Plex+Mono:wght@400;500&display=swap');
html, body, [class*="css"] { font-family: 'IBM Plex Sans', sans-serif; }
.main { background-color: #0b0f1a; }
.hero {
background: linear-gradient(160deg, #0d1424 0%, #0b0f1a 100%);
border: 1px solid #1e2a3e;
border-top: 3px solid #22d3ee;
border-radius: 12px;
padding: 28px 32px;
margin-bottom: 24px;
}
.hero h1 { font-size: 1.8rem; font-weight: 600; color: #e2e8f0; margin: 0 0 6px 0; }
.hero p { color: #64748b; font-size: 0.95rem; margin: 0; }
.phase-bar {
display: flex; gap: 0; margin-bottom: 28px;
border: 1px solid #1e2a3e; border-radius: 10px; overflow: hidden;
}
.phase {
flex: 1; padding: 10px 6px; text-align: center;
font-size: 0.75rem; color: #4b5563; background: #0d1117;
border-right: 1px solid #1e2a3e; line-height: 1.5;
}
.phase:last-child { border-right: none; }
.phase.done { color: #22d3ee; background: rgba(34,211,238,0.05); }
.phase.active { color: #f8fafc; background: rgba(34,211,238,0.1); font-weight: 600; }
.phase-icon { font-size: 1.1rem; display: block; margin-bottom: 2px; }
.pdf-card {
background: #0d1424;
border: 1px solid #1e2a3e;
border-radius: 10px;
padding: 14px 16px;
margin: 8px 0;
display: flex;
align-items: center;
justify-content: space-between;
}
.pdf-name { font-size: 0.85rem; color: #e2e8f0; font-weight: 500; }
.pdf-meta { font-family: 'IBM Plex Mono', monospace; font-size: 0.72rem; color: #475569; margin-top: 3px; }
.pdf-badge {
font-size: 0.72rem; font-family: 'IBM Plex Mono', monospace;
background: rgba(34,211,238,0.1); color: #22d3ee;
border: 1px solid rgba(34,211,238,0.25); padding: 3px 10px; border-radius: 20px;
}
.answer-box {
background: #0d1424;
border: 1px solid #1e3a4a;
border-left: 3px solid #22d3ee;
border-radius: 10px;
padding: 22px 24px;
color: #e2e8f0;
line-height: 1.75;
font-size: 0.96rem;
margin: 12px 0 20px 0;
}
.chunk-card {
background: #0d1117;
border: 1px solid #1e2a3e;
border-radius: 9px;
padding: 14px 18px;
margin: 8px 0;
}
.chunk-top {
display: flex; justify-content: space-between;
align-items: center; margin-bottom: 8px;
}
.chunk-source { font-size: 0.77rem; font-weight: 600; color: #22d3ee; text-transform: uppercase; letter-spacing: 0.05em; }
.chunk-page { font-family: 'IBM Plex Mono', monospace; font-size: 0.72rem; color: #475569; }
.score-bar-wrap { display: flex; align-items: center; gap: 8px; }
.score-bar {
height: 4px; border-radius: 2px; background: #1e2a3e; width: 80px; overflow: hidden;
}
.score-fill { height: 100%; border-radius: 2px; background: #22d3ee; }
.score-num { font-family: 'IBM Plex Mono', monospace; font-size: 0.72rem; color: #22d3ee; }
.chunk-text { font-size: 0.86rem; color: #94a3b8; line-height: 1.65; }
.stat-row { display: flex; gap: 10px; margin: 16px 0; }
.stat-box {
flex: 1; background: #0d1424; border: 1px solid #1e2a3e;
border-radius: 8px; padding: 12px; text-align: center;
}
.stat-val { font-size: 1.35rem; font-weight: 600; color: #22d3ee; }
.stat-lbl { font-size: 0.7rem; color: #475569; margin-top: 2px; }
.section-label {
font-size: 0.7rem; text-transform: uppercase; letter-spacing: 0.1em;
color: #374151; font-weight: 600; margin: 18px 0 8px 0;
}
section[data-testid="stSidebar"] {
background-color: #080c14; border-right: 1px solid #131c2e;
}
.empty-state {
text-align: center; padding: 48px 24px;
border: 2px dashed #1e2a3e; border-radius: 12px; color: #374151;
}
.empty-state .icon { font-size: 2.5rem; margin-bottom: 12px; }
.empty-state p { font-size: 0.9rem; line-height: 1.6; }
</style>
""", unsafe_allow_html=True)
# ─── Session State ────────────────────────────────────────────────────────────
if "indexed_files" not in st.session_state:
st.session_state.indexed_files = {} # filename β†’ {chunks, pages, size}
if "chroma_collection" not in st.session_state:
st.session_state.chroma_collection = None
if "chroma_client" not in st.session_state:
st.session_state.chroma_client = None
if "total_chunks" not in st.session_state:
st.session_state.total_chunks = 0
# ─── Load embedding model (cached globally) ───────────────────────────────────
@st.cache_resource(show_spinner=False)
def load_embed_model():
return SentenceTransformer('all-MiniLM-L6-v2')
# ─── PDF Extraction ───────────────────────────────────────────────────────────
def extract_text_from_pdf(pdf_bytes: bytes) -> list[dict]:
"""Returns list of {page, text} dicts."""
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
pages = []
for page_num, page in enumerate(doc, start=1):
text = page.get_text("text").strip()
if text:
pages.append({"page": page_num, "text": text})
doc.close()
return pages
# ─── Chunking ─────────────────────────────────────────────────────────────────
def chunk_text(pages: list[dict], chunk_size: int = 400, overlap: int = 60) -> list[dict]:
"""Splits page text into overlapping word-based chunks."""
chunks = []
for p in pages:
words = p["text"].split()
start = 0
while start < len(words):
end = start + chunk_size
chunk_words = words[start:end]
chunk_text_str = " ".join(chunk_words).strip()
if len(chunk_text_str) > 60:
chunks.append({"text": chunk_text_str, "page": p["page"]})
start += chunk_size - overlap
return chunks
# ─── Index PDF into ChromaDB ──────────────────────────────────────────────────
def index_pdf(filename: str, pdf_bytes: bytes, embed_model):
# Init or reuse ChromaDB
if st.session_state.chroma_client is None:
st.session_state.chroma_client = chromadb.Client()
st.session_state.chroma_collection = st.session_state.chroma_client.get_or_create_collection(
name="pdf_rag", metadata={"hnsw:space": "cosine"}
)
collection = st.session_state.chroma_collection
# Extract & chunk
pages = extract_text_from_pdf(pdf_bytes)
chunks = chunk_text(pages)
if not chunks:
return 0, 0
# Embed & add
texts = [c["text"] for c in chunks]
embeddings = embed_model.encode(texts, batch_size=32, show_progress_bar=False).tolist()
ids, docs, metas, embeds = [], [], [], []
for i, (chunk, emb) in enumerate(zip(chunks, embeddings)):
chunk_id = f"{hashlib.md5(filename.encode()).hexdigest()[:8]}_chunk_{i}"
ids.append(chunk_id)
docs.append(chunk["text"])
metas.append({"filename": filename, "page": chunk["page"]})
embeds.append(emb)
collection.add(ids=ids, embeddings=embeds, documents=docs, metadatas=metas)
st.session_state.indexed_files[filename] = {
"chunks": len(chunks),
"pages": len(pages),
"size_kb": round(len(pdf_bytes) / 1024, 1)
}
st.session_state.total_chunks += len(chunks)
return len(chunks), len(pages)
# ─── RAG Query ────────────────────────────────────────────────────────────────
def rag_query(question: str, embed_model, top_k: int, api_key: str):
collection = st.session_state.chroma_collection
q_emb = embed_model.encode(question).tolist()
results = collection.query(query_embeddings=[q_emb], n_results=top_k)
chunks = []
for i in range(len(results["documents"][0])):
distance = results["distances"][0][i]
chunks.append({
"text": results["documents"][0][i],
"filename": results["metadatas"][0][i]["filename"],
"page": results["metadatas"][0][i]["page"],
"relevance": round((1 - distance) * 100, 1),
})
context = "\n\n".join([
f"[Source: {c['filename']}, Page {c['page']}]\n{c['text']}" for c in chunks
])
prompt = f"""You are a helpful assistant. Answer the user's question using ONLY the document context provided below. Be concise and clear. Always mention the source filename and page number when referencing specific information. If the answer cannot be found in the provided context, say "I couldn't find that information in the uploaded documents."
Document Context:
{context}
Question: {question}
Answer:"""
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
payload = {
"model": "llama-3.3-70b-versatile",
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 600,
"temperature": 0.2,
}
r = requests.post("https://api.groq.com/openai/v1/chat/completions", headers=headers, json=payload, timeout=30)
r.raise_for_status()
answer = r.json()["choices"][0]["message"]["content"]
return answer, chunks
# ─── Determine current phase ──────────────────────────────────────────────────
has_docs = len(st.session_state.indexed_files) > 0
phase = 1 if not has_docs else 2
# ─── Sidebar ──────────────────────────────────────────────────────────────────
with st.sidebar:
st.markdown("## πŸ“‚ PDF RAG Demo")
st.markdown("<div style='color:#374151;font-size:0.8rem'>Upload β†’ Extract β†’ Index β†’ Ask</div>", unsafe_allow_html=True)
st.markdown("---")
env_key = os.environ.get("GROQ_API_KEY", "")
if env_key:
api_key = env_key
st.success("βœ… Groq key loaded from secrets")
else:
api_key = st.text_input("πŸ”‘ Groq API Key", type="password", placeholder="gsk_...", help="Free at console.groq.com")
if not api_key:
st.caption("Get free key β†’ [console.groq.com](https://console.groq.com)")
st.markdown("---")
st.markdown("<div class='section-label'>Indexed Documents</div>", unsafe_allow_html=True)
if st.session_state.indexed_files:
for fname, info in st.session_state.indexed_files.items():
st.markdown(f"""
<div style='padding:6px 0;border-bottom:1px solid #131c2e'>
<div style='font-size:0.8rem;color:#e2e8f0'>πŸ“„ {fname}</div>
<div style='font-size:0.72rem;color:#475569;font-family:IBM Plex Mono,monospace'>
{info["pages"]}p Β· {info["chunks"]} chunks Β· {info["size_kb"]}KB
</div>
</div>""", unsafe_allow_html=True)
st.markdown("---")
if st.button("πŸ—‘οΈ Clear all & reset", use_container_width=True):
for key in ["indexed_files", "chroma_collection", "chroma_client", "total_chunks"]:
del st.session_state[key]
st.rerun()
else:
st.markdown("<div style='color:#374151;font-size:0.82rem'>No documents indexed yet.</div>", unsafe_allow_html=True)
st.markdown("---")
st.markdown("""
<div style='font-size:0.77rem;color:#374151;line-height:1.9'>
<b style='color:#4b5563'>Stack</b><br>
πŸ“„ PDF parsing: PyMuPDF<br>
βœ‚οΈ Chunking: word-overlap (400w)<br>
πŸ”’ Embeddings: all-MiniLM-L6-v2<br>
πŸ—„οΈ Vector DB: ChromaDB in-memory<br>
🧠 LLM: Groq · Llama 3.3 70B<br>
🌐 Hosting: HuggingFace Spaces
</div>
""", unsafe_allow_html=True)
# ─── Hero ─────────────────────────────────────────────────────────────────────
st.markdown("""
<div class='hero'>
<h1>πŸ“‚ PDF RAG β€” Upload & Ask</h1>
<p>Upload any PDF documents Β· They get extracted, chunked, embedded, and indexed Β· Then ask questions across all of them</p>
</div>
""", unsafe_allow_html=True)
# Phase bar
st.markdown(f"""
<div class='phase-bar'>
<div class='phase {"done" if phase > 1 else "active"}'>
<span class='phase-icon'>πŸ“€</span>Upload PDFs
</div>
<div class='phase {"active" if phase == 1 else "done"}'>
<span class='phase-icon'>πŸ“‘</span>Extract Text
</div>
<div class='phase {"active" if phase == 1 else "done"}'>
<span class='phase-icon'>βœ‚οΈ</span>Chunk
</div>
<div class='phase {"active" if phase == 1 else "done"}'>
<span class='phase-icon'>πŸ”’</span>Embed
</div>
<div class='phase {"active" if phase == 1 else "done"}'>
<span class='phase-icon'>πŸ—„οΈ</span>Index
</div>
<div class='phase {"active" if phase == 2 else ""}'>
<span class='phase-icon'>πŸ’¬</span>Ask Questions
</div>
</div>
""", unsafe_allow_html=True)
# ─── Load model ───────────────────────────────────────────────────────────────
with st.spinner("βš™οΈ Loading embedding model..."):
embed_model = load_embed_model()
# ════════════════════════════════════════════════════════════
# PHASE 1 β€” Upload & Index
# ════════════════════════════════════════════════════════════
st.markdown("<div class='section-label'>Step 1 β€” Upload PDF Documents</div>", unsafe_allow_html=True)
uploaded_files = st.file_uploader(
"Drop your PDF files here",
type=["pdf"],
accept_multiple_files=True,
label_visibility="collapsed"
)
if uploaded_files:
new_files = [f for f in uploaded_files if f.name not in st.session_state.indexed_files]
if new_files:
st.markdown(f"**{len(new_files)} new file(s) ready to index:**")
for f in new_files:
st.markdown(f"<div class='pdf-card'><div><div class='pdf-name'>πŸ“„ {f.name}</div><div class='pdf-meta'>{round(f.size/1024,1)} KB</div></div><div class='pdf-badge'>ready</div></div>", unsafe_allow_html=True)
if st.button(f"⚑ Extract & Index {len(new_files)} PDF(s)", type="primary", use_container_width=True):
progress = st.progress(0, text="Starting...")
for idx, f in enumerate(new_files):
progress.progress((idx) / len(new_files), text=f"Processing: {f.name}")
pdf_bytes = f.read()
with st.spinner(f"Extracting & indexing **{f.name}**..."):
n_chunks, n_pages = index_pdf(f.name, pdf_bytes, embed_model)
st.success(f"βœ… **{f.name}** β†’ {n_pages} pages Β· {n_chunks} chunks indexed")
progress.progress(1.0, text="Done!")
st.balloons()
st.rerun()
else:
st.info("All uploaded files are already indexed. Upload new files or ask questions below.")
elif not has_docs:
st.markdown("""
<div class='empty-state'>
<div class='icon'>πŸ“‚</div>
<p><b style='color:#94a3b8'>No documents uploaded yet</b><br>
Upload one or more PDF files above to get started.<br>
Any topic works β€” reports, manuals, research papers, policies.</p>
</div>
""", unsafe_allow_html=True)
# ════════════════════════════════════════════════════════════
# PHASE 2 β€” Stats & Query
# ════════════════════════════════════════════════════════════
if has_docs:
total_pages = sum(v["pages"] for v in st.session_state.indexed_files.values())
st.markdown("<div class='section-label' style='margin-top:24px'>Index Summary</div>", unsafe_allow_html=True)
st.markdown(f"""
<div class='stat-row'>
<div class='stat-box'><div class='stat-val'>{len(st.session_state.indexed_files)}</div><div class='stat-lbl'>Documents</div></div>
<div class='stat-box'><div class='stat-val'>{total_pages}</div><div class='stat-lbl'>Pages Parsed</div></div>
<div class='stat-box'><div class='stat-val'>{st.session_state.total_chunks}</div><div class='stat-lbl'>Chunks Indexed</div></div>
<div class='stat-box'><div class='stat-val'>384</div><div class='stat-lbl'>Embedding Dims</div></div>
</div>
""", unsafe_allow_html=True)
if not api_key:
st.warning("πŸ‘ˆ Enter your Groq API key in the sidebar to start asking questions.")
st.stop()
st.markdown("---")
st.markdown("<div class='section-label'>Step 2 β€” Ask a Question</div>", unsafe_allow_html=True)
col1, col2 = st.columns([5, 1])
with col1:
question = st.text_input("", placeholder="What does the document say about...?", label_visibility="collapsed")
with col2:
top_k = st.selectbox("Top K", [2, 3, 4, 5], index=1, help="Number of chunks to retrieve")
ask_btn = st.button("πŸ” Search & Answer", type="primary", use_container_width=True)
if ask_btn and question:
with st.spinner("πŸ” Searching index and generating answer..."):
try:
answer, chunks = rag_query(question, embed_model, top_k, api_key)
st.markdown(f"<div class='section-label'>Answer</div>", unsafe_allow_html=True)
st.markdown(f"<div class='answer-box'>{answer}</div>", unsafe_allow_html=True)
st.markdown("<div class='section-label'>Retrieved Chunks (context sent to LLM)</div>", unsafe_allow_html=True)
for i, chunk in enumerate(chunks):
bar_width = int(chunk['relevance'])
st.markdown(f"""
<div class='chunk-card'>
<div class='chunk-top'>
<div>
<div class='chunk-source'>πŸ“„ {chunk['filename']}</div>
<div class='chunk-page'>Page {chunk['page']}</div>
</div>
<div class='score-bar-wrap'>
<div class='score-bar'><div class='score-fill' style='width:{bar_width}%'></div></div>
<div class='score-num'>{chunk['relevance']}%</div>
</div>
</div>
<div class='chunk-text'>{chunk['text']}</div>
</div>
""", unsafe_allow_html=True)
except requests.HTTPError as e:
if e.response.status_code == 401:
st.error("❌ Invalid Groq API key.")
else:
st.error(f"❌ API error: {str(e)}")
except Exception as e:
st.error(f"❌ Error: {str(e)}")
elif ask_btn and not question:
st.warning("Please enter a question.")