File size: 7,534 Bytes
1dda07c
c5ff675
0ff15a5
11a349e
0ff15a5
1dda07c
 
0ff15a5
 
 
1dda07c
0ff15a5
 
 
1dda07c
 
 
 
 
 
 
0ff15a5
1dda07c
 
 
 
 
 
 
 
 
 
0ff15a5
 
 
1dda07c
 
 
 
 
 
 
 
0ff15a5
 
1dda07c
0ff15a5
 
 
 
 
 
 
9a04841
 
0ff15a5
 
11a349e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ff15a5
 
1dda07c
0ff15a5
714fdb8
0ff15a5
1dda07c
0ff15a5
 
 
 
 
 
 
 
 
 
1dda07c
0ff15a5
1dda07c
 
c792260
11a349e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ff15a5
11a349e
 
 
1dda07c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
from gptManager import ChatgptManager
from utils import *
import json 
import sqlparse
from constants import TABLE_RELATIONS

class QueryHelper:
  def __init__(self, gptInstanceForTableCols: ChatgptManager,
               gptInstanceForQuery: ChatgptManager,
               dbEngine, schemaName, 
               platform, metadataLayout: MetaDataLayout, sampleDataRows, 
               gptSampleRows, getSampleDataForTablesAndCols, tableSummaryJson='tableSummaryDict.json'):
    self.gptInstanceForTableCols = gptInstanceForTableCols
    self.gptInstanceForQuery = gptInstanceForQuery
    self.schemaName = schemaName
    self.platform = platform
    self.metadataLayout = metadataLayout
    self.sampleDataRows = sampleDataRows
    self.gptSampleRows = gptSampleRows
    self.getSampleDataForTablesAndCols = getSampleDataForTablesAndCols
    self.dbEngine = dbEngine
    self.tableSummaryJson = tableSummaryJson
    self._onMetadataChange()
  
  def _onMetadataChange(self):
    metadataLayout = self.metadataLayout
    sampleDataRows = self.sampleDataRows
    dbEngine = self.dbEngine
    schemaName = self.schemaName
    selectedTablesAndCols = metadataLayout.getSelectedTablesAndCols()
    self.sampleData = self.getSampleDataForTablesAndCols(dbEngine=dbEngine,schemaName=schemaName,
                                                         tablesAndCols=selectedTablesAndCols, maxRows=sampleDataRows)
    self.promptTableColsInfo = self.getSystemPromptForTableCols()
    self.gptInstanceForTableCols.setSystemPrompt(self.promptTableColsInfo)
    
  
  def getMetadata(self) -> MetaDataLayout :
    return self.metadataLayout
  
  def updateMetadata(self, metadataLayout):
    self.metadataLayout = metadataLayout
    self._onMetadataChange()
  
  def getQueryForUserInput(self, userInput):
    prospectTablesAndColsText = self.gptInstanceForTableCols.getResponseForUserInput(userInput)
    selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols()
    prospectTablesAndCols = dict()
    for table in selectedTablesAndCols:
      if table in prospectTablesAndColsText:
        prospectTablesAndCols[table] = []
        for col in selectedTablesAndCols[table]:
          if col in prospectTablesAndColsText:
            prospectTablesAndCols[table].append(col)
    print("tables and cols select by gpt", prospectTablesAndCols)
    promptForQuery = self.getSystemPromptForQuery(prospectTablesAndCols)
    self.gptInstanceForQuery.setSystemPrompt(promptForQuery)
    gptResponse = self.gptInstanceForQuery.getResponseForUserInput(userInput)
    #following CoT in select column then get query to save tokens
    tryParsing = True
    parsedSql = False
    if tryParsing:
      try:
        txt = gptResponse.split("```json")[-1].split('```')[0].replace('\n', '')
        sqlResult = json.loads(txt)['finalResult']
        parsedSql = True
        tryParsing = False
      except:
        print("Couldn't parse desired result from gpt response using method 1.")
    if tryParsing:
      try:
        sqlResult = json.loads(gptResponse)['finalResult']
        parsedSql = True
        tryParsing = False
      except:
        print("Couldn't parse desired result from gpt response using method 2")
    if parsedSql:
      isFormatted = False
      try:
        formattedSql = sqlparse.format(sqlResult, reindent=True)
        responseToReturn = formattedSql
        isFormatted = True
      except:
        isFormatted = False
      if not isFormatted:
        try:
          formattedSql = sqlparse.format(sqlResult['result'], reindent=True)
          responseToReturn = formattedSql
          print("gpt didn't give parsed result. So parsing again. the formatting.")
        except:
          responseToReturn = str(sqlResult)
    else:
      responseToReturn = gptResponse
    return responseToReturn
  
  def getSystemPromptForTableCols(self):
    schemaName = self.schemaName
    platform = self.platform
    tableSummaryDict = json.load(open(self.tableSummaryJson, 'r'))
    selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols()

    promptTableInfo = f"""You are a powerful text to sql model. Answer which tables and columns are needed
    to answer user input using sql query. and following are tables and columns info. and example user input and result query."""
    for idx, tableName in enumerate(selectedTablesAndCols.keys(), start=1):
        promptTableInfo += f"table name {tableName} and summary is {tableSummaryDict[tableName]}"
        promptTableInfo += f" and columns {', '.join(selectedTablesAndCols[tableName])} \n"
    promptTableInfo += "XXXX"
    #Join statements
    promptTableInfo += f"and table Relations are {TABLE_RELATIONS}"
    return promptTableInfo
    

  def getSystemPromptForQuery(self, prospectTablesAndCols):
    schemaName = self.schemaName
    platform = self.platform
    tableSummaryDict = json.load(open(self.tableSummaryJson, 'r'))
    egUserInput = "I want to get top 5 product categories by state, then rank categories on decreasing order of total sales"
    
    cotSubtaskOutput = """{
          "subquery1": {
            "inputSubquery": [],
            "descriptioin":"calculate the total sales and assigns ranks to product categories within each state based on the descending order of sales in the tbl_f_sales table, utilizing joins with tbl_d_product and tbl_d_customer tables.",
            "result": "SELECT c.state, b.category, SUM(a.transaction_amount) as total_sales,
    RANK() OVER(PARTITION BY c.state ORDER BY SUM(a.transaction_amount) DESC) as category_rank
    FROM lpdatamart.tbl_f_sales a
    JOIN lpdatamart.tbl_d_product b
    ON a.product_id = b.product_id
    JOIN lpdatamart.tbl_d_customer c
    ON a.customer_id = c.customer_id
    GROUP BY c.state, b.category "
          },
          "subquery2": {
            "inputSubquery": ["subquery1"],
            "description":"extracts state, category, and total sales information from a subquery named "subquery1," filtering the results to include only categories with ranks up to 5 and sorting them by state and category rank."
            "result":"SELECT state, category, total_sales
FROM ranked_categories
WHERE category_rank <= 5
ORDER BY state, category_rank"
          },
          "finalResult":"WITH subquery1 AS (
    SELECT c.state, b.category, SUM(a.transaction_amount) as total_sales,
    RANK() OVER(PARTITION BY c.state ORDER BY SUM(a.transaction_amount) DESC) as category_rank
    FROM lpdatamart.tbl_f_sales a
    JOIN lpdatamart.tbl_d_product b
    ON a.product_id = b.product_id
    JOIN lpdatamart.tbl_d_customer c
    ON a.customer_id = c.customer_id
    GROUP BY c.state, b.category
)
SELECT state, category, total_sales
FROM subquery1
WHERE category_rank <= 5
ORDER BY state, category_rank"
        }"""
    prompt = f"""You are a powerful text to sql model. Your task is to return sql query which answers
    user's input. Please follow subquery structure if the sql needs to have multiple subqueries.
    ###example userInput is {egUserInput}. output is {cotSubtaskOutput}. Output should be in json format as provided. Only output should be in response, nothing else.\n\n 
    """
         
    for idx, tableName in enumerate(prospectTablesAndCols.keys(), start=1):
        prompt += f"table name is {tableName}, table data is {self.sampleData[tableName][prospectTablesAndCols[tableName]].head(self.gptSampleRows)}"   
    prompt += f"and table Relations are {TABLE_RELATIONS}"
    return prompt.replace("\\"," ").replace("  "," ").replace("XXXX", "    ")