COCODEDE04 commited on
Commit
a6c0646
·
verified ·
1 Parent(s): 45857b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -94
app.py CHANGED
@@ -1,5 +1,5 @@
1
- import os, json, io, traceback
2
- from typing import Any, Dict, List, Optional
3
 
4
  import numpy as np
5
  import tensorflow as tf
@@ -7,7 +7,7 @@ from fastapi import FastAPI, Request
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from fastapi.responses import JSONResponse
9
 
10
- # ---------- SHAP optional import ----------
11
  try:
12
  import shap
13
  SHAP_AVAILABLE = True
@@ -94,8 +94,7 @@ def load_joblib_if_exists(candidates: List[str]):
94
  p = os.path.join(os.getcwd(), name)
95
  if os.path.isfile(p):
96
  try:
97
- # Import inside to avoid hard dependency if not used
98
- import joblib # type: ignore
99
  with open(p, "rb") as fh:
100
  obj = joblib.load(fh)
101
  return obj, p, None
@@ -147,8 +146,6 @@ def coral_probs_from_logits(logits_np: np.ndarray) -> np.ndarray:
147
  left = tf.concat([tf.ones_like(sig[:, :1]), sig], axis=1)
148
  right = tf.concat([sig, tf.zeros_like(sig[:, :1])], axis=1)
149
  probs = tf.clip_by_value(left - right, 1e-12, 1.0)
150
- # normalize row-wise just in case
151
- probs = probs / tf.reduce_sum(probs, axis=1, keepdims=True)
152
  return probs.numpy()
153
 
154
 
@@ -164,17 +161,14 @@ def decode_logits(raw: np.ndarray) -> (np.ndarray, str):
164
  K = len(CLASSES)
165
 
166
  if M == K - 1:
167
- # CORAL logits
168
  probs = coral_probs_from_logits(raw)[0]
169
  return probs, "auto_coral"
170
  elif M == K:
171
- # Softmax or unnormalized scores
172
  row = raw[0]
173
  exps = np.exp(row - np.max(row))
174
  probs = exps / np.sum(exps)
175
  return probs, "auto_softmax"
176
  else:
177
- # Fallback: normalize across whatever is there
178
  row = raw[0]
179
  s = float(np.sum(np.abs(row)))
180
  probs = (row / s) if s > 0 else np.ones_like(row) / len(row)
@@ -202,7 +196,6 @@ def build_raw_vector(payload: Dict[str, Any]) -> np.ndarray:
202
 
203
  def apply_imputer_if_any(x: np.ndarray) -> np.ndarray:
204
  if imputer is not None:
205
- # imputer expects 2D
206
  return imputer.transform(x.reshape(1, -1)).astype(np.float32)[0]
207
  # fallback: replace NaNs with feature means from stats if available, else 0
208
  out = x.copy()
@@ -238,7 +231,7 @@ def apply_scaling_or_stats(raw_vec: np.ndarray) -> (np.ndarray, Dict[str, float]
238
  return z, z_detail, "manual_stats"
239
 
240
 
241
- # --------- SHAP model wrapper & explainer ---------
242
  def model_proba_from_z(z_batch_np: np.ndarray) -> np.ndarray:
243
  """
244
  Wrapper for SHAP: takes (N, n_features) in z-space and returns (N, K) probabilities.
@@ -250,14 +243,11 @@ def model_proba_from_z(z_batch_np: np.ndarray) -> np.ndarray:
250
  K = len(CLASSES)
251
 
252
  if M == K - 1:
253
- # CORAL
254
  probs = coral_probs_from_logits(raw) # (N, K)
255
  elif M == K:
256
- # Softmax or scores
257
  exps = np.exp(raw - np.max(raw, axis=1, keepdims=True))
258
  probs = exps / np.sum(exps, axis=1, keepdims=True)
259
  else:
260
- # Fallback normalize
261
  s = np.sum(np.abs(raw), axis=1, keepdims=True)
262
  probs = np.divide(raw, s, out=np.ones_like(raw) / max(M, 1), where=(s > 0))
263
  return probs
@@ -266,7 +256,6 @@ def model_proba_from_z(z_batch_np: np.ndarray) -> np.ndarray:
266
  EXPLAINER = None
267
  if SHAP_AVAILABLE:
268
  try:
269
- # Background: 50 "average" institutions at z=0
270
  BACKGROUND_Z = np.zeros((50, len(FEATURES)), dtype=np.float32)
271
  EXPLAINER = shap.KernelExplainer(model_proba_from_z, BACKGROUND_Z)
272
  print("SHAP KernelExplainer initialized.")
@@ -314,7 +303,7 @@ def health():
314
  "imputer": bool(imputer),
315
  "scaler": bool(scaler),
316
  "stats_available": bool(stats),
317
- "shap_available": bool(EXPLAINER is not None),
318
  }
319
 
320
 
@@ -356,12 +345,12 @@ async def predict(req: Request):
356
  if not isinstance(payload, dict):
357
  return JSONResponse(status_code=400, content={"error": "Expected JSON object"})
358
 
359
- # ---------- PREPROCESSING ----------
360
  raw = build_raw_vector(payload) # may contain NaNs
361
  raw_imp = apply_imputer_if_any(raw) # impute
362
  z_vec, z_detail, z_mode = apply_scaling_or_stats(raw_imp) # scale / z-score
363
 
364
- # ---------- PREDICTION ----------
365
  X = z_vec.reshape(1, -1).astype(np.float32)
366
  raw_logits = model.predict(X, verbose=0)
367
  probs, mode = decode_logits(raw_logits)
@@ -370,87 +359,62 @@ async def predict(req: Request):
370
  probs_dict = {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}
371
  missing = [f for i, f in enumerate(FEATURES) if np.isnan(raw[i])]
372
 
373
- # ---------- SHAP EXPLANATION (predicted class only) ----------
374
- shap_out = {"error": "SHAP not computed"}
375
  if EXPLAINER is not None:
376
  try:
377
- shap_vals = EXPLAINER.shap_values(X, nsamples=100)
378
-
379
- # 1) Pull raw SHAP tensor
380
- if isinstance(shap_vals, list):
381
- # Classic multi-output: list[len = n_classes], each (n_samples, n_features)
382
- raw_sv = np.array(shap_vals[pred_idx])
 
 
 
 
 
 
 
 
 
 
 
383
  else:
384
- # Single array, possibly (n_samples, n_features) or (n_samples, n_features, n_outputs)
385
- raw_sv = np.array(shap_vals)
386
-
387
- # 2) Normalize shapes to a 1D vector (n_features,) for the predicted class
388
- if raw_sv.ndim == 1:
389
- # Already (n_features,)
390
- shap_vec = raw_sv.astype(float)
391
-
392
- elif raw_sv.ndim == 2:
393
- # (n_samples, n_features) or (n_features, 1)
394
- if raw_sv.shape[0] == 1:
395
- # (1, n_features)
396
- shap_vec = raw_sv[0].astype(float)
397
- elif raw_sv.shape[1] == 1:
398
- # (n_features, 1)
399
- shap_vec = raw_sv[:, 0].astype(float)
400
  else:
401
- # assume (n_samples, n_features), take first sample
402
- shap_vec = raw_sv[0].astype(float)
403
-
404
- elif raw_sv.ndim == 3:
405
- # Most likely (n_samples, n_features, n_outputs)
406
- n_samples, n_features, n_outputs = raw_sv.shape
407
- if n_samples < 1:
408
- raise ValueError(f"SHAP 3D output has zero samples: {raw_sv.shape}")
409
- if pred_idx >= n_outputs:
410
- raise ValueError(
411
- f"SHAP 3D output has only {n_outputs} outputs, "
412
- f"cannot index class {pred_idx}"
413
- )
414
- # take first sample, all features, predicted class
415
- shap_vec = raw_sv[0, :, pred_idx].astype(float)
416
-
417
- else:
418
- # Fallback: flatten all sample dims, keep first feature-block
419
- flat = raw_sv.reshape(raw_sv.shape[0], -1)
420
- shap_vec = flat[0].astype(float)
421
-
422
- # 3) Sanity check length
423
- if shap_vec.shape[0] != len(FEATURES):
424
- raise ValueError(
425
- f"Unexpected SHAP vector length {shap_vec.shape[0]} "
426
- f"(expected {len(FEATURES)})"
427
- )
428
-
429
- # 4) Expected value (baseline) for the predicted class
430
- exp_raw = EXPLAINER.expected_value
431
- if isinstance(exp_raw, (list, np.ndarray)):
432
- exp_val = float(np.array(exp_raw)[pred_idx])
433
- else:
434
- exp_val = float(exp_raw)
435
 
436
- # 5) Map feature -> contribution
437
- shap_feature_contribs = {
438
- FEATURES[i]: float(shap_vec[i])
439
- for i in range(len(FEATURES))
440
- }
 
 
441
 
442
- shap_out = {
443
- "explained_class": CLASSES[pred_idx],
444
- "expected_value": exp_val,
445
- "shap_values": shap_feature_contribs,
446
- }
447
 
448
  except Exception as e:
449
- shap_out = {"error": str(e), "trace": traceback.format_exc()}
450
- else:
451
- shap_out = {"error": "SHAP not available on server"}
 
 
452
 
453
- # ---------- RESPONSE ----------
454
  return {
455
  "input_ok": (len(missing) == 0),
456
  "missing": missing,
@@ -459,10 +423,10 @@ async def predict(req: Request):
459
  "scaler": bool(scaler),
460
  "z_mode": z_mode,
461
  },
462
- "z_scores": z_detail, # per feature (z-space)
463
- "probabilities": probs_dict, # per class
464
  "predicted_state": CLASSES[pred_idx],
465
- "shap": shap_out, # SHAP for predicted state only
466
  "debug": {
467
  "raw_shape": list(raw_logits.shape),
468
  "decode_mode": mode,
 
1
+ import os, json, traceback
2
+ from typing import Any, Dict, List
3
 
4
  import numpy as np
5
  import tensorflow as tf
 
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from fastapi.responses import JSONResponse
9
 
10
+ # Try SHAP
11
  try:
12
  import shap
13
  SHAP_AVAILABLE = True
 
94
  p = os.path.join(os.getcwd(), name)
95
  if os.path.isfile(p):
96
  try:
97
+ import joblib # lazy import
 
98
  with open(p, "rb") as fh:
99
  obj = joblib.load(fh)
100
  return obj, p, None
 
146
  left = tf.concat([tf.ones_like(sig[:, :1]), sig], axis=1)
147
  right = tf.concat([sig, tf.zeros_like(sig[:, :1])], axis=1)
148
  probs = tf.clip_by_value(left - right, 1e-12, 1.0)
 
 
149
  return probs.numpy()
150
 
151
 
 
161
  K = len(CLASSES)
162
 
163
  if M == K - 1:
 
164
  probs = coral_probs_from_logits(raw)[0]
165
  return probs, "auto_coral"
166
  elif M == K:
 
167
  row = raw[0]
168
  exps = np.exp(row - np.max(row))
169
  probs = exps / np.sum(exps)
170
  return probs, "auto_softmax"
171
  else:
 
172
  row = raw[0]
173
  s = float(np.sum(np.abs(row)))
174
  probs = (row / s) if s > 0 else np.ones_like(row) / len(row)
 
196
 
197
  def apply_imputer_if_any(x: np.ndarray) -> np.ndarray:
198
  if imputer is not None:
 
199
  return imputer.transform(x.reshape(1, -1)).astype(np.float32)[0]
200
  # fallback: replace NaNs with feature means from stats if available, else 0
201
  out = x.copy()
 
231
  return z, z_detail, "manual_stats"
232
 
233
 
234
+ # --------- SHAP: model wrapper & explainer ---------
235
  def model_proba_from_z(z_batch_np: np.ndarray) -> np.ndarray:
236
  """
237
  Wrapper for SHAP: takes (N, n_features) in z-space and returns (N, K) probabilities.
 
243
  K = len(CLASSES)
244
 
245
  if M == K - 1:
 
246
  probs = coral_probs_from_logits(raw) # (N, K)
247
  elif M == K:
 
248
  exps = np.exp(raw - np.max(raw, axis=1, keepdims=True))
249
  probs = exps / np.sum(exps, axis=1, keepdims=True)
250
  else:
 
251
  s = np.sum(np.abs(raw), axis=1, keepdims=True)
252
  probs = np.divide(raw, s, out=np.ones_like(raw) / max(M, 1), where=(s > 0))
253
  return probs
 
256
  EXPLAINER = None
257
  if SHAP_AVAILABLE:
258
  try:
 
259
  BACKGROUND_Z = np.zeros((50, len(FEATURES)), dtype=np.float32)
260
  EXPLAINER = shap.KernelExplainer(model_proba_from_z, BACKGROUND_Z)
261
  print("SHAP KernelExplainer initialized.")
 
303
  "imputer": bool(imputer),
304
  "scaler": bool(scaler),
305
  "stats_available": bool(stats),
306
+ "shap_available": bool(EXPLAINER),
307
  }
308
 
309
 
 
345
  if not isinstance(payload, dict):
346
  return JSONResponse(status_code=400, content={"error": "Expected JSON object"})
347
 
348
+ # Build in EXACT training order
349
  raw = build_raw_vector(payload) # may contain NaNs
350
  raw_imp = apply_imputer_if_any(raw) # impute
351
  z_vec, z_detail, z_mode = apply_scaling_or_stats(raw_imp) # scale / z-score
352
 
353
+ # Predict
354
  X = z_vec.reshape(1, -1).astype(np.float32)
355
  raw_logits = model.predict(X, verbose=0)
356
  probs, mode = decode_logits(raw_logits)
 
359
  probs_dict = {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}
360
  missing = [f for i, f in enumerate(FEATURES) if np.isnan(raw[i])]
361
 
362
+ # ---------- SHAP for ALL classes ----------
363
+ shap_payload: Dict[str, Any] = {"available": bool(EXPLAINER)}
364
  if EXPLAINER is not None:
365
  try:
366
+ shap_raw = EXPLAINER.shap_values(X, nsamples=100)
367
+ shap_all_classes: Dict[str, Dict[str, float]] = {}
368
+
369
+ if isinstance(shap_raw, list):
370
+ # standard KernelExplainer multi-output: list of length K, each (1, n_features)
371
+ for c_idx, cls_name in enumerate(CLASSES):
372
+ if c_idx >= len(shap_raw):
373
+ break
374
+ arr = np.array(shap_raw[c_idx])
375
+ if arr.ndim == 2:
376
+ vec = arr[0]
377
+ else:
378
+ vec = arr.reshape(-1)
379
+ m = min(len(FEATURES), len(vec))
380
+ shap_all_classes[cls_name] = {
381
+ FEATURES[i]: float(vec[i]) for i in range(m)
382
+ }
383
  else:
384
+ # Fallback: single ndarray, try to interpret first dim as classes
385
+ arr = np.array(shap_raw)
386
+ if arr.ndim == 3:
387
+ # e.g. (K, 1, n_features) or (1, K, n_features)
388
+ if arr.shape[1] == 1:
389
+ arr2 = arr[:, 0, :]
390
+ elif arr.shape[0] == 1:
391
+ arr2 = arr[0, :, :]
392
+ else:
393
+ arr2 = arr.reshape(arr.shape[0], -1)
394
+ elif arr.ndim == 2:
395
+ # (K, n_features)
396
+ arr2 = arr
 
 
 
397
  else:
398
+ raise ValueError(f"Unsupported SHAP array shape: {arr.shape}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
 
400
+ K_eff = min(arr2.shape[0], len(CLASSES))
401
+ for c_idx in range(K_eff):
402
+ vec = arr2[c_idx]
403
+ m = min(len(FEATURES), len(vec))
404
+ shap_all_classes[CLASSES[c_idx]] = {
405
+ FEATURES[i]: float(vec[i]) for i in range(m)
406
+ }
407
 
408
+ shap_payload["all_classes"] = shap_all_classes
 
 
 
 
409
 
410
  except Exception as e:
411
+ shap_payload = {
412
+ "available": False,
413
+ "error": str(e),
414
+ "trace": traceback.format_exc(),
415
+ }
416
 
417
+ # ---------- final response ----------
418
  return {
419
  "input_ok": (len(missing) == 0),
420
  "missing": missing,
 
423
  "scaler": bool(scaler),
424
  "z_mode": z_mode,
425
  },
426
+ "z_scores": z_detail, # per feature
427
+ "probabilities": probs_dict,
428
  "predicted_state": CLASSES[pred_idx],
429
+ "shap": shap_payload, # FULL per-class SHAP matrix
430
  "debug": {
431
  "raw_shape": list(raw_logits.shape),
432
  "decode_mode": mode,