File size: 1,912 Bytes
a7d80f2
 
7e4b5f8
a7d80f2
7e4b5f8
a7d80f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e4b5f8
a7d80f2
 
 
 
 
 
 
 
 
 
 
 
 
7e4b5f8
 
 
 
 
 
 
 
 
 
 
 
a7d80f2
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import shutil
import os
import json
from pathlib import Path
from datetime import datetime
from mlpipeline.entity import ModelPusherConfig, ModelPusherArtifact
from mlpipeline.logging.logger import get_logger
from mlpipeline.exception import ModelPusherException
import sys

logger = get_logger(__name__)


class ModelPusher:
    def __init__(self, config: ModelPusherConfig):
        self.config = config
    
    def push_model(self) -> ModelPusherArtifact:
        try:
            logger.info("Starting model pusher")
            
            os.makedirs(self.config.model_registry_path, exist_ok=True)
            os.makedirs(self.config.root_dir, exist_ok=True)
            
            model_source = Path(self.config.model_path)
            model_dest = Path(self.config.model_registry_path) / model_source.name
            
            if model_source.is_dir():
                if model_dest.exists():
                    shutil.rmtree(model_dest)
                shutil.copytree(model_source, model_dest)
            else:
                shutil.copy2(model_source, model_dest)
            
            logger.info(f"Model pushed to: {model_dest}")
            
            # Save push metadata
            metadata = {
                "pushed_at": datetime.now().isoformat(),
                "source_path": str(model_source),
                "destination_path": str(model_dest),
                "status": "success"
            }
            
            metadata_path = Path(self.config.root_dir) / "push_metadata.json"
            with open(metadata_path, 'w') as f:
                json.dump(metadata, f, indent=4)
            
            return ModelPusherArtifact(
                pushed_model_path=str(model_dest),
                is_pushed=True,
                message="Model pushed successfully"
            )
        except Exception as e:
            raise ModelPusherException(str(e), sys)