D0nG4667
fix: artifacts loader, complete singleton pattern
0a36c22
import joblib
import json
import logging
import requests
from typing import Any, Optional
from pathlib import Path
from sklearn.pipeline import Pipeline
from core.config import settings
from core.exceptions import ArtifactLoadError
logger = logging.getLogger(__name__)
class ModelArtifacts:
"""
Singleton-like container for ML artifacts.
Loaded at startup.
"""
_instance = None
def __init__(self):
self.model: Optional[Pipeline] = None
self.threshold: float = 0.5
self.shap_background: Any = None
self.feature_names: Optional[list] = None
self.is_loaded = False
def download_if_missing(self, path: Path, url: str):
if not path.exists():
logger.info(f"Downloading {path.name} from {url}...")
path.parent.mkdir(parents=True, exist_ok=True)
try:
with requests.get(url, stream=True) as response:
response.raise_for_status()
with open(path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
logger.info(f"Downloaded {path.name}")
except Exception as e:
logger.error(f"Failed to download {path.name}: {e}")
else:
logger.info(f"Artifact {path.name} found locally. Skipping download.")
# Singleton pattern
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls()
return cls._instance
def load_artifacts(self):
"""
Loads all artifacts from disk into memory.
"""
if self.is_loaded:
logger.info("Artifacts already loaded.")
return
# Directories are now Path objects
model_dir = settings.MODEL_DIR
artifacts_dir = settings.ARTIFACTS_DIR
logger.info(f"Loading models from {model_dir}")
logger.info(f"Loading artifacts from {artifacts_dir}")
# Ensure all artifacts are present
for path, url in settings.ARTIFACT_URLS.items():
self.download_if_missing(path, url)
try:
# Load Model
model_path = model_dir / settings.MODEL_FILENAME
if model_path.exists():
self.model = joblib.load(model_path)
logger.info(f"Loaded model from {model_path}")
else:
logger.error(f"Model file not found at {model_path}")
raise FileNotFoundError(f"Model not found at {model_path}")
# Load Threshold
thresh_path = artifacts_dir / settings.THRESHOLD_FILENAME
if thresh_path.exists():
with open(thresh_path, "r") as f:
data = json.load(f)
self.threshold = float(data.get("threshold", 0.5))
logger.info(f"Loaded threshold {self.threshold} from {thresh_path}")
else:
logger.warning(
f"Threshold file not found at {thresh_path}, using default 0.5"
)
# Load SHAP Background
shap_path = artifacts_dir / settings.SHAP_BACKGROUND_FILENAME
if shap_path.exists():
self.shap_background = joblib.load(shap_path)
logger.info(f"Loaded SHAP background from {shap_path}")
else:
logger.warning(
f"SHAP background not found at {shap_path}, SHAP explanations might fail."
)
self.is_loaded = True
logger.info("All artifacts loaded successfully.")
except Exception as e:
logger.error(f"Failed to load artifacts: {e}")
raise ArtifactLoadError(f"Critical error loading artifacts: {e}")
def clear(self):
"""
Unloads all ML artifacts from memory.
"""
logger.info("Unloading artifacts...")
self.model = None
self.threshold = None
self.shap_background = None
self.is_loaded = False
logger.info("Artifacts unloaded.")
# Global instance
def get_artifacts() -> ModelArtifacts:
return ModelArtifacts.get_instance()