"""
Premium Streamlit UI for LangGraph RAG Q&A Agent
Enhanced with Blue & Black theme and dynamic dashboard
"""
import streamlit as st
import sys
from pathlib import Path
import json
from datetime import datetime, timezone
import plotly.graph_objects as go
import plotly.express as px
from typing import Dict, Any
# Add src to path
sys.path.insert(0, str(Path(__file__).parent))
from rag_pipeline import RAGPipeline
from llm_utils import create_llm_handler
from reflection import create_reflection_evaluator
from agent_workflow import create_rag_agent
from evaluation import create_evaluator
# Premium CSS Styling - Blue & Black Theme
PREMIUM_CSS = """
"""
@st.cache_resource
def initialize_agent(provider="huggingface", use_llm_reflection=False):
"""Initialize and cache the RAG agent."""
from pathlib import Path
import os
# Get correct paths
current_file = Path(__file__).resolve()
project_root = current_file.parent.parent
data_dir = project_root / "data"
chroma_dir = project_root / "chroma_db"
# Verify data directory exists
if not data_dir.exists():
raise FileNotFoundError(f"Data directory not found: {data_dir}")
# Initialize RAG pipeline with correct paths
rag_pipeline = RAGPipeline(
data_directory=str(data_dir),
collection_name="rag_knowledge_base",
persist_directory=str(chroma_dir)
)
rag_pipeline.build_index(force_rebuild=False)
llm_handler = create_llm_handler(
provider=provider,
model_name="google/flan-t5-large", # Force Flan-T5
temperature=0.7,
max_tokens=500
)
reflection_evaluator = create_reflection_evaluator(
llm_handler=llm_handler if use_llm_reflection else None,
use_llm_reflection=use_llm_reflection
)
agent = create_rag_agent(
rag_pipeline=rag_pipeline,
llm_handler=llm_handler,
reflection_evaluator=reflection_evaluator,
max_iterations=2
)
return agent
def initialize_evaluator():
"""Initialize the RAG evaluator."""
return create_evaluator()
def create_gauge_chart(value: float, title: str, max_value: float = 1.0) -> go.Figure:
"""Create a premium gauge chart."""
fig = go.Figure(go.Indicator(
mode="gauge+number+delta",
value=value * 100,
domain={'x': [0, 1], 'y': [0, 1]},
title={'text': title, 'font': {'size': 16, 'color': '#e0e7ff'}},
number={'suffix': "%", 'font': {'size': 40, 'color': '#3b82f6'}},
gauge={
'axis': {'range': [None, 100], 'tickwidth': 1, 'tickcolor': "#94a3b8"},
'bar': {'color': "#3b82f6"},
'bgcolor': "rgba(30, 41, 59, 0.5)",
'borderwidth': 2,
'bordercolor': "rgba(59, 130, 246, 0.3)",
'steps': [
{'range': [0, 40], 'color': 'rgba(239, 68, 68, 0.3)'},
{'range': [40, 70], 'color': 'rgba(245, 158, 11, 0.3)'},
{'range': [70, 100], 'color': 'rgba(16, 185, 129, 0.3)'}
],
'threshold': {
'line': {'color': "#8b5cf6", 'width': 4},
'thickness': 0.75,
'value': 80
}
}
))
fig.update_layout(
paper_bgcolor='rgba(0,0,0,0)',
plot_bgcolor='rgba(0,0,0,0)',
font={'color': "#e0e7ff", 'family': "Inter"},
height=250,
margin=dict(l=20, r=20, t=50, b=20)
)
return fig
def create_bar_chart(data: Dict[str, float], title: str) -> go.Figure:
"""Create a premium bar chart."""
fig = go.Figure(data=[
go.Bar(
x=list(data.keys()),
y=list(data.values()),
marker=dict(
color=list(data.values()),
colorscale=[[0, '#ef4444'], [0.5, '#f59e0b'], [1, '#10b981']],
line=dict(color='rgba(59, 130, 246, 0.5)', width=2)
),
text=[f'{v:.3f}' for v in data.values()],
textposition='outside',
textfont=dict(color='#e0e7ff', size=14)
)
])
fig.update_layout(
title=dict(text=title, font=dict(size=18, color='#e0e7ff')),
paper_bgcolor='rgba(0,0,0,0)',
plot_bgcolor='rgba(30, 41, 59, 0.3)',
font={'color': "#e0e7ff", 'family': "Inter"},
xaxis=dict(gridcolor='rgba(59, 130, 246, 0.1)'),
yaxis=dict(gridcolor='rgba(59, 130, 246, 0.1)', range=[0, 1]),
height=300,
margin=dict(l=40, r=40, t=60, b=40)
)
return fig
def create_radar_chart(scores: Dict[str, float]) -> go.Figure:
"""Create a premium radar chart for score breakdown."""
categories = list(scores.keys())
values = list(scores.values())
fig = go.Figure(data=go.Scatterpolar(
r=values,
theta=[cat.replace('_', ' ').title() for cat in categories],
fill='toself',
fillcolor='rgba(59, 130, 246, 0.3)',
line=dict(color='#3b82f6', width=3)
))
fig.update_layout(
polar=dict(
bgcolor='rgba(30, 41, 59, 0.3)',
radialaxis=dict(
visible=True,
range=[0, 1],
gridcolor='rgba(59, 130, 246, 0.2)',
tickfont=dict(color='#94a3b8')
),
angularaxis=dict(
gridcolor='rgba(59, 130, 246, 0.2)',
tickfont=dict(color='#e0e7ff', size=11)
)
),
paper_bgcolor='rgba(0,0,0,0)',
font={'color': "#e0e7ff", 'family': "Inter"},
height=400,
margin=dict(l=80, r=80, t=40, b=40)
)
return fig
def display_premium_metrics(evaluation_result: Dict, reflection_result: Dict):
"""Display premium metrics dashboard."""
st.markdown("---")
st.markdown("## 📊 **Dynamic Performance Dashboard**")
metrics = evaluation_result.get("metrics", {})
# Tab system
tab1, tab2, tab3, tab4 = st.tabs([
"🎯 **Overview**",
"📈 **Quality Scores**",
"🔍 **Reflection Analysis**",
"📋 **Detailed Report**"
])
with tab1:
st.markdown("### Real-Time Performance Metrics")
# Top row - Gauge charts
col1, col2, col3 = st.columns(3)
with col1:
context_rel = metrics.get('context_relevance', 0)
fig = create_gauge_chart(context_rel, "Context Relevance")
st.plotly_chart(fig, use_container_width=True)
with col2:
reflection_score = reflection_result.get('score', 0)
fig = create_gauge_chart(reflection_score, "Overall Quality")
st.plotly_chart(fig, use_container_width=True)
with col3:
# Calculate average score
avg_score = 0
count = 0
if "rouge" in metrics:
avg_score += sum(metrics["rouge"].values()) / len(metrics["rouge"])
count += 1
if "bertscore" in metrics:
avg_score += metrics["bertscore"].get("f1", 0)
count += 1
if count > 0:
avg_score /= count
fig = create_gauge_chart(avg_score if count > 0 else reflection_score, "Combined Score")
st.plotly_chart(fig, use_container_width=True)
# Bottom row - Key stats
st.markdown("### Key Statistics")
col1, col2, col3, col4 = st.columns(4)
with col1:
st.markdown(f"""
{metrics.get('answer_length', 0)}
Characters
""", unsafe_allow_html=True)
with col2:
st.markdown(f"""
{metrics.get('word_count', 0)}
Words
""", unsafe_allow_html=True)
with col3:
st.markdown(f"""
{evaluation_result.get('num_contexts', 0)}
Contexts Used
""", unsafe_allow_html=True)
with col4:
iterations = st.session_state.get('iterations', 0)
st.markdown(f"""
""", unsafe_allow_html=True)
with tab2:
st.markdown("### Quality Assessment Scores")
col1, col2 = st.columns(2)
with col1:
if "rouge" in metrics:
st.markdown("#### 📝 ROUGE Scores")
rouge_data = {
'ROUGE-1': metrics["rouge"].get('rouge1', 0),
'ROUGE-2': metrics["rouge"].get('rouge2', 0),
'ROUGE-L': metrics["rouge"].get('rougeL', 0)
}
fig = create_bar_chart(rouge_data, "ROUGE Score Analysis")
st.plotly_chart(fig, use_container_width=True)
else:
st.info("💡 Add a reference answer to see ROUGE scores")
with col2:
if "bertscore" in metrics:
st.markdown("#### 🧠 BERTScore Metrics")
bert_data = {
'Precision': metrics["bertscore"].get('precision', 0),
'Recall': metrics["bertscore"].get('recall', 0),
'F1 Score': metrics["bertscore"].get('f1', 0)
}
fig = create_bar_chart(bert_data, "BERTScore Analysis")
st.plotly_chart(fig, use_container_width=True)
else:
st.info("💡 Add a reference answer to see BERTScore")
with tab3:
st.markdown("### Reflection Analysis Dashboard")
col1, col2 = st.columns([1, 2])
with col1:
relevance = reflection_result.get('relevance', 'Unknown')
if relevance == "Relevant":
badge_class = "status-relevant"
icon = "✅"
elif relevance == "Partially Relevant":
badge_class = "status-partial"
icon = "⚠️"
else:
badge_class = "status-irrelevant"
icon = "❌"
st.markdown(f"""
{icon} {relevance}
{reflection_result.get('score', 0):.1%}
Quality Score
Recommendation:
{reflection_result.get('recommendation', 'N/A')}
""", unsafe_allow_html=True)
with col2:
st.markdown("#### 💭 Reasoning")
st.markdown(f"""
{reflection_result.get('reasoning', 'No reasoning provided')}
""", unsafe_allow_html=True)
# Radar chart for score breakdown
if reflection_result.get('method') == 'heuristic':
breakdown = reflection_result.get('score_breakdown', {})
if breakdown:
st.markdown("#### 📊 Score Breakdown")
fig = create_radar_chart(breakdown)
st.plotly_chart(fig, use_container_width=True)
with tab4:
st.markdown("### Detailed Evaluation Report")
# Timestamp - FIXED
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC")
st.markdown(f"**Generated:** `{timestamp}`")
col1, col2 = st.columns(2)
with col1:
st.markdown("#### Evaluation Metrics")
st.json(metrics)
with col2:
st.markdown("#### Reflection Analysis")
st.json(reflection_result)
# Download button
combined_data = {
"timestamp": timestamp,
"query": evaluation_result.get("query", ""),
"generated_answer": evaluation_result.get("generated_answer", ""),
"evaluation_metrics": metrics,
"reflection_analysis": reflection_result
}
json_str = json.dumps(combined_data, indent=2)
st.download_button(
label="📥 **Download Complete Report (JSON)**",
data=json_str,
file_name=f"rag_evaluation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
mime="application/json",
use_container_width=True
)
def main():
"""Main Premium Streamlit app."""
st.set_page_config(
page_title="LangGraph RAG Q&A Agent - Premium",
page_icon="🤖",
layout="wide",
initial_sidebar_state="expanded"
)
# Apply premium CSS
st.markdown(PREMIUM_CSS, unsafe_allow_html=True)
# Premium Header
st.markdown("""
""", unsafe_allow_html=True)
# Sidebar
with st.sidebar:
st.markdown("## ⚙️ **Control Panel**")
provider = st.selectbox(
"🔌 LLM Provider",
["huggingface", "openai"],
help="Select your preferred LLM provider"
)
use_llm_reflection = st.checkbox(
"🧠 LLM Reflection Mode",
value=False,
help="Enable AI-powered reflection (more accurate, slower)"
)
enable_evaluation = st.checkbox(
"📊 Advanced Analytics",
value=True,
help="Enable comprehensive evaluation metrics"
)
st.markdown("---")
st.markdown("### 📝 **Reference Answer**")
reference_answer = st.text_area(
"Optional: For comparison metrics",
placeholder="Provide a reference answer to calculate advanced metrics...",
height=100,
label_visibility="collapsed"
)
st.markdown("---")
st.markdown("### 💡 **Quick Queries**")
sample_queries = [
("🤖", "What is machine learning?"),
("🐍", "Explain Python programming"),
("☁️", "What is cloud computing?"),
("💾", "Tell me about databases"),
("🧠", "What is deep learning?"),
("📊", "Explain supervised learning"),
("🗄️", "What are NoSQL databases?"),
("💬", "What is NLP?")
]
for icon, query in sample_queries:
if st.button(f"{icon} {query}", key=f"sample_{query}", use_container_width=True):
st.session_state["query"] = query
st.markdown("---")
st.markdown("""
📊 Metrics Available
- Context Relevance
- Quality Scores
- ROUGE Analysis*
- BERTScore*
- Reflection Insights
*Requires reference answer
""", unsafe_allow_html=True)
# Initialize components
try:
with st.spinner("🚀 Initializing AI Agent..."):
agent = initialize_agent(provider, use_llm_reflection)
if enable_evaluation:
evaluator = initialize_evaluator()
st.sidebar.success("✅ **System Online**")
except Exception as e:
st.sidebar.error(f"❌ **Error:** {str(e)[:50]}...")
st.stop()
# Main query interface
st.markdown("### 💬 **Ask Your Question**")
query = st.text_input(
"Query input",
value=st.session_state.get("query", ""),
placeholder="Type your question about AI, Python, ML, Cloud, or Databases...",
label_visibility="collapsed"
)
col1, col2, col3, col4, col5 = st.columns([2, 2, 2, 2, 2])
with col1:
submit_button = st.button("🚀 **Ask Question**", type="primary", use_container_width=True)
with col2:
clear_button = st.button("🗑️ **Clear**", use_container_width=True)
if clear_button:
st.session_state.clear()
st.rerun()
# Process query
if submit_button and query:
with st.spinner("🤔 Processing your question..."):
try:
result = agent.query(query)
# Store iterations
st.session_state['iterations'] = result.get('iteration', 0)
# Display answer - CRITICAL SECTION
st.markdown("---")
st.markdown("## 💬 **AI Response**")
# Get answer from multiple possible keys
answer = result.get('final_response', '') or result.get('answer', '')
if answer and answer.strip():
st.markdown(f"""
{answer}
""", unsafe_allow_html=True)
else:
st.warning("⚠️ Answer was generated but appears empty. Check terminal output.")
st.code(str(result), language="json") # Debug output
# Show iteration info
if result.get("iteration", 0) > 0:
st.info(f"🔄 Answer refined {result['iteration']} time(s) using reflection feedback")
# Evaluation
if enable_evaluation:
with st.spinner("📊 Calculating analytics..."):
retrieved_contexts = None
if result.get("retrieved_chunks"):
retrieved_contexts = [chunk["content"] for chunk in result["retrieved_chunks"]]
evaluation_result = evaluator.evaluate_response(
query=query,
generated_answer=answer,
reference_answer=reference_answer if reference_answer.strip() else None,
retrieved_contexts=retrieved_contexts
)
display_premium_metrics(evaluation_result, result.get("reflection", {}))
# Processing details
st.markdown("---")
st.markdown("## 🔍 **Processing Pipeline**")
col1, col2 = st.columns(2)
with col1:
with st.expander("📋 **Planning Phase**", expanded=False):
st.markdown(f"""
{result.get("plan", "No plan available")}
""", unsafe_allow_html=True)
with col2:
chunks = result.get('retrieved_chunks', [])
with st.expander(f"🔍 **Retrieved Context** ({len(chunks)} chunks)", expanded=False):
if chunks:
for i, chunk in enumerate(chunks, 1):
st.markdown(f"""
Chunk {i} -
{chunk['metadata']['source']}
Similarity: {chunk['similarity_score']:.3f}
{chunk["content"][:300]}{'...' if len(chunk["content"]) > 300 else ''}
""", unsafe_allow_html=True)
else:
st.info("No context retrieval needed for this query")
except Exception as e:
st.error(f"❌ **Error:** {str(e)}")
with st.expander("🔍 **Error Details**"):
import traceback
st.code(traceback.format_exc())
# Premium Footer
st.markdown("""
""", unsafe_allow_html=True)
if __name__ == "__main__":
main()