AshenH commited on
Commit
e4818d5
·
verified ·
1 Parent(s): 4cad9bd

Update tools/explain_tool.py

Browse files
Files changed (1) hide show
  1. tools/explain_tool.py +289 -50
tools/explain_tool.py CHANGED
@@ -3,10 +3,13 @@ import os
3
  import io
4
  import json
5
  import base64
 
6
  from typing import Dict, Optional
7
 
8
  import shap
9
  import pandas as pd
 
 
10
  import matplotlib.pyplot as plt
11
  import joblib
12
  from huggingface_hub import hf_hub_download
@@ -14,70 +17,306 @@ from huggingface_hub import hf_hub_download
14
  from utils.config import AppConfig
15
  from utils.tracing import Tracer
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  class ExplainTool:
19
  """
20
- Generates global SHAP visualizations for a sample of rows (CPU-friendly).
 
21
  """
 
22
  def __init__(self, cfg: AppConfig, tracer: Tracer):
23
  self.cfg = cfg
24
  self.tracer = tracer
25
  self._model = None
26
  self._feature_order = None
27
-
 
 
28
  def _ensure_model(self):
 
29
  if self._model is not None:
30
  return
31
- token = os.getenv("HF_TOKEN")
32
- repo = self.cfg.hf_model_repo
33
-
34
- model_path = hf_hub_download(repo_id=repo, filename="model.pkl", token=token)
35
- self._model = joblib.load(model_path)
36
-
37
  try:
38
- meta_path = hf_hub_download(repo_id=repo, filename="feature_metadata.json", token=token)
39
- with open(meta_path, "r", encoding="utf-8") as f:
40
- meta = json.load(f) or {}
41
- self._feature_order = meta.get("feature_order")
42
- except Exception:
43
- self._feature_order = None
44
-
45
- @staticmethod
46
- def _to_data_uri(fig) -> str:
47
- buf = io.BytesIO()
48
- fig.savefig(buf, format="png", bbox_inches="tight", dpi=150)
49
- plt.close(fig)
50
- buf.seek(0)
51
- return "data:image/png;base64," + base64.b64encode(buf.read()).decode()
52
-
53
- def run(self, df: Optional[pd.DataFrame]) -> Dict[str, str]:
54
- self._ensure_model()
55
- if df is None or len(df) == 0:
56
- return {}
57
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  if self._feature_order:
59
- cols = [c for c in self._feature_order if c in df.columns]
60
- X = df[cols].copy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  else:
 
62
  X = df.copy()
63
-
64
- n = min(len(X), 500)
65
- sample = X.sample(n, random_state=42) if len(X) > n else X
66
-
67
- explainer = shap.Explainer(self._model, sample)
68
- sv = explainer(sample)
69
-
70
- fig_bar = plt.figure()
71
- shap.plots.bar(sv, show=False)
72
- bar_uri = self._to_data_uri(fig_bar)
73
-
74
- fig_bee = plt.figure()
75
- shap.plots.beeswarm(sv, show=False)
76
- bee_uri = self._to_data_uri(fig_bee)
77
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  try:
79
- self.tracer.trace_event("explain", {"rows": int(n)})
80
- except Exception:
81
- pass
82
-
83
- return {"global_bar": bar_uri, "beeswarm": bee_uri}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import io
4
  import json
5
  import base64
6
+ import logging
7
  from typing import Dict, Optional
8
 
9
  import shap
10
  import pandas as pd
11
+ import matplotlib
12
+ matplotlib.use('Agg') # Non-interactive backend
13
  import matplotlib.pyplot as plt
14
  import joblib
15
  from huggingface_hub import hf_hub_download
 
17
  from utils.config import AppConfig
18
  from utils.tracing import Tracer
19
 
20
+ logger = logging.getLogger(__name__)
21
+
22
+ # Constants
23
+ MAX_SAMPLE_SIZE = 1000
24
+ MIN_SAMPLE_SIZE = 10
25
+ DEFAULT_SAMPLE_SIZE = 500
26
+ MAX_IMAGE_SIZE_MB = 5
27
+
28
+
29
+ class ExplainToolError(Exception):
30
+ """Custom exception for explanation tool errors."""
31
+ pass
32
+
33
 
34
  class ExplainTool:
35
  """
36
+ Generates SHAP-based model explanations with global visualizations.
37
+ CPU-friendly with sampling for large datasets.
38
  """
39
+
40
  def __init__(self, cfg: AppConfig, tracer: Tracer):
41
  self.cfg = cfg
42
  self.tracer = tracer
43
  self._model = None
44
  self._feature_order = None
45
+
46
+ logger.info("ExplainTool initialized (lazy loading)")
47
+
48
  def _ensure_model(self):
49
+ """Lazy load model and metadata from HuggingFace."""
50
  if self._model is not None:
51
  return
52
+
 
 
 
 
 
53
  try:
54
+ token = os.getenv("HF_TOKEN")
55
+ repo = self.cfg.hf_model_repo
56
+
57
+ if not repo:
58
+ raise ExplainToolError("HF_MODEL_REPO not configured")
59
+
60
+ logger.info(f"Loading model for explanations from: {repo}")
61
+
62
+ # Download and load model
63
+ try:
64
+ model_path = hf_hub_download(
65
+ repo_id=repo,
66
+ filename="model.pkl",
67
+ token=token
68
+ )
69
+ self._model = joblib.load(model_path)
70
+ logger.info(f"Model loaded: {type(self._model).__name__}")
71
+ except Exception as e:
72
+ raise ExplainToolError(f"Failed to load model: {e}") from e
73
+
74
+ # Load feature metadata
75
+ try:
76
+ meta_path = hf_hub_download(
77
+ repo_id=repo,
78
+ filename="feature_metadata.json",
79
+ token=token
80
+ )
81
+ with open(meta_path, "r", encoding="utf-8") as f:
82
+ meta = json.load(f) or {}
83
+
84
+ self._feature_order = meta.get("feature_order")
85
+ logger.info(f"Loaded feature order: {len(self._feature_order or [])} features")
86
+
87
+ except Exception as e:
88
+ logger.warning(f"Could not load feature metadata: {e}")
89
+ self._feature_order = None
90
+
91
+ except ExplainToolError:
92
+ raise
93
+ except Exception as e:
94
+ raise ExplainToolError(f"Model initialization failed: {e}") from e
95
+
96
+ def _validate_data(self, df: pd.DataFrame) -> tuple[bool, str]:
97
+ """
98
+ Validate input dataframe.
99
+ Returns (is_valid, error_message).
100
+ """
101
+ if df is None or df.empty:
102
+ return False, "Input dataframe is empty"
103
+
104
+ if len(df.columns) == 0:
105
+ return False, "Dataframe has no columns"
106
+
107
+ return True, ""
108
+
109
+ def _prepare_features(self, df: pd.DataFrame) -> pd.DataFrame:
110
+ """
111
+ Prepare feature matrix for SHAP analysis.
112
+ Selects and orders features according to model expectations.
113
+ """
114
  if self._feature_order:
115
+ # Use specified feature order
116
+ available_features = [col for col in self._feature_order if col in df.columns]
117
+ missing_features = [col for col in self._feature_order if col not in df.columns]
118
+
119
+ if missing_features:
120
+ logger.warning(
121
+ f"Missing {len(missing_features)} features for explanation: "
122
+ f"{missing_features[:5]}"
123
+ )
124
+
125
+ if not available_features:
126
+ raise ExplainToolError(
127
+ f"No required features found in dataframe. "
128
+ f"Required: {self._feature_order}, "
129
+ f"Available: {list(df.columns)}"
130
+ )
131
+
132
+ X = df[available_features].copy()
133
+ logger.info(f"Using {len(available_features)} features for explanation")
134
  else:
135
+ # Use all columns
136
  X = df.copy()
137
+ logger.warning("No feature order specified - using all columns")
138
+
139
+ # Remove non-numeric columns
140
+ numeric_cols = X.select_dtypes(include=['number']).columns
141
+ if len(numeric_cols) < len(X.columns):
142
+ dropped = set(X.columns) - set(numeric_cols)
143
+ logger.warning(f"Dropping {len(dropped)} non-numeric columns: {list(dropped)[:5]}")
144
+ X = X[numeric_cols]
145
+
146
+ if X.empty or len(X.columns) == 0:
147
+ raise ExplainToolError("No numeric features available for explanation")
148
+
149
+ return X
150
+
151
+ def _sample_data(self, X: pd.DataFrame, sample_size: int = DEFAULT_SAMPLE_SIZE) -> pd.DataFrame:
152
+ """
153
+ Sample data for SHAP analysis to keep computation manageable.
154
+ """
155
+ n = len(X)
156
+
157
+ if n <= MIN_SAMPLE_SIZE:
158
+ logger.info(f"Using all {n} rows (below minimum sample size)")
159
+ return X
160
+
161
+ # Determine sample size
162
+ target_size = min(sample_size, MAX_SAMPLE_SIZE)
163
+ target_size = max(target_size, MIN_SAMPLE_SIZE)
164
+
165
+ if n <= target_size:
166
+ logger.info(f"Using all {n} rows (below target sample size)")
167
+ return X
168
+
169
+ # Stratified sampling if possible
170
  try:
171
+ sample = X.sample(n=target_size, random_state=42)
172
+ logger.info(f"Sampled {target_size} rows from {n} total")
173
+ return sample
174
+ except Exception as e:
175
+ logger.warning(f"Sampling failed: {e}, using head()")
176
+ return X.head(target_size)
177
+
178
+ @staticmethod
179
+ def _to_data_uri(fig) -> str:
180
+ """
181
+ Convert matplotlib figure to base64 data URI.
182
+ Includes size validation.
183
+ """
184
+ try:
185
+ buf = io.BytesIO()
186
+ fig.savefig(buf, format="png", bbox_inches="tight", dpi=150)
187
+ plt.close(fig)
188
+ buf.seek(0)
189
+
190
+ # Check size
191
+ size_mb = len(buf.getvalue()) / (1024 * 1024)
192
+ if size_mb > MAX_IMAGE_SIZE_MB:
193
+ logger.warning(f"Generated image is large: {size_mb:.2f} MB")
194
+
195
+ data_uri = "data:image/png;base64," + base64.b64encode(buf.read()).decode()
196
+ logger.debug(f"Generated data URI of size: {len(data_uri)} chars")
197
+
198
+ return data_uri
199
+
200
+ except Exception as e:
201
+ logger.error(f"Failed to convert figure to data URI: {e}")
202
+ raise ExplainToolError(f"Image conversion failed: {e}") from e
203
+
204
+ def _generate_shap_values(self, X: pd.DataFrame) -> shap.Explanation:
205
+ """
206
+ Generate SHAP values for the sample.
207
+ """
208
+ try:
209
+ logger.info("Creating SHAP explainer...")
210
+ explainer = shap.Explainer(self._model, X)
211
+
212
+ logger.info("Computing SHAP values...")
213
+ shap_values = explainer(X)
214
+
215
+ logger.info(f"SHAP values computed: shape={shap_values.values.shape}")
216
+ return shap_values
217
+
218
+ except Exception as e:
219
+ raise ExplainToolError(f"SHAP computation failed: {e}") from e
220
+
221
+ def _create_bar_plot(self, shap_values: shap.Explanation) -> str:
222
+ """Create global feature importance bar plot."""
223
+ try:
224
+ logger.info("Creating bar plot...")
225
+ fig = plt.figure(figsize=(10, 6))
226
+ shap.plots.bar(shap_values, show=False, max_display=20)
227
+ plt.title("Feature Importance (Global)", fontsize=14, pad=20)
228
+ plt.xlabel("Mean |SHAP value|", fontsize=12)
229
+ plt.tight_layout()
230
+
231
+ uri = self._to_data_uri(fig)
232
+ logger.info("Bar plot created successfully")
233
+ return uri
234
+
235
+ except Exception as e:
236
+ logger.error(f"Bar plot creation failed: {e}")
237
+ # Return empty data URI rather than failing completely
238
+ return ""
239
+
240
+ def _create_beeswarm_plot(self, shap_values: shap.Explanation) -> str:
241
+ """Create beeswarm plot showing feature effects."""
242
+ try:
243
+ logger.info("Creating beeswarm plot...")
244
+ fig = plt.figure(figsize=(10, 8))
245
+ shap.plots.beeswarm(shap_values, show=False, max_display=20)
246
+ plt.title("Feature Effects Distribution", fontsize=14, pad=20)
247
+ plt.tight_layout()
248
+
249
+ uri = self._to_data_uri(fig)
250
+ logger.info("Beeswarm plot created successfully")
251
+ return uri
252
+
253
+ except Exception as e:
254
+ logger.error(f"Beeswarm plot creation failed: {e}")
255
+ return ""
256
+
257
+ def run(self, df: Optional[pd.DataFrame]) -> Dict[str, str]:
258
+ """
259
+ Generate SHAP explanations for input data.
260
+
261
+ Args:
262
+ df: Input dataframe with features
263
+
264
+ Returns:
265
+ Dictionary mapping plot names to base64 data URIs
266
+
267
+ Raises:
268
+ ExplainToolError: If explanation generation fails
269
+ """
270
+ try:
271
+ # Validate input
272
+ is_valid, error_msg = self._validate_data(df)
273
+ if not is_valid:
274
+ logger.warning(f"Invalid input: {error_msg}")
275
+ return {}
276
+
277
+ # Ensure model is loaded
278
+ self._ensure_model()
279
+
280
+ # Prepare features
281
+ X = self._prepare_features(df)
282
+ logger.info(f"Prepared features: {X.shape}")
283
+
284
+ # Sample data for efficiency
285
+ sample = self._sample_data(X)
286
+
287
+ # Generate SHAP values
288
+ shap_values = self._generate_shap_values(sample)
289
+
290
+ # Create visualizations
291
+ result = {}
292
+
293
+ # Bar plot (feature importance)
294
+ bar_uri = self._create_bar_plot(shap_values)
295
+ if bar_uri:
296
+ result["global_bar"] = bar_uri
297
+
298
+ # Beeswarm plot (feature effects)
299
+ bee_uri = self._create_beeswarm_plot(shap_values)
300
+ if bee_uri:
301
+ result["beeswarm"] = bee_uri
302
+
303
+ # Log success
304
+ logger.info(f"Generated {len(result)} explanation visualizations")
305
+
306
+ if self.tracer:
307
+ self.tracer.trace_event("explain", {
308
+ "rows": len(sample),
309
+ "features": len(X.columns),
310
+ "visualizations": len(result)
311
+ })
312
+
313
+ return result
314
+
315
+ except ExplainToolError:
316
+ raise
317
+ except Exception as e:
318
+ error_msg = f"Explanation generation failed: {str(e)}"
319
+ logger.error(error_msg)
320
+ if self.tracer:
321
+ self.tracer.trace_event("explain_error", {"error": error_msg})
322
+ raise ExplainToolError(error_msg) from e