COCODEDE04 commited on
Commit
130812d
·
verified ·
1 Parent(s): 04ccf8e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -17
app.py CHANGED
@@ -378,41 +378,103 @@ async def predict(req: Request):
378
  probs_dict = {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}
379
  missing = [f for i, f in enumerate(FEATURES) if np.isnan(raw_vec[i])]
380
 
381
- # ---------- 3) SHAP explanations (all classes) ----------
382
  shap_block: Dict[str, Any] = {"available": False}
383
 
384
  if EXPLAINER is not None and SHAP_AVAILABLE:
385
  try:
386
  X_z = z_vec.reshape(1, -1).astype(np.float32)
387
- # KernelExplainer: usually returns list of length K (one (1,D) array per class)
388
  shap_vals = EXPLAINER.shap_values(X_z, nsamples=50)
389
 
390
  all_classes: Dict[str, Dict[str, float]] = {}
391
 
 
392
  if isinstance(shap_vals, list):
393
- # per-class outputs
394
  for k, class_name in enumerate(CLASSES):
395
  if k >= len(shap_vals):
396
  continue
397
- vec = np.array(shap_vals[k][0], dtype=float) # shape (D,)
398
- if vec.shape[0] != len(FEATURES):
399
- # shape mismatch: bail out gracefully
 
 
 
 
 
400
  continue
 
401
  all_classes[class_name] = {
402
  FEATURES[i]: float(vec[i]) for i in range(len(FEATURES))
403
  }
404
 
405
- shap_block = {
406
- "available": True,
407
- "mode": "per_class",
408
- "explained_classes": list(all_classes.keys()),
409
- "all_classes": all_classes,
410
- }
 
 
 
 
 
 
411
 
 
412
  else:
413
- # single array (1, D) – treat as "predicted class only" fallback
414
- vec = np.array(shap_vals[0], dtype=float)
415
- if vec.shape[0] == len(FEATURES):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
  shap_block = {
417
  "available": True,
418
  "mode": "single_class",
@@ -421,10 +483,11 @@ async def predict(req: Request):
421
  FEATURES[i]: float(vec[i]) for i in range(len(FEATURES))
422
  },
423
  }
 
424
  else:
425
  shap_block = {
426
  "available": False,
427
- "error": f"Unexpected SHAP vector length {vec.shape[0]} (expected {len(FEATURES)})",
428
  }
429
 
430
  except Exception as e:
@@ -433,7 +496,6 @@ async def predict(req: Request):
433
  "error": str(e),
434
  "trace": traceback.format_exc(),
435
  }
436
-
437
  # ---------- 4) Build response ----------
438
  return {
439
  "input_ok": (len(missing) == 0),
 
378
  probs_dict = {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}
379
  missing = [f for i, f in enumerate(FEATURES) if np.isnan(raw_vec[i])]
380
 
381
+ # ---------- 3) SHAP explanations (all classes) ----------
382
  shap_block: Dict[str, Any] = {"available": False}
383
 
384
  if EXPLAINER is not None and SHAP_AVAILABLE:
385
  try:
386
  X_z = z_vec.reshape(1, -1).astype(np.float32)
 
387
  shap_vals = EXPLAINER.shap_values(X_z, nsamples=50)
388
 
389
  all_classes: Dict[str, Dict[str, float]] = {}
390
 
391
+ # ---------- CASE 1: SHAP returns list (usual multi-class) ----------
392
  if isinstance(shap_vals, list):
 
393
  for k, class_name in enumerate(CLASSES):
394
  if k >= len(shap_vals):
395
  continue
396
+ arr = np.array(shap_vals[k], dtype=float) # shape (N, D) or (D,)
397
+ # reduce to a 1D (D,) vector for the first sample
398
+ if arr.ndim == 2 and arr.shape[0] >= 1 and arr.shape[1] == len(FEATURES):
399
+ vec = arr[0, :]
400
+ elif arr.ndim == 1 and arr.shape[0] == len(FEATURES):
401
+ vec = arr
402
+ else:
403
+ # shape we don't know how to handle for this class
404
  continue
405
+
406
  all_classes[class_name] = {
407
  FEATURES[i]: float(vec[i]) for i in range(len(FEATURES))
408
  }
409
 
410
+ if all_classes:
411
+ shap_block = {
412
+ "available": True,
413
+ "mode": "per_class",
414
+ "explained_classes": list(all_classes.keys()),
415
+ "all_classes": all_classes,
416
+ }
417
+ else:
418
+ shap_block = {
419
+ "available": False,
420
+ "error": "No per-class SHAP vectors matched expected shape.",
421
+ }
422
 
423
+ # ---------- CASE 2: SHAP returns a numpy array ----------
424
  else:
425
+ arr = np.array(shap_vals, dtype=float)
426
+
427
+ # (1, K, D)
428
+ if (
429
+ arr.ndim == 3
430
+ and arr.shape[0] == 1
431
+ and arr.shape[1] == len(CLASSES)
432
+ and arr.shape[2] == len(FEATURES)
433
+ ):
434
+ for k, class_name in enumerate(CLASSES):
435
+ vec = arr[0, k, :] # (D,)
436
+ all_classes[class_name] = {
437
+ FEATURES[i]: float(vec[i]) for i in range(len(FEATURES))
438
+ }
439
+ shap_block = {
440
+ "available": True,
441
+ "mode": "per_class",
442
+ "explained_classes": list(all_classes.keys()),
443
+ "all_classes": all_classes,
444
+ }
445
+
446
+ # (K, D)
447
+ elif (
448
+ arr.ndim == 2
449
+ and arr.shape[0] == len(CLASSES)
450
+ and arr.shape[1] == len(FEATURES)
451
+ ):
452
+ for k, class_name in enumerate(CLASSES):
453
+ vec = arr[k, :] # (D,)
454
+ all_classes[class_name] = {
455
+ FEATURES[i]: float(vec[i]) for i in range(len(FEATURES))
456
+ }
457
+ shap_block = {
458
+ "available": True,
459
+ "mode": "per_class",
460
+ "explained_classes": list(all_classes.keys()),
461
+ "all_classes": all_classes,
462
+ }
463
+
464
+ # Single-vector fallback: (1, D) or (D,)
465
+ elif arr.ndim == 2 and arr.shape[0] == 1 and arr.shape[1] == len(FEATURES):
466
+ vec = arr[0, :] # (D,)
467
+ shap_block = {
468
+ "available": True,
469
+ "mode": "single_class",
470
+ "explained_class": CLASSES[pred_idx],
471
+ "values": {
472
+ FEATURES[i]: float(vec[i]) for i in range(len(FEATURES))
473
+ },
474
+ }
475
+
476
+ elif arr.ndim == 1 and arr.shape[0] == len(FEATURES):
477
+ vec = arr # (D,)
478
  shap_block = {
479
  "available": True,
480
  "mode": "single_class",
 
483
  FEATURES[i]: float(vec[i]) for i in range(len(FEATURES))
484
  },
485
  }
486
+
487
  else:
488
  shap_block = {
489
  "available": False,
490
+ "error": f"Unexpected SHAP array shape {arr.shape}",
491
  }
492
 
493
  except Exception as e:
 
496
  "error": str(e),
497
  "trace": traceback.format_exc(),
498
  }
 
499
  # ---------- 4) Build response ----------
500
  return {
501
  "input_ok": (len(missing) == 0),