File size: 6,812 Bytes
775d0c7
 
0ff15a5
 
775d0c7
9928b56
0ff15a5
 
2a4b462
0ff15a5
 
2a4b462
 
 
 
 
 
 
0ff15a5
2a4b462
 
 
 
 
 
 
 
 
 
0ff15a5
2a4b462
 
 
 
 
 
 
 
 
0ff15a5
 
 
 
2a4b462
0ff15a5
 
 
 
2a4b462
0ff15a5
 
2a4b462
0ff15a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a4b462
0ff15a5
 
 
 
 
 
 
 
 
 
 
 
2a4b462
0ff15a5
5fa214e
0ff15a5
2a4b462
0ff15a5
 
 
 
 
 
 
 
 
 
2a4b462
0ff15a5
2a4b462
 
0ff15a5
2a4b462
 
 
 
 
 
 
 
 
0ff15a5
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from gptManager import ChatgptManager
from utils import *
import json 
from constants import TABLE_RELATIONS

class QueryHelperChainOfThought:
  def __init__(self, gptInstanceForCoT: ChatgptManager,
               dbEngine, schemaName, 
               platform, metadataLayout: MetaDataLayout, sampleDataRows, 
               gptSampleRows, getSampleDataForTablesAndCols, tableSummaryJson='tableSummaryDict.json'):
    self.gptInstanceForCoT = gptInstanceForCoT
    self.schemaName = schemaName
    self.platform = platform
    self.metadataLayout = metadataLayout
    self.sampleDataRows = sampleDataRows
    self.gptSampleRows = gptSampleRows
    self.getSampleDataForTablesAndCols = getSampleDataForTablesAndCols
    self.dbEngine = dbEngine
    self.tableSummaryJson = tableSummaryJson
    self._onMetadataChange()
  
  def _onMetadataChange(self):
    metadataLayout = self.metadataLayout
    sampleDataRows = self.sampleDataRows
    dbEngine = self.dbEngine
    schemaName = self.schemaName
    selectedTablesAndCols = metadataLayout.getSelectedTablesAndCols()
    self.sampleData = self.getSampleDataForTablesAndCols(dbEngine=dbEngine,schemaName=schemaName,
                                                         tablesAndCols=selectedTablesAndCols, maxRows=sampleDataRows)
    
  
  def getMetadata(self) -> MetaDataLayout :
    return self.metadataLayout
  
  def updateMetadata(self, metadataLayout):
    self.metadataLayout = metadataLayout
    self._onMetadataChange()
  
  def getQueryForUserInputCoT(self, userInput):
    prompt = self.getPromptForCot()
    self.gptInstanceForCot.setSystemPrompt(userInput)
    gptResponse = self.gptInstanceForCoT.getResponseForUserInput(userInput)
    return gptResponse
  
  def getPromptForCot(self):
    schemaName = self.schemaName
    platform = self.platform
    tableSummaryDict = json.load(self.tableSummaryJson)
    selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols()

    egUserInput = "I want to get top 5 product categories by state, then rank categories on decreasing order of total sales"
    
    cotSubtaskOutput = """{
          "subquery1": {
            "inputSubquery": [],
            "descriptioin":"calculate the total sales and assigns ranks to product categories within each state based on the descending order of sales in the tbl_f_sales table, utilizing joins with tbl_d_product and tbl_d_customer tables.",
            "result": "SELECT c.state, b.category, SUM(a.transaction_amount) as total_sales,
    RANK() OVER(PARTITION BY c.state ORDER BY SUM(a.transaction_amount) DESC) as category_rank
    FROM lpdatamart.tbl_f_sales a
    JOIN lpdatamart.tbl_d_product b
    ON a.product_id = b.product_id
    JOIN lpdatamart.tbl_d_customer c
    ON a.customer_id = c.customer_id
    GROUP BY c.state, b.category "
          },
          "subquery2": {
            "inputSubquery": ["subquery1"],
            "description":"extracts state, category, and total sales information from a subquery named "subquery1," filtering the results to include only categories with ranks up to 5 and sorting them by state and category rank."
            "result":"SELECT state, category, total_sales
FROM ranked_categories
WHERE category_rank <= 5
ORDER BY state, category_rank"
          },
          "finalResult":"WITH subquery1 AS (
    SELECT c.state, b.category, SUM(a.transaction_amount) as total_sales,
    RANK() OVER(PARTITION BY c.state ORDER BY SUM(a.transaction_amount) DESC) as category_rank
    FROM lpdatamart.tbl_f_sales a
    JOIN lpdatamart.tbl_d_product b
    ON a.product_id = b.product_id
    JOIN lpdatamart.tbl_d_customer c
    ON a.customer_id = c.customer_id
    GROUP BY c.state, b.category
)
SELECT state, category, total_sales
FROM subquery1
WHERE category_rank <= 5
ORDER BY state, category_rank"
        }"""
    promptTableInfo = self.getSystemPromptForTableCols()
    selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols()
    promptColumnsInfo = getSystemPromptForQuery(selectedTablesAndCols)

    prompt = f"""You are a powerful text to sql model. Your task is to return sql query which answers
    user's input. Please follow subquery structure if the sql needs to have multiple subqueries.
    ###example userInput {egUserInput}. output {cotSubtaskStructure}
    tables information are {promptTableInfo}.
    columns data are {promptColumnsInfo}.
    """
   
    prompt += f"and table Relations are {TABLE_RELATIONS}"
    
    return prompt
  
  def getSystemPromptForTableCols(self):
    schemaName = self.schemaName
    platform = self.platform
    tableSummaryDict = json.load(open(self.tableSummaryJson, 'r'))
    selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols()

    promptTableInfo = f"""You are a powerful text to sql model. Answer which tables and columns are needed
    to answer user input using sql query. and following are tables and columns info. and example user input and result query."""
    for idx, tableName in enumerate(selectedTablesAndCols.keys(), start=1):
        promptTableInfo += f"table name {tableName} and summary is {tableSummaryDict[tableName]}"
        promptTableInfo += f" and columns {', '.join(selectedTablesAndCols[tableName])} \n"
    promptTableInfo += "XXXX"
    #Join statements
    promptTableInfo += f"and table Relations are {TABLE_RELATIONS}"
    return promptTableInfo
    

  def getSystemPromptForQuery(self, prospectTablesAndCols):
    schemaName = self.schemaName
    platform = self.platform
    tableSummaryDict = json.load(self.tableSummaryJson)
    exampleQuery = """SELECT a.customer_id, COUNT(a.product_id) as chandelier_count
FROM lpdatamart.tbl_f_sales a
JOIN lpdatamart.tbl_d_product b ON a.product_id = b.product_id
JOIN lpdatamart.tbl_d_calendar c ON a.date_id = c.date_id
WHERE UPPER(b.product_name) LIKE '%CHANDELIER%' AND c.calendar_month = 'NOVEMBER' AND c.year = 2023
GROUP BY a.customer_id
ORDER BY chandelier_count DESC"""
      
    question = "top 5 customers who bought most chandeliers in nov 2023"
    promptForQuery = f"""You are a powerful text to sql model. Answer user input with sql query. And the query needs to run on {platform}. and schemaName is {schemaName}. There is example user input and desired generated sql query. Follow similar patterns as example. eg case insensitive, explicit variable declaration etc. user input : {question}, query : {exampleQuery}. and table's data is \n"""
    for idx, tableName in enumerate(prospectTablesAndCols.keys(), start=1):
        promptForQuery += f"table name is {tableName}, table data is {self.sampleData[tableName][prospectTablesAndCols[tableName]].head(gptSampleRows)}"   
    promptForQuery += f"and table Relations are {TABLE_RELATIONS}"
    return promptForQuery.replace("\\"," ").replace("  "," ").replace("XXXX", "    ")