from gptManager import ChatgptManager from utils import * class QueryHelperChainOfThought: def __init__(self, gptInstance: ChatgptManager, dbEngine, schemaName, platform, metadataLayout: MetaDataLayout, sampleDataRows, gptSampleRows, getSampleDataForTablesAndCols): self.gptInstance = gptInstance self.schemaName = schemaName self.platform = platform self.metadataLayout = metadataLayout self.sampleDataRows = sampleDataRows self.gptSampleRows = gptSampleRows self.getSampleDataForTablesAndCols = getSampleDataForTablesAndCols self.dbEngine = dbEngine self._onMetadataChange() def _onMetadataChange(self): metadataLayout = self.metadataLayout sampleDataRows = self.sampleDataRows dbEngine = self.dbEngine schemaName = self.schemaName selectedTablesAndCols = metadataLayout.getSelectedTablesAndCols() self.sampleData = self.getSampleDataForTablesAndCols(dbEngine=dbEngine,schemaName=schemaName, tablesAndCols=selectedTablesAndCols, maxRows=sampleDataRows) def getMetadata(self) -> MetaDataLayout : return self.metadataLayout def updateMetadata(self, metadataLayout): self.metadataLayout = metadataLayout self._onMetadataChange() def modifySqlQueryEnteredByUser(self, userSqlQuery): platform = self.platform userPrompt = f"Please correct the following sql query, also it has to be run on {platform}. sql query is \n {userSqlQuery}." systemPrompt = "" modifiedSql = self.gptInstance.getResponseForUserInput(userPrompt, systemPrompt) return modifiedSql def filteredSampleDataForProspects(self, prospectTablesAndCols): sampleData = self.sampleData filteredData = {} for table in prospectTablesAndCols.keys(): # filteredData[table] = sampleData[table][prospectTablesAndCols[table]] #take all columns of prospects filteredData[table] = sampleData[table][prospectTablesAndCols[table]] return filteredData def extractSingleJson(self, text): pattern = r'\{.*?\}' matches = re.findall(pattern, text, re.DOTALL) extracted_json = [json.loads(match) for match in matches][0] return extracted_json def getQueryForUserInputCoT(self, userInput): #1. Is the input complete to create a query, or ask user to reask with more detailed input systemPromptForInputClarification = """Given an input text, user want to generate sql query. Please answer if the user input is complete or user needs to ask in more detailed way. Answer in following format. 'Yes' ; if yes, break the userinput into smaller subtask for query generation. Formatted into { "Task 1": "task 1 description", "Task 2": "task 2 description" } 'No' ; if no, then Reason- please be more detailed about customer details; if more modification needed""" cotStep1 = self.gptInstance.getResponseForUserInput(userInput, systemPromptForInputClarification, chatHistory) if "yes" in cot1.lower()[:5]: print("User input sufficient") tasks = self.extractSingleJson(cotStep1) print(f"tasks are {tasks}") taskQueries = {} for key, task in tasks.items(): taskQuery = self.getQueryForUserInput(userInput) taskQueries[key] = {"task":task, "taskQuery":taskQuery} print(f"tasks and their queries {taskQueries}") combiningSubtasksQueryPrompt = f"""Combine following subtask and their queries to generate sql query to answer the user input.\n """ userPrompt = f"user input is {userInput}" for key in taskQueries.keys(): task = taskQueries[key]["task"] query = taskQueries[key]["taskQuery"] userPrompt += f" task: {task}, task query: {query}" return self.self.gptInstance.getResponseForUserInput(userPrompt, combiningSubtasksQueryPrompt) return f"Please rephrase your query. {' '.join(cot1.split('Reason')[1:])}" def getQueryForUserInput(self, userInput, chatHistory=[]): gptSampleRows = self.gptSampleRows selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols() prospectTablesAndCols = self.getProspectiveTablesAndCols(userInput, selectedTablesAndCols, chatHistory) print("getting prospects", prospectTablesAndCols) prospectTablesData = self.filteredSampleDataForProspects(prospectTablesAndCols) systemPromptForQueryGeneration = self.getSystemPromptForQueryGeneration(prospectTablesData, gptSampleRows=gptSampleRows) queryByGpt = self.gptInstance.getResponseForUserInput(userInput, systemPromptForQueryGeneration, chatHistory) queryByGpt = preProcessGptQueryReponse(queryByGpt, metadataLayout=self.metadataLayout) return queryByGpt, prospectTablesAndCols def getProspectiveTablesAndCols(self, userInput, selectedTablesAndCols, chatHistory=[]): schemaName = self.schemaName systemPromptForProspectColumns = self.getSystemPromptForProspectColumns(selectedTablesAndCols) prospectiveTablesColsText = self.gptInstance.getResponseForUserInput(userInput, systemPromptForProspectColumns, chatHistory) prospectTablesAndCols = {} for table in selectedTablesAndCols.keys(): if table in prospectiveTablesColsText: prospectTablesAndCols[table] = [] for column in selectedTablesAndCols[table]: if column in prospectiveTablesColsText: prospectTablesAndCols[table].append(column) return prospectTablesAndCols def getSystemPromptForQueryGeneration(self, prospectTablesData, gptSampleRows): schemaName = self.schemaName platform = self.platform exampleQuery = """SELECT a.customer_id, COUNT(a.product_id) as chandelier_count FROM lpdatamart.tbl_f_sales a JOIN lpdatamart.tbl_d_product b ON a.product_id = b.product_id JOIN lpdatamart.tbl_d_calendar c ON a.date_id = c.date_id WHERE UPPER(b.product_name) LIKE '%CHANDELIER%' AND c.calendar_month = 'NOVEMBER' AND c.year = 2023 GROUP BY a.customer_id ORDER BY chandelier_count DESC""" question = "top 5 customers who bought most chandeliers in nov 2023" prompt = f"""Given an input text, generate the corresponding SQL query for given details. Schema Name is {schemaName}. And sql platform is {platform}.\n following is sample data. Also There is example user input and desired generated sql query. Follow similar patterns as example. eg case insensitive, explicit variable declaration etc. user input : {question}, query : {exampleQuery} """ for idx, tableName in enumerate(prospectTablesData.keys(), start=1): prompt += f"table name is {tableName}, table data is {prospectTablesData[tableName].head(gptSampleRows)}" prompt += "XXXX" return prompt.replace("\n"," ").replace("\\"," ").replace(" "," ").replace("XXXX", " ") def getSystemPromptForProspectColumns(self, selectedTablesAndCols): schemaName = self.schemaName platform = self.platform prompt = f"""Given an input text, User wants to know which all tables and columns would be possibily to have the desired data. Output them as json. Schema Name is {schemaName}. And sql platform is {platform}.\n""" for idx, tableName in enumerate(selectedTablesAndCols.keys(), start=1): prompt += f"table name {tableName} {', '.join(selectedTablesAndCols[tableName])}" prompt += "XXXX" return prompt.replace("\n"," ").replace("\\"," ").replace(" "," ").replace("XXXX", " ")