boq-api / models /artifacts_loader.py
gabcares's picture
Upload 80 files
72fdabd verified
Raw
History Blame Contribute Delete
4.55 kB
import joblib
import json
import logging
from pathlib import Path
from typing import Optional, Union, Any
from utils.model import QSSelectiveCalibratedModel
from utils.engines import QSRuleBasedFallbackEngine
from utils.rules import QS_FALLBACK_RULES
from core.config import settings
from core.exceptions import ArtifactLoadError
logger = logging.getLogger(__name__)
def load_qs_model(base: Union[str, Path]) -> QSSelectiveCalibratedModel:
"""
Safely reconstructs the hermetic QSSelectiveCalibratedModel from serialized artifacts.
"""
base = Path(base)
# 1. Load Scikit-Learn binaries
calibrator = joblib.load(base / "calibrator.joblib")
preprocessor = joblib.load(base / "preprocessor.joblib")
label_encoder = joblib.load(base / "label_encoder.joblib")
# 2. Safely load JSON metadata using context managers and utf-8
with open(base / "threshold.json", "r", encoding="utf-8") as f:
threshold = json.load(f)["threshold"]
with open(base / "hierarchy_map.json", "r", encoding="utf-8") as f:
hierarchy_map = json.load(f)
with open(base / "schema.json", "r", encoding="utf-8") as f:
schema = json.load(f)
with open(base / "metadata.json", "r", encoding="utf-8") as f:
metadata = json.load(f)
# 3. Instantiate the hermetic model container
model = QSSelectiveCalibratedModel(
calibrator=calibrator,
preprocessor=preprocessor,
label_encoder=label_encoder,
threshold=threshold,
# Safely fall back to "Others" if loading an older metadata payload
others_label=metadata.get("others_label", "Others"),
nrm_hierarchy_map=hierarchy_map,
model_version=metadata.get("model_version", "v1.0.0"),
schema_version=metadata.get("schema_version", "v1"),
hierarchy_version=metadata.get("hierarchy_version", "v1"),
drift_log_path=base / "drift" / "schema_drift.log",
)
# 4. Restore class-level schema definitions safely
# NOTE: Do NOT overwrite model.TRAIN_SCHEMA here because __post_init__
# perfectly reconstructs it directly from the loaded preprocessor pipeline!
model.RENAME_MAP = schema.get("RENAME_MAP", model.RENAME_MAP)
return model
class ModelArtifacts:
"""
Singleton-like container for ML artifacts.
Loaded at startup.
"""
_instance = None
def __init__(self):
self.model: Optional[QSSelectiveCalibratedModel] = None
self.fallback_engine: Optional[QSRuleBasedFallbackEngine] = None
self.shap_background: Any = None
self.is_loaded = False
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls()
return cls._instance
def load_artifacts(self):
"""
Loads all artifacts from disk into memory.
"""
if self.is_loaded:
logger.info("Artifacts already loaded.")
return
model_dir = settings.MODEL_DIR
logger.info(f"Loading models from {model_dir}")
try:
# 1. Safely reconstruct the hermetic model
self.model = load_qs_model(model_dir)
# 2. Instantiate the deterministic fallback engine
self.fallback_engine = QSRuleBasedFallbackEngine(rule_map=QS_FALLBACK_RULES)
# 3. Load SHAP Background
shap_path = model_dir / "shap_background.joblib"
if shap_path.exists():
self.shap_background = joblib.load(shap_path)
logger.info(f"Loaded SHAP background from {shap_path}")
else:
logger.warning(
f"SHAP background not found at {shap_path}, SHAP explanations might fail."
)
self.is_loaded = True
logger.info(
f"Successfully loaded model: {self.model.model_name} (version: {self.model.model_version}), Fallback Engine, and SHAP context."
)
except Exception as e:
logger.error(f"Failed to load artifacts: {e}")
raise ArtifactLoadError(f"Critical error loading artifacts: {e}")
def clear(self):
"""
Unloads all ML artifacts from memory.
"""
logger.info("Unloading artifacts...")
self.model = None
self.fallback_engine = None
self.shap_background = None
self.is_loaded = False
logger.info("Artifacts unloaded.")
# Global instance
def get_artifacts() -> ModelArtifacts:
return ModelArtifacts.get_instance()