Spaces:
Runtime error
Runtime error
| import psycopg2 | |
| import re | |
| import pandas as pd | |
| from persistStorage import retrieveTablesDataFromLocalDb, saveTablesDataToLocalDB | |
| from config import SCHEMA_INFO_FILE_PATH | |
| import os | |
| import pickle | |
| import sqlparse | |
| import json | |
| class DataWrapper: | |
| def __init__(self, data): | |
| if isinstance(data, list): | |
| emptyDict = {dataKey:None for dataKey in data} | |
| self.__dict__.update(emptyDict) | |
| elif isinstance(data, dict): | |
| self.__dict__.update(data) | |
| def addKey(self, key, val=None): | |
| self.__dict__.update({key:val}) | |
| def __repr__(self): | |
| return self.__dict__.__repr__() | |
| class MetaDataLayout: | |
| def __init__(self, schemaName, allTablesAndCols): | |
| self.schemaName = schemaName | |
| self.datalayout = { | |
| "schema": self.schemaName, | |
| "selectedTables":{}, | |
| "allTables":allTablesAndCols | |
| } | |
| def getDataLayout(self): | |
| return self.datalayout | |
| def setSelection(self, tablesAndCols): | |
| """ | |
| tablesAndCols : {"table1":["col1", "col2"], "table1":["cola","colb"]} | |
| """ | |
| datalayout = self.datalayout | |
| for table in tablesAndCols: | |
| if table in datalayout['allTables'].keys(): | |
| datalayout['selectedTables'][table] = tablesAndCols[table] | |
| else: | |
| print(f"Table {table} doesn't exists in the schema") | |
| self.datalayout = datalayout | |
| def resetSelection(self): | |
| datalayout = self.datalayout | |
| datalayout['selectedTables'] = {} | |
| self.datalayout = datalayout | |
| def getSelectedTablesAndCols(self): | |
| return self.datalayout['selectedTables'] | |
| def getAllTablesCols(self): | |
| return self.datalayout['allTables'] | |
| class DbEngine: | |
| def __init__(self, dbCreds): | |
| self.dbCreds = dbCreds | |
| self._connection = None | |
| def connect(self): | |
| dbCreds = self.dbCreds | |
| keepaliveKwargs = { | |
| "keepalives": 1, | |
| "keepalives_idle": 100, | |
| "keepalives_interval": 5, | |
| "keepalives_count": 5, | |
| } | |
| if self._connection is None or self._connection.closed != 0: | |
| self._connection = psycopg2.connect(database=dbCreds.database, user = dbCreds.user, | |
| password = dbCreds.password, host = dbCreds.host, | |
| port = dbCreds.port, **keepaliveKwargs) | |
| def getConnection(self): | |
| if self._connection is None or self._connection.closed != 0: | |
| self.connect() | |
| return self._connection | |
| def disconnect(self): | |
| if self._connection is not None and self._connection.closed == 0: | |
| self._connection.close() | |
| def execute_query(self, query): | |
| try: | |
| self.connect() | |
| with self._connection.cursor() as cursor: | |
| cursor.execute(query) | |
| result = cursor.fetchall() | |
| except Exception as e: | |
| raise Exception(e) | |
| return result | |
| def executeQuery(dbEngine, query): | |
| result = dbEngine.execute_query(query) | |
| return result | |
| def executeColumnsQuery(dbEngine, columnQuery): | |
| dbEngine.connect() | |
| with dbEngine._connection.cursor() as cursor: | |
| cursor.execute(columnQuery) | |
| columns = [desc[0] for desc in cursor.description] | |
| return columns | |
| def closeDbEngine(dbEngine): | |
| dbEngine.disconnect() | |
| def getAllTablesInfo(dbEngine, schemaName, useCache=True): | |
| if useCache: | |
| if os.path.isfile(SCHEMA_INFO_FILE_PATH): | |
| with open(SCHEMA_INFO_FILE_PATH,'rb') as fh: | |
| tablesAndCols = pickle.load(fh) | |
| return tablesAndCols | |
| tablesAndCols = {} | |
| print("Getting tables Info, list of tables and columns...") | |
| allTablesQuery = f"""SELECT table_name FROM information_schema.tables | |
| WHERE table_schema = '{schemaName}'""" | |
| tables = executeQuery(dbEngine, allTablesQuery) | |
| for table in tables: | |
| tableName = table[0] | |
| columnsQuery = f"""Select * FROM {schemaName}.{tableName} LIMIT 0""" | |
| columns = executeColumnsQuery(dbEngine, columnsQuery) | |
| tablesAndCols[tableName] = columns | |
| with open(SCHEMA_INFO_FILE_PATH, 'wb') as fh: | |
| pickle.dump(tablesAndCols, fh) | |
| return tablesAndCols | |
| def getSampleDataForTablesAndCols(dbEngine, schemaName, tablesAndCols, maxRows): | |
| data = retrieveTablesDataFromLocalDb(list(tablesAndCols.keys())) | |
| if data!={}: | |
| return data | |
| dbEngine.connect() | |
| conn = dbEngine.getConnection() | |
| print("Didn't find any cache/valid cache.") | |
| print("Getting data from aws redshift") | |
| for table in tablesAndCols.keys(): | |
| try: | |
| sqlQuery = f"""select * from {schemaName}.{table} limit {maxRows}""" | |
| data[table] = pd.read_sql_query(sqlQuery, con=conn) | |
| except: | |
| print(f"couldn't read table data. Table: {table}") | |
| data[table] = pd.DataFrame({}) | |
| saveTablesDataToLocalDB(data) | |
| return data | |
| # Function to test the generated sql query | |
| def isDataQuery(sql_query): | |
| upper_query = sql_query.upper() | |
| dml_keywords = ['INSERT', 'UPDATE', 'DELETE', 'MERGE'] | |
| for keyword in dml_keywords: | |
| if re.search(fr'\b{keyword}\b', upper_query): | |
| return False # Found a DML keyword, indicating modification | |
| # If no DML keywords are found, it's likely a data query | |
| return True | |
| def extractSqlFromGptResponse(gptReponse): | |
| sqlPattern = re.compile(r"```sql\n(.*?)```", re.DOTALL) | |
| # Find the match in the text | |
| match = re.search(sqlPattern, gptReponse) | |
| # Extract the SQL query if a match is found | |
| if match: | |
| sqlQuery = match.group(1) | |
| return sqlQuery | |
| else: | |
| return "" | |
| def addSchemaToTableInSQL(sqlQuery, schemaName, tablesList): | |
| for table in tablesList: | |
| pattern = re.compile(rf'(?<!\S){re.escape(table)}(?!\S)', re.IGNORECASE) | |
| replacement = f'{schemaName}.{table}' | |
| sqlQuery = re.sub(pattern, replacement, sqlQuery) | |
| return sqlQuery | |
| def preProcessGptQueryReponse(gptResponse, metadataLayout: MetaDataLayout): | |
| schemaName = metadataLayout.schemaName | |
| tablesList = metadataLayout.getAllTablesCols().keys() | |
| gptResponse = addSchemaToTableInSQL(gptResponse, schemaName=schemaName, tablesList=tablesList) | |
| return gptResponse | |
| def remove_with_as(sql_query): | |
| pattern = r'WITH\s+.*?AS\s*\((?:[^()]|\((?:[^()]+|\([^()]*\))*\))*\)\s*,?' | |
| cte_pattern = re.compile(pattern, re.IGNORECASE | re.DOTALL) | |
| sql_query = cte_pattern.sub('', sql_query) | |
| pattern = r'WITH\s+.*?AS\s*\((?:[^()]|\((?:[^()]+|\([^()]*\))*\))*\)\s*' | |
| cte_pattern = re.compile(pattern, re.IGNORECASE | re.DOTALL) | |
| sql_query = cte_pattern.sub('', sql_query) | |
| pattern = r'subquery\d+\s+AS\s*\((?:[^()]|\((?:[^()]+|\([^()]*\))*\))*\)\s*' | |
| cte_pattern = re.compile(pattern, re.IGNORECASE | re.DOTALL) | |
| sql_query = cte_pattern.sub('', sql_query) | |
| return sql_query | |
| def construct_with_stats(final_query, subquery_info): | |
| with_as_statements = [] | |
| for subquery_name, subquery in subquery_info.items(): | |
| description = subquery['description'].replace('\n',' ') | |
| with_as_statement = f"{subquery_name} AS (\n{subquery['result']}\n) \n--{description}\n" | |
| with_as_statements.append(with_as_statement) | |
| with_as_statement = "\n,\n".join(with_as_statements) | |
| if len(with_as_statements)>0: | |
| final_query_with_with_as = f"WITH {with_as_statement}\n{final_query}" | |
| else: | |
| final_query_with_with_as = final_query | |
| return final_query_with_with_as | |
| def get_keys_matching_pattern(dictionary, pattern): | |
| return [key for key in dictionary.keys() if re.match(pattern, key)] | |
| def get_subquery_info(json_response): | |
| subquery_keys = get_keys_matching_pattern(json_response, r'subquery\d+') | |
| return {key:json_response[key] for key in subquery_keys} | |
| def add_single_query_description(query, json_response): | |
| if json_response.get("query",None)!=None: | |
| description = json_response['query']['description'].replace('\n',' ') | |
| query = f"{query}\n --{description}\n" | |
| return query | |
| def construct_final_query(query, json_response): | |
| query = remove_with_as(query) | |
| subquery_info = get_subquery_info(json_response) | |
| final_query_with_with_as = construct_with_stats(query, subquery_info) | |
| final_query_with_with_as = add_single_query_description(final_query_with_with_as, json_response) | |
| final_query_with_with_as = final_query_with_with_as.replace('\n\n','\n') | |
| final_query_with_with_as = sqlparse.format(final_query_with_with_as, reindent=True) | |
| return final_query_with_with_as | |
| def getQueryFromGptResponse(gptResponse): | |
| tryParsing = True | |
| parsedSql = False | |
| if tryParsing: | |
| try: | |
| txt = gptResponse.split("```json")[-1].split('```')[0].replace('\n', ' ') | |
| jsonResponse = json.loads(txt) | |
| sqlResult = jsonResponse['finalResult'] | |
| parsedSql = True | |
| tryParsing = False | |
| print("parsed desired result from gpt response using method 1.") | |
| except: | |
| print("Couldn't parse desired result from gpt response using method 1.") | |
| if tryParsing: | |
| try: | |
| jsonResponse = json.loads(gptResponse.replace("```json","").replace("```","").replace('\n', ' ')) | |
| sqlResult = jsonResponse['finalResult'] | |
| parsedSql = True | |
| tryParsing = False | |
| print("parsed desired result from gpt response using method 2.") | |
| except: | |
| print("Couldn't parse desired result from gpt response using method 2") | |
| if tryParsing: | |
| try: | |
| txt = gptResponse.split("```json")[-1].split('```')[0].replace('\n', ' ') | |
| jsonResponse = json.loads(txt) | |
| sqlResult = jsonResponse[list(jsonResponse.keys())[0]]['result'] | |
| parsedSql = True | |
| tryParsing = False | |
| print("parsed desired result from gpt response using method 3.") | |
| except: | |
| print("Couldn't parse desired result from gpt response using method 3.") | |
| if parsedSql: | |
| isFormatted = False | |
| try: | |
| formattedSql = sqlparse.format(sqlResult, reindent=True) | |
| responseToReturn = formattedSql | |
| isFormatted = True | |
| except: | |
| isFormatted = False | |
| if not isFormatted: | |
| try: | |
| formattedSql = sqlparse.format(sqlResult['result'], reindent=True) | |
| responseToReturn = formattedSql | |
| print("gpt didn't give parsed result. So parsing again. the formatting.") | |
| except: | |
| responseToReturn = str(sqlResult) | |
| else: | |
| responseToReturn = gptResponse | |
| jsonResponse = {} | |
| return responseToReturn, jsonResponse |