Spaces:
Sleeping
Sleeping
| """ | |
| ML Multi-Class Classification Pipeline (2-8 classes) | |
| Eye & ENT Hospital of Fudan University — Laboratory Medicine, Ren Jun | |
| Gradio 5.12.0 + Python 3.11 | |
| Changelog v3 (vs v2): | |
| [v3-1] compute_multiclass_metrics now returns full per-class and macro | |
| AUC, Accuracy, Sensitivity (Recall), Specificity, Precision (PPV), | |
| NPV, F1 for every class, plus macro/weighted averages. | |
| [v3-2] Per-fold metrics table extended with all new indicators. | |
| [v3-3] Summary sheets (Summary_InternalVal, Train_vs_InternalVal) carry | |
| all new macro indicators. | |
| [v3-4] Per-class detail sheets written for every model (train + val). | |
| [v3-5] External validation Excel extended with all new indicators. | |
| [v3-6] best_params.txt log extended with all new indicators. | |
| [v3-7] Console log shows key new indicators. | |
| Previous fixes retained: | |
| [FIX-1] XGBoost num_class=None bug | |
| [FIX-2] Bootstrap p-value centered on 0 | |
| [FIX-3] SHAP 3D axis detection | |
| [FIX-4] Per-model train-set ROC/PR/CM | |
| [FIX-5] Best-model Train vs InternalVal overlay plots | |
| [FIX-6] Train_vs_InternalVal Excel sheet | |
| [FIX-7] Guest account expiry updated | |
| """ | |
| import numpy as np | |
| import pandas as pd | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier | |
| from sklearn.tree import DecisionTreeClassifier | |
| from sklearn.neighbors import KNeighborsClassifier | |
| from sklearn.linear_model import LogisticRegression | |
| from sklearn.naive_bayes import GaussianNB | |
| from sklearn.svm import SVC | |
| from xgboost import XGBClassifier | |
| from sklearn.model_selection import StratifiedKFold, GridSearchCV | |
| from sklearn.metrics import ( | |
| roc_auc_score, confusion_matrix, roc_curve, | |
| auc as auc_score, precision_recall_curve, | |
| classification_report, accuracy_score, f1_score, | |
| cohen_kappa_score, precision_score, recall_score | |
| ) | |
| from sklearn.preprocessing import label_binarize | |
| import seaborn as sns | |
| import warnings | |
| import os | |
| import shap | |
| import pickle | |
| from copy import deepcopy | |
| import zipfile | |
| import tempfile | |
| import traceback | |
| import time | |
| import shutil | |
| import gc | |
| import threading | |
| import gradio as gr | |
| warnings.filterwarnings('ignore') | |
| # Publication-quality plot settings | |
| plt.rcParams['font.family'] = 'serif' | |
| plt.rcParams['font.serif'] = ['Times New Roman', 'DejaVu Serif', 'serif'] | |
| plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans'] | |
| plt.rcParams['axes.unicode_minus'] = False | |
| plt.rcParams['figure.dpi'] = 150 | |
| plt.rcParams['savefig.dpi'] = 300 | |
| plt.rcParams['axes.linewidth'] = 1.2 | |
| plt.rcParams['xtick.major.width'] = 1.0 | |
| plt.rcParams['ytick.major.width'] = 1.0 | |
| plt.rcParams['xtick.labelsize'] = 11 | |
| plt.rcParams['ytick.labelsize'] = 11 | |
| # ============================================================================ | |
| # Cache Cleanup | |
| # ============================================================================ | |
| CLEANUP_MAX_AGE_MINUTES = 30 | |
| CLEANUP_INTERVAL_SECONDS = 600 | |
| def cleanup_old_temp_files(): | |
| now = time.time(); tmp = tempfile.gettempdir() | |
| try: | |
| for item in os.listdir(tmp): | |
| p = os.path.join(tmp, item) | |
| if item.startswith("ml_"): | |
| age = now - os.path.getmtime(p) | |
| if age > CLEANUP_MAX_AGE_MINUTES * 60: | |
| if os.path.isdir(p): shutil.rmtree(p, ignore_errors=True) | |
| elif os.path.isfile(p): os.remove(p) | |
| except: pass | |
| gc.collect() | |
| def periodic_cleanup(): | |
| while True: | |
| time.sleep(CLEANUP_INTERVAL_SECONDS) | |
| cleanup_old_temp_files() | |
| _ct = threading.Thread(target=periodic_cleanup, daemon=True); _ct.start() | |
| # ============================================================================ | |
| # [v3-1] Extended metrics: Sensitivity, Specificity, PPV, NPV per class | |
| # ============================================================================ | |
| def compute_per_class_sens_spec_ppv_npv(y_true, y_pred, y_proba, classes): | |
| """ | |
| For each class c, treat it as a binary OvR problem: | |
| TP = predicted c AND true c | |
| FP = predicted c AND true != c | |
| FN = predicted != c AND true c | |
| TN = predicted != c AND true != c | |
| Returns a dict keyed by class index with: | |
| Sensitivity (Recall / TPR), Specificity (TNR), | |
| PPV (Precision), NPV, F1, AUC (OvR) | |
| Also returns macro averages of each metric. | |
| """ | |
| n_classes = len(classes) | |
| y_true = np.asarray(y_true) | |
| y_pred = np.asarray(y_pred) | |
| y_bin = label_binarize(y_true, classes=classes) | |
| if n_classes == 2: | |
| y_bin = np.hstack([1 - y_bin, y_bin]) | |
| per_class = {} | |
| for i, c in enumerate(classes): | |
| yt_b = y_bin[:, i] # true binary label for class c | |
| yp_b = (y_pred == c).astype(int) | |
| TP = int(np.sum((yt_b == 1) & (yp_b == 1))) | |
| FP = int(np.sum((yt_b == 0) & (yp_b == 1))) | |
| FN = int(np.sum((yt_b == 1) & (yp_b == 0))) | |
| TN = int(np.sum((yt_b == 0) & (yp_b == 0))) | |
| sens = TP / (TP + FN) if (TP + FN) > 0 else 0.0 # Sensitivity = Recall | |
| spec = TN / (TN + FP) if (TN + FP) > 0 else 0.0 # Specificity | |
| ppv = TP / (TP + FP) if (TP + FP) > 0 else 0.0 # PPV = Precision | |
| npv = TN / (TN + FN) if (TN + FN) > 0 else 0.0 # NPV | |
| f1 = (2 * ppv * sens / (ppv + sens)) if (ppv + sens) > 0 else 0.0 | |
| try: | |
| auc_c = roc_auc_score(yt_b, y_proba[:, i]) | |
| except: | |
| auc_c = 0.0 | |
| per_class[c] = { | |
| 'TP': TP, 'FP': FP, 'FN': FN, 'TN': TN, | |
| 'Sensitivity': sens, 'Specificity': spec, | |
| 'PPV': ppv, 'NPV': npv, 'F1': f1, 'AUC': auc_c | |
| } | |
| # Macro averages | |
| macro = {} | |
| for metric in ['Sensitivity', 'Specificity', 'PPV', 'NPV', 'F1', 'AUC']: | |
| macro[f'Macro_{metric}'] = np.mean([per_class[c][metric] for c in classes]) | |
| return per_class, macro | |
| def compute_multiclass_metrics(y_true, y_pred, y_proba, classes): | |
| """ | |
| [v3-1] Extended: returns AUC, Accuracy, Sensitivity, Specificity, | |
| Precision (PPV), NPV, F1 — macro and per-class — plus Kappa. | |
| """ | |
| n_classes = len(classes) | |
| y_true = np.asarray(y_true) | |
| y_pred = np.asarray(y_pred) | |
| acc = accuracy_score(y_true, y_pred) | |
| kappa = cohen_kappa_score(y_true, y_pred) | |
| # Macro AUC | |
| try: | |
| if n_classes == 2: | |
| macro_auc = roc_auc_score(y_true, y_proba[:, 1]) | |
| else: | |
| macro_auc = roc_auc_score(y_true, y_proba, | |
| multi_class='ovr', average='macro') | |
| except: | |
| macro_auc = 0.0 | |
| # sklearn macro/weighted aggregates | |
| f1_macro = f1_score(y_true, y_pred, average='macro', zero_division=0, labels=classes) | |
| f1_weighted = f1_score(y_true, y_pred, average='weighted', zero_division=0, labels=classes) | |
| prec_macro = precision_score(y_true, y_pred, average='macro', zero_division=0, labels=classes) | |
| recall_macro = recall_score( y_true, y_pred, average='macro', zero_division=0, labels=classes) | |
| # Per-class Sensitivity / Specificity / PPV / NPV / F1 / AUC | |
| per_class, macro_ext = compute_per_class_sens_spec_ppv_npv( | |
| y_true, y_pred, y_proba, classes) | |
| # sklearn classification_report (for precision/recall/f1 by class) | |
| report = classification_report( | |
| y_true, y_pred, labels=classes, output_dict=True, zero_division=0) | |
| return { | |
| # ── Macro aggregates ── | |
| 'Accuracy': acc, | |
| 'Macro_AUC': macro_auc, | |
| 'Macro_Sensitivity': macro_ext['Macro_Sensitivity'], # == Macro Recall | |
| 'Macro_Specificity': macro_ext['Macro_Specificity'], | |
| 'Macro_PPV': macro_ext['Macro_PPV'], # == Macro Precision | |
| 'Macro_NPV': macro_ext['Macro_NPV'], | |
| 'Macro_F1': macro_ext['Macro_F1'], | |
| 'Weighted_F1': f1_weighted, | |
| 'Kappa': kappa, | |
| # ── Per-class detail ── | |
| 'per_class': per_class, # dict keyed by class value | |
| 'report': report, | |
| } | |
| def metrics_to_flat_row(metrics, prefix=''): | |
| """Flatten a metrics dict into a single-row dict for DataFrame construction.""" | |
| row = { | |
| f'{prefix}AUC': metrics['Macro_AUC'], | |
| f'{prefix}Accuracy': metrics['Accuracy'], | |
| f'{prefix}Sensitivity': metrics['Macro_Sensitivity'], | |
| f'{prefix}Specificity': metrics['Macro_Specificity'], | |
| f'{prefix}PPV': metrics['Macro_PPV'], | |
| f'{prefix}NPV': metrics['Macro_NPV'], | |
| f'{prefix}F1': metrics['Macro_F1'], | |
| f'{prefix}Weighted_F1': metrics['Weighted_F1'], | |
| f'{prefix}Kappa': metrics['Kappa'], | |
| } | |
| return row | |
| def per_class_df(metrics, classes): | |
| """Build a tidy per-class DataFrame from compute_multiclass_metrics output.""" | |
| rows = [] | |
| for c in classes: | |
| pc = metrics['per_class'][c] | |
| rows.append({ | |
| 'Class': c, | |
| 'AUC': pc['AUC'], | |
| 'Sensitivity': pc['Sensitivity'], | |
| 'Specificity': pc['Specificity'], | |
| 'PPV': pc['PPV'], | |
| 'NPV': pc['NPV'], | |
| 'F1': pc['F1'], | |
| 'TP': pc['TP'], | |
| 'FP': pc['FP'], | |
| 'FN': pc['FN'], | |
| 'TN': pc['TN'], | |
| }) | |
| # Append macro row | |
| rows.append({ | |
| 'Class': 'Macro', | |
| 'AUC': metrics['Macro_AUC'], | |
| 'Sensitivity': metrics['Macro_Sensitivity'], | |
| 'Specificity': metrics['Macro_Specificity'], | |
| 'PPV': metrics['Macro_PPV'], | |
| 'NPV': metrics['Macro_NPV'], | |
| 'F1': metrics['Macro_F1'], | |
| 'TP': '', 'FP': '', 'FN': '', 'TN': '', | |
| }) | |
| return pd.DataFrame(rows) | |
| # ============================================================================ | |
| # Plotting helpers | |
| # ============================================================================ | |
| def plot_multiclass_roc(y_true, y_proba, classes, title, filepath_prefix, rf): | |
| """Plot ROC curves: one-vs-rest for each class + macro average.""" | |
| n_classes = len(classes) | |
| y_bin = label_binarize(y_true, classes=classes) | |
| if n_classes == 2: | |
| y_bin = np.hstack([1 - y_bin, y_bin]) | |
| fpr_dict, tpr_dict, auc_dict = {}, {}, {} | |
| for i in range(n_classes): | |
| fpr_dict[i], tpr_dict[i], _ = roc_curve(y_bin[:, i], y_proba[:, i]) | |
| auc_dict[i] = auc_score(fpr_dict[i], tpr_dict[i]) | |
| all_fpr = np.unique(np.concatenate([fpr_dict[i] for i in range(n_classes)])) | |
| mean_tpr = np.zeros_like(all_fpr) | |
| for i in range(n_classes): | |
| mean_tpr += np.interp(all_fpr, fpr_dict[i], tpr_dict[i]) | |
| mean_tpr /= n_classes | |
| macro_auc = auc_score(all_fpr, mean_tpr) | |
| COLORS = ['#e41a1c','#377eb8','#4daf4a','#984ea3', | |
| '#ff7f00','#a65628','#f781bf','#999999'] | |
| plt.figure(figsize=(10, 8)) | |
| for i in range(n_classes): | |
| plt.plot(fpr_dict[i], tpr_dict[i], color=COLORS[i % len(COLORS)], lw=2, | |
| label=f'Class {classes[i]} (AUC={auc_dict[i]:.3f})') | |
| plt.plot(all_fpr, mean_tpr, 'k--', lw=2.5, | |
| label=f'Macro Avg (AUC={macro_auc:.3f})') | |
| plt.plot([0,1],[0,1],'--',color='#cccccc',lw=1) | |
| plt.xlim([-0.02,1.02]); plt.ylim([-0.02,1.02]) | |
| plt.xlabel('False Positive Rate', fontsize=13) | |
| plt.ylabel('True Positive Rate', fontsize=13) | |
| plt.title(title, fontsize=14, fontweight='bold') | |
| plt.legend(loc='lower right', fontsize=9) | |
| plt.grid(True, alpha=0.15); plt.tight_layout() | |
| plt.savefig(os.path.join(rf, f'{filepath_prefix}.pdf'), | |
| format='pdf', bbox_inches='tight', dpi=300) | |
| plt.savefig(os.path.join(rf, f'{filepath_prefix}.png'), | |
| format='png', bbox_inches='tight', dpi=150) | |
| plt.close() | |
| return macro_auc, auc_dict | |
| def plot_multiclass_pr(y_true, y_proba, classes, title, filepath_prefix, rf): | |
| """Plot Precision-Recall curves for each class.""" | |
| n_classes = len(classes) | |
| y_bin = label_binarize(y_true, classes=classes) | |
| if n_classes == 2: | |
| y_bin = np.hstack([1 - y_bin, y_bin]) | |
| COLORS = ['#e41a1c','#377eb8','#4daf4a','#984ea3', | |
| '#ff7f00','#a65628','#f781bf','#999999'] | |
| plt.figure(figsize=(10, 8)) | |
| for i in range(n_classes): | |
| prec, rec, _ = precision_recall_curve(y_bin[:, i], y_proba[:, i]) | |
| ap = auc_score(rec, prec) | |
| plt.plot(rec, prec, color=COLORS[i % len(COLORS)], lw=2, | |
| label=f'Class {classes[i]} (AP={ap:.3f})') | |
| plt.xlim([-0.02,1.02]); plt.ylim([-0.02,1.02]) | |
| plt.xlabel('Recall', fontsize=13); plt.ylabel('Precision', fontsize=13) | |
| plt.title(title, fontsize=14, fontweight='bold') | |
| plt.legend(loc='lower left', fontsize=9) | |
| plt.grid(True, alpha=0.15); plt.tight_layout() | |
| plt.savefig(os.path.join(rf, f'{filepath_prefix}.pdf'), | |
| format='pdf', bbox_inches='tight', dpi=300) | |
| plt.savefig(os.path.join(rf, f'{filepath_prefix}.png'), | |
| format='png', bbox_inches='tight', dpi=150) | |
| plt.close() | |
| def plot_confusion_matrix(y_true, y_pred, classes, title, filepath_prefix, rf): | |
| """Plot confusion matrix heatmap.""" | |
| cm = confusion_matrix(y_true, y_pred, labels=classes) | |
| plt.figure(figsize=(max(6, len(classes)*1.2), max(5, len(classes)*1.0))) | |
| sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=True, | |
| xticklabels=classes, yticklabels=classes, | |
| annot_kws={'fontsize': 11}) | |
| plt.xlabel('Predicted', fontsize=12); plt.ylabel('True', fontsize=12) | |
| plt.title(title, fontsize=13, fontweight='bold'); plt.tight_layout() | |
| plt.savefig(os.path.join(rf, f'{filepath_prefix}.pdf'), | |
| format='pdf', bbox_inches='tight', dpi=300) | |
| plt.savefig(os.path.join(rf, f'{filepath_prefix}.png'), | |
| format='png', bbox_inches='tight', dpi=150) | |
| plt.close() | |
| return cm | |
| def plot_train_vs_val_roc(y_train, train_proba, y_val, val_proba, | |
| classes, model_name, filepath_prefix, rf): | |
| """Overlay train-set ROC and internal-validation (CV OOF) ROC.""" | |
| n_classes = len(classes) | |
| def macro_roc(y_true, y_proba): | |
| y_bin = label_binarize(y_true, classes=classes) | |
| if n_classes == 2: | |
| y_bin = np.hstack([1 - y_bin, y_bin]) | |
| all_fpr = np.linspace(0, 1, 300) | |
| mean_tpr = np.zeros_like(all_fpr) | |
| for i in range(n_classes): | |
| f, t, _ = roc_curve(y_bin[:, i], y_proba[:, i]) | |
| mean_tpr += np.interp(all_fpr, f, t) | |
| mean_tpr /= n_classes; mean_tpr[-1] = 1.0 | |
| return all_fpr, mean_tpr, auc_score(all_fpr, mean_tpr) | |
| fpr_tr, tpr_tr, auc_tr = macro_roc(y_train, train_proba) | |
| fpr_vl, tpr_vl, auc_vl = macro_roc(y_val, val_proba) | |
| plt.figure(figsize=(10, 8)) | |
| plt.plot(fpr_tr, tpr_tr, color='#e41a1c', lw=2.5, | |
| label=f'Training set (Macro AUC={auc_tr:.3f})') | |
| plt.plot(fpr_vl, tpr_vl, color='#377eb8', lw=2.5, linestyle='--', | |
| label=f'Internal validation / CV-OOF (Macro AUC={auc_vl:.3f})') | |
| plt.plot([0,1],[0,1],'--',color='#cccccc',lw=1) | |
| plt.xlim([-0.02,1.02]); plt.ylim([-0.02,1.02]) | |
| plt.xlabel('False Positive Rate', fontsize=13) | |
| plt.ylabel('True Positive Rate', fontsize=13) | |
| plt.title(f'ROC — Train vs Internal Validation — {model_name}', | |
| fontsize=14, fontweight='bold') | |
| plt.legend(loc='lower right', fontsize=11) | |
| plt.grid(True, alpha=0.15); plt.tight_layout() | |
| plt.savefig(os.path.join(rf, f'{filepath_prefix}.pdf'), | |
| format='pdf', bbox_inches='tight', dpi=300) | |
| plt.savefig(os.path.join(rf, f'{filepath_prefix}.png'), | |
| format='png', bbox_inches='tight', dpi=150) | |
| plt.close() | |
| return auc_tr, auc_vl | |
| def plot_train_vs_val_pr(y_train, train_proba, y_val, val_proba, | |
| classes, model_name, filepath_prefix, rf): | |
| """Overlay train-set PR and internal-validation PR.""" | |
| n_classes = len(classes) | |
| def macro_pr(y_true, y_proba): | |
| y_bin = label_binarize(y_true, classes=classes) | |
| if n_classes == 2: | |
| y_bin = np.hstack([1 - y_bin, y_bin]) | |
| all_rec = np.linspace(0, 1, 300) | |
| mean_prec = np.zeros_like(all_rec) | |
| for i in range(n_classes): | |
| prec, rec, _ = precision_recall_curve(y_bin[:, i], y_proba[:, i]) | |
| mean_prec += np.interp(all_rec, rec[::-1], prec[::-1]) | |
| mean_prec /= n_classes | |
| return all_rec, mean_prec, auc_score(all_rec, mean_prec) | |
| rec_tr, prec_tr, ap_tr = macro_pr(y_train, train_proba) | |
| rec_vl, prec_vl, ap_vl = macro_pr(y_val, val_proba) | |
| plt.figure(figsize=(10, 8)) | |
| plt.plot(rec_tr, prec_tr, color='#e41a1c', lw=2.5, | |
| label=f'Training set (Macro AP={ap_tr:.3f})') | |
| plt.plot(rec_vl, prec_vl, color='#377eb8', lw=2.5, linestyle='--', | |
| label=f'Internal validation / CV-OOF (Macro AP={ap_vl:.3f})') | |
| plt.xlim([-0.02,1.02]); plt.ylim([-0.02,1.02]) | |
| plt.xlabel('Recall', fontsize=13); plt.ylabel('Precision', fontsize=13) | |
| plt.title(f'PR — Train vs Internal Validation — {model_name}', | |
| fontsize=14, fontweight='bold') | |
| plt.legend(loc='lower left', fontsize=11) | |
| plt.grid(True, alpha=0.15); plt.tight_layout() | |
| plt.savefig(os.path.join(rf, f'{filepath_prefix}.pdf'), | |
| format='pdf', bbox_inches='tight', dpi=300) | |
| plt.savefig(os.path.join(rf, f'{filepath_prefix}.png'), | |
| format='png', bbox_inches='tight', dpi=150) | |
| plt.close() | |
| return ap_tr, ap_vl | |
| # ============================================================================ | |
| # Bootstrap AUC test [FIX-2 retained] | |
| # ============================================================================ | |
| def bootstrap_auc_test(y_true, proba_a, proba_b, classes, | |
| n_bootstrap=2000, seed=42): | |
| rng = np.random.RandomState(seed) | |
| n = len(y_true) | |
| n_classes = len(classes) | |
| def calc_macro_auc(yt, pa, pb): | |
| try: | |
| if n_classes == 2: | |
| a1 = roc_auc_score(yt, pa[:, 1]) | |
| a2 = roc_auc_score(yt, pb[:, 1]) | |
| else: | |
| a1 = roc_auc_score(yt, pa, multi_class='ovr', average='macro') | |
| a2 = roc_auc_score(yt, pb, multi_class='ovr', average='macro') | |
| return a1, a2 | |
| except: | |
| return 0.0, 0.0 | |
| auc_a, auc_b = calc_macro_auc(y_true, proba_a, proba_b) | |
| observed_diff = auc_a - auc_b | |
| diffs = [] | |
| for _ in range(n_bootstrap): | |
| idx = rng.choice(n, n, replace=True) | |
| yt_b = y_true[idx]; pa_b = proba_a[idx]; pb_b = proba_b[idx] | |
| if len(np.unique(yt_b)) < n_classes: | |
| continue | |
| a1, a2 = calc_macro_auc(yt_b, pa_b, pb_b) | |
| diffs.append(a1 - a2) | |
| if len(diffs) < 100: | |
| return 1.0, auc_a, auc_b, -1, 1 | |
| diffs = np.array(diffs) | |
| # [FIX-2] H0: diff = 0, two-sided | |
| p_value = np.mean(np.abs(diffs) >= np.abs(observed_diff)) | |
| p_value = max(p_value, 1.0 / n_bootstrap) | |
| ci_low = np.percentile(diffs, 2.5) | |
| ci_high = np.percentile(diffs, 97.5) | |
| return p_value, auc_a, auc_b, ci_low, ci_high | |
| # ============================================================================ | |
| # [FIX-1] Model configs — XGBoost num_class constructed conditionally | |
| # ============================================================================ | |
| ALL_MODEL_NAMES = ['RF', 'DT', 'KNN', 'XGB', 'AdaBoost', 'LR', 'NB', 'SVM'] | |
| def get_models_config(selected, n_classes, rs=42): | |
| xgb_kwargs = dict(random_state=rs, eval_metric='mlogloss', n_jobs=-1) | |
| if n_classes > 2: | |
| xgb_kwargs['objective'] = 'multi:softprob' | |
| xgb_kwargs['num_class'] = n_classes | |
| else: | |
| xgb_kwargs['objective'] = 'binary:logistic' | |
| xgb_kwargs['eval_metric'] = 'logloss' | |
| cfg = { | |
| 'RF': {'model': RandomForestClassifier(random_state=rs, n_jobs=-1), | |
| 'params': {'n_estimators': [100,200], 'max_depth': [20,50], | |
| 'min_samples_split': [2,5]}}, | |
| 'DT': {'model': DecisionTreeClassifier(random_state=rs), | |
| 'params': {'max_depth': [20,50], 'min_samples_split': [2,10], | |
| 'criterion': ['gini','entropy']}}, | |
| 'KNN': {'model': KNeighborsClassifier(n_jobs=-1), | |
| 'params': {'n_neighbors': [3,5,7], | |
| 'weights': ['uniform','distance']}}, | |
| 'XGB': {'model': XGBClassifier(**xgb_kwargs), | |
| 'params': {'n_estimators': [100,200], 'max_depth': [5,7], | |
| 'learning_rate': [0.05,0.1]}}, | |
| 'AdaBoost': {'model': AdaBoostClassifier(random_state=rs), | |
| 'params': {'n_estimators': [50,100], | |
| 'learning_rate': [0.1,0.5,1.0]}}, | |
| 'LR': {'model': LogisticRegression(random_state=rs, n_jobs=-1, | |
| max_iter=2000), | |
| 'params': {'C': [0.1,1,10], 'solver': ['lbfgs']}}, | |
| 'NB': {'model': GaussianNB(), | |
| 'params': {'var_smoothing': [1e-9,1e-7,1e-5]}}, | |
| 'SVM': {'model': SVC(probability=True, random_state=rs, | |
| decision_function_shape='ovr'), | |
| 'params': {'C': [1,10], 'kernel': ['rbf','linear']}}, | |
| } | |
| return {k: v for k, v in cfg.items() if k in selected} | |
| # ============================================================================ | |
| # Main Pipeline | |
| # ============================================================================ | |
| def run_pipeline( | |
| train_file, val_file1, val_file2, val_file3, n_classes_select, | |
| selected_models, enable_tuning, | |
| cv_folds, top_n_features, shap_sample_size, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| if train_file is None: | |
| return None, "❌ 请先上传训练集 CSV 文件" | |
| sel = (selected_models if isinstance(selected_models, list) | |
| else [s.strip() for s in str(selected_models).split(",") if s.strip()]) | |
| if not sel: | |
| return None, "❌ 请至少选择一个模型" | |
| RS = 42; CVF = int(cv_folds) | |
| TOPN = int(top_n_features); SHAPSZ = int(shap_sample_size) | |
| TUNING = bool(enable_tuning) | |
| L = [] | |
| def log(m): L.append(str(m)) | |
| rf = tempfile.mkdtemp(prefix="ml_") | |
| try: | |
| # ── Load Data ── | |
| progress(0.02, desc="📂 加载数据...") | |
| log("━" * 60) | |
| log(" 🧬 ML 多分类模型训练与评估系统 v3") | |
| log("━" * 60) | |
| tp = (train_file if isinstance(train_file, str) | |
| else getattr(train_file, 'name', str(train_file))) | |
| data = pd.read_csv(tp) | |
| y = data.iloc[:, 0] | |
| col2 = data.iloc[:, 1] | |
| col2_is_id = ((col2.dtype == 'object') or | |
| (col2.nunique() / len(col2) > 0.5)) | |
| if col2_is_id: | |
| X = data.iloc[:, 2:] | |
| log(f" 📋 CSV: Col1=Label, Col2=ID({data.columns[1]}), Col3+=Features") | |
| else: | |
| X = data.iloc[:, 1:] | |
| log(f" 📋 CSV: Col1=Label, Col2+=Features (no ID column)") | |
| fnames = X.columns.tolist() | |
| user_n = int(str(n_classes_select).split(" ")[0]) | |
| detected_classes = sorted(y.unique()) | |
| detected_classes = [int(c) if hasattr(c, 'item') else c | |
| for c in detected_classes] | |
| detected_n = len(detected_classes) | |
| if detected_n != user_n: | |
| return None, ( | |
| f"❌ 您选择了 {user_n} 分类,但数据中检测到 {detected_n} 个类别: " | |
| f"{detected_classes}\n请将分类数修改为 {detected_n},或检查数据标签列") | |
| classes = detected_classes | |
| n_classes = user_n | |
| log(f" ✅ {n_classes} 分类 — 数据验证通过") | |
| label_map = {c: i for i, c in enumerate(classes)} | |
| label_map_inv = {i: c for c, i in label_map.items()} | |
| y_mapped = y.map(label_map) | |
| class_indices = list(range(n_classes)) | |
| log(f" 📊 训练集: {X.shape[0]} 样本 × {X.shape[1]} 特征") | |
| log(f" 🏷️ 类别数: {n_classes} 类 — {classes}") | |
| log(f" 📊 分布: {dict(y.value_counts().sort_index())}") | |
| log(f" 🤖 模型: {', '.join(sel)}") | |
| log(f" 🔧 调优: {'开启' if TUNING else '关闭'} | CV: {CVF}折") | |
| if n_classes < 2 or n_classes > 8: | |
| return None, f"❌ 仅支持 2~8 分类,当前检测到 {n_classes} 类" | |
| task_type = "Binary" if n_classes == 2 else f"{n_classes}-Class" | |
| task_type_cn = "二分类" if n_classes == 2 else f"{n_classes}分类" | |
| log(f" 📋 任务: {task_type_cn} ({task_type})") | |
| mcfg = get_models_config(sel, n_classes, RS) | |
| skf = StratifiedKFold(n_splits=CVF, shuffle=True, random_state=RS) | |
| COLORS = ['#2563eb','#f59e0b','#10b981','#ef4444', | |
| '#8b5cf6','#ec4899','#06b6d4','#6b7280'] | |
| bpd = {} # best params | |
| amr = {} # all model results (CV-OOF) | |
| tms = {} # trained models (full data) | |
| train_results = {} # metrics on full training set | |
| total = len(mcfg) | |
| # ── Train All Models ────────────────────────────────────────────── | |
| for mi, (mn, cf) in enumerate(mcfg.items()): | |
| pv = 0.05 + 0.35 * mi / total | |
| progress(pv, desc=f"🏋️ [{mi+1}/{total}] 训练 {mn}...") | |
| log(f"\n{'─'*50}") | |
| log(f" 🔄 [{mi+1}/{total}] {mn}") | |
| Xv = X.values | |
| # Optional GridSearch | |
| if TUNING: | |
| log(f" ⏳ GridSearchCV (CV={CVF})...") | |
| scoring = 'roc_auc_ovr' if n_classes > 2 else 'roc_auc' | |
| gs = GridSearchCV(cf['model'], cf['params'], cv=skf, | |
| scoring=scoring, n_jobs=-1, verbose=0) | |
| gs.fit(Xv, y_mapped) | |
| bp = gs.best_params_; bpd[mn] = bp | |
| log(f" ✓ 最佳CV Score: {gs.best_score_:.4f}") | |
| else: | |
| bp = {}; bpd[mn] = "Default" | |
| # Fit final model on full training set | |
| mdl = deepcopy(cf['model']) | |
| if bp: mdl.set_params(**bp) | |
| mdl.fit(Xv, y_mapped) | |
| tms[mn] = mdl | |
| # Training-set metrics | |
| train_proba_full = mdl.predict_proba(Xv) | |
| train_pred_full = mdl.predict(Xv) | |
| train_met = compute_multiclass_metrics( | |
| y_mapped.values, train_pred_full, | |
| train_proba_full, class_indices) | |
| train_results[mn] = { | |
| 'proba': train_proba_full, | |
| 'pred': train_pred_full, | |
| 'metrics': train_met, | |
| } | |
| # ── CV evaluation (OOF = Internal Validation) ── | |
| all_yt = []; all_yp = []; all_yproba = [] | |
| fold_metrics = [] | |
| for fi, (tri, tei) in enumerate(skf.split(X, y_mapped), 1): | |
| Xtr, Xte = X.iloc[tri].values, X.iloc[tei].values | |
| ytr, yte = y_mapped.iloc[tri], y_mapped.iloc[tei] | |
| mf = deepcopy(cf['model']) | |
| if bp: mf.set_params(**bp) | |
| mf.fit(Xtr, ytr) | |
| ypred = mf.predict(Xte) | |
| yproba = mf.predict_proba(Xte) | |
| all_yt.extend(yte) | |
| all_yp.extend(ypred) | |
| all_yproba.append(yproba) | |
| fm = compute_multiclass_metrics(yte, ypred, yproba, class_indices) | |
| # [v3-2] Extended fold row | |
| fold_metrics.append({ | |
| 'Fold': fi, | |
| 'AUC': fm['Macro_AUC'], | |
| 'Accuracy': fm['Accuracy'], | |
| 'Sensitivity': fm['Macro_Sensitivity'], | |
| 'Specificity': fm['Macro_Specificity'], | |
| 'PPV': fm['Macro_PPV'], | |
| 'NPV': fm['Macro_NPV'], | |
| 'F1': fm['Macro_F1'], | |
| 'Weighted_F1': fm['Weighted_F1'], | |
| 'Kappa': fm['Kappa'], | |
| }) | |
| all_yt = np.array(all_yt) | |
| all_yp = np.array(all_yp) | |
| all_yproba = np.vstack(all_yproba) | |
| # Build fold table with Mean row | |
| fdf = pd.DataFrame(fold_metrics) | |
| mean_row = { | |
| col: (fdf[col].mean() if col != 'Fold' else 'Mean') | |
| for col in fdf.columns | |
| } | |
| fdf = pd.concat([fdf, pd.DataFrame([mean_row])], ignore_index=True) | |
| # OOF aggregate metrics (computed on concatenated OOF predictions) | |
| oof_met = compute_multiclass_metrics( | |
| all_yt, all_yp, all_yproba, class_indices) | |
| amr[mn] = { | |
| 'fold_df': fdf, | |
| 'mean_auc': mean_row['AUC'], | |
| 'mean_acc': mean_row['Accuracy'], | |
| 'mean_sens': mean_row['Sensitivity'], | |
| 'mean_spec': mean_row['Specificity'], | |
| 'mean_ppv': mean_row['PPV'], | |
| 'mean_npv': mean_row['NPV'], | |
| 'mean_f1': mean_row['F1'], | |
| 'mean_wf1': mean_row['Weighted_F1'], | |
| 'mean_kappa': mean_row['Kappa'], | |
| 'oof_metrics': oof_met, | |
| 'all_yt': all_yt, | |
| 'all_yp': all_yp, | |
| 'all_yproba': all_yproba, | |
| } | |
| # [v3-7] Log all key metrics | |
| tm = train_met; vm = mean_row | |
| log(f" ✅ [Train] AUC={tm['Macro_AUC']:.4f} Acc={tm['Accuracy']:.4f} " | |
| f"Sens={tm['Macro_Sensitivity']:.4f} Spec={tm['Macro_Specificity']:.4f} " | |
| f"PPV={tm['Macro_PPV']:.4f} NPV={tm['Macro_NPV']:.4f} " | |
| f"F1={tm['Macro_F1']:.4f} Kappa={tm['Kappa']:.4f}") | |
| log(f" ✅ [CV-OOF] AUC={vm['AUC']:.4f} Acc={vm['Accuracy']:.4f} " | |
| f"Sens={vm['Sensitivity']:.4f} Spec={vm['Specificity']:.4f} " | |
| f"PPV={vm['PPV']:.4f} NPV={vm['NPV']:.4f} " | |
| f"F1={vm['F1']:.4f} Kappa={vm['Kappa']:.4f}") | |
| mnames = list(amr.keys()); nm = len(mnames) | |
| log(f"\n{'━'*60}") | |
| log(f" ✅ {nm} 个模型训练完成") | |
| # ── Training-set ROC / PR / CM for every model ─────────────────── | |
| progress(0.40, desc="📈 训练集曲线...") | |
| log(f"\n 📈 绘制训练集 ROC / PR / CM...") | |
| for mn in mnames: | |
| tr = train_results[mn] | |
| tm = tr['metrics'] | |
| plot_multiclass_roc( | |
| y_mapped.values, tr['proba'], class_indices, | |
| f'ROC (Train) — {mn} ({task_type}, AUC={tm["Macro_AUC"]:.3f})', | |
| f'roc_train_{mn}', rf) | |
| plot_multiclass_pr( | |
| y_mapped.values, tr['proba'], class_indices, | |
| f'PR (Train) — {mn} ({task_type})', | |
| f'pr_train_{mn}', rf) | |
| plot_confusion_matrix( | |
| y_mapped.values, tr['pred'], class_indices, | |
| f'CM (Train) — {mn} (Acc={tm["Accuracy"]:.3f})', | |
| f'cm_train_{mn}', rf) | |
| # Combined training-set ROC (all models, macro) | |
| plt.figure(figsize=(10, 8)) | |
| for i, mn in enumerate(mnames): | |
| tr = train_results[mn] | |
| y_bin = label_binarize(y_mapped.values, classes=class_indices) | |
| if n_classes == 2: y_bin = np.hstack([1 - y_bin, y_bin]) | |
| all_fpr = np.linspace(0, 1, 200); mean_tpr = np.zeros_like(all_fpr) | |
| for c in range(n_classes): | |
| f, t, _ = roc_curve(y_bin[:, c], tr['proba'][:, c]) | |
| mean_tpr += np.interp(all_fpr, f, t) | |
| mean_tpr /= n_classes; mean_tpr[-1] = 1.0 | |
| ma = auc_score(all_fpr, mean_tpr) | |
| plt.plot(all_fpr, mean_tpr, color=COLORS[i%8], lw=2.5, | |
| label=f'{mn} (Macro AUC={ma:.3f})') | |
| plt.plot([0,1],[0,1],'--',color='#ccc',lw=1) | |
| plt.xlim([-0.02,1.02]); plt.ylim([-0.02,1.02]) | |
| plt.xlabel('FPR',fontsize=13); plt.ylabel('TPR',fontsize=13) | |
| plt.title(f'ROC (Train) — All Models ({task_type})', | |
| fontsize=14, fontweight='bold') | |
| plt.legend(loc='lower right',fontsize=10) | |
| plt.grid(True,alpha=0.15); plt.tight_layout() | |
| plt.savefig(os.path.join(rf,'roc_train_all.pdf'), | |
| format='pdf',bbox_inches='tight',dpi=300) | |
| plt.savefig(os.path.join(rf,'roc_train_all.png'), | |
| format='png',bbox_inches='tight',dpi=150) | |
| plt.close() | |
| # ── CV-OOF ROC / PR / CM ───────────────────────────────────────── | |
| progress(0.44, desc="📈 内部验证ROC曲线...") | |
| log(f"\n 📈 绘制内部验证(CV-OOF) ROC / PR / CM...") | |
| for mn in mnames: | |
| r = amr[mn] | |
| plot_multiclass_roc( | |
| r['all_yt'], r['all_yproba'], class_indices, | |
| f'ROC (Internal Val) — {mn} ({task_type}, AUC={r["mean_auc"]:.3f})', | |
| f'roc_val_{mn}', rf) | |
| # Combined CV-OOF ROC | |
| plt.figure(figsize=(10, 8)) | |
| for i, mn in enumerate(mnames): | |
| r = amr[mn] | |
| y_bin = label_binarize(r['all_yt'], classes=class_indices) | |
| if n_classes == 2: y_bin = np.hstack([1 - y_bin, y_bin]) | |
| all_fpr = np.linspace(0, 1, 200); mean_tpr = np.zeros_like(all_fpr) | |
| for c in range(n_classes): | |
| f, t, _ = roc_curve(y_bin[:, c], r['all_yproba'][:, c]) | |
| mean_tpr += np.interp(all_fpr, f, t) | |
| mean_tpr /= n_classes; mean_tpr[-1] = 1.0 | |
| ma = auc_score(all_fpr, mean_tpr) | |
| plt.plot(all_fpr, mean_tpr, color=COLORS[i%8], lw=2.5, | |
| label=f'{mn} (Macro AUC={ma:.3f})') | |
| plt.plot([0,1],[0,1],'--',color='#ccc',lw=1) | |
| plt.xlim([-0.02,1.02]); plt.ylim([-0.02,1.02]) | |
| plt.xlabel('FPR',fontsize=13); plt.ylabel('TPR',fontsize=13) | |
| plt.title(f'ROC (Internal Val / CV-OOF) — All Models ({task_type})', | |
| fontsize=14, fontweight='bold') | |
| plt.legend(loc='lower right',fontsize=10) | |
| plt.grid(True,alpha=0.15); plt.tight_layout() | |
| plt.savefig(os.path.join(rf,'roc_val_all.pdf'), | |
| format='pdf',bbox_inches='tight',dpi=300) | |
| plt.savefig(os.path.join(rf,'roc_val_all.png'), | |
| format='png',bbox_inches='tight',dpi=150) | |
| plt.close() | |
| progress(0.48, desc="📈 PR曲线...") | |
| for mn in mnames: | |
| r = amr[mn] | |
| plot_multiclass_pr( | |
| r['all_yt'], r['all_yproba'], class_indices, | |
| f'PR (Internal Val) — {mn} ({task_type})', | |
| f'pr_val_{mn}', rf) | |
| progress(0.51, desc="📊 混淆矩阵...") | |
| for mn in mnames: | |
| r = amr[mn] | |
| plot_confusion_matrix( | |
| r['all_yt'], r['all_yp'], class_indices, | |
| f'CM (Internal Val) — {mn} (Acc={r["mean_acc"]:.3f})', | |
| f'cm_val_{mn}', rf) | |
| # ── Bootstrap AUC Test ──────────────────────────────────────────── | |
| progress(0.54, desc="🔬 Bootstrap AUC 检验...") | |
| best_mn = max(amr, key=lambda x: amr[x]['mean_auc']) | |
| best_auc = amr[best_mn]['mean_auc'] | |
| log(f"\n 🏆 最佳模型: {best_mn} (Macro AUC={best_auc:.4f})") | |
| log(f" 🔬 Bootstrap 检验 (n=2000, α=0.05)...") | |
| ALPHA = 0.05 | |
| bootstrap_results = [] | |
| retained = [best_mn] | |
| for om in mnames: | |
| if om == best_mn: | |
| continue | |
| p_val, auc_a, auc_b, ci_lo, ci_hi = bootstrap_auc_test( | |
| amr[best_mn]['all_yt'], | |
| amr[best_mn]['all_yproba'], | |
| amr[om]['all_yproba'], | |
| class_indices, n_bootstrap=2000) | |
| dec = "Retained" if p_val >= ALPHA else "Excluded" | |
| if p_val >= ALPHA: | |
| retained.append(om) | |
| bootstrap_results.append({ | |
| 'Model_A': best_mn, 'AUC_A': auc_a, | |
| 'Model_B': om, 'AUC_B': auc_b, | |
| 'AUC_Diff': auc_a - auc_b, | |
| 'CI_95_Low': ci_lo, 'CI_95_High': ci_hi, | |
| 'P_value': p_val, 'Decision': dec, | |
| }) | |
| log(f" {best_mn} vs {om}: ΔAUC={auc_a-auc_b:+.4f} " | |
| f"95%CI=[{ci_lo:+.4f},{ci_hi:+.4f}] " | |
| f"P={p_val:.4f} → {dec}") | |
| bootstrap_df = (pd.DataFrame(bootstrap_results) | |
| .sort_values('P_value', ascending=False) | |
| if bootstrap_results else pd.DataFrame()) | |
| log(f" ✅ 保留 {len(retained)}/{nm} 个模型: {', '.join(retained)}") | |
| # ── Best model: Train vs Internal Val overlay ───────────────────── | |
| progress(0.57, desc="📈 Train vs Val 对比图...") | |
| log(f"\n 📈 最佳模型 {best_mn}: Train vs Internal Validation 对比...") | |
| auc_tr_b, auc_vl_b = plot_train_vs_val_roc( | |
| y_mapped.values, train_results[best_mn]['proba'], | |
| amr[best_mn]['all_yt'], amr[best_mn]['all_yproba'], | |
| class_indices, best_mn, f'roc_train_vs_val_{best_mn}', rf) | |
| ap_tr_b, ap_vl_b = plot_train_vs_val_pr( | |
| y_mapped.values, train_results[best_mn]['proba'], | |
| amr[best_mn]['all_yt'], amr[best_mn]['all_yproba'], | |
| class_indices, best_mn, f'pr_train_vs_val_{best_mn}', rf) | |
| log(f" ROC — Train AUC={auc_tr_b:.4f} / Val AUC={auc_vl_b:.4f}") | |
| log(f" PR — Train AP={ap_tr_b:.4f} / Val AP={ap_vl_b:.4f}") | |
| # ── SHAP ────────────────────────────────────────────────────────── | |
| progress(0.60, desc="🔥 SHAP分析...") | |
| log(f"\n 🔥 SHAP特征分析 (保留模型中 Top 3)...") | |
| shap_imp = {} | |
| models_for_shap = sorted(retained, | |
| key=lambda x: amr[x]['mean_auc'], | |
| reverse=True)[:3] | |
| for si, mn in enumerate(models_for_shap): | |
| progress(0.60 + 0.10 * si / max(len(models_for_shap), 1), | |
| desc=f"🔥 SHAP: {mn}...") | |
| mo = tms[mn]; Xshap = X.values | |
| ns = min(SHAPSZ, Xshap.shape[0]) | |
| np.random.seed(RS) | |
| sidx = np.random.choice(Xshap.shape[0], ns, replace=False) | |
| Xs = Xshap[sidx] | |
| try: | |
| if mn in ['RF', 'XGB', 'DT', 'AdaBoost']: | |
| exp = shap.TreeExplainer(mo) | |
| sv = exp.shap_values(Xs) | |
| else: | |
| bg = Xs[np.random.choice(ns, min(50, ns), replace=False)] | |
| exp = shap.KernelExplainer( | |
| lambda x, m=mo: m.predict_proba(x), bg) | |
| sv = exp.shap_values(Xs) | |
| # [FIX-3] Robust SHAP shape handling | |
| if isinstance(sv, list): | |
| sv_abs = np.mean([np.abs(s) for s in sv], axis=0) | |
| elif sv.ndim == 3: | |
| if sv.shape[2] == n_classes: | |
| sv_abs = np.mean(np.abs(sv), axis=2) | |
| elif sv.shape[1] == n_classes: | |
| sv_abs = np.mean(np.abs(sv), axis=1) | |
| else: | |
| sv_abs = np.abs(sv).mean(axis=-1) | |
| else: | |
| sv_abs = np.abs(sv) | |
| fi = sv_abs.mean(axis=0) | |
| if len(fi) > len(fnames): fi = fi[:len(fnames)] | |
| elif len(fi) < len(fnames): | |
| fi = np.pad(fi, (0, len(fnames) - len(fi))) | |
| idf = (pd.DataFrame({'Feature': fnames, 'Importance': fi}) | |
| .sort_values('Importance', ascending=False)) | |
| shap_imp[mn] = idf | |
| plt.figure(figsize=(10, max(6, TOPN * 0.3))) | |
| top_df = idf.head(TOPN).iloc[::-1] | |
| plt.barh(top_df['Feature'], top_df['Importance'], | |
| color='#2563eb', alpha=0.8) | |
| plt.xlabel('Mean |SHAP|', fontsize=12) | |
| plt.title(f'SHAP Feature Importance — {mn} (Top {TOPN})', | |
| fontsize=13, fontweight='bold') | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(rf, f'shap_{mn}.pdf'), | |
| format='pdf', bbox_inches='tight') | |
| plt.savefig(os.path.join(rf, f'shap_{mn}.png'), | |
| format='png', bbox_inches='tight', dpi=150) | |
| plt.close() | |
| log(f" ✅ {mn} Top3: " | |
| f"{', '.join(idf.head(3)['Feature'].tolist())}") | |
| except Exception as e: | |
| log(f" ⚠ {mn} SHAP失败: {e}") | |
| # ── Feature Ablation ────────────────────────────────────────────── | |
| progress(0.72, desc="🧪 特征消融...") | |
| log(f"\n 🧪 特征消融 (仅最佳模型 {best_mn})...") | |
| ablation_data = None | |
| if best_mn in shap_imp: | |
| imp_df = shap_imp[best_mn] | |
| top_feats = imp_df.head(TOPN)['Feature'].tolist() | |
| fcs = []; aucs_a = [] | |
| for nf in range(1, len(top_feats) + 1): | |
| Xsub = X[top_feats[:nf]] | |
| fold_aucs = [] | |
| for tri, tei in skf.split(Xsub, y_mapped): | |
| mf = deepcopy(mcfg[best_mn]['model']) | |
| bp2 = bpd.get(best_mn, {}) | |
| if isinstance(bp2, dict) and bp2: | |
| mf.set_params(**bp2) | |
| mf.fit(Xsub.iloc[tri].values, y_mapped.iloc[tri]) | |
| yproba_f = mf.predict_proba(Xsub.iloc[tei].values) | |
| yte_f = y_mapped.iloc[tei] | |
| try: | |
| a = (roc_auc_score(yte_f, yproba_f[:, 1]) | |
| if n_classes == 2 else | |
| roc_auc_score(yte_f, yproba_f, | |
| multi_class='ovr', average='macro')) | |
| except: | |
| a = 0.0 | |
| fold_aucs.append(a) | |
| fcs.append(nf); aucs_a.append(np.mean(fold_aucs)) | |
| full_auc = amr[best_mn]['mean_auc'] | |
| opt_n = len(top_feats) | |
| for i, a in enumerate(aucs_a): | |
| if a >= full_auc * 0.95: | |
| opt_n = i + 1; break | |
| ablation_data = { | |
| 'fcs': fcs, 'aucs': aucs_a, 'feats': top_feats, | |
| 'opt_n': opt_n, 'opt_feats': top_feats[:opt_n] | |
| } | |
| log(f" ✅ 最优特征数: {opt_n} " | |
| f"(AUC={aucs_a[opt_n-1]:.4f} vs Full={full_auc:.4f})") | |
| plt.figure(figsize=(10, 7)) | |
| plt.plot(fcs, aucs_a, 'o-', color='#2563eb', lw=2, ms=5) | |
| plt.scatter([opt_n], [aucs_a[opt_n-1]], s=200, marker='*', | |
| color='#ef4444', edgecolors='black', lw=2, zorder=5) | |
| plt.axhline(y=full_auc, color='gray', ls='--', lw=1, alpha=0.5, | |
| label=f'Full AUC={full_auc:.3f}') | |
| plt.xlabel('Number of Features', fontsize=13) | |
| plt.ylabel('Macro AUC', fontsize=13) | |
| plt.title(f'Feature Ablation — {best_mn} (★ Optimal={opt_n})', | |
| fontsize=14, fontweight='bold') | |
| plt.legend(fontsize=11); plt.grid(True, alpha=0.15); plt.tight_layout() | |
| plt.savefig(os.path.join(rf, 'ablation.pdf'), | |
| format='pdf', bbox_inches='tight') | |
| plt.savefig(os.path.join(rf, 'ablation.png'), | |
| format='png', bbox_inches='tight', dpi=150) | |
| plt.close() | |
| # ── External Validation ─────────────────────────────────────────── | |
| val_files_list = [vf for vf in [val_file1, val_file2, val_file3] | |
| if vf is not None] | |
| final_feats = ablation_data['opt_feats'] if ablation_data else fnames | |
| if val_files_list: | |
| progress(0.82, desc="🧪 外部验证...") | |
| log(f"\n{'━'*60}") | |
| log(f" 🧪 外部验证 ({len(val_files_list)} 个验证集)") | |
| for vi, vf in enumerate(val_files_list, 1): | |
| vp = (vf if isinstance(vf, str) | |
| else getattr(vf, 'name', str(vf))) | |
| ed = pd.read_csv(vp); ye_raw = ed.iloc[:, 0] | |
| vcol2 = ed.iloc[:, 1] | |
| vcol2_is_id = ((vcol2.dtype == 'object') or | |
| (vcol2.nunique() / len(vcol2) > 0.5)) | |
| Xe = ed.iloc[:, 2:] if vcol2_is_id else ed.iloc[:, 1:] | |
| ye = ye_raw.map(label_map) | |
| if ye.isna().any(): | |
| log(f" ⚠ 验证集 {vi} 含有训练集中不存在的标签,已跳过") | |
| continue | |
| log(f"\n 📊 验证集 {vi}: {Xe.shape[0]} 样本, " | |
| f"{os.path.basename(vp)}") | |
| Xes = Xe[final_feats]; Xtf = X[final_feats] | |
| fm = deepcopy(mcfg[best_mn]['model']) | |
| bp3 = bpd[best_mn] | |
| if isinstance(bp3, dict) and bp3: | |
| fm.set_params(**bp3) | |
| fm.fit(Xtf.values, y_mapped) | |
| yep = fm.predict_proba(Xes.values) | |
| yed = fm.predict(Xes.values) | |
| ye_np = ye.values | |
| ext_met = compute_multiclass_metrics( | |
| ye_np, yed, yep, class_indices) | |
| em = ext_met | |
| log(f" ✅ AUC={em['Macro_AUC']:.4f} " | |
| f"Acc={em['Accuracy']:.4f} " | |
| f"Sens={em['Macro_Sensitivity']:.4f} " | |
| f"Spec={em['Macro_Specificity']:.4f} " | |
| f"PPV={em['Macro_PPV']:.4f} " | |
| f"NPV={em['Macro_NPV']:.4f} " | |
| f"F1={em['Macro_F1']:.4f} " | |
| f"Kappa={em['Kappa']:.4f}") | |
| sfx = f'_ext{vi}' if len(val_files_list) > 1 else '_ext' | |
| tag = f'Validation {vi}' if len(val_files_list) > 1 else 'External' | |
| plot_multiclass_roc(ye_np, yep, class_indices, | |
| f'ROC — {tag} ({best_mn})', f'roc{sfx}', rf) | |
| plot_multiclass_pr(ye_np, yep, class_indices, | |
| f'PR — {tag} ({best_mn})', f'pr{sfx}', rf) | |
| plot_confusion_matrix(ye_np, yed, class_indices, | |
| f'CM — {tag} ({best_mn})', f'cm{sfx}', rf) | |
| # [v3-5] Extended external validation Excel | |
| with pd.ExcelWriter( | |
| os.path.join(rf, f'validation{sfx}.xlsx'), | |
| engine='openpyxl' | |
| ) as w: | |
| # Macro metrics row | |
| macro_row = {'Model': best_mn, | |
| 'N_Features': len(final_feats)} | |
| macro_row.update(metrics_to_flat_row(em)) | |
| pd.DataFrame([macro_row]).to_excel( | |
| w, sheet_name='Metrics_Macro', index=False) | |
| # Per-class detail | |
| per_class_df(em, class_indices).to_excel( | |
| w, sheet_name='Metrics_PerClass', index=False) | |
| pd.DataFrame({'Feature': final_feats}).to_excel( | |
| w, sheet_name='Features', index=False) | |
| # ── Save Results ────────────────────────────────────────────────── | |
| progress(0.92, desc="💾 保存结果...") | |
| log(f"\n 💾 保存结果...") | |
| with pd.ExcelWriter( | |
| os.path.join(rf, 'model_evaluation.xlsx'), | |
| engine='openpyxl' | |
| ) as w: | |
| # 1. Per-fold CV results for every model [v3-2 extended columns] | |
| for mn, r in amr.items(): | |
| r['fold_df'].to_excel(w, sheet_name=mn, index=False) | |
| # 2. Summary — Internal Validation (CV-OOF) [v3-3 all metrics] | |
| sd = [] | |
| for mn, r in amr.items(): | |
| row = { | |
| 'Model': mn, | |
| 'Retained': 'Yes' if mn in retained else 'No', | |
| 'Best': 'Best' if mn == best_mn else '', | |
| } | |
| row.update({ | |
| 'AUC': r['mean_auc'], | |
| 'Accuracy': r['mean_acc'], | |
| 'Sensitivity': r['mean_sens'], | |
| 'Specificity': r['mean_spec'], | |
| 'PPV': r['mean_ppv'], | |
| 'NPV': r['mean_npv'], | |
| 'F1': r['mean_f1'], | |
| 'Weighted_F1': r['mean_wf1'], | |
| 'Kappa': r['mean_kappa'], | |
| }) | |
| sd.append(row) | |
| (pd.DataFrame(sd) | |
| .sort_values('AUC', ascending=False) | |
| .to_excel(w, sheet_name='Summary_InternalVal', index=False)) | |
| # 3. Train vs Internal Validation [v3-3 all metrics] | |
| comparison_rows = [] | |
| for mn in amr: | |
| tr_m = train_results[mn]['metrics'] | |
| vm = amr[mn] | |
| row = { | |
| 'Model': mn, | |
| 'Train_AUC': tr_m['Macro_AUC'], | |
| 'Train_Accuracy': tr_m['Accuracy'], | |
| 'Train_Sensitivity': tr_m['Macro_Sensitivity'], | |
| 'Train_Specificity': tr_m['Macro_Specificity'], | |
| 'Train_PPV': tr_m['Macro_PPV'], | |
| 'Train_NPV': tr_m['Macro_NPV'], | |
| 'Train_F1': tr_m['Macro_F1'], | |
| 'Train_Kappa': tr_m['Kappa'], | |
| 'Val_AUC': vm['mean_auc'], | |
| 'Val_Accuracy': vm['mean_acc'], | |
| 'Val_Sensitivity': vm['mean_sens'], | |
| 'Val_Specificity': vm['mean_spec'], | |
| 'Val_PPV': vm['mean_ppv'], | |
| 'Val_NPV': vm['mean_npv'], | |
| 'Val_F1': vm['mean_f1'], | |
| 'Val_Kappa': vm['mean_kappa'], | |
| 'AUC_Gap': tr_m['Macro_AUC'] - vm['mean_auc'], | |
| 'Retained': 'Yes' if mn in retained else 'No', | |
| 'Best': 'Best' if mn == best_mn else '', | |
| } | |
| comparison_rows.append(row) | |
| (pd.DataFrame(comparison_rows) | |
| .sort_values('Val_AUC', ascending=False) | |
| .to_excel(w, sheet_name='Train_vs_InternalVal', index=False)) | |
| # 4. Bootstrap test | |
| if len(bootstrap_df) > 0: | |
| bootstrap_df.to_excel(w, sheet_name='Bootstrap_Test', | |
| index=False) | |
| # 5. [v3-4] Per-class detail for EVERY model (train + val) | |
| for mn in mnames: | |
| # Val (OOF) | |
| oof_pc = per_class_df(amr[mn]['oof_metrics'], class_indices) | |
| sheet_v = f'{mn}_Val_PerClass' | |
| if len(sheet_v) > 31: sheet_v = sheet_v[:31] | |
| oof_pc.to_excel(w, sheet_name=sheet_v, index=False) | |
| # Train | |
| tr_pc = per_class_df(train_results[mn]['metrics'], class_indices) | |
| sheet_t = f'{mn}_Train_PerClass' | |
| if len(sheet_t) > 31: sheet_t = sheet_t[:31] | |
| tr_pc.to_excel(w, sheet_name=sheet_t, index=False) | |
| # Ablation Excel | |
| if ablation_data: | |
| with pd.ExcelWriter( | |
| os.path.join(rf, 'feature_ablation.xlsx'), | |
| engine='openpyxl' | |
| ) as w: | |
| pd.DataFrame({ | |
| 'N': ablation_data['fcs'], | |
| 'AUC': ablation_data['aucs'] | |
| }).to_excel(w, sheet_name='Ablation', index=False) | |
| for mn, idf in shap_imp.items(): | |
| idf.to_excel(w, sheet_name=f'{mn}_Imp', index=False) | |
| # ── best_params.txt [v3-6] all metrics ────────────────────────── | |
| with open(os.path.join(rf, 'best_params.txt'), 'w', | |
| encoding='utf-8') as f: | |
| f.write(f"Task: {task_type} Classification ({n_classes} classes)\n") | |
| f.write(f"Classes: {classes}\n") | |
| f.write(f"Label Mapping: {label_map}\n\n") | |
| f.write(f"Statistical Test: Bootstrap AUC Test " | |
| f"(n=2000, alpha=0.05)\n") | |
| f.write(f"Retained Models: {', '.join(retained)} " | |
| f"({len(retained)}/{nm})\n\n") | |
| f.write("=" * 65 + "\n") | |
| f.write("Model Performance Summary\n") | |
| f.write(f"{'Metric':<14} " | |
| f"{'AUC':>7} {'Acc':>7} {'Sens':>7} {'Spec':>7} " | |
| f"{'PPV':>7} {'NPV':>7} {'F1':>7} {'Kappa':>7}\n") | |
| f.write("-" * 65 + "\n") | |
| def fmt_row(label, m_auc, m_acc, m_sens, m_spec, | |
| m_ppv, m_npv, m_f1, m_kappa): | |
| return (f"{label:<14} " | |
| f"{m_auc:>7.4f} {m_acc:>7.4f} {m_sens:>7.4f} " | |
| f"{m_spec:>7.4f} {m_ppv:>7.4f} {m_npv:>7.4f} " | |
| f"{m_f1:>7.4f} {m_kappa:>7.4f}\n") | |
| for mn in mcfg: | |
| status = ("* Best" if mn == best_mn | |
| else ("Retained" if mn in retained else "Excluded")) | |
| tr_m = train_results[mn]['metrics'] | |
| vm = amr[mn] | |
| f.write(f"\nModel: {mn} | {status}\n") | |
| f.write(fmt_row( | |
| " Train", | |
| tr_m['Macro_AUC'], tr_m['Accuracy'], | |
| tr_m['Macro_Sensitivity'], tr_m['Macro_Specificity'], | |
| tr_m['Macro_PPV'], tr_m['Macro_NPV'], | |
| tr_m['Macro_F1'], tr_m['Kappa'])) | |
| f.write(fmt_row( | |
| " CV-OOF", | |
| vm['mean_auc'], vm['mean_acc'], | |
| vm['mean_sens'], vm['mean_spec'], | |
| vm['mean_ppv'], vm['mean_npv'], | |
| vm['mean_f1'], vm['mean_kappa'])) | |
| f.write(f" AUC Gap: " | |
| f"{tr_m['Macro_AUC'] - vm['mean_auc']:+.4f}\n") | |
| bp = bpd[mn] | |
| if isinstance(bp, dict): | |
| for k, v in bp.items(): | |
| f.write(f" {k}: {v}\n") | |
| else: | |
| f.write(f" Params: {bp}\n") | |
| if len(bootstrap_df) > 0: | |
| f.write("\n" + "=" * 65 + "\n") | |
| f.write("Bootstrap AUC Comparison Results\n") | |
| f.write("=" * 65 + "\n") | |
| for _, row in bootstrap_df.iterrows(): | |
| f.write(f" {row['Model_A']} vs {row['Model_B']}: " | |
| f"dAUC={row['AUC_Diff']:+.4f} " | |
| f"95%CI=[{row['CI_95_Low']:+.4f}," | |
| f"{row['CI_95_High']:+.4f}] " | |
| f"P={row['P_value']:.4f} -> {row['Decision']}\n") | |
| if ablation_data: | |
| f.write(f"\nOptimal Features ({ablation_data['opt_n']}): " | |
| f"{', '.join(ablation_data['opt_feats'])}\n") | |
| # Save best model pickle | |
| pickle.dump({ | |
| 'model_name': best_mn, | |
| 'model': tms[best_mn], | |
| 'best_params': bpd[best_mn], | |
| 'classes': classes, | |
| 'n_classes': n_classes, | |
| 'label_map': label_map, | |
| 'features': final_feats, | |
| 'task_type': task_type, | |
| }, open(os.path.join(rf, f'model_{best_mn}.pkl'), 'wb')) | |
| # ── ZIP ─────────────────────────────────────────────────────────── | |
| progress(0.97, desc="📦 打包ZIP...") | |
| zp = os.path.join(tempfile.gettempdir(), | |
| f"ml_results_{int(time.time())}_{os.getpid()}.zip") | |
| with zipfile.ZipFile(zp, 'w', zipfile.ZIP_DEFLATED) as zf: | |
| for root, _, files in os.walk(rf): | |
| for fn in files: | |
| zf.write(os.path.join(root, fn), | |
| os.path.relpath(os.path.join(root, fn), rf)) | |
| nf = sum(len(f) for _, _, f in os.walk(rf)) | |
| shutil.rmtree(rf, ignore_errors=True); gc.collect() | |
| tm_b = train_results[best_mn]['metrics'] | |
| log(f"\n{'━'*60}") | |
| log(f" 🎉 分析完成!共 {nf} 个文件已打包") | |
| log(f" 📋 Task: {task_type} | Best Model: {best_mn}") | |
| log(f" 📊 Train — AUC={tm_b['Macro_AUC']:.4f} " | |
| f"Acc={tm_b['Accuracy']:.4f} " | |
| f"Sens={tm_b['Macro_Sensitivity']:.4f} " | |
| f"Spec={tm_b['Macro_Specificity']:.4f} " | |
| f"PPV={tm_b['Macro_PPV']:.4f} NPV={tm_b['Macro_NPV']:.4f} " | |
| f"F1={tm_b['Macro_F1']:.4f}") | |
| log(f" 📊 CV-OOF — AUC={best_auc:.4f} " | |
| f"Acc={amr[best_mn]['mean_acc']:.4f} " | |
| f"Sens={amr[best_mn]['mean_sens']:.4f} " | |
| f"Spec={amr[best_mn]['mean_spec']:.4f} " | |
| f"PPV={amr[best_mn]['mean_ppv']:.4f} " | |
| f"NPV={amr[best_mn]['mean_npv']:.4f} " | |
| f"F1={amr[best_mn]['mean_f1']:.4f}") | |
| log(f"{'━'*60}") | |
| progress(1.0, desc="✅ 完成!") | |
| return zp, "\n".join(L) | |
| except Exception as e: | |
| log(f"\n❌ 错误: {e}") | |
| log(traceback.format_exc()) | |
| if os.path.exists(rf): shutil.rmtree(rf, ignore_errors=True) | |
| gc.collect() | |
| return None, "\n".join(L) | |
| # ============================================================================ | |
| # Gradio UI | |
| # ============================================================================ | |
| CUSTOM_CSS = """ | |
| .header-banner { | |
| background: linear-gradient(135deg, #0a2463 0%, #1e3a7a 40%, #2554a8 100%); | |
| border-radius: 16px; padding: 28px 36px; margin-bottom: 20px; | |
| box-shadow: 0 8px 32px rgba(0,0,0,0.18); position: relative; overflow: hidden; | |
| } | |
| .header-banner::before { | |
| content: ''; position: absolute; top: -50%; right: -20%; | |
| width: 400px; height: 400px; | |
| background: radial-gradient(circle, rgba(96,165,250,0.2) 0%, transparent 70%); | |
| border-radius: 50%; | |
| } | |
| .header-banner img { max-height: 52px; border-radius: 6px; margin-bottom: 12px; } | |
| .header-banner h1 { color: #e2e8f0 !important; font-size: 1.7em !important; | |
| margin: 4px 0 6px 0 !important; font-weight: 700 !important; } | |
| .header-banner p { color: #94a3b8 !important; font-size: 0.92em !important; | |
| margin: 2px 0 !important; line-height: 1.6; } | |
| .header-banner .credit { color: #64748b !important; font-size: 0.82em !important; | |
| margin-top: 10px !important; | |
| border-top: 1px solid rgba(148,163,184,0.15); padding-top: 10px; } | |
| .section-title { | |
| background: linear-gradient(90deg, #2563eb 0%, #3b82f6 100%); | |
| color: white !important; padding: 8px 16px; border-radius: 8px; | |
| font-size: 0.95em !important; font-weight: 600 !important; | |
| margin: 12px 0 8px 0; } | |
| .pipeline-box { | |
| background: linear-gradient(135deg, #f0f9ff 0%, #e0f2fe 100%); | |
| border: 1px solid #bae6fd; border-radius: 12px; | |
| padding: 14px 18px; margin: 8px 0; font-size: 0.88em; } | |
| .pipeline-box code { background: #2563eb; color: white; padding: 2px 8px; | |
| border-radius: 4px; font-size: 0.85em; margin: 0 2px; } | |
| .log-area textarea { | |
| font-family: 'Menlo','Consolas',monospace !important; | |
| font-size: 12.5px !important; line-height: 1.5 !important; | |
| background: #0f172a !important; color: #e2e8f0 !important; | |
| border-radius: 10px !important; padding: 16px !important; } | |
| .gradio-container { max-width: 1280px !important; } | |
| footer { display: none !important; } | |
| """ | |
| with gr.Blocks( | |
| title="ML 多分类模型平台 — 复旦大学附属眼耳鼻喉科医院", | |
| theme=gr.themes.Soft(primary_hue="blue", secondary_hue="slate", | |
| neutral_hue="slate"), | |
| css=CUSTOM_CSS, | |
| ) as demo: | |
| gr.HTML(""" | |
| <div class="header-banner"> | |
| <img src="https://huggingface.co/spaces/fudan-renjun/machine-learning-2/resolve/main/hospital_logo.png" | |
| alt="Logo" onerror="this.style.display='none'"/> | |
| <h1>🧬 ML 多分类模型训练与评估平台</h1> | |
| <p>支持 2~8 分类 · 上传 CSV 即可完成全流程分析</p> | |
| <p>评估指标:AUC · Accuracy · Sensitivity · Specificity · PPV · NPV · F1 · Kappa</p> | |
| <p class="credit">复旦大学附属眼耳鼻喉科医院 · 检验科 · 任俊</p> | |
| </div> | |
| """) | |
| gr.HTML(""" | |
| <div class="pipeline-box"> | |
| <strong>📋 流程:</strong> | |
| <code>训练+训练集评估</code> → <code>交叉验证(OOF)</code> → | |
| <code>Train vs Val对比</code> → <code>SHAP</code> → | |
| <code>特征消融</code> → <code>外部验证</code> | |
| | | |
| <strong>指标:</strong> | |
| AUC · Accuracy · Sensitivity · Specificity · PPV · NPV · F1 · Kappa(宏平均+逐类) | |
| | | |
| <strong>CSV:</strong> 第1列=标签(整数), 第2列=ID, 第3列起=特征 | |
| </div> | |
| """) | |
| with gr.Row(equal_height=False): | |
| with gr.Column(scale=5): | |
| gr.HTML('<div class="section-title">📂 数据上传</div>') | |
| train_file = gr.File(label="训练集 CSV(必需)", file_types=[".csv"]) | |
| gr.HTML('<p style="color:#64748b;font-size:0.85em;margin:4px 0 8px 0;">' | |
| '验证集可选,支持同时上传 1~3 个</p>') | |
| with gr.Row(): | |
| val_file1 = gr.File(label="验证集 1(可选)", | |
| file_types=[".csv"], scale=1) | |
| val_file2 = gr.File(label="验证集 2(可选)", | |
| file_types=[".csv"], scale=1) | |
| val_file3 = gr.File(label="验证集 3(可选)", | |
| file_types=[".csv"], scale=1) | |
| gr.HTML('<div class="section-title">🏷️ 分类设置</div>') | |
| n_classes_select = gr.Dropdown( | |
| choices=["2 类(二分类)","3 类","4 类","5 类", | |
| "6 类","7 类","8 类"], | |
| value="2 类(二分类)", label="选择分类数", | |
| info="请根据数据标签列的类别数选择,系统将自动验证是否匹配", | |
| ) | |
| gr.HTML('<div class="section-title">🤖 模型选择</div>') | |
| model_selector = gr.Dropdown( | |
| choices=ALL_MODEL_NAMES, value=ALL_MODEL_NAMES, | |
| multiselect=True, label="选择模型(均支持多分类)", | |
| info=("RF=随机森林 DT=决策树 KNN=K近邻 XGB=XGBoost " | |
| "AdaBoost LR=逻辑回归 NB=朴素贝叶斯 SVM=支持向量机"), | |
| ) | |
| with gr.Row(): | |
| btn_all = gr.Button("🔘 全选", size="sm", variant="secondary") | |
| btn_tree = gr.Button("🌲 树模型", size="sm", variant="secondary") | |
| btn_linear = gr.Button("📐 线性模型", size="sm", variant="secondary") | |
| btn_top4 = gr.Button("⚡ 经典四模型", size="sm", variant="secondary") | |
| btn_all.click(lambda: ALL_MODEL_NAMES, outputs=model_selector) | |
| btn_tree.click(lambda: ['RF','DT','XGB','AdaBoost'], outputs=model_selector) | |
| btn_linear.click(lambda: ['LR','SVM','NB'], outputs=model_selector) | |
| btn_top4.click(lambda: ['RF','XGB','LR','SVM'], outputs=model_selector) | |
| gr.HTML('<div class="section-title">⚙️ 参数配置</div>') | |
| enable_tuning = gr.Checkbox( | |
| value=False, | |
| label="启用超参数调优 (GridSearchCV) ⚠️ 开启后运行时间显著增加") | |
| with gr.Row(): | |
| cv_folds = gr.Slider(3, 10, value=5, step=1, | |
| label="交叉验证折数") | |
| top_n = gr.Slider(5, 50, value=20, step=1, | |
| label="SHAP 前 N 个特征") | |
| shap_sz = gr.Slider(30, 200, value=80, step=10, | |
| label="SHAP 采样数量") | |
| run_btn = gr.Button("🚀 开始分析", variant="primary", size="lg") | |
| with gr.Column(scale=5): | |
| gr.HTML('<div class="section-title">📋 运行日志</div>') | |
| log_output = gr.Textbox( | |
| label="", lines=24, max_lines=50, interactive=False, | |
| placeholder=("点击「开始分析」后,日志将在此显示...\n" | |
| "支持 2~8 分类。\n" | |
| "评估指标:AUC / Accuracy / Sensitivity / " | |
| "Specificity / PPV / NPV / F1 / Kappa"), | |
| elem_classes="log-area", | |
| ) | |
| gr.HTML('<div class="section-title">⬇️ 结果下载</div>') | |
| zip_output = gr.File(label="分析结果 ZIP 压缩包") | |
| run_btn.click( | |
| fn=run_pipeline, | |
| inputs=[train_file, val_file1, val_file2, val_file3, | |
| n_classes_select, model_selector, enable_tuning, | |
| cv_folds, top_n, shap_sz], | |
| outputs=[zip_output, log_output], | |
| api_name="run", | |
| ) | |
| # ============================================================================ | |
| # Authentication | |
| # ============================================================================ | |
| from datetime import datetime | |
| ACCOUNTS = { | |
| "admin": {"password": "admin123", "expires": None}, | |
| "renjun": {"password": "fudan2025", "expires": "2027-12-31"}, | |
| "guest": {"password": "guest888", "expires": "2027-06-30"}, | |
| } | |
| def auth_fn(username, password): | |
| user = ACCOUNTS.get(username) | |
| if not user or user["password"] != password: return False | |
| if user["expires"]: | |
| try: | |
| if datetime.now() > datetime.strptime(user["expires"], "%Y-%m-%d"): | |
| return False | |
| except: return False | |
| return True | |
| demo.queue() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| auth=auth_fn, | |
| auth_message=("🔐 复旦大学附属眼耳鼻喉科医院 · ML多分类分析平台\n" | |
| "请输入账号和密码登录"), | |
| ssr_mode=False, | |
| ) |