fudan-renjun's picture
Update app.py
c2c08a7 verified
"""
ML Multi-Class Classification Pipeline (2-8 classes)
Eye & ENT Hospital of Fudan University — Laboratory Medicine, Ren Jun
Gradio 5.12.0 + Python 3.11
Changelog v3 (vs v2):
[v3-1] compute_multiclass_metrics now returns full per-class and macro
AUC, Accuracy, Sensitivity (Recall), Specificity, Precision (PPV),
NPV, F1 for every class, plus macro/weighted averages.
[v3-2] Per-fold metrics table extended with all new indicators.
[v3-3] Summary sheets (Summary_InternalVal, Train_vs_InternalVal) carry
all new macro indicators.
[v3-4] Per-class detail sheets written for every model (train + val).
[v3-5] External validation Excel extended with all new indicators.
[v3-6] best_params.txt log extended with all new indicators.
[v3-7] Console log shows key new indicators.
Previous fixes retained:
[FIX-1] XGBoost num_class=None bug
[FIX-2] Bootstrap p-value centered on 0
[FIX-3] SHAP 3D axis detection
[FIX-4] Per-model train-set ROC/PR/CM
[FIX-5] Best-model Train vs InternalVal overlay plots
[FIX-6] Train_vs_InternalVal Excel sheet
[FIX-7] Guest account expiry updated
"""
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import GaussianNB
from sklearn.svm import SVC
from xgboost import XGBClassifier
from sklearn.model_selection import StratifiedKFold, GridSearchCV
from sklearn.metrics import (
roc_auc_score, confusion_matrix, roc_curve,
auc as auc_score, precision_recall_curve,
classification_report, accuracy_score, f1_score,
cohen_kappa_score, precision_score, recall_score
)
from sklearn.preprocessing import label_binarize
import seaborn as sns
import warnings
import os
import shap
import pickle
from copy import deepcopy
import zipfile
import tempfile
import traceback
import time
import shutil
import gc
import threading
import gradio as gr
warnings.filterwarnings('ignore')
# Publication-quality plot settings
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Times New Roman', 'DejaVu Serif', 'serif']
plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 150
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['axes.linewidth'] = 1.2
plt.rcParams['xtick.major.width'] = 1.0
plt.rcParams['ytick.major.width'] = 1.0
plt.rcParams['xtick.labelsize'] = 11
plt.rcParams['ytick.labelsize'] = 11
# ============================================================================
# Cache Cleanup
# ============================================================================
CLEANUP_MAX_AGE_MINUTES = 30
CLEANUP_INTERVAL_SECONDS = 600
def cleanup_old_temp_files():
now = time.time(); tmp = tempfile.gettempdir()
try:
for item in os.listdir(tmp):
p = os.path.join(tmp, item)
if item.startswith("ml_"):
age = now - os.path.getmtime(p)
if age > CLEANUP_MAX_AGE_MINUTES * 60:
if os.path.isdir(p): shutil.rmtree(p, ignore_errors=True)
elif os.path.isfile(p): os.remove(p)
except: pass
gc.collect()
def periodic_cleanup():
while True:
time.sleep(CLEANUP_INTERVAL_SECONDS)
cleanup_old_temp_files()
_ct = threading.Thread(target=periodic_cleanup, daemon=True); _ct.start()
# ============================================================================
# [v3-1] Extended metrics: Sensitivity, Specificity, PPV, NPV per class
# ============================================================================
def compute_per_class_sens_spec_ppv_npv(y_true, y_pred, y_proba, classes):
"""
For each class c, treat it as a binary OvR problem:
TP = predicted c AND true c
FP = predicted c AND true != c
FN = predicted != c AND true c
TN = predicted != c AND true != c
Returns a dict keyed by class index with:
Sensitivity (Recall / TPR), Specificity (TNR),
PPV (Precision), NPV, F1, AUC (OvR)
Also returns macro averages of each metric.
"""
n_classes = len(classes)
y_true = np.asarray(y_true)
y_pred = np.asarray(y_pred)
y_bin = label_binarize(y_true, classes=classes)
if n_classes == 2:
y_bin = np.hstack([1 - y_bin, y_bin])
per_class = {}
for i, c in enumerate(classes):
yt_b = y_bin[:, i] # true binary label for class c
yp_b = (y_pred == c).astype(int)
TP = int(np.sum((yt_b == 1) & (yp_b == 1)))
FP = int(np.sum((yt_b == 0) & (yp_b == 1)))
FN = int(np.sum((yt_b == 1) & (yp_b == 0)))
TN = int(np.sum((yt_b == 0) & (yp_b == 0)))
sens = TP / (TP + FN) if (TP + FN) > 0 else 0.0 # Sensitivity = Recall
spec = TN / (TN + FP) if (TN + FP) > 0 else 0.0 # Specificity
ppv = TP / (TP + FP) if (TP + FP) > 0 else 0.0 # PPV = Precision
npv = TN / (TN + FN) if (TN + FN) > 0 else 0.0 # NPV
f1 = (2 * ppv * sens / (ppv + sens)) if (ppv + sens) > 0 else 0.0
try:
auc_c = roc_auc_score(yt_b, y_proba[:, i])
except:
auc_c = 0.0
per_class[c] = {
'TP': TP, 'FP': FP, 'FN': FN, 'TN': TN,
'Sensitivity': sens, 'Specificity': spec,
'PPV': ppv, 'NPV': npv, 'F1': f1, 'AUC': auc_c
}
# Macro averages
macro = {}
for metric in ['Sensitivity', 'Specificity', 'PPV', 'NPV', 'F1', 'AUC']:
macro[f'Macro_{metric}'] = np.mean([per_class[c][metric] for c in classes])
return per_class, macro
def compute_multiclass_metrics(y_true, y_pred, y_proba, classes):
"""
[v3-1] Extended: returns AUC, Accuracy, Sensitivity, Specificity,
Precision (PPV), NPV, F1 — macro and per-class — plus Kappa.
"""
n_classes = len(classes)
y_true = np.asarray(y_true)
y_pred = np.asarray(y_pred)
acc = accuracy_score(y_true, y_pred)
kappa = cohen_kappa_score(y_true, y_pred)
# Macro AUC
try:
if n_classes == 2:
macro_auc = roc_auc_score(y_true, y_proba[:, 1])
else:
macro_auc = roc_auc_score(y_true, y_proba,
multi_class='ovr', average='macro')
except:
macro_auc = 0.0
# sklearn macro/weighted aggregates
f1_macro = f1_score(y_true, y_pred, average='macro', zero_division=0, labels=classes)
f1_weighted = f1_score(y_true, y_pred, average='weighted', zero_division=0, labels=classes)
prec_macro = precision_score(y_true, y_pred, average='macro', zero_division=0, labels=classes)
recall_macro = recall_score( y_true, y_pred, average='macro', zero_division=0, labels=classes)
# Per-class Sensitivity / Specificity / PPV / NPV / F1 / AUC
per_class, macro_ext = compute_per_class_sens_spec_ppv_npv(
y_true, y_pred, y_proba, classes)
# sklearn classification_report (for precision/recall/f1 by class)
report = classification_report(
y_true, y_pred, labels=classes, output_dict=True, zero_division=0)
return {
# ── Macro aggregates ──
'Accuracy': acc,
'Macro_AUC': macro_auc,
'Macro_Sensitivity': macro_ext['Macro_Sensitivity'], # == Macro Recall
'Macro_Specificity': macro_ext['Macro_Specificity'],
'Macro_PPV': macro_ext['Macro_PPV'], # == Macro Precision
'Macro_NPV': macro_ext['Macro_NPV'],
'Macro_F1': macro_ext['Macro_F1'],
'Weighted_F1': f1_weighted,
'Kappa': kappa,
# ── Per-class detail ──
'per_class': per_class, # dict keyed by class value
'report': report,
}
def metrics_to_flat_row(metrics, prefix=''):
"""Flatten a metrics dict into a single-row dict for DataFrame construction."""
row = {
f'{prefix}AUC': metrics['Macro_AUC'],
f'{prefix}Accuracy': metrics['Accuracy'],
f'{prefix}Sensitivity': metrics['Macro_Sensitivity'],
f'{prefix}Specificity': metrics['Macro_Specificity'],
f'{prefix}PPV': metrics['Macro_PPV'],
f'{prefix}NPV': metrics['Macro_NPV'],
f'{prefix}F1': metrics['Macro_F1'],
f'{prefix}Weighted_F1': metrics['Weighted_F1'],
f'{prefix}Kappa': metrics['Kappa'],
}
return row
def per_class_df(metrics, classes):
"""Build a tidy per-class DataFrame from compute_multiclass_metrics output."""
rows = []
for c in classes:
pc = metrics['per_class'][c]
rows.append({
'Class': c,
'AUC': pc['AUC'],
'Sensitivity': pc['Sensitivity'],
'Specificity': pc['Specificity'],
'PPV': pc['PPV'],
'NPV': pc['NPV'],
'F1': pc['F1'],
'TP': pc['TP'],
'FP': pc['FP'],
'FN': pc['FN'],
'TN': pc['TN'],
})
# Append macro row
rows.append({
'Class': 'Macro',
'AUC': metrics['Macro_AUC'],
'Sensitivity': metrics['Macro_Sensitivity'],
'Specificity': metrics['Macro_Specificity'],
'PPV': metrics['Macro_PPV'],
'NPV': metrics['Macro_NPV'],
'F1': metrics['Macro_F1'],
'TP': '', 'FP': '', 'FN': '', 'TN': '',
})
return pd.DataFrame(rows)
# ============================================================================
# Plotting helpers
# ============================================================================
def plot_multiclass_roc(y_true, y_proba, classes, title, filepath_prefix, rf):
"""Plot ROC curves: one-vs-rest for each class + macro average."""
n_classes = len(classes)
y_bin = label_binarize(y_true, classes=classes)
if n_classes == 2:
y_bin = np.hstack([1 - y_bin, y_bin])
fpr_dict, tpr_dict, auc_dict = {}, {}, {}
for i in range(n_classes):
fpr_dict[i], tpr_dict[i], _ = roc_curve(y_bin[:, i], y_proba[:, i])
auc_dict[i] = auc_score(fpr_dict[i], tpr_dict[i])
all_fpr = np.unique(np.concatenate([fpr_dict[i] for i in range(n_classes)]))
mean_tpr = np.zeros_like(all_fpr)
for i in range(n_classes):
mean_tpr += np.interp(all_fpr, fpr_dict[i], tpr_dict[i])
mean_tpr /= n_classes
macro_auc = auc_score(all_fpr, mean_tpr)
COLORS = ['#e41a1c','#377eb8','#4daf4a','#984ea3',
'#ff7f00','#a65628','#f781bf','#999999']
plt.figure(figsize=(10, 8))
for i in range(n_classes):
plt.plot(fpr_dict[i], tpr_dict[i], color=COLORS[i % len(COLORS)], lw=2,
label=f'Class {classes[i]} (AUC={auc_dict[i]:.3f})')
plt.plot(all_fpr, mean_tpr, 'k--', lw=2.5,
label=f'Macro Avg (AUC={macro_auc:.3f})')
plt.plot([0,1],[0,1],'--',color='#cccccc',lw=1)
plt.xlim([-0.02,1.02]); plt.ylim([-0.02,1.02])
plt.xlabel('False Positive Rate', fontsize=13)
plt.ylabel('True Positive Rate', fontsize=13)
plt.title(title, fontsize=14, fontweight='bold')
plt.legend(loc='lower right', fontsize=9)
plt.grid(True, alpha=0.15); plt.tight_layout()
plt.savefig(os.path.join(rf, f'{filepath_prefix}.pdf'),
format='pdf', bbox_inches='tight', dpi=300)
plt.savefig(os.path.join(rf, f'{filepath_prefix}.png'),
format='png', bbox_inches='tight', dpi=150)
plt.close()
return macro_auc, auc_dict
def plot_multiclass_pr(y_true, y_proba, classes, title, filepath_prefix, rf):
"""Plot Precision-Recall curves for each class."""
n_classes = len(classes)
y_bin = label_binarize(y_true, classes=classes)
if n_classes == 2:
y_bin = np.hstack([1 - y_bin, y_bin])
COLORS = ['#e41a1c','#377eb8','#4daf4a','#984ea3',
'#ff7f00','#a65628','#f781bf','#999999']
plt.figure(figsize=(10, 8))
for i in range(n_classes):
prec, rec, _ = precision_recall_curve(y_bin[:, i], y_proba[:, i])
ap = auc_score(rec, prec)
plt.plot(rec, prec, color=COLORS[i % len(COLORS)], lw=2,
label=f'Class {classes[i]} (AP={ap:.3f})')
plt.xlim([-0.02,1.02]); plt.ylim([-0.02,1.02])
plt.xlabel('Recall', fontsize=13); plt.ylabel('Precision', fontsize=13)
plt.title(title, fontsize=14, fontweight='bold')
plt.legend(loc='lower left', fontsize=9)
plt.grid(True, alpha=0.15); plt.tight_layout()
plt.savefig(os.path.join(rf, f'{filepath_prefix}.pdf'),
format='pdf', bbox_inches='tight', dpi=300)
plt.savefig(os.path.join(rf, f'{filepath_prefix}.png'),
format='png', bbox_inches='tight', dpi=150)
plt.close()
def plot_confusion_matrix(y_true, y_pred, classes, title, filepath_prefix, rf):
"""Plot confusion matrix heatmap."""
cm = confusion_matrix(y_true, y_pred, labels=classes)
plt.figure(figsize=(max(6, len(classes)*1.2), max(5, len(classes)*1.0)))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=True,
xticklabels=classes, yticklabels=classes,
annot_kws={'fontsize': 11})
plt.xlabel('Predicted', fontsize=12); plt.ylabel('True', fontsize=12)
plt.title(title, fontsize=13, fontweight='bold'); plt.tight_layout()
plt.savefig(os.path.join(rf, f'{filepath_prefix}.pdf'),
format='pdf', bbox_inches='tight', dpi=300)
plt.savefig(os.path.join(rf, f'{filepath_prefix}.png'),
format='png', bbox_inches='tight', dpi=150)
plt.close()
return cm
def plot_train_vs_val_roc(y_train, train_proba, y_val, val_proba,
classes, model_name, filepath_prefix, rf):
"""Overlay train-set ROC and internal-validation (CV OOF) ROC."""
n_classes = len(classes)
def macro_roc(y_true, y_proba):
y_bin = label_binarize(y_true, classes=classes)
if n_classes == 2:
y_bin = np.hstack([1 - y_bin, y_bin])
all_fpr = np.linspace(0, 1, 300)
mean_tpr = np.zeros_like(all_fpr)
for i in range(n_classes):
f, t, _ = roc_curve(y_bin[:, i], y_proba[:, i])
mean_tpr += np.interp(all_fpr, f, t)
mean_tpr /= n_classes; mean_tpr[-1] = 1.0
return all_fpr, mean_tpr, auc_score(all_fpr, mean_tpr)
fpr_tr, tpr_tr, auc_tr = macro_roc(y_train, train_proba)
fpr_vl, tpr_vl, auc_vl = macro_roc(y_val, val_proba)
plt.figure(figsize=(10, 8))
plt.plot(fpr_tr, tpr_tr, color='#e41a1c', lw=2.5,
label=f'Training set (Macro AUC={auc_tr:.3f})')
plt.plot(fpr_vl, tpr_vl, color='#377eb8', lw=2.5, linestyle='--',
label=f'Internal validation / CV-OOF (Macro AUC={auc_vl:.3f})')
plt.plot([0,1],[0,1],'--',color='#cccccc',lw=1)
plt.xlim([-0.02,1.02]); plt.ylim([-0.02,1.02])
plt.xlabel('False Positive Rate', fontsize=13)
plt.ylabel('True Positive Rate', fontsize=13)
plt.title(f'ROC — Train vs Internal Validation — {model_name}',
fontsize=14, fontweight='bold')
plt.legend(loc='lower right', fontsize=11)
plt.grid(True, alpha=0.15); plt.tight_layout()
plt.savefig(os.path.join(rf, f'{filepath_prefix}.pdf'),
format='pdf', bbox_inches='tight', dpi=300)
plt.savefig(os.path.join(rf, f'{filepath_prefix}.png'),
format='png', bbox_inches='tight', dpi=150)
plt.close()
return auc_tr, auc_vl
def plot_train_vs_val_pr(y_train, train_proba, y_val, val_proba,
classes, model_name, filepath_prefix, rf):
"""Overlay train-set PR and internal-validation PR."""
n_classes = len(classes)
def macro_pr(y_true, y_proba):
y_bin = label_binarize(y_true, classes=classes)
if n_classes == 2:
y_bin = np.hstack([1 - y_bin, y_bin])
all_rec = np.linspace(0, 1, 300)
mean_prec = np.zeros_like(all_rec)
for i in range(n_classes):
prec, rec, _ = precision_recall_curve(y_bin[:, i], y_proba[:, i])
mean_prec += np.interp(all_rec, rec[::-1], prec[::-1])
mean_prec /= n_classes
return all_rec, mean_prec, auc_score(all_rec, mean_prec)
rec_tr, prec_tr, ap_tr = macro_pr(y_train, train_proba)
rec_vl, prec_vl, ap_vl = macro_pr(y_val, val_proba)
plt.figure(figsize=(10, 8))
plt.plot(rec_tr, prec_tr, color='#e41a1c', lw=2.5,
label=f'Training set (Macro AP={ap_tr:.3f})')
plt.plot(rec_vl, prec_vl, color='#377eb8', lw=2.5, linestyle='--',
label=f'Internal validation / CV-OOF (Macro AP={ap_vl:.3f})')
plt.xlim([-0.02,1.02]); plt.ylim([-0.02,1.02])
plt.xlabel('Recall', fontsize=13); plt.ylabel('Precision', fontsize=13)
plt.title(f'PR — Train vs Internal Validation — {model_name}',
fontsize=14, fontweight='bold')
plt.legend(loc='lower left', fontsize=11)
plt.grid(True, alpha=0.15); plt.tight_layout()
plt.savefig(os.path.join(rf, f'{filepath_prefix}.pdf'),
format='pdf', bbox_inches='tight', dpi=300)
plt.savefig(os.path.join(rf, f'{filepath_prefix}.png'),
format='png', bbox_inches='tight', dpi=150)
plt.close()
return ap_tr, ap_vl
# ============================================================================
# Bootstrap AUC test [FIX-2 retained]
# ============================================================================
def bootstrap_auc_test(y_true, proba_a, proba_b, classes,
n_bootstrap=2000, seed=42):
rng = np.random.RandomState(seed)
n = len(y_true)
n_classes = len(classes)
def calc_macro_auc(yt, pa, pb):
try:
if n_classes == 2:
a1 = roc_auc_score(yt, pa[:, 1])
a2 = roc_auc_score(yt, pb[:, 1])
else:
a1 = roc_auc_score(yt, pa, multi_class='ovr', average='macro')
a2 = roc_auc_score(yt, pb, multi_class='ovr', average='macro')
return a1, a2
except:
return 0.0, 0.0
auc_a, auc_b = calc_macro_auc(y_true, proba_a, proba_b)
observed_diff = auc_a - auc_b
diffs = []
for _ in range(n_bootstrap):
idx = rng.choice(n, n, replace=True)
yt_b = y_true[idx]; pa_b = proba_a[idx]; pb_b = proba_b[idx]
if len(np.unique(yt_b)) < n_classes:
continue
a1, a2 = calc_macro_auc(yt_b, pa_b, pb_b)
diffs.append(a1 - a2)
if len(diffs) < 100:
return 1.0, auc_a, auc_b, -1, 1
diffs = np.array(diffs)
# [FIX-2] H0: diff = 0, two-sided
p_value = np.mean(np.abs(diffs) >= np.abs(observed_diff))
p_value = max(p_value, 1.0 / n_bootstrap)
ci_low = np.percentile(diffs, 2.5)
ci_high = np.percentile(diffs, 97.5)
return p_value, auc_a, auc_b, ci_low, ci_high
# ============================================================================
# [FIX-1] Model configs — XGBoost num_class constructed conditionally
# ============================================================================
ALL_MODEL_NAMES = ['RF', 'DT', 'KNN', 'XGB', 'AdaBoost', 'LR', 'NB', 'SVM']
def get_models_config(selected, n_classes, rs=42):
xgb_kwargs = dict(random_state=rs, eval_metric='mlogloss', n_jobs=-1)
if n_classes > 2:
xgb_kwargs['objective'] = 'multi:softprob'
xgb_kwargs['num_class'] = n_classes
else:
xgb_kwargs['objective'] = 'binary:logistic'
xgb_kwargs['eval_metric'] = 'logloss'
cfg = {
'RF': {'model': RandomForestClassifier(random_state=rs, n_jobs=-1),
'params': {'n_estimators': [100,200], 'max_depth': [20,50],
'min_samples_split': [2,5]}},
'DT': {'model': DecisionTreeClassifier(random_state=rs),
'params': {'max_depth': [20,50], 'min_samples_split': [2,10],
'criterion': ['gini','entropy']}},
'KNN': {'model': KNeighborsClassifier(n_jobs=-1),
'params': {'n_neighbors': [3,5,7],
'weights': ['uniform','distance']}},
'XGB': {'model': XGBClassifier(**xgb_kwargs),
'params': {'n_estimators': [100,200], 'max_depth': [5,7],
'learning_rate': [0.05,0.1]}},
'AdaBoost': {'model': AdaBoostClassifier(random_state=rs),
'params': {'n_estimators': [50,100],
'learning_rate': [0.1,0.5,1.0]}},
'LR': {'model': LogisticRegression(random_state=rs, n_jobs=-1,
max_iter=2000),
'params': {'C': [0.1,1,10], 'solver': ['lbfgs']}},
'NB': {'model': GaussianNB(),
'params': {'var_smoothing': [1e-9,1e-7,1e-5]}},
'SVM': {'model': SVC(probability=True, random_state=rs,
decision_function_shape='ovr'),
'params': {'C': [1,10], 'kernel': ['rbf','linear']}},
}
return {k: v for k, v in cfg.items() if k in selected}
# ============================================================================
# Main Pipeline
# ============================================================================
def run_pipeline(
train_file, val_file1, val_file2, val_file3, n_classes_select,
selected_models, enable_tuning,
cv_folds, top_n_features, shap_sample_size,
progress=gr.Progress(track_tqdm=True),
):
if train_file is None:
return None, "❌ 请先上传训练集 CSV 文件"
sel = (selected_models if isinstance(selected_models, list)
else [s.strip() for s in str(selected_models).split(",") if s.strip()])
if not sel:
return None, "❌ 请至少选择一个模型"
RS = 42; CVF = int(cv_folds)
TOPN = int(top_n_features); SHAPSZ = int(shap_sample_size)
TUNING = bool(enable_tuning)
L = []
def log(m): L.append(str(m))
rf = tempfile.mkdtemp(prefix="ml_")
try:
# ── Load Data ──
progress(0.02, desc="📂 加载数据...")
log("━" * 60)
log(" 🧬 ML 多分类模型训练与评估系统 v3")
log("━" * 60)
tp = (train_file if isinstance(train_file, str)
else getattr(train_file, 'name', str(train_file)))
data = pd.read_csv(tp)
y = data.iloc[:, 0]
col2 = data.iloc[:, 1]
col2_is_id = ((col2.dtype == 'object') or
(col2.nunique() / len(col2) > 0.5))
if col2_is_id:
X = data.iloc[:, 2:]
log(f" 📋 CSV: Col1=Label, Col2=ID({data.columns[1]}), Col3+=Features")
else:
X = data.iloc[:, 1:]
log(f" 📋 CSV: Col1=Label, Col2+=Features (no ID column)")
fnames = X.columns.tolist()
user_n = int(str(n_classes_select).split(" ")[0])
detected_classes = sorted(y.unique())
detected_classes = [int(c) if hasattr(c, 'item') else c
for c in detected_classes]
detected_n = len(detected_classes)
if detected_n != user_n:
return None, (
f"❌ 您选择了 {user_n} 分类,但数据中检测到 {detected_n} 个类别: "
f"{detected_classes}\n请将分类数修改为 {detected_n},或检查数据标签列")
classes = detected_classes
n_classes = user_n
log(f" ✅ {n_classes} 分类 — 数据验证通过")
label_map = {c: i for i, c in enumerate(classes)}
label_map_inv = {i: c for c, i in label_map.items()}
y_mapped = y.map(label_map)
class_indices = list(range(n_classes))
log(f" 📊 训练集: {X.shape[0]} 样本 × {X.shape[1]} 特征")
log(f" 🏷️ 类别数: {n_classes} 类 — {classes}")
log(f" 📊 分布: {dict(y.value_counts().sort_index())}")
log(f" 🤖 模型: {', '.join(sel)}")
log(f" 🔧 调优: {'开启' if TUNING else '关闭'} | CV: {CVF}折")
if n_classes < 2 or n_classes > 8:
return None, f"❌ 仅支持 2~8 分类,当前检测到 {n_classes} 类"
task_type = "Binary" if n_classes == 2 else f"{n_classes}-Class"
task_type_cn = "二分类" if n_classes == 2 else f"{n_classes}分类"
log(f" 📋 任务: {task_type_cn} ({task_type})")
mcfg = get_models_config(sel, n_classes, RS)
skf = StratifiedKFold(n_splits=CVF, shuffle=True, random_state=RS)
COLORS = ['#2563eb','#f59e0b','#10b981','#ef4444',
'#8b5cf6','#ec4899','#06b6d4','#6b7280']
bpd = {} # best params
amr = {} # all model results (CV-OOF)
tms = {} # trained models (full data)
train_results = {} # metrics on full training set
total = len(mcfg)
# ── Train All Models ──────────────────────────────────────────────
for mi, (mn, cf) in enumerate(mcfg.items()):
pv = 0.05 + 0.35 * mi / total
progress(pv, desc=f"🏋️ [{mi+1}/{total}] 训练 {mn}...")
log(f"\n{'─'*50}")
log(f" 🔄 [{mi+1}/{total}] {mn}")
Xv = X.values
# Optional GridSearch
if TUNING:
log(f" ⏳ GridSearchCV (CV={CVF})...")
scoring = 'roc_auc_ovr' if n_classes > 2 else 'roc_auc'
gs = GridSearchCV(cf['model'], cf['params'], cv=skf,
scoring=scoring, n_jobs=-1, verbose=0)
gs.fit(Xv, y_mapped)
bp = gs.best_params_; bpd[mn] = bp
log(f" ✓ 最佳CV Score: {gs.best_score_:.4f}")
else:
bp = {}; bpd[mn] = "Default"
# Fit final model on full training set
mdl = deepcopy(cf['model'])
if bp: mdl.set_params(**bp)
mdl.fit(Xv, y_mapped)
tms[mn] = mdl
# Training-set metrics
train_proba_full = mdl.predict_proba(Xv)
train_pred_full = mdl.predict(Xv)
train_met = compute_multiclass_metrics(
y_mapped.values, train_pred_full,
train_proba_full, class_indices)
train_results[mn] = {
'proba': train_proba_full,
'pred': train_pred_full,
'metrics': train_met,
}
# ── CV evaluation (OOF = Internal Validation) ──
all_yt = []; all_yp = []; all_yproba = []
fold_metrics = []
for fi, (tri, tei) in enumerate(skf.split(X, y_mapped), 1):
Xtr, Xte = X.iloc[tri].values, X.iloc[tei].values
ytr, yte = y_mapped.iloc[tri], y_mapped.iloc[tei]
mf = deepcopy(cf['model'])
if bp: mf.set_params(**bp)
mf.fit(Xtr, ytr)
ypred = mf.predict(Xte)
yproba = mf.predict_proba(Xte)
all_yt.extend(yte)
all_yp.extend(ypred)
all_yproba.append(yproba)
fm = compute_multiclass_metrics(yte, ypred, yproba, class_indices)
# [v3-2] Extended fold row
fold_metrics.append({
'Fold': fi,
'AUC': fm['Macro_AUC'],
'Accuracy': fm['Accuracy'],
'Sensitivity': fm['Macro_Sensitivity'],
'Specificity': fm['Macro_Specificity'],
'PPV': fm['Macro_PPV'],
'NPV': fm['Macro_NPV'],
'F1': fm['Macro_F1'],
'Weighted_F1': fm['Weighted_F1'],
'Kappa': fm['Kappa'],
})
all_yt = np.array(all_yt)
all_yp = np.array(all_yp)
all_yproba = np.vstack(all_yproba)
# Build fold table with Mean row
fdf = pd.DataFrame(fold_metrics)
mean_row = {
col: (fdf[col].mean() if col != 'Fold' else 'Mean')
for col in fdf.columns
}
fdf = pd.concat([fdf, pd.DataFrame([mean_row])], ignore_index=True)
# OOF aggregate metrics (computed on concatenated OOF predictions)
oof_met = compute_multiclass_metrics(
all_yt, all_yp, all_yproba, class_indices)
amr[mn] = {
'fold_df': fdf,
'mean_auc': mean_row['AUC'],
'mean_acc': mean_row['Accuracy'],
'mean_sens': mean_row['Sensitivity'],
'mean_spec': mean_row['Specificity'],
'mean_ppv': mean_row['PPV'],
'mean_npv': mean_row['NPV'],
'mean_f1': mean_row['F1'],
'mean_wf1': mean_row['Weighted_F1'],
'mean_kappa': mean_row['Kappa'],
'oof_metrics': oof_met,
'all_yt': all_yt,
'all_yp': all_yp,
'all_yproba': all_yproba,
}
# [v3-7] Log all key metrics
tm = train_met; vm = mean_row
log(f" ✅ [Train] AUC={tm['Macro_AUC']:.4f} Acc={tm['Accuracy']:.4f} "
f"Sens={tm['Macro_Sensitivity']:.4f} Spec={tm['Macro_Specificity']:.4f} "
f"PPV={tm['Macro_PPV']:.4f} NPV={tm['Macro_NPV']:.4f} "
f"F1={tm['Macro_F1']:.4f} Kappa={tm['Kappa']:.4f}")
log(f" ✅ [CV-OOF] AUC={vm['AUC']:.4f} Acc={vm['Accuracy']:.4f} "
f"Sens={vm['Sensitivity']:.4f} Spec={vm['Specificity']:.4f} "
f"PPV={vm['PPV']:.4f} NPV={vm['NPV']:.4f} "
f"F1={vm['F1']:.4f} Kappa={vm['Kappa']:.4f}")
mnames = list(amr.keys()); nm = len(mnames)
log(f"\n{'━'*60}")
log(f" ✅ {nm} 个模型训练完成")
# ── Training-set ROC / PR / CM for every model ───────────────────
progress(0.40, desc="📈 训练集曲线...")
log(f"\n 📈 绘制训练集 ROC / PR / CM...")
for mn in mnames:
tr = train_results[mn]
tm = tr['metrics']
plot_multiclass_roc(
y_mapped.values, tr['proba'], class_indices,
f'ROC (Train) — {mn} ({task_type}, AUC={tm["Macro_AUC"]:.3f})',
f'roc_train_{mn}', rf)
plot_multiclass_pr(
y_mapped.values, tr['proba'], class_indices,
f'PR (Train) — {mn} ({task_type})',
f'pr_train_{mn}', rf)
plot_confusion_matrix(
y_mapped.values, tr['pred'], class_indices,
f'CM (Train) — {mn} (Acc={tm["Accuracy"]:.3f})',
f'cm_train_{mn}', rf)
# Combined training-set ROC (all models, macro)
plt.figure(figsize=(10, 8))
for i, mn in enumerate(mnames):
tr = train_results[mn]
y_bin = label_binarize(y_mapped.values, classes=class_indices)
if n_classes == 2: y_bin = np.hstack([1 - y_bin, y_bin])
all_fpr = np.linspace(0, 1, 200); mean_tpr = np.zeros_like(all_fpr)
for c in range(n_classes):
f, t, _ = roc_curve(y_bin[:, c], tr['proba'][:, c])
mean_tpr += np.interp(all_fpr, f, t)
mean_tpr /= n_classes; mean_tpr[-1] = 1.0
ma = auc_score(all_fpr, mean_tpr)
plt.plot(all_fpr, mean_tpr, color=COLORS[i%8], lw=2.5,
label=f'{mn} (Macro AUC={ma:.3f})')
plt.plot([0,1],[0,1],'--',color='#ccc',lw=1)
plt.xlim([-0.02,1.02]); plt.ylim([-0.02,1.02])
plt.xlabel('FPR',fontsize=13); plt.ylabel('TPR',fontsize=13)
plt.title(f'ROC (Train) — All Models ({task_type})',
fontsize=14, fontweight='bold')
plt.legend(loc='lower right',fontsize=10)
plt.grid(True,alpha=0.15); plt.tight_layout()
plt.savefig(os.path.join(rf,'roc_train_all.pdf'),
format='pdf',bbox_inches='tight',dpi=300)
plt.savefig(os.path.join(rf,'roc_train_all.png'),
format='png',bbox_inches='tight',dpi=150)
plt.close()
# ── CV-OOF ROC / PR / CM ─────────────────────────────────────────
progress(0.44, desc="📈 内部验证ROC曲线...")
log(f"\n 📈 绘制内部验证(CV-OOF) ROC / PR / CM...")
for mn in mnames:
r = amr[mn]
plot_multiclass_roc(
r['all_yt'], r['all_yproba'], class_indices,
f'ROC (Internal Val) — {mn} ({task_type}, AUC={r["mean_auc"]:.3f})',
f'roc_val_{mn}', rf)
# Combined CV-OOF ROC
plt.figure(figsize=(10, 8))
for i, mn in enumerate(mnames):
r = amr[mn]
y_bin = label_binarize(r['all_yt'], classes=class_indices)
if n_classes == 2: y_bin = np.hstack([1 - y_bin, y_bin])
all_fpr = np.linspace(0, 1, 200); mean_tpr = np.zeros_like(all_fpr)
for c in range(n_classes):
f, t, _ = roc_curve(y_bin[:, c], r['all_yproba'][:, c])
mean_tpr += np.interp(all_fpr, f, t)
mean_tpr /= n_classes; mean_tpr[-1] = 1.0
ma = auc_score(all_fpr, mean_tpr)
plt.plot(all_fpr, mean_tpr, color=COLORS[i%8], lw=2.5,
label=f'{mn} (Macro AUC={ma:.3f})')
plt.plot([0,1],[0,1],'--',color='#ccc',lw=1)
plt.xlim([-0.02,1.02]); plt.ylim([-0.02,1.02])
plt.xlabel('FPR',fontsize=13); plt.ylabel('TPR',fontsize=13)
plt.title(f'ROC (Internal Val / CV-OOF) — All Models ({task_type})',
fontsize=14, fontweight='bold')
plt.legend(loc='lower right',fontsize=10)
plt.grid(True,alpha=0.15); plt.tight_layout()
plt.savefig(os.path.join(rf,'roc_val_all.pdf'),
format='pdf',bbox_inches='tight',dpi=300)
plt.savefig(os.path.join(rf,'roc_val_all.png'),
format='png',bbox_inches='tight',dpi=150)
plt.close()
progress(0.48, desc="📈 PR曲线...")
for mn in mnames:
r = amr[mn]
plot_multiclass_pr(
r['all_yt'], r['all_yproba'], class_indices,
f'PR (Internal Val) — {mn} ({task_type})',
f'pr_val_{mn}', rf)
progress(0.51, desc="📊 混淆矩阵...")
for mn in mnames:
r = amr[mn]
plot_confusion_matrix(
r['all_yt'], r['all_yp'], class_indices,
f'CM (Internal Val) — {mn} (Acc={r["mean_acc"]:.3f})',
f'cm_val_{mn}', rf)
# ── Bootstrap AUC Test ────────────────────────────────────────────
progress(0.54, desc="🔬 Bootstrap AUC 检验...")
best_mn = max(amr, key=lambda x: amr[x]['mean_auc'])
best_auc = amr[best_mn]['mean_auc']
log(f"\n 🏆 最佳模型: {best_mn} (Macro AUC={best_auc:.4f})")
log(f" 🔬 Bootstrap 检验 (n=2000, α=0.05)...")
ALPHA = 0.05
bootstrap_results = []
retained = [best_mn]
for om in mnames:
if om == best_mn:
continue
p_val, auc_a, auc_b, ci_lo, ci_hi = bootstrap_auc_test(
amr[best_mn]['all_yt'],
amr[best_mn]['all_yproba'],
amr[om]['all_yproba'],
class_indices, n_bootstrap=2000)
dec = "Retained" if p_val >= ALPHA else "Excluded"
if p_val >= ALPHA:
retained.append(om)
bootstrap_results.append({
'Model_A': best_mn, 'AUC_A': auc_a,
'Model_B': om, 'AUC_B': auc_b,
'AUC_Diff': auc_a - auc_b,
'CI_95_Low': ci_lo, 'CI_95_High': ci_hi,
'P_value': p_val, 'Decision': dec,
})
log(f" {best_mn} vs {om}: ΔAUC={auc_a-auc_b:+.4f} "
f"95%CI=[{ci_lo:+.4f},{ci_hi:+.4f}] "
f"P={p_val:.4f}{dec}")
bootstrap_df = (pd.DataFrame(bootstrap_results)
.sort_values('P_value', ascending=False)
if bootstrap_results else pd.DataFrame())
log(f" ✅ 保留 {len(retained)}/{nm} 个模型: {', '.join(retained)}")
# ── Best model: Train vs Internal Val overlay ─────────────────────
progress(0.57, desc="📈 Train vs Val 对比图...")
log(f"\n 📈 最佳模型 {best_mn}: Train vs Internal Validation 对比...")
auc_tr_b, auc_vl_b = plot_train_vs_val_roc(
y_mapped.values, train_results[best_mn]['proba'],
amr[best_mn]['all_yt'], amr[best_mn]['all_yproba'],
class_indices, best_mn, f'roc_train_vs_val_{best_mn}', rf)
ap_tr_b, ap_vl_b = plot_train_vs_val_pr(
y_mapped.values, train_results[best_mn]['proba'],
amr[best_mn]['all_yt'], amr[best_mn]['all_yproba'],
class_indices, best_mn, f'pr_train_vs_val_{best_mn}', rf)
log(f" ROC — Train AUC={auc_tr_b:.4f} / Val AUC={auc_vl_b:.4f}")
log(f" PR — Train AP={ap_tr_b:.4f} / Val AP={ap_vl_b:.4f}")
# ── SHAP ──────────────────────────────────────────────────────────
progress(0.60, desc="🔥 SHAP分析...")
log(f"\n 🔥 SHAP特征分析 (保留模型中 Top 3)...")
shap_imp = {}
models_for_shap = sorted(retained,
key=lambda x: amr[x]['mean_auc'],
reverse=True)[:3]
for si, mn in enumerate(models_for_shap):
progress(0.60 + 0.10 * si / max(len(models_for_shap), 1),
desc=f"🔥 SHAP: {mn}...")
mo = tms[mn]; Xshap = X.values
ns = min(SHAPSZ, Xshap.shape[0])
np.random.seed(RS)
sidx = np.random.choice(Xshap.shape[0], ns, replace=False)
Xs = Xshap[sidx]
try:
if mn in ['RF', 'XGB', 'DT', 'AdaBoost']:
exp = shap.TreeExplainer(mo)
sv = exp.shap_values(Xs)
else:
bg = Xs[np.random.choice(ns, min(50, ns), replace=False)]
exp = shap.KernelExplainer(
lambda x, m=mo: m.predict_proba(x), bg)
sv = exp.shap_values(Xs)
# [FIX-3] Robust SHAP shape handling
if isinstance(sv, list):
sv_abs = np.mean([np.abs(s) for s in sv], axis=0)
elif sv.ndim == 3:
if sv.shape[2] == n_classes:
sv_abs = np.mean(np.abs(sv), axis=2)
elif sv.shape[1] == n_classes:
sv_abs = np.mean(np.abs(sv), axis=1)
else:
sv_abs = np.abs(sv).mean(axis=-1)
else:
sv_abs = np.abs(sv)
fi = sv_abs.mean(axis=0)
if len(fi) > len(fnames): fi = fi[:len(fnames)]
elif len(fi) < len(fnames):
fi = np.pad(fi, (0, len(fnames) - len(fi)))
idf = (pd.DataFrame({'Feature': fnames, 'Importance': fi})
.sort_values('Importance', ascending=False))
shap_imp[mn] = idf
plt.figure(figsize=(10, max(6, TOPN * 0.3)))
top_df = idf.head(TOPN).iloc[::-1]
plt.barh(top_df['Feature'], top_df['Importance'],
color='#2563eb', alpha=0.8)
plt.xlabel('Mean |SHAP|', fontsize=12)
plt.title(f'SHAP Feature Importance — {mn} (Top {TOPN})',
fontsize=13, fontweight='bold')
plt.tight_layout()
plt.savefig(os.path.join(rf, f'shap_{mn}.pdf'),
format='pdf', bbox_inches='tight')
plt.savefig(os.path.join(rf, f'shap_{mn}.png'),
format='png', bbox_inches='tight', dpi=150)
plt.close()
log(f" ✅ {mn} Top3: "
f"{', '.join(idf.head(3)['Feature'].tolist())}")
except Exception as e:
log(f" ⚠ {mn} SHAP失败: {e}")
# ── Feature Ablation ──────────────────────────────────────────────
progress(0.72, desc="🧪 特征消融...")
log(f"\n 🧪 特征消融 (仅最佳模型 {best_mn})...")
ablation_data = None
if best_mn in shap_imp:
imp_df = shap_imp[best_mn]
top_feats = imp_df.head(TOPN)['Feature'].tolist()
fcs = []; aucs_a = []
for nf in range(1, len(top_feats) + 1):
Xsub = X[top_feats[:nf]]
fold_aucs = []
for tri, tei in skf.split(Xsub, y_mapped):
mf = deepcopy(mcfg[best_mn]['model'])
bp2 = bpd.get(best_mn, {})
if isinstance(bp2, dict) and bp2:
mf.set_params(**bp2)
mf.fit(Xsub.iloc[tri].values, y_mapped.iloc[tri])
yproba_f = mf.predict_proba(Xsub.iloc[tei].values)
yte_f = y_mapped.iloc[tei]
try:
a = (roc_auc_score(yte_f, yproba_f[:, 1])
if n_classes == 2 else
roc_auc_score(yte_f, yproba_f,
multi_class='ovr', average='macro'))
except:
a = 0.0
fold_aucs.append(a)
fcs.append(nf); aucs_a.append(np.mean(fold_aucs))
full_auc = amr[best_mn]['mean_auc']
opt_n = len(top_feats)
for i, a in enumerate(aucs_a):
if a >= full_auc * 0.95:
opt_n = i + 1; break
ablation_data = {
'fcs': fcs, 'aucs': aucs_a, 'feats': top_feats,
'opt_n': opt_n, 'opt_feats': top_feats[:opt_n]
}
log(f" ✅ 最优特征数: {opt_n} "
f"(AUC={aucs_a[opt_n-1]:.4f} vs Full={full_auc:.4f})")
plt.figure(figsize=(10, 7))
plt.plot(fcs, aucs_a, 'o-', color='#2563eb', lw=2, ms=5)
plt.scatter([opt_n], [aucs_a[opt_n-1]], s=200, marker='*',
color='#ef4444', edgecolors='black', lw=2, zorder=5)
plt.axhline(y=full_auc, color='gray', ls='--', lw=1, alpha=0.5,
label=f'Full AUC={full_auc:.3f}')
plt.xlabel('Number of Features', fontsize=13)
plt.ylabel('Macro AUC', fontsize=13)
plt.title(f'Feature Ablation — {best_mn} (★ Optimal={opt_n})',
fontsize=14, fontweight='bold')
plt.legend(fontsize=11); plt.grid(True, alpha=0.15); plt.tight_layout()
plt.savefig(os.path.join(rf, 'ablation.pdf'),
format='pdf', bbox_inches='tight')
plt.savefig(os.path.join(rf, 'ablation.png'),
format='png', bbox_inches='tight', dpi=150)
plt.close()
# ── External Validation ───────────────────────────────────────────
val_files_list = [vf for vf in [val_file1, val_file2, val_file3]
if vf is not None]
final_feats = ablation_data['opt_feats'] if ablation_data else fnames
if val_files_list:
progress(0.82, desc="🧪 外部验证...")
log(f"\n{'━'*60}")
log(f" 🧪 外部验证 ({len(val_files_list)} 个验证集)")
for vi, vf in enumerate(val_files_list, 1):
vp = (vf if isinstance(vf, str)
else getattr(vf, 'name', str(vf)))
ed = pd.read_csv(vp); ye_raw = ed.iloc[:, 0]
vcol2 = ed.iloc[:, 1]
vcol2_is_id = ((vcol2.dtype == 'object') or
(vcol2.nunique() / len(vcol2) > 0.5))
Xe = ed.iloc[:, 2:] if vcol2_is_id else ed.iloc[:, 1:]
ye = ye_raw.map(label_map)
if ye.isna().any():
log(f" ⚠ 验证集 {vi} 含有训练集中不存在的标签,已跳过")
continue
log(f"\n 📊 验证集 {vi}: {Xe.shape[0]} 样本, "
f"{os.path.basename(vp)}")
Xes = Xe[final_feats]; Xtf = X[final_feats]
fm = deepcopy(mcfg[best_mn]['model'])
bp3 = bpd[best_mn]
if isinstance(bp3, dict) and bp3:
fm.set_params(**bp3)
fm.fit(Xtf.values, y_mapped)
yep = fm.predict_proba(Xes.values)
yed = fm.predict(Xes.values)
ye_np = ye.values
ext_met = compute_multiclass_metrics(
ye_np, yed, yep, class_indices)
em = ext_met
log(f" ✅ AUC={em['Macro_AUC']:.4f} "
f"Acc={em['Accuracy']:.4f} "
f"Sens={em['Macro_Sensitivity']:.4f} "
f"Spec={em['Macro_Specificity']:.4f} "
f"PPV={em['Macro_PPV']:.4f} "
f"NPV={em['Macro_NPV']:.4f} "
f"F1={em['Macro_F1']:.4f} "
f"Kappa={em['Kappa']:.4f}")
sfx = f'_ext{vi}' if len(val_files_list) > 1 else '_ext'
tag = f'Validation {vi}' if len(val_files_list) > 1 else 'External'
plot_multiclass_roc(ye_np, yep, class_indices,
f'ROC — {tag} ({best_mn})', f'roc{sfx}', rf)
plot_multiclass_pr(ye_np, yep, class_indices,
f'PR — {tag} ({best_mn})', f'pr{sfx}', rf)
plot_confusion_matrix(ye_np, yed, class_indices,
f'CM — {tag} ({best_mn})', f'cm{sfx}', rf)
# [v3-5] Extended external validation Excel
with pd.ExcelWriter(
os.path.join(rf, f'validation{sfx}.xlsx'),
engine='openpyxl'
) as w:
# Macro metrics row
macro_row = {'Model': best_mn,
'N_Features': len(final_feats)}
macro_row.update(metrics_to_flat_row(em))
pd.DataFrame([macro_row]).to_excel(
w, sheet_name='Metrics_Macro', index=False)
# Per-class detail
per_class_df(em, class_indices).to_excel(
w, sheet_name='Metrics_PerClass', index=False)
pd.DataFrame({'Feature': final_feats}).to_excel(
w, sheet_name='Features', index=False)
# ── Save Results ──────────────────────────────────────────────────
progress(0.92, desc="💾 保存结果...")
log(f"\n 💾 保存结果...")
with pd.ExcelWriter(
os.path.join(rf, 'model_evaluation.xlsx'),
engine='openpyxl'
) as w:
# 1. Per-fold CV results for every model [v3-2 extended columns]
for mn, r in amr.items():
r['fold_df'].to_excel(w, sheet_name=mn, index=False)
# 2. Summary — Internal Validation (CV-OOF) [v3-3 all metrics]
sd = []
for mn, r in amr.items():
row = {
'Model': mn,
'Retained': 'Yes' if mn in retained else 'No',
'Best': 'Best' if mn == best_mn else '',
}
row.update({
'AUC': r['mean_auc'],
'Accuracy': r['mean_acc'],
'Sensitivity': r['mean_sens'],
'Specificity': r['mean_spec'],
'PPV': r['mean_ppv'],
'NPV': r['mean_npv'],
'F1': r['mean_f1'],
'Weighted_F1': r['mean_wf1'],
'Kappa': r['mean_kappa'],
})
sd.append(row)
(pd.DataFrame(sd)
.sort_values('AUC', ascending=False)
.to_excel(w, sheet_name='Summary_InternalVal', index=False))
# 3. Train vs Internal Validation [v3-3 all metrics]
comparison_rows = []
for mn in amr:
tr_m = train_results[mn]['metrics']
vm = amr[mn]
row = {
'Model': mn,
'Train_AUC': tr_m['Macro_AUC'],
'Train_Accuracy': tr_m['Accuracy'],
'Train_Sensitivity': tr_m['Macro_Sensitivity'],
'Train_Specificity': tr_m['Macro_Specificity'],
'Train_PPV': tr_m['Macro_PPV'],
'Train_NPV': tr_m['Macro_NPV'],
'Train_F1': tr_m['Macro_F1'],
'Train_Kappa': tr_m['Kappa'],
'Val_AUC': vm['mean_auc'],
'Val_Accuracy': vm['mean_acc'],
'Val_Sensitivity': vm['mean_sens'],
'Val_Specificity': vm['mean_spec'],
'Val_PPV': vm['mean_ppv'],
'Val_NPV': vm['mean_npv'],
'Val_F1': vm['mean_f1'],
'Val_Kappa': vm['mean_kappa'],
'AUC_Gap': tr_m['Macro_AUC'] - vm['mean_auc'],
'Retained': 'Yes' if mn in retained else 'No',
'Best': 'Best' if mn == best_mn else '',
}
comparison_rows.append(row)
(pd.DataFrame(comparison_rows)
.sort_values('Val_AUC', ascending=False)
.to_excel(w, sheet_name='Train_vs_InternalVal', index=False))
# 4. Bootstrap test
if len(bootstrap_df) > 0:
bootstrap_df.to_excel(w, sheet_name='Bootstrap_Test',
index=False)
# 5. [v3-4] Per-class detail for EVERY model (train + val)
for mn in mnames:
# Val (OOF)
oof_pc = per_class_df(amr[mn]['oof_metrics'], class_indices)
sheet_v = f'{mn}_Val_PerClass'
if len(sheet_v) > 31: sheet_v = sheet_v[:31]
oof_pc.to_excel(w, sheet_name=sheet_v, index=False)
# Train
tr_pc = per_class_df(train_results[mn]['metrics'], class_indices)
sheet_t = f'{mn}_Train_PerClass'
if len(sheet_t) > 31: sheet_t = sheet_t[:31]
tr_pc.to_excel(w, sheet_name=sheet_t, index=False)
# Ablation Excel
if ablation_data:
with pd.ExcelWriter(
os.path.join(rf, 'feature_ablation.xlsx'),
engine='openpyxl'
) as w:
pd.DataFrame({
'N': ablation_data['fcs'],
'AUC': ablation_data['aucs']
}).to_excel(w, sheet_name='Ablation', index=False)
for mn, idf in shap_imp.items():
idf.to_excel(w, sheet_name=f'{mn}_Imp', index=False)
# ── best_params.txt [v3-6] all metrics ──────────────────────────
with open(os.path.join(rf, 'best_params.txt'), 'w',
encoding='utf-8') as f:
f.write(f"Task: {task_type} Classification ({n_classes} classes)\n")
f.write(f"Classes: {classes}\n")
f.write(f"Label Mapping: {label_map}\n\n")
f.write(f"Statistical Test: Bootstrap AUC Test "
f"(n=2000, alpha=0.05)\n")
f.write(f"Retained Models: {', '.join(retained)} "
f"({len(retained)}/{nm})\n\n")
f.write("=" * 65 + "\n")
f.write("Model Performance Summary\n")
f.write(f"{'Metric':<14} "
f"{'AUC':>7} {'Acc':>7} {'Sens':>7} {'Spec':>7} "
f"{'PPV':>7} {'NPV':>7} {'F1':>7} {'Kappa':>7}\n")
f.write("-" * 65 + "\n")
def fmt_row(label, m_auc, m_acc, m_sens, m_spec,
m_ppv, m_npv, m_f1, m_kappa):
return (f"{label:<14} "
f"{m_auc:>7.4f} {m_acc:>7.4f} {m_sens:>7.4f} "
f"{m_spec:>7.4f} {m_ppv:>7.4f} {m_npv:>7.4f} "
f"{m_f1:>7.4f} {m_kappa:>7.4f}\n")
for mn in mcfg:
status = ("* Best" if mn == best_mn
else ("Retained" if mn in retained else "Excluded"))
tr_m = train_results[mn]['metrics']
vm = amr[mn]
f.write(f"\nModel: {mn} | {status}\n")
f.write(fmt_row(
" Train",
tr_m['Macro_AUC'], tr_m['Accuracy'],
tr_m['Macro_Sensitivity'], tr_m['Macro_Specificity'],
tr_m['Macro_PPV'], tr_m['Macro_NPV'],
tr_m['Macro_F1'], tr_m['Kappa']))
f.write(fmt_row(
" CV-OOF",
vm['mean_auc'], vm['mean_acc'],
vm['mean_sens'], vm['mean_spec'],
vm['mean_ppv'], vm['mean_npv'],
vm['mean_f1'], vm['mean_kappa']))
f.write(f" AUC Gap: "
f"{tr_m['Macro_AUC'] - vm['mean_auc']:+.4f}\n")
bp = bpd[mn]
if isinstance(bp, dict):
for k, v in bp.items():
f.write(f" {k}: {v}\n")
else:
f.write(f" Params: {bp}\n")
if len(bootstrap_df) > 0:
f.write("\n" + "=" * 65 + "\n")
f.write("Bootstrap AUC Comparison Results\n")
f.write("=" * 65 + "\n")
for _, row in bootstrap_df.iterrows():
f.write(f" {row['Model_A']} vs {row['Model_B']}: "
f"dAUC={row['AUC_Diff']:+.4f} "
f"95%CI=[{row['CI_95_Low']:+.4f},"
f"{row['CI_95_High']:+.4f}] "
f"P={row['P_value']:.4f} -> {row['Decision']}\n")
if ablation_data:
f.write(f"\nOptimal Features ({ablation_data['opt_n']}): "
f"{', '.join(ablation_data['opt_feats'])}\n")
# Save best model pickle
pickle.dump({
'model_name': best_mn,
'model': tms[best_mn],
'best_params': bpd[best_mn],
'classes': classes,
'n_classes': n_classes,
'label_map': label_map,
'features': final_feats,
'task_type': task_type,
}, open(os.path.join(rf, f'model_{best_mn}.pkl'), 'wb'))
# ── ZIP ───────────────────────────────────────────────────────────
progress(0.97, desc="📦 打包ZIP...")
zp = os.path.join(tempfile.gettempdir(),
f"ml_results_{int(time.time())}_{os.getpid()}.zip")
with zipfile.ZipFile(zp, 'w', zipfile.ZIP_DEFLATED) as zf:
for root, _, files in os.walk(rf):
for fn in files:
zf.write(os.path.join(root, fn),
os.path.relpath(os.path.join(root, fn), rf))
nf = sum(len(f) for _, _, f in os.walk(rf))
shutil.rmtree(rf, ignore_errors=True); gc.collect()
tm_b = train_results[best_mn]['metrics']
log(f"\n{'━'*60}")
log(f" 🎉 分析完成!共 {nf} 个文件已打包")
log(f" 📋 Task: {task_type} | Best Model: {best_mn}")
log(f" 📊 Train — AUC={tm_b['Macro_AUC']:.4f} "
f"Acc={tm_b['Accuracy']:.4f} "
f"Sens={tm_b['Macro_Sensitivity']:.4f} "
f"Spec={tm_b['Macro_Specificity']:.4f} "
f"PPV={tm_b['Macro_PPV']:.4f} NPV={tm_b['Macro_NPV']:.4f} "
f"F1={tm_b['Macro_F1']:.4f}")
log(f" 📊 CV-OOF — AUC={best_auc:.4f} "
f"Acc={amr[best_mn]['mean_acc']:.4f} "
f"Sens={amr[best_mn]['mean_sens']:.4f} "
f"Spec={amr[best_mn]['mean_spec']:.4f} "
f"PPV={amr[best_mn]['mean_ppv']:.4f} "
f"NPV={amr[best_mn]['mean_npv']:.4f} "
f"F1={amr[best_mn]['mean_f1']:.4f}")
log(f"{'━'*60}")
progress(1.0, desc="✅ 完成!")
return zp, "\n".join(L)
except Exception as e:
log(f"\n❌ 错误: {e}")
log(traceback.format_exc())
if os.path.exists(rf): shutil.rmtree(rf, ignore_errors=True)
gc.collect()
return None, "\n".join(L)
# ============================================================================
# Gradio UI
# ============================================================================
CUSTOM_CSS = """
.header-banner {
background: linear-gradient(135deg, #0a2463 0%, #1e3a7a 40%, #2554a8 100%);
border-radius: 16px; padding: 28px 36px; margin-bottom: 20px;
box-shadow: 0 8px 32px rgba(0,0,0,0.18); position: relative; overflow: hidden;
}
.header-banner::before {
content: ''; position: absolute; top: -50%; right: -20%;
width: 400px; height: 400px;
background: radial-gradient(circle, rgba(96,165,250,0.2) 0%, transparent 70%);
border-radius: 50%;
}
.header-banner img { max-height: 52px; border-radius: 6px; margin-bottom: 12px; }
.header-banner h1 { color: #e2e8f0 !important; font-size: 1.7em !important;
margin: 4px 0 6px 0 !important; font-weight: 700 !important; }
.header-banner p { color: #94a3b8 !important; font-size: 0.92em !important;
margin: 2px 0 !important; line-height: 1.6; }
.header-banner .credit { color: #64748b !important; font-size: 0.82em !important;
margin-top: 10px !important;
border-top: 1px solid rgba(148,163,184,0.15); padding-top: 10px; }
.section-title {
background: linear-gradient(90deg, #2563eb 0%, #3b82f6 100%);
color: white !important; padding: 8px 16px; border-radius: 8px;
font-size: 0.95em !important; font-weight: 600 !important;
margin: 12px 0 8px 0; }
.pipeline-box {
background: linear-gradient(135deg, #f0f9ff 0%, #e0f2fe 100%);
border: 1px solid #bae6fd; border-radius: 12px;
padding: 14px 18px; margin: 8px 0; font-size: 0.88em; }
.pipeline-box code { background: #2563eb; color: white; padding: 2px 8px;
border-radius: 4px; font-size: 0.85em; margin: 0 2px; }
.log-area textarea {
font-family: 'Menlo','Consolas',monospace !important;
font-size: 12.5px !important; line-height: 1.5 !important;
background: #0f172a !important; color: #e2e8f0 !important;
border-radius: 10px !important; padding: 16px !important; }
.gradio-container { max-width: 1280px !important; }
footer { display: none !important; }
"""
with gr.Blocks(
title="ML 多分类模型平台 — 复旦大学附属眼耳鼻喉科医院",
theme=gr.themes.Soft(primary_hue="blue", secondary_hue="slate",
neutral_hue="slate"),
css=CUSTOM_CSS,
) as demo:
gr.HTML("""
<div class="header-banner">
<img src="https://huggingface.co/spaces/fudan-renjun/machine-learning-2/resolve/main/hospital_logo.png"
alt="Logo" onerror="this.style.display='none'"/>
<h1>🧬 ML 多分类模型训练与评估平台</h1>
<p>支持 2~8 分类 · 上传 CSV 即可完成全流程分析</p>
<p>评估指标:AUC · Accuracy · Sensitivity · Specificity · PPV · NPV · F1 · Kappa</p>
<p class="credit">复旦大学附属眼耳鼻喉科医院 · 检验科 · 任俊</p>
</div>
""")
gr.HTML("""
<div class="pipeline-box">
<strong>📋 流程:</strong>
<code>训练+训练集评估</code> → <code>交叉验证(OOF)</code> →
<code>Train vs Val对比</code> → <code>SHAP</code> →
<code>特征消融</code> → <code>外部验证</code>
&nbsp;|&nbsp;
<strong>指标:</strong>
AUC · Accuracy · Sensitivity · Specificity · PPV · NPV · F1 · Kappa(宏平均+逐类)
&nbsp;|&nbsp;
<strong>CSV:</strong> 第1列=标签(整数), 第2列=ID, 第3列起=特征
</div>
""")
with gr.Row(equal_height=False):
with gr.Column(scale=5):
gr.HTML('<div class="section-title">📂 数据上传</div>')
train_file = gr.File(label="训练集 CSV(必需)", file_types=[".csv"])
gr.HTML('<p style="color:#64748b;font-size:0.85em;margin:4px 0 8px 0;">'
'验证集可选,支持同时上传 1~3 个</p>')
with gr.Row():
val_file1 = gr.File(label="验证集 1(可选)",
file_types=[".csv"], scale=1)
val_file2 = gr.File(label="验证集 2(可选)",
file_types=[".csv"], scale=1)
val_file3 = gr.File(label="验证集 3(可选)",
file_types=[".csv"], scale=1)
gr.HTML('<div class="section-title">🏷️ 分类设置</div>')
n_classes_select = gr.Dropdown(
choices=["2 类(二分类)","3 类","4 类","5 类",
"6 类","7 类","8 类"],
value="2 类(二分类)", label="选择分类数",
info="请根据数据标签列的类别数选择,系统将自动验证是否匹配",
)
gr.HTML('<div class="section-title">🤖 模型选择</div>')
model_selector = gr.Dropdown(
choices=ALL_MODEL_NAMES, value=ALL_MODEL_NAMES,
multiselect=True, label="选择模型(均支持多分类)",
info=("RF=随机森林 DT=决策树 KNN=K近邻 XGB=XGBoost "
"AdaBoost LR=逻辑回归 NB=朴素贝叶斯 SVM=支持向量机"),
)
with gr.Row():
btn_all = gr.Button("🔘 全选", size="sm", variant="secondary")
btn_tree = gr.Button("🌲 树模型", size="sm", variant="secondary")
btn_linear = gr.Button("📐 线性模型", size="sm", variant="secondary")
btn_top4 = gr.Button("⚡ 经典四模型", size="sm", variant="secondary")
btn_all.click(lambda: ALL_MODEL_NAMES, outputs=model_selector)
btn_tree.click(lambda: ['RF','DT','XGB','AdaBoost'], outputs=model_selector)
btn_linear.click(lambda: ['LR','SVM','NB'], outputs=model_selector)
btn_top4.click(lambda: ['RF','XGB','LR','SVM'], outputs=model_selector)
gr.HTML('<div class="section-title">⚙️ 参数配置</div>')
enable_tuning = gr.Checkbox(
value=False,
label="启用超参数调优 (GridSearchCV) ⚠️ 开启后运行时间显著增加")
with gr.Row():
cv_folds = gr.Slider(3, 10, value=5, step=1,
label="交叉验证折数")
top_n = gr.Slider(5, 50, value=20, step=1,
label="SHAP 前 N 个特征")
shap_sz = gr.Slider(30, 200, value=80, step=10,
label="SHAP 采样数量")
run_btn = gr.Button("🚀 开始分析", variant="primary", size="lg")
with gr.Column(scale=5):
gr.HTML('<div class="section-title">📋 运行日志</div>')
log_output = gr.Textbox(
label="", lines=24, max_lines=50, interactive=False,
placeholder=("点击「开始分析」后,日志将在此显示...\n"
"支持 2~8 分类。\n"
"评估指标:AUC / Accuracy / Sensitivity / "
"Specificity / PPV / NPV / F1 / Kappa"),
elem_classes="log-area",
)
gr.HTML('<div class="section-title">⬇️ 结果下载</div>')
zip_output = gr.File(label="分析结果 ZIP 压缩包")
run_btn.click(
fn=run_pipeline,
inputs=[train_file, val_file1, val_file2, val_file3,
n_classes_select, model_selector, enable_tuning,
cv_folds, top_n, shap_sz],
outputs=[zip_output, log_output],
api_name="run",
)
# ============================================================================
# Authentication
# ============================================================================
from datetime import datetime
ACCOUNTS = {
"admin": {"password": "admin123", "expires": None},
"renjun": {"password": "fudan2025", "expires": "2027-12-31"},
"guest": {"password": "guest888", "expires": "2027-06-30"},
}
def auth_fn(username, password):
user = ACCOUNTS.get(username)
if not user or user["password"] != password: return False
if user["expires"]:
try:
if datetime.now() > datetime.strptime(user["expires"], "%Y-%m-%d"):
return False
except: return False
return True
demo.queue()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
auth=auth_fn,
auth_message=("🔐 复旦大学附属眼耳鼻喉科医院 · ML多分类分析平台\n"
"请输入账号和密码登录"),
ssr_mode=False,
)