| | """Streamlit chat interface for RAG application.""" |
| | import streamlit as st |
| | import sys |
| | import os |
| | from datetime import datetime |
| | import json |
| | import pandas as pd |
| | from typing import Optional |
| | import warnings |
| |
|
| | |
| | warnings.filterwarnings('ignore') |
| | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' |
| |
|
| | |
| | sys.path.append(os.path.dirname(os.path.abspath(__file__))) |
| |
|
| | from config import settings |
| | from dataset_loader import RAGBenchLoader |
| | from vector_store import ChromaDBManager |
| | from llm_client import GroqLLMClient, RAGPipeline |
| | from trace_evaluator import TRACEEvaluator |
| | from embedding_models import EmbeddingFactory |
| | from chunking_strategies import ChunkingFactory |
| |
|
| |
|
| | |
| | st.set_page_config( |
| | page_title="RAG Capstone Project", |
| | page_icon="π€", |
| | layout="wide" |
| | ) |
| |
|
| | |
| | if "chat_history" not in st.session_state: |
| | st.session_state.chat_history = [] |
| |
|
| | if "rag_pipeline" not in st.session_state: |
| | st.session_state.rag_pipeline = None |
| |
|
| | if "vector_store" not in st.session_state: |
| | st.session_state.vector_store = None |
| |
|
| | if "collection_loaded" not in st.session_state: |
| | st.session_state.collection_loaded = False |
| |
|
| | if "evaluation_results" not in st.session_state: |
| | st.session_state.evaluation_results = None |
| |
|
| | if "dataset_size" not in st.session_state: |
| | st.session_state.dataset_size = 10000 |
| |
|
| | if "current_dataset" not in st.session_state: |
| | st.session_state.current_dataset = None |
| |
|
| | if "current_llm" not in st.session_state: |
| | st.session_state.current_llm = settings.llm_models[1] |
| |
|
| | if "selected_collection" not in st.session_state: |
| | st.session_state.selected_collection = None |
| |
|
| | if "available_collections" not in st.session_state: |
| | st.session_state.available_collections = [] |
| |
|
| |
|
| | def get_available_collections(): |
| | """Get list of available collections from ChromaDB.""" |
| | try: |
| | vector_store = ChromaDBManager(settings.chroma_persist_directory) |
| | collections = vector_store.list_collections() |
| | return collections |
| | except Exception as e: |
| | print(f"Error getting collections: {e}") |
| | return [] |
| |
|
| |
|
| | def main(): |
| | """Main Streamlit application.""" |
| | st.title("π€ RAG Capstone Project") |
| | st.markdown("### Retrieval-Augmented Generation with TRACE Evaluation") |
| | |
| | |
| | available_collections = get_available_collections() |
| | st.session_state.available_collections = available_collections |
| | |
| | |
| | with st.sidebar: |
| | st.header("Configuration") |
| | |
| | |
| | groq_api_key = st.text_input( |
| | "Groq API Key", |
| | type="password", |
| | value=settings.groq_api_key or "", |
| | help="Enter your Groq API key" |
| | ) |
| | |
| | st.divider() |
| | |
| | |
| | if available_collections: |
| | st.subheader("π Existing Collections") |
| | st.write(f"Found {len(available_collections)} collection(s)") |
| | |
| | selected_collection = st.selectbox( |
| | "Or select existing collection:", |
| | available_collections, |
| | key="collection_selector" |
| | ) |
| | |
| | if st.button("π Load Existing Collection", type="secondary"): |
| | if not groq_api_key: |
| | st.error("Please enter your Groq API key") |
| | else: |
| | load_existing_collection(groq_api_key, selected_collection) |
| | |
| | st.divider() |
| | |
| | |
| | st.subheader("π Create New Collection") |
| | |
| | |
| | st.subheader("1. Dataset Selection") |
| | dataset_name = st.selectbox( |
| | "Choose Dataset", |
| | settings.ragbench_datasets, |
| | index=0 |
| | ) |
| | |
| | |
| | if st.button("π Check Dataset Size", key="check_size"): |
| | with st.spinner("Checking dataset size..."): |
| | try: |
| | from datasets import load_dataset |
| | import os |
| | |
| | |
| | st.info(f"Fetching dataset info for '{dataset_name}'...") |
| | ds = load_dataset( |
| | "rungalileo/ragbench", |
| | dataset_name, |
| | split="train", |
| | trust_remote_code=True, |
| | download_mode="force_redownload" |
| | ) |
| | dataset_size = len(ds) |
| | |
| | st.session_state.dataset_size = dataset_size |
| | st.session_state.current_dataset = dataset_name |
| | st.success(f"β
Dataset '{dataset_name}' has {dataset_size:,} samples available") |
| | except Exception as e: |
| | st.error(f"β Error: {str(e)}") |
| | st.exception(e) |
| | st.warning(f"Could not determine dataset size. Using default of 10,000.") |
| | st.session_state.dataset_size = 10000 |
| | st.session_state.current_dataset = dataset_name |
| | |
| | |
| | max_samples_available = st.session_state.get('dataset_size', 10000) |
| | |
| | st.caption(f"Max available samples: {max_samples_available:,}") |
| | |
| | num_samples = st.slider( |
| | "Number of samples", |
| | min_value=10, |
| | max_value=max_samples_available, |
| | value=min(100, max_samples_available), |
| | step=50 if max_samples_available > 1000 else 10, |
| | help="Adjust slider to select number of samples" |
| | ) |
| | |
| | load_all_samples = st.checkbox( |
| | "Load all available samples", |
| | value=False, |
| | help="Override slider and load entire dataset" |
| | ) |
| | |
| | st.divider() |
| | |
| | |
| | st.subheader("2. Chunking Strategy") |
| | chunking_strategy = st.selectbox( |
| | "Choose Chunking Strategy", |
| | settings.chunking_strategies, |
| | index=0 |
| | ) |
| | |
| | chunk_size = st.slider( |
| | "Chunk Size", |
| | min_value=256, |
| | max_value=1024, |
| | value=512, |
| | step=128 |
| | ) |
| | |
| | overlap = st.slider( |
| | "Overlap", |
| | min_value=0, |
| | max_value=200, |
| | value=50, |
| | step=10 |
| | ) |
| | |
| | st.divider() |
| | |
| | |
| | st.subheader("3. Embedding Model") |
| | embedding_model = st.selectbox( |
| | "Choose Embedding Model", |
| | settings.embedding_models, |
| | index=0 |
| | ) |
| | |
| | st.divider() |
| | |
| | |
| | st.subheader("4. LLM Model") |
| | llm_model = st.selectbox( |
| | "Choose LLM", |
| | settings.llm_models, |
| | index=1 |
| | ) |
| | |
| | st.divider() |
| | |
| | |
| | if st.button("π Load Data & Create Collection", type="primary"): |
| | if not groq_api_key: |
| | st.error("Please enter your Groq API key") |
| | else: |
| | |
| | samples_to_load = None if load_all_samples else num_samples |
| | load_and_create_collection( |
| | groq_api_key, |
| | dataset_name, |
| | samples_to_load, |
| | chunking_strategy, |
| | chunk_size, |
| | overlap, |
| | embedding_model, |
| | llm_model |
| | ) |
| | |
| | |
| | if not st.session_state.collection_loaded: |
| | st.info("π Please configure and load a dataset from the sidebar to begin") |
| | |
| | |
| | with st.expander("π How to Use", expanded=True): |
| | st.markdown(""" |
| | 1. **Enter your Groq API Key** in the sidebar |
| | 2. **Select a dataset** from RAG Bench |
| | 3. **Choose a chunking strategy** (dense, sparse, hybrid, re-ranking) |
| | 4. **Select an embedding model** for document vectorization |
| | 5. **Choose an LLM model** for response generation |
| | 6. **Click "Load Data & Create Collection"** to initialize |
| | 7. **Start chatting** in the chat interface |
| | 8. **View retrieved documents** and evaluation metrics |
| | 9. **Run TRACE evaluation** on test data |
| | """) |
| | |
| | |
| | col1, col2 = st.columns(2) |
| | |
| | with col1: |
| | st.subheader("π Available Datasets") |
| | for ds in settings.ragbench_datasets: |
| | st.markdown(f"- {ds}") |
| | |
| | with col2: |
| | st.subheader("π€ Available Models") |
| | st.markdown("**Embedding Models:**") |
| | for em in settings.embedding_models: |
| | st.markdown(f"- {em}") |
| | |
| | st.markdown("**LLM Models:**") |
| | for lm in settings.llm_models: |
| | st.markdown(f"- {lm}") |
| | |
| | else: |
| | |
| | tab1, tab2, tab3 = st.tabs(["π¬ Chat", "π Evaluation", "π History"]) |
| | |
| | with tab1: |
| | chat_interface() |
| | |
| | with tab2: |
| | evaluation_interface() |
| | |
| | with tab3: |
| | history_interface() |
| |
|
| |
|
| | def load_existing_collection(api_key: str, collection_name: str): |
| | """Load an existing collection from ChromaDB.""" |
| | with st.spinner(f"Loading collection '{collection_name}'..."): |
| | try: |
| | |
| | vector_store = ChromaDBManager(settings.chroma_persist_directory) |
| | vector_store.get_collection(collection_name) |
| | |
| | |
| | st.session_state.current_llm = st.selectbox( |
| | "Select LLM for this collection:", |
| | settings.llm_models, |
| | key=f"llm_selector_{collection_name}" |
| | ) |
| | |
| | |
| | st.info("Initializing LLM client...") |
| | llm_client = GroqLLMClient( |
| | api_key=api_key, |
| | model_name=st.session_state.current_llm, |
| | max_rpm=settings.groq_rpm_limit, |
| | rate_limit_delay=settings.rate_limit_delay |
| | ) |
| | |
| | |
| | st.info("Creating RAG pipeline...") |
| | rag_pipeline = RAGPipeline( |
| | llm_client=llm_client, |
| | vector_store_manager=vector_store |
| | ) |
| | |
| | |
| | st.session_state.vector_store = vector_store |
| | st.session_state.rag_pipeline = rag_pipeline |
| | st.session_state.collection_loaded = True |
| | st.session_state.current_collection = collection_name |
| | st.session_state.selected_collection = collection_name |
| | st.session_state.groq_api_key = api_key |
| | |
| | st.success(f"β
Collection '{collection_name}' loaded successfully!") |
| | st.rerun() |
| | |
| | except Exception as e: |
| | st.error(f"Error loading collection: {str(e)}") |
| | st.exception(e) |
| |
|
| |
|
| | def load_and_create_collection( |
| | api_key: str, |
| | dataset_name: str, |
| | num_samples: Optional[int], |
| | chunking_strategy: str, |
| | chunk_size: int, |
| | overlap: int, |
| | embedding_model: str, |
| | llm_model: str |
| | ): |
| | """Load dataset and create vector collection.""" |
| | with st.spinner("Loading dataset and creating collection..."): |
| | try: |
| | |
| | loader = RAGBenchLoader() |
| | |
| | |
| | if num_samples is None: |
| | st.info(f"Loading {dataset_name} dataset (all available samples)...") |
| | else: |
| | st.info(f"Loading {dataset_name} dataset ({num_samples} samples)...") |
| | dataset = loader.load_dataset(dataset_name, split="train", max_samples=num_samples) |
| | st.info(f"Loading {dataset_name} dataset...") |
| | dataset = loader.load_dataset(dataset_name, split="train", max_samples=num_samples) |
| | |
| | if not dataset: |
| | st.error("Failed to load dataset") |
| | return |
| | |
| | |
| | st.info("Initializing vector store...") |
| | vector_store = ChromaDBManager(settings.chroma_persist_directory) |
| | |
| | |
| | collection_name = f"{dataset_name}_{chunking_strategy}_{embedding_model.split('/')[-1]}" |
| | collection_name = collection_name.replace("-", "_").replace(".", "_") |
| | |
| | |
| | existing_collections = vector_store.list_collections() |
| | if collection_name in existing_collections: |
| | st.warning(f"Collection '{collection_name}' already exists. Deleting and recreating...") |
| | vector_store.delete_collection(collection_name) |
| | st.info("Old collection deleted. Creating new one...") |
| | |
| | |
| | st.info(f"Creating collection with {chunking_strategy} chunking...") |
| | vector_store.load_dataset_into_collection( |
| | collection_name=collection_name, |
| | embedding_model_name=embedding_model, |
| | chunking_strategy=chunking_strategy, |
| | dataset_data=dataset, |
| | chunk_size=chunk_size, |
| | overlap=overlap |
| | ) |
| | |
| | |
| | st.info("Initializing LLM client...") |
| | llm_client = GroqLLMClient( |
| | api_key=api_key, |
| | model_name=llm_model, |
| | max_rpm=settings.groq_rpm_limit, |
| | rate_limit_delay=settings.rate_limit_delay |
| | ) |
| | |
| | |
| | rag_pipeline = RAGPipeline( |
| | llm_client=llm_client, |
| | vector_store_manager=vector_store |
| | ) |
| | |
| | |
| | st.session_state.vector_store = vector_store |
| | st.session_state.rag_pipeline = rag_pipeline |
| | st.session_state.collection_loaded = True |
| | st.session_state.current_collection = collection_name |
| | st.session_state.dataset_name = dataset_name |
| | st.session_state.dataset = dataset |
| | |
| | st.success(f"β
Collection '{collection_name}' created successfully!") |
| | st.rerun() |
| | |
| | except Exception as e: |
| | st.error(f"Error: {str(e)}") |
| |
|
| |
|
| | def chat_interface(): |
| | """Chat interface tab.""" |
| | st.subheader("π¬ Chat Interface") |
| | |
| | |
| | if not st.session_state.collection_loaded: |
| | st.warning("β οΈ No data loaded. Please use the configuration panel to load a dataset and create a collection.") |
| | st.info(""" |
| | Steps: |
| | 1. Select a dataset from the dropdown |
| | 2. Click "Load Data & Create Collection" button |
| | 3. Wait for the collection to be created |
| | 4. Then you can start chatting |
| | """) |
| | return |
| | |
| | |
| | col1, col2, col3 = st.columns([2, 2, 1]) |
| | with col1: |
| | st.info(f"π Collection: {st.session_state.current_collection}") |
| | |
| | with col2: |
| | |
| | selected_llm = st.selectbox( |
| | "Select LLM for chat:", |
| | settings.llm_models, |
| | index=settings.llm_models.index(st.session_state.current_llm), |
| | key="chat_llm_selector" |
| | ) |
| | |
| | if selected_llm != st.session_state.current_llm: |
| | st.session_state.current_llm = selected_llm |
| | |
| | llm_client = GroqLLMClient( |
| | api_key=st.session_state.groq_api_key if "groq_api_key" in st.session_state else "", |
| | model_name=selected_llm, |
| | max_rpm=settings.groq_rpm_limit, |
| | rate_limit_delay=settings.rate_limit_delay |
| | ) |
| | st.session_state.rag_pipeline.llm_client = llm_client |
| | |
| | with col3: |
| | if st.button("ποΈ Clear History"): |
| | st.session_state.chat_history = [] |
| | st.session_state.rag_pipeline.clear_history() |
| | st.rerun() |
| | |
| | |
| | chat_container = st.container() |
| | |
| | |
| | with chat_container: |
| | for chat_idx, entry in enumerate(st.session_state.chat_history): |
| | |
| | with st.chat_message("user"): |
| | st.write(entry["query"]) |
| | |
| | |
| | with st.chat_message("assistant"): |
| | st.write(entry["response"]) |
| | |
| | |
| | with st.expander("π Retrieved Documents"): |
| | for doc_idx, doc in enumerate(entry["retrieved_documents"]): |
| | st.markdown(f"**Document {doc_idx+1}** (Distance: {doc.get('distance', 'N/A'):.4f})") |
| | st.text_area( |
| | f"doc_{chat_idx}_{doc_idx}", |
| | value=doc["document"], |
| | height=100, |
| | key=f"doc_area_{chat_idx}_{doc_idx}", |
| | label_visibility="collapsed" |
| | ) |
| | if doc.get("metadata"): |
| | st.caption(f"Metadata: {doc['metadata']}") |
| | |
| | |
| | query = st.chat_input("Ask a question...") |
| | |
| | if query: |
| | |
| | if not st.session_state.rag_pipeline or not st.session_state.rag_pipeline.vector_store.current_collection: |
| | st.error("β No data loaded. Please load a dataset first using the configuration panel.") |
| | st.stop() |
| | |
| | |
| | with chat_container: |
| | with st.chat_message("user"): |
| | st.write(query) |
| | |
| | |
| | with st.spinner("Generating response..."): |
| | try: |
| | result = st.session_state.rag_pipeline.query(query) |
| | except Exception as e: |
| | st.error(f"β Error querying: {str(e)}") |
| | st.info("Please load a dataset and create a collection first.") |
| | st.stop() |
| | |
| | |
| | with chat_container: |
| | with st.chat_message("assistant"): |
| | st.write(result["response"]) |
| | |
| | |
| | with st.expander("π Retrieved Documents"): |
| | for doc_idx, doc in enumerate(result["retrieved_documents"]): |
| | st.markdown(f"**Document {doc_idx+1}** (Distance: {doc.get('distance', 'N/A'):.4f})") |
| | st.text_area( |
| | f"doc_current_{doc_idx}", |
| | value=doc["document"], |
| | height=100, |
| | key=f"doc_current_area_{doc_idx}", |
| | label_visibility="collapsed" |
| | ) |
| | if doc.get("metadata"): |
| | st.caption(f"Metadata: {doc['metadata']}") |
| | |
| | |
| | st.session_state.chat_history.append(result) |
| | st.rerun() |
| |
|
| |
|
| | def evaluation_interface(): |
| | """Evaluation interface tab.""" |
| | st.subheader("π TRACE Evaluation") |
| | |
| | |
| | if not st.session_state.collection_loaded: |
| | st.warning("β οΈ No data loaded. Please load a collection first.") |
| | return |
| | |
| | |
| | col1, col2 = st.columns([3, 1]) |
| | with col1: |
| | selected_llm = st.selectbox( |
| | "Select LLM for evaluation:", |
| | settings.llm_models, |
| | index=settings.llm_models.index(st.session_state.current_llm), |
| | key="eval_llm_selector" |
| | ) |
| | |
| | st.markdown(""" |
| | Run TRACE evaluation metrics on test data: |
| | - **Utilization**: How well the system uses retrieved documents |
| | - **Relevance**: Relevance of retrieved documents to the query |
| | - **Adherence**: How well the response adheres to the retrieved context |
| | - **Completeness**: How complete the response is in answering the query |
| | """) |
| | |
| | num_test_samples = st.slider( |
| | "Number of test samples", |
| | min_value=5, |
| | max_value=50, |
| | value=10, |
| | step=5 |
| | ) |
| | |
| | if st.button("π¬ Run Evaluation", type="primary"): |
| | |
| | run_evaluation(num_test_samples, selected_llm) |
| | |
| | |
| | if st.session_state.evaluation_results: |
| | results = st.session_state.evaluation_results |
| | |
| | st.success("β
Evaluation Complete!") |
| | |
| | |
| | col1, col2, col3, col4, col5 = st.columns(5) |
| | |
| | with col1: |
| | st.metric("π Utilization", f"{results['utilization']:.3f}") |
| | with col2: |
| | st.metric("π― Relevance", f"{results['relevance']:.3f}") |
| | with col3: |
| | st.metric("β
Adherence", f"{results['adherence']:.3f}") |
| | with col4: |
| | st.metric("π Completeness", f"{results['completeness']:.3f}") |
| | with col5: |
| | st.metric("β Average", f"{results['average']:.3f}") |
| | |
| | |
| | with st.expander("π Detailed Results"): |
| | df = pd.DataFrame(results["individual_scores"]) |
| | st.dataframe(df, use_container_width=True) |
| | |
| | |
| | results_json = json.dumps(results, indent=2) |
| | st.download_button( |
| | label="πΎ Download Results (JSON)", |
| | data=results_json, |
| | file_name=f"trace_evaluation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json", |
| | mime="application/json" |
| | ) |
| |
|
| |
|
| | def run_evaluation(num_samples: int, selected_llm: str = None): |
| | """Run TRACE evaluation.""" |
| | with st.spinner(f"Running evaluation on {num_samples} samples..."): |
| | try: |
| | |
| | if selected_llm and selected_llm != st.session_state.current_llm: |
| | st.info(f"Switching to {selected_llm} for evaluation...") |
| | groq_api_key = st.session_state.groq_api_key if "groq_api_key" in st.session_state else "" |
| | eval_llm_client = GroqLLMClient( |
| | api_key=groq_api_key, |
| | model_name=selected_llm, |
| | max_rpm=settings.groq_rpm_limit, |
| | rate_limit_delay=settings.rate_limit_delay |
| | ) |
| | |
| | original_llm = st.session_state.rag_pipeline.llm_client |
| | st.session_state.rag_pipeline.llm_client = eval_llm_client |
| | |
| | |
| | loader = RAGBenchLoader() |
| | test_data = loader.get_test_data( |
| | st.session_state.dataset_name, |
| | num_samples |
| | ) |
| | |
| | |
| | test_cases = [] |
| | |
| | progress_bar = st.progress(0) |
| | |
| | for i, sample in enumerate(test_data): |
| | |
| | result = st.session_state.rag_pipeline.query( |
| | sample["question"], |
| | n_results=5 |
| | ) |
| | |
| | |
| | test_cases.append({ |
| | "query": sample["question"], |
| | "response": result["response"], |
| | "retrieved_documents": [doc["document"] for doc in result["retrieved_documents"]], |
| | "ground_truth": sample.get("answer", "") |
| | }) |
| | |
| | |
| | progress_bar.progress((i + 1) / num_samples) |
| | |
| | |
| | evaluator = TRACEEvaluator() |
| | results = evaluator.evaluate_batch(test_cases) |
| | |
| | st.session_state.evaluation_results = results |
| | |
| | |
| | if selected_llm and selected_llm != st.session_state.current_llm: |
| | st.session_state.rag_pipeline.llm_client = original_llm |
| | |
| | except Exception as e: |
| | st.error(f"Error during evaluation: {str(e)}") |
| |
|
| |
|
| | def history_interface(): |
| | """History interface tab.""" |
| | st.subheader("π Chat History") |
| | |
| | if not st.session_state.chat_history: |
| | st.info("No chat history yet. Start a conversation in the Chat tab!") |
| | return |
| | |
| | |
| | col1, col2 = st.columns([3, 1]) |
| | with col2: |
| | history_json = json.dumps(st.session_state.chat_history, indent=2) |
| | st.download_button( |
| | label="πΎ Export History", |
| | data=history_json, |
| | file_name=f"chat_history_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json", |
| | mime="application/json" |
| | ) |
| | |
| | |
| | for i, entry in enumerate(st.session_state.chat_history): |
| | with st.expander(f"π¬ Conversation {i+1}: {entry['query'][:50]}..."): |
| | st.markdown(f"**Query:** {entry['query']}") |
| | st.markdown(f"**Response:** {entry['response']}") |
| | st.markdown(f"**Timestamp:** {entry.get('timestamp', 'N/A')}") |
| | |
| | st.markdown("**Retrieved Documents:**") |
| | for j, doc in enumerate(entry["retrieved_documents"]): |
| | st.text_area( |
| | f"Document {j+1}", |
| | value=doc["document"], |
| | height=100, |
| | key=f"history_doc_{i}_{j}" |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|