import joblib import json import logging from pathlib import Path from typing import Optional, Union, Any from utils.model import QSSelectiveCalibratedModel from utils.engines import QSRuleBasedFallbackEngine from utils.rules import QS_FALLBACK_RULES from core.config import settings from core.exceptions import ArtifactLoadError logger = logging.getLogger(__name__) def load_qs_model(base: Union[str, Path]) -> QSSelectiveCalibratedModel: """ Safely reconstructs the hermetic QSSelectiveCalibratedModel from serialized artifacts. """ base = Path(base) # 1. Load Scikit-Learn binaries calibrator = joblib.load(base / "calibrator.joblib") preprocessor = joblib.load(base / "preprocessor.joblib") label_encoder = joblib.load(base / "label_encoder.joblib") # 2. Safely load JSON metadata using context managers and utf-8 with open(base / "threshold.json", "r", encoding="utf-8") as f: threshold = json.load(f)["threshold"] with open(base / "hierarchy_map.json", "r", encoding="utf-8") as f: hierarchy_map = json.load(f) with open(base / "schema.json", "r", encoding="utf-8") as f: schema = json.load(f) with open(base / "metadata.json", "r", encoding="utf-8") as f: metadata = json.load(f) # 3. Instantiate the hermetic model container model = QSSelectiveCalibratedModel( calibrator=calibrator, preprocessor=preprocessor, label_encoder=label_encoder, threshold=threshold, # Safely fall back to "Others" if loading an older metadata payload others_label=metadata.get("others_label", "Others"), nrm_hierarchy_map=hierarchy_map, model_version=metadata.get("model_version", "v1.0.0"), schema_version=metadata.get("schema_version", "v1"), hierarchy_version=metadata.get("hierarchy_version", "v1"), drift_log_path=base / "drift" / "schema_drift.log", ) # 4. Restore class-level schema definitions safely # NOTE: Do NOT overwrite model.TRAIN_SCHEMA here because __post_init__ # perfectly reconstructs it directly from the loaded preprocessor pipeline! model.RENAME_MAP = schema.get("RENAME_MAP", model.RENAME_MAP) return model class ModelArtifacts: """ Singleton-like container for ML artifacts. Loaded at startup. """ _instance = None def __init__(self): self.model: Optional[QSSelectiveCalibratedModel] = None self.fallback_engine: Optional[QSRuleBasedFallbackEngine] = None self.shap_background: Any = None self.is_loaded = False @classmethod def get_instance(cls): if cls._instance is None: cls._instance = cls() return cls._instance def load_artifacts(self): """ Loads all artifacts from disk into memory. """ if self.is_loaded: logger.info("Artifacts already loaded.") return model_dir = settings.MODEL_DIR logger.info(f"Loading models from {model_dir}") try: # 1. Safely reconstruct the hermetic model self.model = load_qs_model(model_dir) # 2. Instantiate the deterministic fallback engine self.fallback_engine = QSRuleBasedFallbackEngine(rule_map=QS_FALLBACK_RULES) # 3. Load SHAP Background shap_path = model_dir / "shap_background.joblib" if shap_path.exists(): self.shap_background = joblib.load(shap_path) logger.info(f"Loaded SHAP background from {shap_path}") else: logger.warning( f"SHAP background not found at {shap_path}, SHAP explanations might fail." ) self.is_loaded = True logger.info( f"Successfully loaded model: {self.model.model_name} (version: {self.model.model_version}), Fallback Engine, and SHAP context." ) except Exception as e: logger.error(f"Failed to load artifacts: {e}") raise ArtifactLoadError(f"Critical error loading artifacts: {e}") def clear(self): """ Unloads all ML artifacts from memory. """ logger.info("Unloading artifacts...") self.model = None self.fallback_engine = None self.shap_background = None self.is_loaded = False logger.info("Artifacts unloaded.") # Global instance def get_artifacts() -> ModelArtifacts: return ModelArtifacts.get_instance()