| |
| import os |
| import io |
| import json |
| import base64 |
| from typing import Dict, Optional |
|
|
| import shap |
| import pandas as pd |
| import matplotlib.pyplot as plt |
| import joblib |
| from huggingface_hub import hf_hub_download |
|
|
| from utils.config import AppConfig |
| from utils.tracing import Tracer |
|
|
|
|
| class ExplainTool: |
| """ |
| Generates global SHAP visualizations for a sample of rows (CPU-friendly). |
| """ |
| def __init__(self, cfg: AppConfig, tracer: Tracer): |
| self.cfg = cfg |
| self.tracer = tracer |
| self._model = None |
| self._feature_order = None |
|
|
| def _ensure_model(self): |
| if self._model is not None: |
| return |
| token = os.getenv("HF_TOKEN") |
| repo = self.cfg.hf_model_repo |
|
|
| model_path = hf_hub_download(repo_id=repo, filename="model.pkl", token=token) |
| self._model = joblib.load(model_path) |
|
|
| try: |
| meta_path = hf_hub_download(repo_id=repo, filename="feature_metadata.json", token=token) |
| with open(meta_path, "r", encoding="utf-8") as f: |
| meta = json.load(f) or {} |
| self._feature_order = meta.get("feature_order") |
| except Exception: |
| self._feature_order = None |
|
|
| @staticmethod |
| def _to_data_uri(fig) -> str: |
| buf = io.BytesIO() |
| fig.savefig(buf, format="png", bbox_inches="tight", dpi=150) |
| plt.close(fig) |
| buf.seek(0) |
| return "data:image/png;base64," + base64.b64encode(buf.read()).decode() |
|
|
| def run(self, df: Optional[pd.DataFrame]) -> Dict[str, str]: |
| self._ensure_model() |
| if df is None or len(df) == 0: |
| return {} |
|
|
| if self._feature_order: |
| cols = [c for c in self._feature_order if c in df.columns] |
| X = df[cols].copy() |
| else: |
| X = df.copy() |
|
|
| n = min(len(X), 500) |
| sample = X.sample(n, random_state=42) if len(X) > n else X |
|
|
| explainer = shap.Explainer(self._model, sample) |
| sv = explainer(sample) |
|
|
| fig_bar = plt.figure() |
| shap.plots.bar(sv, show=False) |
| bar_uri = self._to_data_uri(fig_bar) |
|
|
| fig_bee = plt.figure() |
| shap.plots.beeswarm(sv, show=False) |
| bee_uri = self._to_data_uri(fig_bee) |
|
|
| try: |
| self.tracer.trace_event("explain", {"rows": int(n)}) |
| except Exception: |
| pass |
|
|
| return {"global_bar": bar_uri, "beeswarm": bee_uri} |
|
|