COCODEDE04 commited on
Commit
c7c0f5c
·
verified ·
1 Parent(s): 648b4ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -5
app.py CHANGED
@@ -7,7 +7,7 @@ from fastapi import FastAPI, Request
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from fastapi.responses import JSONResponse
9
 
10
-
11
  try:
12
  import shap
13
  SHAP_AVAILABLE = True
@@ -129,7 +129,7 @@ else:
129
  print("⚠️ No scaler found — using manual z-scoring if stats are available.")
130
 
131
  # Stats (means/std) for fallback manual z-score
132
- stats = {}
133
  if os.path.isfile(STATS_PATH):
134
  stats = load_json(STATS_PATH)
135
  print(f"Loaded means/std from {STATS_PATH}")
@@ -147,6 +147,8 @@ def coral_probs_from_logits(logits_np: np.ndarray) -> np.ndarray:
147
  left = tf.concat([tf.ones_like(sig[:, :1]), sig], axis=1)
148
  right = tf.concat([sig, tf.zeros_like(sig[:, :1])], axis=1)
149
  probs = tf.clip_by_value(left - right, 1e-12, 1.0)
 
 
150
  return probs.numpy()
151
 
152
 
@@ -236,8 +238,47 @@ def apply_scaling_or_stats(raw_vec: np.ndarray) -> (np.ndarray, Dict[str, float]
236
  return z, z_detail, "manual_stats"
237
 
238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  # ----------------- FastAPI -----------------
240
- app = FastAPI(title="Static Fingerprint API", version="1.1.0")
241
  app.add_middleware(
242
  CORSMiddleware,
243
  allow_origins=["*"],
@@ -246,6 +287,7 @@ app.add_middleware(
246
  allow_headers=["*"],
247
  )
248
 
 
249
  @app.get("/")
250
  def root():
251
  return {
@@ -253,6 +295,7 @@ def root():
253
  "try": ["GET /health", "POST /predict", "POST /debug/z"],
254
  }
255
 
 
256
  @app.get("/health")
257
  def health():
258
  stats_keys = []
@@ -271,8 +314,10 @@ def health():
271
  "imputer": bool(imputer),
272
  "scaler": bool(scaler),
273
  "stats_available": bool(stats),
 
274
  }
275
 
 
276
  @app.post("/debug/z")
277
  async def debug_z(req: Request):
278
  try:
@@ -299,11 +344,14 @@ async def debug_z(req: Request):
299
  except Exception as e:
300
  return JSONResponse(status_code=500, content={"error": str(e), "trace": traceback.format_exc()})
301
 
 
302
  @app.post("/predict")
303
  async def predict(req: Request):
304
  """
305
  Body: JSON object mapping feature -> numeric value (strings with commas/points ok).
306
  Missing features are imputed if imputer present; else filled with means (if stats) or 0.
 
 
307
  """
308
  try:
309
  payload = await req.json()
@@ -320,12 +368,11 @@ async def predict(req: Request):
320
  raw_logits = model.predict(X, verbose=0)
321
  probs, mode = decode_logits(raw_logits)
322
 
323
- # Package response
324
  pred_idx = int(np.argmax(probs))
325
  probs_dict = {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}
326
  missing = [f for i, f in enumerate(FEATURES) if np.isnan(raw[i])]
327
 
328
- return {
329
  "input_ok": (len(missing) == 0),
330
  "missing": missing,
331
  "preprocess": {
@@ -342,5 +389,26 @@ async def predict(req: Request):
342
  "raw_first_row": [float(v) for v in raw_logits[0]],
343
  },
344
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  except Exception as e:
346
  return JSONResponse(status_code=500, content={"error": str(e), "trace": traceback.format_exc()})
 
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from fastapi.responses import JSONResponse
9
 
10
+ # ---------- SHAP optional import ----------
11
  try:
12
  import shap
13
  SHAP_AVAILABLE = True
 
129
  print("⚠️ No scaler found — using manual z-scoring if stats are available.")
130
 
131
  # Stats (means/std) for fallback manual z-score
132
+ stats: Dict[str, Dict[str, float]] = {}
133
  if os.path.isfile(STATS_PATH):
134
  stats = load_json(STATS_PATH)
135
  print(f"Loaded means/std from {STATS_PATH}")
 
147
  left = tf.concat([tf.ones_like(sig[:, :1]), sig], axis=1)
148
  right = tf.concat([sig, tf.zeros_like(sig[:, :1])], axis=1)
149
  probs = tf.clip_by_value(left - right, 1e-12, 1.0)
150
+ # normalize row-wise just in case
151
+ probs = probs / tf.reduce_sum(probs, axis=1, keepdims=True)
152
  return probs.numpy()
153
 
154
 
 
238
  return z, z_detail, "manual_stats"
239
 
240
 
241
+ # --------- SHAP model wrapper & explainer ---------
242
+ def model_proba_from_z(z_batch_np: np.ndarray) -> np.ndarray:
243
+ """
244
+ Wrapper for SHAP: takes (N, n_features) in z-space and returns (N, K) probabilities.
245
+ """
246
+ raw = model.predict(z_batch_np, verbose=0)
247
+ if raw.ndim != 2:
248
+ raise ValueError(f"Unexpected raw shape from model: {raw.shape}")
249
+ N, M = raw.shape
250
+ K = len(CLASSES)
251
+
252
+ if M == K - 1:
253
+ # CORAL
254
+ probs = coral_probs_from_logits(raw) # (N, K)
255
+ elif M == K:
256
+ # Softmax or scores
257
+ exps = np.exp(raw - np.max(raw, axis=1, keepdims=True))
258
+ probs = exps / np.sum(exps, axis=1, keepdims=True)
259
+ else:
260
+ # Fallback normalize
261
+ s = np.sum(np.abs(raw), axis=1, keepdims=True)
262
+ probs = np.divide(raw, s, out=np.ones_like(raw) / max(M, 1), where=(s > 0))
263
+ return probs
264
+
265
+
266
+ EXPLAINER = None
267
+ if SHAP_AVAILABLE:
268
+ try:
269
+ # Background: 50 "average" institutions at z=0
270
+ BACKGROUND_Z = np.zeros((50, len(FEATURES)), dtype=np.float32)
271
+ EXPLAINER = shap.KernelExplainer(model_proba_from_z, BACKGROUND_Z)
272
+ print("SHAP KernelExplainer initialized.")
273
+ except Exception as e:
274
+ EXPLAINER = None
275
+ print("⚠️ Failed to initialize SHAP explainer:", repr(e))
276
+ else:
277
+ print("SHAP not installed; explanations disabled.")
278
+
279
+
280
  # ----------------- FastAPI -----------------
281
+ app = FastAPI(title="Static Fingerprint API", version="1.2.0")
282
  app.add_middleware(
283
  CORSMiddleware,
284
  allow_origins=["*"],
 
287
  allow_headers=["*"],
288
  )
289
 
290
+
291
  @app.get("/")
292
  def root():
293
  return {
 
295
  "try": ["GET /health", "POST /predict", "POST /debug/z"],
296
  }
297
 
298
+
299
  @app.get("/health")
300
  def health():
301
  stats_keys = []
 
314
  "imputer": bool(imputer),
315
  "scaler": bool(scaler),
316
  "stats_available": bool(stats),
317
+ "shap_available": bool(EXPLAINER is not None),
318
  }
319
 
320
+
321
  @app.post("/debug/z")
322
  async def debug_z(req: Request):
323
  try:
 
344
  except Exception as e:
345
  return JSONResponse(status_code=500, content={"error": str(e), "trace": traceback.format_exc()})
346
 
347
+
348
  @app.post("/predict")
349
  async def predict(req: Request):
350
  """
351
  Body: JSON object mapping feature -> numeric value (strings with commas/points ok).
352
  Missing features are imputed if imputer present; else filled with means (if stats) or 0.
353
+
354
+ Now also returns SHAP values for the predicted_state (if SHAP is available).
355
  """
356
  try:
357
  payload = await req.json()
 
368
  raw_logits = model.predict(X, verbose=0)
369
  probs, mode = decode_logits(raw_logits)
370
 
 
371
  pred_idx = int(np.argmax(probs))
372
  probs_dict = {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}
373
  missing = [f for i, f in enumerate(FEATURES) if np.isnan(raw[i])]
374
 
375
+ resp: Dict[str, Any] = {
376
  "input_ok": (len(missing) == 0),
377
  "missing": missing,
378
  "preprocess": {
 
389
  "raw_first_row": [float(v) for v in raw_logits[0]],
390
  },
391
  }
392
+
393
+ # ---- SHAP explanation for predicted class ----
394
+ if EXPLAINER is not None:
395
+ try:
396
+ shap_vals_list = EXPLAINER.shap_values(X, nsamples="auto")
397
+ # shap_vals_list is a list of length K (classes)
398
+ if isinstance(shap_vals_list, list) and len(shap_vals_list) == len(CLASSES):
399
+ shap_for_pred = shap_vals_list[pred_idx][0] # (n_features,)
400
+ resp["shap_target"] = CLASSES[pred_idx]
401
+ resp["shap_values"] = {
402
+ FEATURES[i]: float(shap_for_pred[i]) for i in range(len(FEATURES))
403
+ }
404
+ else:
405
+ resp["shap_error"] = "Unexpected SHAP output shape."
406
+ except Exception as e:
407
+ resp["shap_error"] = f"SHAP computation failed: {repr(e)}"
408
+ else:
409
+ resp["shap_error"] = "SHAP not available in this environment."
410
+
411
+ return resp
412
+
413
  except Exception as e:
414
  return JSONResponse(status_code=500, content={"error": str(e), "trace": traceback.format_exc()})