Spaces:
Runtime error
Runtime error
Commit ·
11a349e
1
Parent(s): 7d0c63c
added CoT in select table then get query prompt style
Browse files- queryHelperManager.py +83 -14
queryHelperManager.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
from gptManager import ChatgptManager
|
| 2 |
from utils import *
|
| 3 |
import json
|
|
|
|
| 4 |
from constants import TABLE_RELATIONS
|
| 5 |
|
| 6 |
class QueryHelper:
|
|
@@ -54,7 +55,42 @@ class QueryHelper:
|
|
| 54 |
promptForQuery = self.getSystemPromptForQuery(prospectTablesAndCols)
|
| 55 |
self.gptInstanceForQuery.setSystemPrompt(promptForQuery)
|
| 56 |
gptResponse = self.gptInstanceForQuery.getResponseForUserInput(userInput)
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
def getSystemPromptForTableCols(self):
|
| 60 |
schemaName = self.schemaName
|
|
@@ -77,18 +113,51 @@ class QueryHelper:
|
|
| 77 |
schemaName = self.schemaName
|
| 78 |
platform = self.platform
|
| 79 |
tableSummaryDict = json.load(open(self.tableSummaryJson, 'r'))
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
for idx, tableName in enumerate(prospectTablesAndCols.keys(), start=1):
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
return
|
| 94 |
|
|
|
|
| 1 |
from gptManager import ChatgptManager
|
| 2 |
from utils import *
|
| 3 |
import json
|
| 4 |
+
import sqlparse
|
| 5 |
from constants import TABLE_RELATIONS
|
| 6 |
|
| 7 |
class QueryHelper:
|
|
|
|
| 55 |
promptForQuery = self.getSystemPromptForQuery(prospectTablesAndCols)
|
| 56 |
self.gptInstanceForQuery.setSystemPrompt(promptForQuery)
|
| 57 |
gptResponse = self.gptInstanceForQuery.getResponseForUserInput(userInput)
|
| 58 |
+
#following CoT in select column then get query to save tokens
|
| 59 |
+
tryParsing = True
|
| 60 |
+
parsedSql = False
|
| 61 |
+
if tryParsing:
|
| 62 |
+
try:
|
| 63 |
+
txt = gptResponse.split("```json")[-1].split('```')[0].replace('\n', '')
|
| 64 |
+
sqlResult = json.loads(txt)['finalResult']
|
| 65 |
+
parsedSql = True
|
| 66 |
+
tryParsing = False
|
| 67 |
+
except:
|
| 68 |
+
print("Couldn't parse desired result from gpt response using method 1.")
|
| 69 |
+
if tryParsing:
|
| 70 |
+
try:
|
| 71 |
+
sqlResult = json.loads(gptResponse)['finalResult']
|
| 72 |
+
parsedSql = True
|
| 73 |
+
tryParsing = False
|
| 74 |
+
except:
|
| 75 |
+
print("Couldn't parse desired result from gpt response using method 2")
|
| 76 |
+
if parsedSql:
|
| 77 |
+
isFormatted = False
|
| 78 |
+
try:
|
| 79 |
+
formattedSql = sqlparse.format(sqlResult, reindent=True)
|
| 80 |
+
responseToReturn = formattedSql
|
| 81 |
+
isFormatted = True
|
| 82 |
+
except:
|
| 83 |
+
isFormatted = False
|
| 84 |
+
if not isFormatted:
|
| 85 |
+
try:
|
| 86 |
+
formattedSql = sqlparse.format(sqlResult['result'], reindent=True)
|
| 87 |
+
responseToReturn = formattedSql
|
| 88 |
+
print("gpt didn't give parsed result. So parsing again. the formatting.")
|
| 89 |
+
except:
|
| 90 |
+
responseToReturn = str(sqlResult)
|
| 91 |
+
else:
|
| 92 |
+
responseToReturn = gptResponse
|
| 93 |
+
return responseToReturn
|
| 94 |
|
| 95 |
def getSystemPromptForTableCols(self):
|
| 96 |
schemaName = self.schemaName
|
|
|
|
| 113 |
schemaName = self.schemaName
|
| 114 |
platform = self.platform
|
| 115 |
tableSummaryDict = json.load(open(self.tableSummaryJson, 'r'))
|
| 116 |
+
egUserInput = "I want to get top 5 product categories by state, then rank categories on decreasing order of total sales"
|
| 117 |
+
|
| 118 |
+
cotSubtaskOutput = """{
|
| 119 |
+
"subquery1": {
|
| 120 |
+
"inputSubquery": [],
|
| 121 |
+
"descriptioin":"calculate the total sales and assigns ranks to product categories within each state based on the descending order of sales in the tbl_f_sales table, utilizing joins with tbl_d_product and tbl_d_customer tables.",
|
| 122 |
+
"result": "SELECT c.state, b.category, SUM(a.transaction_amount) as total_sales,
|
| 123 |
+
RANK() OVER(PARTITION BY c.state ORDER BY SUM(a.transaction_amount) DESC) as category_rank
|
| 124 |
+
FROM lpdatamart.tbl_f_sales a
|
| 125 |
+
JOIN lpdatamart.tbl_d_product b
|
| 126 |
+
ON a.product_id = b.product_id
|
| 127 |
+
JOIN lpdatamart.tbl_d_customer c
|
| 128 |
+
ON a.customer_id = c.customer_id
|
| 129 |
+
GROUP BY c.state, b.category "
|
| 130 |
+
},
|
| 131 |
+
"subquery2": {
|
| 132 |
+
"inputSubquery": ["subquery1"],
|
| 133 |
+
"description":"extracts state, category, and total sales information from a subquery named "subquery1," filtering the results to include only categories with ranks up to 5 and sorting them by state and category rank."
|
| 134 |
+
"result":"SELECT state, category, total_sales
|
| 135 |
+
FROM ranked_categories
|
| 136 |
+
WHERE category_rank <= 5
|
| 137 |
+
ORDER BY state, category_rank"
|
| 138 |
+
},
|
| 139 |
+
"finalResult":"WITH subquery1 AS (
|
| 140 |
+
SELECT c.state, b.category, SUM(a.transaction_amount) as total_sales,
|
| 141 |
+
RANK() OVER(PARTITION BY c.state ORDER BY SUM(a.transaction_amount) DESC) as category_rank
|
| 142 |
+
FROM lpdatamart.tbl_f_sales a
|
| 143 |
+
JOIN lpdatamart.tbl_d_product b
|
| 144 |
+
ON a.product_id = b.product_id
|
| 145 |
+
JOIN lpdatamart.tbl_d_customer c
|
| 146 |
+
ON a.customer_id = c.customer_id
|
| 147 |
+
GROUP BY c.state, b.category
|
| 148 |
+
)
|
| 149 |
+
SELECT state, category, total_sales
|
| 150 |
+
FROM subquery1
|
| 151 |
+
WHERE category_rank <= 5
|
| 152 |
+
ORDER BY state, category_rank"
|
| 153 |
+
}"""
|
| 154 |
+
prompt = f"""You are a powerful text to sql model. Your task is to return sql query which answers
|
| 155 |
+
user's input. Please follow subquery structure if the sql needs to have multiple subqueries.
|
| 156 |
+
###example userInput is {egUserInput}. output is {cotSubtaskOutput}. Output should be in json format as provided. Only output should be in response, nothing else.\n\n
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
for idx, tableName in enumerate(prospectTablesAndCols.keys(), start=1):
|
| 160 |
+
prompt += f"table name is {tableName}, table data is {self.sampleData[tableName][prospectTablesAndCols[tableName]].head(self.gptSampleRows)}"
|
| 161 |
+
prompt += f"and table Relations are {TABLE_RELATIONS}"
|
| 162 |
+
return prompt.replace("\\"," ").replace(" "," ").replace("XXXX", " ")
|
| 163 |
|