Spaces:
Runtime error
Runtime error
File size: 7,534 Bytes
1dda07c c5ff675 0ff15a5 11a349e 0ff15a5 1dda07c 0ff15a5 1dda07c 0ff15a5 1dda07c 0ff15a5 1dda07c 0ff15a5 1dda07c 0ff15a5 1dda07c 0ff15a5 9a04841 0ff15a5 11a349e 0ff15a5 1dda07c 0ff15a5 714fdb8 0ff15a5 1dda07c 0ff15a5 1dda07c 0ff15a5 1dda07c c792260 11a349e 0ff15a5 11a349e 1dda07c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | from gptManager import ChatgptManager
from utils import *
import json
import sqlparse
from constants import TABLE_RELATIONS
class QueryHelper:
def __init__(self, gptInstanceForTableCols: ChatgptManager,
gptInstanceForQuery: ChatgptManager,
dbEngine, schemaName,
platform, metadataLayout: MetaDataLayout, sampleDataRows,
gptSampleRows, getSampleDataForTablesAndCols, tableSummaryJson='tableSummaryDict.json'):
self.gptInstanceForTableCols = gptInstanceForTableCols
self.gptInstanceForQuery = gptInstanceForQuery
self.schemaName = schemaName
self.platform = platform
self.metadataLayout = metadataLayout
self.sampleDataRows = sampleDataRows
self.gptSampleRows = gptSampleRows
self.getSampleDataForTablesAndCols = getSampleDataForTablesAndCols
self.dbEngine = dbEngine
self.tableSummaryJson = tableSummaryJson
self._onMetadataChange()
def _onMetadataChange(self):
metadataLayout = self.metadataLayout
sampleDataRows = self.sampleDataRows
dbEngine = self.dbEngine
schemaName = self.schemaName
selectedTablesAndCols = metadataLayout.getSelectedTablesAndCols()
self.sampleData = self.getSampleDataForTablesAndCols(dbEngine=dbEngine,schemaName=schemaName,
tablesAndCols=selectedTablesAndCols, maxRows=sampleDataRows)
self.promptTableColsInfo = self.getSystemPromptForTableCols()
self.gptInstanceForTableCols.setSystemPrompt(self.promptTableColsInfo)
def getMetadata(self) -> MetaDataLayout :
return self.metadataLayout
def updateMetadata(self, metadataLayout):
self.metadataLayout = metadataLayout
self._onMetadataChange()
def getQueryForUserInput(self, userInput):
prospectTablesAndColsText = self.gptInstanceForTableCols.getResponseForUserInput(userInput)
selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols()
prospectTablesAndCols = dict()
for table in selectedTablesAndCols:
if table in prospectTablesAndColsText:
prospectTablesAndCols[table] = []
for col in selectedTablesAndCols[table]:
if col in prospectTablesAndColsText:
prospectTablesAndCols[table].append(col)
print("tables and cols select by gpt", prospectTablesAndCols)
promptForQuery = self.getSystemPromptForQuery(prospectTablesAndCols)
self.gptInstanceForQuery.setSystemPrompt(promptForQuery)
gptResponse = self.gptInstanceForQuery.getResponseForUserInput(userInput)
#following CoT in select column then get query to save tokens
tryParsing = True
parsedSql = False
if tryParsing:
try:
txt = gptResponse.split("```json")[-1].split('```')[0].replace('\n', '')
sqlResult = json.loads(txt)['finalResult']
parsedSql = True
tryParsing = False
except:
print("Couldn't parse desired result from gpt response using method 1.")
if tryParsing:
try:
sqlResult = json.loads(gptResponse)['finalResult']
parsedSql = True
tryParsing = False
except:
print("Couldn't parse desired result from gpt response using method 2")
if parsedSql:
isFormatted = False
try:
formattedSql = sqlparse.format(sqlResult, reindent=True)
responseToReturn = formattedSql
isFormatted = True
except:
isFormatted = False
if not isFormatted:
try:
formattedSql = sqlparse.format(sqlResult['result'], reindent=True)
responseToReturn = formattedSql
print("gpt didn't give parsed result. So parsing again. the formatting.")
except:
responseToReturn = str(sqlResult)
else:
responseToReturn = gptResponse
return responseToReturn
def getSystemPromptForTableCols(self):
schemaName = self.schemaName
platform = self.platform
tableSummaryDict = json.load(open(self.tableSummaryJson, 'r'))
selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols()
promptTableInfo = f"""You are a powerful text to sql model. Answer which tables and columns are needed
to answer user input using sql query. and following are tables and columns info. and example user input and result query."""
for idx, tableName in enumerate(selectedTablesAndCols.keys(), start=1):
promptTableInfo += f"table name {tableName} and summary is {tableSummaryDict[tableName]}"
promptTableInfo += f" and columns {', '.join(selectedTablesAndCols[tableName])} \n"
promptTableInfo += "XXXX"
#Join statements
promptTableInfo += f"and table Relations are {TABLE_RELATIONS}"
return promptTableInfo
def getSystemPromptForQuery(self, prospectTablesAndCols):
schemaName = self.schemaName
platform = self.platform
tableSummaryDict = json.load(open(self.tableSummaryJson, 'r'))
egUserInput = "I want to get top 5 product categories by state, then rank categories on decreasing order of total sales"
cotSubtaskOutput = """{
"subquery1": {
"inputSubquery": [],
"descriptioin":"calculate the total sales and assigns ranks to product categories within each state based on the descending order of sales in the tbl_f_sales table, utilizing joins with tbl_d_product and tbl_d_customer tables.",
"result": "SELECT c.state, b.category, SUM(a.transaction_amount) as total_sales,
RANK() OVER(PARTITION BY c.state ORDER BY SUM(a.transaction_amount) DESC) as category_rank
FROM lpdatamart.tbl_f_sales a
JOIN lpdatamart.tbl_d_product b
ON a.product_id = b.product_id
JOIN lpdatamart.tbl_d_customer c
ON a.customer_id = c.customer_id
GROUP BY c.state, b.category "
},
"subquery2": {
"inputSubquery": ["subquery1"],
"description":"extracts state, category, and total sales information from a subquery named "subquery1," filtering the results to include only categories with ranks up to 5 and sorting them by state and category rank."
"result":"SELECT state, category, total_sales
FROM ranked_categories
WHERE category_rank <= 5
ORDER BY state, category_rank"
},
"finalResult":"WITH subquery1 AS (
SELECT c.state, b.category, SUM(a.transaction_amount) as total_sales,
RANK() OVER(PARTITION BY c.state ORDER BY SUM(a.transaction_amount) DESC) as category_rank
FROM lpdatamart.tbl_f_sales a
JOIN lpdatamart.tbl_d_product b
ON a.product_id = b.product_id
JOIN lpdatamart.tbl_d_customer c
ON a.customer_id = c.customer_id
GROUP BY c.state, b.category
)
SELECT state, category, total_sales
FROM subquery1
WHERE category_rank <= 5
ORDER BY state, category_rank"
}"""
prompt = f"""You are a powerful text to sql model. Your task is to return sql query which answers
user's input. Please follow subquery structure if the sql needs to have multiple subqueries.
###example userInput is {egUserInput}. output is {cotSubtaskOutput}. Output should be in json format as provided. Only output should be in response, nothing else.\n\n
"""
for idx, tableName in enumerate(prospectTablesAndCols.keys(), start=1):
prompt += f"table name is {tableName}, table data is {self.sampleData[tableName][prospectTablesAndCols[tableName]].head(self.gptSampleRows)}"
prompt += f"and table Relations are {TABLE_RELATIONS}"
return prompt.replace("\\"," ").replace(" "," ").replace("XXXX", " ")
|