Spaces:
Runtime error
Runtime error
File size: 4,793 Bytes
1dda07c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 | from gptManager import ChatgptManager
from utils import MetaDataLayout
class QueryHelper:
def __init__(self, gptInstance: ChatgptManager, dbEngine, schemaName,
platform, metadataLayout: MetaDataLayout, sampleDataRows,
gptSampleRows, getSampleDataForTablesAndCols):
self.gptInstance = gptInstance
self.schemaName = schemaName
self.platform = platform
self.metadataLayout = metadataLayout
self.sampleDataRows = sampleDataRows
self.gptSampleRows = gptSampleRows
self.getSampleDataForTablesAndCols = getSampleDataForTablesAndCols
self.dbEngine = dbEngine
self._onMetadataChange()
def _onMetadataChange(self):
metadataLayout = self.metadataLayout
sampleDataRows = self.sampleDataRows
dbEngine = self.dbEngine
schemaName = self.schemaName
selectedTablesAndCols = metadataLayout.getSelectedTablesAndCols()
self.sampleData = self.getSampleDataForTablesAndCols(dbEngine=dbEngine,schemaName=schemaName,
tablesAndCols=selectedTablesAndCols, maxRows=sampleDataRows)
def getMetadata(self) -> MetaDataLayout :
return self.metadataLayout
def updateMetadata(self, metadataLayout):
self.metadataLayout = metadataLayout
self._onMetadataChange()
def modifySqlQueryEnteredByUser(self, userSqlQuery):
platform = self.platform
userPrompt = f"Please correct the following sql query, also it has to be run on {platform}. sql query is \n {userSqlQuery}."
systemPrompt = ""
modifiedSql = self.gptInstance.getResponseForUserInput(userPrompt, systemPrompt)
return modifiedSql
def filteredSampleDataForProspects(self, prospectTablesAndCols):
sampleData = self.sampleData
filteredData = {}
for table in prospectTablesAndCols.keys():
# filteredData[table] = sampleData[table][prospectTablesAndCols[table]]
#take all columns of prospects
filteredData[table] = sampleData[table]
return filteredData
def getQueryForUserInput(self, userInput, chatHistory=[]):
gptSampleRows = self.gptSampleRows
selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols()
prospectTablesAndCols = self.getProspectiveTablesAndCols(userInput, selectedTablesAndCols, chatHistory)
print("getting prospects", prospectTablesAndCols)
prospectTablesData = self.filteredSampleDataForProspects(prospectTablesAndCols)
systemPromptForQueryGeneration = self.getSystemPromptForQueryGeneration(prospectTablesData, gptSampleRows=gptSampleRows)
queryByGpt = self.gptInstance.getResponseForUserInput(userInput, systemPromptForQueryGeneration, chatHistory)
queryByGpt = preProcessGptQueryReponse(queryByGpt, metadataLayout=metadataLayout)
return queryByGpt, prospectTablesAndCols
def getProspectiveTablesAndCols(self, userInput, selectedTablesAndCols, chatHistory=[]):
schemaName = self.schemaName
systemPromptForProspectColumns = self.getSystemPromptForProspectColumns(selectedTablesAndCols)
prospectiveTablesColsText = self.gptInstance.getResponseForUserInput(userInput, systemPromptForProspectColumns, chatHistory)
prospectTablesAndCols = {}
for table in selectedTablesAndCols.keys():
if table in prospectiveTablesColsText:
prospectTablesAndCols[table] = []
for column in selectedTablesAndCols[table]:
if column in prospectiveTablesColsText:
prospectTablesAndCols[table].append(column)
return prospectTablesAndCols
def getSystemPromptForQueryGeneration(self, prospectTablesData, gptSampleRows):
schemaName = self.schemaName
platform = self.platform
prompt = f"""Given an input text, generate the corresponding SQL query for given details. Schema Name is {schemaName}. And sql platform is {platform}.\n following is sample data"""
for idx, tableName in enumerate(prospectTablesData.keys(), start=1):
prompt += f"table name is {tableName}, table data is {prospectTablesData[tableName].head(gptSampleRows)}"
prompt += "XXXX"
return prompt.replace("\n"," ").replace("\\"," ").replace(" "," ").replace("XXXX", " ")
def getSystemPromptForProspectColumns(self, selectedTablesAndCols):
schemaName = self.schemaName
platform = self.platform
prompt = f"""Given an input text, User wants to know which all tables and columns would be possibily to have the desired data. Output them as json. Schema Name is {schemaName}. And sql platform is {platform}.\n"""
for idx, tableName in enumerate(selectedTablesAndCols.keys(), start=1):
prompt += f"table name {tableName} {', '.join(selectedTablesAndCols[tableName])}"
prompt += "XXXX"
return prompt.replace("\n"," ").replace("\\"," ").replace(" "," ").replace("XXXX", " ")
|