File size: 5,385 Bytes
d61f3de | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | """
This class create a connection to Snowflake, run queries (read and write)
"""
import json
import os
from snowflake.snowpark import Session
from dotenv import load_dotenv
import logging
logger = logging.getLogger()
load_dotenv()
class SnowFlakeConn:
def __init__(self):
self.session = self.connect_to_snowflake()
# =========================================================
def connect_to_snowflake(self):
# --- Snowflake connection via env vars ---
# Validate all required credentials exist
required_credentials = [
"SNOWFLAKE_USER",
"SNOWFLAKE_PASSWORD",
"SNOWFLAKE_ACCOUNT",
"SNOWFLAKE_ROLE",
"SNOWFLAKE_DATABASE",
"SNOWFLAKE_WAREHOUSE",
"SNOWFLAKE_SCHEMA"
]
missing_credentials = []
for cred in required_credentials:
if not self.get_credential(cred):
missing_credentials.append(cred)
if missing_credentials:
error_msg = f"Missing required Snowflake credentials: {', '.join(missing_credentials)}"
logger.error(error_msg)
raise ValueError(error_msg)
conn = dict(
user=self.get_credential("SNOWFLAKE_USER"),
password=self.get_credential("SNOWFLAKE_PASSWORD"),
account=self.get_credential("SNOWFLAKE_ACCOUNT"),
role=self.get_credential("SNOWFLAKE_ROLE"),
database=self.get_credential("SNOWFLAKE_DATABASE"),
warehouse=self.get_credential("SNOWFLAKE_WAREHOUSE"),
schema=self.get_credential("SNOWFLAKE_SCHEMA"),
)
try:
session = Session.builder.configs(conn).create()
logger.info("Successfully connected to Snowflake")
return session
except Exception as e:
logger.error(f"Failed to connect to Snowflake: {e}")
raise
# =========================================================
def get_credential(self, key):
return os.getenv(key)
# =========================================================
def run_read_query(self, query, data):
"""
Executes a SQL query on Snowflake that fetch the data
:return: Pandas dataframe containing the query results
"""
# Connect to Snowflake
try:
dataframe = self.session.sql(query).to_pandas()
dataframe.columns = dataframe.columns.str.lower()
print(f"reading {data} table successfully")
return dataframe
except Exception as e:
error_msg = f"Error reading {data}: {e}"
print(error_msg)
logger.error(error_msg)
raise
# =========================================================
def store_df_to_snowflake(self, table_name, dataframe, database="SOCIAL_MEDIA_DB", schema="ML_FEATURES", overwrite=False):
"""
Executes a SQL query on Snowflake that write the preprocessed data on new tables
:param query: SQL query string to be executed
:return: None
"""
try:
self.session.use_database(database)
self.session.use_schema(schema)
dataframe = dataframe.reset_index(drop=True)
dataframe.columns = dataframe.columns.str.upper()
self.session.write_pandas(df=dataframe,
table_name=table_name.strip().upper(),
auto_create_table=True,
overwrite=overwrite,
use_logical_type=True)
print(f"Data inserted into {table_name} successfully.")
except Exception as e:
print(f"Error in creating/updating/inserting table: {e}")
# =========================================================
def execute_sql_file(self, file_path):
"""
Executes SQL queries from a file
:param file_path: Path to SQL file
:return: Query result or None for DDL/DML
"""
try:
with open(file_path, 'r', encoding='utf-8') as file:
sql_content = file.read()
result = self.session.sql(sql_content).collect()
print(f"Successfully executed SQL from {file_path}")
return result
except Exception as e:
print(f"Error executing SQL file {file_path}: {e}")
return None
# =========================================================
def execute_query(self, query, description="query"):
"""
Executes a SQL query and returns results
:param query: SQL query string
:param description: Description of the query for logging
:return: Query results
"""
try:
result = self.session.sql(query).collect()
print(f"Successfully executed {description}")
return result
except Exception as e:
print(f"Error executing {description}: {e}")
return None
# =========================================================
def get_data(self, data):
# get any sort of data based on requirement --> comments, contents, etc
pass
# =========================================================
def close_connection(self):
self.session.close()
|