File size: 1,501 Bytes
9b2b28b 0f166dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
import os
import io
import shap
import base64
import pandas as pd
from huggingface_hub import hf_hub_download
from utils.config import AppConfig
from utils.tracing import Tracer
class ExplainTool:
def __init__(self, cfg: AppConfig, tracer: Tracer):
self.cfg = cfg
self.tracer = tracer
self._model = None
def _ensure_model(self):
if self._model is None:
import joblib
path = hf_hub_download(repo_id=self.cfg.hf_model_repo, filename="model.pkl", token=os.getenv("HF_TOKEN"))
self._model = joblib.load(path)
def _to_data_uri(self, fig) -> str:
buf = io.BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight")
buf.seek(0)
return "data:image/png;base64," + base64.b64encode(buf.read()).decode()
def run(self, df: pd.DataFrame):
self._ensure_model()
# Use a small sample for speed on CPU Spaces
sample = df.sample(min(len(df), 500), random_state=42)
explainer = shap.Explainer(self._model, sample, feature_names=list(sample.columns))
shap_values = explainer(sample)
# Global summary plot
fig1 = shap.plots.bar(shap_values, show=False)
img1 = self._to_data_uri(fig1)
# Beeswarm (optional)
fig2 = shap.plots.beeswarm(shap_values, show=False)
img2 = self._to_data_uri(fig2)
self.tracer.trace_event("explain", {"rows": len(sample)})
return {"global_bar": img1, "beeswarm": img2} |