import psycopg2 import re import pandas as pd from persistStorage import retrieveTablesDataFromLocalDb, saveTablesDataToLocalDB from config import SCHEMA_INFO_FILE_PATH import os import pickle class DataWrapper: def __init__(self, data): if isinstance(data, list): emptyDict = {dataKey:None for dataKey in data} self.__dict__.update(emptyDict) elif isinstance(data, dict): self.__dict__.update(data) def addKey(self, key, val=None): self.__dict__.update({key:val}) def __repr__(self): return self.__dict__.__repr__() class MetaDataLayout: def __init__(self, schemaName, allTablesAndCols): self.schemaName = schemaName self.datalayout = { "schema": self.schemaName, "selectedTables":{}, "allTables":allTablesAndCols } def setSelection(self, tablesAndCols): """ tablesAndCols : {"table1":["col1", "col2"], "table1":["cola","colb"]} """ datalayout = self.datalayout for table in tablesAndCols: if table in datalayout['allTables'].keys(): datalayout['selectedTables'][table] = tablesAndCols[table] else: print(f"Table {table} doesn't exists in the schema") self.datalayout = datalayout def resetSelection(self): datalayout = self.datalayout datalayout['selectedTables'] = {} self.datalayout = datalayout def getSelectedTablesAndCols(self): return self.datalayout['selectedTables'] def getAllTablesCols(self): return self.datalayout['allTables'] class DbEngine: def __init__(self, dbCreds): self.dbCreds = dbCreds self._connection = None def connect(self): dbCreds = self.dbCreds keepaliveKwargs = { "keepalives": 1, "keepalives_idle": 100, "keepalives_interval": 5, "keepalives_count": 5, } if self._connection is None or self._connection.closed != 0: self._connection = psycopg2.connect(database=dbCreds.database, user = dbCreds.user, password = dbCreds.password, host = dbCreds.host, port = dbCreds.port, **keepaliveKwargs) def getConnection(self): if self._connection is None or self._connection.closed != 0: self.connect() return self._connection def disconnect(self): if self._connection is not None and self._connection.closed == 0: self._connection.close() def execute_query(self, query): try: self.connect() with self._connection.cursor() as cursor: cursor.execute(query) result = cursor.fetchall() except Exception as e: raise Exception(e) return result def executeQuery(dbEngine, query): result = dbEngine.execute_query(query) return result def executeColumnsQuery(dbEngine, columnQuery): dbEngine.connect() with dbEngine._connection.cursor() as cursor: cursor.execute(columnQuery) columns = [desc[0] for desc in cursor.description] return columns def closeDbEngine(dbEngine): dbEngine.disconnect() def getAllTablesInfo(dbEngine, schemaName, useCache=True): if useCache: if os.path.isfile(SCHEMA_INFO_FILE_PATH): with open(SCHEMA_INFO_FILE_PATH,'rb') as fh: tablesAndCols = pickle.load(fh) return tablesAndCols tablesAndCols = {} print("Getting tables Info, list of tables and columns...") allTablesQuery = f"""SELECT table_name FROM information_schema.tables WHERE table_schema = '{schemaName}'""" tables = executeQuery(dbEngine, allTablesQuery) for table in tables: tableName = table[0] columnsQuery = f"""Select * FROM {schemaName}.{tableName} LIMIT 0""" columns = executeColumnsQuery(dbEngine, columnsQuery) tablesAndCols[tableName] = columns with open(SCHEMA_INFO_FILE_PATH, 'wb') as fh: pickle.dump(tablesAndCols, fh) return tablesAndCols def getSampleDataForTablesAndCols(dbEngine, schemaName, tablesAndCols, maxRows): data = retrieveTablesDataFromLocalDb(list(tablesAndCols.keys())) if data!={}: return data dbEngine.connect() conn = dbEngine.getConnection() print("Didn't find any cache/valid cache.") print("Getting data from aws redshift") for table in tablesAndCols.keys(): try: sqlQuery = f"""select * from {schemaName}.{table} limit {maxRows}""" data[table] = pd.read_sql_query(sqlQuery, con=conn) except: print(f"couldn't read table data. Table: {table}") data[table] = pd.DataFrame({}) saveTablesDataToLocalDB(data) return data # Function to test the generated sql query def isDataQuery(sql_query): upper_query = sql_query.upper() dml_keywords = ['INSERT', 'UPDATE', 'DELETE', 'MERGE'] for keyword in dml_keywords: if re.search(fr'\b{keyword}\b', upper_query): return False # Found a DML keyword, indicating modification # If no DML keywords are found, it's likely a data query return True def extractSqlFromGptResponse(gptReponse): sqlPattern = re.compile(r"```sql\n(.*?)```", re.DOTALL) # Find the match in the text match = re.search(sqlPattern, gptReponse) # Extract the SQL query if a match is found if match: sqlQuery = match.group(1) return sqlQuery else: return "" def addSchemaToTableInSQL(sqlQuery, schemaName, tablesList): for table in tablesList: pattern = re.compile(rf'(?