from gptManager import ChatgptManager from utils import * import json import sqlparse from constants import TABLE_RELATIONS class QueryHelperChainOfThought: def __init__(self, gptInstanceForCoT: ChatgptManager, dbEngine, schemaName, platform, metadataLayout: MetaDataLayout, sampleDataRows, gptSampleRows, getSampleDataForTablesAndCols, tableSummaryJson='tableSummaryDict.json'): self.gptInstanceForCoT = gptInstanceForCoT self.schemaName = schemaName self.platform = platform self.metadataLayout = metadataLayout self.sampleDataRows = sampleDataRows self.gptSampleRows = gptSampleRows self.getSampleDataForTablesAndCols = getSampleDataForTablesAndCols self.dbEngine = dbEngine self.tableSummaryJson = tableSummaryJson self._onMetadataChange() def _onMetadataChange(self): metadataLayout = self.metadataLayout sampleDataRows = self.sampleDataRows dbEngine = self.dbEngine schemaName = self.schemaName selectedTablesAndCols = metadataLayout.getSelectedTablesAndCols() self.sampleData = self.getSampleDataForTablesAndCols(dbEngine=dbEngine,schemaName=schemaName, tablesAndCols=selectedTablesAndCols, maxRows=sampleDataRows) def getMetadata(self) -> MetaDataLayout : return self.metadataLayout def updateMetadata(self, metadataLayout): self.metadataLayout = metadataLayout self._onMetadataChange() def getQueryForUserInputWithHistory(self, verboseChatHistory, userInput): prompt = self.getPromptForCot() self.gptInstanceForCoT.setSystemPrompt(prompt) gptResponse = self.gptInstanceForCoT.getResponseForChatHistory(verboseChatHistory, userInput) verboseResponse = gptResponse query, jsonResponse = getQueryFromGptResponse(gptResponse=gptResponse) if query!=gptResponse: finalQuery = construct_final_query(query, jsonResponse) else: finalQuery = query return finalQuery, verboseResponse def getQueryForUserInputCoT(self, userInput): prompt = self.getPromptForCot() self.gptInstanceForCoT.setSystemPrompt(prompt) gptResponse = self.gptInstanceForCoT.getResponseForUserInput(userInput) tryParsing = True parsedSql = False if tryParsing: try: txt = gptResponse.split("```json")[-1].split('```')[0].replace('\n', ' ') sqlResult = json.loads(txt)['finalResult'] parsedSql = True tryParsing = False print("parsed desired result from gpt response using method 1.") except: print("Couldn't parse desired result from gpt response using method 1.") if tryParsing: try: sqlResult = json.loads(gptResponse)['finalResult'] parsedSql = True tryParsing = False print("parsed desired result from gpt response using method 2.") except: print("Couldn't parse desired result from gpt response using method 2") if tryParsing: try: sqlResult = json.loads(gptResponse.replace("```json","").replace("```","").replace('\n', ' '))['finalResult'] parsedSql = True tryParsing = False print("parsed desired result from gpt response using method 3.") except: print("Couldn't parse desired result from gpt response using method 3") if parsedSql: isFormatted = False try: formattedSql = sqlparse.format(sqlResult, reindent=True) responseToReturn = formattedSql isFormatted = True except: isFormatted = False if not isFormatted: try: formattedSql = sqlparse.format(sqlResult['result'], reindent=True) responseToReturn = formattedSql print("gpt didn't give parsed result. So parsing again. the formatting.") except: responseToReturn = str(sqlResult) else: responseToReturn = gptResponse return responseToReturn def getPromptForCot(self): schemaName = self.schemaName platform = self.platform tableSummaryDict = json.load(open(self.tableSummaryJson, 'r')) selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols() egUserInput = "I want to get top 5 product categories by state, then rank categories on decreasing order of total sales" cotSubtaskOutput = """{ "subquery1": { "inputSubquery": [], "descriptioin":"calculate the total sales and assigns ranks to product categories within each state based on the descending order of sales in the tbl_f_sales table, utilizing joins with tbl_d_product and tbl_d_customer tables.", "result": "SELECT c.state, b.category, SUM(a.transaction_amount) as total_sales, RANK() OVER(PARTITION BY c.state ORDER BY SUM(a.transaction_amount) DESC) as category_rank FROM lpdatamart.tbl_f_sales a JOIN lpdatamart.tbl_d_product b ON a.product_id = b.product_id JOIN lpdatamart.tbl_d_customer c ON a.customer_id = c.customer_id GROUP BY c.state, b.category " }, "subquery2": { "inputSubquery": ["subquery1"], "description":"extracts state, category, and total sales information from a subquery named subquery1, filtering the results to include only categories with ranks up to 5 and sorting them by state and category rank.", "result":"SELECT state, category, total_sales FROM ranked_categories WHERE category_rank <= 5 ORDER BY state, category_rank" }, "finalResult":"WITH subquery1 AS ( SELECT c.state, b.category, SUM(a.transaction_amount) as total_sales, RANK() OVER(PARTITION BY c.state ORDER BY SUM(a.transaction_amount) DESC) as category_rank FROM lpdatamart.tbl_f_sales a JOIN lpdatamart.tbl_d_product b ON a.product_id = b.product_id JOIN lpdatamart.tbl_d_customer c ON a.customer_id = c.customer_id GROUP BY c.state, b.category ) SELECT state, category, total_sales FROM subquery1 WHERE category_rank <= 5 ORDER BY state, category_rank" }""" salesInfoPrompt = """ Following is tbl_d_lineaction_code table. and tbl_d_lineaction_code and tbl_f_sales join on tbl_d_lineaction_code.line_action_code, tbl_f_sales.line_action . Each transaction New order, canceled, returned gets entry in tbl_f_sales. Please consider them while calculating revenue etc. line_action_code line_action_code_desc load_date catgory sales_type ---------------- ------------------------------ --------------------- ------- ---------- 201 NEW 2015-08-03 00:37:53.0 WEB 202 UNSHIPPED 2015-08-03 00:37:53.0 WEB 203 SHIPPED 2015-08-03 00:37:54.0 WEB 204 CANCELED 2015-08-03 00:37:55.0 WEB 205 RETURN 2015-08-03 00:37:55.0 WEB 1 SOLD 2014-09-17 02:24:28.0 POS 2 RETURNED 2014-09-17 02:24:30.0 POS 7 ORDERED 2014-09-17 02:24:31.0 POS 8 ORDER CANCELLED 2014-09-17 02:24:31.0 POS 90 ORDER DELIVERED 2014-09-17 02:24:32.0 POS """ masterIdPrompt = """ ###Important Following is important detail about customer data. Following is two columns (customer_id and master_customer_id) of tbl_d_customer. master_customer_id represents real single customer but it can be related to various customer_id which represets various ways the same customer has logged in or created accounts. When asked about only customer, use master_customer_id So when accounting for sales or transaction by particular customer, use their all the related customer_id of their master_customer_id. customer_id master_customer_id ---------- ------------------ 89998738 1562 96656077 1562 1562 1562 74357421 1562 290007176 1562 """ demandPrompt = """ For report creation demands means number of quantity sold.""" subLocationPrompt = """Following are information about OPEN_BOX TBL_D_STORE: SUBLOCATIONCODE: 9009, 90093, 90094 PRO/PRO OPENBOX TBL_D_STORE: SUBLOCATIONCODE: 9004, 90094 CONSUMER/CONSUMER OPENBOX 9003, 90093 """ promptTableInfo = self.getSystemPromptForTableCols() selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols() promptColumnsInfo = self.getSystemPromptForQuery(selectedTablesAndCols) prompt = f"""You are a powerful text to sql model. Your task is to return sql query which answers user's input. Please follow subquery structure if the sql needs to have multiple subqueries. Your response should be in JSON format. Answer user input with sql query. And the query needs to run on {platform}. and schemaName is {schemaName}. And use columns and tables provided, in case, you need additional column information, please ask the user. or define them in comments ###example userInput is {egUserInput}. output is {cotSubtaskOutput}. Output should be in json format as provided. Only output should be in response, nothing else.\n\n tables information are {promptTableInfo}. columns data are {promptColumnsInfo}. Following are extra information which user might assume you know. {salesInfoPrompt} {masterIdPrompt} {demandPrompt} {subLocationPrompt}. """ prompt += f"and table Relations are {TABLE_RELATIONS} " return prompt def getSystemPromptForTableCols(self): schemaName = self.schemaName platform = self.platform tableSummaryDict = json.load(open(self.tableSummaryJson, 'r')) selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols() promptTableInfo = f"""You are a powerful text to sql model. Answer which tables and columns are needed to answer user input using sql query. and following are tables and columns info. and example user input and result query.""" for idx, tableName in enumerate(selectedTablesAndCols.keys(), start=1): promptTableInfo += f"table name {tableName} and summary is {tableSummaryDict[tableName]} " promptTableInfo += f" and columns {', '.join(selectedTablesAndCols[tableName])} \n " promptTableInfo += "XXXX" #Join statements promptTableInfo += f" and table Relations are {TABLE_RELATIONS} " return promptTableInfo def getSystemPromptForQuery(self, prospectTablesAndCols): schemaName = self.schemaName platform = self.platform tableSummaryDict = json.load(open(self.tableSummaryJson,'r')) exampleQuery = """SELECT a.customer_id, COUNT(a.product_id) as chandelier_count FROM lpdatamart.tbl_f_sales a JOIN lpdatamart.tbl_d_product b ON a.product_id = b.product_id JOIN lpdatamart.tbl_d_calendar c ON a.date_id = c.date_id WHERE UPPER(b.product_name) LIKE '%CHANDELIER%' AND c.calendar_month = 'NOVEMBER' AND c.year = 2023 GROUP BY a.customer_id ORDER BY chandelier_count DESC""" question = "top 5 customers who bought most chandeliers in nov 2023" promptForQuery = f"""You are a powerful text to sql model. Answer user input with sql query. And the query needs to run on {platform}. and schemaName is {schemaName}. There is example user input and desired generated sql query. Follow similar patterns as example. eg case insensitive, explicit variable declaration etc. user input : {question}, query : {exampleQuery}. and table's data is \n """ for idx, tableName in enumerate(prospectTablesAndCols.keys(), start=1): promptForQuery += f"table name is {tableName}, table data is {self.sampleData[tableName][prospectTablesAndCols[tableName]].head(self.gptSampleRows)} \n " promptForQuery += f"and table Relations are {TABLE_RELATIONS} \n " return promptForQuery.replace("\\"," ").replace(" "," ").replace("XXXX", " ")