COCODEDE04 commited on
Commit
04ccf8e
·
verified ·
1 Parent(s): a850728

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -86
app.py CHANGED
@@ -350,107 +350,91 @@ async def predict(req: Request):
350
  """
351
  Body: JSON object mapping feature -> numeric value (strings with commas/points ok).
352
  Missing features are imputed if imputer present; else filled with means (if stats) or 0.
 
 
 
 
 
353
  """
354
  try:
355
  payload = await req.json()
356
  if not isinstance(payload, dict):
357
- return JSONResponse(status_code=400, content={"error": "Expected JSON object"})
 
 
 
358
 
359
- # ---------- PREPROCESSING ----------
360
- raw = build_raw_vector(payload) # may contain NaNs
361
- raw_imp = apply_imputer_if_any(raw) # impute
362
- z_vec, z_detail, z_mode = apply_scaling_or_stats(raw_imp) # scale / z-score
363
 
364
- # ---------- PREDICTION ----------
365
  X = z_vec.reshape(1, -1).astype(np.float32)
366
  raw_logits = model.predict(X, verbose=0)
367
- probs, mode = decode_logits(raw_logits)
368
 
369
  pred_idx = int(np.argmax(probs))
370
  probs_dict = {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}
371
- missing = [f for i, f in enumerate(FEATURES) if np.isnan(raw[i])]
 
 
 
372
 
373
- # ---------- SHAP EXPLANATION (predicted class only) ----------
374
- shap_out = {"error": "SHAP not computed"}
375
- if EXPLAINER is not None:
376
  try:
377
- shap_vals = EXPLAINER.shap_values(X, nsamples=100)
 
 
 
 
378
 
379
- # 1) Pull raw SHAP tensor
380
  if isinstance(shap_vals, list):
381
- # Classic multi-output: list[len = n_classes], each (n_samples, n_features)
382
- raw_sv = np.array(shap_vals[pred_idx])
383
- else:
384
- # Single array, possibly (n_samples, n_features) or (n_samples, n_features, n_outputs)
385
- raw_sv = np.array(shap_vals)
386
-
387
- # 2) Normalize shapes to a 1D vector (n_features,) for the predicted class
388
- if raw_sv.ndim == 1:
389
- # Already (n_features,)
390
- shap_vec = raw_sv.astype(float)
391
-
392
- elif raw_sv.ndim == 2:
393
- # (n_samples, n_features) or (n_features, 1)
394
- if raw_sv.shape[0] == 1:
395
- # (1, n_features)
396
- shap_vec = raw_sv[0].astype(float)
397
- elif raw_sv.shape[1] == 1:
398
- # (n_features, 1)
399
- shap_vec = raw_sv[:, 0].astype(float)
400
- else:
401
- # assume (n_samples, n_features), take first sample
402
- shap_vec = raw_sv[0].astype(float)
403
-
404
- elif raw_sv.ndim == 3:
405
- # Most likely (n_samples, n_features, n_outputs)
406
- n_samples, n_features, n_outputs = raw_sv.shape
407
- if n_samples < 1:
408
- raise ValueError(f"SHAP 3D output has zero samples: {raw_sv.shape}")
409
- if pred_idx >= n_outputs:
410
- raise ValueError(
411
- f"SHAP 3D output has only {n_outputs} outputs, "
412
- f"cannot index class {pred_idx}"
413
- )
414
- # take first sample, all features, predicted class
415
- shap_vec = raw_sv[0, :, pred_idx].astype(float)
416
 
417
  else:
418
- # Fallback: flatten all sample dims, keep first feature-block
419
- flat = raw_sv.reshape(raw_sv.shape[0], -1)
420
- shap_vec = flat[0].astype(float)
421
-
422
- # 3) Sanity check length
423
- if shap_vec.shape[0] != len(FEATURES):
424
- raise ValueError(
425
- f"Unexpected SHAP vector length {shap_vec.shape[0]} "
426
- f"(expected {len(FEATURES)})"
427
- )
428
-
429
- # 4) Expected value (baseline) for the predicted class
430
- exp_raw = EXPLAINER.expected_value
431
- if isinstance(exp_raw, (list, np.ndarray)):
432
- exp_val = float(np.array(exp_raw)[pred_idx])
433
- else:
434
- exp_val = float(exp_raw)
435
-
436
- # 5) Map feature -> contribution
437
- shap_feature_contribs = {
438
- FEATURES[i]: float(shap_vec[i])
439
- for i in range(len(FEATURES))
440
- }
441
-
442
- shap_out = {
443
- "explained_class": CLASSES[pred_idx],
444
- "expected_value": exp_val,
445
- "shap_values": shap_feature_contribs,
446
- }
447
 
448
  except Exception as e:
449
- shap_out = {"error": str(e), "trace": traceback.format_exc()}
450
- else:
451
- shap_out = {"error": "SHAP not available on server"}
 
 
452
 
453
- # ---------- RESPONSE ----------
454
  return {
455
  "input_ok": (len(missing) == 0),
456
  "missing": missing,
@@ -459,13 +443,13 @@ async def predict(req: Request):
459
  "scaler": bool(scaler),
460
  "z_mode": z_mode,
461
  },
462
- "z_scores": z_detail, # per feature (z-space)
463
- "probabilities": probs_dict, # per class
464
  "predicted_state": CLASSES[pred_idx],
465
- "shap": shap_out, # SHAP for predicted state only
466
  "debug": {
467
  "raw_shape": list(raw_logits.shape),
468
- "decode_mode": mode,
469
  "raw_first_row": [float(v) for v in raw_logits[0]],
470
  },
471
  }
@@ -473,5 +457,5 @@ async def predict(req: Request):
473
  except Exception as e:
474
  return JSONResponse(
475
  status_code=500,
476
- content={"error": str(e), "trace": traceback.format_exc()}
477
  )
 
350
  """
351
  Body: JSON object mapping feature -> numeric value (strings with commas/points ok).
352
  Missing features are imputed if imputer present; else filled with means (if stats) or 0.
353
+ Returns:
354
+ - probabilities per state
355
+ - predicted_state
356
+ - z_scores (per feature, after imputation & scaling pipeline)
357
+ - shap: per-class explanations if available
358
  """
359
  try:
360
  payload = await req.json()
361
  if not isinstance(payload, dict):
362
+ return JSONResponse(
363
+ status_code=400,
364
+ content={"error": "Expected JSON object"},
365
+ )
366
 
367
+ # ---------- 1) Preprocess: raw -> imputed -> z ----------
368
+ raw_vec = build_raw_vector(payload) # (21,) may contain NaNs
369
+ raw_imp = apply_imputer_if_any(raw_vec) # impute missing
370
+ z_vec, z_detail, z_mode = apply_scaling_or_stats(raw_imp)
371
 
372
+ # ---------- 2) Model prediction ----------
373
  X = z_vec.reshape(1, -1).astype(np.float32)
374
  raw_logits = model.predict(X, verbose=0)
375
+ probs, decode_mode = decode_logits(raw_logits)
376
 
377
  pred_idx = int(np.argmax(probs))
378
  probs_dict = {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}
379
+ missing = [f for i, f in enumerate(FEATURES) if np.isnan(raw_vec[i])]
380
+
381
+ # ---------- 3) SHAP explanations (all classes) ----------
382
+ shap_block: Dict[str, Any] = {"available": False}
383
 
384
+ if EXPLAINER is not None and SHAP_AVAILABLE:
 
 
385
  try:
386
+ X_z = z_vec.reshape(1, -1).astype(np.float32)
387
+ # KernelExplainer: usually returns list of length K (one (1,D) array per class)
388
+ shap_vals = EXPLAINER.shap_values(X_z, nsamples=50)
389
+
390
+ all_classes: Dict[str, Dict[str, float]] = {}
391
 
 
392
  if isinstance(shap_vals, list):
393
+ # per-class outputs
394
+ for k, class_name in enumerate(CLASSES):
395
+ if k >= len(shap_vals):
396
+ continue
397
+ vec = np.array(shap_vals[k][0], dtype=float) # shape (D,)
398
+ if vec.shape[0] != len(FEATURES):
399
+ # shape mismatch: bail out gracefully
400
+ continue
401
+ all_classes[class_name] = {
402
+ FEATURES[i]: float(vec[i]) for i in range(len(FEATURES))
403
+ }
404
+
405
+ shap_block = {
406
+ "available": True,
407
+ "mode": "per_class",
408
+ "explained_classes": list(all_classes.keys()),
409
+ "all_classes": all_classes,
410
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
411
 
412
  else:
413
+ # single array (1, D) treat as "predicted class only" fallback
414
+ vec = np.array(shap_vals[0], dtype=float)
415
+ if vec.shape[0] == len(FEATURES):
416
+ shap_block = {
417
+ "available": True,
418
+ "mode": "single_class",
419
+ "explained_class": CLASSES[pred_idx],
420
+ "values": {
421
+ FEATURES[i]: float(vec[i]) for i in range(len(FEATURES))
422
+ },
423
+ }
424
+ else:
425
+ shap_block = {
426
+ "available": False,
427
+ "error": f"Unexpected SHAP vector length {vec.shape[0]} (expected {len(FEATURES)})",
428
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
429
 
430
  except Exception as e:
431
+ shap_block = {
432
+ "available": False,
433
+ "error": str(e),
434
+ "trace": traceback.format_exc(),
435
+ }
436
 
437
+ # ---------- 4) Build response ----------
438
  return {
439
  "input_ok": (len(missing) == 0),
440
  "missing": missing,
 
443
  "scaler": bool(scaler),
444
  "z_mode": z_mode,
445
  },
446
+ "z_scores": z_detail, # per feature
447
+ "probabilities": probs_dict, # per state
448
  "predicted_state": CLASSES[pred_idx],
449
+ "shap": shap_block,
450
  "debug": {
451
  "raw_shape": list(raw_logits.shape),
452
+ "decode_mode": decode_mode,
453
  "raw_first_row": [float(v) for v in raw_logits[0]],
454
  },
455
  }
 
457
  except Exception as e:
458
  return JSONResponse(
459
  status_code=500,
460
+ content={"error": str(e), "trace": traceback.format_exc()},
461
  )