CapStoneRAG10 / streamlit_app.py
Developer
Feature: Background evaluation - run GPT evaluation without keeping UI open
3bfe1a4
"""Streamlit chat interface for RAG application."""
import streamlit as st
import sys
import os
from datetime import datetime
import json
import pandas as pd
from typing import Optional
import warnings
# Suppress warnings
warnings.filterwarnings('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
# Add parent directory to path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
# Check if running on HuggingFace Spaces
IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None
from config import settings
from dataset_loader import RAGBenchLoader
from vector_store import ChromaDBManager, create_vector_store
try:
from vector_store import QdrantManager, QDRANT_AVAILABLE
except ImportError:
QDRANT_AVAILABLE = False
from llm_client import GroqLLMClient, OllamaLLMClient, RAGPipeline, create_llm_client
from trace_evaluator import TRACEEvaluator
from embedding_models import EmbeddingFactory
from chunking_strategies import ChunkingFactory
import uuid
# Background evaluation support
try:
from background_evaluator import (
create_job, start_background_job, load_job_status,
list_jobs, delete_job, STATUS_PENDING, STATUS_RUNNING,
STATUS_COMPLETED, STATUS_FAILED
)
BACKGROUND_EVAL_AVAILABLE = True
except ImportError:
BACKGROUND_EVAL_AVAILABLE = False
# Page configuration
st.set_page_config(
page_title="RAG Capstone Project",
page_icon="πŸ€–",
layout="wide"
)
# Initialize session state
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
if "rag_pipeline" not in st.session_state:
st.session_state.rag_pipeline = None
if "vector_store" not in st.session_state:
st.session_state.vector_store = None
if "collection_loaded" not in st.session_state:
st.session_state.collection_loaded = False
if "evaluation_results" not in st.session_state:
st.session_state.evaluation_results = None
if "dataset_size" not in st.session_state:
st.session_state.dataset_size = 10000
if "current_dataset" not in st.session_state:
st.session_state.current_dataset = None
if "current_llm" not in st.session_state:
st.session_state.current_llm = settings.llm_models[1]
if "selected_collection" not in st.session_state:
st.session_state.selected_collection = None
if "available_collections" not in st.session_state:
st.session_state.available_collections = []
if "dataset_name" not in st.session_state:
st.session_state.dataset_name = None
if "collection_name" not in st.session_state:
st.session_state.collection_name = None
if "embedding_model" not in st.session_state:
st.session_state.embedding_model = None
if "groq_api_key" not in st.session_state:
st.session_state.groq_api_key = ""
if "llm_provider" not in st.session_state:
st.session_state.llm_provider = settings.llm_provider
if "ollama_model" not in st.session_state:
st.session_state.ollama_model = settings.ollama_model
if "vector_store_provider" not in st.session_state:
st.session_state.vector_store_provider = settings.vector_store_provider
if "qdrant_url" not in st.session_state:
st.session_state.qdrant_url = settings.qdrant_url
if "qdrant_api_key" not in st.session_state:
st.session_state.qdrant_api_key = settings.qdrant_api_key
def get_available_collections(provider: str = None):
"""Get list of available collections from vector store."""
provider = provider or st.session_state.get("vector_store_provider", "chroma")
try:
if provider == "qdrant" and QDRANT_AVAILABLE:
qdrant_url = st.session_state.get("qdrant_url") or settings.qdrant_url
qdrant_api_key = st.session_state.get("qdrant_api_key") or settings.qdrant_api_key
if qdrant_url and qdrant_api_key:
vector_store = create_vector_store("qdrant", url=qdrant_url, api_key=qdrant_api_key)
collections = vector_store.list_collections()
return collections
return []
else:
vector_store = ChromaDBManager(settings.chroma_persist_directory)
collections = vector_store.list_collections()
return collections
except Exception as e:
print(f"Error getting collections: {e}")
return []
def main():
"""Main Streamlit application."""
st.title("πŸ€– RAG Capstone Project")
st.markdown("### Retrieval-Augmented Generation with TRACE Evaluation")
# Show HuggingFace Spaces notice
if IS_HUGGINGFACE_SPACE:
st.info("πŸ€— Running on Hugging Face Spaces - Using Groq API (cloud-based LLM)")
# Get available collections at startup
available_collections = get_available_collections()
st.session_state.available_collections = available_collections
# Sidebar for configuration
with st.sidebar:
st.header("Configuration")
# LLM Provider Selection - Disable Ollama on HuggingFace Spaces
st.subheader("πŸ”Œ LLM Provider")
if IS_HUGGINGFACE_SPACE:
# Force Groq on HuggingFace Spaces (Ollama not available)
st.caption("☁️ **Groq API** (Ollama unavailable on Spaces)")
llm_provider = "groq"
st.session_state.llm_provider = "groq"
else:
llm_provider = st.radio(
"Choose LLM Provider:",
options=["groq", "ollama"],
index=0 if st.session_state.llm_provider == "groq" else 1,
format_func=lambda x: "☁️ Groq API (Cloud)" if x == "groq" else "πŸ–₯️ Ollama (Local)",
help="Groq: Cloud API with rate limits. Ollama: Local unlimited inference.",
key="llm_provider_radio"
)
st.session_state.llm_provider = llm_provider
# Provider-specific settings
if llm_provider == "groq":
st.caption("⚠️ Free tier: 30 requests/min")
# On HuggingFace Spaces, check for API key in secrets first
default_api_key = os.environ.get("GROQ_API_KEY", "") or settings.groq_api_key or ""
# API Key input
groq_api_key = st.text_input(
"Groq API Key",
type="password",
value=default_api_key,
help="Enter your Groq API key (or set GROQ_API_KEY in Spaces secrets)"
)
if IS_HUGGINGFACE_SPACE and not groq_api_key:
st.warning("πŸ’‘ Tip: Add GROQ_API_KEY to your Space secrets for persistence")
else:
# Ollama settings (only available locally)
st.caption("βœ… No rate limits - unlimited usage!")
ollama_host = st.text_input(
"Ollama Host",
value=settings.ollama_host,
help="Ollama server URL (default: http://localhost:11434)"
)
ollama_model = st.selectbox(
"Select Ollama Model:",
options=settings.ollama_models,
index=settings.ollama_models.index(st.session_state.ollama_model) if st.session_state.ollama_model in settings.ollama_models else 0,
key="ollama_model_selector"
)
st.session_state.ollama_model = ollama_model
# Connection check button
if st.button("πŸ” Check Ollama Connection"):
try:
import requests
response = requests.get(f"{ollama_host}/api/tags", timeout=5)
if response.status_code == 200:
models = response.json().get("models", [])
model_names = [m["name"] for m in models]
st.success(f"βœ… Connected! Available models: {', '.join(model_names)}")
else:
st.error(f"❌ Connection failed: {response.status_code}")
except Exception as e:
st.error(f"❌ Cannot connect to Ollama: {e}")
st.info("Make sure Ollama is running: `ollama serve`")
groq_api_key = "" # Not needed for Ollama
st.divider()
# Vector Store Provider Selection
st.subheader("πŸ’Ύ Vector Store")
if IS_HUGGINGFACE_SPACE:
st.caption("☁️ Use **Qdrant Cloud** for persistent storage")
vector_store_options = ["qdrant", "chroma"]
default_idx = 0
else:
vector_store_options = ["chroma", "qdrant"]
default_idx = 0
vector_store_provider = st.radio(
"Choose Vector Store:",
options=vector_store_options,
index=default_idx,
format_func=lambda x: "☁️ Qdrant Cloud (Persistent)" if x == "qdrant" else "πŸ’Ύ ChromaDB (Local)",
help="Qdrant: Cloud storage (persistent). ChromaDB: Local storage (ephemeral on Spaces).",
key="vector_store_radio"
)
st.session_state.vector_store_provider = vector_store_provider
# Qdrant settings
if vector_store_provider == "qdrant":
default_qdrant_url = os.environ.get("QDRANT_URL", "") or settings.qdrant_url
default_qdrant_key = os.environ.get("QDRANT_API_KEY", "") or settings.qdrant_api_key
qdrant_url = st.text_input(
"Qdrant URL",
value=default_qdrant_url,
placeholder="https://xxx-xxx.aws.cloud.qdrant.io:6333",
help="Your Qdrant Cloud cluster URL"
)
qdrant_api_key = st.text_input(
"Qdrant API Key",
type="password",
value=default_qdrant_key,
help="Your Qdrant API key"
)
st.session_state.qdrant_url = qdrant_url
st.session_state.qdrant_api_key = qdrant_api_key
if not qdrant_url or not qdrant_api_key:
st.warning("⚠️ Get free Qdrant Cloud at: https://cloud.qdrant.io")
# Test Qdrant connection
if st.button("πŸ” Test Qdrant Connection"):
if qdrant_url and qdrant_api_key:
try:
test_store = create_vector_store("qdrant", url=qdrant_url, api_key=qdrant_api_key)
collections = test_store.list_collections()
st.success(f"βœ… Connected! Found {len(collections)} collection(s)")
except Exception as e:
st.error(f"❌ Connection failed: {e}")
else:
st.error("Please enter Qdrant URL and API Key")
st.divider()
# Get available collections based on provider
available_collections = get_available_collections(vector_store_provider)
st.session_state.available_collections = available_collections
# Option 1: Use existing collection
if available_collections:
st.subheader("πŸ“š Existing Collections")
st.write(f"Found {len(available_collections)} collection(s)")
selected_collection = st.selectbox(
"Or select existing collection:",
available_collections,
key="collection_selector"
)
if st.button("πŸ“– Load Existing Collection", type="secondary"):
# Validate based on provider
if llm_provider == "groq" and not groq_api_key:
st.error("Please enter your Groq API key")
elif vector_store_provider == "qdrant" and (not st.session_state.get("qdrant_url") or not st.session_state.get("qdrant_api_key")):
st.error("Please enter Qdrant URL and API Key")
else:
load_existing_collection(
groq_api_key,
selected_collection,
llm_provider,
ollama_host if llm_provider == "ollama" else None,
vector_store_provider
)
st.divider()
# Option 2: Create new collection
st.subheader("πŸ†• Create New Collection")
# Dataset selection
st.subheader("1. Dataset Selection")
dataset_name = st.selectbox(
"Choose Dataset",
settings.ragbench_datasets,
index=0
)
# Get dataset size dynamically
if st.button("πŸ” Check Dataset Size", key="check_size"):
with st.spinner("Checking dataset size..."):
try:
from datasets import load_dataset
# Load dataset with download_mode to avoid cache issues
st.info(f"Fetching dataset info for '{dataset_name}'...")
ds = load_dataset(
"rungalileo/ragbench",
dataset_name,
split="train",
trust_remote_code=True,
download_mode="force_redownload" # Force fresh download to avoid cache corruption
)
dataset_size = len(ds)
st.session_state.dataset_size = dataset_size
st.session_state.current_dataset = dataset_name
st.success(f"βœ… Dataset '{dataset_name}' has {dataset_size:,} samples available")
except Exception as e:
st.error(f"❌ Error: {str(e)}")
st.exception(e)
st.warning(f"Could not determine dataset size. Using default of 10,000.")
st.session_state.dataset_size = 10000
st.session_state.current_dataset = dataset_name
# Use stored dataset size or default
max_samples_available = st.session_state.get('dataset_size', 10000)
st.caption(f"Max available samples: {max_samples_available:,}")
num_samples = st.slider(
"Number of samples",
min_value=10,
max_value=max_samples_available,
value=min(100, max_samples_available),
step=50 if max_samples_available > 1000 else 10,
help="Adjust slider to select number of samples"
)
load_all_samples = st.checkbox(
"Load all available samples",
value=False,
help="Override slider and load entire dataset"
)
st.divider()
# Chunking strategy
st.subheader("2. Chunking Strategy")
chunking_strategy = st.selectbox(
"Choose Chunking Strategy",
settings.chunking_strategies,
index=0
)
chunk_size = st.slider(
"Chunk Size",
min_value=256,
max_value=1024,
value=512,
step=128
)
overlap = st.slider(
"Overlap",
min_value=0,
max_value=200,
value=50,
step=10
)
st.divider()
# Embedding model
st.subheader("3. Embedding Model")
embedding_model = st.selectbox(
"Choose Embedding Model",
settings.embedding_models,
index=0
)
st.divider()
# LLM model selection for new collection
st.subheader("4. LLM Model")
if llm_provider == "groq":
llm_model = st.selectbox(
"Choose Groq LLM",
settings.llm_models,
index=1
)
else:
llm_model = st.selectbox(
"Choose Ollama Model",
settings.ollama_models,
index=settings.ollama_models.index(st.session_state.ollama_model) if st.session_state.ollama_model in settings.ollama_models else 0,
key="llm_model_ollama"
)
st.divider()
# Load data button
if st.button("πŸš€ Load Data & Create Collection", type="primary"):
# Validate based on provider
if llm_provider == "groq" and not groq_api_key:
st.error("Please enter your Groq API key")
elif vector_store_provider == "qdrant" and (not st.session_state.get("qdrant_url") or not st.session_state.get("qdrant_api_key")):
st.error("Please enter Qdrant URL and API Key")
else:
# Use None for num_samples if loading all data
samples_to_load = None if load_all_samples else num_samples
load_and_create_collection(
groq_api_key,
dataset_name,
samples_to_load,
chunking_strategy,
chunk_size,
overlap,
embedding_model,
llm_model,
llm_provider,
ollama_host if llm_provider == "ollama" else None,
vector_store_provider
)
# Main content area
if not st.session_state.collection_loaded:
st.info("πŸ‘ˆ Please configure and load a dataset from the sidebar to begin")
# Show instructions
with st.expander("πŸ“– How to Use", expanded=True):
st.markdown("""
1. **Enter your Groq API Key** in the sidebar
2. **Select a dataset** from RAG Bench
3. **Choose a chunking strategy** (dense, sparse, hybrid, re-ranking)
4. **Select an embedding model** for document vectorization
5. **Choose an LLM model** for response generation
6. **Click "Load Data & Create Collection"** to initialize
7. **Start chatting** in the chat interface
8. **View retrieved documents** and evaluation metrics
9. **Run TRACE evaluation** on test data
""")
# Show available options
col1, col2 = st.columns(2)
with col1:
st.subheader("πŸ“Š Available Datasets")
for ds in settings.ragbench_datasets:
st.markdown(f"- {ds}")
with col2:
st.subheader("πŸ€– Available Models")
st.markdown("**Embedding Models:**")
for em in settings.embedding_models:
st.markdown(f"- {em}")
st.markdown("**LLM Models:**")
for lm in settings.llm_models:
st.markdown(f"- {lm}")
else:
# Create tabs for different functionalities
tab1, tab2, tab3 = st.tabs(["πŸ’¬ Chat", "πŸ“Š Evaluation", "πŸ“œ History"])
with tab1:
chat_interface()
with tab2:
evaluation_interface()
with tab3:
history_interface()
def load_existing_collection(api_key: str, collection_name: str, llm_provider: str = "groq", ollama_host: str = None, vector_store_provider: str = "chroma"):
"""Load an existing collection from vector store."""
with st.spinner(f"Loading collection '{collection_name}'..."):
try:
# Initialize vector store based on provider
if vector_store_provider == "qdrant":
qdrant_url = st.session_state.get("qdrant_url") or settings.qdrant_url
qdrant_api_key = st.session_state.get("qdrant_api_key") or settings.qdrant_api_key
vector_store = create_vector_store("qdrant", url=qdrant_url, api_key=qdrant_api_key)
else:
vector_store = ChromaDBManager(settings.chroma_persist_directory)
vector_store.get_collection(collection_name)
# Extract dataset name from collection name (format: dataset_name_strategy_model)
# Try to find which dataset this collection is based on
dataset_name = None
for ds in settings.ragbench_datasets:
if collection_name.startswith(ds.replace("-", "_")):
dataset_name = ds
break
if not dataset_name:
dataset_name = collection_name.split("_")[0] # Fallback: use first part
# Prompt for LLM selection based on provider
if llm_provider == "groq":
st.session_state.current_llm = st.selectbox(
"Select Groq LLM for this collection:",
settings.llm_models,
key=f"llm_selector_{collection_name}"
)
else:
st.session_state.current_llm = st.selectbox(
"Select Ollama Model for this collection:",
settings.ollama_models,
key=f"ollama_selector_{collection_name}"
)
# Initialize LLM client based on provider
st.info(f"Initializing LLM client ({llm_provider})...")
llm_client = create_llm_client(
provider=llm_provider,
api_key=api_key,
api_keys=settings.groq_api_keys if settings.groq_api_keys else None,
model_name=st.session_state.current_llm,
ollama_host=ollama_host or settings.ollama_host,
max_rpm=settings.groq_rpm_limit,
rate_limit_delay=settings.rate_limit_delay,
max_retries=settings.max_retries,
retry_delay=settings.retry_delay
)
# Create RAG pipeline with correct parameter names
st.info("Creating RAG pipeline...")
rag_pipeline = RAGPipeline(
llm_client=llm_client,
vector_store_manager=vector_store
)
# Store in session state
st.session_state.vector_store = vector_store
st.session_state.rag_pipeline = rag_pipeline
st.session_state.collection_loaded = True
st.session_state.current_collection = collection_name
st.session_state.selected_collection = collection_name
st.session_state.groq_api_key = api_key
st.session_state.dataset_name = dataset_name
st.session_state.collection_name = collection_name
st.session_state.llm_provider = llm_provider
# Extract embedding model from collection metadata or vector store
embedding_model_name = None
# For ChromaDB: check collection metadata
if hasattr(vector_store, 'current_collection') and vector_store.current_collection:
if hasattr(vector_store.current_collection, 'metadata'):
collection_metadata = vector_store.current_collection.metadata
if collection_metadata and "embedding_model" in collection_metadata:
embedding_model_name = collection_metadata["embedding_model"]
# For Qdrant or fallback: check if embedding_model was loaded on the vector store
if not embedding_model_name and hasattr(vector_store, 'embedding_model') and vector_store.embedding_model:
if hasattr(vector_store.embedding_model, 'model_name'):
embedding_model_name = vector_store.embedding_model.model_name
# Set session state
if embedding_model_name:
st.session_state.embedding_model = embedding_model_name
else:
st.session_state.embedding_model = None
# Display system prompt and model info
provider_icon = "☁️" if llm_provider == "groq" else "πŸ–₯️"
st.success(f"βœ… Collection '{collection_name}' loaded successfully! {provider_icon} Using {llm_provider.upper()}")
with st.expander("πŸ€– Model & System Prompt Information", expanded=False):
col1, col2 = st.columns(2)
with col1:
st.write(f"**Provider:** {provider_icon} {llm_provider.upper()}")
st.write(f"**Model:** {st.session_state.current_llm}")
st.write(f"**Collection:** {collection_name}")
st.write(f"**Dataset:** {dataset_name}")
with col2:
st.write(f"**Temperature:** 0.0")
st.write(f"**Max Tokens:** 2048")
if llm_provider == "groq":
st.write(f"**Rate Limit:** {settings.groq_rpm_limit} RPM")
else:
st.write(f"**Rate Limit:** βœ… Unlimited (Local)")
st.markdown("#### System Prompt")
st.info("""
You are a Fact-Checking and Citation Specialist. Your task is to perform a rigorous audit of a response against provided documents to determine its accuracy, relevance, and level of support.
**Task:**
1. Analyze the provided documents and identify information relevant to the user's question
2. Evaluate the response sentence-by-sentence
3. Verify each response sentence maps to supporting document sentences
4. Identify which document sentences were actually used in the response
""")
st.rerun()
except Exception as e:
st.error(f"Error loading collection: {str(e)}")
st.exception(e)
def load_and_create_collection(
api_key: str,
dataset_name: str,
num_samples: Optional[int],
chunking_strategy: str,
chunk_size: int,
overlap: int,
embedding_model: str,
llm_model: str,
llm_provider: str = "groq",
ollama_host: str = None,
vector_store_provider: str = "chroma"
):
"""Load dataset and create vector collection."""
with st.spinner("Loading dataset and creating collection..."):
try:
# Initialize dataset loader
loader = RAGBenchLoader()
# Load dataset
if num_samples is None:
st.info(f"Loading {dataset_name} dataset (all available samples)...")
else:
st.info(f"Loading {dataset_name} dataset ({num_samples} samples)...")
dataset = loader.load_dataset(dataset_name, split="train", max_samples=num_samples)
if not dataset:
st.error("Failed to load dataset")
return
# Initialize vector store based on provider
st.info(f"Initializing vector store ({vector_store_provider})...")
if vector_store_provider == "qdrant":
qdrant_url = st.session_state.get("qdrant_url") or settings.qdrant_url
qdrant_api_key = st.session_state.get("qdrant_api_key") or settings.qdrant_api_key
vector_store = create_vector_store("qdrant", url=qdrant_url, api_key=qdrant_api_key)
else:
vector_store = ChromaDBManager(settings.chroma_persist_directory)
# Create collection name
collection_name = f"{dataset_name}_{chunking_strategy}_{embedding_model.split('/')[-1]}"
collection_name = collection_name.replace("-", "_").replace(".", "_")
# Delete existing collection with same name (if exists)
existing_collections = vector_store.list_collections()
if collection_name in existing_collections:
st.warning(f"Collection '{collection_name}' already exists. Deleting and recreating...")
vector_store.delete_collection(collection_name)
st.info("Old collection deleted. Creating new one...")
# Load data into collection
st.info(f"Creating collection with {chunking_strategy} chunking...")
vector_store.load_dataset_into_collection(
collection_name=collection_name,
embedding_model_name=embedding_model,
chunking_strategy=chunking_strategy,
dataset_data=dataset,
chunk_size=chunk_size,
overlap=overlap
)
# Initialize LLM client based on provider
st.info(f"Initializing LLM client ({llm_provider})...")
llm_client = create_llm_client(
provider=llm_provider,
api_key=api_key,
api_keys=settings.groq_api_keys if settings.groq_api_keys else None,
model_name=llm_model,
ollama_host=ollama_host or settings.ollama_host,
max_rpm=settings.groq_rpm_limit,
rate_limit_delay=settings.rate_limit_delay,
max_retries=settings.max_retries,
retry_delay=settings.retry_delay
)
# Create RAG pipeline with correct parameter names
rag_pipeline = RAGPipeline(
llm_client=llm_client,
vector_store_manager=vector_store
)
# Store in session state
st.session_state.vector_store = vector_store
st.session_state.rag_pipeline = rag_pipeline
st.session_state.collection_loaded = True
st.session_state.current_collection = collection_name
st.session_state.dataset_name = dataset_name
st.session_state.dataset = dataset
st.session_state.collection_name = collection_name
st.session_state.embedding_model = embedding_model
st.session_state.groq_api_key = api_key
st.session_state.llm_provider = llm_provider
st.session_state.vector_store_provider = vector_store_provider
provider_icon = "☁️" if llm_provider == "groq" else "πŸ–₯️"
vs_icon = "☁️" if vector_store_provider == "qdrant" else "πŸ’Ύ"
st.success(f"βœ… Collection '{collection_name}' created successfully! {provider_icon} Using {llm_provider.upper()}")
st.rerun()
except Exception as e:
st.error(f"Error: {str(e)}")
def chat_interface():
"""Chat interface tab."""
st.subheader("πŸ’¬ Chat Interface")
# Check if collection is loaded
if not st.session_state.collection_loaded:
st.warning("⚠️ No data loaded. Please use the configuration panel to load a dataset and create a collection.")
st.info("""
Steps:
1. Select a dataset from the dropdown
2. Click "Load Data & Create Collection" button
3. Wait for the collection to be created
4. Then you can start chatting
""")
return
# Display collection info and LLM selector
col1, col2, col3 = st.columns([2, 2, 1])
with col1:
provider_icon = "☁️" if st.session_state.get("llm_provider", "groq") == "groq" else "πŸ–₯️"
st.info(f"πŸ“š Collection: {st.session_state.current_collection} | {provider_icon} {st.session_state.get('llm_provider', 'groq').upper()}")
with col2:
# LLM selector for chat - based on provider
current_provider = st.session_state.get("llm_provider", "groq")
if current_provider == "groq":
model_options = settings.llm_models
try:
current_index = settings.llm_models.index(st.session_state.current_llm)
except ValueError:
current_index = 0
else:
model_options = settings.ollama_models
try:
current_index = settings.ollama_models.index(st.session_state.current_llm)
except ValueError:
current_index = 0
selected_llm = st.selectbox(
f"Select {'Groq' if current_provider == 'groq' else 'Ollama'} Model for chat:",
model_options,
index=current_index,
key="chat_llm_selector"
)
if selected_llm != st.session_state.current_llm:
st.session_state.current_llm = selected_llm
# Recreate LLM client with new model
llm_client = create_llm_client(
provider=current_provider,
api_key=st.session_state.groq_api_key if "groq_api_key" in st.session_state else "",
api_keys=settings.groq_api_keys if settings.groq_api_keys else None,
model_name=selected_llm,
ollama_host=settings.ollama_host,
max_rpm=settings.groq_rpm_limit,
rate_limit_delay=settings.rate_limit_delay
)
st.session_state.rag_pipeline.llm = llm_client
with col3:
if st.button("πŸ—‘οΈ Clear History"):
st.session_state.chat_history = []
st.session_state.rag_pipeline.clear_history()
st.rerun()
# Show system prompt info in expandable section
with st.expander("πŸ€– System Prompt & Model Info", expanded=False):
current_provider = st.session_state.get("llm_provider", "groq")
col1, col2 = st.columns(2)
with col1:
provider_icon = "☁️" if current_provider == "groq" else "πŸ–₯️"
st.write(f"**Provider:** {provider_icon} {current_provider.upper()}")
st.write(f"**LLM Model:** {st.session_state.current_llm}")
st.write(f"**Temperature:** 0.0")
st.write(f"**Max Tokens:** 2048")
with col2:
st.write(f"**Collection:** {st.session_state.current_collection}")
st.write(f"**Dataset:** {st.session_state.get('dataset_name', 'N/A')}")
if current_provider == "groq":
st.write(f"**Rate Limit:** {settings.groq_rpm_limit} RPM")
else:
st.write(f"**Rate Limit:** βœ… Unlimited (Local)")
st.markdown("#### System Prompt Being Used")
system_prompt = """You are a Fact-Checking and Citation Specialist. Your task is to perform a rigorous audit of a response against provided documents to determine its accuracy, relevance, and level of support.
**TASK OVERVIEW**
1. **Analyze Documents**: Review the provided documents and identify information relevant to the user's question.
2. **Evaluate Response**: Review the provided answer sentence-by-sentence.
3. **Verify Support**: Map each answer sentence to specific supporting sentences in the documents.
4. **Identify Utilization**: Determine which document sentences were actually used (directly or implicitly) to form the answer."""
st.info(system_prompt)
# Chat container
chat_container = st.container()
# Display chat history
with chat_container:
for chat_idx, entry in enumerate(st.session_state.chat_history):
# User message
with st.chat_message("user"):
st.write(entry["query"])
# Assistant message
with st.chat_message("assistant"):
st.write(entry["response"])
# Show retrieved documents in expander
with st.expander("πŸ“„ Retrieved Documents"):
for doc_idx, doc in enumerate(entry["retrieved_documents"]):
st.markdown(f"**Document {doc_idx+1}** (Distance: {doc.get('distance', 'N/A'):.4f})")
st.text_area(
f"doc_{chat_idx}_{doc_idx}",
value=doc["document"],
height=100,
key=f"doc_area_{chat_idx}_{doc_idx}",
label_visibility="collapsed"
)
if doc.get("metadata"):
st.caption(f"Metadata: {doc['metadata']}")
# Chat input
query = st.chat_input("Ask a question...")
if query:
# Check if collection exists
if not st.session_state.rag_pipeline or not st.session_state.rag_pipeline.vector_store.current_collection:
st.error("❌ No data loaded. Please load a dataset first using the configuration panel.")
st.stop()
# Add user message
with chat_container:
with st.chat_message("user"):
st.write(query)
# Generate response
with st.spinner("Generating response..."):
try:
result = st.session_state.rag_pipeline.query(query)
except Exception as e:
st.error(f"❌ Error querying: {str(e)}")
st.info("Please load a dataset and create a collection first.")
st.stop()
# Add assistant message
with chat_container:
with st.chat_message("assistant"):
st.write(result["response"])
# Show retrieved documents
with st.expander("πŸ“„ Retrieved Documents"):
for doc_idx, doc in enumerate(result["retrieved_documents"]):
st.markdown(f"**Document {doc_idx+1}** (Distance: {doc.get('distance', 'N/A'):.4f})")
st.text_area(
f"doc_current_{doc_idx}",
value=doc["document"],
height=100,
key=f"doc_current_area_{doc_idx}",
label_visibility="collapsed"
)
if doc.get("metadata"):
st.caption(f"Metadata: {doc['metadata']}")
# Store in history
st.session_state.chat_history.append(result)
st.rerun()
def evaluation_interface():
"""Evaluation interface tab."""
st.subheader("πŸ“Š RAG Evaluation")
# Check if collection is loaded
if not st.session_state.collection_loaded:
st.warning("⚠️ No data loaded. Please load a collection first.")
return
# Evaluation method selector
eval_col1, eval_col2 = st.columns([2, 1])
with eval_col1:
evaluation_method = st.radio(
"Evaluation Method:",
options=["TRACE (Heuristic)", "GPT Labeling (LLM-based)", "Hybrid (Both)"],
horizontal=True,
help="TRACE is fast (no LLM). GPT Labeling is accurate but slower (requires LLM calls)."
)
# Map UI labels to method IDs
method_map = {
"TRACE (Heuristic)": "trace",
"GPT Labeling (LLM-based)": "gpt_labeling",
"Hybrid (Both)": "hybrid"
}
selected_method = method_map[evaluation_method]
# LLM selector for evaluation
current_provider = st.session_state.get("llm_provider", "groq")
col1, col2 = st.columns([3, 1])
with col1:
# Show provider-specific models
if current_provider == "groq":
model_options = settings.llm_models
try:
current_index = settings.llm_models.index(st.session_state.current_llm)
except ValueError:
current_index = 0
else:
model_options = settings.ollama_models
try:
current_index = settings.ollama_models.index(st.session_state.current_llm)
except ValueError:
current_index = 0
selected_llm = st.selectbox(
f"Select {'Groq' if current_provider == 'groq' else 'Ollama'} Model for evaluation:",
model_options,
index=current_index,
key="eval_llm_selector"
)
# Show provider info
provider_icon = "☁️" if current_provider == "groq" else "πŸ–₯️"
if current_provider == "ollama":
st.caption(f"{provider_icon} Using local Ollama - **No rate limits!** Fast evaluation possible.")
else:
st.caption(f"{provider_icon} Using Groq API - Rate limited to {settings.groq_rpm_limit} RPM")
# Show method description
method_descriptions = {
"trace": """
**TRACE Heuristic Method** (Fast, Rule-Based)
- Utilization: How well the system uses retrieved documents
- Relevance: Relevance of retrieved documents to the query
- Adherence: How well the response adheres to the retrieved context
- Completeness: How complete the response is in answering the query
- ⚑ Speed: ~100ms per evaluation
- πŸ’° Cost: Free (no API calls)
""",
"gpt_labeling": """
**GPT Labeling Method** (Accurate, LLM-based)
- Uses sentence-level LLM analysis (from RAGBench paper)
- Context Relevance: Fraction of context relevant to query
- Context Utilization: Fraction of relevant context used
- Completeness: Fraction of relevant info covered
- Adherence: Response supported by context (no hallucinations)
- ⏱️ Speed: ~2-5 seconds per evaluation
- πŸ’° Cost: ~$0.002-0.01 per evaluation
""",
"hybrid": """
**Hybrid Method** (Comprehensive)
- Runs both TRACE and GPT Labeling methods
- Provides both fast and accurate evaluation metrics
- Best for detailed analysis
- ⏱️ Speed: ~3-6 seconds per evaluation
- πŸ’° Cost: Same as GPT Labeling
"""
}
st.markdown(method_descriptions[selected_method])
# Get maximum test samples available for current dataset
try:
loader = RAGBenchLoader()
max_test_samples = loader.get_test_data_size(st.session_state.dataset_name)
st.caption(f"πŸ“Š Available test samples: {max_test_samples:,}")
except Exception as e:
max_test_samples = 100
st.caption(f"Available test samples: ~{max_test_samples} (estimated)")
# Ensure min and max are reasonable
max_test_samples = max(5, min(max_test_samples, 500)) # Cap at 500 for performance
num_test_samples = st.slider(
"Number of test samples",
min_value=5,
max_value=max_test_samples,
value=min(10, max_test_samples),
step=5
)
# Show warning for GPT labeling (API cost) - only for Groq
if selected_method in ["gpt_labeling", "hybrid"]:
current_provider = st.session_state.get("llm_provider", "groq")
if current_provider == "groq":
st.warning(f"⚠️ **{evaluation_method}** requires LLM API calls. This will incur costs and be slower due to rate limiting ({settings.groq_rpm_limit} RPM).")
else:
st.info(f"ℹ️ **{evaluation_method}** using local Ollama - **No rate limits!** Evaluation will be much faster.")
# Background evaluation toggle
run_in_background = st.checkbox(
"πŸ”„ Run in Background",
value=False,
help="Run evaluation in background. You can close the browser and check results later."
)
col_run, col_jobs = st.columns([1, 1])
with col_run:
if st.button("πŸ”¬ Run Evaluation", type="primary"):
if run_in_background:
# Start background job
start_background_evaluation(
num_test_samples,
selected_llm,
selected_method
)
else:
# Use selected LLM for evaluation
run_evaluation(num_test_samples, selected_llm, selected_method)
with col_jobs:
if st.button("πŸ“‹ View Background Jobs"):
st.session_state.show_background_jobs = True
# Show background jobs panel
if st.session_state.get("show_background_jobs", False):
display_background_jobs()
# Display results
if st.session_state.evaluation_results:
results = st.session_state.evaluation_results
st.success("βœ… Evaluation Complete!")
st.divider()
st.markdown("## πŸ“Š Evaluation Metrics")
# Display aggregate scores - handle both TRACE and GPT Labeling metric names
st.markdown("### Main Metrics")
col1, col2, col3, col4, col5 = st.columns(5)
# Determine which metrics are available
utilization = results.get('utilization') or results.get('context_utilization', 0)
relevance = results.get('relevance') or results.get('context_relevance', 0)
adherence = results.get('adherence', 0)
completeness = results.get('completeness', 0)
average = results.get('average', 0)
with col1:
st.metric("πŸ“Š Utilization", f"{utilization:.3f}")
with col2:
st.metric("🎯 Relevance", f"{relevance:.3f}")
with col3:
st.metric("βœ… Adherence", f"{adherence:.3f}")
with col4:
st.metric("πŸ“ Completeness", f"{completeness:.3f}")
with col5:
st.metric("⭐ Average", f"{average:.3f}")
# Detailed results summary - handle both metric types
if "individual_scores" in results:
with st.expander("πŸ“‹ Summary Metrics by Query"):
df = pd.DataFrame(results["individual_scores"])
st.dataframe(df, use_container_width=True)
# Detailed per-query results
if "detailed_results" in results and results["detailed_results"]:
with st.expander("πŸ” Detailed Per-Query Analysis"):
for query_result in results.get("detailed_results", []):
with st.expander(f"Query {query_result['query_id']}: {query_result['question'][:60]}..."):
st.markdown("### Question")
st.write(query_result['question'])
st.markdown("### LLM Response")
st.write(query_result.get('llm_response', 'N/A'))
st.markdown("### Retrieved Documents")
for doc_idx, doc in enumerate(query_result.get('retrieved_documents', []), 1):
with st.expander(f"πŸ“„ Document {doc_idx}"):
st.write(doc)
if query_result.get('ground_truth'):
st.markdown("### Ground Truth")
st.write(query_result['ground_truth'])
# Display metrics with correct labels based on method
metrics = query_result.get('metrics', {})
if metrics:
st.markdown("### Evaluation Metrics")
col1, col2, col3, col4, col5 = st.columns(5)
# Get metric values (handle both TRACE and GPT names)
util_val = metrics.get('utilization') or metrics.get('context_utilization', 0)
rel_val = metrics.get('relevance') or metrics.get('context_relevance', 0)
adh_val = metrics.get('adherence', 0)
comp_val = metrics.get('completeness', 0)
avg_val = metrics.get('average', 0)
with col1:
st.metric("Util", f"{util_val:.3f}")
with col2:
st.metric("Rel", f"{rel_val:.3f}")
with col3:
st.metric("Adh", f"{adh_val:.3f}")
with col4:
st.metric("Comp", f"{comp_val:.3f}")
with col5:
st.metric("Avg", f"{avg_val:.3f}")
# For GPT Labeling and Hybrid methods, show additional metrics
method = results.get("method", "")
if "gpt_labeling" in method or "hybrid" in method:
# Show RMSE aggregation metrics (consistency across evaluations)
if "rmse_metrics" in results:
st.markdown("### πŸ“Š RMSE Aggregation (Metric Consistency)")
rmse_data = results.get("rmse_metrics", {})
rmse_cols = st.columns(4)
with rmse_cols[0]:
rel_mean = rmse_data.get("context_relevance", {}).get("mean", 0)
rel_std = rmse_data.get("context_relevance", {}).get("std_dev", 0)
st.metric("Relevance", f"{rel_mean:.3f} Β±{rel_std:.3f}", help="Mean and Std Dev")
with rmse_cols[1]:
util_mean = rmse_data.get("context_utilization", {}).get("mean", 0)
util_std = rmse_data.get("context_utilization", {}).get("std_dev", 0)
st.metric("Utilization", f"{util_mean:.3f} Β±{util_std:.3f}", help="Mean and Std Dev")
with rmse_cols[2]:
comp_mean = rmse_data.get("completeness", {}).get("mean", 0)
comp_std = rmse_data.get("completeness", {}).get("std_dev", 0)
st.metric("Completeness", f"{comp_mean:.3f} Β±{comp_std:.3f}", help="Mean and Std Dev")
with rmse_cols[3]:
adh_mean = rmse_data.get("adherence", {}).get("mean", 0)
adh_std = rmse_data.get("adherence", {}).get("std_dev", 0)
st.metric("Adherence", f"{adh_mean:.3f} Β±{adh_std:.3f}", help="Mean and Std Dev")
# Show detailed RMSE statistics in expander
with st.expander("See detailed RMSE aggregation statistics"):
for metric_name, metric_data in rmse_data.items():
st.write(f"**{metric_name}**")
col1, col2, col3, col4 = st.columns(4)
with col1:
st.write(f"Mean: {metric_data.get('mean', 0):.4f}")
with col2:
st.write(f"Std Dev: {metric_data.get('std_dev', 0):.4f}")
with col3:
st.write(f"Min: {metric_data.get('min', 0):.4f}")
with col4:
st.write(f"Max: {metric_data.get('max', 0):.4f}")
# Show per-metric statistics if available
if "per_metric_statistics" in results:
st.markdown("### πŸ“ˆ Per-Metric Statistics (Distribution)")
stats_data = results.get("per_metric_statistics", {})
stats_cols = st.columns(4)
with stats_cols[0]:
rel_stats = stats_data.get("context_relevance", {})
st.metric("Relevance Mean", f"{rel_stats.get('mean', 0):.3f}", help=f"Median: {rel_stats.get('median', 0):.3f}")
with stats_cols[1]:
util_stats = stats_data.get("context_utilization", {})
st.metric("Utilization Mean", f"{util_stats.get('mean', 0):.3f}", help=f"Median: {util_stats.get('median', 0):.3f}")
with stats_cols[2]:
comp_stats = stats_data.get("completeness", {})
st.metric("Completeness Mean", f"{comp_stats.get('mean', 0):.3f}", help=f"Median: {comp_stats.get('median', 0):.3f}")
with stats_cols[3]:
adh_stats = stats_data.get("adherence", {})
st.metric("Adherence Mean", f"{adh_stats.get('mean', 0):.3f}", help=f"Median: {adh_stats.get('median', 0):.3f}")
# Show detailed statistics
with st.expander("See detailed per-metric statistics"):
for metric_name, metric_stats in stats_data.items():
st.write(f"**{metric_name}**")
col1, col2 = st.columns(2)
with col1:
st.write(f"""
- Mean: {metric_stats.get('mean', 0):.4f}
- Median: {metric_stats.get('median', 0):.4f}
- Std Dev: {metric_stats.get('std_dev', 0):.4f}
- Min: {metric_stats.get('min', 0):.4f}
- Max: {metric_stats.get('max', 0):.4f}
""")
with col2:
st.write(f"""
- 25th percentile: {metric_stats.get('percentile_25', 0):.4f}
- 75th percentile: {metric_stats.get('percentile_75', 0):.4f}
- Perfect (>=0.95): {metric_stats.get('perfect_count', 0)}
- Poor (<0.3): {metric_stats.get('poor_count', 0)}
- Samples: {metric_stats.get('sample_count', 0)}
""")
# Show RMSE vs RAGBench Ground Truth (per RAGBench paper requirement)
if "rmse_vs_ground_truth" in results:
st.markdown("### πŸ“‰ RMSE vs RAGBench Ground Truth")
st.info("Compares predicted TRACE scores against original RAGBench dataset scores")
rmse_gt = results.get("rmse_vs_ground_truth", {})
per_metric_rmse = rmse_gt.get("per_metric_rmse", {})
if per_metric_rmse:
rmse_gt_cols = st.columns(5)
with rmse_gt_cols[0]:
st.metric("Relevance RMSE", f"{per_metric_rmse.get('context_relevance', 0):.4f}",
delta=None, help="Lower is better (0 = perfect match)")
with rmse_gt_cols[1]:
st.metric("Utilization RMSE", f"{per_metric_rmse.get('context_utilization', 0):.4f}")
with rmse_gt_cols[2]:
st.metric("Completeness RMSE", f"{per_metric_rmse.get('completeness', 0):.4f}")
with rmse_gt_cols[3]:
st.metric("Adherence RMSE", f"{per_metric_rmse.get('adherence', 0):.4f}")
with rmse_gt_cols[4]:
agg_rmse = rmse_gt.get("aggregated_rmse", 0)
consistency = rmse_gt.get("consistency_score", 0)
st.metric("Aggregated RMSE", f"{agg_rmse:.4f}",
delta=f"Consistency: {consistency:.2%}", delta_color="normal")
# Show AUCROC vs RAGBench Ground Truth (per RAGBench paper requirement)
if "aucroc_vs_ground_truth" in results:
st.markdown("### πŸ“Š AUC-ROC vs RAGBench Ground Truth")
st.info("Area Under ROC Curve comparing predicted vs ground truth binary classifications")
auc_gt = results.get("aucroc_vs_ground_truth", {})
if auc_gt:
auc_cols = st.columns(5)
with auc_cols[0]:
st.metric("Relevance AUC", f"{auc_gt.get('context_relevance', 0):.4f}",
help="Higher is better (1.0 = perfect classification)")
with auc_cols[1]:
st.metric("Utilization AUC", f"{auc_gt.get('context_utilization', 0):.4f}")
with auc_cols[2]:
st.metric("Completeness AUC", f"{auc_gt.get('completeness', 0):.4f}")
with auc_cols[3]:
st.metric("Adherence AUC", f"{auc_gt.get('adherence', 0):.4f}")
with auc_cols[4]:
avg_auc = auc_gt.get("average", 0)
st.metric("Average AUC", f"{avg_auc:.4f}")
# Download results
st.divider()
st.markdown("## πŸ’Ύ Download Results")
# Create a comprehensive download with all details
download_data = {
"evaluation_metadata": {
"timestamp": datetime.now().isoformat(),
"dataset": st.session_state.dataset_name,
"method": results.get("evaluation_config", {}).get("evaluation_method", "gpt_labeling_prompts"),
"total_samples": results.get("num_samples", 0),
"embedding_model": st.session_state.embedding_model,
},
"aggregate_metrics": {
"context_relevance": results.get("context_relevance") or results.get("relevance", 0),
"context_utilization": results.get("context_utilization") or results.get("utilization", 0),
"completeness": results.get("completeness", 0),
"adherence": results.get("adherence", 0),
"average": results.get("average", 0),
},
"rmse_metrics": results.get("rmse_metrics", {}),
"per_metric_statistics": results.get("per_metric_statistics", {}),
"rmse_vs_ground_truth": results.get("rmse_vs_ground_truth", {}),
"aucroc_vs_ground_truth": results.get("aucroc_vs_ground_truth", {}),
"detailed_results": results.get("detailed_results", [])
}
results_json = json.dumps(download_data, indent=2, default=str)
col1, col2 = st.columns(2)
with col1:
st.download_button(
label="πŸ“₯ Download Complete Results (JSON)",
data=results_json,
file_name=f"evaluation_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
mime="application/json",
help="Download all evaluation results including metrics and per-query details"
)
with col2:
st.download_button(
label="πŸ“‹ Download Metrics Only (JSON)",
data=json.dumps(download_data["aggregate_metrics"], indent=2),
file_name=f"evaluation_metrics_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
mime="application/json",
help="Download only the aggregate metrics"
)
def start_background_evaluation(num_samples: int, selected_llm: str = None, method: str = "trace"):
"""Start evaluation in background thread.
Args:
num_samples: Number of test samples to evaluate
selected_llm: LLM model to use for evaluation
method: Evaluation method
"""
if not BACKGROUND_EVAL_AVAILABLE:
st.error("Background evaluation not available. Running in foreground.")
run_evaluation(num_samples, selected_llm, method)
return
try:
# Generate unique job ID
job_id = f"{st.session_state.collection_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{str(uuid.uuid4())[:8]}"
# Create job record
job = create_job(
job_id=job_id,
collection_name=st.session_state.collection_name,
dataset_name=st.session_state.dataset_name,
num_samples=num_samples,
method=method,
llm_model=selected_llm or st.session_state.current_llm,
embedding_model=st.session_state.embedding_model,
llm_provider=st.session_state.get("llm_provider", "groq")
)
# Start background thread
start_background_job(
job_id=job_id,
rag_pipeline=st.session_state.rag_pipeline,
vector_store=st.session_state.vector_store,
dataset_name=st.session_state.dataset_name,
num_samples=num_samples,
method=method,
llm_model=selected_llm or st.session_state.current_llm,
embedding_model=st.session_state.embedding_model,
llm_provider=st.session_state.get("llm_provider", "groq"),
groq_api_key=st.session_state.get("groq_api_key"),
groq_api_keys=settings.groq_api_keys if settings.groq_api_keys else None,
ollama_host=settings.ollama_host
)
st.success(f"""
βœ… **Background evaluation started!**
- **Job ID:** `{job_id}`
- **Method:** {method}
- **Samples:** {num_samples}
You can safely close this page. Check back later using "View Background Jobs" button.
""")
# Show the jobs panel
st.session_state.show_background_jobs = True
except Exception as e:
st.error(f"Failed to start background evaluation: {str(e)}")
def display_background_jobs():
"""Display background jobs status panel."""
st.divider()
st.markdown("### πŸ“‹ Background Evaluation Jobs")
if not BACKGROUND_EVAL_AVAILABLE:
st.warning("Background evaluation module not available.")
return
# Refresh button
if st.button("πŸ”„ Refresh Status"):
st.rerun()
# Get all jobs
jobs = list_jobs()
if not jobs:
st.info("No background jobs found.")
if st.button("Close Panel"):
st.session_state.show_background_jobs = False
st.rerun()
return
# Display jobs
for job in jobs[:10]: # Show last 10 jobs
job_id = job.get("job_id", "Unknown")
status = job.get("status", "unknown")
progress = job.get("progress", 0)
# Status icon
status_icons = {
"pending": "⏳",
"running": "πŸ”„",
"completed": "βœ…",
"failed": "❌"
}
icon = status_icons.get(status, "❓")
with st.expander(f"{icon} {job.get('collection_name', 'Unknown')} - {status.upper()} ({progress}%)", expanded=(status == "running")):
col1, col2 = st.columns([2, 1])
with col1:
st.write(f"**Job ID:** `{job_id}`")
st.write(f"**Method:** {job.get('method', 'Unknown')}")
st.write(f"**Samples:** {job.get('num_samples', 0)}")
st.write(f"**LLM:** {job.get('llm_model', 'Unknown')}")
st.write(f"**Created:** {job.get('created_at', 'Unknown')}")
if status == "running":
st.progress(progress / 100)
st.write(f"**Current Step:** {job.get('current_step', '')}")
if status == "completed":
st.write(f"**Completed:** {job.get('completed_at', 'Unknown')}")
results_file = job.get('results_file')
if results_file:
st.write(f"**Results File:** `{results_file}`")
# Load results into session state
if st.button(f"πŸ“Š Load Results", key=f"load_{job_id}"):
results = job.get("results")
if results:
st.session_state.evaluation_results = results
st.success("Results loaded! Scroll down to see metrics.")
st.rerun()
if status == "failed":
st.error(f"**Error:** {job.get('error', 'Unknown error')}")
with col2:
# Show logs
logs = job.get("logs", [])
if logs:
st.write("**Recent Logs:**")
for log in logs[-5:]: # Last 5 logs
st.caption(f"{log.get('message', '')}")
# Delete button
if st.button(f"πŸ—‘οΈ Delete", key=f"del_{job_id}"):
delete_job(job_id)
st.success(f"Job {job_id} deleted.")
st.rerun()
if st.button("Close Panel"):
st.session_state.show_background_jobs = False
st.rerun()
def run_evaluation(num_samples: int, selected_llm: str = None, method: str = "trace"):
"""Run evaluation using selected method (TRACE, GPT Labeling, or Hybrid).
Args:
num_samples: Number of test samples to evaluate
selected_llm: LLM model to use for evaluation
method: Evaluation method ("trace", "gpt_labeling", or "hybrid")
"""
with st.spinner(f"Running evaluation on {num_samples} samples..."):
try:
# Create logs container
logs_container = st.container()
logs_list = []
# Display logs header once outside function
logs_placeholder = st.empty()
def add_log(message: str):
"""Add log message and update display."""
logs_list.append(message)
with logs_placeholder.container():
st.markdown("### πŸ“‹ Evaluation Logs:")
for log_msg in logs_list:
st.caption(log_msg)
# Log evaluation start
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
add_log(f"⏱️ Evaluation started at {timestamp}")
add_log(f"πŸ“Š Dataset: {st.session_state.dataset_name}")
add_log(f"πŸ“ˆ Total samples: {num_samples}")
add_log(f"πŸ€– LLM Model: {selected_llm if selected_llm else st.session_state.current_llm}")
add_log(f"πŸ”— Vector Store: {st.session_state.collection_name}")
add_log(f"🧠 Embedding Model: {st.session_state.embedding_model}")
# Map method names
method_names = {
"trace": "TRACE (Heuristic)",
"gpt_labeling": "GPT Labeling (LLM-based)",
"hybrid": "Hybrid (Both)"
}
add_log(f"πŸ”¬ Evaluation Method: {method_names.get(method, method)}")
# Use selected LLM if provided - create with appropriate provider
eval_llm_client = None
original_llm = None
current_provider = st.session_state.get("llm_provider", "groq")
if selected_llm and selected_llm != st.session_state.current_llm:
add_log(f"πŸ”„ Switching LLM to {selected_llm} ({current_provider.upper()})...")
groq_api_key = st.session_state.groq_api_key if "groq_api_key" in st.session_state else ""
eval_llm_client = create_llm_client(
provider=current_provider,
api_key=groq_api_key,
api_keys=settings.groq_api_keys if settings.groq_api_keys else None,
model_name=selected_llm,
ollama_host=settings.ollama_host,
max_rpm=settings.groq_rpm_limit,
rate_limit_delay=settings.rate_limit_delay,
max_retries=settings.max_retries,
retry_delay=settings.retry_delay
)
# Temporarily replace LLM client
original_llm = st.session_state.rag_pipeline.llm
st.session_state.rag_pipeline.llm = eval_llm_client
else:
eval_llm_client = st.session_state.rag_pipeline.llm
# Log provider info
provider_icon = "☁️" if current_provider == "groq" else "πŸ–₯️"
add_log(f"{provider_icon} LLM Provider: {current_provider.upper()}")
# Get test data
add_log("πŸ“₯ Loading test data...")
loader = RAGBenchLoader()
test_data = loader.get_test_data(
st.session_state.dataset_name,
num_samples
)
add_log(f"βœ… Loaded {len(test_data)} test samples")
# Prepare test cases
test_cases = []
progress_bar = st.progress(0)
status_text = st.empty()
add_log("πŸ” Processing samples...")
for i, sample in enumerate(test_data):
status_text.text(f"Processing sample {i+1}/{num_samples}")
# Query the RAG system
result = st.session_state.rag_pipeline.query(
sample["question"],
n_results=5
)
# Prepare test case
test_cases.append({
"query": sample["question"],
"response": result["response"],
"retrieved_documents": [doc["document"] for doc in result["retrieved_documents"]],
"ground_truth": sample.get("answer", "")
})
# Update progress
progress_bar.progress((i + 1) / num_samples)
# Log every 10 samples
if (i + 1) % 10 == 0 or (i + 1) == num_samples:
add_log(f" βœ“ Processed {i + 1}/{num_samples} samples")
status_text.text(f"Running {method_names.get(method, method)} evaluation...")
add_log(f"πŸ“Š Running evaluation using {method_names.get(method, method)}...")
# Extract chunking and embedding metadata from session state
# (These were stored when the collection was loaded/created)
chunking_strategy = st.session_state.vector_store.chunking_strategy if st.session_state.vector_store else None
embedding_model = st.session_state.embedding_model
chunk_size = st.session_state.vector_store.chunk_size if st.session_state.vector_store else None
chunk_overlap = st.session_state.vector_store.chunk_overlap if st.session_state.vector_store else None
# Log retrieval configuration
add_log(f"πŸ”§ Retrieval Configuration:")
add_log(f" β€’ Chunking Strategy: {chunking_strategy or 'Unknown'}")
add_log(f" β€’ Chunk Size: {chunk_size or 'Unknown'}")
add_log(f" β€’ Chunk Overlap: {chunk_overlap or 'Unknown'}")
add_log(f" β€’ Embedding Model: {embedding_model or 'Unknown'}")
# Import unified pipeline
try:
from evaluation_pipeline import UnifiedEvaluationPipeline
# Run evaluation with metadata using unified pipeline
pipeline = UnifiedEvaluationPipeline(
llm_client=eval_llm_client,
chunking_strategy=chunking_strategy,
embedding_model=embedding_model,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
# Run evaluation with selected method
results = pipeline.evaluate_batch(test_cases, method=method)
except ImportError:
# Fallback to TRACE only if evaluation_pipeline module not available
add_log("⚠️ evaluation_pipeline module not found, falling back to TRACE...")
# Run evaluation with metadata using TRACE
evaluator = TRACEEvaluator(
chunking_strategy=chunking_strategy,
embedding_model=embedding_model,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
results = evaluator.evaluate_batch(test_cases)
st.session_state.evaluation_results = results
# Log evaluation results summary
add_log("βœ… Evaluation completed successfully!")
# Display appropriate metrics based on method
if method == "trace":
add_log(f" β€’ Utilization: {results.get('utilization', 0):.2%}")
add_log(f" β€’ Relevance: {results.get('relevance', 0):.2%}")
add_log(f" β€’ Adherence: {results.get('adherence', 0):.2%}")
add_log(f" β€’ Completeness: {results.get('completeness', 0):.2%}")
add_log(f" β€’ Average: {results.get('average', 0):.2%}")
elif method == "gpt_labeling":
if "context_relevance" in results:
add_log(f" β€’ Context Relevance: {results.get('context_relevance', 0):.2%}")
add_log(f" β€’ Context Utilization: {results.get('context_utilization', 0):.2%}")
add_log(f" β€’ Completeness: {results.get('completeness', 0):.2%}")
add_log(f" β€’ Adherence: {results.get('adherence', 0):.2%}")
add_log(f" β€’ Average: {results.get('average', 0):.2%}")
# NEW: Display RMSE and AUCROC metrics if available
if "rmse_metrics" in results:
add_log(f"πŸ“ˆ RMSE Metrics (vs ground truth):")
rmse_metrics = results.get("rmse_metrics", {})
add_log(f" β€’ Context Relevance RMSE: {rmse_metrics.get('relevance', 0):.4f}")
add_log(f" β€’ Context Utilization RMSE: {rmse_metrics.get('utilization', 0):.4f}")
add_log(f" β€’ Completeness RMSE: {rmse_metrics.get('completeness', 0):.4f}")
add_log(f" β€’ Adherence RMSE: {rmse_metrics.get('adherence', 0):.4f}")
add_log(f" β€’ Average RMSE: {rmse_metrics.get('average', 0):.4f}")
if "auc_metrics" in results:
add_log(f"πŸ“Š AUCROC Metrics (binary classification):")
auc_metrics = results.get("auc_metrics", {})
add_log(f" β€’ Context Relevance AUCROC: {auc_metrics.get('relevance', 0):.4f}")
add_log(f" β€’ Context Utilization AUCROC: {auc_metrics.get('utilization', 0):.4f}")
add_log(f" β€’ Completeness AUCROC: {auc_metrics.get('completeness', 0):.4f}")
add_log(f" β€’ Adherence AUCROC: {auc_metrics.get('adherence', 0):.4f}")
add_log(f" β€’ Average AUCROC: {auc_metrics.get('average', 0):.4f}")
elif method == "hybrid":
add_log(" πŸ“Š TRACE Metrics:")
trace_res = results.get("trace_results", {})
add_log(f" β€’ Utilization: {trace_res.get('utilization', 0):.2%}")
add_log(f" β€’ Relevance: {trace_res.get('relevance', 0):.2%}")
add_log(f" β€’ Adherence: {trace_res.get('adherence', 0):.2%}")
add_log(f" β€’ Completeness: {trace_res.get('completeness', 0):.2%}")
add_log(" 🧠 GPT Labeling Metrics:")
gpt_res = results.get("gpt_results", {})
add_log(f" β€’ Context Relevance: {gpt_res.get('context_relevance', 0):.2%}")
add_log(f" β€’ Context Utilization: {gpt_res.get('context_utilization', 0):.2%}")
add_log(f" β€’ Completeness: {gpt_res.get('completeness', 0):.2%}")
add_log(f" β€’ Adherence: {gpt_res.get('adherence', 0):.2%}")
# Restore original LLM if it was switched
if selected_llm and selected_llm != st.session_state.current_llm and original_llm:
st.session_state.rag_pipeline.llm = original_llm
add_log(f"πŸ”„ Restored original LLM")
add_log(f"⏱️ Evaluation completed at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
# Auto-save evaluation results to JSON file
try:
collection_name = st.session_state.get("collection_name", "unknown_collection")
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
filename = f"{collection_name}_{timestamp}_evaluation.json"
# Prepare save data with metadata
save_data = {
"collection_name": collection_name,
"timestamp": datetime.now().isoformat(),
"evaluation_method": method,
"num_samples": num_samples,
"embedding_model": st.session_state.get("embedding_model", "unknown"),
"llm_model": st.session_state.get("current_llm", "unknown"),
"results": results
}
# Save to file
with open(filename, 'w', encoding='utf-8') as f:
json.dump(save_data, f, indent=2, default=str)
add_log(f"πŸ’Ύ Results auto-saved to: {filename}")
st.success(f"πŸ’Ύ Results saved to: {filename}")
except Exception as save_error:
add_log(f"⚠️ Could not auto-save results: {str(save_error)}")
except Exception as e:
st.error(f"Error during evaluation: {str(e)}")
add_log(f"❌ Error: {str(e)}")
def history_interface():
"""History interface tab."""
st.subheader("πŸ“œ Chat History")
if not st.session_state.chat_history:
st.info("No chat history yet. Start a conversation in the Chat tab!")
return
# Export history
col1, col2 = st.columns([3, 1])
with col2:
history_json = json.dumps(st.session_state.chat_history, indent=2)
st.download_button(
label="πŸ’Ύ Export History",
data=history_json,
file_name=f"chat_history_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
mime="application/json"
)
# Display history
for i, entry in enumerate(st.session_state.chat_history):
with st.expander(f"πŸ’¬ Conversation {i+1}: {entry['query'][:50]}..."):
st.markdown(f"**Query:** {entry['query']}")
st.markdown(f"**Response:** {entry['response']}")
st.markdown(f"**Timestamp:** {entry.get('timestamp', 'N/A')}")
st.markdown("**Retrieved Documents:**")
for j, doc in enumerate(entry["retrieved_documents"]):
st.text_area(
f"Document {j+1}",
value=doc["document"],
height=100,
key=f"history_doc_{i}_{j}"
)
if __name__ == "__main__":
main()