Spaces:
Sleeping
Sleeping
| """Streamlit chat interface for RAG application.""" | |
| import streamlit as st | |
| import sys | |
| import os | |
| from datetime import datetime | |
| import json | |
| import pandas as pd | |
| from typing import Optional | |
| import warnings | |
| # Suppress warnings | |
| warnings.filterwarnings('ignore') | |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | |
| # Add parent directory to path | |
| sys.path.append(os.path.dirname(os.path.abspath(__file__))) | |
| # Check if running on HuggingFace Spaces | |
| IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None | |
| from config import settings | |
| from dataset_loader import RAGBenchLoader | |
| from vector_store import ChromaDBManager, create_vector_store | |
| try: | |
| from vector_store import QdrantManager, QDRANT_AVAILABLE | |
| except ImportError: | |
| QDRANT_AVAILABLE = False | |
| from llm_client import GroqLLMClient, OllamaLLMClient, RAGPipeline, create_llm_client | |
| from trace_evaluator import TRACEEvaluator | |
| from embedding_models import EmbeddingFactory | |
| from chunking_strategies import ChunkingFactory | |
| import uuid | |
| # Background evaluation support | |
| try: | |
| from background_evaluator import ( | |
| create_job, start_background_job, load_job_status, | |
| list_jobs, delete_job, STATUS_PENDING, STATUS_RUNNING, | |
| STATUS_COMPLETED, STATUS_FAILED | |
| ) | |
| BACKGROUND_EVAL_AVAILABLE = True | |
| except ImportError: | |
| BACKGROUND_EVAL_AVAILABLE = False | |
| # Page configuration | |
| st.set_page_config( | |
| page_title="RAG Capstone Project", | |
| page_icon="π€", | |
| layout="wide" | |
| ) | |
| # Initialize session state | |
| if "chat_history" not in st.session_state: | |
| st.session_state.chat_history = [] | |
| if "rag_pipeline" not in st.session_state: | |
| st.session_state.rag_pipeline = None | |
| if "vector_store" not in st.session_state: | |
| st.session_state.vector_store = None | |
| if "collection_loaded" not in st.session_state: | |
| st.session_state.collection_loaded = False | |
| if "evaluation_results" not in st.session_state: | |
| st.session_state.evaluation_results = None | |
| if "dataset_size" not in st.session_state: | |
| st.session_state.dataset_size = 10000 | |
| if "current_dataset" not in st.session_state: | |
| st.session_state.current_dataset = None | |
| if "current_llm" not in st.session_state: | |
| st.session_state.current_llm = settings.llm_models[1] | |
| if "selected_collection" not in st.session_state: | |
| st.session_state.selected_collection = None | |
| if "available_collections" not in st.session_state: | |
| st.session_state.available_collections = [] | |
| if "dataset_name" not in st.session_state: | |
| st.session_state.dataset_name = None | |
| if "collection_name" not in st.session_state: | |
| st.session_state.collection_name = None | |
| if "embedding_model" not in st.session_state: | |
| st.session_state.embedding_model = None | |
| if "groq_api_key" not in st.session_state: | |
| st.session_state.groq_api_key = "" | |
| if "llm_provider" not in st.session_state: | |
| st.session_state.llm_provider = settings.llm_provider | |
| if "ollama_model" not in st.session_state: | |
| st.session_state.ollama_model = settings.ollama_model | |
| if "vector_store_provider" not in st.session_state: | |
| st.session_state.vector_store_provider = settings.vector_store_provider | |
| if "qdrant_url" not in st.session_state: | |
| st.session_state.qdrant_url = settings.qdrant_url | |
| if "qdrant_api_key" not in st.session_state: | |
| st.session_state.qdrant_api_key = settings.qdrant_api_key | |
| def get_available_collections(provider: str = None): | |
| """Get list of available collections from vector store.""" | |
| provider = provider or st.session_state.get("vector_store_provider", "chroma") | |
| try: | |
| if provider == "qdrant" and QDRANT_AVAILABLE: | |
| qdrant_url = st.session_state.get("qdrant_url") or settings.qdrant_url | |
| qdrant_api_key = st.session_state.get("qdrant_api_key") or settings.qdrant_api_key | |
| if qdrant_url and qdrant_api_key: | |
| vector_store = create_vector_store("qdrant", url=qdrant_url, api_key=qdrant_api_key) | |
| collections = vector_store.list_collections() | |
| return collections | |
| return [] | |
| else: | |
| vector_store = ChromaDBManager(settings.chroma_persist_directory) | |
| collections = vector_store.list_collections() | |
| return collections | |
| except Exception as e: | |
| print(f"Error getting collections: {e}") | |
| return [] | |
| def main(): | |
| """Main Streamlit application.""" | |
| st.title("π€ RAG Capstone Project") | |
| st.markdown("### Retrieval-Augmented Generation with TRACE Evaluation") | |
| # Show HuggingFace Spaces notice | |
| if IS_HUGGINGFACE_SPACE: | |
| st.info("π€ Running on Hugging Face Spaces - Using Groq API (cloud-based LLM)") | |
| # Get available collections at startup | |
| available_collections = get_available_collections() | |
| st.session_state.available_collections = available_collections | |
| # Sidebar for configuration | |
| with st.sidebar: | |
| st.header("Configuration") | |
| # LLM Provider Selection - Disable Ollama on HuggingFace Spaces | |
| st.subheader("π LLM Provider") | |
| if IS_HUGGINGFACE_SPACE: | |
| # Force Groq on HuggingFace Spaces (Ollama not available) | |
| st.caption("βοΈ **Groq API** (Ollama unavailable on Spaces)") | |
| llm_provider = "groq" | |
| st.session_state.llm_provider = "groq" | |
| else: | |
| llm_provider = st.radio( | |
| "Choose LLM Provider:", | |
| options=["groq", "ollama"], | |
| index=0 if st.session_state.llm_provider == "groq" else 1, | |
| format_func=lambda x: "βοΈ Groq API (Cloud)" if x == "groq" else "π₯οΈ Ollama (Local)", | |
| help="Groq: Cloud API with rate limits. Ollama: Local unlimited inference.", | |
| key="llm_provider_radio" | |
| ) | |
| st.session_state.llm_provider = llm_provider | |
| # Provider-specific settings | |
| if llm_provider == "groq": | |
| st.caption("β οΈ Free tier: 30 requests/min") | |
| # On HuggingFace Spaces, check for API key in secrets first | |
| default_api_key = os.environ.get("GROQ_API_KEY", "") or settings.groq_api_key or "" | |
| # API Key input | |
| groq_api_key = st.text_input( | |
| "Groq API Key", | |
| type="password", | |
| value=default_api_key, | |
| help="Enter your Groq API key (or set GROQ_API_KEY in Spaces secrets)" | |
| ) | |
| if IS_HUGGINGFACE_SPACE and not groq_api_key: | |
| st.warning("π‘ Tip: Add GROQ_API_KEY to your Space secrets for persistence") | |
| else: | |
| # Ollama settings (only available locally) | |
| st.caption("β No rate limits - unlimited usage!") | |
| ollama_host = st.text_input( | |
| "Ollama Host", | |
| value=settings.ollama_host, | |
| help="Ollama server URL (default: http://localhost:11434)" | |
| ) | |
| ollama_model = st.selectbox( | |
| "Select Ollama Model:", | |
| options=settings.ollama_models, | |
| index=settings.ollama_models.index(st.session_state.ollama_model) if st.session_state.ollama_model in settings.ollama_models else 0, | |
| key="ollama_model_selector" | |
| ) | |
| st.session_state.ollama_model = ollama_model | |
| # Connection check button | |
| if st.button("π Check Ollama Connection"): | |
| try: | |
| import requests | |
| response = requests.get(f"{ollama_host}/api/tags", timeout=5) | |
| if response.status_code == 200: | |
| models = response.json().get("models", []) | |
| model_names = [m["name"] for m in models] | |
| st.success(f"β Connected! Available models: {', '.join(model_names)}") | |
| else: | |
| st.error(f"β Connection failed: {response.status_code}") | |
| except Exception as e: | |
| st.error(f"β Cannot connect to Ollama: {e}") | |
| st.info("Make sure Ollama is running: `ollama serve`") | |
| groq_api_key = "" # Not needed for Ollama | |
| st.divider() | |
| # Vector Store Provider Selection | |
| st.subheader("πΎ Vector Store") | |
| if IS_HUGGINGFACE_SPACE: | |
| st.caption("βοΈ Use **Qdrant Cloud** for persistent storage") | |
| vector_store_options = ["qdrant", "chroma"] | |
| default_idx = 0 | |
| else: | |
| vector_store_options = ["chroma", "qdrant"] | |
| default_idx = 0 | |
| vector_store_provider = st.radio( | |
| "Choose Vector Store:", | |
| options=vector_store_options, | |
| index=default_idx, | |
| format_func=lambda x: "βοΈ Qdrant Cloud (Persistent)" if x == "qdrant" else "πΎ ChromaDB (Local)", | |
| help="Qdrant: Cloud storage (persistent). ChromaDB: Local storage (ephemeral on Spaces).", | |
| key="vector_store_radio" | |
| ) | |
| st.session_state.vector_store_provider = vector_store_provider | |
| # Qdrant settings | |
| if vector_store_provider == "qdrant": | |
| default_qdrant_url = os.environ.get("QDRANT_URL", "") or settings.qdrant_url | |
| default_qdrant_key = os.environ.get("QDRANT_API_KEY", "") or settings.qdrant_api_key | |
| qdrant_url = st.text_input( | |
| "Qdrant URL", | |
| value=default_qdrant_url, | |
| placeholder="https://xxx-xxx.aws.cloud.qdrant.io:6333", | |
| help="Your Qdrant Cloud cluster URL" | |
| ) | |
| qdrant_api_key = st.text_input( | |
| "Qdrant API Key", | |
| type="password", | |
| value=default_qdrant_key, | |
| help="Your Qdrant API key" | |
| ) | |
| st.session_state.qdrant_url = qdrant_url | |
| st.session_state.qdrant_api_key = qdrant_api_key | |
| if not qdrant_url or not qdrant_api_key: | |
| st.warning("β οΈ Get free Qdrant Cloud at: https://cloud.qdrant.io") | |
| # Test Qdrant connection | |
| if st.button("π Test Qdrant Connection"): | |
| if qdrant_url and qdrant_api_key: | |
| try: | |
| test_store = create_vector_store("qdrant", url=qdrant_url, api_key=qdrant_api_key) | |
| collections = test_store.list_collections() | |
| st.success(f"β Connected! Found {len(collections)} collection(s)") | |
| except Exception as e: | |
| st.error(f"β Connection failed: {e}") | |
| else: | |
| st.error("Please enter Qdrant URL and API Key") | |
| st.divider() | |
| # Get available collections based on provider | |
| available_collections = get_available_collections(vector_store_provider) | |
| st.session_state.available_collections = available_collections | |
| # Option 1: Use existing collection | |
| if available_collections: | |
| st.subheader("π Existing Collections") | |
| st.write(f"Found {len(available_collections)} collection(s)") | |
| selected_collection = st.selectbox( | |
| "Or select existing collection:", | |
| available_collections, | |
| key="collection_selector" | |
| ) | |
| if st.button("π Load Existing Collection", type="secondary"): | |
| # Validate based on provider | |
| if llm_provider == "groq" and not groq_api_key: | |
| st.error("Please enter your Groq API key") | |
| elif vector_store_provider == "qdrant" and (not st.session_state.get("qdrant_url") or not st.session_state.get("qdrant_api_key")): | |
| st.error("Please enter Qdrant URL and API Key") | |
| else: | |
| load_existing_collection( | |
| groq_api_key, | |
| selected_collection, | |
| llm_provider, | |
| ollama_host if llm_provider == "ollama" else None, | |
| vector_store_provider | |
| ) | |
| st.divider() | |
| # Option 2: Create new collection | |
| st.subheader("π Create New Collection") | |
| # Dataset selection | |
| st.subheader("1. Dataset Selection") | |
| dataset_name = st.selectbox( | |
| "Choose Dataset", | |
| settings.ragbench_datasets, | |
| index=0 | |
| ) | |
| # Get dataset size dynamically | |
| if st.button("π Check Dataset Size", key="check_size"): | |
| with st.spinner("Checking dataset size..."): | |
| try: | |
| from datasets import load_dataset | |
| # Load dataset with download_mode to avoid cache issues | |
| st.info(f"Fetching dataset info for '{dataset_name}'...") | |
| ds = load_dataset( | |
| "rungalileo/ragbench", | |
| dataset_name, | |
| split="train", | |
| trust_remote_code=True, | |
| download_mode="force_redownload" # Force fresh download to avoid cache corruption | |
| ) | |
| dataset_size = len(ds) | |
| st.session_state.dataset_size = dataset_size | |
| st.session_state.current_dataset = dataset_name | |
| st.success(f"β Dataset '{dataset_name}' has {dataset_size:,} samples available") | |
| except Exception as e: | |
| st.error(f"β Error: {str(e)}") | |
| st.exception(e) | |
| st.warning(f"Could not determine dataset size. Using default of 10,000.") | |
| st.session_state.dataset_size = 10000 | |
| st.session_state.current_dataset = dataset_name | |
| # Use stored dataset size or default | |
| max_samples_available = st.session_state.get('dataset_size', 10000) | |
| st.caption(f"Max available samples: {max_samples_available:,}") | |
| num_samples = st.slider( | |
| "Number of samples", | |
| min_value=10, | |
| max_value=max_samples_available, | |
| value=min(100, max_samples_available), | |
| step=50 if max_samples_available > 1000 else 10, | |
| help="Adjust slider to select number of samples" | |
| ) | |
| load_all_samples = st.checkbox( | |
| "Load all available samples", | |
| value=False, | |
| help="Override slider and load entire dataset" | |
| ) | |
| st.divider() | |
| # Chunking strategy | |
| st.subheader("2. Chunking Strategy") | |
| chunking_strategy = st.selectbox( | |
| "Choose Chunking Strategy", | |
| settings.chunking_strategies, | |
| index=0 | |
| ) | |
| chunk_size = st.slider( | |
| "Chunk Size", | |
| min_value=256, | |
| max_value=1024, | |
| value=512, | |
| step=128 | |
| ) | |
| overlap = st.slider( | |
| "Overlap", | |
| min_value=0, | |
| max_value=200, | |
| value=50, | |
| step=10 | |
| ) | |
| st.divider() | |
| # Embedding model | |
| st.subheader("3. Embedding Model") | |
| embedding_model = st.selectbox( | |
| "Choose Embedding Model", | |
| settings.embedding_models, | |
| index=0 | |
| ) | |
| st.divider() | |
| # LLM model selection for new collection | |
| st.subheader("4. LLM Model") | |
| if llm_provider == "groq": | |
| llm_model = st.selectbox( | |
| "Choose Groq LLM", | |
| settings.llm_models, | |
| index=1 | |
| ) | |
| else: | |
| llm_model = st.selectbox( | |
| "Choose Ollama Model", | |
| settings.ollama_models, | |
| index=settings.ollama_models.index(st.session_state.ollama_model) if st.session_state.ollama_model in settings.ollama_models else 0, | |
| key="llm_model_ollama" | |
| ) | |
| st.divider() | |
| # Load data button | |
| if st.button("π Load Data & Create Collection", type="primary"): | |
| # Validate based on provider | |
| if llm_provider == "groq" and not groq_api_key: | |
| st.error("Please enter your Groq API key") | |
| elif vector_store_provider == "qdrant" and (not st.session_state.get("qdrant_url") or not st.session_state.get("qdrant_api_key")): | |
| st.error("Please enter Qdrant URL and API Key") | |
| else: | |
| # Use None for num_samples if loading all data | |
| samples_to_load = None if load_all_samples else num_samples | |
| load_and_create_collection( | |
| groq_api_key, | |
| dataset_name, | |
| samples_to_load, | |
| chunking_strategy, | |
| chunk_size, | |
| overlap, | |
| embedding_model, | |
| llm_model, | |
| llm_provider, | |
| ollama_host if llm_provider == "ollama" else None, | |
| vector_store_provider | |
| ) | |
| # Main content area | |
| if not st.session_state.collection_loaded: | |
| st.info("π Please configure and load a dataset from the sidebar to begin") | |
| # Show instructions | |
| with st.expander("π How to Use", expanded=True): | |
| st.markdown(""" | |
| 1. **Enter your Groq API Key** in the sidebar | |
| 2. **Select a dataset** from RAG Bench | |
| 3. **Choose a chunking strategy** (dense, sparse, hybrid, re-ranking) | |
| 4. **Select an embedding model** for document vectorization | |
| 5. **Choose an LLM model** for response generation | |
| 6. **Click "Load Data & Create Collection"** to initialize | |
| 7. **Start chatting** in the chat interface | |
| 8. **View retrieved documents** and evaluation metrics | |
| 9. **Run TRACE evaluation** on test data | |
| """) | |
| # Show available options | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.subheader("π Available Datasets") | |
| for ds in settings.ragbench_datasets: | |
| st.markdown(f"- {ds}") | |
| with col2: | |
| st.subheader("π€ Available Models") | |
| st.markdown("**Embedding Models:**") | |
| for em in settings.embedding_models: | |
| st.markdown(f"- {em}") | |
| st.markdown("**LLM Models:**") | |
| for lm in settings.llm_models: | |
| st.markdown(f"- {lm}") | |
| else: | |
| # Create tabs for different functionalities | |
| tab1, tab2, tab3 = st.tabs(["π¬ Chat", "π Evaluation", "π History"]) | |
| with tab1: | |
| chat_interface() | |
| with tab2: | |
| evaluation_interface() | |
| with tab3: | |
| history_interface() | |
| def load_existing_collection(api_key: str, collection_name: str, llm_provider: str = "groq", ollama_host: str = None, vector_store_provider: str = "chroma"): | |
| """Load an existing collection from vector store.""" | |
| with st.spinner(f"Loading collection '{collection_name}'..."): | |
| try: | |
| # Initialize vector store based on provider | |
| if vector_store_provider == "qdrant": | |
| qdrant_url = st.session_state.get("qdrant_url") or settings.qdrant_url | |
| qdrant_api_key = st.session_state.get("qdrant_api_key") or settings.qdrant_api_key | |
| vector_store = create_vector_store("qdrant", url=qdrant_url, api_key=qdrant_api_key) | |
| else: | |
| vector_store = ChromaDBManager(settings.chroma_persist_directory) | |
| vector_store.get_collection(collection_name) | |
| # Extract dataset name from collection name (format: dataset_name_strategy_model) | |
| # Try to find which dataset this collection is based on | |
| dataset_name = None | |
| for ds in settings.ragbench_datasets: | |
| if collection_name.startswith(ds.replace("-", "_")): | |
| dataset_name = ds | |
| break | |
| if not dataset_name: | |
| dataset_name = collection_name.split("_")[0] # Fallback: use first part | |
| # Prompt for LLM selection based on provider | |
| if llm_provider == "groq": | |
| st.session_state.current_llm = st.selectbox( | |
| "Select Groq LLM for this collection:", | |
| settings.llm_models, | |
| key=f"llm_selector_{collection_name}" | |
| ) | |
| else: | |
| st.session_state.current_llm = st.selectbox( | |
| "Select Ollama Model for this collection:", | |
| settings.ollama_models, | |
| key=f"ollama_selector_{collection_name}" | |
| ) | |
| # Initialize LLM client based on provider | |
| st.info(f"Initializing LLM client ({llm_provider})...") | |
| llm_client = create_llm_client( | |
| provider=llm_provider, | |
| api_key=api_key, | |
| api_keys=settings.groq_api_keys if settings.groq_api_keys else None, | |
| model_name=st.session_state.current_llm, | |
| ollama_host=ollama_host or settings.ollama_host, | |
| max_rpm=settings.groq_rpm_limit, | |
| rate_limit_delay=settings.rate_limit_delay, | |
| max_retries=settings.max_retries, | |
| retry_delay=settings.retry_delay | |
| ) | |
| # Create RAG pipeline with correct parameter names | |
| st.info("Creating RAG pipeline...") | |
| rag_pipeline = RAGPipeline( | |
| llm_client=llm_client, | |
| vector_store_manager=vector_store | |
| ) | |
| # Store in session state | |
| st.session_state.vector_store = vector_store | |
| st.session_state.rag_pipeline = rag_pipeline | |
| st.session_state.collection_loaded = True | |
| st.session_state.current_collection = collection_name | |
| st.session_state.selected_collection = collection_name | |
| st.session_state.groq_api_key = api_key | |
| st.session_state.dataset_name = dataset_name | |
| st.session_state.collection_name = collection_name | |
| st.session_state.llm_provider = llm_provider | |
| # Extract embedding model from collection metadata or vector store | |
| embedding_model_name = None | |
| # For ChromaDB: check collection metadata | |
| if hasattr(vector_store, 'current_collection') and vector_store.current_collection: | |
| if hasattr(vector_store.current_collection, 'metadata'): | |
| collection_metadata = vector_store.current_collection.metadata | |
| if collection_metadata and "embedding_model" in collection_metadata: | |
| embedding_model_name = collection_metadata["embedding_model"] | |
| # For Qdrant or fallback: check if embedding_model was loaded on the vector store | |
| if not embedding_model_name and hasattr(vector_store, 'embedding_model') and vector_store.embedding_model: | |
| if hasattr(vector_store.embedding_model, 'model_name'): | |
| embedding_model_name = vector_store.embedding_model.model_name | |
| # Set session state | |
| if embedding_model_name: | |
| st.session_state.embedding_model = embedding_model_name | |
| else: | |
| st.session_state.embedding_model = None | |
| # Display system prompt and model info | |
| provider_icon = "βοΈ" if llm_provider == "groq" else "π₯οΈ" | |
| st.success(f"β Collection '{collection_name}' loaded successfully! {provider_icon} Using {llm_provider.upper()}") | |
| with st.expander("π€ Model & System Prompt Information", expanded=False): | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.write(f"**Provider:** {provider_icon} {llm_provider.upper()}") | |
| st.write(f"**Model:** {st.session_state.current_llm}") | |
| st.write(f"**Collection:** {collection_name}") | |
| st.write(f"**Dataset:** {dataset_name}") | |
| with col2: | |
| st.write(f"**Temperature:** 0.0") | |
| st.write(f"**Max Tokens:** 2048") | |
| if llm_provider == "groq": | |
| st.write(f"**Rate Limit:** {settings.groq_rpm_limit} RPM") | |
| else: | |
| st.write(f"**Rate Limit:** β Unlimited (Local)") | |
| st.markdown("#### System Prompt") | |
| st.info(""" | |
| You are a Fact-Checking and Citation Specialist. Your task is to perform a rigorous audit of a response against provided documents to determine its accuracy, relevance, and level of support. | |
| **Task:** | |
| 1. Analyze the provided documents and identify information relevant to the user's question | |
| 2. Evaluate the response sentence-by-sentence | |
| 3. Verify each response sentence maps to supporting document sentences | |
| 4. Identify which document sentences were actually used in the response | |
| """) | |
| st.rerun() | |
| except Exception as e: | |
| st.error(f"Error loading collection: {str(e)}") | |
| st.exception(e) | |
| def load_and_create_collection( | |
| api_key: str, | |
| dataset_name: str, | |
| num_samples: Optional[int], | |
| chunking_strategy: str, | |
| chunk_size: int, | |
| overlap: int, | |
| embedding_model: str, | |
| llm_model: str, | |
| llm_provider: str = "groq", | |
| ollama_host: str = None, | |
| vector_store_provider: str = "chroma" | |
| ): | |
| """Load dataset and create vector collection.""" | |
| with st.spinner("Loading dataset and creating collection..."): | |
| try: | |
| # Initialize dataset loader | |
| loader = RAGBenchLoader() | |
| # Load dataset | |
| if num_samples is None: | |
| st.info(f"Loading {dataset_name} dataset (all available samples)...") | |
| else: | |
| st.info(f"Loading {dataset_name} dataset ({num_samples} samples)...") | |
| dataset = loader.load_dataset(dataset_name, split="train", max_samples=num_samples) | |
| if not dataset: | |
| st.error("Failed to load dataset") | |
| return | |
| # Initialize vector store based on provider | |
| st.info(f"Initializing vector store ({vector_store_provider})...") | |
| if vector_store_provider == "qdrant": | |
| qdrant_url = st.session_state.get("qdrant_url") or settings.qdrant_url | |
| qdrant_api_key = st.session_state.get("qdrant_api_key") or settings.qdrant_api_key | |
| vector_store = create_vector_store("qdrant", url=qdrant_url, api_key=qdrant_api_key) | |
| else: | |
| vector_store = ChromaDBManager(settings.chroma_persist_directory) | |
| # Create collection name | |
| collection_name = f"{dataset_name}_{chunking_strategy}_{embedding_model.split('/')[-1]}" | |
| collection_name = collection_name.replace("-", "_").replace(".", "_") | |
| # Delete existing collection with same name (if exists) | |
| existing_collections = vector_store.list_collections() | |
| if collection_name in existing_collections: | |
| st.warning(f"Collection '{collection_name}' already exists. Deleting and recreating...") | |
| vector_store.delete_collection(collection_name) | |
| st.info("Old collection deleted. Creating new one...") | |
| # Load data into collection | |
| st.info(f"Creating collection with {chunking_strategy} chunking...") | |
| vector_store.load_dataset_into_collection( | |
| collection_name=collection_name, | |
| embedding_model_name=embedding_model, | |
| chunking_strategy=chunking_strategy, | |
| dataset_data=dataset, | |
| chunk_size=chunk_size, | |
| overlap=overlap | |
| ) | |
| # Initialize LLM client based on provider | |
| st.info(f"Initializing LLM client ({llm_provider})...") | |
| llm_client = create_llm_client( | |
| provider=llm_provider, | |
| api_key=api_key, | |
| api_keys=settings.groq_api_keys if settings.groq_api_keys else None, | |
| model_name=llm_model, | |
| ollama_host=ollama_host or settings.ollama_host, | |
| max_rpm=settings.groq_rpm_limit, | |
| rate_limit_delay=settings.rate_limit_delay, | |
| max_retries=settings.max_retries, | |
| retry_delay=settings.retry_delay | |
| ) | |
| # Create RAG pipeline with correct parameter names | |
| rag_pipeline = RAGPipeline( | |
| llm_client=llm_client, | |
| vector_store_manager=vector_store | |
| ) | |
| # Store in session state | |
| st.session_state.vector_store = vector_store | |
| st.session_state.rag_pipeline = rag_pipeline | |
| st.session_state.collection_loaded = True | |
| st.session_state.current_collection = collection_name | |
| st.session_state.dataset_name = dataset_name | |
| st.session_state.dataset = dataset | |
| st.session_state.collection_name = collection_name | |
| st.session_state.embedding_model = embedding_model | |
| st.session_state.groq_api_key = api_key | |
| st.session_state.llm_provider = llm_provider | |
| st.session_state.vector_store_provider = vector_store_provider | |
| provider_icon = "βοΈ" if llm_provider == "groq" else "π₯οΈ" | |
| vs_icon = "βοΈ" if vector_store_provider == "qdrant" else "πΎ" | |
| st.success(f"β Collection '{collection_name}' created successfully! {provider_icon} Using {llm_provider.upper()}") | |
| st.rerun() | |
| except Exception as e: | |
| st.error(f"Error: {str(e)}") | |
| def chat_interface(): | |
| """Chat interface tab.""" | |
| st.subheader("π¬ Chat Interface") | |
| # Check if collection is loaded | |
| if not st.session_state.collection_loaded: | |
| st.warning("β οΈ No data loaded. Please use the configuration panel to load a dataset and create a collection.") | |
| st.info(""" | |
| Steps: | |
| 1. Select a dataset from the dropdown | |
| 2. Click "Load Data & Create Collection" button | |
| 3. Wait for the collection to be created | |
| 4. Then you can start chatting | |
| """) | |
| return | |
| # Display collection info and LLM selector | |
| col1, col2, col3 = st.columns([2, 2, 1]) | |
| with col1: | |
| provider_icon = "βοΈ" if st.session_state.get("llm_provider", "groq") == "groq" else "π₯οΈ" | |
| st.info(f"π Collection: {st.session_state.current_collection} | {provider_icon} {st.session_state.get('llm_provider', 'groq').upper()}") | |
| with col2: | |
| # LLM selector for chat - based on provider | |
| current_provider = st.session_state.get("llm_provider", "groq") | |
| if current_provider == "groq": | |
| model_options = settings.llm_models | |
| try: | |
| current_index = settings.llm_models.index(st.session_state.current_llm) | |
| except ValueError: | |
| current_index = 0 | |
| else: | |
| model_options = settings.ollama_models | |
| try: | |
| current_index = settings.ollama_models.index(st.session_state.current_llm) | |
| except ValueError: | |
| current_index = 0 | |
| selected_llm = st.selectbox( | |
| f"Select {'Groq' if current_provider == 'groq' else 'Ollama'} Model for chat:", | |
| model_options, | |
| index=current_index, | |
| key="chat_llm_selector" | |
| ) | |
| if selected_llm != st.session_state.current_llm: | |
| st.session_state.current_llm = selected_llm | |
| # Recreate LLM client with new model | |
| llm_client = create_llm_client( | |
| provider=current_provider, | |
| api_key=st.session_state.groq_api_key if "groq_api_key" in st.session_state else "", | |
| api_keys=settings.groq_api_keys if settings.groq_api_keys else None, | |
| model_name=selected_llm, | |
| ollama_host=settings.ollama_host, | |
| max_rpm=settings.groq_rpm_limit, | |
| rate_limit_delay=settings.rate_limit_delay | |
| ) | |
| st.session_state.rag_pipeline.llm = llm_client | |
| with col3: | |
| if st.button("ποΈ Clear History"): | |
| st.session_state.chat_history = [] | |
| st.session_state.rag_pipeline.clear_history() | |
| st.rerun() | |
| # Show system prompt info in expandable section | |
| with st.expander("π€ System Prompt & Model Info", expanded=False): | |
| current_provider = st.session_state.get("llm_provider", "groq") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| provider_icon = "βοΈ" if current_provider == "groq" else "π₯οΈ" | |
| st.write(f"**Provider:** {provider_icon} {current_provider.upper()}") | |
| st.write(f"**LLM Model:** {st.session_state.current_llm}") | |
| st.write(f"**Temperature:** 0.0") | |
| st.write(f"**Max Tokens:** 2048") | |
| with col2: | |
| st.write(f"**Collection:** {st.session_state.current_collection}") | |
| st.write(f"**Dataset:** {st.session_state.get('dataset_name', 'N/A')}") | |
| if current_provider == "groq": | |
| st.write(f"**Rate Limit:** {settings.groq_rpm_limit} RPM") | |
| else: | |
| st.write(f"**Rate Limit:** β Unlimited (Local)") | |
| st.markdown("#### System Prompt Being Used") | |
| system_prompt = """You are a Fact-Checking and Citation Specialist. Your task is to perform a rigorous audit of a response against provided documents to determine its accuracy, relevance, and level of support. | |
| **TASK OVERVIEW** | |
| 1. **Analyze Documents**: Review the provided documents and identify information relevant to the user's question. | |
| 2. **Evaluate Response**: Review the provided answer sentence-by-sentence. | |
| 3. **Verify Support**: Map each answer sentence to specific supporting sentences in the documents. | |
| 4. **Identify Utilization**: Determine which document sentences were actually used (directly or implicitly) to form the answer.""" | |
| st.info(system_prompt) | |
| # Chat container | |
| chat_container = st.container() | |
| # Display chat history | |
| with chat_container: | |
| for chat_idx, entry in enumerate(st.session_state.chat_history): | |
| # User message | |
| with st.chat_message("user"): | |
| st.write(entry["query"]) | |
| # Assistant message | |
| with st.chat_message("assistant"): | |
| st.write(entry["response"]) | |
| # Show retrieved documents in expander | |
| with st.expander("π Retrieved Documents"): | |
| for doc_idx, doc in enumerate(entry["retrieved_documents"]): | |
| st.markdown(f"**Document {doc_idx+1}** (Distance: {doc.get('distance', 'N/A'):.4f})") | |
| st.text_area( | |
| f"doc_{chat_idx}_{doc_idx}", | |
| value=doc["document"], | |
| height=100, | |
| key=f"doc_area_{chat_idx}_{doc_idx}", | |
| label_visibility="collapsed" | |
| ) | |
| if doc.get("metadata"): | |
| st.caption(f"Metadata: {doc['metadata']}") | |
| # Chat input | |
| query = st.chat_input("Ask a question...") | |
| if query: | |
| # Check if collection exists | |
| if not st.session_state.rag_pipeline or not st.session_state.rag_pipeline.vector_store.current_collection: | |
| st.error("β No data loaded. Please load a dataset first using the configuration panel.") | |
| st.stop() | |
| # Add user message | |
| with chat_container: | |
| with st.chat_message("user"): | |
| st.write(query) | |
| # Generate response | |
| with st.spinner("Generating response..."): | |
| try: | |
| result = st.session_state.rag_pipeline.query(query) | |
| except Exception as e: | |
| st.error(f"β Error querying: {str(e)}") | |
| st.info("Please load a dataset and create a collection first.") | |
| st.stop() | |
| # Add assistant message | |
| with chat_container: | |
| with st.chat_message("assistant"): | |
| st.write(result["response"]) | |
| # Show retrieved documents | |
| with st.expander("π Retrieved Documents"): | |
| for doc_idx, doc in enumerate(result["retrieved_documents"]): | |
| st.markdown(f"**Document {doc_idx+1}** (Distance: {doc.get('distance', 'N/A'):.4f})") | |
| st.text_area( | |
| f"doc_current_{doc_idx}", | |
| value=doc["document"], | |
| height=100, | |
| key=f"doc_current_area_{doc_idx}", | |
| label_visibility="collapsed" | |
| ) | |
| if doc.get("metadata"): | |
| st.caption(f"Metadata: {doc['metadata']}") | |
| # Store in history | |
| st.session_state.chat_history.append(result) | |
| st.rerun() | |
| def evaluation_interface(): | |
| """Evaluation interface tab.""" | |
| st.subheader("π RAG Evaluation") | |
| # Check if collection is loaded | |
| if not st.session_state.collection_loaded: | |
| st.warning("β οΈ No data loaded. Please load a collection first.") | |
| return | |
| # Evaluation method selector | |
| eval_col1, eval_col2 = st.columns([2, 1]) | |
| with eval_col1: | |
| evaluation_method = st.radio( | |
| "Evaluation Method:", | |
| options=["TRACE (Heuristic)", "GPT Labeling (LLM-based)", "Hybrid (Both)"], | |
| horizontal=True, | |
| help="TRACE is fast (no LLM). GPT Labeling is accurate but slower (requires LLM calls)." | |
| ) | |
| # Map UI labels to method IDs | |
| method_map = { | |
| "TRACE (Heuristic)": "trace", | |
| "GPT Labeling (LLM-based)": "gpt_labeling", | |
| "Hybrid (Both)": "hybrid" | |
| } | |
| selected_method = method_map[evaluation_method] | |
| # LLM selector for evaluation | |
| current_provider = st.session_state.get("llm_provider", "groq") | |
| col1, col2 = st.columns([3, 1]) | |
| with col1: | |
| # Show provider-specific models | |
| if current_provider == "groq": | |
| model_options = settings.llm_models | |
| try: | |
| current_index = settings.llm_models.index(st.session_state.current_llm) | |
| except ValueError: | |
| current_index = 0 | |
| else: | |
| model_options = settings.ollama_models | |
| try: | |
| current_index = settings.ollama_models.index(st.session_state.current_llm) | |
| except ValueError: | |
| current_index = 0 | |
| selected_llm = st.selectbox( | |
| f"Select {'Groq' if current_provider == 'groq' else 'Ollama'} Model for evaluation:", | |
| model_options, | |
| index=current_index, | |
| key="eval_llm_selector" | |
| ) | |
| # Show provider info | |
| provider_icon = "βοΈ" if current_provider == "groq" else "π₯οΈ" | |
| if current_provider == "ollama": | |
| st.caption(f"{provider_icon} Using local Ollama - **No rate limits!** Fast evaluation possible.") | |
| else: | |
| st.caption(f"{provider_icon} Using Groq API - Rate limited to {settings.groq_rpm_limit} RPM") | |
| # Show method description | |
| method_descriptions = { | |
| "trace": """ | |
| **TRACE Heuristic Method** (Fast, Rule-Based) | |
| - Utilization: How well the system uses retrieved documents | |
| - Relevance: Relevance of retrieved documents to the query | |
| - Adherence: How well the response adheres to the retrieved context | |
| - Completeness: How complete the response is in answering the query | |
| - β‘ Speed: ~100ms per evaluation | |
| - π° Cost: Free (no API calls) | |
| """, | |
| "gpt_labeling": """ | |
| **GPT Labeling Method** (Accurate, LLM-based) | |
| - Uses sentence-level LLM analysis (from RAGBench paper) | |
| - Context Relevance: Fraction of context relevant to query | |
| - Context Utilization: Fraction of relevant context used | |
| - Completeness: Fraction of relevant info covered | |
| - Adherence: Response supported by context (no hallucinations) | |
| - β±οΈ Speed: ~2-5 seconds per evaluation | |
| - π° Cost: ~$0.002-0.01 per evaluation | |
| """, | |
| "hybrid": """ | |
| **Hybrid Method** (Comprehensive) | |
| - Runs both TRACE and GPT Labeling methods | |
| - Provides both fast and accurate evaluation metrics | |
| - Best for detailed analysis | |
| - β±οΈ Speed: ~3-6 seconds per evaluation | |
| - π° Cost: Same as GPT Labeling | |
| """ | |
| } | |
| st.markdown(method_descriptions[selected_method]) | |
| # Get maximum test samples available for current dataset | |
| try: | |
| loader = RAGBenchLoader() | |
| max_test_samples = loader.get_test_data_size(st.session_state.dataset_name) | |
| st.caption(f"π Available test samples: {max_test_samples:,}") | |
| except Exception as e: | |
| max_test_samples = 100 | |
| st.caption(f"Available test samples: ~{max_test_samples} (estimated)") | |
| # Ensure min and max are reasonable | |
| max_test_samples = max(5, min(max_test_samples, 500)) # Cap at 500 for performance | |
| num_test_samples = st.slider( | |
| "Number of test samples", | |
| min_value=5, | |
| max_value=max_test_samples, | |
| value=min(10, max_test_samples), | |
| step=5 | |
| ) | |
| # Show warning for GPT labeling (API cost) - only for Groq | |
| if selected_method in ["gpt_labeling", "hybrid"]: | |
| current_provider = st.session_state.get("llm_provider", "groq") | |
| if current_provider == "groq": | |
| st.warning(f"β οΈ **{evaluation_method}** requires LLM API calls. This will incur costs and be slower due to rate limiting ({settings.groq_rpm_limit} RPM).") | |
| else: | |
| st.info(f"βΉοΈ **{evaluation_method}** using local Ollama - **No rate limits!** Evaluation will be much faster.") | |
| # Background evaluation toggle | |
| run_in_background = st.checkbox( | |
| "π Run in Background", | |
| value=False, | |
| help="Run evaluation in background. You can close the browser and check results later." | |
| ) | |
| col_run, col_jobs = st.columns([1, 1]) | |
| with col_run: | |
| if st.button("π¬ Run Evaluation", type="primary"): | |
| if run_in_background: | |
| # Start background job | |
| start_background_evaluation( | |
| num_test_samples, | |
| selected_llm, | |
| selected_method | |
| ) | |
| else: | |
| # Use selected LLM for evaluation | |
| run_evaluation(num_test_samples, selected_llm, selected_method) | |
| with col_jobs: | |
| if st.button("π View Background Jobs"): | |
| st.session_state.show_background_jobs = True | |
| # Show background jobs panel | |
| if st.session_state.get("show_background_jobs", False): | |
| display_background_jobs() | |
| # Display results | |
| if st.session_state.evaluation_results: | |
| results = st.session_state.evaluation_results | |
| st.success("β Evaluation Complete!") | |
| st.divider() | |
| st.markdown("## π Evaluation Metrics") | |
| # Display aggregate scores - handle both TRACE and GPT Labeling metric names | |
| st.markdown("### Main Metrics") | |
| col1, col2, col3, col4, col5 = st.columns(5) | |
| # Determine which metrics are available | |
| utilization = results.get('utilization') or results.get('context_utilization', 0) | |
| relevance = results.get('relevance') or results.get('context_relevance', 0) | |
| adherence = results.get('adherence', 0) | |
| completeness = results.get('completeness', 0) | |
| average = results.get('average', 0) | |
| with col1: | |
| st.metric("π Utilization", f"{utilization:.3f}") | |
| with col2: | |
| st.metric("π― Relevance", f"{relevance:.3f}") | |
| with col3: | |
| st.metric("β Adherence", f"{adherence:.3f}") | |
| with col4: | |
| st.metric("π Completeness", f"{completeness:.3f}") | |
| with col5: | |
| st.metric("β Average", f"{average:.3f}") | |
| # Detailed results summary - handle both metric types | |
| if "individual_scores" in results: | |
| with st.expander("π Summary Metrics by Query"): | |
| df = pd.DataFrame(results["individual_scores"]) | |
| st.dataframe(df, use_container_width=True) | |
| # Detailed per-query results | |
| if "detailed_results" in results and results["detailed_results"]: | |
| with st.expander("π Detailed Per-Query Analysis"): | |
| for query_result in results.get("detailed_results", []): | |
| with st.expander(f"Query {query_result['query_id']}: {query_result['question'][:60]}..."): | |
| st.markdown("### Question") | |
| st.write(query_result['question']) | |
| st.markdown("### LLM Response") | |
| st.write(query_result.get('llm_response', 'N/A')) | |
| st.markdown("### Retrieved Documents") | |
| for doc_idx, doc in enumerate(query_result.get('retrieved_documents', []), 1): | |
| with st.expander(f"π Document {doc_idx}"): | |
| st.write(doc) | |
| if query_result.get('ground_truth'): | |
| st.markdown("### Ground Truth") | |
| st.write(query_result['ground_truth']) | |
| # Display metrics with correct labels based on method | |
| metrics = query_result.get('metrics', {}) | |
| if metrics: | |
| st.markdown("### Evaluation Metrics") | |
| col1, col2, col3, col4, col5 = st.columns(5) | |
| # Get metric values (handle both TRACE and GPT names) | |
| util_val = metrics.get('utilization') or metrics.get('context_utilization', 0) | |
| rel_val = metrics.get('relevance') or metrics.get('context_relevance', 0) | |
| adh_val = metrics.get('adherence', 0) | |
| comp_val = metrics.get('completeness', 0) | |
| avg_val = metrics.get('average', 0) | |
| with col1: | |
| st.metric("Util", f"{util_val:.3f}") | |
| with col2: | |
| st.metric("Rel", f"{rel_val:.3f}") | |
| with col3: | |
| st.metric("Adh", f"{adh_val:.3f}") | |
| with col4: | |
| st.metric("Comp", f"{comp_val:.3f}") | |
| with col5: | |
| st.metric("Avg", f"{avg_val:.3f}") | |
| # For GPT Labeling and Hybrid methods, show additional metrics | |
| method = results.get("method", "") | |
| if "gpt_labeling" in method or "hybrid" in method: | |
| # Show RMSE aggregation metrics (consistency across evaluations) | |
| if "rmse_metrics" in results: | |
| st.markdown("### π RMSE Aggregation (Metric Consistency)") | |
| rmse_data = results.get("rmse_metrics", {}) | |
| rmse_cols = st.columns(4) | |
| with rmse_cols[0]: | |
| rel_mean = rmse_data.get("context_relevance", {}).get("mean", 0) | |
| rel_std = rmse_data.get("context_relevance", {}).get("std_dev", 0) | |
| st.metric("Relevance", f"{rel_mean:.3f} Β±{rel_std:.3f}", help="Mean and Std Dev") | |
| with rmse_cols[1]: | |
| util_mean = rmse_data.get("context_utilization", {}).get("mean", 0) | |
| util_std = rmse_data.get("context_utilization", {}).get("std_dev", 0) | |
| st.metric("Utilization", f"{util_mean:.3f} Β±{util_std:.3f}", help="Mean and Std Dev") | |
| with rmse_cols[2]: | |
| comp_mean = rmse_data.get("completeness", {}).get("mean", 0) | |
| comp_std = rmse_data.get("completeness", {}).get("std_dev", 0) | |
| st.metric("Completeness", f"{comp_mean:.3f} Β±{comp_std:.3f}", help="Mean and Std Dev") | |
| with rmse_cols[3]: | |
| adh_mean = rmse_data.get("adherence", {}).get("mean", 0) | |
| adh_std = rmse_data.get("adherence", {}).get("std_dev", 0) | |
| st.metric("Adherence", f"{adh_mean:.3f} Β±{adh_std:.3f}", help="Mean and Std Dev") | |
| # Show detailed RMSE statistics in expander | |
| with st.expander("See detailed RMSE aggregation statistics"): | |
| for metric_name, metric_data in rmse_data.items(): | |
| st.write(f"**{metric_name}**") | |
| col1, col2, col3, col4 = st.columns(4) | |
| with col1: | |
| st.write(f"Mean: {metric_data.get('mean', 0):.4f}") | |
| with col2: | |
| st.write(f"Std Dev: {metric_data.get('std_dev', 0):.4f}") | |
| with col3: | |
| st.write(f"Min: {metric_data.get('min', 0):.4f}") | |
| with col4: | |
| st.write(f"Max: {metric_data.get('max', 0):.4f}") | |
| # Show per-metric statistics if available | |
| if "per_metric_statistics" in results: | |
| st.markdown("### π Per-Metric Statistics (Distribution)") | |
| stats_data = results.get("per_metric_statistics", {}) | |
| stats_cols = st.columns(4) | |
| with stats_cols[0]: | |
| rel_stats = stats_data.get("context_relevance", {}) | |
| st.metric("Relevance Mean", f"{rel_stats.get('mean', 0):.3f}", help=f"Median: {rel_stats.get('median', 0):.3f}") | |
| with stats_cols[1]: | |
| util_stats = stats_data.get("context_utilization", {}) | |
| st.metric("Utilization Mean", f"{util_stats.get('mean', 0):.3f}", help=f"Median: {util_stats.get('median', 0):.3f}") | |
| with stats_cols[2]: | |
| comp_stats = stats_data.get("completeness", {}) | |
| st.metric("Completeness Mean", f"{comp_stats.get('mean', 0):.3f}", help=f"Median: {comp_stats.get('median', 0):.3f}") | |
| with stats_cols[3]: | |
| adh_stats = stats_data.get("adherence", {}) | |
| st.metric("Adherence Mean", f"{adh_stats.get('mean', 0):.3f}", help=f"Median: {adh_stats.get('median', 0):.3f}") | |
| # Show detailed statistics | |
| with st.expander("See detailed per-metric statistics"): | |
| for metric_name, metric_stats in stats_data.items(): | |
| st.write(f"**{metric_name}**") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.write(f""" | |
| - Mean: {metric_stats.get('mean', 0):.4f} | |
| - Median: {metric_stats.get('median', 0):.4f} | |
| - Std Dev: {metric_stats.get('std_dev', 0):.4f} | |
| - Min: {metric_stats.get('min', 0):.4f} | |
| - Max: {metric_stats.get('max', 0):.4f} | |
| """) | |
| with col2: | |
| st.write(f""" | |
| - 25th percentile: {metric_stats.get('percentile_25', 0):.4f} | |
| - 75th percentile: {metric_stats.get('percentile_75', 0):.4f} | |
| - Perfect (>=0.95): {metric_stats.get('perfect_count', 0)} | |
| - Poor (<0.3): {metric_stats.get('poor_count', 0)} | |
| - Samples: {metric_stats.get('sample_count', 0)} | |
| """) | |
| # Show RMSE vs RAGBench Ground Truth (per RAGBench paper requirement) | |
| if "rmse_vs_ground_truth" in results: | |
| st.markdown("### π RMSE vs RAGBench Ground Truth") | |
| st.info("Compares predicted TRACE scores against original RAGBench dataset scores") | |
| rmse_gt = results.get("rmse_vs_ground_truth", {}) | |
| per_metric_rmse = rmse_gt.get("per_metric_rmse", {}) | |
| if per_metric_rmse: | |
| rmse_gt_cols = st.columns(5) | |
| with rmse_gt_cols[0]: | |
| st.metric("Relevance RMSE", f"{per_metric_rmse.get('context_relevance', 0):.4f}", | |
| delta=None, help="Lower is better (0 = perfect match)") | |
| with rmse_gt_cols[1]: | |
| st.metric("Utilization RMSE", f"{per_metric_rmse.get('context_utilization', 0):.4f}") | |
| with rmse_gt_cols[2]: | |
| st.metric("Completeness RMSE", f"{per_metric_rmse.get('completeness', 0):.4f}") | |
| with rmse_gt_cols[3]: | |
| st.metric("Adherence RMSE", f"{per_metric_rmse.get('adherence', 0):.4f}") | |
| with rmse_gt_cols[4]: | |
| agg_rmse = rmse_gt.get("aggregated_rmse", 0) | |
| consistency = rmse_gt.get("consistency_score", 0) | |
| st.metric("Aggregated RMSE", f"{agg_rmse:.4f}", | |
| delta=f"Consistency: {consistency:.2%}", delta_color="normal") | |
| # Show AUCROC vs RAGBench Ground Truth (per RAGBench paper requirement) | |
| if "aucroc_vs_ground_truth" in results: | |
| st.markdown("### π AUC-ROC vs RAGBench Ground Truth") | |
| st.info("Area Under ROC Curve comparing predicted vs ground truth binary classifications") | |
| auc_gt = results.get("aucroc_vs_ground_truth", {}) | |
| if auc_gt: | |
| auc_cols = st.columns(5) | |
| with auc_cols[0]: | |
| st.metric("Relevance AUC", f"{auc_gt.get('context_relevance', 0):.4f}", | |
| help="Higher is better (1.0 = perfect classification)") | |
| with auc_cols[1]: | |
| st.metric("Utilization AUC", f"{auc_gt.get('context_utilization', 0):.4f}") | |
| with auc_cols[2]: | |
| st.metric("Completeness AUC", f"{auc_gt.get('completeness', 0):.4f}") | |
| with auc_cols[3]: | |
| st.metric("Adherence AUC", f"{auc_gt.get('adherence', 0):.4f}") | |
| with auc_cols[4]: | |
| avg_auc = auc_gt.get("average", 0) | |
| st.metric("Average AUC", f"{avg_auc:.4f}") | |
| # Download results | |
| st.divider() | |
| st.markdown("## πΎ Download Results") | |
| # Create a comprehensive download with all details | |
| download_data = { | |
| "evaluation_metadata": { | |
| "timestamp": datetime.now().isoformat(), | |
| "dataset": st.session_state.dataset_name, | |
| "method": results.get("evaluation_config", {}).get("evaluation_method", "gpt_labeling_prompts"), | |
| "total_samples": results.get("num_samples", 0), | |
| "embedding_model": st.session_state.embedding_model, | |
| }, | |
| "aggregate_metrics": { | |
| "context_relevance": results.get("context_relevance") or results.get("relevance", 0), | |
| "context_utilization": results.get("context_utilization") or results.get("utilization", 0), | |
| "completeness": results.get("completeness", 0), | |
| "adherence": results.get("adherence", 0), | |
| "average": results.get("average", 0), | |
| }, | |
| "rmse_metrics": results.get("rmse_metrics", {}), | |
| "per_metric_statistics": results.get("per_metric_statistics", {}), | |
| "rmse_vs_ground_truth": results.get("rmse_vs_ground_truth", {}), | |
| "aucroc_vs_ground_truth": results.get("aucroc_vs_ground_truth", {}), | |
| "detailed_results": results.get("detailed_results", []) | |
| } | |
| results_json = json.dumps(download_data, indent=2, default=str) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.download_button( | |
| label="π₯ Download Complete Results (JSON)", | |
| data=results_json, | |
| file_name=f"evaluation_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json", | |
| mime="application/json", | |
| help="Download all evaluation results including metrics and per-query details" | |
| ) | |
| with col2: | |
| st.download_button( | |
| label="π Download Metrics Only (JSON)", | |
| data=json.dumps(download_data["aggregate_metrics"], indent=2), | |
| file_name=f"evaluation_metrics_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json", | |
| mime="application/json", | |
| help="Download only the aggregate metrics" | |
| ) | |
| def start_background_evaluation(num_samples: int, selected_llm: str = None, method: str = "trace"): | |
| """Start evaluation in background thread. | |
| Args: | |
| num_samples: Number of test samples to evaluate | |
| selected_llm: LLM model to use for evaluation | |
| method: Evaluation method | |
| """ | |
| if not BACKGROUND_EVAL_AVAILABLE: | |
| st.error("Background evaluation not available. Running in foreground.") | |
| run_evaluation(num_samples, selected_llm, method) | |
| return | |
| try: | |
| # Generate unique job ID | |
| job_id = f"{st.session_state.collection_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{str(uuid.uuid4())[:8]}" | |
| # Create job record | |
| job = create_job( | |
| job_id=job_id, | |
| collection_name=st.session_state.collection_name, | |
| dataset_name=st.session_state.dataset_name, | |
| num_samples=num_samples, | |
| method=method, | |
| llm_model=selected_llm or st.session_state.current_llm, | |
| embedding_model=st.session_state.embedding_model, | |
| llm_provider=st.session_state.get("llm_provider", "groq") | |
| ) | |
| # Start background thread | |
| start_background_job( | |
| job_id=job_id, | |
| rag_pipeline=st.session_state.rag_pipeline, | |
| vector_store=st.session_state.vector_store, | |
| dataset_name=st.session_state.dataset_name, | |
| num_samples=num_samples, | |
| method=method, | |
| llm_model=selected_llm or st.session_state.current_llm, | |
| embedding_model=st.session_state.embedding_model, | |
| llm_provider=st.session_state.get("llm_provider", "groq"), | |
| groq_api_key=st.session_state.get("groq_api_key"), | |
| groq_api_keys=settings.groq_api_keys if settings.groq_api_keys else None, | |
| ollama_host=settings.ollama_host | |
| ) | |
| st.success(f""" | |
| β **Background evaluation started!** | |
| - **Job ID:** `{job_id}` | |
| - **Method:** {method} | |
| - **Samples:** {num_samples} | |
| You can safely close this page. Check back later using "View Background Jobs" button. | |
| """) | |
| # Show the jobs panel | |
| st.session_state.show_background_jobs = True | |
| except Exception as e: | |
| st.error(f"Failed to start background evaluation: {str(e)}") | |
| def display_background_jobs(): | |
| """Display background jobs status panel.""" | |
| st.divider() | |
| st.markdown("### π Background Evaluation Jobs") | |
| if not BACKGROUND_EVAL_AVAILABLE: | |
| st.warning("Background evaluation module not available.") | |
| return | |
| # Refresh button | |
| if st.button("π Refresh Status"): | |
| st.rerun() | |
| # Get all jobs | |
| jobs = list_jobs() | |
| if not jobs: | |
| st.info("No background jobs found.") | |
| if st.button("Close Panel"): | |
| st.session_state.show_background_jobs = False | |
| st.rerun() | |
| return | |
| # Display jobs | |
| for job in jobs[:10]: # Show last 10 jobs | |
| job_id = job.get("job_id", "Unknown") | |
| status = job.get("status", "unknown") | |
| progress = job.get("progress", 0) | |
| # Status icon | |
| status_icons = { | |
| "pending": "β³", | |
| "running": "π", | |
| "completed": "β ", | |
| "failed": "β" | |
| } | |
| icon = status_icons.get(status, "β") | |
| with st.expander(f"{icon} {job.get('collection_name', 'Unknown')} - {status.upper()} ({progress}%)", expanded=(status == "running")): | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| st.write(f"**Job ID:** `{job_id}`") | |
| st.write(f"**Method:** {job.get('method', 'Unknown')}") | |
| st.write(f"**Samples:** {job.get('num_samples', 0)}") | |
| st.write(f"**LLM:** {job.get('llm_model', 'Unknown')}") | |
| st.write(f"**Created:** {job.get('created_at', 'Unknown')}") | |
| if status == "running": | |
| st.progress(progress / 100) | |
| st.write(f"**Current Step:** {job.get('current_step', '')}") | |
| if status == "completed": | |
| st.write(f"**Completed:** {job.get('completed_at', 'Unknown')}") | |
| results_file = job.get('results_file') | |
| if results_file: | |
| st.write(f"**Results File:** `{results_file}`") | |
| # Load results into session state | |
| if st.button(f"π Load Results", key=f"load_{job_id}"): | |
| results = job.get("results") | |
| if results: | |
| st.session_state.evaluation_results = results | |
| st.success("Results loaded! Scroll down to see metrics.") | |
| st.rerun() | |
| if status == "failed": | |
| st.error(f"**Error:** {job.get('error', 'Unknown error')}") | |
| with col2: | |
| # Show logs | |
| logs = job.get("logs", []) | |
| if logs: | |
| st.write("**Recent Logs:**") | |
| for log in logs[-5:]: # Last 5 logs | |
| st.caption(f"{log.get('message', '')}") | |
| # Delete button | |
| if st.button(f"ποΈ Delete", key=f"del_{job_id}"): | |
| delete_job(job_id) | |
| st.success(f"Job {job_id} deleted.") | |
| st.rerun() | |
| if st.button("Close Panel"): | |
| st.session_state.show_background_jobs = False | |
| st.rerun() | |
| def run_evaluation(num_samples: int, selected_llm: str = None, method: str = "trace"): | |
| """Run evaluation using selected method (TRACE, GPT Labeling, or Hybrid). | |
| Args: | |
| num_samples: Number of test samples to evaluate | |
| selected_llm: LLM model to use for evaluation | |
| method: Evaluation method ("trace", "gpt_labeling", or "hybrid") | |
| """ | |
| with st.spinner(f"Running evaluation on {num_samples} samples..."): | |
| try: | |
| # Create logs container | |
| logs_container = st.container() | |
| logs_list = [] | |
| # Display logs header once outside function | |
| logs_placeholder = st.empty() | |
| def add_log(message: str): | |
| """Add log message and update display.""" | |
| logs_list.append(message) | |
| with logs_placeholder.container(): | |
| st.markdown("### π Evaluation Logs:") | |
| for log_msg in logs_list: | |
| st.caption(log_msg) | |
| # Log evaluation start | |
| timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| add_log(f"β±οΈ Evaluation started at {timestamp}") | |
| add_log(f"π Dataset: {st.session_state.dataset_name}") | |
| add_log(f"π Total samples: {num_samples}") | |
| add_log(f"π€ LLM Model: {selected_llm if selected_llm else st.session_state.current_llm}") | |
| add_log(f"π Vector Store: {st.session_state.collection_name}") | |
| add_log(f"π§ Embedding Model: {st.session_state.embedding_model}") | |
| # Map method names | |
| method_names = { | |
| "trace": "TRACE (Heuristic)", | |
| "gpt_labeling": "GPT Labeling (LLM-based)", | |
| "hybrid": "Hybrid (Both)" | |
| } | |
| add_log(f"π¬ Evaluation Method: {method_names.get(method, method)}") | |
| # Use selected LLM if provided - create with appropriate provider | |
| eval_llm_client = None | |
| original_llm = None | |
| current_provider = st.session_state.get("llm_provider", "groq") | |
| if selected_llm and selected_llm != st.session_state.current_llm: | |
| add_log(f"π Switching LLM to {selected_llm} ({current_provider.upper()})...") | |
| groq_api_key = st.session_state.groq_api_key if "groq_api_key" in st.session_state else "" | |
| eval_llm_client = create_llm_client( | |
| provider=current_provider, | |
| api_key=groq_api_key, | |
| api_keys=settings.groq_api_keys if settings.groq_api_keys else None, | |
| model_name=selected_llm, | |
| ollama_host=settings.ollama_host, | |
| max_rpm=settings.groq_rpm_limit, | |
| rate_limit_delay=settings.rate_limit_delay, | |
| max_retries=settings.max_retries, | |
| retry_delay=settings.retry_delay | |
| ) | |
| # Temporarily replace LLM client | |
| original_llm = st.session_state.rag_pipeline.llm | |
| st.session_state.rag_pipeline.llm = eval_llm_client | |
| else: | |
| eval_llm_client = st.session_state.rag_pipeline.llm | |
| # Log provider info | |
| provider_icon = "βοΈ" if current_provider == "groq" else "π₯οΈ" | |
| add_log(f"{provider_icon} LLM Provider: {current_provider.upper()}") | |
| # Get test data | |
| add_log("π₯ Loading test data...") | |
| loader = RAGBenchLoader() | |
| test_data = loader.get_test_data( | |
| st.session_state.dataset_name, | |
| num_samples | |
| ) | |
| add_log(f"β Loaded {len(test_data)} test samples") | |
| # Prepare test cases | |
| test_cases = [] | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| add_log("π Processing samples...") | |
| for i, sample in enumerate(test_data): | |
| status_text.text(f"Processing sample {i+1}/{num_samples}") | |
| # Query the RAG system | |
| result = st.session_state.rag_pipeline.query( | |
| sample["question"], | |
| n_results=5 | |
| ) | |
| # Prepare test case | |
| test_cases.append({ | |
| "query": sample["question"], | |
| "response": result["response"], | |
| "retrieved_documents": [doc["document"] for doc in result["retrieved_documents"]], | |
| "ground_truth": sample.get("answer", "") | |
| }) | |
| # Update progress | |
| progress_bar.progress((i + 1) / num_samples) | |
| # Log every 10 samples | |
| if (i + 1) % 10 == 0 or (i + 1) == num_samples: | |
| add_log(f" β Processed {i + 1}/{num_samples} samples") | |
| status_text.text(f"Running {method_names.get(method, method)} evaluation...") | |
| add_log(f"π Running evaluation using {method_names.get(method, method)}...") | |
| # Extract chunking and embedding metadata from session state | |
| # (These were stored when the collection was loaded/created) | |
| chunking_strategy = st.session_state.vector_store.chunking_strategy if st.session_state.vector_store else None | |
| embedding_model = st.session_state.embedding_model | |
| chunk_size = st.session_state.vector_store.chunk_size if st.session_state.vector_store else None | |
| chunk_overlap = st.session_state.vector_store.chunk_overlap if st.session_state.vector_store else None | |
| # Log retrieval configuration | |
| add_log(f"π§ Retrieval Configuration:") | |
| add_log(f" β’ Chunking Strategy: {chunking_strategy or 'Unknown'}") | |
| add_log(f" β’ Chunk Size: {chunk_size or 'Unknown'}") | |
| add_log(f" β’ Chunk Overlap: {chunk_overlap or 'Unknown'}") | |
| add_log(f" β’ Embedding Model: {embedding_model or 'Unknown'}") | |
| # Import unified pipeline | |
| try: | |
| from evaluation_pipeline import UnifiedEvaluationPipeline | |
| # Run evaluation with metadata using unified pipeline | |
| pipeline = UnifiedEvaluationPipeline( | |
| llm_client=eval_llm_client, | |
| chunking_strategy=chunking_strategy, | |
| embedding_model=embedding_model, | |
| chunk_size=chunk_size, | |
| chunk_overlap=chunk_overlap | |
| ) | |
| # Run evaluation with selected method | |
| results = pipeline.evaluate_batch(test_cases, method=method) | |
| except ImportError: | |
| # Fallback to TRACE only if evaluation_pipeline module not available | |
| add_log("β οΈ evaluation_pipeline module not found, falling back to TRACE...") | |
| # Run evaluation with metadata using TRACE | |
| evaluator = TRACEEvaluator( | |
| chunking_strategy=chunking_strategy, | |
| embedding_model=embedding_model, | |
| chunk_size=chunk_size, | |
| chunk_overlap=chunk_overlap | |
| ) | |
| results = evaluator.evaluate_batch(test_cases) | |
| st.session_state.evaluation_results = results | |
| # Log evaluation results summary | |
| add_log("β Evaluation completed successfully!") | |
| # Display appropriate metrics based on method | |
| if method == "trace": | |
| add_log(f" β’ Utilization: {results.get('utilization', 0):.2%}") | |
| add_log(f" β’ Relevance: {results.get('relevance', 0):.2%}") | |
| add_log(f" β’ Adherence: {results.get('adherence', 0):.2%}") | |
| add_log(f" β’ Completeness: {results.get('completeness', 0):.2%}") | |
| add_log(f" β’ Average: {results.get('average', 0):.2%}") | |
| elif method == "gpt_labeling": | |
| if "context_relevance" in results: | |
| add_log(f" β’ Context Relevance: {results.get('context_relevance', 0):.2%}") | |
| add_log(f" β’ Context Utilization: {results.get('context_utilization', 0):.2%}") | |
| add_log(f" β’ Completeness: {results.get('completeness', 0):.2%}") | |
| add_log(f" β’ Adherence: {results.get('adherence', 0):.2%}") | |
| add_log(f" β’ Average: {results.get('average', 0):.2%}") | |
| # NEW: Display RMSE and AUCROC metrics if available | |
| if "rmse_metrics" in results: | |
| add_log(f"π RMSE Metrics (vs ground truth):") | |
| rmse_metrics = results.get("rmse_metrics", {}) | |
| add_log(f" β’ Context Relevance RMSE: {rmse_metrics.get('relevance', 0):.4f}") | |
| add_log(f" β’ Context Utilization RMSE: {rmse_metrics.get('utilization', 0):.4f}") | |
| add_log(f" β’ Completeness RMSE: {rmse_metrics.get('completeness', 0):.4f}") | |
| add_log(f" β’ Adherence RMSE: {rmse_metrics.get('adherence', 0):.4f}") | |
| add_log(f" β’ Average RMSE: {rmse_metrics.get('average', 0):.4f}") | |
| if "auc_metrics" in results: | |
| add_log(f"π AUCROC Metrics (binary classification):") | |
| auc_metrics = results.get("auc_metrics", {}) | |
| add_log(f" β’ Context Relevance AUCROC: {auc_metrics.get('relevance', 0):.4f}") | |
| add_log(f" β’ Context Utilization AUCROC: {auc_metrics.get('utilization', 0):.4f}") | |
| add_log(f" β’ Completeness AUCROC: {auc_metrics.get('completeness', 0):.4f}") | |
| add_log(f" β’ Adherence AUCROC: {auc_metrics.get('adherence', 0):.4f}") | |
| add_log(f" β’ Average AUCROC: {auc_metrics.get('average', 0):.4f}") | |
| elif method == "hybrid": | |
| add_log(" π TRACE Metrics:") | |
| trace_res = results.get("trace_results", {}) | |
| add_log(f" β’ Utilization: {trace_res.get('utilization', 0):.2%}") | |
| add_log(f" β’ Relevance: {trace_res.get('relevance', 0):.2%}") | |
| add_log(f" β’ Adherence: {trace_res.get('adherence', 0):.2%}") | |
| add_log(f" β’ Completeness: {trace_res.get('completeness', 0):.2%}") | |
| add_log(" π§ GPT Labeling Metrics:") | |
| gpt_res = results.get("gpt_results", {}) | |
| add_log(f" β’ Context Relevance: {gpt_res.get('context_relevance', 0):.2%}") | |
| add_log(f" β’ Context Utilization: {gpt_res.get('context_utilization', 0):.2%}") | |
| add_log(f" β’ Completeness: {gpt_res.get('completeness', 0):.2%}") | |
| add_log(f" β’ Adherence: {gpt_res.get('adherence', 0):.2%}") | |
| # Restore original LLM if it was switched | |
| if selected_llm and selected_llm != st.session_state.current_llm and original_llm: | |
| st.session_state.rag_pipeline.llm = original_llm | |
| add_log(f"π Restored original LLM") | |
| add_log(f"β±οΈ Evaluation completed at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") | |
| # Auto-save evaluation results to JSON file | |
| try: | |
| collection_name = st.session_state.get("collection_name", "unknown_collection") | |
| timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') | |
| filename = f"{collection_name}_{timestamp}_evaluation.json" | |
| # Prepare save data with metadata | |
| save_data = { | |
| "collection_name": collection_name, | |
| "timestamp": datetime.now().isoformat(), | |
| "evaluation_method": method, | |
| "num_samples": num_samples, | |
| "embedding_model": st.session_state.get("embedding_model", "unknown"), | |
| "llm_model": st.session_state.get("current_llm", "unknown"), | |
| "results": results | |
| } | |
| # Save to file | |
| with open(filename, 'w', encoding='utf-8') as f: | |
| json.dump(save_data, f, indent=2, default=str) | |
| add_log(f"πΎ Results auto-saved to: {filename}") | |
| st.success(f"πΎ Results saved to: {filename}") | |
| except Exception as save_error: | |
| add_log(f"β οΈ Could not auto-save results: {str(save_error)}") | |
| except Exception as e: | |
| st.error(f"Error during evaluation: {str(e)}") | |
| add_log(f"β Error: {str(e)}") | |
| def history_interface(): | |
| """History interface tab.""" | |
| st.subheader("π Chat History") | |
| if not st.session_state.chat_history: | |
| st.info("No chat history yet. Start a conversation in the Chat tab!") | |
| return | |
| # Export history | |
| col1, col2 = st.columns([3, 1]) | |
| with col2: | |
| history_json = json.dumps(st.session_state.chat_history, indent=2) | |
| st.download_button( | |
| label="πΎ Export History", | |
| data=history_json, | |
| file_name=f"chat_history_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json", | |
| mime="application/json" | |
| ) | |
| # Display history | |
| for i, entry in enumerate(st.session_state.chat_history): | |
| with st.expander(f"π¬ Conversation {i+1}: {entry['query'][:50]}..."): | |
| st.markdown(f"**Query:** {entry['query']}") | |
| st.markdown(f"**Response:** {entry['response']}") | |
| st.markdown(f"**Timestamp:** {entry.get('timestamp', 'N/A')}") | |
| st.markdown("**Retrieved Documents:**") | |
| for j, doc in enumerate(entry["retrieved_documents"]): | |
| st.text_area( | |
| f"Document {j+1}", | |
| value=doc["document"], | |
| height=100, | |
| key=f"history_doc_{i}_{j}" | |
| ) | |
| if __name__ == "__main__": | |
| main() | |