Spaces:
Runtime error
Runtime error
| from gptManager import ChatgptManager | |
| from utils import * | |
| import json | |
| import sqlparse | |
| from constants import TABLE_RELATIONS | |
| class QueryHelperChainOfThought: | |
| def __init__(self, gptInstanceForCoT: ChatgptManager, | |
| dbEngine, schemaName, | |
| platform, metadataLayout: MetaDataLayout, sampleDataRows, | |
| gptSampleRows, getSampleDataForTablesAndCols, tableSummaryJson='tableSummaryDict.json'): | |
| self.gptInstanceForCoT = gptInstanceForCoT | |
| self.schemaName = schemaName | |
| self.platform = platform | |
| self.metadataLayout = metadataLayout | |
| self.sampleDataRows = sampleDataRows | |
| self.gptSampleRows = gptSampleRows | |
| self.getSampleDataForTablesAndCols = getSampleDataForTablesAndCols | |
| self.dbEngine = dbEngine | |
| self.tableSummaryJson = tableSummaryJson | |
| self._onMetadataChange() | |
| def _onMetadataChange(self): | |
| metadataLayout = self.metadataLayout | |
| sampleDataRows = self.sampleDataRows | |
| dbEngine = self.dbEngine | |
| schemaName = self.schemaName | |
| selectedTablesAndCols = metadataLayout.getSelectedTablesAndCols() | |
| self.sampleData = self.getSampleDataForTablesAndCols(dbEngine=dbEngine,schemaName=schemaName, | |
| tablesAndCols=selectedTablesAndCols, maxRows=sampleDataRows) | |
| def getMetadata(self) -> MetaDataLayout : | |
| return self.metadataLayout | |
| def updateMetadata(self, metadataLayout): | |
| self.metadataLayout = metadataLayout | |
| self._onMetadataChange() | |
| def getQueryForUserInputWithHistory(self, verboseChatHistory, userInput): | |
| prompt = self.getPromptForCot() | |
| self.gptInstanceForCoT.setSystemPrompt(prompt) | |
| gptResponse = self.gptInstanceForCoT.getResponseForChatHistory(verboseChatHistory, userInput) | |
| verboseResponse = gptResponse | |
| query, jsonResponse = getQueryFromGptResponse(gptResponse=gptResponse) | |
| if query!=gptResponse: | |
| finalQuery = construct_final_query(query, jsonResponse) | |
| else: | |
| finalQuery = query | |
| return finalQuery, verboseResponse | |
| def getQueryForUserInputCoT(self, userInput): | |
| prompt = self.getPromptForCot() | |
| self.gptInstanceForCoT.setSystemPrompt(prompt) | |
| gptResponse = self.gptInstanceForCoT.getResponseForUserInput(userInput) | |
| tryParsing = True | |
| parsedSql = False | |
| if tryParsing: | |
| try: | |
| txt = gptResponse.split("```json")[-1].split('```')[0].replace('\n', ' ') | |
| sqlResult = json.loads(txt)['finalResult'] | |
| parsedSql = True | |
| tryParsing = False | |
| print("parsed desired result from gpt response using method 1.") | |
| except: | |
| print("Couldn't parse desired result from gpt response using method 1.") | |
| if tryParsing: | |
| try: | |
| sqlResult = json.loads(gptResponse)['finalResult'] | |
| parsedSql = True | |
| tryParsing = False | |
| print("parsed desired result from gpt response using method 2.") | |
| except: | |
| print("Couldn't parse desired result from gpt response using method 2") | |
| if tryParsing: | |
| try: | |
| sqlResult = json.loads(gptResponse.replace("```json","").replace("```","").replace('\n', ' '))['finalResult'] | |
| parsedSql = True | |
| tryParsing = False | |
| print("parsed desired result from gpt response using method 3.") | |
| except: | |
| print("Couldn't parse desired result from gpt response using method 3") | |
| if parsedSql: | |
| isFormatted = False | |
| try: | |
| formattedSql = sqlparse.format(sqlResult, reindent=True) | |
| responseToReturn = formattedSql | |
| isFormatted = True | |
| except: | |
| isFormatted = False | |
| if not isFormatted: | |
| try: | |
| formattedSql = sqlparse.format(sqlResult['result'], reindent=True) | |
| responseToReturn = formattedSql | |
| print("gpt didn't give parsed result. So parsing again. the formatting.") | |
| except: | |
| responseToReturn = str(sqlResult) | |
| else: | |
| responseToReturn = gptResponse | |
| return responseToReturn | |
| def getPromptForCot(self): | |
| schemaName = self.schemaName | |
| platform = self.platform | |
| tableSummaryDict = json.load(open(self.tableSummaryJson, 'r')) | |
| selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols() | |
| egUserInput = "I want to get top 5 product categories by state, then rank categories on decreasing order of total sales" | |
| cotSubtaskOutput = """{ | |
| "subquery1": { | |
| "inputSubquery": [], | |
| "descriptioin":"calculate the total sales and assigns ranks to product categories within each state based on the descending order of sales in the tbl_f_sales table, utilizing joins with tbl_d_product and tbl_d_customer tables.", | |
| "result": "SELECT c.state, b.category, SUM(a.transaction_amount) as total_sales, | |
| RANK() OVER(PARTITION BY c.state ORDER BY SUM(a.transaction_amount) DESC) as category_rank | |
| FROM lpdatamart.tbl_f_sales a | |
| JOIN lpdatamart.tbl_d_product b | |
| ON a.product_id = b.product_id | |
| JOIN lpdatamart.tbl_d_customer c | |
| ON a.customer_id = c.customer_id | |
| GROUP BY c.state, b.category " | |
| }, | |
| "subquery2": { | |
| "inputSubquery": ["subquery1"], | |
| "description":"extracts state, category, and total sales information from a subquery named subquery1, filtering the results to include only categories with ranks up to 5 and sorting them by state and category rank.", | |
| "result":"SELECT state, category, total_sales | |
| FROM ranked_categories | |
| WHERE category_rank <= 5 | |
| ORDER BY state, category_rank" | |
| }, | |
| "finalResult":"WITH subquery1 AS ( | |
| SELECT c.state, b.category, SUM(a.transaction_amount) as total_sales, | |
| RANK() OVER(PARTITION BY c.state ORDER BY SUM(a.transaction_amount) DESC) as category_rank | |
| FROM lpdatamart.tbl_f_sales a | |
| JOIN lpdatamart.tbl_d_product b | |
| ON a.product_id = b.product_id | |
| JOIN lpdatamart.tbl_d_customer c | |
| ON a.customer_id = c.customer_id | |
| GROUP BY c.state, b.category | |
| ) | |
| SELECT state, category, total_sales | |
| FROM subquery1 | |
| WHERE category_rank <= 5 | |
| ORDER BY state, category_rank" | |
| }""" | |
| salesInfoPrompt = """ | |
| Following is tbl_d_lineaction_code table. and tbl_d_lineaction_code and tbl_f_sales | |
| join on tbl_d_lineaction_code.line_action_code, tbl_f_sales.line_action . | |
| Each transaction New order, canceled, returned gets entry in tbl_f_sales. Please consider them while calculating revenue etc. | |
| line_action_code line_action_code_desc load_date catgory sales_type | |
| ---------------- ------------------------------ --------------------- ------- ---------- | |
| 201 NEW 2015-08-03 00:37:53.0 <null> WEB | |
| 202 UNSHIPPED 2015-08-03 00:37:53.0 <null> WEB | |
| 203 SHIPPED 2015-08-03 00:37:54.0 <null> WEB | |
| 204 CANCELED 2015-08-03 00:37:55.0 <null> WEB | |
| 205 RETURN 2015-08-03 00:37:55.0 <null> WEB | |
| 1 SOLD 2014-09-17 02:24:28.0 <null> POS | |
| 2 RETURNED 2014-09-17 02:24:30.0 <null> POS | |
| 7 ORDERED 2014-09-17 02:24:31.0 <null> POS | |
| 8 ORDER CANCELLED 2014-09-17 02:24:31.0 <null> POS | |
| 90 ORDER DELIVERED 2014-09-17 02:24:32.0 <null> POS | |
| """ | |
| masterIdPrompt = """ | |
| ###Important | |
| Following is important detail about customer data. | |
| Following is two columns (customer_id and master_customer_id) of tbl_d_customer. | |
| master_customer_id represents real single customer but it can be related to various customer_id which | |
| represets various ways the same customer has logged in or created accounts. | |
| When asked about only customer, use master_customer_id | |
| So when accounting for sales or transaction by particular customer, use their all the related customer_id of their | |
| master_customer_id. | |
| customer_id master_customer_id | |
| ---------- ------------------ | |
| 89998738 1562 | |
| 96656077 1562 | |
| 1562 1562 | |
| 74357421 1562 | |
| 290007176 1562 | |
| """ | |
| demandPrompt = """ | |
| For report creation demands means number of quantity sold.""" | |
| subLocationPrompt = """Following are information about | |
| OPEN_BOX TBL_D_STORE: SUBLOCATIONCODE: 9009, 90093, 90094 | |
| PRO/PRO OPENBOX TBL_D_STORE: SUBLOCATIONCODE: 9004, 90094 | |
| CONSUMER/CONSUMER OPENBOX 9003, 90093 | |
| """ | |
| promptTableInfo = self.getSystemPromptForTableCols() | |
| selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols() | |
| promptColumnsInfo = self.getSystemPromptForQuery(selectedTablesAndCols) | |
| prompt = f"""You are a powerful text to sql model. Your task is to return sql query which answers | |
| user's input. Please follow subquery structure if the sql needs to have multiple subqueries. Your response should be in JSON format. | |
| Answer user input with sql query. And the query needs to run on {platform}. and schemaName is {schemaName}. | |
| And use columns and tables provided, in case, you need additional column information, please ask the user. or define them in comments | |
| ###example userInput is {egUserInput}. output is {cotSubtaskOutput}. Output should be in json format as provided. Only output should be in response, nothing else.\n\n | |
| tables information are {promptTableInfo}. | |
| columns data are {promptColumnsInfo}. | |
| Following are extra information which user might assume you know. | |
| {salesInfoPrompt} | |
| {masterIdPrompt} | |
| {demandPrompt} | |
| {subLocationPrompt}. | |
| """ | |
| prompt += f"and table Relations are {TABLE_RELATIONS} " | |
| return prompt | |
| def getSystemPromptForTableCols(self): | |
| schemaName = self.schemaName | |
| platform = self.platform | |
| tableSummaryDict = json.load(open(self.tableSummaryJson, 'r')) | |
| selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols() | |
| promptTableInfo = f"""You are a powerful text to sql model. Answer which tables and columns are needed | |
| to answer user input using sql query. and following are tables and columns info. and example user input and result query.""" | |
| for idx, tableName in enumerate(selectedTablesAndCols.keys(), start=1): | |
| promptTableInfo += f"table name {tableName} and summary is {tableSummaryDict[tableName]} " | |
| promptTableInfo += f" and columns {', '.join(selectedTablesAndCols[tableName])} \n " | |
| promptTableInfo += "XXXX" | |
| #Join statements | |
| promptTableInfo += f" and table Relations are {TABLE_RELATIONS} " | |
| return promptTableInfo | |
| def getSystemPromptForQuery(self, prospectTablesAndCols): | |
| schemaName = self.schemaName | |
| platform = self.platform | |
| tableSummaryDict = json.load(open(self.tableSummaryJson,'r')) | |
| exampleQuery = """SELECT a.customer_id, COUNT(a.product_id) as chandelier_count | |
| FROM lpdatamart.tbl_f_sales a | |
| JOIN lpdatamart.tbl_d_product b ON a.product_id = b.product_id | |
| JOIN lpdatamart.tbl_d_calendar c ON a.date_id = c.date_id | |
| WHERE UPPER(b.product_name) LIKE '%CHANDELIER%' AND c.calendar_month = 'NOVEMBER' AND c.year = 2023 | |
| GROUP BY a.customer_id | |
| ORDER BY chandelier_count DESC""" | |
| question = "top 5 customers who bought most chandeliers in nov 2023" | |
| promptForQuery = f"""You are a powerful text to sql model. Answer user input with sql query. And the query needs to run on {platform}. and schemaName is {schemaName}. There is example user input and desired generated sql query. Follow similar patterns as example. eg case insensitive, explicit variable declaration etc. user input : {question}, query : {exampleQuery}. and table's data is \n """ | |
| for idx, tableName in enumerate(prospectTablesAndCols.keys(), start=1): | |
| promptForQuery += f"table name is {tableName}, table data is {self.sampleData[tableName][prospectTablesAndCols[tableName]].head(self.gptSampleRows)} \n " | |
| promptForQuery += f"and table Relations are {TABLE_RELATIONS} \n " | |
| return promptForQuery.replace("\\"," ").replace(" "," ").replace("XXXX", " ") |