Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -372,7 +372,7 @@ async def predict(req: Request):
|
|
| 372 |
probs_dict = {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}
|
| 373 |
missing = [f for i, f in enumerate(FEATURES) if np.isnan(raw[i])]
|
| 374 |
|
| 375 |
-
|
| 376 |
"input_ok": (len(missing) == 0),
|
| 377 |
"missing": missing,
|
| 378 |
"preprocess": {
|
|
@@ -380,35 +380,67 @@ async def predict(req: Request):
|
|
| 380 |
"scaler": bool(scaler),
|
| 381 |
"z_mode": z_mode,
|
| 382 |
},
|
| 383 |
-
"z_scores": z_detail,
|
| 384 |
"probabilities": probs_dict,
|
| 385 |
"predicted_state": CLASSES[pred_idx],
|
|
|
|
| 386 |
"debug": {
|
| 387 |
"raw_shape": list(raw_logits.shape),
|
| 388 |
"decode_mode": mode,
|
| 389 |
"raw_first_row": [float(v) for v in raw_logits[0]],
|
| 390 |
},
|
| 391 |
}
|
| 392 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
# ---- SHAP explanation for predicted class ----
|
|
|
|
|
|
|
| 394 |
if EXPLAINER is not None:
|
| 395 |
try:
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
else:
|
| 405 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
except Exception as e:
|
| 407 |
-
|
|
|
|
|
|
|
|
|
|
| 408 |
else:
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
return resp
|
| 412 |
|
| 413 |
except Exception as e:
|
| 414 |
return JSONResponse(status_code=500, content={"error": str(e), "trace": traceback.format_exc()})
|
|
|
|
| 372 |
probs_dict = {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}
|
| 373 |
missing = [f for i, f in enumerate(FEATURES) if np.isnan(raw[i])]
|
| 374 |
|
| 375 |
+
return {
|
| 376 |
"input_ok": (len(missing) == 0),
|
| 377 |
"missing": missing,
|
| 378 |
"preprocess": {
|
|
|
|
| 380 |
"scaler": bool(scaler),
|
| 381 |
"z_mode": z_mode,
|
| 382 |
},
|
| 383 |
+
"z_scores": z_detail,
|
| 384 |
"probabilities": probs_dict,
|
| 385 |
"predicted_state": CLASSES[pred_idx],
|
| 386 |
+
"shap": shap_out,
|
| 387 |
"debug": {
|
| 388 |
"raw_shape": list(raw_logits.shape),
|
| 389 |
"decode_mode": mode,
|
| 390 |
"raw_first_row": [float(v) for v in raw_logits[0]],
|
| 391 |
},
|
| 392 |
}
|
| 393 |
+
|
| 394 |
+
pred_idx = int(np.argmax(probs))
|
| 395 |
+
probs_dict = {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}
|
| 396 |
+
missing = [f for i, f in enumerate(FEATURES) if np.isnan(raw[i])]
|
| 397 |
+
|
| 398 |
# ---- SHAP explanation for predicted class ----
|
| 399 |
+
# -------- SHAP EXPLANATION (predicted class only) --------
|
| 400 |
+
shap_out = None
|
| 401 |
if EXPLAINER is not None:
|
| 402 |
try:
|
| 403 |
+
# X is already z-space: shape (1, n_features)
|
| 404 |
+
shap_vals = EXPLAINER.shap_values(X, nsamples=100)
|
| 405 |
+
|
| 406 |
+
# Case 1: multi-output -> list of length K, each (1, n_features)
|
| 407 |
+
if isinstance(shap_vals, list):
|
| 408 |
+
shap_vec = np.array(shap_vals[pred_idx][0], dtype=float)
|
| 409 |
+
# expected_value may also be a list per class
|
| 410 |
+
exp_val_raw = EXPLAINER.expected_value
|
| 411 |
+
if isinstance(exp_val_raw, (list, np.ndarray)):
|
| 412 |
+
exp_val = float(exp_val_raw[pred_idx])
|
| 413 |
+
else:
|
| 414 |
+
exp_val = float(exp_val_raw)
|
| 415 |
+
# Case 2: single-output -> ndarray (1, n_features)
|
| 416 |
+
elif isinstance(shap_vals, np.ndarray):
|
| 417 |
+
shap_vec = np.array(shap_vals[0], dtype=float)
|
| 418 |
+
exp_val_raw = EXPLAINER.expected_value
|
| 419 |
+
if isinstance(exp_val_raw, (list, np.ndarray)):
|
| 420 |
+
exp_val = float(exp_val_raw[0])
|
| 421 |
+
else:
|
| 422 |
+
exp_val = float(exp_val_raw)
|
| 423 |
else:
|
| 424 |
+
raise TypeError(f"Unsupported SHAP return type: {type(shap_vals)}")
|
| 425 |
+
|
| 426 |
+
# Map feature -> SHAP contribution (for the predicted class)
|
| 427 |
+
shap_feature_contribs = {
|
| 428 |
+
FEATURES[i]: float(shap_vec[i])
|
| 429 |
+
for i in range(len(FEATURES))
|
| 430 |
+
}
|
| 431 |
+
|
| 432 |
+
shap_out = {
|
| 433 |
+
"explained_class": CLASSES[pred_idx],
|
| 434 |
+
"expected_value": exp_val,
|
| 435 |
+
"shap_values": shap_feature_contribs,
|
| 436 |
+
}
|
| 437 |
except Exception as e:
|
| 438 |
+
shap_out = {
|
| 439 |
+
"error": str(e),
|
| 440 |
+
"trace": traceback.format_exc()
|
| 441 |
+
}
|
| 442 |
else:
|
| 443 |
+
shap_out = {"error": "SHAP not available on server"}
|
|
|
|
|
|
|
| 444 |
|
| 445 |
except Exception as e:
|
| 446 |
return JSONResponse(status_code=500, content={"error": str(e), "trace": traceback.format_exc()})
|