File size: 20,059 Bytes
cb38618
c5631f7
23f813e
790f821
23f813e
 
790f821
23f813e
cb38618
23f813e
790f821
bafbec1
790f821
 
e912511
 
9fd72a9
790f821
8a76b1c
 
6a85195
8a76b1c
 
 
 
790f821
 
 
 
8a76b1c
 
bafbec1
 
e0ef454
cb38618
23f813e
cb38618
2008452
fa6a00f
 
2008452
fa6a00f
672f9ae
 
2008452
2be906f
fa6a00f
 
894519b
f5c7c21
c428fa2
 
4a229f4
 
 
8fc34e7
 
 
282a1ba
 
 
0cbb043
 
 
 
 
 
 
 
 
8fc34e7
739c1d1
 
 
 
4a229f4
 
 
8a76b1c
 
 
4a229f4
8fc34e7
 
 
 
739c1d1
 
 
 
8fc34e7
 
 
 
 
 
 
 
 
 
739c1d1
 
8fc34e7
 
 
 
 
 
 
 
 
 
 
bafbec1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23f813e
e912511
23f813e
 
 
 
 
 
2008452
23f813e
790f821
 
112d15c
790f821
 
 
 
 
 
 
 
 
23f813e
5dbdcc3
bafbec1
e912511
 
23f813e
c5a694c
 
fa6a00f
 
 
 
c5a694c
e912511
 
 
3574be0
bafbec1
 
812553a
f5c7c21
e912511
08c0da1
 
 
 
 
3c25866
 
 
08c0da1
246d445
3c25866
e9eb721
5dbdcc3
e9eb721
 
 
3c25866
81407fe
c5a694c
d36b2b7
2196b2c
adbfdfa
3a28483
9fd72a9
 
 
9b6dcba
 
9fd72a9
 
 
 
 
 
 
 
 
1fae026
 
9fd72a9
 
 
 
df8d6b8
9fd72a9
d67ae7d
 
 
9fd72a9
5d7e5fe
9fd72a9
2b898d5
77a8bfe
914fcbb
3c25866
9fd72a9
9b6dcba
 
9fd72a9
 
 
79ec2bb
 
9fd72a9
79ec2bb
 
 
1fae026
 
9fd72a9
 
 
79ec2bb
 
 
 
9fd72a9
 
914fcbb
9fd72a9
d67ae7d
914fcbb
d67ae7d
5d7e5fe
fa6a00f
 
 
 
3c25866
 
 
e912511
 
 
fa6a00f
 
c8c96ab
790f821
c62acd1
 
c5a694c
 
 
 
5decd13
c5a694c
 
 
 
 
 
 
9b6dcba
 
42c1511
 
 
c5a694c
 
42c1511
c5a694c
 
 
1fae026
 
42c1511
 
 
 
c5a694c
 
42c1511
c5a694c
42c1511
c5a694c
5d7e5fe
c5a694c
 
 
 
fa6a00f
c5a694c
9b6dcba
 
c5a694c
 
 
 
 
 
 
 
 
1fae026
 
c5a694c
 
 
 
 
 
 
 
 
 
 
5d7e5fe
fa6a00f
 
 
 
3c25866
 
c5a694c
 
 
fa6a00f
 
 
c5a694c
5decd13
c5a694c
 
 
f9b7f8c
3574be0
fa6a00f
 
 
7694344
fa6a00f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d7e5fe
fa6a00f
 
 
3c25866
 
fa6a00f
 
 
3574be0
08c0da1
 
3574be0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
import os
import pandas as pd
from sqlalchemy import create_engine, inspect, URL
from langchain_openai import AzureChatOpenAI
from langchain_community.utilities import SQLDatabase
from langchain_community.agent_toolkits import create_sql_agent
from langchain import PromptTemplate, SQLDatabase
from langchain_experimental.sql.base import SQLDatabaseChain
import streamlit as st
import pyodbc                                                                                           
import openai
import hmac
from langchain_openai import AzureChatOpenAI

from tabulate import tabulate

from utils import SQLDatabaseChainPatched, table_search, extract_question_type, extract_table_name, extract_question_list

#os.environ['OPENAI_API_KEY'] = os.environ['OPENAI_API_KEY2']
#os.environ['AZURE_OPENAI_ENDPOINT'] = os.environ['AZURE_OPENAI_ENDPOINT2']

#openai.api_key = os.environ['OPENAI_API_KEY']
#openai.api_type = 'azure'
#openai.api_base = os.environ['AZURE_OPENAI_ENDPOINT']
#openai.api_version = os.environ['OPENAI_API_VERSION']
openai.api_key = os.environ['OPENAI_API_KEY']
openai.api_type = 'azure'
openai.api_base = os.environ['AZURE_OPENAI_ENDPOINT']
openai.api_version = os.environ['OPENAI_API_VERSION']
os.environ['AZURE_OPENAI_API_KEY'] = os.environ['OPENAI_API_KEY']

password = os.environ['app_password']

deployment_name = "gpt-4o"

print(pyodbc.drivers()) 

mapping = {'History_All_Skus_Availability': 'SKU Availability', \
           'HISTORY_AVAVBAIL':'AV availability', \
           'HISTORY_BUFamilyAvailability': 'Family and Business Unit (BU) availability', \
            'HISTORY_OpenOrderShortage': 'Part Shortage', \
            'MasterSkuAvBom_PA': 'Business unit to SKU to AV Mapping', \
           'SMF_WT_BASE_ORDER': 'All Orders', \
           'SMF_WT_BASE_FORECAST': 'All Forecast', \
          'DAILY_INVENTORY': 'Part Inventory', \
          'HISTORY_Sku_Shortage': 'SKU Shortage', 'PART_PRICE_MASTER': 'Part Prices'}
inv_mapping = {val: key for (key,val) in mapping.items()}

st.title("Welcome to the Analysis GPT")

st.markdown("We have the following table information - {}".format(", ".join(list(inv_mapping.keys()))))

template = """

            You are a database expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
            The final answer should be in a concise natural language. 
            
            Use the history if you can not understand the question.

            Make sure youn understand the plural nouns and process them accordingly to ensure correct query.
            
            For instance -- 
            
            commodities should be converted to commodity,
            
            products should be converted to product, 
            
            SKUs to be converted to SKU, 
            
            families to be converted to family.
            
            If the question is in another language, translate it to English before proceeding.
            
            Do not repeat the question while generating the SQL query.
            
            Only generate a correct {dialect} query.
            
            Once the SQLResult is available, generate the final answer in natural language format. Do not regenerate the question or SQL query in the final answer.
            If a question asks about price increase or decrease, first you should get the data for the given time period and then use your intelligence to calculate the increase/decrease over time. 
            
            Breakdown a complex queries into subproblems and solve them.
            
            If the question asks any information for any particular number of days, use the lookback from the maximum date in the table, not from today's date.
            
            Please note that MSSQL does not use LIMIT, but uses TOP clause.
            
            You may also need to resolve the column name, as per the metadata. For instance, if the user asks about families and the column name is family, you should use family in the generated SQL.
            
            Make sure that the column names are present in the table, by looking at the metadata.
            
            If a question asks about availability over a period of time, you need to use SUM to calculate the total availability over that time period.
            
            If a question mentions SKU, then use SKU column for filter, do not use any other column like comodity
            
            If a question asks about AV of shortage, do not use AV in the SQL query as AV is not a valid column name. AV is the key in the Shortage column.
            
            In the OpenOrderShotage table, the column Item should be used to extract the part ids, to answer questions related to shortage.
            
            In the OpenOrderShotage table, Customer_Part_Name column is equivalent to SKU.
            
            The AV_Shortage column in History_All_Skus_Availability table is a dictionary. So use this column judiciously.
            
            Use the following format:
            Question: Question here
            SQLQuery: SQL Query to run
            SQLResult: Result of the SQLQuery
            Answer: Final answer here.
            
            Only use the following tables:
            {table_info}
            Question: {input}
            """

def check_password():
    """Returns `True` if the user had the correct password."""

    def password_entered():
        """Checks whether a password entered by the user is correct."""
        if hmac.compare_digest(st.session_state["password"], password):
            st.session_state["password_correct"] = True
            del st.session_state["password"]  # Don't store the password.
        else:
            st.session_state["password_correct"] = False

    # Return True if the password is validated.
    if st.session_state.get("password_correct", False):
        return True

    # Show input for password.
    st.text_input(
        "Password", type="password", on_change=password_entered, key="password"
    )
    if "password_correct" in st.session_state:
        st.error("😕 Password incorrect")
    return False
    
if __name__ == '__main__':
    connection_string = ("Driver=FreeTDS;Server=crawlersdb.c3pzpntwjvdf.us-east-1.rds.amazonaws.com;Database=SmartCleverST;PORT=1433;UID=CleverData;PWD={};TrustServerCertificate=yes;".format(os.environ['DB_PWD'])
    )
    connection_url = URL.create(
        "mssql+pyodbc", 
        query={"odbc_connect": connection_string}
    )
    engine = create_engine(connection_url)
    db = SQLDatabase(engine=engine, sample_rows_in_table_info=3, view_support=True)

    prompt = PromptTemplate(template=template, input_variables=["dialect","input","table_info","top_k"])
    llm = AzureChatOpenAI(
    deployment_name=deployment_name, temperature=0
    )
    
    db_chain = SQLDatabaseChainPatched.from_llm(
            llm, db,
            prompt=prompt,
        )
    db_chain.set_llms(llms={
            '4k': llm
        })

    #question = st.text_input("Ask a question in natural language and press enter")
    
    if "messages" not in st.session_state:
        st.session_state.messages = []

    if "last_message_failed" not in st.session_state:
        st.session_state.last_message_failed = False

    if "ask_user_selection" not in st.session_state:
        st.session_state.ask_user_selection = False
        st.session_state.prev_selection = []
        
    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])

    if not check_password():
        st.stop()  # Do not continue if check_password is not True.

    question = st.chat_input("What is your question today? Type and press enter.")
        
    #if 'questions' not in st.session_state:
    #    st.session_state['questions'] = []

    if 'history' not in st.session_state:
        st.session_state['history'] = []

    if "previous_response" not in st.session_state:
        st.session_state['previous_response'] = ""
        
    if question is not None and question != "":
        #q_relevant = extract_question_type(llm, question)
        with st.chat_message("user"):
            st.markdown(question)
        # Add user message to chat history
        st.session_state.messages.append({"role": "user", "content": question})
    
        #if 'yes' in q_relevant.lower():
        if st.session_state['previous_response'] != "Sorry I may not have answer to this question.":
            st.session_state.last_message_failed = False
            with st.status("Retrieving results..."):
                #top_table_names = table_search(question, topk=1)['table'].tolist()
                questions = extract_question_list(llm, question)
                
                if type(questions) == list:
                    responses = []
                    for q in questions:
                        #top_table_names = extract_table_name(llm, q) #[extract_table_name(llm, q)] 
                        top_table_names = table_search(q, topk=3)['table'].tolist()
                        print (top_table_names)
                        #history = st.session_state['questions']
                        history = st.session_state['history']
                        try:
                            db_chain._call(inputs={'query': q, 'history': history, \
                                   'table_names_to_use': top_table_names})
                        except:
                            pass

                        if db_chain.intermediate_steps.get("result",'') != '':
                            response = db_chain.intermediate_steps.get("result",'')
                        elif db_chain.intermediate_steps.get("sql_data",'') != '':
                            out = pd.DataFrame.from_dict(db_chain.intermediate_steps['sql_data'])
                            response = tabulate(out, headers='keys', tablefmt='psql')
                        else:
                            response = ""

                        if "SQLQuery" in response or "Answer:" in response:
                            response = ""
                            
                        responses.append(response)
                        st.session_state['history'].append(db_chain.intermediate_steps.get("sql_cmd",''))

                    response = "\n\n".join(responses)
                    if response == "":
                        response = "Sorry I may not have answer to this question."
                        
                else:
                    #top_table_names = extract_table_name(llm, question) #[extract_table_name(llm, question)] 
                    top_table_names = table_search(question, topk=3)['table'].tolist()
                    print (top_table_names)
                    #history = st.session_state['questions']
                    history = st.session_state['history']
                    #try:
                    db_chain._call(inputs={'query': question, 'history': history, \
                               'table_names_to_use': top_table_names})
                    #except:
                    #    pass

                    if db_chain.intermediate_steps.get("result",'') != '':
                        response = db_chain.intermediate_steps.get("result",'')
                    elif db_chain.intermediate_steps.get("sql_data",'') != '':
                        out = pd.DataFrame.from_dict(db_chain.intermediate_steps['sql_data'])
                        response = tabulate(out, headers='keys', tablefmt='psql')
                    elif db_chain.intermediate_steps.get("sql_cmd_unchecked",'') == '':
                        #print (db_chain)
                        #st.markdown("Sorry I cannot answer that. Please try again later.")
                        response = "Sorry I may not have answer to this question."
                    else:
                        #st.markdown("Sorry I cannot answer that. Please try again later.")
                        response = "Sorry I may not have answer to this question."

                    if "SQLQuery" in response or "Answer:" in response:
                        response = "Sorry I may not have answer to this question."
                        
                    st.session_state['history'].append(db_chain.intermediate_steps.get("sql_cmd",''))

                    if response == "Sorry I may not have answer to this question.":
                        st.session_state.ask_user_selection = True
                        st.session_state.prev_selection = [mapping[tab] for tab in top_table_names if tab in mapping]

            st.session_state['previous_response'] = response
            
            with st.chat_message("assistant"):
                st.markdown(response)
            # Add assistant response to chat history
            if st.session_state.ask_user_selection == False:
                st.session_state.messages.append({"role": "assistant", "content": response})

        else:
            with st.chat_message("assistant"):
                st.markdown("Looks like this question is not related to the database, but a generic. Do you want me to answer it from the table? Otherwise I will use my own knowledge.")
                st.session_state.last_message_failed = True

    if st.session_state.last_message_failed == True:
        if st.button("Yes"):
            question = st.session_state.messages[-1]['content']
            with st.status("Retrieving results..."):
                #top_table_names = table_search(question, topk=1)['table'].tolist()
                questions = extract_question_list(llm, question)

                if type(questions) == list:
                    responses = []
                    for q in questions:
                        #top_table_names = extract_table_name(llm, q) #[extract_table_name(llm, q)] 
                        top_table_names = table_search(q, topk=3)['table'].tolist()
                        print (top_table_names)
                        #history = st.session_state['questions']
                        history = st.session_state['history']
                        try:
                            db_chain._call(inputs={'query': q, 'history': history, \
                                   'table_names_to_use': top_table_names})
                        except:
                            pass

                        if db_chain.intermediate_steps.get("result",'') != '':
                            response = db_chain.intermediate_steps.get("result",'')
                        elif db_chain.intermediate_steps.get("sql_data",'') != '':
                            out = pd.DataFrame.from_dict(db_chain.intermediate_steps['sql_data'])
                            response = tabulate(out, headers='keys', tablefmt='psql')
                        else:
                            response = ""

                        if "SQLQuery" in response or "Answer:" in response:
                            response = ""
                            
                        responses.append(response)
                        st.session_state['history'].append(db_chain.intermediate_steps.get("sql_cmd",''))

                    response = "\n\n".join(responses)
                    if response == "":
                        response = "Sorry I may not have answer to this question."
                    
                else:
                    #top_table_names = extract_table_name(llm, question) #[extract_table_name(llm, question)] 
                    top_table_names = table_search(question, topk=3)['table'].tolist()
                    print (top_table_names)
                    #history = st.session_state['questions']
                    history = st.session_state['history']
                    #try:
                    db_chain._call(inputs={'query': question, 'history': history, \
                               'table_names_to_use': top_table_names})
                    #except:
                    #    pass

                    if db_chain.intermediate_steps.get("result",'') != '':
                        response = db_chain.intermediate_steps.get("result",'')
                    elif db_chain.intermediate_steps.get("sql_data",'') != '':
                        out = pd.DataFrame.from_dict(db_chain.intermediate_steps['sql_data'])
                        response = tabulate(out, headers='keys', tablefmt='psql')
                    elif db_chain.intermediate_steps.get("sql_cmd_unchecked",'') == '':
                        response = "Sorry I may not have answer to this question."
                    else:
                        response = "Sorry I may not have answer to this question."

                    if "SQLQuery" in response or "Answer:" in response:
                        response = "Sorry I may not have answer to this question."
                        
                    st.session_state['history'].append(db_chain.intermediate_steps.get("sql_cmd",''))

                    if response == "Sorry I may not have answer to this question.":
                        st.session_state.ask_user_selection = True
                        st.session_state.prev_selection = [mapping[tab] for tab in top_table_names if tab in mapping]

            st.session_state['previous_response'] = response
            with st.chat_message("assistant"):
                st.markdown(response)
            # Add assistant response to chat history
            if st.session_state.ask_user_selection == False:
                st.session_state.messages.append({"role": "assistant", "content": response})
                
        elif st.button("No"):
            question = st.session_state.messages[-1]['content']
            response = llm.invoke(question).content
            with st.chat_message("assistant"):
                st.markdown(response)
            st.session_state.messages.append({"role": "assistant", "content": response})

    if st.session_state.ask_user_selection == True and len(st.session_state.prev_selection) > 1:
        top_table_names = [st.selectbox("You can help me to look at the right table, so that I can answer to your previous question", \
                                       tuple(st.session_state.prev_selection))]
        top_table_names = [inv_mapping[i] for i in top_table_names]
        question = st.session_state.messages[-1]['content']
        with st.status("Retrieving results..."):
            history = st.session_state['history']
            #try:
            db_chain._call(inputs={'query': question, 'history': history, \
                       'table_names_to_use': top_table_names})
            #except:
            #    pass
    
            if db_chain.intermediate_steps.get("result",'') != '':
                response = db_chain.intermediate_steps.get("result",'')
            elif db_chain.intermediate_steps.get("sql_data",'') != '':
                out = pd.DataFrame.from_dict(db_chain.intermediate_steps['sql_data'])
                response = tabulate(out, headers='keys', tablefmt='psql')
            elif db_chain.intermediate_steps.get("sql_cmd_unchecked",'') == '':
                response = "Sorry I still cannot answer to this question."
            else:
                response = "Sorry I still cannot answer to this question."
    
            if "SQLQuery" in response or "Answer:" in response:
                response = "Sorry I still cannot answer to this question."
                
        st.session_state['history'].append(db_chain.intermediate_steps.get("sql_cmd",''))

        st.session_state.ask_user_selection = False
        st.session_state.prev_selection = []

        st.session_state['previous_response'] = response
        with st.chat_message("assistant"):
            st.markdown(response)
            
    if st.button("Reset Chat History"):
        #st.session_state['questions'] = []
        st.session_state['history'] = []
        st.session_state.messages = []