SoggyBurritos's picture
Upload 12 files
2d39721 verified
"""This module contains utility functions for input conversion and validation."""
import os
import logging
from logging import Logger # For type hinting
import json
import joblib
import streamlit as st
import torch
from .consts import (
FEATURE_NAMES,
CATEGORY_MAPPING,
GENDER_MAPPING,
STATE_MAPPING,
INPUT_METADATA,
STREAMLIT_VALIDATED,
MODEL_WEIGHTS_FULL_PATH,
CONFIG_PATH,
FEATURE_SCALER_PATH,
)
from .model import Agent
def setup_logger(config: dict, propogate: bool = False) -> Logger:
"""Sets up and returns a named logger based on the provided config dictionary. The new logger will have different handlers based on the config.
Args:
config (dict): Dictionary containing logging configuration.
propogate (bool): Whether to allow log messages to propagate to ancestor loggers.
Returns:
Logger: Configured logger instance.
"""
logger_name = config.get("logger_name", "main")
log_to_file = config.get("log_to_file", True) # Set whether to log to a logfile or not
log_file = config.get("log_file", "logs/app.log") # Get the log file path
log_lvl = config.get("log_level", "INFO")
log_level = getattr(logging, log_lvl.upper(), logging.INFO) # Set fallback if invalid input
log_mode = config.get("log_mode", "w") # Set the log file mode
log_format = config.get("log_format", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
date_format = config.get("date_format", "%Y-%m-%d %H:%M:%S")
log_to_console = config.get("log_to_console", True) # Set whether to log to console or not
handlers = [] # Initialize the list of logging handlers
logger = logging.getLogger(logger_name) # Create logger object with the specified name
if not log_to_file and not log_to_console:
# If no handlers are specified by the config
print(
f"Warning: No logging handlers configured for {logger_name}.\nVerbose Logging will be disabled.\nIn 'config/config.json', set ['log_to_file': true] or ['log_to_console': true] if you want to change the logging behavior.",
flush=True,
)
else:
# Create log parent directory if it doesn't exist
parent_dir = os.path.dirname(log_file) # Get the parent directory of the log file
if parent_dir and parent_dir != ".":
try:
os.makedirs(name=parent_dir, exist_ok=True)
print(
f"Parent directory '{parent_dir}' used to store the log file.", flush=True
) # flush=True to ensure the message is printed immediately
except OSError as e:
print(
f"Error creating directory '{parent_dir}': {e} INFO: Using default log file 'app.log' instead.",
flush=True,
)
log_file = "app.log" # Fall back to a default log file if problem occurs.
# Remove all old handlers inherrited from the root logger
for handler in logger.handlers[:]:
handler.close()
logger.removeHandler(handler)
formatter = logging.Formatter(
fmt=log_format, datefmt=date_format
) # Create a formatter for the log messages
if log_to_console:
console_handler = (
logging.StreamHandler()
) # Initialize sending log messages to the console (stdout)
console_handler.setFormatter(formatter) # Set the formatter for the console handler
handlers.append(console_handler) # Add the console_handler to the list of handlers
if log_to_file:
file_handler = logging.FileHandler(
filename=log_file, mode=log_mode, encoding="utf-8"
) # Initialize sending log messages to a file; Enables emoji use
file_handler.setFormatter(formatter) # Set the style for the console handler
handlers.append(file_handler) # Add the file_handler to the list of handlers
# Add the handlers to the logger
for handler in handlers:
logger.addHandler(handler)
logger.setLevel(log_level) # Set logger minimum log level
logger.propagate = propogate # Prevent the log messages from being propagated to the root logger; gets rid of the root logger's default handlers,
return logger
def convert_inputs(**kwargs) -> list:
"""Convert user inputs into a list of features for the model.
Args:
**kwargs: Dictionary of user inputs (e.g., {'category': 'entertainment', 'amt': 25.0, ...})
Returns:
features: A list of converted features ready for model input.
"""
features = [] # Create empty list to store all the features
for feature_name in FEATURE_NAMES: # Loop through FEATURE_NAMES
try:
# Get the value from the kwargs dictionary
value = kwargs.get(feature_name)
# Perform validation (using metadata where possible)
if value is None:
raise ValueError(f"Missing required input: {feature_name}")
# --- Mapped Features ---
if feature_name == "category":
# Use Specified Mapping for feature
mapped_value = CATEGORY_MAPPING.get(value, None)
if mapped_value is not None:
if not isinstance(mapped_value, float):
raise ValueError(f"{feature_name} must be a float.")
features.append(mapped_value)
else:
raise ValueError(f"{feature_name}; value={value}; no mapping.")
elif feature_name == "gender":
# Use Specified Mapping for feature
mapped_value = GENDER_MAPPING.get(value, None)
if mapped_value is not None:
if not isinstance(mapped_value, float):
raise ValueError(f"{feature_name} must be a float.")
features.append(mapped_value)
else:
raise ValueError(f"{feature_name}; value={value}; no mapping.")
elif feature_name == "state":
# Use Specified Mapping for feature
mapped_value = STATE_MAPPING.get(value, None)
if mapped_value is not None:
if not isinstance(mapped_value, float):
raise ValueError(f"{feature_name} must be a float.")
features.append(mapped_value)
else:
raise ValueError(f"{feature_name}; value={value}; no mapping.")
# ... Add logic for other mapped fields here
# --- Streamlit-Validated Features ---
elif feature_name in STREAMLIT_VALIDATED:
# Use INPUT_METADATA for range validation
meta = INPUT_METADATA.get(feature_name, {})
min_v = meta.get("min_value")
max_v = meta.get("max_value")
if min_v is not None and max_v is not None and not (min_v <= value <= max_v):
raise ValueError(f"{feature_name} out of expected range.")
features.append(float(value)) # Convert to float
# Default action if not covered by logic above
else:
raise ValueError(f"No conversion for {feature_name}")
except ValueError as e:
log_and_stop(f"Validation Error for {feature_name}: {e}")
except Exception as e: # Catch all other exceptions
log_and_stop(f"An unexpected fatal error occurred: {e}")
# Verify final length
if len(features) != len(FEATURE_NAMES):
log_and_stop(
f"Fatal Error: Final feature list length mismatch. Created list size: {len(features)} | Expected list size: {len(FEATURE_NAMES)}"
)
return features
@st.cache_data
def load_config():
"""Loads configuration file using global variable. Optimized using streamlit caching.
Args:
N/A
Returns:
config (dict): the python dictionary containing configuration data
"""
message = "" # Initialize variable in case of errors
try:
with open(CONFIG_PATH, "r", encoding="utf-8") as f:
config = json.load(f)
except FileNotFoundError:
# For streamlit to acknowledge the '\n' character as a newline use ' \n'. Streamlit processes strings as Markdown
message = f"❌ Configuration file not found at '{CONFIG_PATH}'. \nPlease ensure the file exists or fix path to file."
except json.JSONDecodeError as e:
message = f"❌ Failed to parse JSON: {e}."
except Exception as e: # Catch all other exceptions
message = f"An unexpected fatal error occurred: {e}"
# This block executes ONLY if the 'try' block succeeds (no exceptions)
else:
return config
# **This block executes after try/except/else**
finally:
# Check if a 'message' was set by any of the 'except' blocks.
if message:
message += " \nStopping Execution." # Add the common suffix
print(message)
st.error(message)
st.stop()
@st.cache_resource
def load_model(_logger: Logger):
"""Helper function that loads the model's architecture and instantiates a model with its trained weights. Optimized using streamlit caching.
Args:
_logger (Logger): The logger instance to log messages. Use underscore to prevent hashing by Streamlit.
Returns:
Agent (torch.nn.Module): Returns agent to cpu in evaluation mode.
"""
message = "" # Initialize variable in case of errors
try:
model_weights = torch.load(MODEL_WEIGHTS_FULL_PATH, weights_only=True)
_logger.info(f"βœ… Model weights loaded successfully from {MODEL_WEIGHTS_FULL_PATH}")
except FileNotFoundError:
message = f"❌ Model Weights file not found at '{MODEL_WEIGHTS_FULL_PATH}'. \nPlease ensure the file exists."
log_and_stop(message)
CONFIG = load_config()
MODEL_CONFIG = CONFIG.get("model", {})
try:
agent = Agent(cfg=MODEL_CONFIG) # Create agent instance
agent.load_state_dict(state_dict=model_weights)
except RuntimeError as e:
message = f"❌ A runtime error occurred while creating model or loading model weights: {e}"
except FileNotFoundError as e:
message = f"❌ Model weights file not found: {e}"
except KeyError as e:
message = f"❌ Missing key in model configuration: {e}"
except Exception as e: # Catch all other exceptions
message = f"An unexpected fatal error occurred: {e}"
# Execute if no exception was caught
else:
return agent.eval().to("cpu")
# If exception was thrown continue to the finally block
finally:
if message:
log_and_stop(message)
@st.cache_data
def load_feature_scaler(_logger: Logger):
"""Loads the feature scaler using the global variable. Optimized using streamlit caching.
Args:
_logger (Logger): The logger instance to log messages. Use underscore to prevent hashing by Streamlit.
Returns:
feature_scaler: the loaded scalert object
"""
message = "" # Initialize variable in case of errors
# Load feature scaler
try:
feature_scaler = joblib.load(FEATURE_SCALER_PATH)
_logger.info(f"βœ… Feature Scaler loaded successfully from {FEATURE_SCALER_PATH}")
except FileNotFoundError:
message = f"❌ Scaler file not found at '{FEATURE_SCALER_PATH}'. \nPlease ensure the file exists or fix path to file."
except Exception as e: # Catch all other exceptions
message = f"An unexpected fatal error occurred: {e}"
# Execute if no exception was caught
else:
return feature_scaler
# If exception was thrown continue to the finally block
finally:
if message:
log_and_stop(message)
@st.cache_data
def load_label_scaler(_logger: Logger):
"""Loads the label scaler using the global variable. Optimized using streamlit caching.
Args:
_logger (Logger): The logger instance to log messages. Use underscore to prevent hashing by Streamlit.
Returns:
label_scaler: the loaded scalert object
"""
# Not used in this implementation
label_scaler = None
return label_scaler
def log_and_stop(message: str):
"""Helper function to log relevant messages. Handles message and exits the program.
Args:
message (str): The message to log and display
Returns:
N/A
"""
logger_name = load_config()["logging"]["logger_name"]
logger = logging.getLogger(logger_name)
message += " \nStopping Execution." # Add the common suffix
logger.info(message, exc_info=False, stack_info=False) # Console output
st.error(message) # Streamlit UI output
st.stop() # Stops Streamlit app