File size: 5,385 Bytes
d61f3de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""
This class create a connection to Snowflake, run queries (read and write)
"""
import json
import os
from snowflake.snowpark import Session
from dotenv import load_dotenv
import logging
logger = logging.getLogger()
load_dotenv()

class SnowFlakeConn:
    def __init__(self):
        self.session = self.connect_to_snowflake()


    # =========================================================
    def connect_to_snowflake(self):
        # --- Snowflake connection via env vars ---
        # Validate all required credentials exist
        required_credentials = [
            "SNOWFLAKE_USER",
            "SNOWFLAKE_PASSWORD",
            "SNOWFLAKE_ACCOUNT",
            "SNOWFLAKE_ROLE",
            "SNOWFLAKE_DATABASE",
            "SNOWFLAKE_WAREHOUSE",
            "SNOWFLAKE_SCHEMA"
        ]

        missing_credentials = []
        for cred in required_credentials:
            if not self.get_credential(cred):
                missing_credentials.append(cred)

        if missing_credentials:
            error_msg = f"Missing required Snowflake credentials: {', '.join(missing_credentials)}"
            logger.error(error_msg)
            raise ValueError(error_msg)

        conn = dict(
            user=self.get_credential("SNOWFLAKE_USER"),
            password=self.get_credential("SNOWFLAKE_PASSWORD"),
            account=self.get_credential("SNOWFLAKE_ACCOUNT"),
            role=self.get_credential("SNOWFLAKE_ROLE"),
            database=self.get_credential("SNOWFLAKE_DATABASE"),
            warehouse=self.get_credential("SNOWFLAKE_WAREHOUSE"),
            schema=self.get_credential("SNOWFLAKE_SCHEMA"),
        )

        try:
            session = Session.builder.configs(conn).create()
            logger.info("Successfully connected to Snowflake")
            return session
        except Exception as e:
            logger.error(f"Failed to connect to Snowflake: {e}")
            raise

    # =========================================================
    def get_credential(self, key):
        return os.getenv(key)

    # =========================================================
    def run_read_query(self, query, data):
        """
        Executes a SQL query on Snowflake that fetch the data
        :return: Pandas dataframe containing the query results
        """

        # Connect to Snowflake
        try:
            dataframe = self.session.sql(query).to_pandas()
            dataframe.columns = dataframe.columns.str.lower()
            print(f"reading {data} table successfully")
            return dataframe
        except Exception as e:
            error_msg = f"Error reading {data}: {e}"
            print(error_msg)
            logger.error(error_msg)
            raise

    # =========================================================
    def store_df_to_snowflake(self, table_name, dataframe, database="SOCIAL_MEDIA_DB", schema="ML_FEATURES", overwrite=False):
        """
        Executes a SQL query on Snowflake that write the preprocessed data on new tables
        :param query: SQL query string to be executed
        :return: None
        """

        try:
            self.session.use_database(database)
            self.session.use_schema(schema)

            dataframe = dataframe.reset_index(drop=True)
            dataframe.columns = dataframe.columns.str.upper()

            self.session.write_pandas(df=dataframe,
                                      table_name=table_name.strip().upper(),
                                      auto_create_table=True,
                                      overwrite=overwrite,
                                      use_logical_type=True)
            print(f"Data inserted into {table_name} successfully.")

        except Exception as e:
            print(f"Error in creating/updating/inserting table: {e}")

    # =========================================================
    def execute_sql_file(self, file_path):
        """
        Executes SQL queries from a file
        :param file_path: Path to SQL file
        :return: Query result or None for DDL/DML
        """
        try:
            with open(file_path, 'r', encoding='utf-8') as file:
                sql_content = file.read()

            result = self.session.sql(sql_content).collect()
            print(f"Successfully executed SQL from {file_path}")
            return result
        except Exception as e:
            print(f"Error executing SQL file {file_path}: {e}")
            return None

    # =========================================================
    def execute_query(self, query, description="query"):
        """
        Executes a SQL query and returns results
        :param query: SQL query string
        :param description: Description of the query for logging
        :return: Query results
        """
        try:
            result = self.session.sql(query).collect()
            print(f"Successfully executed {description}")
            return result
        except Exception as e:
            print(f"Error executing {description}: {e}")
            return None


    # =========================================================
    def get_data(self, data):
        # get any sort of data based on requirement --> comments, contents, etc
        pass

    # =========================================================
    def close_connection(self):
        self.session.close()