QueryHelper / utils.py
anumaurya114exp's picture
Update utils.py
ce776f8 verified
import psycopg2
import re
import pandas as pd
from persistStorage import retrieveTablesDataFromLocalDb, saveTablesDataToLocalDB
from config import SCHEMA_INFO_FILE_PATH
import os
import pickle
import sqlparse
import json
class DataWrapper:
def __init__(self, data):
if isinstance(data, list):
emptyDict = {dataKey:None for dataKey in data}
self.__dict__.update(emptyDict)
elif isinstance(data, dict):
self.__dict__.update(data)
def addKey(self, key, val=None):
self.__dict__.update({key:val})
def __repr__(self):
return self.__dict__.__repr__()
class MetaDataLayout:
def __init__(self, schemaName, allTablesAndCols):
self.schemaName = schemaName
self.datalayout = {
"schema": self.schemaName,
"selectedTables":{},
"allTables":allTablesAndCols
}
def getDataLayout(self):
return self.datalayout
def setSelection(self, tablesAndCols):
"""
tablesAndCols : {"table1":["col1", "col2"], "table1":["cola","colb"]}
"""
datalayout = self.datalayout
for table in tablesAndCols:
if table in datalayout['allTables'].keys():
datalayout['selectedTables'][table] = tablesAndCols[table]
else:
print(f"Table {table} doesn't exists in the schema")
self.datalayout = datalayout
def resetSelection(self):
datalayout = self.datalayout
datalayout['selectedTables'] = {}
self.datalayout = datalayout
def getSelectedTablesAndCols(self):
return self.datalayout['selectedTables']
def getAllTablesCols(self):
return self.datalayout['allTables']
class DbEngine:
def __init__(self, dbCreds):
self.dbCreds = dbCreds
self._connection = None
def connect(self):
dbCreds = self.dbCreds
keepaliveKwargs = {
"keepalives": 1,
"keepalives_idle": 100,
"keepalives_interval": 5,
"keepalives_count": 5,
}
if self._connection is None or self._connection.closed != 0:
self._connection = psycopg2.connect(database=dbCreds.database, user = dbCreds.user,
password = dbCreds.password, host = dbCreds.host,
port = dbCreds.port, **keepaliveKwargs)
def getConnection(self):
if self._connection is None or self._connection.closed != 0:
self.connect()
return self._connection
def disconnect(self):
if self._connection is not None and self._connection.closed == 0:
self._connection.close()
def execute_query(self, query):
try:
self.connect()
with self._connection.cursor() as cursor:
cursor.execute(query)
result = cursor.fetchall()
except Exception as e:
raise Exception(e)
return result
def executeQuery(dbEngine, query):
result = dbEngine.execute_query(query)
return result
def executeColumnsQuery(dbEngine, columnQuery):
dbEngine.connect()
with dbEngine._connection.cursor() as cursor:
cursor.execute(columnQuery)
columns = [desc[0] for desc in cursor.description]
return columns
def closeDbEngine(dbEngine):
dbEngine.disconnect()
def getAllTablesInfo(dbEngine, schemaName, useCache=True):
if useCache:
if os.path.isfile(SCHEMA_INFO_FILE_PATH):
with open(SCHEMA_INFO_FILE_PATH,'rb') as fh:
tablesAndCols = pickle.load(fh)
return tablesAndCols
tablesAndCols = {}
print("Getting tables Info, list of tables and columns...")
allTablesQuery = f"""SELECT table_name FROM information_schema.tables
WHERE table_schema = '{schemaName}'"""
tables = executeQuery(dbEngine, allTablesQuery)
for table in tables:
tableName = table[0]
columnsQuery = f"""Select * FROM {schemaName}.{tableName} LIMIT 0"""
columns = executeColumnsQuery(dbEngine, columnsQuery)
tablesAndCols[tableName] = columns
with open(SCHEMA_INFO_FILE_PATH, 'wb') as fh:
pickle.dump(tablesAndCols, fh)
return tablesAndCols
def getSampleDataForTablesAndCols(dbEngine, schemaName, tablesAndCols, maxRows):
data = retrieveTablesDataFromLocalDb(list(tablesAndCols.keys()))
if data!={}:
return data
dbEngine.connect()
conn = dbEngine.getConnection()
print("Didn't find any cache/valid cache.")
print("Getting data from aws redshift")
for table in tablesAndCols.keys():
try:
sqlQuery = f"""select * from {schemaName}.{table} limit {maxRows}"""
data[table] = pd.read_sql_query(sqlQuery, con=conn)
except:
print(f"couldn't read table data. Table: {table}")
data[table] = pd.DataFrame({})
saveTablesDataToLocalDB(data)
return data
# Function to test the generated sql query
def isDataQuery(sql_query):
upper_query = sql_query.upper()
dml_keywords = ['INSERT', 'UPDATE', 'DELETE', 'MERGE']
for keyword in dml_keywords:
if re.search(fr'\b{keyword}\b', upper_query):
return False # Found a DML keyword, indicating modification
# If no DML keywords are found, it's likely a data query
return True
def extractSqlFromGptResponse(gptReponse):
sqlPattern = re.compile(r"```sql\n(.*?)```", re.DOTALL)
# Find the match in the text
match = re.search(sqlPattern, gptReponse)
# Extract the SQL query if a match is found
if match:
sqlQuery = match.group(1)
return sqlQuery
else:
return ""
def addSchemaToTableInSQL(sqlQuery, schemaName, tablesList):
for table in tablesList:
pattern = re.compile(rf'(?<!\S){re.escape(table)}(?!\S)', re.IGNORECASE)
replacement = f'{schemaName}.{table}'
sqlQuery = re.sub(pattern, replacement, sqlQuery)
return sqlQuery
def preProcessGptQueryReponse(gptResponse, metadataLayout: MetaDataLayout):
schemaName = metadataLayout.schemaName
tablesList = metadataLayout.getAllTablesCols().keys()
gptResponse = addSchemaToTableInSQL(gptResponse, schemaName=schemaName, tablesList=tablesList)
return gptResponse
def remove_with_as(sql_query):
pattern = r'WITH\s+.*?AS\s*\((?:[^()]|\((?:[^()]+|\([^()]*\))*\))*\)\s*,?'
cte_pattern = re.compile(pattern, re.IGNORECASE | re.DOTALL)
sql_query = cte_pattern.sub('', sql_query)
pattern = r'WITH\s+.*?AS\s*\((?:[^()]|\((?:[^()]+|\([^()]*\))*\))*\)\s*'
cte_pattern = re.compile(pattern, re.IGNORECASE | re.DOTALL)
sql_query = cte_pattern.sub('', sql_query)
pattern = r'subquery\d+\s+AS\s*\((?:[^()]|\((?:[^()]+|\([^()]*\))*\))*\)\s*'
cte_pattern = re.compile(pattern, re.IGNORECASE | re.DOTALL)
sql_query = cte_pattern.sub('', sql_query)
return sql_query
def construct_with_stats(final_query, subquery_info):
with_as_statements = []
for subquery_name, subquery in subquery_info.items():
description = subquery['description'].replace('\n',' ')
with_as_statement = f"{subquery_name} AS (\n{subquery['result']}\n) \n--{description}\n"
with_as_statements.append(with_as_statement)
with_as_statement = "\n,\n".join(with_as_statements)
if len(with_as_statements)>0:
final_query_with_with_as = f"WITH {with_as_statement}\n{final_query}"
else:
final_query_with_with_as = final_query
return final_query_with_with_as
def get_keys_matching_pattern(dictionary, pattern):
return [key for key in dictionary.keys() if re.match(pattern, key)]
def get_subquery_info(json_response):
subquery_keys = get_keys_matching_pattern(json_response, r'subquery\d+')
return {key:json_response[key] for key in subquery_keys}
def add_single_query_description(query, json_response):
if json_response.get("query",None)!=None:
description = json_response['query']['description'].replace('\n',' ')
query = f"{query}\n --{description}\n"
return query
def construct_final_query(query, json_response):
query = remove_with_as(query)
subquery_info = get_subquery_info(json_response)
final_query_with_with_as = construct_with_stats(query, subquery_info)
final_query_with_with_as = add_single_query_description(final_query_with_with_as, json_response)
final_query_with_with_as = final_query_with_with_as.replace('\n\n','\n')
final_query_with_with_as = sqlparse.format(final_query_with_with_as, reindent=True)
return final_query_with_with_as
def getQueryFromGptResponse(gptResponse):
tryParsing = True
parsedSql = False
if tryParsing:
try:
txt = gptResponse.split("```json")[-1].split('```')[0].replace('\n', ' ')
jsonResponse = json.loads(txt)
sqlResult = jsonResponse['finalResult']
parsedSql = True
tryParsing = False
print("parsed desired result from gpt response using method 1.")
except:
print("Couldn't parse desired result from gpt response using method 1.")
if tryParsing:
try:
jsonResponse = json.loads(gptResponse.replace("```json","").replace("```","").replace('\n', ' '))
sqlResult = jsonResponse['finalResult']
parsedSql = True
tryParsing = False
print("parsed desired result from gpt response using method 2.")
except:
print("Couldn't parse desired result from gpt response using method 2")
if tryParsing:
try:
txt = gptResponse.split("```json")[-1].split('```')[0].replace('\n', ' ')
jsonResponse = json.loads(txt)
sqlResult = jsonResponse[list(jsonResponse.keys())[0]]['result']
parsedSql = True
tryParsing = False
print("parsed desired result from gpt response using method 3.")
except:
print("Couldn't parse desired result from gpt response using method 3.")
if parsedSql:
isFormatted = False
try:
formattedSql = sqlparse.format(sqlResult, reindent=True)
responseToReturn = formattedSql
isFormatted = True
except:
isFormatted = False
if not isFormatted:
try:
formattedSql = sqlparse.format(sqlResult['result'], reindent=True)
responseToReturn = formattedSql
print("gpt didn't give parsed result. So parsing again. the formatting.")
except:
responseToReturn = str(sqlResult)
else:
responseToReturn = gptResponse
jsonResponse = {}
return responseToReturn, jsonResponse