Spaces:
Runtime error
Runtime error
File size: 12,241 Bytes
648ac33 0ff15a5 3d2478f 648ac33 775d0c7 9928b56 0ff15a5 2a4b462 0ff15a5 2a4b462 0ff15a5 2a4b462 0ff15a5 2a4b462 4099f5c c4e1fde 4099f5c c4e1fde 4099f5c 2a4b462 0ff15a5 071923f 0ff15a5 36d99e9 4099f5c 36d99e9 4099f5c 36d99e9 31528fe 36d99e9 4099f5c 36d99e9 4099f5c 36d99e9 a76a91d 0082791 a76a91d 0082791 a76a91d 36d99e9 3d2478f 2a4b462 0ff15a5 72663f4 2a4b462 0ff15a5 2a4b462 0ff15a5 4099f5c 0ff15a5 66eef17 53577dd 9049952 66eef17 9049952 53577dd c23bd0a 66eef17 0ff15a5 f1037ed 2a4b462 0ff15a5 4099f5c c4e1fde 5228359 df58d60 0ff15a5 4099f5c 66eef17 0ff15a5 4099f5c 0ff15a5 2a4b462 0ff15a5 72663f4 0ff15a5 2a4b462 0ff15a5 4099f5c 0ff15a5 4099f5c 0ff15a5 2a4b462 0ff15a5 2a4b462 72663f4 2a4b462 4099f5c 0ff15a5 4099f5c c4e1fde | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 | from gptManager import ChatgptManager
from utils import *
import json
import sqlparse
from constants import TABLE_RELATIONS
class QueryHelperChainOfThought:
def __init__(self, gptInstanceForCoT: ChatgptManager,
dbEngine, schemaName,
platform, metadataLayout: MetaDataLayout, sampleDataRows,
gptSampleRows, getSampleDataForTablesAndCols, tableSummaryJson='tableSummaryDict.json'):
self.gptInstanceForCoT = gptInstanceForCoT
self.schemaName = schemaName
self.platform = platform
self.metadataLayout = metadataLayout
self.sampleDataRows = sampleDataRows
self.gptSampleRows = gptSampleRows
self.getSampleDataForTablesAndCols = getSampleDataForTablesAndCols
self.dbEngine = dbEngine
self.tableSummaryJson = tableSummaryJson
self._onMetadataChange()
def _onMetadataChange(self):
metadataLayout = self.metadataLayout
sampleDataRows = self.sampleDataRows
dbEngine = self.dbEngine
schemaName = self.schemaName
selectedTablesAndCols = metadataLayout.getSelectedTablesAndCols()
self.sampleData = self.getSampleDataForTablesAndCols(dbEngine=dbEngine,schemaName=schemaName,
tablesAndCols=selectedTablesAndCols, maxRows=sampleDataRows)
def getMetadata(self) -> MetaDataLayout :
return self.metadataLayout
def updateMetadata(self, metadataLayout):
self.metadataLayout = metadataLayout
self._onMetadataChange()
def getQueryForUserInputWithHistory(self, verboseChatHistory, userInput):
prompt = self.getPromptForCot()
self.gptInstanceForCoT.setSystemPrompt(prompt)
gptResponse = self.gptInstanceForCoT.getResponseForChatHistory(verboseChatHistory, userInput)
verboseResponse = gptResponse
query, jsonResponse = getQueryFromGptResponse(gptResponse=gptResponse)
if query!=gptResponse:
finalQuery = construct_final_query(query, jsonResponse)
else:
finalQuery = query
return finalQuery, verboseResponse
def getQueryForUserInputCoT(self, userInput):
prompt = self.getPromptForCot()
self.gptInstanceForCoT.setSystemPrompt(prompt)
gptResponse = self.gptInstanceForCoT.getResponseForUserInput(userInput)
tryParsing = True
parsedSql = False
if tryParsing:
try:
txt = gptResponse.split("```json")[-1].split('```')[0].replace('\n', ' ')
sqlResult = json.loads(txt)['finalResult']
parsedSql = True
tryParsing = False
print("parsed desired result from gpt response using method 1.")
except:
print("Couldn't parse desired result from gpt response using method 1.")
if tryParsing:
try:
sqlResult = json.loads(gptResponse)['finalResult']
parsedSql = True
tryParsing = False
print("parsed desired result from gpt response using method 2.")
except:
print("Couldn't parse desired result from gpt response using method 2")
if tryParsing:
try:
sqlResult = json.loads(gptResponse.replace("```json","").replace("```","").replace('\n', ' '))['finalResult']
parsedSql = True
tryParsing = False
print("parsed desired result from gpt response using method 3.")
except:
print("Couldn't parse desired result from gpt response using method 3")
if parsedSql:
isFormatted = False
try:
formattedSql = sqlparse.format(sqlResult, reindent=True)
responseToReturn = formattedSql
isFormatted = True
except:
isFormatted = False
if not isFormatted:
try:
formattedSql = sqlparse.format(sqlResult['result'], reindent=True)
responseToReturn = formattedSql
print("gpt didn't give parsed result. So parsing again. the formatting.")
except:
responseToReturn = str(sqlResult)
else:
responseToReturn = gptResponse
return responseToReturn
def getPromptForCot(self):
schemaName = self.schemaName
platform = self.platform
tableSummaryDict = json.load(open(self.tableSummaryJson, 'r'))
selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols()
egUserInput = "I want to get top 5 product categories by state, then rank categories on decreasing order of total sales"
cotSubtaskOutput = """{
"subquery1": {
"inputSubquery": [],
"descriptioin":"calculate the total sales and assigns ranks to product categories within each state based on the descending order of sales in the tbl_f_sales table, utilizing joins with tbl_d_product and tbl_d_customer tables.",
"result": "SELECT c.state, b.category, SUM(a.transaction_amount) as total_sales,
RANK() OVER(PARTITION BY c.state ORDER BY SUM(a.transaction_amount) DESC) as category_rank
FROM lpdatamart.tbl_f_sales a
JOIN lpdatamart.tbl_d_product b
ON a.product_id = b.product_id
JOIN lpdatamart.tbl_d_customer c
ON a.customer_id = c.customer_id
GROUP BY c.state, b.category "
},
"subquery2": {
"inputSubquery": ["subquery1"],
"description":"extracts state, category, and total sales information from a subquery named subquery1, filtering the results to include only categories with ranks up to 5 and sorting them by state and category rank.",
"result":"SELECT state, category, total_sales
FROM ranked_categories
WHERE category_rank <= 5
ORDER BY state, category_rank"
},
"finalResult":"WITH subquery1 AS (
SELECT c.state, b.category, SUM(a.transaction_amount) as total_sales,
RANK() OVER(PARTITION BY c.state ORDER BY SUM(a.transaction_amount) DESC) as category_rank
FROM lpdatamart.tbl_f_sales a
JOIN lpdatamart.tbl_d_product b
ON a.product_id = b.product_id
JOIN lpdatamart.tbl_d_customer c
ON a.customer_id = c.customer_id
GROUP BY c.state, b.category
)
SELECT state, category, total_sales
FROM subquery1
WHERE category_rank <= 5
ORDER BY state, category_rank"
}"""
salesInfoPrompt = """
Following is tbl_d_lineaction_code table. and tbl_d_lineaction_code and tbl_f_sales
join on tbl_d_lineaction_code.line_action_code, tbl_f_sales.line_action .
Each transaction New order, canceled, returned gets entry in tbl_f_sales. Please consider them while calculating revenue etc.
line_action_code line_action_code_desc load_date catgory sales_type
---------------- ------------------------------ --------------------- ------- ----------
201 NEW 2015-08-03 00:37:53.0 <null> WEB
202 UNSHIPPED 2015-08-03 00:37:53.0 <null> WEB
203 SHIPPED 2015-08-03 00:37:54.0 <null> WEB
204 CANCELED 2015-08-03 00:37:55.0 <null> WEB
205 RETURN 2015-08-03 00:37:55.0 <null> WEB
1 SOLD 2014-09-17 02:24:28.0 <null> POS
2 RETURNED 2014-09-17 02:24:30.0 <null> POS
7 ORDERED 2014-09-17 02:24:31.0 <null> POS
8 ORDER CANCELLED 2014-09-17 02:24:31.0 <null> POS
90 ORDER DELIVERED 2014-09-17 02:24:32.0 <null> POS
"""
masterIdPrompt = """
###Important
Following is important detail about customer data.
Following is two columns (customer_id and master_customer_id) of tbl_d_customer.
master_customer_id represents real single customer but it can be related to various customer_id which
represets various ways the same customer has logged in or created accounts.
When asked about only customer, use master_customer_id
So when accounting for sales or transaction by particular customer, use their all the related customer_id of their
master_customer_id.
customer_id master_customer_id
---------- ------------------
89998738 1562
96656077 1562
1562 1562
74357421 1562
290007176 1562
"""
demandPrompt = """
For report creation demands means number of quantity sold."""
subLocationPrompt = """Following are information about
OPEN_BOX TBL_D_STORE: SUBLOCATIONCODE: 9009, 90093, 90094
PRO/PRO OPENBOX TBL_D_STORE: SUBLOCATIONCODE: 9004, 90094
CONSUMER/CONSUMER OPENBOX 9003, 90093
"""
promptTableInfo = self.getSystemPromptForTableCols()
selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols()
promptColumnsInfo = self.getSystemPromptForQuery(selectedTablesAndCols)
prompt = f"""You are a powerful text to sql model. Your task is to return sql query which answers
user's input. Please follow subquery structure if the sql needs to have multiple subqueries. Your response should be in JSON format.
Answer user input with sql query. And the query needs to run on {platform}. and schemaName is {schemaName}.
And use columns and tables provided, in case, you need additional column information, please ask the user. or define them in comments
###example userInput is {egUserInput}. output is {cotSubtaskOutput}. Output should be in json format as provided. Only output should be in response, nothing else.\n\n
tables information are {promptTableInfo}.
columns data are {promptColumnsInfo}.
Following are extra information which user might assume you know.
{salesInfoPrompt}
{masterIdPrompt}
{demandPrompt}
{subLocationPrompt}.
"""
prompt += f"and table Relations are {TABLE_RELATIONS} "
return prompt
def getSystemPromptForTableCols(self):
schemaName = self.schemaName
platform = self.platform
tableSummaryDict = json.load(open(self.tableSummaryJson, 'r'))
selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols()
promptTableInfo = f"""You are a powerful text to sql model. Answer which tables and columns are needed
to answer user input using sql query. and following are tables and columns info. and example user input and result query."""
for idx, tableName in enumerate(selectedTablesAndCols.keys(), start=1):
promptTableInfo += f"table name {tableName} and summary is {tableSummaryDict[tableName]} "
promptTableInfo += f" and columns {', '.join(selectedTablesAndCols[tableName])} \n "
promptTableInfo += "XXXX"
#Join statements
promptTableInfo += f" and table Relations are {TABLE_RELATIONS} "
return promptTableInfo
def getSystemPromptForQuery(self, prospectTablesAndCols):
schemaName = self.schemaName
platform = self.platform
tableSummaryDict = json.load(open(self.tableSummaryJson,'r'))
exampleQuery = """SELECT a.customer_id, COUNT(a.product_id) as chandelier_count
FROM lpdatamart.tbl_f_sales a
JOIN lpdatamart.tbl_d_product b ON a.product_id = b.product_id
JOIN lpdatamart.tbl_d_calendar c ON a.date_id = c.date_id
WHERE UPPER(b.product_name) LIKE '%CHANDELIER%' AND c.calendar_month = 'NOVEMBER' AND c.year = 2023
GROUP BY a.customer_id
ORDER BY chandelier_count DESC"""
question = "top 5 customers who bought most chandeliers in nov 2023"
promptForQuery = f"""You are a powerful text to sql model. Answer user input with sql query. And the query needs to run on {platform}. and schemaName is {schemaName}. There is example user input and desired generated sql query. Follow similar patterns as example. eg case insensitive, explicit variable declaration etc. user input : {question}, query : {exampleQuery}. and table's data is \n """
for idx, tableName in enumerate(prospectTablesAndCols.keys(), start=1):
promptForQuery += f"table name is {tableName}, table data is {self.sampleData[tableName][prospectTablesAndCols[tableName]].head(self.gptSampleRows)} \n "
promptForQuery += f"and table Relations are {TABLE_RELATIONS} \n "
return promptForQuery.replace("\\"," ").replace(" "," ").replace("XXXX", " ") |