"""Streamlit chat interface for RAG application.""" import streamlit as st import sys import os from datetime import datetime import json import pandas as pd from typing import Optional import warnings # Suppress warnings warnings.filterwarnings('ignore') os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Add parent directory to path sys.path.append(os.path.dirname(os.path.abspath(__file__))) from config import settings from dataset_loader import RAGBenchLoader from vector_store import ChromaDBManager from llm_client import GroqLLMClient, RAGPipeline from trace_evaluator import TRACEEvaluator from embedding_models import EmbeddingFactory from chunking_strategies import ChunkingFactory # Page configuration st.set_page_config( page_title="RAG Capstone Project", page_icon="🤖", layout="wide" ) # Initialize session state if "chat_history" not in st.session_state: st.session_state.chat_history = [] if "rag_pipeline" not in st.session_state: st.session_state.rag_pipeline = None if "vector_store" not in st.session_state: st.session_state.vector_store = None if "collection_loaded" not in st.session_state: st.session_state.collection_loaded = False if "evaluation_results" not in st.session_state: st.session_state.evaluation_results = None if "dataset_size" not in st.session_state: st.session_state.dataset_size = 10000 if "current_dataset" not in st.session_state: st.session_state.current_dataset = None if "current_llm" not in st.session_state: st.session_state.current_llm = settings.llm_models[1] if "selected_collection" not in st.session_state: st.session_state.selected_collection = None if "available_collections" not in st.session_state: st.session_state.available_collections = [] def get_available_collections(): """Get list of available collections from ChromaDB.""" try: vector_store = ChromaDBManager(settings.chroma_persist_directory) collections = vector_store.list_collections() return collections except Exception as e: print(f"Error getting collections: {e}") return [] def main(): """Main Streamlit application.""" st.title("🤖 RAG Capstone Project") st.markdown("### Retrieval-Augmented Generation with TRACE Evaluation") # Get available collections at startup available_collections = get_available_collections() st.session_state.available_collections = available_collections # Sidebar for configuration with st.sidebar: st.header("Configuration") # API Key input groq_api_key = st.text_input( "Groq API Key", type="password", value=settings.groq_api_key or "", help="Enter your Groq API key" ) st.divider() # Option 1: Use existing collection if available_collections: st.subheader("📚 Existing Collections") st.write(f"Found {len(available_collections)} collection(s)") selected_collection = st.selectbox( "Or select existing collection:", available_collections, key="collection_selector" ) if st.button("📖 Load Existing Collection", type="secondary"): if not groq_api_key: st.error("Please enter your Groq API key") else: load_existing_collection(groq_api_key, selected_collection) st.divider() # Option 2: Create new collection st.subheader("🆕 Create New Collection") # Dataset selection st.subheader("1. Dataset Selection") dataset_name = st.selectbox( "Choose Dataset", settings.ragbench_datasets, index=0 ) # Get dataset size dynamically if st.button("🔍 Check Dataset Size", key="check_size"): with st.spinner("Checking dataset size..."): try: from datasets import load_dataset import os # Load dataset with download_mode to avoid cache issues st.info(f"Fetching dataset info for '{dataset_name}'...") ds = load_dataset( "rungalileo/ragbench", dataset_name, split="train", trust_remote_code=True, download_mode="force_redownload" # Force fresh download to avoid cache corruption ) dataset_size = len(ds) st.session_state.dataset_size = dataset_size st.session_state.current_dataset = dataset_name st.success(f"✅ Dataset '{dataset_name}' has {dataset_size:,} samples available") except Exception as e: st.error(f"❌ Error: {str(e)}") st.exception(e) st.warning(f"Could not determine dataset size. Using default of 10,000.") st.session_state.dataset_size = 10000 st.session_state.current_dataset = dataset_name # Use stored dataset size or default max_samples_available = st.session_state.get('dataset_size', 10000) st.caption(f"Max available samples: {max_samples_available:,}") num_samples = st.slider( "Number of samples", min_value=10, max_value=max_samples_available, value=min(100, max_samples_available), step=50 if max_samples_available > 1000 else 10, help="Adjust slider to select number of samples" ) load_all_samples = st.checkbox( "Load all available samples", value=False, help="Override slider and load entire dataset" ) st.divider() # Chunking strategy st.subheader("2. Chunking Strategy") chunking_strategy = st.selectbox( "Choose Chunking Strategy", settings.chunking_strategies, index=0 ) chunk_size = st.slider( "Chunk Size", min_value=256, max_value=1024, value=512, step=128 ) overlap = st.slider( "Overlap", min_value=0, max_value=200, value=50, step=10 ) st.divider() # Embedding model st.subheader("3. Embedding Model") embedding_model = st.selectbox( "Choose Embedding Model", settings.embedding_models, index=0 ) st.divider() # LLM model selection for new collection st.subheader("4. LLM Model") llm_model = st.selectbox( "Choose LLM", settings.llm_models, index=1 ) st.divider() # Load data button if st.button("🚀 Load Data & Create Collection", type="primary"): if not groq_api_key: st.error("Please enter your Groq API key") else: # Use None for num_samples if loading all data samples_to_load = None if load_all_samples else num_samples load_and_create_collection( groq_api_key, dataset_name, samples_to_load, chunking_strategy, chunk_size, overlap, embedding_model, llm_model ) # Main content area if not st.session_state.collection_loaded: st.info("👈 Please configure and load a dataset from the sidebar to begin") # Show instructions with st.expander("📖 How to Use", expanded=True): st.markdown(""" 1. **Enter your Groq API Key** in the sidebar 2. **Select a dataset** from RAG Bench 3. **Choose a chunking strategy** (dense, sparse, hybrid, re-ranking) 4. **Select an embedding model** for document vectorization 5. **Choose an LLM model** for response generation 6. **Click "Load Data & Create Collection"** to initialize 7. **Start chatting** in the chat interface 8. **View retrieved documents** and evaluation metrics 9. **Run TRACE evaluation** on test data """) # Show available options col1, col2 = st.columns(2) with col1: st.subheader("📊 Available Datasets") for ds in settings.ragbench_datasets: st.markdown(f"- {ds}") with col2: st.subheader("🤖 Available Models") st.markdown("**Embedding Models:**") for em in settings.embedding_models: st.markdown(f"- {em}") st.markdown("**LLM Models:**") for lm in settings.llm_models: st.markdown(f"- {lm}") else: # Create tabs for different functionalities tab1, tab2, tab3 = st.tabs(["💬 Chat", "📊 Evaluation", "📜 History"]) with tab1: chat_interface() with tab2: evaluation_interface() with tab3: history_interface() def load_existing_collection(api_key: str, collection_name: str): """Load an existing collection from ChromaDB.""" with st.spinner(f"Loading collection '{collection_name}'..."): try: # Initialize vector store and get collection vector_store = ChromaDBManager(settings.chroma_persist_directory) vector_store.get_collection(collection_name) # Prompt for LLM selection st.session_state.current_llm = st.selectbox( "Select LLM for this collection:", settings.llm_models, key=f"llm_selector_{collection_name}" ) # Initialize LLM client st.info("Initializing LLM client...") llm_client = GroqLLMClient( api_key=api_key, model_name=st.session_state.current_llm, max_rpm=settings.groq_rpm_limit, rate_limit_delay=settings.rate_limit_delay ) # Create RAG pipeline with correct parameter names st.info("Creating RAG pipeline...") rag_pipeline = RAGPipeline( llm_client=llm_client, vector_store_manager=vector_store ) # Store in session state st.session_state.vector_store = vector_store st.session_state.rag_pipeline = rag_pipeline st.session_state.collection_loaded = True st.session_state.current_collection = collection_name st.session_state.selected_collection = collection_name st.session_state.groq_api_key = api_key st.success(f"✅ Collection '{collection_name}' loaded successfully!") st.rerun() except Exception as e: st.error(f"Error loading collection: {str(e)}") st.exception(e) def load_and_create_collection( api_key: str, dataset_name: str, num_samples: Optional[int], chunking_strategy: str, chunk_size: int, overlap: int, embedding_model: str, llm_model: str ): """Load dataset and create vector collection.""" with st.spinner("Loading dataset and creating collection..."): try: # Initialize dataset loader loader = RAGBenchLoader() # Load dataset if num_samples is None: st.info(f"Loading {dataset_name} dataset (all available samples)...") else: st.info(f"Loading {dataset_name} dataset ({num_samples} samples)...") dataset = loader.load_dataset(dataset_name, split="train", max_samples=num_samples) st.info(f"Loading {dataset_name} dataset...") dataset = loader.load_dataset(dataset_name, split="train", max_samples=num_samples) if not dataset: st.error("Failed to load dataset") return # Initialize vector store st.info("Initializing vector store...") vector_store = ChromaDBManager(settings.chroma_persist_directory) # Create collection name collection_name = f"{dataset_name}_{chunking_strategy}_{embedding_model.split('/')[-1]}" collection_name = collection_name.replace("-", "_").replace(".", "_") # Delete existing collection with same name (if exists) existing_collections = vector_store.list_collections() if collection_name in existing_collections: st.warning(f"Collection '{collection_name}' already exists. Deleting and recreating...") vector_store.delete_collection(collection_name) st.info("Old collection deleted. Creating new one...") # Load data into collection st.info(f"Creating collection with {chunking_strategy} chunking...") vector_store.load_dataset_into_collection( collection_name=collection_name, embedding_model_name=embedding_model, chunking_strategy=chunking_strategy, dataset_data=dataset, chunk_size=chunk_size, overlap=overlap ) # Initialize LLM client st.info("Initializing LLM client...") llm_client = GroqLLMClient( api_key=api_key, model_name=llm_model, max_rpm=settings.groq_rpm_limit, rate_limit_delay=settings.rate_limit_delay ) # Create RAG pipeline with correct parameter names rag_pipeline = RAGPipeline( llm_client=llm_client, vector_store_manager=vector_store ) # Store in session state st.session_state.vector_store = vector_store st.session_state.rag_pipeline = rag_pipeline st.session_state.collection_loaded = True st.session_state.current_collection = collection_name st.session_state.dataset_name = dataset_name st.session_state.dataset = dataset st.success(f"✅ Collection '{collection_name}' created successfully!") st.rerun() except Exception as e: st.error(f"Error: {str(e)}") def chat_interface(): """Chat interface tab.""" st.subheader("💬 Chat Interface") # Check if collection is loaded if not st.session_state.collection_loaded: st.warning("⚠️ No data loaded. Please use the configuration panel to load a dataset and create a collection.") st.info(""" Steps: 1. Select a dataset from the dropdown 2. Click "Load Data & Create Collection" button 3. Wait for the collection to be created 4. Then you can start chatting """) return # Display collection info and LLM selector col1, col2, col3 = st.columns([2, 2, 1]) with col1: st.info(f"📚 Collection: {st.session_state.current_collection}") with col2: # LLM selector for chat selected_llm = st.selectbox( "Select LLM for chat:", settings.llm_models, index=settings.llm_models.index(st.session_state.current_llm), key="chat_llm_selector" ) if selected_llm != st.session_state.current_llm: st.session_state.current_llm = selected_llm # Recreate RAG pipeline with new LLM llm_client = GroqLLMClient( api_key=st.session_state.groq_api_key if "groq_api_key" in st.session_state else "", model_name=selected_llm, max_rpm=settings.groq_rpm_limit, rate_limit_delay=settings.rate_limit_delay ) st.session_state.rag_pipeline.llm_client = llm_client with col3: if st.button("🗑️ Clear History"): st.session_state.chat_history = [] st.session_state.rag_pipeline.clear_history() st.rerun() # Chat container chat_container = st.container() # Display chat history with chat_container: for chat_idx, entry in enumerate(st.session_state.chat_history): # User message with st.chat_message("user"): st.write(entry["query"]) # Assistant message with st.chat_message("assistant"): st.write(entry["response"]) # Show retrieved documents in expander with st.expander("📄 Retrieved Documents"): for doc_idx, doc in enumerate(entry["retrieved_documents"]): st.markdown(f"**Document {doc_idx+1}** (Distance: {doc.get('distance', 'N/A'):.4f})") st.text_area( f"doc_{chat_idx}_{doc_idx}", value=doc["document"], height=100, key=f"doc_area_{chat_idx}_{doc_idx}", label_visibility="collapsed" ) if doc.get("metadata"): st.caption(f"Metadata: {doc['metadata']}") # Chat input query = st.chat_input("Ask a question...") if query: # Check if collection exists if not st.session_state.rag_pipeline or not st.session_state.rag_pipeline.vector_store.current_collection: st.error("❌ No data loaded. Please load a dataset first using the configuration panel.") st.stop() # Add user message with chat_container: with st.chat_message("user"): st.write(query) # Generate response with st.spinner("Generating response..."): try: result = st.session_state.rag_pipeline.query(query) except Exception as e: st.error(f"❌ Error querying: {str(e)}") st.info("Please load a dataset and create a collection first.") st.stop() # Add assistant message with chat_container: with st.chat_message("assistant"): st.write(result["response"]) # Show retrieved documents with st.expander("📄 Retrieved Documents"): for doc_idx, doc in enumerate(result["retrieved_documents"]): st.markdown(f"**Document {doc_idx+1}** (Distance: {doc.get('distance', 'N/A'):.4f})") st.text_area( f"doc_current_{doc_idx}", value=doc["document"], height=100, key=f"doc_current_area_{doc_idx}", label_visibility="collapsed" ) if doc.get("metadata"): st.caption(f"Metadata: {doc['metadata']}") # Store in history st.session_state.chat_history.append(result) st.rerun() def evaluation_interface(): """Evaluation interface tab.""" st.subheader("📊 TRACE Evaluation") # Check if collection is loaded if not st.session_state.collection_loaded: st.warning("⚠️ No data loaded. Please load a collection first.") return # LLM selector for evaluation col1, col2 = st.columns([3, 1]) with col1: selected_llm = st.selectbox( "Select LLM for evaluation:", settings.llm_models, index=settings.llm_models.index(st.session_state.current_llm), key="eval_llm_selector" ) st.markdown(""" Run TRACE evaluation metrics on test data: - **Utilization**: How well the system uses retrieved documents - **Relevance**: Relevance of retrieved documents to the query - **Adherence**: How well the response adheres to the retrieved context - **Completeness**: How complete the response is in answering the query """) num_test_samples = st.slider( "Number of test samples", min_value=5, max_value=50, value=10, step=5 ) if st.button("🔬 Run Evaluation", type="primary"): # Use selected LLM for evaluation run_evaluation(num_test_samples, selected_llm) # Display results if st.session_state.evaluation_results: results = st.session_state.evaluation_results st.success("✅ Evaluation Complete!") # Display aggregate scores col1, col2, col3, col4, col5 = st.columns(5) with col1: st.metric("📊 Utilization", f"{results['utilization']:.3f}") with col2: st.metric("🎯 Relevance", f"{results['relevance']:.3f}") with col3: st.metric("✅ Adherence", f"{results['adherence']:.3f}") with col4: st.metric("📝 Completeness", f"{results['completeness']:.3f}") with col5: st.metric("⭐ Average", f"{results['average']:.3f}") # Detailed results with st.expander("📋 Detailed Results"): df = pd.DataFrame(results["individual_scores"]) st.dataframe(df, use_container_width=True) # Download results results_json = json.dumps(results, indent=2) st.download_button( label="💾 Download Results (JSON)", data=results_json, file_name=f"trace_evaluation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json", mime="application/json" ) def run_evaluation(num_samples: int, selected_llm: str = None): """Run TRACE evaluation.""" with st.spinner(f"Running evaluation on {num_samples} samples..."): try: # Use selected LLM if provided if selected_llm and selected_llm != st.session_state.current_llm: st.info(f"Switching to {selected_llm} for evaluation...") groq_api_key = st.session_state.groq_api_key if "groq_api_key" in st.session_state else "" eval_llm_client = GroqLLMClient( api_key=groq_api_key, model_name=selected_llm, max_rpm=settings.groq_rpm_limit, rate_limit_delay=settings.rate_limit_delay ) # Temporarily replace LLM client original_llm = st.session_state.rag_pipeline.llm_client st.session_state.rag_pipeline.llm_client = eval_llm_client # Get test data loader = RAGBenchLoader() test_data = loader.get_test_data( st.session_state.dataset_name, num_samples ) # Prepare test cases test_cases = [] progress_bar = st.progress(0) for i, sample in enumerate(test_data): # Query the RAG system result = st.session_state.rag_pipeline.query( sample["question"], n_results=5 ) # Prepare test case test_cases.append({ "query": sample["question"], "response": result["response"], "retrieved_documents": [doc["document"] for doc in result["retrieved_documents"]], "ground_truth": sample.get("answer", "") }) # Update progress progress_bar.progress((i + 1) / num_samples) # Run evaluation evaluator = TRACEEvaluator() results = evaluator.evaluate_batch(test_cases) st.session_state.evaluation_results = results # Restore original LLM if it was switched if selected_llm and selected_llm != st.session_state.current_llm: st.session_state.rag_pipeline.llm_client = original_llm except Exception as e: st.error(f"Error during evaluation: {str(e)}") def history_interface(): """History interface tab.""" st.subheader("📜 Chat History") if not st.session_state.chat_history: st.info("No chat history yet. Start a conversation in the Chat tab!") return # Export history col1, col2 = st.columns([3, 1]) with col2: history_json = json.dumps(st.session_state.chat_history, indent=2) st.download_button( label="💾 Export History", data=history_json, file_name=f"chat_history_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json", mime="application/json" ) # Display history for i, entry in enumerate(st.session_state.chat_history): with st.expander(f"💬 Conversation {i+1}: {entry['query'][:50]}..."): st.markdown(f"**Query:** {entry['query']}") st.markdown(f"**Response:** {entry['response']}") st.markdown(f"**Timestamp:** {entry.get('timestamp', 'N/A')}") st.markdown("**Retrieved Documents:**") for j, doc in enumerate(entry["retrieved_documents"]): st.text_area( f"Document {j+1}", value=doc["document"], height=100, key=f"history_doc_{i}_{j}" ) if __name__ == "__main__": main()