fudan-renjun commited on
Commit
a21dd2d
·
verified ·
1 Parent(s): 27e1310

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +233 -262
app.py CHANGED
@@ -431,11 +431,97 @@ def run_pipeline(
431
  log(f"\n{'━'*50}")
432
  log(f" ✅ {nm} 个模型训练完成")
433
 
434
- # ============================================================
435
- # ── 辅助函数:Macro ROC / PR 曲线数据 ──
436
- # ============================================================
437
- def _macro_roc_curve(yt, yp, nc, cls_idx):
438
- """Return (all_fpr, mean_tpr, macro_auc) for overlay plotting."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  y_b = label_binarize(yt, classes=cls_idx)
440
  if nc == 2:
441
  y_b = np.hstack([1 - y_b, y_b])
@@ -447,7 +533,7 @@ def run_pipeline(
447
  mean_tpr /= nc; mean_tpr[-1] = 1.0
448
  return all_fpr, mean_tpr, auc_score(all_fpr, mean_tpr)
449
 
450
- def _macro_pr_curve(yt, yp, nc, cls_idx):
451
  y_b = label_binarize(yt, classes=cls_idx)
452
  if nc == 2:
453
  y_b = np.hstack([1 - y_b, y_b])
@@ -459,44 +545,35 @@ def run_pipeline(
459
  mean_prec /= nc
460
  return all_rec, mean_prec
461
 
462
- # ============================================================
463
- # ── 训练集 ROC / PR(所有模型,in-sample)──
464
- # ============================================================
465
- progress(0.40, desc="📈 训练集ROC/PR曲线...")
466
- log(f"\n 📈 绘制各模型训练集 ROC / PR 曲线...")
467
-
468
- train_roc_summary = {} # mn -> train macro_auc
469
- train_roc_data = {} # mn -> (fpr, tpr, auc)
470
- train_pr_data = {} # mn -> (rec, prec)
471
 
472
  for mn in mnames:
473
- yproba_tr = tms[mn].predict_proba(X.values)
474
 
475
- # 每个模型:各类 + macro 的独立 ROC / PR
476
  plot_multiclass_roc(
477
- y_mapped.values, yproba_tr, class_indices,
478
- f'Train ROC — {mn} ({task_type})', f'train_roc_{mn}', rf
 
479
  )
480
  plot_multiclass_pr(
481
- y_mapped.values, yproba_tr, class_indices,
482
- f'Train PR — {mn} ({task_type})', f'train_pr_{mn}', rf
 
483
  )
484
 
485
- fpr_tr, tpr_tr, auc_tr = _macro_roc_curve(
486
- y_mapped.values, yproba_tr, n_classes, class_indices)
487
- rec_tr, prec_tr = _macro_pr_curve(
488
- y_mapped.values, yproba_tr, n_classes, class_indices)
489
-
490
- train_roc_data[mn] = (fpr_tr, tpr_tr, auc_tr)
491
- train_pr_data[mn] = (rec_tr, prec_tr)
492
- train_roc_summary[mn] = auc_tr
493
 
494
- # 汇总训练集 ROC(所有模型叠加
495
  plt.figure(figsize=(10, 8))
496
  for i, mn in enumerate(mnames):
497
- fpr_tr, tpr_tr, auc_tr = train_roc_data[mn]
498
- plt.plot(fpr_tr, tpr_tr, color=COLORS[i % 8], lw=2.5,
499
- label=f'{mn} (Train Macro AUC={auc_tr:.3f})')
500
  plt.plot([0, 1], [0, 1], '--', color='#ccc', lw=1)
501
  plt.xlim([-0.02, 1.02]); plt.ylim([-0.02, 1.02])
502
  plt.xlabel('False Positive Rate', fontsize=13)
@@ -508,92 +585,61 @@ def run_pipeline(
508
  plt.savefig(os.path.join(rf, 'train_roc_all.png'), format='png', bbox_inches='tight', dpi=150)
509
  plt.close()
510
 
511
- # 汇总训练集 PR(所有模型叠加
512
  plt.figure(figsize=(10, 8))
513
  for i, mn in enumerate(mnames):
514
- rec_tr, prec_tr = train_pr_data[mn]
515
- plt.plot(rec_tr, prec_tr, color=COLORS[i % 8], lw=2.5,
516
- label=f'{mn} (Mean AP={prec_tr.mean():.3f})')
517
  plt.xlim([-0.02, 1.02]); plt.ylim([-0.02, 1.02])
518
- plt.xlabel('Recall', fontsize=13); plt.ylabel('Precision', fontsize=13)
 
519
  plt.title(f'Train PR — All Models ({task_type})', fontsize=14, fontweight='bold')
520
  plt.legend(loc='lower left', fontsize=10)
521
  plt.grid(True, alpha=0.15); plt.tight_layout()
522
  plt.savefig(os.path.join(rf, 'train_pr_all.pdf'), format='pdf', bbox_inches='tight', dpi=300)
523
  plt.savefig(os.path.join(rf, 'train_pr_all.png'), format='png', bbox_inches='tight', dpi=150)
524
  plt.close()
525
- log(f" ✅ 训练集 ROC/PR 曲线已生成(��模型独立 + 汇总 {nm*2+2*2} 张图)")
526
-
527
- # ============================================================
528
- # ── 交叉验证 ROC原有逻辑,保留──
529
- # ============================================================
530
- progress(0.42, desc="📈 交叉验证ROC曲线...")
531
- log(f"\n 📈 绘制交叉验证 ROC 曲线...")
532
- for mn in mnames:
533
- r = amr[mn]
534
- plot_multiclass_roc(r['all_yt'], r['all_yproba'], class_indices,
535
- f'CV ROC — {mn} ({task_type}, Macro AUC={r["mean_auc"]:.3f})', f'roc_{mn}', rf)
536
-
537
- # 汇总 CV ROC(所有模型)
538
- plt.figure(figsize=(10, 8))
539
- for i, mn in enumerate(mnames):
540
- r = amr[mn]
541
- fpr_cv, tpr_cv, auc_cv = _macro_roc_curve(
542
- r['all_yt'], r['all_yproba'], n_classes, class_indices)
543
- plt.plot(fpr_cv, tpr_cv, color=COLORS[i % 8], lw=2.5,
544
- label=f'{mn} (CV Macro AUC={auc_cv:.3f})')
545
- plt.plot([0, 1], [0, 1], '--', color='#ccc', lw=1)
546
- plt.xlim([-0.02, 1.02]); plt.ylim([-0.02, 1.02])
547
- plt.xlabel('FPR', fontsize=13); plt.ylabel('TPR', fontsize=13)
548
- plt.title(f'CV ROC — All Models ({task_type})', fontsize=14, fontweight='bold')
549
- plt.legend(loc='lower right', fontsize=10)
550
- plt.grid(True, alpha=0.15); plt.tight_layout()
551
- plt.savefig(os.path.join(rf, 'roc_all.pdf'), format='pdf', bbox_inches='tight', dpi=300)
552
- plt.savefig(os.path.join(rf, 'roc_all.png'), format='png', bbox_inches='tight', dpi=150)
553
- plt.close()
554
-
555
- # ============================================================
556
- # ── 最佳模型:训练集 vs 内部验证集(CV holdout)对比 ──
557
- # ============================================================
558
- progress(0.44, desc="📊 最终模型训练集vs内部验证集对比...")
559
-
560
- # 先确定最佳模型(后续 Bootstrap 也会用到,此处提前计算)
561
- best_mn = max(amr, key=lambda x: amr[x]['mean_auc'])
562
- best_auc = amr[best_mn]['mean_auc']
563
-
564
- log(f"\n 📊 最终模型 [{best_mn}] 训练集 vs 内部验证集(CV holdout)对比...")
565
-
566
- # 训练集预测
567
- yproba_best_train = tms[best_mn].predict_proba(X.values)
568
- ypred_best_train = tms[best_mn].predict(X.values)
569
- metrics_train = compute_multiclass_metrics(
570
- y_mapped.values, ypred_best_train, yproba_best_train, class_indices
571
- )
572
-
573
- # CV holdout(已在 amr 中累积)
574
- yproba_best_cv = amr[best_mn]['all_yproba']
575
- ypred_best_cv = amr[best_mn]['all_yp']
576
- ytrue_best_cv = amr[best_mn]['all_yt']
577
- metrics_cv = compute_multiclass_metrics(
578
- ytrue_best_cv, ypred_best_cv, yproba_best_cv, class_indices
579
- )
580
-
581
- log(f" Train → AUC={metrics_train['Macro_AUC']:.4f} Acc={metrics_train['Accuracy']:.4f}"
582
- f" F1={metrics_train['Macro_F1']:.4f} Kappa={metrics_train['Kappa']:.4f}")
583
- log(f" CV-Val → AUC={metrics_cv['Macro_AUC']:.4f} Acc={metrics_cv['Accuracy']:.4f}"
584
- f" F1={metrics_cv['Macro_F1']:.4f} Kappa={metrics_cv['Kappa']:.4f}")
585
-
586
- # 对比 ROC
587
- fpr_tr_b, tpr_tr_b, auc_tr_b = _macro_roc_curve(
588
- y_mapped.values, yproba_best_train, n_classes, class_indices)
589
- fpr_cv_b, tpr_cv_b, auc_cv_b = _macro_roc_curve(
590
- ytrue_best_cv, yproba_best_cv, n_classes, class_indices)
591
 
592
  fig, ax = plt.subplots(figsize=(10, 8))
593
- ax.plot(fpr_tr_b, tpr_tr_b, color='#e41a1c', lw=2.5,
594
- label=f'Train set (Macro AUC={auc_tr_b:.3f})')
595
- ax.plot(fpr_cv_b, tpr_cv_b, color='#377eb8', lw=2.5, linestyle='--',
596
- label=f'Internal CV (Macro AUC={auc_cv_b:.3f})')
597
  ax.plot([0, 1], [0, 1], '--', color='#ccc', lw=1)
598
  ax.set_xlim([-0.02, 1.02]); ax.set_ylim([-0.02, 1.02])
599
  ax.set_xlabel('False Positive Rate', fontsize=13)
@@ -608,17 +654,17 @@ def run_pipeline(
608
  format='png', bbox_inches='tight', dpi=150)
609
  plt.close()
610
 
611
- # 对比 PR
612
- rec_tr_b, prec_tr_b = _macro_pr_curve(
613
- y_mapped.values, yproba_best_train, n_classes, class_indices)
614
- rec_cv_b, prec_cv_b = _macro_pr_curve(
615
- ytrue_best_cv, yproba_best_cv, n_classes, class_indices)
616
 
617
  fig, ax = plt.subplots(figsize=(10, 8))
618
- ax.plot(rec_tr_b, prec_tr_b, color='#e41a1c', lw=2.5,
619
- label=f'Train set (Mean AP={prec_tr_b.mean():.3f})')
620
- ax.plot(rec_cv_b, prec_cv_b, color='#377eb8', lw=2.5, linestyle='--',
621
- label=f'Internal CV (Mean AP={prec_cv_b.mean():.3f})')
622
  ax.set_xlim([-0.02, 1.02]); ax.set_ylim([-0.02, 1.02])
623
  ax.set_xlabel('Recall', fontsize=13)
624
  ax.set_ylabel('Precision', fontsize=13)
@@ -632,91 +678,41 @@ def run_pipeline(
632
  format='png', bbox_inches='tight', dpi=150)
633
  plt.close()
634
 
635
- # 训练集混淆矩阵(最佳模型
636
  plot_confusion_matrix(
637
- y_mapped.values, ypred_best_train, class_indices,
638
- f'Train CM — {best_mn} (Acc={metrics_train["Accuracy"]:.3f})',
639
  f'cm_train_{best_mn}', rf
640
  )
641
 
642
- # 保存 Train vs CV 汇总 Excel
643
- with pd.ExcelWriter(os.path.join(rf, f'train_vs_cv_{best_mn}.xlsx'),
644
- engine='openpyxl') as w:
645
- summary_rows = [
646
- {'Split': 'Train', 'Model': best_mn,
647
- 'Macro_AUC': metrics_train['Macro_AUC'],
648
- 'Accuracy': metrics_train['Accuracy'],
649
- 'Macro_F1': metrics_train['Macro_F1'],
650
- 'Weighted_F1': metrics_train['Weighted_F1'],
651
- 'Kappa': metrics_train['Kappa']},
652
  {'Split': 'Internal_CV', 'Model': best_mn,
653
- 'Macro_AUC': metrics_cv['Macro_AUC'],
654
- 'Accuracy': metrics_cv['Accuracy'],
655
- 'Macro_F1': metrics_cv['Macro_F1'],
656
- 'Weighted_F1': metrics_cv['Weighted_F1'],
657
- 'Kappa': metrics_cv['Kappa']},
658
- ]
659
- pd.DataFrame(summary_rows).to_excel(w, sheet_name='Summary', index=False)
660
- pd.DataFrame(metrics_train['report']).T.to_excel(w, sheet_name='Train_PerClass', index=True)
661
- pd.DataFrame(metrics_cv['report']).T.to_excel(w, sheet_name='CV_PerClass', index=True)
662
  amr[best_mn]['fold_df'].to_excel(w, sheet_name='CV_FoldDetail', index=False)
663
 
664
- log(f" ✅ Train vs CV 对比图及数据已保存 → train_vs_cv_{best_mn}.xlsx")
665
-
666
- # ── PR Curves (CV,原有逻辑) ──
667
- progress(0.48, desc="📈 交叉验证PR曲线...")
668
- for mn in mnames:
669
- r = amr[mn]
670
- plot_multiclass_pr(r['all_yt'], r['all_yproba'], class_indices,
671
- f'CV PR — {mn} ({task_type})', f'pr_{mn}', rf)
672
-
673
- # ── Confusion Matrices (CV) ──
674
- progress(0.52, desc="📊 混淆矩阵...")
675
- for mn in mnames:
676
- r = amr[mn]
677
- plot_confusion_matrix(r['all_yt'], r['all_yp'], class_indices,
678
- f'CV CM — {mn} (Acc={r["mean_acc"]:.3f})', f'cm_{mn}', rf)
679
-
680
- # ── Bootstrap AUC Test ──
681
- progress(0.55, desc="🔬 Bootstrap AUC 检验...")
682
- log(f"\n 🏆 最佳模型: {best_mn} (Macro AUC={best_auc:.4f})")
683
- log(f" 🔬 Bootstrap 检验 (n=2000, α=0.05)...")
684
 
685
- ALPHA = 0.05
686
- bootstrap_results = []
687
- retained = [best_mn]
688
-
689
- for om in mnames:
690
- if om == best_mn:
691
- continue
692
- p_val, auc_a, auc_b, ci_lo, ci_hi = bootstrap_auc_test(
693
- amr[best_mn]['all_yt'],
694
- amr[best_mn]['all_yproba'],
695
- amr[om]['all_yproba'],
696
- class_indices, n_bootstrap=2000
697
- )
698
- if p_val >= ALPHA:
699
- retained.append(om)
700
- dec = "Retained"
701
- else:
702
- dec = "Excluded"
703
-
704
- bootstrap_results.append({
705
- 'Model_A': best_mn, 'AUC_A': auc_a,
706
- 'Model_B': om, 'AUC_B': auc_b,
707
- 'AUC_Diff': auc_a - auc_b,
708
- 'CI_95_Low': ci_lo, 'CI_95_High': ci_hi,
709
- 'P_value': p_val, 'Decision': dec
710
- })
711
- log(f" {best_mn} vs {om}: ΔAUC={auc_a-auc_b:+.4f} 95%CI=[{ci_lo:+.4f},{ci_hi:+.4f}] P={p_val:.4f} → {dec}")
712
-
713
- bootstrap_df = pd.DataFrame(bootstrap_results).sort_values('P_value', ascending=False) if bootstrap_results else pd.DataFrame()
714
- log(f" ✅ 保留 {len(retained)}/{nm} 个模型: {', '.join(retained)}")
715
-
716
- # ── SHAP ──
717
  progress(0.62, desc="🔥 SHAP分析...")
718
  log(f"\n 🔥 SHAP特征分析 (保留模型中 Top 3)...")
719
  shap_imp = {}
 
720
  models_for_shap = sorted(retained, key=lambda x: amr[x]['mean_auc'], reverse=True)[:3]
721
 
722
  for si, mn in enumerate(models_for_shap):
@@ -733,10 +729,12 @@ def run_pipeline(
733
  exp = shap.KernelExplainer(lambda x, m=mo: m.predict_proba(x), bg)
734
  sv = exp.shap_values(Xs)
735
 
 
736
  if isinstance(sv, list):
 
737
  sv_abs = np.mean([np.abs(s) for s in sv], axis=0)
738
  elif sv.ndim == 3:
739
- sv_abs = np.mean(np.abs(sv), axis=2)
740
  else:
741
  sv_abs = np.abs(sv)
742
 
@@ -747,6 +745,7 @@ def run_pipeline(
747
  idf = pd.DataFrame({'Feature': fnames, 'Importance': fi}).sort_values('Importance', ascending=False)
748
  shap_imp[mn] = idf
749
 
 
750
  plt.figure(figsize=(10, max(6, TOPN * 0.3)))
751
  top_df = idf.head(TOPN).iloc[::-1]
752
  plt.barh(top_df['Feature'], top_df['Importance'], color='#2563eb', alpha=0.8)
@@ -760,7 +759,7 @@ def run_pipeline(
760
  except Exception as e:
761
  log(f" ⚠ {mn} SHAP失败: {e}")
762
 
763
- # ── Feature Ablation ──
764
  progress(0.72, desc="🧪 特征消融...")
765
  log(f"\n 🧪 特征消融 (仅最佳模型 {best_mn})...")
766
  ablation_data = None
@@ -768,6 +767,7 @@ def run_pipeline(
768
  imp_df = shap_imp[best_mn]
769
  top_feats = imp_df.head(TOPN)['Feature'].tolist()
770
  fcs = []; aucs_a = []
 
771
 
772
  for nf in range(1, len(top_feats) + 1):
773
  Xsub = X[top_feats[:nf]]
@@ -788,31 +788,29 @@ def run_pipeline(
788
  fold_aucs.append(a)
789
  fcs.append(nf); aucs_a.append(np.mean(fold_aucs))
790
 
 
791
  full_auc = amr[best_mn]['mean_auc']
792
  opt_n = len(top_feats)
793
  for i, a in enumerate(aucs_a):
794
  if a >= full_auc * 0.95:
795
  opt_n = i + 1; break
796
 
797
- ablation_data = {'fcs': fcs, 'aucs': aucs_a, 'feats': top_feats,
798
- 'opt_n': opt_n, 'opt_feats': top_feats[:opt_n]}
799
  log(f" ✅ 最优特征数: {opt_n} (AUC={aucs_a[opt_n-1]:.4f} vs Full={full_auc:.4f})")
800
 
 
801
  plt.figure(figsize=(10, 7))
802
  plt.plot(fcs, aucs_a, 'o-', color='#2563eb', lw=2, ms=5)
803
- plt.scatter([opt_n], [aucs_a[opt_n-1]], s=200, marker='*',
804
- color='#ef4444', edgecolors='black', lw=2, zorder=5)
805
- plt.axhline(y=full_auc, color='gray', ls='--', lw=1, alpha=0.5,
806
- label=f'Full AUC={full_auc:.3f}')
807
  plt.xlabel('Number of Features', fontsize=13); plt.ylabel('Macro AUC', fontsize=13)
808
- plt.title(f'Feature Ablation — {best_mn} (★ Optimal={opt_n})',
809
- fontsize=14, fontweight='bold')
810
  plt.legend(fontsize=11); plt.grid(True, alpha=0.15); plt.tight_layout()
811
  plt.savefig(os.path.join(rf, 'ablation.pdf'), format='pdf', bbox_inches='tight')
812
  plt.savefig(os.path.join(rf, 'ablation.png'), format='png', bbox_inches='tight', dpi=150)
813
  plt.close()
814
 
815
- # ── External Validation ──
816
  val_files_list = [vf for vf in [val_file1, val_file2, val_file3] if vf is not None]
817
  final_feats = ablation_data['opt_feats'] if ablation_data else fnames
818
 
@@ -828,6 +826,7 @@ def run_pipeline(
828
  vcol2_is_id = (vcol2.dtype == 'object') or (vcol2.nunique() / len(vcol2) > 0.5)
829
  Xe = ed.iloc[:, 2:] if vcol2_is_id else ed.iloc[:, 1:]
830
 
 
831
  ye = ye_raw.map(label_map)
832
  if ye.isna().any():
833
  log(f" ⚠ 验证集 {vi} 含有训练集中不存在的标签,已跳过")
@@ -844,60 +843,52 @@ def run_pipeline(
844
  ye_np = ye.values
845
 
846
  metrics = compute_multiclass_metrics(ye_np, yed, yep, class_indices)
847
- log(f" ✅ AUC={metrics['Macro_AUC']:.4f} Acc={metrics['Accuracy']:.4f}"
848
- f" F1={metrics['Macro_F1']:.4f} Kappa={metrics['Kappa']:.4f}")
849
 
850
  sfx = f'_ext{vi}' if len(val_files_list) > 1 else '_ext'
851
  tag = f'Validation {vi}' if len(val_files_list) > 1 else 'External'
852
 
853
- plot_multiclass_roc(ye_np, yep, class_indices,
854
- f'ROC — {tag} ({best_mn})', f'roc{sfx}', rf)
855
- plot_multiclass_pr(ye_np, yep, class_indices,
856
- f'PR — {tag} ({best_mn})', f'pr{sfx}', rf)
857
- plot_confusion_matrix(ye_np, yed, class_indices,
858
- f'CM — {tag} ({best_mn})', f'cm{sfx}', rf)
859
 
860
  with pd.ExcelWriter(os.path.join(rf, f'validation{sfx}.xlsx'), engine='openpyxl') as w:
861
  pd.DataFrame([{'Model': best_mn, 'N_Features': len(final_feats),
862
  'Macro_AUC': metrics['Macro_AUC'], 'Accuracy': metrics['Accuracy'],
863
  'Macro_F1': metrics['Macro_F1'], 'Weighted_F1': metrics['Weighted_F1'],
864
  'Kappa': metrics['Kappa']}]).to_excel(w, sheet_name='Metrics', index=False)
865
- pd.DataFrame(metrics['report']).T.to_excel(w, sheet_name='Per_Class', index=True)
 
866
  pd.DataFrame({'Feature': final_feats}).to_excel(w, sheet_name='Features', index=False)
867
 
868
- # ── Save Results ──
869
  progress(0.92, desc="💾 保存结果...")
870
  log(f"\n 💾 保存结果...")
871
 
872
  with pd.ExcelWriter(os.path.join(rf, 'model_evaluation.xlsx'), engine='openpyxl') as w:
873
  for mn, r in amr.items():
874
  r['fold_df'].to_excel(w, sheet_name=mn, index=False)
875
- # Summary(新增 Train_AUC 列)
876
- sd = [{'Model': mn,
877
- 'CV_Macro_AUC': r['mean_auc'],
878
- 'Train_Macro_AUC': train_roc_summary.get(mn, ''),
879
- 'CV_Accuracy': r['mean_acc'],
880
- 'CV_Macro_F1': r['mean_f1'],
881
- 'Retained': 'Yes' if mn in retained else 'No',
882
  'Best': 'Best' if mn == best_mn else ''}
883
  for mn, r in amr.items()]
884
- pd.DataFrame(sd).sort_values('CV_Macro_AUC', ascending=False).to_excel(
885
- w, sheet_name='Summary', index=False)
886
  if len(bootstrap_df) > 0:
887
  bootstrap_df.to_excel(w, sheet_name='Bootstrap_Test', index=False)
888
- best_report = classification_report(
889
- amr[best_mn]['all_yt'], amr[best_mn]['all_yp'],
890
- labels=class_indices, output_dict=True, zero_division=0)
891
  pd.DataFrame(best_report).T.to_excel(w, sheet_name=f'{best_mn}_PerClass', index=True)
892
 
893
  if ablation_data:
894
  with pd.ExcelWriter(os.path.join(rf, 'feature_ablation.xlsx'), engine='openpyxl') as w:
895
- pd.DataFrame({'N': ablation_data['fcs'], 'AUC': ablation_data['aucs']}).to_excel(
896
- w, sheet_name='Ablation', index=False)
897
  for mn, idf in shap_imp.items():
898
  idf.to_excel(w, sheet_name=f'{mn}_Imp', index=False)
899
 
900
- # Save params
901
  with open(os.path.join(rf, 'best_params.txt'), 'w', encoding='utf-8') as f:
902
  f.write(f"Task: {task_type} Classification ({n_classes} classes)\n")
903
  f.write(f"Classes: {classes}\n")
@@ -906,8 +897,7 @@ def run_pipeline(
906
  f.write(f"Retained Models: {', '.join(retained)} ({len(retained)}/{nm})\n\n")
907
  for mn in mcfg:
908
  status = "* Best" if mn == best_mn else ("Retained" if mn in retained else "Excluded")
909
- f.write(f"Model: {mn} | CV_AUC={amr[mn]['mean_auc']:.4f}"
910
- f" | Train_AUC={train_roc_summary.get(mn, 'N/A')} | {status}\n")
911
  bp = bpd[mn]
912
  if isinstance(bp, dict):
913
  for k, v in bp.items(): f.write(f" {k}: {v}\n")
@@ -919,40 +909,24 @@ def run_pipeline(
919
  f.write("=" * 50 + "\n")
920
  for _, row in bootstrap_df.iterrows():
921
  f.write(f" {row['Model_A']} vs {row['Model_B']}: ")
922
- f.write(f"dAUC={row['AUC_Diff']:+.4f} "
923
- f"95%CI=[{row['CI_95_Low']:+.4f},{row['CI_95_High']:+.4f}] ")
924
  f.write(f"P={row['P_value']:.4f} -> {row['Decision']}\n")
925
  if ablation_data:
926
- f.write(f"\nOptimal Features ({ablation_data['opt_n']}): "
927
- f"{', '.join(ablation_data['opt_feats'])}\n")
928
- f.write(f"\n{'='*50}\n")
929
- f.write(f"Best Model [{best_mn}] Train vs Internal CV\n")
930
- f.write(f"{'='*50}\n")
931
- f.write(f" Train → AUC={metrics_train['Macro_AUC']:.4f}"
932
- f" Acc={metrics_train['Accuracy']:.4f}"
933
- f" F1={metrics_train['Macro_F1']:.4f}"
934
- f" Kappa={metrics_train['Kappa']:.4f}\n")
935
- f.write(f" CV-Val → AUC={metrics_cv['Macro_AUC']:.4f}"
936
- f" Acc={metrics_cv['Accuracy']:.4f}"
937
- f" F1={metrics_cv['Macro_F1']:.4f}"
938
- f" Kappa={metrics_cv['Kappa']:.4f}\n")
939
-
940
- # Save model
941
  pickle.dump({
942
  'model_name': best_mn, 'model': tms[best_mn], 'best_params': bpd[best_mn],
943
  'classes': classes, 'n_classes': n_classes, 'label_map': label_map,
944
  'features': final_feats, 'task_type': task_type
945
  }, open(os.path.join(rf, f'model_{best_mn}.pkl'), 'wb'))
946
 
947
- # ── ZIP ──
948
  progress(0.97, desc="📦 打包ZIP...")
949
- zp = os.path.join(tempfile.gettempdir(),
950
- f"ml_results_{int(time.time())}_{os.getpid()}.zip")
951
  with zipfile.ZipFile(zp, 'w', zipfile.ZIP_DEFLATED) as zf:
952
  for root, _, files in os.walk(rf):
953
- for fn in files:
954
- zf.write(os.path.join(root, fn),
955
- os.path.relpath(os.path.join(root, fn), rf))
956
 
957
  nf = sum(len(f) for _, _, f in os.walk(rf))
958
  shutil.rmtree(rf, ignore_errors=True); gc.collect()
@@ -973,7 +947,7 @@ def run_pipeline(
973
 
974
 
975
  # ============================================================================
976
- # Gradio UI
977
  # ============================================================================
978
  CUSTOM_CSS = """
979
  .header-banner {
@@ -1029,7 +1003,6 @@ with gr.Blocks(
1029
  <div class="pipeline-box">
1030
  <strong>📋 流程:</strong>
1031
  <code>选择分类数</code> → <code>模型训练</code> → <code>交叉验证</code> →
1032
- <code>训练集ROC/PR</code> → <code>Train vs CV对比</code> →
1033
  <code>SHAP分析</code> → <code>特征消融</code> → <code>外部验证</code>
1034
  &nbsp;&nbsp;|&nbsp;&nbsp;
1035
  <strong>CSV格式:</strong> 第1列=标签(整数), 第2列=ID, 第3列起=特征
@@ -1063,22 +1036,20 @@ with gr.Blocks(
1063
  info="RF=随机森林 DT=决策树 KNN=K近邻 XGB=XGBoost AdaBoost LR=逻辑回归 NB=朴素贝叶斯 SVM=支持向量机",
1064
  )
1065
  with gr.Row():
1066
- btn_all = gr.Button("🔘 全选", size="sm", variant="secondary")
1067
- btn_tree = gr.Button("🌲 树模型", size="sm", variant="secondary")
1068
  btn_linear = gr.Button("📐 线性模型", size="sm", variant="secondary")
1069
- btn_top4 = gr.Button("⚡ 经典四模型", size="sm", variant="secondary")
1070
  btn_all.click(lambda: ALL_MODEL_NAMES, outputs=model_selector)
1071
  btn_tree.click(lambda: ['RF','DT','XGB','AdaBoost'], outputs=model_selector)
1072
  btn_linear.click(lambda: ['LR','SVM','NB'], outputs=model_selector)
1073
  btn_top4.click(lambda: ['RF','XGB','LR','SVM'], outputs=model_selector)
1074
 
1075
  gr.HTML('<div class="section-title">⚙️ 参数配置</div>')
1076
- enable_tuning = gr.Checkbox(
1077
- value=False,
1078
- label="启用超参数调优 (GridSearchCV) ⚠️ 开启后运行时间显著增加")
1079
  with gr.Row():
1080
  cv_folds = gr.Slider(3, 10, value=5, step=1, label="交叉验证折数")
1081
- top_n = gr.Slider(5, 50, value=20, step=1, label="SHAP 前 N 个特征")
1082
  shap_sz = gr.Slider(30, 200, value=80, step=10, label="SHAP 采样数量")
1083
 
1084
  run_btn = gr.Button("🚀 开始分析", variant="primary", size="lg")
@@ -1102,7 +1073,7 @@ with gr.Blocks(
1102
  )
1103
 
1104
  # ============================================================================
1105
- # Authentication
1106
  # ============================================================================
1107
  from datetime import datetime
1108
 
 
431
  log(f"\n{'━'*50}")
432
  log(f" ✅ {nm} 个模型训练完成")
433
 
434
+ # ── ROC Curves ── 【原有代码,原封不动】
435
+ progress(0.42, desc="📈 ROC曲线...")
436
+ log(f"\n 📈 绘制图表...")
437
+ for mn in mnames:
438
+ r = amr[mn]
439
+ plot_multiclass_roc(r['all_yt'], r['all_yproba'], class_indices,
440
+ f'ROC — {mn} ({task_type}, Macro AUC={r["mean_auc"]:.3f})', f'roc_{mn}', rf)
441
+
442
+ # Combined ROC (macro per model) 【原有代码,原封不动】
443
+ plt.figure(figsize=(10, 8))
444
+ for i, mn in enumerate(mnames):
445
+ r = amr[mn]
446
+ y_bin = label_binarize(r['all_yt'], classes=class_indices)
447
+ if n_classes == 2: y_bin = np.hstack([1 - y_bin, y_bin])
448
+ all_fpr = np.linspace(0, 1, 200); mean_tpr = np.zeros_like(all_fpr)
449
+ for c in range(n_classes):
450
+ f, t, _ = roc_curve(y_bin[:, c], r['all_yproba'][:, c])
451
+ mean_tpr += np.interp(all_fpr, f, t)
452
+ mean_tpr /= n_classes; mean_tpr[-1] = 1.0
453
+ ma = auc_score(all_fpr, mean_tpr)
454
+ plt.plot(all_fpr, mean_tpr, color=COLORS[i%8], lw=2.5, label=f'{mn} (Macro AUC={ma:.3f})')
455
+ plt.plot([0,1],[0,1],'--',color='#ccc',lw=1)
456
+ plt.xlim([-0.02,1.02]); plt.ylim([-0.02,1.02])
457
+ plt.xlabel('FPR',fontsize=13); plt.ylabel('TPR',fontsize=13)
458
+ plt.title(f'ROC — All Models ({task_type})',fontsize=14,fontweight='bold')
459
+ plt.legend(loc='lower right',fontsize=10); plt.grid(True,alpha=0.15); plt.tight_layout()
460
+ plt.savefig(os.path.join(rf,'roc_all.pdf'),format='pdf',bbox_inches='tight',dpi=300)
461
+ plt.savefig(os.path.join(rf,'roc_all.png'),format='png',bbox_inches='tight',dpi=150)
462
+ plt.close()
463
+
464
+ # ── PR Curves ── 【原有代码,原封不动】
465
+ progress(0.48, desc="📈 PR曲线...")
466
+ for mn in mnames:
467
+ r = amr[mn]
468
+ plot_multiclass_pr(r['all_yt'], r['all_yproba'], class_indices,
469
+ f'PR — {mn} ({task_type})', f'pr_{mn}', rf)
470
+
471
+ # ── Confusion Matrices ── 【原有代码,原封不动】
472
+ progress(0.52, desc="📊 混淆矩阵...")
473
+ for mn in mnames:
474
+ r = amr[mn]
475
+ plot_confusion_matrix(r['all_yt'], r['all_yp'], class_indices,
476
+ f'CM — {mn} (Acc={r["mean_acc"]:.3f})', f'cm_{mn}', rf)
477
+
478
+ # ── Bootstrap AUC Test ── 【原有代码,原封不动】
479
+ progress(0.55, desc="🔬 Bootstrap AUC 检验...")
480
+ best_mn = max(amr, key=lambda x: amr[x]['mean_auc'])
481
+ best_auc = amr[best_mn]['mean_auc']
482
+ log(f"\n 🏆 最佳模型: {best_mn} (Macro AUC={best_auc:.4f})")
483
+ log(f" 🔬 Bootstrap 检验 (n=2000, α=0.05)...")
484
+
485
+ ALPHA = 0.05
486
+ bootstrap_results = []
487
+ retained = [best_mn]
488
+
489
+ for om in mnames:
490
+ if om == best_mn:
491
+ continue
492
+ p_val, auc_a, auc_b, ci_lo, ci_hi = bootstrap_auc_test(
493
+ amr[best_mn]['all_yt'],
494
+ amr[best_mn]['all_yproba'],
495
+ amr[om]['all_yproba'],
496
+ class_indices, n_bootstrap=2000
497
+ )
498
+ if p_val >= ALPHA:
499
+ retained.append(om)
500
+ dec = "Retained"
501
+ else:
502
+ dec = "Excluded"
503
+
504
+ bootstrap_results.append({
505
+ 'Model_A': best_mn, 'AUC_A': auc_a,
506
+ 'Model_B': om, 'AUC_B': auc_b,
507
+ 'AUC_Diff': auc_a - auc_b,
508
+ 'CI_95_Low': ci_lo, 'CI_95_High': ci_hi,
509
+ 'P_value': p_val, 'Decision': dec
510
+ })
511
+ log(f" {best_mn} vs {om}: ΔAUC={auc_a-auc_b:+.4f} 95%CI=[{ci_lo:+.4f},{ci_hi:+.4f}] P={p_val:.4f} → {dec}")
512
+
513
+ bootstrap_df = pd.DataFrame(bootstrap_results).sort_values('P_value', ascending=False) if bootstrap_results else pd.DataFrame()
514
+ log(f" ✅ 保留 {len(retained)}/{nm} 个模型: {', '.join(retained)}")
515
+
516
+ # ====================================================================
517
+ # ★★★ 新增 Part-1:训练集全模型 ROC / PR 曲线
518
+ # 新文件名前缀 train_roc_* / train_pr_*,与原有文件名零冲突
519
+ # ====================================================================
520
+ progress(0.57, desc="📈 [新增] 训练集ROC/PR曲线...")
521
+ log(f"\n 📈 [新增] 各模型训练集(in-sample)ROC / PR 曲线...")
522
+
523
+ # 两个内部辅助函数,仅用于叠加绘图数据准备
524
+ def _macro_roc_arrays(yt, yp, nc, cls_idx):
525
  y_b = label_binarize(yt, classes=cls_idx)
526
  if nc == 2:
527
  y_b = np.hstack([1 - y_b, y_b])
 
533
  mean_tpr /= nc; mean_tpr[-1] = 1.0
534
  return all_fpr, mean_tpr, auc_score(all_fpr, mean_tpr)
535
 
536
+ def _macro_pr_arrays(yt, yp, nc, cls_idx):
537
  y_b = label_binarize(yt, classes=cls_idx)
538
  if nc == 2:
539
  y_b = np.hstack([1 - y_b, y_b])
 
545
  mean_prec /= nc
546
  return all_rec, mean_prec
547
 
548
+ _tr_roc = {} # mn -> (fpr, tpr, auc) 供汇总图使用
549
+ _tr_pr = {} # mn -> (rec, prec) 供汇总图使用
 
 
 
 
 
 
 
550
 
551
  for mn in mnames:
552
+ yp_tr = tms[mn].predict_proba(X.values)
553
 
554
+ # 每个模型独立图:各类别曲线 + macro(复用已有绘函数,仅前缀不同)
555
  plot_multiclass_roc(
556
+ y_mapped.values, yp_tr, class_indices,
557
+ f'Train ROC — {mn} ({task_type})',
558
+ f'train_roc_{mn}', rf
559
  )
560
  plot_multiclass_pr(
561
+ y_mapped.values, yp_tr, class_indices,
562
+ f'Train PR — {mn} ({task_type})',
563
+ f'train_pr_{mn}', rf
564
  )
565
 
566
+ fpr_t, tpr_t, auc_t = _macro_roc_arrays(y_mapped.values, yp_tr, n_classes, class_indices)
567
+ rec_t, prec_t = _macro_pr_arrays(y_mapped.values, yp_tr, n_classes, class_indices)
568
+ _tr_roc[mn] = (fpr_t, tpr_t, auc_t)
569
+ _tr_pr[mn] = (rec_t, prec_t)
 
 
 
 
570
 
571
+ # 汇总训练集全模型 ROC(train_roc_all
572
  plt.figure(figsize=(10, 8))
573
  for i, mn in enumerate(mnames):
574
+ fpr_t, tpr_t, auc_t = _tr_roc[mn]
575
+ plt.plot(fpr_t, tpr_t, color=COLORS[i % 8], lw=2.5,
576
+ label=f'{mn} (Train Macro AUC={auc_t:.3f})')
577
  plt.plot([0, 1], [0, 1], '--', color='#ccc', lw=1)
578
  plt.xlim([-0.02, 1.02]); plt.ylim([-0.02, 1.02])
579
  plt.xlabel('False Positive Rate', fontsize=13)
 
585
  plt.savefig(os.path.join(rf, 'train_roc_all.png'), format='png', bbox_inches='tight', dpi=150)
586
  plt.close()
587
 
588
+ # 汇总训练集全模型 PR(train_pr_all
589
  plt.figure(figsize=(10, 8))
590
  for i, mn in enumerate(mnames):
591
+ rec_t, prec_t = _tr_pr[mn]
592
+ plt.plot(rec_t, prec_t, color=COLORS[i % 8], lw=2.5,
593
+ label=f'{mn} (Mean AP={prec_t.mean():.3f})')
594
  plt.xlim([-0.02, 1.02]); plt.ylim([-0.02, 1.02])
595
+ plt.xlabel('Recall', fontsize=13)
596
+ plt.ylabel('Precision', fontsize=13)
597
  plt.title(f'Train PR — All Models ({task_type})', fontsize=14, fontweight='bold')
598
  plt.legend(loc='lower left', fontsize=10)
599
  plt.grid(True, alpha=0.15); plt.tight_layout()
600
  plt.savefig(os.path.join(rf, 'train_pr_all.pdf'), format='pdf', bbox_inches='tight', dpi=300)
601
  plt.savefig(os.path.join(rf, 'train_pr_all.png'), format='png', bbox_inches='tight', dpi=150)
602
  plt.close()
603
+ log(f" ✅ 训练集 ROC/PR 已生成:各模型独立 + 汇总图(train_roc_all / train_pr_all)")
604
+
605
+ # ====================================================================
606
+ # ★★★ 新增 Part-2:最终模型best_mn训练集 vs 内部 CV 对比
607
+ # 新文件:roc_train_vs_cv_* / pr_train_vs_cv_* / cm_train_*
608
+ # train_vs_cv_*.xlsx
609
+ # 原有文件:roc_* / pr_* / cm_* / model_evaluation.xlsx 均不变
610
+ # ====================================================================
611
+ progress(0.59, desc="📊 [新增] 最终模型Train vs CV对比...")
612
+ log(f"\n 📊 [新增] 最终模型 [{best_mn}] 训练集 vs 内部验证集(CV holdout)...")
613
+
614
+ # 训练集预测(用全量 fit 后的模型 tms[best_mn])
615
+ yp_best_tr = tms[best_mn].predict_proba(X.values)
616
+ yd_best_tr = tms[best_mn].predict(X.values)
617
+ met_tr = compute_multiclass_metrics(
618
+ y_mapped.values, yd_best_tr, yp_best_tr, class_indices)
619
+
620
+ # 内部 CV holdout(直接取 amr 中已累积的结果,不重新运算)
621
+ yp_best_cv = amr[best_mn]['all_yproba']
622
+ yd_best_cv = amr[best_mn]['all_yp']
623
+ yt_best_cv = amr[best_mn]['all_yt']
624
+ met_cv = compute_multiclass_metrics(
625
+ yt_best_cv, yd_best_cv, yp_best_cv, class_indices)
626
+
627
+ log(f" Train → AUC={met_tr['Macro_AUC']:.4f} Acc={met_tr['Accuracy']:.4f}"
628
+ f" F1={met_tr['Macro_F1']:.4f} Kappa={met_tr['Kappa']:.4f}")
629
+ log(f" CV-Val AUC={met_cv['Macro_AUC']:.4f} Acc={met_cv['Accuracy']:.4f}"
630
+ f" F1={met_cv['Macro_F1']:.4f} Kappa={met_cv['Kappa']:.4f}")
631
+
632
+ # 对比 ROC(roc_train_vs_cv_{best_mn})
633
+ fpr_tb, tpr_tb, auc_tb = _macro_roc_arrays(
634
+ y_mapped.values, yp_best_tr, n_classes, class_indices)
635
+ fpr_cb, tpr_cb, auc_cb = _macro_roc_arrays(
636
+ yt_best_cv, yp_best_cv, n_classes, class_indices)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
637
 
638
  fig, ax = plt.subplots(figsize=(10, 8))
639
+ ax.plot(fpr_tb, tpr_tb, color='#e41a1c', lw=2.5,
640
+ label=f'Train set (Macro AUC={auc_tb:.3f})')
641
+ ax.plot(fpr_cb, tpr_cb, color='#377eb8', lw=2.5, linestyle='--',
642
+ label=f'Internal CV (Macro AUC={auc_cb:.3f})')
643
  ax.plot([0, 1], [0, 1], '--', color='#ccc', lw=1)
644
  ax.set_xlim([-0.02, 1.02]); ax.set_ylim([-0.02, 1.02])
645
  ax.set_xlabel('False Positive Rate', fontsize=13)
 
654
  format='png', bbox_inches='tight', dpi=150)
655
  plt.close()
656
 
657
+ # 对比 PR(pr_train_vs_cv_{best_mn})
658
+ rec_tb, prec_tb = _macro_pr_arrays(
659
+ y_mapped.values, yp_best_tr, n_classes, class_indices)
660
+ rec_cb, prec_cb = _macro_pr_arrays(
661
+ yt_best_cv, yp_best_cv, n_classes, class_indices)
662
 
663
  fig, ax = plt.subplots(figsize=(10, 8))
664
+ ax.plot(rec_tb, prec_tb, color='#e41a1c', lw=2.5,
665
+ label=f'Train set (Mean AP={prec_tb.mean():.3f})')
666
+ ax.plot(rec_cb, prec_cb, color='#377eb8', lw=2.5, linestyle='--',
667
+ label=f'Internal CV (Mean AP={prec_cb.mean():.3f})')
668
  ax.set_xlim([-0.02, 1.02]); ax.set_ylim([-0.02, 1.02])
669
  ax.set_xlabel('Recall', fontsize=13)
670
  ax.set_ylabel('Precision', fontsize=13)
 
678
  format='png', bbox_inches='tight', dpi=150)
679
  plt.close()
680
 
681
+ # 训练集混淆矩阵(cm_train_{best_mn}
682
  plot_confusion_matrix(
683
+ y_mapped.values, yd_best_tr, class_indices,
684
+ f'Train CM — {best_mn} (Acc={met_tr["Accuracy"]:.3f})',
685
  f'cm_train_{best_mn}', rf
686
  )
687
 
688
+ # 指标汇总 Excel(train_vs_cv_{best_mn}.xlsx,独立新文件)
689
+ with pd.ExcelWriter(
690
+ os.path.join(rf, f'train_vs_cv_{best_mn}.xlsx'),
691
+ engine='openpyxl') as w:
692
+ pd.DataFrame([
693
+ {'Split': 'Train', 'Model': best_mn,
694
+ 'Macro_AUC': met_tr['Macro_AUC'], 'Accuracy': met_tr['Accuracy'],
695
+ 'Macro_F1': met_tr['Macro_F1'], 'Weighted_F1': met_tr['Weighted_F1'],
696
+ 'Kappa': met_tr['Kappa']},
 
697
  {'Split': 'Internal_CV', 'Model': best_mn,
698
+ 'Macro_AUC': met_cv['Macro_AUC'], 'Accuracy': met_cv['Accuracy'],
699
+ 'Macro_F1': met_cv['Macro_F1'], 'Weighted_F1': met_cv['Weighted_F1'],
700
+ 'Kappa': met_cv['Kappa']},
701
+ ]).to_excel(w, sheet_name='Summary', index=False)
702
+ pd.DataFrame(met_tr['report']).T.to_excel(w, sheet_name='Train_PerClass', index=True)
703
+ pd.DataFrame(met_cv['report']).T.to_excel(w, sheet_name='CV_PerClass', index=True)
 
 
 
704
  amr[best_mn]['fold_df'].to_excel(w, sheet_name='CV_FoldDetail', index=False)
705
 
706
+ log(f" ✅ Train vs CV 对比图及汇总数据已保存 → train_vs_cv_{best_mn}.xlsx")
707
+ # ====================================================================
708
+ # ★★★ 新增结束
709
+ # ====================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
710
 
711
+ # ── SHAP ── 【原有代码,原封不动】
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
712
  progress(0.62, desc="🔥 SHAP分析...")
713
  log(f"\n 🔥 SHAP特征分析 (保留模型中 Top 3)...")
714
  shap_imp = {}
715
+ # SHAP for top 3 retained models
716
  models_for_shap = sorted(retained, key=lambda x: amr[x]['mean_auc'], reverse=True)[:3]
717
 
718
  for si, mn in enumerate(models_for_shap):
 
729
  exp = shap.KernelExplainer(lambda x, m=mo: m.predict_proba(x), bg)
730
  sv = exp.shap_values(Xs)
731
 
732
+ # Handle SHAP output: could be list of arrays (one per class) or 3D array
733
  if isinstance(sv, list):
734
+ # Average absolute SHAP across all classes
735
  sv_abs = np.mean([np.abs(s) for s in sv], axis=0)
736
  elif sv.ndim == 3:
737
+ sv_abs = np.mean(np.abs(sv), axis=2) # (samples, features)
738
  else:
739
  sv_abs = np.abs(sv)
740
 
 
745
  idf = pd.DataFrame({'Feature': fnames, 'Importance': fi}).sort_values('Importance', ascending=False)
746
  shap_imp[mn] = idf
747
 
748
+ # Bar plot (works for any number of classes)
749
  plt.figure(figsize=(10, max(6, TOPN * 0.3)))
750
  top_df = idf.head(TOPN).iloc[::-1]
751
  plt.barh(top_df['Feature'], top_df['Importance'], color='#2563eb', alpha=0.8)
 
759
  except Exception as e:
760
  log(f" ⚠ {mn} SHAP失败: {e}")
761
 
762
+ # ── Feature Ablation (for best model only) ── 【原有代码,原封不动】
763
  progress(0.72, desc="🧪 特征消融...")
764
  log(f"\n 🧪 特征消融 (仅最佳模型 {best_mn})...")
765
  ablation_data = None
 
767
  imp_df = shap_imp[best_mn]
768
  top_feats = imp_df.head(TOPN)['Feature'].tolist()
769
  fcs = []; aucs_a = []
770
+ scoring = 'roc_auc_ovr' if n_classes > 2 else 'roc_auc'
771
 
772
  for nf in range(1, len(top_feats) + 1):
773
  Xsub = X[top_feats[:nf]]
 
788
  fold_aucs.append(a)
789
  fcs.append(nf); aucs_a.append(np.mean(fold_aucs))
790
 
791
+ # Find optimal: first N where AUC >= 95% of full AUC
792
  full_auc = amr[best_mn]['mean_auc']
793
  opt_n = len(top_feats)
794
  for i, a in enumerate(aucs_a):
795
  if a >= full_auc * 0.95:
796
  opt_n = i + 1; break
797
 
798
+ ablation_data = {'fcs': fcs, 'aucs': aucs_a, 'feats': top_feats, 'opt_n': opt_n, 'opt_feats': top_feats[:opt_n]}
 
799
  log(f" ✅ 最优特征数: {opt_n} (AUC={aucs_a[opt_n-1]:.4f} vs Full={full_auc:.4f})")
800
 
801
+ # Plot
802
  plt.figure(figsize=(10, 7))
803
  plt.plot(fcs, aucs_a, 'o-', color='#2563eb', lw=2, ms=5)
804
+ plt.scatter([opt_n], [aucs_a[opt_n-1]], s=200, marker='*', color='#ef4444', edgecolors='black', lw=2, zorder=5)
805
+ plt.axhline(y=full_auc, color='gray', ls='--', lw=1, alpha=0.5, label=f'Full AUC={full_auc:.3f}')
 
 
806
  plt.xlabel('Number of Features', fontsize=13); plt.ylabel('Macro AUC', fontsize=13)
807
+ plt.title(f'Feature Ablation — {best_mn} (★ Optimal={opt_n})', fontsize=14, fontweight='bold')
 
808
  plt.legend(fontsize=11); plt.grid(True, alpha=0.15); plt.tight_layout()
809
  plt.savefig(os.path.join(rf, 'ablation.pdf'), format='pdf', bbox_inches='tight')
810
  plt.savefig(os.path.join(rf, 'ablation.png'), format='png', bbox_inches='tight', dpi=150)
811
  plt.close()
812
 
813
+ # ── External Validation ── 【原有代码,原封不动】
814
  val_files_list = [vf for vf in [val_file1, val_file2, val_file3] if vf is not None]
815
  final_feats = ablation_data['opt_feats'] if ablation_data else fnames
816
 
 
826
  vcol2_is_id = (vcol2.dtype == 'object') or (vcol2.nunique() / len(vcol2) > 0.5)
827
  Xe = ed.iloc[:, 2:] if vcol2_is_id else ed.iloc[:, 1:]
828
 
829
+ # Map validation labels using same mapping
830
  ye = ye_raw.map(label_map)
831
  if ye.isna().any():
832
  log(f" ⚠ 验证集 {vi} 含有训练集中不存在的标签,已跳过")
 
843
  ye_np = ye.values
844
 
845
  metrics = compute_multiclass_metrics(ye_np, yed, yep, class_indices)
846
+ log(f" ✅ AUC={metrics['Macro_AUC']:.4f} Acc={metrics['Accuracy']:.4f} F1={metrics['Macro_F1']:.4f} Kappa={metrics['Kappa']:.4f}")
 
847
 
848
  sfx = f'_ext{vi}' if len(val_files_list) > 1 else '_ext'
849
  tag = f'Validation {vi}' if len(val_files_list) > 1 else 'External'
850
 
851
+ plot_multiclass_roc(ye_np, yep, class_indices, f'ROC — {tag} ({best_mn})', f'roc{sfx}', rf)
852
+ plot_multiclass_pr(ye_np, yep, class_indices, f'PR — {tag} ({best_mn})', f'pr{sfx}', rf)
853
+ plot_confusion_matrix(ye_np, yed, class_indices, f'CM — {tag} ({best_mn})', f'cm{sfx}', rf)
 
 
 
854
 
855
  with pd.ExcelWriter(os.path.join(rf, f'validation{sfx}.xlsx'), engine='openpyxl') as w:
856
  pd.DataFrame([{'Model': best_mn, 'N_Features': len(final_feats),
857
  'Macro_AUC': metrics['Macro_AUC'], 'Accuracy': metrics['Accuracy'],
858
  'Macro_F1': metrics['Macro_F1'], 'Weighted_F1': metrics['Weighted_F1'],
859
  'Kappa': metrics['Kappa']}]).to_excel(w, sheet_name='Metrics', index=False)
860
+ rpt = pd.DataFrame(metrics['report']).T
861
+ rpt.to_excel(w, sheet_name='Per_Class', index=True)
862
  pd.DataFrame({'Feature': final_feats}).to_excel(w, sheet_name='Features', index=False)
863
 
864
+ # ── Save Results ── 【原有代码,原封不动】
865
  progress(0.92, desc="💾 保存结果...")
866
  log(f"\n 💾 保存结果...")
867
 
868
  with pd.ExcelWriter(os.path.join(rf, 'model_evaluation.xlsx'), engine='openpyxl') as w:
869
  for mn, r in amr.items():
870
  r['fold_df'].to_excel(w, sheet_name=mn, index=False)
871
+ # Summary with retained status
872
+ sd = [{'Model': mn, 'Macro_AUC': r['mean_auc'], 'Accuracy': r['mean_acc'],
873
+ 'Macro_F1': r['mean_f1'], 'Retained': 'Yes' if mn in retained else 'No',
 
 
 
 
874
  'Best': 'Best' if mn == best_mn else ''}
875
  for mn, r in amr.items()]
876
+ pd.DataFrame(sd).sort_values('Macro_AUC', ascending=False).to_excel(w, sheet_name='Summary', index=False)
877
+ # Bootstrap test results
878
  if len(bootstrap_df) > 0:
879
  bootstrap_df.to_excel(w, sheet_name='Bootstrap_Test', index=False)
880
+ # Per-class report for best model
881
+ best_report = classification_report(amr[best_mn]['all_yt'], amr[best_mn]['all_yp'],
882
+ labels=class_indices, output_dict=True, zero_division=0)
883
  pd.DataFrame(best_report).T.to_excel(w, sheet_name=f'{best_mn}_PerClass', index=True)
884
 
885
  if ablation_data:
886
  with pd.ExcelWriter(os.path.join(rf, 'feature_ablation.xlsx'), engine='openpyxl') as w:
887
+ pd.DataFrame({'N': ablation_data['fcs'], 'AUC': ablation_data['aucs']}).to_excel(w, sheet_name='Ablation', index=False)
 
888
  for mn, idf in shap_imp.items():
889
  idf.to_excel(w, sheet_name=f'{mn}_Imp', index=False)
890
 
891
+ # Save params (English for SCI) 【原有代码,原封不动】
892
  with open(os.path.join(rf, 'best_params.txt'), 'w', encoding='utf-8') as f:
893
  f.write(f"Task: {task_type} Classification ({n_classes} classes)\n")
894
  f.write(f"Classes: {classes}\n")
 
897
  f.write(f"Retained Models: {', '.join(retained)} ({len(retained)}/{nm})\n\n")
898
  for mn in mcfg:
899
  status = "* Best" if mn == best_mn else ("Retained" if mn in retained else "Excluded")
900
+ f.write(f"Model: {mn} | AUC={amr[mn]['mean_auc']:.4f} | {status}\n")
 
901
  bp = bpd[mn]
902
  if isinstance(bp, dict):
903
  for k, v in bp.items(): f.write(f" {k}: {v}\n")
 
909
  f.write("=" * 50 + "\n")
910
  for _, row in bootstrap_df.iterrows():
911
  f.write(f" {row['Model_A']} vs {row['Model_B']}: ")
912
+ f.write(f"dAUC={row['AUC_Diff']:+.4f} 95%CI=[{row['CI_95_Low']:+.4f},{row['CI_95_High']:+.4f}] ")
 
913
  f.write(f"P={row['P_value']:.4f} -> {row['Decision']}\n")
914
  if ablation_data:
915
+ f.write(f"\nOptimal Features ({ablation_data['opt_n']}): {', '.join(ablation_data['opt_feats'])}\n")
916
+
917
+ # Save model 【原有代码,原封不动】
 
 
 
 
 
 
 
 
 
 
 
 
918
  pickle.dump({
919
  'model_name': best_mn, 'model': tms[best_mn], 'best_params': bpd[best_mn],
920
  'classes': classes, 'n_classes': n_classes, 'label_map': label_map,
921
  'features': final_feats, 'task_type': task_type
922
  }, open(os.path.join(rf, f'model_{best_mn}.pkl'), 'wb'))
923
 
924
+ # ── ZIP ── 【原有代码,原封不动】
925
  progress(0.97, desc="📦 打包ZIP...")
926
+ zp = os.path.join(tempfile.gettempdir(), f"ml_results_{int(time.time())}_{os.getpid()}.zip")
 
927
  with zipfile.ZipFile(zp, 'w', zipfile.ZIP_DEFLATED) as zf:
928
  for root, _, files in os.walk(rf):
929
+ for fn in files: zf.write(os.path.join(root, fn), os.path.relpath(os.path.join(root, fn), rf))
 
 
930
 
931
  nf = sum(len(f) for _, _, f in os.walk(rf))
932
  shutil.rmtree(rf, ignore_errors=True); gc.collect()
 
947
 
948
 
949
  # ============================================================================
950
+ # Gradio UI 【原有代码,原封不动】
951
  # ============================================================================
952
  CUSTOM_CSS = """
953
  .header-banner {
 
1003
  <div class="pipeline-box">
1004
  <strong>📋 流程:</strong>
1005
  <code>选择分类数</code> → <code>模型训练</code> → <code>交叉验证</code> →
 
1006
  <code>SHAP分析</code> → <code>特征消融</code> → <code>外部验证</code>
1007
  &nbsp;&nbsp;|&nbsp;&nbsp;
1008
  <strong>CSV格式:</strong> 第1列=标签(整数), 第2列=ID, 第3列起=特征
 
1036
  info="RF=随机森林 DT=决策树 KNN=K近邻 XGB=XGBoost AdaBoost LR=逻辑回归 NB=朴素贝叶斯 SVM=支持向量机",
1037
  )
1038
  with gr.Row():
1039
+ btn_all = gr.Button("🔘 全选", size="sm", variant="secondary")
1040
+ btn_tree = gr.Button("🌲 树模型", size="sm", variant="secondary")
1041
  btn_linear = gr.Button("📐 线性模型", size="sm", variant="secondary")
1042
+ btn_top4 = gr.Button("⚡ 经典四模型", size="sm", variant="secondary")
1043
  btn_all.click(lambda: ALL_MODEL_NAMES, outputs=model_selector)
1044
  btn_tree.click(lambda: ['RF','DT','XGB','AdaBoost'], outputs=model_selector)
1045
  btn_linear.click(lambda: ['LR','SVM','NB'], outputs=model_selector)
1046
  btn_top4.click(lambda: ['RF','XGB','LR','SVM'], outputs=model_selector)
1047
 
1048
  gr.HTML('<div class="section-title">⚙️ 参数配置</div>')
1049
+ enable_tuning = gr.Checkbox(value=False, label="启用超参数调优 (GridSearchCV) ⚠️ 开启后运行时间显著增加")
 
 
1050
  with gr.Row():
1051
  cv_folds = gr.Slider(3, 10, value=5, step=1, label="交叉验证折数")
1052
+ top_n = gr.Slider(5, 50, value=20, step=1, label="SHAP 前 N 个特征")
1053
  shap_sz = gr.Slider(30, 200, value=80, step=10, label="SHAP 采样数量")
1054
 
1055
  run_btn = gr.Button("🚀 开始分析", variant="primary", size="lg")
 
1073
  )
1074
 
1075
  # ============================================================================
1076
+ # Authentication 【原有代码,原封不动】
1077
  # ============================================================================
1078
  from datetime import datetime
1079