Spaces:
Runtime error
Runtime error
| from gptManager import ChatgptManager | |
| from utils import * | |
| import json | |
| from constants import TABLE_RELATIONS | |
| class QueryHelper: | |
| def __init__(self, gptInstanceForTableCols: ChatgptManager, | |
| gptInstanceForQuery: ChatgptManager, | |
| dbEngine, schemaName, | |
| platform, metadataLayout: MetaDataLayout, sampleDataRows, | |
| gptSampleRows, getSampleDataForTablesAndCols, tableSummaryJson='tableSummaryDict.json'): | |
| self.gptInstanceForTableCols = gptInstanceForTableCols | |
| self.gptInstanceForQuery = gptInstanceForQuery | |
| self.schemaName = schemaName | |
| self.platform = platform | |
| self.metadataLayout = metadataLayout | |
| self.sampleDataRows = sampleDataRows | |
| self.gptSampleRows = gptSampleRows | |
| self.getSampleDataForTablesAndCols = getSampleDataForTablesAndCols | |
| self.dbEngine = dbEngine | |
| self.tableSummaryJson = tableSummaryJson | |
| self._onMetadataChange() | |
| def _onMetadataChange(self): | |
| metadataLayout = self.metadataLayout | |
| sampleDataRows = self.sampleDataRows | |
| dbEngine = self.dbEngine | |
| schemaName = self.schemaName | |
| selectedTablesAndCols = metadataLayout.getSelectedTablesAndCols() | |
| self.sampleData = self.getSampleDataForTablesAndCols(dbEngine=dbEngine,schemaName=schemaName, | |
| tablesAndCols=selectedTablesAndCols, maxRows=sampleDataRows) | |
| self.promptTableColsInfo = self.getSystemPromptForTableCols() | |
| self.gptInstanceForTableCols.setSystemPrompt(self.promptTableColsInfo) | |
| def getMetadata(self) -> MetaDataLayout : | |
| return self.metadataLayout | |
| def updateMetadata(self, metadataLayout): | |
| self.metadataLayout = metadataLayout | |
| self._onMetadataChange() | |
| def getQueryForUserInput(self, userInput): | |
| prospectTablesAndColsText = self.gptInstanceForTableCols.getResponseForUserInput(userInput) | |
| selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols() | |
| prospectTablesAndCols = dict() | |
| for table in selectedTablesAndCols: | |
| if table in prospectTablesAndColsText: | |
| prospectTablesAndCols[table] = [] | |
| for col in selectedTablesAndCols[table]: | |
| if col in prospectTablesAndColsText: | |
| prospectTablesAndCols[table].append(col) | |
| promptForQuery = getSystemPromptForQuery(prospectTablesAndCols) | |
| self.gptInstanceForQuery.setSystemPrompt(promptForQuery) | |
| gptResponse = self.gptInstanceForQuery.getResponseForUserInput(userInput) | |
| return gptResponse | |
| def getSystemPromptForTableCols(self): | |
| schemaName = self.schemaName | |
| platform = self.platform | |
| tableSummaryDict = json.load(self.tableSummaryJson) | |
| selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols() | |
| promptTableInfo = f"""You are a powerful text to sql model. Answer which tables and columns are needed | |
| to answer user input using sql query. and following are tables and columns info. and example user input and result query.""" | |
| for idx, tableName in enumerate(selectedTablesAndCols.keys(), start=1): | |
| promptTableInfo += f"table name {tableName} and summary is {tableSummaryDict[tableName]}" | |
| promptTableInfo += f" and columns {', '.join(selectedTablesAndCols[tableName])} \n" | |
| promptTableInfo += "XXXX" | |
| #Join statements | |
| promptTableInfo += f"and table Relations are {TABLE_RELATIONS}" | |
| return promptTableInfo | |
| def getSystemPromptForQuery(self, prospectTablesAndCols): | |
| schemaName = self.schemaName | |
| platform = self.platform | |
| tableSummaryDict = json.load(self.tableSummaryJson) | |
| exampleQuery = """SELECT a.customer_id, COUNT(a.product_id) as chandelier_count | |
| FROM lpdatamart.tbl_f_sales a | |
| JOIN lpdatamart.tbl_d_product b ON a.product_id = b.product_id | |
| JOIN lpdatamart.tbl_d_calendar c ON a.date_id = c.date_id | |
| WHERE UPPER(b.product_name) LIKE '%CHANDELIER%' AND c.calendar_month = 'NOVEMBER' AND c.year = 2023 | |
| GROUP BY a.customer_id | |
| ORDER BY chandelier_count DESC""" | |
| question = "top 5 customers who bought most chandeliers in nov 2023" | |
| promptForQuery = f"""You are a powerful text to sql model. Answer user input with sql query. And the query needs to run on {platform}. and schemaName is {schemaName}. There is example user input and desired generated sql query. Follow similar patterns as example. eg case insensitive, explicit variable declaration etc. user input : {question}, query : {exampleQuery}. and table's data is \n""" | |
| for idx, tableName in enumerate(prospectTablesAndCols.keys(), start=1): | |
| promptForQuery += f"table name is {tableName}, table data is {self.sampleData[tableName][prospectTablesAndCols[tableName]].head(gptSampleRows)}" | |
| promptForQuery += f"and table Relations are {TABLE_RELATIONS}" | |
| return promptForQuery.replace("\\"," ").replace(" "," ").replace("XXXX", " ") | |