Synav commited on
Commit
bb3bf4d
·
verified ·
1 Parent(s): ddb0ad4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +189 -83
app.py CHANGED
@@ -7,7 +7,8 @@ import joblib
7
  import shap
8
  import matplotlib.pyplot as plt
9
  import os
10
- from huggingface_hub import HfApi
 
11
 
12
 
13
  from sklearn.pipeline import Pipeline
@@ -231,6 +232,70 @@ def publish_to_hub(model_repo_id: str, version_tag: str):
231
  "latest_meta_path": "latest/meta.json",
232
  }
233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  # ============================================================
235
  # Streamlit UI
236
  # ============================================================
@@ -311,89 +376,130 @@ with tab_train:
311
 
312
  # ---------------- PREDICT ----------------
313
  with tab_predict:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  if st.session_state.pipe is None:
315
- st.warning("Train a model first.")
316
- else:
317
- infer_file = st.file_uploader("Upload inference Excel (.xlsx)", type=["xlsx"])
318
- if infer_file:
319
- df_inf = pd.read_excel(infer_file, engine="openpyxl")
320
- X_inf = df_inf[FEATURE_COLS].copy()
321
- X_inf = X_inf.replace({pd.NA: np.nan})
322
-
323
- for c in CAT_COLS:
324
- X_inf[c] = X_inf[c].astype("object")
325
- X_inf.loc[X_inf[c].isna(), c] = np.nan
326
- X_inf[c] = X_inf[c].map(lambda v: v if pd.isna(v) else str(v))
 
 
 
327
 
328
- for c in NUM_COLS:
329
- X_inf[c] = pd.to_numeric(X_inf[c], errors="coerce")
330
- for c in CAT_COLS:
331
- X_inf[c] = X_inf[c].astype("object")
332
-
333
- pipe = st.session_state.pipe
334
- proba = pipe.predict_proba(X_inf)[:, 1]
335
-
336
- df_out = df_inf.copy()
337
- df_out["predicted_probability"] = proba
338
- st.dataframe(df_out.head())
339
-
340
- st.download_button(
341
- "Download predictions",
342
- df_out.to_csv(index=False).encode(),
343
- "predictions.csv",
344
- "text/csv"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  )
346
 
347
- st.subheader("SHAP explanation")
348
-
349
- with st.form("shap_form"):
350
- row = st.number_input("Row index", 0, len(X_inf) - 1, 0)
351
- explain_btn = st.form_submit_button("Generate SHAP explanation")
352
-
353
- if explain_btn:
354
- X_one = X_inf.iloc[[int(row)]]
355
-
356
- pre = pipe.named_steps["preprocess"]
357
- X_one_t = pre.transform(X_one)
358
-
359
- # Build explainer if missing
360
- if st.session_state.get("explainer") is None:
361
- st.session_state.explainer = build_shap_explainer(pipe, X_inf)
362
-
363
- explainer = st.session_state.explainer
364
- shap_vals = explainer.shap_values(X_one_t)
365
- base = explainer.expected_value
366
- if isinstance(shap_vals, list):
367
- shap_vals = shap_vals[1]
368
-
369
- try:
370
- names = list(pre.get_feature_names_out())
371
- except Exception:
372
- names = [f"f{i}" for i in range(len(shap_vals[0]))]
373
-
374
- try:
375
- x_dense = X_one_t.toarray()[0]
376
- except Exception:
377
- x_dense = np.array(X_one_t)[0]
378
-
379
- exp = shap.Explanation(
380
- values=shap_vals[0],
381
- base_values=float(base) if np.isscalar(base) else float(np.array(base).reshape(-1)[0]),
382
- data=x_dense,
383
- feature_names=names,
384
- )
385
-
386
- c1, c2 = st.columns(2)
387
-
388
- with c1:
389
- st.markdown("**Waterfall**")
390
- fig = plt.figure()
391
- shap.plots.waterfall(exp, show=False, max_display=20)
392
- st.pyplot(fig, clear_figure=True)
393
-
394
- with c2:
395
- st.markdown("**Top features**")
396
- fig2 = plt.figure()
397
- shap.plots.bar(exp, show=False, max_display=20)
398
- st.pyplot(fig2, clear_figure=True)
399
- st.stop()
 
7
  import shap
8
  import matplotlib.pyplot as plt
9
  import os
10
+ from huggingface_hub import hf_hub_download, HfApi
11
+
12
 
13
 
14
  from sklearn.pipeline import Pipeline
 
232
  "latest_meta_path": "latest/meta.json",
233
  }
234
 
235
+ MODEL_REPO_ID = "Synav/LogiSHAP-Studio-LogReg"
236
+
237
+ def list_release_versions(model_repo_id: str):
238
+ """
239
+ Returns sorted version tags found under releases/<version>/model.joblib in the model repo.
240
+ """
241
+ api = HfApi(token=os.environ.get("HF_TOKEN") or None)
242
+ files = api.list_repo_files(repo_id=model_repo_id, repo_type="model")
243
+
244
+ versions = set()
245
+ for f in files:
246
+ # We only care about releases/<version>/model.joblib
247
+ if f.startswith("releases/") and f.endswith("/model.joblib"):
248
+ parts = f.split("/")
249
+ if len(parts) >= 3:
250
+ versions.add(parts[1])
251
+
252
+ # Most users want newest first (timestamp tags sort lexicographically)
253
+ return sorted(versions, reverse=True)
254
+
255
+
256
+ def load_model_by_version(model_repo_id: str, version_tag: str):
257
+ """
258
+ Loads a specific version from releases/<version_tag>/model.joblib and meta.json
259
+ """
260
+ model_file = hf_hub_download(
261
+ repo_id=model_repo_id,
262
+ repo_type="model",
263
+ filename=f"releases/{version_tag}/model.joblib",
264
+ )
265
+ meta_file = hf_hub_download(
266
+ repo_id=model_repo_id,
267
+ repo_type="model",
268
+ filename=f"releases/{version_tag}/meta.json",
269
+ )
270
+
271
+ pipe = joblib.load(model_file)
272
+ with open(meta_file, "r", encoding="utf-8") as f:
273
+ meta = json.load(f)
274
+
275
+ return pipe, meta
276
+
277
+
278
+ def load_latest_model(model_repo_id: str):
279
+ """
280
+ Loads latest/model.joblib and latest/meta.json
281
+ """
282
+ model_file = hf_hub_download(
283
+ repo_id=model_repo_id,
284
+ repo_type="model",
285
+ filename="latest/model.joblib",
286
+ )
287
+ meta_file = hf_hub_download(
288
+ repo_id=model_repo_id,
289
+ repo_type="model",
290
+ filename="latest/meta.json",
291
+ )
292
+
293
+ pipe = joblib.load(model_file)
294
+ with open(meta_file, "r", encoding="utf-8") as f:
295
+ meta = json.load(f)
296
+
297
+ return pipe, meta
298
+
299
  # ============================================================
300
  # Streamlit UI
301
  # ============================================================
 
376
 
377
  # ---------------- PREDICT ----------------
378
  with tab_predict:
379
+ st.subheader("Select a trained model (no retraining required)")
380
+
381
+ MODEL_REPO_ID = "Synav/LogiSHAP-Studio-LogReg"
382
+
383
+ # Ensure session state keys exist
384
+ if "pipe" not in st.session_state:
385
+ st.session_state.pipe = None
386
+ if "meta" not in st.session_state:
387
+ st.session_state.meta = None
388
+ if "explainer" not in st.session_state:
389
+ st.session_state.explainer = None
390
+
391
+ # List available releases
392
+ try:
393
+ versions = list_release_versions(MODEL_REPO_ID)
394
+ except Exception as e:
395
+ versions = []
396
+ st.error(f"Could not list model versions: {e}")
397
+
398
+ choices = ["latest"] + versions if versions else ["latest"]
399
+ selected = st.selectbox("Choose model version", choices, index=0)
400
+
401
+ if st.button("Load selected model"):
402
+ try:
403
+ with st.spinner("Loading model from Hugging Face Hub..."):
404
+ if selected == "latest":
405
+ pipe, meta = load_latest_model(MODEL_REPO_ID)
406
+ else:
407
+ pipe, meta = load_model_by_version(MODEL_REPO_ID, selected)
408
+
409
+ st.session_state.pipe = pipe
410
+ st.session_state.meta = meta
411
+ st.session_state.explainer = None # rebuild later with inference data
412
+ st.success(f"Loaded model: {selected}")
413
+ except Exception as e:
414
+ st.error(f"Load failed: {e}")
415
+
416
+ st.divider()
417
  if st.session_state.pipe is None:
418
+ st.warning("Load a model version above, then upload an inference Excel.")
419
+ st.stop()
420
+
421
+ pipe = st.session_state.pipe
422
+
423
+ infer_file = st.file_uploader("Upload inference Excel (.xlsx)", type=["xlsx"])
424
+ if infer_file:
425
+ df_inf = pd.read_excel(infer_file, engine="openpyxl")
426
+ X_inf = df_inf[FEATURE_COLS].copy()
427
+ X_inf = X_inf.replace({pd.NA: np.nan})
428
+
429
+ for c in CAT_COLS:
430
+ X_inf[c] = X_inf[c].astype("object")
431
+ X_inf.loc[X_inf[c].isna(), c] = np.nan
432
+ X_inf[c] = X_inf[c].map(lambda v: v if pd.isna(v) else str(v))
433
 
434
+ for c in NUM_COLS:
435
+ X_inf[c] = pd.to_numeric(X_inf[c], errors="coerce")
436
+ for c in CAT_COLS:
437
+ X_inf[c] = X_inf[c].astype("object")
438
+
439
+ pipe = st.session_state.pipe
440
+ proba = pipe.predict_proba(X_inf)[:, 1]
441
+
442
+ df_out = df_inf.copy()
443
+ df_out["predicted_probability"] = proba
444
+ st.dataframe(df_out.head())
445
+
446
+ st.download_button(
447
+ "Download predictions",
448
+ df_out.to_csv(index=False).encode(),
449
+ "predictions.csv",
450
+ "text/csv"
451
+ )
452
+
453
+ st.subheader("SHAP explanation")
454
+
455
+ with st.form("shap_form"):
456
+ row = st.number_input("Row index", 0, len(X_inf) - 1, 0)
457
+ explain_btn = st.form_submit_button("Generate SHAP explanation")
458
+
459
+ if explain_btn:
460
+ X_one = X_inf.iloc[[int(row)]]
461
+
462
+ pre = pipe.named_steps["preprocess"]
463
+ X_one_t = pre.transform(X_one)
464
+
465
+ # Build explainer if missing
466
+ if st.session_state.get("explainer") is None:
467
+ st.session_state.explainer = build_shap_explainer(pipe, X_inf)
468
+
469
+ explainer = st.session_state.explainer
470
+ shap_vals = explainer.shap_values(X_one_t)
471
+ base = explainer.expected_value
472
+ if isinstance(shap_vals, list):
473
+ shap_vals = shap_vals[1]
474
+
475
+ try:
476
+ names = list(pre.get_feature_names_out())
477
+ except Exception:
478
+ names = [f"f{i}" for i in range(len(shap_vals[0]))]
479
+
480
+ try:
481
+ x_dense = X_one_t.toarray()[0]
482
+ except Exception:
483
+ x_dense = np.array(X_one_t)[0]
484
+
485
+ exp = shap.Explanation(
486
+ values=shap_vals[0],
487
+ base_values=float(base) if np.isscalar(base) else float(np.array(base).reshape(-1)[0]),
488
+ data=x_dense,
489
+ feature_names=names,
490
  )
491
 
492
+ c1, c2 = st.columns(2)
493
+
494
+ with c1:
495
+ st.markdown("**Waterfall**")
496
+ fig = plt.figure()
497
+ shap.plots.waterfall(exp, show=False, max_display=20)
498
+ st.pyplot(fig, clear_figure=True)
499
+
500
+ with c2:
501
+ st.markdown("**Top features**")
502
+ fig2 = plt.figure()
503
+ shap.plots.bar(exp, show=False, max_display=20)
504
+ st.pyplot(fig2, clear_figure=True)
505
+ st.stop()