Spaces:
Runtime error
Runtime error
| from openai import OpenAI | |
| import pandas as pd | |
| import psycopg2 | |
| import time | |
| import gradio as gr | |
| import sqlparse | |
| import re | |
| import os | |
| import warnings | |
| from persistStorage import saveLog | |
| from config import * | |
| from constants import * | |
| from utils import * | |
| from gptManager import ChatgptManager | |
| # from queryHelper import QueryHelper | |
| pd.set_option('display.max_columns', None) | |
| pd.set_option('display.max_rows', None) | |
| # Filter out all warning messages | |
| warnings.filterwarnings("ignore") | |
| dbCreds = DataWrapper(DB_CREDS_DATA) | |
| dbEngine = DbEngine(dbCreds) | |
| dbEngine.connect() | |
| tablesAndCols = getAllTablesInfo(dbEngine, SCHEMA_NAME) | |
| metadataLayout = MetaDataLayout(schemaName=SCHEMA_NAME, allTablesAndCols=tablesAndCols) | |
| metadataLayout.setSelection(DEFAULT_TABLES_COLS) | |
| selectedTablesAndCols = metadataLayout.getSelectedTablesAndCols() | |
| def getSampleDataForTablesAndCols(dbEngine, schemaName, tablesAndCols, maxRows): | |
| data = {} | |
| conn = dbEngine.connection | |
| for table in tablesAndCols.keys(): | |
| try: | |
| sqlQuery = f"""select * from {schemaName}.{table} limit {maxRows}""" | |
| data[table] = pd.read_sql_query(sqlQuery, con=conn) | |
| except Exception as e: | |
| print(e) | |
| print(f"couldn't read table data. Table: {table}") | |
| return data | |
| class QueryHelper: | |
| def __init__(self, gptInstance, dbEngine, schemaName, | |
| platform, metadataLayout, sampleDataRows, | |
| gptSampleRows, getSampleDataForTablesAndCols): | |
| self.gptInstance = gptInstance | |
| self.schemaName = schemaName | |
| self.platform = platform | |
| self.metadataLayout = metadataLayout | |
| self.sampleDataRows = sampleDataRows | |
| self.gptSampleRows = gptSampleRows | |
| self.getSampleDataForTablesAndCols = getSampleDataForTablesAndCols | |
| self.dbEngine = dbEngine | |
| self._onMetadataChange() | |
| def _onMetadataChange(self): | |
| metadataLayout = self.metadataLayout | |
| sampleDataRows = self.sampleDataRows | |
| dbEngine = self.dbEngine | |
| schemaName = self.schemaName | |
| selectedTablesAndCols = metadataLayout.getSelectedTablesAndCols() | |
| self.sampleData = self.getSampleDataForTablesAndCols(dbEngine=dbEngine,schemaName=schemaName, | |
| tablesAndCols=selectedTablesAndCols, maxRows=sampleDataRows) | |
| def getMetadata(self): | |
| return self.metadataLayout | |
| def updateMetadata(self, metadataLayout): | |
| self.metadataLayout = metadataLayout | |
| self._onMetadataChange() | |
| def modifySqlQueryEnteredByUser(self, userSqlQuery): | |
| platform = self.platform | |
| userPrompt = f"Please correct the following sql query, also it has to be run on {platform}. sql query is \n {userSqlQuery}." | |
| systemPrompt = "" | |
| modifiedSql = self.gptInstance.getResponseForUserInput(userPrompt, systemPrompt) | |
| return modifiedSql | |
| def filteredSampleDataForProspects(self, prospectTablesAndCols): | |
| sampleData = self.sampleData | |
| filteredData = {} | |
| for table in prospectTablesAndCols.keys(): | |
| # filteredData[table] = sampleData[table][prospectTablesAndCols[table]] | |
| #take all columns of prospects | |
| filteredData[table] = sampleData[table] | |
| return filteredData | |
| def getQueryForUserInput(self, userInput): | |
| gptSampleRows = self.gptSampleRows | |
| selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols() | |
| prospectTablesAndCols = self.getProspectiveTablesAndCols(userInput, selectedTablesAndCols) | |
| print("getting prospects", prospectTablesAndCols) | |
| prospectTablesData = self.filteredSampleDataForProspects(prospectTablesAndCols) | |
| systemPromptForQueryGeneration = self.getSystemPromptForQueryGeneration(prospectTablesData, gptSampleRows=gptSampleRows) | |
| queryByGpt = self.gptInstance.getResponseForUserInput(userInput, systemPromptForQueryGeneration) | |
| return queryByGpt | |
| def getProspectiveTablesAndCols(self, userInput, selectedTablesAndCols): | |
| schemaName = self.schemaName | |
| systemPromptForProspectColumns = self.getSystemPromptForProspectColumns(selectedTablesAndCols) | |
| prospectiveTablesColsText = self.gptInstance.getResponseForUserInput(userInput, systemPromptForProspectColumns) | |
| prospectTablesAndCols = {} | |
| for table in selectedTablesAndCols.keys(): | |
| if table in prospectiveTablesColsText: | |
| prospectTablesAndCols[table] = [] | |
| for column in selectedTablesAndCols[table]: | |
| if column in prospectiveTablesColsText: | |
| prospectTablesAndCols[table].append(column) | |
| return prospectTablesAndCols | |
| def getSystemPromptForQueryGeneration(self, prospectTablesData, gptSampleRows): | |
| schemaName = self.schemaName | |
| platform = self.platform | |
| prompt = f"""Given an input text, generate the corresponding SQL query for given details. Schema Name is {schemaName}. And sql platform is {platform}.\n following is sample data""" | |
| for idx, tableName in enumerate(prospectTablesData.keys(), start=1): | |
| prompt += f"table name is {tableName}, table data is {prospectTablesData[tableName].head(gptSampleRows)}" | |
| prompt += "XXXX" | |
| return prompt.replace("\n"," ").replace("\\"," ").replace(" "," ").replace("XXXX", " ") | |
| def getSystemPromptForProspectColumns(self, selectedTablesAndCols): | |
| schemaName = self.schemaName | |
| platform = self.platform | |
| prompt = f"""Given an input text, User wants to know which all tables and columns would be possibily to have the desired data. Output them as json. Schema Name is {schemaName}. And sql platform is {platform}.\n""" | |
| for idx, tableName in enumerate(selectedTablesAndCols.keys(), start=1): | |
| prompt += f"table name {tableName} {', '.join(selectedTablesAndCols[tableName])}" | |
| prompt += "XXXX" | |
| return prompt.replace("\n"," ").replace("\\"," ").replace(" "," ").replace("XXXX", " ") | |
| openAIClient = OpenAI(api_key=OPENAI_API_KEY) | |
| gptInstance = ChatgptManager(openAIClient, model=GPT_MODEL) | |
| queryHelper = QueryHelper(gptInstance=gptInstance, | |
| schemaName=SCHEMA_NAME,platform=PLATFORM, | |
| metadataLayout=metadataLayout, | |
| sampleDataRows=SAMPLE_ROW_MAX, | |
| gptSampleRows=GPT_SAMPLE_ROWS, | |
| dbEngine=dbEngine, | |
| getSampleDataForTablesAndCols=getSampleDataForTablesAndCols) | |
| def checkAuth(username, password): | |
| global ADMIN, PASSWD | |
| if username == ADMIN and password == PASSWD: | |
| return True | |
| return False | |
| # Function to save history of chat | |
| def respond(message, chatHistory): | |
| """gpt response handler for gradio ui""" | |
| global queryHelper | |
| try: | |
| botMessage = queryHelper.getQueryForUserInput(message) | |
| except Exception as e: | |
| errorMessage = {"function":"queryHelper.getQueryForUserInput","error":str(e), "userInput":message} | |
| saveLog(errorMessage, 'error') | |
| logMessage = {"userInput":message, "queryGenerated":botMessage} | |
| saveLog(logMessage) | |
| chatHistory.append((message, botMessage)) | |
| time.sleep(2) | |
| return "", chatHistory | |
| # Function to test the generated sql query | |
| def isDataQuery(sql_query): | |
| upper_query = sql_query.upper() | |
| dml_keywords = ['INSERT', 'UPDATE', 'DELETE', 'MERGE'] | |
| for keyword in dml_keywords: | |
| if re.search(fr'\b{keyword}\b', upper_query): | |
| return False # Found a DML keyword, indicating modification | |
| # If no DML keywords are found, it's likely a data query | |
| return True | |
| def testSQL(sql): | |
| global dbEngine, queryHelper | |
| sql=sql.replace(';', '') | |
| if ('limit' in sql[-15:].lower())==False: | |
| sql = sql + ' ' + 'limit 5' | |
| sql = str(sql) | |
| sql = sqlparse.format(sql, reindent=True, keyword_case='upper') | |
| print(sql) | |
| if not isDataQuery(sql): | |
| return "Sorry not allowed to run. As the query modifies the data." | |
| try: | |
| conn = dbEngine.connection | |
| df = pd.read_sql_query(sql, con=conn) | |
| return pd.DataFrame(df) | |
| except Exception as e: | |
| errorMessage = {"function":"testSQL","error":str(e), "userInput":sql} | |
| saveLog(errorMessage, 'error') | |
| print(f"Error occured during running the query {sql}.\n and the error is {str(e)}") | |
| prompt = f"Please correct the following sql query, also it has to be run on {PLATFORM}. sql query is \n {sql}. the error occured is {str(e)}." | |
| modifiedSql = queryHelper.modifySqlQueryEnteredByUser(prompt) | |
| return f"The query you entered throws some error. Here is modified version. Please try this.\n {modifiedSql}" | |
| def onSelectedTablesChange(tablesSelected): | |
| #Updates tables visible and allow selecting columns for them | |
| global queryHelper | |
| print(f"Selected tables : {tablesSelected}") | |
| metadataLayout = queryHelper.getMetadata() | |
| allTables = list(metadataLayout.getAllTablesCols()) | |
| tableBoxes = [] | |
| for i in range(len(allTables)): | |
| if allTables[i] in tablesSelected: | |
| tableBoxes.append(gr.Textbox(f"Textbox {allTables[i]}", visible=True, label=f"{allTables[i]}")) | |
| else: | |
| tableBoxes.append(gr.Textbox(f"Textbox {allTables[i]}", visible=False, label=f"{allTables[i]}")) | |
| return tableBoxes | |
| def onSelectedColumnsChange(*tableBoxes): | |
| #update selection of columns and tables (include new tables and cols in gpts context) | |
| global queryHelper | |
| metadataLayout = queryHelper.getMetadata() | |
| allTablesList = list(metadataLayout.getAllTablesCols().keys()) | |
| tablesAndCols = {} | |
| result = '' | |
| print("Getting selected tables and columns from gradio") | |
| for tableBox, table in zip(tableBoxes, allTablesList): | |
| if isinstance(tableBox, list): | |
| if len(tableBox)!=0: | |
| tablesAndCols[table] = tableBox | |
| else: | |
| pass | |
| metadataLayout.setSelection(tablesAndCols=tablesAndCols) | |
| print("metadata updated") | |
| print("Updating queryHelper state, and sample data") | |
| queryHelper.updateMetadata(metadataLayout) | |
| return "Columns udpated" | |
| def onResetToDefaultSelection(): | |
| global queryHelper | |
| tablesSelected = list(DefaultTablesAndCols.keys()) | |
| tableBoxes = [] | |
| allTablesList = list(metadataLayout.getAllTablesCols().keys()) | |
| for i in range(len(allTablesList)): | |
| if allTablesList[i] in tablesSelected: | |
| tableBoxes.append(gr.Textbox(f"Textbox {allTablesList[i]}", visible=True, label=f"{allTablesList[i]}")) | |
| else: | |
| tableBoxes.append(gr.Textbox(f"Textbox {allTablesList[i]}", visible=False, label=f"{allTablesList[i]}")) | |
| metadataLayout.resetSelection() | |
| metadataLayout.setSelection(DefaultTablesAndCols) | |
| queryHelper.updateMetadata(metadataLayout) | |
| return tableBoxes | |
| with gr.Blocks() as demo: | |
| # screen 1 : Chatbot for question answering to generate sql query from user input in english | |
| with gr.Tab("Query Helper"): | |
| gr.Markdown("""<h1><center> Query Helper</center></h1>""") | |
| chatbot = gr.Chatbot() | |
| msg = gr.Textbox() | |
| clear = gr.ClearButton([msg, chatbot]) | |
| msg.submit(respond, [msg, chatbot], [msg, chatbot]) | |
| # screen 2 : To run sql query against database | |
| with gr.Tab("Run Query"): | |
| gr.Markdown("""<h1><center> Run Query </center></h1>""") | |
| text_input = gr.Textbox(label = 'Input SQL Query', placeholder="Write your SQL query here ...") | |
| text_output = gr.Textbox(label = 'Result') | |
| text_button = gr.Button("RUN QUERY") | |
| clear = gr.ClearButton([text_input, text_output]) | |
| text_button.click(testSQL, inputs=text_input, outputs=text_output) | |
| # screen 3 : To set creds, schema, tables and columns | |
| with gr.Tab("Setup"): | |
| gr.Markdown("""<h1><center> Run Query </center></h1>""") | |
| text_input = gr.Textbox(label = 'schema name', value= SCHEMA_NAME) | |
| allTablesAndCols = queryHelper.getMetadata().getAllTablesCols() | |
| selectedTablesAndCols = queryHelper.getMetadata().getSelectedTablesAndCols() | |
| allTablesList = list(allTablesAndCols.keys()) | |
| selectedTablesList = list(selectedTablesAndCols.keys()) | |
| dropDown = gr.Dropdown( | |
| allTablesList, value=selectedTablesList, multiselect=True, label="Selected Tables", info="Select Tables from available tables of the schema" | |
| ) | |
| refreshTables = gr.Button("Refresh selected tables") | |
| tableBoxes = [] | |
| for i in range(len(allTablesList)): | |
| if allTablesList[i] in selectedTablesList: | |
| columnsDropDown = gr.Dropdown( | |
| allTablesAndCols[allTablesList[i]],visible=True,value=selectedTablesAndCols.get(allTablesList[i],None), multiselect=True, label=allTablesList[i], info="Select columns of a table" | |
| ) | |
| #tableBoxes[allTables[i]] = columnsDropDown | |
| tableBoxes.append(columnsDropDown) | |
| else: | |
| columnsDropDown = gr.Dropdown( | |
| allTablesAndCols[allTablesList[i]], visible=False, value=None, multiselect=True, label=allTablesList[i], info="Select columns of a table" | |
| ) | |
| #tableBoxes[allTables[i]] = columnsDropDown | |
| tableBoxes.append(columnsDropDown) | |
| refreshTables.click(onSelectedTablesChange, inputs=dropDown, outputs=tableBoxes) | |
| columnsTextBox = gr.Textbox(label = 'Result') | |
| refreshColumns = gr.Button("Refresh selected columns and Reload Data") | |
| refreshColumns.click(onSelectedColumnsChange, inputs=tableBoxes, outputs=columnsTextBox) | |
| resetToDefaultSelection = gr.Button("Reset to Default") | |
| resetToDefaultSelection.click(onResetToDefaultSelection, inputs=None, outputs=tableBoxes) | |
| demo.launch(share=True, debug=True, ssl_verify=False, auth=checkAuth) | |
| dbEngine.connect() |