Spaces:
Runtime error
Runtime error
update test error handling
Browse files- queryHelperManagerCoT.py +15 -54
queryHelperManagerCoT.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
-
from gptManager import ChatgptManager
|
| 2 |
-
from utils import *
|
| 3 |
import json
|
| 4 |
import sqlparse
|
| 5 |
-
from constants import TABLE_RELATIONS
|
| 6 |
|
| 7 |
class QueryHelperChainOfThought:
|
| 8 |
def __init__(self, gptInstanceForCoT: ChatgptManager,
|
|
@@ -42,53 +42,13 @@ class QueryHelperChainOfThought:
|
|
| 42 |
self.gptInstanceForCoT.setSystemPrompt(prompt)
|
| 43 |
gptResponse = self.gptInstanceForCoT.getResponseForChatHistory(verboseChatHistory, userInput)
|
| 44 |
verboseResponse = gptResponse
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
if
|
| 48 |
-
|
| 49 |
-
txt = gptResponse.split("```json")[-1].split('```')[0].replace('\n', ' ')
|
| 50 |
-
sqlResult = json.loads(txt)['finalResult']
|
| 51 |
-
parsedSql = True
|
| 52 |
-
tryParsing = False
|
| 53 |
-
print("parsed desired result from gpt response using method 1.")
|
| 54 |
-
except:
|
| 55 |
-
print("Couldn't parse desired result from gpt response using method 1.")
|
| 56 |
-
if tryParsing:
|
| 57 |
-
try:
|
| 58 |
-
sqlResult = json.loads(gptResponse.replace("```json","").replace("```","").replace('\n', ' '))['finalResult']
|
| 59 |
-
parsedSql = True
|
| 60 |
-
tryParsing = False
|
| 61 |
-
print("parsed desired result from gpt response using method 2.")
|
| 62 |
-
except:
|
| 63 |
-
print("Couldn't parse desired result from gpt response using method 2")
|
| 64 |
-
if tryParsing:
|
| 65 |
-
try:
|
| 66 |
-
txt = gptResponse.split("```json")[-1].split('```')[0].replace('\n', ' ')
|
| 67 |
-
jsonResponse = json.loads(txt)
|
| 68 |
-
sqlResult = jsonResponse[list(jsonResponse.keys())[0]]['result']
|
| 69 |
-
parsedSql = True
|
| 70 |
-
tryParsing = False
|
| 71 |
-
print("parsed desired result from gpt response using method 3.")
|
| 72 |
-
except:
|
| 73 |
-
print("Couldn't parse desired result from gpt response using method 3.")
|
| 74 |
-
if parsedSql:
|
| 75 |
-
isFormatted = False
|
| 76 |
-
try:
|
| 77 |
-
formattedSql = sqlparse.format(sqlResult, reindent=True)
|
| 78 |
-
responseToReturn = formattedSql
|
| 79 |
-
isFormatted = True
|
| 80 |
-
except:
|
| 81 |
-
isFormatted = False
|
| 82 |
-
if not isFormatted:
|
| 83 |
-
try:
|
| 84 |
-
formattedSql = sqlparse.format(sqlResult['result'], reindent=True)
|
| 85 |
-
responseToReturn = formattedSql
|
| 86 |
-
print("gpt didn't give parsed result. So parsing again. the formatting.")
|
| 87 |
-
except:
|
| 88 |
-
responseToReturn = str(sqlResult)
|
| 89 |
else:
|
| 90 |
-
|
| 91 |
-
return
|
| 92 |
|
| 93 |
|
| 94 |
def getQueryForUserInputCoT(self, userInput):
|
|
@@ -144,7 +104,7 @@ class QueryHelperChainOfThought:
|
|
| 144 |
def getPromptForCot(self):
|
| 145 |
schemaName = self.schemaName
|
| 146 |
platform = self.platform
|
| 147 |
-
tableSummaryDict = json.load(open(
|
| 148 |
selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols()
|
| 149 |
|
| 150 |
egUserInput = "I want to get top 5 product categories by state, then rank categories on decreasing order of total sales"
|
|
@@ -237,6 +197,8 @@ ORDER BY state, category_rank"
|
|
| 237 |
|
| 238 |
prompt = f"""You are a powerful text to sql model. Your task is to return sql query which answers
|
| 239 |
user's input. Please follow subquery structure if the sql needs to have multiple subqueries. Your response should be in JSON format.
|
|
|
|
|
|
|
| 240 |
###example userInput is {egUserInput}. output is {cotSubtaskOutput}. Output should be in json format as provided. Only output should be in response, nothing else.\n\n
|
| 241 |
tables information are {promptTableInfo}.
|
| 242 |
columns data are {promptColumnsInfo}.
|
|
@@ -254,7 +216,7 @@ ORDER BY state, category_rank"
|
|
| 254 |
def getSystemPromptForTableCols(self):
|
| 255 |
schemaName = self.schemaName
|
| 256 |
platform = self.platform
|
| 257 |
-
tableSummaryDict = json.load(open(
|
| 258 |
selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols()
|
| 259 |
|
| 260 |
promptTableInfo = f"""You are a powerful text to sql model. Answer which tables and columns are needed
|
|
@@ -271,7 +233,7 @@ ORDER BY state, category_rank"
|
|
| 271 |
def getSystemPromptForQuery(self, prospectTablesAndCols):
|
| 272 |
schemaName = self.schemaName
|
| 273 |
platform = self.platform
|
| 274 |
-
tableSummaryDict = json.load(open(
|
| 275 |
exampleQuery = """SELECT a.customer_id, COUNT(a.product_id) as chandelier_count
|
| 276 |
FROM lpdatamart.tbl_f_sales a
|
| 277 |
JOIN lpdatamart.tbl_d_product b ON a.product_id = b.product_id
|
|
@@ -285,5 +247,4 @@ ORDER BY chandelier_count DESC"""
|
|
| 285 |
for idx, tableName in enumerate(prospectTablesAndCols.keys(), start=1):
|
| 286 |
promptForQuery += f"table name is {tableName}, table data is {self.sampleData[tableName][prospectTablesAndCols[tableName]].head(self.gptSampleRows)} \n "
|
| 287 |
promptForQuery += f"and table Relations are {TABLE_RELATIONS} \n "
|
| 288 |
-
return promptForQuery.replace("\\"," ").replace(" "," ").replace("XXXX", " ")
|
| 289 |
-
|
|
|
|
| 1 |
+
from .gptManager import ChatgptManager
|
| 2 |
+
from .utils import *
|
| 3 |
import json
|
| 4 |
import sqlparse
|
| 5 |
+
from .constants import TABLE_RELATIONS
|
| 6 |
|
| 7 |
class QueryHelperChainOfThought:
|
| 8 |
def __init__(self, gptInstanceForCoT: ChatgptManager,
|
|
|
|
| 42 |
self.gptInstanceForCoT.setSystemPrompt(prompt)
|
| 43 |
gptResponse = self.gptInstanceForCoT.getResponseForChatHistory(verboseChatHistory, userInput)
|
| 44 |
verboseResponse = gptResponse
|
| 45 |
+
query, jsonResponse = getQueryFromGptResponse(gptResponse=gptResponse)
|
| 46 |
+
|
| 47 |
+
if query!=gptResponse:
|
| 48 |
+
finalQuery = construct_final_query(query, jsonResponse)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
else:
|
| 50 |
+
finalQuery = query
|
| 51 |
+
return finalQuery, verboseResponse
|
| 52 |
|
| 53 |
|
| 54 |
def getQueryForUserInputCoT(self, userInput):
|
|
|
|
| 104 |
def getPromptForCot(self):
|
| 105 |
schemaName = self.schemaName
|
| 106 |
platform = self.platform
|
| 107 |
+
tableSummaryDict = json.load(open(r"./core/queryHelper/tableSummaryDict.json", 'r'))
|
| 108 |
selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols()
|
| 109 |
|
| 110 |
egUserInput = "I want to get top 5 product categories by state, then rank categories on decreasing order of total sales"
|
|
|
|
| 197 |
|
| 198 |
prompt = f"""You are a powerful text to sql model. Your task is to return sql query which answers
|
| 199 |
user's input. Please follow subquery structure if the sql needs to have multiple subqueries. Your response should be in JSON format.
|
| 200 |
+
Answer user input with sql query. And the query needs to run on {platform}. and schemaName is {schemaName}.
|
| 201 |
+
And use columns and tables provided, in case, you need additional column information, please ask the user.
|
| 202 |
###example userInput is {egUserInput}. output is {cotSubtaskOutput}. Output should be in json format as provided. Only output should be in response, nothing else.\n\n
|
| 203 |
tables information are {promptTableInfo}.
|
| 204 |
columns data are {promptColumnsInfo}.
|
|
|
|
| 216 |
def getSystemPromptForTableCols(self):
|
| 217 |
schemaName = self.schemaName
|
| 218 |
platform = self.platform
|
| 219 |
+
tableSummaryDict = json.load(open(r"./core/queryHelper/tableSummaryDict.json", 'r'))
|
| 220 |
selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols()
|
| 221 |
|
| 222 |
promptTableInfo = f"""You are a powerful text to sql model. Answer which tables and columns are needed
|
|
|
|
| 233 |
def getSystemPromptForQuery(self, prospectTablesAndCols):
|
| 234 |
schemaName = self.schemaName
|
| 235 |
platform = self.platform
|
| 236 |
+
tableSummaryDict = json.load(open(r"./core/queryHelper/tableSummaryDict.json",'r'))
|
| 237 |
exampleQuery = """SELECT a.customer_id, COUNT(a.product_id) as chandelier_count
|
| 238 |
FROM lpdatamart.tbl_f_sales a
|
| 239 |
JOIN lpdatamart.tbl_d_product b ON a.product_id = b.product_id
|
|
|
|
| 247 |
for idx, tableName in enumerate(prospectTablesAndCols.keys(), start=1):
|
| 248 |
promptForQuery += f"table name is {tableName}, table data is {self.sampleData[tableName][prospectTablesAndCols[tableName]].head(self.gptSampleRows)} \n "
|
| 249 |
promptForQuery += f"and table Relations are {TABLE_RELATIONS} \n "
|
| 250 |
+
return promptForQuery.replace("\\"," ").replace(" "," ").replace("XXXX", " ")
|
|
|