Spaces:
Runtime error
Runtime error
File size: 8,395 Bytes
775d0c7 9b37f9e 775d0c7 2a4b462 ff7bbd5 2a4b462 ff7bbd5 1eca8aa e376fbb 2a4b462 f789cd9 2a4b462 8e7d4e9 2a4b462 8e7d4e9 2a4b462 8e7d4e9 2a4b462 8e7d4e9 2a4b462 ff7bbd5 2a4b462 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | from gptManager import ChatgptManager
from utils import *
import re
import json
class QueryHelperChainOfThought:
def __init__(self, gptInstance: ChatgptManager, dbEngine, schemaName,
platform, metadataLayout: MetaDataLayout, sampleDataRows,
gptSampleRows, getSampleDataForTablesAndCols):
self.gptInstance = gptInstance
self.schemaName = schemaName
self.platform = platform
self.metadataLayout = metadataLayout
self.sampleDataRows = sampleDataRows
self.gptSampleRows = gptSampleRows
self.getSampleDataForTablesAndCols = getSampleDataForTablesAndCols
self.dbEngine = dbEngine
self._onMetadataChange()
def _onMetadataChange(self):
metadataLayout = self.metadataLayout
sampleDataRows = self.sampleDataRows
dbEngine = self.dbEngine
schemaName = self.schemaName
selectedTablesAndCols = metadataLayout.getSelectedTablesAndCols()
self.sampleData = self.getSampleDataForTablesAndCols(dbEngine=dbEngine,schemaName=schemaName,
tablesAndCols=selectedTablesAndCols, maxRows=sampleDataRows)
def getMetadata(self) -> MetaDataLayout :
return self.metadataLayout
def updateMetadata(self, metadataLayout):
self.metadataLayout = metadataLayout
self._onMetadataChange()
def modifySqlQueryEnteredByUser(self, userSqlQuery):
platform = self.platform
userPrompt = f"Please correct the following sql query, also it has to be run on {platform}. sql query is \n {userSqlQuery}."
systemPrompt = ""
modifiedSql = self.gptInstance.getResponseForUserInput(userPrompt, systemPrompt)
return modifiedSql
def filteredSampleDataForProspects(self, prospectTablesAndCols):
sampleData = self.sampleData
filteredData = {}
for table in prospectTablesAndCols.keys():
# filteredData[table] = sampleData[table][prospectTablesAndCols[table]]
#take all columns of prospects
filteredData[table] = sampleData[table][prospectTablesAndCols[table]]
return filteredData
def extractSingleJson(self, text):
pattern = r'\{.*?\}'
matches = re.findall(pattern, text, re.DOTALL)
extracted_json = [json.loads(match) for match in matches][0]
return extracted_json
def getQueryForUserInputCoT(self, userInput):
selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols()
systemPromptTail = self.getSystemPromptTailForCoTStep1(selectedTablesAndCols)
#1. Is the input complete to create a query, or ask user to reask with more detailed input
systemPromptForInputClarification = """Given an input text, user want to generate sql query. Please answer if the user input is complete or user needs to ask in more detailed way. Answer in following format. 'Yes' ; if yes, break the userinput into smaller subtask for query generation. Formatted into
{
"Task 1": "task 1 description",
"Task 2": "task 2 description"
}
'No' ; if no, then Reason- please be more detailed about customer details; if more modification needed."""
systemPromptForInputClarification = systemPromptForInputClarification + '\n' + systemPromptTail
cotStep1 = self.gptInstance.getResponseForUserInput(userInput, systemPromptForInputClarification)
if "yes" in cotStep1.lower()[:5]:
print("User input sufficient")
tasks = self.extractSingleJson(cotStep1)
print(f"tasks are {tasks}")
taskQueries = {}
prospectTablesAndColsAll = []
for key, task in tasks.items():
taskQuery, prospectTablesAndCols = self.getQueryForUserInput(userInput)
taskQueries[key] = {"task":task, "taskQuery":taskQuery}
prospectTablesAndColsAll.append(prospectTablesAndCols)
print(f"tasks and their queries {taskQueries}")
combiningSubtasksQueryPrompt = f"""Combine following subtask and their queries to generate sql query to answer the user input.\n """
userPrompt = f"user input is {userInput}"
for key in taskQueries.keys():
task = taskQueries[key]["task"]
query = taskQueries[key]["taskQuery"]
userPrompt += f" task: {task}, task query: {query}"
return self.gptInstance.getResponseForUserInput(userPrompt, combiningSubtasksQueryPrompt), prospectTablesAndColsAll
return f"Please rephrase your query. {' '.join(cotStep1.split('Reason')[1:])}", None
def getQueryForUserInput(self, userInput, chatHistory=[]):
gptSampleRows = self.gptSampleRows
selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols()
prospectTablesAndCols = self.getProspectiveTablesAndCols(userInput, selectedTablesAndCols, chatHistory)
print("getting prospects", prospectTablesAndCols)
prospectTablesData = self.filteredSampleDataForProspects(prospectTablesAndCols)
systemPromptForQueryGeneration = self.getSystemPromptForQueryGeneration(prospectTablesData, gptSampleRows=gptSampleRows)
queryByGpt = self.gptInstance.getResponseForUserInput(userInput, systemPromptForQueryGeneration, chatHistory)
queryByGpt = preProcessGptQueryReponse(queryByGpt, metadataLayout=self.metadataLayout)
return queryByGpt, prospectTablesAndCols
def getProspectiveTablesAndCols(self, userInput, selectedTablesAndCols, chatHistory=[]):
schemaName = self.schemaName
systemPromptForProspectColumns = self.getSystemPromptForProspectColumns(selectedTablesAndCols)
prospectiveTablesColsText = self.gptInstance.getResponseForUserInput(userInput, systemPromptForProspectColumns, chatHistory)
prospectTablesAndCols = {}
for table in selectedTablesAndCols.keys():
if table in prospectiveTablesColsText:
prospectTablesAndCols[table] = []
for column in selectedTablesAndCols[table]:
if column in prospectiveTablesColsText:
prospectTablesAndCols[table].append(column)
return prospectTablesAndCols
def getSystemPromptForQueryGeneration(self, prospectTablesData, gptSampleRows):
schemaName = self.schemaName
platform = self.platform
exampleQuery = """SELECT a.customer_id, COUNT(a.product_id) as chandelier_count
FROM lpdatamart.tbl_f_sales a
JOIN lpdatamart.tbl_d_product b ON a.product_id = b.product_id
JOIN lpdatamart.tbl_d_calendar c ON a.date_id = c.date_id
WHERE UPPER(b.product_name) LIKE '%CHANDELIER%' AND c.calendar_month = 'NOVEMBER' AND c.year = 2023
GROUP BY a.customer_id
ORDER BY chandelier_count DESC"""
question = "top 5 customers who bought most chandeliers in nov 2023"
prompt = f"""Given an input text, generate the corresponding SQL query for given details. Schema Name is {schemaName}. And sql platform is {platform}.\n following is sample data. Also
There is example user input and desired generated sql query. Follow similar patterns as example. eg case insensitive, explicit variable declaration etc. user input : {question}, query : {exampleQuery} """
for idx, tableName in enumerate(prospectTablesData.keys(), start=1):
prompt += f"table name is {tableName}, table data is {prospectTablesData[tableName].head(gptSampleRows)}"
prompt += "XXXX"
return prompt.replace("\n"," ").replace("\\"," ").replace(" "," ").replace("XXXX", " ")
def getSystemPromptForProspectColumns(self, selectedTablesAndCols):
schemaName = self.schemaName
platform = self.platform
prompt = f"""Given an input text, User wants to know which all tables and columns would be possibily to have the desired data. Output them as json. Schema Name is {schemaName}. And sql platform is {platform}.\n"""
for idx, tableName in enumerate(selectedTablesAndCols.keys(), start=1):
prompt += f"table name {tableName} {', '.join(selectedTablesAndCols[tableName])}"
prompt += "XXXX"
return prompt.replace("\n"," ").replace("\\"," ").replace(" "," ").replace("XXXX", " ")
def getSystemPromptTailForCoTStep1(self, selectedTablesAndCols):
schemaName = self.schemaName
platform = self.platform
prompt = f"""schema name is {schemaName}. And sql platform is {platform}. and table info are below.\n"""
for idx, tableName in enumerate(selectedTablesAndCols.keys(), start=1):
prompt += f"table name {tableName} {', '.join(selectedTablesAndCols[tableName])}"
prompt += "XXXX"
return prompt.replace("\n"," ").replace("\\"," ").replace(" "," ").replace("XXXX", " ") |