|
|
import gradio as gr |
|
|
import requests |
|
|
import uuid |
|
|
import time |
|
|
import urllib3 |
|
|
import os |
|
|
import pandas as pd |
|
|
import base64 |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) |
|
|
|
|
|
|
|
|
BASE_URL = os.getenv("BASE_URL") |
|
|
BEARER_TOKEN = os.getenv("BEARER_TOKEN") |
|
|
HEADERS = { |
|
|
"Authorization": f"Bearer {BEARER_TOKEN}", |
|
|
"Content-Type": "application/json" |
|
|
} |
|
|
|
|
|
def process_zbb_query(user_message, history, session_id, last_query, raw_sql): |
|
|
""" |
|
|
Main chat handler. |
|
|
Format: List of Dictionaries [{"role": "user", "content": ...}] |
|
|
""" |
|
|
|
|
|
last_query = user_message |
|
|
if not session_id: |
|
|
session_id = f"sess_{uuid.uuid4().hex[:6]}" |
|
|
|
|
|
|
|
|
|
|
|
history.append({"role": "user", "content": user_message}) |
|
|
history.append({"role": "assistant", "content": "β³ **Agent is thinking...**"}) |
|
|
|
|
|
|
|
|
raw_sql = [] |
|
|
|
|
|
|
|
|
yield history, session_id, gr.update(visible=False, value=None), last_query, raw_sql |
|
|
|
|
|
try: |
|
|
|
|
|
backend_history = history[:-1] |
|
|
|
|
|
|
|
|
kickoff_url = f"{BASE_URL}/kickoff" |
|
|
payload = { |
|
|
"inputs": { |
|
|
"session_id": session_id, |
|
|
"current_query": user_message, |
|
|
"conversation_history": backend_history |
|
|
} |
|
|
} |
|
|
|
|
|
post_resp = requests.post(kickoff_url, json=payload, headers=HEADERS, verify=False) |
|
|
post_resp.raise_for_status() |
|
|
kickoff_id = post_resp.json().get("kickoff_id") |
|
|
|
|
|
|
|
|
status_url = f"{BASE_URL}/status/{kickoff_id}" |
|
|
result_data = None |
|
|
|
|
|
for i in range(200): |
|
|
get_resp = requests.get(status_url, headers=HEADERS, verify=False) |
|
|
status_data = get_resp.json() |
|
|
|
|
|
if status_data.get("state") == "SUCCESS": |
|
|
result_data = status_data.get("result") |
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
history[-1]["content"] = f"β³ **Agent is thinking...** ({i+1}s)" |
|
|
yield history, session_id, gr.update(visible=False), last_query, raw_sql |
|
|
time.sleep(1) |
|
|
|
|
|
|
|
|
if result_data: |
|
|
response_text = result_data.get("assistant_message", "No message returned.") |
|
|
sql_data = result_data.get("sql_result", []) |
|
|
|
|
|
|
|
|
history[-1]["content"] = response_text |
|
|
|
|
|
|
|
|
if isinstance(sql_data, list) and len(sql_data) > 0: |
|
|
raw_sql = sql_data |
|
|
df = pd.DataFrame(sql_data) |
|
|
df = df.map(lambda x: x.strip("'") if isinstance(x, str) else x) |
|
|
yield history, session_id, gr.update(visible=True, value=df), last_query, raw_sql |
|
|
else: |
|
|
yield history, session_id, gr.update(visible=False, value=None), last_query, raw_sql |
|
|
else: |
|
|
history[-1]["content"] = "β οΈ Request timed out. Please try again." |
|
|
yield history, session_id, gr.update(visible=False), last_query, raw_sql |
|
|
|
|
|
except Exception as e: |
|
|
history[-1]["content"] = f"β Error: {str(e)}" |
|
|
yield history, session_id, gr.update(visible=False), last_query, raw_sql |
|
|
|
|
|
|
|
|
def generate_visualization(session_id, last_query, raw_sql): |
|
|
""" |
|
|
Handler for the Visualization button. |
|
|
""" |
|
|
if not raw_sql: |
|
|
return "<h3>β οΈ No data available to visualize. Please run a query first.</h3>" |
|
|
|
|
|
yield "<h3>β³ Generating Visualization...</h3>" |
|
|
|
|
|
try: |
|
|
kickoff_url = f"{BASE_URL}/kickoff" |
|
|
payload = { |
|
|
"inputs": { |
|
|
"session_id": session_id, |
|
|
"current_query": last_query, |
|
|
"conversation_history": [], |
|
|
"router_flag": "viz", |
|
|
"sql_result": raw_sql |
|
|
} |
|
|
} |
|
|
|
|
|
post_resp = requests.post(kickoff_url, json=payload, headers=HEADERS, verify=False) |
|
|
post_resp.raise_for_status() |
|
|
kickoff_id = post_resp.json().get("kickoff_id") |
|
|
|
|
|
status_url = f"{BASE_URL}/status/{kickoff_id}" |
|
|
|
|
|
for i in range(200): |
|
|
get_resp = requests.get(status_url, headers=HEADERS, verify=False) |
|
|
status_data = get_resp.json() |
|
|
|
|
|
if status_data.get("state") == "SUCCESS": |
|
|
result = status_data.get("result", {}) |
|
|
raw_html = result.get("final_response", "") |
|
|
|
|
|
if not raw_html: |
|
|
yield "<div>No visualization content returned</div>" |
|
|
return |
|
|
|
|
|
try: |
|
|
html_b64 = base64.b64encode(raw_html.encode('utf-8')).decode('utf-8') |
|
|
iframe_html = f""" |
|
|
<iframe |
|
|
src="data:text/html;base64,{html_b64}" |
|
|
style="width: 100%; height: 600px; border: none;" |
|
|
scrolling="yes"> |
|
|
</iframe> |
|
|
""" |
|
|
yield iframe_html |
|
|
except Exception as encode_error: |
|
|
yield f"<div>Error encoding visualization: {str(encode_error)}</div>" |
|
|
return |
|
|
|
|
|
time.sleep(1) |
|
|
|
|
|
yield "<h3>β οΈ Visualization timed out.</h3>" |
|
|
|
|
|
except Exception as e: |
|
|
yield f"<h3>β Error generating visualization: {str(e)}</h3>" |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# π° ZBB GenAI Analysis") |
|
|
|
|
|
session_id = gr.State("") |
|
|
last_query_state = gr.State("") |
|
|
raw_sql_state = gr.State([]) |
|
|
|
|
|
|
|
|
source_df = gr.DataFrame( |
|
|
label="Query Output", |
|
|
interactive=False, |
|
|
visible=False, |
|
|
wrap=True, |
|
|
render=False |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=3): |
|
|
|
|
|
|
|
|
|
|
|
chatbot = gr.Chatbot(label="ZBB Assistant", height=500) |
|
|
msg = gr.Textbox( |
|
|
placeholder="Type your query here...", |
|
|
label="Command" |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("### π‘ Suggestions") |
|
|
examples = gr.Examples( |
|
|
examples=[["How is Overhead tracking vs Budget YTD for NAZ by Function?"], |
|
|
["Show me NAZ overhead performance"], |
|
|
["What is the travel budget for SAZ?"], |
|
|
["Compare Marketing actuals vs budget for M1 2025"], |
|
|
["Drill down into IT expenses for APAC"] |
|
|
], |
|
|
inputs=[msg], |
|
|
label="Click to fill query" |
|
|
) |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
with gr.Accordion("π Data Sources ", open=False) as data_accordion: |
|
|
source_df.render() |
|
|
|
|
|
gr.Markdown("### Visualization") |
|
|
viz_btn = gr.Button("π Get Visualization", variant="primary") |
|
|
viz_output = gr.HTML(label="Chart", min_height=500) |
|
|
|
|
|
|
|
|
submit_event = msg.submit( |
|
|
process_zbb_query, |
|
|
inputs=[msg, chatbot, session_id, last_query_state, raw_sql_state], |
|
|
outputs=[chatbot, session_id, source_df, last_query_state, raw_sql_state] |
|
|
).then(lambda: "", None, [msg]) |
|
|
|
|
|
|
|
|
viz_btn.click( |
|
|
generate_visualization, |
|
|
inputs=[session_id, last_query_state, raw_sql_state], |
|
|
outputs=[viz_output] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import os |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
load_dotenv() |
|
|
demo.launch( |
|
|
auth=(os.getenv("ID"), os.getenv("PASS")), |
|
|
auth_message="Please enter your ABInBev credentials to access the ZBB GenAI Tool.", |
|
|
theme=gr.themes.Soft() |
|
|
) |