Spaces:
Runtime error
Runtime error
Commit ·
2a4b462
1
Parent(s): 8863139
Upload queryHelperManagerCoT.py
Browse files- queryHelperManagerCoT.py +135 -0
queryHelperManagerCoT.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class QueryHelperChainOfThought:
|
| 2 |
+
def __init__(self, gptInstance: ChatgptManager, dbEngine, schemaName,
|
| 3 |
+
platform, metadataLayout: MetaDataLayout, sampleDataRows,
|
| 4 |
+
gptSampleRows, getSampleDataForTablesAndCols):
|
| 5 |
+
self.gptInstance = gptInstance
|
| 6 |
+
self.schemaName = schemaName
|
| 7 |
+
self.platform = platform
|
| 8 |
+
self.metadataLayout = metadataLayout
|
| 9 |
+
self.sampleDataRows = sampleDataRows
|
| 10 |
+
self.gptSampleRows = gptSampleRows
|
| 11 |
+
self.getSampleDataForTablesAndCols = getSampleDataForTablesAndCols
|
| 12 |
+
self.dbEngine = dbEngine
|
| 13 |
+
self._onMetadataChange()
|
| 14 |
+
|
| 15 |
+
def _onMetadataChange(self):
|
| 16 |
+
metadataLayout = self.metadataLayout
|
| 17 |
+
sampleDataRows = self.sampleDataRows
|
| 18 |
+
dbEngine = self.dbEngine
|
| 19 |
+
schemaName = self.schemaName
|
| 20 |
+
|
| 21 |
+
selectedTablesAndCols = metadataLayout.getSelectedTablesAndCols()
|
| 22 |
+
self.sampleData = self.getSampleDataForTablesAndCols(dbEngine=dbEngine,schemaName=schemaName,
|
| 23 |
+
tablesAndCols=selectedTablesAndCols, maxRows=sampleDataRows)
|
| 24 |
+
|
| 25 |
+
def getMetadata(self) -> MetaDataLayout :
|
| 26 |
+
return self.metadataLayout
|
| 27 |
+
|
| 28 |
+
def updateMetadata(self, metadataLayout):
|
| 29 |
+
self.metadataLayout = metadataLayout
|
| 30 |
+
self._onMetadataChange()
|
| 31 |
+
|
| 32 |
+
def modifySqlQueryEnteredByUser(self, userSqlQuery):
|
| 33 |
+
platform = self.platform
|
| 34 |
+
userPrompt = f"Please correct the following sql query, also it has to be run on {platform}. sql query is \n {userSqlQuery}."
|
| 35 |
+
systemPrompt = ""
|
| 36 |
+
modifiedSql = self.gptInstance.getResponseForUserInput(userPrompt, systemPrompt)
|
| 37 |
+
return modifiedSql
|
| 38 |
+
|
| 39 |
+
def filteredSampleDataForProspects(self, prospectTablesAndCols):
|
| 40 |
+
sampleData = self.sampleData
|
| 41 |
+
filteredData = {}
|
| 42 |
+
for table in prospectTablesAndCols.keys():
|
| 43 |
+
# filteredData[table] = sampleData[table][prospectTablesAndCols[table]]
|
| 44 |
+
#take all columns of prospects
|
| 45 |
+
filteredData[table] = sampleData[table][prospectTablesAndCols[table]]
|
| 46 |
+
return filteredData
|
| 47 |
+
|
| 48 |
+
def extractSingleJson(self, text):
|
| 49 |
+
pattern = r'\{.*?\}'
|
| 50 |
+
matches = re.findall(pattern, text, re.DOTALL)
|
| 51 |
+
extracted_json = [json.loads(match) for match in matches][0]
|
| 52 |
+
return extracted_json
|
| 53 |
+
|
| 54 |
+
def getQueryForUserInputCoT(self, userInput):
|
| 55 |
+
#1. Is the input complete to create a query, or ask user to reask with more detailed input
|
| 56 |
+
systemPromptForInputClarification = """Given an input text, user want to generate sql query. Please answer if the user input is complete or user needs to ask in more detailed way. Answer in following format. 'Yes' ; if yes, break the userinput into smaller subtask for query generation. Formatted into
|
| 57 |
+
{
|
| 58 |
+
"Task 1": "task 1 description",
|
| 59 |
+
"Task 2": "task 2 description"
|
| 60 |
+
}
|
| 61 |
+
'No' ; if no, then Reason- please be more detailed about customer details; if more modification needed"""
|
| 62 |
+
cotStep1 = self.gptInstance.getResponseForUserInput(userInput, systemPromptForInputClarification, chatHistory)
|
| 63 |
+
if "yes" in cot1.lower()[:5]:
|
| 64 |
+
print("User input sufficient")
|
| 65 |
+
tasks = self.extractSingleJson(cotStep1)
|
| 66 |
+
print(f"tasks are {tasks})
|
| 67 |
+
taskQueries = {}
|
| 68 |
+
for key, task in tasks.items():
|
| 69 |
+
taskQuery = self.getQueryForUserInput(userInput)
|
| 70 |
+
taskQueries[key] = {"task":task, "taskQuery":taskQuery}
|
| 71 |
+
print(f"tasks and their queries {taskQueries}")
|
| 72 |
+
|
| 73 |
+
combiningSubtasksQueryPrompt = f"""Combine following subtask and their queries to generate sql query to answer the user input.\n """
|
| 74 |
+
userPrompt = f"user input is {userInput}"
|
| 75 |
+
for key in taskQueries.keys():
|
| 76 |
+
task = taskQueries[key]["task"]
|
| 77 |
+
query = taskQueries[key]["taskQuery"]
|
| 78 |
+
userPrompt += f" task: {task}, task query: {query}"
|
| 79 |
+
return self.self.gptInstance.getResponseForUserInput(userPrompt, combiningSubtasksQueryPrompt)
|
| 80 |
+
return f"Please rephrase your query. {' '.join(cot1.split('Reason')[1:])}
|
| 81 |
+
|
| 82 |
+
def getQueryForUserInput(self, userInput, chatHistory=[]):
|
| 83 |
+
gptSampleRows = self.gptSampleRows
|
| 84 |
+
selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols()
|
| 85 |
+
prospectTablesAndCols = self.getProspectiveTablesAndCols(userInput, selectedTablesAndCols, chatHistory)
|
| 86 |
+
print("getting prospects", prospectTablesAndCols)
|
| 87 |
+
prospectTablesData = self.filteredSampleDataForProspects(prospectTablesAndCols)
|
| 88 |
+
systemPromptForQueryGeneration = self.getSystemPromptForQueryGeneration(prospectTablesData, gptSampleRows=gptSampleRows)
|
| 89 |
+
queryByGpt = self.gptInstance.getResponseForUserInput(userInput, systemPromptForQueryGeneration, chatHistory)
|
| 90 |
+
|
| 91 |
+
queryByGpt = preProcessGptQueryReponse(queryByGpt, metadataLayout=self.metadataLayout)
|
| 92 |
+
return queryByGpt, prospectTablesAndCols
|
| 93 |
+
|
| 94 |
+
def getProspectiveTablesAndCols(self, userInput, selectedTablesAndCols, chatHistory=[]):
|
| 95 |
+
schemaName = self.schemaName
|
| 96 |
+
|
| 97 |
+
systemPromptForProspectColumns = self.getSystemPromptForProspectColumns(selectedTablesAndCols)
|
| 98 |
+
prospectiveTablesColsText = self.gptInstance.getResponseForUserInput(userInput, systemPromptForProspectColumns, chatHistory)
|
| 99 |
+
prospectTablesAndCols = {}
|
| 100 |
+
for table in selectedTablesAndCols.keys():
|
| 101 |
+
if table in prospectiveTablesColsText:
|
| 102 |
+
prospectTablesAndCols[table] = []
|
| 103 |
+
for column in selectedTablesAndCols[table]:
|
| 104 |
+
if column in prospectiveTablesColsText:
|
| 105 |
+
prospectTablesAndCols[table].append(column)
|
| 106 |
+
return prospectTablesAndCols
|
| 107 |
+
|
| 108 |
+
def getSystemPromptForQueryGeneration(self, prospectTablesData, gptSampleRows):
|
| 109 |
+
schemaName = self.schemaName
|
| 110 |
+
platform = self.platform
|
| 111 |
+
exampleQuery = """SELECT a.customer_id, COUNT(a.product_id) as chandelier_count
|
| 112 |
+
FROM lpdatamart.tbl_f_sales a
|
| 113 |
+
JOIN lpdatamart.tbl_d_product b ON a.product_id = b.product_id
|
| 114 |
+
JOIN lpdatamart.tbl_d_calendar c ON a.date_id = c.date_id
|
| 115 |
+
WHERE UPPER(b.product_name) LIKE '%CHANDELIER%' AND c.calendar_month = 'NOVEMBER' AND c.year = 2023
|
| 116 |
+
GROUP BY a.customer_id
|
| 117 |
+
ORDER BY chandelier_count DESC"""
|
| 118 |
+
|
| 119 |
+
question = "top 5 customers who bought most chandeliers in nov 2023"
|
| 120 |
+
prompt = f"""Given an input text, generate the corresponding SQL query for given details. Schema Name is {schemaName}. And sql platform is {platform}.\n following is sample data. Also
|
| 121 |
+
There is example user input and desired generated sql query. Follow similar patterns as example. eg case insensitive, explicit variable declaration etc. user input : {question}, query : {exampleQuery} """
|
| 122 |
+
for idx, tableName in enumerate(prospectTablesData.keys(), start=1):
|
| 123 |
+
prompt += f"table name is {tableName}, table data is {prospectTablesData[tableName].head(gptSampleRows)}"
|
| 124 |
+
prompt += "XXXX"
|
| 125 |
+
return prompt.replace("\n"," ").replace("\\"," ").replace(" "," ").replace("XXXX", " ")
|
| 126 |
+
|
| 127 |
+
def getSystemPromptForProspectColumns(self, selectedTablesAndCols):
|
| 128 |
+
schemaName = self.schemaName
|
| 129 |
+
platform = self.platform
|
| 130 |
+
|
| 131 |
+
prompt = f"""Given an input text, User wants to know which all tables and columns would be possibily to have the desired data. Output them as json. Schema Name is {schemaName}. And sql platform is {platform}.\n"""
|
| 132 |
+
for idx, tableName in enumerate(selectedTablesAndCols.keys(), start=1):
|
| 133 |
+
prompt += f"table name {tableName} {', '.join(selectedTablesAndCols[tableName])}"
|
| 134 |
+
prompt += "XXXX"
|
| 135 |
+
return prompt.replace("\n"," ").replace("\\"," ").replace(" "," ").replace("XXXX", " ")
|