ZBB_AI / app.py
ABInBev's picture
Update app.py
3639586 verified
import gradio as gr
import requests
import uuid
import time
import urllib3
import os
import pandas as pd
import base64
from dotenv import load_dotenv
load_dotenv()
# Disable SSL warnings for corporate networks
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
# --- CONFIGURATION ---
BASE_URL = os.getenv("BASE_URL")
BEARER_TOKEN = os.getenv("BEARER_TOKEN")
HEADERS = {
"Authorization": f"Bearer {BEARER_TOKEN}",
"Content-Type": "application/json"
}
def process_zbb_query(user_message, history, session_id, last_query, raw_sql):
"""
Main chat handler.
Format: List of Dictionaries [{"role": "user", "content": ...}]
"""
# 1. Update Last Query State & Session
last_query = user_message
if not session_id:
session_id = f"sess_{uuid.uuid4().hex[:6]}"
# 2. Add User Message & 'Thinking' Placeholder (Dictionary format)
# This matches the "messages" format required by newer Gradio versions
history.append({"role": "user", "content": user_message})
history.append({"role": "assistant", "content": "⏳ **Agent is thinking...**"})
# Reset raw_sql state for new query
raw_sql = []
# Yield initial state
yield history, session_id, gr.update(visible=False, value=None), last_query, raw_sql
try:
# Construct history for backend (excluding the last "Thinking" message)
backend_history = history[:-1]
# Kickoff request
kickoff_url = f"{BASE_URL}/kickoff"
payload = {
"inputs": {
"session_id": session_id,
"current_query": user_message,
"conversation_history": backend_history
}
}
post_resp = requests.post(kickoff_url, json=payload, headers=HEADERS, verify=False)
post_resp.raise_for_status()
kickoff_id = post_resp.json().get("kickoff_id")
# Polling Loop
status_url = f"{BASE_URL}/status/{kickoff_id}"
result_data = None
for i in range(200):
get_resp = requests.get(status_url, headers=HEADERS, verify=False)
status_data = get_resp.json()
if status_data.get("state") == "SUCCESS":
result_data = status_data.get("result")
break
# Update thinking status with progress
# Access the last message dictionary and update its content
history[-1]["content"] = f"⏳ **Agent is thinking...** ({i+1}s)"
yield history, session_id, gr.update(visible=False), last_query, raw_sql
time.sleep(1)
# Handle Results
if result_data:
response_text = result_data.get("assistant_message", "No message returned.")
sql_data = result_data.get("sql_result", [])
# Replace 'Thinking' message with actual response
history[-1]["content"] = response_text
# If SQL data exists, update DataFrame & State
if isinstance(sql_data, list) and len(sql_data) > 0:
raw_sql = sql_data
df = pd.DataFrame(sql_data)
df = df.map(lambda x: x.strip("'") if isinstance(x, str) else x)
yield history, session_id, gr.update(visible=True, value=df), last_query, raw_sql
else:
yield history, session_id, gr.update(visible=False, value=None), last_query, raw_sql
else:
history[-1]["content"] = "⚠️ Request timed out. Please try again."
yield history, session_id, gr.update(visible=False), last_query, raw_sql
except Exception as e:
history[-1]["content"] = f"❌ Error: {str(e)}"
yield history, session_id, gr.update(visible=False), last_query, raw_sql
def generate_visualization(session_id, last_query, raw_sql):
"""
Handler for the Visualization button.
"""
if not raw_sql:
return "<h3>⚠️ No data available to visualize. Please run a query first.</h3>"
yield "<h3>⏳ Generating Visualization...</h3>"
try:
kickoff_url = f"{BASE_URL}/kickoff"
payload = {
"inputs": {
"session_id": session_id,
"current_query": last_query,
"conversation_history": [],
"router_flag": "viz",
"sql_result": raw_sql
}
}
post_resp = requests.post(kickoff_url, json=payload, headers=HEADERS, verify=False)
post_resp.raise_for_status()
kickoff_id = post_resp.json().get("kickoff_id")
status_url = f"{BASE_URL}/status/{kickoff_id}"
for i in range(200):
get_resp = requests.get(status_url, headers=HEADERS, verify=False)
status_data = get_resp.json()
if status_data.get("state") == "SUCCESS":
result = status_data.get("result", {})
raw_html = result.get("final_response", "")
if not raw_html:
yield "<div>No visualization content returned</div>"
return
try:
html_b64 = base64.b64encode(raw_html.encode('utf-8')).decode('utf-8')
iframe_html = f"""
<iframe
src="data:text/html;base64,{html_b64}"
style="width: 100%; height: 600px; border: none;"
scrolling="yes">
</iframe>
"""
yield iframe_html
except Exception as encode_error:
yield f"<div>Error encoding visualization: {str(encode_error)}</div>"
return
time.sleep(1)
yield "<h3>⚠️ Visualization timed out.</h3>"
except Exception as e:
yield f"<h3>❌ Error generating visualization: {str(e)}</h3>"
# --- UI Setup ---
with gr.Blocks() as demo:
gr.Markdown("# πŸ’° ZBB GenAI Analysis")
session_id = gr.State("")
last_query_state = gr.State("")
raw_sql_state = gr.State([])
# Pre-define dataframe (hidden initially)
source_df = gr.DataFrame(
label="Query Output",
interactive=False,
visible=False,
wrap=True,
render=False
)
with gr.Row():
with gr.Column(scale=3):
# FIXED: Removed type="messages".
# Gradio 5+ defaults to messages format without the argument.
# We must pass Dictionary data (done in process_zbb_query)
chatbot = gr.Chatbot(label="ZBB Assistant", height=500)
msg = gr.Textbox(
placeholder="Type your query here...",
label="Command"
)
# --- SUGGESTIONS SECTION ---
gr.Markdown("### πŸ’‘ Suggestions")
examples = gr.Examples(
examples=[["How is Overhead tracking vs Budget YTD for NAZ by Function?"],
["Show me NAZ overhead performance"],
["What is the travel budget for SAZ?"],
["Compare Marketing actuals vs budget for M1 2025"],
["Drill down into IT expenses for APAC"]
],
inputs=[msg],
label="Click to fill query"
)
with gr.Column(scale=2):
with gr.Accordion("πŸ“Š Data Sources ", open=False) as data_accordion:
source_df.render()
gr.Markdown("### Visualization")
viz_btn = gr.Button("πŸ“Š Get Visualization", variant="primary")
viz_output = gr.HTML(label="Chart", min_height=500)
# 1. Chat Submission
submit_event = msg.submit(
process_zbb_query,
inputs=[msg, chatbot, session_id, last_query_state, raw_sql_state],
outputs=[chatbot, session_id, source_df, last_query_state, raw_sql_state]
).then(lambda: "", None, [msg])
# 2. Visualization Event
viz_btn.click(
generate_visualization,
inputs=[session_id, last_query_state, raw_sql_state],
outputs=[viz_output]
)
if __name__ == "__main__":
import os
from dotenv import load_dotenv
load_dotenv()
demo.launch(
auth=(os.getenv("ID"), os.getenv("PASS")),
auth_message="Please enter your ABInBev credentials to access the ZBB GenAI Tool.",
theme=gr.themes.Soft()
)