ToxicTweet-Tagger / operations /register_model.py
Subi003's picture
Upload folder using huggingface_hub
4c01182 verified
# Register model in the Mlflow model registry and transition it to the "Staging" stage.
import sys, os
import json
import mlflow
import dagshub
from pathlib import Path
from src.core.logger import logging
from src.core.exception import AppException
from src.core.configuration import AppConfiguration
from dotenv import load_dotenv
load_dotenv()
# get environment variables
uri = os.getenv("MLFLOW_URI")
dagshub_token = os.getenv("DAGSHUB_TOKEN")
dagshub_username = os.getenv("OWNER")
if not dagshub_token or not dagshub_username:
raise EnvironmentError("Dagshub environment variables is not set")
os.environ["MLFLOW_TRACKING_USERNAME"] = dagshub_username
os.environ["MLFLOW_TRACKING_PASSWORD"] = dagshub_token
mlflow.set_tracking_uri(uri) # type: ignore
# For local use
# ==================================================================================
# repo_owner = os.getenv("OWNER")
# repo_name = os.getenv("REPO")
# mlflow.set_tracking_uri(uri)
#
# if repo_owner is None:
# raise EnvironmentError("Missing dagshub logging environment credentials.")
# dagshub.init(repo_owner=repo_owner, repo_name=repo_name, mlflow=True)
# ===================================================================================
class ModelRegistration:
def __init__(self, config = AppConfiguration()):
"""
Initializes the DataPreprocessing object by creating a data preprocessing configuration.
Raises:
AppException: If error occurs during creation of data preprocessing configuration
"""
try:
self.registration_config = config.model_registration_config()
except Exception as e:
logging.error(f"Failed to create model registration configuration: {e}", exc_info=True)
raise AppException(e, sys)
def load_exp_info(self, exp_info_file: Path):
"""
Loads the experiment information from the saved JSON file.
Args:
exp_info_file (Path): The file path to the 'experiment_info.json' file.
Returns:
dict: A dictionary containing the experiment information if the file is found.
"""
try:
if not os.path.exists(exp_info_file):
raise FileNotFoundError("'experiment_info.json' file not found")
with open(exp_info_file, 'r') as file:
exp_info = json.load(file)
return exp_info
except Exception as e:
logging.error(f"Failed to get model experiment information: {e}", exc_info=True)
raise AppException(e, sys)
def register(self, experiment_info: dict, register_modelname: str):
"""
Registers a model in Mlflow using the provided experiment information and model name.
Transitions the registered model version to the "Staging" stage.
Args:
experiment_info (dict): dictionary containing experiment details:
'run_id' and 'model_name'.
register_modelname (str): name under which the model to be registered in Mlflow..
"""
try:
model_uri = f"runs:/{experiment_info['run_id']}/{experiment_info['model']}"
logging.info("Registering model in Mlflow")
model_version = mlflow.register_model(model_uri, register_modelname)
# Transition the model to "Staging" stage
client = mlflow.MlflowClient()
client.transition_model_version_stage(
name = register_modelname,
version = model_version.version,
stage = "Staging"
)
logging.info(f"{register_modelname} - version : {model_version.version} registered and transitioned to Staging")
except Exception as e:
logging.error(f"Error during model registration in Mlflow: {e}", exc_info=True)
raise AppException(e, sys)
def register_model():
"""
Main function to handle the model registration process in Mlflow.
Initializes a RegisterModel object, loads experiment information from a JSON file,
and registers the model in Mlflow. The registered model is transitioned to the "Staging" stage.
Raises:
AppException: If an error occurs during model registration.
"""
obj = ModelRegistration()
try:
logging.info(f"{'='*20}Model Registration{'='*20}")
exp_info_filepath = obj.registration_config.experiment_info_filepath
experiment_info = obj.load_exp_info(exp_info_filepath)
register_modelname = "ToxicTagger-Models"
if type(experiment_info) is not dict:
logging.error("'register_model' function expects dict type object ")
return
obj.register(experiment_info, register_modelname)
logging.info(f"{'='*20}Model Registration in Mlflow Completed Successfully{'='*20} \n\n")
except Exception as e:
logging.error(f"Error during model registration in Mlflow: {e}", exc_info=True)
raise AppException(e, sys)
if __name__ == "__main__":
register_model()