File size: 4,274 Bytes
fdadf61 0df85ee fdadf61 0df85ee fdadf61 ab973af fdadf61 ab973af fdadf61 ab973af fdadf61 0df85ee ab973af 0df85ee 0a36c22 fdadf61 ab973af fdadf61 ab973af 0df85ee fdadf61 ab973af fdadf61 ab973af fdadf61 ab973af fdadf61 ab973af fdadf61 ab973af fdadf61 ab973af | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 | import joblib
import json
import logging
import requests
from typing import Any, Optional
from pathlib import Path
from sklearn.pipeline import Pipeline
from core.config import settings
from core.exceptions import ArtifactLoadError
logger = logging.getLogger(__name__)
class ModelArtifacts:
"""
Singleton-like container for ML artifacts.
Loaded at startup.
"""
_instance = None
def __init__(self):
self.model: Optional[Pipeline] = None
self.threshold: float = 0.5
self.shap_background: Any = None
self.feature_names: Optional[list] = None
self.is_loaded = False
def download_if_missing(self, path: Path, url: str):
if not path.exists():
logger.info(f"Downloading {path.name} from {url}...")
path.parent.mkdir(parents=True, exist_ok=True)
try:
with requests.get(url, stream=True) as response:
response.raise_for_status()
with open(path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
logger.info(f"Downloaded {path.name}")
except Exception as e:
logger.error(f"Failed to download {path.name}: {e}")
else:
logger.info(f"Artifact {path.name} found locally. Skipping download.")
# Singleton pattern
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls()
return cls._instance
def load_artifacts(self):
"""
Loads all artifacts from disk into memory.
"""
if self.is_loaded:
logger.info("Artifacts already loaded.")
return
# Directories are now Path objects
model_dir = settings.MODEL_DIR
artifacts_dir = settings.ARTIFACTS_DIR
logger.info(f"Loading models from {model_dir}")
logger.info(f"Loading artifacts from {artifacts_dir}")
# Ensure all artifacts are present
for path, url in settings.ARTIFACT_URLS.items():
self.download_if_missing(path, url)
try:
# Load Model
model_path = model_dir / settings.MODEL_FILENAME
if model_path.exists():
self.model = joblib.load(model_path)
logger.info(f"Loaded model from {model_path}")
else:
logger.error(f"Model file not found at {model_path}")
raise FileNotFoundError(f"Model not found at {model_path}")
# Load Threshold
thresh_path = artifacts_dir / settings.THRESHOLD_FILENAME
if thresh_path.exists():
with open(thresh_path, "r") as f:
data = json.load(f)
self.threshold = float(data.get("threshold", 0.5))
logger.info(f"Loaded threshold {self.threshold} from {thresh_path}")
else:
logger.warning(
f"Threshold file not found at {thresh_path}, using default 0.5"
)
# Load SHAP Background
shap_path = artifacts_dir / settings.SHAP_BACKGROUND_FILENAME
if shap_path.exists():
self.shap_background = joblib.load(shap_path)
logger.info(f"Loaded SHAP background from {shap_path}")
else:
logger.warning(
f"SHAP background not found at {shap_path}, SHAP explanations might fail."
)
self.is_loaded = True
logger.info("All artifacts loaded successfully.")
except Exception as e:
logger.error(f"Failed to load artifacts: {e}")
raise ArtifactLoadError(f"Critical error loading artifacts: {e}")
def clear(self):
"""
Unloads all ML artifacts from memory.
"""
logger.info("Unloading artifacts...")
self.model = None
self.threshold = None
self.shap_background = None
self.is_loaded = False
logger.info("Artifacts unloaded.")
# Global instance
def get_artifacts() -> ModelArtifacts:
return ModelArtifacts.get_instance()
|