File size: 4,820 Bytes
1dda07c
c5ff675
0ff15a5
 
1dda07c
 
0ff15a5
 
 
1dda07c
0ff15a5
 
 
1dda07c
 
 
 
 
 
 
0ff15a5
1dda07c
 
 
 
 
 
 
 
 
 
0ff15a5
 
 
1dda07c
 
 
 
 
 
 
 
0ff15a5
 
1dda07c
0ff15a5
 
 
 
 
 
 
 
 
 
 
 
 
1dda07c
0ff15a5
 
 
1dda07c
0ff15a5
 
 
 
 
 
 
 
 
 
1dda07c
0ff15a5
1dda07c
 
0ff15a5
fb64f16
a978b21
 
 
 
 
fb64f16
 
a978b21
0ff15a5
 
 
 
 
1dda07c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from gptManager import ChatgptManager
from utils import *
import json 
from constants import TABLE_RELATIONS

class QueryHelper:
  def __init__(self, gptInstanceForTableCols: ChatgptManager,
               gptInstanceForQuery: ChatgptManager,
               dbEngine, schemaName, 
               platform, metadataLayout: MetaDataLayout, sampleDataRows, 
               gptSampleRows, getSampleDataForTablesAndCols, tableSummaryJson='tableSummaryDict.json'):
    self.gptInstanceForTableCols = gptInstanceForTableCols
    self.gptInstanceForQuery = gptInstanceForQuery
    self.schemaName = schemaName
    self.platform = platform
    self.metadataLayout = metadataLayout
    self.sampleDataRows = sampleDataRows
    self.gptSampleRows = gptSampleRows
    self.getSampleDataForTablesAndCols = getSampleDataForTablesAndCols
    self.dbEngine = dbEngine
    self.tableSummaryJson = tableSummaryJson
    self._onMetadataChange()
  
  def _onMetadataChange(self):
    metadataLayout = self.metadataLayout
    sampleDataRows = self.sampleDataRows
    dbEngine = self.dbEngine
    schemaName = self.schemaName
    selectedTablesAndCols = metadataLayout.getSelectedTablesAndCols()
    self.sampleData = self.getSampleDataForTablesAndCols(dbEngine=dbEngine,schemaName=schemaName,
                                                         tablesAndCols=selectedTablesAndCols, maxRows=sampleDataRows)
    self.promptTableColsInfo = self.getSystemPromptForTableCols()
    self.gptInstanceForTableCols.setSystemPrompt(self.promptTableColsInfo)
    
  
  def getMetadata(self) -> MetaDataLayout :
    return self.metadataLayout
  
  def updateMetadata(self, metadataLayout):
    self.metadataLayout = metadataLayout
    self._onMetadataChange()
  
  def getQueryForUserInput(self, userInput):
    prospectTablesAndColsText = self.gptInstanceForTableCols.getResponseForUserInput(userInput)
    selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols()
    prospectTablesAndCols = dict()
    for table in selectedTablesAndCols:
      if table in prospectTablesAndColsText:
        prospectTablesAndCols[table] = []
        for col in selectedTablesAndCols[table]:
          if col in prospectTablesAndColsText:
            prospectTablesAndCols[table].append(col)
    promptForQuery = getSystemPromptForQuery(prospectTablesAndCols)
    self.gptInstanceForQuery.setSystemPrompt(promptForQuery)
    gptResponse = self.gptInstanceForQuery.getResponseForUserInput(userInput)
    return gptResponse
  
  def getSystemPromptForTableCols(self):
    schemaName = self.schemaName
    platform = self.platform
    tableSummaryDict = json.load(self.tableSummaryJson)
    selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols()

    promptTableInfo = f"""You are a powerful text to sql model. Answer which tables and columns are needed
    to answer user input using sql query. and following are tables and columns info. and example user input and result query."""
    for idx, tableName in enumerate(selectedTablesAndCols.keys(), start=1):
        promptTableInfo += f"table name {tableName} and summary is {tableSummaryDict[tableName]}"
        promptTableInfo += f" and columns {', '.join(selectedTablesAndCols[tableName])} \n"
    promptTableInfo += "XXXX"
    #Join statements
    promptTableInfo += f"and table Relations are {TABLE_RELATIONS}"
    return promptTableInfo
    

  def getSystemPromptForQuery(self, prospectTablesAndCols):
    schemaName = self.schemaName
    platform = self.platform
    tableSummaryDict = json.load(self.tableSummaryJson)
    exampleQuery = """SELECT a.customer_id, COUNT(a.product_id) as chandelier_count
FROM lpdatamart.tbl_f_sales a
JOIN lpdatamart.tbl_d_product b ON a.product_id = b.product_id
JOIN lpdatamart.tbl_d_calendar c ON a.date_id = c.date_id
WHERE UPPER(b.product_name) LIKE '%CHANDELIER%' AND c.calendar_month = 'NOVEMBER' AND c.year = 2023
GROUP BY a.customer_id
ORDER BY chandelier_count DESC"""
      
    question = "top 5 customers who bought most chandeliers in nov 2023"
    promptForQuery = f"""You are a powerful text to sql model. Answer user input with sql query. And the query needs to run on {platform}. and schemaName is {schemaName}. There is example user input and desired generated sql query. Follow similar patterns as example. eg case insensitive, explicit variable declaration etc. user input : {question}, query : {exampleQuery}. and table's data is \n"""
    for idx, tableName in enumerate(prospectTablesAndCols.keys(), start=1):
        promptForQuery += f"table name is {tableName}, table data is {self.sampleData[tableName][prospectTablesAndCols[tableName]].head(gptSampleRows)}"   
    promptForQuery += f"and table Relations are {TABLE_RELATIONS}"
    return promptForQuery.replace("\\"," ").replace("  "," ").replace("XXXX", "    ")