AshenH commited on
Commit
aed2def
·
verified ·
1 Parent(s): 2dcd5ce

Update tools/explain_tool.py

Browse files
Files changed (1) hide show
  1. tools/explain_tool.py +81 -20
tools/explain_tool.py CHANGED
@@ -1,44 +1,105 @@
 
1
  import os
2
  import io
3
- import shap
4
  import base64
 
 
 
5
  import pandas as pd
 
 
6
  from huggingface_hub import hf_hub_download
 
7
  from utils.config import AppConfig
8
  from utils.tracing import Tracer
9
 
 
10
  class ExplainTool:
 
 
 
 
11
  def __init__(self, cfg: AppConfig, tracer: Tracer):
12
  self.cfg = cfg
13
  self.tracer = tracer
14
  self._model = None
 
15
 
16
  def _ensure_model(self):
17
- if self._model is None:
18
- import joblib
19
- path = hf_hub_download(repo_id=self.cfg.hf_model_repo, filename="model.pkl", token=os.getenv("HF_TOKEN"))
20
- self._model = joblib.load(path)
 
 
 
 
 
 
 
21
 
22
- def _to_data_uri(self, fig) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  buf = io.BytesIO()
24
- fig.savefig(buf, format="png", bbox_inches="tight")
 
25
  buf.seek(0)
26
  return "data:image/png;base64," + base64.b64encode(buf.read()).decode()
27
 
28
- def run(self, df: pd.DataFrame):
 
 
 
29
  self._ensure_model()
30
- # Use a small sample for speed on CPU Spaces
31
- sample = df.sample(min(len(df), 500), random_state=42)
32
- explainer = shap.Explainer(self._model, sample, feature_names=list(sample.columns))
33
- shap_values = explainer(sample)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- # Global summary plot
36
- fig1 = shap.plots.bar(shap_values, show=False)
37
- img1 = self._to_data_uri(fig1)
 
38
 
39
- # Beeswarm (optional)
40
- fig2 = shap.plots.beeswarm(shap_values, show=False)
41
- img2 = self._to_data_uri(fig2)
 
42
 
43
- self.tracer.trace_event("explain", {"rows": len(sample)})
44
- return {"global_bar": img1, "beeswarm": img2}
 
1
+ # space/tools/explain_tool.py
2
  import os
3
  import io
4
+ import json
5
  import base64
6
+ from typing import Dict, Optional
7
+
8
+ import shap
9
  import pandas as pd
10
+ import matplotlib.pyplot as plt
11
+ import joblib
12
  from huggingface_hub import hf_hub_download
13
+
14
  from utils.config import AppConfig
15
  from utils.tracing import Tracer
16
 
17
+
18
  class ExplainTool:
19
+ """
20
+ Generates lightweight global SHAP visualizations (bar + beeswarm) for a sample
21
+ of the current DataFrame. Designed to run on CPU in HF Spaces.
22
+ """
23
  def __init__(self, cfg: AppConfig, tracer: Tracer):
24
  self.cfg = cfg
25
  self.tracer = tracer
26
  self._model = None
27
+ self._feature_order = None
28
 
29
  def _ensure_model(self):
30
+ if self._model is not None:
31
+ return
32
+ token = os.getenv("HF_TOKEN")
33
+ repo = self.cfg.hf_model_repo
34
+
35
+ model_path = hf_hub_download(
36
+ repo_id=repo,
37
+ filename="model.pkl",
38
+ token=token
39
+ )
40
+ self._model = joblib.load(model_path)
41
 
42
+ # read optional feature metadata to keep column order consistent
43
+ try:
44
+ meta_path = hf_hub_download(
45
+ repo_id=repo,
46
+ filename="feature_metadata.json",
47
+ token=token
48
+ )
49
+ with open(meta_path, "r", encoding="utf-8") as f:
50
+ meta = json.load(f) or {}
51
+ self._feature_order = meta.get("feature_order")
52
+ except Exception:
53
+ self._feature_order = None
54
+
55
+ @staticmethod
56
+ def _to_data_uri(fig) -> str:
57
  buf = io.BytesIO()
58
+ fig.savefig(buf, format="png", bbox_inches="tight", dpi=150)
59
+ plt.close(fig)
60
  buf.seek(0)
61
  return "data:image/png;base64," + base64.b64encode(buf.read()).decode()
62
 
63
+ def run(self, df: Optional[pd.DataFrame]) -> Dict[str, str]:
64
+ """
65
+ Returns dict of {plot_name: data_uri_png}. If df is None/empty, returns {}.
66
+ """
67
  self._ensure_model()
68
+ if df is None or len(df) == 0:
69
+ return {}
70
+
71
+ # Select & sample features
72
+ if self._feature_order:
73
+ missing = [c for c in self._feature_order if c not in df.columns]
74
+ if missing:
75
+ # best effort: intersect
76
+ X = df[[c for c in self._feature_order if c in df.columns]].copy()
77
+ else:
78
+ X = df[self._feature_order].copy()
79
+ else:
80
+ X = df.copy()
81
+
82
+ # Small sample for speed
83
+ n = min(len(X), 500)
84
+ sample = X.sample(n, random_state=42) if len(X) > n else X
85
+
86
+ # Build explainer and compute SHAP values
87
+ explainer = shap.Explainer(self._model, sample)
88
+ sv = explainer(sample)
89
+
90
+ # --- Global bar plot ---
91
+ fig_bar = plt.figure()
92
+ shap.plots.bar(sv, show=False)
93
+ bar_uri = self._to_data_uri(fig_bar)
94
 
95
+ # --- Beeswarm plot ---
96
+ fig_bee = plt.figure()
97
+ shap.plots.beeswarm(sv, show=False)
98
+ bee_uri = self._to_data_uri(fig_bee)
99
 
100
+ try:
101
+ self.tracer.trace_event("explain", {"rows": int(n)})
102
+ except Exception:
103
+ pass
104
 
105
+ return {"global_bar": bar_uri, "beeswarm": bee_uri}