| import joblib |
| import pandas as pd |
| import tensorflow as tf |
| import logging |
| import json |
| from pathlib import Path |
| from core.config import settings |
| from core.exceptions import ArtifactLoadError |
| from utils.temperature_scaling import TemperatureScaler |
| import requests |
|
|
| logger = logging.getLogger(__name__) |
|
|
| class ModelArtifacts: |
| """ |
| Singleton-like container for ML artifacts. |
| Loaded at startup. |
| """ |
|
|
| _instance = None |
| |
| def __init__(self): |
| self.feature_creator = None |
| self.preprocessor = None |
| self.model = None |
| self.scaler = None |
| self.threshold = 0.5 |
| self.background_data = None |
| self.is_loaded = False |
|
|
| def download_if_missing(self, path: Path, url: str): |
| if not path.exists(): |
| logger.info(f"Downloading {path.name} from {url}...") |
| path.parent.mkdir(parents=True, exist_ok=True) |
| try: |
| with requests.get(url, stream=True) as response: |
| response.raise_for_status() |
| with open(path, "wb") as f: |
| for chunk in response.iter_content(chunk_size=8192): |
| if chunk: |
| f.write(chunk) |
| logger.info(f"Downloaded {path.name}") |
| except Exception as e: |
| logger.error(f"Failed to download {path.name}: {e}") |
| else: |
| logger.info(f"Artifact {path.name} found locally. Skipping download.") |
|
|
| |
| |
| @classmethod |
| def get_instance(cls): |
| if cls._instance is None: |
| cls._instance = cls() |
| return cls._instance |
|
|
|
|
| def load_artifacts(self): |
| """ |
| Loads all ML artifacts into memory. |
| """ |
| if self.is_loaded: |
| logger.info("Artifacts already loaded.") |
| return |
|
|
| logger.info("Loading artifacts...") |
| |
| |
| for path, url in settings.ARTIFACT_URLS.items(): |
| self.download_if_missing(path, url) |
|
|
| try: |
|
|
| |
| if not settings.model_path_abs.exists(): |
| logger.warning(f"Model file not found at {settings.model_path_abs}. Skipping load.") |
| else: |
| |
| |
| with tf.keras.utils.custom_object_scope({'TemperatureScaler': TemperatureScaler}): |
| self.model = tf.keras.models.load_model(settings.model_path_abs, compile=False) |
| logger.info(f"Keras model loaded from {settings.model_path_abs}.") |
|
|
| |
| if not settings.scaler_path_abs.exists(): |
| logger.warning(f"Scaler file not found at {settings.scaler_path_abs}. Using default initialization.") |
| self.scaler = TemperatureScaler() |
| else: |
| with tf.keras.utils.custom_object_scope({'TemperatureScaler': TemperatureScaler}): |
| self.scaler = tf.keras.models.load_model(settings.scaler_path_abs, compile=False) |
| logger.info(f"TemperatureScaler loaded from {settings.scaler_path_abs}.") |
|
|
| |
| if not settings.feature_creator_path_abs.exists(): |
| logger.warning(f"Feature Creator file not found at {settings.feature_creator_path_abs}. Skipping load.") |
| else: |
| self.feature_creator = joblib.load(settings.feature_creator_path_abs) |
| logger.info(f"Feature Creator loaded from {settings.feature_creator_path_abs}.") |
|
|
| |
| if not settings.preprocessor_path_abs.exists(): |
| logger.warning(f"Preprocessor file not found at {settings.preprocessor_path_abs}. Skipping load.") |
| else: |
| self.preprocessor = joblib.load(settings.preprocessor_path_abs) |
| logger.info(f"Preprocessor loaded from {settings.preprocessor_path_abs}.") |
| |
| |
| if settings.threshold_path_abs.exists(): |
| try: |
| with open(settings.threshold_path_abs, 'r') as f: |
| data = json.load(f) |
| |
| if isinstance(data, dict) and "best_threshold" in data: |
| self.threshold = float(data["best_threshold"]) |
| elif isinstance(data, float): |
| self.threshold = data |
| logger.info(f"Clinical threshold loaded: {self.threshold}") |
| except Exception as e: |
| logger.warning(f"Failed to parse threshold file: {e}. Using default 0.5") |
| |
| |
| if not settings.background_data_path_abs.exists(): |
| logger.warning(f"Background data not found at {settings.background_data_path_abs}. Skipping load.") |
| else: |
| self.background_data = pd.read_csv(settings.background_data_path_abs) |
| logger.info("Background data loaded.") |
|
|
| self.is_loaded = True |
| logger.info("All artifacts loaded successfully.") |
|
|
| except Exception as e: |
| logger.error(f"Failed to load artifacts: {e}") |
| raise ArtifactLoadError(f"Failed to load artifacts: {e}") |
| |
| def clear(self): |
| """ |
| Unloads all ML artifacts from memory. |
| """ |
| logger.info("Unloading artifacts...") |
| self.feature_creator = None |
| self.preprocessor = None |
| self.model = None |
| self.scaler = None |
| self.threshold = None |
| self.background_data = None |
| self.is_loaded = False |
| logger.info("Artifacts unloaded.") |
|
|
|
|
| |
| def get_artifacts() -> ModelArtifacts: |
| return ModelArtifacts.get_instance() |
|
|