File size: 5,373 Bytes
37bd4dd
1dda07c
67524d7
522ce80
37bd4dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67524d7
37bd4dd
 
 
4cb09a6
 
 
 
 
 
67524d7
 
37bd4dd
4cb09a6
67524d7
 
 
 
 
 
37bd4dd
 
67524d7
 
37bd4dd
 
67524d7
 
 
 
 
 
 
37bd4dd
 
 
 
 
 
 
 
67524d7
 
37bd4dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
522ce80
 
 
 
67524d7
 
0a41071
522ce80
37bd4dd
 
 
 
 
 
522ce80
 
1dda07c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67524d7
1dda07c
 
fd68fcc
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import psycopg2
import re
import pandas as pd
from persistStorage import retrieveTablesDataFromLocalDb, saveTablesDataToLocalDB

class DataWrapper:
  def __init__(self, data):
    if isinstance(data, list):
      emptyDict = {dataKey:None for dataKey in data}
      self.__dict__.update(emptyDict)
    elif isinstance(data, dict):
      self.__dict__.update(data)

  def addKey(self, key, val=None):
    self.__dict__.update({key:val})

  def __repr__(self):
    return self.__dict__.__repr__()

class MetaDataLayout:
  def __init__(self, schemaName, allTablesAndCols):
    self.schemaName = schemaName
    self.datalayout = {
        "schema": self.schemaName,
        "selectedTables":{},
        "allTables":allTablesAndCols
    }

  def setSelection(self, tablesAndCols):
    """
    tablesAndCols : {"table1":["col1", "col2"], "table1":["cola","colb"]}
    """
    datalayout = self.datalayout
    for table in tablesAndCols:
      if table in datalayout['allTables'].keys():
        datalayout['selectedTables'][table] = tablesAndCols[table]
      else:
        print(f"Table {table} doesn't exists in the schema")
    self.datalayout = datalayout

  def resetSelection(self):
    datalayout = self.datalayout
    datalayout['selectedTables'] = {}
    self.datalayout = datalayout

  def getSelectedTablesAndCols(self):
    return self.datalayout['selectedTables']
  
  def getAllTablesCols(self):
    return self.datalayout['allTables']

  

  
class DbEngine:
  def __init__(self, dbCreds):
    self.dbCreds = dbCreds
    self._connection = None

  def connect(self):
    dbCreds = self.dbCreds
    keepaliveKwargs = {
    "keepalives": 1,
    "keepalives_idle": 100,
    "keepalives_interval": 5,
    "keepalives_count": 5,
    }
    if self._connection is None or self._connection.closed != 0:
      self._connection = psycopg2.connect(database=dbCreds.database, user = dbCreds.user,
                  password = dbCreds.password, host = dbCreds.host,
                  port = dbCreds.port, **keepaliveKwargs)
      
  def getConnection(self):
    if self._connection is None or self._connection.closed != 0:
      self.connect()
    return self._connection


  def disconnect(self):
    if self._connection is not None and self._connection.closed == 0:
      self._connection.close()

  def execute_query(self, query):
    try:
      self.connect()
      with self._connection.cursor() as cursor:
        cursor.execute(query)
        result = cursor.fetchall()
    except Exception as e:
      raise Exception(e)
    return result

  
def executeQuery(dbEngine, query):
  result = dbEngine.execute_query(query)
  return result

def executeColumnsQuery(dbEngine, columnQuery):
  dbEngine.connect()
  with dbEngine._connection.cursor() as cursor:
    cursor.execute(columnQuery)
    columns = [desc[0] for desc in cursor.description]
  return columns
  
def closeDbEngine(dbEngine):
  dbEngine.disconnect()

def getAllTablesInfo(dbEngine, schemaName):
  tablesAndCols = {}
  allTablesQuery = f"""SELECT table_name FROM information_schema.tables
    WHERE table_schema = '{schemaName}'"""
  tables = executeQuery(dbEngine, allTablesQuery)
  for table in tables:
    tableName = table[0]
    columnsQuery = f"""Select * FROM {schemaName}.{tableName} LIMIT 0"""
    columns = executeColumnsQuery(dbEngine, columnsQuery)
    tablesAndCols[tableName] = columns
  return tablesAndCols

def getSampleDataForTablesAndCols(dbEngine, schemaName, tablesAndCols, maxRows):
    data = retrieveTablesDataFromLocalDb(list(tablesAndCols.keys()))
    if data!={}:
      return data
    
    dbEngine.connect()
    conn = dbEngine.getConnection()
    print("Didn't find any cache/valid cache.")
    print("Getting data from aws redshift")
    for table in tablesAndCols.keys():
      try:
        sqlQuery = f"""select * from {schemaName}.{table} limit {maxRows}"""
        data[table] = pd.read_sql_query(sqlQuery, con=conn)
      except:
        print(f"couldn't read table data. Table: {table}")
        data[table] = pd.DataFrame({})
    saveTablesDataToLocalDB(data)
    return data

# Function to test the generated sql query
def isDataQuery(sql_query):
    upper_query = sql_query.upper()

    dml_keywords = ['INSERT', 'UPDATE', 'DELETE', 'MERGE']
    for keyword in dml_keywords:
        if re.search(fr'\b{keyword}\b', upper_query):
            return False  # Found a DML keyword, indicating modification

    # If no DML keywords are found, it's likely a data query
    return True

def extractSqlFromGptResponse(gptReponse):
  sqlPattern = re.compile(r"```sql\n(.*?)```", re.DOTALL)

  # Find the match in the text
  match = re.search(sqlPattern, gptReponse)

  # Extract the SQL query if a match is found
  if match:
    sqlQuery = match.group(1)
    return sqlQuery
  else:
    return ""
  
def addSchemaToTableInSQL(sqlQuery, schemaName, tablesList):

  for table in tablesList:
    pattern = re.compile(rf'(?<!\S){re.escape(table)}(?!\S)', re.IGNORECASE)
    replacement = f'{schemaName}.{table}'
    sqlQuery = re.sub(pattern, replacement, sqlQuery)
  return sqlQuery

def preProcessGptQueryReponse(gptResponse, metadataLayout: MetaDataLayout):
   schemaName = metadataLayout.schemaName
   tablesList = metadataLayout.getAllTablesCols().keys()
   gptResponse = addSchemaToTableInSQL(gptResponse, schemaName=schemaName, tablesList=tablesList)
   return gptResponse