COCODEDE04 commited on
Commit
45857b7
·
verified ·
1 Parent(s): 7dc78bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -13
app.py CHANGED
@@ -370,50 +370,70 @@ async def predict(req: Request):
370
  probs_dict = {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}
371
  missing = [f for i, f in enumerate(FEATURES) if np.isnan(raw[i])]
372
 
373
- # ---------- SHAP EXPLANATION (predicted class only) ----------
374
  shap_out = {"error": "SHAP not computed"}
375
  if EXPLAINER is not None:
376
  try:
377
  shap_vals = EXPLAINER.shap_values(X, nsamples=100)
378
 
379
- # 1) Pull out the array for the predicted class (if multi-output)
380
  if isinstance(shap_vals, list):
 
381
  raw_sv = np.array(shap_vals[pred_idx])
382
  else:
 
383
  raw_sv = np.array(shap_vals)
384
 
385
- # 2) Normalize shapes: we want a 1D vector of length n_features
386
- # Possible shapes we might see:
387
- # (n_features,)
388
- # (1, n_features)
389
- # (n_samples, n_features) -> take first sample
390
  if raw_sv.ndim == 1:
 
391
  shap_vec = raw_sv.astype(float)
 
392
  elif raw_sv.ndim == 2:
 
393
  if raw_sv.shape[0] == 1:
 
394
  shap_vec = raw_sv[0].astype(float)
 
 
 
395
  else:
396
- # assume shape (n_samples, n_features); take first sample
397
  shap_vec = raw_sv[0].astype(float)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
  else:
399
- # last resort: flatten sample dims, take first "row"
400
- raw_sv = raw_sv.reshape(raw_sv.shape[0], -1)
401
- shap_vec = raw_sv[0].astype(float)
402
 
 
403
  if shap_vec.shape[0] != len(FEATURES):
404
  raise ValueError(
405
  f"Unexpected SHAP vector length {shap_vec.shape[0]} "
406
  f"(expected {len(FEATURES)})"
407
  )
408
 
409
- # 3) Expected value: baseline logit/prob for that class
410
  exp_raw = EXPLAINER.expected_value
411
  if isinstance(exp_raw, (list, np.ndarray)):
412
  exp_val = float(np.array(exp_raw)[pred_idx])
413
  else:
414
  exp_val = float(exp_raw)
415
 
416
- # 4) Map feature -> SHAP contribution
417
  shap_feature_contribs = {
418
  FEATURES[i]: float(shap_vec[i])
419
  for i in range(len(FEATURES))
 
370
  probs_dict = {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}
371
  missing = [f for i, f in enumerate(FEATURES) if np.isnan(raw[i])]
372
 
373
+ # ---------- SHAP EXPLANATION (predicted class only) ----------
374
  shap_out = {"error": "SHAP not computed"}
375
  if EXPLAINER is not None:
376
  try:
377
  shap_vals = EXPLAINER.shap_values(X, nsamples=100)
378
 
379
+ # 1) Pull raw SHAP tensor
380
  if isinstance(shap_vals, list):
381
+ # Classic multi-output: list[len = n_classes], each (n_samples, n_features)
382
  raw_sv = np.array(shap_vals[pred_idx])
383
  else:
384
+ # Single array, possibly (n_samples, n_features) or (n_samples, n_features, n_outputs)
385
  raw_sv = np.array(shap_vals)
386
 
387
+ # 2) Normalize shapes to a 1D vector (n_features,) for the predicted class
 
 
 
 
388
  if raw_sv.ndim == 1:
389
+ # Already (n_features,)
390
  shap_vec = raw_sv.astype(float)
391
+
392
  elif raw_sv.ndim == 2:
393
+ # (n_samples, n_features) or (n_features, 1)
394
  if raw_sv.shape[0] == 1:
395
+ # (1, n_features)
396
  shap_vec = raw_sv[0].astype(float)
397
+ elif raw_sv.shape[1] == 1:
398
+ # (n_features, 1)
399
+ shap_vec = raw_sv[:, 0].astype(float)
400
  else:
401
+ # assume (n_samples, n_features), take first sample
402
  shap_vec = raw_sv[0].astype(float)
403
+
404
+ elif raw_sv.ndim == 3:
405
+ # Most likely (n_samples, n_features, n_outputs)
406
+ n_samples, n_features, n_outputs = raw_sv.shape
407
+ if n_samples < 1:
408
+ raise ValueError(f"SHAP 3D output has zero samples: {raw_sv.shape}")
409
+ if pred_idx >= n_outputs:
410
+ raise ValueError(
411
+ f"SHAP 3D output has only {n_outputs} outputs, "
412
+ f"cannot index class {pred_idx}"
413
+ )
414
+ # take first sample, all features, predicted class
415
+ shap_vec = raw_sv[0, :, pred_idx].astype(float)
416
+
417
  else:
418
+ # Fallback: flatten all sample dims, keep first feature-block
419
+ flat = raw_sv.reshape(raw_sv.shape[0], -1)
420
+ shap_vec = flat[0].astype(float)
421
 
422
+ # 3) Sanity check length
423
  if shap_vec.shape[0] != len(FEATURES):
424
  raise ValueError(
425
  f"Unexpected SHAP vector length {shap_vec.shape[0]} "
426
  f"(expected {len(FEATURES)})"
427
  )
428
 
429
+ # 4) Expected value (baseline) for the predicted class
430
  exp_raw = EXPLAINER.expected_value
431
  if isinstance(exp_raw, (list, np.ndarray)):
432
  exp_val = float(np.array(exp_raw)[pred_idx])
433
  else:
434
  exp_val = float(exp_raw)
435
 
436
+ # 5) Map feature -> contribution
437
  shap_feature_contribs = {
438
  FEATURES[i]: float(shap_vec[i])
439
  for i in range(len(FEATURES))