anumaurya114exp commited on
Commit
2a4b462
·
1 Parent(s): 8863139

Upload queryHelperManagerCoT.py

Browse files
Files changed (1) hide show
  1. queryHelperManagerCoT.py +135 -0
queryHelperManagerCoT.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class QueryHelperChainOfThought:
2
+ def __init__(self, gptInstance: ChatgptManager, dbEngine, schemaName,
3
+ platform, metadataLayout: MetaDataLayout, sampleDataRows,
4
+ gptSampleRows, getSampleDataForTablesAndCols):
5
+ self.gptInstance = gptInstance
6
+ self.schemaName = schemaName
7
+ self.platform = platform
8
+ self.metadataLayout = metadataLayout
9
+ self.sampleDataRows = sampleDataRows
10
+ self.gptSampleRows = gptSampleRows
11
+ self.getSampleDataForTablesAndCols = getSampleDataForTablesAndCols
12
+ self.dbEngine = dbEngine
13
+ self._onMetadataChange()
14
+
15
+ def _onMetadataChange(self):
16
+ metadataLayout = self.metadataLayout
17
+ sampleDataRows = self.sampleDataRows
18
+ dbEngine = self.dbEngine
19
+ schemaName = self.schemaName
20
+
21
+ selectedTablesAndCols = metadataLayout.getSelectedTablesAndCols()
22
+ self.sampleData = self.getSampleDataForTablesAndCols(dbEngine=dbEngine,schemaName=schemaName,
23
+ tablesAndCols=selectedTablesAndCols, maxRows=sampleDataRows)
24
+
25
+ def getMetadata(self) -> MetaDataLayout :
26
+ return self.metadataLayout
27
+
28
+ def updateMetadata(self, metadataLayout):
29
+ self.metadataLayout = metadataLayout
30
+ self._onMetadataChange()
31
+
32
+ def modifySqlQueryEnteredByUser(self, userSqlQuery):
33
+ platform = self.platform
34
+ userPrompt = f"Please correct the following sql query, also it has to be run on {platform}. sql query is \n {userSqlQuery}."
35
+ systemPrompt = ""
36
+ modifiedSql = self.gptInstance.getResponseForUserInput(userPrompt, systemPrompt)
37
+ return modifiedSql
38
+
39
+ def filteredSampleDataForProspects(self, prospectTablesAndCols):
40
+ sampleData = self.sampleData
41
+ filteredData = {}
42
+ for table in prospectTablesAndCols.keys():
43
+ # filteredData[table] = sampleData[table][prospectTablesAndCols[table]]
44
+ #take all columns of prospects
45
+ filteredData[table] = sampleData[table][prospectTablesAndCols[table]]
46
+ return filteredData
47
+
48
+ def extractSingleJson(self, text):
49
+ pattern = r'\{.*?\}'
50
+ matches = re.findall(pattern, text, re.DOTALL)
51
+ extracted_json = [json.loads(match) for match in matches][0]
52
+ return extracted_json
53
+
54
+ def getQueryForUserInputCoT(self, userInput):
55
+ #1. Is the input complete to create a query, or ask user to reask with more detailed input
56
+ systemPromptForInputClarification = """Given an input text, user want to generate sql query. Please answer if the user input is complete or user needs to ask in more detailed way. Answer in following format. 'Yes' ; if yes, break the userinput into smaller subtask for query generation. Formatted into
57
+ {
58
+ "Task 1": "task 1 description",
59
+ "Task 2": "task 2 description"
60
+ }
61
+ 'No' ; if no, then Reason- please be more detailed about customer details; if more modification needed"""
62
+ cotStep1 = self.gptInstance.getResponseForUserInput(userInput, systemPromptForInputClarification, chatHistory)
63
+ if "yes" in cot1.lower()[:5]:
64
+ print("User input sufficient")
65
+ tasks = self.extractSingleJson(cotStep1)
66
+ print(f"tasks are {tasks})
67
+ taskQueries = {}
68
+ for key, task in tasks.items():
69
+ taskQuery = self.getQueryForUserInput(userInput)
70
+ taskQueries[key] = {"task":task, "taskQuery":taskQuery}
71
+ print(f"tasks and their queries {taskQueries}")
72
+
73
+ combiningSubtasksQueryPrompt = f"""Combine following subtask and their queries to generate sql query to answer the user input.\n """
74
+ userPrompt = f"user input is {userInput}"
75
+ for key in taskQueries.keys():
76
+ task = taskQueries[key]["task"]
77
+ query = taskQueries[key]["taskQuery"]
78
+ userPrompt += f" task: {task}, task query: {query}"
79
+ return self.self.gptInstance.getResponseForUserInput(userPrompt, combiningSubtasksQueryPrompt)
80
+ return f"Please rephrase your query. {' '.join(cot1.split('Reason')[1:])}
81
+
82
+ def getQueryForUserInput(self, userInput, chatHistory=[]):
83
+ gptSampleRows = self.gptSampleRows
84
+ selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols()
85
+ prospectTablesAndCols = self.getProspectiveTablesAndCols(userInput, selectedTablesAndCols, chatHistory)
86
+ print("getting prospects", prospectTablesAndCols)
87
+ prospectTablesData = self.filteredSampleDataForProspects(prospectTablesAndCols)
88
+ systemPromptForQueryGeneration = self.getSystemPromptForQueryGeneration(prospectTablesData, gptSampleRows=gptSampleRows)
89
+ queryByGpt = self.gptInstance.getResponseForUserInput(userInput, systemPromptForQueryGeneration, chatHistory)
90
+
91
+ queryByGpt = preProcessGptQueryReponse(queryByGpt, metadataLayout=self.metadataLayout)
92
+ return queryByGpt, prospectTablesAndCols
93
+
94
+ def getProspectiveTablesAndCols(self, userInput, selectedTablesAndCols, chatHistory=[]):
95
+ schemaName = self.schemaName
96
+
97
+ systemPromptForProspectColumns = self.getSystemPromptForProspectColumns(selectedTablesAndCols)
98
+ prospectiveTablesColsText = self.gptInstance.getResponseForUserInput(userInput, systemPromptForProspectColumns, chatHistory)
99
+ prospectTablesAndCols = {}
100
+ for table in selectedTablesAndCols.keys():
101
+ if table in prospectiveTablesColsText:
102
+ prospectTablesAndCols[table] = []
103
+ for column in selectedTablesAndCols[table]:
104
+ if column in prospectiveTablesColsText:
105
+ prospectTablesAndCols[table].append(column)
106
+ return prospectTablesAndCols
107
+
108
+ def getSystemPromptForQueryGeneration(self, prospectTablesData, gptSampleRows):
109
+ schemaName = self.schemaName
110
+ platform = self.platform
111
+ exampleQuery = """SELECT a.customer_id, COUNT(a.product_id) as chandelier_count
112
+ FROM lpdatamart.tbl_f_sales a
113
+ JOIN lpdatamart.tbl_d_product b ON a.product_id = b.product_id
114
+ JOIN lpdatamart.tbl_d_calendar c ON a.date_id = c.date_id
115
+ WHERE UPPER(b.product_name) LIKE '%CHANDELIER%' AND c.calendar_month = 'NOVEMBER' AND c.year = 2023
116
+ GROUP BY a.customer_id
117
+ ORDER BY chandelier_count DESC"""
118
+
119
+ question = "top 5 customers who bought most chandeliers in nov 2023"
120
+ prompt = f"""Given an input text, generate the corresponding SQL query for given details. Schema Name is {schemaName}. And sql platform is {platform}.\n following is sample data. Also
121
+ There is example user input and desired generated sql query. Follow similar patterns as example. eg case insensitive, explicit variable declaration etc. user input : {question}, query : {exampleQuery} """
122
+ for idx, tableName in enumerate(prospectTablesData.keys(), start=1):
123
+ prompt += f"table name is {tableName}, table data is {prospectTablesData[tableName].head(gptSampleRows)}"
124
+ prompt += "XXXX"
125
+ return prompt.replace("\n"," ").replace("\\"," ").replace(" "," ").replace("XXXX", " ")
126
+
127
+ def getSystemPromptForProspectColumns(self, selectedTablesAndCols):
128
+ schemaName = self.schemaName
129
+ platform = self.platform
130
+
131
+ prompt = f"""Given an input text, User wants to know which all tables and columns would be possibily to have the desired data. Output them as json. Schema Name is {schemaName}. And sql platform is {platform}.\n"""
132
+ for idx, tableName in enumerate(selectedTablesAndCols.keys(), start=1):
133
+ prompt += f"table name {tableName} {', '.join(selectedTablesAndCols[tableName])}"
134
+ prompt += "XXXX"
135
+ return prompt.replace("\n"," ").replace("\\"," ").replace(" "," ").replace("XXXX", " ")