Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -394,6 +394,7 @@ async def predict(req: Request):
|
|
| 394 |
if k >= len(shap_vals):
|
| 395 |
continue
|
| 396 |
arr = np.array(shap_vals[k], dtype=float) # shape (N, D) or (D,)
|
|
|
|
| 397 |
# reduce to a 1D (D,) vector for the first sample
|
| 398 |
if arr.ndim == 2 and arr.shape[0] >= 1 and arr.shape[1] == len(FEATURES):
|
| 399 |
vec = arr[0, :]
|
|
@@ -424,8 +425,29 @@ async def predict(req: Request):
|
|
| 424 |
else:
|
| 425 |
arr = np.array(shap_vals, dtype=float)
|
| 426 |
|
| 427 |
-
# (1, K,
|
| 428 |
if (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 429 |
arr.ndim == 3
|
| 430 |
and arr.shape[0] == 1
|
| 431 |
and arr.shape[1] == len(CLASSES)
|
|
@@ -436,6 +458,7 @@ async def predict(req: Request):
|
|
| 436 |
all_classes[class_name] = {
|
| 437 |
FEATURES[i]: float(vec[i]) for i in range(len(FEATURES))
|
| 438 |
}
|
|
|
|
| 439 |
shap_block = {
|
| 440 |
"available": True,
|
| 441 |
"mode": "per_class",
|
|
@@ -454,6 +477,7 @@ async def predict(req: Request):
|
|
| 454 |
all_classes[class_name] = {
|
| 455 |
FEATURES[i]: float(vec[i]) for i in range(len(FEATURES))
|
| 456 |
}
|
|
|
|
| 457 |
shap_block = {
|
| 458 |
"available": True,
|
| 459 |
"mode": "per_class",
|
|
|
|
| 394 |
if k >= len(shap_vals):
|
| 395 |
continue
|
| 396 |
arr = np.array(shap_vals[k], dtype=float) # shape (N, D) or (D,)
|
| 397 |
+
|
| 398 |
# reduce to a 1D (D,) vector for the first sample
|
| 399 |
if arr.ndim == 2 and arr.shape[0] >= 1 and arr.shape[1] == len(FEATURES):
|
| 400 |
vec = arr[0, :]
|
|
|
|
| 425 |
else:
|
| 426 |
arr = np.array(shap_vals, dtype=float)
|
| 427 |
|
| 428 |
+
# (1, D, K) <-- THIS IS YOUR (1, 21, 5) CASE
|
| 429 |
if (
|
| 430 |
+
arr.ndim == 3
|
| 431 |
+
and arr.shape[0] == 1
|
| 432 |
+
and arr.shape[1] == len(FEATURES)
|
| 433 |
+
and arr.shape[2] == len(CLASSES)
|
| 434 |
+
):
|
| 435 |
+
# first sample, loop over classes on last axis
|
| 436 |
+
for k, class_name in enumerate(CLASSES):
|
| 437 |
+
vec = arr[0, :, k] # (D,)
|
| 438 |
+
all_classes[class_name] = {
|
| 439 |
+
FEATURES[i]: float(vec[i]) for i in range(len(FEATURES))
|
| 440 |
+
}
|
| 441 |
+
|
| 442 |
+
shap_block = {
|
| 443 |
+
"available": True,
|
| 444 |
+
"mode": "per_class",
|
| 445 |
+
"explained_classes": list(all_classes.keys()),
|
| 446 |
+
"all_classes": all_classes,
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
# (1, K, D)
|
| 450 |
+
elif (
|
| 451 |
arr.ndim == 3
|
| 452 |
and arr.shape[0] == 1
|
| 453 |
and arr.shape[1] == len(CLASSES)
|
|
|
|
| 458 |
all_classes[class_name] = {
|
| 459 |
FEATURES[i]: float(vec[i]) for i in range(len(FEATURES))
|
| 460 |
}
|
| 461 |
+
|
| 462 |
shap_block = {
|
| 463 |
"available": True,
|
| 464 |
"mode": "per_class",
|
|
|
|
| 477 |
all_classes[class_name] = {
|
| 478 |
FEATURES[i]: float(vec[i]) for i in range(len(FEATURES))
|
| 479 |
}
|
| 480 |
+
|
| 481 |
shap_block = {
|
| 482 |
"available": True,
|
| 483 |
"mode": "per_class",
|