# Register model in the Mlflow model registry and transition it to the "Staging" stage. import sys, os import json import mlflow import dagshub from pathlib import Path from src.core.logger import logging from src.core.exception import AppException from src.core.configuration import AppConfiguration from dotenv import load_dotenv load_dotenv() # get environment variables uri = os.getenv("MLFLOW_URI") dagshub_token = os.getenv("DAGSHUB_TOKEN") dagshub_username = os.getenv("OWNER") if not dagshub_token or not dagshub_username: raise EnvironmentError("Dagshub environment variables is not set") os.environ["MLFLOW_TRACKING_USERNAME"] = dagshub_username os.environ["MLFLOW_TRACKING_PASSWORD"] = dagshub_token mlflow.set_tracking_uri(uri) # type: ignore # For local use # ================================================================================== # repo_owner = os.getenv("OWNER") # repo_name = os.getenv("REPO") # mlflow.set_tracking_uri(uri) # # if repo_owner is None: # raise EnvironmentError("Missing dagshub logging environment credentials.") # dagshub.init(repo_owner=repo_owner, repo_name=repo_name, mlflow=True) # =================================================================================== class ModelRegistration: def __init__(self, config = AppConfiguration()): """ Initializes the DataPreprocessing object by creating a data preprocessing configuration. Raises: AppException: If error occurs during creation of data preprocessing configuration """ try: self.registration_config = config.model_registration_config() except Exception as e: logging.error(f"Failed to create model registration configuration: {e}", exc_info=True) raise AppException(e, sys) def load_exp_info(self, exp_info_file: Path): """ Loads the experiment information from the saved JSON file. Args: exp_info_file (Path): The file path to the 'experiment_info.json' file. Returns: dict: A dictionary containing the experiment information if the file is found. """ try: if not os.path.exists(exp_info_file): raise FileNotFoundError("'experiment_info.json' file not found") with open(exp_info_file, 'r') as file: exp_info = json.load(file) return exp_info except Exception as e: logging.error(f"Failed to get model experiment information: {e}", exc_info=True) raise AppException(e, sys) def register(self, experiment_info: dict, register_modelname: str): """ Registers a model in Mlflow using the provided experiment information and model name. Transitions the registered model version to the "Staging" stage. Args: experiment_info (dict): dictionary containing experiment details: 'run_id' and 'model_name'. register_modelname (str): name under which the model to be registered in Mlflow.. """ try: model_uri = f"runs:/{experiment_info['run_id']}/{experiment_info['model']}" logging.info("Registering model in Mlflow") model_version = mlflow.register_model(model_uri, register_modelname) # Transition the model to "Staging" stage client = mlflow.MlflowClient() client.transition_model_version_stage( name = register_modelname, version = model_version.version, stage = "Staging" ) logging.info(f"{register_modelname} - version : {model_version.version} registered and transitioned to Staging") except Exception as e: logging.error(f"Error during model registration in Mlflow: {e}", exc_info=True) raise AppException(e, sys) def register_model(): """ Main function to handle the model registration process in Mlflow. Initializes a RegisterModel object, loads experiment information from a JSON file, and registers the model in Mlflow. The registered model is transitioned to the "Staging" stage. Raises: AppException: If an error occurs during model registration. """ obj = ModelRegistration() try: logging.info(f"{'='*20}Model Registration{'='*20}") exp_info_filepath = obj.registration_config.experiment_info_filepath experiment_info = obj.load_exp_info(exp_info_filepath) register_modelname = "ToxicTagger-Models" if type(experiment_info) is not dict: logging.error("'register_model' function expects dict type object ") return obj.register(experiment_info, register_modelname) logging.info(f"{'='*20}Model Registration in Mlflow Completed Successfully{'='*20} \n\n") except Exception as e: logging.error(f"Error during model registration in Mlflow: {e}", exc_info=True) raise AppException(e, sys) if __name__ == "__main__": register_model()