COCODEDE04 commited on
Commit
1122e44
·
verified ·
1 Parent(s): 130812d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -1
app.py CHANGED
@@ -394,6 +394,7 @@ async def predict(req: Request):
394
  if k >= len(shap_vals):
395
  continue
396
  arr = np.array(shap_vals[k], dtype=float) # shape (N, D) or (D,)
 
397
  # reduce to a 1D (D,) vector for the first sample
398
  if arr.ndim == 2 and arr.shape[0] >= 1 and arr.shape[1] == len(FEATURES):
399
  vec = arr[0, :]
@@ -424,8 +425,29 @@ async def predict(req: Request):
424
  else:
425
  arr = np.array(shap_vals, dtype=float)
426
 
427
- # (1, K, D)
428
  if (
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429
  arr.ndim == 3
430
  and arr.shape[0] == 1
431
  and arr.shape[1] == len(CLASSES)
@@ -436,6 +458,7 @@ async def predict(req: Request):
436
  all_classes[class_name] = {
437
  FEATURES[i]: float(vec[i]) for i in range(len(FEATURES))
438
  }
 
439
  shap_block = {
440
  "available": True,
441
  "mode": "per_class",
@@ -454,6 +477,7 @@ async def predict(req: Request):
454
  all_classes[class_name] = {
455
  FEATURES[i]: float(vec[i]) for i in range(len(FEATURES))
456
  }
 
457
  shap_block = {
458
  "available": True,
459
  "mode": "per_class",
 
394
  if k >= len(shap_vals):
395
  continue
396
  arr = np.array(shap_vals[k], dtype=float) # shape (N, D) or (D,)
397
+
398
  # reduce to a 1D (D,) vector for the first sample
399
  if arr.ndim == 2 and arr.shape[0] >= 1 and arr.shape[1] == len(FEATURES):
400
  vec = arr[0, :]
 
425
  else:
426
  arr = np.array(shap_vals, dtype=float)
427
 
428
+ # (1, D, K) <-- THIS IS YOUR (1, 21, 5) CASE
429
  if (
430
+ arr.ndim == 3
431
+ and arr.shape[0] == 1
432
+ and arr.shape[1] == len(FEATURES)
433
+ and arr.shape[2] == len(CLASSES)
434
+ ):
435
+ # first sample, loop over classes on last axis
436
+ for k, class_name in enumerate(CLASSES):
437
+ vec = arr[0, :, k] # (D,)
438
+ all_classes[class_name] = {
439
+ FEATURES[i]: float(vec[i]) for i in range(len(FEATURES))
440
+ }
441
+
442
+ shap_block = {
443
+ "available": True,
444
+ "mode": "per_class",
445
+ "explained_classes": list(all_classes.keys()),
446
+ "all_classes": all_classes,
447
+ }
448
+
449
+ # (1, K, D)
450
+ elif (
451
  arr.ndim == 3
452
  and arr.shape[0] == 1
453
  and arr.shape[1] == len(CLASSES)
 
458
  all_classes[class_name] = {
459
  FEATURES[i]: float(vec[i]) for i in range(len(FEATURES))
460
  }
461
+
462
  shap_block = {
463
  "available": True,
464
  "mode": "per_class",
 
477
  all_classes[class_name] = {
478
  FEATURES[i]: float(vec[i]) for i in range(len(FEATURES))
479
  }
480
+
481
  shap_block = {
482
  "available": True,
483
  "mode": "per_class",