Spaces:
Runtime error
Runtime error
| from gptManager import ChatgptManager | |
| from utils import * | |
| import json | |
| import sqlparse | |
| from constants import TABLE_RELATIONS | |
| class QueryHelper: | |
| def __init__(self, gptInstanceForTableCols: ChatgptManager, | |
| gptInstanceForQuery: ChatgptManager, | |
| dbEngine, schemaName, | |
| platform, metadataLayout: MetaDataLayout, sampleDataRows, | |
| gptSampleRows, getSampleDataForTablesAndCols, tableSummaryJson='tableSummaryDict.json'): | |
| self.gptInstanceForTableCols = gptInstanceForTableCols | |
| self.gptInstanceForQuery = gptInstanceForQuery | |
| self.schemaName = schemaName | |
| self.platform = platform | |
| self.metadataLayout = metadataLayout | |
| self.sampleDataRows = sampleDataRows | |
| self.gptSampleRows = gptSampleRows | |
| self.getSampleDataForTablesAndCols = getSampleDataForTablesAndCols | |
| self.dbEngine = dbEngine | |
| self.tableSummaryJson = tableSummaryJson | |
| self._onMetadataChange() | |
| def _onMetadataChange(self): | |
| metadataLayout = self.metadataLayout | |
| sampleDataRows = self.sampleDataRows | |
| dbEngine = self.dbEngine | |
| schemaName = self.schemaName | |
| selectedTablesAndCols = metadataLayout.getSelectedTablesAndCols() | |
| self.sampleData = self.getSampleDataForTablesAndCols(dbEngine=dbEngine,schemaName=schemaName, | |
| tablesAndCols=selectedTablesAndCols, maxRows=sampleDataRows) | |
| self.promptTableColsInfo = self.getSystemPromptForTableCols() | |
| self.gptInstanceForTableCols.setSystemPrompt(self.promptTableColsInfo) | |
| def getMetadata(self) -> MetaDataLayout : | |
| return self.metadataLayout | |
| def updateMetadata(self, metadataLayout): | |
| self.metadataLayout = metadataLayout | |
| self._onMetadataChange() | |
| def getQueryForUserInput(self, userInput): | |
| prospectTablesAndColsText = self.gptInstanceForTableCols.getResponseForUserInput(userInput) | |
| selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols() | |
| prospectTablesAndCols = dict() | |
| for table in selectedTablesAndCols: | |
| if table in prospectTablesAndColsText: | |
| prospectTablesAndCols[table] = [] | |
| for col in selectedTablesAndCols[table]: | |
| if col in prospectTablesAndColsText: | |
| prospectTablesAndCols[table].append(col) | |
| print("tables and cols select by gpt", prospectTablesAndCols) | |
| promptForQuery = self.getSystemPromptForQuery(prospectTablesAndCols) | |
| self.gptInstanceForQuery.setSystemPrompt(promptForQuery) | |
| gptResponse = self.gptInstanceForQuery.getResponseForUserInput(userInput) | |
| #following CoT in select column then get query to save tokens | |
| tryParsing = True | |
| parsedSql = False | |
| if tryParsing: | |
| try: | |
| txt = gptResponse.split("```json")[-1].split('```')[0].replace('\n', '') | |
| sqlResult = json.loads(txt)['finalResult'] | |
| parsedSql = True | |
| tryParsing = False | |
| except: | |
| print("Couldn't parse desired result from gpt response using method 1.") | |
| if tryParsing: | |
| try: | |
| sqlResult = json.loads(gptResponse)['finalResult'] | |
| parsedSql = True | |
| tryParsing = False | |
| except: | |
| print("Couldn't parse desired result from gpt response using method 2") | |
| if parsedSql: | |
| isFormatted = False | |
| try: | |
| formattedSql = sqlparse.format(sqlResult, reindent=True) | |
| responseToReturn = formattedSql | |
| isFormatted = True | |
| except: | |
| isFormatted = False | |
| if not isFormatted: | |
| try: | |
| formattedSql = sqlparse.format(sqlResult['result'], reindent=True) | |
| responseToReturn = formattedSql | |
| print("gpt didn't give parsed result. So parsing again. the formatting.") | |
| except: | |
| responseToReturn = str(sqlResult) | |
| else: | |
| responseToReturn = gptResponse | |
| return responseToReturn | |
| def getSystemPromptForTableCols(self): | |
| schemaName = self.schemaName | |
| platform = self.platform | |
| tableSummaryDict = json.load(open(self.tableSummaryJson, 'r')) | |
| selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols() | |
| promptTableInfo = f"""You are a powerful text to sql model. Answer which tables and columns are needed | |
| to answer user input using sql query. and following are tables and columns info. and example user input and result query.""" | |
| for idx, tableName in enumerate(selectedTablesAndCols.keys(), start=1): | |
| promptTableInfo += f"table name {tableName} and summary is {tableSummaryDict[tableName]}" | |
| promptTableInfo += f" and columns {', '.join(selectedTablesAndCols[tableName])} \n" | |
| promptTableInfo += "XXXX" | |
| #Join statements | |
| promptTableInfo += f"and table Relations are {TABLE_RELATIONS}" | |
| return promptTableInfo | |
| def getSystemPromptForQuery(self, prospectTablesAndCols): | |
| schemaName = self.schemaName | |
| platform = self.platform | |
| tableSummaryDict = json.load(open(self.tableSummaryJson, 'r')) | |
| egUserInput = "I want to get top 5 product categories by state, then rank categories on decreasing order of total sales" | |
| cotSubtaskOutput = """{ | |
| "subquery1": { | |
| "inputSubquery": [], | |
| "descriptioin":"calculate the total sales and assigns ranks to product categories within each state based on the descending order of sales in the tbl_f_sales table, utilizing joins with tbl_d_product and tbl_d_customer tables.", | |
| "result": "SELECT c.state, b.category, SUM(a.transaction_amount) as total_sales, | |
| RANK() OVER(PARTITION BY c.state ORDER BY SUM(a.transaction_amount) DESC) as category_rank | |
| FROM lpdatamart.tbl_f_sales a | |
| JOIN lpdatamart.tbl_d_product b | |
| ON a.product_id = b.product_id | |
| JOIN lpdatamart.tbl_d_customer c | |
| ON a.customer_id = c.customer_id | |
| GROUP BY c.state, b.category " | |
| }, | |
| "subquery2": { | |
| "inputSubquery": ["subquery1"], | |
| "description":"extracts state, category, and total sales information from a subquery named "subquery1," filtering the results to include only categories with ranks up to 5 and sorting them by state and category rank." | |
| "result":"SELECT state, category, total_sales | |
| FROM ranked_categories | |
| WHERE category_rank <= 5 | |
| ORDER BY state, category_rank" | |
| }, | |
| "finalResult":"WITH subquery1 AS ( | |
| SELECT c.state, b.category, SUM(a.transaction_amount) as total_sales, | |
| RANK() OVER(PARTITION BY c.state ORDER BY SUM(a.transaction_amount) DESC) as category_rank | |
| FROM lpdatamart.tbl_f_sales a | |
| JOIN lpdatamart.tbl_d_product b | |
| ON a.product_id = b.product_id | |
| JOIN lpdatamart.tbl_d_customer c | |
| ON a.customer_id = c.customer_id | |
| GROUP BY c.state, b.category | |
| ) | |
| SELECT state, category, total_sales | |
| FROM subquery1 | |
| WHERE category_rank <= 5 | |
| ORDER BY state, category_rank" | |
| }""" | |
| prompt = f"""You are a powerful text to sql model. Your task is to return sql query which answers | |
| user's input. Please follow subquery structure if the sql needs to have multiple subqueries. | |
| ###example userInput is {egUserInput}. output is {cotSubtaskOutput}. Output should be in json format as provided. Only output should be in response, nothing else.\n\n | |
| """ | |
| for idx, tableName in enumerate(prospectTablesAndCols.keys(), start=1): | |
| prompt += f"table name is {tableName}, table data is {self.sampleData[tableName][prospectTablesAndCols[tableName]].head(self.gptSampleRows)}" | |
| prompt += f"and table Relations are {TABLE_RELATIONS}" | |
| return prompt.replace("\\"," ").replace(" "," ").replace("XXXX", " ") | |