Spaces:
Runtime error
Runtime error
| from gptManager import ChatgptManager | |
| from utils import * | |
| class QueryHelper: | |
| def __init__(self, gptInstance: ChatgptManager, dbEngine, schemaName, | |
| platform, metadataLayout: MetaDataLayout, sampleDataRows, | |
| gptSampleRows, getSampleDataForTablesAndCols): | |
| self.gptInstance = gptInstance | |
| self.schemaName = schemaName | |
| self.platform = platform | |
| self.metadataLayout = metadataLayout | |
| self.sampleDataRows = sampleDataRows | |
| self.gptSampleRows = gptSampleRows | |
| self.getSampleDataForTablesAndCols = getSampleDataForTablesAndCols | |
| self.dbEngine = dbEngine | |
| self._onMetadataChange() | |
| def _onMetadataChange(self): | |
| metadataLayout = self.metadataLayout | |
| sampleDataRows = self.sampleDataRows | |
| dbEngine = self.dbEngine | |
| schemaName = self.schemaName | |
| selectedTablesAndCols = metadataLayout.getSelectedTablesAndCols() | |
| self.sampleData = self.getSampleDataForTablesAndCols(dbEngine=dbEngine,schemaName=schemaName, | |
| tablesAndCols=selectedTablesAndCols, maxRows=sampleDataRows) | |
| def getMetadata(self) -> MetaDataLayout : | |
| return self.metadataLayout | |
| def updateMetadata(self, metadataLayout): | |
| self.metadataLayout = metadataLayout | |
| self._onMetadataChange() | |
| def modifySqlQueryEnteredByUser(self, userSqlQuery): | |
| platform = self.platform | |
| userPrompt = f"Please correct the following sql query, also it has to be run on {platform}. sql query is \n {userSqlQuery}." | |
| systemPrompt = "" | |
| modifiedSql = self.gptInstance.getResponseForUserInput(userPrompt, systemPrompt) | |
| return modifiedSql | |
| def filteredSampleDataForProspects(self, prospectTablesAndCols): | |
| sampleData = self.sampleData | |
| filteredData = {} | |
| for table in prospectTablesAndCols.keys(): | |
| # filteredData[table] = sampleData[table][prospectTablesAndCols[table]] | |
| #take all columns of prospects | |
| filteredData[table] = sampleData[table][prospectTablesAndCols[table]] | |
| return filteredData | |
| def getQueryForUserInput(self, userInput, chatHistory=[]): | |
| gptSampleRows = self.gptSampleRows | |
| selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols() | |
| prospectTablesAndCols = self.getProspectiveTablesAndCols(userInput, selectedTablesAndCols, chatHistory) | |
| print("getting prospects", prospectTablesAndCols) | |
| prospectTablesData = self.filteredSampleDataForProspects(prospectTablesAndCols) | |
| systemPromptForQueryGeneration = self.getSystemPromptForQueryGeneration(prospectTablesData, gptSampleRows=gptSampleRows) | |
| queryByGpt = self.gptInstance.getResponseForUserInput(userInput, systemPromptForQueryGeneration, chatHistory) | |
| queryByGpt = preProcessGptQueryReponse(queryByGpt, metadataLayout=self.metadataLayout) | |
| return queryByGpt, prospectTablesAndCols | |
| def getProspectiveTablesAndCols(self, userInput, selectedTablesAndCols, chatHistory=[]): | |
| schemaName = self.schemaName | |
| systemPromptForProspectColumns = self.getSystemPromptForProspectColumns(selectedTablesAndCols) | |
| prospectiveTablesColsText = self.gptInstance.getResponseForUserInput(userInput, systemPromptForProspectColumns, chatHistory) | |
| prospectTablesAndCols = {} | |
| for table in selectedTablesAndCols.keys(): | |
| if table in prospectiveTablesColsText: | |
| prospectTablesAndCols[table] = [] | |
| for column in selectedTablesAndCols[table]: | |
| if column in prospectiveTablesColsText: | |
| prospectTablesAndCols[table].append(column) | |
| return prospectTablesAndCols | |
| def getSystemPromptForQueryGeneration(self, prospectTablesData, gptSampleRows): | |
| schemaName = self.schemaName | |
| platform = self.platform | |
| exampleQuery = """SELECT a.customer_id, COUNT(a.product_id) as chandelier_count | |
| FROM lpdatamart.tbl_f_sales a | |
| JOIN lpdatamart.tbl_d_product b ON a.product_id = b.product_id | |
| JOIN lpdatamart.tbl_d_calendar c ON a.date_id = c.date_id | |
| WHERE UPPER(b.product_name) LIKE '%CHANDELIER%' AND c.calendar_month = 'NOVEMBER' AND c.year = 2023 | |
| GROUP BY a.customer_id | |
| ORDER BY chandelier_count DESC""" | |
| question = "top 5 customers who bought most chandeliers in nov 2023" | |
| prompt = f"""Given an input text, generate the corresponding SQL query for given details. Schema Name is {schemaName}. And sql platform is {platform}.\n following is sample data. Also | |
| There is example user input and desired generated sql query. Follow similar patterns as example. eg case insensitive, explicit variable declaration etc. user input : {question}, query : {exampleQuery} """ | |
| for idx, tableName in enumerate(prospectTablesData.keys(), start=1): | |
| prompt += f"table name is {tableName}, table data is {prospectTablesData[tableName].head(gptSampleRows)}" | |
| prompt += "XXXX" | |
| return prompt.replace("\n"," ").replace("\\"," ").replace(" "," ").replace("XXXX", " ") | |
| def getSystemPromptForProspectColumns(self, selectedTablesAndCols): | |
| schemaName = self.schemaName | |
| platform = self.platform | |
| prompt = f"""Given an input text, User wants to know which all tables and columns would be possibily to have the desired data. Output them as json. Schema Name is {schemaName}. And sql platform is {platform}.\n""" | |
| for idx, tableName in enumerate(selectedTablesAndCols.keys(), start=1): | |
| prompt += f"table name {tableName} {', '.join(selectedTablesAndCols[tableName])}" | |
| prompt += "XXXX" | |
| return prompt.replace("\n"," ").replace("\\"," ").replace(" "," ").replace("XXXX", " ") | |