from openai import OpenAI import pandas as pd import psycopg2 import time import gradio as gr import sqlparse import re import os import warnings from config import * from constants import * from utils import * from gptManager import ChatgptManager # from queryHelper import QueryHelper pd.set_option('display.max_columns', None) pd.set_option('display.max_rows', None) # Filter out all warning messages warnings.filterwarnings("ignore") dbCreds = DataWrapper(DB_CREDS_DATA) dbEngine = DbEngine(dbCreds) dbEngine.connect() tablesAndCols = getAllTablesInfo(dbEngine, SCHEMA_NAME) metadataLayout = MetaDataLayout(schemaName=SCHEMA_NAME, allTablesAndCols=tablesAndCols) metadataLayout.setSelection(DEFAULT_TABLES_COLS) selectedTablesAndCols = metadataLayout.getSelectedTablesAndCols() def getSampleDataForTablesAndCols(dbEngine, schemaName, tablesAndCols, maxRows): data = {} conn = dbEngine.connection for table in tablesAndCols.keys(): try: sqlQuery = f"""select * from {schemaName}.{table} limit {maxRows}""" data[table] = pd.read_sql_query(sqlQuery, con=conn) except Exception as e: print(e) print(f"couldn't read table data. Table: {table}") return data class QueryHelper: def __init__(self, gptInstance, dbEngine, schemaName, platform, metadataLayout, sampleDataRows, gptSampleRows, getSampleDataForTablesAndCols): self.gptInstance = gptInstance self.schemaName = schemaName self.platform = platform self.metadataLayout = metadataLayout self.sampleDataRows = sampleDataRows self.gptSampleRows = gptSampleRows self.getSampleDataForTablesAndCols = getSampleDataForTablesAndCols self.dbEngine = dbEngine self._onMetadataChange() def _onMetadataChange(self): metadataLayout = self.metadataLayout sampleDataRows = self.sampleDataRows dbEngine = self.dbEngine schemaName = self.schemaName selectedTablesAndCols = metadataLayout.getSelectedTablesAndCols() self.sampleData = self.getSampleDataForTablesAndCols(dbEngine=dbEngine,schemaName=schemaName, tablesAndCols=selectedTablesAndCols, maxRows=sampleDataRows) def getMetadata(self): return self.metadataLayout def updateMetadata(self, metadataLayout): self.metadataLayout = metadataLayout self._onMetadataChange() def modifySqlQueryEnteredByUser(self, userSqlQuery): platform = self.platform userPrompt = f"Please correct the following sql query, also it has to be run on {platform}. sql query is \n {userSqlQuery}." systemPrompt = "" modifiedSql = self.gptInstance.getResponseForUserInput(userPrompt, systemPrompt) return modifiedSql def filteredSampleDataForProspects(self, prospectTablesAndCols): sampleData = self.sampleData filteredData = {} for table in prospectTablesAndCols.keys(): # filteredData[table] = sampleData[table][prospectTablesAndCols[table]] #take all columns of prospects filteredData[table] = sampleData[table] return filteredData def getQueryForUserInput(self, userInput): gptSampleRows = self.gptSampleRows selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols() prospectTablesAndCols = self.getProspectiveTablesAndCols(userInput, selectedTablesAndCols) print("getting prospects", prospectTablesAndCols) prospectTablesData = self.filteredSampleDataForProspects(prospectTablesAndCols) systemPromptForQueryGeneration = self.getSystemPromptForQueryGeneration(prospectTablesData, gptSampleRows=gptSampleRows) queryByGpt = self.gptInstance.getResponseForUserInput(userInput, systemPromptForQueryGeneration) return queryByGpt def getProspectiveTablesAndCols(self, userInput, selectedTablesAndCols): schemaName = self.schemaName systemPromptForProspectColumns = self.getSystemPromptForProspectColumns(selectedTablesAndCols) prospectiveTablesColsText = self.gptInstance.getResponseForUserInput(userInput, systemPromptForProspectColumns) prospectTablesAndCols = {} for table in selectedTablesAndCols.keys(): if table in prospectiveTablesColsText: prospectTablesAndCols[table] = [] for column in selectedTablesAndCols[table]: if column in prospectiveTablesColsText: prospectTablesAndCols[table].append(column) return prospectTablesAndCols def getSystemPromptForQueryGeneration(self, prospectTablesData, gptSampleRows): schemaName = self.schemaName platform = self.platform prompt = f"""Given an input text, generate the corresponding SQL query for given details. Schema Name is {schemaName}. And sql platform is {platform}.\n following is sample data""" for idx, tableName in enumerate(prospectTablesData.keys(), start=1): prompt += f"table name is {tableName}, table data is {prospectTablesData[tableName].head(gptSampleRows)}" prompt += "XXXX" return prompt.replace("\n"," ").replace("\\"," ").replace(" "," ").replace("XXXX", " ") def getSystemPromptForProspectColumns(self, selectedTablesAndCols): schemaName = self.schemaName platform = self.platform prompt = f"""Given an input text, User wants to know which all tables and columns would be possibily to have the desired data. Output them as json. Schema Name is {schemaName}. And sql platform is {platform}.\n""" for idx, tableName in enumerate(selectedTablesAndCols.keys(), start=1): prompt += f"table name {tableName} {', '.join(selectedTablesAndCols[tableName])}" prompt += "XXXX" return prompt.replace("\n"," ").replace("\\"," ").replace(" "," ").replace("XXXX", " ") openAIClient = OpenAI(api_key=OPENAI_API_KEY) gptInstance = ChatgptManager(openAIClient, model=GPT_MODEL) queryHelper = QueryHelper(gptInstance=gptInstance, schemaName=SCHEMA_NAME,platform=PLATFORM, metadataLayout=metadataLayout, sampleDataRows=SAMPLE_ROW_MAX, gptSampleRows=GPT_SAMPLE_ROWS, dbEngine=dbEngine, getSampleDataForTablesAndCols=getSampleDataForTablesAndCols) def checkAuth(username, password): global ADMIN, PASSWD if username == ADMIN and password == PASSWD: return True return False # Function to save history of chat def respond(message, chatHistory): """gpt response handler for gradio ui""" global queryHelper botMessage = queryHelper.getQueryForUserInput(message) chatHistory.append((message, botMessage)) time.sleep(2) return "", chatHistory # Function to test the generated sql query def isDataQuery(sql_query): upper_query = sql_query.upper() dml_keywords = ['INSERT', 'UPDATE', 'DELETE', 'MERGE'] for keyword in dml_keywords: if re.search(fr'\b{keyword}\b', upper_query): return False # Found a DML keyword, indicating modification # If no DML keywords are found, it's likely a data query return True def testSQL(sql): global dbEngine, queryHelper sql=sql.replace(';', '') if ('limit' in sql[-15:].lower())==False: sql = sql + ' ' + 'limit 5' sql = str(sql) sql = sqlparse.format(sql, reindent=True, keyword_case='upper') print(sql) if not isDataQuery(sql): return "Sorry not allowed to run. As the query modifies the data." try: conn = dbEngine.connection df = pd.read_sql_query(sql, con=conn) return pd.DataFrame(df) except Exception as e: print(f"Error occured during running the query {sql}.\n and the error is {str(e)}") prompt = f"Please correct the following sql query, also it has to be run on {PLATFORM}. sql query is \n {sql}. the error occured is {str(e)}." modifiedSql = queryHelper.modifySqlQueryEnteredByUser(prompt) return f"The query you entered throws some error. Here is modified version. Please try this.\n {modifiedSql}" def onSelectedTablesChange(tablesSelected): #Updates tables visible and allow selecting columns for them global queryHelper print(f"Selected tables : {tablesSelected}") metadataLayout = queryHelper.getMetadata() allTables = list(metadataLayout.getAllTablesCols()) tableBoxes = [] for i in range(len(allTables)): if allTables[i] in tablesSelected: tableBoxes.append(gr.Textbox(f"Textbox {allTables[i]}", visible=True, label=f"{allTables[i]}")) else: tableBoxes.append(gr.Textbox(f"Textbox {allTables[i]}", visible=False, label=f"{allTables[i]}")) return tableBoxes def onSelectedColumnsChange(*tableBoxes): #update selection of columns and tables (include new tables and cols in gpts context) global queryHelper metadataLayout = queryHelper.getMetadata() allTablesList = list(metadataLayout.getAllTablesCols().keys()) tablesAndCols = {} result = '' print("Getting selected tables and columns from gradio") for tableBox, table in zip(tableBoxes, allTablesList): if isinstance(tableBox, list): if len(tableBox)!=0: tablesAndCols[table] = tableBox else: pass metadataLayout.setSelection(tablesAndCols=tablesAndCols) print("metadata updated") print("Updating queryHelper state, and sample data") queryHelper.updateMetadata(metadataLayout) return "Columns udpated" def onResetToDefaultSelection(): global queryHelper tablesSelected = list(DefaultTablesAndCols.keys()) tableBoxes = [] allTablesList = list(metadataLayout.getAllTablesCols().keys()) for i in range(len(allTablesList)): if allTablesList[i] in tablesSelected: tableBoxes.append(gr.Textbox(f"Textbox {allTablesList[i]}", visible=True, label=f"{allTablesList[i]}")) else: tableBoxes.append(gr.Textbox(f"Textbox {allTablesList[i]}", visible=False, label=f"{allTablesList[i]}")) metadataLayout.resetSelection() metadataLayout.setSelection(DefaultTablesAndCols) queryHelper.updateMetadata(metadataLayout) return tableBoxes with gr.Blocks() as demo: # screen 1 : Chatbot for question answering to generate sql query from user input in english with gr.Tab("Query Helper"): gr.Markdown("""