Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -362,95 +362,99 @@ async def predict(req: Request):
|
|
| 362 |
Body: JSON object mapping feature -> numeric value (strings with commas/points ok).
|
| 363 |
Missing features are imputed if imputer present; else filled with means (if stats) or 0.
|
| 364 |
|
| 365 |
-
|
| 366 |
-
|
|
|
|
|
|
|
| 367 |
"""
|
| 368 |
try:
|
| 369 |
payload = await req.json()
|
| 370 |
if not isinstance(payload, dict):
|
| 371 |
-
return JSONResponse(
|
|
|
|
|
|
|
|
|
|
| 372 |
|
| 373 |
-
#
|
| 374 |
-
raw = build_raw_vector(payload)
|
| 375 |
-
raw_imp = apply_imputer_if_any(raw)
|
| 376 |
-
z_vec, z_detail, z_mode = apply_scaling_or_stats(raw_imp) #
|
| 377 |
|
| 378 |
-
#
|
| 379 |
-
X_z = z_vec.reshape(1, -1).astype(np.float32)
|
| 380 |
-
raw_logits = model.predict(X_z, verbose=0)
|
| 381 |
-
probs, decode_mode = decode_logits(raw_logits)
|
| 382 |
|
| 383 |
pred_idx = int(np.argmax(probs))
|
| 384 |
-
pred_class = CLASSES[pred_idx]
|
| 385 |
probs_dict = {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}
|
| 386 |
missing = [f for i, f in enumerate(FEATURES) if np.isnan(raw[i])]
|
| 387 |
|
| 388 |
-
#
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
if not SHAP_AVAILABLE:
|
| 392 |
-
shap_payload = {
|
| 393 |
-
"available": False,
|
| 394 |
-
"reason": "SHAP library not installed in this environment.",
|
| 395 |
-
}
|
| 396 |
-
else:
|
| 397 |
try:
|
| 398 |
-
#
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
# Background: 50 "average" institutions at z=0
|
| 408 |
-
background_z = np.zeros((50, len(FEATURES)), dtype=np.float32)
|
| 409 |
-
|
| 410 |
-
# KernelExplainer for a scalar-output model
|
| 411 |
-
explainer = shap.KernelExplainer(f_scalar, background_z)
|
| 412 |
-
|
| 413 |
-
# SHAP for this one observation (in z-space)
|
| 414 |
-
shap_vals = explainer.shap_values(X_z, nsamples=50)
|
| 415 |
-
# For scalar output, shap_vals is usually a 2D array (N, D),
|
| 416 |
-
# but some versions wrap it in a list. Handle both:
|
| 417 |
if isinstance(shap_vals, list):
|
| 418 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 419 |
else:
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
# Expect (1, n_features)
|
| 423 |
-
if shap_mat.ndim == 1:
|
| 424 |
-
shap_mat = shap_mat.reshape(1, -1)
|
| 425 |
-
|
| 426 |
-
if shap_mat.shape[0] != 1:
|
| 427 |
-
raise ValueError(f"Unexpected SHAP batch size {shap_mat.shape[0]} (expected 1)")
|
| 428 |
-
if shap_mat.shape[1] != len(FEATURES):
|
| 429 |
-
raise ValueError(
|
| 430 |
-
f"Unexpected SHAP vector length {shap_mat.shape[1]} "
|
| 431 |
-
f"(expected {len(FEATURES)})"
|
| 432 |
)
|
| 433 |
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
shap_feature_contribs = {
|
| 437 |
-
FEATURES[i]: float(shap_vec[i]) for i in range(len(FEATURES))
|
| 438 |
-
}
|
| 439 |
-
|
| 440 |
-
shap_payload = {
|
| 441 |
"available": True,
|
| 442 |
-
"
|
| 443 |
-
"
|
| 444 |
}
|
| 445 |
|
| 446 |
except Exception as e:
|
| 447 |
-
|
| 448 |
"available": False,
|
| 449 |
"error": str(e),
|
| 450 |
"trace": traceback.format_exc(),
|
| 451 |
}
|
| 452 |
|
| 453 |
-
#
|
| 454 |
return {
|
| 455 |
"input_ok": (len(missing) == 0),
|
| 456 |
"missing": missing,
|
|
@@ -459,10 +463,10 @@ async def predict(req: Request):
|
|
| 459 |
"scaler": bool(scaler),
|
| 460 |
"z_mode": z_mode,
|
| 461 |
},
|
| 462 |
-
"z_scores": z_detail, # per
|
| 463 |
-
"probabilities": probs_dict,
|
| 464 |
-
"predicted_state":
|
| 465 |
-
"shap":
|
| 466 |
"debug": {
|
| 467 |
"raw_shape": list(raw_logits.shape),
|
| 468 |
"decode_mode": decode_mode,
|
|
|
|
| 362 |
Body: JSON object mapping feature -> numeric value (strings with commas/points ok).
|
| 363 |
Missing features are imputed if imputer present; else filled with means (if stats) or 0.
|
| 364 |
|
| 365 |
+
Returns:
|
| 366 |
+
- probabilities over classes
|
| 367 |
+
- z-scores per indicator
|
| 368 |
+
- SHAP contributions for *all* classes (if SHAP is available), in z-space.
|
| 369 |
"""
|
| 370 |
try:
|
| 371 |
payload = await req.json()
|
| 372 |
if not isinstance(payload, dict):
|
| 373 |
+
return JSONResponse(
|
| 374 |
+
status_code=400,
|
| 375 |
+
content={"error": "Expected JSON object"},
|
| 376 |
+
)
|
| 377 |
|
| 378 |
+
# 1) Build raw feature vector in training order
|
| 379 |
+
raw = build_raw_vector(payload) # may contain NaNs
|
| 380 |
+
raw_imp = apply_imputer_if_any(raw) # impute
|
| 381 |
+
z_vec, z_detail, z_mode = apply_scaling_or_stats(raw_imp) # scale / z-score
|
| 382 |
|
| 383 |
+
# 2) Predict
|
| 384 |
+
X_z = z_vec.reshape(1, -1).astype(np.float32) # (1, D) in z-space
|
| 385 |
+
raw_logits = model.predict(X_z, verbose=0) # (1, M)
|
| 386 |
+
probs, decode_mode = decode_logits(raw_logits) # (K,)
|
| 387 |
|
| 388 |
pred_idx = int(np.argmax(probs))
|
|
|
|
| 389 |
probs_dict = {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}
|
| 390 |
missing = [f for i, f in enumerate(FEATURES) if np.isnan(raw[i])]
|
| 391 |
|
| 392 |
+
# 3) SHAP for ALL classes (if explainer is available)
|
| 393 |
+
shap_block: Dict[str, Any] = {"available": False}
|
| 394 |
+
if EXPLAINER is not None and SHAP_AVAILABLE:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
try:
|
| 396 |
+
# KernelExplainer built with model_proba_from_z, so we pass z-space
|
| 397 |
+
shap_vals = EXPLAINER.shap_values(X_z, nsamples=50)
|
| 398 |
+
K = len(CLASSES)
|
| 399 |
+
D = len(FEATURES)
|
| 400 |
+
|
| 401 |
+
all_classes: Dict[str, Dict[str, float]] = {}
|
| 402 |
+
|
| 403 |
+
# Case 1: vector-output model → list of length K
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
if isinstance(shap_vals, list):
|
| 405 |
+
if len(shap_vals) != K:
|
| 406 |
+
raise ValueError(
|
| 407 |
+
f"Expected {K} SHAP arrays (one per class), got {len(shap_vals)}"
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
for c_idx, cname in enumerate(CLASSES):
|
| 411 |
+
arr = np.asarray(shap_vals[c_idx])
|
| 412 |
+
if arr.ndim != 2 or arr.shape[0] < 1 or arr.shape[1] != D:
|
| 413 |
+
raise ValueError(
|
| 414 |
+
f"Unexpected SHAP shape for class {cname}: {arr.shape}, expected (1,{D})"
|
| 415 |
+
)
|
| 416 |
+
vec = arr[0] # (D,)
|
| 417 |
+
all_classes[cname] = {
|
| 418 |
+
FEATURES[i]: float(vec[i]) for i in range(D)
|
| 419 |
+
}
|
| 420 |
+
|
| 421 |
+
# Case 2: some SHAP versions return a single (K,D) array
|
| 422 |
+
elif isinstance(shap_vals, np.ndarray):
|
| 423 |
+
arr = np.asarray(shap_vals)
|
| 424 |
+
if arr.ndim == 3 and arr.shape[0] == 1 and arr.shape[2] == D:
|
| 425 |
+
# shape (1, K, D) → take [0]
|
| 426 |
+
arr = arr[0]
|
| 427 |
+
if arr.ndim != 2 or arr.shape[0] != K or arr.shape[1] != D:
|
| 428 |
+
raise ValueError(
|
| 429 |
+
f"Unexpected SHAP ndarray shape {arr.shape}; "
|
| 430 |
+
f"expected (K,{D}) or (1,K,{D})"
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
for c_idx, cname in enumerate(CLASSES):
|
| 434 |
+
vec = arr[c_idx] # (D,)
|
| 435 |
+
all_classes[cname] = {
|
| 436 |
+
FEATURES[i]: float(vec[i]) for i in range(D)
|
| 437 |
+
}
|
| 438 |
+
|
| 439 |
else:
|
| 440 |
+
raise TypeError(
|
| 441 |
+
f"Unsupported SHAP output type: {type(shap_vals).__name__}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 442 |
)
|
| 443 |
|
| 444 |
+
shap_block = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 445 |
"available": True,
|
| 446 |
+
"predicted_class": CLASSES[pred_idx],
|
| 447 |
+
"all_classes": all_classes,
|
| 448 |
}
|
| 449 |
|
| 450 |
except Exception as e:
|
| 451 |
+
shap_block = {
|
| 452 |
"available": False,
|
| 453 |
"error": str(e),
|
| 454 |
"trace": traceback.format_exc(),
|
| 455 |
}
|
| 456 |
|
| 457 |
+
# 4) Final response
|
| 458 |
return {
|
| 459 |
"input_ok": (len(missing) == 0),
|
| 460 |
"missing": missing,
|
|
|
|
| 463 |
"scaler": bool(scaler),
|
| 464 |
"z_mode": z_mode,
|
| 465 |
},
|
| 466 |
+
"z_scores": z_detail, # per indicator, in z-space
|
| 467 |
+
"probabilities": probs_dict,
|
| 468 |
+
"predicted_state": CLASSES[pred_idx],
|
| 469 |
+
"shap": shap_block,
|
| 470 |
"debug": {
|
| 471 |
"raw_shape": list(raw_logits.shape),
|
| 472 |
"decode_mode": decode_mode,
|