File size: 2,978 Bytes
17c5137 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 | from src.entity.config_entity import ModelPusherConfig
from src.entity import artifact_entity
from src.predictor import ModelResolver
from src.exception import CropException
from src.logger import logging
from src.utils import load_object, save_object
from src.entity.artifact_entity import (
DataTransformationArtifact,
ModelTrainerArtifact,
ModelPusherArtifact,
)
import sys
import os
class ModelPusher:
def __init__(
self,
model_pusher_config: ModelPusherConfig,
data_transformation_artifact: DataTransformationArtifact,
model_trainer_artifact: ModelTrainerArtifact,
):
try:
logging.info(f"{'>'*20} Model Pusher Initiated {'<'*30}")
self.model_pusher_config = model_pusher_config
self.data_transformation_artifact = data_transformation_artifact
self.model_trainer_artifact = model_trainer_artifact
self.model_resolver = ModelResolver(
model_registry=self.model_pusher_config.saved_model_dir
)
except Exception as e:
raise CropException(e, sys)
def initiate_model_pusher(self) -> ModelPusherArtifact:
try:
# load object
logging.info(f"Loading transformer model and target encoder")
transformer = load_object(file_path=self.data_transformation_artifact.transform_object_path)
model = load_object(file_path=self.model_trainer_artifact.model_path)
target_encoder = load_object(file_path=self.data_transformation_artifact.target_encoder_path)
# model pusher dir
logging.info(f"Saving model into model pusher directory")
save_object(file_path=self.model_pusher_config.pusher_transformer_path,obj=transformer)
save_object(file_path=self.model_pusher_config.pusher_model_path, obj=model)
save_object(file_path=self.model_pusher_config.pusher_target_encoder_path, obj=target_encoder)
# saved model dir
logging.info(f"Saving model in saved model dir")
transformer_path = self.model_resolver.get_latest_save_transformer_path()
model_path = self.model_resolver.get_latest_save_model_path()
target_encoder_path = self.model_resolver.get_latest_save_target_encoder_path()
save_object(file_path=transformer_path, obj=transformer)
save_object(file_path=model_path, obj=model)
save_object(file_path=target_encoder_path, obj=target_encoder)
model_pusher_artifact = ModelPusherArtifact(
pusher_model_dir=self.model_pusher_config.pusher_model_dir,
saved_model_dir=self.model_pusher_config.saved_model_dir,
)
logging.info(f"Model Pusher artifact: {model_pusher_artifact}")
return model_pusher_artifact
except Exception as e:
raise CropException(e, sys)
|