COCODEDE04 commited on
Commit
706263e
·
verified ·
1 Parent(s): a6c0646

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -60
app.py CHANGED
@@ -339,73 +339,87 @@ async def predict(req: Request):
339
  """
340
  Body: JSON object mapping feature -> numeric value (strings with commas/points ok).
341
  Missing features are imputed if imputer present; else filled with means (if stats) or 0.
 
 
 
342
  """
343
  try:
344
  payload = await req.json()
345
  if not isinstance(payload, dict):
346
  return JSONResponse(status_code=400, content={"error": "Expected JSON object"})
347
 
348
- # Build in EXACT training order
349
  raw = build_raw_vector(payload) # may contain NaNs
350
- raw_imp = apply_imputer_if_any(raw) # impute
351
- z_vec, z_detail, z_mode = apply_scaling_or_stats(raw_imp) # scale / z-score
352
 
353
- # Predict
354
- X = z_vec.reshape(1, -1).astype(np.float32)
355
- raw_logits = model.predict(X, verbose=0)
356
- probs, mode = decode_logits(raw_logits)
357
 
358
  pred_idx = int(np.argmax(probs))
 
359
  probs_dict = {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}
360
  missing = [f for i, f in enumerate(FEATURES) if np.isnan(raw[i])]
361
 
362
- # ---------- SHAP for ALL classes ----------
363
- shap_payload: Dict[str, Any] = {"available": bool(EXPLAINER)}
364
- if EXPLAINER is not None:
 
 
 
 
 
 
 
365
  try:
366
- shap_raw = EXPLAINER.shap_values(X, nsamples=100)
367
- shap_all_classes: Dict[str, Dict[str, float]] = {}
368
-
369
- if isinstance(shap_raw, list):
370
- # standard KernelExplainer multi-output: list of length K, each (1, n_features)
371
- for c_idx, cls_name in enumerate(CLASSES):
372
- if c_idx >= len(shap_raw):
373
- break
374
- arr = np.array(shap_raw[c_idx])
375
- if arr.ndim == 2:
376
- vec = arr[0]
377
- else:
378
- vec = arr.reshape(-1)
379
- m = min(len(FEATURES), len(vec))
380
- shap_all_classes[cls_name] = {
381
- FEATURES[i]: float(vec[i]) for i in range(m)
382
- }
 
 
 
 
 
 
 
 
 
 
 
383
  else:
384
- # Fallback: single ndarray, try to interpret first dim as classes
385
- arr = np.array(shap_raw)
386
- if arr.ndim == 3:
387
- # e.g. (K, 1, n_features) or (1, K, n_features)
388
- if arr.shape[1] == 1:
389
- arr2 = arr[:, 0, :]
390
- elif arr.shape[0] == 1:
391
- arr2 = arr[0, :, :]
392
- else:
393
- arr2 = arr.reshape(arr.shape[0], -1)
394
- elif arr.ndim == 2:
395
- # (K, n_features)
396
- arr2 = arr
397
- else:
398
- raise ValueError(f"Unsupported SHAP array shape: {arr.shape}")
399
-
400
- K_eff = min(arr2.shape[0], len(CLASSES))
401
- for c_idx in range(K_eff):
402
- vec = arr2[c_idx]
403
- m = min(len(FEATURES), len(vec))
404
- shap_all_classes[CLASSES[c_idx]] = {
405
- FEATURES[i]: float(vec[i]) for i in range(m)
406
- }
407
-
408
- shap_payload["all_classes"] = shap_all_classes
409
 
410
  except Exception as e:
411
  shap_payload = {
@@ -414,7 +428,7 @@ async def predict(req: Request):
414
  "trace": traceback.format_exc(),
415
  }
416
 
417
- # ---------- final response ----------
418
  return {
419
  "input_ok": (len(missing) == 0),
420
  "missing": missing,
@@ -423,13 +437,13 @@ async def predict(req: Request):
423
  "scaler": bool(scaler),
424
  "z_mode": z_mode,
425
  },
426
- "z_scores": z_detail, # per feature
427
- "probabilities": probs_dict,
428
- "predicted_state": CLASSES[pred_idx],
429
- "shap": shap_payload, # FULL per-class SHAP matrix
430
  "debug": {
431
  "raw_shape": list(raw_logits.shape),
432
- "decode_mode": mode,
433
  "raw_first_row": [float(v) for v in raw_logits[0]],
434
  },
435
  }
@@ -437,5 +451,6 @@ async def predict(req: Request):
437
  except Exception as e:
438
  return JSONResponse(
439
  status_code=500,
440
- content={"error": str(e), "trace": traceback.format_exc()}
441
- )
 
 
339
  """
340
  Body: JSON object mapping feature -> numeric value (strings with commas/points ok).
341
  Missing features are imputed if imputer present; else filled with means (if stats) or 0.
342
+
343
+ This endpoint ALSO computes SHAP values for the *predicted class only*,
344
+ returning one SHAP value per feature (21 in total).
345
  """
346
  try:
347
  payload = await req.json()
348
  if not isinstance(payload, dict):
349
  return JSONResponse(status_code=400, content={"error": "Expected JSON object"})
350
 
351
+ # ---------- 1) Build features in EXACT training order ----------
352
  raw = build_raw_vector(payload) # may contain NaNs
353
+ raw_imp = apply_imputer_if_any(raw) # median / training imputer
354
+ z_vec, z_detail, z_mode = apply_scaling_or_stats(raw_imp) # scaler or manual z-score
355
 
356
+ # ---------- 2) Model prediction ----------
357
+ X_z = z_vec.reshape(1, -1).astype(np.float32)
358
+ raw_logits = model.predict(X_z, verbose=0)
359
+ probs, decode_mode = decode_logits(raw_logits)
360
 
361
  pred_idx = int(np.argmax(probs))
362
+ pred_class = CLASSES[pred_idx]
363
  probs_dict = {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}
364
  missing = [f for i, f in enumerate(FEATURES) if np.isnan(raw[i])]
365
 
366
+ # ---------- 3) SHAP explanation for the predicted class ----------
367
+ shap_payload: Dict[str, Any]
368
+
369
+ if not SHAP_AVAILABLE:
370
+ # shap library not installed in this environment
371
+ shap_payload = {
372
+ "available": False,
373
+ "reason": "SHAP library not installed in this environment.",
374
+ }
375
+ else:
376
  try:
377
+ # Helper: probability function in *z-space*
378
+ def model_proba_from_z(z_batch_np: np.ndarray) -> np.ndarray:
379
+ """
380
+ Takes (N, n_features) in z-space and returns (N, K) probabilities.
381
+ This mirrors the normal predict pipeline but assumes we're already in z-space.
382
+ """
383
+ raw_local = model.predict(z_batch_np, verbose=0)
384
+ return decode_logits(raw_local)[0].reshape(-1, len(CLASSES))
385
+
386
+ # Scalar function: probability of the *predicted* class only
387
+ def f_scalar(z_batch):
388
+ z_batch = np.array(z_batch, dtype=np.float32)
389
+ probs_batch = model_proba_from_z(z_batch) # (N, K)
390
+ return probs_batch[:, pred_idx] # (N,)
391
+
392
+ # Background: 50 "average" institutions at z=0
393
+ background_z = np.zeros((50, len(FEATURES)), dtype=np.float32)
394
+
395
+ # Create a per-call KernelExplainer for this scalar output
396
+ explainer = shap.KernelExplainer(f_scalar, background_z)
397
+
398
+ # SHAP for this *one* observation (in z-space)
399
+ shap_vals = explainer.shap_values(X_z, nsamples=50)
400
+ shap_arr = np.array(shap_vals)
401
+
402
+ # We expect shape (1, n_features) or (n_features,)
403
+ if shap_arr.ndim == 2 and shap_arr.shape[0] == 1:
404
+ shap_vec = shap_arr[0]
405
  else:
406
+ shap_vec = shap_arr.reshape(-1)
407
+
408
+ if shap_vec.size != len(FEATURES):
409
+ raise ValueError(
410
+ f"Unexpected SHAP vector length {shap_vec.size} "
411
+ f"(expected {len(FEATURES)})"
412
+ )
413
+
414
+ shap_feature_contribs = {
415
+ FEATURES[i]: float(shap_vec[i]) for i in range(len(FEATURES))
416
+ }
417
+
418
+ shap_payload = {
419
+ "available": True,
420
+ "class": pred_class,
421
+ "values": shap_feature_contribs,
422
+ }
 
 
 
 
 
 
 
 
423
 
424
  except Exception as e:
425
  shap_payload = {
 
428
  "trace": traceback.format_exc(),
429
  }
430
 
431
+ # ---------- 4) Final JSON response ----------
432
  return {
433
  "input_ok": (len(missing) == 0),
434
  "missing": missing,
 
437
  "scaler": bool(scaler),
438
  "z_mode": z_mode,
439
  },
440
+ "z_scores": z_detail, # per feature (model input)
441
+ "probabilities": probs_dict, # state → probability
442
+ "predicted_state": pred_class,
443
+ "shap": shap_payload, # explanation for predicted class only
444
  "debug": {
445
  "raw_shape": list(raw_logits.shape),
446
+ "decode_mode": decode_mode,
447
  "raw_first_row": [float(v) for v in raw_logits[0]],
448
  },
449
  }
 
451
  except Exception as e:
452
  return JSONResponse(
453
  status_code=500,
454
+ content={"error": str(e), "trace": traceback.format_exc()},
455
+ )
456
+