Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -235,17 +235,194 @@ def bootstrap_auc_test(y_true, proba_a, proba_b, classes, n_bootstrap=2000, seed
|
|
| 235 |
return 1.0, auc_a, auc_b, -1, 1 # Not enough valid bootstraps
|
| 236 |
|
| 237 |
diffs = np.array(diffs)
|
| 238 |
-
# Two-sided p-value: proportion of bootstrap diffs that cross zero
|
| 239 |
-
# Under H0: diff=0, we center the diffs
|
| 240 |
centered = diffs - np.mean(diffs)
|
| 241 |
p_value = np.mean(np.abs(centered) >= np.abs(observed_diff))
|
| 242 |
-
p_value = max(p_value, 1.0 / n_bootstrap)
|
| 243 |
|
| 244 |
ci_low = np.percentile(diffs, 2.5)
|
| 245 |
ci_high = np.percentile(diffs, 97.5)
|
| 246 |
|
| 247 |
return p_value, auc_a, auc_b, ci_low, ci_high
|
| 248 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
# ============================================================================
|
| 250 |
# Model configs (multi-class compatible)
|
| 251 |
# ============================================================================
|
|
@@ -326,10 +503,8 @@ def run_pipeline(
|
|
| 326 |
log(f" 📋 CSV: Col1=Label, Col2+=Features (no ID column)")
|
| 327 |
fnames = X.columns.tolist()
|
| 328 |
|
| 329 |
-
# Parse user selection: "3 类" -> 3, "2 类(二分类)" -> 2
|
| 330 |
user_n = int(str(n_classes_select).split(" ")[0])
|
| 331 |
|
| 332 |
-
# Validate against actual data
|
| 333 |
detected_classes = sorted(y.unique())
|
| 334 |
detected_classes = [int(c) if hasattr(c, 'item') else c for c in detected_classes]
|
| 335 |
detected_n = len(detected_classes)
|
|
@@ -342,7 +517,6 @@ def run_pipeline(
|
|
| 342 |
n_classes = user_n
|
| 343 |
log(f" ✅ {n_classes} 分类 — 数据验证通过")
|
| 344 |
|
| 345 |
-
# Remap to 0,1,...,n-1
|
| 346 |
label_map = {c: i for i, c in enumerate(classes)}
|
| 347 |
label_map_inv = {i: c for c, i in label_map.items()}
|
| 348 |
y_mapped = y.map(label_map)
|
|
@@ -439,7 +613,7 @@ def run_pipeline(
|
|
| 439 |
plot_multiclass_roc(r['all_yt'], r['all_yproba'], class_indices,
|
| 440 |
f'ROC — {mn} ({task_type}, Macro AUC={r["mean_auc"]:.3f})', f'roc_{mn}', rf)
|
| 441 |
|
| 442 |
-
# Combined ROC (macro per model)
|
| 443 |
plt.figure(figsize=(10, 8))
|
| 444 |
for i, mn in enumerate(mnames):
|
| 445 |
r = amr[mn]
|
|
@@ -515,16 +689,13 @@ def run_pipeline(
|
|
| 515 |
|
| 516 |
# ====================================================================
|
| 517 |
# ★★★ 新增 Part-1:训练集全模型 ROC / PR 曲线
|
| 518 |
-
# 新文件名前缀 train_roc_* / train_pr_*,与原有文件名零冲突
|
| 519 |
# ====================================================================
|
| 520 |
progress(0.57, desc="📈 [新增] 训练集ROC/PR曲线...")
|
| 521 |
log(f"\n 📈 [新增] 各模型训练集(in-sample)ROC / PR 曲线...")
|
| 522 |
|
| 523 |
-
# 两个内部辅助函数,仅用于叠加绘图数据准备
|
| 524 |
def _macro_roc_arrays(yt, yp, nc, cls_idx):
|
| 525 |
y_b = label_binarize(yt, classes=cls_idx)
|
| 526 |
-
if nc == 2:
|
| 527 |
-
y_b = np.hstack([1 - y_b, y_b])
|
| 528 |
all_fpr = np.linspace(0, 1, 300)
|
| 529 |
mean_tpr = np.zeros_like(all_fpr)
|
| 530 |
for c in range(nc):
|
|
@@ -535,8 +706,7 @@ def run_pipeline(
|
|
| 535 |
|
| 536 |
def _macro_pr_arrays(yt, yp, nc, cls_idx):
|
| 537 |
y_b = label_binarize(yt, classes=cls_idx)
|
| 538 |
-
if nc == 2:
|
| 539 |
-
y_b = np.hstack([1 - y_b, y_b])
|
| 540 |
all_rec = np.linspace(0, 1, 300)
|
| 541 |
mean_prec = np.zeros_like(all_rec)
|
| 542 |
for c in range(nc):
|
|
@@ -545,165 +715,129 @@ def run_pipeline(
|
|
| 545 |
mean_prec /= nc
|
| 546 |
return all_rec, mean_prec
|
| 547 |
|
| 548 |
-
_tr_roc = {}
|
| 549 |
-
_tr_pr = {} # mn -> (rec, prec) 供汇总图使用
|
| 550 |
-
|
| 551 |
for mn in mnames:
|
| 552 |
yp_tr = tms[mn].predict_proba(X.values)
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
f'Train ROC — {mn} ({task_type})',
|
| 558 |
-
f'train_roc_{mn}', rf
|
| 559 |
-
)
|
| 560 |
-
plot_multiclass_pr(
|
| 561 |
-
y_mapped.values, yp_tr, class_indices,
|
| 562 |
-
f'Train PR — {mn} ({task_type})',
|
| 563 |
-
f'train_pr_{mn}', rf
|
| 564 |
-
)
|
| 565 |
-
|
| 566 |
fpr_t, tpr_t, auc_t = _macro_roc_arrays(y_mapped.values, yp_tr, n_classes, class_indices)
|
| 567 |
rec_t, prec_t = _macro_pr_arrays(y_mapped.values, yp_tr, n_classes, class_indices)
|
| 568 |
_tr_roc[mn] = (fpr_t, tpr_t, auc_t)
|
| 569 |
_tr_pr[mn] = (rec_t, prec_t)
|
| 570 |
|
| 571 |
-
# 汇总:训练集全模型 ROC(train_roc_all)
|
| 572 |
plt.figure(figsize=(10, 8))
|
| 573 |
for i, mn in enumerate(mnames):
|
| 574 |
fpr_t, tpr_t, auc_t = _tr_roc[mn]
|
| 575 |
-
plt.plot(fpr_t, tpr_t, color=COLORS[i
|
| 576 |
-
|
| 577 |
-
plt.
|
| 578 |
-
plt.
|
| 579 |
-
plt.
|
| 580 |
-
plt.
|
| 581 |
-
plt.
|
| 582 |
-
plt.
|
| 583 |
-
plt.grid(True, alpha=0.15); plt.tight_layout()
|
| 584 |
-
plt.savefig(os.path.join(rf, 'train_roc_all.pdf'), format='pdf', bbox_inches='tight', dpi=300)
|
| 585 |
-
plt.savefig(os.path.join(rf, 'train_roc_all.png'), format='png', bbox_inches='tight', dpi=150)
|
| 586 |
plt.close()
|
| 587 |
|
| 588 |
-
# 汇总:训练集全模型 PR(train_pr_all)
|
| 589 |
plt.figure(figsize=(10, 8))
|
| 590 |
for i, mn in enumerate(mnames):
|
| 591 |
rec_t, prec_t = _tr_pr[mn]
|
| 592 |
-
plt.plot(rec_t, prec_t, color=COLORS[i
|
| 593 |
-
|
| 594 |
-
plt.
|
| 595 |
-
plt.
|
| 596 |
-
plt.
|
| 597 |
-
plt.
|
| 598 |
-
plt.
|
| 599 |
-
plt.grid(True, alpha=0.15); plt.tight_layout()
|
| 600 |
-
plt.savefig(os.path.join(rf, 'train_pr_all.pdf'), format='pdf', bbox_inches='tight', dpi=300)
|
| 601 |
-
plt.savefig(os.path.join(rf, 'train_pr_all.png'), format='png', bbox_inches='tight', dpi=150)
|
| 602 |
plt.close()
|
| 603 |
-
log(f" ✅ 训练集 ROC/PR 已生成
|
| 604 |
|
| 605 |
# ====================================================================
|
| 606 |
-
# ★★★ 新增 Part-2:最终模型
|
| 607 |
-
# 新文件:roc_train_vs_cv_* / pr_train_vs_cv_* / cm_train_*
|
| 608 |
-
# train_vs_cv_*.xlsx
|
| 609 |
-
# 原有文件:roc_* / pr_* / cm_* / model_evaluation.xlsx 均不变
|
| 610 |
# ====================================================================
|
| 611 |
progress(0.59, desc="📊 [新增] 最终模型Train vs CV对比...")
|
| 612 |
log(f"\n 📊 [新增] 最终模型 [{best_mn}] 训练集 vs 内部验证集(CV holdout)...")
|
| 613 |
|
| 614 |
-
# 训练集预测(用全量 fit 后的模型 tms[best_mn])
|
| 615 |
yp_best_tr = tms[best_mn].predict_proba(X.values)
|
| 616 |
yd_best_tr = tms[best_mn].predict(X.values)
|
| 617 |
-
met_tr = compute_multiclass_metrics(
|
| 618 |
-
y_mapped.values, yd_best_tr, yp_best_tr, class_indices)
|
| 619 |
|
| 620 |
-
# 内部 CV holdout(直接取 amr 中已累积的结果,不重新运算)
|
| 621 |
yp_best_cv = amr[best_mn]['all_yproba']
|
| 622 |
yd_best_cv = amr[best_mn]['all_yp']
|
| 623 |
yt_best_cv = amr[best_mn]['all_yt']
|
| 624 |
-
|
| 625 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 626 |
|
| 627 |
log(f" Train → AUC={met_tr['Macro_AUC']:.4f} Acc={met_tr['Accuracy']:.4f}"
|
| 628 |
-
f"
|
| 629 |
log(f" CV-Val → AUC={met_cv['Macro_AUC']:.4f} Acc={met_cv['Accuracy']:.4f}"
|
| 630 |
-
f"
|
| 631 |
-
|
| 632 |
-
# 对比 ROC(roc_train_vs_cv_{best_mn})
|
| 633 |
-
fpr_tb, tpr_tb, auc_tb = _macro_roc_arrays(
|
| 634 |
-
y_mapped.values, yp_best_tr, n_classes, class_indices)
|
| 635 |
-
fpr_cb, tpr_cb, auc_cb = _macro_roc_arrays(
|
| 636 |
-
yt_best_cv, yp_best_cv, n_classes, class_indices)
|
| 637 |
|
|
|
|
|
|
|
|
|
|
| 638 |
fig, ax = plt.subplots(figsize=(10, 8))
|
| 639 |
-
ax.plot(fpr_tb, tpr_tb, color='#e41a1c', lw=2.5,
|
| 640 |
-
|
| 641 |
-
ax.plot(
|
| 642 |
-
|
| 643 |
-
ax.
|
| 644 |
-
ax.
|
| 645 |
-
ax.
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
fontsize=14, fontweight='bold')
|
| 649 |
-
ax.legend(loc='lower right', fontsize=11)
|
| 650 |
-
ax.grid(True, alpha=0.15); plt.tight_layout()
|
| 651 |
-
plt.savefig(os.path.join(rf, f'roc_train_vs_cv_{best_mn}.pdf'),
|
| 652 |
-
format='pdf', bbox_inches='tight', dpi=300)
|
| 653 |
-
plt.savefig(os.path.join(rf, f'roc_train_vs_cv_{best_mn}.png'),
|
| 654 |
-
format='png', bbox_inches='tight', dpi=150)
|
| 655 |
plt.close()
|
| 656 |
|
| 657 |
-
# 对比 PR
|
| 658 |
-
rec_tb, prec_tb = _macro_pr_arrays(
|
| 659 |
-
|
| 660 |
-
rec_cb, prec_cb = _macro_pr_arrays(
|
| 661 |
-
yt_best_cv, yp_best_cv, n_classes, class_indices)
|
| 662 |
-
|
| 663 |
fig, ax = plt.subplots(figsize=(10, 8))
|
| 664 |
-
ax.plot(rec_tb, prec_tb, color='#e41a1c', lw=2.5,
|
| 665 |
-
|
| 666 |
-
ax.
|
| 667 |
-
|
| 668 |
-
ax.
|
| 669 |
-
ax.
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
fontsize=14, fontweight='bold')
|
| 673 |
-
ax.legend(loc='lower left', fontsize=11)
|
| 674 |
-
ax.grid(True, alpha=0.15); plt.tight_layout()
|
| 675 |
-
plt.savefig(os.path.join(rf, f'pr_train_vs_cv_{best_mn}.pdf'),
|
| 676 |
-
format='pdf', bbox_inches='tight', dpi=300)
|
| 677 |
-
plt.savefig(os.path.join(rf, f'pr_train_vs_cv_{best_mn}.png'),
|
| 678 |
-
format='png', bbox_inches='tight', dpi=150)
|
| 679 |
plt.close()
|
| 680 |
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
y_mapped.values, yd_best_tr, class_indices,
|
| 684 |
-
f'Train CM — {best_mn} (Acc={met_tr["Accuracy"]:.3f})',
|
| 685 |
-
f'cm_train_{best_mn}', rf
|
| 686 |
-
)
|
| 687 |
|
| 688 |
-
#
|
| 689 |
-
with pd.ExcelWriter(
|
| 690 |
-
|
| 691 |
-
|
|
|
|
|
|
|
| 692 |
pd.DataFrame([
|
| 693 |
-
{'Split':
|
| 694 |
-
'Macro_AUC':
|
| 695 |
-
'
|
| 696 |
-
'
|
| 697 |
-
|
| 698 |
-
'
|
| 699 |
-
|
| 700 |
-
'
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 704 |
amr[best_mn]['fold_df'].to_excel(w, sheet_name='CV_FoldDetail', index=False)
|
| 705 |
|
| 706 |
-
log(f" ✅ Train vs CV 对比
|
| 707 |
# ====================================================================
|
| 708 |
# ★★★ 新增结束
|
| 709 |
# ====================================================================
|
|
@@ -712,7 +846,6 @@ def run_pipeline(
|
|
| 712 |
progress(0.62, desc="🔥 SHAP分析...")
|
| 713 |
log(f"\n 🔥 SHAP特征分析 (保留模型中 Top 3)...")
|
| 714 |
shap_imp = {}
|
| 715 |
-
# SHAP for top 3 retained models
|
| 716 |
models_for_shap = sorted(retained, key=lambda x: amr[x]['mean_auc'], reverse=True)[:3]
|
| 717 |
|
| 718 |
for si, mn in enumerate(models_for_shap):
|
|
@@ -729,12 +862,10 @@ def run_pipeline(
|
|
| 729 |
exp = shap.KernelExplainer(lambda x, m=mo: m.predict_proba(x), bg)
|
| 730 |
sv = exp.shap_values(Xs)
|
| 731 |
|
| 732 |
-
# Handle SHAP output: could be list of arrays (one per class) or 3D array
|
| 733 |
if isinstance(sv, list):
|
| 734 |
-
# Average absolute SHAP across all classes
|
| 735 |
sv_abs = np.mean([np.abs(s) for s in sv], axis=0)
|
| 736 |
elif sv.ndim == 3:
|
| 737 |
-
sv_abs = np.mean(np.abs(sv), axis=2)
|
| 738 |
else:
|
| 739 |
sv_abs = np.abs(sv)
|
| 740 |
|
|
@@ -745,7 +876,6 @@ def run_pipeline(
|
|
| 745 |
idf = pd.DataFrame({'Feature': fnames, 'Importance': fi}).sort_values('Importance', ascending=False)
|
| 746 |
shap_imp[mn] = idf
|
| 747 |
|
| 748 |
-
# Bar plot (works for any number of classes)
|
| 749 |
plt.figure(figsize=(10, max(6, TOPN * 0.3)))
|
| 750 |
top_df = idf.head(TOPN).iloc[::-1]
|
| 751 |
plt.barh(top_df['Feature'], top_df['Importance'], color='#2563eb', alpha=0.8)
|
|
@@ -759,7 +889,7 @@ def run_pipeline(
|
|
| 759 |
except Exception as e:
|
| 760 |
log(f" ⚠ {mn} SHAP失败: {e}")
|
| 761 |
|
| 762 |
-
# ── Feature Ablation
|
| 763 |
progress(0.72, desc="🧪 特征消融...")
|
| 764 |
log(f"\n 🧪 特征消融 (仅最佳模型 {best_mn})...")
|
| 765 |
ablation_data = None
|
|
@@ -767,11 +897,14 @@ def run_pipeline(
|
|
| 767 |
imp_df = shap_imp[best_mn]
|
| 768 |
top_feats = imp_df.head(TOPN)['Feature'].tolist()
|
| 769 |
fcs = []; aucs_a = []
|
| 770 |
-
|
|
|
|
| 771 |
|
| 772 |
for nf in range(1, len(top_feats) + 1):
|
| 773 |
Xsub = X[top_feats[:nf]]
|
| 774 |
fold_aucs = []
|
|
|
|
|
|
|
| 775 |
for tri, tei in skf.split(Xsub, y_mapped):
|
| 776 |
mf = deepcopy(mcfg[best_mn]['model'])
|
| 777 |
bp2 = bpd.get(best_mn, {})
|
|
@@ -786,31 +919,57 @@ def run_pipeline(
|
|
| 786 |
a = roc_auc_score(yte_f, yproba_f, multi_class='ovr', average='macro')
|
| 787 |
except: a = 0.0
|
| 788 |
fold_aucs.append(a)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 789 |
fcs.append(nf); aucs_a.append(np.mean(fold_aucs))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 790 |
|
| 791 |
-
# Find optimal: first N where AUC >= 95% of full AUC
|
| 792 |
full_auc = amr[best_mn]['mean_auc']
|
| 793 |
opt_n = len(top_feats)
|
| 794 |
for i, a in enumerate(aucs_a):
|
| 795 |
if a >= full_auc * 0.95:
|
| 796 |
opt_n = i + 1; break
|
| 797 |
|
| 798 |
-
ablation_data = {
|
|
|
|
|
|
|
|
|
|
| 799 |
log(f" ✅ 最优特征数: {opt_n} (AUC={aucs_a[opt_n-1]:.4f} vs Full={full_auc:.4f})")
|
| 800 |
|
| 801 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 802 |
plt.figure(figsize=(10, 7))
|
| 803 |
plt.plot(fcs, aucs_a, 'o-', color='#2563eb', lw=2, ms=5)
|
| 804 |
-
plt.scatter([opt_n], [aucs_a[opt_n-1]], s=200, marker='*',
|
| 805 |
-
|
|
|
|
|
|
|
| 806 |
plt.xlabel('Number of Features', fontsize=13); plt.ylabel('Macro AUC', fontsize=13)
|
| 807 |
-
plt.title(f'Feature Ablation — {best_mn} (★ Optimal={opt_n})',
|
|
|
|
| 808 |
plt.legend(fontsize=11); plt.grid(True, alpha=0.15); plt.tight_layout()
|
| 809 |
plt.savefig(os.path.join(rf, 'ablation.pdf'), format='pdf', bbox_inches='tight')
|
| 810 |
plt.savefig(os.path.join(rf, 'ablation.png'), format='png', bbox_inches='tight', dpi=150)
|
| 811 |
plt.close()
|
| 812 |
|
| 813 |
-
# ── External Validation ── 【原有代码,原封不动】
|
| 814 |
val_files_list = [vf for vf in [val_file1, val_file2, val_file3] if vf is not None]
|
| 815 |
final_feats = ablation_data['opt_feats'] if ablation_data else fnames
|
| 816 |
|
|
@@ -826,7 +985,6 @@ def run_pipeline(
|
|
| 826 |
vcol2_is_id = (vcol2.dtype == 'object') or (vcol2.nunique() / len(vcol2) > 0.5)
|
| 827 |
Xe = ed.iloc[:, 2:] if vcol2_is_id else ed.iloc[:, 1:]
|
| 828 |
|
| 829 |
-
# Map validation labels using same mapping
|
| 830 |
ye = ye_raw.map(label_map)
|
| 831 |
if ye.isna().any():
|
| 832 |
log(f" ⚠ 验证集 {vi} 含有训练集中不存在的标签,已跳过")
|
|
@@ -842,8 +1000,16 @@ def run_pipeline(
|
|
| 842 |
yep = fm.predict_proba(Xes.values); yed = fm.predict(Xes.values)
|
| 843 |
ye_np = ye.values
|
| 844 |
|
|
|
|
| 845 |
metrics = compute_multiclass_metrics(ye_np, yed, yep, class_indices)
|
| 846 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 847 |
|
| 848 |
sfx = f'_ext{vi}' if len(val_files_list) > 1 else '_ext'
|
| 849 |
tag = f'Validation {vi}' if len(val_files_list) > 1 else 'External'
|
|
@@ -853,42 +1019,87 @@ def run_pipeline(
|
|
| 853 |
plot_confusion_matrix(ye_np, yed, class_indices, f'CM — {tag} ({best_mn})', f'cm{sfx}', rf)
|
| 854 |
|
| 855 |
with pd.ExcelWriter(os.path.join(rf, f'validation{sfx}.xlsx'), engine='openpyxl') as w:
|
|
|
|
| 856 |
pd.DataFrame([{'Model': best_mn, 'N_Features': len(final_feats),
|
| 857 |
'Macro_AUC': metrics['Macro_AUC'], 'Accuracy': metrics['Accuracy'],
|
| 858 |
'Macro_F1': metrics['Macro_F1'], 'Weighted_F1': metrics['Weighted_F1'],
|
| 859 |
'Kappa': metrics['Kappa']}]).to_excel(w, sheet_name='Metrics', index=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 860 |
rpt = pd.DataFrame(metrics['report']).T
|
| 861 |
rpt.to_excel(w, sheet_name='Per_Class', index=True)
|
| 862 |
pd.DataFrame({'Feature': final_feats}).to_excel(w, sheet_name='Features', index=False)
|
| 863 |
|
| 864 |
-
# ── Save Results ── 【原有代码,原封不动】
|
| 865 |
progress(0.92, desc="💾 保存结果...")
|
| 866 |
log(f"\n 💾 保存结果...")
|
| 867 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 868 |
with pd.ExcelWriter(os.path.join(rf, 'model_evaluation.xlsx'), engine='openpyxl') as w:
|
|
|
|
| 869 |
for mn, r in amr.items():
|
| 870 |
r['fold_df'].to_excel(w, sheet_name=mn, index=False)
|
| 871 |
-
|
|
|
|
| 872 |
sd = [{'Model': mn, 'Macro_AUC': r['mean_auc'], 'Accuracy': r['mean_acc'],
|
| 873 |
'Macro_F1': r['mean_f1'], 'Retained': 'Yes' if mn in retained else 'No',
|
| 874 |
'Best': 'Best' if mn == best_mn else ''}
|
| 875 |
for mn, r in amr.items()]
|
| 876 |
-
pd.DataFrame(sd).sort_values('Macro_AUC', ascending=False).to_excel(
|
| 877 |
-
|
|
|
|
|
|
|
| 878 |
if len(bootstrap_df) > 0:
|
| 879 |
bootstrap_df.to_excel(w, sheet_name='Bootstrap_Test', index=False)
|
| 880 |
-
|
|
|
|
| 881 |
best_report = classification_report(amr[best_mn]['all_yt'], amr[best_mn]['all_yp'],
|
| 882 |
labels=class_indices, output_dict=True, zero_division=0)
|
| 883 |
pd.DataFrame(best_report).T.to_excel(w, sheet_name=f'{best_mn}_PerClass', index=True)
|
| 884 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 885 |
if ablation_data:
|
| 886 |
with pd.ExcelWriter(os.path.join(rf, 'feature_ablation.xlsx'), engine='openpyxl') as w:
|
| 887 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 888 |
for mn, idf in shap_imp.items():
|
| 889 |
idf.to_excel(w, sheet_name=f'{mn}_Imp', index=False)
|
| 890 |
|
| 891 |
-
# Save params
|
| 892 |
with open(os.path.join(rf, 'best_params.txt'), 'w', encoding='utf-8') as f:
|
| 893 |
f.write(f"Task: {task_type} Classification ({n_classes} classes)\n")
|
| 894 |
f.write(f"Classes: {classes}\n")
|
|
@@ -914,7 +1125,7 @@ def run_pipeline(
|
|
| 914 |
if ablation_data:
|
| 915 |
f.write(f"\nOptimal Features ({ablation_data['opt_n']}): {', '.join(ablation_data['opt_feats'])}\n")
|
| 916 |
|
| 917 |
-
# Save model
|
| 918 |
pickle.dump({
|
| 919 |
'model_name': best_mn, 'model': tms[best_mn], 'best_params': bpd[best_mn],
|
| 920 |
'classes': classes, 'n_classes': n_classes, 'label_map': label_map,
|
|
|
|
| 235 |
return 1.0, auc_a, auc_b, -1, 1 # Not enough valid bootstraps
|
| 236 |
|
| 237 |
diffs = np.array(diffs)
|
|
|
|
|
|
|
| 238 |
centered = diffs - np.mean(diffs)
|
| 239 |
p_value = np.mean(np.abs(centered) >= np.abs(observed_diff))
|
| 240 |
+
p_value = max(p_value, 1.0 / n_bootstrap)
|
| 241 |
|
| 242 |
ci_low = np.percentile(diffs, 2.5)
|
| 243 |
ci_high = np.percentile(diffs, 97.5)
|
| 244 |
|
| 245 |
return p_value, auc_a, auc_b, ci_low, ci_high
|
| 246 |
|
| 247 |
+
|
| 248 |
+
# ============================================================================
|
| 249 |
+
# ★ 新增全局工具函数:Bootstrap 95%CI + 敏感性/特异性等指标计算
|
| 250 |
+
# ============================================================================
|
| 251 |
+
|
| 252 |
+
def _bootstrap_ci(y_true, y_pred, y_proba, classes, metric_fn, n_bootstrap=1000, seed=42):
|
| 253 |
+
"""
|
| 254 |
+
通用 Bootstrap 95% CI 计算器。
|
| 255 |
+
metric_fn(yt, yp, yproba) -> float
|
| 256 |
+
返回 (point_estimate, ci_low, ci_high)
|
| 257 |
+
"""
|
| 258 |
+
rng = np.random.RandomState(seed)
|
| 259 |
+
n = len(y_true)
|
| 260 |
+
n_cls = len(classes)
|
| 261 |
+
point = metric_fn(y_true, y_pred, y_proba)
|
| 262 |
+
boots = []
|
| 263 |
+
for _ in range(n_bootstrap):
|
| 264 |
+
idx = rng.choice(n, n, replace=True)
|
| 265 |
+
yt_b = y_true[idx]
|
| 266 |
+
if len(np.unique(yt_b)) < n_cls:
|
| 267 |
+
continue
|
| 268 |
+
try:
|
| 269 |
+
boots.append(metric_fn(yt_b, y_pred[idx], y_proba[idx]))
|
| 270 |
+
except:
|
| 271 |
+
pass
|
| 272 |
+
if len(boots) < 50:
|
| 273 |
+
return point, np.nan, np.nan
|
| 274 |
+
return point, float(np.percentile(boots, 2.5)), float(np.percentile(boots, 97.5))
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def compute_extended_metrics_with_ci(y_true, y_pred, y_proba, classes,
|
| 278 |
+
n_bootstrap=1000, seed=42):
|
| 279 |
+
"""
|
| 280 |
+
计算完整的多分类诊断指标,包含 95% Bootstrap CI。
|
| 281 |
+
指标(均为 macro-OvR 平均):
|
| 282 |
+
Accuracy, Macro_AUC, Macro_F1, Weighted_F1, Kappa,
|
| 283 |
+
Sensitivity (Recall), Specificity, PPV (Precision), NPV, F1_macro
|
| 284 |
+
返回 dict,每个指标带 _CI_low / _CI_high。
|
| 285 |
+
同时返回 per_class_df(逐类详细指标)。
|
| 286 |
+
"""
|
| 287 |
+
n_cls = len(classes)
|
| 288 |
+
y_true = np.array(y_true)
|
| 289 |
+
y_pred = np.array(y_pred)
|
| 290 |
+
y_proba = np.array(y_proba)
|
| 291 |
+
|
| 292 |
+
# ── 逐类指标(OvR) ──
|
| 293 |
+
per_rows = []
|
| 294 |
+
for i, cls in enumerate(classes):
|
| 295 |
+
yt_b = (y_true == i).astype(int)
|
| 296 |
+
yp_b = (y_pred == i).astype(int)
|
| 297 |
+
ypr_b = y_proba[:, i]
|
| 298 |
+
|
| 299 |
+
cm_b = confusion_matrix(yt_b, yp_b, labels=[0, 1])
|
| 300 |
+
tn, fp, fn, tp = cm_b.ravel() if cm_b.shape == (2,2) else (0,0,0,0)
|
| 301 |
+
|
| 302 |
+
sens = tp / (tp + fn) if (tp + fn) > 0 else 0.0
|
| 303 |
+
spec = tn / (tn + fp) if (tn + fp) > 0 else 0.0
|
| 304 |
+
ppv = tp / (tp + fp) if (tp + fp) > 0 else 0.0
|
| 305 |
+
npv = tn / (tn + fn) if (tn + fn) > 0 else 0.0
|
| 306 |
+
f1_c = 2*ppv*sens / (ppv+sens) if (ppv+sens) > 0 else 0.0
|
| 307 |
+
try:
|
| 308 |
+
auc_c = roc_auc_score(yt_b, ypr_b)
|
| 309 |
+
except:
|
| 310 |
+
auc_c = np.nan
|
| 311 |
+
|
| 312 |
+
per_rows.append({
|
| 313 |
+
'Class': cls, 'TP': int(tp), 'FP': int(fp), 'FN': int(fn), 'TN': int(tn),
|
| 314 |
+
'Sensitivity': sens, 'Specificity': spec,
|
| 315 |
+
'PPV': ppv, 'NPV': npv, 'F1': f1_c, 'AUC': auc_c
|
| 316 |
+
})
|
| 317 |
+
|
| 318 |
+
per_class_df = pd.DataFrame(per_rows)
|
| 319 |
+
|
| 320 |
+
# ── Macro 平均点估计 ──
|
| 321 |
+
def _macro_sens(yt, yp, ypr):
|
| 322 |
+
vals = []
|
| 323 |
+
for i in range(n_cls):
|
| 324 |
+
yt_b = (yt == i).astype(int); yp_b = (yp == i).astype(int)
|
| 325 |
+
cm_b = confusion_matrix(yt_b, yp_b, labels=[0,1])
|
| 326 |
+
tn,fp,fn,tp = cm_b.ravel() if cm_b.shape==(2,2) else (0,0,0,0)
|
| 327 |
+
vals.append(tp/(tp+fn) if (tp+fn)>0 else 0.0)
|
| 328 |
+
return float(np.mean(vals))
|
| 329 |
+
|
| 330 |
+
def _macro_spec(yt, yp, ypr):
|
| 331 |
+
vals = []
|
| 332 |
+
for i in range(n_cls):
|
| 333 |
+
yt_b = (yt == i).astype(int); yp_b = (yp == i).astype(int)
|
| 334 |
+
cm_b = confusion_matrix(yt_b, yp_b, labels=[0,1])
|
| 335 |
+
tn,fp,fn,tp = cm_b.ravel() if cm_b.shape==(2,2) else (0,0,0,0)
|
| 336 |
+
vals.append(tn/(tn+fp) if (tn+fp)>0 else 0.0)
|
| 337 |
+
return float(np.mean(vals))
|
| 338 |
+
|
| 339 |
+
def _macro_ppv(yt, yp, ypr):
|
| 340 |
+
vals = []
|
| 341 |
+
for i in range(n_cls):
|
| 342 |
+
yt_b = (yt == i).astype(int); yp_b = (yp == i).astype(int)
|
| 343 |
+
cm_b = confusion_matrix(yt_b, yp_b, labels=[0,1])
|
| 344 |
+
tn,fp,fn,tp = cm_b.ravel() if cm_b.shape==(2,2) else (0,0,0,0)
|
| 345 |
+
vals.append(tp/(tp+fp) if (tp+fp)>0 else 0.0)
|
| 346 |
+
return float(np.mean(vals))
|
| 347 |
+
|
| 348 |
+
def _macro_npv(yt, yp, ypr):
|
| 349 |
+
vals = []
|
| 350 |
+
for i in range(n_cls):
|
| 351 |
+
yt_b = (yt == i).astype(int); yp_b = (yp == i).astype(int)
|
| 352 |
+
cm_b = confusion_matrix(yt_b, yp_b, labels=[0,1])
|
| 353 |
+
tn,fp,fn,tp = cm_b.ravel() if cm_b.shape==(2,2) else (0,0,0,0)
|
| 354 |
+
vals.append(tn/(tn+fn) if (tn+fn)>0 else 0.0)
|
| 355 |
+
return float(np.mean(vals))
|
| 356 |
+
|
| 357 |
+
def _macro_f1(yt, yp, ypr):
|
| 358 |
+
return float(f1_score(yt, yp, average='macro', zero_division=0))
|
| 359 |
+
|
| 360 |
+
def _acc(yt, yp, ypr):
|
| 361 |
+
return float(accuracy_score(yt, yp))
|
| 362 |
+
|
| 363 |
+
def _kappa(yt, yp, ypr):
|
| 364 |
+
return float(cohen_kappa_score(yt, yp))
|
| 365 |
+
|
| 366 |
+
def _macro_auc(yt, yp, ypr):
|
| 367 |
+
try:
|
| 368 |
+
if n_cls == 2:
|
| 369 |
+
return float(roc_auc_score(yt, ypr[:, 1]))
|
| 370 |
+
return float(roc_auc_score(yt, ypr, multi_class='ovr', average='macro'))
|
| 371 |
+
except:
|
| 372 |
+
return 0.0
|
| 373 |
+
|
| 374 |
+
def _wf1(yt, yp, ypr):
|
| 375 |
+
return float(f1_score(yt, yp, average='weighted', zero_division=0))
|
| 376 |
+
|
| 377 |
+
metric_fns = {
|
| 378 |
+
'Accuracy': _acc,
|
| 379 |
+
'Macro_AUC': _macro_auc,
|
| 380 |
+
'Sensitivity': _macro_sens,
|
| 381 |
+
'Specificity': _macro_spec,
|
| 382 |
+
'PPV': _macro_ppv,
|
| 383 |
+
'NPV': _macro_npv,
|
| 384 |
+
'Macro_F1': _macro_f1,
|
| 385 |
+
'Weighted_F1': _wf1,
|
| 386 |
+
'Kappa': _kappa,
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
result = {}
|
| 390 |
+
for name, fn in metric_fns.items():
|
| 391 |
+
pt, lo, hi = _bootstrap_ci(y_true, y_pred, y_proba, classes, fn,
|
| 392 |
+
n_bootstrap=n_bootstrap, seed=seed)
|
| 393 |
+
result[name] = pt
|
| 394 |
+
result[f'{name}_CI_low'] = lo
|
| 395 |
+
result[f'{name}_CI_high'] = hi
|
| 396 |
+
|
| 397 |
+
# 保留 report 字段以兼容原有代码
|
| 398 |
+
result['report'] = classification_report(
|
| 399 |
+
y_true, y_pred, labels=list(range(n_cls)), output_dict=True, zero_division=0)
|
| 400 |
+
|
| 401 |
+
return result, per_class_df
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def _fmt(val, lo, hi):
|
| 405 |
+
"""格式化为 '0.xxx (0.xxx–0.xxx)' 供展示"""
|
| 406 |
+
if np.isnan(lo):
|
| 407 |
+
return f"{val:.4f}"
|
| 408 |
+
return f"{val:.4f} ({lo:.4f}–{hi:.4f})"
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
def build_metrics_summary_df(ext_metrics, model_name, split_name):
|
| 412 |
+
"""把 compute_extended_metrics_with_ci 结果转为单行 DataFrame,含 CI 列"""
|
| 413 |
+
keys = ['Accuracy','Macro_AUC','Sensitivity','Specificity',
|
| 414 |
+
'PPV','NPV','Macro_F1','Weighted_F1','Kappa']
|
| 415 |
+
row = {'Model': model_name, 'Split': split_name}
|
| 416 |
+
for k in keys:
|
| 417 |
+
row[k] = ext_metrics.get(k, np.nan)
|
| 418 |
+
row[f'{k}_95CI'] = _fmt(
|
| 419 |
+
ext_metrics.get(k, np.nan),
|
| 420 |
+
ext_metrics.get(f'{k}_CI_low', np.nan),
|
| 421 |
+
ext_metrics.get(f'{k}_CI_high', np.nan)
|
| 422 |
+
)
|
| 423 |
+
return pd.DataFrame([row])
|
| 424 |
+
|
| 425 |
+
|
| 426 |
# ============================================================================
|
| 427 |
# Model configs (multi-class compatible)
|
| 428 |
# ============================================================================
|
|
|
|
| 503 |
log(f" 📋 CSV: Col1=Label, Col2+=Features (no ID column)")
|
| 504 |
fnames = X.columns.tolist()
|
| 505 |
|
|
|
|
| 506 |
user_n = int(str(n_classes_select).split(" ")[0])
|
| 507 |
|
|
|
|
| 508 |
detected_classes = sorted(y.unique())
|
| 509 |
detected_classes = [int(c) if hasattr(c, 'item') else c for c in detected_classes]
|
| 510 |
detected_n = len(detected_classes)
|
|
|
|
| 517 |
n_classes = user_n
|
| 518 |
log(f" ✅ {n_classes} 分类 — 数据验证通过")
|
| 519 |
|
|
|
|
| 520 |
label_map = {c: i for i, c in enumerate(classes)}
|
| 521 |
label_map_inv = {i: c for c, i in label_map.items()}
|
| 522 |
y_mapped = y.map(label_map)
|
|
|
|
| 613 |
plot_multiclass_roc(r['all_yt'], r['all_yproba'], class_indices,
|
| 614 |
f'ROC — {mn} ({task_type}, Macro AUC={r["mean_auc"]:.3f})', f'roc_{mn}', rf)
|
| 615 |
|
| 616 |
+
# Combined ROC (macro per model)
|
| 617 |
plt.figure(figsize=(10, 8))
|
| 618 |
for i, mn in enumerate(mnames):
|
| 619 |
r = amr[mn]
|
|
|
|
| 689 |
|
| 690 |
# ====================================================================
|
| 691 |
# ★★★ 新增 Part-1:训练集全模型 ROC / PR 曲线
|
|
|
|
| 692 |
# ====================================================================
|
| 693 |
progress(0.57, desc="📈 [新增] 训练集ROC/PR曲线...")
|
| 694 |
log(f"\n 📈 [新增] 各模型训练集(in-sample)ROC / PR 曲线...")
|
| 695 |
|
|
|
|
| 696 |
def _macro_roc_arrays(yt, yp, nc, cls_idx):
|
| 697 |
y_b = label_binarize(yt, classes=cls_idx)
|
| 698 |
+
if nc == 2: y_b = np.hstack([1 - y_b, y_b])
|
|
|
|
| 699 |
all_fpr = np.linspace(0, 1, 300)
|
| 700 |
mean_tpr = np.zeros_like(all_fpr)
|
| 701 |
for c in range(nc):
|
|
|
|
| 706 |
|
| 707 |
def _macro_pr_arrays(yt, yp, nc, cls_idx):
|
| 708 |
y_b = label_binarize(yt, classes=cls_idx)
|
| 709 |
+
if nc == 2: y_b = np.hstack([1 - y_b, y_b])
|
|
|
|
| 710 |
all_rec = np.linspace(0, 1, 300)
|
| 711 |
mean_prec = np.zeros_like(all_rec)
|
| 712 |
for c in range(nc):
|
|
|
|
| 715 |
mean_prec /= nc
|
| 716 |
return all_rec, mean_prec
|
| 717 |
|
| 718 |
+
_tr_roc = {}; _tr_pr = {}
|
|
|
|
|
|
|
| 719 |
for mn in mnames:
|
| 720 |
yp_tr = tms[mn].predict_proba(X.values)
|
| 721 |
+
plot_multiclass_roc(y_mapped.values, yp_tr, class_indices,
|
| 722 |
+
f'Train ROC — {mn} ({task_type})', f'train_roc_{mn}', rf)
|
| 723 |
+
plot_multiclass_pr(y_mapped.values, yp_tr, class_indices,
|
| 724 |
+
f'Train PR — {mn} ({task_type})', f'train_pr_{mn}', rf)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 725 |
fpr_t, tpr_t, auc_t = _macro_roc_arrays(y_mapped.values, yp_tr, n_classes, class_indices)
|
| 726 |
rec_t, prec_t = _macro_pr_arrays(y_mapped.values, yp_tr, n_classes, class_indices)
|
| 727 |
_tr_roc[mn] = (fpr_t, tpr_t, auc_t)
|
| 728 |
_tr_pr[mn] = (rec_t, prec_t)
|
| 729 |
|
|
|
|
| 730 |
plt.figure(figsize=(10, 8))
|
| 731 |
for i, mn in enumerate(mnames):
|
| 732 |
fpr_t, tpr_t, auc_t = _tr_roc[mn]
|
| 733 |
+
plt.plot(fpr_t, tpr_t, color=COLORS[i%8], lw=2.5, label=f'{mn} (Train Macro AUC={auc_t:.3f})')
|
| 734 |
+
plt.plot([0,1],[0,1],'--',color='#ccc',lw=1)
|
| 735 |
+
plt.xlim([-0.02,1.02]); plt.ylim([-0.02,1.02])
|
| 736 |
+
plt.xlabel('False Positive Rate',fontsize=13); plt.ylabel('True Positive Rate',fontsize=13)
|
| 737 |
+
plt.title(f'Train ROC — All Models ({task_type})',fontsize=14,fontweight='bold')
|
| 738 |
+
plt.legend(loc='lower right',fontsize=10); plt.grid(True,alpha=0.15); plt.tight_layout()
|
| 739 |
+
plt.savefig(os.path.join(rf,'train_roc_all.pdf'),format='pdf',bbox_inches='tight',dpi=300)
|
| 740 |
+
plt.savefig(os.path.join(rf,'train_roc_all.png'),format='png',bbox_inches='tight',dpi=150)
|
|
|
|
|
|
|
|
|
|
| 741 |
plt.close()
|
| 742 |
|
|
|
|
| 743 |
plt.figure(figsize=(10, 8))
|
| 744 |
for i, mn in enumerate(mnames):
|
| 745 |
rec_t, prec_t = _tr_pr[mn]
|
| 746 |
+
plt.plot(rec_t, prec_t, color=COLORS[i%8], lw=2.5, label=f'{mn} (Mean AP={prec_t.mean():.3f})')
|
| 747 |
+
plt.xlim([-0.02,1.02]); plt.ylim([-0.02,1.02])
|
| 748 |
+
plt.xlabel('Recall',fontsize=13); plt.ylabel('Precision',fontsize=13)
|
| 749 |
+
plt.title(f'Train PR — All Models ({task_type})',fontsize=14,fontweight='bold')
|
| 750 |
+
plt.legend(loc='lower left',fontsize=10); plt.grid(True,alpha=0.15); plt.tight_layout()
|
| 751 |
+
plt.savefig(os.path.join(rf,'train_pr_all.pdf'),format='pdf',bbox_inches='tight',dpi=300)
|
| 752 |
+
plt.savefig(os.path.join(rf,'train_pr_all.png'),format='png',bbox_inches='tight',dpi=150)
|
|
|
|
|
|
|
|
|
|
| 753 |
plt.close()
|
| 754 |
+
log(f" ✅ 训练集 ROC/PR 已生成(各模型独立图 + 汇总图)")
|
| 755 |
|
| 756 |
# ====================================================================
|
| 757 |
+
# ★★★ 新增 Part-2:最终模型 Train vs Internal CV 对比(含扩展指标+CI)
|
|
|
|
|
|
|
|
|
|
| 758 |
# ====================================================================
|
| 759 |
progress(0.59, desc="📊 [新增] 最终模型Train vs CV对比...")
|
| 760 |
log(f"\n 📊 [新增] 最终模型 [{best_mn}] 训练集 vs 内部验证集(CV holdout)...")
|
| 761 |
|
|
|
|
| 762 |
yp_best_tr = tms[best_mn].predict_proba(X.values)
|
| 763 |
yd_best_tr = tms[best_mn].predict(X.values)
|
|
|
|
|
|
|
| 764 |
|
|
|
|
| 765 |
yp_best_cv = amr[best_mn]['all_yproba']
|
| 766 |
yd_best_cv = amr[best_mn]['all_yp']
|
| 767 |
yt_best_cv = amr[best_mn]['all_yt']
|
| 768 |
+
|
| 769 |
+
# 扩展指标(含95%CI)
|
| 770 |
+
met_tr_ext, pc_tr = compute_extended_metrics_with_ci(
|
| 771 |
+
y_mapped.values, yd_best_tr, yp_best_tr, class_indices, n_bootstrap=1000, seed=RS)
|
| 772 |
+
met_cv_ext, pc_cv = compute_extended_metrics_with_ci(
|
| 773 |
+
yt_best_cv, yd_best_cv, yp_best_cv, class_indices, n_bootstrap=1000, seed=RS)
|
| 774 |
+
|
| 775 |
+
# 保留原有兼容字段
|
| 776 |
+
met_tr = {k: met_tr_ext[k] for k in ['Accuracy','Macro_AUC','Macro_F1','Weighted_F1','Kappa','report']}
|
| 777 |
+
met_cv = {k: met_cv_ext[k] for k in ['Accuracy','Macro_AUC','Macro_F1','Weighted_F1','Kappa','report']}
|
| 778 |
|
| 779 |
log(f" Train → AUC={met_tr['Macro_AUC']:.4f} Acc={met_tr['Accuracy']:.4f}"
|
| 780 |
+
f" Sens={met_tr_ext['Sensitivity']:.4f} Spec={met_tr_ext['Specificity']:.4f}")
|
| 781 |
log(f" CV-Val → AUC={met_cv['Macro_AUC']:.4f} Acc={met_cv['Accuracy']:.4f}"
|
| 782 |
+
f" Sens={met_cv_ext['Sensitivity']:.4f} Spec={met_cv_ext['Specificity']:.4f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 783 |
|
| 784 |
+
# 对比 ROC
|
| 785 |
+
fpr_tb, tpr_tb, auc_tb = _macro_roc_arrays(y_mapped.values, yp_best_tr, n_classes, class_indices)
|
| 786 |
+
fpr_cb, tpr_cb, auc_cb = _macro_roc_arrays(yt_best_cv, yp_best_cv, n_classes, class_indices)
|
| 787 |
fig, ax = plt.subplots(figsize=(10, 8))
|
| 788 |
+
ax.plot(fpr_tb, tpr_tb, color='#e41a1c', lw=2.5, label=f'Train set (Macro AUC={auc_tb:.3f})')
|
| 789 |
+
ax.plot(fpr_cb, tpr_cb, color='#377eb8', lw=2.5, linestyle='--', label=f'Internal CV (Macro AUC={auc_cb:.3f})')
|
| 790 |
+
ax.plot([0,1],[0,1],'--',color='#ccc',lw=1)
|
| 791 |
+
ax.set_xlim([-0.02,1.02]); ax.set_ylim([-0.02,1.02])
|
| 792 |
+
ax.set_xlabel('False Positive Rate',fontsize=13); ax.set_ylabel('True Positive Rate',fontsize=13)
|
| 793 |
+
ax.set_title(f'ROC — {best_mn}: Train vs Internal CV ({task_type})',fontsize=14,fontweight='bold')
|
| 794 |
+
ax.legend(loc='lower right',fontsize=11); ax.grid(True,alpha=0.15); plt.tight_layout()
|
| 795 |
+
plt.savefig(os.path.join(rf,f'roc_train_vs_cv_{best_mn}.pdf'),format='pdf',bbox_inches='tight',dpi=300)
|
| 796 |
+
plt.savefig(os.path.join(rf,f'roc_train_vs_cv_{best_mn}.png'),format='png',bbox_inches='tight',dpi=150)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 797 |
plt.close()
|
| 798 |
|
| 799 |
+
# 对比 PR
|
| 800 |
+
rec_tb, prec_tb = _macro_pr_arrays(y_mapped.values, yp_best_tr, n_classes, class_indices)
|
| 801 |
+
rec_cb, prec_cb = _macro_pr_arrays(yt_best_cv, yp_best_cv, n_classes, class_indices)
|
|
|
|
|
|
|
|
|
|
| 802 |
fig, ax = plt.subplots(figsize=(10, 8))
|
| 803 |
+
ax.plot(rec_tb, prec_tb, color='#e41a1c', lw=2.5, label=f'Train set (Mean AP={prec_tb.mean():.3f})')
|
| 804 |
+
ax.plot(rec_cb, prec_cb, color='#377eb8', lw=2.5, linestyle='--', label=f'Internal CV (Mean AP={prec_cb.mean():.3f})')
|
| 805 |
+
ax.set_xlim([-0.02,1.02]); ax.set_ylim([-0.02,1.02])
|
| 806 |
+
ax.set_xlabel('Recall',fontsize=13); ax.set_ylabel('Precision',fontsize=13)
|
| 807 |
+
ax.set_title(f'PR — {best_mn}: Train vs Internal CV ({task_type})',fontsize=14,fontweight='bold')
|
| 808 |
+
ax.legend(loc='lower left',fontsize=11); ax.grid(True,alpha=0.15); plt.tight_layout()
|
| 809 |
+
plt.savefig(os.path.join(rf,f'pr_train_vs_cv_{best_mn}.pdf'),format='pdf',bbox_inches='tight',dpi=300)
|
| 810 |
+
plt.savefig(os.path.join(rf,f'pr_train_vs_cv_{best_mn}.png'),format='png',bbox_inches='tight',dpi=150)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 811 |
plt.close()
|
| 812 |
|
| 813 |
+
plot_confusion_matrix(y_mapped.values, yd_best_tr, class_indices,
|
| 814 |
+
f'Train CM — {best_mn} (Acc={met_tr["Accuracy"]:.3f})', f'cm_train_{best_mn}', rf)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 815 |
|
| 816 |
+
# Train vs CV Excel(含扩展指标+CI)
|
| 817 |
+
with pd.ExcelWriter(os.path.join(rf,f'train_vs_cv_{best_mn}.xlsx'),engine='openpyxl') as w:
|
| 818 |
+
df_tr = build_metrics_summary_df(met_tr_ext, best_mn, 'Train')
|
| 819 |
+
df_cv = build_metrics_summary_df(met_cv_ext, best_mn, 'Internal_CV')
|
| 820 |
+
pd.concat([df_tr, df_cv], ignore_index=True).to_excel(w, sheet_name='Summary_with_CI', index=False)
|
| 821 |
+
# 简洁数字版(兼容旧格式)
|
| 822 |
pd.DataFrame([
|
| 823 |
+
{'Split':'Train','Model':best_mn,
|
| 824 |
+
'Macro_AUC':met_tr_ext['Macro_AUC'],'Accuracy':met_tr_ext['Accuracy'],
|
| 825 |
+
'Sensitivity':met_tr_ext['Sensitivity'],'Specificity':met_tr_ext['Specificity'],
|
| 826 |
+
'PPV':met_tr_ext['PPV'],'NPV':met_tr_ext['NPV'],
|
| 827 |
+
'Macro_F1':met_tr_ext['Macro_F1'],'Weighted_F1':met_tr_ext['Weighted_F1'],
|
| 828 |
+
'Kappa':met_tr_ext['Kappa']},
|
| 829 |
+
{'Split':'Internal_CV','Model':best_mn,
|
| 830 |
+
'Macro_AUC':met_cv_ext['Macro_AUC'],'Accuracy':met_cv_ext['Accuracy'],
|
| 831 |
+
'Sensitivity':met_cv_ext['Sensitivity'],'Specificity':met_cv_ext['Specificity'],
|
| 832 |
+
'PPV':met_cv_ext['PPV'],'NPV':met_cv_ext['NPV'],
|
| 833 |
+
'Macro_F1':met_cv_ext['Macro_F1'],'Weighted_F1':met_cv_ext['Weighted_F1'],
|
| 834 |
+
'Kappa':met_cv_ext['Kappa']},
|
| 835 |
+
]).to_excel(w, sheet_name='Summary_numeric', index=False)
|
| 836 |
+
pc_tr.to_excel(w, sheet_name='Train_PerClass', index=False)
|
| 837 |
+
pc_cv.to_excel(w, sheet_name='CV_PerClass', index=False)
|
| 838 |
amr[best_mn]['fold_df'].to_excel(w, sheet_name='CV_FoldDetail', index=False)
|
| 839 |
|
| 840 |
+
log(f" ✅ Train vs CV 对比(含95%CI)已保存 → train_vs_cv_{best_mn}.xlsx")
|
| 841 |
# ====================================================================
|
| 842 |
# ★★★ 新增结束
|
| 843 |
# ====================================================================
|
|
|
|
| 846 |
progress(0.62, desc="🔥 SHAP分析...")
|
| 847 |
log(f"\n 🔥 SHAP特征分析 (保留模型中 Top 3)...")
|
| 848 |
shap_imp = {}
|
|
|
|
| 849 |
models_for_shap = sorted(retained, key=lambda x: amr[x]['mean_auc'], reverse=True)[:3]
|
| 850 |
|
| 851 |
for si, mn in enumerate(models_for_shap):
|
|
|
|
| 862 |
exp = shap.KernelExplainer(lambda x, m=mo: m.predict_proba(x), bg)
|
| 863 |
sv = exp.shap_values(Xs)
|
| 864 |
|
|
|
|
| 865 |
if isinstance(sv, list):
|
|
|
|
| 866 |
sv_abs = np.mean([np.abs(s) for s in sv], axis=0)
|
| 867 |
elif sv.ndim == 3:
|
| 868 |
+
sv_abs = np.mean(np.abs(sv), axis=2)
|
| 869 |
else:
|
| 870 |
sv_abs = np.abs(sv)
|
| 871 |
|
|
|
|
| 876 |
idf = pd.DataFrame({'Feature': fnames, 'Importance': fi}).sort_values('Importance', ascending=False)
|
| 877 |
shap_imp[mn] = idf
|
| 878 |
|
|
|
|
| 879 |
plt.figure(figsize=(10, max(6, TOPN * 0.3)))
|
| 880 |
top_df = idf.head(TOPN).iloc[::-1]
|
| 881 |
plt.barh(top_df['Feature'], top_df['Importance'], color='#2563eb', alpha=0.8)
|
|
|
|
| 889 |
except Exception as e:
|
| 890 |
log(f" ⚠ {mn} SHAP失败: {e}")
|
| 891 |
|
| 892 |
+
# ── Feature Ablation ── 【原有代码,原封不动;Excel 中新增 p 值列】
|
| 893 |
progress(0.72, desc="🧪 特征消融...")
|
| 894 |
log(f"\n 🧪 特征消融 (仅最佳模型 {best_mn})...")
|
| 895 |
ablation_data = None
|
|
|
|
| 897 |
imp_df = shap_imp[best_mn]
|
| 898 |
top_feats = imp_df.head(TOPN)['Feature'].tolist()
|
| 899 |
fcs = []; aucs_a = []
|
| 900 |
+
# 同时收集每步全量 CV holdout 概率(用于相邻步 p 值)
|
| 901 |
+
all_probas_per_step = []
|
| 902 |
|
| 903 |
for nf in range(1, len(top_feats) + 1):
|
| 904 |
Xsub = X[top_feats[:nf]]
|
| 905 |
fold_aucs = []
|
| 906 |
+
step_yt_all = []; step_yp_all = []; step_yproba_all = []
|
| 907 |
+
|
| 908 |
for tri, tei in skf.split(Xsub, y_mapped):
|
| 909 |
mf = deepcopy(mcfg[best_mn]['model'])
|
| 910 |
bp2 = bpd.get(best_mn, {})
|
|
|
|
| 919 |
a = roc_auc_score(yte_f, yproba_f, multi_class='ovr', average='macro')
|
| 920 |
except: a = 0.0
|
| 921 |
fold_aucs.append(a)
|
| 922 |
+
step_yt_all.extend(yte_f.tolist())
|
| 923 |
+
step_yp_all.extend(mf.predict(Xsub.iloc[tei].values).tolist())
|
| 924 |
+
step_yproba_all.append(yproba_f)
|
| 925 |
+
|
| 926 |
fcs.append(nf); aucs_a.append(np.mean(fold_aucs))
|
| 927 |
+
all_probas_per_step.append({
|
| 928 |
+
'yt': np.array(step_yt_all),
|
| 929 |
+
'yproba': np.vstack(step_yproba_all)
|
| 930 |
+
})
|
| 931 |
|
|
|
|
| 932 |
full_auc = amr[best_mn]['mean_auc']
|
| 933 |
opt_n = len(top_feats)
|
| 934 |
for i, a in enumerate(aucs_a):
|
| 935 |
if a >= full_auc * 0.95:
|
| 936 |
opt_n = i + 1; break
|
| 937 |
|
| 938 |
+
ablation_data = {
|
| 939 |
+
'fcs': fcs, 'aucs': aucs_a, 'feats': top_feats,
|
| 940 |
+
'opt_n': opt_n, 'opt_feats': top_feats[:opt_n]
|
| 941 |
+
}
|
| 942 |
log(f" ✅ 最优特征数: {opt_n} (AUC={aucs_a[opt_n-1]:.4f} vs Full={full_auc:.4f})")
|
| 943 |
|
| 944 |
+
# 计算相邻特征数 Bootstrap p 值(vs full-feature model)
|
| 945 |
+
ref_step = all_probas_per_step[-1] # full features
|
| 946 |
+
ablation_pvals = []
|
| 947 |
+
for si2, step in enumerate(all_probas_per_step):
|
| 948 |
+
if si2 == len(all_probas_per_step) - 1:
|
| 949 |
+
ablation_pvals.append(np.nan) # full vs full
|
| 950 |
+
continue
|
| 951 |
+
p_v, _, _, _, _ = bootstrap_auc_test(
|
| 952 |
+
ref_step['yt'], ref_step['yproba'], step['yproba'],
|
| 953 |
+
class_indices, n_bootstrap=500, seed=RS
|
| 954 |
+
)
|
| 955 |
+
ablation_pvals.append(p_v)
|
| 956 |
+
|
| 957 |
+
# Plot(原有不变)
|
| 958 |
plt.figure(figsize=(10, 7))
|
| 959 |
plt.plot(fcs, aucs_a, 'o-', color='#2563eb', lw=2, ms=5)
|
| 960 |
+
plt.scatter([opt_n], [aucs_a[opt_n-1]], s=200, marker='*',
|
| 961 |
+
color='#ef4444', edgecolors='black', lw=2, zorder=5)
|
| 962 |
+
plt.axhline(y=full_auc, color='gray', ls='--', lw=1, alpha=0.5,
|
| 963 |
+
label=f'Full AUC={full_auc:.3f}')
|
| 964 |
plt.xlabel('Number of Features', fontsize=13); plt.ylabel('Macro AUC', fontsize=13)
|
| 965 |
+
plt.title(f'Feature Ablation — {best_mn} (★ Optimal={opt_n})',
|
| 966 |
+
fontsize=14, fontweight='bold')
|
| 967 |
plt.legend(fontsize=11); plt.grid(True, alpha=0.15); plt.tight_layout()
|
| 968 |
plt.savefig(os.path.join(rf, 'ablation.pdf'), format='pdf', bbox_inches='tight')
|
| 969 |
plt.savefig(os.path.join(rf, 'ablation.png'), format='png', bbox_inches='tight', dpi=150)
|
| 970 |
plt.close()
|
| 971 |
|
| 972 |
+
# ── External Validation ── 【原有代码,原封不动;Excel 新增扩展指标】
|
| 973 |
val_files_list = [vf for vf in [val_file1, val_file2, val_file3] if vf is not None]
|
| 974 |
final_feats = ablation_data['opt_feats'] if ablation_data else fnames
|
| 975 |
|
|
|
|
| 985 |
vcol2_is_id = (vcol2.dtype == 'object') or (vcol2.nunique() / len(vcol2) > 0.5)
|
| 986 |
Xe = ed.iloc[:, 2:] if vcol2_is_id else ed.iloc[:, 1:]
|
| 987 |
|
|
|
|
| 988 |
ye = ye_raw.map(label_map)
|
| 989 |
if ye.isna().any():
|
| 990 |
log(f" ⚠ 验证集 {vi} 含有训练集中不存在的标签,已跳过")
|
|
|
|
| 1000 |
yep = fm.predict_proba(Xes.values); yed = fm.predict(Xes.values)
|
| 1001 |
ye_np = ye.values
|
| 1002 |
|
| 1003 |
+
# 原有基础指标
|
| 1004 |
metrics = compute_multiclass_metrics(ye_np, yed, yep, class_indices)
|
| 1005 |
+
# 新增扩展指标
|
| 1006 |
+
met_ext_vi, pc_vi = compute_extended_metrics_with_ci(
|
| 1007 |
+
ye_np, yed, yep, class_indices, n_bootstrap=1000, seed=RS)
|
| 1008 |
+
|
| 1009 |
+
log(f" ✅ AUC={metrics['Macro_AUC']:.4f} Acc={metrics['Accuracy']:.4f}"
|
| 1010 |
+
f" Sens={met_ext_vi['Sensitivity']:.4f} Spec={met_ext_vi['Specificity']:.4f}"
|
| 1011 |
+
f" PPV={met_ext_vi['PPV']:.4f} NPV={met_ext_vi['NPV']:.4f}"
|
| 1012 |
+
f" F1={metrics['Macro_F1']:.4f} Kappa={metrics['Kappa']:.4f}")
|
| 1013 |
|
| 1014 |
sfx = f'_ext{vi}' if len(val_files_list) > 1 else '_ext'
|
| 1015 |
tag = f'Validation {vi}' if len(val_files_list) > 1 else 'External'
|
|
|
|
| 1019 |
plot_confusion_matrix(ye_np, yed, class_indices, f'CM — {tag} ({best_mn})', f'cm{sfx}', rf)
|
| 1020 |
|
| 1021 |
with pd.ExcelWriter(os.path.join(rf, f'validation{sfx}.xlsx'), engine='openpyxl') as w:
|
| 1022 |
+
# 原有 Metrics sheet(保持兼容)
|
| 1023 |
pd.DataFrame([{'Model': best_mn, 'N_Features': len(final_feats),
|
| 1024 |
'Macro_AUC': metrics['Macro_AUC'], 'Accuracy': metrics['Accuracy'],
|
| 1025 |
'Macro_F1': metrics['Macro_F1'], 'Weighted_F1': metrics['Weighted_F1'],
|
| 1026 |
'Kappa': metrics['Kappa']}]).to_excel(w, sheet_name='Metrics', index=False)
|
| 1027 |
+
# 新增:含 Sensitivity/Specificity/PPV/NPV + 95%CI
|
| 1028 |
+
build_metrics_summary_df(met_ext_vi, best_mn, tag).to_excel(
|
| 1029 |
+
w, sheet_name='Metrics_with_CI', index=False)
|
| 1030 |
+
pc_vi.to_excel(w, sheet_name='PerClass_detail', index=False)
|
| 1031 |
rpt = pd.DataFrame(metrics['report']).T
|
| 1032 |
rpt.to_excel(w, sheet_name='Per_Class', index=True)
|
| 1033 |
pd.DataFrame({'Feature': final_feats}).to_excel(w, sheet_name='Features', index=False)
|
| 1034 |
|
| 1035 |
+
# ── Save Results ── 【原有代码,原封不动;新增扩展指标到 model_evaluation.xlsx】
|
| 1036 |
progress(0.92, desc="💾 保存结果...")
|
| 1037 |
log(f"\n 💾 保存结果...")
|
| 1038 |
|
| 1039 |
+
# 为所有模型计算 CV holdout 扩展指标(含 CI)
|
| 1040 |
+
log(f" 🔬 [新增] 计算各模型完整诊断指标 + 95%CI(Bootstrap n=1000)...")
|
| 1041 |
+
all_ext_metrics = {}
|
| 1042 |
+
all_per_class = {}
|
| 1043 |
+
for mn in mnames:
|
| 1044 |
+
r = amr[mn]
|
| 1045 |
+
ext_m, pc_m = compute_extended_metrics_with_ci(
|
| 1046 |
+
r['all_yt'], r['all_yp'], r['all_yproba'],
|
| 1047 |
+
class_indices, n_bootstrap=1000, seed=RS)
|
| 1048 |
+
all_ext_metrics[mn] = ext_m
|
| 1049 |
+
all_per_class[mn] = pc_m
|
| 1050 |
+
|
| 1051 |
with pd.ExcelWriter(os.path.join(rf, 'model_evaluation.xlsx'), engine='openpyxl') as w:
|
| 1052 |
+
# 原有:各模型分折明细
|
| 1053 |
for mn, r in amr.items():
|
| 1054 |
r['fold_df'].to_excel(w, sheet_name=mn, index=False)
|
| 1055 |
+
|
| 1056 |
+
# 原有:Summary(保持原格式不变)
|
| 1057 |
sd = [{'Model': mn, 'Macro_AUC': r['mean_auc'], 'Accuracy': r['mean_acc'],
|
| 1058 |
'Macro_F1': r['mean_f1'], 'Retained': 'Yes' if mn in retained else 'No',
|
| 1059 |
'Best': 'Best' if mn == best_mn else ''}
|
| 1060 |
for mn, r in amr.items()]
|
| 1061 |
+
pd.DataFrame(sd).sort_values('Macro_AUC', ascending=False).to_excel(
|
| 1062 |
+
w, sheet_name='Summary', index=False)
|
| 1063 |
+
|
| 1064 |
+
# 原有:Bootstrap 检验
|
| 1065 |
if len(bootstrap_df) > 0:
|
| 1066 |
bootstrap_df.to_excel(w, sheet_name='Bootstrap_Test', index=False)
|
| 1067 |
+
|
| 1068 |
+
# 原有:最佳模型 PerClass
|
| 1069 |
best_report = classification_report(amr[best_mn]['all_yt'], amr[best_mn]['all_yp'],
|
| 1070 |
labels=class_indices, output_dict=True, zero_division=0)
|
| 1071 |
pd.DataFrame(best_report).T.to_excel(w, sheet_name=f'{best_mn}_PerClass', index=True)
|
| 1072 |
|
| 1073 |
+
# ★ 新增:所有模型完整指标 + 95%CI(纵向汇总)
|
| 1074 |
+
rows_ci = []
|
| 1075 |
+
for mn in mnames:
|
| 1076 |
+
row = build_metrics_summary_df(all_ext_metrics[mn], mn, 'CV_holdout')
|
| 1077 |
+
rows_ci.append(row)
|
| 1078 |
+
pd.concat(rows_ci, ignore_index=True).to_excel(
|
| 1079 |
+
w, sheet_name='All_Models_Metrics_CI', index=False)
|
| 1080 |
+
|
| 1081 |
+
# ★ 新增:每个模型逐类详细指标
|
| 1082 |
+
for mn in mnames:
|
| 1083 |
+
sheet = f'{mn}_PerClass_detail'[:31] # Excel sheet name limit
|
| 1084 |
+
all_per_class[mn].to_excel(w, sheet_name=sheet, index=False)
|
| 1085 |
+
|
| 1086 |
+
# 特征消融 Excel(原有基础上新增 p 值列)
|
| 1087 |
if ablation_data:
|
| 1088 |
with pd.ExcelWriter(os.path.join(rf, 'feature_ablation.xlsx'), engine='openpyxl') as w:
|
| 1089 |
+
# ★ 新增:Ablation sheet 加入 p 值
|
| 1090 |
+
abl_df = pd.DataFrame({
|
| 1091 |
+
'N': ablation_data['fcs'],
|
| 1092 |
+
'AUC': ablation_data['aucs'],
|
| 1093 |
+
'P_vs_full (Bootstrap)': ablation_pvals, # NaN for full
|
| 1094 |
+
})
|
| 1095 |
+
abl_df['Significant (p<0.05)'] = abl_df['P_vs_full (Bootstrap)'].apply(
|
| 1096 |
+
lambda x: 'Yes' if (not np.isnan(x) and x < 0.05) else ('No' if not np.isnan(x) else 'Ref'))
|
| 1097 |
+
abl_df.to_excel(w, sheet_name='Ablation', index=False)
|
| 1098 |
+
|
| 1099 |
for mn, idf in shap_imp.items():
|
| 1100 |
idf.to_excel(w, sheet_name=f'{mn}_Imp', index=False)
|
| 1101 |
|
| 1102 |
+
# Save params 【原有代码,原封不动】
|
| 1103 |
with open(os.path.join(rf, 'best_params.txt'), 'w', encoding='utf-8') as f:
|
| 1104 |
f.write(f"Task: {task_type} Classification ({n_classes} classes)\n")
|
| 1105 |
f.write(f"Classes: {classes}\n")
|
|
|
|
| 1125 |
if ablation_data:
|
| 1126 |
f.write(f"\nOptimal Features ({ablation_data['opt_n']}): {', '.join(ablation_data['opt_feats'])}\n")
|
| 1127 |
|
| 1128 |
+
# Save model 【原有代码,原封不动】
|
| 1129 |
pickle.dump({
|
| 1130 |
'model_name': best_mn, 'model': tms[best_mn], 'best_params': bpd[best_mn],
|
| 1131 |
'classes': classes, 'n_classes': n_classes, 'label_map': label_map,
|