COCODEDE04 commited on
Commit
b8d0ae8
·
verified ·
1 Parent(s): c7c0f5c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -16
app.py CHANGED
@@ -372,7 +372,7 @@ async def predict(req: Request):
372
  probs_dict = {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}
373
  missing = [f for i, f in enumerate(FEATURES) if np.isnan(raw[i])]
374
 
375
- resp: Dict[str, Any] = {
376
  "input_ok": (len(missing) == 0),
377
  "missing": missing,
378
  "preprocess": {
@@ -380,35 +380,67 @@ async def predict(req: Request):
380
  "scaler": bool(scaler),
381
  "z_mode": z_mode,
382
  },
383
- "z_scores": z_detail, # per feature
384
  "probabilities": probs_dict,
385
  "predicted_state": CLASSES[pred_idx],
 
386
  "debug": {
387
  "raw_shape": list(raw_logits.shape),
388
  "decode_mode": mode,
389
  "raw_first_row": [float(v) for v in raw_logits[0]],
390
  },
391
  }
392
-
 
 
 
 
393
  # ---- SHAP explanation for predicted class ----
 
 
394
  if EXPLAINER is not None:
395
  try:
396
- shap_vals_list = EXPLAINER.shap_values(X, nsamples="auto")
397
- # shap_vals_list is a list of length K (classes)
398
- if isinstance(shap_vals_list, list) and len(shap_vals_list) == len(CLASSES):
399
- shap_for_pred = shap_vals_list[pred_idx][0] # (n_features,)
400
- resp["shap_target"] = CLASSES[pred_idx]
401
- resp["shap_values"] = {
402
- FEATURES[i]: float(shap_for_pred[i]) for i in range(len(FEATURES))
403
- }
 
 
 
 
 
 
 
 
 
 
 
 
404
  else:
405
- resp["shap_error"] = "Unexpected SHAP output shape."
 
 
 
 
 
 
 
 
 
 
 
 
406
  except Exception as e:
407
- resp["shap_error"] = f"SHAP computation failed: {repr(e)}"
 
 
 
408
  else:
409
- resp["shap_error"] = "SHAP not available in this environment."
410
-
411
- return resp
412
 
413
  except Exception as e:
414
  return JSONResponse(status_code=500, content={"error": str(e), "trace": traceback.format_exc()})
 
372
  probs_dict = {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}
373
  missing = [f for i, f in enumerate(FEATURES) if np.isnan(raw[i])]
374
 
375
+ return {
376
  "input_ok": (len(missing) == 0),
377
  "missing": missing,
378
  "preprocess": {
 
380
  "scaler": bool(scaler),
381
  "z_mode": z_mode,
382
  },
383
+ "z_scores": z_detail,
384
  "probabilities": probs_dict,
385
  "predicted_state": CLASSES[pred_idx],
386
+ "shap": shap_out,
387
  "debug": {
388
  "raw_shape": list(raw_logits.shape),
389
  "decode_mode": mode,
390
  "raw_first_row": [float(v) for v in raw_logits[0]],
391
  },
392
  }
393
+
394
+ pred_idx = int(np.argmax(probs))
395
+ probs_dict = {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}
396
+ missing = [f for i, f in enumerate(FEATURES) if np.isnan(raw[i])]
397
+
398
  # ---- SHAP explanation for predicted class ----
399
+ # -------- SHAP EXPLANATION (predicted class only) --------
400
+ shap_out = None
401
  if EXPLAINER is not None:
402
  try:
403
+ # X is already z-space: shape (1, n_features)
404
+ shap_vals = EXPLAINER.shap_values(X, nsamples=100)
405
+
406
+ # Case 1: multi-output -> list of length K, each (1, n_features)
407
+ if isinstance(shap_vals, list):
408
+ shap_vec = np.array(shap_vals[pred_idx][0], dtype=float)
409
+ # expected_value may also be a list per class
410
+ exp_val_raw = EXPLAINER.expected_value
411
+ if isinstance(exp_val_raw, (list, np.ndarray)):
412
+ exp_val = float(exp_val_raw[pred_idx])
413
+ else:
414
+ exp_val = float(exp_val_raw)
415
+ # Case 2: single-output -> ndarray (1, n_features)
416
+ elif isinstance(shap_vals, np.ndarray):
417
+ shap_vec = np.array(shap_vals[0], dtype=float)
418
+ exp_val_raw = EXPLAINER.expected_value
419
+ if isinstance(exp_val_raw, (list, np.ndarray)):
420
+ exp_val = float(exp_val_raw[0])
421
+ else:
422
+ exp_val = float(exp_val_raw)
423
  else:
424
+ raise TypeError(f"Unsupported SHAP return type: {type(shap_vals)}")
425
+
426
+ # Map feature -> SHAP contribution (for the predicted class)
427
+ shap_feature_contribs = {
428
+ FEATURES[i]: float(shap_vec[i])
429
+ for i in range(len(FEATURES))
430
+ }
431
+
432
+ shap_out = {
433
+ "explained_class": CLASSES[pred_idx],
434
+ "expected_value": exp_val,
435
+ "shap_values": shap_feature_contribs,
436
+ }
437
  except Exception as e:
438
+ shap_out = {
439
+ "error": str(e),
440
+ "trace": traceback.format_exc()
441
+ }
442
  else:
443
+ shap_out = {"error": "SHAP not available on server"}
 
 
444
 
445
  except Exception as e:
446
  return JSONResponse(status_code=500, content={"error": str(e), "trace": traceback.format_exc()})