fudan-renjun commited on
Commit
27e1310
·
verified ·
1 Parent(s): e195614

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +294 -70
app.py CHANGED
@@ -431,54 +431,254 @@ def run_pipeline(
431
  log(f"\n{'━'*50}")
432
  log(f" ✅ {nm} 个模型训练完成")
433
 
434
- # ── ROC Curves ──
435
- progress(0.42, desc="📈 ROC曲线...")
436
- log(f"\n 📈 绘制图表...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
437
  for mn in mnames:
438
  r = amr[mn]
439
  plot_multiclass_roc(r['all_yt'], r['all_yproba'], class_indices,
440
- f'ROC — {mn} ({task_type}, Macro AUC={r["mean_auc"]:.3f})', f'roc_{mn}', rf)
441
 
442
- # Combined ROC (macro per model)
443
  plt.figure(figsize=(10, 8))
444
  for i, mn in enumerate(mnames):
445
  r = amr[mn]
446
- y_bin = label_binarize(r['all_yt'], classes=class_indices)
447
- if n_classes == 2: y_bin = np.hstack([1 - y_bin, y_bin])
448
- all_fpr = np.linspace(0, 1, 200); mean_tpr = np.zeros_like(all_fpr)
449
- for c in range(n_classes):
450
- f, t, _ = roc_curve(y_bin[:, c], r['all_yproba'][:, c])
451
- mean_tpr += np.interp(all_fpr, f, t)
452
- mean_tpr /= n_classes; mean_tpr[-1] = 1.0
453
- ma = auc_score(all_fpr, mean_tpr)
454
- plt.plot(all_fpr, mean_tpr, color=COLORS[i%8], lw=2.5, label=f'{mn} (Macro AUC={ma:.3f})')
455
- plt.plot([0,1],[0,1],'--',color='#ccc',lw=1)
456
- plt.xlim([-0.02,1.02]); plt.ylim([-0.02,1.02])
457
- plt.xlabel('FPR',fontsize=13); plt.ylabel('TPR',fontsize=13)
458
- plt.title(f'ROC — All Models ({task_type})',fontsize=14,fontweight='bold')
459
- plt.legend(loc='lower right',fontsize=10); plt.grid(True,alpha=0.15); plt.tight_layout()
460
- plt.savefig(os.path.join(rf,'roc_all.pdf'),format='pdf',bbox_inches='tight',dpi=300)
461
- plt.savefig(os.path.join(rf,'roc_all.png'),format='png',bbox_inches='tight',dpi=150)
462
  plt.close()
463
 
464
- # ── PR Curves ──
465
- progress(0.48, desc="📈 PR曲线...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
  for mn in mnames:
467
  r = amr[mn]
468
  plot_multiclass_pr(r['all_yt'], r['all_yproba'], class_indices,
469
- f'PR — {mn} ({task_type})', f'pr_{mn}', rf)
470
 
471
- # ── Confusion Matrices ──
472
  progress(0.52, desc="📊 混淆矩阵...")
473
  for mn in mnames:
474
  r = amr[mn]
475
  plot_confusion_matrix(r['all_yt'], r['all_yp'], class_indices,
476
- f'CM — {mn} (Acc={r["mean_acc"]:.3f})', f'cm_{mn}', rf)
477
 
478
- # ── Bootstrap AUC Test (模型统计检验) ──
479
  progress(0.55, desc="🔬 Bootstrap AUC 检验...")
480
- best_mn = max(amr, key=lambda x: amr[x]['mean_auc'])
481
- best_auc = amr[best_mn]['mean_auc']
482
  log(f"\n 🏆 最佳模型: {best_mn} (Macro AUC={best_auc:.4f})")
483
  log(f" 🔬 Bootstrap 检验 (n=2000, α=0.05)...")
484
 
@@ -517,7 +717,6 @@ def run_pipeline(
517
  progress(0.62, desc="🔥 SHAP分析...")
518
  log(f"\n 🔥 SHAP特征分析 (保留模型中 Top 3)...")
519
  shap_imp = {}
520
- # SHAP for top 3 retained models
521
  models_for_shap = sorted(retained, key=lambda x: amr[x]['mean_auc'], reverse=True)[:3]
522
 
523
  for si, mn in enumerate(models_for_shap):
@@ -534,12 +733,10 @@ def run_pipeline(
534
  exp = shap.KernelExplainer(lambda x, m=mo: m.predict_proba(x), bg)
535
  sv = exp.shap_values(Xs)
536
 
537
- # Handle SHAP output: could be list of arrays (one per class) or 3D array
538
  if isinstance(sv, list):
539
- # Average absolute SHAP across all classes
540
  sv_abs = np.mean([np.abs(s) for s in sv], axis=0)
541
  elif sv.ndim == 3:
542
- sv_abs = np.mean(np.abs(sv), axis=2) # (samples, features)
543
  else:
544
  sv_abs = np.abs(sv)
545
 
@@ -550,7 +747,6 @@ def run_pipeline(
550
  idf = pd.DataFrame({'Feature': fnames, 'Importance': fi}).sort_values('Importance', ascending=False)
551
  shap_imp[mn] = idf
552
 
553
- # Bar plot (works for any number of classes)
554
  plt.figure(figsize=(10, max(6, TOPN * 0.3)))
555
  top_df = idf.head(TOPN).iloc[::-1]
556
  plt.barh(top_df['Feature'], top_df['Importance'], color='#2563eb', alpha=0.8)
@@ -564,7 +760,7 @@ def run_pipeline(
564
  except Exception as e:
565
  log(f" ⚠ {mn} SHAP失败: {e}")
566
 
567
- # ── Feature Ablation (for best model only) ──
568
  progress(0.72, desc="🧪 特征消融...")
569
  log(f"\n 🧪 特征消融 (仅最佳模型 {best_mn})...")
570
  ablation_data = None
@@ -572,7 +768,6 @@ def run_pipeline(
572
  imp_df = shap_imp[best_mn]
573
  top_feats = imp_df.head(TOPN)['Feature'].tolist()
574
  fcs = []; aucs_a = []
575
- scoring = 'roc_auc_ovr' if n_classes > 2 else 'roc_auc'
576
 
577
  for nf in range(1, len(top_feats) + 1):
578
  Xsub = X[top_feats[:nf]]
@@ -593,23 +788,25 @@ def run_pipeline(
593
  fold_aucs.append(a)
594
  fcs.append(nf); aucs_a.append(np.mean(fold_aucs))
595
 
596
- # Find optimal: first N where AUC >= 95% of full AUC
597
  full_auc = amr[best_mn]['mean_auc']
598
  opt_n = len(top_feats)
599
  for i, a in enumerate(aucs_a):
600
  if a >= full_auc * 0.95:
601
  opt_n = i + 1; break
602
 
603
- ablation_data = {'fcs': fcs, 'aucs': aucs_a, 'feats': top_feats, 'opt_n': opt_n, 'opt_feats': top_feats[:opt_n]}
 
604
  log(f" ✅ 最优特征数: {opt_n} (AUC={aucs_a[opt_n-1]:.4f} vs Full={full_auc:.4f})")
605
 
606
- # Plot
607
  plt.figure(figsize=(10, 7))
608
  plt.plot(fcs, aucs_a, 'o-', color='#2563eb', lw=2, ms=5)
609
- plt.scatter([opt_n], [aucs_a[opt_n-1]], s=200, marker='*', color='#ef4444', edgecolors='black', lw=2, zorder=5)
610
- plt.axhline(y=full_auc, color='gray', ls='--', lw=1, alpha=0.5, label=f'Full AUC={full_auc:.3f}')
 
 
611
  plt.xlabel('Number of Features', fontsize=13); plt.ylabel('Macro AUC', fontsize=13)
612
- plt.title(f'Feature Ablation — {best_mn} (★ Optimal={opt_n})', fontsize=14, fontweight='bold')
 
613
  plt.legend(fontsize=11); plt.grid(True, alpha=0.15); plt.tight_layout()
614
  plt.savefig(os.path.join(rf, 'ablation.pdf'), format='pdf', bbox_inches='tight')
615
  plt.savefig(os.path.join(rf, 'ablation.png'), format='png', bbox_inches='tight', dpi=150)
@@ -631,7 +828,6 @@ def run_pipeline(
631
  vcol2_is_id = (vcol2.dtype == 'object') or (vcol2.nunique() / len(vcol2) > 0.5)
632
  Xe = ed.iloc[:, 2:] if vcol2_is_id else ed.iloc[:, 1:]
633
 
634
- # Map validation labels using same mapping
635
  ye = ye_raw.map(label_map)
636
  if ye.isna().any():
637
  log(f" ⚠ 验证集 {vi} 含有训练集中不存在的标签,已跳过")
@@ -648,22 +844,25 @@ def run_pipeline(
648
  ye_np = ye.values
649
 
650
  metrics = compute_multiclass_metrics(ye_np, yed, yep, class_indices)
651
- log(f" ✅ AUC={metrics['Macro_AUC']:.4f} Acc={metrics['Accuracy']:.4f} F1={metrics['Macro_F1']:.4f} Kappa={metrics['Kappa']:.4f}")
 
652
 
653
  sfx = f'_ext{vi}' if len(val_files_list) > 1 else '_ext'
654
  tag = f'Validation {vi}' if len(val_files_list) > 1 else 'External'
655
 
656
- plot_multiclass_roc(ye_np, yep, class_indices, f'ROC — {tag} ({best_mn})', f'roc{sfx}', rf)
657
- plot_multiclass_pr(ye_np, yep, class_indices, f'PR — {tag} ({best_mn})', f'pr{sfx}', rf)
658
- plot_confusion_matrix(ye_np, yed, class_indices, f'CM — {tag} ({best_mn})', f'cm{sfx}', rf)
 
 
 
659
 
660
  with pd.ExcelWriter(os.path.join(rf, f'validation{sfx}.xlsx'), engine='openpyxl') as w:
661
  pd.DataFrame([{'Model': best_mn, 'N_Features': len(final_feats),
662
  'Macro_AUC': metrics['Macro_AUC'], 'Accuracy': metrics['Accuracy'],
663
  'Macro_F1': metrics['Macro_F1'], 'Weighted_F1': metrics['Weighted_F1'],
664
  'Kappa': metrics['Kappa']}]).to_excel(w, sheet_name='Metrics', index=False)
665
- rpt = pd.DataFrame(metrics['report']).T
666
- rpt.to_excel(w, sheet_name='Per_Class', index=True)
667
  pd.DataFrame({'Feature': final_feats}).to_excel(w, sheet_name='Features', index=False)
668
 
669
  # ── Save Results ──
@@ -673,27 +872,32 @@ def run_pipeline(
673
  with pd.ExcelWriter(os.path.join(rf, 'model_evaluation.xlsx'), engine='openpyxl') as w:
674
  for mn, r in amr.items():
675
  r['fold_df'].to_excel(w, sheet_name=mn, index=False)
676
- # Summary with retained status
677
- sd = [{'Model': mn, 'Macro_AUC': r['mean_auc'], 'Accuracy': r['mean_acc'],
678
- 'Macro_F1': r['mean_f1'], 'Retained': 'Yes' if mn in retained else 'No',
 
 
 
 
679
  'Best': 'Best' if mn == best_mn else ''}
680
  for mn, r in amr.items()]
681
- pd.DataFrame(sd).sort_values('Macro_AUC', ascending=False).to_excel(w, sheet_name='Summary', index=False)
682
- # Bootstrap test results
683
  if len(bootstrap_df) > 0:
684
  bootstrap_df.to_excel(w, sheet_name='Bootstrap_Test', index=False)
685
- # Per-class report for best model
686
- best_report = classification_report(amr[best_mn]['all_yt'], amr[best_mn]['all_yp'],
687
- labels=class_indices, output_dict=True, zero_division=0)
688
  pd.DataFrame(best_report).T.to_excel(w, sheet_name=f'{best_mn}_PerClass', index=True)
689
 
690
  if ablation_data:
691
  with pd.ExcelWriter(os.path.join(rf, 'feature_ablation.xlsx'), engine='openpyxl') as w:
692
- pd.DataFrame({'N': ablation_data['fcs'], 'AUC': ablation_data['aucs']}).to_excel(w, sheet_name='Ablation', index=False)
 
693
  for mn, idf in shap_imp.items():
694
  idf.to_excel(w, sheet_name=f'{mn}_Imp', index=False)
695
 
696
- # Save params (English for SCI)
697
  with open(os.path.join(rf, 'best_params.txt'), 'w', encoding='utf-8') as f:
698
  f.write(f"Task: {task_type} Classification ({n_classes} classes)\n")
699
  f.write(f"Classes: {classes}\n")
@@ -702,7 +906,8 @@ def run_pipeline(
702
  f.write(f"Retained Models: {', '.join(retained)} ({len(retained)}/{nm})\n\n")
703
  for mn in mcfg:
704
  status = "* Best" if mn == best_mn else ("Retained" if mn in retained else "Excluded")
705
- f.write(f"Model: {mn} | AUC={amr[mn]['mean_auc']:.4f} | {status}\n")
 
706
  bp = bpd[mn]
707
  if isinstance(bp, dict):
708
  for k, v in bp.items(): f.write(f" {k}: {v}\n")
@@ -714,10 +919,23 @@ def run_pipeline(
714
  f.write("=" * 50 + "\n")
715
  for _, row in bootstrap_df.iterrows():
716
  f.write(f" {row['Model_A']} vs {row['Model_B']}: ")
717
- f.write(f"dAUC={row['AUC_Diff']:+.4f} 95%CI=[{row['CI_95_Low']:+.4f},{row['CI_95_High']:+.4f}] ")
 
718
  f.write(f"P={row['P_value']:.4f} -> {row['Decision']}\n")
719
  if ablation_data:
720
- f.write(f"\nOptimal Features ({ablation_data['opt_n']}): {', '.join(ablation_data['opt_feats'])}\n")
 
 
 
 
 
 
 
 
 
 
 
 
721
 
722
  # Save model
723
  pickle.dump({
@@ -728,10 +946,13 @@ def run_pipeline(
728
 
729
  # ── ZIP ──
730
  progress(0.97, desc="📦 打包ZIP...")
731
- zp = os.path.join(tempfile.gettempdir(), f"ml_results_{int(time.time())}_{os.getpid()}.zip")
 
732
  with zipfile.ZipFile(zp, 'w', zipfile.ZIP_DEFLATED) as zf:
733
  for root, _, files in os.walk(rf):
734
- for fn in files: zf.write(os.path.join(root, fn), os.path.relpath(os.path.join(root, fn), rf))
 
 
735
 
736
  nf = sum(len(f) for _, _, f in os.walk(rf))
737
  shutil.rmtree(rf, ignore_errors=True); gc.collect()
@@ -808,6 +1029,7 @@ with gr.Blocks(
808
  <div class="pipeline-box">
809
  <strong>📋 流程:</strong>
810
  <code>选择分类数</code> → <code>模型训练</code> → <code>交叉验证</code> →
 
811
  <code>SHAP分析</code> → <code>特征消融</code> → <code>外部验证</code>
812
  &nbsp;&nbsp;|&nbsp;&nbsp;
813
  <strong>CSV格式:</strong> 第1列=标签(整数), 第2列=ID, 第3列起=特征
@@ -841,20 +1063,22 @@ with gr.Blocks(
841
  info="RF=随机森林 DT=决策树 KNN=K近邻 XGB=XGBoost AdaBoost LR=逻辑回归 NB=朴素贝叶斯 SVM=支持向量机",
842
  )
843
  with gr.Row():
844
- btn_all = gr.Button("🔘 全选", size="sm", variant="secondary")
845
- btn_tree = gr.Button("🌲 树模型", size="sm", variant="secondary")
846
  btn_linear = gr.Button("📐 线性模型", size="sm", variant="secondary")
847
- btn_top4 = gr.Button("⚡ 经典四模型", size="sm", variant="secondary")
848
  btn_all.click(lambda: ALL_MODEL_NAMES, outputs=model_selector)
849
  btn_tree.click(lambda: ['RF','DT','XGB','AdaBoost'], outputs=model_selector)
850
  btn_linear.click(lambda: ['LR','SVM','NB'], outputs=model_selector)
851
  btn_top4.click(lambda: ['RF','XGB','LR','SVM'], outputs=model_selector)
852
 
853
  gr.HTML('<div class="section-title">⚙️ 参数配置</div>')
854
- enable_tuning = gr.Checkbox(value=False, label="启用超参数调优 (GridSearchCV) ⚠️ 开启后运行时间显著增加")
 
 
855
  with gr.Row():
856
  cv_folds = gr.Slider(3, 10, value=5, step=1, label="交叉验证折数")
857
- top_n = gr.Slider(5, 50, value=20, step=1, label="SHAP 前 N 个特征")
858
  shap_sz = gr.Slider(30, 200, value=80, step=10, label="SHAP 采样数量")
859
 
860
  run_btn = gr.Button("🚀 开始分析", variant="primary", size="lg")
@@ -900,4 +1124,4 @@ def auth_fn(username, password):
900
  demo.queue()
901
  demo.launch(server_name="0.0.0.0", server_port=7860, auth=auth_fn,
902
  auth_message="🔐 复旦大学附属眼耳鼻喉科医院 · ML多分类分析平台\n请输入账号和密码登录",
903
- ssr_mode=False)
 
431
  log(f"\n{'━'*50}")
432
  log(f" ✅ {nm} 个模型训练完成")
433
 
434
+ # ============================================================
435
+ # ── 辅助函数:Macro ROC / PR 曲线数据 ──
436
+ # ============================================================
437
+ def _macro_roc_curve(yt, yp, nc, cls_idx):
438
+ """Return (all_fpr, mean_tpr, macro_auc) for overlay plotting."""
439
+ y_b = label_binarize(yt, classes=cls_idx)
440
+ if nc == 2:
441
+ y_b = np.hstack([1 - y_b, y_b])
442
+ all_fpr = np.linspace(0, 1, 300)
443
+ mean_tpr = np.zeros_like(all_fpr)
444
+ for c in range(nc):
445
+ f_, t_, _ = roc_curve(y_b[:, c], yp[:, c])
446
+ mean_tpr += np.interp(all_fpr, f_, t_)
447
+ mean_tpr /= nc; mean_tpr[-1] = 1.0
448
+ return all_fpr, mean_tpr, auc_score(all_fpr, mean_tpr)
449
+
450
+ def _macro_pr_curve(yt, yp, nc, cls_idx):
451
+ y_b = label_binarize(yt, classes=cls_idx)
452
+ if nc == 2:
453
+ y_b = np.hstack([1 - y_b, y_b])
454
+ all_rec = np.linspace(0, 1, 300)
455
+ mean_prec = np.zeros_like(all_rec)
456
+ for c in range(nc):
457
+ prec_, rec_, _ = precision_recall_curve(y_b[:, c], yp[:, c])
458
+ mean_prec += np.interp(all_rec[::-1], rec_[::-1], prec_[::-1])[::-1]
459
+ mean_prec /= nc
460
+ return all_rec, mean_prec
461
+
462
+ # ============================================================
463
+ # ── 训练集 ROC / PR(所有模型,in-sample)──
464
+ # ============================================================
465
+ progress(0.40, desc="📈 训练集ROC/PR曲线...")
466
+ log(f"\n 📈 绘制各模型训练集 ROC / PR 曲线...")
467
+
468
+ train_roc_summary = {} # mn -> train macro_auc
469
+ train_roc_data = {} # mn -> (fpr, tpr, auc)
470
+ train_pr_data = {} # mn -> (rec, prec)
471
+
472
+ for mn in mnames:
473
+ yproba_tr = tms[mn].predict_proba(X.values)
474
+
475
+ # 每个模型:各类 + macro 的独立 ROC / PR 图
476
+ plot_multiclass_roc(
477
+ y_mapped.values, yproba_tr, class_indices,
478
+ f'Train ROC — {mn} ({task_type})', f'train_roc_{mn}', rf
479
+ )
480
+ plot_multiclass_pr(
481
+ y_mapped.values, yproba_tr, class_indices,
482
+ f'Train PR — {mn} ({task_type})', f'train_pr_{mn}', rf
483
+ )
484
+
485
+ fpr_tr, tpr_tr, auc_tr = _macro_roc_curve(
486
+ y_mapped.values, yproba_tr, n_classes, class_indices)
487
+ rec_tr, prec_tr = _macro_pr_curve(
488
+ y_mapped.values, yproba_tr, n_classes, class_indices)
489
+
490
+ train_roc_data[mn] = (fpr_tr, tpr_tr, auc_tr)
491
+ train_pr_data[mn] = (rec_tr, prec_tr)
492
+ train_roc_summary[mn] = auc_tr
493
+
494
+ # 汇总训练集 ROC(所有模型叠加)
495
+ plt.figure(figsize=(10, 8))
496
+ for i, mn in enumerate(mnames):
497
+ fpr_tr, tpr_tr, auc_tr = train_roc_data[mn]
498
+ plt.plot(fpr_tr, tpr_tr, color=COLORS[i % 8], lw=2.5,
499
+ label=f'{mn} (Train Macro AUC={auc_tr:.3f})')
500
+ plt.plot([0, 1], [0, 1], '--', color='#ccc', lw=1)
501
+ plt.xlim([-0.02, 1.02]); plt.ylim([-0.02, 1.02])
502
+ plt.xlabel('False Positive Rate', fontsize=13)
503
+ plt.ylabel('True Positive Rate', fontsize=13)
504
+ plt.title(f'Train ROC — All Models ({task_type})', fontsize=14, fontweight='bold')
505
+ plt.legend(loc='lower right', fontsize=10)
506
+ plt.grid(True, alpha=0.15); plt.tight_layout()
507
+ plt.savefig(os.path.join(rf, 'train_roc_all.pdf'), format='pdf', bbox_inches='tight', dpi=300)
508
+ plt.savefig(os.path.join(rf, 'train_roc_all.png'), format='png', bbox_inches='tight', dpi=150)
509
+ plt.close()
510
+
511
+ # 汇总训练集 PR(所有模型叠加)
512
+ plt.figure(figsize=(10, 8))
513
+ for i, mn in enumerate(mnames):
514
+ rec_tr, prec_tr = train_pr_data[mn]
515
+ plt.plot(rec_tr, prec_tr, color=COLORS[i % 8], lw=2.5,
516
+ label=f'{mn} (Mean AP={prec_tr.mean():.3f})')
517
+ plt.xlim([-0.02, 1.02]); plt.ylim([-0.02, 1.02])
518
+ plt.xlabel('Recall', fontsize=13); plt.ylabel('Precision', fontsize=13)
519
+ plt.title(f'Train PR — All Models ({task_type})', fontsize=14, fontweight='bold')
520
+ plt.legend(loc='lower left', fontsize=10)
521
+ plt.grid(True, alpha=0.15); plt.tight_layout()
522
+ plt.savefig(os.path.join(rf, 'train_pr_all.pdf'), format='pdf', bbox_inches='tight', dpi=300)
523
+ plt.savefig(os.path.join(rf, 'train_pr_all.png'), format='png', bbox_inches='tight', dpi=150)
524
+ plt.close()
525
+ log(f" ✅ 训练集 ROC/PR 曲线已生成(各模型独立 + 汇总共 {nm*2+2*2} 张图)")
526
+
527
+ # ============================================================
528
+ # ── 交叉验证 ROC(原有逻辑,保留)──
529
+ # ============================================================
530
+ progress(0.42, desc="📈 交叉验证ROC曲线...")
531
+ log(f"\n 📈 绘制交叉验证 ROC 曲线...")
532
  for mn in mnames:
533
  r = amr[mn]
534
  plot_multiclass_roc(r['all_yt'], r['all_yproba'], class_indices,
535
+ f'CV ROC — {mn} ({task_type}, Macro AUC={r["mean_auc"]:.3f})', f'roc_{mn}', rf)
536
 
537
+ # 汇总 CV ROC(所有模型)
538
  plt.figure(figsize=(10, 8))
539
  for i, mn in enumerate(mnames):
540
  r = amr[mn]
541
+ fpr_cv, tpr_cv, auc_cv = _macro_roc_curve(
542
+ r['all_yt'], r['all_yproba'], n_classes, class_indices)
543
+ plt.plot(fpr_cv, tpr_cv, color=COLORS[i % 8], lw=2.5,
544
+ label=f'{mn} (CV Macro AUC={auc_cv:.3f})')
545
+ plt.plot([0, 1], [0, 1], '--', color='#ccc', lw=1)
546
+ plt.xlim([-0.02, 1.02]); plt.ylim([-0.02, 1.02])
547
+ plt.xlabel('FPR', fontsize=13); plt.ylabel('TPR', fontsize=13)
548
+ plt.title(f'CV ROC — All Models ({task_type})', fontsize=14, fontweight='bold')
549
+ plt.legend(loc='lower right', fontsize=10)
550
+ plt.grid(True, alpha=0.15); plt.tight_layout()
551
+ plt.savefig(os.path.join(rf, 'roc_all.pdf'), format='pdf', bbox_inches='tight', dpi=300)
552
+ plt.savefig(os.path.join(rf, 'roc_all.png'), format='png', bbox_inches='tight', dpi=150)
 
 
 
 
553
  plt.close()
554
 
555
+ # ============================================================
556
+ # ── 最佳模型:训练集 vs 内部验证集(CV holdout)对比 ──
557
+ # ============================================================
558
+ progress(0.44, desc="📊 最终模型训练集vs内部验证集对比...")
559
+
560
+ # 先确定最佳模型(后续 Bootstrap 也会用到,此处提前计算)
561
+ best_mn = max(amr, key=lambda x: amr[x]['mean_auc'])
562
+ best_auc = amr[best_mn]['mean_auc']
563
+
564
+ log(f"\n 📊 最终模型 [{best_mn}] 训练集 vs 内部验证集(CV holdout)对比...")
565
+
566
+ # 训练集预测
567
+ yproba_best_train = tms[best_mn].predict_proba(X.values)
568
+ ypred_best_train = tms[best_mn].predict(X.values)
569
+ metrics_train = compute_multiclass_metrics(
570
+ y_mapped.values, ypred_best_train, yproba_best_train, class_indices
571
+ )
572
+
573
+ # CV holdout(已在 amr 中累积)
574
+ yproba_best_cv = amr[best_mn]['all_yproba']
575
+ ypred_best_cv = amr[best_mn]['all_yp']
576
+ ytrue_best_cv = amr[best_mn]['all_yt']
577
+ metrics_cv = compute_multiclass_metrics(
578
+ ytrue_best_cv, ypred_best_cv, yproba_best_cv, class_indices
579
+ )
580
+
581
+ log(f" Train → AUC={metrics_train['Macro_AUC']:.4f} Acc={metrics_train['Accuracy']:.4f}"
582
+ f" F1={metrics_train['Macro_F1']:.4f} Kappa={metrics_train['Kappa']:.4f}")
583
+ log(f" CV-Val → AUC={metrics_cv['Macro_AUC']:.4f} Acc={metrics_cv['Accuracy']:.4f}"
584
+ f" F1={metrics_cv['Macro_F1']:.4f} Kappa={metrics_cv['Kappa']:.4f}")
585
+
586
+ # 对比 ROC
587
+ fpr_tr_b, tpr_tr_b, auc_tr_b = _macro_roc_curve(
588
+ y_mapped.values, yproba_best_train, n_classes, class_indices)
589
+ fpr_cv_b, tpr_cv_b, auc_cv_b = _macro_roc_curve(
590
+ ytrue_best_cv, yproba_best_cv, n_classes, class_indices)
591
+
592
+ fig, ax = plt.subplots(figsize=(10, 8))
593
+ ax.plot(fpr_tr_b, tpr_tr_b, color='#e41a1c', lw=2.5,
594
+ label=f'Train set (Macro AUC={auc_tr_b:.3f})')
595
+ ax.plot(fpr_cv_b, tpr_cv_b, color='#377eb8', lw=2.5, linestyle='--',
596
+ label=f'Internal CV (Macro AUC={auc_cv_b:.3f})')
597
+ ax.plot([0, 1], [0, 1], '--', color='#ccc', lw=1)
598
+ ax.set_xlim([-0.02, 1.02]); ax.set_ylim([-0.02, 1.02])
599
+ ax.set_xlabel('False Positive Rate', fontsize=13)
600
+ ax.set_ylabel('True Positive Rate', fontsize=13)
601
+ ax.set_title(f'ROC — {best_mn}: Train vs Internal CV ({task_type})',
602
+ fontsize=14, fontweight='bold')
603
+ ax.legend(loc='lower right', fontsize=11)
604
+ ax.grid(True, alpha=0.15); plt.tight_layout()
605
+ plt.savefig(os.path.join(rf, f'roc_train_vs_cv_{best_mn}.pdf'),
606
+ format='pdf', bbox_inches='tight', dpi=300)
607
+ plt.savefig(os.path.join(rf, f'roc_train_vs_cv_{best_mn}.png'),
608
+ format='png', bbox_inches='tight', dpi=150)
609
+ plt.close()
610
+
611
+ # 对比 PR
612
+ rec_tr_b, prec_tr_b = _macro_pr_curve(
613
+ y_mapped.values, yproba_best_train, n_classes, class_indices)
614
+ rec_cv_b, prec_cv_b = _macro_pr_curve(
615
+ ytrue_best_cv, yproba_best_cv, n_classes, class_indices)
616
+
617
+ fig, ax = plt.subplots(figsize=(10, 8))
618
+ ax.plot(rec_tr_b, prec_tr_b, color='#e41a1c', lw=2.5,
619
+ label=f'Train set (Mean AP={prec_tr_b.mean():.3f})')
620
+ ax.plot(rec_cv_b, prec_cv_b, color='#377eb8', lw=2.5, linestyle='--',
621
+ label=f'Internal CV (Mean AP={prec_cv_b.mean():.3f})')
622
+ ax.set_xlim([-0.02, 1.02]); ax.set_ylim([-0.02, 1.02])
623
+ ax.set_xlabel('Recall', fontsize=13)
624
+ ax.set_ylabel('Precision', fontsize=13)
625
+ ax.set_title(f'PR — {best_mn}: Train vs Internal CV ({task_type})',
626
+ fontsize=14, fontweight='bold')
627
+ ax.legend(loc='lower left', fontsize=11)
628
+ ax.grid(True, alpha=0.15); plt.tight_layout()
629
+ plt.savefig(os.path.join(rf, f'pr_train_vs_cv_{best_mn}.pdf'),
630
+ format='pdf', bbox_inches='tight', dpi=300)
631
+ plt.savefig(os.path.join(rf, f'pr_train_vs_cv_{best_mn}.png'),
632
+ format='png', bbox_inches='tight', dpi=150)
633
+ plt.close()
634
+
635
+ # 训练集混淆矩阵(最佳模型)
636
+ plot_confusion_matrix(
637
+ y_mapped.values, ypred_best_train, class_indices,
638
+ f'Train CM — {best_mn} (Acc={metrics_train["Accuracy"]:.3f})',
639
+ f'cm_train_{best_mn}', rf
640
+ )
641
+
642
+ # 保存 Train vs CV 汇总 Excel
643
+ with pd.ExcelWriter(os.path.join(rf, f'train_vs_cv_{best_mn}.xlsx'),
644
+ engine='openpyxl') as w:
645
+ summary_rows = [
646
+ {'Split': 'Train', 'Model': best_mn,
647
+ 'Macro_AUC': metrics_train['Macro_AUC'],
648
+ 'Accuracy': metrics_train['Accuracy'],
649
+ 'Macro_F1': metrics_train['Macro_F1'],
650
+ 'Weighted_F1': metrics_train['Weighted_F1'],
651
+ 'Kappa': metrics_train['Kappa']},
652
+ {'Split': 'Internal_CV', 'Model': best_mn,
653
+ 'Macro_AUC': metrics_cv['Macro_AUC'],
654
+ 'Accuracy': metrics_cv['Accuracy'],
655
+ 'Macro_F1': metrics_cv['Macro_F1'],
656
+ 'Weighted_F1': metrics_cv['Weighted_F1'],
657
+ 'Kappa': metrics_cv['Kappa']},
658
+ ]
659
+ pd.DataFrame(summary_rows).to_excel(w, sheet_name='Summary', index=False)
660
+ pd.DataFrame(metrics_train['report']).T.to_excel(w, sheet_name='Train_PerClass', index=True)
661
+ pd.DataFrame(metrics_cv['report']).T.to_excel(w, sheet_name='CV_PerClass', index=True)
662
+ amr[best_mn]['fold_df'].to_excel(w, sheet_name='CV_FoldDetail', index=False)
663
+
664
+ log(f" ✅ Train vs CV 对比图及数据已保存 → train_vs_cv_{best_mn}.xlsx")
665
+
666
+ # ── PR Curves (CV,原有逻辑) ──
667
+ progress(0.48, desc="📈 交叉验证PR曲线...")
668
  for mn in mnames:
669
  r = amr[mn]
670
  plot_multiclass_pr(r['all_yt'], r['all_yproba'], class_indices,
671
+ f'CV PR — {mn} ({task_type})', f'pr_{mn}', rf)
672
 
673
+ # ── Confusion Matrices (CV) ──
674
  progress(0.52, desc="📊 混淆矩阵...")
675
  for mn in mnames:
676
  r = amr[mn]
677
  plot_confusion_matrix(r['all_yt'], r['all_yp'], class_indices,
678
+ f'CV CM — {mn} (Acc={r["mean_acc"]:.3f})', f'cm_{mn}', rf)
679
 
680
+ # ── Bootstrap AUC Test ──
681
  progress(0.55, desc="🔬 Bootstrap AUC 检验...")
 
 
682
  log(f"\n 🏆 最佳模型: {best_mn} (Macro AUC={best_auc:.4f})")
683
  log(f" 🔬 Bootstrap 检验 (n=2000, α=0.05)...")
684
 
 
717
  progress(0.62, desc="🔥 SHAP分析...")
718
  log(f"\n 🔥 SHAP特征分析 (保留模型中 Top 3)...")
719
  shap_imp = {}
 
720
  models_for_shap = sorted(retained, key=lambda x: amr[x]['mean_auc'], reverse=True)[:3]
721
 
722
  for si, mn in enumerate(models_for_shap):
 
733
  exp = shap.KernelExplainer(lambda x, m=mo: m.predict_proba(x), bg)
734
  sv = exp.shap_values(Xs)
735
 
 
736
  if isinstance(sv, list):
 
737
  sv_abs = np.mean([np.abs(s) for s in sv], axis=0)
738
  elif sv.ndim == 3:
739
+ sv_abs = np.mean(np.abs(sv), axis=2)
740
  else:
741
  sv_abs = np.abs(sv)
742
 
 
747
  idf = pd.DataFrame({'Feature': fnames, 'Importance': fi}).sort_values('Importance', ascending=False)
748
  shap_imp[mn] = idf
749
 
 
750
  plt.figure(figsize=(10, max(6, TOPN * 0.3)))
751
  top_df = idf.head(TOPN).iloc[::-1]
752
  plt.barh(top_df['Feature'], top_df['Importance'], color='#2563eb', alpha=0.8)
 
760
  except Exception as e:
761
  log(f" ⚠ {mn} SHAP失败: {e}")
762
 
763
+ # ── Feature Ablation ──
764
  progress(0.72, desc="🧪 特征消融...")
765
  log(f"\n 🧪 特征消融 (仅最佳模型 {best_mn})...")
766
  ablation_data = None
 
768
  imp_df = shap_imp[best_mn]
769
  top_feats = imp_df.head(TOPN)['Feature'].tolist()
770
  fcs = []; aucs_a = []
 
771
 
772
  for nf in range(1, len(top_feats) + 1):
773
  Xsub = X[top_feats[:nf]]
 
788
  fold_aucs.append(a)
789
  fcs.append(nf); aucs_a.append(np.mean(fold_aucs))
790
 
 
791
  full_auc = amr[best_mn]['mean_auc']
792
  opt_n = len(top_feats)
793
  for i, a in enumerate(aucs_a):
794
  if a >= full_auc * 0.95:
795
  opt_n = i + 1; break
796
 
797
+ ablation_data = {'fcs': fcs, 'aucs': aucs_a, 'feats': top_feats,
798
+ 'opt_n': opt_n, 'opt_feats': top_feats[:opt_n]}
799
  log(f" ✅ 最优特征数: {opt_n} (AUC={aucs_a[opt_n-1]:.4f} vs Full={full_auc:.4f})")
800
 
 
801
  plt.figure(figsize=(10, 7))
802
  plt.plot(fcs, aucs_a, 'o-', color='#2563eb', lw=2, ms=5)
803
+ plt.scatter([opt_n], [aucs_a[opt_n-1]], s=200, marker='*',
804
+ color='#ef4444', edgecolors='black', lw=2, zorder=5)
805
+ plt.axhline(y=full_auc, color='gray', ls='--', lw=1, alpha=0.5,
806
+ label=f'Full AUC={full_auc:.3f}')
807
  plt.xlabel('Number of Features', fontsize=13); plt.ylabel('Macro AUC', fontsize=13)
808
+ plt.title(f'Feature Ablation — {best_mn} (★ Optimal={opt_n})',
809
+ fontsize=14, fontweight='bold')
810
  plt.legend(fontsize=11); plt.grid(True, alpha=0.15); plt.tight_layout()
811
  plt.savefig(os.path.join(rf, 'ablation.pdf'), format='pdf', bbox_inches='tight')
812
  plt.savefig(os.path.join(rf, 'ablation.png'), format='png', bbox_inches='tight', dpi=150)
 
828
  vcol2_is_id = (vcol2.dtype == 'object') or (vcol2.nunique() / len(vcol2) > 0.5)
829
  Xe = ed.iloc[:, 2:] if vcol2_is_id else ed.iloc[:, 1:]
830
 
 
831
  ye = ye_raw.map(label_map)
832
  if ye.isna().any():
833
  log(f" ⚠ 验证集 {vi} 含有训练集中不存在的标签,已跳过")
 
844
  ye_np = ye.values
845
 
846
  metrics = compute_multiclass_metrics(ye_np, yed, yep, class_indices)
847
+ log(f" ✅ AUC={metrics['Macro_AUC']:.4f} Acc={metrics['Accuracy']:.4f}"
848
+ f" F1={metrics['Macro_F1']:.4f} Kappa={metrics['Kappa']:.4f}")
849
 
850
  sfx = f'_ext{vi}' if len(val_files_list) > 1 else '_ext'
851
  tag = f'Validation {vi}' if len(val_files_list) > 1 else 'External'
852
 
853
+ plot_multiclass_roc(ye_np, yep, class_indices,
854
+ f'ROC — {tag} ({best_mn})', f'roc{sfx}', rf)
855
+ plot_multiclass_pr(ye_np, yep, class_indices,
856
+ f'PR — {tag} ({best_mn})', f'pr{sfx}', rf)
857
+ plot_confusion_matrix(ye_np, yed, class_indices,
858
+ f'CM — {tag} ({best_mn})', f'cm{sfx}', rf)
859
 
860
  with pd.ExcelWriter(os.path.join(rf, f'validation{sfx}.xlsx'), engine='openpyxl') as w:
861
  pd.DataFrame([{'Model': best_mn, 'N_Features': len(final_feats),
862
  'Macro_AUC': metrics['Macro_AUC'], 'Accuracy': metrics['Accuracy'],
863
  'Macro_F1': metrics['Macro_F1'], 'Weighted_F1': metrics['Weighted_F1'],
864
  'Kappa': metrics['Kappa']}]).to_excel(w, sheet_name='Metrics', index=False)
865
+ pd.DataFrame(metrics['report']).T.to_excel(w, sheet_name='Per_Class', index=True)
 
866
  pd.DataFrame({'Feature': final_feats}).to_excel(w, sheet_name='Features', index=False)
867
 
868
  # ── Save Results ──
 
872
  with pd.ExcelWriter(os.path.join(rf, 'model_evaluation.xlsx'), engine='openpyxl') as w:
873
  for mn, r in amr.items():
874
  r['fold_df'].to_excel(w, sheet_name=mn, index=False)
875
+ # Summary(新增 Train_AUC 列)
876
+ sd = [{'Model': mn,
877
+ 'CV_Macro_AUC': r['mean_auc'],
878
+ 'Train_Macro_AUC': train_roc_summary.get(mn, ''),
879
+ 'CV_Accuracy': r['mean_acc'],
880
+ 'CV_Macro_F1': r['mean_f1'],
881
+ 'Retained': 'Yes' if mn in retained else 'No',
882
  'Best': 'Best' if mn == best_mn else ''}
883
  for mn, r in amr.items()]
884
+ pd.DataFrame(sd).sort_values('CV_Macro_AUC', ascending=False).to_excel(
885
+ w, sheet_name='Summary', index=False)
886
  if len(bootstrap_df) > 0:
887
  bootstrap_df.to_excel(w, sheet_name='Bootstrap_Test', index=False)
888
+ best_report = classification_report(
889
+ amr[best_mn]['all_yt'], amr[best_mn]['all_yp'],
890
+ labels=class_indices, output_dict=True, zero_division=0)
891
  pd.DataFrame(best_report).T.to_excel(w, sheet_name=f'{best_mn}_PerClass', index=True)
892
 
893
  if ablation_data:
894
  with pd.ExcelWriter(os.path.join(rf, 'feature_ablation.xlsx'), engine='openpyxl') as w:
895
+ pd.DataFrame({'N': ablation_data['fcs'], 'AUC': ablation_data['aucs']}).to_excel(
896
+ w, sheet_name='Ablation', index=False)
897
  for mn, idf in shap_imp.items():
898
  idf.to_excel(w, sheet_name=f'{mn}_Imp', index=False)
899
 
900
+ # Save params
901
  with open(os.path.join(rf, 'best_params.txt'), 'w', encoding='utf-8') as f:
902
  f.write(f"Task: {task_type} Classification ({n_classes} classes)\n")
903
  f.write(f"Classes: {classes}\n")
 
906
  f.write(f"Retained Models: {', '.join(retained)} ({len(retained)}/{nm})\n\n")
907
  for mn in mcfg:
908
  status = "* Best" if mn == best_mn else ("Retained" if mn in retained else "Excluded")
909
+ f.write(f"Model: {mn} | CV_AUC={amr[mn]['mean_auc']:.4f}"
910
+ f" | Train_AUC={train_roc_summary.get(mn, 'N/A')} | {status}\n")
911
  bp = bpd[mn]
912
  if isinstance(bp, dict):
913
  for k, v in bp.items(): f.write(f" {k}: {v}\n")
 
919
  f.write("=" * 50 + "\n")
920
  for _, row in bootstrap_df.iterrows():
921
  f.write(f" {row['Model_A']} vs {row['Model_B']}: ")
922
+ f.write(f"dAUC={row['AUC_Diff']:+.4f} "
923
+ f"95%CI=[{row['CI_95_Low']:+.4f},{row['CI_95_High']:+.4f}] ")
924
  f.write(f"P={row['P_value']:.4f} -> {row['Decision']}\n")
925
  if ablation_data:
926
+ f.write(f"\nOptimal Features ({ablation_data['opt_n']}): "
927
+ f"{', '.join(ablation_data['opt_feats'])}\n")
928
+ f.write(f"\n{'='*50}\n")
929
+ f.write(f"Best Model [{best_mn}] Train vs Internal CV\n")
930
+ f.write(f"{'='*50}\n")
931
+ f.write(f" Train → AUC={metrics_train['Macro_AUC']:.4f}"
932
+ f" Acc={metrics_train['Accuracy']:.4f}"
933
+ f" F1={metrics_train['Macro_F1']:.4f}"
934
+ f" Kappa={metrics_train['Kappa']:.4f}\n")
935
+ f.write(f" CV-Val → AUC={metrics_cv['Macro_AUC']:.4f}"
936
+ f" Acc={metrics_cv['Accuracy']:.4f}"
937
+ f" F1={metrics_cv['Macro_F1']:.4f}"
938
+ f" Kappa={metrics_cv['Kappa']:.4f}\n")
939
 
940
  # Save model
941
  pickle.dump({
 
946
 
947
  # ── ZIP ──
948
  progress(0.97, desc="📦 打包ZIP...")
949
+ zp = os.path.join(tempfile.gettempdir(),
950
+ f"ml_results_{int(time.time())}_{os.getpid()}.zip")
951
  with zipfile.ZipFile(zp, 'w', zipfile.ZIP_DEFLATED) as zf:
952
  for root, _, files in os.walk(rf):
953
+ for fn in files:
954
+ zf.write(os.path.join(root, fn),
955
+ os.path.relpath(os.path.join(root, fn), rf))
956
 
957
  nf = sum(len(f) for _, _, f in os.walk(rf))
958
  shutil.rmtree(rf, ignore_errors=True); gc.collect()
 
1029
  <div class="pipeline-box">
1030
  <strong>📋 流程:</strong>
1031
  <code>选择分类数</code> → <code>模型训练</code> → <code>交叉验证</code> →
1032
+ <code>训练集ROC/PR</code> → <code>Train vs CV对比</code> →
1033
  <code>SHAP分析</code> → <code>特征消融</code> → <code>外部验证</code>
1034
  &nbsp;&nbsp;|&nbsp;&nbsp;
1035
  <strong>CSV格式:</strong> 第1列=标签(整数), 第2列=ID, 第3列起=特征
 
1063
  info="RF=随机森林 DT=决策树 KNN=K近邻 XGB=XGBoost AdaBoost LR=逻辑回归 NB=朴素贝叶斯 SVM=支持向量机",
1064
  )
1065
  with gr.Row():
1066
+ btn_all = gr.Button("🔘 全选", size="sm", variant="secondary")
1067
+ btn_tree = gr.Button("🌲 树模型", size="sm", variant="secondary")
1068
  btn_linear = gr.Button("📐 线性模型", size="sm", variant="secondary")
1069
+ btn_top4 = gr.Button("⚡ 经典四模型", size="sm", variant="secondary")
1070
  btn_all.click(lambda: ALL_MODEL_NAMES, outputs=model_selector)
1071
  btn_tree.click(lambda: ['RF','DT','XGB','AdaBoost'], outputs=model_selector)
1072
  btn_linear.click(lambda: ['LR','SVM','NB'], outputs=model_selector)
1073
  btn_top4.click(lambda: ['RF','XGB','LR','SVM'], outputs=model_selector)
1074
 
1075
  gr.HTML('<div class="section-title">⚙️ 参数配置</div>')
1076
+ enable_tuning = gr.Checkbox(
1077
+ value=False,
1078
+ label="启用超参数调优 (GridSearchCV) ⚠️ 开启后运行时间显著增加")
1079
  with gr.Row():
1080
  cv_folds = gr.Slider(3, 10, value=5, step=1, label="交叉验证折数")
1081
+ top_n = gr.Slider(5, 50, value=20, step=1, label="SHAP 前 N 个特征")
1082
  shap_sz = gr.Slider(30, 200, value=80, step=10, label="SHAP 采样数量")
1083
 
1084
  run_btn = gr.Button("🚀 开始分析", variant="primary", size="lg")
 
1124
  demo.queue()
1125
  demo.launch(server_name="0.0.0.0", server_port=7860, auth=auth_fn,
1126
  auth_message="🔐 复旦大学附属眼耳鼻喉科医院 · ML多分类分析平台\n请输入账号和密码登录",
1127
+ ssr_mode=False)