"""Streamlit chat interface for RAG application.""" import streamlit as st import sys import os from datetime import datetime import json import pandas as pd from typing import Optional import warnings # Suppress warnings warnings.filterwarnings('ignore') os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Add parent directory to path sys.path.append(os.path.dirname(os.path.abspath(__file__))) # Check if running on HuggingFace Spaces IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None from config import settings from dataset_loader import RAGBenchLoader from vector_store import ChromaDBManager, create_vector_store try: from vector_store import QdrantManager, QDRANT_AVAILABLE except ImportError: QDRANT_AVAILABLE = False from llm_client import GroqLLMClient, OllamaLLMClient, RAGPipeline, create_llm_client from trace_evaluator import TRACEEvaluator from embedding_models import EmbeddingFactory from chunking_strategies import ChunkingFactory import uuid # Background evaluation support try: from background_evaluator import ( create_job, start_background_job, load_job_status, list_jobs, delete_job, STATUS_PENDING, STATUS_RUNNING, STATUS_COMPLETED, STATUS_FAILED ) BACKGROUND_EVAL_AVAILABLE = True except ImportError: BACKGROUND_EVAL_AVAILABLE = False # Page configuration st.set_page_config( page_title="RAG Capstone Project", page_icon="🤖", layout="wide" ) # Initialize session state if "chat_history" not in st.session_state: st.session_state.chat_history = [] if "rag_pipeline" not in st.session_state: st.session_state.rag_pipeline = None if "vector_store" not in st.session_state: st.session_state.vector_store = None if "collection_loaded" not in st.session_state: st.session_state.collection_loaded = False if "evaluation_results" not in st.session_state: st.session_state.evaluation_results = None if "dataset_size" not in st.session_state: st.session_state.dataset_size = 10000 if "current_dataset" not in st.session_state: st.session_state.current_dataset = None if "current_llm" not in st.session_state: st.session_state.current_llm = settings.llm_models[1] if "selected_collection" not in st.session_state: st.session_state.selected_collection = None if "available_collections" not in st.session_state: st.session_state.available_collections = [] if "dataset_name" not in st.session_state: st.session_state.dataset_name = None if "collection_name" not in st.session_state: st.session_state.collection_name = None if "embedding_model" not in st.session_state: st.session_state.embedding_model = None if "groq_api_key" not in st.session_state: st.session_state.groq_api_key = "" if "llm_provider" not in st.session_state: st.session_state.llm_provider = settings.llm_provider if "ollama_model" not in st.session_state: st.session_state.ollama_model = settings.ollama_model if "vector_store_provider" not in st.session_state: st.session_state.vector_store_provider = settings.vector_store_provider if "qdrant_url" not in st.session_state: st.session_state.qdrant_url = settings.qdrant_url if "qdrant_api_key" not in st.session_state: st.session_state.qdrant_api_key = settings.qdrant_api_key def get_available_collections(provider: str = None): """Get list of available collections from vector store.""" provider = provider or st.session_state.get("vector_store_provider", "chroma") try: if provider == "qdrant" and QDRANT_AVAILABLE: qdrant_url = st.session_state.get("qdrant_url") or settings.qdrant_url qdrant_api_key = st.session_state.get("qdrant_api_key") or settings.qdrant_api_key if qdrant_url and qdrant_api_key: vector_store = create_vector_store("qdrant", url=qdrant_url, api_key=qdrant_api_key) collections = vector_store.list_collections() return collections return [] else: vector_store = ChromaDBManager(settings.chroma_persist_directory) collections = vector_store.list_collections() return collections except Exception as e: print(f"Error getting collections: {e}") return [] def main(): """Main Streamlit application.""" st.title("🤖 RAG Capstone Project") st.markdown("### Retrieval-Augmented Generation with TRACE Evaluation") # Show HuggingFace Spaces notice if IS_HUGGINGFACE_SPACE: st.info("🤗 Running on Hugging Face Spaces - Using Groq API (cloud-based LLM)") # Get available collections at startup available_collections = get_available_collections() st.session_state.available_collections = available_collections # Sidebar for configuration with st.sidebar: st.header("Configuration") # LLM Provider Selection - Disable Ollama on HuggingFace Spaces st.subheader("🔌 LLM Provider") if IS_HUGGINGFACE_SPACE: # Force Groq on HuggingFace Spaces (Ollama not available) st.caption("â˜ī¸ **Groq API** (Ollama unavailable on Spaces)") llm_provider = "groq" st.session_state.llm_provider = "groq" else: llm_provider = st.radio( "Choose LLM Provider:", options=["groq", "ollama"], index=0 if st.session_state.llm_provider == "groq" else 1, format_func=lambda x: "â˜ī¸ Groq API (Cloud)" if x == "groq" else "đŸ–Ĩī¸ Ollama (Local)", help="Groq: Cloud API with rate limits. Ollama: Local unlimited inference.", key="llm_provider_radio" ) st.session_state.llm_provider = llm_provider # Provider-specific settings if llm_provider == "groq": st.caption("âš ī¸ Free tier: 30 requests/min") # On HuggingFace Spaces, check for API key in secrets first default_api_key = os.environ.get("GROQ_API_KEY", "") or settings.groq_api_key or "" # API Key input groq_api_key = st.text_input( "Groq API Key", type="password", value=default_api_key, help="Enter your Groq API key (or set GROQ_API_KEY in Spaces secrets)" ) if IS_HUGGINGFACE_SPACE and not groq_api_key: st.warning("💡 Tip: Add GROQ_API_KEY to your Space secrets for persistence") else: # Ollama settings (only available locally) st.caption("✅ No rate limits - unlimited usage!") ollama_host = st.text_input( "Ollama Host", value=settings.ollama_host, help="Ollama server URL (default: http://localhost:11434)" ) ollama_model = st.selectbox( "Select Ollama Model:", options=settings.ollama_models, index=settings.ollama_models.index(st.session_state.ollama_model) if st.session_state.ollama_model in settings.ollama_models else 0, key="ollama_model_selector" ) st.session_state.ollama_model = ollama_model # Connection check button if st.button("🔍 Check Ollama Connection"): try: import requests response = requests.get(f"{ollama_host}/api/tags", timeout=5) if response.status_code == 200: models = response.json().get("models", []) model_names = [m["name"] for m in models] st.success(f"✅ Connected! Available models: {', '.join(model_names)}") else: st.error(f"❌ Connection failed: {response.status_code}") except Exception as e: st.error(f"❌ Cannot connect to Ollama: {e}") st.info("Make sure Ollama is running: `ollama serve`") groq_api_key = "" # Not needed for Ollama st.divider() # Vector Store Provider Selection st.subheader("💾 Vector Store") if IS_HUGGINGFACE_SPACE: st.caption("â˜ī¸ Use **Qdrant Cloud** for persistent storage") vector_store_options = ["qdrant", "chroma"] default_idx = 0 else: vector_store_options = ["chroma", "qdrant"] default_idx = 0 vector_store_provider = st.radio( "Choose Vector Store:", options=vector_store_options, index=default_idx, format_func=lambda x: "â˜ī¸ Qdrant Cloud (Persistent)" if x == "qdrant" else "💾 ChromaDB (Local)", help="Qdrant: Cloud storage (persistent). ChromaDB: Local storage (ephemeral on Spaces).", key="vector_store_radio" ) st.session_state.vector_store_provider = vector_store_provider # Qdrant settings if vector_store_provider == "qdrant": default_qdrant_url = os.environ.get("QDRANT_URL", "") or settings.qdrant_url default_qdrant_key = os.environ.get("QDRANT_API_KEY", "") or settings.qdrant_api_key qdrant_url = st.text_input( "Qdrant URL", value=default_qdrant_url, placeholder="https://xxx-xxx.aws.cloud.qdrant.io:6333", help="Your Qdrant Cloud cluster URL" ) qdrant_api_key = st.text_input( "Qdrant API Key", type="password", value=default_qdrant_key, help="Your Qdrant API key" ) st.session_state.qdrant_url = qdrant_url st.session_state.qdrant_api_key = qdrant_api_key if not qdrant_url or not qdrant_api_key: st.warning("âš ī¸ Get free Qdrant Cloud at: https://cloud.qdrant.io") # Test Qdrant connection if st.button("🔍 Test Qdrant Connection"): if qdrant_url and qdrant_api_key: try: test_store = create_vector_store("qdrant", url=qdrant_url, api_key=qdrant_api_key) collections = test_store.list_collections() st.success(f"✅ Connected! Found {len(collections)} collection(s)") except Exception as e: st.error(f"❌ Connection failed: {e}") else: st.error("Please enter Qdrant URL and API Key") st.divider() # Get available collections based on provider available_collections = get_available_collections(vector_store_provider) st.session_state.available_collections = available_collections # Option 1: Use existing collection if available_collections: st.subheader("📚 Existing Collections") st.write(f"Found {len(available_collections)} collection(s)") selected_collection = st.selectbox( "Or select existing collection:", available_collections, key="collection_selector" ) if st.button("📖 Load Existing Collection", type="secondary"): # Validate based on provider if llm_provider == "groq" and not groq_api_key: st.error("Please enter your Groq API key") elif vector_store_provider == "qdrant" and (not st.session_state.get("qdrant_url") or not st.session_state.get("qdrant_api_key")): st.error("Please enter Qdrant URL and API Key") else: load_existing_collection( groq_api_key, selected_collection, llm_provider, ollama_host if llm_provider == "ollama" else None, vector_store_provider ) st.divider() # Option 2: Create new collection st.subheader("🆕 Create New Collection") # Dataset selection st.subheader("1. Dataset Selection") dataset_name = st.selectbox( "Choose Dataset", settings.ragbench_datasets, index=0 ) # Get dataset size dynamically if st.button("🔍 Check Dataset Size", key="check_size"): with st.spinner("Checking dataset size..."): try: from datasets import load_dataset # Load dataset with download_mode to avoid cache issues st.info(f"Fetching dataset info for '{dataset_name}'...") ds = load_dataset( "rungalileo/ragbench", dataset_name, split="train", trust_remote_code=True, download_mode="force_redownload" # Force fresh download to avoid cache corruption ) dataset_size = len(ds) st.session_state.dataset_size = dataset_size st.session_state.current_dataset = dataset_name st.success(f"✅ Dataset '{dataset_name}' has {dataset_size:,} samples available") except Exception as e: st.error(f"❌ Error: {str(e)}") st.exception(e) st.warning(f"Could not determine dataset size. Using default of 10,000.") st.session_state.dataset_size = 10000 st.session_state.current_dataset = dataset_name # Use stored dataset size or default max_samples_available = st.session_state.get('dataset_size', 10000) st.caption(f"Max available samples: {max_samples_available:,}") num_samples = st.slider( "Number of samples", min_value=10, max_value=max_samples_available, value=min(100, max_samples_available), step=50 if max_samples_available > 1000 else 10, help="Adjust slider to select number of samples" ) load_all_samples = st.checkbox( "Load all available samples", value=False, help="Override slider and load entire dataset" ) st.divider() # Chunking strategy st.subheader("2. Chunking Strategy") chunking_strategy = st.selectbox( "Choose Chunking Strategy", settings.chunking_strategies, index=0 ) chunk_size = st.slider( "Chunk Size", min_value=256, max_value=1024, value=512, step=128 ) overlap = st.slider( "Overlap", min_value=0, max_value=200, value=50, step=10 ) st.divider() # Embedding model st.subheader("3. Embedding Model") embedding_model = st.selectbox( "Choose Embedding Model", settings.embedding_models, index=0 ) st.divider() # LLM model selection for new collection st.subheader("4. LLM Model") if llm_provider == "groq": llm_model = st.selectbox( "Choose Groq LLM", settings.llm_models, index=1 ) else: llm_model = st.selectbox( "Choose Ollama Model", settings.ollama_models, index=settings.ollama_models.index(st.session_state.ollama_model) if st.session_state.ollama_model in settings.ollama_models else 0, key="llm_model_ollama" ) st.divider() # Load data button if st.button("🚀 Load Data & Create Collection", type="primary"): # Validate based on provider if llm_provider == "groq" and not groq_api_key: st.error("Please enter your Groq API key") elif vector_store_provider == "qdrant" and (not st.session_state.get("qdrant_url") or not st.session_state.get("qdrant_api_key")): st.error("Please enter Qdrant URL and API Key") else: # Use None for num_samples if loading all data samples_to_load = None if load_all_samples else num_samples load_and_create_collection( groq_api_key, dataset_name, samples_to_load, chunking_strategy, chunk_size, overlap, embedding_model, llm_model, llm_provider, ollama_host if llm_provider == "ollama" else None, vector_store_provider ) # Main content area if not st.session_state.collection_loaded: st.info("👈 Please configure and load a dataset from the sidebar to begin") # Show instructions with st.expander("📖 How to Use", expanded=True): st.markdown(""" 1. **Enter your Groq API Key** in the sidebar 2. **Select a dataset** from RAG Bench 3. **Choose a chunking strategy** (dense, sparse, hybrid, re-ranking) 4. **Select an embedding model** for document vectorization 5. **Choose an LLM model** for response generation 6. **Click "Load Data & Create Collection"** to initialize 7. **Start chatting** in the chat interface 8. **View retrieved documents** and evaluation metrics 9. **Run TRACE evaluation** on test data """) # Show available options col1, col2 = st.columns(2) with col1: st.subheader("📊 Available Datasets") for ds in settings.ragbench_datasets: st.markdown(f"- {ds}") with col2: st.subheader("🤖 Available Models") st.markdown("**Embedding Models:**") for em in settings.embedding_models: st.markdown(f"- {em}") st.markdown("**LLM Models:**") for lm in settings.llm_models: st.markdown(f"- {lm}") else: # Create tabs for different functionalities tab1, tab2, tab3 = st.tabs(["đŸ’Ŧ Chat", "📊 Evaluation", "📜 History"]) with tab1: chat_interface() with tab2: evaluation_interface() with tab3: history_interface() def load_existing_collection(api_key: str, collection_name: str, llm_provider: str = "groq", ollama_host: str = None, vector_store_provider: str = "chroma"): """Load an existing collection from vector store.""" with st.spinner(f"Loading collection '{collection_name}'..."): try: # Initialize vector store based on provider if vector_store_provider == "qdrant": qdrant_url = st.session_state.get("qdrant_url") or settings.qdrant_url qdrant_api_key = st.session_state.get("qdrant_api_key") or settings.qdrant_api_key vector_store = create_vector_store("qdrant", url=qdrant_url, api_key=qdrant_api_key) else: vector_store = ChromaDBManager(settings.chroma_persist_directory) vector_store.get_collection(collection_name) # Extract dataset name from collection name (format: dataset_name_strategy_model) # Try to find which dataset this collection is based on dataset_name = None for ds in settings.ragbench_datasets: if collection_name.startswith(ds.replace("-", "_")): dataset_name = ds break if not dataset_name: dataset_name = collection_name.split("_")[0] # Fallback: use first part # Prompt for LLM selection based on provider if llm_provider == "groq": st.session_state.current_llm = st.selectbox( "Select Groq LLM for this collection:", settings.llm_models, key=f"llm_selector_{collection_name}" ) else: st.session_state.current_llm = st.selectbox( "Select Ollama Model for this collection:", settings.ollama_models, key=f"ollama_selector_{collection_name}" ) # Initialize LLM client based on provider st.info(f"Initializing LLM client ({llm_provider})...") llm_client = create_llm_client( provider=llm_provider, api_key=api_key, api_keys=settings.groq_api_keys if settings.groq_api_keys else None, model_name=st.session_state.current_llm, ollama_host=ollama_host or settings.ollama_host, max_rpm=settings.groq_rpm_limit, rate_limit_delay=settings.rate_limit_delay, max_retries=settings.max_retries, retry_delay=settings.retry_delay ) # Create RAG pipeline with correct parameter names st.info("Creating RAG pipeline...") rag_pipeline = RAGPipeline( llm_client=llm_client, vector_store_manager=vector_store ) # Store in session state st.session_state.vector_store = vector_store st.session_state.rag_pipeline = rag_pipeline st.session_state.collection_loaded = True st.session_state.current_collection = collection_name st.session_state.selected_collection = collection_name st.session_state.groq_api_key = api_key st.session_state.dataset_name = dataset_name st.session_state.collection_name = collection_name st.session_state.llm_provider = llm_provider # Extract embedding model from collection metadata or vector store embedding_model_name = None # For ChromaDB: check collection metadata if hasattr(vector_store, 'current_collection') and vector_store.current_collection: if hasattr(vector_store.current_collection, 'metadata'): collection_metadata = vector_store.current_collection.metadata if collection_metadata and "embedding_model" in collection_metadata: embedding_model_name = collection_metadata["embedding_model"] # For Qdrant or fallback: check if embedding_model was loaded on the vector store if not embedding_model_name and hasattr(vector_store, 'embedding_model') and vector_store.embedding_model: if hasattr(vector_store.embedding_model, 'model_name'): embedding_model_name = vector_store.embedding_model.model_name # Set session state if embedding_model_name: st.session_state.embedding_model = embedding_model_name else: st.session_state.embedding_model = None # Display system prompt and model info provider_icon = "â˜ī¸" if llm_provider == "groq" else "đŸ–Ĩī¸" st.success(f"✅ Collection '{collection_name}' loaded successfully! {provider_icon} Using {llm_provider.upper()}") with st.expander("🤖 Model & System Prompt Information", expanded=False): col1, col2 = st.columns(2) with col1: st.write(f"**Provider:** {provider_icon} {llm_provider.upper()}") st.write(f"**Model:** {st.session_state.current_llm}") st.write(f"**Collection:** {collection_name}") st.write(f"**Dataset:** {dataset_name}") with col2: st.write(f"**Temperature:** 0.0") st.write(f"**Max Tokens:** 2048") if llm_provider == "groq": st.write(f"**Rate Limit:** {settings.groq_rpm_limit} RPM") else: st.write(f"**Rate Limit:** ✅ Unlimited (Local)") st.markdown("#### System Prompt") st.info(""" You are a Fact-Checking and Citation Specialist. Your task is to perform a rigorous audit of a response against provided documents to determine its accuracy, relevance, and level of support. **Task:** 1. Analyze the provided documents and identify information relevant to the user's question 2. Evaluate the response sentence-by-sentence 3. Verify each response sentence maps to supporting document sentences 4. Identify which document sentences were actually used in the response """) st.rerun() except Exception as e: st.error(f"Error loading collection: {str(e)}") st.exception(e) def load_and_create_collection( api_key: str, dataset_name: str, num_samples: Optional[int], chunking_strategy: str, chunk_size: int, overlap: int, embedding_model: str, llm_model: str, llm_provider: str = "groq", ollama_host: str = None, vector_store_provider: str = "chroma" ): """Load dataset and create vector collection.""" with st.spinner("Loading dataset and creating collection..."): try: # Initialize dataset loader loader = RAGBenchLoader() # Load dataset if num_samples is None: st.info(f"Loading {dataset_name} dataset (all available samples)...") else: st.info(f"Loading {dataset_name} dataset ({num_samples} samples)...") dataset = loader.load_dataset(dataset_name, split="train", max_samples=num_samples) if not dataset: st.error("Failed to load dataset") return # Initialize vector store based on provider st.info(f"Initializing vector store ({vector_store_provider})...") if vector_store_provider == "qdrant": qdrant_url = st.session_state.get("qdrant_url") or settings.qdrant_url qdrant_api_key = st.session_state.get("qdrant_api_key") or settings.qdrant_api_key vector_store = create_vector_store("qdrant", url=qdrant_url, api_key=qdrant_api_key) else: vector_store = ChromaDBManager(settings.chroma_persist_directory) # Create collection name collection_name = f"{dataset_name}_{chunking_strategy}_{embedding_model.split('/')[-1]}" collection_name = collection_name.replace("-", "_").replace(".", "_") # Delete existing collection with same name (if exists) existing_collections = vector_store.list_collections() if collection_name in existing_collections: st.warning(f"Collection '{collection_name}' already exists. Deleting and recreating...") vector_store.delete_collection(collection_name) st.info("Old collection deleted. Creating new one...") # Load data into collection st.info(f"Creating collection with {chunking_strategy} chunking...") vector_store.load_dataset_into_collection( collection_name=collection_name, embedding_model_name=embedding_model, chunking_strategy=chunking_strategy, dataset_data=dataset, chunk_size=chunk_size, overlap=overlap ) # Initialize LLM client based on provider st.info(f"Initializing LLM client ({llm_provider})...") llm_client = create_llm_client( provider=llm_provider, api_key=api_key, api_keys=settings.groq_api_keys if settings.groq_api_keys else None, model_name=llm_model, ollama_host=ollama_host or settings.ollama_host, max_rpm=settings.groq_rpm_limit, rate_limit_delay=settings.rate_limit_delay, max_retries=settings.max_retries, retry_delay=settings.retry_delay ) # Create RAG pipeline with correct parameter names rag_pipeline = RAGPipeline( llm_client=llm_client, vector_store_manager=vector_store ) # Store in session state st.session_state.vector_store = vector_store st.session_state.rag_pipeline = rag_pipeline st.session_state.collection_loaded = True st.session_state.current_collection = collection_name st.session_state.dataset_name = dataset_name st.session_state.dataset = dataset st.session_state.collection_name = collection_name st.session_state.embedding_model = embedding_model st.session_state.groq_api_key = api_key st.session_state.llm_provider = llm_provider st.session_state.vector_store_provider = vector_store_provider provider_icon = "â˜ī¸" if llm_provider == "groq" else "đŸ–Ĩī¸" vs_icon = "â˜ī¸" if vector_store_provider == "qdrant" else "💾" st.success(f"✅ Collection '{collection_name}' created successfully! {provider_icon} Using {llm_provider.upper()}") st.rerun() except Exception as e: st.error(f"Error: {str(e)}") def chat_interface(): """Chat interface tab.""" st.subheader("đŸ’Ŧ Chat Interface") # Check if collection is loaded if not st.session_state.collection_loaded: st.warning("âš ī¸ No data loaded. Please use the configuration panel to load a dataset and create a collection.") st.info(""" Steps: 1. Select a dataset from the dropdown 2. Click "Load Data & Create Collection" button 3. Wait for the collection to be created 4. Then you can start chatting """) return # Display collection info and LLM selector col1, col2, col3 = st.columns([2, 2, 1]) with col1: provider_icon = "â˜ī¸" if st.session_state.get("llm_provider", "groq") == "groq" else "đŸ–Ĩī¸" st.info(f"📚 Collection: {st.session_state.current_collection} | {provider_icon} {st.session_state.get('llm_provider', 'groq').upper()}") with col2: # LLM selector for chat - based on provider current_provider = st.session_state.get("llm_provider", "groq") if current_provider == "groq": model_options = settings.llm_models try: current_index = settings.llm_models.index(st.session_state.current_llm) except ValueError: current_index = 0 else: model_options = settings.ollama_models try: current_index = settings.ollama_models.index(st.session_state.current_llm) except ValueError: current_index = 0 selected_llm = st.selectbox( f"Select {'Groq' if current_provider == 'groq' else 'Ollama'} Model for chat:", model_options, index=current_index, key="chat_llm_selector" ) if selected_llm != st.session_state.current_llm: st.session_state.current_llm = selected_llm # Recreate LLM client with new model llm_client = create_llm_client( provider=current_provider, api_key=st.session_state.groq_api_key if "groq_api_key" in st.session_state else "", api_keys=settings.groq_api_keys if settings.groq_api_keys else None, model_name=selected_llm, ollama_host=settings.ollama_host, max_rpm=settings.groq_rpm_limit, rate_limit_delay=settings.rate_limit_delay ) st.session_state.rag_pipeline.llm = llm_client with col3: if st.button("đŸ—‘ī¸ Clear History"): st.session_state.chat_history = [] st.session_state.rag_pipeline.clear_history() st.rerun() # Show system prompt info in expandable section with st.expander("🤖 System Prompt & Model Info", expanded=False): current_provider = st.session_state.get("llm_provider", "groq") col1, col2 = st.columns(2) with col1: provider_icon = "â˜ī¸" if current_provider == "groq" else "đŸ–Ĩī¸" st.write(f"**Provider:** {provider_icon} {current_provider.upper()}") st.write(f"**LLM Model:** {st.session_state.current_llm}") st.write(f"**Temperature:** 0.0") st.write(f"**Max Tokens:** 2048") with col2: st.write(f"**Collection:** {st.session_state.current_collection}") st.write(f"**Dataset:** {st.session_state.get('dataset_name', 'N/A')}") if current_provider == "groq": st.write(f"**Rate Limit:** {settings.groq_rpm_limit} RPM") else: st.write(f"**Rate Limit:** ✅ Unlimited (Local)") st.markdown("#### System Prompt Being Used") system_prompt = """You are a Fact-Checking and Citation Specialist. Your task is to perform a rigorous audit of a response against provided documents to determine its accuracy, relevance, and level of support. **TASK OVERVIEW** 1. **Analyze Documents**: Review the provided documents and identify information relevant to the user's question. 2. **Evaluate Response**: Review the provided answer sentence-by-sentence. 3. **Verify Support**: Map each answer sentence to specific supporting sentences in the documents. 4. **Identify Utilization**: Determine which document sentences were actually used (directly or implicitly) to form the answer.""" st.info(system_prompt) # Chat container chat_container = st.container() # Display chat history with chat_container: for chat_idx, entry in enumerate(st.session_state.chat_history): # User message with st.chat_message("user"): st.write(entry["query"]) # Assistant message with st.chat_message("assistant"): st.write(entry["response"]) # Show retrieved documents in expander with st.expander("📄 Retrieved Documents"): for doc_idx, doc in enumerate(entry["retrieved_documents"]): st.markdown(f"**Document {doc_idx+1}** (Distance: {doc.get('distance', 'N/A'):.4f})") st.text_area( f"doc_{chat_idx}_{doc_idx}", value=doc["document"], height=100, key=f"doc_area_{chat_idx}_{doc_idx}", label_visibility="collapsed" ) if doc.get("metadata"): st.caption(f"Metadata: {doc['metadata']}") # Chat input query = st.chat_input("Ask a question...") if query: # Check if collection exists if not st.session_state.rag_pipeline or not st.session_state.rag_pipeline.vector_store.current_collection: st.error("❌ No data loaded. Please load a dataset first using the configuration panel.") st.stop() # Add user message with chat_container: with st.chat_message("user"): st.write(query) # Generate response with st.spinner("Generating response..."): try: result = st.session_state.rag_pipeline.query(query) except Exception as e: st.error(f"❌ Error querying: {str(e)}") st.info("Please load a dataset and create a collection first.") st.stop() # Add assistant message with chat_container: with st.chat_message("assistant"): st.write(result["response"]) # Show retrieved documents with st.expander("📄 Retrieved Documents"): for doc_idx, doc in enumerate(result["retrieved_documents"]): st.markdown(f"**Document {doc_idx+1}** (Distance: {doc.get('distance', 'N/A'):.4f})") st.text_area( f"doc_current_{doc_idx}", value=doc["document"], height=100, key=f"doc_current_area_{doc_idx}", label_visibility="collapsed" ) if doc.get("metadata"): st.caption(f"Metadata: {doc['metadata']}") # Store in history st.session_state.chat_history.append(result) st.rerun() def evaluation_interface(): """Evaluation interface tab.""" st.subheader("📊 RAG Evaluation") # Check if collection is loaded if not st.session_state.collection_loaded: st.warning("âš ī¸ No data loaded. Please load a collection first.") return # Evaluation method selector eval_col1, eval_col2 = st.columns([2, 1]) with eval_col1: evaluation_method = st.radio( "Evaluation Method:", options=["TRACE (Heuristic)", "GPT Labeling (LLM-based)", "Hybrid (Both)"], horizontal=True, help="TRACE is fast (no LLM). GPT Labeling is accurate but slower (requires LLM calls)." ) # Map UI labels to method IDs method_map = { "TRACE (Heuristic)": "trace", "GPT Labeling (LLM-based)": "gpt_labeling", "Hybrid (Both)": "hybrid" } selected_method = method_map[evaluation_method] # LLM selector for evaluation current_provider = st.session_state.get("llm_provider", "groq") col1, col2 = st.columns([3, 1]) with col1: # Show provider-specific models if current_provider == "groq": model_options = settings.llm_models try: current_index = settings.llm_models.index(st.session_state.current_llm) except ValueError: current_index = 0 else: model_options = settings.ollama_models try: current_index = settings.ollama_models.index(st.session_state.current_llm) except ValueError: current_index = 0 selected_llm = st.selectbox( f"Select {'Groq' if current_provider == 'groq' else 'Ollama'} Model for evaluation:", model_options, index=current_index, key="eval_llm_selector" ) # Show provider info provider_icon = "â˜ī¸" if current_provider == "groq" else "đŸ–Ĩī¸" if current_provider == "ollama": st.caption(f"{provider_icon} Using local Ollama - **No rate limits!** Fast evaluation possible.") else: st.caption(f"{provider_icon} Using Groq API - Rate limited to {settings.groq_rpm_limit} RPM") # Show method description method_descriptions = { "trace": """ **TRACE Heuristic Method** (Fast, Rule-Based) - Utilization: How well the system uses retrieved documents - Relevance: Relevance of retrieved documents to the query - Adherence: How well the response adheres to the retrieved context - Completeness: How complete the response is in answering the query - ⚡ Speed: ~100ms per evaluation - 💰 Cost: Free (no API calls) """, "gpt_labeling": """ **GPT Labeling Method** (Accurate, LLM-based) - Uses sentence-level LLM analysis (from RAGBench paper) - Context Relevance: Fraction of context relevant to query - Context Utilization: Fraction of relevant context used - Completeness: Fraction of relevant info covered - Adherence: Response supported by context (no hallucinations) - âąī¸ Speed: ~2-5 seconds per evaluation - 💰 Cost: ~$0.002-0.01 per evaluation """, "hybrid": """ **Hybrid Method** (Comprehensive) - Runs both TRACE and GPT Labeling methods - Provides both fast and accurate evaluation metrics - Best for detailed analysis - âąī¸ Speed: ~3-6 seconds per evaluation - 💰 Cost: Same as GPT Labeling """ } st.markdown(method_descriptions[selected_method]) # Get maximum test samples available for current dataset try: loader = RAGBenchLoader() max_test_samples = loader.get_test_data_size(st.session_state.dataset_name) st.caption(f"📊 Available test samples: {max_test_samples:,}") except Exception as e: max_test_samples = 100 st.caption(f"Available test samples: ~{max_test_samples} (estimated)") # Ensure min and max are reasonable max_test_samples = max(5, min(max_test_samples, 500)) # Cap at 500 for performance num_test_samples = st.slider( "Number of test samples", min_value=5, max_value=max_test_samples, value=min(10, max_test_samples), step=5 ) # Show warning for GPT labeling (API cost) - only for Groq if selected_method in ["gpt_labeling", "hybrid"]: current_provider = st.session_state.get("llm_provider", "groq") if current_provider == "groq": st.warning(f"âš ī¸ **{evaluation_method}** requires LLM API calls. This will incur costs and be slower due to rate limiting ({settings.groq_rpm_limit} RPM).") else: st.info(f"â„šī¸ **{evaluation_method}** using local Ollama - **No rate limits!** Evaluation will be much faster.") # Background evaluation toggle run_in_background = st.checkbox( "🔄 Run in Background", value=False, help="Run evaluation in background. You can close the browser and check results later." ) col_run, col_jobs = st.columns([1, 1]) with col_run: if st.button("đŸ”Ŧ Run Evaluation", type="primary"): if run_in_background: # Start background job start_background_evaluation( num_test_samples, selected_llm, selected_method ) else: # Use selected LLM for evaluation run_evaluation(num_test_samples, selected_llm, selected_method) with col_jobs: if st.button("📋 View Background Jobs"): st.session_state.show_background_jobs = True # Show background jobs panel if st.session_state.get("show_background_jobs", False): display_background_jobs() # Display results if st.session_state.evaluation_results: results = st.session_state.evaluation_results st.success("✅ Evaluation Complete!") st.divider() st.markdown("## 📊 Evaluation Metrics") # Display aggregate scores - handle both TRACE and GPT Labeling metric names st.markdown("### Main Metrics") col1, col2, col3, col4, col5 = st.columns(5) # Determine which metrics are available utilization = results.get('utilization') or results.get('context_utilization', 0) relevance = results.get('relevance') or results.get('context_relevance', 0) adherence = results.get('adherence', 0) completeness = results.get('completeness', 0) average = results.get('average', 0) with col1: st.metric("📊 Utilization", f"{utilization:.3f}") with col2: st.metric("đŸŽ¯ Relevance", f"{relevance:.3f}") with col3: st.metric("✅ Adherence", f"{adherence:.3f}") with col4: st.metric("📝 Completeness", f"{completeness:.3f}") with col5: st.metric("⭐ Average", f"{average:.3f}") # Detailed results summary - handle both metric types if "individual_scores" in results: with st.expander("📋 Summary Metrics by Query"): df = pd.DataFrame(results["individual_scores"]) st.dataframe(df, use_container_width=True) # Detailed per-query results if "detailed_results" in results and results["detailed_results"]: with st.expander("🔍 Detailed Per-Query Analysis"): for query_result in results.get("detailed_results", []): with st.expander(f"Query {query_result['query_id']}: {query_result['question'][:60]}..."): st.markdown("### Question") st.write(query_result['question']) st.markdown("### LLM Response") st.write(query_result.get('llm_response', 'N/A')) st.markdown("### Retrieved Documents") for doc_idx, doc in enumerate(query_result.get('retrieved_documents', []), 1): with st.expander(f"📄 Document {doc_idx}"): st.write(doc) if query_result.get('ground_truth'): st.markdown("### Ground Truth") st.write(query_result['ground_truth']) # Display metrics with correct labels based on method metrics = query_result.get('metrics', {}) if metrics: st.markdown("### Evaluation Metrics") col1, col2, col3, col4, col5 = st.columns(5) # Get metric values (handle both TRACE and GPT names) util_val = metrics.get('utilization') or metrics.get('context_utilization', 0) rel_val = metrics.get('relevance') or metrics.get('context_relevance', 0) adh_val = metrics.get('adherence', 0) comp_val = metrics.get('completeness', 0) avg_val = metrics.get('average', 0) with col1: st.metric("Util", f"{util_val:.3f}") with col2: st.metric("Rel", f"{rel_val:.3f}") with col3: st.metric("Adh", f"{adh_val:.3f}") with col4: st.metric("Comp", f"{comp_val:.3f}") with col5: st.metric("Avg", f"{avg_val:.3f}") # For GPT Labeling and Hybrid methods, show additional metrics method = results.get("method", "") if "gpt_labeling" in method or "hybrid" in method: # Show RMSE aggregation metrics (consistency across evaluations) if "rmse_metrics" in results: st.markdown("### 📊 RMSE Aggregation (Metric Consistency)") rmse_data = results.get("rmse_metrics", {}) rmse_cols = st.columns(4) with rmse_cols[0]: rel_mean = rmse_data.get("context_relevance", {}).get("mean", 0) rel_std = rmse_data.get("context_relevance", {}).get("std_dev", 0) st.metric("Relevance", f"{rel_mean:.3f} Âą{rel_std:.3f}", help="Mean and Std Dev") with rmse_cols[1]: util_mean = rmse_data.get("context_utilization", {}).get("mean", 0) util_std = rmse_data.get("context_utilization", {}).get("std_dev", 0) st.metric("Utilization", f"{util_mean:.3f} Âą{util_std:.3f}", help="Mean and Std Dev") with rmse_cols[2]: comp_mean = rmse_data.get("completeness", {}).get("mean", 0) comp_std = rmse_data.get("completeness", {}).get("std_dev", 0) st.metric("Completeness", f"{comp_mean:.3f} Âą{comp_std:.3f}", help="Mean and Std Dev") with rmse_cols[3]: adh_mean = rmse_data.get("adherence", {}).get("mean", 0) adh_std = rmse_data.get("adherence", {}).get("std_dev", 0) st.metric("Adherence", f"{adh_mean:.3f} Âą{adh_std:.3f}", help="Mean and Std Dev") # Show detailed RMSE statistics in expander with st.expander("See detailed RMSE aggregation statistics"): for metric_name, metric_data in rmse_data.items(): st.write(f"**{metric_name}**") col1, col2, col3, col4 = st.columns(4) with col1: st.write(f"Mean: {metric_data.get('mean', 0):.4f}") with col2: st.write(f"Std Dev: {metric_data.get('std_dev', 0):.4f}") with col3: st.write(f"Min: {metric_data.get('min', 0):.4f}") with col4: st.write(f"Max: {metric_data.get('max', 0):.4f}") # Show per-metric statistics if available if "per_metric_statistics" in results: st.markdown("### 📈 Per-Metric Statistics (Distribution)") stats_data = results.get("per_metric_statistics", {}) stats_cols = st.columns(4) with stats_cols[0]: rel_stats = stats_data.get("context_relevance", {}) st.metric("Relevance Mean", f"{rel_stats.get('mean', 0):.3f}", help=f"Median: {rel_stats.get('median', 0):.3f}") with stats_cols[1]: util_stats = stats_data.get("context_utilization", {}) st.metric("Utilization Mean", f"{util_stats.get('mean', 0):.3f}", help=f"Median: {util_stats.get('median', 0):.3f}") with stats_cols[2]: comp_stats = stats_data.get("completeness", {}) st.metric("Completeness Mean", f"{comp_stats.get('mean', 0):.3f}", help=f"Median: {comp_stats.get('median', 0):.3f}") with stats_cols[3]: adh_stats = stats_data.get("adherence", {}) st.metric("Adherence Mean", f"{adh_stats.get('mean', 0):.3f}", help=f"Median: {adh_stats.get('median', 0):.3f}") # Show detailed statistics with st.expander("See detailed per-metric statistics"): for metric_name, metric_stats in stats_data.items(): st.write(f"**{metric_name}**") col1, col2 = st.columns(2) with col1: st.write(f""" - Mean: {metric_stats.get('mean', 0):.4f} - Median: {metric_stats.get('median', 0):.4f} - Std Dev: {metric_stats.get('std_dev', 0):.4f} - Min: {metric_stats.get('min', 0):.4f} - Max: {metric_stats.get('max', 0):.4f} """) with col2: st.write(f""" - 25th percentile: {metric_stats.get('percentile_25', 0):.4f} - 75th percentile: {metric_stats.get('percentile_75', 0):.4f} - Perfect (>=0.95): {metric_stats.get('perfect_count', 0)} - Poor (<0.3): {metric_stats.get('poor_count', 0)} - Samples: {metric_stats.get('sample_count', 0)} """) # Show RMSE vs RAGBench Ground Truth (per RAGBench paper requirement) if "rmse_vs_ground_truth" in results: st.markdown("### 📉 RMSE vs RAGBench Ground Truth") st.info("Compares predicted TRACE scores against original RAGBench dataset scores") rmse_gt = results.get("rmse_vs_ground_truth", {}) per_metric_rmse = rmse_gt.get("per_metric_rmse", {}) if per_metric_rmse: rmse_gt_cols = st.columns(5) with rmse_gt_cols[0]: st.metric("Relevance RMSE", f"{per_metric_rmse.get('context_relevance', 0):.4f}", delta=None, help="Lower is better (0 = perfect match)") with rmse_gt_cols[1]: st.metric("Utilization RMSE", f"{per_metric_rmse.get('context_utilization', 0):.4f}") with rmse_gt_cols[2]: st.metric("Completeness RMSE", f"{per_metric_rmse.get('completeness', 0):.4f}") with rmse_gt_cols[3]: st.metric("Adherence RMSE", f"{per_metric_rmse.get('adherence', 0):.4f}") with rmse_gt_cols[4]: agg_rmse = rmse_gt.get("aggregated_rmse", 0) consistency = rmse_gt.get("consistency_score", 0) st.metric("Aggregated RMSE", f"{agg_rmse:.4f}", delta=f"Consistency: {consistency:.2%}", delta_color="normal") # Show AUCROC vs RAGBench Ground Truth (per RAGBench paper requirement) if "aucroc_vs_ground_truth" in results: st.markdown("### 📊 AUC-ROC vs RAGBench Ground Truth") st.info("Area Under ROC Curve comparing predicted vs ground truth binary classifications") auc_gt = results.get("aucroc_vs_ground_truth", {}) if auc_gt: auc_cols = st.columns(5) with auc_cols[0]: st.metric("Relevance AUC", f"{auc_gt.get('context_relevance', 0):.4f}", help="Higher is better (1.0 = perfect classification)") with auc_cols[1]: st.metric("Utilization AUC", f"{auc_gt.get('context_utilization', 0):.4f}") with auc_cols[2]: st.metric("Completeness AUC", f"{auc_gt.get('completeness', 0):.4f}") with auc_cols[3]: st.metric("Adherence AUC", f"{auc_gt.get('adherence', 0):.4f}") with auc_cols[4]: avg_auc = auc_gt.get("average", 0) st.metric("Average AUC", f"{avg_auc:.4f}") # Download results st.divider() st.markdown("## 💾 Download Results") # Create a comprehensive download with all details download_data = { "evaluation_metadata": { "timestamp": datetime.now().isoformat(), "dataset": st.session_state.dataset_name, "method": results.get("evaluation_config", {}).get("evaluation_method", "gpt_labeling_prompts"), "total_samples": results.get("num_samples", 0), "embedding_model": st.session_state.embedding_model, }, "aggregate_metrics": { "context_relevance": results.get("context_relevance") or results.get("relevance", 0), "context_utilization": results.get("context_utilization") or results.get("utilization", 0), "completeness": results.get("completeness", 0), "adherence": results.get("adherence", 0), "average": results.get("average", 0), }, "rmse_metrics": results.get("rmse_metrics", {}), "per_metric_statistics": results.get("per_metric_statistics", {}), "rmse_vs_ground_truth": results.get("rmse_vs_ground_truth", {}), "aucroc_vs_ground_truth": results.get("aucroc_vs_ground_truth", {}), "detailed_results": results.get("detailed_results", []) } results_json = json.dumps(download_data, indent=2, default=str) col1, col2 = st.columns(2) with col1: st.download_button( label="đŸ“Ĩ Download Complete Results (JSON)", data=results_json, file_name=f"evaluation_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json", mime="application/json", help="Download all evaluation results including metrics and per-query details" ) with col2: st.download_button( label="📋 Download Metrics Only (JSON)", data=json.dumps(download_data["aggregate_metrics"], indent=2), file_name=f"evaluation_metrics_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json", mime="application/json", help="Download only the aggregate metrics" ) def start_background_evaluation(num_samples: int, selected_llm: str = None, method: str = "trace"): """Start evaluation in background thread. Args: num_samples: Number of test samples to evaluate selected_llm: LLM model to use for evaluation method: Evaluation method """ if not BACKGROUND_EVAL_AVAILABLE: st.error("Background evaluation not available. Running in foreground.") run_evaluation(num_samples, selected_llm, method) return try: # Generate unique job ID job_id = f"{st.session_state.collection_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{str(uuid.uuid4())[:8]}" # Create job record job = create_job( job_id=job_id, collection_name=st.session_state.collection_name, dataset_name=st.session_state.dataset_name, num_samples=num_samples, method=method, llm_model=selected_llm or st.session_state.current_llm, embedding_model=st.session_state.embedding_model, llm_provider=st.session_state.get("llm_provider", "groq") ) # Start background thread start_background_job( job_id=job_id, rag_pipeline=st.session_state.rag_pipeline, vector_store=st.session_state.vector_store, dataset_name=st.session_state.dataset_name, num_samples=num_samples, method=method, llm_model=selected_llm or st.session_state.current_llm, embedding_model=st.session_state.embedding_model, llm_provider=st.session_state.get("llm_provider", "groq"), groq_api_key=st.session_state.get("groq_api_key"), groq_api_keys=settings.groq_api_keys if settings.groq_api_keys else None, ollama_host=settings.ollama_host ) st.success(f""" ✅ **Background evaluation started!** - **Job ID:** `{job_id}` - **Method:** {method} - **Samples:** {num_samples} You can safely close this page. Check back later using "View Background Jobs" button. """) # Show the jobs panel st.session_state.show_background_jobs = True except Exception as e: st.error(f"Failed to start background evaluation: {str(e)}") def display_background_jobs(): """Display background jobs status panel.""" st.divider() st.markdown("### 📋 Background Evaluation Jobs") if not BACKGROUND_EVAL_AVAILABLE: st.warning("Background evaluation module not available.") return # Refresh button if st.button("🔄 Refresh Status"): st.rerun() # Get all jobs jobs = list_jobs() if not jobs: st.info("No background jobs found.") if st.button("Close Panel"): st.session_state.show_background_jobs = False st.rerun() return # Display jobs for job in jobs[:10]: # Show last 10 jobs job_id = job.get("job_id", "Unknown") status = job.get("status", "unknown") progress = job.get("progress", 0) # Status icon status_icons = { "pending": "âŗ", "running": "🔄", "completed": "✅", "failed": "❌" } icon = status_icons.get(status, "❓") with st.expander(f"{icon} {job.get('collection_name', 'Unknown')} - {status.upper()} ({progress}%)", expanded=(status == "running")): col1, col2 = st.columns([2, 1]) with col1: st.write(f"**Job ID:** `{job_id}`") st.write(f"**Method:** {job.get('method', 'Unknown')}") st.write(f"**Samples:** {job.get('num_samples', 0)}") st.write(f"**LLM:** {job.get('llm_model', 'Unknown')}") st.write(f"**Created:** {job.get('created_at', 'Unknown')}") if status == "running": st.progress(progress / 100) st.write(f"**Current Step:** {job.get('current_step', '')}") if status == "completed": st.write(f"**Completed:** {job.get('completed_at', 'Unknown')}") results_file = job.get('results_file') if results_file: st.write(f"**Results File:** `{results_file}`") # Load results into session state if st.button(f"📊 Load Results", key=f"load_{job_id}"): results = job.get("results") if results: st.session_state.evaluation_results = results st.success("Results loaded! Scroll down to see metrics.") st.rerun() if status == "failed": st.error(f"**Error:** {job.get('error', 'Unknown error')}") with col2: # Show logs logs = job.get("logs", []) if logs: st.write("**Recent Logs:**") for log in logs[-5:]: # Last 5 logs st.caption(f"{log.get('message', '')}") # Delete button if st.button(f"đŸ—‘ī¸ Delete", key=f"del_{job_id}"): delete_job(job_id) st.success(f"Job {job_id} deleted.") st.rerun() if st.button("Close Panel"): st.session_state.show_background_jobs = False st.rerun() def run_evaluation(num_samples: int, selected_llm: str = None, method: str = "trace"): """Run evaluation using selected method (TRACE, GPT Labeling, or Hybrid). Args: num_samples: Number of test samples to evaluate selected_llm: LLM model to use for evaluation method: Evaluation method ("trace", "gpt_labeling", or "hybrid") """ with st.spinner(f"Running evaluation on {num_samples} samples..."): try: # Create logs container logs_container = st.container() logs_list = [] # Display logs header once outside function logs_placeholder = st.empty() def add_log(message: str): """Add log message and update display.""" logs_list.append(message) with logs_placeholder.container(): st.markdown("### 📋 Evaluation Logs:") for log_msg in logs_list: st.caption(log_msg) # Log evaluation start timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") add_log(f"âąī¸ Evaluation started at {timestamp}") add_log(f"📊 Dataset: {st.session_state.dataset_name}") add_log(f"📈 Total samples: {num_samples}") add_log(f"🤖 LLM Model: {selected_llm if selected_llm else st.session_state.current_llm}") add_log(f"🔗 Vector Store: {st.session_state.collection_name}") add_log(f"🧠 Embedding Model: {st.session_state.embedding_model}") # Map method names method_names = { "trace": "TRACE (Heuristic)", "gpt_labeling": "GPT Labeling (LLM-based)", "hybrid": "Hybrid (Both)" } add_log(f"đŸ”Ŧ Evaluation Method: {method_names.get(method, method)}") # Use selected LLM if provided - create with appropriate provider eval_llm_client = None original_llm = None current_provider = st.session_state.get("llm_provider", "groq") if selected_llm and selected_llm != st.session_state.current_llm: add_log(f"🔄 Switching LLM to {selected_llm} ({current_provider.upper()})...") groq_api_key = st.session_state.groq_api_key if "groq_api_key" in st.session_state else "" eval_llm_client = create_llm_client( provider=current_provider, api_key=groq_api_key, api_keys=settings.groq_api_keys if settings.groq_api_keys else None, model_name=selected_llm, ollama_host=settings.ollama_host, max_rpm=settings.groq_rpm_limit, rate_limit_delay=settings.rate_limit_delay, max_retries=settings.max_retries, retry_delay=settings.retry_delay ) # Temporarily replace LLM client original_llm = st.session_state.rag_pipeline.llm st.session_state.rag_pipeline.llm = eval_llm_client else: eval_llm_client = st.session_state.rag_pipeline.llm # Log provider info provider_icon = "â˜ī¸" if current_provider == "groq" else "đŸ–Ĩī¸" add_log(f"{provider_icon} LLM Provider: {current_provider.upper()}") # Get test data add_log("đŸ“Ĩ Loading test data...") loader = RAGBenchLoader() test_data = loader.get_test_data( st.session_state.dataset_name, num_samples ) add_log(f"✅ Loaded {len(test_data)} test samples") # Prepare test cases test_cases = [] progress_bar = st.progress(0) status_text = st.empty() add_log("🔍 Processing samples...") for i, sample in enumerate(test_data): status_text.text(f"Processing sample {i+1}/{num_samples}") # Query the RAG system result = st.session_state.rag_pipeline.query( sample["question"], n_results=5 ) # Prepare test case test_cases.append({ "query": sample["question"], "response": result["response"], "retrieved_documents": [doc["document"] for doc in result["retrieved_documents"]], "ground_truth": sample.get("answer", "") }) # Update progress progress_bar.progress((i + 1) / num_samples) # Log every 10 samples if (i + 1) % 10 == 0 or (i + 1) == num_samples: add_log(f" ✓ Processed {i + 1}/{num_samples} samples") status_text.text(f"Running {method_names.get(method, method)} evaluation...") add_log(f"📊 Running evaluation using {method_names.get(method, method)}...") # Extract chunking and embedding metadata from session state # (These were stored when the collection was loaded/created) chunking_strategy = st.session_state.vector_store.chunking_strategy if st.session_state.vector_store else None embedding_model = st.session_state.embedding_model chunk_size = st.session_state.vector_store.chunk_size if st.session_state.vector_store else None chunk_overlap = st.session_state.vector_store.chunk_overlap if st.session_state.vector_store else None # Log retrieval configuration add_log(f"🔧 Retrieval Configuration:") add_log(f" â€ĸ Chunking Strategy: {chunking_strategy or 'Unknown'}") add_log(f" â€ĸ Chunk Size: {chunk_size or 'Unknown'}") add_log(f" â€ĸ Chunk Overlap: {chunk_overlap or 'Unknown'}") add_log(f" â€ĸ Embedding Model: {embedding_model or 'Unknown'}") # Import unified pipeline try: from evaluation_pipeline import UnifiedEvaluationPipeline # Run evaluation with metadata using unified pipeline pipeline = UnifiedEvaluationPipeline( llm_client=eval_llm_client, chunking_strategy=chunking_strategy, embedding_model=embedding_model, chunk_size=chunk_size, chunk_overlap=chunk_overlap ) # Run evaluation with selected method results = pipeline.evaluate_batch(test_cases, method=method) except ImportError: # Fallback to TRACE only if evaluation_pipeline module not available add_log("âš ī¸ evaluation_pipeline module not found, falling back to TRACE...") # Run evaluation with metadata using TRACE evaluator = TRACEEvaluator( chunking_strategy=chunking_strategy, embedding_model=embedding_model, chunk_size=chunk_size, chunk_overlap=chunk_overlap ) results = evaluator.evaluate_batch(test_cases) st.session_state.evaluation_results = results # Log evaluation results summary add_log("✅ Evaluation completed successfully!") # Display appropriate metrics based on method if method == "trace": add_log(f" â€ĸ Utilization: {results.get('utilization', 0):.2%}") add_log(f" â€ĸ Relevance: {results.get('relevance', 0):.2%}") add_log(f" â€ĸ Adherence: {results.get('adherence', 0):.2%}") add_log(f" â€ĸ Completeness: {results.get('completeness', 0):.2%}") add_log(f" â€ĸ Average: {results.get('average', 0):.2%}") elif method == "gpt_labeling": if "context_relevance" in results: add_log(f" â€ĸ Context Relevance: {results.get('context_relevance', 0):.2%}") add_log(f" â€ĸ Context Utilization: {results.get('context_utilization', 0):.2%}") add_log(f" â€ĸ Completeness: {results.get('completeness', 0):.2%}") add_log(f" â€ĸ Adherence: {results.get('adherence', 0):.2%}") add_log(f" â€ĸ Average: {results.get('average', 0):.2%}") # NEW: Display RMSE and AUCROC metrics if available if "rmse_metrics" in results: add_log(f"📈 RMSE Metrics (vs ground truth):") rmse_metrics = results.get("rmse_metrics", {}) add_log(f" â€ĸ Context Relevance RMSE: {rmse_metrics.get('relevance', 0):.4f}") add_log(f" â€ĸ Context Utilization RMSE: {rmse_metrics.get('utilization', 0):.4f}") add_log(f" â€ĸ Completeness RMSE: {rmse_metrics.get('completeness', 0):.4f}") add_log(f" â€ĸ Adherence RMSE: {rmse_metrics.get('adherence', 0):.4f}") add_log(f" â€ĸ Average RMSE: {rmse_metrics.get('average', 0):.4f}") if "auc_metrics" in results: add_log(f"📊 AUCROC Metrics (binary classification):") auc_metrics = results.get("auc_metrics", {}) add_log(f" â€ĸ Context Relevance AUCROC: {auc_metrics.get('relevance', 0):.4f}") add_log(f" â€ĸ Context Utilization AUCROC: {auc_metrics.get('utilization', 0):.4f}") add_log(f" â€ĸ Completeness AUCROC: {auc_metrics.get('completeness', 0):.4f}") add_log(f" â€ĸ Adherence AUCROC: {auc_metrics.get('adherence', 0):.4f}") add_log(f" â€ĸ Average AUCROC: {auc_metrics.get('average', 0):.4f}") elif method == "hybrid": add_log(" 📊 TRACE Metrics:") trace_res = results.get("trace_results", {}) add_log(f" â€ĸ Utilization: {trace_res.get('utilization', 0):.2%}") add_log(f" â€ĸ Relevance: {trace_res.get('relevance', 0):.2%}") add_log(f" â€ĸ Adherence: {trace_res.get('adherence', 0):.2%}") add_log(f" â€ĸ Completeness: {trace_res.get('completeness', 0):.2%}") add_log(" 🧠 GPT Labeling Metrics:") gpt_res = results.get("gpt_results", {}) add_log(f" â€ĸ Context Relevance: {gpt_res.get('context_relevance', 0):.2%}") add_log(f" â€ĸ Context Utilization: {gpt_res.get('context_utilization', 0):.2%}") add_log(f" â€ĸ Completeness: {gpt_res.get('completeness', 0):.2%}") add_log(f" â€ĸ Adherence: {gpt_res.get('adherence', 0):.2%}") # Restore original LLM if it was switched if selected_llm and selected_llm != st.session_state.current_llm and original_llm: st.session_state.rag_pipeline.llm = original_llm add_log(f"🔄 Restored original LLM") add_log(f"âąī¸ Evaluation completed at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") # Auto-save evaluation results to JSON file try: collection_name = st.session_state.get("collection_name", "unknown_collection") timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') filename = f"{collection_name}_{timestamp}_evaluation.json" # Prepare save data with metadata save_data = { "collection_name": collection_name, "timestamp": datetime.now().isoformat(), "evaluation_method": method, "num_samples": num_samples, "embedding_model": st.session_state.get("embedding_model", "unknown"), "llm_model": st.session_state.get("current_llm", "unknown"), "results": results } # Save to file with open(filename, 'w', encoding='utf-8') as f: json.dump(save_data, f, indent=2, default=str) add_log(f"💾 Results auto-saved to: {filename}") st.success(f"💾 Results saved to: {filename}") except Exception as save_error: add_log(f"âš ī¸ Could not auto-save results: {str(save_error)}") except Exception as e: st.error(f"Error during evaluation: {str(e)}") add_log(f"❌ Error: {str(e)}") def history_interface(): """History interface tab.""" st.subheader("📜 Chat History") if not st.session_state.chat_history: st.info("No chat history yet. Start a conversation in the Chat tab!") return # Export history col1, col2 = st.columns([3, 1]) with col2: history_json = json.dumps(st.session_state.chat_history, indent=2) st.download_button( label="💾 Export History", data=history_json, file_name=f"chat_history_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json", mime="application/json" ) # Display history for i, entry in enumerate(st.session_state.chat_history): with st.expander(f"đŸ’Ŧ Conversation {i+1}: {entry['query'][:50]}..."): st.markdown(f"**Query:** {entry['query']}") st.markdown(f"**Response:** {entry['response']}") st.markdown(f"**Timestamp:** {entry.get('timestamp', 'N/A')}") st.markdown("**Retrieved Documents:**") for j, doc in enumerate(entry["retrieved_documents"]): st.text_area( f"Document {j+1}", value=doc["document"], height=100, key=f"history_doc_{i}_{j}" ) if __name__ == "__main__": main()