Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from src.loader import load_documents | |
| from src.chunking import chunk_documents | |
| from src.vectorstore import VectorStore | |
| from src.rag_pipeline import RAGPipeline | |
| from src.utils import ensure_directories | |
| from src.evaluation import analyze_confidence_distribution | |
| import os | |
| import tempfile | |
| from pathlib import Path | |
| # Page config | |
| st.set_page_config(page_title="Policy RAG Assistant", layout="wide") | |
| # Initialize | |
| ensure_directories() | |
| # Check API key | |
| if not os.getenv("GROQ_API_KEY"): | |
| st.error("GROQ_API_KEY not set. Please set it as an environment variable.") | |
| st.stop() | |
| # Initialize session state | |
| if "vector_store" not in st.session_state: | |
| st.session_state.vector_store = None | |
| if "rag_pipeline" not in st.session_state: | |
| st.session_state.rag_pipeline = None | |
| if "uploaded_files_count" not in st.session_state: | |
| st.session_state.uploaded_files_count = 0 | |
| # Title | |
| st.title("Policy RAG Assistant") | |
| st.markdown("Ask questions about company policies") | |
| # Sidebar | |
| with st.sidebar: | |
| st.header("Setup") | |
| upload_method = st.radio( | |
| "Choose upload method:", | |
| ["Upload files here", "Load from data/policies/"], | |
| key="upload_method" | |
| ) | |
| if upload_method == "Upload files here": | |
| uploaded_files = st.file_uploader( | |
| "Upload policy documents", | |
| type=["pdf", "txt", "md"], | |
| accept_multiple_files=True, | |
| ) | |
| if uploaded_files and st.button("Process Uploaded Files"): | |
| with st.spinner("Processing uploaded files..."): | |
| from src.loader import load_pdf, load_text | |
| docs = [] | |
| for uploaded_file in uploaded_files: | |
| try: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=Path(uploaded_file.name).suffix) as tmp_file: | |
| tmp_file.write(uploaded_file.getvalue()) | |
| tmp_path = Path(tmp_file.name) | |
| if tmp_path.suffix.lower() == ".pdf": | |
| text = load_pdf(tmp_path) | |
| elif tmp_path.suffix.lower() in [".txt", ".md"]: | |
| text = load_text(tmp_path) | |
| else: | |
| continue | |
| if text.strip(): | |
| docs.append({ | |
| "text": text, | |
| "metadata": { | |
| "source": uploaded_file.name, | |
| "type": tmp_path.suffix[1:] | |
| } | |
| }) | |
| tmp_path.unlink() | |
| except Exception as e: | |
| st.error(f"Error processing {uploaded_file.name}: {e}") | |
| if docs: | |
| chunked = chunk_documents(docs, chunk_size=500, overlap=100) | |
| vector_store = VectorStore() | |
| vector_store.reset() | |
| vector_store.add_documents(chunked) | |
| st.session_state.vector_store = vector_store | |
| st.session_state.rag_pipeline = RAGPipeline(vector_store) | |
| st.session_state.uploaded_files_count = len(docs) | |
| st.success(f"Processed {len(docs)} documents, {len(chunked)} chunks") | |
| else: | |
| st.warning("No valid documents were processed") | |
| else: | |
| if st.button("Load Documents from Folder"): | |
| with st.spinner("Loading documents..."): | |
| docs = load_documents() | |
| if docs: | |
| chunked = chunk_documents(docs, chunk_size=500, overlap=100) | |
| vector_store = VectorStore() | |
| vector_store.reset() | |
| vector_store.add_documents(chunked) | |
| st.session_state.vector_store = vector_store | |
| st.session_state.rag_pipeline = RAGPipeline(vector_store) | |
| st.session_state.uploaded_files_count = len(docs) | |
| st.success(f"Loaded {len(docs)} documents, {len(chunked)} chunks") | |
| else: | |
| st.warning("No documents found in data/policies/") | |
| if st.session_state.vector_store: | |
| st.divider() | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.metric("Documents", st.session_state.uploaded_files_count) | |
| with col2: | |
| st.metric("Total Chunks", st.session_state.vector_store.count()) | |
| st.divider() | |
| st.header("Analytics") | |
| if st.button("View Stats"): | |
| stats = analyze_confidence_distribution() | |
| st.json(stats) | |
| # Main area | |
| if st.session_state.rag_pipeline is None: | |
| st.info("Upload documents or load from folder in the sidebar to get started") | |
| else: | |
| col1, col2 = st.columns([3, 1]) | |
| with col1: | |
| question = st.text_input("Ask a question:", placeholder="e.g., What is the vacation policy?") | |
| with col2: | |
| prompt_type = st.selectbox("Prompt:", ["improved", "initial", "compare"]) | |
| if question: | |
| if prompt_type == "compare": | |
| colA, colB = st.columns(2) | |
| with colA: | |
| st.subheader("Initial Prompt Result") | |
| result_initial = st.session_state.rag_pipeline.query(question, prompt_type="initial") | |
| st.write(result_initial["answer"]) | |
| st.metric("Confidence", result_initial.get("confidence", "N/A")) | |
| if result_initial.get("evaluation"): | |
| st.json(result_initial["evaluation"]) | |
| with colB: | |
| st.subheader("Improved Prompt Result") | |
| result_improved = st.session_state.rag_pipeline.query(question, prompt_type="improved") | |
| st.write(result_improved["answer"]) | |
| st.metric("Confidence", result_improved.get("confidence", "N/A")) | |
| if result_improved.get("evaluation"): | |
| st.json(result_improved["evaluation"]) | |
| display_chunks = result_improved["retrieved_chunks"] | |
| else: | |
| with st.spinner("Searching..."): | |
| response = st.session_state.rag_pipeline.query(question, prompt_type=prompt_type) | |
| st.markdown("### Answer") | |
| st.write(response["answer"]) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.metric("Confidence", response.get("confidence", "N/A")) | |
| with col2: | |
| st.metric("Sources Used", len(response["retrieved_chunks"])) | |
| if response.get("evaluation"): | |
| st.subheader("Evaluation") | |
| st.json(response["evaluation"]) | |
| if response.get("evidence"): | |
| with st.expander("Evidence"): | |
| for i, ev in enumerate(response["evidence"], 1): | |
| st.markdown(f"{i}. {ev}") | |
| display_chunks = response["retrieved_chunks"] | |
| with st.expander("Retrieved Chunks"): | |
| for i, chunk in enumerate(display_chunks, 1): | |
| st.markdown(f"Chunk {i} (score: {chunk.get('score', 0):.4f})") | |
| st.markdown(f"Source: {chunk.get('metadata', {}).get('source', 'Unknown')}") | |
| st.text(chunk["text"][:300] + "..." if len(chunk["text"]) > 300 else chunk["text"]) | |
| st.divider() | |