anumaurya114exp commited on
Commit
0ff15a5
·
1 Parent(s): 1d61200

new cot and with history query helper

Browse files
Files changed (5) hide show
  1. app.py +13 -22
  2. constants.py +37 -2
  3. gptManager.py +12 -59
  4. queryHelperManager.py +47 -59
  5. queryHelperManagerCoT.py +91 -107
app.py CHANGED
@@ -35,8 +35,10 @@ selectedTablesAndCols = metadataLayout.getSelectedTablesAndCols()
35
 
36
 
37
  openAIClient = OpenAI(api_key=OPENAI_API_KEY)
38
- gptInstance = ChatgptManager(openAIClient, model=GPT_MODEL)
39
- queryHelper = QueryHelper(gptInstance=gptInstance,
 
 
40
  schemaName=SCHEMA_NAME,platform=PLATFORM,
41
  metadataLayout=metadataLayout,
42
  sampleDataRows=SAMPLE_ROW_MAX,
@@ -46,8 +48,8 @@ queryHelper = QueryHelper(gptInstance=gptInstance,
46
 
47
 
48
  openAIClient2 = OpenAI(api_key=OPENAI_API_KEY)
49
- gptInstance2 = ChatgptManager(openAIClient2, model=GPT_MODEL)
50
- queryHelperCot = QueryHelperChainOfThought(gptInstance=gptInstance2,
51
  schemaName=SCHEMA_NAME,platform=PLATFORM,
52
  metadataLayout=metadataLayout,
53
  sampleDataRows=SAMPLE_ROW_MAX,
@@ -68,16 +70,15 @@ def respond(message, chatHistory):
68
  """gpt response handler for gradio ui"""
69
  global queryHelper
70
  try:
71
- botMessage, prospectTablesAndCols = queryHelper.getQueryForUserInput(message, chatHistory)
72
  except Exception as e:
73
  errorMessage = {"function":"queryHelper.getQueryForUserInput","error":str(e), "userInput":message}
74
  saveLog(errorMessage, 'error')
75
  raise ValueError(str(e))
76
  queryGenerated = extractSqlFromGptResponse(botMessage)
77
- logMessage = {"userInput":message, "tablesColsSelectedByGpt":str(prospectTablesAndCols) , "queryGenerated":queryGenerated, "completeGptResponse":botMessage, "function":"queryHelper.getQueryForUserInput"}
78
  saveLog(logMessage)
79
  chatHistory.append((message, botMessage))
80
- time.sleep(2)
81
  return "", chatHistory
82
 
83
  # Function to save history of chat
@@ -85,19 +86,14 @@ def respondCoT(message, chatHistory):
85
  """gpt response handler for gradio ui"""
86
  global queryHelperCot
87
  try:
88
- if "modify" in message[:12].lower():
89
- botMessage, prospectTablesAndCols = queryHelperCot.getQueryForUserInput(message, chatHistory)
90
- else:
91
- botMessage, prospectTablesAndCols = queryHelperCot.getQueryForUserInputCoT(message)
92
  except Exception as e:
93
- errorMessage = {"function":"queryHelperFineTuned.getQueryForUserInput","error":str(e), "userInput":message}
94
- saveLog(errorMessage, 'error')
95
- raise ValueError(str(e))
96
- queryGenerated = extractSqlFromGptResponse(botMessage)
97
- logMessage = {"userInput":message, "tablesColsSelectedByGpt":str(prospectTablesAndCols) , "queryGenerated":queryGenerated, "completeGptResponse":botMessage, "function":"queryHelperCot.getQueryForUserInputCoT"}
98
  saveLog(logMessage)
99
  chatHistory.append((message, botMessage))
100
- time.sleep(2)
101
  return "", chatHistory
102
 
103
 
@@ -131,11 +127,6 @@ def testSQL(sql):
131
  dbEngine2.disconnect()
132
 
133
  print(f"Error occured during running the query {sql}.\n and the error is {str(e)}")
134
-
135
- # prompt = f"Please correct the following sql query, also it has to be run on {PLATFORM}. sql query is \n {sql}. the error occured is {str(e)}."
136
- # modifiedSql = queryHelper.modifySqlQueryEnteredByUser(prompt)
137
- # logMessage = {"function":"queryHelper.modifySqlQueryEnteredByUser", "sqlQuery":sql, "modifiedSQLQuery":modifiedSql}
138
- # saveLog(logMessage, 'info')
139
  return f"The query you entered throws some error. Here is the error.\n {str(e)}"
140
 
141
 
 
35
 
36
 
37
  openAIClient = OpenAI(api_key=OPENAI_API_KEY)
38
+ gptInstanceForTableCols = ChatgptManager(openAIClient, model=GPT_MODEL)
39
+ gptInstanceForQuery = ChatgptManager(openAIClient, model=GPT_MODEL)
40
+ queryHelper = QueryHelper(gptInstanceForTableCols=gptInstanceForTableCols,
41
+ gptInstanceForQuery=gptInstanceForQuery,
42
  schemaName=SCHEMA_NAME,platform=PLATFORM,
43
  metadataLayout=metadataLayout,
44
  sampleDataRows=SAMPLE_ROW_MAX,
 
48
 
49
 
50
  openAIClient2 = OpenAI(api_key=OPENAI_API_KEY)
51
+ gptInstanceForCoT = ChatgptManager(openAIClient2, model=GPT_MODEL)
52
+ queryHelperCot = QueryHelperChainOfThought(gptInstanceForCoT=gptInstanceForCoT,
53
  schemaName=SCHEMA_NAME,platform=PLATFORM,
54
  metadataLayout=metadataLayout,
55
  sampleDataRows=SAMPLE_ROW_MAX,
 
70
  """gpt response handler for gradio ui"""
71
  global queryHelper
72
  try:
73
+ botMessage = queryHelper.getQueryForUserInput(message)
74
  except Exception as e:
75
  errorMessage = {"function":"queryHelper.getQueryForUserInput","error":str(e), "userInput":message}
76
  saveLog(errorMessage, 'error')
77
  raise ValueError(str(e))
78
  queryGenerated = extractSqlFromGptResponse(botMessage)
79
+ logMessage = {"userInput":message, "queryGenerated":queryGenerated, "completeGptResponse":botMessage, "function":"queryHelper.getQueryForUserInput"}
80
  saveLog(logMessage)
81
  chatHistory.append((message, botMessage))
 
82
  return "", chatHistory
83
 
84
  # Function to save history of chat
 
86
  """gpt response handler for gradio ui"""
87
  global queryHelperCot
88
  try:
89
+ botMessage = queryHelperCot.getQueryForUserInputCoT(message)
 
 
 
90
  except Exception as e:
91
+ errorMessage = {"function":"queryHelperCot.getQueryForUserInput","error":str(e), "userInput":message}
92
+ saveLog(errorMessage, 'error')
93
+ raise ValueError(str(e))
94
+ logMessage = {"userInput":message, "completeGptResponse":botMessage, "function":"queryHelperCot.getQueryForUserInputCoT"}
 
95
  saveLog(logMessage)
96
  chatHistory.append((message, botMessage))
 
97
  return "", chatHistory
98
 
99
 
 
127
  dbEngine2.disconnect()
128
 
129
  print(f"Error occured during running the query {sql}.\n and the error is {str(e)}")
 
 
 
 
 
130
  return f"The query you entered throws some error. Here is the error.\n {str(e)}"
131
 
132
 
constants.py CHANGED
@@ -2,7 +2,7 @@ __all__ = ["SCHEMA_NAME", "GPT_SAMPLE_ROWS", "PLATFORM", "SAMPLE_ROW_MAX", "DEFA
2
 
3
  #Constants
4
  SCHEMA_NAME = "lpdatamart"
5
- GPT_SAMPLE_ROWS = 5
6
  PLATFORM = "Amazon Redshift"
7
  SAMPLE_ROW_MAX = 50
8
  QUERY_TIMEOUT = 20 #timeout in seconds
@@ -30,4 +30,39 @@ event_col = ['event_id', 'event_type', 'event_description', 'event_detail', 'sta
30
  DEFAULT_TABLES_COLS = {"tbl_d_customer":customer_col, "tbl_d_product":product_col, "tbl_f_sales":sales_col,
31
  "tbl_d_store":store_col, "tbl_d_channel":channel_col, "tbl_d_lineaction_code":lineaction_col,
32
  "tbl_d_calendar":calendar_col, 'tbl_f_browse':browse_col, 'tbl_d_time': time_col, 'tbl_d_browse_action': browse_action_col,
33
- 'tbl_d_browse_category':browse_category_col, 'tbl_d_style':style_col, 'tbl_f_emailing': email_col, 'tbl_d_event':event_col}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  #Constants
4
  SCHEMA_NAME = "lpdatamart"
5
+ GPT_SAMPLE_ROWS = 1
6
  PLATFORM = "Amazon Redshift"
7
  SAMPLE_ROW_MAX = 50
8
  QUERY_TIMEOUT = 20 #timeout in seconds
 
30
  DEFAULT_TABLES_COLS = {"tbl_d_customer":customer_col, "tbl_d_product":product_col, "tbl_f_sales":sales_col,
31
  "tbl_d_store":store_col, "tbl_d_channel":channel_col, "tbl_d_lineaction_code":lineaction_col,
32
  "tbl_d_calendar":calendar_col, 'tbl_f_browse':browse_col, 'tbl_d_time': time_col, 'tbl_d_browse_action': browse_action_col,
33
+ 'tbl_d_browse_category':browse_category_col, 'tbl_d_style':style_col, 'tbl_f_emailing': email_col, 'tbl_d_event':event_col}
34
+
35
+
36
+ TABLE_RELATIONS = """tbl_d_store and tbl_f_sales on store_id
37
+ tbl_d_time and tbl_f_sales on time_id
38
+ tbl_d_product and tbl_f_sales on product_id
39
+ tbl_d_channel and tbl_f_sales on channel_id
40
+ tbl_d_customer and tbl_f_sales on customer_id
41
+ tbl_d_source and tbl_f_sales on source_id
42
+ tbl_d_calender and tbl_f_sales on date_id
43
+ tbl_d_associate and tbl_f_sales on associate_id
44
+ tbl_d_promption and tbl_f_sales on promotion_id
45
+ tbl_d_keycode and tbl_f_sales on keycode_id
46
+ tbl_d_lineaction_code and tbl_f_sales on tbl_d_lineaction_code.line_action_code, tbl_f_sales.line_action
47
+ tbl_d_event and tbl_f_emailing on event_id
48
+ tbl_d_calender and tbl_f_emailing on date_id
49
+ tbl_d_e_sourceid and tbl_f_emailing on email_source_key
50
+ tbl_d_time and tbl_f_emailing on time_id
51
+ tbl_d_customer and tbl_f_emailing on customer_id
52
+ tbl_d_email and tbl_f_email on email_key
53
+ tbl_d_email and tbl_d_url on url_id
54
+ tbl_f_mailing and tbl_d_calender on date_id
55
+ tbl_d_customer and tbl_f_mailing on customer_id
56
+ tbl_d_keycode and tbl_f_mailing on keycode_id
57
+ tbl_d_email and tbl_f_browse on email_key
58
+ tbl_d_calender and tbl_f_browse on date_id
59
+ tbl_d_product and tbl_f_browse on product_id
60
+ tbl_d_browse_action and tbl_f_browse on browse_action_id
61
+ tbl_d_browse_style and tbl_f_browse on browse_style_id
62
+ tbl_d_source and tbl_f_activity on source_id
63
+ tbl_d_calender and tbl_f_activity on date_id
64
+ tbl_d_time and tbl_f_activity on time_id
65
+ tbl_d_customer and tbl_f_activity on customer_id
66
+ tbl_d_customer and tbl_f_opt_out on customer_id
67
+ tbl_d_calender and tbl_f_opt_out on date_id
68
+ tbl_d_time and tbl_f_opt_out on time_id"""
gptManager.py CHANGED
@@ -7,40 +7,19 @@ class ChatgptManager:
7
  self.tokenLimit = tokenLimit
8
  self.model = model
9
  self.throwError = throwError
 
10
 
11
- def _chatHistoryToGptMessages(self, chatHistory=[]):
12
- messages = []
13
- for i in range(len(chatHistory)):
14
- if i%2==0:
15
- message = {"role":"user", "content":chatHistory[i]}
16
- else:
17
- message = {"role":"assistant", "content": chatHistory[i]}
18
- messages.append(message)
19
- return messages
20
-
21
- def getResponseForUserInput(self, userInput, systemPrompt, chatHistory=[]):
22
- self.messages = self._chatHistoryToGptMessages(chatHistory[:])
23
- newMessage = {"role":"system", "content":systemPrompt}
24
- if not self.isTokeLimitExceeding(newMessage):
25
- self.messages.append(newMessage)
26
  else:
27
- if chatHistory==[]:
28
- raise ValueError("System Prompt Too long.")
29
- return self.getResponseForUserInput(userInput=userInput, systemPrompt=systemPrompt)
30
 
 
31
  userMessage = {"role":"user", "content":userInput}
32
- if not self.isTokeLimitExceeding(userMessage):
33
- self.messages.append(userMessage)
34
- else:
35
- if chatHistory==[]:
36
- raise ValueError("Token Limit exceeding. With user input")
37
- return self.getResponseForUserInput(userInput=userInput, systemPrompt=systemPrompt)
38
-
39
- # completion = self.client.chat.completions.create(
40
- # model="gpt-3.5-turbo-1106",
41
- # messages=self.messages,
42
- # temperature=0,
43
- # )
44
  print(self.messages, "messages being sent to gpt for completion.")
45
  try:
46
  completion = self.client.chat.completions.create(
@@ -51,34 +30,8 @@ class ChatgptManager:
51
  gptResponse = completion.choices[0].message.content
52
  except Exception as e:
53
  if not self.throwError:
54
- gptResponse = "Error while connecting with gpt " + str(e)[:50] + "..."
55
-
56
-
57
-
58
 
59
  self.messages.append({"role": "assistant", "content": gptResponse})
60
- return gptResponse
61
-
62
- def isTokeLimitExceeding(self, newMessage=None, truncate=True, throwError=True):
63
- if self.getTokenCount(newMessage=newMessage) > self.tokenLimit:
64
- return True
65
- return False
66
-
67
-
68
- def getTokenCount(self, newMessage=None):
69
- """Token count including new Message"""
70
-
71
- def getWordsCount(text):
72
- return len(re.findall(r'\b\w+\b', text))
73
-
74
- messages = self.messages[:]
75
- if newMessage!=None:
76
- messages.append(newMessage)
77
-
78
- if len(messages)!=0:
79
- combinedContent = " ".join([str(msg["content"]) for msg in messages])
80
- else:
81
- combinedContent = ""
82
-
83
- currentTokensInMessages = getWordsCount(combinedContent)
84
- return currentTokensInMessages
 
7
  self.tokenLimit = tokenLimit
8
  self.model = model
9
  self.throwError = throwError
10
+ self.messages = []
11
 
12
+ def setSystemPrompt(self, systemPrompt):
13
+ systemMessage = {"role":"system", "content":systemPrompt}
14
+ if len(self.messages)==0:
15
+ self.messages = [systemMessage]
 
 
 
 
 
 
 
 
 
 
 
16
  else:
17
+ del self.messages[0]
18
+ self.messages.insert(0, systemMessage)
 
19
 
20
+ def getResponseForUserInput(self, userInput):
21
  userMessage = {"role":"user", "content":userInput}
22
+ self.messages.append(userMessage)
 
 
 
 
 
 
 
 
 
 
 
23
  print(self.messages, "messages being sent to gpt for completion.")
24
  try:
25
  completion = self.client.chat.completions.create(
 
30
  gptResponse = completion.choices[0].message.content
31
  except Exception as e:
32
  if not self.throwError:
33
+ errorText = "Error while connecting with gpt " + str(e)[:100] + "..."
34
+ return errorText
 
 
35
 
36
  self.messages.append({"role": "assistant", "content": gptResponse})
37
+ return gptResponse
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
queryHelperManager.py CHANGED
@@ -1,11 +1,16 @@
1
  from gptManager import ChatgptManager
2
  from utils import *
 
 
3
 
4
  class QueryHelper:
5
- def __init__(self, gptInstance: ChatgptManager, dbEngine, schemaName,
 
 
6
  platform, metadataLayout: MetaDataLayout, sampleDataRows,
7
- gptSampleRows, getSampleDataForTablesAndCols):
8
- self.gptInstance = gptInstance
 
9
  self.schemaName = schemaName
10
  self.platform = platform
11
  self.metadataLayout = metadataLayout
@@ -13,6 +18,7 @@ class QueryHelper:
13
  self.gptSampleRows = gptSampleRows
14
  self.getSampleDataForTablesAndCols = getSampleDataForTablesAndCols
15
  self.dbEngine = dbEngine
 
16
  self._onMetadataChange()
17
 
18
  def _onMetadataChange(self):
@@ -20,10 +26,12 @@ class QueryHelper:
20
  sampleDataRows = self.sampleDataRows
21
  dbEngine = self.dbEngine
22
  schemaName = self.schemaName
23
-
24
  selectedTablesAndCols = metadataLayout.getSelectedTablesAndCols()
25
  self.sampleData = self.getSampleDataForTablesAndCols(dbEngine=dbEngine,schemaName=schemaName,
26
  tablesAndCols=selectedTablesAndCols, maxRows=sampleDataRows)
 
 
 
27
 
28
  def getMetadata(self) -> MetaDataLayout :
29
  return self.metadataLayout
@@ -31,52 +39,43 @@ class QueryHelper:
31
  def updateMetadata(self, metadataLayout):
32
  self.metadataLayout = metadataLayout
33
  self._onMetadataChange()
34
-
35
- def modifySqlQueryEnteredByUser(self, userSqlQuery):
36
- platform = self.platform
37
- userPrompt = f"Please correct the following sql query, also it has to be run on {platform}. sql query is \n {userSqlQuery}."
38
- systemPrompt = ""
39
- modifiedSql = self.gptInstance.getResponseForUserInput(userPrompt, systemPrompt)
40
- return modifiedSql
41
-
42
- def filteredSampleDataForProspects(self, prospectTablesAndCols):
43
- sampleData = self.sampleData
44
- filteredData = {}
45
- for table in prospectTablesAndCols.keys():
46
- # filteredData[table] = sampleData[table][prospectTablesAndCols[table]]
47
- #take all columns of prospects
48
- filteredData[table] = sampleData[table][prospectTablesAndCols[table]]
49
- return filteredData
50
 
51
- def getQueryForUserInput(self, userInput, chatHistory=[]):
52
- gptSampleRows = self.gptSampleRows
53
  selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols()
54
- prospectTablesAndCols = self.getProspectiveTablesAndCols(userInput, selectedTablesAndCols, chatHistory)
55
- print("getting prospects", prospectTablesAndCols)
56
- prospectTablesData = self.filteredSampleDataForProspects(prospectTablesAndCols)
57
- systemPromptForQueryGeneration = self.getSystemPromptForQueryGeneration(prospectTablesData, gptSampleRows=gptSampleRows)
58
- queryByGpt = self.gptInstance.getResponseForUserInput(userInput, systemPromptForQueryGeneration, chatHistory)
59
-
60
- queryByGpt = preProcessGptQueryReponse(queryByGpt, metadataLayout=self.metadataLayout)
61
- return queryByGpt, prospectTablesAndCols
62
-
63
- def getProspectiveTablesAndCols(self, userInput, selectedTablesAndCols, chatHistory=[]):
 
 
 
64
  schemaName = self.schemaName
 
 
 
65
 
66
- systemPromptForProspectColumns = self.getSystemPromptForProspectColumns(selectedTablesAndCols)
67
- prospectiveTablesColsText = self.gptInstance.getResponseForUserInput(userInput, systemPromptForProspectColumns, chatHistory)
68
- prospectTablesAndCols = {}
69
- for table in selectedTablesAndCols.keys():
70
- if table in prospectiveTablesColsText:
71
- prospectTablesAndCols[table] = []
72
- for column in selectedTablesAndCols[table]:
73
- if column in prospectiveTablesColsText:
74
- prospectTablesAndCols[table].append(column)
75
- return prospectTablesAndCols
76
 
77
- def getSystemPromptForQueryGeneration(self, prospectTablesData, gptSampleRows):
78
  schemaName = self.schemaName
79
  platform = self.platform
 
80
  exampleQuery = """SELECT a.customer_id, COUNT(a.product_id) as chandelier_count
81
  FROM lpdatamart.tbl_f_sales a
82
  JOIN lpdatamart.tbl_d_product b ON a.product_id = b.product_id
@@ -86,20 +85,9 @@ GROUP BY a.customer_id
86
  ORDER BY chandelier_count DESC"""
87
 
88
  question = "top 5 customers who bought most chandeliers in nov 2023"
89
- prompt = f"""Given an input text, generate the corresponding SQL query for given details. Schema Name is {schemaName}. And sql platform is {platform}.\n following is sample data. Also
90
- There is example user input and desired generated sql query. Follow similar patterns as example. eg case insensitive, explicit variable declaration etc. user input : {question}, query : {exampleQuery} """
91
- for idx, tableName in enumerate(prospectTablesData.keys(), start=1):
92
- prompt += f"table name is {tableName}, table data is {prospectTablesData[tableName].head(gptSampleRows)}"
93
- prompt += "XXXX"
94
- return prompt.replace("\n"," ").replace("\\"," ").replace(" "," ").replace("XXXX", " ")
95
-
96
- def getSystemPromptForProspectColumns(self, selectedTablesAndCols):
97
- schemaName = self.schemaName
98
- platform = self.platform
99
-
100
- prompt = f"""Given an input text, User wants to know which all tables and columns would be possibily to have the desired data. Output them as json. Schema Name is {schemaName}. And sql platform is {platform}.\n"""
101
- for idx, tableName in enumerate(selectedTablesAndCols.keys(), start=1):
102
- prompt += f"table name {tableName} {', '.join(selectedTablesAndCols[tableName])}"
103
- prompt += "XXXX"
104
- return prompt.replace("\n"," ").replace("\\"," ").replace(" "," ").replace("XXXX", " ")
105
 
 
1
  from gptManager import ChatgptManager
2
  from utils import *
3
+ import json
4
+ from constants import TABLE_RELATIONS
5
 
6
  class QueryHelper:
7
+ def __init__(self, gptInstanceForTableCols: ChatgptManager,
8
+ gptInstanceForQuery: ChatgptManager,
9
+ dbEngine, schemaName,
10
  platform, metadataLayout: MetaDataLayout, sampleDataRows,
11
+ gptSampleRows, getSampleDataForTablesAndCols, tableSummaryJson='tableSummaryDict.json'):
12
+ self.gptInstanceForTableCols = gptInstanceForTableCols
13
+ self.gptInstanceForQuery = gptInstanceForQuery
14
  self.schemaName = schemaName
15
  self.platform = platform
16
  self.metadataLayout = metadataLayout
 
18
  self.gptSampleRows = gptSampleRows
19
  self.getSampleDataForTablesAndCols = getSampleDataForTablesAndCols
20
  self.dbEngine = dbEngine
21
+ self.tableSummaryJson = tableSummaryJson
22
  self._onMetadataChange()
23
 
24
  def _onMetadataChange(self):
 
26
  sampleDataRows = self.sampleDataRows
27
  dbEngine = self.dbEngine
28
  schemaName = self.schemaName
 
29
  selectedTablesAndCols = metadataLayout.getSelectedTablesAndCols()
30
  self.sampleData = self.getSampleDataForTablesAndCols(dbEngine=dbEngine,schemaName=schemaName,
31
  tablesAndCols=selectedTablesAndCols, maxRows=sampleDataRows)
32
+ self.promptTableColsInfo = self.getSystemPromptForTableCols()
33
+ self.gptInstanceForTableCols.setSystemPrompt(self.promptTableColsInfo)
34
+
35
 
36
  def getMetadata(self) -> MetaDataLayout :
37
  return self.metadataLayout
 
39
  def updateMetadata(self, metadataLayout):
40
  self.metadataLayout = metadataLayout
41
  self._onMetadataChange()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ def getQueryForUserInput(self, userInput):
44
+ prospectTablesAndColsText = self.gptInstanceForTableCols.getResponseForUserInput(userInput)
45
  selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols()
46
+ prospectTablesAndCols = dict()
47
+ for table in selectedTablesAndCols:
48
+ if table in prospectTablesAndColsText:
49
+ prospectTablesAndCols[table] = []
50
+ for col in selectedTablesAndCols[table]:
51
+ if col in prospectTablesAndColsText:
52
+ prospectTablesAndCols[table].append(col)
53
+ promptForQuery = getSystemPromptForQuery(prospectTablesAndCols)
54
+ self.gptInstanceForQuery.setSystemPrompt(promptForQuery)
55
+ gptResponse = self.gptInstanceForQuery.getResponseForUserInput(userInput)
56
+ return gptResponse
57
+
58
+ def getSystemPromptForTableCols(self):
59
  schemaName = self.schemaName
60
+ platform = self.platform
61
+ tableSummaryDict = json.load(self.tableSummaryJson)
62
+ selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols()
63
 
64
+ promptTableInfo = f"""You are a powerful text to sql model. Answer which tables and columns are needed
65
+ to answer user input using sql query. and following are tables and columns info. and example user input and result query."""
66
+ for idx, tableName in enumerate(selectedTablesAndCols.keys(), start=1):
67
+ promptTableInfo += f"table name {tableName} and summary is {tableSummaryDict[tableName]}"
68
+ promptTableInfo += f" and columns {', '.join(selectedTablesAndCols[tableName])} \n"
69
+ promptTableInfo += "XXXX"
70
+ #Join statements
71
+ promptTableInfo += f"and table Relations are {TABLE_RELATIONS}"
72
+ return promptTableInfo
73
+
74
 
75
+ def getSystemPromptForQuery(self, prospectTablesAndCols):
76
  schemaName = self.schemaName
77
  platform = self.platform
78
+ tableSummaryDict = json.load(self.tableSummaryJson)
79
  exampleQuery = """SELECT a.customer_id, COUNT(a.product_id) as chandelier_count
80
  FROM lpdatamart.tbl_f_sales a
81
  JOIN lpdatamart.tbl_d_product b ON a.product_id = b.product_id
 
85
  ORDER BY chandelier_count DESC"""
86
 
87
  question = "top 5 customers who bought most chandeliers in nov 2023"
88
+ promptForQuery = f"""You are a powerful text to sql model. Answer user input with sql query. And the query needs to run on {platform}. and schemaName is {schemaName}. There is example user input and desired generated sql query. Follow similar patterns as example. eg case insensitive, explicit variable declaration etc. user input : {question}, query : {exampleQuery}. and table's data is \n"""
89
+ for idx, tableName in enumerate(prospectTablesAndCols.keys(), start=1):
90
+ promptForQuery += f"table name is {tableName}, table data is {self.sampleData[tableName][prospectTablesAndCols[tableName]].head(gptSampleRows)}"
91
+ promptForQuery += f"and table Relations are {TABLE_RELATIONS}"
92
+ return promptForQuery.replace("\\"," ").replace(" "," ").replace("XXXX", " ")
 
 
 
 
 
 
 
 
 
 
 
93
 
queryHelperManagerCoT.py CHANGED
@@ -1,13 +1,14 @@
1
  from gptManager import ChatgptManager
2
  from utils import *
3
- import re
4
- import json
5
 
6
- class QueryHelperChainOfThought:
7
- def __init__(self, gptInstance: ChatgptManager, dbEngine, schemaName,
 
8
  platform, metadataLayout: MetaDataLayout, sampleDataRows,
9
- gptSampleRows, getSampleDataForTablesAndCols):
10
- self.gptInstance = gptInstance
11
  self.schemaName = schemaName
12
  self.platform = platform
13
  self.metadataLayout = metadataLayout
@@ -15,6 +16,7 @@ class QueryHelperChainOfThought:
15
  self.gptSampleRows = gptSampleRows
16
  self.getSampleDataForTablesAndCols = getSampleDataForTablesAndCols
17
  self.dbEngine = dbEngine
 
18
  self._onMetadataChange()
19
 
20
  def _onMetadataChange(self):
@@ -22,10 +24,10 @@ class QueryHelperChainOfThought:
22
  sampleDataRows = self.sampleDataRows
23
  dbEngine = self.dbEngine
24
  schemaName = self.schemaName
25
-
26
  selectedTablesAndCols = metadataLayout.getSelectedTablesAndCols()
27
  self.sampleData = self.getSampleDataForTablesAndCols(dbEngine=dbEngine,schemaName=schemaName,
28
  tablesAndCols=selectedTablesAndCols, maxRows=sampleDataRows)
 
29
 
30
  def getMetadata(self) -> MetaDataLayout :
31
  return self.metadataLayout
@@ -33,91 +35,93 @@ class QueryHelperChainOfThought:
33
  def updateMetadata(self, metadataLayout):
34
  self.metadataLayout = metadataLayout
35
  self._onMetadataChange()
36
-
37
- def modifySqlQueryEnteredByUser(self, userSqlQuery):
38
- platform = self.platform
39
- userPrompt = f"Please correct the following sql query, also it has to be run on {platform}. sql query is \n {userSqlQuery}."
40
- systemPrompt = ""
41
- modifiedSql = self.gptInstance.getResponseForUserInput(userPrompt, systemPrompt)
42
- return modifiedSql
43
-
44
- def filteredSampleDataForProspects(self, prospectTablesAndCols):
45
- sampleData = self.sampleData
46
- filteredData = {}
47
- for table in prospectTablesAndCols.keys():
48
- # filteredData[table] = sampleData[table][prospectTablesAndCols[table]]
49
- #take all columns of prospects
50
- filteredData[table] = sampleData[table][prospectTablesAndCols[table]]
51
- return filteredData
52
-
53
- def extractSingleJson(self, text):
54
- pattern = r'\{.*?\}'
55
- matches = re.findall(pattern, text, re.DOTALL)
56
- extracted_json = [json.loads(match) for match in matches][0]
57
- return extracted_json
58
 
59
  def getQueryForUserInputCoT(self, userInput):
60
- selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols()
61
- systemPromptTail = self.getSystemPromptTailForCoTStep1(selectedTablesAndCols)
62
- #1. Is the input complete to create a query, or ask user to reask with more detailed input
63
- systemPromptForInputClarification = """Given an input text, user want to generate sql query. Please answer if the user input is complete or user needs to ask in more detailed way. Answer in following format. 'Yes' ; if yes, break the userinput into smaller subtask for query generation. Formatted into
64
- {
65
- "Task 1": "task 1 description",
66
- "Task 2": "task 2 description"
67
- }
68
- 'No' ; if no, then Reason- please be more detailed about customer details; if more modification needed."""
69
- systemPromptForInputClarification = systemPromptForInputClarification + '\n' + systemPromptTail
70
- cotStep1 = self.gptInstance.getResponseForUserInput(userInput, systemPromptForInputClarification)
71
- if "yes" in cotStep1.lower()[:5]:
72
- print("User input sufficient")
73
- tasks = self.extractSingleJson(cotStep1)
74
- print(f"tasks are {tasks}")
75
- taskQueries = {}
76
- prospectTablesAndColsAll = []
77
- for key, task in tasks.items():
78
- taskQuery, prospectTablesAndCols = self.getQueryForUserInput(userInput)
79
- taskQueries[key] = {"task":task, "taskQuery":taskQuery}
80
- prospectTablesAndColsAll.append(prospectTablesAndCols)
81
- print(f"tasks and their queries {taskQueries}")
82
-
83
- combiningSubtasksQueryPrompt = f"""Combine following subtask and their queries to generate sql query to answer the user input.\n """
84
- userPrompt = f"user input is {userInput}"
85
- for key in taskQueries.keys():
86
- task = taskQueries[key]["task"]
87
- query = taskQueries[key]["taskQuery"]
88
- userPrompt += f" task: {task}, task query: {query}"
89
- return self.gptInstance.getResponseForUserInput(userPrompt, combiningSubtasksQueryPrompt), prospectTablesAndColsAll
90
- return f"Please rephrase your query. {' '.join(cotStep1.split('Reason')[1:])}", None
91
 
92
- def getQueryForUserInput(self, userInput, chatHistory=[]):
93
- gptSampleRows = self.gptSampleRows
 
 
94
  selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols()
95
- prospectTablesAndCols = self.getProspectiveTablesAndCols(userInput, selectedTablesAndCols, chatHistory)
96
- print("getting prospects", prospectTablesAndCols)
97
- prospectTablesData = self.filteredSampleDataForProspects(prospectTablesAndCols)
98
- systemPromptForQueryGeneration = self.getSystemPromptForQueryGeneration(prospectTablesData, gptSampleRows=gptSampleRows)
99
- queryByGpt = self.gptInstance.getResponseForUserInput(userInput, systemPromptForQueryGeneration, chatHistory)
100
 
101
- queryByGpt = preProcessGptQueryReponse(queryByGpt, metadataLayout=self.metadataLayout)
102
- return queryByGpt, prospectTablesAndCols
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
- def getProspectiveTablesAndCols(self, userInput, selectedTablesAndCols, chatHistory=[]):
 
 
 
 
 
 
 
 
 
 
 
105
  schemaName = self.schemaName
 
 
 
106
 
107
- systemPromptForProspectColumns = self.getSystemPromptForProspectColumns(selectedTablesAndCols)
108
- prospectiveTablesColsText = self.gptInstance.getResponseForUserInput(userInput, systemPromptForProspectColumns, chatHistory)
109
- prospectTablesAndCols = {}
110
- for table in selectedTablesAndCols.keys():
111
- if table in prospectiveTablesColsText:
112
- prospectTablesAndCols[table] = []
113
- for column in selectedTablesAndCols[table]:
114
- if column in prospectiveTablesColsText:
115
- prospectTablesAndCols[table].append(column)
116
- return prospectTablesAndCols
117
 
118
- def getSystemPromptForQueryGeneration(self, prospectTablesData, gptSampleRows):
119
  schemaName = self.schemaName
120
  platform = self.platform
 
121
  exampleQuery = """SELECT a.customer_id, COUNT(a.product_id) as chandelier_count
122
  FROM lpdatamart.tbl_f_sales a
123
  JOIN lpdatamart.tbl_d_product b ON a.product_id = b.product_id
@@ -127,29 +131,9 @@ GROUP BY a.customer_id
127
  ORDER BY chandelier_count DESC"""
128
 
129
  question = "top 5 customers who bought most chandeliers in nov 2023"
130
- prompt = f"""Given an input text, generate the corresponding SQL query for given details. Schema Name is {schemaName}. And sql platform is {platform}.\n following is sample data. Also
131
- There is example user input and desired generated sql query. Follow similar patterns as example. eg case insensitive, explicit variable declaration etc. user input : {question}, query : {exampleQuery} """
132
- for idx, tableName in enumerate(prospectTablesData.keys(), start=1):
133
- prompt += f"table name is {tableName}, table data is {prospectTablesData[tableName].head(gptSampleRows)}"
134
- prompt += "XXXX"
135
- return prompt.replace("\n"," ").replace("\\"," ").replace(" "," ").replace("XXXX", " ")
136
-
137
- def getSystemPromptForProspectColumns(self, selectedTablesAndCols):
138
- schemaName = self.schemaName
139
- platform = self.platform
140
-
141
- prompt = f"""Given an input text, User wants to know which all tables and columns would be possibily to have the desired data. Output them as json. Schema Name is {schemaName}. And sql platform is {platform}.\n"""
142
- for idx, tableName in enumerate(selectedTablesAndCols.keys(), start=1):
143
- prompt += f"table name {tableName} {', '.join(selectedTablesAndCols[tableName])}"
144
- prompt += "XXXX"
145
- return prompt.replace("\n"," ").replace("\\"," ").replace(" "," ").replace("XXXX", " ")
146
-
147
- def getSystemPromptTailForCoTStep1(self, selectedTablesAndCols):
148
- schemaName = self.schemaName
149
- platform = self.platform
150
-
151
- prompt = f"""schema name is {schemaName}. And sql platform is {platform}. and table info are below.\n"""
152
- for idx, tableName in enumerate(selectedTablesAndCols.keys(), start=1):
153
- prompt += f"table name {tableName} {', '.join(selectedTablesAndCols[tableName])}"
154
- prompt += "XXXX"
155
- return prompt.replace("\n"," ").replace("\\"," ").replace(" "," ").replace("XXXX", " ")
 
1
  from gptManager import ChatgptManager
2
  from utils import *
3
+ import json
4
+ from constants import TABLE_RELATIONS
5
 
6
+ class QueryHelper:
7
+ def __init__(self, gptInstanceForCoT: ChatgptManager,
8
+ dbEngine, schemaName,
9
  platform, metadataLayout: MetaDataLayout, sampleDataRows,
10
+ gptSampleRows, getSampleDataForTablesAndCols, tableSummaryJson='tableSummaryDict.json'):
11
+ self.gptInstanceForCoT = gptInstanceForCoT
12
  self.schemaName = schemaName
13
  self.platform = platform
14
  self.metadataLayout = metadataLayout
 
16
  self.gptSampleRows = gptSampleRows
17
  self.getSampleDataForTablesAndCols = getSampleDataForTablesAndCols
18
  self.dbEngine = dbEngine
19
+ self.tableSummaryJson = tableSummaryJson
20
  self._onMetadataChange()
21
 
22
  def _onMetadataChange(self):
 
24
  sampleDataRows = self.sampleDataRows
25
  dbEngine = self.dbEngine
26
  schemaName = self.schemaName
 
27
  selectedTablesAndCols = metadataLayout.getSelectedTablesAndCols()
28
  self.sampleData = self.getSampleDataForTablesAndCols(dbEngine=dbEngine,schemaName=schemaName,
29
  tablesAndCols=selectedTablesAndCols, maxRows=sampleDataRows)
30
+
31
 
32
  def getMetadata(self) -> MetaDataLayout :
33
  return self.metadataLayout
 
35
  def updateMetadata(self, metadataLayout):
36
  self.metadataLayout = metadataLayout
37
  self._onMetadataChange()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  def getQueryForUserInputCoT(self, userInput):
40
+ prompt = self.getPromptForCot()
41
+ self.gptInstanceForCot.setSystemPrompt(userInput)
42
+ gptResponse = self.gptInstanceForCoT.getResponseForUserInput(userInput)
43
+ return gptResponse
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ def getPromptForCot(self):
46
+ schemaName = self.schemaName
47
+ platform = self.platform
48
+ tableSummaryDict = json.load(self.tableSummaryJson)
49
  selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols()
50
+
51
+ egUserInput = "I want to get top 5 product categories by state, then rank categories on decreasing order of total sales"
 
 
 
52
 
53
+ cotSubtaskOutput = """{
54
+ "subquery1": {
55
+ "inputSubquery": [],
56
+ "descriptioin":"calculate the total sales and assigns ranks to product categories within each state based on the descending order of sales in the tbl_f_sales table, utilizing joins with tbl_d_product and tbl_d_customer tables.",
57
+ "result": "SELECT c.state, b.category, SUM(a.transaction_amount) as total_sales,
58
+ RANK() OVER(PARTITION BY c.state ORDER BY SUM(a.transaction_amount) DESC) as category_rank
59
+ FROM lpdatamart.tbl_f_sales a
60
+ JOIN lpdatamart.tbl_d_product b
61
+ ON a.product_id = b.product_id
62
+ JOIN lpdatamart.tbl_d_customer c
63
+ ON a.customer_id = c.customer_id
64
+ GROUP BY c.state, b.category "
65
+ },
66
+ "subquery2": {
67
+ "inputSubquery": ["subquery1"],
68
+ "description":"extracts state, category, and total sales information from a subquery named "subquery1," filtering the results to include only categories with ranks up to 5 and sorting them by state and category rank."
69
+ "result":"SELECT state, category, total_sales
70
+ FROM ranked_categories
71
+ WHERE category_rank <= 5
72
+ ORDER BY state, category_rank"
73
+ },
74
+ "finalResult":"WITH subquery1 AS (
75
+ SELECT c.state, b.category, SUM(a.transaction_amount) as total_sales,
76
+ RANK() OVER(PARTITION BY c.state ORDER BY SUM(a.transaction_amount) DESC) as category_rank
77
+ FROM lpdatamart.tbl_f_sales a
78
+ JOIN lpdatamart.tbl_d_product b
79
+ ON a.product_id = b.product_id
80
+ JOIN lpdatamart.tbl_d_customer c
81
+ ON a.customer_id = c.customer_id
82
+ GROUP BY c.state, b.category
83
+ )
84
+ SELECT state, category, total_sales
85
+ FROM subquery1
86
+ WHERE category_rank <= 5
87
+ ORDER BY state, category_rank"
88
+ }"""
89
+ promptTableInfo = self.getSystemPromptForTableCols()
90
+ selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols()
91
+ promptColumnsInfo = getSystemPromptForQuery(selectedTablesAndCols)
92
 
93
+ prompt = f"""You are a powerful text to sql model. Your task is to return sql query which answers
94
+ user's input. Please follow subquery structure if the sql needs to have multiple subqueries.
95
+ ###example userInput {egUserInput}. output {cotSubtaskStructure}
96
+ tables information are {promptTableInfo}.
97
+ columns data are {promptColumnsInfo}.
98
+ """
99
+
100
+ prompt += f"and table Relations are {TABLE_RELATIONS}"
101
+
102
+ return prompt
103
+
104
+ def getSystemPromptForTableCols(self):
105
  schemaName = self.schemaName
106
+ platform = self.platform
107
+ tableSummaryDict = json.load(self.tableSummaryJson)
108
+ selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols()
109
 
110
+ promptTableInfo = f"""You are a powerful text to sql model. Answer which tables and columns are needed
111
+ to answer user input using sql query. and following are tables and columns info. and example user input and result query."""
112
+ for idx, tableName in enumerate(selectedTablesAndCols.keys(), start=1):
113
+ promptTableInfo += f"table name {tableName} and summary is {tableSummaryDict[tableName]}"
114
+ promptTableInfo += f" and columns {', '.join(selectedTablesAndCols[tableName])} \n"
115
+ promptTableInfo += "XXXX"
116
+ #Join statements
117
+ promptTableInfo += f"and table Relations are {TABLE_RELATIONS}"
118
+ return promptTableInfo
119
+
120
 
121
+ def getSystemPromptForQuery(self, prospectTablesAndCols):
122
  schemaName = self.schemaName
123
  platform = self.platform
124
+ tableSummaryDict = json.load(self.tableSummaryJson)
125
  exampleQuery = """SELECT a.customer_id, COUNT(a.product_id) as chandelier_count
126
  FROM lpdatamart.tbl_f_sales a
127
  JOIN lpdatamart.tbl_d_product b ON a.product_id = b.product_id
 
131
  ORDER BY chandelier_count DESC"""
132
 
133
  question = "top 5 customers who bought most chandeliers in nov 2023"
134
+ promptForQuery = f"""You are a powerful text to sql model. Answer user input with sql query. And the query needs to run on {platform}. and schemaName is {schemaName}. There is example user input and desired generated sql query. Follow similar patterns as example. eg case insensitive, explicit variable declaration etc. user input : {question}, query : {exampleQuery}. and table's data is \n"""
135
+ for idx, tableName in enumerate(prospectTablesAndCols.keys(), start=1):
136
+ promptForQuery += f"table name is {tableName}, table data is {self.sampleData[tableName][prospectTablesAndCols[tableName]].head(gptSampleRows)}"
137
+ promptForQuery += f"and table Relations are {TABLE_RELATIONS}"
138
+ return promptForQuery.replace("\\"," ").replace(" "," ").replace("XXXX", " ")
139
+