File size: 4,274 Bytes
fdadf61
 
 
0df85ee
fdadf61
0df85ee
 
fdadf61
 
 
 
 
 
ab973af
fdadf61
 
 
 
 
ab973af
fdadf61
ab973af
fdadf61
 
 
 
 
 
 
0df85ee
 
 
 
 
ab973af
 
 
 
 
 
0df85ee
 
 
 
 
 
0a36c22
fdadf61
 
 
 
 
 
 
 
 
ab973af
fdadf61
 
 
 
 
 
 
 
 
 
ab973af
0df85ee
 
 
 
fdadf61
 
 
 
 
 
 
 
 
 
 
 
 
ab973af
fdadf61
 
 
 
ab973af
 
 
fdadf61
 
 
 
 
 
 
ab973af
 
 
fdadf61
 
 
ab973af
fdadf61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab973af
fdadf61
ab973af
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import joblib
import json
import logging
import requests
from typing import Any, Optional
from pathlib import Path

from sklearn.pipeline import Pipeline
from core.config import settings
from core.exceptions import ArtifactLoadError

logger = logging.getLogger(__name__)


class ModelArtifacts:
    """
    Singleton-like container for ML artifacts.
    Loaded at startup.
    """

    _instance = None

    def __init__(self):
        self.model: Optional[Pipeline] = None
        self.threshold: float = 0.5
        self.shap_background: Any = None
        self.feature_names: Optional[list] = None
        self.is_loaded = False

    def download_if_missing(self, path: Path, url: str):
        if not path.exists():
            logger.info(f"Downloading {path.name} from {url}...")
            path.parent.mkdir(parents=True, exist_ok=True)
            try:
                with requests.get(url, stream=True) as response:
                    response.raise_for_status()
                    with open(path, "wb") as f:
                        for chunk in response.iter_content(chunk_size=8192):
                            if chunk:
                                f.write(chunk)
                logger.info(f"Downloaded {path.name}")
            except Exception as e:
                logger.error(f"Failed to download {path.name}: {e}")
        else:
            logger.info(f"Artifact {path.name} found locally. Skipping download.")

    # Singleton pattern
    @classmethod
    def get_instance(cls):
        if cls._instance is None:
            cls._instance = cls()
        return cls._instance

    def load_artifacts(self):
        """
        Loads all artifacts from disk into memory.
        """
        if self.is_loaded:
            logger.info("Artifacts already loaded.")
            return

        # Directories are now Path objects
        model_dir = settings.MODEL_DIR
        artifacts_dir = settings.ARTIFACTS_DIR

        logger.info(f"Loading models from {model_dir}")
        logger.info(f"Loading artifacts from {artifacts_dir}")

        # Ensure all artifacts are present
        for path, url in settings.ARTIFACT_URLS.items():
            self.download_if_missing(path, url)

        try:
            # Load Model
            model_path = model_dir / settings.MODEL_FILENAME
            if model_path.exists():
                self.model = joblib.load(model_path)
                logger.info(f"Loaded model from {model_path}")
            else:
                logger.error(f"Model file not found at {model_path}")
                raise FileNotFoundError(f"Model not found at {model_path}")

            # Load Threshold
            thresh_path = artifacts_dir / settings.THRESHOLD_FILENAME
            if thresh_path.exists():
                with open(thresh_path, "r") as f:
                    data = json.load(f)
                    self.threshold = float(data.get("threshold", 0.5))
                logger.info(f"Loaded threshold {self.threshold} from {thresh_path}")
            else:
                logger.warning(
                    f"Threshold file not found at {thresh_path}, using default 0.5"
                )

            # Load SHAP Background
            shap_path = artifacts_dir / settings.SHAP_BACKGROUND_FILENAME
            if shap_path.exists():
                self.shap_background = joblib.load(shap_path)
                logger.info(f"Loaded SHAP background from {shap_path}")
            else:
                logger.warning(
                    f"SHAP background not found at {shap_path}, SHAP explanations might fail."
                )

            self.is_loaded = True
            logger.info("All artifacts loaded successfully.")

        except Exception as e:
            logger.error(f"Failed to load artifacts: {e}")
            raise ArtifactLoadError(f"Critical error loading artifacts: {e}")

    def clear(self):
        """
        Unloads all ML artifacts from memory.
        """
        logger.info("Unloading artifacts...")
        self.model = None
        self.threshold = None
        self.shap_background = None
        self.is_loaded = False
        logger.info("Artifacts unloaded.")


# Global instance
def get_artifacts() -> ModelArtifacts:
    return ModelArtifacts.get_instance()