COCODEDE04 commited on
Commit
465e1a1
·
verified ·
1 Parent(s): 706263e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -39
app.py CHANGED
@@ -234,23 +234,25 @@ def apply_scaling_or_stats(raw_vec: np.ndarray) -> (np.ndarray, Dict[str, float]
234
  # --------- SHAP: model wrapper & explainer ---------
235
  def model_proba_from_z(z_batch_np: np.ndarray) -> np.ndarray:
236
  """
237
- Wrapper for SHAP: takes (N, n_features) in z-space and returns (N, K) probabilities.
 
 
238
  """
239
- raw = model.predict(z_batch_np, verbose=0)
240
- if raw.ndim != 2:
241
- raise ValueError(f"Unexpected raw shape from model: {raw.shape}")
242
- N, M = raw.shape
243
- K = len(CLASSES)
244
 
245
- if M == K - 1:
246
- probs = coral_probs_from_logits(raw) # (N, K)
247
- elif M == K:
248
- exps = np.exp(raw - np.max(raw, axis=1, keepdims=True))
249
- probs = exps / np.sum(exps, axis=1, keepdims=True)
250
- else:
251
- s = np.sum(np.abs(raw), axis=1, keepdims=True)
252
- probs = np.divide(raw, s, out=np.ones_like(raw) / max(M, 1), where=(s > 0))
253
- return probs
 
 
 
 
254
 
255
 
256
  EXPLAINER = None
@@ -341,7 +343,7 @@ async def predict(req: Request):
341
  Missing features are imputed if imputer present; else filled with means (if stats) or 0.
342
 
343
  This endpoint ALSO computes SHAP values for the *predicted class only*,
344
- returning one SHAP value per feature (21 in total).
345
  """
346
  try:
347
  payload = await req.json()
@@ -367,50 +369,50 @@ async def predict(req: Request):
367
  shap_payload: Dict[str, Any]
368
 
369
  if not SHAP_AVAILABLE:
370
- # shap library not installed in this environment
371
  shap_payload = {
372
  "available": False,
373
  "reason": "SHAP library not installed in this environment.",
374
  }
375
  else:
376
  try:
377
- # Helper: probability function in *z-space*
378
- def model_proba_from_z(z_batch_np: np.ndarray) -> np.ndarray:
379
- """
380
- Takes (N, n_features) in z-space and returns (N, K) probabilities.
381
- This mirrors the normal predict pipeline but assumes we're already in z-space.
382
- """
383
- raw_local = model.predict(z_batch_np, verbose=0)
384
- return decode_logits(raw_local)[0].reshape(-1, len(CLASSES))
385
-
386
  # Scalar function: probability of the *predicted* class only
387
  def f_scalar(z_batch):
388
- z_batch = np.array(z_batch, dtype=np.float32)
 
 
 
389
  probs_batch = model_proba_from_z(z_batch) # (N, K)
390
  return probs_batch[:, pred_idx] # (N,)
391
 
392
  # Background: 50 "average" institutions at z=0
393
  background_z = np.zeros((50, len(FEATURES)), dtype=np.float32)
394
 
395
- # Create a per-call KernelExplainer for this scalar output
396
  explainer = shap.KernelExplainer(f_scalar, background_z)
397
 
398
- # SHAP for this *one* observation (in z-space)
399
  shap_vals = explainer.shap_values(X_z, nsamples=50)
400
- shap_arr = np.array(shap_vals)
401
-
402
- # We expect shape (1, n_features) or (n_features,)
403
- if shap_arr.ndim == 2 and shap_arr.shape[0] == 1:
404
- shap_vec = shap_arr[0]
405
  else:
406
- shap_vec = shap_arr.reshape(-1)
407
 
408
- if shap_vec.size != len(FEATURES):
 
 
 
 
 
 
409
  raise ValueError(
410
- f"Unexpected SHAP vector length {shap_vec.size} "
411
  f"(expected {len(FEATURES)})"
412
  )
413
 
 
 
414
  shap_feature_contribs = {
415
  FEATURES[i]: float(shap_vec[i]) for i in range(len(FEATURES))
416
  }
@@ -452,5 +454,4 @@ async def predict(req: Request):
452
  return JSONResponse(
453
  status_code=500,
454
  content={"error": str(e), "trace": traceback.format_exc()},
455
- )
456
-
 
234
  # --------- SHAP: model wrapper & explainer ---------
235
  def model_proba_from_z(z_batch_np: np.ndarray) -> np.ndarray:
236
  """
237
+ Takes (N, n_features) OR a single 1D sample in z-space
238
+ and returns (N, K) probabilities.
239
+ Safe for both normal /predict and SHAP calls.
240
  """
241
+ z = np.array(z_batch_np, dtype=np.float32)
 
 
 
 
242
 
243
+ # Ensure 2D: (N, D)
244
+ if z.ndim == 1:
245
+ z = z.reshape(1, -1)
246
+
247
+ raw = model.predict(z, verbose=0)
248
+ probs, _ = decode_logits(raw)
249
+
250
+ # decode_logits may return (K,) if N=1, so enforce 2D
251
+ probs = np.array(probs, dtype=np.float32)
252
+ if probs.ndim == 1:
253
+ probs = probs.reshape(1, -1)
254
+
255
+ return probs # shape: (N, K)
256
 
257
 
258
  EXPLAINER = None
 
343
  Missing features are imputed if imputer present; else filled with means (if stats) or 0.
344
 
345
  This endpoint ALSO computes SHAP values for the *predicted class only*,
346
+ returning one SHAP value per feature (21 in total) when SHAP is available.
347
  """
348
  try:
349
  payload = await req.json()
 
369
  shap_payload: Dict[str, Any]
370
 
371
  if not SHAP_AVAILABLE:
 
372
  shap_payload = {
373
  "available": False,
374
  "reason": "SHAP library not installed in this environment.",
375
  }
376
  else:
377
  try:
 
 
 
 
 
 
 
 
 
378
  # Scalar function: probability of the *predicted* class only
379
  def f_scalar(z_batch):
380
+ """
381
+ z_batch: (N, D) or (D,)
382
+ returns: (N,) probability of the predicted class
383
+ """
384
  probs_batch = model_proba_from_z(z_batch) # (N, K)
385
  return probs_batch[:, pred_idx] # (N,)
386
 
387
  # Background: 50 "average" institutions at z=0
388
  background_z = np.zeros((50, len(FEATURES)), dtype=np.float32)
389
 
390
+ # KernelExplainer for a scalar-output model
391
  explainer = shap.KernelExplainer(f_scalar, background_z)
392
 
393
+ # SHAP for this one observation (in z-space)
394
  shap_vals = explainer.shap_values(X_z, nsamples=50)
395
+ # For scalar output, shap_vals is usually a 2D array (N, D),
396
+ # but some versions wrap it in a list. Handle both:
397
+ if isinstance(shap_vals, list):
398
+ shap_mat = np.array(shap_vals[0])
 
399
  else:
400
+ shap_mat = np.array(shap_vals)
401
 
402
+ # Expect (1, n_features)
403
+ if shap_mat.ndim == 1:
404
+ shap_mat = shap_mat.reshape(1, -1)
405
+
406
+ if shap_mat.shape[0] != 1:
407
+ raise ValueError(f"Unexpected SHAP batch size {shap_mat.shape[0]} (expected 1)")
408
+ if shap_mat.shape[1] != len(FEATURES):
409
  raise ValueError(
410
+ f"Unexpected SHAP vector length {shap_mat.shape[1]} "
411
  f"(expected {len(FEATURES)})"
412
  )
413
 
414
+ shap_vec = shap_mat[0] # (n_features,)
415
+
416
  shap_feature_contribs = {
417
  FEATURES[i]: float(shap_vec[i]) for i in range(len(FEATURES))
418
  }
 
454
  return JSONResponse(
455
  status_code=500,
456
  content={"error": str(e), "trace": traceback.format_exc()},
457
+ )