from gptManager import ChatgptManager from utils import * class QueryHelper: def __init__(self, gptInstance: ChatgptManager, dbEngine, schemaName, platform, metadataLayout: MetaDataLayout, sampleDataRows, gptSampleRows, getSampleDataForTablesAndCols): self.gptInstance = gptInstance self.schemaName = schemaName self.platform = platform self.metadataLayout = metadataLayout self.sampleDataRows = sampleDataRows self.gptSampleRows = gptSampleRows self.getSampleDataForTablesAndCols = getSampleDataForTablesAndCols self.dbEngine = dbEngine self._onMetadataChange() def _onMetadataChange(self): metadataLayout = self.metadataLayout sampleDataRows = self.sampleDataRows dbEngine = self.dbEngine schemaName = self.schemaName selectedTablesAndCols = metadataLayout.getSelectedTablesAndCols() self.sampleData = self.getSampleDataForTablesAndCols(dbEngine=dbEngine,schemaName=schemaName, tablesAndCols=selectedTablesAndCols, maxRows=sampleDataRows) def getMetadata(self) -> MetaDataLayout : return self.metadataLayout def updateMetadata(self, metadataLayout): self.metadataLayout = metadataLayout self._onMetadataChange() def modifySqlQueryEnteredByUser(self, userSqlQuery): platform = self.platform userPrompt = f"Please correct the following sql query, also it has to be run on {platform}. sql query is \n {userSqlQuery}." systemPrompt = "" modifiedSql = self.gptInstance.getResponseForUserInput(userPrompt, systemPrompt) return modifiedSql def filteredSampleDataForProspects(self, prospectTablesAndCols): sampleData = self.sampleData filteredData = {} for table in prospectTablesAndCols.keys(): # filteredData[table] = sampleData[table][prospectTablesAndCols[table]] #take all columns of prospects filteredData[table] = sampleData[table][prospectTablesAndCols[table]] return filteredData def getQueryForUserInput(self, userInput, chatHistory=[]): gptSampleRows = self.gptSampleRows selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols() prospectTablesAndCols = self.getProspectiveTablesAndCols(userInput, selectedTablesAndCols, chatHistory) print("getting prospects", prospectTablesAndCols) prospectTablesData = self.filteredSampleDataForProspects(prospectTablesAndCols) systemPromptForQueryGeneration = self.getSystemPromptForQueryGeneration(prospectTablesData, gptSampleRows=gptSampleRows) queryByGpt = self.gptInstance.getResponseForUserInput(userInput, systemPromptForQueryGeneration, chatHistory) queryByGpt = preProcessGptQueryReponse(queryByGpt, metadataLayout=self.metadataLayout) return queryByGpt, prospectTablesAndCols def getProspectiveTablesAndCols(self, userInput, selectedTablesAndCols, chatHistory=[]): schemaName = self.schemaName systemPromptForProspectColumns = self.getSystemPromptForProspectColumns(selectedTablesAndCols) prospectiveTablesColsText = self.gptInstance.getResponseForUserInput(userInput, systemPromptForProspectColumns, chatHistory) prospectTablesAndCols = {} for table in selectedTablesAndCols.keys(): if table in prospectiveTablesColsText: prospectTablesAndCols[table] = [] for column in selectedTablesAndCols[table]: if column in prospectiveTablesColsText: prospectTablesAndCols[table].append(column) return prospectTablesAndCols def getSystemPromptForQueryGeneration(self, prospectTablesData, gptSampleRows): schemaName = self.schemaName platform = self.platform exampleQuery = """SELECT a.customer_id, COUNT(a.product_id) as chandelier_count FROM lpdatamart.tbl_f_sales a JOIN lpdatamart.tbl_d_product b ON a.product_id = b.product_id JOIN lpdatamart.tbl_d_calendar c ON a.date_id = c.date_id WHERE UPPER(b.product_name) LIKE '%CHANDELIER%' AND c.calendar_month = 'NOVEMBER' AND c.year = 2023 GROUP BY a.customer_id ORDER BY chandelier_count DESC""" question = "top 5 customers who bought most chandeliers in nov 2023" prompt = f"""Given an input text, generate the corresponding SQL query for given details. Schema Name is {schemaName}. And sql platform is {platform}.\n following is sample data. Also There is example user input and desired generated sql query. Follow similar patterns as example. eg case insensitive, explicit variable declaration etc. user input : {question}, query : {exampleQuery} """ for idx, tableName in enumerate(prospectTablesData.keys(), start=1): prompt += f"table name is {tableName}, table data is {prospectTablesData[tableName].head(gptSampleRows)}" prompt += "XXXX" return prompt.replace("\n"," ").replace("\\"," ").replace(" "," ").replace("XXXX", " ") def getSystemPromptForProspectColumns(self, selectedTablesAndCols): schemaName = self.schemaName platform = self.platform prompt = f"""Given an input text, User wants to know which all tables and columns would be possibily to have the desired data. Output them as json. Schema Name is {schemaName}. And sql platform is {platform}.\n""" for idx, tableName in enumerate(selectedTablesAndCols.keys(), start=1): prompt += f"table name {tableName} {', '.join(selectedTablesAndCols[tableName])}" prompt += "XXXX" return prompt.replace("\n"," ").replace("\\"," ").replace(" "," ").replace("XXXX", " ")