File size: 13,234 Bytes
2d39721
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
"""This module contains utility functions for input conversion and validation."""

import os
import logging
from logging import Logger  # For type hinting
import json
import joblib
import streamlit as st
import torch

from .consts import (
    FEATURE_NAMES,
    CATEGORY_MAPPING,
    GENDER_MAPPING,
    STATE_MAPPING,
    INPUT_METADATA,
    STREAMLIT_VALIDATED,
    MODEL_WEIGHTS_FULL_PATH,
    CONFIG_PATH,
    FEATURE_SCALER_PATH,
)
from .model import Agent


def setup_logger(config: dict, propogate: bool = False) -> Logger:
    """Sets up and returns a named logger based on the provided config dictionary. The new logger will have different handlers based on the config.



    Args:

        config (dict): Dictionary containing logging configuration.

        propogate (bool): Whether to allow log messages to propagate to ancestor loggers.

    Returns:

        Logger: Configured logger instance.

    """
    logger_name = config.get("logger_name", "main")
    log_to_file = config.get("log_to_file", True)  # Set whether to log to a logfile or not
    log_file = config.get("log_file", "logs/app.log")  # Get the log file path
    log_lvl = config.get("log_level", "INFO")
    log_level = getattr(logging, log_lvl.upper(), logging.INFO)  # Set fallback if invalid input
    log_mode = config.get("log_mode", "w")  # Set the log file mode
    log_format = config.get("log_format", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
    date_format = config.get("date_format", "%Y-%m-%d %H:%M:%S")
    log_to_console = config.get("log_to_console", True)  # Set whether to log to console or not

    handlers = []  # Initialize the list of logging handlers

    logger = logging.getLogger(logger_name)  # Create logger object with the specified name

    if not log_to_file and not log_to_console:
        # If no handlers are specified by the config
        print(
            f"Warning: No logging handlers configured for {logger_name}.\nVerbose Logging will be disabled.\nIn 'config/config.json', set ['log_to_file': true] or ['log_to_console': true] if you want to change the logging behavior.",
            flush=True,
        )
    else:
        # Create log parent directory if it doesn't exist
        parent_dir = os.path.dirname(log_file)  # Get the parent directory of the log file
        if parent_dir and parent_dir != ".":
            try:
                os.makedirs(name=parent_dir, exist_ok=True)
                print(
                    f"Parent directory '{parent_dir}' used to store the log file.", flush=True
                )  # flush=True to ensure the message is printed immediately
            except OSError as e:
                print(
                    f"Error creating directory '{parent_dir}': {e} INFO: Using default log file 'app.log' instead.",
                    flush=True,
                )
                log_file = "app.log"  # Fall back to a default log file if problem occurs.

        # Remove all old handlers inherrited from the root logger
        for handler in logger.handlers[:]:
            handler.close()
            logger.removeHandler(handler)

        formatter = logging.Formatter(
            fmt=log_format, datefmt=date_format
        )  # Create a formatter for the log messages

        if log_to_console:
            console_handler = (
                logging.StreamHandler()
            )  # Initialize sending log messages to the console (stdout)
            console_handler.setFormatter(formatter)  # Set the formatter for the console handler
            handlers.append(console_handler)  # Add the console_handler to the list of handlers
        if log_to_file:
            file_handler = logging.FileHandler(
                filename=log_file, mode=log_mode, encoding="utf-8"
            )  # Initialize sending log messages to a file; Enables emoji use
            file_handler.setFormatter(formatter)  # Set the style for the console handler
            handlers.append(file_handler)  # Add the file_handler to the list of handlers

        # Add the handlers to the logger
        for handler in handlers:
            logger.addHandler(handler)

    logger.setLevel(log_level)  # Set logger minimum log level

    logger.propagate = propogate  # Prevent the log messages from being propagated to the root logger; gets rid of the root logger's default handlers,

    return logger


def convert_inputs(**kwargs) -> list:
    """Convert user inputs into a list of features for the model.

    Args:

        **kwargs: Dictionary of user inputs (e.g., {'category': 'entertainment', 'amt': 25.0, ...})

    Returns:

        features: A list of converted features ready for model input.

    """
    features = []  # Create empty list to store all the features

    for feature_name in FEATURE_NAMES:  # Loop through FEATURE_NAMES
        try:
            # Get the value from the kwargs dictionary
            value = kwargs.get(feature_name)

            # Perform validation (using metadata where possible)
            if value is None:
                raise ValueError(f"Missing required input: {feature_name}")

            # --- Mapped Features ---
            if feature_name == "category":
                # Use Specified Mapping for feature
                mapped_value = CATEGORY_MAPPING.get(value, None)
                if mapped_value is not None:
                    if not isinstance(mapped_value, float):
                        raise ValueError(f"{feature_name} must be a float.")
                    features.append(mapped_value)
                else:
                    raise ValueError(f"{feature_name}; value={value}; no mapping.")

            elif feature_name == "gender":
                # Use Specified Mapping for feature
                mapped_value = GENDER_MAPPING.get(value, None)
                if mapped_value is not None:
                    if not isinstance(mapped_value, float):
                        raise ValueError(f"{feature_name} must be a float.")
                    features.append(mapped_value)
                else:
                    raise ValueError(f"{feature_name}; value={value}; no mapping.")

            elif feature_name == "state":
                # Use Specified Mapping for feature
                mapped_value = STATE_MAPPING.get(value, None)
                if mapped_value is not None:
                    if not isinstance(mapped_value, float):
                        raise ValueError(f"{feature_name} must be a float.")
                    features.append(mapped_value)
                else:
                    raise ValueError(f"{feature_name}; value={value}; no mapping.")

            # ... Add logic for other mapped fields here

            # --- Streamlit-Validated Features ---
            elif feature_name in STREAMLIT_VALIDATED:
                # Use INPUT_METADATA for range validation
                meta = INPUT_METADATA.get(feature_name, {})
                min_v = meta.get("min_value")
                max_v = meta.get("max_value")

                if min_v is not None and max_v is not None and not (min_v <= value <= max_v):
                    raise ValueError(f"{feature_name} out of expected range.")
                features.append(float(value))  # Convert to float
            # Default action if not covered by logic above
            else:
                raise ValueError(f"No conversion for {feature_name}")
        except ValueError as e:
            log_and_stop(f"Validation Error for {feature_name}: {e}")
        except Exception as e:  # Catch all other exceptions
            log_and_stop(f"An unexpected fatal error occurred: {e}")

    # Verify final length
    if len(features) != len(FEATURE_NAMES):
        log_and_stop(
            f"Fatal Error: Final feature list length mismatch. Created list size: {len(features)} | Expected list size: {len(FEATURE_NAMES)}"
        )

    return features


@st.cache_data
def load_config():
    """Loads configuration file using global variable. Optimized using streamlit caching.

    Args:

        N/A

    Returns:

        config (dict): the python dictionary containing configuration data

    """
    message = ""  # Initialize variable in case of errors
    try:
        with open(CONFIG_PATH, "r", encoding="utf-8") as f:
            config = json.load(f)
    except FileNotFoundError:
        # For streamlit to acknowledge the '\n' character as a newline use '  \n'. Streamlit processes strings as Markdown
        message = f"❌ Configuration file not found at '{CONFIG_PATH}'.  \nPlease ensure the file exists or fix path to file."
    except json.JSONDecodeError as e:
        message = f"❌ Failed to parse JSON: {e}."
    except Exception as e:  # Catch all other exceptions
        message = f"An unexpected fatal error occurred: {e}"
    # This block executes ONLY if the 'try' block succeeds (no exceptions)
    else:
        return config

    # **This block executes after try/except/else**
    finally:
        # Check if a 'message' was set by any of the 'except' blocks.
        if message:
            message += "  \nStopping Execution."  # Add the common suffix
            print(message)
            st.error(message)
            st.stop()


@st.cache_resource
def load_model(_logger: Logger):
    """Helper function that loads the model's architecture and instantiates a model with its trained weights. Optimized using streamlit caching.

    Args:

        _logger (Logger): The logger instance to log messages. Use underscore to prevent hashing by Streamlit.

    Returns:

        Agent (torch.nn.Module): Returns agent to cpu in evaluation mode.

    """
    message = ""  # Initialize variable in case of errors
    try:
        model_weights = torch.load(MODEL_WEIGHTS_FULL_PATH, weights_only=True)
        _logger.info(f"βœ… Model weights loaded successfully from {MODEL_WEIGHTS_FULL_PATH}")
    except FileNotFoundError:
        message = f"❌ Model Weights file not found at '{MODEL_WEIGHTS_FULL_PATH}'.  \nPlease ensure the file exists."
        log_and_stop(message)

    CONFIG = load_config()
    MODEL_CONFIG = CONFIG.get("model", {})

    try:
        agent = Agent(cfg=MODEL_CONFIG)  # Create agent instance
        agent.load_state_dict(state_dict=model_weights)
    except RuntimeError as e:
        message = f"❌ A runtime error occurred while creating model or loading model weights: {e}"
    except FileNotFoundError as e:
        message = f"❌ Model weights file not found: {e}"
    except KeyError as e:
        message = f"❌ Missing key in model configuration: {e}"
    except Exception as e:  # Catch all other exceptions
        message = f"An unexpected fatal error occurred: {e}"
    # Execute if no exception was caught
    else:
        return agent.eval().to("cpu")
    # If exception was thrown continue to the finally block
    finally:
        if message:
            log_and_stop(message)


@st.cache_data
def load_feature_scaler(_logger: Logger):
    """Loads the feature scaler using the global variable. Optimized using streamlit caching.

    Args:

        _logger (Logger): The logger instance to log messages. Use underscore to prevent hashing by Streamlit.

    Returns:

        feature_scaler: the loaded scalert object

    """
    message = ""  # Initialize variable in case of errors
    # Load feature scaler
    try:
        feature_scaler = joblib.load(FEATURE_SCALER_PATH)
        _logger.info(f"βœ… Feature Scaler loaded successfully from {FEATURE_SCALER_PATH}")
    except FileNotFoundError:
        message = f"❌ Scaler file not found at '{FEATURE_SCALER_PATH}'.  \nPlease ensure the file exists or fix path to file."
    except Exception as e:  # Catch all other exceptions
        message = f"An unexpected fatal error occurred: {e}"
    # Execute if no exception was caught
    else:
        return feature_scaler
    # If exception was thrown continue to the finally block
    finally:
        if message:
            log_and_stop(message)


@st.cache_data
def load_label_scaler(_logger: Logger):
    """Loads the label scaler using the global variable. Optimized using streamlit caching.

    Args:

        _logger (Logger): The logger instance to log messages. Use underscore to prevent hashing by Streamlit.

    Returns:

        label_scaler: the loaded scalert object

    """
    # Not used in this implementation
    label_scaler = None

    return label_scaler


def log_and_stop(message: str):
    """Helper function to log relevant messages. Handles message and exits the program.

    Args:

        message (str): The message to log and display

    Returns:

        N/A

    """
    logger_name = load_config()["logging"]["logger_name"]
    logger = logging.getLogger(logger_name)

    message += "  \nStopping Execution."  # Add the common suffix

    logger.info(message, exc_info=False, stack_info=False)  # Console output
    st.error(message)  # Streamlit UI output
    st.stop()  # Stops Streamlit app