DrinksAdmin / app.py
rairo's picture
Update app.py
f07fd52 verified
import streamlit as st
import pandas as pd
import pandasai as PandasAI
from pandasai import SmartDatalake, SmartDataframe
from pandasai.responses.response_parser import ResponseParser
from pandasai.llm import GoogleGemini
import plotly.express as px
from PIL import Image
import io
import base64
import google.generativeai as genai
#from fpdf import FPDF
import markdown2
import re
import json
import os
from markdown_pdf import MarkdownPdf, Section
import tempfile
from langchain_google_genai import ChatGoogleGenerativeAI
# Configure Gemini API
gemini_api_key = os.environ.get('GOOGLE_API_KEY')
if not gemini_api_key:
st.error("GOOGLE_API_KEY environment variable not set.")
st.stop()
genai.configure(api_key=gemini_api_key)
generation_config = {
"temperature": 0.2,
"top_p": 0.95,
"max_output_tokens": 5000,
}
model = genai.GenerativeModel(
model_name="gemini-2.0-flash-thinking-exp",
generation_config=generation_config,
)
# Pandasai gemini
llm1 = ChatGoogleGenerativeAI(
model="gemini-2.0-flash-thinking-exp",
temperature=0,
max_tokens=None,
timeout=None,
max_retries=2
)
def load_data():
"""Load data from CSV files and validate"""
try:
events_df = pd.read_csv("Delta-Events.csv")
customers_df = pd.read_csv("delta_customers.csv")
products_df = pd.read_csv("Customer_Products.csv")
# Validate data
if events_df.empty or customers_df.empty or products_df.empty:
st.error("One or more data files are empty.")
return None
return {
'events': events_df,
'customers': customers_df,
'products': products_df
}
except Exception as e:
st.error(f"Error loading data: {e}")
return None
#Dashboard
def create_dashboard(data):
"""Create dashboard visualizations"""
st.header("Business Insights Dashboard")
# Merge relevant data
merged_orders = pd.concat([
data['events'][['Surbub', 'Order Value $']].rename(columns={'Surbub': 'Suburb'}),
data['customers'][['Surburb', 'Order_Value']].rename(columns={'Surburb': 'Suburb', 'Order_Value': 'Order Value $'})
])
with st.container():
col1, col2 = st.columns(2)
with col1:
# Total Orders by Suburb
suburb_orders = merged_orders.groupby('Suburb')['Order Value $'].sum().reset_index()
fig = px.bar(suburb_orders, x='Suburb', y='Order Value $',
title='Total Order Value by Suburb')
st.plotly_chart(fig, use_container_width=True)
with col2:
# Event Types Distribution
event_counts = data['events'].groupby('Event')['Order Value $'].sum().reset_index()
event_counts.columns = ['Event', 'Order Value $'] # Rename columns explicitly
fig = px.pie(event_counts, names='Event', values='Order Value $',
title='Event Type Distribution By Order Value')
st.plotly_chart(fig, use_container_width=True)
# Top Products Analysis
with st.container():
st.subheader("Product Performance")
product_sales = data['products'].groupby('Product')['Quantity'].sum().nlargest(10).reset_index()
fig = px.bar(product_sales, x='Product', y='Quantity',
title='Top 10 Products by Quantity Sold')
st.plotly_chart(fig, use_container_width=True)
# --- Chat Tab Functions ---
class StreamLitResponse(ResponseParser):
def __init__(self, context):
super().__init__(context)
def format_dataframe(self, result):
"""Enhanced DataFrame rendering with type identifier"""
return {
'type': 'dataframe',
'value': result['value']
}
def format_plot(self, result):
"""Enhanced plot rendering with type identifier"""
try:
image = result['value']
# Convert image to base64 for consistent storage
if isinstance(image, Image.Image):
buffered = io.BytesIO()
image.save(buffered, format="PNG")
base64_image = base64.b64encode(buffered.getvalue()).decode('utf-8')
elif isinstance(image, bytes):
base64_image = base64.b64encode(image).decode('utf-8')
elif isinstance(image, str) and os.path.exists(image):
with open(image, "rb") as f:
base64_image = base64.b64encode(f.read()).decode('utf-8')
else:
return {'type': 'text', 'value': "Unsupported image format"}
return {
'type': 'plot',
'value': base64_image
}
except Exception as e:
return {'type': 'text', 'value': f"Error processing plot: {e}"}
def format_other(self, result):
"""Handle other types of responses"""
return {
'type': 'text',
'value': str(result['value'])
}
def generateResponse(prompt, data):
"""Generate response using PandasAI with SmartDataLake"""
# Ensure data is a dictionary of DataFrames
if not isinstance(data, dict) or not all(isinstance(df, pd.DataFrame) for df in data.values()):
st.error("Invalid data format. Expected a dictionary of DataFrames.")
return None
pandas_agent = SmartDatalake(
list(data.values()), # Pass list of DataFrames
config={
"llm": llm1,
"response_parser": StreamLitResponse
}
)
return pandas_agent.chat(prompt)
def render_chat_message(message):
"""Render different types of chat messages"""
if "dataframe" in message:
st.dataframe(message["dataframe"])
elif "plot" in message:
try:
plot_data = message["plot"]
if isinstance(plot_data, str):
st.image(f"data:image/png;base64,{plot_data}")
elif isinstance(plot_data, Image.Image):
st.image(plot_data)
elif isinstance(plot_data, go.Figure):
st.plotly_chart(plot_data)
elif isinstance(plot_data, bytes):
image = Image.open(io.BytesIO(plot_data))
st.image(image)
else:
st.write("Unsupported plot format")
except Exception as e:
st.error(f"Error rendering plot: {e}")
if "content" in message:
st.markdown(message["content"])
def handle_userinput(question, data):
"""Handle user input with SmartDataLake"""
try:
if data and all(not df.empty for df in data.values()):
st.session_state.chat_history.append({
"role": "user",
"content": question
})
result = generateResponse(question, data)
if isinstance(result, dict):
response_type = result.get('type', 'text')
response_value = result.get('value')
if response_type == 'dataframe':
st.session_state.chat_history.append({
"role": "assistant",
"content": "Here's the table:",
"dataframe": response_value
})
elif response_type == 'plot':
st.session_state.chat_history.append({
"role": "assistant",
"content": "Here's the chart:",
"plot": response_value
})
else:
st.session_state.chat_history.append({
"role": "assistant",
"content": str(response_value)
})
else:
st.session_state.chat_history.append({
"role": "assistant",
"content": str(result)
})
else:
st.error("No valid data available for analysis.")
except Exception as e:
st.error(f"Error processing input: {e}")
def main():
st.set_page_config(page_title="Business Analytics Suite", page_icon="📊", layout="wide")
# Initialize session state
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
if "data" not in st.session_state:
st.session_state.data = load_data()
# Create tabs
tab_dashboard, tab_chat, tab_reports = st.tabs(["📊 Dashboard", "💬 Chat", "📈 Reports"])
# Dashboard Tab
with tab_dashboard:
if st.session_state.data:
create_dashboard(st.session_state.data)
else:
st.error("Failed to load data for dashboard")
# Chat Tab
with tab_chat:
st.title("AI Data Analyst")
chat_container = st.container()
with chat_container:
for message in st.session_state.chat_history:
with st.chat_message(message["role"]):
render_chat_message(message)
user_question = st.chat_input("Ask a question about your data:")
if user_question:
handle_userinput(user_question, st.session_state.data)
chat_container.empty()
with chat_container:
for message in st.session_state.chat_history:
with st.chat_message(message["role"]):
render_chat_message(message)
# Reports Tab
with tab_reports:
st.title("Custom Reports")
if st.session_state.data:
# Suburb Filter
suburbs = pd.concat([
st.session_state.data['events']['Surbub'],
st.session_state.data['customers']['Surburb']
]).unique()
selected_suburbs = st.multiselect("Select Suburbs", suburbs)
if st.button("Generate Report"):
with st.spinner("Analyzing data..."):
# Prepare filtered data
filtered_data = {
'events': st.session_state.data['events'][
st.session_state.data['events']['Surbub'].isin(selected_suburbs)
] if selected_suburbs else st.session_state.data['events'],
'customers': st.session_state.data['customers'][
st.session_state.data['customers']['Surburb'].isin(selected_suburbs)
] if selected_suburbs else st.session_state.data['customers'],
'products': st.session_state.data['products']
}
# Convert to JSON
json_data = {k: v.to_json(orient='records') for k, v in filtered_data.items()}
# Generate report
prompt = f"""
Analyze this business data and generate a comprehensive report in plain text format. Use markdown for headings and structure. Do not include any json.
Data:
{json.dumps(json_data, indent=2)}
No introductory quips or salutations or follow up questions, just write the report.
"""
response = model.generate_content(prompt)
report = response.text
html_text = markdown2.markdown(report)
# PDF Generation and display
try:
# Create a temporary file to store the PDF
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file:
pdf = MarkdownPdf()
pdf.meta["title"] = 'Suburb Business Report'
pdf.add_section(Section(report, toc=False))
pdf.save(tmp_file.name) # Save the PDF to the temporary file
# Read the PDF bytes from the temporary file
with open(tmp_file.name, "rb") as f:
pdf_bytes = f.read()
# Provide the PDF for download
st.download_button(
label="Download Report as PDF",
data=pdf_bytes,
file_name="report.pdf",
mime="application/pdf"
)
st.write(html_text, unsafe_allow_html=True) # Display the report below the download button
except Exception as e:
st.error(f"Error generating PDF: {e}")
st.write(html_text, unsafe_allow_html=True)
else:
st.error("No data available for reports")
# Sidebar
with st.sidebar:
st.header("Data Management")
if st.button("Reload Data"):
st.session_state.data = load_data()
if st.button("Clear Chat"):
st.session_state.chat_history = []
if __name__ == "__main__":
main()