|
|
import streamlit as st |
|
|
import os |
|
|
import pandas as pd |
|
|
from pandasai import SmartDataframe |
|
|
from pandasai.responses.response_parser import ResponseParser |
|
|
from pandasai.llm import GoogleGemini |
|
|
import plotly.graph_objects as go |
|
|
from PIL import Image |
|
|
import io |
|
|
import base64 |
|
|
import requests |
|
|
import google.generativeai as genai |
|
|
from fpdf import FPDF |
|
|
import markdown2 |
|
|
import re |
|
|
from markdown_pdf import MarkdownPdf, Section |
|
|
|
|
|
|
|
|
API_URL = "https://irisplus.elixir.co.zw/public/api/profile/reporting/stock-card/genericReports" |
|
|
PAYLOAD = { |
|
|
"stock_card_report_id": "d2f1a0e1-7be1-472c-9610-94287154e544" |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
gemini_api_key = os.environ.get('GOOGLE_API_KEY') |
|
|
if not gemini_api_key: |
|
|
st.error("GOOGLE_API_KEY environment variable not set.") |
|
|
st.stop() |
|
|
|
|
|
genai.configure(api_key=gemini_api_key) |
|
|
|
|
|
generation_config = { |
|
|
"temperature": 0.2, |
|
|
"top_p": 0.95, |
|
|
"max_output_tokens": 5000, |
|
|
} |
|
|
|
|
|
model = genai.GenerativeModel( |
|
|
model_name="gemini-2.0-flash-thinking-exp", |
|
|
generation_config=generation_config, |
|
|
) |
|
|
|
|
|
def fetch_data(): |
|
|
"""Fetch stock card report data from API and return cleaned DataFrame""" |
|
|
response = requests.post(API_URL, data=PAYLOAD) |
|
|
if response.status_code == 200: |
|
|
try: |
|
|
data = response.json() |
|
|
if isinstance(data, dict) and 'actual_report' in data and isinstance(data['actual_report'], list): |
|
|
df = pd.DataFrame(data['actual_report']) |
|
|
|
|
|
df.dropna(axis=1, how='all', inplace=True) |
|
|
return df |
|
|
else: |
|
|
st.error("Unexpected response format from API.") |
|
|
return None |
|
|
except ValueError: |
|
|
st.error("Error: Response is not valid JSON.") |
|
|
return None |
|
|
else: |
|
|
st.error(f"Error fetching data: {response.status_code} - {response.text}") |
|
|
return None |
|
|
|
|
|
def md_to_pdf(md_text, pdf): |
|
|
"""Renders basic Markdown to PDF using fpdf text functions (limited formatting).""" |
|
|
md = markdown2.markdown(md_text) |
|
|
lines = md.split('\n') |
|
|
pdf.set_font("Arial", "", 12) |
|
|
|
|
|
for line in lines: |
|
|
line = line.strip() |
|
|
|
|
|
|
|
|
if line.startswith("# "): |
|
|
pdf.set_font("Arial", "B", 18) |
|
|
pdf.cell(0, 10, line[2:], ln=True) |
|
|
elif line.startswith("## "): |
|
|
pdf.set_font("Arial", "B", 16) |
|
|
pdf.cell(0, 10, line[3:], ln=True) |
|
|
elif line.startswith("### "): |
|
|
pdf.set_font("Arial", "B", 14) |
|
|
pdf.cell(0, 10, line[4:], ln=True) |
|
|
|
|
|
|
|
|
elif "**" in line: |
|
|
parts = line.split("**") |
|
|
for i, part in enumerate(parts): |
|
|
if i % 2 == 1: |
|
|
pdf.set_font("Arial", "B", 12) |
|
|
pdf.cell(0, 10, part, ln=False) |
|
|
else: |
|
|
pdf.set_font("Arial", "", 12) |
|
|
pdf.cell(0, 10, part, ln=False) |
|
|
pdf.ln() |
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
pdf.set_font("Arial", "", 12) |
|
|
pdf.multi_cell(0, 10, line) |
|
|
|
|
|
|
|
|
def generate_pdf(report_text): |
|
|
"""Generates PDF from report text.""" |
|
|
pdf = FPDF() |
|
|
pdf.add_page() |
|
|
try: |
|
|
pdf.add_font('Arial', '', 'arial.ttf', uni=True) |
|
|
except: |
|
|
st.warning("Arial font not found. Unicode might not work.") |
|
|
pdf.set_font("Arial", "", 12) |
|
|
md_to_pdf(report_text, pdf) |
|
|
pdf_bytes = pdf.output(dest="S").encode("latin1") |
|
|
return pdf_bytes |
|
|
|
|
|
|
|
|
class StreamLitResponse(ResponseParser): |
|
|
def __init__(self, context): |
|
|
super().__init__(context) |
|
|
|
|
|
def format_dataframe(self, result): |
|
|
"""Enhanced DataFrame rendering with type identifier""" |
|
|
return { |
|
|
'type': 'dataframe', |
|
|
'value': result['value'] |
|
|
} |
|
|
|
|
|
def format_plot(self, result): |
|
|
"""Enhanced plot rendering with type identifier""" |
|
|
try: |
|
|
image = result['value'] |
|
|
|
|
|
if isinstance(image, Image.Image): |
|
|
buffered = io.BytesIO() |
|
|
image.save(buffered, format="PNG") |
|
|
base64_image = base64.b64encode(buffered.getvalue()).decode('utf-8') |
|
|
elif isinstance(image, bytes): |
|
|
base64_image = base64.b64encode(image).decode('utf-8') |
|
|
elif isinstance(image, str) and os.path.exists(image): |
|
|
with open(image, "rb") as f: |
|
|
base64_image = base64.b64encode(f.read()).decode('utf-8') |
|
|
else: |
|
|
return {'type': 'text', 'value': "Unsupported image format"} |
|
|
return { |
|
|
'type': 'plot', |
|
|
'value': base64_image |
|
|
} |
|
|
except Exception as e: |
|
|
return {'type': 'text', 'value': f"Error processing plot: {e}"} |
|
|
|
|
|
def format_other(self, result): |
|
|
"""Handle other types of responses""" |
|
|
return { |
|
|
'type': 'text', |
|
|
'value': str(result['value']) |
|
|
} |
|
|
|
|
|
def generateResponse(prompt, df): |
|
|
"""Generate response using PandasAI with SmartDataframe""" |
|
|
llm = GoogleGemini(api_key=gemini_api_key) |
|
|
pandas_agent = SmartDataframe(df, config={ |
|
|
"llm": llm, |
|
|
"response_parser": StreamLitResponse |
|
|
}) |
|
|
return pandas_agent.chat(prompt) |
|
|
|
|
|
def render_chat_message(message): |
|
|
"""Render different types of chat messages""" |
|
|
if "dataframe" in message: |
|
|
st.dataframe(message["dataframe"]) |
|
|
elif "plot" in message: |
|
|
try: |
|
|
plot_data = message["plot"] |
|
|
if isinstance(plot_data, str): |
|
|
st.image(f"data:image/png;base64,{plot_data}") |
|
|
elif isinstance(plot_data, Image.Image): |
|
|
st.image(plot_data) |
|
|
elif isinstance(plot_data, go.Figure): |
|
|
st.plotly_chart(plot_data) |
|
|
elif isinstance(plot_data, bytes): |
|
|
image = Image.open(io.BytesIO(plot_data)) |
|
|
st.image(image) |
|
|
else: |
|
|
st.write("Unsupported plot format") |
|
|
except Exception as e: |
|
|
st.error(f"Error rendering plot: {e}") |
|
|
if "content" in message: |
|
|
st.markdown(message["content"]) |
|
|
|
|
|
def handle_userinput(question, df): |
|
|
"""Enhanced input handling with robust content processing""" |
|
|
try: |
|
|
|
|
|
if df is not None and not df.empty: |
|
|
|
|
|
st.session_state.chat_history.append({ |
|
|
"role": "user", |
|
|
"content": question |
|
|
}) |
|
|
|
|
|
result = generateResponse(question, df) |
|
|
if isinstance(result, dict): |
|
|
response_type = result.get('type', 'text') |
|
|
response_value = result.get('value') |
|
|
if response_type == 'dataframe': |
|
|
st.session_state.chat_history.append({ |
|
|
"role": "assistant", |
|
|
"content": "Here's the table:", |
|
|
"dataframe": response_value |
|
|
}) |
|
|
elif response_type == 'plot': |
|
|
st.session_state.chat_history.append({ |
|
|
"role": "assistant", |
|
|
"content": "Here's the chart:", |
|
|
"plot": response_value |
|
|
}) |
|
|
else: |
|
|
st.session_state.chat_history.append({ |
|
|
"role": "assistant", |
|
|
"content": str(response_value) |
|
|
}) |
|
|
else: |
|
|
st.session_state.chat_history.append({ |
|
|
"role": "assistant", |
|
|
"content": str(result) |
|
|
}) |
|
|
else: |
|
|
st.write("No data loaded.") |
|
|
except Exception as e: |
|
|
st.error(f"Error processing input: {e}") |
|
|
|
|
|
def main(): |
|
|
st.set_page_config(page_title="AI Chat with Your Data", page_icon="📊") |
|
|
|
|
|
|
|
|
if "chat_history" not in st.session_state: |
|
|
st.session_state.chat_history = [] |
|
|
if "dfs" not in st.session_state: |
|
|
st.session_state.dfs = fetch_data() |
|
|
|
|
|
|
|
|
tab_chat, tab_reports = st.tabs(["Chat", "Reports"]) |
|
|
|
|
|
|
|
|
with tab_chat: |
|
|
st.title("AI Chat with Your Data 📊") |
|
|
|
|
|
chat_container = st.container() |
|
|
with chat_container: |
|
|
for message in st.session_state.chat_history: |
|
|
with st.chat_message(message["role"]): |
|
|
render_chat_message(message) |
|
|
|
|
|
user_question = st.chat_input("Ask a question about your data:") |
|
|
if user_question: |
|
|
handle_userinput(user_question, st.session_state.dfs) |
|
|
|
|
|
chat_container.empty() |
|
|
with chat_container: |
|
|
for message in st.session_state.chat_history: |
|
|
with st.chat_message(message["role"]): |
|
|
render_chat_message(message) |
|
|
|
|
|
|
|
|
|
|
|
with tab_reports: |
|
|
st.title("Reports") |
|
|
st.write("Filter by product to generate a report") |
|
|
df_report = fetch_data() |
|
|
if df_report is not None and not df_report.empty: |
|
|
product_names = df_report["product"].unique().tolist() if "product" in df_report.columns else [] |
|
|
selected_products = st.multiselect("Select Product(s)", product_names, default=product_names) |
|
|
if st.button("Apply Filters and Generate Report"): |
|
|
filtered_df = df_report.copy() |
|
|
if selected_products: |
|
|
filtered_df = filtered_df[filtered_df["product"].isin(selected_products)] |
|
|
|
|
|
st.write("Filtered DataFrame Preview:") |
|
|
with st.expander("Preview"): |
|
|
st.dataframe(filtered_df.head()) |
|
|
|
|
|
with st.spinner("Generating Report, Please Wait...."): |
|
|
prompt = f""" |
|
|
You are an expert business analyst. Analyze the following data and generate a comprehensive and insightful business report including key performance indicators and recommendations.\n\nData:\n{filtered_df.to_markdown(index=False)} |
|
|
""" |
|
|
response = model.generate_content(prompt) |
|
|
report = response.text |
|
|
|
|
|
try: |
|
|
|
|
|
st.markdown(report) |
|
|
except Exception as e: |
|
|
st.error(f"Error generating report {e}") |
|
|
|
|
|
else: |
|
|
st.error("No data available for reports.") |
|
|
|
|
|
|
|
|
with st.sidebar: |
|
|
st.subheader("Options") |
|
|
if st.button("Reload Data"): |
|
|
with st.spinner("Fetching latest data..."): |
|
|
st.session_state.dfs = fetch_data() |
|
|
st.success("Data refreshed!") |
|
|
if st.button("Clear Chat"): |
|
|
st.session_state.chat_history = [] |
|
|
st.experimental_rerun() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|