|
|
import os |
|
|
import io |
|
|
import shap |
|
|
import base64 |
|
|
import pandas as pd |
|
|
from huggingface_hub import hf_hub_download |
|
|
from utils.config import AppConfig |
|
|
from utils.tracing import Tracer |
|
|
|
|
|
class ExplainTool: |
|
|
def __init__(self, cfg: AppConfig, tracer: Tracer): |
|
|
self.cfg = cfg |
|
|
self.tracer = tracer |
|
|
self._model = None |
|
|
|
|
|
def _ensure_model(self): |
|
|
if self._model is None: |
|
|
import joblib |
|
|
path = hf_hub_download(repo_id=self.cfg.hf_model_repo, filename="model.pkl", token=os.getenv("HF_TOKEN")) |
|
|
self._model = joblib.load(path) |
|
|
|
|
|
def _to_data_uri(self, fig) -> str: |
|
|
buf = io.BytesIO() |
|
|
fig.savefig(buf, format="png", bbox_inches="tight") |
|
|
buf.seek(0) |
|
|
return "data:image/png;base64," + base64.b64encode(buf.read()).decode() |
|
|
|
|
|
def run(self, df: pd.DataFrame): |
|
|
self._ensure_model() |
|
|
|
|
|
sample = df.sample(min(len(df), 500), random_state=42) |
|
|
explainer = shap.Explainer(self._model, sample, feature_names=list(sample.columns)) |
|
|
shap_values = explainer(sample) |
|
|
|
|
|
|
|
|
fig1 = shap.plots.bar(shap_values, show=False) |
|
|
img1 = self._to_data_uri(fig1) |
|
|
|
|
|
|
|
|
fig2 = shap.plots.beeswarm(shap_values, show=False) |
|
|
img2 = self._to_data_uri(fig2) |
|
|
|
|
|
self.tracer.trace_event("explain", {"rows": len(sample)}) |
|
|
return {"global_bar": img1, "beeswarm": img2} |