Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -234,9 +234,13 @@ def apply_scaling_or_stats(raw_vec: np.ndarray) -> (np.ndarray, Dict[str, float]
|
|
| 234 |
# --------- SHAP: model wrapper & explainer ---------
|
| 235 |
def model_proba_from_z(z_batch_np: np.ndarray) -> np.ndarray:
|
| 236 |
"""
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
"""
|
| 241 |
z = np.array(z_batch_np, dtype=np.float32)
|
| 242 |
|
|
@@ -244,15 +248,31 @@ def model_proba_from_z(z_batch_np: np.ndarray) -> np.ndarray:
|
|
| 244 |
if z.ndim == 1:
|
| 245 |
z = z.reshape(1, -1)
|
| 246 |
|
| 247 |
-
raw = model.predict(z, verbose=0)
|
| 248 |
-
|
|
|
|
| 249 |
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
if probs.ndim == 1:
|
| 253 |
-
probs = probs.reshape(1, -1)
|
| 254 |
|
| 255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
|
| 257 |
|
| 258 |
EXPLAINER = None
|
|
|
|
| 234 |
# --------- SHAP: model wrapper & explainer ---------
|
| 235 |
def model_proba_from_z(z_batch_np: np.ndarray) -> np.ndarray:
|
| 236 |
"""
|
| 237 |
+
Batch-safe wrapper for SHAP and other callers.
|
| 238 |
+
|
| 239 |
+
Input:
|
| 240 |
+
z_batch_np: (N, n_features) or (n_features,) in z-space
|
| 241 |
+
|
| 242 |
+
Output:
|
| 243 |
+
probs: (N, K) matrix of class probabilities
|
| 244 |
"""
|
| 245 |
z = np.array(z_batch_np, dtype=np.float32)
|
| 246 |
|
|
|
|
| 248 |
if z.ndim == 1:
|
| 249 |
z = z.reshape(1, -1)
|
| 250 |
|
| 251 |
+
raw = model.predict(z, verbose=0) # shape: (N, M)
|
| 252 |
+
if raw.ndim != 2:
|
| 253 |
+
raise ValueError(f"Unexpected raw shape from model: {raw.shape}")
|
| 254 |
|
| 255 |
+
N, M = raw.shape
|
| 256 |
+
K = len(CLASSES)
|
|
|
|
|
|
|
| 257 |
|
| 258 |
+
if M == K - 1:
|
| 259 |
+
# CORAL: logits for K-1 thresholds → K probabilities
|
| 260 |
+
probs = coral_probs_from_logits(raw) # (N, K)
|
| 261 |
+
elif M == K:
|
| 262 |
+
# Softmax or unnormalized scores, per row
|
| 263 |
+
exps = np.exp(raw - np.max(raw, axis=1, keepdims=True))
|
| 264 |
+
probs = exps / np.sum(exps, axis=1, keepdims=True) # (N, K)
|
| 265 |
+
else:
|
| 266 |
+
# Fallback: row-wise normalization
|
| 267 |
+
s = np.sum(np.abs(raw), axis=1, keepdims=True) # (N, 1)
|
| 268 |
+
probs = np.divide(
|
| 269 |
+
raw,
|
| 270 |
+
s,
|
| 271 |
+
out=np.ones_like(raw) / max(M, 1),
|
| 272 |
+
where=(s > 0),
|
| 273 |
+
) # (N, M)
|
| 274 |
+
|
| 275 |
+
return probs
|
| 276 |
|
| 277 |
|
| 278 |
EXPLAINER = None
|