RAG10 / streamlit_app.py
Vivek Kadamati
Initial commit
ee444c0
"""Streamlit chat interface for RAG application."""
import streamlit as st
import sys
import os
from datetime import datetime
import json
import pandas as pd
from typing import Optional
import warnings
# Suppress warnings
warnings.filterwarnings('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
# Add parent directory to path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from config import settings
from dataset_loader import RAGBenchLoader
from vector_store import ChromaDBManager
from llm_client import GroqLLMClient, RAGPipeline
from trace_evaluator import TRACEEvaluator
from embedding_models import EmbeddingFactory
from chunking_strategies import ChunkingFactory
# Page configuration
st.set_page_config(
page_title="RAG Capstone Project",
page_icon="πŸ€–",
layout="wide"
)
# Initialize session state
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
if "rag_pipeline" not in st.session_state:
st.session_state.rag_pipeline = None
if "vector_store" not in st.session_state:
st.session_state.vector_store = None
if "collection_loaded" not in st.session_state:
st.session_state.collection_loaded = False
if "evaluation_results" not in st.session_state:
st.session_state.evaluation_results = None
if "dataset_size" not in st.session_state:
st.session_state.dataset_size = 10000
if "current_dataset" not in st.session_state:
st.session_state.current_dataset = None
if "current_llm" not in st.session_state:
st.session_state.current_llm = settings.llm_models[1]
if "selected_collection" not in st.session_state:
st.session_state.selected_collection = None
if "available_collections" not in st.session_state:
st.session_state.available_collections = []
def get_available_collections():
"""Get list of available collections from ChromaDB."""
try:
vector_store = ChromaDBManager(settings.chroma_persist_directory)
collections = vector_store.list_collections()
return collections
except Exception as e:
print(f"Error getting collections: {e}")
return []
def main():
"""Main Streamlit application."""
st.title("πŸ€– RAG Capstone Project")
st.markdown("### Retrieval-Augmented Generation with TRACE Evaluation")
# Get available collections at startup
available_collections = get_available_collections()
st.session_state.available_collections = available_collections
# Sidebar for configuration
with st.sidebar:
st.header("Configuration")
# API Key input
groq_api_key = st.text_input(
"Groq API Key",
type="password",
value=settings.groq_api_key or "",
help="Enter your Groq API key"
)
st.divider()
# Option 1: Use existing collection
if available_collections:
st.subheader("πŸ“š Existing Collections")
st.write(f"Found {len(available_collections)} collection(s)")
selected_collection = st.selectbox(
"Or select existing collection:",
available_collections,
key="collection_selector"
)
if st.button("πŸ“– Load Existing Collection", type="secondary"):
if not groq_api_key:
st.error("Please enter your Groq API key")
else:
load_existing_collection(groq_api_key, selected_collection)
st.divider()
# Option 2: Create new collection
st.subheader("πŸ†• Create New Collection")
# Dataset selection
st.subheader("1. Dataset Selection")
dataset_name = st.selectbox(
"Choose Dataset",
settings.ragbench_datasets,
index=0
)
# Get dataset size dynamically
if st.button("πŸ” Check Dataset Size", key="check_size"):
with st.spinner("Checking dataset size..."):
try:
from datasets import load_dataset
import os
# Load dataset with download_mode to avoid cache issues
st.info(f"Fetching dataset info for '{dataset_name}'...")
ds = load_dataset(
"rungalileo/ragbench",
dataset_name,
split="train",
trust_remote_code=True,
download_mode="force_redownload" # Force fresh download to avoid cache corruption
)
dataset_size = len(ds)
st.session_state.dataset_size = dataset_size
st.session_state.current_dataset = dataset_name
st.success(f"βœ… Dataset '{dataset_name}' has {dataset_size:,} samples available")
except Exception as e:
st.error(f"❌ Error: {str(e)}")
st.exception(e)
st.warning(f"Could not determine dataset size. Using default of 10,000.")
st.session_state.dataset_size = 10000
st.session_state.current_dataset = dataset_name
# Use stored dataset size or default
max_samples_available = st.session_state.get('dataset_size', 10000)
st.caption(f"Max available samples: {max_samples_available:,}")
num_samples = st.slider(
"Number of samples",
min_value=10,
max_value=max_samples_available,
value=min(100, max_samples_available),
step=50 if max_samples_available > 1000 else 10,
help="Adjust slider to select number of samples"
)
load_all_samples = st.checkbox(
"Load all available samples",
value=False,
help="Override slider and load entire dataset"
)
st.divider()
# Chunking strategy
st.subheader("2. Chunking Strategy")
chunking_strategy = st.selectbox(
"Choose Chunking Strategy",
settings.chunking_strategies,
index=0
)
chunk_size = st.slider(
"Chunk Size",
min_value=256,
max_value=1024,
value=512,
step=128
)
overlap = st.slider(
"Overlap",
min_value=0,
max_value=200,
value=50,
step=10
)
st.divider()
# Embedding model
st.subheader("3. Embedding Model")
embedding_model = st.selectbox(
"Choose Embedding Model",
settings.embedding_models,
index=0
)
st.divider()
# LLM model selection for new collection
st.subheader("4. LLM Model")
llm_model = st.selectbox(
"Choose LLM",
settings.llm_models,
index=1
)
st.divider()
# Load data button
if st.button("πŸš€ Load Data & Create Collection", type="primary"):
if not groq_api_key:
st.error("Please enter your Groq API key")
else:
# Use None for num_samples if loading all data
samples_to_load = None if load_all_samples else num_samples
load_and_create_collection(
groq_api_key,
dataset_name,
samples_to_load,
chunking_strategy,
chunk_size,
overlap,
embedding_model,
llm_model
)
# Main content area
if not st.session_state.collection_loaded:
st.info("πŸ‘ˆ Please configure and load a dataset from the sidebar to begin")
# Show instructions
with st.expander("πŸ“– How to Use", expanded=True):
st.markdown("""
1. **Enter your Groq API Key** in the sidebar
2. **Select a dataset** from RAG Bench
3. **Choose a chunking strategy** (dense, sparse, hybrid, re-ranking)
4. **Select an embedding model** for document vectorization
5. **Choose an LLM model** for response generation
6. **Click "Load Data & Create Collection"** to initialize
7. **Start chatting** in the chat interface
8. **View retrieved documents** and evaluation metrics
9. **Run TRACE evaluation** on test data
""")
# Show available options
col1, col2 = st.columns(2)
with col1:
st.subheader("πŸ“Š Available Datasets")
for ds in settings.ragbench_datasets:
st.markdown(f"- {ds}")
with col2:
st.subheader("πŸ€– Available Models")
st.markdown("**Embedding Models:**")
for em in settings.embedding_models:
st.markdown(f"- {em}")
st.markdown("**LLM Models:**")
for lm in settings.llm_models:
st.markdown(f"- {lm}")
else:
# Create tabs for different functionalities
tab1, tab2, tab3 = st.tabs(["πŸ’¬ Chat", "πŸ“Š Evaluation", "πŸ“œ History"])
with tab1:
chat_interface()
with tab2:
evaluation_interface()
with tab3:
history_interface()
def load_existing_collection(api_key: str, collection_name: str):
"""Load an existing collection from ChromaDB."""
with st.spinner(f"Loading collection '{collection_name}'..."):
try:
# Initialize vector store and get collection
vector_store = ChromaDBManager(settings.chroma_persist_directory)
vector_store.get_collection(collection_name)
# Prompt for LLM selection
st.session_state.current_llm = st.selectbox(
"Select LLM for this collection:",
settings.llm_models,
key=f"llm_selector_{collection_name}"
)
# Initialize LLM client
st.info("Initializing LLM client...")
llm_client = GroqLLMClient(
api_key=api_key,
model_name=st.session_state.current_llm,
max_rpm=settings.groq_rpm_limit,
rate_limit_delay=settings.rate_limit_delay
)
# Create RAG pipeline with correct parameter names
st.info("Creating RAG pipeline...")
rag_pipeline = RAGPipeline(
llm_client=llm_client,
vector_store_manager=vector_store
)
# Store in session state
st.session_state.vector_store = vector_store
st.session_state.rag_pipeline = rag_pipeline
st.session_state.collection_loaded = True
st.session_state.current_collection = collection_name
st.session_state.selected_collection = collection_name
st.session_state.groq_api_key = api_key
st.success(f"βœ… Collection '{collection_name}' loaded successfully!")
st.rerun()
except Exception as e:
st.error(f"Error loading collection: {str(e)}")
st.exception(e)
def load_and_create_collection(
api_key: str,
dataset_name: str,
num_samples: Optional[int],
chunking_strategy: str,
chunk_size: int,
overlap: int,
embedding_model: str,
llm_model: str
):
"""Load dataset and create vector collection."""
with st.spinner("Loading dataset and creating collection..."):
try:
# Initialize dataset loader
loader = RAGBenchLoader()
# Load dataset
if num_samples is None:
st.info(f"Loading {dataset_name} dataset (all available samples)...")
else:
st.info(f"Loading {dataset_name} dataset ({num_samples} samples)...")
dataset = loader.load_dataset(dataset_name, split="train", max_samples=num_samples)
st.info(f"Loading {dataset_name} dataset...")
dataset = loader.load_dataset(dataset_name, split="train", max_samples=num_samples)
if not dataset:
st.error("Failed to load dataset")
return
# Initialize vector store
st.info("Initializing vector store...")
vector_store = ChromaDBManager(settings.chroma_persist_directory)
# Create collection name
collection_name = f"{dataset_name}_{chunking_strategy}_{embedding_model.split('/')[-1]}"
collection_name = collection_name.replace("-", "_").replace(".", "_")
# Delete existing collection with same name (if exists)
existing_collections = vector_store.list_collections()
if collection_name in existing_collections:
st.warning(f"Collection '{collection_name}' already exists. Deleting and recreating...")
vector_store.delete_collection(collection_name)
st.info("Old collection deleted. Creating new one...")
# Load data into collection
st.info(f"Creating collection with {chunking_strategy} chunking...")
vector_store.load_dataset_into_collection(
collection_name=collection_name,
embedding_model_name=embedding_model,
chunking_strategy=chunking_strategy,
dataset_data=dataset,
chunk_size=chunk_size,
overlap=overlap
)
# Initialize LLM client
st.info("Initializing LLM client...")
llm_client = GroqLLMClient(
api_key=api_key,
model_name=llm_model,
max_rpm=settings.groq_rpm_limit,
rate_limit_delay=settings.rate_limit_delay
)
# Create RAG pipeline with correct parameter names
rag_pipeline = RAGPipeline(
llm_client=llm_client,
vector_store_manager=vector_store
)
# Store in session state
st.session_state.vector_store = vector_store
st.session_state.rag_pipeline = rag_pipeline
st.session_state.collection_loaded = True
st.session_state.current_collection = collection_name
st.session_state.dataset_name = dataset_name
st.session_state.dataset = dataset
st.success(f"βœ… Collection '{collection_name}' created successfully!")
st.rerun()
except Exception as e:
st.error(f"Error: {str(e)}")
def chat_interface():
"""Chat interface tab."""
st.subheader("πŸ’¬ Chat Interface")
# Check if collection is loaded
if not st.session_state.collection_loaded:
st.warning("⚠️ No data loaded. Please use the configuration panel to load a dataset and create a collection.")
st.info("""
Steps:
1. Select a dataset from the dropdown
2. Click "Load Data & Create Collection" button
3. Wait for the collection to be created
4. Then you can start chatting
""")
return
# Display collection info and LLM selector
col1, col2, col3 = st.columns([2, 2, 1])
with col1:
st.info(f"πŸ“š Collection: {st.session_state.current_collection}")
with col2:
# LLM selector for chat
selected_llm = st.selectbox(
"Select LLM for chat:",
settings.llm_models,
index=settings.llm_models.index(st.session_state.current_llm),
key="chat_llm_selector"
)
if selected_llm != st.session_state.current_llm:
st.session_state.current_llm = selected_llm
# Recreate RAG pipeline with new LLM
llm_client = GroqLLMClient(
api_key=st.session_state.groq_api_key if "groq_api_key" in st.session_state else "",
model_name=selected_llm,
max_rpm=settings.groq_rpm_limit,
rate_limit_delay=settings.rate_limit_delay
)
st.session_state.rag_pipeline.llm_client = llm_client
with col3:
if st.button("πŸ—‘οΈ Clear History"):
st.session_state.chat_history = []
st.session_state.rag_pipeline.clear_history()
st.rerun()
# Chat container
chat_container = st.container()
# Display chat history
with chat_container:
for chat_idx, entry in enumerate(st.session_state.chat_history):
# User message
with st.chat_message("user"):
st.write(entry["query"])
# Assistant message
with st.chat_message("assistant"):
st.write(entry["response"])
# Show retrieved documents in expander
with st.expander("πŸ“„ Retrieved Documents"):
for doc_idx, doc in enumerate(entry["retrieved_documents"]):
st.markdown(f"**Document {doc_idx+1}** (Distance: {doc.get('distance', 'N/A'):.4f})")
st.text_area(
f"doc_{chat_idx}_{doc_idx}",
value=doc["document"],
height=100,
key=f"doc_area_{chat_idx}_{doc_idx}",
label_visibility="collapsed"
)
if doc.get("metadata"):
st.caption(f"Metadata: {doc['metadata']}")
# Chat input
query = st.chat_input("Ask a question...")
if query:
# Check if collection exists
if not st.session_state.rag_pipeline or not st.session_state.rag_pipeline.vector_store.current_collection:
st.error("❌ No data loaded. Please load a dataset first using the configuration panel.")
st.stop()
# Add user message
with chat_container:
with st.chat_message("user"):
st.write(query)
# Generate response
with st.spinner("Generating response..."):
try:
result = st.session_state.rag_pipeline.query(query)
except Exception as e:
st.error(f"❌ Error querying: {str(e)}")
st.info("Please load a dataset and create a collection first.")
st.stop()
# Add assistant message
with chat_container:
with st.chat_message("assistant"):
st.write(result["response"])
# Show retrieved documents
with st.expander("πŸ“„ Retrieved Documents"):
for doc_idx, doc in enumerate(result["retrieved_documents"]):
st.markdown(f"**Document {doc_idx+1}** (Distance: {doc.get('distance', 'N/A'):.4f})")
st.text_area(
f"doc_current_{doc_idx}",
value=doc["document"],
height=100,
key=f"doc_current_area_{doc_idx}",
label_visibility="collapsed"
)
if doc.get("metadata"):
st.caption(f"Metadata: {doc['metadata']}")
# Store in history
st.session_state.chat_history.append(result)
st.rerun()
def evaluation_interface():
"""Evaluation interface tab."""
st.subheader("πŸ“Š TRACE Evaluation")
# Check if collection is loaded
if not st.session_state.collection_loaded:
st.warning("⚠️ No data loaded. Please load a collection first.")
return
# LLM selector for evaluation
col1, col2 = st.columns([3, 1])
with col1:
selected_llm = st.selectbox(
"Select LLM for evaluation:",
settings.llm_models,
index=settings.llm_models.index(st.session_state.current_llm),
key="eval_llm_selector"
)
st.markdown("""
Run TRACE evaluation metrics on test data:
- **Utilization**: How well the system uses retrieved documents
- **Relevance**: Relevance of retrieved documents to the query
- **Adherence**: How well the response adheres to the retrieved context
- **Completeness**: How complete the response is in answering the query
""")
num_test_samples = st.slider(
"Number of test samples",
min_value=5,
max_value=50,
value=10,
step=5
)
if st.button("πŸ”¬ Run Evaluation", type="primary"):
# Use selected LLM for evaluation
run_evaluation(num_test_samples, selected_llm)
# Display results
if st.session_state.evaluation_results:
results = st.session_state.evaluation_results
st.success("βœ… Evaluation Complete!")
# Display aggregate scores
col1, col2, col3, col4, col5 = st.columns(5)
with col1:
st.metric("πŸ“Š Utilization", f"{results['utilization']:.3f}")
with col2:
st.metric("🎯 Relevance", f"{results['relevance']:.3f}")
with col3:
st.metric("βœ… Adherence", f"{results['adherence']:.3f}")
with col4:
st.metric("πŸ“ Completeness", f"{results['completeness']:.3f}")
with col5:
st.metric("⭐ Average", f"{results['average']:.3f}")
# Detailed results
with st.expander("πŸ“‹ Detailed Results"):
df = pd.DataFrame(results["individual_scores"])
st.dataframe(df, use_container_width=True)
# Download results
results_json = json.dumps(results, indent=2)
st.download_button(
label="πŸ’Ύ Download Results (JSON)",
data=results_json,
file_name=f"trace_evaluation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
mime="application/json"
)
def run_evaluation(num_samples: int, selected_llm: str = None):
"""Run TRACE evaluation."""
with st.spinner(f"Running evaluation on {num_samples} samples..."):
try:
# Use selected LLM if provided
if selected_llm and selected_llm != st.session_state.current_llm:
st.info(f"Switching to {selected_llm} for evaluation...")
groq_api_key = st.session_state.groq_api_key if "groq_api_key" in st.session_state else ""
eval_llm_client = GroqLLMClient(
api_key=groq_api_key,
model_name=selected_llm,
max_rpm=settings.groq_rpm_limit,
rate_limit_delay=settings.rate_limit_delay
)
# Temporarily replace LLM client
original_llm = st.session_state.rag_pipeline.llm_client
st.session_state.rag_pipeline.llm_client = eval_llm_client
# Get test data
loader = RAGBenchLoader()
test_data = loader.get_test_data(
st.session_state.dataset_name,
num_samples
)
# Prepare test cases
test_cases = []
progress_bar = st.progress(0)
for i, sample in enumerate(test_data):
# Query the RAG system
result = st.session_state.rag_pipeline.query(
sample["question"],
n_results=5
)
# Prepare test case
test_cases.append({
"query": sample["question"],
"response": result["response"],
"retrieved_documents": [doc["document"] for doc in result["retrieved_documents"]],
"ground_truth": sample.get("answer", "")
})
# Update progress
progress_bar.progress((i + 1) / num_samples)
# Run evaluation
evaluator = TRACEEvaluator()
results = evaluator.evaluate_batch(test_cases)
st.session_state.evaluation_results = results
# Restore original LLM if it was switched
if selected_llm and selected_llm != st.session_state.current_llm:
st.session_state.rag_pipeline.llm_client = original_llm
except Exception as e:
st.error(f"Error during evaluation: {str(e)}")
def history_interface():
"""History interface tab."""
st.subheader("πŸ“œ Chat History")
if not st.session_state.chat_history:
st.info("No chat history yet. Start a conversation in the Chat tab!")
return
# Export history
col1, col2 = st.columns([3, 1])
with col2:
history_json = json.dumps(st.session_state.chat_history, indent=2)
st.download_button(
label="πŸ’Ύ Export History",
data=history_json,
file_name=f"chat_history_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
mime="application/json"
)
# Display history
for i, entry in enumerate(st.session_state.chat_history):
with st.expander(f"πŸ’¬ Conversation {i+1}: {entry['query'][:50]}..."):
st.markdown(f"**Query:** {entry['query']}")
st.markdown(f"**Response:** {entry['response']}")
st.markdown(f"**Timestamp:** {entry.get('timestamp', 'N/A')}")
st.markdown("**Retrieved Documents:**")
for j, doc in enumerate(entry["retrieved_documents"]):
st.text_area(
f"Document {j+1}",
value=doc["document"],
height=100,
key=f"history_doc_{i}_{j}"
)
if __name__ == "__main__":
main()