Spaces:
Sleeping
Sleeping
| """This module contains utility functions for input conversion and validation.""" | |
| import os | |
| import logging | |
| from logging import Logger # For type hinting | |
| import json | |
| import joblib | |
| import streamlit as st | |
| import torch | |
| from .consts import ( | |
| FEATURE_NAMES, | |
| CATEGORY_MAPPING, | |
| GENDER_MAPPING, | |
| STATE_MAPPING, | |
| INPUT_METADATA, | |
| STREAMLIT_VALIDATED, | |
| MODEL_WEIGHTS_FULL_PATH, | |
| CONFIG_PATH, | |
| FEATURE_SCALER_PATH, | |
| ) | |
| from .model import Agent | |
| def setup_logger(config: dict, propogate: bool = False) -> Logger: | |
| """Sets up and returns a named logger based on the provided config dictionary. The new logger will have different handlers based on the config. | |
| Args: | |
| config (dict): Dictionary containing logging configuration. | |
| propogate (bool): Whether to allow log messages to propagate to ancestor loggers. | |
| Returns: | |
| Logger: Configured logger instance. | |
| """ | |
| logger_name = config.get("logger_name", "main") | |
| log_to_file = config.get("log_to_file", True) # Set whether to log to a logfile or not | |
| log_file = config.get("log_file", "logs/app.log") # Get the log file path | |
| log_lvl = config.get("log_level", "INFO") | |
| log_level = getattr(logging, log_lvl.upper(), logging.INFO) # Set fallback if invalid input | |
| log_mode = config.get("log_mode", "w") # Set the log file mode | |
| log_format = config.get("log_format", "%(asctime)s - %(name)s - %(levelname)s - %(message)s") | |
| date_format = config.get("date_format", "%Y-%m-%d %H:%M:%S") | |
| log_to_console = config.get("log_to_console", True) # Set whether to log to console or not | |
| handlers = [] # Initialize the list of logging handlers | |
| logger = logging.getLogger(logger_name) # Create logger object with the specified name | |
| if not log_to_file and not log_to_console: | |
| # If no handlers are specified by the config | |
| print( | |
| f"Warning: No logging handlers configured for {logger_name}.\nVerbose Logging will be disabled.\nIn 'config/config.json', set ['log_to_file': true] or ['log_to_console': true] if you want to change the logging behavior.", | |
| flush=True, | |
| ) | |
| else: | |
| # Create log parent directory if it doesn't exist | |
| parent_dir = os.path.dirname(log_file) # Get the parent directory of the log file | |
| if parent_dir and parent_dir != ".": | |
| try: | |
| os.makedirs(name=parent_dir, exist_ok=True) | |
| print( | |
| f"Parent directory '{parent_dir}' used to store the log file.", flush=True | |
| ) # flush=True to ensure the message is printed immediately | |
| except OSError as e: | |
| print( | |
| f"Error creating directory '{parent_dir}': {e} INFO: Using default log file 'app.log' instead.", | |
| flush=True, | |
| ) | |
| log_file = "app.log" # Fall back to a default log file if problem occurs. | |
| # Remove all old handlers inherrited from the root logger | |
| for handler in logger.handlers[:]: | |
| handler.close() | |
| logger.removeHandler(handler) | |
| formatter = logging.Formatter( | |
| fmt=log_format, datefmt=date_format | |
| ) # Create a formatter for the log messages | |
| if log_to_console: | |
| console_handler = ( | |
| logging.StreamHandler() | |
| ) # Initialize sending log messages to the console (stdout) | |
| console_handler.setFormatter(formatter) # Set the formatter for the console handler | |
| handlers.append(console_handler) # Add the console_handler to the list of handlers | |
| if log_to_file: | |
| file_handler = logging.FileHandler( | |
| filename=log_file, mode=log_mode, encoding="utf-8" | |
| ) # Initialize sending log messages to a file; Enables emoji use | |
| file_handler.setFormatter(formatter) # Set the style for the console handler | |
| handlers.append(file_handler) # Add the file_handler to the list of handlers | |
| # Add the handlers to the logger | |
| for handler in handlers: | |
| logger.addHandler(handler) | |
| logger.setLevel(log_level) # Set logger minimum log level | |
| logger.propagate = propogate # Prevent the log messages from being propagated to the root logger; gets rid of the root logger's default handlers, | |
| return logger | |
| def convert_inputs(**kwargs) -> list: | |
| """Convert user inputs into a list of features for the model. | |
| Args: | |
| **kwargs: Dictionary of user inputs (e.g., {'category': 'entertainment', 'amt': 25.0, ...}) | |
| Returns: | |
| features: A list of converted features ready for model input. | |
| """ | |
| features = [] # Create empty list to store all the features | |
| for feature_name in FEATURE_NAMES: # Loop through FEATURE_NAMES | |
| try: | |
| # Get the value from the kwargs dictionary | |
| value = kwargs.get(feature_name) | |
| # Perform validation (using metadata where possible) | |
| if value is None: | |
| raise ValueError(f"Missing required input: {feature_name}") | |
| # --- Mapped Features --- | |
| if feature_name == "category": | |
| # Use Specified Mapping for feature | |
| mapped_value = CATEGORY_MAPPING.get(value, None) | |
| if mapped_value is not None: | |
| if not isinstance(mapped_value, float): | |
| raise ValueError(f"{feature_name} must be a float.") | |
| features.append(mapped_value) | |
| else: | |
| raise ValueError(f"{feature_name}; value={value}; no mapping.") | |
| elif feature_name == "gender": | |
| # Use Specified Mapping for feature | |
| mapped_value = GENDER_MAPPING.get(value, None) | |
| if mapped_value is not None: | |
| if not isinstance(mapped_value, float): | |
| raise ValueError(f"{feature_name} must be a float.") | |
| features.append(mapped_value) | |
| else: | |
| raise ValueError(f"{feature_name}; value={value}; no mapping.") | |
| elif feature_name == "state": | |
| # Use Specified Mapping for feature | |
| mapped_value = STATE_MAPPING.get(value, None) | |
| if mapped_value is not None: | |
| if not isinstance(mapped_value, float): | |
| raise ValueError(f"{feature_name} must be a float.") | |
| features.append(mapped_value) | |
| else: | |
| raise ValueError(f"{feature_name}; value={value}; no mapping.") | |
| # ... Add logic for other mapped fields here | |
| # --- Streamlit-Validated Features --- | |
| elif feature_name in STREAMLIT_VALIDATED: | |
| # Use INPUT_METADATA for range validation | |
| meta = INPUT_METADATA.get(feature_name, {}) | |
| min_v = meta.get("min_value") | |
| max_v = meta.get("max_value") | |
| if min_v is not None and max_v is not None and not (min_v <= value <= max_v): | |
| raise ValueError(f"{feature_name} out of expected range.") | |
| features.append(float(value)) # Convert to float | |
| # Default action if not covered by logic above | |
| else: | |
| raise ValueError(f"No conversion for {feature_name}") | |
| except ValueError as e: | |
| log_and_stop(f"Validation Error for {feature_name}: {e}") | |
| except Exception as e: # Catch all other exceptions | |
| log_and_stop(f"An unexpected fatal error occurred: {e}") | |
| # Verify final length | |
| if len(features) != len(FEATURE_NAMES): | |
| log_and_stop( | |
| f"Fatal Error: Final feature list length mismatch. Created list size: {len(features)} | Expected list size: {len(FEATURE_NAMES)}" | |
| ) | |
| return features | |
| def load_config(): | |
| """Loads configuration file using global variable. Optimized using streamlit caching. | |
| Args: | |
| N/A | |
| Returns: | |
| config (dict): the python dictionary containing configuration data | |
| """ | |
| message = "" # Initialize variable in case of errors | |
| try: | |
| with open(CONFIG_PATH, "r", encoding="utf-8") as f: | |
| config = json.load(f) | |
| except FileNotFoundError: | |
| # For streamlit to acknowledge the '\n' character as a newline use ' \n'. Streamlit processes strings as Markdown | |
| message = f"β Configuration file not found at '{CONFIG_PATH}'. \nPlease ensure the file exists or fix path to file." | |
| except json.JSONDecodeError as e: | |
| message = f"β Failed to parse JSON: {e}." | |
| except Exception as e: # Catch all other exceptions | |
| message = f"An unexpected fatal error occurred: {e}" | |
| # This block executes ONLY if the 'try' block succeeds (no exceptions) | |
| else: | |
| return config | |
| # **This block executes after try/except/else** | |
| finally: | |
| # Check if a 'message' was set by any of the 'except' blocks. | |
| if message: | |
| message += " \nStopping Execution." # Add the common suffix | |
| print(message) | |
| st.error(message) | |
| st.stop() | |
| def load_model(_logger: Logger): | |
| """Helper function that loads the model's architecture and instantiates a model with its trained weights. Optimized using streamlit caching. | |
| Args: | |
| _logger (Logger): The logger instance to log messages. Use underscore to prevent hashing by Streamlit. | |
| Returns: | |
| Agent (torch.nn.Module): Returns agent to cpu in evaluation mode. | |
| """ | |
| message = "" # Initialize variable in case of errors | |
| try: | |
| model_weights = torch.load(MODEL_WEIGHTS_FULL_PATH, weights_only=True) | |
| _logger.info(f"β Model weights loaded successfully from {MODEL_WEIGHTS_FULL_PATH}") | |
| except FileNotFoundError: | |
| message = f"β Model Weights file not found at '{MODEL_WEIGHTS_FULL_PATH}'. \nPlease ensure the file exists." | |
| log_and_stop(message) | |
| CONFIG = load_config() | |
| MODEL_CONFIG = CONFIG.get("model", {}) | |
| try: | |
| agent = Agent(cfg=MODEL_CONFIG) # Create agent instance | |
| agent.load_state_dict(state_dict=model_weights) | |
| except RuntimeError as e: | |
| message = f"β A runtime error occurred while creating model or loading model weights: {e}" | |
| except FileNotFoundError as e: | |
| message = f"β Model weights file not found: {e}" | |
| except KeyError as e: | |
| message = f"β Missing key in model configuration: {e}" | |
| except Exception as e: # Catch all other exceptions | |
| message = f"An unexpected fatal error occurred: {e}" | |
| # Execute if no exception was caught | |
| else: | |
| return agent.eval().to("cpu") | |
| # If exception was thrown continue to the finally block | |
| finally: | |
| if message: | |
| log_and_stop(message) | |
| def load_feature_scaler(_logger: Logger): | |
| """Loads the feature scaler using the global variable. Optimized using streamlit caching. | |
| Args: | |
| _logger (Logger): The logger instance to log messages. Use underscore to prevent hashing by Streamlit. | |
| Returns: | |
| feature_scaler: the loaded scalert object | |
| """ | |
| message = "" # Initialize variable in case of errors | |
| # Load feature scaler | |
| try: | |
| feature_scaler = joblib.load(FEATURE_SCALER_PATH) | |
| _logger.info(f"β Feature Scaler loaded successfully from {FEATURE_SCALER_PATH}") | |
| except FileNotFoundError: | |
| message = f"β Scaler file not found at '{FEATURE_SCALER_PATH}'. \nPlease ensure the file exists or fix path to file." | |
| except Exception as e: # Catch all other exceptions | |
| message = f"An unexpected fatal error occurred: {e}" | |
| # Execute if no exception was caught | |
| else: | |
| return feature_scaler | |
| # If exception was thrown continue to the finally block | |
| finally: | |
| if message: | |
| log_and_stop(message) | |
| def load_label_scaler(_logger: Logger): | |
| """Loads the label scaler using the global variable. Optimized using streamlit caching. | |
| Args: | |
| _logger (Logger): The logger instance to log messages. Use underscore to prevent hashing by Streamlit. | |
| Returns: | |
| label_scaler: the loaded scalert object | |
| """ | |
| # Not used in this implementation | |
| label_scaler = None | |
| return label_scaler | |
| def log_and_stop(message: str): | |
| """Helper function to log relevant messages. Handles message and exits the program. | |
| Args: | |
| message (str): The message to log and display | |
| Returns: | |
| N/A | |
| """ | |
| logger_name = load_config()["logging"]["logger_name"] | |
| logger = logging.getLogger(logger_name) | |
| message += " \nStopping Execution." # Add the common suffix | |
| logger.info(message, exc_info=False, stack_info=False) # Console output | |
| st.error(message) # Streamlit UI output | |
| st.stop() # Stops Streamlit app | |