Atlan / app.py
ashkunwar
Update application with enhanced features for Hugging Face deployment
4ee7173
import streamlit as st
st.set_page_config(
page_title="🎯 Atlan Customer Support Copilot",
page_icon="🎯",
layout="wide",
initial_sidebar_state="expanded"
)
import json
import asyncio
import logging
import os
from typing import List, Dict
from datetime import datetime
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from dotenv import load_dotenv
load_dotenv()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
try:
# Try multiple sources for API key: Streamlit secrets, environment variables, .env file
if hasattr(st, 'secrets') and 'GROQ_API_KEY' in st.secrets:
os.environ['GROQ_API_KEY'] = st.secrets['GROQ_API_KEY']
st.success("πŸ”‘ API key loaded from Streamlit Cloud secrets")
elif 'GROQ_API_KEY' in os.environ:
st.success("πŸ”‘ API key loaded from environment variables")
elif hasattr(st, 'secrets') and hasattr(st.secrets, 'default') and 'GROQ_API_KEY' in st.secrets.default:
os.environ['GROQ_API_KEY'] = st.secrets.default['GROQ_API_KEY']
st.success("πŸ”‘ API key loaded from Streamlit secrets")
else:
st.error("⚠️ GROQ_API_KEY not found!")
st.info("**For Hugging Face Spaces deployment:**")
st.info("Add your API key in the Space settings > Secrets tab")
st.code("""
# In Hugging Face Spaces Secrets:
GROQ_API_KEY = "your_groq_api_key_here"
""")
st.info("**For Streamlit Cloud deployment:**")
st.info("Add your API key in the Streamlit Cloud app settings > Secrets tab")
st.info("**For local development:**")
st.info("Add GROQ_API_KEY to your .env file")
st.code("""
# In .env file:
GROQ_API_KEY=your_groq_api_key_here
""")
st.stop()
except Exception as e:
st.error(f"⚠️ Error accessing API key: {e}")
st.error("Please check your configuration")
st.stop()
# Import application modules after environment setup
try:
from models import Ticket, TicketClassification, TopicTagEnum, SentimentEnum, PriorityEnum
from classifier import TicketClassifier
from enhanced_rag import EnhancedRAGPipeline
except ImportError as e:
st.error(f"❌ Failed to import required modules: {e}")
st.error("Please ensure all required files are present in the directory")
st.stop()
st.markdown("""
<style>
.main-header {
text-align: center;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 2rem;
border-radius: 10px;
margin-bottom: 2rem;
}
.ticket-card {
border: 1px solid #e1e5e9;
border-radius: 8px;
padding: 1rem;
margin: 1rem 0;
background: white;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
.tag {
background: #667eea;
color: white;
padding: 0.2rem 0.5rem;
border-radius: 15px;
font-size: 0.8rem;
margin: 0.2rem;
display: inline-block;
}
.metric-card {
background: white;
padding: 1rem;
border-radius: 8px;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
text-align: center;
}
</style>
""", unsafe_allow_html=True)
@st.cache_resource
def initialize_ai_models():
try:
classifier = TicketClassifier()
rag_pipeline = EnhancedRAGPipeline(groq_client=classifier.client)
return classifier, rag_pipeline
except Exception as e:
st.error(f"❌ Failed to initialize AI models: {e}")
return None, None
def load_sample_tickets():
try:
with open('sample_tickets.json', 'r') as f:
tickets_data = json.load(f)
return [Ticket(**ticket_data) for ticket_data in tickets_data]
except FileNotFoundError:
st.warning("πŸ“‹ Sample tickets file not found. Using demo data for cloud deployment.")
# Create minimal demo data for cloud deployment
demo_tickets = [
{
"id": "DEMO-001",
"subject": "Demo ticket - Connection issue",
"body": "This is a demo ticket showing connection problems with our data source."
},
{
"id": "DEMO-002",
"subject": "Demo ticket - API question",
"body": "This is a demo ticket asking about API usage and documentation."
}
]
return [Ticket(**ticket_data) for ticket_data in demo_tickets]
except Exception as e:
st.error(f"❌ Error loading tickets: {e}")
return []
async def classify_tickets_async(classifier, tickets):
try:
classifications = await classifier.classify_tickets_bulk(tickets)
return list(zip(tickets, classifications))
except Exception as e:
st.error(f"❌ Classification error: {e}")
return []
def run_async(coro):
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(coro)
def calculate_stats(classified_tickets):
if not classified_tickets:
return {
'total': 0,
'high_priority': 0,
'frustrated': 0,
'rag_eligible': 0,
'most_common_tag': 'N/A',
'tag_counts': {}
}
total = len(classified_tickets)
high_priority = sum(1 for _, classification in classified_tickets
if classification.priority == PriorityEnum.P0)
frustrated = sum(1 for _, classification in classified_tickets
if classification.sentiment in [SentimentEnum.FRUSTRATED, SentimentEnum.ANGRY])
# Count RAG-eligible topics
rag_topics = ['How-to', 'Product', 'Best practices', 'API/SDK', 'SSO']
rag_eligible = sum(1 for _, classification in classified_tickets
if any(tag.value in rag_topics for tag in classification.topic_tags))
# Count tag frequencies
tag_counts = {}
for _, classification in classified_tickets:
for tag in classification.topic_tags:
tag_counts[tag.value] = tag_counts.get(tag.value, 0) + 1
most_common_tag = max(tag_counts.keys(), key=lambda x: tag_counts[x]) if tag_counts else 'N/A'
return {
'total': total,
'high_priority': high_priority,
'frustrated': frustrated,
'rag_eligible': rag_eligible,
'most_common_tag': most_common_tag,
'tag_counts': tag_counts
}
def display_ticket_card(ticket, classification):
with st.container():
st.markdown(f"**{ticket.id}**")
st.write(f"**Subject:** {ticket.subject}")
st.write(f"**Message:** {ticket.body[:300]}{'...' if len(ticket.body) > 300 else ''}")
st.write("**πŸ“‹ Topics:**")
cols = st.columns(len(classification.topic_tags))
for i, tag in enumerate(classification.topic_tags):
with cols[i]:
st.markdown(f'<span style="background: #667eea; color: white; padding: 0.2rem 0.5rem; border-radius: 10px; font-size: 0.8rem; margin: 0.1rem;">{tag.value}</span>', unsafe_allow_html=True)
sentiment_color = '#ff6b6b' if 'frustrated' in classification.sentiment.value.lower() else '#ff3838' if 'angry' in classification.sentiment.value.lower() else '#4ecdc4' if 'curious' in classification.sentiment.value.lower() else '#95a5a6'
st.markdown(f"**😊 Sentiment:** <span style='background: {sentiment_color}; color: white; padding: 0.3rem 0.8rem; border-radius: 15px; font-size: 0.9rem;'>{classification.sentiment.value}</span>", unsafe_allow_html=True)
priority_color = '#ff3838' if 'P0' in classification.priority.value else '#ffa726' if 'P1' in classification.priority.value else '#66bb6a'
st.markdown(f"**πŸ”₯ Priority:** <span style='background: {priority_color}; color: white; padding: 0.3rem 0.8rem; border-radius: 15px; font-size: 0.9rem;'>{classification.priority.value}</span>", unsafe_allow_html=True)
st.write(f"**πŸ€– AI Reasoning:** {classification.reasoning}")
st.divider()
def main():
classifier, rag_pipeline = initialize_ai_models()
if classifier is None or rag_pipeline is None:
st.stop()
st.markdown("""
<div class="main-header">
<h1>🎯 Atlan Customer Support Copilot</h1>
<p>AI-powered ticket classification and intelligent response generation</p>
</div>
""", unsafe_allow_html=True)
# Sidebar navigation
st.sidebar.title("🧭 Navigation")
page = st.sidebar.selectbox("Choose a page", [
"πŸ“Š Bulk Classification Dashboard",
"πŸ€– Interactive AI Agent",
"πŸ“ Single Ticket Classification",
"πŸ“‚ Upload & Classify"
])
# Page routing
if page == "πŸ“Š Bulk Classification Dashboard":
bulk_dashboard_page(classifier)
elif page == "πŸ€– Interactive AI Agent":
interactive_agent_page(classifier, rag_pipeline)
elif page == "πŸ“ Single Ticket Classification":
single_ticket_page(classifier)
elif page == "πŸ“‚ Upload & Classify":
upload_classify_page(classifier)
def bulk_dashboard_page(classifier):
"""Bulk classification dashboard page"""
st.header("πŸ“Š Bulk Classification Dashboard")
st.subheader("Auto-loaded sample tickets with AI classification")
# Initialize session state for bulk results
if 'bulk_results' not in st.session_state:
st.session_state.bulk_results = None
# Auto-load bulk results
if st.session_state.bulk_results is None:
with st.spinner("πŸ”„ Loading and classifying sample tickets..."):
tickets = load_sample_tickets()
if tickets:
try:
classified_tickets = run_async(classify_tickets_async(classifier, tickets))
st.session_state.bulk_results = classified_tickets
st.success(f"βœ… Successfully classified {len(classified_tickets)} tickets!")
except Exception as e:
st.error(f"❌ Error during classification: {e}")
st.session_state.bulk_results = []
else:
st.session_state.bulk_results = []
if st.session_state.bulk_results:
# Display statistics
stats = calculate_stats(st.session_state.bulk_results)
col1, col2, col3, col4, col5 = st.columns(5)
with col1:
st.metric("πŸ“‹ Total Tickets", stats['total'])
with col2:
st.metric("🚨 High Priority", stats['high_priority'])
with col3:
st.metric("😀 Frustrated/Angry", stats['frustrated'])
with col4:
st.metric("πŸ€– RAG-Eligible", stats['rag_eligible'])
with col5:
st.metric("🏷️ Top Topic", stats['most_common_tag'])
# Visualizations
if stats['tag_counts']:
col1, col2 = st.columns(2)
with col1:
# Priority distribution
priority_data = {}
for _, classification in st.session_state.bulk_results:
priority = classification.priority.value
priority_data[priority] = priority_data.get(priority, 0) + 1
fig_priority = px.pie(
values=list(priority_data.values()),
names=list(priority_data.keys()),
title="πŸ“Š Priority Distribution",
color_discrete_map={
'P0 (High)': '#ff3838',
'P1 (Medium)': '#ffa726',
'P2 (Low)': '#66bb6a'
}
)
st.plotly_chart(fig_priority, use_container_width=True)
with col2:
# Topic distribution
fig_tags = px.bar(
x=list(stats['tag_counts'].values()),
y=list(stats['tag_counts'].keys()),
orientation='h',
title="🏷️ Topic Distribution",
labels={'x': 'Count', 'y': 'Topics'}
)
fig_tags.update_layout(height=400)
st.plotly_chart(fig_tags, use_container_width=True)
# Display tickets with filters
st.subheader("πŸ“‹ All Classified Tickets")
col1, col2, col3 = st.columns(3)
with col1:
priority_filter = st.selectbox("Filter by Priority",
["All"] + [p.value for p in PriorityEnum])
with col2:
sentiment_filter = st.selectbox("Filter by Sentiment",
["All"] + [s.value for s in SentimentEnum])
with col3:
topic_filter = st.selectbox("Filter by Topic",
["All"] + [t.value for t in TopicTagEnum])
# Apply filters
filtered_results = st.session_state.bulk_results
if priority_filter != "All":
filtered_results = [(t, c) for t, c in filtered_results if c.priority.value == priority_filter]
if sentiment_filter != "All":
filtered_results = [(t, c) for t, c in filtered_results if c.sentiment.value == sentiment_filter]
if topic_filter != "All":
filtered_results = [(t, c) for t, c in filtered_results if any(tag.value == topic_filter for tag in c.topic_tags)]
st.info(f"Showing {len(filtered_results)} of {len(st.session_state.bulk_results)} tickets")
# Display filtered tickets
for ticket, classification in filtered_results:
display_ticket_card(ticket, classification)
# Refresh button
if st.button("πŸ”„ Refresh Classifications"):
st.session_state.bulk_results = None
st.rerun()
def interactive_agent_page(classifier, rag_pipeline):
"""Interactive AI agent page"""
st.header("πŸ€– Interactive AI Agent")
st.subheader("Submit a new ticket or question from any channel")
# Input form
with st.form("interactive_form"):
question = st.text_area(
"Customer Question or Ticket:",
placeholder="Enter the customer's question or ticket description...",
height=150
)
channel = st.selectbox(
"Channel:",
["Web", "Email", "WhatsApp", "Voice", "Live Chat"]
)
submit_button = st.form_submit_button("πŸš€ Process with AI Agent")
if submit_button and question:
with st.spinner("πŸ€– Analyzing question and generating response..."):
try:
# Create a dummy ticket for classification
ticket = Ticket(id="INTERACTIVE-001", subject=question[:80], body=question)
# Classify the ticket
classification = run_async(classifier.classify_ticket(ticket))
topic_tags = [tag.value for tag in classification.topic_tags]
# Generate response using RAG pipeline
rag_result = run_async(rag_pipeline.generate_answer(question, topic_tags))
# Display results in two columns
col1, col2 = st.columns(2)
with col1:
st.subheader("πŸ“Š Internal Analysis (Back-end View)")
st.markdown(f"""
**🏷️ Topic Tags:** {', '.join([f'`{tag}`' for tag in topic_tags])}
**😊 Sentiment:** `{classification.sentiment.value}`
**⚑ Priority:** `{classification.priority.value}`
**πŸ€– AI Reasoning:** {classification.reasoning}
""")
with col2:
st.subheader("πŸ’¬ Final Response (Front-end View)")
if rag_result['type'] == 'direct_answer':
st.success("πŸ’‘ Direct Answer (RAG-Generated)")
st.write(rag_result['answer'])
if rag_result.get('sources'):
st.subheader("πŸ“š Sources:")
for source in rag_result['sources']:
st.markdown(f"- [{source}]({source})")
else:
st.warning("πŸ“‹ Ticket Routed")
st.write(rag_result['message'])
except Exception as e:
st.error(f"❌ Error processing question: {e}")
def single_ticket_page(classifier):
"""Single ticket classification page"""
st.header("πŸ“ Single Ticket Classification")
with st.form("single_ticket_form"):
ticket_id = st.text_input("Ticket ID:", placeholder="e.g., TICKET-001")
subject = st.text_input("Subject:", placeholder="Enter ticket subject")
body = st.text_area("Message Body:", placeholder="Enter the full ticket message...", height=150)
classify_button = st.form_submit_button("πŸ” Classify Ticket")
if classify_button and ticket_id and subject and body:
with st.spinner("πŸ”„ Classifying ticket..."):
try:
ticket = Ticket(id=ticket_id, subject=subject, body=body)
classification = run_async(classifier.classify_ticket(ticket))
st.success("βœ… Classification complete!")
display_ticket_card(ticket, classification)
except Exception as e:
st.error(f"❌ Error classifying ticket: {e}")
def upload_classify_page(classifier):
"""Upload and classify page"""
st.header("πŸ“‚ Upload & Classify Tickets")
uploaded_file = st.file_uploader("Choose a JSON file", type="json")
if uploaded_file is not None:
try:
tickets_data = json.load(uploaded_file)
tickets = [Ticket(**ticket_data) for ticket_data in tickets_data]
st.info(f"πŸ“„ Loaded {len(tickets)} tickets from file")
if st.button("πŸš€ Classify All Tickets"):
with st.spinner("πŸ”„ Classifying tickets..."):
try:
classified_tickets = run_async(classify_tickets_async(classifier, tickets))
st.success(f"βœ… Successfully classified {len(classified_tickets)} tickets!")
# Display statistics
stats = calculate_stats(classified_tickets)
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Total", stats['total'])
with col2:
st.metric("High Priority", stats['high_priority'])
with col3:
st.metric("Frustrated", stats['frustrated'])
with col4:
st.metric("RAG-Eligible", stats['rag_eligible'])
# Display tickets
for ticket, classification in classified_tickets:
display_ticket_card(ticket, classification)
except Exception as e:
st.error(f"❌ Error during classification: {e}")
except Exception as e:
st.error(f"❌ Error loading file: {e}")
# Footer
def show_footer():
"""Display footer"""
st.markdown("---")
st.markdown("""
<div style="text-align: center; color: #666; padding: 1rem;">
<p>🎯 <strong>Atlan Customer Support Copilot</strong> - AI-powered ticket classification and response generation</p>
<p>Built with Streamlit β€’ Powered by Groq AI β€’ Enhanced RAG Pipeline</p>
</div>
""", unsafe_allow_html=True)
# Run the app
if __name__ == "__main__":
main()
show_footer()