D0nG4667
fix(api): fix singleton AttributeError in artifacts_loader.py
d9bf10d
import joblib
import pandas as pd
import tensorflow as tf
import logging
import json
from pathlib import Path
from core.config import settings
from core.exceptions import ArtifactLoadError
from utils.temperature_scaling import TemperatureScaler
import requests
logger = logging.getLogger(__name__)
class ModelArtifacts:
"""
Singleton-like container for ML artifacts.
Loaded at startup.
"""
_instance = None
def __init__(self):
self.feature_creator = None
self.preprocessor = None
self.model = None
self.scaler = None
self.threshold = 0.5
self.background_data = None
self.is_loaded = False
def download_if_missing(self, path: Path, url: str):
if not path.exists():
logger.info(f"Downloading {path.name} from {url}...")
path.parent.mkdir(parents=True, exist_ok=True)
try:
with requests.get(url, stream=True) as response:
response.raise_for_status()
with open(path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
logger.info(f"Downloaded {path.name}")
except Exception as e:
logger.error(f"Failed to download {path.name}: {e}")
else:
logger.info(f"Artifact {path.name} found locally. Skipping download.")
# Singleton pattern
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls()
return cls._instance
def load_artifacts(self):
"""
Loads all ML artifacts into memory.
"""
if self.is_loaded:
logger.info("Artifacts already loaded.")
return
logger.info("Loading artifacts...")
# Ensure all artifacts are present
for path, url in settings.ARTIFACT_URLS.items():
self.download_if_missing(path, url)
try:
# 1. Load Keras Model
if not settings.model_path_abs.exists():
logger.warning(f"Model file not found at {settings.model_path_abs}. Skipping load.")
else:
# Register custom objects if the model was saved with them or references them
# Though typical SaveModel format handles layers well, sometimes custom classes need registration
with tf.keras.utils.custom_object_scope({'TemperatureScaler': TemperatureScaler}):
self.model = tf.keras.models.load_model(settings.model_path_abs, compile=False)
logger.info(f"Keras model loaded from {settings.model_path_abs}.")
# 2. Load Temperature Scaler
if not settings.scaler_path_abs.exists():
logger.warning(f"Scaler file not found at {settings.scaler_path_abs}. Using default initialization.")
self.scaler = TemperatureScaler() # Default fallback
else:
with tf.keras.utils.custom_object_scope({'TemperatureScaler': TemperatureScaler}):
self.scaler = tf.keras.models.load_model(settings.scaler_path_abs, compile=False)
logger.info(f"TemperatureScaler loaded from {settings.scaler_path_abs}.")
# 3. Load Feature Creator (Pipeline)
if not settings.feature_creator_path_abs.exists():
logger.warning(f"Feature Creator file not found at {settings.feature_creator_path_abs}. Skipping load.")
else:
self.feature_creator = joblib.load(settings.feature_creator_path_abs)
logger.info(f"Feature Creator loaded from {settings.feature_creator_path_abs}.")
# 4. Load Preprocessor
if not settings.preprocessor_path_abs.exists():
logger.warning(f"Preprocessor file not found at {settings.preprocessor_path_abs}. Skipping load.")
else:
self.preprocessor = joblib.load(settings.preprocessor_path_abs)
logger.info(f"Preprocessor loaded from {settings.preprocessor_path_abs}.")
# 5. Load Threshold
if settings.threshold_path_abs.exists():
try:
with open(settings.threshold_path_abs, 'r') as f:
data = json.load(f)
# Expecting format like {"best_threshold": 0.35} or just a float
if isinstance(data, dict) and "best_threshold" in data:
self.threshold = float(data["best_threshold"])
elif isinstance(data, float):
self.threshold = data
logger.info(f"Clinical threshold loaded: {self.threshold}")
except Exception as e:
logger.warning(f"Failed to parse threshold file: {e}. Using default 0.5")
# 6. Load Background Data
if not settings.background_data_path_abs.exists():
logger.warning(f"Background data not found at {settings.background_data_path_abs}. Skipping load.")
else:
self.background_data = pd.read_csv(settings.background_data_path_abs)
logger.info("Background data loaded.")
self.is_loaded = True
logger.info("All artifacts loaded successfully.")
except Exception as e:
logger.error(f"Failed to load artifacts: {e}")
raise ArtifactLoadError(f"Failed to load artifacts: {e}")
def clear(self):
"""
Unloads all ML artifacts from memory.
"""
logger.info("Unloading artifacts...")
self.feature_creator = None
self.preprocessor = None
self.model = None
self.scaler = None
self.threshold = None
self.background_data = None
self.is_loaded = False
logger.info("Artifacts unloaded.")
# Global instance
def get_artifacts() -> ModelArtifacts:
return ModelArtifacts.get_instance()