Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -431,54 +431,254 @@ def run_pipeline(
|
|
| 431 |
log(f"\n{'━'*50}")
|
| 432 |
log(f" ✅ {nm} 个模型训练完成")
|
| 433 |
|
| 434 |
-
#
|
| 435 |
-
|
| 436 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
for mn in mnames:
|
| 438 |
r = amr[mn]
|
| 439 |
plot_multiclass_roc(r['all_yt'], r['all_yproba'], class_indices,
|
| 440 |
-
f'ROC — {mn} ({task_type}, Macro AUC={r["mean_auc"]:.3f})', f'roc_{mn}', rf)
|
| 441 |
|
| 442 |
-
#
|
| 443 |
plt.figure(figsize=(10, 8))
|
| 444 |
for i, mn in enumerate(mnames):
|
| 445 |
r = amr[mn]
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
plt.
|
| 456 |
-
plt.
|
| 457 |
-
plt.
|
| 458 |
-
plt.title(f'ROC — All Models ({task_type})',fontsize=14,fontweight='bold')
|
| 459 |
-
plt.legend(loc='lower right',fontsize=10); plt.grid(True,alpha=0.15); plt.tight_layout()
|
| 460 |
-
plt.savefig(os.path.join(rf,'roc_all.pdf'),format='pdf',bbox_inches='tight',dpi=300)
|
| 461 |
-
plt.savefig(os.path.join(rf,'roc_all.png'),format='png',bbox_inches='tight',dpi=150)
|
| 462 |
plt.close()
|
| 463 |
|
| 464 |
-
#
|
| 465 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 466 |
for mn in mnames:
|
| 467 |
r = amr[mn]
|
| 468 |
plot_multiclass_pr(r['all_yt'], r['all_yproba'], class_indices,
|
| 469 |
-
f'PR — {mn} ({task_type})', f'pr_{mn}', rf)
|
| 470 |
|
| 471 |
-
# ── Confusion Matrices ──
|
| 472 |
progress(0.52, desc="📊 混淆矩阵...")
|
| 473 |
for mn in mnames:
|
| 474 |
r = amr[mn]
|
| 475 |
plot_confusion_matrix(r['all_yt'], r['all_yp'], class_indices,
|
| 476 |
-
f'CM — {mn} (Acc={r["mean_acc"]:.3f})', f'cm_{mn}', rf)
|
| 477 |
|
| 478 |
-
# ── Bootstrap AUC Test
|
| 479 |
progress(0.55, desc="🔬 Bootstrap AUC 检验...")
|
| 480 |
-
best_mn = max(amr, key=lambda x: amr[x]['mean_auc'])
|
| 481 |
-
best_auc = amr[best_mn]['mean_auc']
|
| 482 |
log(f"\n 🏆 最佳模型: {best_mn} (Macro AUC={best_auc:.4f})")
|
| 483 |
log(f" 🔬 Bootstrap 检验 (n=2000, α=0.05)...")
|
| 484 |
|
|
@@ -517,7 +717,6 @@ def run_pipeline(
|
|
| 517 |
progress(0.62, desc="🔥 SHAP分析...")
|
| 518 |
log(f"\n 🔥 SHAP特征分析 (保留模型中 Top 3)...")
|
| 519 |
shap_imp = {}
|
| 520 |
-
# SHAP for top 3 retained models
|
| 521 |
models_for_shap = sorted(retained, key=lambda x: amr[x]['mean_auc'], reverse=True)[:3]
|
| 522 |
|
| 523 |
for si, mn in enumerate(models_for_shap):
|
|
@@ -534,12 +733,10 @@ def run_pipeline(
|
|
| 534 |
exp = shap.KernelExplainer(lambda x, m=mo: m.predict_proba(x), bg)
|
| 535 |
sv = exp.shap_values(Xs)
|
| 536 |
|
| 537 |
-
# Handle SHAP output: could be list of arrays (one per class) or 3D array
|
| 538 |
if isinstance(sv, list):
|
| 539 |
-
# Average absolute SHAP across all classes
|
| 540 |
sv_abs = np.mean([np.abs(s) for s in sv], axis=0)
|
| 541 |
elif sv.ndim == 3:
|
| 542 |
-
sv_abs = np.mean(np.abs(sv), axis=2)
|
| 543 |
else:
|
| 544 |
sv_abs = np.abs(sv)
|
| 545 |
|
|
@@ -550,7 +747,6 @@ def run_pipeline(
|
|
| 550 |
idf = pd.DataFrame({'Feature': fnames, 'Importance': fi}).sort_values('Importance', ascending=False)
|
| 551 |
shap_imp[mn] = idf
|
| 552 |
|
| 553 |
-
# Bar plot (works for any number of classes)
|
| 554 |
plt.figure(figsize=(10, max(6, TOPN * 0.3)))
|
| 555 |
top_df = idf.head(TOPN).iloc[::-1]
|
| 556 |
plt.barh(top_df['Feature'], top_df['Importance'], color='#2563eb', alpha=0.8)
|
|
@@ -564,7 +760,7 @@ def run_pipeline(
|
|
| 564 |
except Exception as e:
|
| 565 |
log(f" ⚠ {mn} SHAP失败: {e}")
|
| 566 |
|
| 567 |
-
# ── Feature Ablation
|
| 568 |
progress(0.72, desc="🧪 特征消融...")
|
| 569 |
log(f"\n 🧪 特征消融 (仅最佳模型 {best_mn})...")
|
| 570 |
ablation_data = None
|
|
@@ -572,7 +768,6 @@ def run_pipeline(
|
|
| 572 |
imp_df = shap_imp[best_mn]
|
| 573 |
top_feats = imp_df.head(TOPN)['Feature'].tolist()
|
| 574 |
fcs = []; aucs_a = []
|
| 575 |
-
scoring = 'roc_auc_ovr' if n_classes > 2 else 'roc_auc'
|
| 576 |
|
| 577 |
for nf in range(1, len(top_feats) + 1):
|
| 578 |
Xsub = X[top_feats[:nf]]
|
|
@@ -593,23 +788,25 @@ def run_pipeline(
|
|
| 593 |
fold_aucs.append(a)
|
| 594 |
fcs.append(nf); aucs_a.append(np.mean(fold_aucs))
|
| 595 |
|
| 596 |
-
# Find optimal: first N where AUC >= 95% of full AUC
|
| 597 |
full_auc = amr[best_mn]['mean_auc']
|
| 598 |
opt_n = len(top_feats)
|
| 599 |
for i, a in enumerate(aucs_a):
|
| 600 |
if a >= full_auc * 0.95:
|
| 601 |
opt_n = i + 1; break
|
| 602 |
|
| 603 |
-
ablation_data = {'fcs': fcs, 'aucs': aucs_a, 'feats': top_feats,
|
|
|
|
| 604 |
log(f" ✅ 最优特征数: {opt_n} (AUC={aucs_a[opt_n-1]:.4f} vs Full={full_auc:.4f})")
|
| 605 |
|
| 606 |
-
# Plot
|
| 607 |
plt.figure(figsize=(10, 7))
|
| 608 |
plt.plot(fcs, aucs_a, 'o-', color='#2563eb', lw=2, ms=5)
|
| 609 |
-
plt.scatter([opt_n], [aucs_a[opt_n-1]], s=200, marker='*',
|
| 610 |
-
|
|
|
|
|
|
|
| 611 |
plt.xlabel('Number of Features', fontsize=13); plt.ylabel('Macro AUC', fontsize=13)
|
| 612 |
-
plt.title(f'Feature Ablation — {best_mn} (★ Optimal={opt_n})',
|
|
|
|
| 613 |
plt.legend(fontsize=11); plt.grid(True, alpha=0.15); plt.tight_layout()
|
| 614 |
plt.savefig(os.path.join(rf, 'ablation.pdf'), format='pdf', bbox_inches='tight')
|
| 615 |
plt.savefig(os.path.join(rf, 'ablation.png'), format='png', bbox_inches='tight', dpi=150)
|
|
@@ -631,7 +828,6 @@ def run_pipeline(
|
|
| 631 |
vcol2_is_id = (vcol2.dtype == 'object') or (vcol2.nunique() / len(vcol2) > 0.5)
|
| 632 |
Xe = ed.iloc[:, 2:] if vcol2_is_id else ed.iloc[:, 1:]
|
| 633 |
|
| 634 |
-
# Map validation labels using same mapping
|
| 635 |
ye = ye_raw.map(label_map)
|
| 636 |
if ye.isna().any():
|
| 637 |
log(f" ⚠ 验证集 {vi} 含有训练集中不存在的标签,已跳过")
|
|
@@ -648,22 +844,25 @@ def run_pipeline(
|
|
| 648 |
ye_np = ye.values
|
| 649 |
|
| 650 |
metrics = compute_multiclass_metrics(ye_np, yed, yep, class_indices)
|
| 651 |
-
log(f" ✅ AUC={metrics['Macro_AUC']:.4f} Acc={metrics['Accuracy']:.4f}
|
|
|
|
| 652 |
|
| 653 |
sfx = f'_ext{vi}' if len(val_files_list) > 1 else '_ext'
|
| 654 |
tag = f'Validation {vi}' if len(val_files_list) > 1 else 'External'
|
| 655 |
|
| 656 |
-
plot_multiclass_roc(ye_np, yep, class_indices,
|
| 657 |
-
|
| 658 |
-
|
|
|
|
|
|
|
|
|
|
| 659 |
|
| 660 |
with pd.ExcelWriter(os.path.join(rf, f'validation{sfx}.xlsx'), engine='openpyxl') as w:
|
| 661 |
pd.DataFrame([{'Model': best_mn, 'N_Features': len(final_feats),
|
| 662 |
'Macro_AUC': metrics['Macro_AUC'], 'Accuracy': metrics['Accuracy'],
|
| 663 |
'Macro_F1': metrics['Macro_F1'], 'Weighted_F1': metrics['Weighted_F1'],
|
| 664 |
'Kappa': metrics['Kappa']}]).to_excel(w, sheet_name='Metrics', index=False)
|
| 665 |
-
|
| 666 |
-
rpt.to_excel(w, sheet_name='Per_Class', index=True)
|
| 667 |
pd.DataFrame({'Feature': final_feats}).to_excel(w, sheet_name='Features', index=False)
|
| 668 |
|
| 669 |
# ── Save Results ──
|
|
@@ -673,27 +872,32 @@ def run_pipeline(
|
|
| 673 |
with pd.ExcelWriter(os.path.join(rf, 'model_evaluation.xlsx'), engine='openpyxl') as w:
|
| 674 |
for mn, r in amr.items():
|
| 675 |
r['fold_df'].to_excel(w, sheet_name=mn, index=False)
|
| 676 |
-
# Summary
|
| 677 |
-
sd = [{'Model': mn,
|
| 678 |
-
'
|
|
|
|
|
|
|
|
|
|
|
|
|
| 679 |
'Best': 'Best' if mn == best_mn else ''}
|
| 680 |
for mn, r in amr.items()]
|
| 681 |
-
pd.DataFrame(sd).sort_values('
|
| 682 |
-
|
| 683 |
if len(bootstrap_df) > 0:
|
| 684 |
bootstrap_df.to_excel(w, sheet_name='Bootstrap_Test', index=False)
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
|
| 688 |
pd.DataFrame(best_report).T.to_excel(w, sheet_name=f'{best_mn}_PerClass', index=True)
|
| 689 |
|
| 690 |
if ablation_data:
|
| 691 |
with pd.ExcelWriter(os.path.join(rf, 'feature_ablation.xlsx'), engine='openpyxl') as w:
|
| 692 |
-
pd.DataFrame({'N': ablation_data['fcs'], 'AUC': ablation_data['aucs']}).to_excel(
|
|
|
|
| 693 |
for mn, idf in shap_imp.items():
|
| 694 |
idf.to_excel(w, sheet_name=f'{mn}_Imp', index=False)
|
| 695 |
|
| 696 |
-
# Save params
|
| 697 |
with open(os.path.join(rf, 'best_params.txt'), 'w', encoding='utf-8') as f:
|
| 698 |
f.write(f"Task: {task_type} Classification ({n_classes} classes)\n")
|
| 699 |
f.write(f"Classes: {classes}\n")
|
|
@@ -702,7 +906,8 @@ def run_pipeline(
|
|
| 702 |
f.write(f"Retained Models: {', '.join(retained)} ({len(retained)}/{nm})\n\n")
|
| 703 |
for mn in mcfg:
|
| 704 |
status = "* Best" if mn == best_mn else ("Retained" if mn in retained else "Excluded")
|
| 705 |
-
f.write(f"Model: {mn} |
|
|
|
|
| 706 |
bp = bpd[mn]
|
| 707 |
if isinstance(bp, dict):
|
| 708 |
for k, v in bp.items(): f.write(f" {k}: {v}\n")
|
|
@@ -714,10 +919,23 @@ def run_pipeline(
|
|
| 714 |
f.write("=" * 50 + "\n")
|
| 715 |
for _, row in bootstrap_df.iterrows():
|
| 716 |
f.write(f" {row['Model_A']} vs {row['Model_B']}: ")
|
| 717 |
-
f.write(f"dAUC={row['AUC_Diff']:+.4f}
|
|
|
|
| 718 |
f.write(f"P={row['P_value']:.4f} -> {row['Decision']}\n")
|
| 719 |
if ablation_data:
|
| 720 |
-
f.write(f"\nOptimal Features ({ablation_data['opt_n']}):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 721 |
|
| 722 |
# Save model
|
| 723 |
pickle.dump({
|
|
@@ -728,10 +946,13 @@ def run_pipeline(
|
|
| 728 |
|
| 729 |
# ── ZIP ──
|
| 730 |
progress(0.97, desc="📦 打包ZIP...")
|
| 731 |
-
zp = os.path.join(tempfile.gettempdir(),
|
|
|
|
| 732 |
with zipfile.ZipFile(zp, 'w', zipfile.ZIP_DEFLATED) as zf:
|
| 733 |
for root, _, files in os.walk(rf):
|
| 734 |
-
for fn in files:
|
|
|
|
|
|
|
| 735 |
|
| 736 |
nf = sum(len(f) for _, _, f in os.walk(rf))
|
| 737 |
shutil.rmtree(rf, ignore_errors=True); gc.collect()
|
|
@@ -808,6 +1029,7 @@ with gr.Blocks(
|
|
| 808 |
<div class="pipeline-box">
|
| 809 |
<strong>📋 流程:</strong>
|
| 810 |
<code>选择分类数</code> → <code>模型训练</code> → <code>交叉验证</code> →
|
|
|
|
| 811 |
<code>SHAP分析</code> → <code>特征消融</code> → <code>外部验证</code>
|
| 812 |
|
|
| 813 |
<strong>CSV格式:</strong> 第1列=标签(整数), 第2列=ID, 第3列起=特征
|
|
@@ -841,20 +1063,22 @@ with gr.Blocks(
|
|
| 841 |
info="RF=随机森林 DT=决策树 KNN=K近邻 XGB=XGBoost AdaBoost LR=逻辑回归 NB=朴素贝叶斯 SVM=支持向量机",
|
| 842 |
)
|
| 843 |
with gr.Row():
|
| 844 |
-
btn_all
|
| 845 |
-
btn_tree
|
| 846 |
btn_linear = gr.Button("📐 线性模型", size="sm", variant="secondary")
|
| 847 |
-
btn_top4
|
| 848 |
btn_all.click(lambda: ALL_MODEL_NAMES, outputs=model_selector)
|
| 849 |
btn_tree.click(lambda: ['RF','DT','XGB','AdaBoost'], outputs=model_selector)
|
| 850 |
btn_linear.click(lambda: ['LR','SVM','NB'], outputs=model_selector)
|
| 851 |
btn_top4.click(lambda: ['RF','XGB','LR','SVM'], outputs=model_selector)
|
| 852 |
|
| 853 |
gr.HTML('<div class="section-title">⚙️ 参数配置</div>')
|
| 854 |
-
enable_tuning = gr.Checkbox(
|
|
|
|
|
|
|
| 855 |
with gr.Row():
|
| 856 |
cv_folds = gr.Slider(3, 10, value=5, step=1, label="交叉验证折数")
|
| 857 |
-
top_n
|
| 858 |
shap_sz = gr.Slider(30, 200, value=80, step=10, label="SHAP 采样数量")
|
| 859 |
|
| 860 |
run_btn = gr.Button("🚀 开始分析", variant="primary", size="lg")
|
|
@@ -900,4 +1124,4 @@ def auth_fn(username, password):
|
|
| 900 |
demo.queue()
|
| 901 |
demo.launch(server_name="0.0.0.0", server_port=7860, auth=auth_fn,
|
| 902 |
auth_message="🔐 复旦大学附属眼耳鼻喉科医院 · ML多分类分析平台\n请输入账号和密码登录",
|
| 903 |
-
ssr_mode=False)
|
|
|
|
| 431 |
log(f"\n{'━'*50}")
|
| 432 |
log(f" ✅ {nm} 个模型训练完成")
|
| 433 |
|
| 434 |
+
# ============================================================
|
| 435 |
+
# ── 辅助函数:Macro ROC / PR 曲线数据 ──
|
| 436 |
+
# ============================================================
|
| 437 |
+
def _macro_roc_curve(yt, yp, nc, cls_idx):
|
| 438 |
+
"""Return (all_fpr, mean_tpr, macro_auc) for overlay plotting."""
|
| 439 |
+
y_b = label_binarize(yt, classes=cls_idx)
|
| 440 |
+
if nc == 2:
|
| 441 |
+
y_b = np.hstack([1 - y_b, y_b])
|
| 442 |
+
all_fpr = np.linspace(0, 1, 300)
|
| 443 |
+
mean_tpr = np.zeros_like(all_fpr)
|
| 444 |
+
for c in range(nc):
|
| 445 |
+
f_, t_, _ = roc_curve(y_b[:, c], yp[:, c])
|
| 446 |
+
mean_tpr += np.interp(all_fpr, f_, t_)
|
| 447 |
+
mean_tpr /= nc; mean_tpr[-1] = 1.0
|
| 448 |
+
return all_fpr, mean_tpr, auc_score(all_fpr, mean_tpr)
|
| 449 |
+
|
| 450 |
+
def _macro_pr_curve(yt, yp, nc, cls_idx):
|
| 451 |
+
y_b = label_binarize(yt, classes=cls_idx)
|
| 452 |
+
if nc == 2:
|
| 453 |
+
y_b = np.hstack([1 - y_b, y_b])
|
| 454 |
+
all_rec = np.linspace(0, 1, 300)
|
| 455 |
+
mean_prec = np.zeros_like(all_rec)
|
| 456 |
+
for c in range(nc):
|
| 457 |
+
prec_, rec_, _ = precision_recall_curve(y_b[:, c], yp[:, c])
|
| 458 |
+
mean_prec += np.interp(all_rec[::-1], rec_[::-1], prec_[::-1])[::-1]
|
| 459 |
+
mean_prec /= nc
|
| 460 |
+
return all_rec, mean_prec
|
| 461 |
+
|
| 462 |
+
# ============================================================
|
| 463 |
+
# ── 训练集 ROC / PR(所有模型,in-sample)──
|
| 464 |
+
# ============================================================
|
| 465 |
+
progress(0.40, desc="📈 训练集ROC/PR曲线...")
|
| 466 |
+
log(f"\n 📈 绘制各模型训练集 ROC / PR 曲线...")
|
| 467 |
+
|
| 468 |
+
train_roc_summary = {} # mn -> train macro_auc
|
| 469 |
+
train_roc_data = {} # mn -> (fpr, tpr, auc)
|
| 470 |
+
train_pr_data = {} # mn -> (rec, prec)
|
| 471 |
+
|
| 472 |
+
for mn in mnames:
|
| 473 |
+
yproba_tr = tms[mn].predict_proba(X.values)
|
| 474 |
+
|
| 475 |
+
# 每个模型:各类 + macro 的独立 ROC / PR 图
|
| 476 |
+
plot_multiclass_roc(
|
| 477 |
+
y_mapped.values, yproba_tr, class_indices,
|
| 478 |
+
f'Train ROC — {mn} ({task_type})', f'train_roc_{mn}', rf
|
| 479 |
+
)
|
| 480 |
+
plot_multiclass_pr(
|
| 481 |
+
y_mapped.values, yproba_tr, class_indices,
|
| 482 |
+
f'Train PR — {mn} ({task_type})', f'train_pr_{mn}', rf
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
fpr_tr, tpr_tr, auc_tr = _macro_roc_curve(
|
| 486 |
+
y_mapped.values, yproba_tr, n_classes, class_indices)
|
| 487 |
+
rec_tr, prec_tr = _macro_pr_curve(
|
| 488 |
+
y_mapped.values, yproba_tr, n_classes, class_indices)
|
| 489 |
+
|
| 490 |
+
train_roc_data[mn] = (fpr_tr, tpr_tr, auc_tr)
|
| 491 |
+
train_pr_data[mn] = (rec_tr, prec_tr)
|
| 492 |
+
train_roc_summary[mn] = auc_tr
|
| 493 |
+
|
| 494 |
+
# 汇总训练集 ROC(所有模型叠加)
|
| 495 |
+
plt.figure(figsize=(10, 8))
|
| 496 |
+
for i, mn in enumerate(mnames):
|
| 497 |
+
fpr_tr, tpr_tr, auc_tr = train_roc_data[mn]
|
| 498 |
+
plt.plot(fpr_tr, tpr_tr, color=COLORS[i % 8], lw=2.5,
|
| 499 |
+
label=f'{mn} (Train Macro AUC={auc_tr:.3f})')
|
| 500 |
+
plt.plot([0, 1], [0, 1], '--', color='#ccc', lw=1)
|
| 501 |
+
plt.xlim([-0.02, 1.02]); plt.ylim([-0.02, 1.02])
|
| 502 |
+
plt.xlabel('False Positive Rate', fontsize=13)
|
| 503 |
+
plt.ylabel('True Positive Rate', fontsize=13)
|
| 504 |
+
plt.title(f'Train ROC — All Models ({task_type})', fontsize=14, fontweight='bold')
|
| 505 |
+
plt.legend(loc='lower right', fontsize=10)
|
| 506 |
+
plt.grid(True, alpha=0.15); plt.tight_layout()
|
| 507 |
+
plt.savefig(os.path.join(rf, 'train_roc_all.pdf'), format='pdf', bbox_inches='tight', dpi=300)
|
| 508 |
+
plt.savefig(os.path.join(rf, 'train_roc_all.png'), format='png', bbox_inches='tight', dpi=150)
|
| 509 |
+
plt.close()
|
| 510 |
+
|
| 511 |
+
# 汇总训练集 PR(所有模型叠加)
|
| 512 |
+
plt.figure(figsize=(10, 8))
|
| 513 |
+
for i, mn in enumerate(mnames):
|
| 514 |
+
rec_tr, prec_tr = train_pr_data[mn]
|
| 515 |
+
plt.plot(rec_tr, prec_tr, color=COLORS[i % 8], lw=2.5,
|
| 516 |
+
label=f'{mn} (Mean AP={prec_tr.mean():.3f})')
|
| 517 |
+
plt.xlim([-0.02, 1.02]); plt.ylim([-0.02, 1.02])
|
| 518 |
+
plt.xlabel('Recall', fontsize=13); plt.ylabel('Precision', fontsize=13)
|
| 519 |
+
plt.title(f'Train PR — All Models ({task_type})', fontsize=14, fontweight='bold')
|
| 520 |
+
plt.legend(loc='lower left', fontsize=10)
|
| 521 |
+
plt.grid(True, alpha=0.15); plt.tight_layout()
|
| 522 |
+
plt.savefig(os.path.join(rf, 'train_pr_all.pdf'), format='pdf', bbox_inches='tight', dpi=300)
|
| 523 |
+
plt.savefig(os.path.join(rf, 'train_pr_all.png'), format='png', bbox_inches='tight', dpi=150)
|
| 524 |
+
plt.close()
|
| 525 |
+
log(f" ✅ 训练集 ROC/PR 曲线已生成(各模型独立 + 汇总共 {nm*2+2*2} 张图)")
|
| 526 |
+
|
| 527 |
+
# ============================================================
|
| 528 |
+
# ── 交叉验证 ROC(原有逻辑,保留)──
|
| 529 |
+
# ============================================================
|
| 530 |
+
progress(0.42, desc="📈 交叉验证ROC曲线...")
|
| 531 |
+
log(f"\n 📈 绘制交叉验证 ROC 曲线...")
|
| 532 |
for mn in mnames:
|
| 533 |
r = amr[mn]
|
| 534 |
plot_multiclass_roc(r['all_yt'], r['all_yproba'], class_indices,
|
| 535 |
+
f'CV ROC — {mn} ({task_type}, Macro AUC={r["mean_auc"]:.3f})', f'roc_{mn}', rf)
|
| 536 |
|
| 537 |
+
# 汇总 CV ROC(所有模型)
|
| 538 |
plt.figure(figsize=(10, 8))
|
| 539 |
for i, mn in enumerate(mnames):
|
| 540 |
r = amr[mn]
|
| 541 |
+
fpr_cv, tpr_cv, auc_cv = _macro_roc_curve(
|
| 542 |
+
r['all_yt'], r['all_yproba'], n_classes, class_indices)
|
| 543 |
+
plt.plot(fpr_cv, tpr_cv, color=COLORS[i % 8], lw=2.5,
|
| 544 |
+
label=f'{mn} (CV Macro AUC={auc_cv:.3f})')
|
| 545 |
+
plt.plot([0, 1], [0, 1], '--', color='#ccc', lw=1)
|
| 546 |
+
plt.xlim([-0.02, 1.02]); plt.ylim([-0.02, 1.02])
|
| 547 |
+
plt.xlabel('FPR', fontsize=13); plt.ylabel('TPR', fontsize=13)
|
| 548 |
+
plt.title(f'CV ROC — All Models ({task_type})', fontsize=14, fontweight='bold')
|
| 549 |
+
plt.legend(loc='lower right', fontsize=10)
|
| 550 |
+
plt.grid(True, alpha=0.15); plt.tight_layout()
|
| 551 |
+
plt.savefig(os.path.join(rf, 'roc_all.pdf'), format='pdf', bbox_inches='tight', dpi=300)
|
| 552 |
+
plt.savefig(os.path.join(rf, 'roc_all.png'), format='png', bbox_inches='tight', dpi=150)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 553 |
plt.close()
|
| 554 |
|
| 555 |
+
# ============================================================
|
| 556 |
+
# ── 最佳模型:训练集 vs 内部验证集(CV holdout)对比 ──
|
| 557 |
+
# ============================================================
|
| 558 |
+
progress(0.44, desc="📊 最终模型训练集vs内部验证集对比...")
|
| 559 |
+
|
| 560 |
+
# 先确定最佳模型(后续 Bootstrap 也会用到,此处提前计算)
|
| 561 |
+
best_mn = max(amr, key=lambda x: amr[x]['mean_auc'])
|
| 562 |
+
best_auc = amr[best_mn]['mean_auc']
|
| 563 |
+
|
| 564 |
+
log(f"\n 📊 最终模型 [{best_mn}] 训练集 vs 内部验证集(CV holdout)对比...")
|
| 565 |
+
|
| 566 |
+
# 训练集预测
|
| 567 |
+
yproba_best_train = tms[best_mn].predict_proba(X.values)
|
| 568 |
+
ypred_best_train = tms[best_mn].predict(X.values)
|
| 569 |
+
metrics_train = compute_multiclass_metrics(
|
| 570 |
+
y_mapped.values, ypred_best_train, yproba_best_train, class_indices
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
# CV holdout(已在 amr 中累积)
|
| 574 |
+
yproba_best_cv = amr[best_mn]['all_yproba']
|
| 575 |
+
ypred_best_cv = amr[best_mn]['all_yp']
|
| 576 |
+
ytrue_best_cv = amr[best_mn]['all_yt']
|
| 577 |
+
metrics_cv = compute_multiclass_metrics(
|
| 578 |
+
ytrue_best_cv, ypred_best_cv, yproba_best_cv, class_indices
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
log(f" Train → AUC={metrics_train['Macro_AUC']:.4f} Acc={metrics_train['Accuracy']:.4f}"
|
| 582 |
+
f" F1={metrics_train['Macro_F1']:.4f} Kappa={metrics_train['Kappa']:.4f}")
|
| 583 |
+
log(f" CV-Val → AUC={metrics_cv['Macro_AUC']:.4f} Acc={metrics_cv['Accuracy']:.4f}"
|
| 584 |
+
f" F1={metrics_cv['Macro_F1']:.4f} Kappa={metrics_cv['Kappa']:.4f}")
|
| 585 |
+
|
| 586 |
+
# 对比 ROC
|
| 587 |
+
fpr_tr_b, tpr_tr_b, auc_tr_b = _macro_roc_curve(
|
| 588 |
+
y_mapped.values, yproba_best_train, n_classes, class_indices)
|
| 589 |
+
fpr_cv_b, tpr_cv_b, auc_cv_b = _macro_roc_curve(
|
| 590 |
+
ytrue_best_cv, yproba_best_cv, n_classes, class_indices)
|
| 591 |
+
|
| 592 |
+
fig, ax = plt.subplots(figsize=(10, 8))
|
| 593 |
+
ax.plot(fpr_tr_b, tpr_tr_b, color='#e41a1c', lw=2.5,
|
| 594 |
+
label=f'Train set (Macro AUC={auc_tr_b:.3f})')
|
| 595 |
+
ax.plot(fpr_cv_b, tpr_cv_b, color='#377eb8', lw=2.5, linestyle='--',
|
| 596 |
+
label=f'Internal CV (Macro AUC={auc_cv_b:.3f})')
|
| 597 |
+
ax.plot([0, 1], [0, 1], '--', color='#ccc', lw=1)
|
| 598 |
+
ax.set_xlim([-0.02, 1.02]); ax.set_ylim([-0.02, 1.02])
|
| 599 |
+
ax.set_xlabel('False Positive Rate', fontsize=13)
|
| 600 |
+
ax.set_ylabel('True Positive Rate', fontsize=13)
|
| 601 |
+
ax.set_title(f'ROC — {best_mn}: Train vs Internal CV ({task_type})',
|
| 602 |
+
fontsize=14, fontweight='bold')
|
| 603 |
+
ax.legend(loc='lower right', fontsize=11)
|
| 604 |
+
ax.grid(True, alpha=0.15); plt.tight_layout()
|
| 605 |
+
plt.savefig(os.path.join(rf, f'roc_train_vs_cv_{best_mn}.pdf'),
|
| 606 |
+
format='pdf', bbox_inches='tight', dpi=300)
|
| 607 |
+
plt.savefig(os.path.join(rf, f'roc_train_vs_cv_{best_mn}.png'),
|
| 608 |
+
format='png', bbox_inches='tight', dpi=150)
|
| 609 |
+
plt.close()
|
| 610 |
+
|
| 611 |
+
# 对比 PR
|
| 612 |
+
rec_tr_b, prec_tr_b = _macro_pr_curve(
|
| 613 |
+
y_mapped.values, yproba_best_train, n_classes, class_indices)
|
| 614 |
+
rec_cv_b, prec_cv_b = _macro_pr_curve(
|
| 615 |
+
ytrue_best_cv, yproba_best_cv, n_classes, class_indices)
|
| 616 |
+
|
| 617 |
+
fig, ax = plt.subplots(figsize=(10, 8))
|
| 618 |
+
ax.plot(rec_tr_b, prec_tr_b, color='#e41a1c', lw=2.5,
|
| 619 |
+
label=f'Train set (Mean AP={prec_tr_b.mean():.3f})')
|
| 620 |
+
ax.plot(rec_cv_b, prec_cv_b, color='#377eb8', lw=2.5, linestyle='--',
|
| 621 |
+
label=f'Internal CV (Mean AP={prec_cv_b.mean():.3f})')
|
| 622 |
+
ax.set_xlim([-0.02, 1.02]); ax.set_ylim([-0.02, 1.02])
|
| 623 |
+
ax.set_xlabel('Recall', fontsize=13)
|
| 624 |
+
ax.set_ylabel('Precision', fontsize=13)
|
| 625 |
+
ax.set_title(f'PR — {best_mn}: Train vs Internal CV ({task_type})',
|
| 626 |
+
fontsize=14, fontweight='bold')
|
| 627 |
+
ax.legend(loc='lower left', fontsize=11)
|
| 628 |
+
ax.grid(True, alpha=0.15); plt.tight_layout()
|
| 629 |
+
plt.savefig(os.path.join(rf, f'pr_train_vs_cv_{best_mn}.pdf'),
|
| 630 |
+
format='pdf', bbox_inches='tight', dpi=300)
|
| 631 |
+
plt.savefig(os.path.join(rf, f'pr_train_vs_cv_{best_mn}.png'),
|
| 632 |
+
format='png', bbox_inches='tight', dpi=150)
|
| 633 |
+
plt.close()
|
| 634 |
+
|
| 635 |
+
# 训练集混淆矩阵(最佳模型)
|
| 636 |
+
plot_confusion_matrix(
|
| 637 |
+
y_mapped.values, ypred_best_train, class_indices,
|
| 638 |
+
f'Train CM — {best_mn} (Acc={metrics_train["Accuracy"]:.3f})',
|
| 639 |
+
f'cm_train_{best_mn}', rf
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
# 保存 Train vs CV 汇总 Excel
|
| 643 |
+
with pd.ExcelWriter(os.path.join(rf, f'train_vs_cv_{best_mn}.xlsx'),
|
| 644 |
+
engine='openpyxl') as w:
|
| 645 |
+
summary_rows = [
|
| 646 |
+
{'Split': 'Train', 'Model': best_mn,
|
| 647 |
+
'Macro_AUC': metrics_train['Macro_AUC'],
|
| 648 |
+
'Accuracy': metrics_train['Accuracy'],
|
| 649 |
+
'Macro_F1': metrics_train['Macro_F1'],
|
| 650 |
+
'Weighted_F1': metrics_train['Weighted_F1'],
|
| 651 |
+
'Kappa': metrics_train['Kappa']},
|
| 652 |
+
{'Split': 'Internal_CV', 'Model': best_mn,
|
| 653 |
+
'Macro_AUC': metrics_cv['Macro_AUC'],
|
| 654 |
+
'Accuracy': metrics_cv['Accuracy'],
|
| 655 |
+
'Macro_F1': metrics_cv['Macro_F1'],
|
| 656 |
+
'Weighted_F1': metrics_cv['Weighted_F1'],
|
| 657 |
+
'Kappa': metrics_cv['Kappa']},
|
| 658 |
+
]
|
| 659 |
+
pd.DataFrame(summary_rows).to_excel(w, sheet_name='Summary', index=False)
|
| 660 |
+
pd.DataFrame(metrics_train['report']).T.to_excel(w, sheet_name='Train_PerClass', index=True)
|
| 661 |
+
pd.DataFrame(metrics_cv['report']).T.to_excel(w, sheet_name='CV_PerClass', index=True)
|
| 662 |
+
amr[best_mn]['fold_df'].to_excel(w, sheet_name='CV_FoldDetail', index=False)
|
| 663 |
+
|
| 664 |
+
log(f" ✅ Train vs CV 对比图及数据已保存 → train_vs_cv_{best_mn}.xlsx")
|
| 665 |
+
|
| 666 |
+
# ── PR Curves (CV,原有逻辑) ──
|
| 667 |
+
progress(0.48, desc="📈 交叉验证PR曲线...")
|
| 668 |
for mn in mnames:
|
| 669 |
r = amr[mn]
|
| 670 |
plot_multiclass_pr(r['all_yt'], r['all_yproba'], class_indices,
|
| 671 |
+
f'CV PR — {mn} ({task_type})', f'pr_{mn}', rf)
|
| 672 |
|
| 673 |
+
# ── Confusion Matrices (CV) ──
|
| 674 |
progress(0.52, desc="📊 混淆矩阵...")
|
| 675 |
for mn in mnames:
|
| 676 |
r = amr[mn]
|
| 677 |
plot_confusion_matrix(r['all_yt'], r['all_yp'], class_indices,
|
| 678 |
+
f'CV CM — {mn} (Acc={r["mean_acc"]:.3f})', f'cm_{mn}', rf)
|
| 679 |
|
| 680 |
+
# ── Bootstrap AUC Test ──
|
| 681 |
progress(0.55, desc="🔬 Bootstrap AUC 检验...")
|
|
|
|
|
|
|
| 682 |
log(f"\n 🏆 最佳模型: {best_mn} (Macro AUC={best_auc:.4f})")
|
| 683 |
log(f" 🔬 Bootstrap 检验 (n=2000, α=0.05)...")
|
| 684 |
|
|
|
|
| 717 |
progress(0.62, desc="🔥 SHAP分析...")
|
| 718 |
log(f"\n 🔥 SHAP特征分析 (保留模型中 Top 3)...")
|
| 719 |
shap_imp = {}
|
|
|
|
| 720 |
models_for_shap = sorted(retained, key=lambda x: amr[x]['mean_auc'], reverse=True)[:3]
|
| 721 |
|
| 722 |
for si, mn in enumerate(models_for_shap):
|
|
|
|
| 733 |
exp = shap.KernelExplainer(lambda x, m=mo: m.predict_proba(x), bg)
|
| 734 |
sv = exp.shap_values(Xs)
|
| 735 |
|
|
|
|
| 736 |
if isinstance(sv, list):
|
|
|
|
| 737 |
sv_abs = np.mean([np.abs(s) for s in sv], axis=0)
|
| 738 |
elif sv.ndim == 3:
|
| 739 |
+
sv_abs = np.mean(np.abs(sv), axis=2)
|
| 740 |
else:
|
| 741 |
sv_abs = np.abs(sv)
|
| 742 |
|
|
|
|
| 747 |
idf = pd.DataFrame({'Feature': fnames, 'Importance': fi}).sort_values('Importance', ascending=False)
|
| 748 |
shap_imp[mn] = idf
|
| 749 |
|
|
|
|
| 750 |
plt.figure(figsize=(10, max(6, TOPN * 0.3)))
|
| 751 |
top_df = idf.head(TOPN).iloc[::-1]
|
| 752 |
plt.barh(top_df['Feature'], top_df['Importance'], color='#2563eb', alpha=0.8)
|
|
|
|
| 760 |
except Exception as e:
|
| 761 |
log(f" ⚠ {mn} SHAP失败: {e}")
|
| 762 |
|
| 763 |
+
# ── Feature Ablation ──
|
| 764 |
progress(0.72, desc="🧪 特征消融...")
|
| 765 |
log(f"\n 🧪 特征消融 (仅最佳模型 {best_mn})...")
|
| 766 |
ablation_data = None
|
|
|
|
| 768 |
imp_df = shap_imp[best_mn]
|
| 769 |
top_feats = imp_df.head(TOPN)['Feature'].tolist()
|
| 770 |
fcs = []; aucs_a = []
|
|
|
|
| 771 |
|
| 772 |
for nf in range(1, len(top_feats) + 1):
|
| 773 |
Xsub = X[top_feats[:nf]]
|
|
|
|
| 788 |
fold_aucs.append(a)
|
| 789 |
fcs.append(nf); aucs_a.append(np.mean(fold_aucs))
|
| 790 |
|
|
|
|
| 791 |
full_auc = amr[best_mn]['mean_auc']
|
| 792 |
opt_n = len(top_feats)
|
| 793 |
for i, a in enumerate(aucs_a):
|
| 794 |
if a >= full_auc * 0.95:
|
| 795 |
opt_n = i + 1; break
|
| 796 |
|
| 797 |
+
ablation_data = {'fcs': fcs, 'aucs': aucs_a, 'feats': top_feats,
|
| 798 |
+
'opt_n': opt_n, 'opt_feats': top_feats[:opt_n]}
|
| 799 |
log(f" ✅ 最优特征数: {opt_n} (AUC={aucs_a[opt_n-1]:.4f} vs Full={full_auc:.4f})")
|
| 800 |
|
|
|
|
| 801 |
plt.figure(figsize=(10, 7))
|
| 802 |
plt.plot(fcs, aucs_a, 'o-', color='#2563eb', lw=2, ms=5)
|
| 803 |
+
plt.scatter([opt_n], [aucs_a[opt_n-1]], s=200, marker='*',
|
| 804 |
+
color='#ef4444', edgecolors='black', lw=2, zorder=5)
|
| 805 |
+
plt.axhline(y=full_auc, color='gray', ls='--', lw=1, alpha=0.5,
|
| 806 |
+
label=f'Full AUC={full_auc:.3f}')
|
| 807 |
plt.xlabel('Number of Features', fontsize=13); plt.ylabel('Macro AUC', fontsize=13)
|
| 808 |
+
plt.title(f'Feature Ablation — {best_mn} (★ Optimal={opt_n})',
|
| 809 |
+
fontsize=14, fontweight='bold')
|
| 810 |
plt.legend(fontsize=11); plt.grid(True, alpha=0.15); plt.tight_layout()
|
| 811 |
plt.savefig(os.path.join(rf, 'ablation.pdf'), format='pdf', bbox_inches='tight')
|
| 812 |
plt.savefig(os.path.join(rf, 'ablation.png'), format='png', bbox_inches='tight', dpi=150)
|
|
|
|
| 828 |
vcol2_is_id = (vcol2.dtype == 'object') or (vcol2.nunique() / len(vcol2) > 0.5)
|
| 829 |
Xe = ed.iloc[:, 2:] if vcol2_is_id else ed.iloc[:, 1:]
|
| 830 |
|
|
|
|
| 831 |
ye = ye_raw.map(label_map)
|
| 832 |
if ye.isna().any():
|
| 833 |
log(f" ⚠ 验证集 {vi} 含有训练集中不存在的标签,已跳过")
|
|
|
|
| 844 |
ye_np = ye.values
|
| 845 |
|
| 846 |
metrics = compute_multiclass_metrics(ye_np, yed, yep, class_indices)
|
| 847 |
+
log(f" ✅ AUC={metrics['Macro_AUC']:.4f} Acc={metrics['Accuracy']:.4f}"
|
| 848 |
+
f" F1={metrics['Macro_F1']:.4f} Kappa={metrics['Kappa']:.4f}")
|
| 849 |
|
| 850 |
sfx = f'_ext{vi}' if len(val_files_list) > 1 else '_ext'
|
| 851 |
tag = f'Validation {vi}' if len(val_files_list) > 1 else 'External'
|
| 852 |
|
| 853 |
+
plot_multiclass_roc(ye_np, yep, class_indices,
|
| 854 |
+
f'ROC — {tag} ({best_mn})', f'roc{sfx}', rf)
|
| 855 |
+
plot_multiclass_pr(ye_np, yep, class_indices,
|
| 856 |
+
f'PR — {tag} ({best_mn})', f'pr{sfx}', rf)
|
| 857 |
+
plot_confusion_matrix(ye_np, yed, class_indices,
|
| 858 |
+
f'CM — {tag} ({best_mn})', f'cm{sfx}', rf)
|
| 859 |
|
| 860 |
with pd.ExcelWriter(os.path.join(rf, f'validation{sfx}.xlsx'), engine='openpyxl') as w:
|
| 861 |
pd.DataFrame([{'Model': best_mn, 'N_Features': len(final_feats),
|
| 862 |
'Macro_AUC': metrics['Macro_AUC'], 'Accuracy': metrics['Accuracy'],
|
| 863 |
'Macro_F1': metrics['Macro_F1'], 'Weighted_F1': metrics['Weighted_F1'],
|
| 864 |
'Kappa': metrics['Kappa']}]).to_excel(w, sheet_name='Metrics', index=False)
|
| 865 |
+
pd.DataFrame(metrics['report']).T.to_excel(w, sheet_name='Per_Class', index=True)
|
|
|
|
| 866 |
pd.DataFrame({'Feature': final_feats}).to_excel(w, sheet_name='Features', index=False)
|
| 867 |
|
| 868 |
# ── Save Results ──
|
|
|
|
| 872 |
with pd.ExcelWriter(os.path.join(rf, 'model_evaluation.xlsx'), engine='openpyxl') as w:
|
| 873 |
for mn, r in amr.items():
|
| 874 |
r['fold_df'].to_excel(w, sheet_name=mn, index=False)
|
| 875 |
+
# Summary(新增 Train_AUC 列)
|
| 876 |
+
sd = [{'Model': mn,
|
| 877 |
+
'CV_Macro_AUC': r['mean_auc'],
|
| 878 |
+
'Train_Macro_AUC': train_roc_summary.get(mn, ''),
|
| 879 |
+
'CV_Accuracy': r['mean_acc'],
|
| 880 |
+
'CV_Macro_F1': r['mean_f1'],
|
| 881 |
+
'Retained': 'Yes' if mn in retained else 'No',
|
| 882 |
'Best': 'Best' if mn == best_mn else ''}
|
| 883 |
for mn, r in amr.items()]
|
| 884 |
+
pd.DataFrame(sd).sort_values('CV_Macro_AUC', ascending=False).to_excel(
|
| 885 |
+
w, sheet_name='Summary', index=False)
|
| 886 |
if len(bootstrap_df) > 0:
|
| 887 |
bootstrap_df.to_excel(w, sheet_name='Bootstrap_Test', index=False)
|
| 888 |
+
best_report = classification_report(
|
| 889 |
+
amr[best_mn]['all_yt'], amr[best_mn]['all_yp'],
|
| 890 |
+
labels=class_indices, output_dict=True, zero_division=0)
|
| 891 |
pd.DataFrame(best_report).T.to_excel(w, sheet_name=f'{best_mn}_PerClass', index=True)
|
| 892 |
|
| 893 |
if ablation_data:
|
| 894 |
with pd.ExcelWriter(os.path.join(rf, 'feature_ablation.xlsx'), engine='openpyxl') as w:
|
| 895 |
+
pd.DataFrame({'N': ablation_data['fcs'], 'AUC': ablation_data['aucs']}).to_excel(
|
| 896 |
+
w, sheet_name='Ablation', index=False)
|
| 897 |
for mn, idf in shap_imp.items():
|
| 898 |
idf.to_excel(w, sheet_name=f'{mn}_Imp', index=False)
|
| 899 |
|
| 900 |
+
# Save params
|
| 901 |
with open(os.path.join(rf, 'best_params.txt'), 'w', encoding='utf-8') as f:
|
| 902 |
f.write(f"Task: {task_type} Classification ({n_classes} classes)\n")
|
| 903 |
f.write(f"Classes: {classes}\n")
|
|
|
|
| 906 |
f.write(f"Retained Models: {', '.join(retained)} ({len(retained)}/{nm})\n\n")
|
| 907 |
for mn in mcfg:
|
| 908 |
status = "* Best" if mn == best_mn else ("Retained" if mn in retained else "Excluded")
|
| 909 |
+
f.write(f"Model: {mn} | CV_AUC={amr[mn]['mean_auc']:.4f}"
|
| 910 |
+
f" | Train_AUC={train_roc_summary.get(mn, 'N/A')} | {status}\n")
|
| 911 |
bp = bpd[mn]
|
| 912 |
if isinstance(bp, dict):
|
| 913 |
for k, v in bp.items(): f.write(f" {k}: {v}\n")
|
|
|
|
| 919 |
f.write("=" * 50 + "\n")
|
| 920 |
for _, row in bootstrap_df.iterrows():
|
| 921 |
f.write(f" {row['Model_A']} vs {row['Model_B']}: ")
|
| 922 |
+
f.write(f"dAUC={row['AUC_Diff']:+.4f} "
|
| 923 |
+
f"95%CI=[{row['CI_95_Low']:+.4f},{row['CI_95_High']:+.4f}] ")
|
| 924 |
f.write(f"P={row['P_value']:.4f} -> {row['Decision']}\n")
|
| 925 |
if ablation_data:
|
| 926 |
+
f.write(f"\nOptimal Features ({ablation_data['opt_n']}): "
|
| 927 |
+
f"{', '.join(ablation_data['opt_feats'])}\n")
|
| 928 |
+
f.write(f"\n{'='*50}\n")
|
| 929 |
+
f.write(f"Best Model [{best_mn}] Train vs Internal CV\n")
|
| 930 |
+
f.write(f"{'='*50}\n")
|
| 931 |
+
f.write(f" Train → AUC={metrics_train['Macro_AUC']:.4f}"
|
| 932 |
+
f" Acc={metrics_train['Accuracy']:.4f}"
|
| 933 |
+
f" F1={metrics_train['Macro_F1']:.4f}"
|
| 934 |
+
f" Kappa={metrics_train['Kappa']:.4f}\n")
|
| 935 |
+
f.write(f" CV-Val → AUC={metrics_cv['Macro_AUC']:.4f}"
|
| 936 |
+
f" Acc={metrics_cv['Accuracy']:.4f}"
|
| 937 |
+
f" F1={metrics_cv['Macro_F1']:.4f}"
|
| 938 |
+
f" Kappa={metrics_cv['Kappa']:.4f}\n")
|
| 939 |
|
| 940 |
# Save model
|
| 941 |
pickle.dump({
|
|
|
|
| 946 |
|
| 947 |
# ── ZIP ──
|
| 948 |
progress(0.97, desc="📦 打包ZIP...")
|
| 949 |
+
zp = os.path.join(tempfile.gettempdir(),
|
| 950 |
+
f"ml_results_{int(time.time())}_{os.getpid()}.zip")
|
| 951 |
with zipfile.ZipFile(zp, 'w', zipfile.ZIP_DEFLATED) as zf:
|
| 952 |
for root, _, files in os.walk(rf):
|
| 953 |
+
for fn in files:
|
| 954 |
+
zf.write(os.path.join(root, fn),
|
| 955 |
+
os.path.relpath(os.path.join(root, fn), rf))
|
| 956 |
|
| 957 |
nf = sum(len(f) for _, _, f in os.walk(rf))
|
| 958 |
shutil.rmtree(rf, ignore_errors=True); gc.collect()
|
|
|
|
| 1029 |
<div class="pipeline-box">
|
| 1030 |
<strong>📋 流程:</strong>
|
| 1031 |
<code>选择分类数</code> → <code>模型训练</code> → <code>交叉验证</code> →
|
| 1032 |
+
<code>训练集ROC/PR</code> → <code>Train vs CV对比</code> →
|
| 1033 |
<code>SHAP分析</code> → <code>特征消融</code> → <code>外部验证</code>
|
| 1034 |
|
|
| 1035 |
<strong>CSV格式:</strong> 第1列=标签(整数), 第2列=ID, 第3列起=特征
|
|
|
|
| 1063 |
info="RF=随机森林 DT=决策树 KNN=K近邻 XGB=XGBoost AdaBoost LR=逻辑回归 NB=朴素贝叶斯 SVM=支持向量机",
|
| 1064 |
)
|
| 1065 |
with gr.Row():
|
| 1066 |
+
btn_all = gr.Button("🔘 全选", size="sm", variant="secondary")
|
| 1067 |
+
btn_tree = gr.Button("🌲 树模型", size="sm", variant="secondary")
|
| 1068 |
btn_linear = gr.Button("📐 线性模型", size="sm", variant="secondary")
|
| 1069 |
+
btn_top4 = gr.Button("⚡ 经典四模型", size="sm", variant="secondary")
|
| 1070 |
btn_all.click(lambda: ALL_MODEL_NAMES, outputs=model_selector)
|
| 1071 |
btn_tree.click(lambda: ['RF','DT','XGB','AdaBoost'], outputs=model_selector)
|
| 1072 |
btn_linear.click(lambda: ['LR','SVM','NB'], outputs=model_selector)
|
| 1073 |
btn_top4.click(lambda: ['RF','XGB','LR','SVM'], outputs=model_selector)
|
| 1074 |
|
| 1075 |
gr.HTML('<div class="section-title">⚙️ 参数配置</div>')
|
| 1076 |
+
enable_tuning = gr.Checkbox(
|
| 1077 |
+
value=False,
|
| 1078 |
+
label="启用超参数调优 (GridSearchCV) ⚠️ 开启后运行时间显著增加")
|
| 1079 |
with gr.Row():
|
| 1080 |
cv_folds = gr.Slider(3, 10, value=5, step=1, label="交叉验证折数")
|
| 1081 |
+
top_n = gr.Slider(5, 50, value=20, step=1, label="SHAP 前 N 个特征")
|
| 1082 |
shap_sz = gr.Slider(30, 200, value=80, step=10, label="SHAP 采样数量")
|
| 1083 |
|
| 1084 |
run_btn = gr.Button("🚀 开始分析", variant="primary", size="lg")
|
|
|
|
| 1124 |
demo.queue()
|
| 1125 |
demo.launch(server_name="0.0.0.0", server_port=7860, auth=auth_fn,
|
| 1126 |
auth_message="🔐 复旦大学附属眼耳鼻喉科医院 · ML多分类分析平台\n请输入账号和密码登录",
|
| 1127 |
+
ssr_mode=False)
|