Spaces:
Runtime error
Runtime error
File size: 5,500 Bytes
1dda07c c5ff675 1dda07c 9aff6be 1dda07c 8819e5a 1dda07c fb64f16 a978b21 fb64f16 a978b21 1dda07c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 | from gptManager import ChatgptManager
from utils import *
class QueryHelper:
def __init__(self, gptInstance: ChatgptManager, dbEngine, schemaName,
platform, metadataLayout: MetaDataLayout, sampleDataRows,
gptSampleRows, getSampleDataForTablesAndCols):
self.gptInstance = gptInstance
self.schemaName = schemaName
self.platform = platform
self.metadataLayout = metadataLayout
self.sampleDataRows = sampleDataRows
self.gptSampleRows = gptSampleRows
self.getSampleDataForTablesAndCols = getSampleDataForTablesAndCols
self.dbEngine = dbEngine
self._onMetadataChange()
def _onMetadataChange(self):
metadataLayout = self.metadataLayout
sampleDataRows = self.sampleDataRows
dbEngine = self.dbEngine
schemaName = self.schemaName
selectedTablesAndCols = metadataLayout.getSelectedTablesAndCols()
self.sampleData = self.getSampleDataForTablesAndCols(dbEngine=dbEngine,schemaName=schemaName,
tablesAndCols=selectedTablesAndCols, maxRows=sampleDataRows)
def getMetadata(self) -> MetaDataLayout :
return self.metadataLayout
def updateMetadata(self, metadataLayout):
self.metadataLayout = metadataLayout
self._onMetadataChange()
def modifySqlQueryEnteredByUser(self, userSqlQuery):
platform = self.platform
userPrompt = f"Please correct the following sql query, also it has to be run on {platform}. sql query is \n {userSqlQuery}."
systemPrompt = ""
modifiedSql = self.gptInstance.getResponseForUserInput(userPrompt, systemPrompt)
return modifiedSql
def filteredSampleDataForProspects(self, prospectTablesAndCols):
sampleData = self.sampleData
filteredData = {}
for table in prospectTablesAndCols.keys():
# filteredData[table] = sampleData[table][prospectTablesAndCols[table]]
#take all columns of prospects
filteredData[table] = sampleData[table][prospectTablesAndCols[table]]
return filteredData
def getQueryForUserInput(self, userInput, chatHistory=[]):
gptSampleRows = self.gptSampleRows
selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols()
prospectTablesAndCols = self.getProspectiveTablesAndCols(userInput, selectedTablesAndCols, chatHistory)
print("getting prospects", prospectTablesAndCols)
prospectTablesData = self.filteredSampleDataForProspects(prospectTablesAndCols)
systemPromptForQueryGeneration = self.getSystemPromptForQueryGeneration(prospectTablesData, gptSampleRows=gptSampleRows)
queryByGpt = self.gptInstance.getResponseForUserInput(userInput, systemPromptForQueryGeneration, chatHistory)
queryByGpt = preProcessGptQueryReponse(queryByGpt, metadataLayout=self.metadataLayout)
return queryByGpt, prospectTablesAndCols
def getProspectiveTablesAndCols(self, userInput, selectedTablesAndCols, chatHistory=[]):
schemaName = self.schemaName
systemPromptForProspectColumns = self.getSystemPromptForProspectColumns(selectedTablesAndCols)
prospectiveTablesColsText = self.gptInstance.getResponseForUserInput(userInput, systemPromptForProspectColumns, chatHistory)
prospectTablesAndCols = {}
for table in selectedTablesAndCols.keys():
if table in prospectiveTablesColsText:
prospectTablesAndCols[table] = []
for column in selectedTablesAndCols[table]:
if column in prospectiveTablesColsText:
prospectTablesAndCols[table].append(column)
return prospectTablesAndCols
def getSystemPromptForQueryGeneration(self, prospectTablesData, gptSampleRows):
schemaName = self.schemaName
platform = self.platform
exampleQuery = """SELECT a.customer_id, COUNT(a.product_id) as chandelier_count
FROM lpdatamart.tbl_f_sales a
JOIN lpdatamart.tbl_d_product b ON a.product_id = b.product_id
JOIN lpdatamart.tbl_d_calendar c ON a.date_id = c.date_id
WHERE UPPER(b.product_name) LIKE '%CHANDELIER%' AND c.calendar_month = 'NOVEMBER' AND c.year = 2023
GROUP BY a.customer_id
ORDER BY chandelier_count DESC"""
question = "top 5 customers who bought most chandeliers in nov 2023"
prompt = f"""Given an input text, generate the corresponding SQL query for given details. Schema Name is {schemaName}. And sql platform is {platform}.\n following is sample data. Also
There is example user input and desired generated sql query. Follow similar patterns as example. eg case insensitive, explicit variable declaration etc. user input : {question}, query : {exampleQuery} """
for idx, tableName in enumerate(prospectTablesData.keys(), start=1):
prompt += f"table name is {tableName}, table data is {prospectTablesData[tableName].head(gptSampleRows)}"
prompt += "XXXX"
return prompt.replace("\n"," ").replace("\\"," ").replace(" "," ").replace("XXXX", " ")
def getSystemPromptForProspectColumns(self, selectedTablesAndCols):
schemaName = self.schemaName
platform = self.platform
prompt = f"""Given an input text, User wants to know which all tables and columns would be possibily to have the desired data. Output them as json. Schema Name is {schemaName}. And sql platform is {platform}.\n"""
for idx, tableName in enumerate(selectedTablesAndCols.keys(), start=1):
prompt += f"table name {tableName} {', '.join(selectedTablesAndCols[tableName])}"
prompt += "XXXX"
return prompt.replace("\n"," ").replace("\\"," ").replace(" "," ").replace("XXXX", " ")
|