COCODEDE04 commited on
Commit
4b96a3d
·
verified ·
1 Parent(s): 9b0fc98

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -31
app.py CHANGED
@@ -350,20 +350,22 @@ async def predict(req: Request):
350
  """
351
  Body: JSON object mapping feature -> numeric value (strings with commas/points ok).
352
  Missing features are imputed if imputer present; else filled with means (if stats) or 0.
353
-
354
- Now also returns SHAP values for the predicted_state (if SHAP is available).
355
  """
356
  try:
357
  payload = await req.json()
358
  if not isinstance(payload, dict):
359
  return JSONResponse(status_code=400, content={"error": "Expected JSON object"})
360
 
 
 
 
 
361
  # Build in EXACT training order
362
  raw = build_raw_vector(payload) # may contain NaNs
363
  raw_imp = apply_imputer_if_any(raw) # impute
364
  z_vec, z_detail, z_mode = apply_scaling_or_stats(raw_imp) # scale / z-score
365
 
366
- # Predict
367
  X = z_vec.reshape(1, -1).astype(np.float32)
368
  raw_logits = model.predict(X, verbose=0)
369
  probs, mode = decode_logits(raw_logits)
@@ -372,32 +374,7 @@ async def predict(req: Request):
372
  probs_dict = {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}
373
  missing = [f for i, f in enumerate(FEATURES) if np.isnan(raw[i])]
374
 
375
- return {
376
- "input_ok": (len(missing) == 0),
377
- "missing": missing,
378
- "preprocess": {
379
- "imputer": bool(imputer),
380
- "scaler": bool(scaler),
381
- "z_mode": z_mode,
382
- },
383
- "z_scores": z_detail,
384
- "probabilities": probs_dict,
385
- "predicted_state": CLASSES[pred_idx],
386
- "shap": shap_out,
387
- "debug": {
388
- "raw_shape": list(raw_logits.shape),
389
- "decode_mode": mode,
390
- "raw_first_row": [float(v) for v in raw_logits[0]],
391
- },
392
- }
393
-
394
- pred_idx = int(np.argmax(probs))
395
- probs_dict = {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}
396
- missing = [f for i, f in enumerate(FEATURES) if np.isnan(raw[i])]
397
-
398
- # ---- SHAP explanation for predicted class ----
399
- # -------- SHAP EXPLANATION (predicted class only) --------
400
- shap_out = None
401
  if EXPLAINER is not None:
402
  try:
403
  # X is already z-space: shape (1, n_features)
@@ -406,12 +383,12 @@ async def predict(req: Request):
406
  # Case 1: multi-output -> list of length K, each (1, n_features)
407
  if isinstance(shap_vals, list):
408
  shap_vec = np.array(shap_vals[pred_idx][0], dtype=float)
409
- # expected_value may also be a list per class
410
  exp_val_raw = EXPLAINER.expected_value
411
  if isinstance(exp_val_raw, (list, np.ndarray)):
412
  exp_val = float(exp_val_raw[pred_idx])
413
  else:
414
  exp_val = float(exp_val_raw)
 
415
  # Case 2: single-output -> ndarray (1, n_features)
416
  elif isinstance(shap_vals, np.ndarray):
417
  shap_vec = np.array(shap_vals[0], dtype=float)
@@ -420,6 +397,8 @@ async def predict(req: Request):
420
  exp_val = float(exp_val_raw[0])
421
  else:
422
  exp_val = float(exp_val_raw)
 
 
423
  else:
424
  raise TypeError(f"Unsupported SHAP return type: {type(shap_vals)}")
425
 
@@ -434,6 +413,7 @@ async def predict(req: Request):
434
  "expected_value": exp_val,
435
  "shap_values": shap_feature_contribs,
436
  }
 
437
  except Exception as e:
438
  shap_out = {
439
  "error": str(e),
@@ -442,5 +422,31 @@ async def predict(req: Request):
442
  else:
443
  shap_out = {"error": "SHAP not available on server"}
444
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
  except Exception as e:
446
- return JSONResponse(status_code=500, content={"error": str(e), "trace": traceback.format_exc()})
 
 
 
 
 
 
 
350
  """
351
  Body: JSON object mapping feature -> numeric value (strings with commas/points ok).
352
  Missing features are imputed if imputer present; else filled with means (if stats) or 0.
 
 
353
  """
354
  try:
355
  payload = await req.json()
356
  if not isinstance(payload, dict):
357
  return JSONResponse(status_code=400, content={"error": "Expected JSON object"})
358
 
359
+ # default SHAP block – will be overwritten if explanation succeeds
360
+ shap_out = {"error": "SHAP not computed"}
361
+
362
+ # ---------- PREPROCESSING ----------
363
  # Build in EXACT training order
364
  raw = build_raw_vector(payload) # may contain NaNs
365
  raw_imp = apply_imputer_if_any(raw) # impute
366
  z_vec, z_detail, z_mode = apply_scaling_or_stats(raw_imp) # scale / z-score
367
 
368
+ # ---------- PREDICTION ----------
369
  X = z_vec.reshape(1, -1).astype(np.float32)
370
  raw_logits = model.predict(X, verbose=0)
371
  probs, mode = decode_logits(raw_logits)
 
374
  probs_dict = {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}
375
  missing = [f for i, f in enumerate(FEATURES) if np.isnan(raw[i])]
376
 
377
+ # ---------- SHAP EXPLANATION (predicted class only) ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
  if EXPLAINER is not None:
379
  try:
380
  # X is already z-space: shape (1, n_features)
 
383
  # Case 1: multi-output -> list of length K, each (1, n_features)
384
  if isinstance(shap_vals, list):
385
  shap_vec = np.array(shap_vals[pred_idx][0], dtype=float)
 
386
  exp_val_raw = EXPLAINER.expected_value
387
  if isinstance(exp_val_raw, (list, np.ndarray)):
388
  exp_val = float(exp_val_raw[pred_idx])
389
  else:
390
  exp_val = float(exp_val_raw)
391
+
392
  # Case 2: single-output -> ndarray (1, n_features)
393
  elif isinstance(shap_vals, np.ndarray):
394
  shap_vec = np.array(shap_vals[0], dtype=float)
 
397
  exp_val = float(exp_val_raw[0])
398
  else:
399
  exp_val = float(exp_val_raw)
400
+
401
+ # Anything else – we consider wrong type
402
  else:
403
  raise TypeError(f"Unsupported SHAP return type: {type(shap_vals)}")
404
 
 
413
  "expected_value": exp_val,
414
  "shap_values": shap_feature_contribs,
415
  }
416
+
417
  except Exception as e:
418
  shap_out = {
419
  "error": str(e),
 
422
  else:
423
  shap_out = {"error": "SHAP not available on server"}
424
 
425
+ # ---------- RESPONSE ----------
426
+ return {
427
+ "input_ok": (len(missing) == 0),
428
+ "missing": missing,
429
+ "preprocess": {
430
+ "imputer": bool(imputer),
431
+ "scaler": bool(scaler),
432
+ "z_mode": z_mode,
433
+ },
434
+ "z_scores": z_detail, # per feature
435
+ "probabilities": probs_dict, # per class
436
+ "predicted_state": CLASSES[pred_idx],
437
+ "shap": shap_out, # SHAP for predicted state only
438
+ "debug": {
439
+ "raw_shape": list(raw_logits.shape),
440
+ "decode_mode": mode,
441
+ "raw_first_row": [float(v) for v in raw_logits[0]],
442
+ },
443
+ }
444
+
445
  except Exception as e:
446
+ return JSONResponse(
447
+ status_code=500,
448
+ content={
449
+ "error": str(e),
450
+ "trace": traceback.format_exc()
451
+ }
452
+ )