amitcoolll's picture
Improve retrieval: LLM query rewriting, clarification, and chat UI
6e1d29c
import os
import streamlit as st
from src.config import (
UPLOAD_DIR,
CHUNK_TOKENS,
CHUNK_OVERLAP,
TOP_K
)
from src.parsers import read_pdf, read_docx
from src.chunking import chunk_text
from src.embeddings import embed_texts
from src.vectorstore import add_documents, reset_collection
from src.rag import answer_question, clarification_question
# ---------------- Page config ----------------
st.set_page_config(
page_title="Document Chatbot (RAG)",
layout="wide"
)
st.title("πŸ“„ Document Chatbot (RAG)")
st.caption(
"Upload PDF/DOCX β†’ Build Index β†’ Chat using document knowledge with citations"
)
# ---------------- Sidebar ----------------
with st.sidebar:
st.header("βš™οΈ Settings")
st.write(f"Chunk size: {CHUNK_TOKENS} tokens")
st.write(f"Overlap: {CHUNK_OVERLAP} tokens")
st.write(f"Top-K retrieval: {TOP_K}")
if st.button("🧹 Clear Index"):
reset_collection()
st.success("Index cleared.")
if st.button("πŸ—‘οΈ Clear Chat"):
st.session_state.messages = []
st.session_state.pending_question = None
st.session_state.clarification = None
st.success("Chat cleared.")
# ---------------- Ensure folders ----------------
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs("./data", exist_ok=True)
# ---------------- Upload ----------------
st.subheader("πŸ“€ Upload Documents")
uploaded_files = st.file_uploader(
"Upload PDF/DOCX files",
type=["pdf", "docx"],
accept_multiple_files=True
)
# ---------------- Build Index ----------------
if st.button("βœ… Build Index"):
if not uploaded_files:
st.warning("Please upload at least one document.")
else:
with st.spinner("Indexing documents..."):
documents, metadatas, ids = [], [], []
for uploaded_file in uploaded_files:
file_name = uploaded_file.name
file_bytes = uploaded_file.read()
if file_name.lower().endswith(".pdf"):
pages = read_pdf(file_bytes)
elif file_name.lower().endswith(".docx"):
pages = read_docx(file_bytes)
else:
continue
for page_no, text in pages:
for i, chunk in enumerate(chunk_text(text)):
documents.append(chunk)
metadatas.append({
"file": file_name,
"page": page_no
})
ids.append(f"{file_name}_p{page_no}_c{i}")
if not documents:
st.error("No text extracted. Scanned PDFs need OCR.")
else:
vectors = embed_texts(documents)
add_documents(
documents=documents,
embeddings=vectors,
metadatas=metadatas,
ids=ids
)
st.success(
f"Indexed {len(documents)} chunks "
f"from {len(uploaded_files)} file(s)."
)
st.divider()
# ===================== CHAT UI =====================
st.subheader("πŸ’¬ Chat with your documents")
# Session state
if "messages" not in st.session_state:
st.session_state.messages = []
if "pending_question" not in st.session_state:
st.session_state.pending_question = None
if "clarification" not in st.session_state:
st.session_state.clarification = None
# Render chat history
for msg in st.session_state.messages:
with st.chat_message(msg["role"]):
st.markdown(msg["content"])
# Chat input
user_input = st.chat_input("Ask something about the uploaded documents...")
if user_input:
st.session_state.messages.append({
"role": "user",
"content": user_input
})
# Ask LLM if clarification is needed
clarify = clarification_question(user_input)
if clarify:
st.session_state.pending_question = user_input
st.session_state.clarification = clarify
st.session_state.messages.append({
"role": "assistant",
"content": clarify
})
st.rerun()
else:
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
answer, citations = answer_question(user_input)
final = answer
if citations:
final += "\n\n**Sources:**\n" + "\n".join(f"- {c}" for c in citations)
st.markdown(final)
st.session_state.messages.append({
"role": "assistant",
"content": final
})
# ---------- Clarification buttons ----------
if st.session_state.pending_question:
col1, col2 = st.columns(2)
if col1.button("βœ… Yes, that's what I mean"):
q = st.session_state.pending_question
st.session_state.pending_question = None
st.session_state.clarification = None
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
answer, citations = answer_question(q)
final = answer
if citations:
final += "\n\n**Sources:**\n" + "\n".join(f"- {c}" for c in citations)
st.markdown(final)
st.session_state.messages.append({
"role": "assistant",
"content": final
})
st.rerun()
if col2.button("❌ No, something else"):
st.session_state.pending_question = None
st.session_state.clarification = None
st.session_state.messages.append({
"role": "assistant",
"content": "Okay β€” please type your question with a bit more detail."
})
st.rerun()