Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -7,7 +7,7 @@ from fastapi import FastAPI, Request
|
|
| 7 |
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
from fastapi.responses import JSONResponse
|
| 9 |
|
| 10 |
-
|
| 11 |
try:
|
| 12 |
import shap
|
| 13 |
SHAP_AVAILABLE = True
|
|
@@ -129,7 +129,7 @@ else:
|
|
| 129 |
print("⚠️ No scaler found — using manual z-scoring if stats are available.")
|
| 130 |
|
| 131 |
# Stats (means/std) for fallback manual z-score
|
| 132 |
-
stats = {}
|
| 133 |
if os.path.isfile(STATS_PATH):
|
| 134 |
stats = load_json(STATS_PATH)
|
| 135 |
print(f"Loaded means/std from {STATS_PATH}")
|
|
@@ -147,6 +147,8 @@ def coral_probs_from_logits(logits_np: np.ndarray) -> np.ndarray:
|
|
| 147 |
left = tf.concat([tf.ones_like(sig[:, :1]), sig], axis=1)
|
| 148 |
right = tf.concat([sig, tf.zeros_like(sig[:, :1])], axis=1)
|
| 149 |
probs = tf.clip_by_value(left - right, 1e-12, 1.0)
|
|
|
|
|
|
|
| 150 |
return probs.numpy()
|
| 151 |
|
| 152 |
|
|
@@ -236,8 +238,47 @@ def apply_scaling_or_stats(raw_vec: np.ndarray) -> (np.ndarray, Dict[str, float]
|
|
| 236 |
return z, z_detail, "manual_stats"
|
| 237 |
|
| 238 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
# ----------------- FastAPI -----------------
|
| 240 |
-
app = FastAPI(title="Static Fingerprint API", version="1.
|
| 241 |
app.add_middleware(
|
| 242 |
CORSMiddleware,
|
| 243 |
allow_origins=["*"],
|
|
@@ -246,6 +287,7 @@ app.add_middleware(
|
|
| 246 |
allow_headers=["*"],
|
| 247 |
)
|
| 248 |
|
|
|
|
| 249 |
@app.get("/")
|
| 250 |
def root():
|
| 251 |
return {
|
|
@@ -253,6 +295,7 @@ def root():
|
|
| 253 |
"try": ["GET /health", "POST /predict", "POST /debug/z"],
|
| 254 |
}
|
| 255 |
|
|
|
|
| 256 |
@app.get("/health")
|
| 257 |
def health():
|
| 258 |
stats_keys = []
|
|
@@ -271,8 +314,10 @@ def health():
|
|
| 271 |
"imputer": bool(imputer),
|
| 272 |
"scaler": bool(scaler),
|
| 273 |
"stats_available": bool(stats),
|
|
|
|
| 274 |
}
|
| 275 |
|
|
|
|
| 276 |
@app.post("/debug/z")
|
| 277 |
async def debug_z(req: Request):
|
| 278 |
try:
|
|
@@ -299,11 +344,14 @@ async def debug_z(req: Request):
|
|
| 299 |
except Exception as e:
|
| 300 |
return JSONResponse(status_code=500, content={"error": str(e), "trace": traceback.format_exc()})
|
| 301 |
|
|
|
|
| 302 |
@app.post("/predict")
|
| 303 |
async def predict(req: Request):
|
| 304 |
"""
|
| 305 |
Body: JSON object mapping feature -> numeric value (strings with commas/points ok).
|
| 306 |
Missing features are imputed if imputer present; else filled with means (if stats) or 0.
|
|
|
|
|
|
|
| 307 |
"""
|
| 308 |
try:
|
| 309 |
payload = await req.json()
|
|
@@ -320,12 +368,11 @@ async def predict(req: Request):
|
|
| 320 |
raw_logits = model.predict(X, verbose=0)
|
| 321 |
probs, mode = decode_logits(raw_logits)
|
| 322 |
|
| 323 |
-
# Package response
|
| 324 |
pred_idx = int(np.argmax(probs))
|
| 325 |
probs_dict = {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}
|
| 326 |
missing = [f for i, f in enumerate(FEATURES) if np.isnan(raw[i])]
|
| 327 |
|
| 328 |
-
|
| 329 |
"input_ok": (len(missing) == 0),
|
| 330 |
"missing": missing,
|
| 331 |
"preprocess": {
|
|
@@ -342,5 +389,26 @@ async def predict(req: Request):
|
|
| 342 |
"raw_first_row": [float(v) for v in raw_logits[0]],
|
| 343 |
},
|
| 344 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
except Exception as e:
|
| 346 |
return JSONResponse(status_code=500, content={"error": str(e), "trace": traceback.format_exc()})
|
|
|
|
| 7 |
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
from fastapi.responses import JSONResponse
|
| 9 |
|
| 10 |
+
# ---------- SHAP optional import ----------
|
| 11 |
try:
|
| 12 |
import shap
|
| 13 |
SHAP_AVAILABLE = True
|
|
|
|
| 129 |
print("⚠️ No scaler found — using manual z-scoring if stats are available.")
|
| 130 |
|
| 131 |
# Stats (means/std) for fallback manual z-score
|
| 132 |
+
stats: Dict[str, Dict[str, float]] = {}
|
| 133 |
if os.path.isfile(STATS_PATH):
|
| 134 |
stats = load_json(STATS_PATH)
|
| 135 |
print(f"Loaded means/std from {STATS_PATH}")
|
|
|
|
| 147 |
left = tf.concat([tf.ones_like(sig[:, :1]), sig], axis=1)
|
| 148 |
right = tf.concat([sig, tf.zeros_like(sig[:, :1])], axis=1)
|
| 149 |
probs = tf.clip_by_value(left - right, 1e-12, 1.0)
|
| 150 |
+
# normalize row-wise just in case
|
| 151 |
+
probs = probs / tf.reduce_sum(probs, axis=1, keepdims=True)
|
| 152 |
return probs.numpy()
|
| 153 |
|
| 154 |
|
|
|
|
| 238 |
return z, z_detail, "manual_stats"
|
| 239 |
|
| 240 |
|
| 241 |
+
# --------- SHAP model wrapper & explainer ---------
|
| 242 |
+
def model_proba_from_z(z_batch_np: np.ndarray) -> np.ndarray:
|
| 243 |
+
"""
|
| 244 |
+
Wrapper for SHAP: takes (N, n_features) in z-space and returns (N, K) probabilities.
|
| 245 |
+
"""
|
| 246 |
+
raw = model.predict(z_batch_np, verbose=0)
|
| 247 |
+
if raw.ndim != 2:
|
| 248 |
+
raise ValueError(f"Unexpected raw shape from model: {raw.shape}")
|
| 249 |
+
N, M = raw.shape
|
| 250 |
+
K = len(CLASSES)
|
| 251 |
+
|
| 252 |
+
if M == K - 1:
|
| 253 |
+
# CORAL
|
| 254 |
+
probs = coral_probs_from_logits(raw) # (N, K)
|
| 255 |
+
elif M == K:
|
| 256 |
+
# Softmax or scores
|
| 257 |
+
exps = np.exp(raw - np.max(raw, axis=1, keepdims=True))
|
| 258 |
+
probs = exps / np.sum(exps, axis=1, keepdims=True)
|
| 259 |
+
else:
|
| 260 |
+
# Fallback normalize
|
| 261 |
+
s = np.sum(np.abs(raw), axis=1, keepdims=True)
|
| 262 |
+
probs = np.divide(raw, s, out=np.ones_like(raw) / max(M, 1), where=(s > 0))
|
| 263 |
+
return probs
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
EXPLAINER = None
|
| 267 |
+
if SHAP_AVAILABLE:
|
| 268 |
+
try:
|
| 269 |
+
# Background: 50 "average" institutions at z=0
|
| 270 |
+
BACKGROUND_Z = np.zeros((50, len(FEATURES)), dtype=np.float32)
|
| 271 |
+
EXPLAINER = shap.KernelExplainer(model_proba_from_z, BACKGROUND_Z)
|
| 272 |
+
print("SHAP KernelExplainer initialized.")
|
| 273 |
+
except Exception as e:
|
| 274 |
+
EXPLAINER = None
|
| 275 |
+
print("⚠️ Failed to initialize SHAP explainer:", repr(e))
|
| 276 |
+
else:
|
| 277 |
+
print("SHAP not installed; explanations disabled.")
|
| 278 |
+
|
| 279 |
+
|
| 280 |
# ----------------- FastAPI -----------------
|
| 281 |
+
app = FastAPI(title="Static Fingerprint API", version="1.2.0")
|
| 282 |
app.add_middleware(
|
| 283 |
CORSMiddleware,
|
| 284 |
allow_origins=["*"],
|
|
|
|
| 287 |
allow_headers=["*"],
|
| 288 |
)
|
| 289 |
|
| 290 |
+
|
| 291 |
@app.get("/")
|
| 292 |
def root():
|
| 293 |
return {
|
|
|
|
| 295 |
"try": ["GET /health", "POST /predict", "POST /debug/z"],
|
| 296 |
}
|
| 297 |
|
| 298 |
+
|
| 299 |
@app.get("/health")
|
| 300 |
def health():
|
| 301 |
stats_keys = []
|
|
|
|
| 314 |
"imputer": bool(imputer),
|
| 315 |
"scaler": bool(scaler),
|
| 316 |
"stats_available": bool(stats),
|
| 317 |
+
"shap_available": bool(EXPLAINER is not None),
|
| 318 |
}
|
| 319 |
|
| 320 |
+
|
| 321 |
@app.post("/debug/z")
|
| 322 |
async def debug_z(req: Request):
|
| 323 |
try:
|
|
|
|
| 344 |
except Exception as e:
|
| 345 |
return JSONResponse(status_code=500, content={"error": str(e), "trace": traceback.format_exc()})
|
| 346 |
|
| 347 |
+
|
| 348 |
@app.post("/predict")
|
| 349 |
async def predict(req: Request):
|
| 350 |
"""
|
| 351 |
Body: JSON object mapping feature -> numeric value (strings with commas/points ok).
|
| 352 |
Missing features are imputed if imputer present; else filled with means (if stats) or 0.
|
| 353 |
+
|
| 354 |
+
Now also returns SHAP values for the predicted_state (if SHAP is available).
|
| 355 |
"""
|
| 356 |
try:
|
| 357 |
payload = await req.json()
|
|
|
|
| 368 |
raw_logits = model.predict(X, verbose=0)
|
| 369 |
probs, mode = decode_logits(raw_logits)
|
| 370 |
|
|
|
|
| 371 |
pred_idx = int(np.argmax(probs))
|
| 372 |
probs_dict = {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}
|
| 373 |
missing = [f for i, f in enumerate(FEATURES) if np.isnan(raw[i])]
|
| 374 |
|
| 375 |
+
resp: Dict[str, Any] = {
|
| 376 |
"input_ok": (len(missing) == 0),
|
| 377 |
"missing": missing,
|
| 378 |
"preprocess": {
|
|
|
|
| 389 |
"raw_first_row": [float(v) for v in raw_logits[0]],
|
| 390 |
},
|
| 391 |
}
|
| 392 |
+
|
| 393 |
+
# ---- SHAP explanation for predicted class ----
|
| 394 |
+
if EXPLAINER is not None:
|
| 395 |
+
try:
|
| 396 |
+
shap_vals_list = EXPLAINER.shap_values(X, nsamples="auto")
|
| 397 |
+
# shap_vals_list is a list of length K (classes)
|
| 398 |
+
if isinstance(shap_vals_list, list) and len(shap_vals_list) == len(CLASSES):
|
| 399 |
+
shap_for_pred = shap_vals_list[pred_idx][0] # (n_features,)
|
| 400 |
+
resp["shap_target"] = CLASSES[pred_idx]
|
| 401 |
+
resp["shap_values"] = {
|
| 402 |
+
FEATURES[i]: float(shap_for_pred[i]) for i in range(len(FEATURES))
|
| 403 |
+
}
|
| 404 |
+
else:
|
| 405 |
+
resp["shap_error"] = "Unexpected SHAP output shape."
|
| 406 |
+
except Exception as e:
|
| 407 |
+
resp["shap_error"] = f"SHAP computation failed: {repr(e)}"
|
| 408 |
+
else:
|
| 409 |
+
resp["shap_error"] = "SHAP not available in this environment."
|
| 410 |
+
|
| 411 |
+
return resp
|
| 412 |
+
|
| 413 |
except Exception as e:
|
| 414 |
return JSONResponse(status_code=500, content={"error": str(e), "trace": traceback.format_exc()})
|