File size: 5,500 Bytes
1dda07c
c5ff675
1dda07c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9aff6be
1dda07c
 
 
 
 
 
 
 
 
 
 
8819e5a
1dda07c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb64f16
a978b21
 
 
 
 
fb64f16
 
a978b21
 
 
1dda07c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from gptManager import ChatgptManager
from utils import *

class QueryHelper:
  def __init__(self, gptInstance: ChatgptManager, dbEngine, schemaName, 
               platform, metadataLayout: MetaDataLayout, sampleDataRows, 
               gptSampleRows, getSampleDataForTablesAndCols):
    self.gptInstance = gptInstance
    self.schemaName = schemaName
    self.platform = platform
    self.metadataLayout = metadataLayout
    self.sampleDataRows = sampleDataRows
    self.gptSampleRows = gptSampleRows
    self.getSampleDataForTablesAndCols = getSampleDataForTablesAndCols
    self.dbEngine = dbEngine
    self._onMetadataChange()
  
  def _onMetadataChange(self):
    metadataLayout = self.metadataLayout
    sampleDataRows = self.sampleDataRows
    dbEngine = self.dbEngine
    schemaName = self.schemaName
    
    selectedTablesAndCols = metadataLayout.getSelectedTablesAndCols()
    self.sampleData = self.getSampleDataForTablesAndCols(dbEngine=dbEngine,schemaName=schemaName,
                                                         tablesAndCols=selectedTablesAndCols, maxRows=sampleDataRows)
  
  def getMetadata(self) -> MetaDataLayout :
    return self.metadataLayout
  
  def updateMetadata(self, metadataLayout):
    self.metadataLayout = metadataLayout
    self._onMetadataChange()

  def modifySqlQueryEnteredByUser(self, userSqlQuery):
    platform = self.platform
    userPrompt = f"Please correct the following sql query, also it has to be run on {platform}. sql query is \n {userSqlQuery}."
    systemPrompt = ""
    modifiedSql = self.gptInstance.getResponseForUserInput(userPrompt, systemPrompt)
    return modifiedSql
  
  def filteredSampleDataForProspects(self, prospectTablesAndCols):
    sampleData = self.sampleData
    filteredData = {}
    for table in prospectTablesAndCols.keys():
      # filteredData[table] = sampleData[table][prospectTablesAndCols[table]]
      #take all columns of prospects
      filteredData[table] = sampleData[table][prospectTablesAndCols[table]]
    return filteredData
  
  def getQueryForUserInput(self, userInput, chatHistory=[]):
    gptSampleRows = self.gptSampleRows
    selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols()
    prospectTablesAndCols = self.getProspectiveTablesAndCols(userInput, selectedTablesAndCols, chatHistory)
    print("getting prospects", prospectTablesAndCols)
    prospectTablesData = self.filteredSampleDataForProspects(prospectTablesAndCols)
    systemPromptForQueryGeneration = self.getSystemPromptForQueryGeneration(prospectTablesData, gptSampleRows=gptSampleRows)
    queryByGpt = self.gptInstance.getResponseForUserInput(userInput, systemPromptForQueryGeneration, chatHistory)
    
    queryByGpt = preProcessGptQueryReponse(queryByGpt, metadataLayout=self.metadataLayout)
    return queryByGpt, prospectTablesAndCols

  def getProspectiveTablesAndCols(self, userInput, selectedTablesAndCols, chatHistory=[]):
    schemaName = self.schemaName

    systemPromptForProspectColumns = self.getSystemPromptForProspectColumns(selectedTablesAndCols)
    prospectiveTablesColsText = self.gptInstance.getResponseForUserInput(userInput, systemPromptForProspectColumns, chatHistory)
    prospectTablesAndCols = {}
    for table in selectedTablesAndCols.keys():
      if table in prospectiveTablesColsText:
        prospectTablesAndCols[table] = []
        for column in selectedTablesAndCols[table]:
          if column in prospectiveTablesColsText:
            prospectTablesAndCols[table].append(column)
    return prospectTablesAndCols

  def getSystemPromptForQueryGeneration(self, prospectTablesData, gptSampleRows):
    schemaName = self.schemaName
    platform = self.platform
    exampleQuery = """SELECT a.customer_id, COUNT(a.product_id) as chandelier_count
FROM lpdatamart.tbl_f_sales a
JOIN lpdatamart.tbl_d_product b ON a.product_id = b.product_id
JOIN lpdatamart.tbl_d_calendar c ON a.date_id = c.date_id
WHERE UPPER(b.product_name) LIKE '%CHANDELIER%' AND c.calendar_month = 'NOVEMBER' AND c.year = 2023
GROUP BY a.customer_id
ORDER BY chandelier_count DESC"""
      
    question = "top 5 customers who bought most chandeliers in nov 2023"
    prompt = f"""Given an input text, generate the corresponding SQL query for given details. Schema Name is {schemaName}. And sql platform is {platform}.\n following is sample data. Also 
    There is example user input and desired generated sql query. Follow similar patterns as example. eg case insensitive, explicit variable declaration etc. user input : {question}, query : {exampleQuery} """
    for idx, tableName in enumerate(prospectTablesData.keys(), start=1):
        prompt += f"table name is {tableName}, table data is {prospectTablesData[tableName].head(gptSampleRows)}"
    prompt += "XXXX"
    return prompt.replace("\n"," ").replace("\\"," ").replace("  "," ").replace("XXXX", "    ")

  def getSystemPromptForProspectColumns(self, selectedTablesAndCols):
    schemaName = self.schemaName
    platform = self.platform

    prompt = f"""Given an input text, User wants to know which all tables and columns would be possibily to have the desired data. Output them as json. Schema Name is {schemaName}. And sql platform is {platform}.\n"""
    for idx, tableName in enumerate(selectedTablesAndCols.keys(), start=1):
        prompt += f"table name {tableName} {', '.join(selectedTablesAndCols[tableName])}"
    prompt += "XXXX"
    return prompt.replace("\n"," ").replace("\\"," ").replace("  "," ").replace("XXXX", "    ")