Spaces:
Runtime error
Runtime error
| from openai import OpenAI | |
| import pandas as pd | |
| import psycopg2 | |
| import time | |
| import gradio as gr | |
| import sqlparse | |
| import re | |
| import os | |
| import warnings | |
| from persistStorage import saveLog, getAllLogFilesPaths, getNewCsvFilePath, removeAllCsvFiles | |
| from config import * | |
| from constants import * | |
| from utils import * | |
| from gptManager import ChatgptManager | |
| from queryHelperManagerCoT import QueryHelperChainOfThought | |
| pd.set_option('display.max_columns', None) | |
| pd.set_option('display.max_rows', None) | |
| # Filter out all warning messages | |
| warnings.filterwarnings("ignore") | |
| LOGGED_IN_USERS = [] | |
| dbCreds = DataWrapper(DB_CREDS_DATA) | |
| dbEngine = DbEngine(dbCreds) | |
| print("getting tablesAndCols..") | |
| tablesAndCols = getAllTablesInfo(dbEngine, SCHEMA_NAME) | |
| print("Done.") | |
| metadataLayout = MetaDataLayout(schemaName=SCHEMA_NAME, allTablesAndCols=tablesAndCols) | |
| metadataLayout.setSelection(DEFAULT_TABLES_COLS) | |
| selectedTablesAndCols = metadataLayout.getSelectedTablesAndCols() | |
| openAIClient2 = OpenAI(api_key=OPENAI_API_KEY) | |
| gptInstanceForCoT = ChatgptManager(openAIClient2, model=GPT_MODEL) | |
| queryHelperCot = QueryHelperChainOfThought(gptInstanceForCoT=gptInstanceForCoT, | |
| schemaName=SCHEMA_NAME,platform=PLATFORM, | |
| metadataLayout=metadataLayout, | |
| sampleDataRows=SAMPLE_ROW_MAX, | |
| gptSampleRows=GPT_SAMPLE_ROWS, | |
| dbEngine=dbEngine, | |
| getSampleDataForTablesAndCols=getSampleDataForTablesAndCols) | |
| def checkAuth(username, password): | |
| global ADMIN, PASSWD | |
| if username == ADMIN and password == PASSWD: | |
| LOGGED_IN_USERS.append(username) | |
| print("user logged in...",username) | |
| return True | |
| return False | |
| # Function to save history of chat | |
| def respondCoT(message, chatHistory, verboseChatHistory, loggedUser): | |
| """gpt response handler for gradio ui""" | |
| global queryHelperCot | |
| if len(loggedUser)==0: | |
| loggedUser.append(LOGGED_IN_USERS[-1]) | |
| try: | |
| # botMessage = queryHelperCot.getQueryForUserInputCoT(message) | |
| botMessage, verboseBotMessage = queryHelperCot.getQueryForUserInputWithHistory(verboseChatHistory, message) | |
| except Exception as e: | |
| errorMessage = {"function":"queryHelperCot.getQueryForUserInputWithHistory","error":str(e), "userInput":message} | |
| saveLog(errorMessage, 'error') | |
| raise ValueError(str(e)) | |
| logMessage = {"userInput":message, "completeGptResponse":verboseBotMessage, | |
| "parsedResponse":botMessage, "function":"queryHelperCot.getQueryForUserInputWithHistory"} | |
| saveLog(logMessage) | |
| chatHistory.append((message, botMessage)) | |
| verboseChatHistory.append((message, verboseBotMessage)) | |
| return "", chatHistory, verboseChatHistory, loggedUser | |
| def preProcessSQL(sql): | |
| sql=sql.replace(';', '') | |
| disclaimerOutputStripping = "" | |
| if ('limit' in sql[-15:].lower())==False: | |
| sql = sql + ' ' + 'limit 100' | |
| disclaimerOutputStripping = """Results are stripped to show only top 100 rows. | |
| Please add your custom limit to get extended result. | |
| eg\n select * from schema.table limit 200\n""" | |
| sql = sqlparse.format(sql, reindent=True, keyword_case='upper') | |
| return sql, disclaimerOutputStripping | |
| def onGetResultCsvFile(sql): | |
| global dbEngine, queryHelperCot | |
| sql, disclaimerOutputStripping = preProcessSQL(sql=sql) | |
| if not isDataQuery(sql): | |
| return "Sorry not allowed to run. As the query modifies the data." | |
| try: | |
| dbEngine2 = DbEngine(dbCreds) | |
| dbEngine2.connect() | |
| conn = dbEngine2.getConnection() | |
| df = pd.read_sql_query(sql, con=conn) | |
| dbEngine2.disconnect() | |
| # return disclaimerOutputStripping + str(pd.DataFrame(df)) | |
| except Exception as e: | |
| # errorMessage = {"function":"testSQL","error":str(e), "userInput":sql} | |
| # saveLog(errorMessage, 'error') | |
| dbEngine2.disconnect() | |
| df = pd.DataFrame() | |
| # print(f"Error occured during running the query {sql}.\n and the error is {str(e)}") | |
| removeAllCsvFiles() | |
| csvFilePath = getNewCsvFilePath() | |
| df.to_csv(csvFilePath, index=False) | |
| downloadableFilesPaths = getAllLogFilesPaths() | |
| fileComponent = gr.File(csvFilePath) | |
| return fileComponent | |
| def testSQL(sql): | |
| global dbEngine, queryHelperCot | |
| sql, disclaimerOutputStripping = preProcessSQL(sql=sql) | |
| if not isDataQuery(sql): | |
| return "Sorry not allowed to run. As the query modifies the data." | |
| try: | |
| dbEngine2 = DbEngine(dbCreds) | |
| dbEngine2.connect() | |
| conn = dbEngine2.getConnection() | |
| df = pd.read_sql_query(sql, con=conn) | |
| dbEngine2.disconnect() | |
| table_output = gr.Dataframe(df) | |
| return disclaimerOutputStripping, table_output | |
| except Exception as e: | |
| errorMessage = {"function":"testSQL","error":str(e), "userInput":sql} | |
| saveLog(errorMessage, 'error') | |
| dbEngine2.disconnect() | |
| print(f"Error occured during running the query {sql}.\n and the error is {str(e)}") | |
| table_output = gr.Dataframe(pd.DataFrame()) | |
| return f"The query you entered throws some error. Here is the error.\n {str(e)}", table_output | |
| def onSelectedTablesChange(tablesSelected): | |
| #Updates tables visible and allow selecting columns for them | |
| global queryHelperCot | |
| print(f"Selected tables : {tablesSelected}") | |
| metadataLayout = queryHelperCot.getMetadata() | |
| allTablesAndCols = metadataLayout.getAllTablesCols() | |
| selectedTablesAndCols = metadataLayout.getSelectedTablesAndCols() | |
| allTablesList = list(allTablesAndCols.keys()) | |
| tableBoxes = [] | |
| for i in range(len(allTablesList)): | |
| if allTablesList[i] in tablesSelected: | |
| dd = gr.Dropdown( | |
| allTablesAndCols[allTablesList[i]],visible=True,value=selectedTablesAndCols.get(allTablesList[i],None), multiselect=True, label=allTablesList[i], info="Select columns of a table" | |
| ) | |
| tableBoxes.append(dd) | |
| else: | |
| dd = gr.Dropdown( | |
| allTablesAndCols[allTablesList[i]],visible=False,value=selectedTablesAndCols.get(allTablesList[i],None), multiselect=True, label=allTablesList[i], info="Select columns of a table" | |
| ) | |
| tableBoxes.append(dd) | |
| return tableBoxes | |
| def onSelectedColumnsChange(*tableBoxes): | |
| #update selection of columns and tables (include new tables and cols in gpts context) | |
| global queryHelperCot | |
| metadataLayout = queryHelperCot.getMetadata() | |
| allTablesAndCols = metadataLayout.getAllTablesCols() | |
| allTablesList = list(allTablesAndCols.keys()) | |
| tablesAndCols = {} | |
| result = '' | |
| print("Getting selected tables and columns from gradio") | |
| for tableBox, table in zip(tableBoxes, allTablesList): | |
| if isinstance(tableBox, list): | |
| if len(tableBox)!=0: | |
| tablesAndCols[table] = tableBox | |
| else: | |
| pass | |
| metadataLayout.setSelection(tablesAndCols=tablesAndCols) | |
| print("metadata updated") | |
| print("Updating queryHelperCot state, and sample data") | |
| queryHelperCot.updateMetadata(metadataLayout) | |
| return "Columns udpated" | |
| def onResetToDefaultSelection(): | |
| global queryHelperCot | |
| metadataLayout = queryHelperCot.getMetadata() | |
| metadataLayout.setSelection(tablesAndCols=tablesAndCols) | |
| queryHelperCot.updateMetadata(metadataLayout) | |
| metadataLayout = queryHelperCot.getMetadata() | |
| allTablesAndCols = metadataLayout.getAllTablesCols() | |
| selectedTablesAndCols = metadataLayout.getSelectedTablesAndCols() | |
| allTablesList = list(allTablesAndCols.keys()) | |
| tableBoxes = [] | |
| for i in range(len(allTablesList)): | |
| if allTablesList[i] in selectedTablesAndCols.keys(): | |
| dd = gr.Dropdown( | |
| allTablesAndCols[allTablesList[i]],visible=True,value=selectedTablesAndCols.get(allTablesList[i],None), multiselect=True, label=allTablesList[i], info="Select columns of a table" | |
| ) | |
| tableBoxes.append(dd) | |
| else: | |
| dd = gr.Dropdown( | |
| allTablesAndCols[allTablesList[i]],visible=False,value=selectedTablesAndCols.get(allTablesList[i],None), multiselect=True, label=allTablesList[i], info="Select columns of a table" | |
| ) | |
| tableBoxes.append(dd) | |
| return tableBoxes | |
| def onSyncLogsWithDataDir(): | |
| downloadableFilesPaths = getAllLogFilesPaths() | |
| fileComponent = gr.File(downloadableFilesPaths, file_count='multiple') | |
| return fileComponent | |
| with gr.Blocks() as demo: | |
| loggedUser = gr.State([]) | |
| verboseChatHistory = gr.State([]) | |
| with gr.Tab("Query Helper"): | |
| gr.Markdown("""<h1><center> Query Helper</center></h1>""") | |
| chatbot = gr.Chatbot() | |
| msg = gr.Textbox() | |
| def clearChatHistory(): | |
| return [] | |
| clear = gr.ClearButton([msg, chatbot, verboseChatHistory], value="Clear Chat") | |
| clearButton = gr.Button("Clear Context") | |
| clearButton.click(clearChatHistory, inputs=None, outputs=[verboseChatHistory]) | |
| msg.submit(respondCoT, [msg, chatbot, verboseChatHistory, loggedUser], [msg, chatbot, verboseChatHistory, loggedUser]) | |
| # screen 2 : To run sql query against database | |
| with gr.Tab("Run Query"): | |
| gr.Markdown("""<h1><center> Run Query </center></h1>""") | |
| text_input = gr.Textbox(label = 'Input SQL Query', placeholder="Write your SQL query here ...") | |
| text_button = gr.Button("RUN QUERY") | |
| text_output = gr.Textbox(label = 'Result') | |
| table_output = gr.Dataframe(pd.DataFrame()) | |
| clear = gr.ClearButton([text_input, text_output]) | |
| text_button.click(testSQL, inputs=text_input, outputs=[text_output, table_output]) | |
| csvFileComponent = gr.File([], file_count='multiple') | |
| downloadCsv = gr.Button("Generate csv result file") | |
| downloadCsv.click(onGetResultCsvFile, inputs=text_input, outputs=csvFileComponent) | |
| # screen 3 : To set creds, schema, tables and columns | |
| with gr.Tab("Setup"): | |
| gr.Markdown("""<h1><center> Setup Tab </center></h1>""") | |
| text_input = gr.Textbox(label = 'schema name', value= SCHEMA_NAME) | |
| allTablesAndCols = queryHelperCot.getMetadata().getAllTablesCols() | |
| selectedTablesAndCols = queryHelperCot.getMetadata().getSelectedTablesAndCols() | |
| allTablesList = list(allTablesAndCols.keys()) | |
| selectedTablesList = list(selectedTablesAndCols.keys()) | |
| dropDown = gr.Dropdown( | |
| allTablesList, value=selectedTablesList, multiselect=True, label="Selected Tables", info="Select Tables from available tables of the schema" | |
| ) | |
| refreshTables = gr.Button("Refresh selected tables") | |
| tableBoxes = [] | |
| for i in range(len(allTablesList)): | |
| if allTablesList[i] in selectedTablesList: | |
| columnsDropDown = gr.Dropdown( | |
| allTablesAndCols[allTablesList[i]],visible=True,value=selectedTablesAndCols.get(allTablesList[i],None), multiselect=True, label=allTablesList[i], info="Select columns of a table" | |
| ) | |
| #tableBoxes[allTables[i]] = columnsDropDown | |
| tableBoxes.append(columnsDropDown) | |
| else: | |
| columnsDropDown = gr.Dropdown( | |
| allTablesAndCols[allTablesList[i]], visible=False, value=None, multiselect=True, label=allTablesList[i], info="Select columns of a table" | |
| ) | |
| #tableBoxes[allTables[i]] = columnsDropDown | |
| tableBoxes.append(columnsDropDown) | |
| refreshTables.click(onSelectedTablesChange, inputs=dropDown, outputs=tableBoxes) | |
| columnsTextBox = gr.Textbox(label = 'Result') | |
| refreshColumns = gr.Button("Refresh selected columns and Reload Data") | |
| refreshColumns.click(onSelectedColumnsChange, inputs=tableBoxes, outputs=columnsTextBox) | |
| resetToDefaultSelection = gr.Button("Reset to Default") | |
| resetToDefaultSelection.click(onResetToDefaultSelection, inputs=None, outputs=tableBoxes) | |
| #screen 4 for downloading logs | |
| with gr.Tab("Log-files"): | |
| downloadableFilesPaths = getAllLogFilesPaths() | |
| fileComponent = gr.File(downloadableFilesPaths, file_count='multiple') | |
| refreshLogs = gr.Button("Sync Log files from /data") | |
| refreshLogs.click(onSyncLogsWithDataDir, inputs=None, outputs=fileComponent) | |
| print("Ready to launch...") | |
| demo.launch(share=True, debug=True, ssl_verify=False, auth=checkAuth) |