Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,903 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ML Multi-Class Classification Pipeline (2-8 classes)
|
| 3 |
+
Eye & ENT Hospital of Fudan University — Laboratory Medicine, Ren Jun
|
| 4 |
+
Gradio 5.12.0 + Python 3.11
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import matplotlib
|
| 10 |
+
matplotlib.use('Agg')
|
| 11 |
+
import matplotlib.pyplot as plt
|
| 12 |
+
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier
|
| 13 |
+
from sklearn.tree import DecisionTreeClassifier
|
| 14 |
+
from sklearn.neighbors import KNeighborsClassifier
|
| 15 |
+
from sklearn.linear_model import LogisticRegression
|
| 16 |
+
from sklearn.naive_bayes import GaussianNB
|
| 17 |
+
from sklearn.svm import SVC
|
| 18 |
+
from xgboost import XGBClassifier
|
| 19 |
+
from sklearn.model_selection import StratifiedKFold, GridSearchCV
|
| 20 |
+
from sklearn.metrics import (
|
| 21 |
+
roc_auc_score, confusion_matrix, roc_curve,
|
| 22 |
+
auc as auc_score, precision_recall_curve,
|
| 23 |
+
classification_report, accuracy_score, f1_score,
|
| 24 |
+
cohen_kappa_score
|
| 25 |
+
)
|
| 26 |
+
from sklearn.preprocessing import label_binarize
|
| 27 |
+
import seaborn as sns
|
| 28 |
+
import warnings
|
| 29 |
+
from scipy import stats
|
| 30 |
+
import os
|
| 31 |
+
import shap
|
| 32 |
+
import pickle
|
| 33 |
+
from copy import deepcopy
|
| 34 |
+
import zipfile
|
| 35 |
+
import tempfile
|
| 36 |
+
import traceback
|
| 37 |
+
import time
|
| 38 |
+
import shutil
|
| 39 |
+
import gc
|
| 40 |
+
import threading
|
| 41 |
+
import gradio as gr
|
| 42 |
+
from itertools import cycle
|
| 43 |
+
|
| 44 |
+
warnings.filterwarnings('ignore')
|
| 45 |
+
# Publication-quality plot settings for SCI papers
|
| 46 |
+
plt.rcParams['font.family'] = 'serif'
|
| 47 |
+
plt.rcParams['font.serif'] = ['Times New Roman', 'DejaVu Serif', 'serif']
|
| 48 |
+
plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans']
|
| 49 |
+
plt.rcParams['axes.unicode_minus'] = False
|
| 50 |
+
plt.rcParams['figure.dpi'] = 150
|
| 51 |
+
plt.rcParams['savefig.dpi'] = 300
|
| 52 |
+
plt.rcParams['axes.linewidth'] = 1.2
|
| 53 |
+
plt.rcParams['xtick.major.width'] = 1.0
|
| 54 |
+
plt.rcParams['ytick.major.width'] = 1.0
|
| 55 |
+
plt.rcParams['xtick.labelsize'] = 11
|
| 56 |
+
plt.rcParams['ytick.labelsize'] = 11
|
| 57 |
+
|
| 58 |
+
# ============================================================================
|
| 59 |
+
# Cache Cleanup
|
| 60 |
+
# ============================================================================
|
| 61 |
+
CLEANUP_MAX_AGE_MINUTES = 30
|
| 62 |
+
CLEANUP_INTERVAL_SECONDS = 600
|
| 63 |
+
CLEANUP_MAX_DISK_MB = 1024
|
| 64 |
+
|
| 65 |
+
def cleanup_old_temp_files():
|
| 66 |
+
now = time.time(); tmp = tempfile.gettempdir()
|
| 67 |
+
try:
|
| 68 |
+
for item in os.listdir(tmp):
|
| 69 |
+
p = os.path.join(tmp, item)
|
| 70 |
+
if item.startswith("ml_"):
|
| 71 |
+
age = now - os.path.getmtime(p)
|
| 72 |
+
if age > CLEANUP_MAX_AGE_MINUTES * 60:
|
| 73 |
+
if os.path.isdir(p): shutil.rmtree(p, ignore_errors=True)
|
| 74 |
+
elif os.path.isfile(p): os.remove(p)
|
| 75 |
+
except: pass
|
| 76 |
+
gc.collect()
|
| 77 |
+
|
| 78 |
+
def periodic_cleanup():
|
| 79 |
+
while True:
|
| 80 |
+
time.sleep(CLEANUP_INTERVAL_SECONDS)
|
| 81 |
+
cleanup_old_temp_files()
|
| 82 |
+
|
| 83 |
+
_ct = threading.Thread(target=periodic_cleanup, daemon=True); _ct.start()
|
| 84 |
+
|
| 85 |
+
# ============================================================================
|
| 86 |
+
# Multi-class Helper Functions
|
| 87 |
+
# ============================================================================
|
| 88 |
+
|
| 89 |
+
def multiclass_roc_auc(y_true, y_proba, classes):
|
| 90 |
+
"""Calculate per-class and macro/micro ROC AUC for multi-class"""
|
| 91 |
+
n_classes = len(classes)
|
| 92 |
+
if n_classes == 2:
|
| 93 |
+
return roc_auc_score(y_true, y_proba[:, 1])
|
| 94 |
+
try:
|
| 95 |
+
return roc_auc_score(y_true, y_proba, multi_class='ovr', average='macro')
|
| 96 |
+
except:
|
| 97 |
+
return 0.0
|
| 98 |
+
|
| 99 |
+
def plot_multiclass_roc(y_true, y_proba, classes, title, filepath_prefix, rf):
|
| 100 |
+
"""Plot ROC curves: one-vs-rest for each class + macro average"""
|
| 101 |
+
n_classes = len(classes)
|
| 102 |
+
y_bin = label_binarize(y_true, classes=classes)
|
| 103 |
+
if n_classes == 2:
|
| 104 |
+
y_bin = np.hstack([1 - y_bin, y_bin])
|
| 105 |
+
|
| 106 |
+
fpr_dict, tpr_dict, auc_dict = {}, {}, {}
|
| 107 |
+
for i in range(n_classes):
|
| 108 |
+
fpr_dict[i], tpr_dict[i], _ = roc_curve(y_bin[:, i], y_proba[:, i])
|
| 109 |
+
auc_dict[i] = auc_score(fpr_dict[i], tpr_dict[i])
|
| 110 |
+
|
| 111 |
+
# Macro average
|
| 112 |
+
all_fpr = np.unique(np.concatenate([fpr_dict[i] for i in range(n_classes)]))
|
| 113 |
+
mean_tpr = np.zeros_like(all_fpr)
|
| 114 |
+
for i in range(n_classes):
|
| 115 |
+
mean_tpr += np.interp(all_fpr, fpr_dict[i], tpr_dict[i])
|
| 116 |
+
mean_tpr /= n_classes
|
| 117 |
+
macro_auc = auc_score(all_fpr, mean_tpr)
|
| 118 |
+
|
| 119 |
+
COLORS = ['#e41a1c','#377eb8','#4daf4a','#984ea3','#ff7f00','#a65628','#f781bf','#999999']
|
| 120 |
+
plt.figure(figsize=(10, 8))
|
| 121 |
+
for i in range(n_classes):
|
| 122 |
+
plt.plot(fpr_dict[i], tpr_dict[i], color=COLORS[i % len(COLORS)], lw=2,
|
| 123 |
+
label=f'Class {classes[i]} (AUC={auc_dict[i]:.3f})')
|
| 124 |
+
plt.plot(all_fpr, mean_tpr, 'k--', lw=2.5, label=f'Macro Avg (AUC={macro_auc:.3f})')
|
| 125 |
+
plt.plot([0,1],[0,1],'--',color='#cccccc',lw=1)
|
| 126 |
+
plt.xlim([-0.02,1.02]); plt.ylim([-0.02,1.02])
|
| 127 |
+
plt.xlabel('False Positive Rate', fontsize=13); plt.ylabel('True Positive Rate', fontsize=13)
|
| 128 |
+
plt.title(title, fontsize=14, fontweight='bold')
|
| 129 |
+
plt.legend(loc='lower right', fontsize=9); plt.grid(True, alpha=0.15); plt.tight_layout()
|
| 130 |
+
plt.savefig(os.path.join(rf, f'{filepath_prefix}.pdf'), format='pdf', bbox_inches='tight', dpi=300)
|
| 131 |
+
plt.savefig(os.path.join(rf, f'{filepath_prefix}.png'), format='png', bbox_inches='tight', dpi=150)
|
| 132 |
+
plt.close()
|
| 133 |
+
return macro_auc, auc_dict
|
| 134 |
+
|
| 135 |
+
def plot_multiclass_pr(y_true, y_proba, classes, title, filepath_prefix, rf):
|
| 136 |
+
"""Plot Precision-Recall curves for each class"""
|
| 137 |
+
n_classes = len(classes)
|
| 138 |
+
y_bin = label_binarize(y_true, classes=classes)
|
| 139 |
+
if n_classes == 2:
|
| 140 |
+
y_bin = np.hstack([1 - y_bin, y_bin])
|
| 141 |
+
|
| 142 |
+
COLORS = ['#e41a1c','#377eb8','#4daf4a','#984ea3','#ff7f00','#a65628','#f781bf','#999999']
|
| 143 |
+
plt.figure(figsize=(10, 8))
|
| 144 |
+
for i in range(n_classes):
|
| 145 |
+
prec, rec, _ = precision_recall_curve(y_bin[:, i], y_proba[:, i])
|
| 146 |
+
ap = auc_score(rec, prec)
|
| 147 |
+
plt.plot(rec, prec, color=COLORS[i % len(COLORS)], lw=2,
|
| 148 |
+
label=f'Class {classes[i]} (AP={ap:.3f})')
|
| 149 |
+
plt.xlim([-0.02,1.02]); plt.ylim([-0.02,1.02])
|
| 150 |
+
plt.xlabel('Recall', fontsize=13); plt.ylabel('Precision', fontsize=13)
|
| 151 |
+
plt.title(title, fontsize=14, fontweight='bold')
|
| 152 |
+
plt.legend(loc='lower left', fontsize=9); plt.grid(True, alpha=0.15); plt.tight_layout()
|
| 153 |
+
plt.savefig(os.path.join(rf, f'{filepath_prefix}.pdf'), format='pdf', bbox_inches='tight', dpi=300)
|
| 154 |
+
plt.savefig(os.path.join(rf, f'{filepath_prefix}.png'), format='png', bbox_inches='tight', dpi=150)
|
| 155 |
+
plt.close()
|
| 156 |
+
|
| 157 |
+
def plot_confusion_matrix(y_true, y_pred, classes, title, filepath_prefix, rf):
|
| 158 |
+
"""Plot confusion matrix heatmap"""
|
| 159 |
+
cm = confusion_matrix(y_true, y_pred, labels=classes)
|
| 160 |
+
plt.figure(figsize=(max(6, len(classes)*1.2), max(5, len(classes)*1.0)))
|
| 161 |
+
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=True,
|
| 162 |
+
xticklabels=classes, yticklabels=classes, annot_kws={'fontsize': 11})
|
| 163 |
+
plt.xlabel('Predicted', fontsize=12); plt.ylabel('True', fontsize=12)
|
| 164 |
+
plt.title(title, fontsize=13, fontweight='bold'); plt.tight_layout()
|
| 165 |
+
plt.savefig(os.path.join(rf, f'{filepath_prefix}.pdf'), format='pdf', bbox_inches='tight', dpi=300)
|
| 166 |
+
plt.savefig(os.path.join(rf, f'{filepath_prefix}.png'), format='png', bbox_inches='tight', dpi=150)
|
| 167 |
+
plt.close()
|
| 168 |
+
return cm
|
| 169 |
+
|
| 170 |
+
def compute_multiclass_metrics(y_true, y_pred, y_proba, classes):
|
| 171 |
+
"""Compute comprehensive multi-class metrics"""
|
| 172 |
+
n_classes = len(classes)
|
| 173 |
+
acc = accuracy_score(y_true, y_pred)
|
| 174 |
+
kappa = cohen_kappa_score(y_true, y_pred)
|
| 175 |
+
|
| 176 |
+
# Per-class from classification_report
|
| 177 |
+
report = classification_report(y_true, y_pred, labels=classes, output_dict=True, zero_division=0)
|
| 178 |
+
|
| 179 |
+
# AUC
|
| 180 |
+
try:
|
| 181 |
+
if n_classes == 2:
|
| 182 |
+
macro_auc = roc_auc_score(y_true, y_proba[:, 1])
|
| 183 |
+
else:
|
| 184 |
+
macro_auc = roc_auc_score(y_true, y_proba, multi_class='ovr', average='macro')
|
| 185 |
+
except:
|
| 186 |
+
macro_auc = 0.0
|
| 187 |
+
|
| 188 |
+
f1_macro = f1_score(y_true, y_pred, average='macro', zero_division=0)
|
| 189 |
+
f1_weighted = f1_score(y_true, y_pred, average='weighted', zero_division=0)
|
| 190 |
+
|
| 191 |
+
return {
|
| 192 |
+
'Accuracy': acc, 'Macro_AUC': macro_auc, 'Macro_F1': f1_macro,
|
| 193 |
+
'Weighted_F1': f1_weighted, 'Kappa': kappa, 'report': report
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def bootstrap_auc_test(y_true, proba_a, proba_b, classes, n_bootstrap=2000, seed=42):
|
| 198 |
+
"""
|
| 199 |
+
Bootstrap 检验:比较两个模型的 Macro AUC 是否有显著差异
|
| 200 |
+
适用于多分类(DeLong 检验的替代方案)
|
| 201 |
+
返回: p_value, auc_a, auc_b, ci_low, ci_high (差值的95%置信区间)
|
| 202 |
+
"""
|
| 203 |
+
rng = np.random.RandomState(seed)
|
| 204 |
+
n = len(y_true)
|
| 205 |
+
n_classes = len(classes)
|
| 206 |
+
|
| 207 |
+
def calc_macro_auc(yt, pa, pb):
|
| 208 |
+
try:
|
| 209 |
+
if n_classes == 2:
|
| 210 |
+
a1 = roc_auc_score(yt, pa[:, 1])
|
| 211 |
+
a2 = roc_auc_score(yt, pb[:, 1])
|
| 212 |
+
else:
|
| 213 |
+
a1 = roc_auc_score(yt, pa, multi_class='ovr', average='macro')
|
| 214 |
+
a2 = roc_auc_score(yt, pb, multi_class='ovr', average='macro')
|
| 215 |
+
return a1, a2
|
| 216 |
+
except:
|
| 217 |
+
return 0.0, 0.0
|
| 218 |
+
|
| 219 |
+
# Observed AUC
|
| 220 |
+
auc_a, auc_b = calc_macro_auc(y_true, proba_a, proba_b)
|
| 221 |
+
observed_diff = auc_a - auc_b
|
| 222 |
+
|
| 223 |
+
# Bootstrap
|
| 224 |
+
diffs = []
|
| 225 |
+
for _ in range(n_bootstrap):
|
| 226 |
+
idx = rng.choice(n, n, replace=True)
|
| 227 |
+
yt_b = y_true[idx]; pa_b = proba_a[idx]; pb_b = proba_b[idx]
|
| 228 |
+
# Ensure all classes present in bootstrap sample
|
| 229 |
+
if len(np.unique(yt_b)) < n_classes:
|
| 230 |
+
continue
|
| 231 |
+
a1, a2 = calc_macro_auc(yt_b, pa_b, pb_b)
|
| 232 |
+
diffs.append(a1 - a2)
|
| 233 |
+
|
| 234 |
+
if len(diffs) < 100:
|
| 235 |
+
return 1.0, auc_a, auc_b, -1, 1 # Not enough valid bootstraps
|
| 236 |
+
|
| 237 |
+
diffs = np.array(diffs)
|
| 238 |
+
# Two-sided p-value: proportion of bootstrap diffs that cross zero
|
| 239 |
+
# Under H0: diff=0, we center the diffs
|
| 240 |
+
centered = diffs - np.mean(diffs)
|
| 241 |
+
p_value = np.mean(np.abs(centered) >= np.abs(observed_diff))
|
| 242 |
+
p_value = max(p_value, 1.0 / n_bootstrap) # Floor
|
| 243 |
+
|
| 244 |
+
ci_low = np.percentile(diffs, 2.5)
|
| 245 |
+
ci_high = np.percentile(diffs, 97.5)
|
| 246 |
+
|
| 247 |
+
return p_value, auc_a, auc_b, ci_low, ci_high
|
| 248 |
+
|
| 249 |
+
# ============================================================================
|
| 250 |
+
# Model configs (multi-class compatible)
|
| 251 |
+
# ============================================================================
|
| 252 |
+
ALL_MODEL_NAMES = ['RF', 'DT', 'KNN', 'XGB', 'AdaBoost', 'LR', 'NB', 'SVM']
|
| 253 |
+
|
| 254 |
+
def get_models_config(selected, n_classes, rs=42):
|
| 255 |
+
cfg = {
|
| 256 |
+
'RF': {'model': RandomForestClassifier(random_state=rs, n_jobs=-1),
|
| 257 |
+
'params': {'n_estimators': [100,200], 'max_depth': [20,50], 'min_samples_split': [2,5]}},
|
| 258 |
+
'DT': {'model': DecisionTreeClassifier(random_state=rs),
|
| 259 |
+
'params': {'max_depth': [20,50], 'min_samples_split': [2,10], 'criterion': ['gini','entropy']}},
|
| 260 |
+
'KNN': {'model': KNeighborsClassifier(n_jobs=-1),
|
| 261 |
+
'params': {'n_neighbors': [3,5,7], 'weights': ['uniform','distance']}},
|
| 262 |
+
'XGB': {'model': XGBClassifier(random_state=rs, eval_metric='mlogloss', n_jobs=-1,
|
| 263 |
+
num_class=n_classes if n_classes > 2 else None,
|
| 264 |
+
objective='multi:softprob' if n_classes > 2 else 'binary:logistic'),
|
| 265 |
+
'params': {'n_estimators': [100,200], 'max_depth': [5,7], 'learning_rate': [0.05,0.1]}},
|
| 266 |
+
'AdaBoost': {'model': AdaBoostClassifier(random_state=rs),
|
| 267 |
+
'params': {'n_estimators': [50,100], 'learning_rate': [0.1,0.5,1.0]}},
|
| 268 |
+
'LR': {'model': LogisticRegression(random_state=rs, n_jobs=-1, max_iter=2000),
|
| 269 |
+
'params': {'C': [0.1,1,10], 'solver': ['lbfgs']}},
|
| 270 |
+
'NB': {'model': GaussianNB(),
|
| 271 |
+
'params': {'var_smoothing': [1e-9,1e-7,1e-5]}},
|
| 272 |
+
'SVM': {'model': SVC(probability=True, random_state=rs, decision_function_shape='ovr'),
|
| 273 |
+
'params': {'C': [1,10], 'kernel': ['rbf','linear']}},
|
| 274 |
+
}
|
| 275 |
+
# Clean XGB None params
|
| 276 |
+
if n_classes <= 2:
|
| 277 |
+
xgb_params = cfg['XGB']['model'].get_params()
|
| 278 |
+
if 'num_class' in xgb_params and xgb_params['num_class'] is None:
|
| 279 |
+
cfg['XGB']['model'] = XGBClassifier(random_state=rs, eval_metric='logloss', n_jobs=-1)
|
| 280 |
+
return {k: v for k, v in cfg.items() if k in selected}
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
# ============================================================================
|
| 284 |
+
# Main Pipeline
|
| 285 |
+
# ============================================================================
|
| 286 |
+
def run_pipeline(
|
| 287 |
+
train_file, val_file1, val_file2, val_file3, n_classes_select,
|
| 288 |
+
selected_models, enable_tuning,
|
| 289 |
+
cv_folds, top_n_features, shap_sample_size,
|
| 290 |
+
progress=gr.Progress(track_tqdm=True),
|
| 291 |
+
):
|
| 292 |
+
if train_file is None:
|
| 293 |
+
return None, "❌ 请先上传训练集 CSV 文件"
|
| 294 |
+
sel = selected_models if isinstance(selected_models, list) else [s.strip() for s in str(selected_models).split(",") if s.strip()]
|
| 295 |
+
if not sel:
|
| 296 |
+
return None, "❌ 请至少选择一个模型"
|
| 297 |
+
|
| 298 |
+
RS = 42; CVF = int(cv_folds)
|
| 299 |
+
TOPN = int(top_n_features); SHAPSZ = int(shap_sample_size)
|
| 300 |
+
TUNING = bool(enable_tuning)
|
| 301 |
+
|
| 302 |
+
L = []
|
| 303 |
+
def log(m): L.append(str(m))
|
| 304 |
+
|
| 305 |
+
rf = tempfile.mkdtemp(prefix="ml_")
|
| 306 |
+
|
| 307 |
+
try:
|
| 308 |
+
# ── Load Data ──
|
| 309 |
+
progress(0.02, desc="📂 加载数据...")
|
| 310 |
+
log("━" * 50)
|
| 311 |
+
log(" 🧬 ML 多分类模型训练与评估系统")
|
| 312 |
+
log("━" * 50)
|
| 313 |
+
|
| 314 |
+
tp = train_file if isinstance(train_file, str) else getattr(train_file, 'name', str(train_file))
|
| 315 |
+
data = pd.read_csv(tp)
|
| 316 |
+
|
| 317 |
+
# Smart CSV format detection
|
| 318 |
+
y = data.iloc[:, 0]
|
| 319 |
+
col2 = data.iloc[:, 1]
|
| 320 |
+
col2_is_id = (col2.dtype == 'object') or (col2.nunique() / len(col2) > 0.5)
|
| 321 |
+
if col2_is_id:
|
| 322 |
+
X = data.iloc[:, 2:]
|
| 323 |
+
log(f" 📋 CSV: Col1=Label, Col2=ID({data.columns[1]}), Col3+=Features")
|
| 324 |
+
else:
|
| 325 |
+
X = data.iloc[:, 1:]
|
| 326 |
+
log(f" 📋 CSV: Col1=Label, Col2+=Features (no ID column)")
|
| 327 |
+
fnames = X.columns.tolist()
|
| 328 |
+
|
| 329 |
+
# Parse user selection: "3 类" -> 3, "2 类(二分类)" -> 2
|
| 330 |
+
user_n = int(str(n_classes_select).split(" ")[0])
|
| 331 |
+
|
| 332 |
+
# Validate against actual data
|
| 333 |
+
detected_classes = sorted(y.unique())
|
| 334 |
+
detected_classes = [int(c) if hasattr(c, 'item') else c for c in detected_classes]
|
| 335 |
+
detected_n = len(detected_classes)
|
| 336 |
+
|
| 337 |
+
if detected_n != user_n:
|
| 338 |
+
return None, (f"❌ 您选择了 {user_n} 分类,但数据中检测到 {detected_n} 个类别: {detected_classes}\n"
|
| 339 |
+
f"请将分类数修改为 {detected_n},或检查数据标签列")
|
| 340 |
+
|
| 341 |
+
classes = detected_classes
|
| 342 |
+
n_classes = user_n
|
| 343 |
+
log(f" ✅ {n_classes} 分类 — 数据验证通过")
|
| 344 |
+
|
| 345 |
+
# Remap to 0,1,...,n-1
|
| 346 |
+
label_map = {c: i for i, c in enumerate(classes)}
|
| 347 |
+
label_map_inv = {i: c for c, i in label_map.items()}
|
| 348 |
+
y_mapped = y.map(label_map)
|
| 349 |
+
class_indices = list(range(n_classes))
|
| 350 |
+
|
| 351 |
+
log(f" 📊 训练集: {X.shape[0]} 样本 × {X.shape[1]} 特征")
|
| 352 |
+
log(f" 🏷️ 类别数: {n_classes} 类 — {classes}")
|
| 353 |
+
log(f" 📊 分布: {dict(y.value_counts().sort_index())}")
|
| 354 |
+
log(f" 🤖 模型: {', '.join(sel)}")
|
| 355 |
+
log(f" 🔧 调优: {'开启' if TUNING else '关闭'} | CV: {CVF}折")
|
| 356 |
+
|
| 357 |
+
if n_classes < 2 or n_classes > 8:
|
| 358 |
+
return None, f"❌ 仅支持 2~8 分类,当前检测到 {n_classes} 类"
|
| 359 |
+
|
| 360 |
+
task_type = "Binary" if n_classes == 2 else f"{n_classes}-Class"
|
| 361 |
+
task_type_cn = "二分类" if n_classes == 2 else f"{n_classes}分类"
|
| 362 |
+
log(f" 📋 任务: {task_type_cn} ({task_type})")
|
| 363 |
+
|
| 364 |
+
mcfg = get_models_config(sel, n_classes, RS)
|
| 365 |
+
skf = StratifiedKFold(n_splits=CVF, shuffle=True, random_state=RS)
|
| 366 |
+
|
| 367 |
+
# ── Train All Models ──
|
| 368 |
+
bpd = {}; amr = {}; tms = {}
|
| 369 |
+
total = len(mcfg)
|
| 370 |
+
COLORS = ['#2563eb','#f59e0b','#10b981','#ef4444','#8b5cf6','#ec4899','#06b6d4','#6b7280']
|
| 371 |
+
|
| 372 |
+
for mi, (mn, cf) in enumerate(mcfg.items()):
|
| 373 |
+
pv = 0.05 + 0.35 * mi / total
|
| 374 |
+
progress(pv, desc=f"🏋️ [{mi+1}/{total}] 训练 {mn}...")
|
| 375 |
+
log(f"\n{'─'*40}")
|
| 376 |
+
log(f" 🔄 [{mi+1}/{total}] {mn}")
|
| 377 |
+
|
| 378 |
+
Xv = X.values
|
| 379 |
+
if TUNING:
|
| 380 |
+
log(f" ⏳ GridSearchCV (CV={CVF})...")
|
| 381 |
+
scoring = 'roc_auc_ovr' if n_classes > 2 else 'roc_auc'
|
| 382 |
+
gs = GridSearchCV(cf['model'], cf['params'], cv=skf, scoring=scoring, n_jobs=-1, verbose=0)
|
| 383 |
+
gs.fit(Xv, y_mapped)
|
| 384 |
+
bp = gs.best_params_; bpd[mn] = bp
|
| 385 |
+
log(f" ✓ 最佳CV Score: {gs.best_score_:.4f}")
|
| 386 |
+
else:
|
| 387 |
+
bp = {}; bpd[mn] = "Default"
|
| 388 |
+
|
| 389 |
+
mdl = deepcopy(cf['model'])
|
| 390 |
+
if bp: mdl.set_params(**bp)
|
| 391 |
+
mdl.fit(Xv, y_mapped)
|
| 392 |
+
tms[mn] = mdl
|
| 393 |
+
|
| 394 |
+
# CV evaluation
|
| 395 |
+
all_yt = []; all_yp = []; all_yproba = []
|
| 396 |
+
fold_metrics = []
|
| 397 |
+
|
| 398 |
+
for fi, (tri, tei) in enumerate(skf.split(X, y_mapped), 1):
|
| 399 |
+
Xtr, Xte = X.iloc[tri].values, X.iloc[tei].values
|
| 400 |
+
ytr, yte = y_mapped.iloc[tri], y_mapped.iloc[tei]
|
| 401 |
+
mf = deepcopy(cf['model'])
|
| 402 |
+
if bp: mf.set_params(**bp)
|
| 403 |
+
mf.fit(Xtr, ytr)
|
| 404 |
+
ypred = mf.predict(Xte)
|
| 405 |
+
yproba = mf.predict_proba(Xte)
|
| 406 |
+
|
| 407 |
+
all_yt.extend(yte); all_yp.extend(ypred); all_yproba.append(yproba)
|
| 408 |
+
|
| 409 |
+
metrics = compute_multiclass_metrics(yte, ypred, yproba, class_indices)
|
| 410 |
+
fold_metrics.append({
|
| 411 |
+
'Fold': fi, 'Accuracy': metrics['Accuracy'],
|
| 412 |
+
'Macro_AUC': metrics['Macro_AUC'], 'Macro_F1': metrics['Macro_F1'],
|
| 413 |
+
'Weighted_F1': metrics['Weighted_F1'], 'Kappa': metrics['Kappa']
|
| 414 |
+
})
|
| 415 |
+
|
| 416 |
+
all_yt = np.array(all_yt); all_yp = np.array(all_yp)
|
| 417 |
+
all_yproba = np.vstack(all_yproba)
|
| 418 |
+
|
| 419 |
+
fdf = pd.DataFrame(fold_metrics)
|
| 420 |
+
mean_row = {col: fdf[col].mean() if col != 'Fold' else 'Mean' for col in fdf.columns}
|
| 421 |
+
fdf = pd.concat([fdf, pd.DataFrame([mean_row])], ignore_index=True)
|
| 422 |
+
|
| 423 |
+
amr[mn] = {
|
| 424 |
+
'fold_df': fdf, 'mean_auc': mean_row['Macro_AUC'],
|
| 425 |
+
'mean_acc': mean_row['Accuracy'], 'mean_f1': mean_row['Macro_F1'],
|
| 426 |
+
'all_yt': all_yt, 'all_yp': all_yp, 'all_yproba': all_yproba
|
| 427 |
+
}
|
| 428 |
+
log(f" ✅ AUC={mean_row['Macro_AUC']:.4f} Acc={mean_row['Accuracy']:.4f} F1={mean_row['Macro_F1']:.4f} Kappa={mean_row['Kappa']:.4f}")
|
| 429 |
+
|
| 430 |
+
mnames = list(amr.keys()); nm = len(mnames)
|
| 431 |
+
log(f"\n{'━'*50}")
|
| 432 |
+
log(f" ✅ {nm} 个模型训练完成")
|
| 433 |
+
|
| 434 |
+
# ── ROC Curves ──
|
| 435 |
+
progress(0.42, desc="📈 ROC曲线...")
|
| 436 |
+
log(f"\n 📈 绘制图表...")
|
| 437 |
+
for mn in mnames:
|
| 438 |
+
r = amr[mn]
|
| 439 |
+
plot_multiclass_roc(r['all_yt'], r['all_yproba'], class_indices,
|
| 440 |
+
f'ROC — {mn} ({task_type}, Macro AUC={r["mean_auc"]:.3f})', f'roc_{mn}', rf)
|
| 441 |
+
|
| 442 |
+
# Combined ROC (macro per model)
|
| 443 |
+
plt.figure(figsize=(10, 8))
|
| 444 |
+
for i, mn in enumerate(mnames):
|
| 445 |
+
r = amr[mn]
|
| 446 |
+
y_bin = label_binarize(r['all_yt'], classes=class_indices)
|
| 447 |
+
if n_classes == 2: y_bin = np.hstack([1 - y_bin, y_bin])
|
| 448 |
+
all_fpr = np.linspace(0, 1, 200); mean_tpr = np.zeros_like(all_fpr)
|
| 449 |
+
for c in range(n_classes):
|
| 450 |
+
f, t, _ = roc_curve(y_bin[:, c], r['all_yproba'][:, c])
|
| 451 |
+
mean_tpr += np.interp(all_fpr, f, t)
|
| 452 |
+
mean_tpr /= n_classes; mean_tpr[-1] = 1.0
|
| 453 |
+
ma = auc_score(all_fpr, mean_tpr)
|
| 454 |
+
plt.plot(all_fpr, mean_tpr, color=COLORS[i%8], lw=2.5, label=f'{mn} (Macro AUC={ma:.3f})')
|
| 455 |
+
plt.plot([0,1],[0,1],'--',color='#ccc',lw=1)
|
| 456 |
+
plt.xlim([-0.02,1.02]); plt.ylim([-0.02,1.02])
|
| 457 |
+
plt.xlabel('FPR',fontsize=13); plt.ylabel('TPR',fontsize=13)
|
| 458 |
+
plt.title(f'ROC — All Models ({task_type})',fontsize=14,fontweight='bold')
|
| 459 |
+
plt.legend(loc='lower right',fontsize=10); plt.grid(True,alpha=0.15); plt.tight_layout()
|
| 460 |
+
plt.savefig(os.path.join(rf,'roc_all.pdf'),format='pdf',bbox_inches='tight',dpi=300)
|
| 461 |
+
plt.savefig(os.path.join(rf,'roc_all.png'),format='png',bbox_inches='tight',dpi=150)
|
| 462 |
+
plt.close()
|
| 463 |
+
|
| 464 |
+
# ── PR Curves ──
|
| 465 |
+
progress(0.48, desc="📈 PR曲线...")
|
| 466 |
+
for mn in mnames:
|
| 467 |
+
r = amr[mn]
|
| 468 |
+
plot_multiclass_pr(r['all_yt'], r['all_yproba'], class_indices,
|
| 469 |
+
f'PR — {mn} ({task_type})', f'pr_{mn}', rf)
|
| 470 |
+
|
| 471 |
+
# ── Confusion Matrices ──
|
| 472 |
+
progress(0.52, desc="📊 混淆矩阵...")
|
| 473 |
+
for mn in mnames:
|
| 474 |
+
r = amr[mn]
|
| 475 |
+
plot_confusion_matrix(r['all_yt'], r['all_yp'], class_indices,
|
| 476 |
+
f'CM — {mn} (Acc={r["mean_acc"]:.3f})', f'cm_{mn}', rf)
|
| 477 |
+
|
| 478 |
+
# ── Bootstrap AUC Test (模型统计检验) ──
|
| 479 |
+
progress(0.55, desc="🔬 Bootstrap AUC 检验...")
|
| 480 |
+
best_mn = max(amr, key=lambda x: amr[x]['mean_auc'])
|
| 481 |
+
best_auc = amr[best_mn]['mean_auc']
|
| 482 |
+
log(f"\n 🏆 最佳模型: {best_mn} (Macro AUC={best_auc:.4f})")
|
| 483 |
+
log(f" 🔬 Bootstrap 检验 (n=2000, α=0.05)...")
|
| 484 |
+
|
| 485 |
+
ALPHA = 0.05
|
| 486 |
+
bootstrap_results = []
|
| 487 |
+
retained = [best_mn]
|
| 488 |
+
|
| 489 |
+
for om in mnames:
|
| 490 |
+
if om == best_mn:
|
| 491 |
+
continue
|
| 492 |
+
p_val, auc_a, auc_b, ci_lo, ci_hi = bootstrap_auc_test(
|
| 493 |
+
amr[best_mn]['all_yt'],
|
| 494 |
+
amr[best_mn]['all_yproba'],
|
| 495 |
+
amr[om]['all_yproba'],
|
| 496 |
+
class_indices, n_bootstrap=2000
|
| 497 |
+
)
|
| 498 |
+
if p_val >= ALPHA:
|
| 499 |
+
retained.append(om)
|
| 500 |
+
dec = "Retained"
|
| 501 |
+
else:
|
| 502 |
+
dec = "Excluded"
|
| 503 |
+
|
| 504 |
+
bootstrap_results.append({
|
| 505 |
+
'Model_A': best_mn, 'AUC_A': auc_a,
|
| 506 |
+
'Model_B': om, 'AUC_B': auc_b,
|
| 507 |
+
'AUC_Diff': auc_a - auc_b,
|
| 508 |
+
'CI_95_Low': ci_lo, 'CI_95_High': ci_hi,
|
| 509 |
+
'P_value': p_val, 'Decision': dec
|
| 510 |
+
})
|
| 511 |
+
log(f" {best_mn} vs {om}: ΔAUC={auc_a-auc_b:+.4f} 95%CI=[{ci_lo:+.4f},{ci_hi:+.4f}] P={p_val:.4f} → {dec}")
|
| 512 |
+
|
| 513 |
+
bootstrap_df = pd.DataFrame(bootstrap_results).sort_values('P_value', ascending=False) if bootstrap_results else pd.DataFrame()
|
| 514 |
+
log(f" ✅ 保留 {len(retained)}/{nm} 个模型: {', '.join(retained)}")
|
| 515 |
+
|
| 516 |
+
# ── SHAP ──
|
| 517 |
+
progress(0.62, desc="🔥 SHAP分析...")
|
| 518 |
+
log(f"\n 🔥 SHAP特征分析 (保留模型中 Top 3)...")
|
| 519 |
+
shap_imp = {}
|
| 520 |
+
# SHAP for top 3 retained models
|
| 521 |
+
models_for_shap = sorted(retained, key=lambda x: amr[x]['mean_auc'], reverse=True)[:3]
|
| 522 |
+
|
| 523 |
+
for si, mn in enumerate(models_for_shap):
|
| 524 |
+
progress(0.62 + 0.10 * si / max(len(models_for_shap),1), desc=f"🔥 SHAP: {mn}...")
|
| 525 |
+
mo = tms[mn]; Xshap = X.values
|
| 526 |
+
ns = min(SHAPSZ, Xshap.shape[0])
|
| 527 |
+
np.random.seed(RS); sidx = np.random.choice(Xshap.shape[0], ns, replace=False)
|
| 528 |
+
Xs = Xshap[sidx]
|
| 529 |
+
try:
|
| 530 |
+
if mn in ['RF', 'XGB', 'DT', 'AdaBoost']:
|
| 531 |
+
exp = shap.TreeExplainer(mo); sv = exp.shap_values(Xs)
|
| 532 |
+
else:
|
| 533 |
+
bg = Xs[np.random.choice(ns, min(50, ns), replace=False)]
|
| 534 |
+
exp = shap.KernelExplainer(lambda x, m=mo: m.predict_proba(x), bg)
|
| 535 |
+
sv = exp.shap_values(Xs)
|
| 536 |
+
|
| 537 |
+
# Handle SHAP output: could be list of arrays (one per class) or 3D array
|
| 538 |
+
if isinstance(sv, list):
|
| 539 |
+
# Average absolute SHAP across all classes
|
| 540 |
+
sv_abs = np.mean([np.abs(s) for s in sv], axis=0)
|
| 541 |
+
elif sv.ndim == 3:
|
| 542 |
+
sv_abs = np.mean(np.abs(sv), axis=2) # (samples, features)
|
| 543 |
+
else:
|
| 544 |
+
sv_abs = np.abs(sv)
|
| 545 |
+
|
| 546 |
+
fi = sv_abs.mean(axis=0)
|
| 547 |
+
if len(fi) > len(fnames): fi = fi[:len(fnames)]
|
| 548 |
+
elif len(fi) < len(fnames): fi = np.pad(fi, (0, len(fnames) - len(fi)))
|
| 549 |
+
|
| 550 |
+
idf = pd.DataFrame({'Feature': fnames, 'Importance': fi}).sort_values('Importance', ascending=False)
|
| 551 |
+
shap_imp[mn] = idf
|
| 552 |
+
|
| 553 |
+
# Bar plot (works for any number of classes)
|
| 554 |
+
plt.figure(figsize=(10, max(6, TOPN * 0.3)))
|
| 555 |
+
top_df = idf.head(TOPN).iloc[::-1]
|
| 556 |
+
plt.barh(top_df['Feature'], top_df['Importance'], color='#2563eb', alpha=0.8)
|
| 557 |
+
plt.xlabel('Mean |SHAP|', fontsize=12)
|
| 558 |
+
plt.title(f'SHAP Feature Importance — {mn} (Top {TOPN})', fontsize=13, fontweight='bold')
|
| 559 |
+
plt.tight_layout()
|
| 560 |
+
plt.savefig(os.path.join(rf, f'shap_{mn}.pdf'), format='pdf', bbox_inches='tight')
|
| 561 |
+
plt.savefig(os.path.join(rf, f'shap_{mn}.png'), format='png', bbox_inches='tight', dpi=150)
|
| 562 |
+
plt.close()
|
| 563 |
+
log(f" ✅ {mn} Top3: {', '.join(idf.head(3)['Feature'].tolist())}")
|
| 564 |
+
except Exception as e:
|
| 565 |
+
log(f" ⚠ {mn} SHAP失败: {e}")
|
| 566 |
+
|
| 567 |
+
# ── Feature Ablation (for best model only) ──
|
| 568 |
+
progress(0.72, desc="🧪 特征消融...")
|
| 569 |
+
log(f"\n 🧪 特征消融 (仅最佳模型 {best_mn})...")
|
| 570 |
+
ablation_data = None
|
| 571 |
+
if best_mn in shap_imp:
|
| 572 |
+
imp_df = shap_imp[best_mn]
|
| 573 |
+
top_feats = imp_df.head(TOPN)['Feature'].tolist()
|
| 574 |
+
fcs = []; aucs_a = []
|
| 575 |
+
scoring = 'roc_auc_ovr' if n_classes > 2 else 'roc_auc'
|
| 576 |
+
|
| 577 |
+
for nf in range(1, len(top_feats) + 1):
|
| 578 |
+
Xsub = X[top_feats[:nf]]
|
| 579 |
+
fold_aucs = []
|
| 580 |
+
for tri, tei in skf.split(Xsub, y_mapped):
|
| 581 |
+
mf = deepcopy(mcfg[best_mn]['model'])
|
| 582 |
+
bp2 = bpd.get(best_mn, {})
|
| 583 |
+
if isinstance(bp2, dict) and bp2: mf.set_params(**bp2)
|
| 584 |
+
mf.fit(Xsub.iloc[tri].values, y_mapped.iloc[tri])
|
| 585 |
+
yproba_f = mf.predict_proba(Xsub.iloc[tei].values)
|
| 586 |
+
yte_f = y_mapped.iloc[tei]
|
| 587 |
+
try:
|
| 588 |
+
if n_classes == 2:
|
| 589 |
+
a = roc_auc_score(yte_f, yproba_f[:, 1])
|
| 590 |
+
else:
|
| 591 |
+
a = roc_auc_score(yte_f, yproba_f, multi_class='ovr', average='macro')
|
| 592 |
+
except: a = 0.0
|
| 593 |
+
fold_aucs.append(a)
|
| 594 |
+
fcs.append(nf); aucs_a.append(np.mean(fold_aucs))
|
| 595 |
+
|
| 596 |
+
# Find optimal: first N where AUC >= 95% of full AUC
|
| 597 |
+
full_auc = amr[best_mn]['mean_auc']
|
| 598 |
+
opt_n = len(top_feats)
|
| 599 |
+
for i, a in enumerate(aucs_a):
|
| 600 |
+
if a >= full_auc * 0.95:
|
| 601 |
+
opt_n = i + 1; break
|
| 602 |
+
|
| 603 |
+
ablation_data = {'fcs': fcs, 'aucs': aucs_a, 'feats': top_feats, 'opt_n': opt_n, 'opt_feats': top_feats[:opt_n]}
|
| 604 |
+
log(f" ✅ 最优特征数: {opt_n} (AUC={aucs_a[opt_n-1]:.4f} vs Full={full_auc:.4f})")
|
| 605 |
+
|
| 606 |
+
# Plot
|
| 607 |
+
plt.figure(figsize=(10, 7))
|
| 608 |
+
plt.plot(fcs, aucs_a, 'o-', color='#2563eb', lw=2, ms=5)
|
| 609 |
+
plt.scatter([opt_n], [aucs_a[opt_n-1]], s=200, marker='*', color='#ef4444', edgecolors='black', lw=2, zorder=5)
|
| 610 |
+
plt.axhline(y=full_auc, color='gray', ls='--', lw=1, alpha=0.5, label=f'Full AUC={full_auc:.3f}')
|
| 611 |
+
plt.xlabel('Number of Features', fontsize=13); plt.ylabel('Macro AUC', fontsize=13)
|
| 612 |
+
plt.title(f'Feature Ablation — {best_mn} (★ Optimal={opt_n})', fontsize=14, fontweight='bold')
|
| 613 |
+
plt.legend(fontsize=11); plt.grid(True, alpha=0.15); plt.tight_layout()
|
| 614 |
+
plt.savefig(os.path.join(rf, 'ablation.pdf'), format='pdf', bbox_inches='tight')
|
| 615 |
+
plt.savefig(os.path.join(rf, 'ablation.png'), format='png', bbox_inches='tight', dpi=150)
|
| 616 |
+
plt.close()
|
| 617 |
+
|
| 618 |
+
# ── External Validation ──
|
| 619 |
+
val_files_list = [vf for vf in [val_file1, val_file2, val_file3] if vf is not None]
|
| 620 |
+
final_feats = ablation_data['opt_feats'] if ablation_data else fnames
|
| 621 |
+
|
| 622 |
+
if val_files_list:
|
| 623 |
+
progress(0.82, desc="🧪 外部验证...")
|
| 624 |
+
log(f"\n{'━'*50}")
|
| 625 |
+
log(f" 🧪 外部验证 ({len(val_files_list)} 个验证集)")
|
| 626 |
+
|
| 627 |
+
for vi, vf in enumerate(val_files_list, 1):
|
| 628 |
+
vp = vf if isinstance(vf, str) else getattr(vf, 'name', str(vf))
|
| 629 |
+
ed = pd.read_csv(vp); ye_raw = ed.iloc[:, 0]
|
| 630 |
+
vcol2 = ed.iloc[:, 1]
|
| 631 |
+
vcol2_is_id = (vcol2.dtype == 'object') or (vcol2.nunique() / len(vcol2) > 0.5)
|
| 632 |
+
Xe = ed.iloc[:, 2:] if vcol2_is_id else ed.iloc[:, 1:]
|
| 633 |
+
|
| 634 |
+
# Map validation labels using same mapping
|
| 635 |
+
ye = ye_raw.map(label_map)
|
| 636 |
+
if ye.isna().any():
|
| 637 |
+
log(f" ⚠ 验证集 {vi} 含有训练集中不存在的标签,已跳过")
|
| 638 |
+
continue
|
| 639 |
+
|
| 640 |
+
log(f"\n 📊 验证集 {vi}: {Xe.shape[0]} 样本, {os.path.basename(vp)}")
|
| 641 |
+
|
| 642 |
+
Xes = Xe[final_feats]; Xtf = X[final_feats]
|
| 643 |
+
fm = deepcopy(mcfg[best_mn]['model'])
|
| 644 |
+
bp3 = bpd[best_mn]
|
| 645 |
+
if isinstance(bp3, dict) and bp3: fm.set_params(**bp3)
|
| 646 |
+
fm.fit(Xtf.values, y_mapped)
|
| 647 |
+
yep = fm.predict_proba(Xes.values); yed = fm.predict(Xes.values)
|
| 648 |
+
ye_np = ye.values
|
| 649 |
+
|
| 650 |
+
metrics = compute_multiclass_metrics(ye_np, yed, yep, class_indices)
|
| 651 |
+
log(f" ✅ AUC={metrics['Macro_AUC']:.4f} Acc={metrics['Accuracy']:.4f} F1={metrics['Macro_F1']:.4f} Kappa={metrics['Kappa']:.4f}")
|
| 652 |
+
|
| 653 |
+
sfx = f'_ext{vi}' if len(val_files_list) > 1 else '_ext'
|
| 654 |
+
tag = f'Validation {vi}' if len(val_files_list) > 1 else 'External'
|
| 655 |
+
|
| 656 |
+
plot_multiclass_roc(ye_np, yep, class_indices, f'ROC — {tag} ({best_mn})', f'roc{sfx}', rf)
|
| 657 |
+
plot_multiclass_pr(ye_np, yep, class_indices, f'PR — {tag} ({best_mn})', f'pr{sfx}', rf)
|
| 658 |
+
plot_confusion_matrix(ye_np, yed, class_indices, f'CM — {tag} ({best_mn})', f'cm{sfx}', rf)
|
| 659 |
+
|
| 660 |
+
with pd.ExcelWriter(os.path.join(rf, f'validation{sfx}.xlsx'), engine='openpyxl') as w:
|
| 661 |
+
pd.DataFrame([{'Model': best_mn, 'N_Features': len(final_feats),
|
| 662 |
+
'Macro_AUC': metrics['Macro_AUC'], 'Accuracy': metrics['Accuracy'],
|
| 663 |
+
'Macro_F1': metrics['Macro_F1'], 'Weighted_F1': metrics['Weighted_F1'],
|
| 664 |
+
'Kappa': metrics['Kappa']}]).to_excel(w, sheet_name='Metrics', index=False)
|
| 665 |
+
rpt = pd.DataFrame(metrics['report']).T
|
| 666 |
+
rpt.to_excel(w, sheet_name='Per_Class', index=True)
|
| 667 |
+
pd.DataFrame({'Feature': final_feats}).to_excel(w, sheet_name='Features', index=False)
|
| 668 |
+
|
| 669 |
+
# ── Save Results ──
|
| 670 |
+
progress(0.92, desc="💾 保存结果...")
|
| 671 |
+
log(f"\n 💾 保存结果...")
|
| 672 |
+
|
| 673 |
+
with pd.ExcelWriter(os.path.join(rf, 'model_evaluation.xlsx'), engine='openpyxl') as w:
|
| 674 |
+
for mn, r in amr.items():
|
| 675 |
+
r['fold_df'].to_excel(w, sheet_name=mn, index=False)
|
| 676 |
+
# Summary with retained status
|
| 677 |
+
sd = [{'Model': mn, 'Macro_AUC': r['mean_auc'], 'Accuracy': r['mean_acc'],
|
| 678 |
+
'Macro_F1': r['mean_f1'], 'Retained': 'Yes' if mn in retained else 'No',
|
| 679 |
+
'Best': 'Best' if mn == best_mn else ''}
|
| 680 |
+
for mn, r in amr.items()]
|
| 681 |
+
pd.DataFrame(sd).sort_values('Macro_AUC', ascending=False).to_excel(w, sheet_name='Summary', index=False)
|
| 682 |
+
# Bootstrap test results
|
| 683 |
+
if len(bootstrap_df) > 0:
|
| 684 |
+
bootstrap_df.to_excel(w, sheet_name='Bootstrap_Test', index=False)
|
| 685 |
+
# Per-class report for best model
|
| 686 |
+
best_report = classification_report(amr[best_mn]['all_yt'], amr[best_mn]['all_yp'],
|
| 687 |
+
labels=class_indices, output_dict=True, zero_division=0)
|
| 688 |
+
pd.DataFrame(best_report).T.to_excel(w, sheet_name=f'{best_mn}_PerClass', index=True)
|
| 689 |
+
|
| 690 |
+
if ablation_data:
|
| 691 |
+
with pd.ExcelWriter(os.path.join(rf, 'feature_ablation.xlsx'), engine='openpyxl') as w:
|
| 692 |
+
pd.DataFrame({'N': ablation_data['fcs'], 'AUC': ablation_data['aucs']}).to_excel(w, sheet_name='Ablation', index=False)
|
| 693 |
+
for mn, idf in shap_imp.items():
|
| 694 |
+
idf.to_excel(w, sheet_name=f'{mn}_Imp', index=False)
|
| 695 |
+
|
| 696 |
+
# Save params (English for SCI)
|
| 697 |
+
with open(os.path.join(rf, 'best_params.txt'), 'w', encoding='utf-8') as f:
|
| 698 |
+
f.write(f"Task: {task_type} Classification ({n_classes} classes)\n")
|
| 699 |
+
f.write(f"Classes: {classes}\n")
|
| 700 |
+
f.write(f"Label Mapping: {label_map}\n\n")
|
| 701 |
+
f.write(f"Statistical Test: Bootstrap AUC Test (n=2000, alpha=0.05)\n")
|
| 702 |
+
f.write(f"Retained Models: {', '.join(retained)} ({len(retained)}/{nm})\n\n")
|
| 703 |
+
for mn in mcfg:
|
| 704 |
+
status = "* Best" if mn == best_mn else ("Retained" if mn in retained else "Excluded")
|
| 705 |
+
f.write(f"Model: {mn} | AUC={amr[mn]['mean_auc']:.4f} | {status}\n")
|
| 706 |
+
bp = bpd[mn]
|
| 707 |
+
if isinstance(bp, dict):
|
| 708 |
+
for k, v in bp.items(): f.write(f" {k}: {v}\n")
|
| 709 |
+
else: f.write(f" {bp}\n")
|
| 710 |
+
f.write("\n")
|
| 711 |
+
if len(bootstrap_df) > 0:
|
| 712 |
+
f.write("\n" + "=" * 50 + "\n")
|
| 713 |
+
f.write("Bootstrap AUC Comparison Results\n")
|
| 714 |
+
f.write("=" * 50 + "\n")
|
| 715 |
+
for _, row in bootstrap_df.iterrows():
|
| 716 |
+
f.write(f" {row['Model_A']} vs {row['Model_B']}: ")
|
| 717 |
+
f.write(f"dAUC={row['AUC_Diff']:+.4f} 95%CI=[{row['CI_95_Low']:+.4f},{row['CI_95_High']:+.4f}] ")
|
| 718 |
+
f.write(f"P={row['P_value']:.4f} -> {row['Decision']}\n")
|
| 719 |
+
if ablation_data:
|
| 720 |
+
f.write(f"\nOptimal Features ({ablation_data['opt_n']}): {', '.join(ablation_data['opt_feats'])}\n")
|
| 721 |
+
|
| 722 |
+
# Save model
|
| 723 |
+
pickle.dump({
|
| 724 |
+
'model_name': best_mn, 'model': tms[best_mn], 'best_params': bpd[best_mn],
|
| 725 |
+
'classes': classes, 'n_classes': n_classes, 'label_map': label_map,
|
| 726 |
+
'features': final_feats, 'task_type': task_type
|
| 727 |
+
}, open(os.path.join(rf, f'model_{best_mn}.pkl'), 'wb'))
|
| 728 |
+
|
| 729 |
+
# ── ZIP ──
|
| 730 |
+
progress(0.97, desc="📦 打包ZIP...")
|
| 731 |
+
zp = os.path.join(tempfile.gettempdir(), f"ml_results_{int(time.time())}_{os.getpid()}.zip")
|
| 732 |
+
with zipfile.ZipFile(zp, 'w', zipfile.ZIP_DEFLATED) as zf:
|
| 733 |
+
for root, _, files in os.walk(rf):
|
| 734 |
+
for fn in files: zf.write(os.path.join(root, fn), os.path.relpath(os.path.join(root, fn), rf))
|
| 735 |
+
|
| 736 |
+
nf = sum(len(f) for _, _, f in os.walk(rf))
|
| 737 |
+
shutil.rmtree(rf, ignore_errors=True); gc.collect()
|
| 738 |
+
|
| 739 |
+
log(f"\n{'━'*50}")
|
| 740 |
+
log(f" 🎉 分析完成!共 {nf} 个文件已打包")
|
| 741 |
+
log(f" 📋 Task: {task_type} | Best Model: {best_mn}")
|
| 742 |
+
log(f"{'━'*50}")
|
| 743 |
+
progress(1.0, desc="✅ 完成!")
|
| 744 |
+
return zp, "\n".join(L)
|
| 745 |
+
|
| 746 |
+
except Exception as e:
|
| 747 |
+
log(f"\n❌ 错误: {e}")
|
| 748 |
+
log(traceback.format_exc())
|
| 749 |
+
if os.path.exists(rf): shutil.rmtree(rf, ignore_errors=True)
|
| 750 |
+
gc.collect()
|
| 751 |
+
return None, "\n".join(L)
|
| 752 |
+
|
| 753 |
+
|
| 754 |
+
# ============================================================================
|
| 755 |
+
# Gradio UI
|
| 756 |
+
# ============================================================================
|
| 757 |
+
CUSTOM_CSS = """
|
| 758 |
+
.header-banner {
|
| 759 |
+
background: linear-gradient(135deg, #0a2463 0%, #1e3a7a 40%, #2554a8 100%);
|
| 760 |
+
border-radius: 16px; padding: 28px 36px; margin-bottom: 20px;
|
| 761 |
+
box-shadow: 0 8px 32px rgba(0,0,0,0.18); position: relative; overflow: hidden;
|
| 762 |
+
}
|
| 763 |
+
.header-banner::before {
|
| 764 |
+
content: ''; position: absolute; top: -50%; right: -20%;
|
| 765 |
+
width: 400px; height: 400px;
|
| 766 |
+
background: radial-gradient(circle, rgba(96,165,250,0.2) 0%, transparent 70%);
|
| 767 |
+
border-radius: 50%;
|
| 768 |
+
}
|
| 769 |
+
.header-banner img { max-height: 52px; border-radius: 6px; margin-bottom: 12px; }
|
| 770 |
+
.header-banner h1 { color: #e2e8f0 !important; font-size: 1.7em !important; margin: 4px 0 6px 0 !important; font-weight: 700 !important; }
|
| 771 |
+
.header-banner p { color: #94a3b8 !important; font-size: 0.92em !important; margin: 2px 0 !important; line-height: 1.6; }
|
| 772 |
+
.header-banner .credit { color: #64748b !important; font-size: 0.82em !important; margin-top: 10px !important; border-top: 1px solid rgba(148,163,184,0.15); padding-top: 10px; }
|
| 773 |
+
.section-title {
|
| 774 |
+
background: linear-gradient(90deg, #2563eb 0%, #3b82f6 100%);
|
| 775 |
+
color: white !important; padding: 8px 16px; border-radius: 8px;
|
| 776 |
+
font-size: 0.95em !important; font-weight: 600 !important; margin: 12px 0 8px 0;
|
| 777 |
+
}
|
| 778 |
+
.pipeline-box {
|
| 779 |
+
background: linear-gradient(135deg, #f0f9ff 0%, #e0f2fe 100%);
|
| 780 |
+
border: 1px solid #bae6fd; border-radius: 12px; padding: 14px 18px; margin: 8px 0; font-size: 0.88em;
|
| 781 |
+
}
|
| 782 |
+
.pipeline-box code { background: #2563eb; color: white; padding: 2px 8px; border-radius: 4px; font-size: 0.85em; margin: 0 2px; }
|
| 783 |
+
.log-area textarea {
|
| 784 |
+
font-family: 'Menlo','Consolas',monospace !important; font-size: 12.5px !important; line-height: 1.5 !important;
|
| 785 |
+
background: #0f172a !important; color: #e2e8f0 !important; border-radius: 10px !important; padding: 16px !important;
|
| 786 |
+
}
|
| 787 |
+
.gradio-container { max-width: 1280px !important; }
|
| 788 |
+
footer { display: none !important; }
|
| 789 |
+
"""
|
| 790 |
+
|
| 791 |
+
with gr.Blocks(
|
| 792 |
+
title="ML 多分类模型平台 — 复旦大学附属眼耳鼻喉科医院",
|
| 793 |
+
theme=gr.themes.Soft(primary_hue="blue", secondary_hue="slate", neutral_hue="slate"),
|
| 794 |
+
css=CUSTOM_CSS,
|
| 795 |
+
) as demo:
|
| 796 |
+
|
| 797 |
+
gr.HTML("""
|
| 798 |
+
<div class="header-banner">
|
| 799 |
+
<img src="https://huggingface.co/spaces/fudan-renjun/machine-learning-2/resolve/main/hospital_logo.png"
|
| 800 |
+
alt="Logo" onerror="this.style.display='none'"/>
|
| 801 |
+
<h1>🧬 ML 多分类模型训练与评估平台</h1>
|
| 802 |
+
<p>支持 2~8 分类 · 上传 CSV 即可完成全流程分析</p>
|
| 803 |
+
<p class="credit">复旦大学附属眼耳鼻喉科医院 · 检验科 · 任俊</p>
|
| 804 |
+
</div>
|
| 805 |
+
""")
|
| 806 |
+
|
| 807 |
+
gr.HTML("""
|
| 808 |
+
<div class="pipeline-box">
|
| 809 |
+
<strong>📋 流程:</strong>
|
| 810 |
+
<code>选择分类数</code> → <code>模型训练</code> → <code>交叉验证</code> →
|
| 811 |
+
<code>SHAP分析</code> → <code>特征消融</code> → <code>外部验证</code>
|
| 812 |
+
|
|
| 813 |
+
<strong>CSV格式:</strong> 第1列=标签(整数), 第2列=ID, 第3列起=特征
|
| 814 |
+
|
|
| 815 |
+
<strong>支持:</strong> 2分类 / 3分类 / ... / 8分类
|
| 816 |
+
</div>
|
| 817 |
+
""")
|
| 818 |
+
|
| 819 |
+
with gr.Row(equal_height=False):
|
| 820 |
+
with gr.Column(scale=5):
|
| 821 |
+
gr.HTML('<div class="section-title">📂 数据上传</div>')
|
| 822 |
+
train_file = gr.File(label="训练集 CSV(必需)", file_types=[".csv"])
|
| 823 |
+
gr.HTML('<p style="color:#64748b;font-size:0.85em;margin:4px 0 8px 0;">验证集可选,支持同时上传 1~3 个,分别生成独立报告</p>')
|
| 824 |
+
with gr.Row():
|
| 825 |
+
val_file1 = gr.File(label="验证集 1(可选)", file_types=[".csv"], scale=1)
|
| 826 |
+
val_file2 = gr.File(label="验证集 2(可选)", file_types=[".csv"], scale=1)
|
| 827 |
+
val_file3 = gr.File(label="验证集 3(可选)", file_types=[".csv"], scale=1)
|
| 828 |
+
|
| 829 |
+
gr.HTML('<div class="section-title">🏷️ 分类设置</div>')
|
| 830 |
+
n_classes_select = gr.Dropdown(
|
| 831 |
+
choices=["2 类(二分类)", "3 类", "4 类", "5 类", "6 类", "7 类", "8 类"],
|
| 832 |
+
value="2 类(二分类)",
|
| 833 |
+
label="选择分类数",
|
| 834 |
+
info="请根据数据标签列的类别数选择,系统将自动验证是否匹配",
|
| 835 |
+
)
|
| 836 |
+
|
| 837 |
+
gr.HTML('<div class="section-title">🤖 模型选择</div>')
|
| 838 |
+
model_selector = gr.Dropdown(
|
| 839 |
+
choices=ALL_MODEL_NAMES, value=ALL_MODEL_NAMES, multiselect=True,
|
| 840 |
+
label="选择模型(均支持多分类)",
|
| 841 |
+
info="RF=随机森林 DT=决策树 KNN=K近邻 XGB=XGBoost AdaBoost LR=逻辑回归 NB=朴素贝叶斯 SVM=支持向量机",
|
| 842 |
+
)
|
| 843 |
+
with gr.Row():
|
| 844 |
+
btn_all = gr.Button("🔘 全选", size="sm", variant="secondary")
|
| 845 |
+
btn_tree = gr.Button("🌲 树模型", size="sm", variant="secondary")
|
| 846 |
+
btn_linear = gr.Button("📐 线性模型", size="sm", variant="secondary")
|
| 847 |
+
btn_top4 = gr.Button("⚡ 经典四模型", size="sm", variant="secondary")
|
| 848 |
+
btn_all.click(lambda: ALL_MODEL_NAMES, outputs=model_selector)
|
| 849 |
+
btn_tree.click(lambda: ['RF','DT','XGB','AdaBoost'], outputs=model_selector)
|
| 850 |
+
btn_linear.click(lambda: ['LR','SVM','NB'], outputs=model_selector)
|
| 851 |
+
btn_top4.click(lambda: ['RF','XGB','LR','SVM'], outputs=model_selector)
|
| 852 |
+
|
| 853 |
+
gr.HTML('<div class="section-title">⚙️ 参数配置</div>')
|
| 854 |
+
enable_tuning = gr.Checkbox(value=False, label="启用超参数调优 (GridSearchCV) ⚠️ 开启后运行时间显著增加")
|
| 855 |
+
with gr.Row():
|
| 856 |
+
cv_folds = gr.Slider(3, 10, value=5, step=1, label="交叉验证折数")
|
| 857 |
+
top_n = gr.Slider(5, 50, value=20, step=1, label="SHAP 前 N 个特征")
|
| 858 |
+
shap_sz = gr.Slider(30, 200, value=80, step=10, label="SHAP 采样数量")
|
| 859 |
+
|
| 860 |
+
run_btn = gr.Button("🚀 开始分析", variant="primary", size="lg")
|
| 861 |
+
|
| 862 |
+
with gr.Column(scale=5):
|
| 863 |
+
gr.HTML('<div class="section-title">📋 运行日志</div>')
|
| 864 |
+
log_output = gr.Textbox(
|
| 865 |
+
label="", lines=24, max_lines=50, interactive=False,
|
| 866 |
+
placeholder="点击「开始分析」后,日志将在此显示...\n支持 2~8 分类。",
|
| 867 |
+
elem_classes="log-area",
|
| 868 |
+
)
|
| 869 |
+
gr.HTML('<div class="section-title">⬇️ 结果下载</div>')
|
| 870 |
+
zip_output = gr.File(label="分析结果 ZIP 压缩包")
|
| 871 |
+
|
| 872 |
+
run_btn.click(
|
| 873 |
+
fn=run_pipeline,
|
| 874 |
+
inputs=[train_file, val_file1, val_file2, val_file3, n_classes_select,
|
| 875 |
+
model_selector, enable_tuning, cv_folds, top_n, shap_sz],
|
| 876 |
+
outputs=[zip_output, log_output],
|
| 877 |
+
api_name="run",
|
| 878 |
+
)
|
| 879 |
+
|
| 880 |
+
# ============================================================================
|
| 881 |
+
# Authentication
|
| 882 |
+
# ============================================================================
|
| 883 |
+
from datetime import datetime
|
| 884 |
+
|
| 885 |
+
ACCOUNTS = {
|
| 886 |
+
"admin": {"password": "admin123", "expires": None},
|
| 887 |
+
"renjun": {"password": "fudan2025", "expires": "2026-12-31"},
|
| 888 |
+
"guest": {"password": "guest888", "expires": "2025-06-30"},
|
| 889 |
+
}
|
| 890 |
+
|
| 891 |
+
def auth_fn(username, password):
|
| 892 |
+
user = ACCOUNTS.get(username)
|
| 893 |
+
if not user or user["password"] != password: return False
|
| 894 |
+
if user["expires"]:
|
| 895 |
+
try:
|
| 896 |
+
if datetime.now() > datetime.strptime(user["expires"], "%Y-%m-%d"): return False
|
| 897 |
+
except: return False
|
| 898 |
+
return True
|
| 899 |
+
|
| 900 |
+
demo.queue()
|
| 901 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, auth=auth_fn,
|
| 902 |
+
auth_message="🔐 复旦大学附属眼耳鼻喉科医院 · ML多分类分析平台\n请输入账号和密码登录",
|
| 903 |
+
ssr_mode=False)
|