from gptManager import ChatgptManager from utils import * import json from constants import TABLE_RELATIONS class QueryHelper: def __init__(self, gptInstanceForTableCols: ChatgptManager, gptInstanceForQuery: ChatgptManager, dbEngine, schemaName, platform, metadataLayout: MetaDataLayout, sampleDataRows, gptSampleRows, getSampleDataForTablesAndCols, tableSummaryJson='tableSummaryDict.json'): self.gptInstanceForTableCols = gptInstanceForTableCols self.gptInstanceForQuery = gptInstanceForQuery self.schemaName = schemaName self.platform = platform self.metadataLayout = metadataLayout self.sampleDataRows = sampleDataRows self.gptSampleRows = gptSampleRows self.getSampleDataForTablesAndCols = getSampleDataForTablesAndCols self.dbEngine = dbEngine self.tableSummaryJson = tableSummaryJson self._onMetadataChange() def _onMetadataChange(self): metadataLayout = self.metadataLayout sampleDataRows = self.sampleDataRows dbEngine = self.dbEngine schemaName = self.schemaName selectedTablesAndCols = metadataLayout.getSelectedTablesAndCols() self.sampleData = self.getSampleDataForTablesAndCols(dbEngine=dbEngine,schemaName=schemaName, tablesAndCols=selectedTablesAndCols, maxRows=sampleDataRows) self.promptTableColsInfo = self.getSystemPromptForTableCols() self.gptInstanceForTableCols.setSystemPrompt(self.promptTableColsInfo) def getMetadata(self) -> MetaDataLayout : return self.metadataLayout def updateMetadata(self, metadataLayout): self.metadataLayout = metadataLayout self._onMetadataChange() def getQueryForUserInput(self, userInput): prospectTablesAndColsText = self.gptInstanceForTableCols.getResponseForUserInput(userInput) selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols() prospectTablesAndCols = dict() for table in selectedTablesAndCols: if table in prospectTablesAndColsText: prospectTablesAndCols[table] = [] for col in selectedTablesAndCols[table]: if col in prospectTablesAndColsText: prospectTablesAndCols[table].append(col) print("tables and cols select by gpt", prospectTablesAndCols) promptForQuery = self.getSystemPromptForQuery(prospectTablesAndCols) self.gptInstanceForQuery.setSystemPrompt(promptForQuery) gptResponse = self.gptInstanceForQuery.getResponseForUserInput(userInput) return gptResponse def getSystemPromptForTableCols(self): schemaName = self.schemaName platform = self.platform tableSummaryDict = json.load(open(self.tableSummaryJson, 'r')) selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols() promptTableInfo = f"""You are a powerful text to sql model. Answer which tables and columns are needed to answer user input using sql query. and following are tables and columns info. and example user input and result query.""" for idx, tableName in enumerate(selectedTablesAndCols.keys(), start=1): promptTableInfo += f"table name {tableName} and summary is {tableSummaryDict[tableName]}" promptTableInfo += f" and columns {', '.join(selectedTablesAndCols[tableName])} \n" promptTableInfo += "XXXX" #Join statements promptTableInfo += f"and table Relations are {TABLE_RELATIONS}" return promptTableInfo def getSystemPromptForQuery(self, prospectTablesAndCols): schemaName = self.schemaName platform = self.platform tableSummaryDict = json.load(open(self.tableSummaryJson, 'r')) exampleQuery = """SELECT a.customer_id, COUNT(a.product_id) as chandelier_count FROM lpdatamart.tbl_f_sales a JOIN lpdatamart.tbl_d_product b ON a.product_id = b.product_id JOIN lpdatamart.tbl_d_calendar c ON a.date_id = c.date_id WHERE UPPER(b.product_name) LIKE '%CHANDELIER%' AND c.calendar_month = 'NOVEMBER' AND c.year = 2023 GROUP BY a.customer_id ORDER BY chandelier_count DESC""" question = "top 5 customers who bought most chandeliers in nov 2023" promptForQuery = f"""You are a powerful text to sql model. Answer user input with sql query. And the query needs to run on {platform}. and schemaName is {schemaName}. There is example user input and desired generated sql query. Follow similar patterns as example. eg case insensitive, explicit variable declaration etc. user input : {question}, query : {exampleQuery}. and table's data is \n""" for idx, tableName in enumerate(prospectTablesAndCols.keys(), start=1): promptForQuery += f"table name is {tableName}, table data is {self.sampleData[tableName][prospectTablesAndCols[tableName]].head(self.gptSampleRows)}" promptForQuery += f"and table Relations are {TABLE_RELATIONS}" return promptForQuery.replace("\\"," ").replace(" "," ").replace("XXXX", " ")