File size: 1,501 Bytes
9b2b28b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f166dc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import os
import io
import shap
import base64
import pandas as pd
from huggingface_hub import hf_hub_download
from utils.config import AppConfig
from utils.tracing import Tracer

class ExplainTool:
    def __init__(self, cfg: AppConfig, tracer: Tracer):
        self.cfg = cfg
        self.tracer = tracer
        self._model = None

    def _ensure_model(self):
        if self._model is None:
            import joblib
            path = hf_hub_download(repo_id=self.cfg.hf_model_repo, filename="model.pkl", token=os.getenv("HF_TOKEN"))
            self._model = joblib.load(path)

    def _to_data_uri(self, fig) -> str:
        buf = io.BytesIO()
        fig.savefig(buf, format="png", bbox_inches="tight")
        buf.seek(0)
        return "data:image/png;base64," + base64.b64encode(buf.read()).decode()

    def run(self, df: pd.DataFrame):
        self._ensure_model()
        # Use a small sample for speed on CPU Spaces
        sample = df.sample(min(len(df), 500), random_state=42)
        explainer = shap.Explainer(self._model, sample, feature_names=list(sample.columns))
        shap_values = explainer(sample)

        # Global summary plot
        fig1 = shap.plots.bar(shap_values, show=False)
        img1 = self._to_data_uri(fig1)

        # Beeswarm (optional)
        fig2 = shap.plots.beeswarm(shap_values, show=False)
        img2 = self._to_data_uri(fig2)

        self.tracer.trace_event("explain", {"rows": len(sample)})
        return {"global_bar": img1, "beeswarm": img2}