COCODEDE04 commited on
Commit
c71e704
·
verified ·
1 Parent(s): 465e1a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -10
app.py CHANGED
@@ -234,9 +234,13 @@ def apply_scaling_or_stats(raw_vec: np.ndarray) -> (np.ndarray, Dict[str, float]
234
  # --------- SHAP: model wrapper & explainer ---------
235
  def model_proba_from_z(z_batch_np: np.ndarray) -> np.ndarray:
236
  """
237
- Takes (N, n_features) OR a single 1D sample in z-space
238
- and returns (N, K) probabilities.
239
- Safe for both normal /predict and SHAP calls.
 
 
 
 
240
  """
241
  z = np.array(z_batch_np, dtype=np.float32)
242
 
@@ -244,15 +248,31 @@ def model_proba_from_z(z_batch_np: np.ndarray) -> np.ndarray:
244
  if z.ndim == 1:
245
  z = z.reshape(1, -1)
246
 
247
- raw = model.predict(z, verbose=0)
248
- probs, _ = decode_logits(raw)
 
249
 
250
- # decode_logits may return (K,) if N=1, so enforce 2D
251
- probs = np.array(probs, dtype=np.float32)
252
- if probs.ndim == 1:
253
- probs = probs.reshape(1, -1)
254
 
255
- return probs # shape: (N, K)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
 
257
 
258
  EXPLAINER = None
 
234
  # --------- SHAP: model wrapper & explainer ---------
235
  def model_proba_from_z(z_batch_np: np.ndarray) -> np.ndarray:
236
  """
237
+ Batch-safe wrapper for SHAP and other callers.
238
+
239
+ Input:
240
+ z_batch_np: (N, n_features) or (n_features,) in z-space
241
+
242
+ Output:
243
+ probs: (N, K) matrix of class probabilities
244
  """
245
  z = np.array(z_batch_np, dtype=np.float32)
246
 
 
248
  if z.ndim == 1:
249
  z = z.reshape(1, -1)
250
 
251
+ raw = model.predict(z, verbose=0) # shape: (N, M)
252
+ if raw.ndim != 2:
253
+ raise ValueError(f"Unexpected raw shape from model: {raw.shape}")
254
 
255
+ N, M = raw.shape
256
+ K = len(CLASSES)
 
 
257
 
258
+ if M == K - 1:
259
+ # CORAL: logits for K-1 thresholds → K probabilities
260
+ probs = coral_probs_from_logits(raw) # (N, K)
261
+ elif M == K:
262
+ # Softmax or unnormalized scores, per row
263
+ exps = np.exp(raw - np.max(raw, axis=1, keepdims=True))
264
+ probs = exps / np.sum(exps, axis=1, keepdims=True) # (N, K)
265
+ else:
266
+ # Fallback: row-wise normalization
267
+ s = np.sum(np.abs(raw), axis=1, keepdims=True) # (N, 1)
268
+ probs = np.divide(
269
+ raw,
270
+ s,
271
+ out=np.ones_like(raw) / max(M, 1),
272
+ where=(s > 0),
273
+ ) # (N, M)
274
+
275
+ return probs
276
 
277
 
278
  EXPLAINER = None