Spaces:
Sleeping
Sleeping
| import os | |
| import pandas as pd | |
| from sqlalchemy import create_engine, inspect, URL | |
| from langchain_openai import AzureChatOpenAI | |
| from langchain_community.utilities import SQLDatabase | |
| from langchain_community.agent_toolkits import create_sql_agent | |
| from langchain import PromptTemplate, SQLDatabase | |
| from langchain_experimental.sql.base import SQLDatabaseChain | |
| import streamlit as st | |
| import pyodbc | |
| import openai | |
| import hmac | |
| from langchain_openai import AzureChatOpenAI | |
| from tabulate import tabulate | |
| from utils import SQLDatabaseChainPatched, table_search, extract_question_type, extract_table_name, extract_question_list | |
| #os.environ['OPENAI_API_KEY'] = os.environ['OPENAI_API_KEY2'] | |
| #os.environ['AZURE_OPENAI_ENDPOINT'] = os.environ['AZURE_OPENAI_ENDPOINT2'] | |
| #openai.api_key = os.environ['OPENAI_API_KEY'] | |
| #openai.api_type = 'azure' | |
| #openai.api_base = os.environ['AZURE_OPENAI_ENDPOINT'] | |
| #openai.api_version = os.environ['OPENAI_API_VERSION'] | |
| openai.api_key = os.environ['OPENAI_API_KEY'] | |
| openai.api_type = 'azure' | |
| openai.api_base = os.environ['AZURE_OPENAI_ENDPOINT'] | |
| openai.api_version = os.environ['OPENAI_API_VERSION'] | |
| os.environ['AZURE_OPENAI_API_KEY'] = os.environ['OPENAI_API_KEY'] | |
| password = os.environ['app_password'] | |
| deployment_name = "gpt-4o" | |
| print(pyodbc.drivers()) | |
| mapping = {'History_All_Skus_Availability': 'SKU Availability', \ | |
| 'HISTORY_AVAVBAIL':'AV availability', \ | |
| 'HISTORY_BUFamilyAvailability': 'Family and Business Unit (BU) availability', \ | |
| 'HISTORY_OpenOrderShortage': 'Part Shortage', \ | |
| 'MasterSkuAvBom_PA': 'Business unit to SKU to AV Mapping', \ | |
| 'SMF_WT_BASE_ORDER': 'All Orders', \ | |
| 'SMF_WT_BASE_FORECAST': 'All Forecast', \ | |
| 'DAILY_INVENTORY': 'Part Inventory', \ | |
| 'HISTORY_Sku_Shortage': 'SKU Shortage', 'PART_PRICE_MASTER': 'Part Prices'} | |
| inv_mapping = {val: key for (key,val) in mapping.items()} | |
| st.title("Welcome to the Analysis GPT") | |
| st.markdown("We have the following table information - {}".format(", ".join(list(inv_mapping.keys())))) | |
| template = """ | |
| You are a database expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. | |
| The final answer should be in a concise natural language. | |
| Use the history if you can not understand the question. | |
| Make sure youn understand the plural nouns and process them accordingly to ensure correct query. | |
| For instance -- | |
| commodities should be converted to commodity, | |
| products should be converted to product, | |
| SKUs to be converted to SKU, | |
| families to be converted to family. | |
| If the question is in another language, translate it to English before proceeding. | |
| Do not repeat the question while generating the SQL query. | |
| Only generate a correct {dialect} query. | |
| Once the SQLResult is available, generate the final answer in natural language format. Do not regenerate the question or SQL query in the final answer. | |
| If a question asks about price increase or decrease, first you should get the data for the given time period and then use your intelligence to calculate the increase/decrease over time. | |
| Breakdown a complex queries into subproblems and solve them. | |
| If the question asks any information for any particular number of days, use the lookback from the maximum date in the table, not from today's date. | |
| Please note that MSSQL does not use LIMIT, but uses TOP clause. | |
| You may also need to resolve the column name, as per the metadata. For instance, if the user asks about families and the column name is family, you should use family in the generated SQL. | |
| Make sure that the column names are present in the table, by looking at the metadata. | |
| If a question asks about availability over a period of time, you need to use SUM to calculate the total availability over that time period. | |
| If a question mentions SKU, then use SKU column for filter, do not use any other column like comodity | |
| If a question asks about AV of shortage, do not use AV in the SQL query as AV is not a valid column name. AV is the key in the Shortage column. | |
| In the OpenOrderShotage table, the column Item should be used to extract the part ids, to answer questions related to shortage. | |
| In the OpenOrderShotage table, Customer_Part_Name column is equivalent to SKU. | |
| The AV_Shortage column in History_All_Skus_Availability table is a dictionary. So use this column judiciously. | |
| Use the following format: | |
| Question: Question here | |
| SQLQuery: SQL Query to run | |
| SQLResult: Result of the SQLQuery | |
| Answer: Final answer here. | |
| Only use the following tables: | |
| {table_info} | |
| Question: {input} | |
| """ | |
| def check_password(): | |
| """Returns `True` if the user had the correct password.""" | |
| def password_entered(): | |
| """Checks whether a password entered by the user is correct.""" | |
| if hmac.compare_digest(st.session_state["password"], password): | |
| st.session_state["password_correct"] = True | |
| del st.session_state["password"] # Don't store the password. | |
| else: | |
| st.session_state["password_correct"] = False | |
| # Return True if the password is validated. | |
| if st.session_state.get("password_correct", False): | |
| return True | |
| # Show input for password. | |
| st.text_input( | |
| "Password", type="password", on_change=password_entered, key="password" | |
| ) | |
| if "password_correct" in st.session_state: | |
| st.error("😕 Password incorrect") | |
| return False | |
| if __name__ == '__main__': | |
| connection_string = ("Driver=FreeTDS;Server=crawlersdb.c3pzpntwjvdf.us-east-1.rds.amazonaws.com;Database=SmartCleverST;PORT=1433;UID=CleverData;PWD={};TrustServerCertificate=yes;".format(os.environ['DB_PWD']) | |
| ) | |
| connection_url = URL.create( | |
| "mssql+pyodbc", | |
| query={"odbc_connect": connection_string} | |
| ) | |
| engine = create_engine(connection_url) | |
| db = SQLDatabase(engine=engine, sample_rows_in_table_info=3, view_support=True) | |
| prompt = PromptTemplate(template=template, input_variables=["dialect","input","table_info","top_k"]) | |
| llm = AzureChatOpenAI( | |
| deployment_name=deployment_name, temperature=0 | |
| ) | |
| db_chain = SQLDatabaseChainPatched.from_llm( | |
| llm, db, | |
| prompt=prompt, | |
| ) | |
| db_chain.set_llms(llms={ | |
| '4k': llm | |
| }) | |
| #question = st.text_input("Ask a question in natural language and press enter") | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| if "last_message_failed" not in st.session_state: | |
| st.session_state.last_message_failed = False | |
| if "ask_user_selection" not in st.session_state: | |
| st.session_state.ask_user_selection = False | |
| st.session_state.prev_selection = [] | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| if not check_password(): | |
| st.stop() # Do not continue if check_password is not True. | |
| question = st.chat_input("What is your question today? Type and press enter.") | |
| #if 'questions' not in st.session_state: | |
| # st.session_state['questions'] = [] | |
| if 'history' not in st.session_state: | |
| st.session_state['history'] = [] | |
| if "previous_response" not in st.session_state: | |
| st.session_state['previous_response'] = "" | |
| if question is not None and question != "": | |
| #q_relevant = extract_question_type(llm, question) | |
| with st.chat_message("user"): | |
| st.markdown(question) | |
| # Add user message to chat history | |
| st.session_state.messages.append({"role": "user", "content": question}) | |
| #if 'yes' in q_relevant.lower(): | |
| if st.session_state['previous_response'] != "Sorry I may not have answer to this question.": | |
| st.session_state.last_message_failed = False | |
| with st.status("Retrieving results..."): | |
| #top_table_names = table_search(question, topk=1)['table'].tolist() | |
| questions = extract_question_list(llm, question) | |
| if type(questions) == list: | |
| responses = [] | |
| for q in questions: | |
| #top_table_names = extract_table_name(llm, q) #[extract_table_name(llm, q)] | |
| top_table_names = table_search(q, topk=3)['table'].tolist() | |
| print (top_table_names) | |
| #history = st.session_state['questions'] | |
| history = st.session_state['history'] | |
| try: | |
| db_chain._call(inputs={'query': q, 'history': history, \ | |
| 'table_names_to_use': top_table_names}) | |
| except: | |
| pass | |
| if db_chain.intermediate_steps.get("result",'') != '': | |
| response = db_chain.intermediate_steps.get("result",'') | |
| elif db_chain.intermediate_steps.get("sql_data",'') != '': | |
| out = pd.DataFrame.from_dict(db_chain.intermediate_steps['sql_data']) | |
| response = tabulate(out, headers='keys', tablefmt='psql') | |
| else: | |
| response = "" | |
| if "SQLQuery" in response or "Answer:" in response: | |
| response = "" | |
| responses.append(response) | |
| st.session_state['history'].append(db_chain.intermediate_steps.get("sql_cmd",'')) | |
| response = "\n\n".join(responses) | |
| if response == "": | |
| response = "Sorry I may not have answer to this question." | |
| else: | |
| #top_table_names = extract_table_name(llm, question) #[extract_table_name(llm, question)] | |
| top_table_names = table_search(question, topk=3)['table'].tolist() | |
| print (top_table_names) | |
| #history = st.session_state['questions'] | |
| history = st.session_state['history'] | |
| #try: | |
| db_chain._call(inputs={'query': question, 'history': history, \ | |
| 'table_names_to_use': top_table_names}) | |
| #except: | |
| # pass | |
| if db_chain.intermediate_steps.get("result",'') != '': | |
| response = db_chain.intermediate_steps.get("result",'') | |
| elif db_chain.intermediate_steps.get("sql_data",'') != '': | |
| out = pd.DataFrame.from_dict(db_chain.intermediate_steps['sql_data']) | |
| response = tabulate(out, headers='keys', tablefmt='psql') | |
| elif db_chain.intermediate_steps.get("sql_cmd_unchecked",'') == '': | |
| #print (db_chain) | |
| #st.markdown("Sorry I cannot answer that. Please try again later.") | |
| response = "Sorry I may not have answer to this question." | |
| else: | |
| #st.markdown("Sorry I cannot answer that. Please try again later.") | |
| response = "Sorry I may not have answer to this question." | |
| if "SQLQuery" in response or "Answer:" in response: | |
| response = "Sorry I may not have answer to this question." | |
| st.session_state['history'].append(db_chain.intermediate_steps.get("sql_cmd",'')) | |
| if response == "Sorry I may not have answer to this question.": | |
| st.session_state.ask_user_selection = True | |
| st.session_state.prev_selection = [mapping[tab] for tab in top_table_names if tab in mapping] | |
| st.session_state['previous_response'] = response | |
| with st.chat_message("assistant"): | |
| st.markdown(response) | |
| # Add assistant response to chat history | |
| if st.session_state.ask_user_selection == False: | |
| st.session_state.messages.append({"role": "assistant", "content": response}) | |
| else: | |
| with st.chat_message("assistant"): | |
| st.markdown("Looks like this question is not related to the database, but a generic. Do you want me to answer it from the table? Otherwise I will use my own knowledge.") | |
| st.session_state.last_message_failed = True | |
| if st.session_state.last_message_failed == True: | |
| if st.button("Yes"): | |
| question = st.session_state.messages[-1]['content'] | |
| with st.status("Retrieving results..."): | |
| #top_table_names = table_search(question, topk=1)['table'].tolist() | |
| questions = extract_question_list(llm, question) | |
| if type(questions) == list: | |
| responses = [] | |
| for q in questions: | |
| #top_table_names = extract_table_name(llm, q) #[extract_table_name(llm, q)] | |
| top_table_names = table_search(q, topk=3)['table'].tolist() | |
| print (top_table_names) | |
| #history = st.session_state['questions'] | |
| history = st.session_state['history'] | |
| try: | |
| db_chain._call(inputs={'query': q, 'history': history, \ | |
| 'table_names_to_use': top_table_names}) | |
| except: | |
| pass | |
| if db_chain.intermediate_steps.get("result",'') != '': | |
| response = db_chain.intermediate_steps.get("result",'') | |
| elif db_chain.intermediate_steps.get("sql_data",'') != '': | |
| out = pd.DataFrame.from_dict(db_chain.intermediate_steps['sql_data']) | |
| response = tabulate(out, headers='keys', tablefmt='psql') | |
| else: | |
| response = "" | |
| if "SQLQuery" in response or "Answer:" in response: | |
| response = "" | |
| responses.append(response) | |
| st.session_state['history'].append(db_chain.intermediate_steps.get("sql_cmd",'')) | |
| response = "\n\n".join(responses) | |
| if response == "": | |
| response = "Sorry I may not have answer to this question." | |
| else: | |
| #top_table_names = extract_table_name(llm, question) #[extract_table_name(llm, question)] | |
| top_table_names = table_search(question, topk=3)['table'].tolist() | |
| print (top_table_names) | |
| #history = st.session_state['questions'] | |
| history = st.session_state['history'] | |
| #try: | |
| db_chain._call(inputs={'query': question, 'history': history, \ | |
| 'table_names_to_use': top_table_names}) | |
| #except: | |
| # pass | |
| if db_chain.intermediate_steps.get("result",'') != '': | |
| response = db_chain.intermediate_steps.get("result",'') | |
| elif db_chain.intermediate_steps.get("sql_data",'') != '': | |
| out = pd.DataFrame.from_dict(db_chain.intermediate_steps['sql_data']) | |
| response = tabulate(out, headers='keys', tablefmt='psql') | |
| elif db_chain.intermediate_steps.get("sql_cmd_unchecked",'') == '': | |
| response = "Sorry I may not have answer to this question." | |
| else: | |
| response = "Sorry I may not have answer to this question." | |
| if "SQLQuery" in response or "Answer:" in response: | |
| response = "Sorry I may not have answer to this question." | |
| st.session_state['history'].append(db_chain.intermediate_steps.get("sql_cmd",'')) | |
| if response == "Sorry I may not have answer to this question.": | |
| st.session_state.ask_user_selection = True | |
| st.session_state.prev_selection = [mapping[tab] for tab in top_table_names if tab in mapping] | |
| st.session_state['previous_response'] = response | |
| with st.chat_message("assistant"): | |
| st.markdown(response) | |
| # Add assistant response to chat history | |
| if st.session_state.ask_user_selection == False: | |
| st.session_state.messages.append({"role": "assistant", "content": response}) | |
| elif st.button("No"): | |
| question = st.session_state.messages[-1]['content'] | |
| response = llm.invoke(question).content | |
| with st.chat_message("assistant"): | |
| st.markdown(response) | |
| st.session_state.messages.append({"role": "assistant", "content": response}) | |
| if st.session_state.ask_user_selection == True and len(st.session_state.prev_selection) > 1: | |
| top_table_names = [st.selectbox("You can help me to look at the right table, so that I can answer to your previous question", \ | |
| tuple(st.session_state.prev_selection))] | |
| top_table_names = [inv_mapping[i] for i in top_table_names] | |
| question = st.session_state.messages[-1]['content'] | |
| with st.status("Retrieving results..."): | |
| history = st.session_state['history'] | |
| #try: | |
| db_chain._call(inputs={'query': question, 'history': history, \ | |
| 'table_names_to_use': top_table_names}) | |
| #except: | |
| # pass | |
| if db_chain.intermediate_steps.get("result",'') != '': | |
| response = db_chain.intermediate_steps.get("result",'') | |
| elif db_chain.intermediate_steps.get("sql_data",'') != '': | |
| out = pd.DataFrame.from_dict(db_chain.intermediate_steps['sql_data']) | |
| response = tabulate(out, headers='keys', tablefmt='psql') | |
| elif db_chain.intermediate_steps.get("sql_cmd_unchecked",'') == '': | |
| response = "Sorry I still cannot answer to this question." | |
| else: | |
| response = "Sorry I still cannot answer to this question." | |
| if "SQLQuery" in response or "Answer:" in response: | |
| response = "Sorry I still cannot answer to this question." | |
| st.session_state['history'].append(db_chain.intermediate_steps.get("sql_cmd",'')) | |
| st.session_state.ask_user_selection = False | |
| st.session_state.prev_selection = [] | |
| st.session_state['previous_response'] = response | |
| with st.chat_message("assistant"): | |
| st.markdown(response) | |
| if st.button("Reset Chat History"): | |
| #st.session_state['questions'] = [] | |
| st.session_state['history'] = [] | |
| st.session_state.messages = [] |