Developer commited on
Commit
f380f3e
Β·
1 Parent(s): 81f53e2

Simple Hello World test app to debug HF Spaces

Browse files
Files changed (3) hide show
  1. Dockerfile +5 -18
  2. streamlit_app.py +45 -1485
  3. streamlit_app_backup.py +1502 -0
Dockerfile CHANGED
@@ -5,9 +5,7 @@ WORKDIR /app
5
 
6
  # Install system dependencies
7
  RUN apt-get update && apt-get install -y \
8
- build-essential \
9
  curl \
10
- git \
11
  && rm -rf /var/lib/apt/lists/*
12
 
13
  # Create a non-root user for HuggingFace Spaces
@@ -19,25 +17,15 @@ ENV HOME=/home/user \
19
  # Set working directory for user
20
  WORKDIR $HOME/app
21
 
22
- # Install torch CPU first (smaller download ~200MB vs 2GB)
23
- RUN pip install --no-cache-dir --upgrade pip && \
24
- pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu && \
25
- echo "Torch installed successfully"
26
-
27
  # Copy requirements and install Python dependencies
28
  COPY --chown=user requirements.txt .
29
- RUN pip install --no-cache-dir -r requirements.txt && \
30
- echo "All requirements installed successfully"
31
-
32
- # Pre-download the lightweight embedding model to avoid timeout at startup
33
- RUN python -c "from sentence_transformers import SentenceTransformer; print('Downloading model...'); SentenceTransformer('all-MiniLM-L6-v2'); print('Model downloaded!')"
34
 
35
  # Copy application files
36
  COPY --chown=user . .
37
 
38
- # Create directories for data (use /tmp for ephemeral storage on Spaces)
39
- RUN mkdir -p /tmp/chroma_db /tmp/data_cache
40
-
41
  # Set environment variables
42
  ENV PYTHONUNBUFFERED=1
43
  ENV SPACE_ID=1
@@ -45,6 +33,5 @@ ENV SPACE_ID=1
45
  # Expose port for Streamlit (HuggingFace Spaces uses 7860)
46
  EXPOSE 7860
47
 
48
- # No healthcheck - let HF Spaces handle it
49
- # Run Streamlit with logging
50
- CMD ["streamlit", "run", "streamlit_app.py", "--server.port=7860", "--server.address=0.0.0.0", "--server.enableCORS=false", "--server.enableXsrfProtection=false", "--logger.level=info"]
 
5
 
6
  # Install system dependencies
7
  RUN apt-get update && apt-get install -y \
 
8
  curl \
 
9
  && rm -rf /var/lib/apt/lists/*
10
 
11
  # Create a non-root user for HuggingFace Spaces
 
17
  # Set working directory for user
18
  WORKDIR $HOME/app
19
 
 
 
 
 
 
20
  # Copy requirements and install Python dependencies
21
  COPY --chown=user requirements.txt .
22
+ RUN pip install --no-cache-dir --upgrade pip && \
23
+ pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu && \
24
+ pip install --no-cache-dir -r requirements.txt
 
 
25
 
26
  # Copy application files
27
  COPY --chown=user . .
28
 
 
 
 
29
  # Set environment variables
30
  ENV PYTHONUNBUFFERED=1
31
  ENV SPACE_ID=1
 
33
  # Expose port for Streamlit (HuggingFace Spaces uses 7860)
34
  EXPOSE 7860
35
 
36
+ # Run Streamlit
37
+ CMD ["streamlit", "run", "streamlit_app.py", "--server.port=7860", "--server.address=0.0.0.0"]
 
streamlit_app.py CHANGED
@@ -1,1502 +1,62 @@
1
- """Streamlit chat interface for RAG application."""
2
  import streamlit as st
3
- import sys
4
- import os
5
- from datetime import datetime
6
- import json
7
- import pandas as pd
8
- from typing import Optional
9
- import warnings
10
-
11
- # Suppress warnings
12
- warnings.filterwarnings('ignore')
13
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
14
-
15
- # Add parent directory to path
16
- sys.path.append(os.path.dirname(os.path.abspath(__file__)))
17
-
18
- # Check if running on HuggingFace Spaces
19
- IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None
20
-
21
- from config import settings
22
- from dataset_loader import RAGBenchLoader
23
- from vector_store import ChromaDBManager, create_vector_store
24
- try:
25
- from vector_store import QdrantManager, QDRANT_AVAILABLE
26
- except ImportError:
27
- QDRANT_AVAILABLE = False
28
- from llm_client import GroqLLMClient, OllamaLLMClient, RAGPipeline, create_llm_client
29
- from trace_evaluator import TRACEEvaluator
30
- from embedding_models import EmbeddingFactory
31
- from chunking_strategies import ChunkingFactory
32
 
33
-
34
- # Page configuration
35
  st.set_page_config(
36
- page_title="RAG Capstone Project",
37
  page_icon="πŸ€–",
38
  layout="wide"
39
  )
40
 
41
- # Initialize session state
42
- if "chat_history" not in st.session_state:
43
- st.session_state.chat_history = []
44
-
45
- if "rag_pipeline" not in st.session_state:
46
- st.session_state.rag_pipeline = None
47
-
48
- if "vector_store" not in st.session_state:
49
- st.session_state.vector_store = None
50
-
51
- if "collection_loaded" not in st.session_state:
52
- st.session_state.collection_loaded = False
53
-
54
- if "evaluation_results" not in st.session_state:
55
- st.session_state.evaluation_results = None
56
-
57
- if "dataset_size" not in st.session_state:
58
- st.session_state.dataset_size = 10000
59
-
60
- if "current_dataset" not in st.session_state:
61
- st.session_state.current_dataset = None
62
-
63
- if "current_llm" not in st.session_state:
64
- st.session_state.current_llm = settings.llm_models[1]
65
-
66
- if "selected_collection" not in st.session_state:
67
- st.session_state.selected_collection = None
68
-
69
- if "available_collections" not in st.session_state:
70
- st.session_state.available_collections = []
71
-
72
- if "dataset_name" not in st.session_state:
73
- st.session_state.dataset_name = None
74
-
75
- if "collection_name" not in st.session_state:
76
- st.session_state.collection_name = None
77
-
78
- if "embedding_model" not in st.session_state:
79
- st.session_state.embedding_model = None
80
-
81
- if "groq_api_key" not in st.session_state:
82
- st.session_state.groq_api_key = ""
83
-
84
- if "llm_provider" not in st.session_state:
85
- st.session_state.llm_provider = settings.llm_provider
86
-
87
- if "ollama_model" not in st.session_state:
88
- st.session_state.ollama_model = settings.ollama_model
89
 
90
- if "vector_store_provider" not in st.session_state:
91
- st.session_state.vector_store_provider = settings.vector_store_provider
92
 
93
- if "qdrant_url" not in st.session_state:
94
- st.session_state.qdrant_url = settings.qdrant_url
95
-
96
- if "qdrant_api_key" not in st.session_state:
97
- st.session_state.qdrant_api_key = settings.qdrant_api_key
98
-
99
-
100
- def get_available_collections(provider: str = None):
101
- """Get list of available collections from vector store."""
102
- provider = provider or st.session_state.get("vector_store_provider", "chroma")
103
-
104
- try:
105
- if provider == "qdrant" and QDRANT_AVAILABLE:
106
- qdrant_url = st.session_state.get("qdrant_url") or settings.qdrant_url
107
- qdrant_api_key = st.session_state.get("qdrant_api_key") or settings.qdrant_api_key
108
- if qdrant_url and qdrant_api_key:
109
- vector_store = create_vector_store("qdrant", url=qdrant_url, api_key=qdrant_api_key)
110
- collections = vector_store.list_collections()
111
- return collections
112
- return []
113
- else:
114
- vector_store = ChromaDBManager(settings.chroma_persist_directory)
115
- collections = vector_store.list_collections()
116
- return collections
117
- except Exception as e:
118
- print(f"Error getting collections: {e}")
119
- return []
120
-
121
-
122
- def main():
123
- """Main Streamlit application."""
124
- st.title("πŸ€– RAG Capstone Project")
125
- st.markdown("### Retrieval-Augmented Generation with TRACE Evaluation")
126
-
127
- # Show HuggingFace Spaces notice
128
- if IS_HUGGINGFACE_SPACE:
129
- st.info("πŸ€— Running on Hugging Face Spaces - Using Groq API (cloud-based LLM)")
130
-
131
- # Get available collections at startup
132
- available_collections = get_available_collections()
133
- st.session_state.available_collections = available_collections
134
-
135
- # Sidebar for configuration
136
- with st.sidebar:
137
- st.header("Configuration")
138
-
139
- # LLM Provider Selection - Disable Ollama on HuggingFace Spaces
140
- st.subheader("πŸ”Œ LLM Provider")
141
-
142
- if IS_HUGGINGFACE_SPACE:
143
- # Force Groq on HuggingFace Spaces (Ollama not available)
144
- st.caption("☁️ **Groq API** (Ollama unavailable on Spaces)")
145
- llm_provider = "groq"
146
- st.session_state.llm_provider = "groq"
147
- else:
148
- llm_provider = st.radio(
149
- "Choose LLM Provider:",
150
- options=["groq", "ollama"],
151
- index=0 if st.session_state.llm_provider == "groq" else 1,
152
- format_func=lambda x: "☁️ Groq API (Cloud)" if x == "groq" else "πŸ–₯️ Ollama (Local)",
153
- help="Groq: Cloud API with rate limits. Ollama: Local unlimited inference.",
154
- key="llm_provider_radio"
155
- )
156
- st.session_state.llm_provider = llm_provider
157
-
158
- # Provider-specific settings
159
- if llm_provider == "groq":
160
- st.caption("⚠️ Free tier: 30 requests/min")
161
-
162
- # On HuggingFace Spaces, check for API key in secrets first
163
- default_api_key = os.environ.get("GROQ_API_KEY", "") or settings.groq_api_key or ""
164
-
165
- # API Key input
166
- groq_api_key = st.text_input(
167
- "Groq API Key",
168
- type="password",
169
- value=default_api_key,
170
- help="Enter your Groq API key (or set GROQ_API_KEY in Spaces secrets)"
171
- )
172
-
173
- if IS_HUGGINGFACE_SPACE and not groq_api_key:
174
- st.warning("πŸ’‘ Tip: Add GROQ_API_KEY to your Space secrets for persistence")
175
- else:
176
- # Ollama settings (only available locally)
177
- st.caption("βœ… No rate limits - unlimited usage!")
178
- ollama_host = st.text_input(
179
- "Ollama Host",
180
- value=settings.ollama_host,
181
- help="Ollama server URL (default: http://localhost:11434)"
182
- )
183
-
184
- ollama_model = st.selectbox(
185
- "Select Ollama Model:",
186
- options=settings.ollama_models,
187
- index=settings.ollama_models.index(st.session_state.ollama_model) if st.session_state.ollama_model in settings.ollama_models else 0,
188
- key="ollama_model_selector"
189
- )
190
- st.session_state.ollama_model = ollama_model
191
-
192
- # Connection check button
193
- if st.button("πŸ” Check Ollama Connection"):
194
- try:
195
- import requests
196
- response = requests.get(f"{ollama_host}/api/tags", timeout=5)
197
- if response.status_code == 200:
198
- models = response.json().get("models", [])
199
- model_names = [m["name"] for m in models]
200
- st.success(f"βœ… Connected! Available models: {', '.join(model_names)}")
201
- else:
202
- st.error(f"❌ Connection failed: {response.status_code}")
203
- except Exception as e:
204
- st.error(f"❌ Cannot connect to Ollama: {e}")
205
- st.info("Make sure Ollama is running: `ollama serve`")
206
-
207
- groq_api_key = "" # Not needed for Ollama
208
-
209
- st.divider()
210
-
211
- # Vector Store Provider Selection
212
- st.subheader("πŸ’Ύ Vector Store")
213
-
214
- if IS_HUGGINGFACE_SPACE:
215
- st.caption("☁️ Use **Qdrant Cloud** for persistent storage")
216
- vector_store_options = ["qdrant", "chroma"]
217
- default_idx = 0
218
- else:
219
- vector_store_options = ["chroma", "qdrant"]
220
- default_idx = 0
221
-
222
- vector_store_provider = st.radio(
223
- "Choose Vector Store:",
224
- options=vector_store_options,
225
- index=default_idx,
226
- format_func=lambda x: "☁️ Qdrant Cloud (Persistent)" if x == "qdrant" else "πŸ’Ύ ChromaDB (Local)",
227
- help="Qdrant: Cloud storage (persistent). ChromaDB: Local storage (ephemeral on Spaces).",
228
- key="vector_store_radio"
229
- )
230
- st.session_state.vector_store_provider = vector_store_provider
231
-
232
- # Qdrant settings
233
- if vector_store_provider == "qdrant":
234
- default_qdrant_url = os.environ.get("QDRANT_URL", "") or settings.qdrant_url
235
- default_qdrant_key = os.environ.get("QDRANT_API_KEY", "") or settings.qdrant_api_key
236
-
237
- qdrant_url = st.text_input(
238
- "Qdrant URL",
239
- value=default_qdrant_url,
240
- placeholder="https://xxx-xxx.aws.cloud.qdrant.io:6333",
241
- help="Your Qdrant Cloud cluster URL"
242
- )
243
- qdrant_api_key = st.text_input(
244
- "Qdrant API Key",
245
- type="password",
246
- value=default_qdrant_key,
247
- help="Your Qdrant API key"
248
- )
249
- st.session_state.qdrant_url = qdrant_url
250
- st.session_state.qdrant_api_key = qdrant_api_key
251
-
252
- if not qdrant_url or not qdrant_api_key:
253
- st.warning("⚠️ Get free Qdrant Cloud at: https://cloud.qdrant.io")
254
-
255
- # Test Qdrant connection
256
- if st.button("πŸ” Test Qdrant Connection"):
257
- if qdrant_url and qdrant_api_key:
258
- try:
259
- test_store = create_vector_store("qdrant", url=qdrant_url, api_key=qdrant_api_key)
260
- collections = test_store.list_collections()
261
- st.success(f"βœ… Connected! Found {len(collections)} collection(s)")
262
- except Exception as e:
263
- st.error(f"❌ Connection failed: {e}")
264
- else:
265
- st.error("Please enter Qdrant URL and API Key")
266
-
267
- st.divider()
268
-
269
- # Get available collections based on provider
270
- available_collections = get_available_collections(vector_store_provider)
271
- st.session_state.available_collections = available_collections
272
-
273
- # Option 1: Use existing collection
274
- if available_collections:
275
- st.subheader("πŸ“š Existing Collections")
276
- st.write(f"Found {len(available_collections)} collection(s)")
277
-
278
- selected_collection = st.selectbox(
279
- "Or select existing collection:",
280
- available_collections,
281
- key="collection_selector"
282
- )
283
-
284
- if st.button("πŸ“– Load Existing Collection", type="secondary"):
285
- # Validate based on provider
286
- if llm_provider == "groq" and not groq_api_key:
287
- st.error("Please enter your Groq API key")
288
- elif vector_store_provider == "qdrant" and (not st.session_state.get("qdrant_url") or not st.session_state.get("qdrant_api_key")):
289
- st.error("Please enter Qdrant URL and API Key")
290
- else:
291
- load_existing_collection(
292
- groq_api_key,
293
- selected_collection,
294
- llm_provider,
295
- ollama_host if llm_provider == "ollama" else None,
296
- vector_store_provider
297
- )
298
-
299
- st.divider()
300
-
301
- # Option 2: Create new collection
302
- st.subheader("πŸ†• Create New Collection")
303
-
304
- # Dataset selection
305
- st.subheader("1. Dataset Selection")
306
- dataset_name = st.selectbox(
307
- "Choose Dataset",
308
- settings.ragbench_datasets,
309
- index=0
310
- )
311
-
312
- # Get dataset size dynamically
313
- if st.button("πŸ” Check Dataset Size", key="check_size"):
314
- with st.spinner("Checking dataset size..."):
315
- try:
316
- from datasets import load_dataset
317
-
318
- # Load dataset with download_mode to avoid cache issues
319
- st.info(f"Fetching dataset info for '{dataset_name}'...")
320
- ds = load_dataset(
321
- "rungalileo/ragbench",
322
- dataset_name,
323
- split="train",
324
- trust_remote_code=True,
325
- download_mode="force_redownload" # Force fresh download to avoid cache corruption
326
- )
327
- dataset_size = len(ds)
328
-
329
- st.session_state.dataset_size = dataset_size
330
- st.session_state.current_dataset = dataset_name
331
- st.success(f"βœ… Dataset '{dataset_name}' has {dataset_size:,} samples available")
332
- except Exception as e:
333
- st.error(f"❌ Error: {str(e)}")
334
- st.exception(e)
335
- st.warning(f"Could not determine dataset size. Using default of 10,000.")
336
- st.session_state.dataset_size = 10000
337
- st.session_state.current_dataset = dataset_name
338
-
339
- # Use stored dataset size or default
340
- max_samples_available = st.session_state.get('dataset_size', 10000)
341
-
342
- st.caption(f"Max available samples: {max_samples_available:,}")
343
-
344
- num_samples = st.slider(
345
- "Number of samples",
346
- min_value=10,
347
- max_value=max_samples_available,
348
- value=min(100, max_samples_available),
349
- step=50 if max_samples_available > 1000 else 10,
350
- help="Adjust slider to select number of samples"
351
- )
352
-
353
- load_all_samples = st.checkbox(
354
- "Load all available samples",
355
- value=False,
356
- help="Override slider and load entire dataset"
357
- )
358
-
359
- st.divider()
360
-
361
- # Chunking strategy
362
- st.subheader("2. Chunking Strategy")
363
- chunking_strategy = st.selectbox(
364
- "Choose Chunking Strategy",
365
- settings.chunking_strategies,
366
- index=0
367
- )
368
-
369
- chunk_size = st.slider(
370
- "Chunk Size",
371
- min_value=256,
372
- max_value=1024,
373
- value=512,
374
- step=128
375
- )
376
-
377
- overlap = st.slider(
378
- "Overlap",
379
- min_value=0,
380
- max_value=200,
381
- value=50,
382
- step=10
383
- )
384
-
385
- st.divider()
386
-
387
- # Embedding model
388
- st.subheader("3. Embedding Model")
389
- embedding_model = st.selectbox(
390
- "Choose Embedding Model",
391
- settings.embedding_models,
392
- index=0
393
- )
394
-
395
- st.divider()
396
-
397
- # LLM model selection for new collection
398
- st.subheader("4. LLM Model")
399
- if llm_provider == "groq":
400
- llm_model = st.selectbox(
401
- "Choose Groq LLM",
402
- settings.llm_models,
403
- index=1
404
- )
405
- else:
406
- llm_model = st.selectbox(
407
- "Choose Ollama Model",
408
- settings.ollama_models,
409
- index=settings.ollama_models.index(st.session_state.ollama_model) if st.session_state.ollama_model in settings.ollama_models else 0,
410
- key="llm_model_ollama"
411
- )
412
-
413
- st.divider()
414
-
415
- # Load data button
416
- if st.button("πŸš€ Load Data & Create Collection", type="primary"):
417
- # Validate based on provider
418
- if llm_provider == "groq" and not groq_api_key:
419
- st.error("Please enter your Groq API key")
420
- elif vector_store_provider == "qdrant" and (not st.session_state.get("qdrant_url") or not st.session_state.get("qdrant_api_key")):
421
- st.error("Please enter Qdrant URL and API Key")
422
- else:
423
- # Use None for num_samples if loading all data
424
- samples_to_load = None if load_all_samples else num_samples
425
- load_and_create_collection(
426
- groq_api_key,
427
- dataset_name,
428
- samples_to_load,
429
- chunking_strategy,
430
- chunk_size,
431
- overlap,
432
- embedding_model,
433
- llm_model,
434
- llm_provider,
435
- ollama_host if llm_provider == "ollama" else None,
436
- vector_store_provider
437
- )
438
-
439
- # Main content area
440
- if not st.session_state.collection_loaded:
441
- st.info("πŸ‘ˆ Please configure and load a dataset from the sidebar to begin")
442
-
443
- # Show instructions
444
- with st.expander("πŸ“– How to Use", expanded=True):
445
- st.markdown("""
446
- 1. **Enter your Groq API Key** in the sidebar
447
- 2. **Select a dataset** from RAG Bench
448
- 3. **Choose a chunking strategy** (dense, sparse, hybrid, re-ranking)
449
- 4. **Select an embedding model** for document vectorization
450
- 5. **Choose an LLM model** for response generation
451
- 6. **Click "Load Data & Create Collection"** to initialize
452
- 7. **Start chatting** in the chat interface
453
- 8. **View retrieved documents** and evaluation metrics
454
- 9. **Run TRACE evaluation** on test data
455
- """)
456
-
457
- # Show available options
458
- col1, col2 = st.columns(2)
459
-
460
- with col1:
461
- st.subheader("πŸ“Š Available Datasets")
462
- for ds in settings.ragbench_datasets:
463
- st.markdown(f"- {ds}")
464
-
465
- with col2:
466
- st.subheader("πŸ€– Available Models")
467
- st.markdown("**Embedding Models:**")
468
- for em in settings.embedding_models:
469
- st.markdown(f"- {em}")
470
-
471
- st.markdown("**LLM Models:**")
472
- for lm in settings.llm_models:
473
- st.markdown(f"- {lm}")
474
-
475
- else:
476
- # Create tabs for different functionalities
477
- tab1, tab2, tab3 = st.tabs(["πŸ’¬ Chat", "πŸ“Š Evaluation", "πŸ“œ History"])
478
-
479
- with tab1:
480
- chat_interface()
481
-
482
- with tab2:
483
- evaluation_interface()
484
-
485
- with tab3:
486
- history_interface()
487
-
488
-
489
- def load_existing_collection(api_key: str, collection_name: str, llm_provider: str = "groq", ollama_host: str = None, vector_store_provider: str = "chroma"):
490
- """Load an existing collection from vector store."""
491
- with st.spinner(f"Loading collection '{collection_name}'..."):
492
- try:
493
- # Initialize vector store based on provider
494
- if vector_store_provider == "qdrant":
495
- qdrant_url = st.session_state.get("qdrant_url") or settings.qdrant_url
496
- qdrant_api_key = st.session_state.get("qdrant_api_key") or settings.qdrant_api_key
497
- vector_store = create_vector_store("qdrant", url=qdrant_url, api_key=qdrant_api_key)
498
- else:
499
- vector_store = ChromaDBManager(settings.chroma_persist_directory)
500
-
501
- vector_store.get_collection(collection_name)
502
-
503
- # Extract dataset name from collection name (format: dataset_name_strategy_model)
504
- # Try to find which dataset this collection is based on
505
- dataset_name = None
506
- for ds in settings.ragbench_datasets:
507
- if collection_name.startswith(ds.replace("-", "_")):
508
- dataset_name = ds
509
- break
510
-
511
- if not dataset_name:
512
- dataset_name = collection_name.split("_")[0] # Fallback: use first part
513
-
514
- # Prompt for LLM selection based on provider
515
- if llm_provider == "groq":
516
- st.session_state.current_llm = st.selectbox(
517
- "Select Groq LLM for this collection:",
518
- settings.llm_models,
519
- key=f"llm_selector_{collection_name}"
520
- )
521
- else:
522
- st.session_state.current_llm = st.selectbox(
523
- "Select Ollama Model for this collection:",
524
- settings.ollama_models,
525
- key=f"ollama_selector_{collection_name}"
526
- )
527
-
528
- # Initialize LLM client based on provider
529
- st.info(f"Initializing LLM client ({llm_provider})...")
530
- llm_client = create_llm_client(
531
- provider=llm_provider,
532
- api_key=api_key,
533
- api_keys=settings.groq_api_keys if settings.groq_api_keys else None,
534
- model_name=st.session_state.current_llm,
535
- ollama_host=ollama_host or settings.ollama_host,
536
- max_rpm=settings.groq_rpm_limit,
537
- rate_limit_delay=settings.rate_limit_delay,
538
- max_retries=settings.max_retries,
539
- retry_delay=settings.retry_delay
540
- )
541
-
542
- # Create RAG pipeline with correct parameter names
543
- st.info("Creating RAG pipeline...")
544
- rag_pipeline = RAGPipeline(
545
- llm_client=llm_client,
546
- vector_store_manager=vector_store
547
- )
548
-
549
- # Store in session state
550
- st.session_state.vector_store = vector_store
551
- st.session_state.rag_pipeline = rag_pipeline
552
- st.session_state.collection_loaded = True
553
- st.session_state.current_collection = collection_name
554
- st.session_state.selected_collection = collection_name
555
- st.session_state.groq_api_key = api_key
556
- st.session_state.dataset_name = dataset_name
557
- st.session_state.collection_name = collection_name
558
- st.session_state.llm_provider = llm_provider
559
-
560
- # Display system prompt and model info
561
- provider_icon = "☁️" if llm_provider == "groq" else "πŸ–₯️"
562
- st.success(f"βœ… Collection '{collection_name}' loaded successfully! {provider_icon} Using {llm_provider.upper()}")
563
-
564
- with st.expander("πŸ€– Model & System Prompt Information", expanded=False):
565
- col1, col2 = st.columns(2)
566
- with col1:
567
- st.write(f"**Provider:** {provider_icon} {llm_provider.upper()}")
568
- st.write(f"**Model:** {st.session_state.current_llm}")
569
- st.write(f"**Collection:** {collection_name}")
570
- st.write(f"**Dataset:** {dataset_name}")
571
- with col2:
572
- st.write(f"**Temperature:** 0.0")
573
- st.write(f"**Max Tokens:** 2048")
574
- if llm_provider == "groq":
575
- st.write(f"**Rate Limit:** {settings.groq_rpm_limit} RPM")
576
- else:
577
- st.write(f"**Rate Limit:** βœ… Unlimited (Local)")
578
-
579
- st.markdown("#### System Prompt")
580
- st.info("""
581
- You are a Fact-Checking and Citation Specialist. Your task is to perform a rigorous audit of a response against provided documents to determine its accuracy, relevance, and level of support.
582
-
583
- **Task:**
584
- 1. Analyze the provided documents and identify information relevant to the user's question
585
- 2. Evaluate the response sentence-by-sentence
586
- 3. Verify each response sentence maps to supporting document sentences
587
- 4. Identify which document sentences were actually used in the response
588
- """)
589
-
590
- st.rerun()
591
-
592
- except Exception as e:
593
- st.error(f"Error loading collection: {str(e)}")
594
- st.exception(e)
595
-
596
-
597
- def load_and_create_collection(
598
- api_key: str,
599
- dataset_name: str,
600
- num_samples: Optional[int],
601
- chunking_strategy: str,
602
- chunk_size: int,
603
- overlap: int,
604
- embedding_model: str,
605
- llm_model: str,
606
- llm_provider: str = "groq",
607
- ollama_host: str = None,
608
- vector_store_provider: str = "chroma"
609
- ):
610
- """Load dataset and create vector collection."""
611
- with st.spinner("Loading dataset and creating collection..."):
612
- try:
613
- # Initialize dataset loader
614
- loader = RAGBenchLoader()
615
-
616
- # Load dataset
617
- if num_samples is None:
618
- st.info(f"Loading {dataset_name} dataset (all available samples)...")
619
- else:
620
- st.info(f"Loading {dataset_name} dataset ({num_samples} samples)...")
621
- dataset = loader.load_dataset(dataset_name, split="train", max_samples=num_samples)
622
-
623
- if not dataset:
624
- st.error("Failed to load dataset")
625
- return
626
-
627
- # Initialize vector store based on provider
628
- st.info(f"Initializing vector store ({vector_store_provider})...")
629
- if vector_store_provider == "qdrant":
630
- qdrant_url = st.session_state.get("qdrant_url") or settings.qdrant_url
631
- qdrant_api_key = st.session_state.get("qdrant_api_key") or settings.qdrant_api_key
632
- vector_store = create_vector_store("qdrant", url=qdrant_url, api_key=qdrant_api_key)
633
- else:
634
- vector_store = ChromaDBManager(settings.chroma_persist_directory)
635
-
636
- # Create collection name
637
- collection_name = f"{dataset_name}_{chunking_strategy}_{embedding_model.split('/')[-1]}"
638
- collection_name = collection_name.replace("-", "_").replace(".", "_")
639
-
640
- # Delete existing collection with same name (if exists)
641
- existing_collections = vector_store.list_collections()
642
- if collection_name in existing_collections:
643
- st.warning(f"Collection '{collection_name}' already exists. Deleting and recreating...")
644
- vector_store.delete_collection(collection_name)
645
- st.info("Old collection deleted. Creating new one...")
646
-
647
- # Load data into collection
648
- st.info(f"Creating collection with {chunking_strategy} chunking...")
649
- vector_store.load_dataset_into_collection(
650
- collection_name=collection_name,
651
- embedding_model_name=embedding_model,
652
- chunking_strategy=chunking_strategy,
653
- dataset_data=dataset,
654
- chunk_size=chunk_size,
655
- overlap=overlap
656
- )
657
-
658
- # Initialize LLM client based on provider
659
- st.info(f"Initializing LLM client ({llm_provider})...")
660
- llm_client = create_llm_client(
661
- provider=llm_provider,
662
- api_key=api_key,
663
- api_keys=settings.groq_api_keys if settings.groq_api_keys else None,
664
- model_name=llm_model,
665
- ollama_host=ollama_host or settings.ollama_host,
666
- max_rpm=settings.groq_rpm_limit,
667
- rate_limit_delay=settings.rate_limit_delay,
668
- max_retries=settings.max_retries,
669
- retry_delay=settings.retry_delay
670
- )
671
-
672
- # Create RAG pipeline with correct parameter names
673
- rag_pipeline = RAGPipeline(
674
- llm_client=llm_client,
675
- vector_store_manager=vector_store
676
- )
677
-
678
- # Store in session state
679
- st.session_state.vector_store = vector_store
680
- st.session_state.rag_pipeline = rag_pipeline
681
- st.session_state.collection_loaded = True
682
- st.session_state.current_collection = collection_name
683
- st.session_state.dataset_name = dataset_name
684
- st.session_state.dataset = dataset
685
- st.session_state.collection_name = collection_name
686
- st.session_state.embedding_model = embedding_model
687
- st.session_state.groq_api_key = api_key
688
- st.session_state.llm_provider = llm_provider
689
- st.session_state.vector_store_provider = vector_store_provider
690
-
691
- provider_icon = "☁️" if llm_provider == "groq" else "πŸ–₯️"
692
- vs_icon = "☁️" if vector_store_provider == "qdrant" else "πŸ’Ύ"
693
- st.success(f"βœ… Collection '{collection_name}' created successfully! {provider_icon} Using {llm_provider.upper()}")
694
- st.rerun()
695
-
696
- except Exception as e:
697
- st.error(f"Error: {str(e)}")
698
-
699
-
700
- def chat_interface():
701
- """Chat interface tab."""
702
- st.subheader("πŸ’¬ Chat Interface")
703
-
704
- # Check if collection is loaded
705
- if not st.session_state.collection_loaded:
706
- st.warning("⚠️ No data loaded. Please use the configuration panel to load a dataset and create a collection.")
707
- st.info("""
708
- Steps:
709
- 1. Select a dataset from the dropdown
710
- 2. Click "Load Data & Create Collection" button
711
- 3. Wait for the collection to be created
712
- 4. Then you can start chatting
713
- """)
714
- return
715
-
716
- # Display collection info and LLM selector
717
- col1, col2, col3 = st.columns([2, 2, 1])
718
- with col1:
719
- provider_icon = "☁️" if st.session_state.get("llm_provider", "groq") == "groq" else "πŸ–₯️"
720
- st.info(f"πŸ“š Collection: {st.session_state.current_collection} | {provider_icon} {st.session_state.get('llm_provider', 'groq').upper()}")
721
-
722
- with col2:
723
- # LLM selector for chat - based on provider
724
- current_provider = st.session_state.get("llm_provider", "groq")
725
- if current_provider == "groq":
726
- model_options = settings.llm_models
727
- try:
728
- current_index = settings.llm_models.index(st.session_state.current_llm)
729
- except ValueError:
730
- current_index = 0
731
- else:
732
- model_options = settings.ollama_models
733
- try:
734
- current_index = settings.ollama_models.index(st.session_state.current_llm)
735
- except ValueError:
736
- current_index = 0
737
-
738
- selected_llm = st.selectbox(
739
- f"Select {'Groq' if current_provider == 'groq' else 'Ollama'} Model for chat:",
740
- model_options,
741
- index=current_index,
742
- key="chat_llm_selector"
743
- )
744
-
745
- if selected_llm != st.session_state.current_llm:
746
- st.session_state.current_llm = selected_llm
747
- # Recreate LLM client with new model
748
- llm_client = create_llm_client(
749
- provider=current_provider,
750
- api_key=st.session_state.groq_api_key if "groq_api_key" in st.session_state else "",
751
- api_keys=settings.groq_api_keys if settings.groq_api_keys else None,
752
- model_name=selected_llm,
753
- ollama_host=settings.ollama_host,
754
- max_rpm=settings.groq_rpm_limit,
755
- rate_limit_delay=settings.rate_limit_delay
756
- )
757
- st.session_state.rag_pipeline.llm = llm_client
758
-
759
- with col3:
760
- if st.button("πŸ—‘οΈ Clear History"):
761
- st.session_state.chat_history = []
762
- st.session_state.rag_pipeline.clear_history()
763
- st.rerun()
764
-
765
- # Show system prompt info in expandable section
766
- with st.expander("πŸ€– System Prompt & Model Info", expanded=False):
767
- current_provider = st.session_state.get("llm_provider", "groq")
768
- col1, col2 = st.columns(2)
769
- with col1:
770
- provider_icon = "☁️" if current_provider == "groq" else "πŸ–₯️"
771
- st.write(f"**Provider:** {provider_icon} {current_provider.upper()}")
772
- st.write(f"**LLM Model:** {st.session_state.current_llm}")
773
- st.write(f"**Temperature:** 0.0")
774
- st.write(f"**Max Tokens:** 2048")
775
- with col2:
776
- st.write(f"**Collection:** {st.session_state.current_collection}")
777
- st.write(f"**Dataset:** {st.session_state.get('dataset_name', 'N/A')}")
778
- if current_provider == "groq":
779
- st.write(f"**Rate Limit:** {settings.groq_rpm_limit} RPM")
780
- else:
781
- st.write(f"**Rate Limit:** βœ… Unlimited (Local)")
782
-
783
- st.markdown("#### System Prompt Being Used")
784
- system_prompt = """You are a Fact-Checking and Citation Specialist. Your task is to perform a rigorous audit of a response against provided documents to determine its accuracy, relevance, and level of support.
785
-
786
- **TASK OVERVIEW**
787
- 1. **Analyze Documents**: Review the provided documents and identify information relevant to the user's question.
788
- 2. **Evaluate Response**: Review the provided answer sentence-by-sentence.
789
- 3. **Verify Support**: Map each answer sentence to specific supporting sentences in the documents.
790
- 4. **Identify Utilization**: Determine which document sentences were actually used (directly or implicitly) to form the answer."""
791
- st.info(system_prompt)
792
-
793
- # Chat container
794
- chat_container = st.container()
795
-
796
- # Display chat history
797
- with chat_container:
798
- for chat_idx, entry in enumerate(st.session_state.chat_history):
799
- # User message
800
- with st.chat_message("user"):
801
- st.write(entry["query"])
802
-
803
- # Assistant message
804
- with st.chat_message("assistant"):
805
- st.write(entry["response"])
806
-
807
- # Show retrieved documents in expander
808
- with st.expander("πŸ“„ Retrieved Documents"):
809
- for doc_idx, doc in enumerate(entry["retrieved_documents"]):
810
- st.markdown(f"**Document {doc_idx+1}** (Distance: {doc.get('distance', 'N/A'):.4f})")
811
- st.text_area(
812
- f"doc_{chat_idx}_{doc_idx}",
813
- value=doc["document"],
814
- height=100,
815
- key=f"doc_area_{chat_idx}_{doc_idx}",
816
- label_visibility="collapsed"
817
- )
818
- if doc.get("metadata"):
819
- st.caption(f"Metadata: {doc['metadata']}")
820
-
821
- # Chat input
822
- query = st.chat_input("Ask a question...")
823
-
824
- if query:
825
- # Check if collection exists
826
- if not st.session_state.rag_pipeline or not st.session_state.rag_pipeline.vector_store.current_collection:
827
- st.error("❌ No data loaded. Please load a dataset first using the configuration panel.")
828
- st.stop()
829
-
830
- # Add user message
831
- with chat_container:
832
- with st.chat_message("user"):
833
- st.write(query)
834
-
835
- # Generate response
836
- with st.spinner("Generating response..."):
837
- try:
838
- result = st.session_state.rag_pipeline.query(query)
839
- except Exception as e:
840
- st.error(f"❌ Error querying: {str(e)}")
841
- st.info("Please load a dataset and create a collection first.")
842
- st.stop()
843
-
844
- # Add assistant message
845
- with chat_container:
846
- with st.chat_message("assistant"):
847
- st.write(result["response"])
848
-
849
- # Show retrieved documents
850
- with st.expander("πŸ“„ Retrieved Documents"):
851
- for doc_idx, doc in enumerate(result["retrieved_documents"]):
852
- st.markdown(f"**Document {doc_idx+1}** (Distance: {doc.get('distance', 'N/A'):.4f})")
853
- st.text_area(
854
- f"doc_current_{doc_idx}",
855
- value=doc["document"],
856
- height=100,
857
- key=f"doc_current_area_{doc_idx}",
858
- label_visibility="collapsed"
859
- )
860
- if doc.get("metadata"):
861
- st.caption(f"Metadata: {doc['metadata']}")
862
-
863
- # Store in history
864
- st.session_state.chat_history.append(result)
865
- st.rerun()
866
 
 
 
 
 
 
 
867
 
868
- def evaluation_interface():
869
- """Evaluation interface tab."""
870
- st.subheader("πŸ“Š RAG Evaluation")
871
-
872
- # Check if collection is loaded
873
- if not st.session_state.collection_loaded:
874
- st.warning("⚠️ No data loaded. Please load a collection first.")
875
- return
876
-
877
- # Evaluation method selector
878
- eval_col1, eval_col2 = st.columns([2, 1])
879
- with eval_col1:
880
- evaluation_method = st.radio(
881
- "Evaluation Method:",
882
- options=["TRACE (Heuristic)", "GPT Labeling (LLM-based)", "Hybrid (Both)"],
883
- horizontal=True,
884
- help="TRACE is fast (no LLM). GPT Labeling is accurate but slower (requires LLM calls)."
885
- )
886
-
887
- # Map UI labels to method IDs
888
- method_map = {
889
- "TRACE (Heuristic)": "trace",
890
- "GPT Labeling (LLM-based)": "gpt_labeling",
891
- "Hybrid (Both)": "hybrid"
892
- }
893
- selected_method = method_map[evaluation_method]
894
-
895
- # LLM selector for evaluation
896
- current_provider = st.session_state.get("llm_provider", "groq")
897
- col1, col2 = st.columns([3, 1])
898
- with col1:
899
- # Show provider-specific models
900
- if current_provider == "groq":
901
- model_options = settings.llm_models
902
- try:
903
- current_index = settings.llm_models.index(st.session_state.current_llm)
904
- except ValueError:
905
- current_index = 0
906
- else:
907
- model_options = settings.ollama_models
908
- try:
909
- current_index = settings.ollama_models.index(st.session_state.current_llm)
910
- except ValueError:
911
- current_index = 0
912
-
913
- selected_llm = st.selectbox(
914
- f"Select {'Groq' if current_provider == 'groq' else 'Ollama'} Model for evaluation:",
915
- model_options,
916
- index=current_index,
917
- key="eval_llm_selector"
918
- )
919
-
920
- # Show provider info
921
- provider_icon = "☁️" if current_provider == "groq" else "πŸ–₯️"
922
- if current_provider == "ollama":
923
- st.caption(f"{provider_icon} Using local Ollama - **No rate limits!** Fast evaluation possible.")
924
- else:
925
- st.caption(f"{provider_icon} Using Groq API - Rate limited to {settings.groq_rpm_limit} RPM")
926
-
927
- # Show method description
928
- method_descriptions = {
929
- "trace": """
930
- **TRACE Heuristic Method** (Fast, Rule-Based)
931
- - Utilization: How well the system uses retrieved documents
932
- - Relevance: Relevance of retrieved documents to the query
933
- - Adherence: How well the response adheres to the retrieved context
934
- - Completeness: How complete the response is in answering the query
935
- - ⚑ Speed: ~100ms per evaluation
936
- - πŸ’° Cost: Free (no API calls)
937
- """,
938
- "gpt_labeling": """
939
- **GPT Labeling Method** (Accurate, LLM-based)
940
- - Uses sentence-level LLM analysis (from RAGBench paper)
941
- - Context Relevance: Fraction of context relevant to query
942
- - Context Utilization: Fraction of relevant context used
943
- - Completeness: Fraction of relevant info covered
944
- - Adherence: Response supported by context (no hallucinations)
945
- - ⏱️ Speed: ~2-5 seconds per evaluation
946
- - πŸ’° Cost: ~$0.002-0.01 per evaluation
947
- """,
948
- "hybrid": """
949
- **Hybrid Method** (Comprehensive)
950
- - Runs both TRACE and GPT Labeling methods
951
- - Provides both fast and accurate evaluation metrics
952
- - Best for detailed analysis
953
- - ⏱️ Speed: ~3-6 seconds per evaluation
954
- - πŸ’° Cost: Same as GPT Labeling
955
- """
956
- }
957
-
958
- st.markdown(method_descriptions[selected_method])
959
-
960
- # Get maximum test samples available for current dataset
961
- try:
962
- loader = RAGBenchLoader()
963
- max_test_samples = loader.get_test_data_size(st.session_state.dataset_name)
964
- st.caption(f"πŸ“Š Available test samples: {max_test_samples:,}")
965
- except Exception as e:
966
- max_test_samples = 100
967
- st.caption(f"Available test samples: ~{max_test_samples} (estimated)")
968
-
969
- # Ensure min and max are reasonable
970
- max_test_samples = max(5, min(max_test_samples, 500)) # Cap at 500 for performance
971
-
972
- num_test_samples = st.slider(
973
- "Number of test samples",
974
- min_value=5,
975
- max_value=max_test_samples,
976
- value=min(10, max_test_samples),
977
- step=5
978
- )
979
-
980
- # Show warning for GPT labeling (API cost) - only for Groq
981
- if selected_method in ["gpt_labeling", "hybrid"]:
982
- current_provider = st.session_state.get("llm_provider", "groq")
983
- if current_provider == "groq":
984
- st.warning(f"⚠️ **{evaluation_method}** requires LLM API calls. This will incur costs and be slower due to rate limiting ({settings.groq_rpm_limit} RPM).")
985
- else:
986
- st.info(f"ℹ️ **{evaluation_method}** using local Ollama - **No rate limits!** Evaluation will be much faster.")
987
-
988
- if st.button("πŸ”¬ Run Evaluation", type="primary"):
989
- # Use selected LLM for evaluation
990
- run_evaluation(num_test_samples, selected_llm, selected_method)
991
-
992
- # Display results
993
- if st.session_state.evaluation_results:
994
- results = st.session_state.evaluation_results
995
-
996
- st.success("βœ… Evaluation Complete!")
997
- st.divider()
998
- st.markdown("## πŸ“Š Evaluation Metrics")
999
-
1000
- # Display aggregate scores - handle both TRACE and GPT Labeling metric names
1001
- st.markdown("### Main Metrics")
1002
- col1, col2, col3, col4, col5 = st.columns(5)
1003
-
1004
- # Determine which metrics are available
1005
- utilization = results.get('utilization') or results.get('context_utilization', 0)
1006
- relevance = results.get('relevance') or results.get('context_relevance', 0)
1007
- adherence = results.get('adherence', 0)
1008
- completeness = results.get('completeness', 0)
1009
- average = results.get('average', 0)
1010
-
1011
- with col1:
1012
- st.metric("πŸ“Š Utilization", f"{utilization:.3f}")
1013
- with col2:
1014
- st.metric("🎯 Relevance", f"{relevance:.3f}")
1015
- with col3:
1016
- st.metric("βœ… Adherence", f"{adherence:.3f}")
1017
- with col4:
1018
- st.metric("πŸ“ Completeness", f"{completeness:.3f}")
1019
- with col5:
1020
- st.metric("⭐ Average", f"{average:.3f}")
1021
-
1022
- # Detailed results summary - handle both metric types
1023
- if "individual_scores" in results:
1024
- with st.expander("πŸ“‹ Summary Metrics by Query"):
1025
- df = pd.DataFrame(results["individual_scores"])
1026
- st.dataframe(df, use_container_width=True)
1027
-
1028
- # Detailed per-query results
1029
- if "detailed_results" in results and results["detailed_results"]:
1030
- with st.expander("πŸ” Detailed Per-Query Analysis"):
1031
- for query_result in results.get("detailed_results", []):
1032
- with st.expander(f"Query {query_result['query_id']}: {query_result['question'][:60]}..."):
1033
- st.markdown("### Question")
1034
- st.write(query_result['question'])
1035
-
1036
- st.markdown("### LLM Response")
1037
- st.write(query_result.get('llm_response', 'N/A'))
1038
-
1039
- st.markdown("### Retrieved Documents")
1040
- for doc_idx, doc in enumerate(query_result.get('retrieved_documents', []), 1):
1041
- with st.expander(f"πŸ“„ Document {doc_idx}"):
1042
- st.write(doc)
1043
-
1044
- if query_result.get('ground_truth'):
1045
- st.markdown("### Ground Truth")
1046
- st.write(query_result['ground_truth'])
1047
-
1048
- # Display metrics with correct labels based on method
1049
- metrics = query_result.get('metrics', {})
1050
- if metrics:
1051
- st.markdown("### Evaluation Metrics")
1052
- col1, col2, col3, col4, col5 = st.columns(5)
1053
-
1054
- # Get metric values (handle both TRACE and GPT names)
1055
- util_val = metrics.get('utilization') or metrics.get('context_utilization', 0)
1056
- rel_val = metrics.get('relevance') or metrics.get('context_relevance', 0)
1057
- adh_val = metrics.get('adherence', 0)
1058
- comp_val = metrics.get('completeness', 0)
1059
- avg_val = metrics.get('average', 0)
1060
-
1061
- with col1:
1062
- st.metric("Util", f"{util_val:.3f}")
1063
- with col2:
1064
- st.metric("Rel", f"{rel_val:.3f}")
1065
- with col3:
1066
- st.metric("Adh", f"{adh_val:.3f}")
1067
- with col4:
1068
- st.metric("Comp", f"{comp_val:.3f}")
1069
- with col5:
1070
- st.metric("Avg", f"{avg_val:.3f}")
1071
-
1072
- # For GPT Labeling and Hybrid methods, show additional metrics
1073
- method = results.get("method", "")
1074
- if "gpt_labeling" in method or "hybrid" in method:
1075
- # Show RMSE aggregation metrics (consistency across evaluations)
1076
- if "rmse_metrics" in results:
1077
- st.markdown("### πŸ“Š RMSE Aggregation (Metric Consistency)")
1078
- rmse_data = results.get("rmse_metrics", {})
1079
-
1080
- rmse_cols = st.columns(4)
1081
- with rmse_cols[0]:
1082
- rel_mean = rmse_data.get("context_relevance", {}).get("mean", 0)
1083
- rel_std = rmse_data.get("context_relevance", {}).get("std_dev", 0)
1084
- st.metric("Relevance", f"{rel_mean:.3f} Β±{rel_std:.3f}", help="Mean and Std Dev")
1085
- with rmse_cols[1]:
1086
- util_mean = rmse_data.get("context_utilization", {}).get("mean", 0)
1087
- util_std = rmse_data.get("context_utilization", {}).get("std_dev", 0)
1088
- st.metric("Utilization", f"{util_mean:.3f} Β±{util_std:.3f}", help="Mean and Std Dev")
1089
- with rmse_cols[2]:
1090
- comp_mean = rmse_data.get("completeness", {}).get("mean", 0)
1091
- comp_std = rmse_data.get("completeness", {}).get("std_dev", 0)
1092
- st.metric("Completeness", f"{comp_mean:.3f} Β±{comp_std:.3f}", help="Mean and Std Dev")
1093
- with rmse_cols[3]:
1094
- adh_mean = rmse_data.get("adherence", {}).get("mean", 0)
1095
- adh_std = rmse_data.get("adherence", {}).get("std_dev", 0)
1096
- st.metric("Adherence", f"{adh_mean:.3f} Β±{adh_std:.3f}", help="Mean and Std Dev")
1097
-
1098
- # Show detailed RMSE statistics in expander
1099
- with st.expander("See detailed RMSE aggregation statistics"):
1100
- for metric_name, metric_data in rmse_data.items():
1101
- st.write(f"**{metric_name}**")
1102
- col1, col2, col3, col4 = st.columns(4)
1103
- with col1:
1104
- st.write(f"Mean: {metric_data.get('mean', 0):.4f}")
1105
- with col2:
1106
- st.write(f"Std Dev: {metric_data.get('std_dev', 0):.4f}")
1107
- with col3:
1108
- st.write(f"Min: {metric_data.get('min', 0):.4f}")
1109
- with col4:
1110
- st.write(f"Max: {metric_data.get('max', 0):.4f}")
1111
-
1112
- # Show per-metric statistics if available
1113
- if "per_metric_statistics" in results:
1114
- st.markdown("### πŸ“ˆ Per-Metric Statistics (Distribution)")
1115
- stats_data = results.get("per_metric_statistics", {})
1116
-
1117
- stats_cols = st.columns(4)
1118
- with stats_cols[0]:
1119
- rel_stats = stats_data.get("context_relevance", {})
1120
- st.metric("Relevance Mean", f"{rel_stats.get('mean', 0):.3f}", help=f"Median: {rel_stats.get('median', 0):.3f}")
1121
- with stats_cols[1]:
1122
- util_stats = stats_data.get("context_utilization", {})
1123
- st.metric("Utilization Mean", f"{util_stats.get('mean', 0):.3f}", help=f"Median: {util_stats.get('median', 0):.3f}")
1124
- with stats_cols[2]:
1125
- comp_stats = stats_data.get("completeness", {})
1126
- st.metric("Completeness Mean", f"{comp_stats.get('mean', 0):.3f}", help=f"Median: {comp_stats.get('median', 0):.3f}")
1127
- with stats_cols[3]:
1128
- adh_stats = stats_data.get("adherence", {})
1129
- st.metric("Adherence Mean", f"{adh_stats.get('mean', 0):.3f}", help=f"Median: {adh_stats.get('median', 0):.3f}")
1130
-
1131
- # Show detailed statistics
1132
- with st.expander("See detailed per-metric statistics"):
1133
- for metric_name, metric_stats in stats_data.items():
1134
- st.write(f"**{metric_name}**")
1135
- col1, col2 = st.columns(2)
1136
- with col1:
1137
- st.write(f"""
1138
- - Mean: {metric_stats.get('mean', 0):.4f}
1139
- - Median: {metric_stats.get('median', 0):.4f}
1140
- - Std Dev: {metric_stats.get('std_dev', 0):.4f}
1141
- - Min: {metric_stats.get('min', 0):.4f}
1142
- - Max: {metric_stats.get('max', 0):.4f}
1143
- """)
1144
- with col2:
1145
- st.write(f"""
1146
- - 25th percentile: {metric_stats.get('percentile_25', 0):.4f}
1147
- - 75th percentile: {metric_stats.get('percentile_75', 0):.4f}
1148
- - Perfect (>=0.95): {metric_stats.get('perfect_count', 0)}
1149
- - Poor (<0.3): {metric_stats.get('poor_count', 0)}
1150
- - Samples: {metric_stats.get('sample_count', 0)}
1151
- """)
1152
-
1153
- # Show RMSE vs RAGBench Ground Truth (per RAGBench paper requirement)
1154
- if "rmse_vs_ground_truth" in results:
1155
- st.markdown("### πŸ“‰ RMSE vs RAGBench Ground Truth")
1156
- st.info("Compares predicted TRACE scores against original RAGBench dataset scores")
1157
- rmse_gt = results.get("rmse_vs_ground_truth", {})
1158
- per_metric_rmse = rmse_gt.get("per_metric_rmse", {})
1159
-
1160
- if per_metric_rmse:
1161
- rmse_gt_cols = st.columns(5)
1162
- with rmse_gt_cols[0]:
1163
- st.metric("Relevance RMSE", f"{per_metric_rmse.get('context_relevance', 0):.4f}",
1164
- delta=None, help="Lower is better (0 = perfect match)")
1165
- with rmse_gt_cols[1]:
1166
- st.metric("Utilization RMSE", f"{per_metric_rmse.get('context_utilization', 0):.4f}")
1167
- with rmse_gt_cols[2]:
1168
- st.metric("Completeness RMSE", f"{per_metric_rmse.get('completeness', 0):.4f}")
1169
- with rmse_gt_cols[3]:
1170
- st.metric("Adherence RMSE", f"{per_metric_rmse.get('adherence', 0):.4f}")
1171
- with rmse_gt_cols[4]:
1172
- agg_rmse = rmse_gt.get("aggregated_rmse", 0)
1173
- consistency = rmse_gt.get("consistency_score", 0)
1174
- st.metric("Aggregated RMSE", f"{agg_rmse:.4f}",
1175
- delta=f"Consistency: {consistency:.2%}", delta_color="normal")
1176
-
1177
- # Show AUCROC vs RAGBench Ground Truth (per RAGBench paper requirement)
1178
- if "aucroc_vs_ground_truth" in results:
1179
- st.markdown("### πŸ“Š AUC-ROC vs RAGBench Ground Truth")
1180
- st.info("Area Under ROC Curve comparing predicted vs ground truth binary classifications")
1181
- auc_gt = results.get("aucroc_vs_ground_truth", {})
1182
-
1183
- if auc_gt:
1184
- auc_cols = st.columns(5)
1185
- with auc_cols[0]:
1186
- st.metric("Relevance AUC", f"{auc_gt.get('context_relevance', 0):.4f}",
1187
- help="Higher is better (1.0 = perfect classification)")
1188
- with auc_cols[1]:
1189
- st.metric("Utilization AUC", f"{auc_gt.get('context_utilization', 0):.4f}")
1190
- with auc_cols[2]:
1191
- st.metric("Completeness AUC", f"{auc_gt.get('completeness', 0):.4f}")
1192
- with auc_cols[3]:
1193
- st.metric("Adherence AUC", f"{auc_gt.get('adherence', 0):.4f}")
1194
- with auc_cols[4]:
1195
- avg_auc = auc_gt.get("average", 0)
1196
- st.metric("Average AUC", f"{avg_auc:.4f}")
1197
-
1198
- # Download results
1199
- st.divider()
1200
- st.markdown("## πŸ’Ύ Download Results")
1201
-
1202
- # Create a comprehensive download with all details
1203
- download_data = {
1204
- "evaluation_metadata": {
1205
- "timestamp": datetime.now().isoformat(),
1206
- "dataset": st.session_state.dataset_name,
1207
- "method": results.get("evaluation_config", {}).get("evaluation_method", "gpt_labeling_prompts"),
1208
- "total_samples": results.get("num_samples", 0),
1209
- "embedding_model": st.session_state.embedding_model,
1210
- },
1211
- "aggregate_metrics": {
1212
- "context_relevance": results.get("context_relevance") or results.get("relevance", 0),
1213
- "context_utilization": results.get("context_utilization") or results.get("utilization", 0),
1214
- "completeness": results.get("completeness", 0),
1215
- "adherence": results.get("adherence", 0),
1216
- "average": results.get("average", 0),
1217
- },
1218
- "rmse_metrics": results.get("rmse_metrics", {}),
1219
- "per_metric_statistics": results.get("per_metric_statistics", {}),
1220
- "rmse_vs_ground_truth": results.get("rmse_vs_ground_truth", {}),
1221
- "aucroc_vs_ground_truth": results.get("aucroc_vs_ground_truth", {}),
1222
- "detailed_results": results.get("detailed_results", [])
1223
- }
1224
-
1225
- results_json = json.dumps(download_data, indent=2, default=str)
1226
-
1227
- col1, col2 = st.columns(2)
1228
- with col1:
1229
- st.download_button(
1230
- label="πŸ“₯ Download Complete Results (JSON)",
1231
- data=results_json,
1232
- file_name=f"evaluation_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
1233
- mime="application/json",
1234
- help="Download all evaluation results including metrics and per-query details"
1235
- )
1236
- with col2:
1237
- st.download_button(
1238
- label="πŸ“‹ Download Metrics Only (JSON)",
1239
- data=json.dumps(download_data["aggregate_metrics"], indent=2),
1240
- file_name=f"evaluation_metrics_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
1241
- mime="application/json",
1242
- help="Download only the aggregate metrics"
1243
- )
1244
 
 
 
 
 
 
1245
 
1246
- def run_evaluation(num_samples: int, selected_llm: str = None, method: str = "trace"):
1247
- """Run evaluation using selected method (TRACE, GPT Labeling, or Hybrid).
1248
-
1249
- Args:
1250
- num_samples: Number of test samples to evaluate
1251
- selected_llm: LLM model to use for evaluation
1252
- method: Evaluation method ("trace", "gpt_labeling", or "hybrid")
1253
- """
1254
- with st.spinner(f"Running evaluation on {num_samples} samples..."):
1255
- try:
1256
- # Create logs container
1257
- logs_container = st.container()
1258
- logs_list = []
1259
-
1260
- # Display logs header once outside function
1261
- logs_placeholder = st.empty()
1262
-
1263
- def add_log(message: str):
1264
- """Add log message and update display."""
1265
- logs_list.append(message)
1266
- with logs_placeholder.container():
1267
- st.markdown("### πŸ“‹ Evaluation Logs:")
1268
- for log_msg in logs_list:
1269
- st.caption(log_msg)
1270
-
1271
- # Log evaluation start
1272
- timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
1273
- add_log(f"⏱️ Evaluation started at {timestamp}")
1274
- add_log(f"πŸ“Š Dataset: {st.session_state.dataset_name}")
1275
- add_log(f"πŸ“ˆ Total samples: {num_samples}")
1276
- add_log(f"πŸ€– LLM Model: {selected_llm if selected_llm else st.session_state.current_llm}")
1277
- add_log(f"πŸ”— Vector Store: {st.session_state.collection_name}")
1278
- add_log(f"🧠 Embedding Model: {st.session_state.embedding_model}")
1279
-
1280
- # Map method names
1281
- method_names = {
1282
- "trace": "TRACE (Heuristic)",
1283
- "gpt_labeling": "GPT Labeling (LLM-based)",
1284
- "hybrid": "Hybrid (Both)"
1285
- }
1286
- add_log(f"πŸ”¬ Evaluation Method: {method_names.get(method, method)}")
1287
-
1288
- # Use selected LLM if provided - create with appropriate provider
1289
- eval_llm_client = None
1290
- original_llm = None
1291
- current_provider = st.session_state.get("llm_provider", "groq")
1292
-
1293
- if selected_llm and selected_llm != st.session_state.current_llm:
1294
- add_log(f"πŸ”„ Switching LLM to {selected_llm} ({current_provider.upper()})...")
1295
- groq_api_key = st.session_state.groq_api_key if "groq_api_key" in st.session_state else ""
1296
- eval_llm_client = create_llm_client(
1297
- provider=current_provider,
1298
- api_key=groq_api_key,
1299
- api_keys=settings.groq_api_keys if settings.groq_api_keys else None,
1300
- model_name=selected_llm,
1301
- ollama_host=settings.ollama_host,
1302
- max_rpm=settings.groq_rpm_limit,
1303
- rate_limit_delay=settings.rate_limit_delay,
1304
- max_retries=settings.max_retries,
1305
- retry_delay=settings.retry_delay
1306
- )
1307
- # Temporarily replace LLM client
1308
- original_llm = st.session_state.rag_pipeline.llm
1309
- st.session_state.rag_pipeline.llm = eval_llm_client
1310
- else:
1311
- eval_llm_client = st.session_state.rag_pipeline.llm
1312
-
1313
- # Log provider info
1314
- provider_icon = "☁️" if current_provider == "groq" else "πŸ–₯️"
1315
- add_log(f"{provider_icon} LLM Provider: {current_provider.upper()}")
1316
-
1317
- # Get test data
1318
- add_log("πŸ“₯ Loading test data...")
1319
- loader = RAGBenchLoader()
1320
- test_data = loader.get_test_data(
1321
- st.session_state.dataset_name,
1322
- num_samples
1323
- )
1324
- add_log(f"βœ… Loaded {len(test_data)} test samples")
1325
-
1326
- # Prepare test cases
1327
- test_cases = []
1328
-
1329
- progress_bar = st.progress(0)
1330
- status_text = st.empty()
1331
-
1332
- add_log("πŸ” Processing samples...")
1333
- for i, sample in enumerate(test_data):
1334
- status_text.text(f"Processing sample {i+1}/{num_samples}")
1335
-
1336
- # Query the RAG system
1337
- result = st.session_state.rag_pipeline.query(
1338
- sample["question"],
1339
- n_results=5
1340
- )
1341
-
1342
- # Prepare test case
1343
- test_cases.append({
1344
- "query": sample["question"],
1345
- "response": result["response"],
1346
- "retrieved_documents": [doc["document"] for doc in result["retrieved_documents"]],
1347
- "ground_truth": sample.get("answer", "")
1348
- })
1349
-
1350
- # Update progress
1351
- progress_bar.progress((i + 1) / num_samples)
1352
-
1353
- # Log every 10 samples
1354
- if (i + 1) % 10 == 0 or (i + 1) == num_samples:
1355
- add_log(f" βœ“ Processed {i + 1}/{num_samples} samples")
1356
-
1357
- status_text.text(f"Running {method_names.get(method, method)} evaluation...")
1358
- add_log(f"πŸ“Š Running evaluation using {method_names.get(method, method)}...")
1359
-
1360
- # Extract chunking and embedding metadata from session state
1361
- # (These were stored when the collection was loaded/created)
1362
- chunking_strategy = st.session_state.vector_store.chunking_strategy if st.session_state.vector_store else None
1363
- embedding_model = st.session_state.embedding_model
1364
- chunk_size = st.session_state.vector_store.chunk_size if st.session_state.vector_store else None
1365
- chunk_overlap = st.session_state.vector_store.chunk_overlap if st.session_state.vector_store else None
1366
-
1367
- # Log retrieval configuration
1368
- add_log(f"πŸ”§ Retrieval Configuration:")
1369
- add_log(f" β€’ Chunking Strategy: {chunking_strategy or 'Unknown'}")
1370
- add_log(f" β€’ Chunk Size: {chunk_size or 'Unknown'}")
1371
- add_log(f" β€’ Chunk Overlap: {chunk_overlap or 'Unknown'}")
1372
- add_log(f" β€’ Embedding Model: {embedding_model or 'Unknown'}")
1373
-
1374
- # Import unified pipeline
1375
- try:
1376
- from evaluation_pipeline import UnifiedEvaluationPipeline
1377
-
1378
- # Run evaluation with metadata using unified pipeline
1379
- pipeline = UnifiedEvaluationPipeline(
1380
- llm_client=eval_llm_client,
1381
- chunking_strategy=chunking_strategy,
1382
- embedding_model=embedding_model,
1383
- chunk_size=chunk_size,
1384
- chunk_overlap=chunk_overlap
1385
- )
1386
-
1387
- # Run evaluation with selected method
1388
- results = pipeline.evaluate_batch(test_cases, method=method)
1389
-
1390
- except ImportError:
1391
- # Fallback to TRACE only if evaluation_pipeline module not available
1392
- add_log("⚠️ evaluation_pipeline module not found, falling back to TRACE...")
1393
-
1394
- # Run evaluation with metadata using TRACE
1395
- evaluator = TRACEEvaluator(
1396
- chunking_strategy=chunking_strategy,
1397
- embedding_model=embedding_model,
1398
- chunk_size=chunk_size,
1399
- chunk_overlap=chunk_overlap
1400
- )
1401
- results = evaluator.evaluate_batch(test_cases)
1402
-
1403
- st.session_state.evaluation_results = results
1404
-
1405
- # Log evaluation results summary
1406
- add_log("βœ… Evaluation completed successfully!")
1407
-
1408
- # Display appropriate metrics based on method
1409
- if method == "trace":
1410
- add_log(f" β€’ Utilization: {results.get('utilization', 0):.2%}")
1411
- add_log(f" β€’ Relevance: {results.get('relevance', 0):.2%}")
1412
- add_log(f" β€’ Adherence: {results.get('adherence', 0):.2%}")
1413
- add_log(f" β€’ Completeness: {results.get('completeness', 0):.2%}")
1414
- add_log(f" β€’ Average: {results.get('average', 0):.2%}")
1415
- elif method == "gpt_labeling":
1416
- if "context_relevance" in results:
1417
- add_log(f" β€’ Context Relevance: {results.get('context_relevance', 0):.2%}")
1418
- add_log(f" β€’ Context Utilization: {results.get('context_utilization', 0):.2%}")
1419
- add_log(f" β€’ Completeness: {results.get('completeness', 0):.2%}")
1420
- add_log(f" β€’ Adherence: {results.get('adherence', 0):.2%}")
1421
- add_log(f" β€’ Average: {results.get('average', 0):.2%}")
1422
- # NEW: Display RMSE and AUCROC metrics if available
1423
- if "rmse_metrics" in results:
1424
- add_log(f"πŸ“ˆ RMSE Metrics (vs ground truth):")
1425
- rmse_metrics = results.get("rmse_metrics", {})
1426
- add_log(f" β€’ Context Relevance RMSE: {rmse_metrics.get('relevance', 0):.4f}")
1427
- add_log(f" β€’ Context Utilization RMSE: {rmse_metrics.get('utilization', 0):.4f}")
1428
- add_log(f" β€’ Completeness RMSE: {rmse_metrics.get('completeness', 0):.4f}")
1429
- add_log(f" β€’ Adherence RMSE: {rmse_metrics.get('adherence', 0):.4f}")
1430
- add_log(f" β€’ Average RMSE: {rmse_metrics.get('average', 0):.4f}")
1431
- if "auc_metrics" in results:
1432
- add_log(f"πŸ“Š AUCROC Metrics (binary classification):")
1433
- auc_metrics = results.get("auc_metrics", {})
1434
- add_log(f" β€’ Context Relevance AUCROC: {auc_metrics.get('relevance', 0):.4f}")
1435
- add_log(f" β€’ Context Utilization AUCROC: {auc_metrics.get('utilization', 0):.4f}")
1436
- add_log(f" β€’ Completeness AUCROC: {auc_metrics.get('completeness', 0):.4f}")
1437
- add_log(f" β€’ Adherence AUCROC: {auc_metrics.get('adherence', 0):.4f}")
1438
- add_log(f" β€’ Average AUCROC: {auc_metrics.get('average', 0):.4f}")
1439
- elif method == "hybrid":
1440
- add_log(" πŸ“Š TRACE Metrics:")
1441
- trace_res = results.get("trace_results", {})
1442
- add_log(f" β€’ Utilization: {trace_res.get('utilization', 0):.2%}")
1443
- add_log(f" β€’ Relevance: {trace_res.get('relevance', 0):.2%}")
1444
- add_log(f" β€’ Adherence: {trace_res.get('adherence', 0):.2%}")
1445
- add_log(f" β€’ Completeness: {trace_res.get('completeness', 0):.2%}")
1446
- add_log(" 🧠 GPT Labeling Metrics:")
1447
- gpt_res = results.get("gpt_results", {})
1448
- add_log(f" β€’ Context Relevance: {gpt_res.get('context_relevance', 0):.2%}")
1449
- add_log(f" β€’ Context Utilization: {gpt_res.get('context_utilization', 0):.2%}")
1450
- add_log(f" β€’ Completeness: {gpt_res.get('completeness', 0):.2%}")
1451
- add_log(f" β€’ Adherence: {gpt_res.get('adherence', 0):.2%}")
1452
-
1453
- # Restore original LLM if it was switched
1454
- if selected_llm and selected_llm != st.session_state.current_llm and original_llm:
1455
- st.session_state.rag_pipeline.llm = original_llm
1456
- add_log(f"πŸ”„ Restored original LLM")
1457
-
1458
- add_log(f"⏱️ Evaluation completed at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
1459
-
1460
- except Exception as e:
1461
- st.error(f"Error during evaluation: {str(e)}")
1462
- add_log(f"❌ Error: {str(e)}")
1463
 
 
 
 
 
 
1464
 
1465
- def history_interface():
1466
- """History interface tab."""
1467
- st.subheader("πŸ“œ Chat History")
1468
-
1469
- if not st.session_state.chat_history:
1470
- st.info("No chat history yet. Start a conversation in the Chat tab!")
1471
- return
1472
-
1473
- # Export history
1474
- col1, col2 = st.columns([3, 1])
1475
- with col2:
1476
- history_json = json.dumps(st.session_state.chat_history, indent=2)
1477
- st.download_button(
1478
- label="πŸ’Ύ Export History",
1479
- data=history_json,
1480
- file_name=f"chat_history_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
1481
- mime="application/json"
1482
- )
1483
-
1484
- # Display history
1485
- for i, entry in enumerate(st.session_state.chat_history):
1486
- with st.expander(f"πŸ’¬ Conversation {i+1}: {entry['query'][:50]}..."):
1487
- st.markdown(f"**Query:** {entry['query']}")
1488
- st.markdown(f"**Response:** {entry['response']}")
1489
- st.markdown(f"**Timestamp:** {entry.get('timestamp', 'N/A')}")
1490
-
1491
- st.markdown("**Retrieved Documents:**")
1492
- for j, doc in enumerate(entry["retrieved_documents"]):
1493
- st.text_area(
1494
- f"Document {j+1}",
1495
- value=doc["document"],
1496
- height=100,
1497
- key=f"history_doc_{i}_{j}"
1498
- )
1499
 
 
 
 
 
 
1500
 
1501
- if __name__ == "__main__":
1502
- main()
 
1
+ """Simple Hello World app to test HuggingFace Spaces."""
2
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
 
 
4
  st.set_page_config(
5
+ page_title="RAG Capstone - Test",
6
  page_icon="πŸ€–",
7
  layout="wide"
8
  )
9
 
10
+ st.title("πŸ€– Hello World!")
11
+ st.write("If you can see this, the HuggingFace Space is working!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ st.success("βœ… Streamlit is running successfully on port 7860")
 
14
 
15
+ import sys
16
+ st.info(f"Python version: {sys.version}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ # Test basic imports
19
+ try:
20
+ import pandas as pd
21
+ st.success("βœ… pandas imported")
22
+ except Exception as e:
23
+ st.error(f"❌ pandas: {e}")
24
 
25
+ try:
26
+ import numpy as np
27
+ st.success("βœ… numpy imported")
28
+ except Exception as e:
29
+ st.error(f"❌ numpy: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ try:
32
+ from groq import Groq
33
+ st.success("βœ… groq imported")
34
+ except Exception as e:
35
+ st.error(f"❌ groq: {e}")
36
 
37
+ try:
38
+ import torch
39
+ st.success(f"βœ… torch imported (version: {torch.__version__})")
40
+ except Exception as e:
41
+ st.error(f"❌ torch: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ try:
44
+ from sentence_transformers import SentenceTransformer
45
+ st.success("βœ… sentence_transformers imported")
46
+ except Exception as e:
47
+ st.error(f"❌ sentence_transformers: {e}")
48
 
49
+ try:
50
+ import chromadb
51
+ st.success("βœ… chromadb imported")
52
+ except Exception as e:
53
+ st.error(f"❌ chromadb: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ try:
56
+ from qdrant_client import QdrantClient
57
+ st.success("βœ… qdrant_client imported")
58
+ except Exception as e:
59
+ st.error(f"❌ qdrant_client: {e}")
60
 
61
+ st.markdown("---")
62
+ st.write("All basic imports completed!")
streamlit_app_backup.py ADDED
@@ -0,0 +1,1502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Streamlit chat interface for RAG application."""
2
+ import streamlit as st
3
+ import sys
4
+ import os
5
+ from datetime import datetime
6
+ import json
7
+ import pandas as pd
8
+ from typing import Optional
9
+ import warnings
10
+
11
+ # Suppress warnings
12
+ warnings.filterwarnings('ignore')
13
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
14
+
15
+ # Add parent directory to path
16
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
17
+
18
+ # Check if running on HuggingFace Spaces
19
+ IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None
20
+
21
+ from config import settings
22
+ from dataset_loader import RAGBenchLoader
23
+ from vector_store import ChromaDBManager, create_vector_store
24
+ try:
25
+ from vector_store import QdrantManager, QDRANT_AVAILABLE
26
+ except ImportError:
27
+ QDRANT_AVAILABLE = False
28
+ from llm_client import GroqLLMClient, OllamaLLMClient, RAGPipeline, create_llm_client
29
+ from trace_evaluator import TRACEEvaluator
30
+ from embedding_models import EmbeddingFactory
31
+ from chunking_strategies import ChunkingFactory
32
+
33
+
34
+ # Page configuration
35
+ st.set_page_config(
36
+ page_title="RAG Capstone Project",
37
+ page_icon="πŸ€–",
38
+ layout="wide"
39
+ )
40
+
41
+ # Initialize session state
42
+ if "chat_history" not in st.session_state:
43
+ st.session_state.chat_history = []
44
+
45
+ if "rag_pipeline" not in st.session_state:
46
+ st.session_state.rag_pipeline = None
47
+
48
+ if "vector_store" not in st.session_state:
49
+ st.session_state.vector_store = None
50
+
51
+ if "collection_loaded" not in st.session_state:
52
+ st.session_state.collection_loaded = False
53
+
54
+ if "evaluation_results" not in st.session_state:
55
+ st.session_state.evaluation_results = None
56
+
57
+ if "dataset_size" not in st.session_state:
58
+ st.session_state.dataset_size = 10000
59
+
60
+ if "current_dataset" not in st.session_state:
61
+ st.session_state.current_dataset = None
62
+
63
+ if "current_llm" not in st.session_state:
64
+ st.session_state.current_llm = settings.llm_models[1]
65
+
66
+ if "selected_collection" not in st.session_state:
67
+ st.session_state.selected_collection = None
68
+
69
+ if "available_collections" not in st.session_state:
70
+ st.session_state.available_collections = []
71
+
72
+ if "dataset_name" not in st.session_state:
73
+ st.session_state.dataset_name = None
74
+
75
+ if "collection_name" not in st.session_state:
76
+ st.session_state.collection_name = None
77
+
78
+ if "embedding_model" not in st.session_state:
79
+ st.session_state.embedding_model = None
80
+
81
+ if "groq_api_key" not in st.session_state:
82
+ st.session_state.groq_api_key = ""
83
+
84
+ if "llm_provider" not in st.session_state:
85
+ st.session_state.llm_provider = settings.llm_provider
86
+
87
+ if "ollama_model" not in st.session_state:
88
+ st.session_state.ollama_model = settings.ollama_model
89
+
90
+ if "vector_store_provider" not in st.session_state:
91
+ st.session_state.vector_store_provider = settings.vector_store_provider
92
+
93
+ if "qdrant_url" not in st.session_state:
94
+ st.session_state.qdrant_url = settings.qdrant_url
95
+
96
+ if "qdrant_api_key" not in st.session_state:
97
+ st.session_state.qdrant_api_key = settings.qdrant_api_key
98
+
99
+
100
+ def get_available_collections(provider: str = None):
101
+ """Get list of available collections from vector store."""
102
+ provider = provider or st.session_state.get("vector_store_provider", "chroma")
103
+
104
+ try:
105
+ if provider == "qdrant" and QDRANT_AVAILABLE:
106
+ qdrant_url = st.session_state.get("qdrant_url") or settings.qdrant_url
107
+ qdrant_api_key = st.session_state.get("qdrant_api_key") or settings.qdrant_api_key
108
+ if qdrant_url and qdrant_api_key:
109
+ vector_store = create_vector_store("qdrant", url=qdrant_url, api_key=qdrant_api_key)
110
+ collections = vector_store.list_collections()
111
+ return collections
112
+ return []
113
+ else:
114
+ vector_store = ChromaDBManager(settings.chroma_persist_directory)
115
+ collections = vector_store.list_collections()
116
+ return collections
117
+ except Exception as e:
118
+ print(f"Error getting collections: {e}")
119
+ return []
120
+
121
+
122
+ def main():
123
+ """Main Streamlit application."""
124
+ st.title("πŸ€– RAG Capstone Project")
125
+ st.markdown("### Retrieval-Augmented Generation with TRACE Evaluation")
126
+
127
+ # Show HuggingFace Spaces notice
128
+ if IS_HUGGINGFACE_SPACE:
129
+ st.info("πŸ€— Running on Hugging Face Spaces - Using Groq API (cloud-based LLM)")
130
+
131
+ # Get available collections at startup
132
+ available_collections = get_available_collections()
133
+ st.session_state.available_collections = available_collections
134
+
135
+ # Sidebar for configuration
136
+ with st.sidebar:
137
+ st.header("Configuration")
138
+
139
+ # LLM Provider Selection - Disable Ollama on HuggingFace Spaces
140
+ st.subheader("πŸ”Œ LLM Provider")
141
+
142
+ if IS_HUGGINGFACE_SPACE:
143
+ # Force Groq on HuggingFace Spaces (Ollama not available)
144
+ st.caption("☁️ **Groq API** (Ollama unavailable on Spaces)")
145
+ llm_provider = "groq"
146
+ st.session_state.llm_provider = "groq"
147
+ else:
148
+ llm_provider = st.radio(
149
+ "Choose LLM Provider:",
150
+ options=["groq", "ollama"],
151
+ index=0 if st.session_state.llm_provider == "groq" else 1,
152
+ format_func=lambda x: "☁️ Groq API (Cloud)" if x == "groq" else "πŸ–₯️ Ollama (Local)",
153
+ help="Groq: Cloud API with rate limits. Ollama: Local unlimited inference.",
154
+ key="llm_provider_radio"
155
+ )
156
+ st.session_state.llm_provider = llm_provider
157
+
158
+ # Provider-specific settings
159
+ if llm_provider == "groq":
160
+ st.caption("⚠️ Free tier: 30 requests/min")
161
+
162
+ # On HuggingFace Spaces, check for API key in secrets first
163
+ default_api_key = os.environ.get("GROQ_API_KEY", "") or settings.groq_api_key or ""
164
+
165
+ # API Key input
166
+ groq_api_key = st.text_input(
167
+ "Groq API Key",
168
+ type="password",
169
+ value=default_api_key,
170
+ help="Enter your Groq API key (or set GROQ_API_KEY in Spaces secrets)"
171
+ )
172
+
173
+ if IS_HUGGINGFACE_SPACE and not groq_api_key:
174
+ st.warning("πŸ’‘ Tip: Add GROQ_API_KEY to your Space secrets for persistence")
175
+ else:
176
+ # Ollama settings (only available locally)
177
+ st.caption("βœ… No rate limits - unlimited usage!")
178
+ ollama_host = st.text_input(
179
+ "Ollama Host",
180
+ value=settings.ollama_host,
181
+ help="Ollama server URL (default: http://localhost:11434)"
182
+ )
183
+
184
+ ollama_model = st.selectbox(
185
+ "Select Ollama Model:",
186
+ options=settings.ollama_models,
187
+ index=settings.ollama_models.index(st.session_state.ollama_model) if st.session_state.ollama_model in settings.ollama_models else 0,
188
+ key="ollama_model_selector"
189
+ )
190
+ st.session_state.ollama_model = ollama_model
191
+
192
+ # Connection check button
193
+ if st.button("πŸ” Check Ollama Connection"):
194
+ try:
195
+ import requests
196
+ response = requests.get(f"{ollama_host}/api/tags", timeout=5)
197
+ if response.status_code == 200:
198
+ models = response.json().get("models", [])
199
+ model_names = [m["name"] for m in models]
200
+ st.success(f"βœ… Connected! Available models: {', '.join(model_names)}")
201
+ else:
202
+ st.error(f"❌ Connection failed: {response.status_code}")
203
+ except Exception as e:
204
+ st.error(f"❌ Cannot connect to Ollama: {e}")
205
+ st.info("Make sure Ollama is running: `ollama serve`")
206
+
207
+ groq_api_key = "" # Not needed for Ollama
208
+
209
+ st.divider()
210
+
211
+ # Vector Store Provider Selection
212
+ st.subheader("πŸ’Ύ Vector Store")
213
+
214
+ if IS_HUGGINGFACE_SPACE:
215
+ st.caption("☁️ Use **Qdrant Cloud** for persistent storage")
216
+ vector_store_options = ["qdrant", "chroma"]
217
+ default_idx = 0
218
+ else:
219
+ vector_store_options = ["chroma", "qdrant"]
220
+ default_idx = 0
221
+
222
+ vector_store_provider = st.radio(
223
+ "Choose Vector Store:",
224
+ options=vector_store_options,
225
+ index=default_idx,
226
+ format_func=lambda x: "☁️ Qdrant Cloud (Persistent)" if x == "qdrant" else "πŸ’Ύ ChromaDB (Local)",
227
+ help="Qdrant: Cloud storage (persistent). ChromaDB: Local storage (ephemeral on Spaces).",
228
+ key="vector_store_radio"
229
+ )
230
+ st.session_state.vector_store_provider = vector_store_provider
231
+
232
+ # Qdrant settings
233
+ if vector_store_provider == "qdrant":
234
+ default_qdrant_url = os.environ.get("QDRANT_URL", "") or settings.qdrant_url
235
+ default_qdrant_key = os.environ.get("QDRANT_API_KEY", "") or settings.qdrant_api_key
236
+
237
+ qdrant_url = st.text_input(
238
+ "Qdrant URL",
239
+ value=default_qdrant_url,
240
+ placeholder="https://xxx-xxx.aws.cloud.qdrant.io:6333",
241
+ help="Your Qdrant Cloud cluster URL"
242
+ )
243
+ qdrant_api_key = st.text_input(
244
+ "Qdrant API Key",
245
+ type="password",
246
+ value=default_qdrant_key,
247
+ help="Your Qdrant API key"
248
+ )
249
+ st.session_state.qdrant_url = qdrant_url
250
+ st.session_state.qdrant_api_key = qdrant_api_key
251
+
252
+ if not qdrant_url or not qdrant_api_key:
253
+ st.warning("⚠️ Get free Qdrant Cloud at: https://cloud.qdrant.io")
254
+
255
+ # Test Qdrant connection
256
+ if st.button("πŸ” Test Qdrant Connection"):
257
+ if qdrant_url and qdrant_api_key:
258
+ try:
259
+ test_store = create_vector_store("qdrant", url=qdrant_url, api_key=qdrant_api_key)
260
+ collections = test_store.list_collections()
261
+ st.success(f"βœ… Connected! Found {len(collections)} collection(s)")
262
+ except Exception as e:
263
+ st.error(f"❌ Connection failed: {e}")
264
+ else:
265
+ st.error("Please enter Qdrant URL and API Key")
266
+
267
+ st.divider()
268
+
269
+ # Get available collections based on provider
270
+ available_collections = get_available_collections(vector_store_provider)
271
+ st.session_state.available_collections = available_collections
272
+
273
+ # Option 1: Use existing collection
274
+ if available_collections:
275
+ st.subheader("πŸ“š Existing Collections")
276
+ st.write(f"Found {len(available_collections)} collection(s)")
277
+
278
+ selected_collection = st.selectbox(
279
+ "Or select existing collection:",
280
+ available_collections,
281
+ key="collection_selector"
282
+ )
283
+
284
+ if st.button("πŸ“– Load Existing Collection", type="secondary"):
285
+ # Validate based on provider
286
+ if llm_provider == "groq" and not groq_api_key:
287
+ st.error("Please enter your Groq API key")
288
+ elif vector_store_provider == "qdrant" and (not st.session_state.get("qdrant_url") or not st.session_state.get("qdrant_api_key")):
289
+ st.error("Please enter Qdrant URL and API Key")
290
+ else:
291
+ load_existing_collection(
292
+ groq_api_key,
293
+ selected_collection,
294
+ llm_provider,
295
+ ollama_host if llm_provider == "ollama" else None,
296
+ vector_store_provider
297
+ )
298
+
299
+ st.divider()
300
+
301
+ # Option 2: Create new collection
302
+ st.subheader("πŸ†• Create New Collection")
303
+
304
+ # Dataset selection
305
+ st.subheader("1. Dataset Selection")
306
+ dataset_name = st.selectbox(
307
+ "Choose Dataset",
308
+ settings.ragbench_datasets,
309
+ index=0
310
+ )
311
+
312
+ # Get dataset size dynamically
313
+ if st.button("πŸ” Check Dataset Size", key="check_size"):
314
+ with st.spinner("Checking dataset size..."):
315
+ try:
316
+ from datasets import load_dataset
317
+
318
+ # Load dataset with download_mode to avoid cache issues
319
+ st.info(f"Fetching dataset info for '{dataset_name}'...")
320
+ ds = load_dataset(
321
+ "rungalileo/ragbench",
322
+ dataset_name,
323
+ split="train",
324
+ trust_remote_code=True,
325
+ download_mode="force_redownload" # Force fresh download to avoid cache corruption
326
+ )
327
+ dataset_size = len(ds)
328
+
329
+ st.session_state.dataset_size = dataset_size
330
+ st.session_state.current_dataset = dataset_name
331
+ st.success(f"βœ… Dataset '{dataset_name}' has {dataset_size:,} samples available")
332
+ except Exception as e:
333
+ st.error(f"❌ Error: {str(e)}")
334
+ st.exception(e)
335
+ st.warning(f"Could not determine dataset size. Using default of 10,000.")
336
+ st.session_state.dataset_size = 10000
337
+ st.session_state.current_dataset = dataset_name
338
+
339
+ # Use stored dataset size or default
340
+ max_samples_available = st.session_state.get('dataset_size', 10000)
341
+
342
+ st.caption(f"Max available samples: {max_samples_available:,}")
343
+
344
+ num_samples = st.slider(
345
+ "Number of samples",
346
+ min_value=10,
347
+ max_value=max_samples_available,
348
+ value=min(100, max_samples_available),
349
+ step=50 if max_samples_available > 1000 else 10,
350
+ help="Adjust slider to select number of samples"
351
+ )
352
+
353
+ load_all_samples = st.checkbox(
354
+ "Load all available samples",
355
+ value=False,
356
+ help="Override slider and load entire dataset"
357
+ )
358
+
359
+ st.divider()
360
+
361
+ # Chunking strategy
362
+ st.subheader("2. Chunking Strategy")
363
+ chunking_strategy = st.selectbox(
364
+ "Choose Chunking Strategy",
365
+ settings.chunking_strategies,
366
+ index=0
367
+ )
368
+
369
+ chunk_size = st.slider(
370
+ "Chunk Size",
371
+ min_value=256,
372
+ max_value=1024,
373
+ value=512,
374
+ step=128
375
+ )
376
+
377
+ overlap = st.slider(
378
+ "Overlap",
379
+ min_value=0,
380
+ max_value=200,
381
+ value=50,
382
+ step=10
383
+ )
384
+
385
+ st.divider()
386
+
387
+ # Embedding model
388
+ st.subheader("3. Embedding Model")
389
+ embedding_model = st.selectbox(
390
+ "Choose Embedding Model",
391
+ settings.embedding_models,
392
+ index=0
393
+ )
394
+
395
+ st.divider()
396
+
397
+ # LLM model selection for new collection
398
+ st.subheader("4. LLM Model")
399
+ if llm_provider == "groq":
400
+ llm_model = st.selectbox(
401
+ "Choose Groq LLM",
402
+ settings.llm_models,
403
+ index=1
404
+ )
405
+ else:
406
+ llm_model = st.selectbox(
407
+ "Choose Ollama Model",
408
+ settings.ollama_models,
409
+ index=settings.ollama_models.index(st.session_state.ollama_model) if st.session_state.ollama_model in settings.ollama_models else 0,
410
+ key="llm_model_ollama"
411
+ )
412
+
413
+ st.divider()
414
+
415
+ # Load data button
416
+ if st.button("πŸš€ Load Data & Create Collection", type="primary"):
417
+ # Validate based on provider
418
+ if llm_provider == "groq" and not groq_api_key:
419
+ st.error("Please enter your Groq API key")
420
+ elif vector_store_provider == "qdrant" and (not st.session_state.get("qdrant_url") or not st.session_state.get("qdrant_api_key")):
421
+ st.error("Please enter Qdrant URL and API Key")
422
+ else:
423
+ # Use None for num_samples if loading all data
424
+ samples_to_load = None if load_all_samples else num_samples
425
+ load_and_create_collection(
426
+ groq_api_key,
427
+ dataset_name,
428
+ samples_to_load,
429
+ chunking_strategy,
430
+ chunk_size,
431
+ overlap,
432
+ embedding_model,
433
+ llm_model,
434
+ llm_provider,
435
+ ollama_host if llm_provider == "ollama" else None,
436
+ vector_store_provider
437
+ )
438
+
439
+ # Main content area
440
+ if not st.session_state.collection_loaded:
441
+ st.info("πŸ‘ˆ Please configure and load a dataset from the sidebar to begin")
442
+
443
+ # Show instructions
444
+ with st.expander("πŸ“– How to Use", expanded=True):
445
+ st.markdown("""
446
+ 1. **Enter your Groq API Key** in the sidebar
447
+ 2. **Select a dataset** from RAG Bench
448
+ 3. **Choose a chunking strategy** (dense, sparse, hybrid, re-ranking)
449
+ 4. **Select an embedding model** for document vectorization
450
+ 5. **Choose an LLM model** for response generation
451
+ 6. **Click "Load Data & Create Collection"** to initialize
452
+ 7. **Start chatting** in the chat interface
453
+ 8. **View retrieved documents** and evaluation metrics
454
+ 9. **Run TRACE evaluation** on test data
455
+ """)
456
+
457
+ # Show available options
458
+ col1, col2 = st.columns(2)
459
+
460
+ with col1:
461
+ st.subheader("πŸ“Š Available Datasets")
462
+ for ds in settings.ragbench_datasets:
463
+ st.markdown(f"- {ds}")
464
+
465
+ with col2:
466
+ st.subheader("πŸ€– Available Models")
467
+ st.markdown("**Embedding Models:**")
468
+ for em in settings.embedding_models:
469
+ st.markdown(f"- {em}")
470
+
471
+ st.markdown("**LLM Models:**")
472
+ for lm in settings.llm_models:
473
+ st.markdown(f"- {lm}")
474
+
475
+ else:
476
+ # Create tabs for different functionalities
477
+ tab1, tab2, tab3 = st.tabs(["πŸ’¬ Chat", "πŸ“Š Evaluation", "πŸ“œ History"])
478
+
479
+ with tab1:
480
+ chat_interface()
481
+
482
+ with tab2:
483
+ evaluation_interface()
484
+
485
+ with tab3:
486
+ history_interface()
487
+
488
+
489
+ def load_existing_collection(api_key: str, collection_name: str, llm_provider: str = "groq", ollama_host: str = None, vector_store_provider: str = "chroma"):
490
+ """Load an existing collection from vector store."""
491
+ with st.spinner(f"Loading collection '{collection_name}'..."):
492
+ try:
493
+ # Initialize vector store based on provider
494
+ if vector_store_provider == "qdrant":
495
+ qdrant_url = st.session_state.get("qdrant_url") or settings.qdrant_url
496
+ qdrant_api_key = st.session_state.get("qdrant_api_key") or settings.qdrant_api_key
497
+ vector_store = create_vector_store("qdrant", url=qdrant_url, api_key=qdrant_api_key)
498
+ else:
499
+ vector_store = ChromaDBManager(settings.chroma_persist_directory)
500
+
501
+ vector_store.get_collection(collection_name)
502
+
503
+ # Extract dataset name from collection name (format: dataset_name_strategy_model)
504
+ # Try to find which dataset this collection is based on
505
+ dataset_name = None
506
+ for ds in settings.ragbench_datasets:
507
+ if collection_name.startswith(ds.replace("-", "_")):
508
+ dataset_name = ds
509
+ break
510
+
511
+ if not dataset_name:
512
+ dataset_name = collection_name.split("_")[0] # Fallback: use first part
513
+
514
+ # Prompt for LLM selection based on provider
515
+ if llm_provider == "groq":
516
+ st.session_state.current_llm = st.selectbox(
517
+ "Select Groq LLM for this collection:",
518
+ settings.llm_models,
519
+ key=f"llm_selector_{collection_name}"
520
+ )
521
+ else:
522
+ st.session_state.current_llm = st.selectbox(
523
+ "Select Ollama Model for this collection:",
524
+ settings.ollama_models,
525
+ key=f"ollama_selector_{collection_name}"
526
+ )
527
+
528
+ # Initialize LLM client based on provider
529
+ st.info(f"Initializing LLM client ({llm_provider})...")
530
+ llm_client = create_llm_client(
531
+ provider=llm_provider,
532
+ api_key=api_key,
533
+ api_keys=settings.groq_api_keys if settings.groq_api_keys else None,
534
+ model_name=st.session_state.current_llm,
535
+ ollama_host=ollama_host or settings.ollama_host,
536
+ max_rpm=settings.groq_rpm_limit,
537
+ rate_limit_delay=settings.rate_limit_delay,
538
+ max_retries=settings.max_retries,
539
+ retry_delay=settings.retry_delay
540
+ )
541
+
542
+ # Create RAG pipeline with correct parameter names
543
+ st.info("Creating RAG pipeline...")
544
+ rag_pipeline = RAGPipeline(
545
+ llm_client=llm_client,
546
+ vector_store_manager=vector_store
547
+ )
548
+
549
+ # Store in session state
550
+ st.session_state.vector_store = vector_store
551
+ st.session_state.rag_pipeline = rag_pipeline
552
+ st.session_state.collection_loaded = True
553
+ st.session_state.current_collection = collection_name
554
+ st.session_state.selected_collection = collection_name
555
+ st.session_state.groq_api_key = api_key
556
+ st.session_state.dataset_name = dataset_name
557
+ st.session_state.collection_name = collection_name
558
+ st.session_state.llm_provider = llm_provider
559
+
560
+ # Display system prompt and model info
561
+ provider_icon = "☁️" if llm_provider == "groq" else "πŸ–₯️"
562
+ st.success(f"βœ… Collection '{collection_name}' loaded successfully! {provider_icon} Using {llm_provider.upper()}")
563
+
564
+ with st.expander("πŸ€– Model & System Prompt Information", expanded=False):
565
+ col1, col2 = st.columns(2)
566
+ with col1:
567
+ st.write(f"**Provider:** {provider_icon} {llm_provider.upper()}")
568
+ st.write(f"**Model:** {st.session_state.current_llm}")
569
+ st.write(f"**Collection:** {collection_name}")
570
+ st.write(f"**Dataset:** {dataset_name}")
571
+ with col2:
572
+ st.write(f"**Temperature:** 0.0")
573
+ st.write(f"**Max Tokens:** 2048")
574
+ if llm_provider == "groq":
575
+ st.write(f"**Rate Limit:** {settings.groq_rpm_limit} RPM")
576
+ else:
577
+ st.write(f"**Rate Limit:** βœ… Unlimited (Local)")
578
+
579
+ st.markdown("#### System Prompt")
580
+ st.info("""
581
+ You are a Fact-Checking and Citation Specialist. Your task is to perform a rigorous audit of a response against provided documents to determine its accuracy, relevance, and level of support.
582
+
583
+ **Task:**
584
+ 1. Analyze the provided documents and identify information relevant to the user's question
585
+ 2. Evaluate the response sentence-by-sentence
586
+ 3. Verify each response sentence maps to supporting document sentences
587
+ 4. Identify which document sentences were actually used in the response
588
+ """)
589
+
590
+ st.rerun()
591
+
592
+ except Exception as e:
593
+ st.error(f"Error loading collection: {str(e)}")
594
+ st.exception(e)
595
+
596
+
597
+ def load_and_create_collection(
598
+ api_key: str,
599
+ dataset_name: str,
600
+ num_samples: Optional[int],
601
+ chunking_strategy: str,
602
+ chunk_size: int,
603
+ overlap: int,
604
+ embedding_model: str,
605
+ llm_model: str,
606
+ llm_provider: str = "groq",
607
+ ollama_host: str = None,
608
+ vector_store_provider: str = "chroma"
609
+ ):
610
+ """Load dataset and create vector collection."""
611
+ with st.spinner("Loading dataset and creating collection..."):
612
+ try:
613
+ # Initialize dataset loader
614
+ loader = RAGBenchLoader()
615
+
616
+ # Load dataset
617
+ if num_samples is None:
618
+ st.info(f"Loading {dataset_name} dataset (all available samples)...")
619
+ else:
620
+ st.info(f"Loading {dataset_name} dataset ({num_samples} samples)...")
621
+ dataset = loader.load_dataset(dataset_name, split="train", max_samples=num_samples)
622
+
623
+ if not dataset:
624
+ st.error("Failed to load dataset")
625
+ return
626
+
627
+ # Initialize vector store based on provider
628
+ st.info(f"Initializing vector store ({vector_store_provider})...")
629
+ if vector_store_provider == "qdrant":
630
+ qdrant_url = st.session_state.get("qdrant_url") or settings.qdrant_url
631
+ qdrant_api_key = st.session_state.get("qdrant_api_key") or settings.qdrant_api_key
632
+ vector_store = create_vector_store("qdrant", url=qdrant_url, api_key=qdrant_api_key)
633
+ else:
634
+ vector_store = ChromaDBManager(settings.chroma_persist_directory)
635
+
636
+ # Create collection name
637
+ collection_name = f"{dataset_name}_{chunking_strategy}_{embedding_model.split('/')[-1]}"
638
+ collection_name = collection_name.replace("-", "_").replace(".", "_")
639
+
640
+ # Delete existing collection with same name (if exists)
641
+ existing_collections = vector_store.list_collections()
642
+ if collection_name in existing_collections:
643
+ st.warning(f"Collection '{collection_name}' already exists. Deleting and recreating...")
644
+ vector_store.delete_collection(collection_name)
645
+ st.info("Old collection deleted. Creating new one...")
646
+
647
+ # Load data into collection
648
+ st.info(f"Creating collection with {chunking_strategy} chunking...")
649
+ vector_store.load_dataset_into_collection(
650
+ collection_name=collection_name,
651
+ embedding_model_name=embedding_model,
652
+ chunking_strategy=chunking_strategy,
653
+ dataset_data=dataset,
654
+ chunk_size=chunk_size,
655
+ overlap=overlap
656
+ )
657
+
658
+ # Initialize LLM client based on provider
659
+ st.info(f"Initializing LLM client ({llm_provider})...")
660
+ llm_client = create_llm_client(
661
+ provider=llm_provider,
662
+ api_key=api_key,
663
+ api_keys=settings.groq_api_keys if settings.groq_api_keys else None,
664
+ model_name=llm_model,
665
+ ollama_host=ollama_host or settings.ollama_host,
666
+ max_rpm=settings.groq_rpm_limit,
667
+ rate_limit_delay=settings.rate_limit_delay,
668
+ max_retries=settings.max_retries,
669
+ retry_delay=settings.retry_delay
670
+ )
671
+
672
+ # Create RAG pipeline with correct parameter names
673
+ rag_pipeline = RAGPipeline(
674
+ llm_client=llm_client,
675
+ vector_store_manager=vector_store
676
+ )
677
+
678
+ # Store in session state
679
+ st.session_state.vector_store = vector_store
680
+ st.session_state.rag_pipeline = rag_pipeline
681
+ st.session_state.collection_loaded = True
682
+ st.session_state.current_collection = collection_name
683
+ st.session_state.dataset_name = dataset_name
684
+ st.session_state.dataset = dataset
685
+ st.session_state.collection_name = collection_name
686
+ st.session_state.embedding_model = embedding_model
687
+ st.session_state.groq_api_key = api_key
688
+ st.session_state.llm_provider = llm_provider
689
+ st.session_state.vector_store_provider = vector_store_provider
690
+
691
+ provider_icon = "☁️" if llm_provider == "groq" else "πŸ–₯️"
692
+ vs_icon = "☁️" if vector_store_provider == "qdrant" else "πŸ’Ύ"
693
+ st.success(f"βœ… Collection '{collection_name}' created successfully! {provider_icon} Using {llm_provider.upper()}")
694
+ st.rerun()
695
+
696
+ except Exception as e:
697
+ st.error(f"Error: {str(e)}")
698
+
699
+
700
+ def chat_interface():
701
+ """Chat interface tab."""
702
+ st.subheader("πŸ’¬ Chat Interface")
703
+
704
+ # Check if collection is loaded
705
+ if not st.session_state.collection_loaded:
706
+ st.warning("⚠️ No data loaded. Please use the configuration panel to load a dataset and create a collection.")
707
+ st.info("""
708
+ Steps:
709
+ 1. Select a dataset from the dropdown
710
+ 2. Click "Load Data & Create Collection" button
711
+ 3. Wait for the collection to be created
712
+ 4. Then you can start chatting
713
+ """)
714
+ return
715
+
716
+ # Display collection info and LLM selector
717
+ col1, col2, col3 = st.columns([2, 2, 1])
718
+ with col1:
719
+ provider_icon = "☁️" if st.session_state.get("llm_provider", "groq") == "groq" else "πŸ–₯️"
720
+ st.info(f"πŸ“š Collection: {st.session_state.current_collection} | {provider_icon} {st.session_state.get('llm_provider', 'groq').upper()}")
721
+
722
+ with col2:
723
+ # LLM selector for chat - based on provider
724
+ current_provider = st.session_state.get("llm_provider", "groq")
725
+ if current_provider == "groq":
726
+ model_options = settings.llm_models
727
+ try:
728
+ current_index = settings.llm_models.index(st.session_state.current_llm)
729
+ except ValueError:
730
+ current_index = 0
731
+ else:
732
+ model_options = settings.ollama_models
733
+ try:
734
+ current_index = settings.ollama_models.index(st.session_state.current_llm)
735
+ except ValueError:
736
+ current_index = 0
737
+
738
+ selected_llm = st.selectbox(
739
+ f"Select {'Groq' if current_provider == 'groq' else 'Ollama'} Model for chat:",
740
+ model_options,
741
+ index=current_index,
742
+ key="chat_llm_selector"
743
+ )
744
+
745
+ if selected_llm != st.session_state.current_llm:
746
+ st.session_state.current_llm = selected_llm
747
+ # Recreate LLM client with new model
748
+ llm_client = create_llm_client(
749
+ provider=current_provider,
750
+ api_key=st.session_state.groq_api_key if "groq_api_key" in st.session_state else "",
751
+ api_keys=settings.groq_api_keys if settings.groq_api_keys else None,
752
+ model_name=selected_llm,
753
+ ollama_host=settings.ollama_host,
754
+ max_rpm=settings.groq_rpm_limit,
755
+ rate_limit_delay=settings.rate_limit_delay
756
+ )
757
+ st.session_state.rag_pipeline.llm = llm_client
758
+
759
+ with col3:
760
+ if st.button("πŸ—‘οΈ Clear History"):
761
+ st.session_state.chat_history = []
762
+ st.session_state.rag_pipeline.clear_history()
763
+ st.rerun()
764
+
765
+ # Show system prompt info in expandable section
766
+ with st.expander("πŸ€– System Prompt & Model Info", expanded=False):
767
+ current_provider = st.session_state.get("llm_provider", "groq")
768
+ col1, col2 = st.columns(2)
769
+ with col1:
770
+ provider_icon = "☁️" if current_provider == "groq" else "πŸ–₯️"
771
+ st.write(f"**Provider:** {provider_icon} {current_provider.upper()}")
772
+ st.write(f"**LLM Model:** {st.session_state.current_llm}")
773
+ st.write(f"**Temperature:** 0.0")
774
+ st.write(f"**Max Tokens:** 2048")
775
+ with col2:
776
+ st.write(f"**Collection:** {st.session_state.current_collection}")
777
+ st.write(f"**Dataset:** {st.session_state.get('dataset_name', 'N/A')}")
778
+ if current_provider == "groq":
779
+ st.write(f"**Rate Limit:** {settings.groq_rpm_limit} RPM")
780
+ else:
781
+ st.write(f"**Rate Limit:** βœ… Unlimited (Local)")
782
+
783
+ st.markdown("#### System Prompt Being Used")
784
+ system_prompt = """You are a Fact-Checking and Citation Specialist. Your task is to perform a rigorous audit of a response against provided documents to determine its accuracy, relevance, and level of support.
785
+
786
+ **TASK OVERVIEW**
787
+ 1. **Analyze Documents**: Review the provided documents and identify information relevant to the user's question.
788
+ 2. **Evaluate Response**: Review the provided answer sentence-by-sentence.
789
+ 3. **Verify Support**: Map each answer sentence to specific supporting sentences in the documents.
790
+ 4. **Identify Utilization**: Determine which document sentences were actually used (directly or implicitly) to form the answer."""
791
+ st.info(system_prompt)
792
+
793
+ # Chat container
794
+ chat_container = st.container()
795
+
796
+ # Display chat history
797
+ with chat_container:
798
+ for chat_idx, entry in enumerate(st.session_state.chat_history):
799
+ # User message
800
+ with st.chat_message("user"):
801
+ st.write(entry["query"])
802
+
803
+ # Assistant message
804
+ with st.chat_message("assistant"):
805
+ st.write(entry["response"])
806
+
807
+ # Show retrieved documents in expander
808
+ with st.expander("πŸ“„ Retrieved Documents"):
809
+ for doc_idx, doc in enumerate(entry["retrieved_documents"]):
810
+ st.markdown(f"**Document {doc_idx+1}** (Distance: {doc.get('distance', 'N/A'):.4f})")
811
+ st.text_area(
812
+ f"doc_{chat_idx}_{doc_idx}",
813
+ value=doc["document"],
814
+ height=100,
815
+ key=f"doc_area_{chat_idx}_{doc_idx}",
816
+ label_visibility="collapsed"
817
+ )
818
+ if doc.get("metadata"):
819
+ st.caption(f"Metadata: {doc['metadata']}")
820
+
821
+ # Chat input
822
+ query = st.chat_input("Ask a question...")
823
+
824
+ if query:
825
+ # Check if collection exists
826
+ if not st.session_state.rag_pipeline or not st.session_state.rag_pipeline.vector_store.current_collection:
827
+ st.error("❌ No data loaded. Please load a dataset first using the configuration panel.")
828
+ st.stop()
829
+
830
+ # Add user message
831
+ with chat_container:
832
+ with st.chat_message("user"):
833
+ st.write(query)
834
+
835
+ # Generate response
836
+ with st.spinner("Generating response..."):
837
+ try:
838
+ result = st.session_state.rag_pipeline.query(query)
839
+ except Exception as e:
840
+ st.error(f"❌ Error querying: {str(e)}")
841
+ st.info("Please load a dataset and create a collection first.")
842
+ st.stop()
843
+
844
+ # Add assistant message
845
+ with chat_container:
846
+ with st.chat_message("assistant"):
847
+ st.write(result["response"])
848
+
849
+ # Show retrieved documents
850
+ with st.expander("πŸ“„ Retrieved Documents"):
851
+ for doc_idx, doc in enumerate(result["retrieved_documents"]):
852
+ st.markdown(f"**Document {doc_idx+1}** (Distance: {doc.get('distance', 'N/A'):.4f})")
853
+ st.text_area(
854
+ f"doc_current_{doc_idx}",
855
+ value=doc["document"],
856
+ height=100,
857
+ key=f"doc_current_area_{doc_idx}",
858
+ label_visibility="collapsed"
859
+ )
860
+ if doc.get("metadata"):
861
+ st.caption(f"Metadata: {doc['metadata']}")
862
+
863
+ # Store in history
864
+ st.session_state.chat_history.append(result)
865
+ st.rerun()
866
+
867
+
868
+ def evaluation_interface():
869
+ """Evaluation interface tab."""
870
+ st.subheader("πŸ“Š RAG Evaluation")
871
+
872
+ # Check if collection is loaded
873
+ if not st.session_state.collection_loaded:
874
+ st.warning("⚠️ No data loaded. Please load a collection first.")
875
+ return
876
+
877
+ # Evaluation method selector
878
+ eval_col1, eval_col2 = st.columns([2, 1])
879
+ with eval_col1:
880
+ evaluation_method = st.radio(
881
+ "Evaluation Method:",
882
+ options=["TRACE (Heuristic)", "GPT Labeling (LLM-based)", "Hybrid (Both)"],
883
+ horizontal=True,
884
+ help="TRACE is fast (no LLM). GPT Labeling is accurate but slower (requires LLM calls)."
885
+ )
886
+
887
+ # Map UI labels to method IDs
888
+ method_map = {
889
+ "TRACE (Heuristic)": "trace",
890
+ "GPT Labeling (LLM-based)": "gpt_labeling",
891
+ "Hybrid (Both)": "hybrid"
892
+ }
893
+ selected_method = method_map[evaluation_method]
894
+
895
+ # LLM selector for evaluation
896
+ current_provider = st.session_state.get("llm_provider", "groq")
897
+ col1, col2 = st.columns([3, 1])
898
+ with col1:
899
+ # Show provider-specific models
900
+ if current_provider == "groq":
901
+ model_options = settings.llm_models
902
+ try:
903
+ current_index = settings.llm_models.index(st.session_state.current_llm)
904
+ except ValueError:
905
+ current_index = 0
906
+ else:
907
+ model_options = settings.ollama_models
908
+ try:
909
+ current_index = settings.ollama_models.index(st.session_state.current_llm)
910
+ except ValueError:
911
+ current_index = 0
912
+
913
+ selected_llm = st.selectbox(
914
+ f"Select {'Groq' if current_provider == 'groq' else 'Ollama'} Model for evaluation:",
915
+ model_options,
916
+ index=current_index,
917
+ key="eval_llm_selector"
918
+ )
919
+
920
+ # Show provider info
921
+ provider_icon = "☁️" if current_provider == "groq" else "πŸ–₯️"
922
+ if current_provider == "ollama":
923
+ st.caption(f"{provider_icon} Using local Ollama - **No rate limits!** Fast evaluation possible.")
924
+ else:
925
+ st.caption(f"{provider_icon} Using Groq API - Rate limited to {settings.groq_rpm_limit} RPM")
926
+
927
+ # Show method description
928
+ method_descriptions = {
929
+ "trace": """
930
+ **TRACE Heuristic Method** (Fast, Rule-Based)
931
+ - Utilization: How well the system uses retrieved documents
932
+ - Relevance: Relevance of retrieved documents to the query
933
+ - Adherence: How well the response adheres to the retrieved context
934
+ - Completeness: How complete the response is in answering the query
935
+ - ⚑ Speed: ~100ms per evaluation
936
+ - πŸ’° Cost: Free (no API calls)
937
+ """,
938
+ "gpt_labeling": """
939
+ **GPT Labeling Method** (Accurate, LLM-based)
940
+ - Uses sentence-level LLM analysis (from RAGBench paper)
941
+ - Context Relevance: Fraction of context relevant to query
942
+ - Context Utilization: Fraction of relevant context used
943
+ - Completeness: Fraction of relevant info covered
944
+ - Adherence: Response supported by context (no hallucinations)
945
+ - ⏱️ Speed: ~2-5 seconds per evaluation
946
+ - πŸ’° Cost: ~$0.002-0.01 per evaluation
947
+ """,
948
+ "hybrid": """
949
+ **Hybrid Method** (Comprehensive)
950
+ - Runs both TRACE and GPT Labeling methods
951
+ - Provides both fast and accurate evaluation metrics
952
+ - Best for detailed analysis
953
+ - ⏱️ Speed: ~3-6 seconds per evaluation
954
+ - πŸ’° Cost: Same as GPT Labeling
955
+ """
956
+ }
957
+
958
+ st.markdown(method_descriptions[selected_method])
959
+
960
+ # Get maximum test samples available for current dataset
961
+ try:
962
+ loader = RAGBenchLoader()
963
+ max_test_samples = loader.get_test_data_size(st.session_state.dataset_name)
964
+ st.caption(f"πŸ“Š Available test samples: {max_test_samples:,}")
965
+ except Exception as e:
966
+ max_test_samples = 100
967
+ st.caption(f"Available test samples: ~{max_test_samples} (estimated)")
968
+
969
+ # Ensure min and max are reasonable
970
+ max_test_samples = max(5, min(max_test_samples, 500)) # Cap at 500 for performance
971
+
972
+ num_test_samples = st.slider(
973
+ "Number of test samples",
974
+ min_value=5,
975
+ max_value=max_test_samples,
976
+ value=min(10, max_test_samples),
977
+ step=5
978
+ )
979
+
980
+ # Show warning for GPT labeling (API cost) - only for Groq
981
+ if selected_method in ["gpt_labeling", "hybrid"]:
982
+ current_provider = st.session_state.get("llm_provider", "groq")
983
+ if current_provider == "groq":
984
+ st.warning(f"⚠️ **{evaluation_method}** requires LLM API calls. This will incur costs and be slower due to rate limiting ({settings.groq_rpm_limit} RPM).")
985
+ else:
986
+ st.info(f"ℹ️ **{evaluation_method}** using local Ollama - **No rate limits!** Evaluation will be much faster.")
987
+
988
+ if st.button("πŸ”¬ Run Evaluation", type="primary"):
989
+ # Use selected LLM for evaluation
990
+ run_evaluation(num_test_samples, selected_llm, selected_method)
991
+
992
+ # Display results
993
+ if st.session_state.evaluation_results:
994
+ results = st.session_state.evaluation_results
995
+
996
+ st.success("βœ… Evaluation Complete!")
997
+ st.divider()
998
+ st.markdown("## πŸ“Š Evaluation Metrics")
999
+
1000
+ # Display aggregate scores - handle both TRACE and GPT Labeling metric names
1001
+ st.markdown("### Main Metrics")
1002
+ col1, col2, col3, col4, col5 = st.columns(5)
1003
+
1004
+ # Determine which metrics are available
1005
+ utilization = results.get('utilization') or results.get('context_utilization', 0)
1006
+ relevance = results.get('relevance') or results.get('context_relevance', 0)
1007
+ adherence = results.get('adherence', 0)
1008
+ completeness = results.get('completeness', 0)
1009
+ average = results.get('average', 0)
1010
+
1011
+ with col1:
1012
+ st.metric("πŸ“Š Utilization", f"{utilization:.3f}")
1013
+ with col2:
1014
+ st.metric("🎯 Relevance", f"{relevance:.3f}")
1015
+ with col3:
1016
+ st.metric("βœ… Adherence", f"{adherence:.3f}")
1017
+ with col4:
1018
+ st.metric("πŸ“ Completeness", f"{completeness:.3f}")
1019
+ with col5:
1020
+ st.metric("⭐ Average", f"{average:.3f}")
1021
+
1022
+ # Detailed results summary - handle both metric types
1023
+ if "individual_scores" in results:
1024
+ with st.expander("πŸ“‹ Summary Metrics by Query"):
1025
+ df = pd.DataFrame(results["individual_scores"])
1026
+ st.dataframe(df, use_container_width=True)
1027
+
1028
+ # Detailed per-query results
1029
+ if "detailed_results" in results and results["detailed_results"]:
1030
+ with st.expander("πŸ” Detailed Per-Query Analysis"):
1031
+ for query_result in results.get("detailed_results", []):
1032
+ with st.expander(f"Query {query_result['query_id']}: {query_result['question'][:60]}..."):
1033
+ st.markdown("### Question")
1034
+ st.write(query_result['question'])
1035
+
1036
+ st.markdown("### LLM Response")
1037
+ st.write(query_result.get('llm_response', 'N/A'))
1038
+
1039
+ st.markdown("### Retrieved Documents")
1040
+ for doc_idx, doc in enumerate(query_result.get('retrieved_documents', []), 1):
1041
+ with st.expander(f"πŸ“„ Document {doc_idx}"):
1042
+ st.write(doc)
1043
+
1044
+ if query_result.get('ground_truth'):
1045
+ st.markdown("### Ground Truth")
1046
+ st.write(query_result['ground_truth'])
1047
+
1048
+ # Display metrics with correct labels based on method
1049
+ metrics = query_result.get('metrics', {})
1050
+ if metrics:
1051
+ st.markdown("### Evaluation Metrics")
1052
+ col1, col2, col3, col4, col5 = st.columns(5)
1053
+
1054
+ # Get metric values (handle both TRACE and GPT names)
1055
+ util_val = metrics.get('utilization') or metrics.get('context_utilization', 0)
1056
+ rel_val = metrics.get('relevance') or metrics.get('context_relevance', 0)
1057
+ adh_val = metrics.get('adherence', 0)
1058
+ comp_val = metrics.get('completeness', 0)
1059
+ avg_val = metrics.get('average', 0)
1060
+
1061
+ with col1:
1062
+ st.metric("Util", f"{util_val:.3f}")
1063
+ with col2:
1064
+ st.metric("Rel", f"{rel_val:.3f}")
1065
+ with col3:
1066
+ st.metric("Adh", f"{adh_val:.3f}")
1067
+ with col4:
1068
+ st.metric("Comp", f"{comp_val:.3f}")
1069
+ with col5:
1070
+ st.metric("Avg", f"{avg_val:.3f}")
1071
+
1072
+ # For GPT Labeling and Hybrid methods, show additional metrics
1073
+ method = results.get("method", "")
1074
+ if "gpt_labeling" in method or "hybrid" in method:
1075
+ # Show RMSE aggregation metrics (consistency across evaluations)
1076
+ if "rmse_metrics" in results:
1077
+ st.markdown("### πŸ“Š RMSE Aggregation (Metric Consistency)")
1078
+ rmse_data = results.get("rmse_metrics", {})
1079
+
1080
+ rmse_cols = st.columns(4)
1081
+ with rmse_cols[0]:
1082
+ rel_mean = rmse_data.get("context_relevance", {}).get("mean", 0)
1083
+ rel_std = rmse_data.get("context_relevance", {}).get("std_dev", 0)
1084
+ st.metric("Relevance", f"{rel_mean:.3f} Β±{rel_std:.3f}", help="Mean and Std Dev")
1085
+ with rmse_cols[1]:
1086
+ util_mean = rmse_data.get("context_utilization", {}).get("mean", 0)
1087
+ util_std = rmse_data.get("context_utilization", {}).get("std_dev", 0)
1088
+ st.metric("Utilization", f"{util_mean:.3f} Β±{util_std:.3f}", help="Mean and Std Dev")
1089
+ with rmse_cols[2]:
1090
+ comp_mean = rmse_data.get("completeness", {}).get("mean", 0)
1091
+ comp_std = rmse_data.get("completeness", {}).get("std_dev", 0)
1092
+ st.metric("Completeness", f"{comp_mean:.3f} Β±{comp_std:.3f}", help="Mean and Std Dev")
1093
+ with rmse_cols[3]:
1094
+ adh_mean = rmse_data.get("adherence", {}).get("mean", 0)
1095
+ adh_std = rmse_data.get("adherence", {}).get("std_dev", 0)
1096
+ st.metric("Adherence", f"{adh_mean:.3f} Β±{adh_std:.3f}", help="Mean and Std Dev")
1097
+
1098
+ # Show detailed RMSE statistics in expander
1099
+ with st.expander("See detailed RMSE aggregation statistics"):
1100
+ for metric_name, metric_data in rmse_data.items():
1101
+ st.write(f"**{metric_name}**")
1102
+ col1, col2, col3, col4 = st.columns(4)
1103
+ with col1:
1104
+ st.write(f"Mean: {metric_data.get('mean', 0):.4f}")
1105
+ with col2:
1106
+ st.write(f"Std Dev: {metric_data.get('std_dev', 0):.4f}")
1107
+ with col3:
1108
+ st.write(f"Min: {metric_data.get('min', 0):.4f}")
1109
+ with col4:
1110
+ st.write(f"Max: {metric_data.get('max', 0):.4f}")
1111
+
1112
+ # Show per-metric statistics if available
1113
+ if "per_metric_statistics" in results:
1114
+ st.markdown("### πŸ“ˆ Per-Metric Statistics (Distribution)")
1115
+ stats_data = results.get("per_metric_statistics", {})
1116
+
1117
+ stats_cols = st.columns(4)
1118
+ with stats_cols[0]:
1119
+ rel_stats = stats_data.get("context_relevance", {})
1120
+ st.metric("Relevance Mean", f"{rel_stats.get('mean', 0):.3f}", help=f"Median: {rel_stats.get('median', 0):.3f}")
1121
+ with stats_cols[1]:
1122
+ util_stats = stats_data.get("context_utilization", {})
1123
+ st.metric("Utilization Mean", f"{util_stats.get('mean', 0):.3f}", help=f"Median: {util_stats.get('median', 0):.3f}")
1124
+ with stats_cols[2]:
1125
+ comp_stats = stats_data.get("completeness", {})
1126
+ st.metric("Completeness Mean", f"{comp_stats.get('mean', 0):.3f}", help=f"Median: {comp_stats.get('median', 0):.3f}")
1127
+ with stats_cols[3]:
1128
+ adh_stats = stats_data.get("adherence", {})
1129
+ st.metric("Adherence Mean", f"{adh_stats.get('mean', 0):.3f}", help=f"Median: {adh_stats.get('median', 0):.3f}")
1130
+
1131
+ # Show detailed statistics
1132
+ with st.expander("See detailed per-metric statistics"):
1133
+ for metric_name, metric_stats in stats_data.items():
1134
+ st.write(f"**{metric_name}**")
1135
+ col1, col2 = st.columns(2)
1136
+ with col1:
1137
+ st.write(f"""
1138
+ - Mean: {metric_stats.get('mean', 0):.4f}
1139
+ - Median: {metric_stats.get('median', 0):.4f}
1140
+ - Std Dev: {metric_stats.get('std_dev', 0):.4f}
1141
+ - Min: {metric_stats.get('min', 0):.4f}
1142
+ - Max: {metric_stats.get('max', 0):.4f}
1143
+ """)
1144
+ with col2:
1145
+ st.write(f"""
1146
+ - 25th percentile: {metric_stats.get('percentile_25', 0):.4f}
1147
+ - 75th percentile: {metric_stats.get('percentile_75', 0):.4f}
1148
+ - Perfect (>=0.95): {metric_stats.get('perfect_count', 0)}
1149
+ - Poor (<0.3): {metric_stats.get('poor_count', 0)}
1150
+ - Samples: {metric_stats.get('sample_count', 0)}
1151
+ """)
1152
+
1153
+ # Show RMSE vs RAGBench Ground Truth (per RAGBench paper requirement)
1154
+ if "rmse_vs_ground_truth" in results:
1155
+ st.markdown("### πŸ“‰ RMSE vs RAGBench Ground Truth")
1156
+ st.info("Compares predicted TRACE scores against original RAGBench dataset scores")
1157
+ rmse_gt = results.get("rmse_vs_ground_truth", {})
1158
+ per_metric_rmse = rmse_gt.get("per_metric_rmse", {})
1159
+
1160
+ if per_metric_rmse:
1161
+ rmse_gt_cols = st.columns(5)
1162
+ with rmse_gt_cols[0]:
1163
+ st.metric("Relevance RMSE", f"{per_metric_rmse.get('context_relevance', 0):.4f}",
1164
+ delta=None, help="Lower is better (0 = perfect match)")
1165
+ with rmse_gt_cols[1]:
1166
+ st.metric("Utilization RMSE", f"{per_metric_rmse.get('context_utilization', 0):.4f}")
1167
+ with rmse_gt_cols[2]:
1168
+ st.metric("Completeness RMSE", f"{per_metric_rmse.get('completeness', 0):.4f}")
1169
+ with rmse_gt_cols[3]:
1170
+ st.metric("Adherence RMSE", f"{per_metric_rmse.get('adherence', 0):.4f}")
1171
+ with rmse_gt_cols[4]:
1172
+ agg_rmse = rmse_gt.get("aggregated_rmse", 0)
1173
+ consistency = rmse_gt.get("consistency_score", 0)
1174
+ st.metric("Aggregated RMSE", f"{agg_rmse:.4f}",
1175
+ delta=f"Consistency: {consistency:.2%}", delta_color="normal")
1176
+
1177
+ # Show AUCROC vs RAGBench Ground Truth (per RAGBench paper requirement)
1178
+ if "aucroc_vs_ground_truth" in results:
1179
+ st.markdown("### πŸ“Š AUC-ROC vs RAGBench Ground Truth")
1180
+ st.info("Area Under ROC Curve comparing predicted vs ground truth binary classifications")
1181
+ auc_gt = results.get("aucroc_vs_ground_truth", {})
1182
+
1183
+ if auc_gt:
1184
+ auc_cols = st.columns(5)
1185
+ with auc_cols[0]:
1186
+ st.metric("Relevance AUC", f"{auc_gt.get('context_relevance', 0):.4f}",
1187
+ help="Higher is better (1.0 = perfect classification)")
1188
+ with auc_cols[1]:
1189
+ st.metric("Utilization AUC", f"{auc_gt.get('context_utilization', 0):.4f}")
1190
+ with auc_cols[2]:
1191
+ st.metric("Completeness AUC", f"{auc_gt.get('completeness', 0):.4f}")
1192
+ with auc_cols[3]:
1193
+ st.metric("Adherence AUC", f"{auc_gt.get('adherence', 0):.4f}")
1194
+ with auc_cols[4]:
1195
+ avg_auc = auc_gt.get("average", 0)
1196
+ st.metric("Average AUC", f"{avg_auc:.4f}")
1197
+
1198
+ # Download results
1199
+ st.divider()
1200
+ st.markdown("## πŸ’Ύ Download Results")
1201
+
1202
+ # Create a comprehensive download with all details
1203
+ download_data = {
1204
+ "evaluation_metadata": {
1205
+ "timestamp": datetime.now().isoformat(),
1206
+ "dataset": st.session_state.dataset_name,
1207
+ "method": results.get("evaluation_config", {}).get("evaluation_method", "gpt_labeling_prompts"),
1208
+ "total_samples": results.get("num_samples", 0),
1209
+ "embedding_model": st.session_state.embedding_model,
1210
+ },
1211
+ "aggregate_metrics": {
1212
+ "context_relevance": results.get("context_relevance") or results.get("relevance", 0),
1213
+ "context_utilization": results.get("context_utilization") or results.get("utilization", 0),
1214
+ "completeness": results.get("completeness", 0),
1215
+ "adherence": results.get("adherence", 0),
1216
+ "average": results.get("average", 0),
1217
+ },
1218
+ "rmse_metrics": results.get("rmse_metrics", {}),
1219
+ "per_metric_statistics": results.get("per_metric_statistics", {}),
1220
+ "rmse_vs_ground_truth": results.get("rmse_vs_ground_truth", {}),
1221
+ "aucroc_vs_ground_truth": results.get("aucroc_vs_ground_truth", {}),
1222
+ "detailed_results": results.get("detailed_results", [])
1223
+ }
1224
+
1225
+ results_json = json.dumps(download_data, indent=2, default=str)
1226
+
1227
+ col1, col2 = st.columns(2)
1228
+ with col1:
1229
+ st.download_button(
1230
+ label="πŸ“₯ Download Complete Results (JSON)",
1231
+ data=results_json,
1232
+ file_name=f"evaluation_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
1233
+ mime="application/json",
1234
+ help="Download all evaluation results including metrics and per-query details"
1235
+ )
1236
+ with col2:
1237
+ st.download_button(
1238
+ label="πŸ“‹ Download Metrics Only (JSON)",
1239
+ data=json.dumps(download_data["aggregate_metrics"], indent=2),
1240
+ file_name=f"evaluation_metrics_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
1241
+ mime="application/json",
1242
+ help="Download only the aggregate metrics"
1243
+ )
1244
+
1245
+
1246
+ def run_evaluation(num_samples: int, selected_llm: str = None, method: str = "trace"):
1247
+ """Run evaluation using selected method (TRACE, GPT Labeling, or Hybrid).
1248
+
1249
+ Args:
1250
+ num_samples: Number of test samples to evaluate
1251
+ selected_llm: LLM model to use for evaluation
1252
+ method: Evaluation method ("trace", "gpt_labeling", or "hybrid")
1253
+ """
1254
+ with st.spinner(f"Running evaluation on {num_samples} samples..."):
1255
+ try:
1256
+ # Create logs container
1257
+ logs_container = st.container()
1258
+ logs_list = []
1259
+
1260
+ # Display logs header once outside function
1261
+ logs_placeholder = st.empty()
1262
+
1263
+ def add_log(message: str):
1264
+ """Add log message and update display."""
1265
+ logs_list.append(message)
1266
+ with logs_placeholder.container():
1267
+ st.markdown("### πŸ“‹ Evaluation Logs:")
1268
+ for log_msg in logs_list:
1269
+ st.caption(log_msg)
1270
+
1271
+ # Log evaluation start
1272
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
1273
+ add_log(f"⏱️ Evaluation started at {timestamp}")
1274
+ add_log(f"πŸ“Š Dataset: {st.session_state.dataset_name}")
1275
+ add_log(f"πŸ“ˆ Total samples: {num_samples}")
1276
+ add_log(f"πŸ€– LLM Model: {selected_llm if selected_llm else st.session_state.current_llm}")
1277
+ add_log(f"πŸ”— Vector Store: {st.session_state.collection_name}")
1278
+ add_log(f"🧠 Embedding Model: {st.session_state.embedding_model}")
1279
+
1280
+ # Map method names
1281
+ method_names = {
1282
+ "trace": "TRACE (Heuristic)",
1283
+ "gpt_labeling": "GPT Labeling (LLM-based)",
1284
+ "hybrid": "Hybrid (Both)"
1285
+ }
1286
+ add_log(f"πŸ”¬ Evaluation Method: {method_names.get(method, method)}")
1287
+
1288
+ # Use selected LLM if provided - create with appropriate provider
1289
+ eval_llm_client = None
1290
+ original_llm = None
1291
+ current_provider = st.session_state.get("llm_provider", "groq")
1292
+
1293
+ if selected_llm and selected_llm != st.session_state.current_llm:
1294
+ add_log(f"πŸ”„ Switching LLM to {selected_llm} ({current_provider.upper()})...")
1295
+ groq_api_key = st.session_state.groq_api_key if "groq_api_key" in st.session_state else ""
1296
+ eval_llm_client = create_llm_client(
1297
+ provider=current_provider,
1298
+ api_key=groq_api_key,
1299
+ api_keys=settings.groq_api_keys if settings.groq_api_keys else None,
1300
+ model_name=selected_llm,
1301
+ ollama_host=settings.ollama_host,
1302
+ max_rpm=settings.groq_rpm_limit,
1303
+ rate_limit_delay=settings.rate_limit_delay,
1304
+ max_retries=settings.max_retries,
1305
+ retry_delay=settings.retry_delay
1306
+ )
1307
+ # Temporarily replace LLM client
1308
+ original_llm = st.session_state.rag_pipeline.llm
1309
+ st.session_state.rag_pipeline.llm = eval_llm_client
1310
+ else:
1311
+ eval_llm_client = st.session_state.rag_pipeline.llm
1312
+
1313
+ # Log provider info
1314
+ provider_icon = "☁️" if current_provider == "groq" else "πŸ–₯️"
1315
+ add_log(f"{provider_icon} LLM Provider: {current_provider.upper()}")
1316
+
1317
+ # Get test data
1318
+ add_log("πŸ“₯ Loading test data...")
1319
+ loader = RAGBenchLoader()
1320
+ test_data = loader.get_test_data(
1321
+ st.session_state.dataset_name,
1322
+ num_samples
1323
+ )
1324
+ add_log(f"βœ… Loaded {len(test_data)} test samples")
1325
+
1326
+ # Prepare test cases
1327
+ test_cases = []
1328
+
1329
+ progress_bar = st.progress(0)
1330
+ status_text = st.empty()
1331
+
1332
+ add_log("πŸ” Processing samples...")
1333
+ for i, sample in enumerate(test_data):
1334
+ status_text.text(f"Processing sample {i+1}/{num_samples}")
1335
+
1336
+ # Query the RAG system
1337
+ result = st.session_state.rag_pipeline.query(
1338
+ sample["question"],
1339
+ n_results=5
1340
+ )
1341
+
1342
+ # Prepare test case
1343
+ test_cases.append({
1344
+ "query": sample["question"],
1345
+ "response": result["response"],
1346
+ "retrieved_documents": [doc["document"] for doc in result["retrieved_documents"]],
1347
+ "ground_truth": sample.get("answer", "")
1348
+ })
1349
+
1350
+ # Update progress
1351
+ progress_bar.progress((i + 1) / num_samples)
1352
+
1353
+ # Log every 10 samples
1354
+ if (i + 1) % 10 == 0 or (i + 1) == num_samples:
1355
+ add_log(f" βœ“ Processed {i + 1}/{num_samples} samples")
1356
+
1357
+ status_text.text(f"Running {method_names.get(method, method)} evaluation...")
1358
+ add_log(f"πŸ“Š Running evaluation using {method_names.get(method, method)}...")
1359
+
1360
+ # Extract chunking and embedding metadata from session state
1361
+ # (These were stored when the collection was loaded/created)
1362
+ chunking_strategy = st.session_state.vector_store.chunking_strategy if st.session_state.vector_store else None
1363
+ embedding_model = st.session_state.embedding_model
1364
+ chunk_size = st.session_state.vector_store.chunk_size if st.session_state.vector_store else None
1365
+ chunk_overlap = st.session_state.vector_store.chunk_overlap if st.session_state.vector_store else None
1366
+
1367
+ # Log retrieval configuration
1368
+ add_log(f"πŸ”§ Retrieval Configuration:")
1369
+ add_log(f" β€’ Chunking Strategy: {chunking_strategy or 'Unknown'}")
1370
+ add_log(f" β€’ Chunk Size: {chunk_size or 'Unknown'}")
1371
+ add_log(f" β€’ Chunk Overlap: {chunk_overlap or 'Unknown'}")
1372
+ add_log(f" β€’ Embedding Model: {embedding_model or 'Unknown'}")
1373
+
1374
+ # Import unified pipeline
1375
+ try:
1376
+ from evaluation_pipeline import UnifiedEvaluationPipeline
1377
+
1378
+ # Run evaluation with metadata using unified pipeline
1379
+ pipeline = UnifiedEvaluationPipeline(
1380
+ llm_client=eval_llm_client,
1381
+ chunking_strategy=chunking_strategy,
1382
+ embedding_model=embedding_model,
1383
+ chunk_size=chunk_size,
1384
+ chunk_overlap=chunk_overlap
1385
+ )
1386
+
1387
+ # Run evaluation with selected method
1388
+ results = pipeline.evaluate_batch(test_cases, method=method)
1389
+
1390
+ except ImportError:
1391
+ # Fallback to TRACE only if evaluation_pipeline module not available
1392
+ add_log("⚠️ evaluation_pipeline module not found, falling back to TRACE...")
1393
+
1394
+ # Run evaluation with metadata using TRACE
1395
+ evaluator = TRACEEvaluator(
1396
+ chunking_strategy=chunking_strategy,
1397
+ embedding_model=embedding_model,
1398
+ chunk_size=chunk_size,
1399
+ chunk_overlap=chunk_overlap
1400
+ )
1401
+ results = evaluator.evaluate_batch(test_cases)
1402
+
1403
+ st.session_state.evaluation_results = results
1404
+
1405
+ # Log evaluation results summary
1406
+ add_log("βœ… Evaluation completed successfully!")
1407
+
1408
+ # Display appropriate metrics based on method
1409
+ if method == "trace":
1410
+ add_log(f" β€’ Utilization: {results.get('utilization', 0):.2%}")
1411
+ add_log(f" β€’ Relevance: {results.get('relevance', 0):.2%}")
1412
+ add_log(f" β€’ Adherence: {results.get('adherence', 0):.2%}")
1413
+ add_log(f" β€’ Completeness: {results.get('completeness', 0):.2%}")
1414
+ add_log(f" β€’ Average: {results.get('average', 0):.2%}")
1415
+ elif method == "gpt_labeling":
1416
+ if "context_relevance" in results:
1417
+ add_log(f" β€’ Context Relevance: {results.get('context_relevance', 0):.2%}")
1418
+ add_log(f" β€’ Context Utilization: {results.get('context_utilization', 0):.2%}")
1419
+ add_log(f" β€’ Completeness: {results.get('completeness', 0):.2%}")
1420
+ add_log(f" β€’ Adherence: {results.get('adherence', 0):.2%}")
1421
+ add_log(f" β€’ Average: {results.get('average', 0):.2%}")
1422
+ # NEW: Display RMSE and AUCROC metrics if available
1423
+ if "rmse_metrics" in results:
1424
+ add_log(f"πŸ“ˆ RMSE Metrics (vs ground truth):")
1425
+ rmse_metrics = results.get("rmse_metrics", {})
1426
+ add_log(f" β€’ Context Relevance RMSE: {rmse_metrics.get('relevance', 0):.4f}")
1427
+ add_log(f" β€’ Context Utilization RMSE: {rmse_metrics.get('utilization', 0):.4f}")
1428
+ add_log(f" β€’ Completeness RMSE: {rmse_metrics.get('completeness', 0):.4f}")
1429
+ add_log(f" β€’ Adherence RMSE: {rmse_metrics.get('adherence', 0):.4f}")
1430
+ add_log(f" β€’ Average RMSE: {rmse_metrics.get('average', 0):.4f}")
1431
+ if "auc_metrics" in results:
1432
+ add_log(f"πŸ“Š AUCROC Metrics (binary classification):")
1433
+ auc_metrics = results.get("auc_metrics", {})
1434
+ add_log(f" β€’ Context Relevance AUCROC: {auc_metrics.get('relevance', 0):.4f}")
1435
+ add_log(f" β€’ Context Utilization AUCROC: {auc_metrics.get('utilization', 0):.4f}")
1436
+ add_log(f" β€’ Completeness AUCROC: {auc_metrics.get('completeness', 0):.4f}")
1437
+ add_log(f" β€’ Adherence AUCROC: {auc_metrics.get('adherence', 0):.4f}")
1438
+ add_log(f" β€’ Average AUCROC: {auc_metrics.get('average', 0):.4f}")
1439
+ elif method == "hybrid":
1440
+ add_log(" πŸ“Š TRACE Metrics:")
1441
+ trace_res = results.get("trace_results", {})
1442
+ add_log(f" β€’ Utilization: {trace_res.get('utilization', 0):.2%}")
1443
+ add_log(f" β€’ Relevance: {trace_res.get('relevance', 0):.2%}")
1444
+ add_log(f" β€’ Adherence: {trace_res.get('adherence', 0):.2%}")
1445
+ add_log(f" β€’ Completeness: {trace_res.get('completeness', 0):.2%}")
1446
+ add_log(" 🧠 GPT Labeling Metrics:")
1447
+ gpt_res = results.get("gpt_results", {})
1448
+ add_log(f" β€’ Context Relevance: {gpt_res.get('context_relevance', 0):.2%}")
1449
+ add_log(f" β€’ Context Utilization: {gpt_res.get('context_utilization', 0):.2%}")
1450
+ add_log(f" β€’ Completeness: {gpt_res.get('completeness', 0):.2%}")
1451
+ add_log(f" β€’ Adherence: {gpt_res.get('adherence', 0):.2%}")
1452
+
1453
+ # Restore original LLM if it was switched
1454
+ if selected_llm and selected_llm != st.session_state.current_llm and original_llm:
1455
+ st.session_state.rag_pipeline.llm = original_llm
1456
+ add_log(f"πŸ”„ Restored original LLM")
1457
+
1458
+ add_log(f"⏱️ Evaluation completed at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
1459
+
1460
+ except Exception as e:
1461
+ st.error(f"Error during evaluation: {str(e)}")
1462
+ add_log(f"❌ Error: {str(e)}")
1463
+
1464
+
1465
+ def history_interface():
1466
+ """History interface tab."""
1467
+ st.subheader("πŸ“œ Chat History")
1468
+
1469
+ if not st.session_state.chat_history:
1470
+ st.info("No chat history yet. Start a conversation in the Chat tab!")
1471
+ return
1472
+
1473
+ # Export history
1474
+ col1, col2 = st.columns([3, 1])
1475
+ with col2:
1476
+ history_json = json.dumps(st.session_state.chat_history, indent=2)
1477
+ st.download_button(
1478
+ label="πŸ’Ύ Export History",
1479
+ data=history_json,
1480
+ file_name=f"chat_history_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
1481
+ mime="application/json"
1482
+ )
1483
+
1484
+ # Display history
1485
+ for i, entry in enumerate(st.session_state.chat_history):
1486
+ with st.expander(f"πŸ’¬ Conversation {i+1}: {entry['query'][:50]}..."):
1487
+ st.markdown(f"**Query:** {entry['query']}")
1488
+ st.markdown(f"**Response:** {entry['response']}")
1489
+ st.markdown(f"**Timestamp:** {entry.get('timestamp', 'N/A')}")
1490
+
1491
+ st.markdown("**Retrieved Documents:**")
1492
+ for j, doc in enumerate(entry["retrieved_documents"]):
1493
+ st.text_area(
1494
+ f"Document {j+1}",
1495
+ value=doc["document"],
1496
+ height=100,
1497
+ key=f"history_doc_{i}_{j}"
1498
+ )
1499
+
1500
+
1501
+ if __name__ == "__main__":
1502
+ main()