File size: 8,772 Bytes
26326f6
 
 
 
 
 
3639586
 
26326f6
 
 
 
3639586
26326f6
 
 
 
 
 
35b5d7a
 
26326f6
 
3639586
 
 
 
 
 
 
35b5d7a
 
26326f6
3639586
 
 
 
 
 
 
2856029
3639586
 
26326f6
35b5d7a
3639586
 
 
 
35b5d7a
 
 
 
 
3639586
35b5d7a
 
 
 
 
 
26326f6
b0749ae
35b5d7a
ab49d91
3639586
35b5d7a
 
 
3639586
35b5d7a
 
 
 
3639586
 
 
 
35b5d7a
ab49d91
3639586
ab49d91
 
3639586
 
 
 
 
 
 
 
 
 
 
 
 
35b5d7a
3639586
 
5c70ff9
35b5d7a
3639586
 
b0749ae
 
3639586
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2856029
 
26326f6
3639586
35b5d7a
3639586
5c70ff9
3639586
 
 
 
 
 
 
 
 
 
 
5c70ff9
3639586
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c70ff9
3639586
 
 
 
 
 
 
 
 
 
5c70ff9
 
b0749ae
 
 
 
9f16e51
 
 
 
 
35b5d7a
b0749ae
 
9f16e51
b0749ae
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
import gradio as gr
import requests
import uuid
import time
import urllib3
import os
import pandas as pd
import base64
from dotenv import load_dotenv

load_dotenv()

# Disable SSL warnings for corporate networks
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

# --- CONFIGURATION ---
BASE_URL = os.getenv("BASE_URL")
BEARER_TOKEN = os.getenv("BEARER_TOKEN")
HEADERS = {
    "Authorization": f"Bearer {BEARER_TOKEN}",
    "Content-Type": "application/json"
}

def process_zbb_query(user_message, history, session_id, last_query, raw_sql):
    """
    Main chat handler.
    Format: List of Dictionaries [{"role": "user", "content": ...}]
    """
    # 1. Update Last Query State & Session
    last_query = user_message
    if not session_id:
        session_id = f"sess_{uuid.uuid4().hex[:6]}"

    # 2. Add User Message & 'Thinking' Placeholder (Dictionary format)
    # This matches the "messages" format required by newer Gradio versions
    history.append({"role": "user", "content": user_message})
    history.append({"role": "assistant", "content": "⏳ **Agent is thinking...**"})
    
    # Reset raw_sql state for new query
    raw_sql = []
    
    # Yield initial state
    yield history, session_id, gr.update(visible=False, value=None), last_query, raw_sql

    try:
        # Construct history for backend (excluding the last "Thinking" message)
        backend_history = history[:-1]

        # Kickoff request
        kickoff_url = f"{BASE_URL}/kickoff"
        payload = {
            "inputs": {
                "session_id": session_id, 
                "current_query": user_message, 
                "conversation_history": backend_history
            }
        }
        
        post_resp = requests.post(kickoff_url, json=payload, headers=HEADERS, verify=False)
        post_resp.raise_for_status()
        kickoff_id = post_resp.json().get("kickoff_id")

        # Polling Loop
        status_url = f"{BASE_URL}/status/{kickoff_id}"
        result_data = None
        
        for i in range(200):
            get_resp = requests.get(status_url, headers=HEADERS, verify=False)
            status_data = get_resp.json()
            
            if status_data.get("state") == "SUCCESS":
                result_data = status_data.get("result")
                break
            
            # Update thinking status with progress
            # Access the last message dictionary and update its content
            history[-1]["content"] = f"⏳ **Agent is thinking...** ({i+1}s)"
            yield history, session_id, gr.update(visible=False), last_query, raw_sql
            time.sleep(1)

        # Handle Results
        if result_data:
            response_text = result_data.get("assistant_message", "No message returned.")
            sql_data = result_data.get("sql_result", [])
            
            # Replace 'Thinking' message with actual response
            history[-1]["content"] = response_text
            
            # If SQL data exists, update DataFrame & State
            if isinstance(sql_data, list) and len(sql_data) > 0:
                raw_sql = sql_data 
                df = pd.DataFrame(sql_data)
                df = df.map(lambda x: x.strip("'") if isinstance(x, str) else x)
                yield history, session_id, gr.update(visible=True, value=df), last_query, raw_sql
            else:
                yield history, session_id, gr.update(visible=False, value=None), last_query, raw_sql
        else:
            history[-1]["content"] = "⚠️ Request timed out. Please try again."
            yield history, session_id, gr.update(visible=False), last_query, raw_sql

    except Exception as e:
        history[-1]["content"] = f"❌ Error: {str(e)}"
        yield history, session_id, gr.update(visible=False), last_query, raw_sql


def generate_visualization(session_id, last_query, raw_sql):
    """
    Handler for the Visualization button.
    """
    if not raw_sql:
        return "<h3>⚠️ No data available to visualize. Please run a query first.</h3>"
    
    yield "<h3>⏳ Generating Visualization...</h3>"
    
    try:
        kickoff_url = f"{BASE_URL}/kickoff"
        payload = {
            "inputs": {
                "session_id": session_id,
                "current_query": last_query,
                "conversation_history": [],
                "router_flag": "viz",
                "sql_result": raw_sql
            }
        }
        
        post_resp = requests.post(kickoff_url, json=payload, headers=HEADERS, verify=False)
        post_resp.raise_for_status()
        kickoff_id = post_resp.json().get("kickoff_id")

        status_url = f"{BASE_URL}/status/{kickoff_id}"
        
        for i in range(200):
            get_resp = requests.get(status_url, headers=HEADERS, verify=False)
            status_data = get_resp.json()
            
            if status_data.get("state") == "SUCCESS":
                result = status_data.get("result", {})
                raw_html = result.get("final_response", "")
                
                if not raw_html:
                    yield "<div>No visualization content returned</div>"
                    return

                try:
                    html_b64 = base64.b64encode(raw_html.encode('utf-8')).decode('utf-8')
                    iframe_html = f"""
                    <iframe 
                        src="data:text/html;base64,{html_b64}" 
                        style="width: 100%; height: 600px; border: none;"
                        scrolling="yes">
                    </iframe>
                    """
                    yield iframe_html
                except Exception as encode_error:
                    yield f"<div>Error encoding visualization: {str(encode_error)}</div>"
                return
            
            time.sleep(1)
            
        yield "<h3>⚠️ Visualization timed out.</h3>"

    except Exception as e:
        yield f"<h3>❌ Error generating visualization: {str(e)}</h3>"


# --- UI Setup ---
with gr.Blocks() as demo:
    gr.Markdown("# πŸ’° ZBB GenAI Analysis")
    
    session_id = gr.State("")
    last_query_state = gr.State("")
    raw_sql_state = gr.State([])
    
    # Pre-define dataframe (hidden initially)
    source_df = gr.DataFrame(
        label="Query Output",
        interactive=False,
        visible=False,
        wrap=True,
        render=False
    )
    
    with gr.Row():
        with gr.Column(scale=3):
            # FIXED: Removed type="messages". 
            # Gradio 5+ defaults to messages format without the argument.
            # We must pass Dictionary data (done in process_zbb_query)
            chatbot = gr.Chatbot(label="ZBB Assistant", height=500)
            msg = gr.Textbox(
                placeholder="Type your query here...",
                label="Command"
            )
            
            # --- SUGGESTIONS SECTION ---
            gr.Markdown("### πŸ’‘ Suggestions")
            examples = gr.Examples(
                examples=[["How is Overhead tracking vs Budget YTD for NAZ by Function?"],
                    ["Show me NAZ overhead performance"],
                    ["What is the travel budget for SAZ?"],
                    ["Compare Marketing actuals vs budget for M1 2025"],
                    ["Drill down into IT expenses for APAC"]
                ],
                inputs=[msg],
                label="Click to fill query"
            )

        with gr.Column(scale=2):
            with gr.Accordion("πŸ“Š Data Sources ", open=False) as data_accordion:
                source_df.render()
                
            gr.Markdown("### Visualization")
            viz_btn = gr.Button("πŸ“Š Get Visualization", variant="primary")
            viz_output = gr.HTML(label="Chart", min_height=500)

    # 1. Chat Submission
    submit_event = msg.submit(
        process_zbb_query, 
        inputs=[msg, chatbot, session_id, last_query_state, raw_sql_state], 
        outputs=[chatbot, session_id, source_df, last_query_state, raw_sql_state]
    ).then(lambda: "", None, [msg]) 

    # 2. Visualization Event
    viz_btn.click(
        generate_visualization,
        inputs=[session_id, last_query_state, raw_sql_state],
        outputs=[viz_output]
    )

if __name__ == "__main__":
    import os
    from dotenv import load_dotenv
    
    load_dotenv()
    # Define a clean theme using Inter (modern, highly readable)
    clean_theme = gr.themes.Soft(
        font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"],
        primary_hue="blue",
    )
    demo.launch(
        auth=(os.getenv("ID"), os.getenv("PASS")),
        auth_message="Please enter your ABInBev credentials to access the ZBB GenAI Tool.",
        theme=clean_theme
    )