Harshdhsvguyt's picture
Upload 19 files
754d8d3 verified
import streamlit as st
from src.loader import load_documents
from src.chunking import chunk_documents
from src.vectorstore import VectorStore
from src.rag_pipeline import RAGPipeline
from src.utils import ensure_directories
from src.evaluation import analyze_confidence_distribution
import os
import tempfile
from pathlib import Path
# Page config
st.set_page_config(page_title="Policy RAG Assistant", layout="wide")
# Initialize
ensure_directories()
# Check API key
if not os.getenv("GROQ_API_KEY"):
st.error("GROQ_API_KEY not set. Please set it as an environment variable.")
st.stop()
# Initialize session state
if "vector_store" not in st.session_state:
st.session_state.vector_store = None
if "rag_pipeline" not in st.session_state:
st.session_state.rag_pipeline = None
if "uploaded_files_count" not in st.session_state:
st.session_state.uploaded_files_count = 0
# Title
st.title("Policy RAG Assistant")
st.markdown("Ask questions about company policies")
# Sidebar
with st.sidebar:
st.header("Setup")
upload_method = st.radio(
"Choose upload method:",
["Upload files here", "Load from data/policies/"],
key="upload_method"
)
if upload_method == "Upload files here":
uploaded_files = st.file_uploader(
"Upload policy documents",
type=["pdf", "txt", "md"],
accept_multiple_files=True,
)
if uploaded_files and st.button("Process Uploaded Files"):
with st.spinner("Processing uploaded files..."):
from src.loader import load_pdf, load_text
docs = []
for uploaded_file in uploaded_files:
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=Path(uploaded_file.name).suffix) as tmp_file:
tmp_file.write(uploaded_file.getvalue())
tmp_path = Path(tmp_file.name)
if tmp_path.suffix.lower() == ".pdf":
text = load_pdf(tmp_path)
elif tmp_path.suffix.lower() in [".txt", ".md"]:
text = load_text(tmp_path)
else:
continue
if text.strip():
docs.append({
"text": text,
"metadata": {
"source": uploaded_file.name,
"type": tmp_path.suffix[1:]
}
})
tmp_path.unlink()
except Exception as e:
st.error(f"Error processing {uploaded_file.name}: {e}")
if docs:
chunked = chunk_documents(docs, chunk_size=500, overlap=100)
vector_store = VectorStore()
vector_store.reset()
vector_store.add_documents(chunked)
st.session_state.vector_store = vector_store
st.session_state.rag_pipeline = RAGPipeline(vector_store)
st.session_state.uploaded_files_count = len(docs)
st.success(f"Processed {len(docs)} documents, {len(chunked)} chunks")
else:
st.warning("No valid documents were processed")
else:
if st.button("Load Documents from Folder"):
with st.spinner("Loading documents..."):
docs = load_documents()
if docs:
chunked = chunk_documents(docs, chunk_size=500, overlap=100)
vector_store = VectorStore()
vector_store.reset()
vector_store.add_documents(chunked)
st.session_state.vector_store = vector_store
st.session_state.rag_pipeline = RAGPipeline(vector_store)
st.session_state.uploaded_files_count = len(docs)
st.success(f"Loaded {len(docs)} documents, {len(chunked)} chunks")
else:
st.warning("No documents found in data/policies/")
if st.session_state.vector_store:
st.divider()
col1, col2 = st.columns(2)
with col1:
st.metric("Documents", st.session_state.uploaded_files_count)
with col2:
st.metric("Total Chunks", st.session_state.vector_store.count())
st.divider()
st.header("Analytics")
if st.button("View Stats"):
stats = analyze_confidence_distribution()
st.json(stats)
# Main area
if st.session_state.rag_pipeline is None:
st.info("Upload documents or load from folder in the sidebar to get started")
else:
col1, col2 = st.columns([3, 1])
with col1:
question = st.text_input("Ask a question:", placeholder="e.g., What is the vacation policy?")
with col2:
prompt_type = st.selectbox("Prompt:", ["improved", "initial", "compare"])
if question:
if prompt_type == "compare":
colA, colB = st.columns(2)
with colA:
st.subheader("Initial Prompt Result")
result_initial = st.session_state.rag_pipeline.query(question, prompt_type="initial")
st.write(result_initial["answer"])
st.metric("Confidence", result_initial.get("confidence", "N/A"))
if result_initial.get("evaluation"):
st.json(result_initial["evaluation"])
with colB:
st.subheader("Improved Prompt Result")
result_improved = st.session_state.rag_pipeline.query(question, prompt_type="improved")
st.write(result_improved["answer"])
st.metric("Confidence", result_improved.get("confidence", "N/A"))
if result_improved.get("evaluation"):
st.json(result_improved["evaluation"])
display_chunks = result_improved["retrieved_chunks"]
else:
with st.spinner("Searching..."):
response = st.session_state.rag_pipeline.query(question, prompt_type=prompt_type)
st.markdown("### Answer")
st.write(response["answer"])
col1, col2 = st.columns(2)
with col1:
st.metric("Confidence", response.get("confidence", "N/A"))
with col2:
st.metric("Sources Used", len(response["retrieved_chunks"]))
if response.get("evaluation"):
st.subheader("Evaluation")
st.json(response["evaluation"])
if response.get("evidence"):
with st.expander("Evidence"):
for i, ev in enumerate(response["evidence"], 1):
st.markdown(f"{i}. {ev}")
display_chunks = response["retrieved_chunks"]
with st.expander("Retrieved Chunks"):
for i, chunk in enumerate(display_chunks, 1):
st.markdown(f"Chunk {i} (score: {chunk.get('score', 0):.4f})")
st.markdown(f"Source: {chunk.get('metadata', {}).get('source', 'Unknown')}")
st.text(chunk["text"][:300] + "..." if len(chunk["text"]) > 300 else chunk["text"])
st.divider()