import os import io import shap import base64 import pandas as pd from huggingface_hub import hf_hub_download from utils.config import AppConfig from utils.tracing import Tracer class ExplainTool: def __init__(self, cfg: AppConfig, tracer: Tracer): self.cfg = cfg self.tracer = tracer self._model = None def _ensure_model(self): if self._model is None: import joblib path = hf_hub_download(repo_id=self.cfg.hf_model_repo, filename="model.pkl", token=os.getenv("HF_TOKEN")) self._model = joblib.load(path) def _to_data_uri(self, fig) -> str: buf = io.BytesIO() fig.savefig(buf, format="png", bbox_inches="tight") buf.seek(0) return "data:image/png;base64," + base64.b64encode(buf.read()).decode() def run(self, df: pd.DataFrame): self._ensure_model() # Use a small sample for speed on CPU Spaces sample = df.sample(min(len(df), 500), random_state=42) explainer = shap.Explainer(self._model, sample, feature_names=list(sample.columns)) shap_values = explainer(sample) # Global summary plot fig1 = shap.plots.bar(shap_values, show=False) img1 = self._to_data_uri(fig1) # Beeswarm (optional) fig2 = shap.plots.beeswarm(shap_values, show=False) img2 = self._to_data_uri(fig2) self.tracer.trace_event("explain", {"rows": len(sample)}) return {"global_bar": img1, "beeswarm": img2}