File size: 7,380 Bytes
2a4b462
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
class QueryHelperChainOfThought:
  def __init__(self, gptInstance: ChatgptManager, dbEngine, schemaName, 
               platform, metadataLayout: MetaDataLayout, sampleDataRows, 
               gptSampleRows, getSampleDataForTablesAndCols):
    self.gptInstance = gptInstance
    self.schemaName = schemaName
    self.platform = platform
    self.metadataLayout = metadataLayout
    self.sampleDataRows = sampleDataRows
    self.gptSampleRows = gptSampleRows
    self.getSampleDataForTablesAndCols = getSampleDataForTablesAndCols
    self.dbEngine = dbEngine
    self._onMetadataChange()
  
  def _onMetadataChange(self):
    metadataLayout = self.metadataLayout
    sampleDataRows = self.sampleDataRows
    dbEngine = self.dbEngine
    schemaName = self.schemaName
    
    selectedTablesAndCols = metadataLayout.getSelectedTablesAndCols()
    self.sampleData = self.getSampleDataForTablesAndCols(dbEngine=dbEngine,schemaName=schemaName,
                                                         tablesAndCols=selectedTablesAndCols, maxRows=sampleDataRows)
  
  def getMetadata(self) -> MetaDataLayout :
    return self.metadataLayout
  
  def updateMetadata(self, metadataLayout):
    self.metadataLayout = metadataLayout
    self._onMetadataChange()

  def modifySqlQueryEnteredByUser(self, userSqlQuery):
    platform = self.platform
    userPrompt = f"Please correct the following sql query, also it has to be run on {platform}. sql query is \n {userSqlQuery}."
    systemPrompt = ""
    modifiedSql = self.gptInstance.getResponseForUserInput(userPrompt, systemPrompt)
    return modifiedSql
  
  def filteredSampleDataForProspects(self, prospectTablesAndCols):
    sampleData = self.sampleData
    filteredData = {}
    for table in prospectTablesAndCols.keys():
      # filteredData[table] = sampleData[table][prospectTablesAndCols[table]]
      #take all columns of prospects
      filteredData[table] = sampleData[table][prospectTablesAndCols[table]]
    return filteredData
  
  def extractSingleJson(self, text):
    pattern = r'\{.*?\}'
    matches = re.findall(pattern, text, re.DOTALL)
    extracted_json = [json.loads(match) for match in matches][0]
    return extracted_json
  
  def getQueryForUserInputCoT(self, userInput):
    #1. Is the input complete to create a query, or ask user to reask with more detailed input
    systemPromptForInputClarification = """Given an input text, user want to generate sql query. Please answer if the user input is complete or user needs to ask in more detailed way. Answer in following format. 'Yes' ; if yes, break the userinput into smaller subtask for query generation. Formatted into 
        {
        "Task 1": "task 1 description",
        "Task 2": "task 2 description"
        }
'No' ; if no, then Reason- please be more detailed about customer details; if more modification needed"""
    cotStep1 = self.gptInstance.getResponseForUserInput(userInput, systemPromptForInputClarification, chatHistory)
    if "yes" in cot1.lower()[:5]:
      print("User input sufficient")
      tasks = self.extractSingleJson(cotStep1)
      print(f"tasks are {tasks})
      taskQueries = {}
      for key, task in tasks.items():
        taskQuery = self.getQueryForUserInput(userInput)
        taskQueries[key] = {"task":task, "taskQuery":taskQuery}
      print(f"tasks and their queries {taskQueries}")
            
      combiningSubtasksQueryPrompt = f"""Combine following subtask and their queries to generate sql query to answer the user input.\n """
      userPrompt = f"user input is {userInput}"
      for key in taskQueries.keys():
        task = taskQueries[key]["task"]
        query = taskQueries[key]["taskQuery"]
        userPrompt += f" task: {task}, task query: {query}"
      return self.self.gptInstance.getResponseForUserInput(userPrompt, combiningSubtasksQueryPrompt)
    return f"Please rephrase your query. {' '.join(cot1.split('Reason')[1:])}
  
  def getQueryForUserInput(self, userInput, chatHistory=[]):
    gptSampleRows = self.gptSampleRows
    selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols()
    prospectTablesAndCols = self.getProspectiveTablesAndCols(userInput, selectedTablesAndCols, chatHistory)
    print("getting prospects", prospectTablesAndCols)
    prospectTablesData = self.filteredSampleDataForProspects(prospectTablesAndCols)
    systemPromptForQueryGeneration = self.getSystemPromptForQueryGeneration(prospectTablesData, gptSampleRows=gptSampleRows)
    queryByGpt = self.gptInstance.getResponseForUserInput(userInput, systemPromptForQueryGeneration, chatHistory)
    
    queryByGpt = preProcessGptQueryReponse(queryByGpt, metadataLayout=self.metadataLayout)
    return queryByGpt, prospectTablesAndCols

  def getProspectiveTablesAndCols(self, userInput, selectedTablesAndCols, chatHistory=[]):
    schemaName = self.schemaName

    systemPromptForProspectColumns = self.getSystemPromptForProspectColumns(selectedTablesAndCols)
    prospectiveTablesColsText = self.gptInstance.getResponseForUserInput(userInput, systemPromptForProspectColumns, chatHistory)
    prospectTablesAndCols = {}
    for table in selectedTablesAndCols.keys():
      if table in prospectiveTablesColsText:
        prospectTablesAndCols[table] = []
        for column in selectedTablesAndCols[table]:
          if column in prospectiveTablesColsText:
            prospectTablesAndCols[table].append(column)
    return prospectTablesAndCols

  def getSystemPromptForQueryGeneration(self, prospectTablesData, gptSampleRows):
    schemaName = self.schemaName
    platform = self.platform
    exampleQuery = """SELECT a.customer_id, COUNT(a.product_id) as chandelier_count
FROM lpdatamart.tbl_f_sales a
JOIN lpdatamart.tbl_d_product b ON a.product_id = b.product_id
JOIN lpdatamart.tbl_d_calendar c ON a.date_id = c.date_id
WHERE UPPER(b.product_name) LIKE '%CHANDELIER%' AND c.calendar_month = 'NOVEMBER' AND c.year = 2023
GROUP BY a.customer_id
ORDER BY chandelier_count DESC"""
      
    question = "top 5 customers who bought most chandeliers in nov 2023"
    prompt = f"""Given an input text, generate the corresponding SQL query for given details. Schema Name is {schemaName}. And sql platform is {platform}.\n following is sample data. Also 
    There is example user input and desired generated sql query. Follow similar patterns as example. eg case insensitive, explicit variable declaration etc. user input : {question}, query : {exampleQuery} """
    for idx, tableName in enumerate(prospectTablesData.keys(), start=1):
        prompt += f"table name is {tableName}, table data is {prospectTablesData[tableName].head(gptSampleRows)}"
    prompt += "XXXX"
    return prompt.replace("\n"," ").replace("\\"," ").replace("  "," ").replace("XXXX", "    ")

  def getSystemPromptForProspectColumns(self, selectedTablesAndCols):
    schemaName = self.schemaName
    platform = self.platform

    prompt = f"""Given an input text, User wants to know which all tables and columns would be possibily to have the desired data. Output them as json. Schema Name is {schemaName}. And sql platform is {platform}.\n"""
    for idx, tableName in enumerate(selectedTablesAndCols.keys(), start=1):
        prompt += f"table name {tableName} {', '.join(selectedTablesAndCols[tableName])}"
    prompt += "XXXX"
    return prompt.replace("\n"," ").replace("\\"," ").replace("  "," ").replace("XXXX", "    ")