from gptManager import ChatgptManager from utils import * import json import sqlparse from constants import TABLE_RELATIONS class QueryHelper: def __init__(self, gptInstanceForTableCols: ChatgptManager, gptInstanceForQuery: ChatgptManager, dbEngine, schemaName, platform, metadataLayout: MetaDataLayout, sampleDataRows, gptSampleRows, getSampleDataForTablesAndCols, tableSummaryJson='tableSummaryDict.json'): self.gptInstanceForTableCols = gptInstanceForTableCols self.gptInstanceForQuery = gptInstanceForQuery self.schemaName = schemaName self.platform = platform self.metadataLayout = metadataLayout self.sampleDataRows = sampleDataRows self.gptSampleRows = gptSampleRows self.getSampleDataForTablesAndCols = getSampleDataForTablesAndCols self.dbEngine = dbEngine self.tableSummaryJson = tableSummaryJson self._onMetadataChange() def _onMetadataChange(self): metadataLayout = self.metadataLayout sampleDataRows = self.sampleDataRows dbEngine = self.dbEngine schemaName = self.schemaName selectedTablesAndCols = metadataLayout.getSelectedTablesAndCols() self.sampleData = self.getSampleDataForTablesAndCols(dbEngine=dbEngine,schemaName=schemaName, tablesAndCols=selectedTablesAndCols, maxRows=sampleDataRows) self.promptTableColsInfo = self.getSystemPromptForTableCols() self.gptInstanceForTableCols.setSystemPrompt(self.promptTableColsInfo) def getMetadata(self) -> MetaDataLayout : return self.metadataLayout def updateMetadata(self, metadataLayout): self.metadataLayout = metadataLayout self._onMetadataChange() def getQueryForUserInput(self, userInput): prospectTablesAndColsText = self.gptInstanceForTableCols.getResponseForUserInput(userInput) selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols() prospectTablesAndCols = dict() for table in selectedTablesAndCols: if table in prospectTablesAndColsText: prospectTablesAndCols[table] = [] for col in selectedTablesAndCols[table]: if col in prospectTablesAndColsText: prospectTablesAndCols[table].append(col) print("tables and cols select by gpt", prospectTablesAndCols) promptForQuery = self.getSystemPromptForQuery(prospectTablesAndCols) self.gptInstanceForQuery.setSystemPrompt(promptForQuery) gptResponse = self.gptInstanceForQuery.getResponseForUserInput(userInput) #following CoT in select column then get query to save tokens tryParsing = True parsedSql = False if tryParsing: try: txt = gptResponse.split("```json")[-1].split('```')[0].replace('\n', '') sqlResult = json.loads(txt)['finalResult'] parsedSql = True tryParsing = False except: print("Couldn't parse desired result from gpt response using method 1.") if tryParsing: try: sqlResult = json.loads(gptResponse)['finalResult'] parsedSql = True tryParsing = False except: print("Couldn't parse desired result from gpt response using method 2") if parsedSql: isFormatted = False try: formattedSql = sqlparse.format(sqlResult, reindent=True) responseToReturn = formattedSql isFormatted = True except: isFormatted = False if not isFormatted: try: formattedSql = sqlparse.format(sqlResult['result'], reindent=True) responseToReturn = formattedSql print("gpt didn't give parsed result. So parsing again. the formatting.") except: responseToReturn = str(sqlResult) else: responseToReturn = gptResponse return responseToReturn def getSystemPromptForTableCols(self): schemaName = self.schemaName platform = self.platform tableSummaryDict = json.load(open(self.tableSummaryJson, 'r')) selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols() promptTableInfo = f"""You are a powerful text to sql model. Answer which tables and columns are needed to answer user input using sql query. and following are tables and columns info. and example user input and result query.""" for idx, tableName in enumerate(selectedTablesAndCols.keys(), start=1): promptTableInfo += f"table name {tableName} and summary is {tableSummaryDict[tableName]}" promptTableInfo += f" and columns {', '.join(selectedTablesAndCols[tableName])} \n" promptTableInfo += "XXXX" #Join statements promptTableInfo += f"and table Relations are {TABLE_RELATIONS}" return promptTableInfo def getSystemPromptForQuery(self, prospectTablesAndCols): schemaName = self.schemaName platform = self.platform tableSummaryDict = json.load(open(self.tableSummaryJson, 'r')) egUserInput = "I want to get top 5 product categories by state, then rank categories on decreasing order of total sales" cotSubtaskOutput = """{ "subquery1": { "inputSubquery": [], "descriptioin":"calculate the total sales and assigns ranks to product categories within each state based on the descending order of sales in the tbl_f_sales table, utilizing joins with tbl_d_product and tbl_d_customer tables.", "result": "SELECT c.state, b.category, SUM(a.transaction_amount) as total_sales, RANK() OVER(PARTITION BY c.state ORDER BY SUM(a.transaction_amount) DESC) as category_rank FROM lpdatamart.tbl_f_sales a JOIN lpdatamart.tbl_d_product b ON a.product_id = b.product_id JOIN lpdatamart.tbl_d_customer c ON a.customer_id = c.customer_id GROUP BY c.state, b.category " }, "subquery2": { "inputSubquery": ["subquery1"], "description":"extracts state, category, and total sales information from a subquery named "subquery1," filtering the results to include only categories with ranks up to 5 and sorting them by state and category rank." "result":"SELECT state, category, total_sales FROM ranked_categories WHERE category_rank <= 5 ORDER BY state, category_rank" }, "finalResult":"WITH subquery1 AS ( SELECT c.state, b.category, SUM(a.transaction_amount) as total_sales, RANK() OVER(PARTITION BY c.state ORDER BY SUM(a.transaction_amount) DESC) as category_rank FROM lpdatamart.tbl_f_sales a JOIN lpdatamart.tbl_d_product b ON a.product_id = b.product_id JOIN lpdatamart.tbl_d_customer c ON a.customer_id = c.customer_id GROUP BY c.state, b.category ) SELECT state, category, total_sales FROM subquery1 WHERE category_rank <= 5 ORDER BY state, category_rank" }""" prompt = f"""You are a powerful text to sql model. Your task is to return sql query which answers user's input. Please follow subquery structure if the sql needs to have multiple subqueries. ###example userInput is {egUserInput}. output is {cotSubtaskOutput}. Output should be in json format as provided. Only output should be in response, nothing else.\n\n """ for idx, tableName in enumerate(prospectTablesAndCols.keys(), start=1): prompt += f"table name is {tableName}, table data is {self.sampleData[tableName][prospectTablesAndCols[tableName]].head(self.gptSampleRows)}" prompt += f"and table Relations are {TABLE_RELATIONS}" return prompt.replace("\\"," ").replace(" "," ").replace("XXXX", " ")