fudan-renjun commited on
Commit
e195614
·
verified ·
1 Parent(s): cdf9ac4

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +903 -0
app.py ADDED
@@ -0,0 +1,903 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ML Multi-Class Classification Pipeline (2-8 classes)
3
+ Eye & ENT Hospital of Fudan University — Laboratory Medicine, Ren Jun
4
+ Gradio 5.12.0 + Python 3.11
5
+ """
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import matplotlib
10
+ matplotlib.use('Agg')
11
+ import matplotlib.pyplot as plt
12
+ from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier
13
+ from sklearn.tree import DecisionTreeClassifier
14
+ from sklearn.neighbors import KNeighborsClassifier
15
+ from sklearn.linear_model import LogisticRegression
16
+ from sklearn.naive_bayes import GaussianNB
17
+ from sklearn.svm import SVC
18
+ from xgboost import XGBClassifier
19
+ from sklearn.model_selection import StratifiedKFold, GridSearchCV
20
+ from sklearn.metrics import (
21
+ roc_auc_score, confusion_matrix, roc_curve,
22
+ auc as auc_score, precision_recall_curve,
23
+ classification_report, accuracy_score, f1_score,
24
+ cohen_kappa_score
25
+ )
26
+ from sklearn.preprocessing import label_binarize
27
+ import seaborn as sns
28
+ import warnings
29
+ from scipy import stats
30
+ import os
31
+ import shap
32
+ import pickle
33
+ from copy import deepcopy
34
+ import zipfile
35
+ import tempfile
36
+ import traceback
37
+ import time
38
+ import shutil
39
+ import gc
40
+ import threading
41
+ import gradio as gr
42
+ from itertools import cycle
43
+
44
+ warnings.filterwarnings('ignore')
45
+ # Publication-quality plot settings for SCI papers
46
+ plt.rcParams['font.family'] = 'serif'
47
+ plt.rcParams['font.serif'] = ['Times New Roman', 'DejaVu Serif', 'serif']
48
+ plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans']
49
+ plt.rcParams['axes.unicode_minus'] = False
50
+ plt.rcParams['figure.dpi'] = 150
51
+ plt.rcParams['savefig.dpi'] = 300
52
+ plt.rcParams['axes.linewidth'] = 1.2
53
+ plt.rcParams['xtick.major.width'] = 1.0
54
+ plt.rcParams['ytick.major.width'] = 1.0
55
+ plt.rcParams['xtick.labelsize'] = 11
56
+ plt.rcParams['ytick.labelsize'] = 11
57
+
58
+ # ============================================================================
59
+ # Cache Cleanup
60
+ # ============================================================================
61
+ CLEANUP_MAX_AGE_MINUTES = 30
62
+ CLEANUP_INTERVAL_SECONDS = 600
63
+ CLEANUP_MAX_DISK_MB = 1024
64
+
65
+ def cleanup_old_temp_files():
66
+ now = time.time(); tmp = tempfile.gettempdir()
67
+ try:
68
+ for item in os.listdir(tmp):
69
+ p = os.path.join(tmp, item)
70
+ if item.startswith("ml_"):
71
+ age = now - os.path.getmtime(p)
72
+ if age > CLEANUP_MAX_AGE_MINUTES * 60:
73
+ if os.path.isdir(p): shutil.rmtree(p, ignore_errors=True)
74
+ elif os.path.isfile(p): os.remove(p)
75
+ except: pass
76
+ gc.collect()
77
+
78
+ def periodic_cleanup():
79
+ while True:
80
+ time.sleep(CLEANUP_INTERVAL_SECONDS)
81
+ cleanup_old_temp_files()
82
+
83
+ _ct = threading.Thread(target=periodic_cleanup, daemon=True); _ct.start()
84
+
85
+ # ============================================================================
86
+ # Multi-class Helper Functions
87
+ # ============================================================================
88
+
89
+ def multiclass_roc_auc(y_true, y_proba, classes):
90
+ """Calculate per-class and macro/micro ROC AUC for multi-class"""
91
+ n_classes = len(classes)
92
+ if n_classes == 2:
93
+ return roc_auc_score(y_true, y_proba[:, 1])
94
+ try:
95
+ return roc_auc_score(y_true, y_proba, multi_class='ovr', average='macro')
96
+ except:
97
+ return 0.0
98
+
99
+ def plot_multiclass_roc(y_true, y_proba, classes, title, filepath_prefix, rf):
100
+ """Plot ROC curves: one-vs-rest for each class + macro average"""
101
+ n_classes = len(classes)
102
+ y_bin = label_binarize(y_true, classes=classes)
103
+ if n_classes == 2:
104
+ y_bin = np.hstack([1 - y_bin, y_bin])
105
+
106
+ fpr_dict, tpr_dict, auc_dict = {}, {}, {}
107
+ for i in range(n_classes):
108
+ fpr_dict[i], tpr_dict[i], _ = roc_curve(y_bin[:, i], y_proba[:, i])
109
+ auc_dict[i] = auc_score(fpr_dict[i], tpr_dict[i])
110
+
111
+ # Macro average
112
+ all_fpr = np.unique(np.concatenate([fpr_dict[i] for i in range(n_classes)]))
113
+ mean_tpr = np.zeros_like(all_fpr)
114
+ for i in range(n_classes):
115
+ mean_tpr += np.interp(all_fpr, fpr_dict[i], tpr_dict[i])
116
+ mean_tpr /= n_classes
117
+ macro_auc = auc_score(all_fpr, mean_tpr)
118
+
119
+ COLORS = ['#e41a1c','#377eb8','#4daf4a','#984ea3','#ff7f00','#a65628','#f781bf','#999999']
120
+ plt.figure(figsize=(10, 8))
121
+ for i in range(n_classes):
122
+ plt.plot(fpr_dict[i], tpr_dict[i], color=COLORS[i % len(COLORS)], lw=2,
123
+ label=f'Class {classes[i]} (AUC={auc_dict[i]:.3f})')
124
+ plt.plot(all_fpr, mean_tpr, 'k--', lw=2.5, label=f'Macro Avg (AUC={macro_auc:.3f})')
125
+ plt.plot([0,1],[0,1],'--',color='#cccccc',lw=1)
126
+ plt.xlim([-0.02,1.02]); plt.ylim([-0.02,1.02])
127
+ plt.xlabel('False Positive Rate', fontsize=13); plt.ylabel('True Positive Rate', fontsize=13)
128
+ plt.title(title, fontsize=14, fontweight='bold')
129
+ plt.legend(loc='lower right', fontsize=9); plt.grid(True, alpha=0.15); plt.tight_layout()
130
+ plt.savefig(os.path.join(rf, f'{filepath_prefix}.pdf'), format='pdf', bbox_inches='tight', dpi=300)
131
+ plt.savefig(os.path.join(rf, f'{filepath_prefix}.png'), format='png', bbox_inches='tight', dpi=150)
132
+ plt.close()
133
+ return macro_auc, auc_dict
134
+
135
+ def plot_multiclass_pr(y_true, y_proba, classes, title, filepath_prefix, rf):
136
+ """Plot Precision-Recall curves for each class"""
137
+ n_classes = len(classes)
138
+ y_bin = label_binarize(y_true, classes=classes)
139
+ if n_classes == 2:
140
+ y_bin = np.hstack([1 - y_bin, y_bin])
141
+
142
+ COLORS = ['#e41a1c','#377eb8','#4daf4a','#984ea3','#ff7f00','#a65628','#f781bf','#999999']
143
+ plt.figure(figsize=(10, 8))
144
+ for i in range(n_classes):
145
+ prec, rec, _ = precision_recall_curve(y_bin[:, i], y_proba[:, i])
146
+ ap = auc_score(rec, prec)
147
+ plt.plot(rec, prec, color=COLORS[i % len(COLORS)], lw=2,
148
+ label=f'Class {classes[i]} (AP={ap:.3f})')
149
+ plt.xlim([-0.02,1.02]); plt.ylim([-0.02,1.02])
150
+ plt.xlabel('Recall', fontsize=13); plt.ylabel('Precision', fontsize=13)
151
+ plt.title(title, fontsize=14, fontweight='bold')
152
+ plt.legend(loc='lower left', fontsize=9); plt.grid(True, alpha=0.15); plt.tight_layout()
153
+ plt.savefig(os.path.join(rf, f'{filepath_prefix}.pdf'), format='pdf', bbox_inches='tight', dpi=300)
154
+ plt.savefig(os.path.join(rf, f'{filepath_prefix}.png'), format='png', bbox_inches='tight', dpi=150)
155
+ plt.close()
156
+
157
+ def plot_confusion_matrix(y_true, y_pred, classes, title, filepath_prefix, rf):
158
+ """Plot confusion matrix heatmap"""
159
+ cm = confusion_matrix(y_true, y_pred, labels=classes)
160
+ plt.figure(figsize=(max(6, len(classes)*1.2), max(5, len(classes)*1.0)))
161
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=True,
162
+ xticklabels=classes, yticklabels=classes, annot_kws={'fontsize': 11})
163
+ plt.xlabel('Predicted', fontsize=12); plt.ylabel('True', fontsize=12)
164
+ plt.title(title, fontsize=13, fontweight='bold'); plt.tight_layout()
165
+ plt.savefig(os.path.join(rf, f'{filepath_prefix}.pdf'), format='pdf', bbox_inches='tight', dpi=300)
166
+ plt.savefig(os.path.join(rf, f'{filepath_prefix}.png'), format='png', bbox_inches='tight', dpi=150)
167
+ plt.close()
168
+ return cm
169
+
170
+ def compute_multiclass_metrics(y_true, y_pred, y_proba, classes):
171
+ """Compute comprehensive multi-class metrics"""
172
+ n_classes = len(classes)
173
+ acc = accuracy_score(y_true, y_pred)
174
+ kappa = cohen_kappa_score(y_true, y_pred)
175
+
176
+ # Per-class from classification_report
177
+ report = classification_report(y_true, y_pred, labels=classes, output_dict=True, zero_division=0)
178
+
179
+ # AUC
180
+ try:
181
+ if n_classes == 2:
182
+ macro_auc = roc_auc_score(y_true, y_proba[:, 1])
183
+ else:
184
+ macro_auc = roc_auc_score(y_true, y_proba, multi_class='ovr', average='macro')
185
+ except:
186
+ macro_auc = 0.0
187
+
188
+ f1_macro = f1_score(y_true, y_pred, average='macro', zero_division=0)
189
+ f1_weighted = f1_score(y_true, y_pred, average='weighted', zero_division=0)
190
+
191
+ return {
192
+ 'Accuracy': acc, 'Macro_AUC': macro_auc, 'Macro_F1': f1_macro,
193
+ 'Weighted_F1': f1_weighted, 'Kappa': kappa, 'report': report
194
+ }
195
+
196
+
197
+ def bootstrap_auc_test(y_true, proba_a, proba_b, classes, n_bootstrap=2000, seed=42):
198
+ """
199
+ Bootstrap 检验:比较两个模型的 Macro AUC 是否有显著差异
200
+ 适用于多分类(DeLong 检验的替代方案)
201
+ 返回: p_value, auc_a, auc_b, ci_low, ci_high (差值的95%置信区间)
202
+ """
203
+ rng = np.random.RandomState(seed)
204
+ n = len(y_true)
205
+ n_classes = len(classes)
206
+
207
+ def calc_macro_auc(yt, pa, pb):
208
+ try:
209
+ if n_classes == 2:
210
+ a1 = roc_auc_score(yt, pa[:, 1])
211
+ a2 = roc_auc_score(yt, pb[:, 1])
212
+ else:
213
+ a1 = roc_auc_score(yt, pa, multi_class='ovr', average='macro')
214
+ a2 = roc_auc_score(yt, pb, multi_class='ovr', average='macro')
215
+ return a1, a2
216
+ except:
217
+ return 0.0, 0.0
218
+
219
+ # Observed AUC
220
+ auc_a, auc_b = calc_macro_auc(y_true, proba_a, proba_b)
221
+ observed_diff = auc_a - auc_b
222
+
223
+ # Bootstrap
224
+ diffs = []
225
+ for _ in range(n_bootstrap):
226
+ idx = rng.choice(n, n, replace=True)
227
+ yt_b = y_true[idx]; pa_b = proba_a[idx]; pb_b = proba_b[idx]
228
+ # Ensure all classes present in bootstrap sample
229
+ if len(np.unique(yt_b)) < n_classes:
230
+ continue
231
+ a1, a2 = calc_macro_auc(yt_b, pa_b, pb_b)
232
+ diffs.append(a1 - a2)
233
+
234
+ if len(diffs) < 100:
235
+ return 1.0, auc_a, auc_b, -1, 1 # Not enough valid bootstraps
236
+
237
+ diffs = np.array(diffs)
238
+ # Two-sided p-value: proportion of bootstrap diffs that cross zero
239
+ # Under H0: diff=0, we center the diffs
240
+ centered = diffs - np.mean(diffs)
241
+ p_value = np.mean(np.abs(centered) >= np.abs(observed_diff))
242
+ p_value = max(p_value, 1.0 / n_bootstrap) # Floor
243
+
244
+ ci_low = np.percentile(diffs, 2.5)
245
+ ci_high = np.percentile(diffs, 97.5)
246
+
247
+ return p_value, auc_a, auc_b, ci_low, ci_high
248
+
249
+ # ============================================================================
250
+ # Model configs (multi-class compatible)
251
+ # ============================================================================
252
+ ALL_MODEL_NAMES = ['RF', 'DT', 'KNN', 'XGB', 'AdaBoost', 'LR', 'NB', 'SVM']
253
+
254
+ def get_models_config(selected, n_classes, rs=42):
255
+ cfg = {
256
+ 'RF': {'model': RandomForestClassifier(random_state=rs, n_jobs=-1),
257
+ 'params': {'n_estimators': [100,200], 'max_depth': [20,50], 'min_samples_split': [2,5]}},
258
+ 'DT': {'model': DecisionTreeClassifier(random_state=rs),
259
+ 'params': {'max_depth': [20,50], 'min_samples_split': [2,10], 'criterion': ['gini','entropy']}},
260
+ 'KNN': {'model': KNeighborsClassifier(n_jobs=-1),
261
+ 'params': {'n_neighbors': [3,5,7], 'weights': ['uniform','distance']}},
262
+ 'XGB': {'model': XGBClassifier(random_state=rs, eval_metric='mlogloss', n_jobs=-1,
263
+ num_class=n_classes if n_classes > 2 else None,
264
+ objective='multi:softprob' if n_classes > 2 else 'binary:logistic'),
265
+ 'params': {'n_estimators': [100,200], 'max_depth': [5,7], 'learning_rate': [0.05,0.1]}},
266
+ 'AdaBoost': {'model': AdaBoostClassifier(random_state=rs),
267
+ 'params': {'n_estimators': [50,100], 'learning_rate': [0.1,0.5,1.0]}},
268
+ 'LR': {'model': LogisticRegression(random_state=rs, n_jobs=-1, max_iter=2000),
269
+ 'params': {'C': [0.1,1,10], 'solver': ['lbfgs']}},
270
+ 'NB': {'model': GaussianNB(),
271
+ 'params': {'var_smoothing': [1e-9,1e-7,1e-5]}},
272
+ 'SVM': {'model': SVC(probability=True, random_state=rs, decision_function_shape='ovr'),
273
+ 'params': {'C': [1,10], 'kernel': ['rbf','linear']}},
274
+ }
275
+ # Clean XGB None params
276
+ if n_classes <= 2:
277
+ xgb_params = cfg['XGB']['model'].get_params()
278
+ if 'num_class' in xgb_params and xgb_params['num_class'] is None:
279
+ cfg['XGB']['model'] = XGBClassifier(random_state=rs, eval_metric='logloss', n_jobs=-1)
280
+ return {k: v for k, v in cfg.items() if k in selected}
281
+
282
+
283
+ # ============================================================================
284
+ # Main Pipeline
285
+ # ============================================================================
286
+ def run_pipeline(
287
+ train_file, val_file1, val_file2, val_file3, n_classes_select,
288
+ selected_models, enable_tuning,
289
+ cv_folds, top_n_features, shap_sample_size,
290
+ progress=gr.Progress(track_tqdm=True),
291
+ ):
292
+ if train_file is None:
293
+ return None, "❌ 请先上传训练集 CSV 文件"
294
+ sel = selected_models if isinstance(selected_models, list) else [s.strip() for s in str(selected_models).split(",") if s.strip()]
295
+ if not sel:
296
+ return None, "❌ 请至少选择一个模型"
297
+
298
+ RS = 42; CVF = int(cv_folds)
299
+ TOPN = int(top_n_features); SHAPSZ = int(shap_sample_size)
300
+ TUNING = bool(enable_tuning)
301
+
302
+ L = []
303
+ def log(m): L.append(str(m))
304
+
305
+ rf = tempfile.mkdtemp(prefix="ml_")
306
+
307
+ try:
308
+ # ── Load Data ──
309
+ progress(0.02, desc="📂 加载数据...")
310
+ log("━" * 50)
311
+ log(" 🧬 ML 多分类模型训练与评估系统")
312
+ log("━" * 50)
313
+
314
+ tp = train_file if isinstance(train_file, str) else getattr(train_file, 'name', str(train_file))
315
+ data = pd.read_csv(tp)
316
+
317
+ # Smart CSV format detection
318
+ y = data.iloc[:, 0]
319
+ col2 = data.iloc[:, 1]
320
+ col2_is_id = (col2.dtype == 'object') or (col2.nunique() / len(col2) > 0.5)
321
+ if col2_is_id:
322
+ X = data.iloc[:, 2:]
323
+ log(f" 📋 CSV: Col1=Label, Col2=ID({data.columns[1]}), Col3+=Features")
324
+ else:
325
+ X = data.iloc[:, 1:]
326
+ log(f" 📋 CSV: Col1=Label, Col2+=Features (no ID column)")
327
+ fnames = X.columns.tolist()
328
+
329
+ # Parse user selection: "3 类" -> 3, "2 类(二分类)" -> 2
330
+ user_n = int(str(n_classes_select).split(" ")[0])
331
+
332
+ # Validate against actual data
333
+ detected_classes = sorted(y.unique())
334
+ detected_classes = [int(c) if hasattr(c, 'item') else c for c in detected_classes]
335
+ detected_n = len(detected_classes)
336
+
337
+ if detected_n != user_n:
338
+ return None, (f"❌ 您选择了 {user_n} 分类,但数据中检测到 {detected_n} 个类别: {detected_classes}\n"
339
+ f"请将分类数修改为 {detected_n},或检查数据标签列")
340
+
341
+ classes = detected_classes
342
+ n_classes = user_n
343
+ log(f" ✅ {n_classes} 分类 — 数据验证通过")
344
+
345
+ # Remap to 0,1,...,n-1
346
+ label_map = {c: i for i, c in enumerate(classes)}
347
+ label_map_inv = {i: c for c, i in label_map.items()}
348
+ y_mapped = y.map(label_map)
349
+ class_indices = list(range(n_classes))
350
+
351
+ log(f" 📊 训练集: {X.shape[0]} 样本 × {X.shape[1]} 特征")
352
+ log(f" 🏷️ 类别数: {n_classes} 类 — {classes}")
353
+ log(f" 📊 分布: {dict(y.value_counts().sort_index())}")
354
+ log(f" 🤖 模型: {', '.join(sel)}")
355
+ log(f" 🔧 调优: {'开启' if TUNING else '关闭'} | CV: {CVF}折")
356
+
357
+ if n_classes < 2 or n_classes > 8:
358
+ return None, f"❌ 仅支持 2~8 分类,当前检测到 {n_classes} 类"
359
+
360
+ task_type = "Binary" if n_classes == 2 else f"{n_classes}-Class"
361
+ task_type_cn = "二分类" if n_classes == 2 else f"{n_classes}分类"
362
+ log(f" 📋 任务: {task_type_cn} ({task_type})")
363
+
364
+ mcfg = get_models_config(sel, n_classes, RS)
365
+ skf = StratifiedKFold(n_splits=CVF, shuffle=True, random_state=RS)
366
+
367
+ # ── Train All Models ──
368
+ bpd = {}; amr = {}; tms = {}
369
+ total = len(mcfg)
370
+ COLORS = ['#2563eb','#f59e0b','#10b981','#ef4444','#8b5cf6','#ec4899','#06b6d4','#6b7280']
371
+
372
+ for mi, (mn, cf) in enumerate(mcfg.items()):
373
+ pv = 0.05 + 0.35 * mi / total
374
+ progress(pv, desc=f"🏋️ [{mi+1}/{total}] 训练 {mn}...")
375
+ log(f"\n{'─'*40}")
376
+ log(f" 🔄 [{mi+1}/{total}] {mn}")
377
+
378
+ Xv = X.values
379
+ if TUNING:
380
+ log(f" ⏳ GridSearchCV (CV={CVF})...")
381
+ scoring = 'roc_auc_ovr' if n_classes > 2 else 'roc_auc'
382
+ gs = GridSearchCV(cf['model'], cf['params'], cv=skf, scoring=scoring, n_jobs=-1, verbose=0)
383
+ gs.fit(Xv, y_mapped)
384
+ bp = gs.best_params_; bpd[mn] = bp
385
+ log(f" ✓ 最佳CV Score: {gs.best_score_:.4f}")
386
+ else:
387
+ bp = {}; bpd[mn] = "Default"
388
+
389
+ mdl = deepcopy(cf['model'])
390
+ if bp: mdl.set_params(**bp)
391
+ mdl.fit(Xv, y_mapped)
392
+ tms[mn] = mdl
393
+
394
+ # CV evaluation
395
+ all_yt = []; all_yp = []; all_yproba = []
396
+ fold_metrics = []
397
+
398
+ for fi, (tri, tei) in enumerate(skf.split(X, y_mapped), 1):
399
+ Xtr, Xte = X.iloc[tri].values, X.iloc[tei].values
400
+ ytr, yte = y_mapped.iloc[tri], y_mapped.iloc[tei]
401
+ mf = deepcopy(cf['model'])
402
+ if bp: mf.set_params(**bp)
403
+ mf.fit(Xtr, ytr)
404
+ ypred = mf.predict(Xte)
405
+ yproba = mf.predict_proba(Xte)
406
+
407
+ all_yt.extend(yte); all_yp.extend(ypred); all_yproba.append(yproba)
408
+
409
+ metrics = compute_multiclass_metrics(yte, ypred, yproba, class_indices)
410
+ fold_metrics.append({
411
+ 'Fold': fi, 'Accuracy': metrics['Accuracy'],
412
+ 'Macro_AUC': metrics['Macro_AUC'], 'Macro_F1': metrics['Macro_F1'],
413
+ 'Weighted_F1': metrics['Weighted_F1'], 'Kappa': metrics['Kappa']
414
+ })
415
+
416
+ all_yt = np.array(all_yt); all_yp = np.array(all_yp)
417
+ all_yproba = np.vstack(all_yproba)
418
+
419
+ fdf = pd.DataFrame(fold_metrics)
420
+ mean_row = {col: fdf[col].mean() if col != 'Fold' else 'Mean' for col in fdf.columns}
421
+ fdf = pd.concat([fdf, pd.DataFrame([mean_row])], ignore_index=True)
422
+
423
+ amr[mn] = {
424
+ 'fold_df': fdf, 'mean_auc': mean_row['Macro_AUC'],
425
+ 'mean_acc': mean_row['Accuracy'], 'mean_f1': mean_row['Macro_F1'],
426
+ 'all_yt': all_yt, 'all_yp': all_yp, 'all_yproba': all_yproba
427
+ }
428
+ log(f" ✅ AUC={mean_row['Macro_AUC']:.4f} Acc={mean_row['Accuracy']:.4f} F1={mean_row['Macro_F1']:.4f} Kappa={mean_row['Kappa']:.4f}")
429
+
430
+ mnames = list(amr.keys()); nm = len(mnames)
431
+ log(f"\n{'━'*50}")
432
+ log(f" ✅ {nm} 个模型训练完成")
433
+
434
+ # ── ROC Curves ──
435
+ progress(0.42, desc="📈 ROC曲线...")
436
+ log(f"\n 📈 绘制图表...")
437
+ for mn in mnames:
438
+ r = amr[mn]
439
+ plot_multiclass_roc(r['all_yt'], r['all_yproba'], class_indices,
440
+ f'ROC — {mn} ({task_type}, Macro AUC={r["mean_auc"]:.3f})', f'roc_{mn}', rf)
441
+
442
+ # Combined ROC (macro per model)
443
+ plt.figure(figsize=(10, 8))
444
+ for i, mn in enumerate(mnames):
445
+ r = amr[mn]
446
+ y_bin = label_binarize(r['all_yt'], classes=class_indices)
447
+ if n_classes == 2: y_bin = np.hstack([1 - y_bin, y_bin])
448
+ all_fpr = np.linspace(0, 1, 200); mean_tpr = np.zeros_like(all_fpr)
449
+ for c in range(n_classes):
450
+ f, t, _ = roc_curve(y_bin[:, c], r['all_yproba'][:, c])
451
+ mean_tpr += np.interp(all_fpr, f, t)
452
+ mean_tpr /= n_classes; mean_tpr[-1] = 1.0
453
+ ma = auc_score(all_fpr, mean_tpr)
454
+ plt.plot(all_fpr, mean_tpr, color=COLORS[i%8], lw=2.5, label=f'{mn} (Macro AUC={ma:.3f})')
455
+ plt.plot([0,1],[0,1],'--',color='#ccc',lw=1)
456
+ plt.xlim([-0.02,1.02]); plt.ylim([-0.02,1.02])
457
+ plt.xlabel('FPR',fontsize=13); plt.ylabel('TPR',fontsize=13)
458
+ plt.title(f'ROC — All Models ({task_type})',fontsize=14,fontweight='bold')
459
+ plt.legend(loc='lower right',fontsize=10); plt.grid(True,alpha=0.15); plt.tight_layout()
460
+ plt.savefig(os.path.join(rf,'roc_all.pdf'),format='pdf',bbox_inches='tight',dpi=300)
461
+ plt.savefig(os.path.join(rf,'roc_all.png'),format='png',bbox_inches='tight',dpi=150)
462
+ plt.close()
463
+
464
+ # ── PR Curves ──
465
+ progress(0.48, desc="📈 PR曲线...")
466
+ for mn in mnames:
467
+ r = amr[mn]
468
+ plot_multiclass_pr(r['all_yt'], r['all_yproba'], class_indices,
469
+ f'PR — {mn} ({task_type})', f'pr_{mn}', rf)
470
+
471
+ # ── Confusion Matrices ──
472
+ progress(0.52, desc="📊 混淆矩阵...")
473
+ for mn in mnames:
474
+ r = amr[mn]
475
+ plot_confusion_matrix(r['all_yt'], r['all_yp'], class_indices,
476
+ f'CM — {mn} (Acc={r["mean_acc"]:.3f})', f'cm_{mn}', rf)
477
+
478
+ # ── Bootstrap AUC Test (模型统计检验) ──
479
+ progress(0.55, desc="🔬 Bootstrap AUC 检验...")
480
+ best_mn = max(amr, key=lambda x: amr[x]['mean_auc'])
481
+ best_auc = amr[best_mn]['mean_auc']
482
+ log(f"\n 🏆 最佳模型: {best_mn} (Macro AUC={best_auc:.4f})")
483
+ log(f" 🔬 Bootstrap 检验 (n=2000, α=0.05)...")
484
+
485
+ ALPHA = 0.05
486
+ bootstrap_results = []
487
+ retained = [best_mn]
488
+
489
+ for om in mnames:
490
+ if om == best_mn:
491
+ continue
492
+ p_val, auc_a, auc_b, ci_lo, ci_hi = bootstrap_auc_test(
493
+ amr[best_mn]['all_yt'],
494
+ amr[best_mn]['all_yproba'],
495
+ amr[om]['all_yproba'],
496
+ class_indices, n_bootstrap=2000
497
+ )
498
+ if p_val >= ALPHA:
499
+ retained.append(om)
500
+ dec = "Retained"
501
+ else:
502
+ dec = "Excluded"
503
+
504
+ bootstrap_results.append({
505
+ 'Model_A': best_mn, 'AUC_A': auc_a,
506
+ 'Model_B': om, 'AUC_B': auc_b,
507
+ 'AUC_Diff': auc_a - auc_b,
508
+ 'CI_95_Low': ci_lo, 'CI_95_High': ci_hi,
509
+ 'P_value': p_val, 'Decision': dec
510
+ })
511
+ log(f" {best_mn} vs {om}: ΔAUC={auc_a-auc_b:+.4f} 95%CI=[{ci_lo:+.4f},{ci_hi:+.4f}] P={p_val:.4f} → {dec}")
512
+
513
+ bootstrap_df = pd.DataFrame(bootstrap_results).sort_values('P_value', ascending=False) if bootstrap_results else pd.DataFrame()
514
+ log(f" ✅ 保留 {len(retained)}/{nm} 个模型: {', '.join(retained)}")
515
+
516
+ # ── SHAP ──
517
+ progress(0.62, desc="🔥 SHAP分析...")
518
+ log(f"\n 🔥 SHAP特征分析 (保留模型中 Top 3)...")
519
+ shap_imp = {}
520
+ # SHAP for top 3 retained models
521
+ models_for_shap = sorted(retained, key=lambda x: amr[x]['mean_auc'], reverse=True)[:3]
522
+
523
+ for si, mn in enumerate(models_for_shap):
524
+ progress(0.62 + 0.10 * si / max(len(models_for_shap),1), desc=f"🔥 SHAP: {mn}...")
525
+ mo = tms[mn]; Xshap = X.values
526
+ ns = min(SHAPSZ, Xshap.shape[0])
527
+ np.random.seed(RS); sidx = np.random.choice(Xshap.shape[0], ns, replace=False)
528
+ Xs = Xshap[sidx]
529
+ try:
530
+ if mn in ['RF', 'XGB', 'DT', 'AdaBoost']:
531
+ exp = shap.TreeExplainer(mo); sv = exp.shap_values(Xs)
532
+ else:
533
+ bg = Xs[np.random.choice(ns, min(50, ns), replace=False)]
534
+ exp = shap.KernelExplainer(lambda x, m=mo: m.predict_proba(x), bg)
535
+ sv = exp.shap_values(Xs)
536
+
537
+ # Handle SHAP output: could be list of arrays (one per class) or 3D array
538
+ if isinstance(sv, list):
539
+ # Average absolute SHAP across all classes
540
+ sv_abs = np.mean([np.abs(s) for s in sv], axis=0)
541
+ elif sv.ndim == 3:
542
+ sv_abs = np.mean(np.abs(sv), axis=2) # (samples, features)
543
+ else:
544
+ sv_abs = np.abs(sv)
545
+
546
+ fi = sv_abs.mean(axis=0)
547
+ if len(fi) > len(fnames): fi = fi[:len(fnames)]
548
+ elif len(fi) < len(fnames): fi = np.pad(fi, (0, len(fnames) - len(fi)))
549
+
550
+ idf = pd.DataFrame({'Feature': fnames, 'Importance': fi}).sort_values('Importance', ascending=False)
551
+ shap_imp[mn] = idf
552
+
553
+ # Bar plot (works for any number of classes)
554
+ plt.figure(figsize=(10, max(6, TOPN * 0.3)))
555
+ top_df = idf.head(TOPN).iloc[::-1]
556
+ plt.barh(top_df['Feature'], top_df['Importance'], color='#2563eb', alpha=0.8)
557
+ plt.xlabel('Mean |SHAP|', fontsize=12)
558
+ plt.title(f'SHAP Feature Importance — {mn} (Top {TOPN})', fontsize=13, fontweight='bold')
559
+ plt.tight_layout()
560
+ plt.savefig(os.path.join(rf, f'shap_{mn}.pdf'), format='pdf', bbox_inches='tight')
561
+ plt.savefig(os.path.join(rf, f'shap_{mn}.png'), format='png', bbox_inches='tight', dpi=150)
562
+ plt.close()
563
+ log(f" ✅ {mn} Top3: {', '.join(idf.head(3)['Feature'].tolist())}")
564
+ except Exception as e:
565
+ log(f" ⚠ {mn} SHAP失败: {e}")
566
+
567
+ # ── Feature Ablation (for best model only) ──
568
+ progress(0.72, desc="🧪 特征消融...")
569
+ log(f"\n 🧪 特征消融 (仅最佳模型 {best_mn})...")
570
+ ablation_data = None
571
+ if best_mn in shap_imp:
572
+ imp_df = shap_imp[best_mn]
573
+ top_feats = imp_df.head(TOPN)['Feature'].tolist()
574
+ fcs = []; aucs_a = []
575
+ scoring = 'roc_auc_ovr' if n_classes > 2 else 'roc_auc'
576
+
577
+ for nf in range(1, len(top_feats) + 1):
578
+ Xsub = X[top_feats[:nf]]
579
+ fold_aucs = []
580
+ for tri, tei in skf.split(Xsub, y_mapped):
581
+ mf = deepcopy(mcfg[best_mn]['model'])
582
+ bp2 = bpd.get(best_mn, {})
583
+ if isinstance(bp2, dict) and bp2: mf.set_params(**bp2)
584
+ mf.fit(Xsub.iloc[tri].values, y_mapped.iloc[tri])
585
+ yproba_f = mf.predict_proba(Xsub.iloc[tei].values)
586
+ yte_f = y_mapped.iloc[tei]
587
+ try:
588
+ if n_classes == 2:
589
+ a = roc_auc_score(yte_f, yproba_f[:, 1])
590
+ else:
591
+ a = roc_auc_score(yte_f, yproba_f, multi_class='ovr', average='macro')
592
+ except: a = 0.0
593
+ fold_aucs.append(a)
594
+ fcs.append(nf); aucs_a.append(np.mean(fold_aucs))
595
+
596
+ # Find optimal: first N where AUC >= 95% of full AUC
597
+ full_auc = amr[best_mn]['mean_auc']
598
+ opt_n = len(top_feats)
599
+ for i, a in enumerate(aucs_a):
600
+ if a >= full_auc * 0.95:
601
+ opt_n = i + 1; break
602
+
603
+ ablation_data = {'fcs': fcs, 'aucs': aucs_a, 'feats': top_feats, 'opt_n': opt_n, 'opt_feats': top_feats[:opt_n]}
604
+ log(f" ✅ 最优特征数: {opt_n} (AUC={aucs_a[opt_n-1]:.4f} vs Full={full_auc:.4f})")
605
+
606
+ # Plot
607
+ plt.figure(figsize=(10, 7))
608
+ plt.plot(fcs, aucs_a, 'o-', color='#2563eb', lw=2, ms=5)
609
+ plt.scatter([opt_n], [aucs_a[opt_n-1]], s=200, marker='*', color='#ef4444', edgecolors='black', lw=2, zorder=5)
610
+ plt.axhline(y=full_auc, color='gray', ls='--', lw=1, alpha=0.5, label=f'Full AUC={full_auc:.3f}')
611
+ plt.xlabel('Number of Features', fontsize=13); plt.ylabel('Macro AUC', fontsize=13)
612
+ plt.title(f'Feature Ablation — {best_mn} (★ Optimal={opt_n})', fontsize=14, fontweight='bold')
613
+ plt.legend(fontsize=11); plt.grid(True, alpha=0.15); plt.tight_layout()
614
+ plt.savefig(os.path.join(rf, 'ablation.pdf'), format='pdf', bbox_inches='tight')
615
+ plt.savefig(os.path.join(rf, 'ablation.png'), format='png', bbox_inches='tight', dpi=150)
616
+ plt.close()
617
+
618
+ # ── External Validation ──
619
+ val_files_list = [vf for vf in [val_file1, val_file2, val_file3] if vf is not None]
620
+ final_feats = ablation_data['opt_feats'] if ablation_data else fnames
621
+
622
+ if val_files_list:
623
+ progress(0.82, desc="🧪 外部验证...")
624
+ log(f"\n{'━'*50}")
625
+ log(f" 🧪 外部验证 ({len(val_files_list)} 个验证集)")
626
+
627
+ for vi, vf in enumerate(val_files_list, 1):
628
+ vp = vf if isinstance(vf, str) else getattr(vf, 'name', str(vf))
629
+ ed = pd.read_csv(vp); ye_raw = ed.iloc[:, 0]
630
+ vcol2 = ed.iloc[:, 1]
631
+ vcol2_is_id = (vcol2.dtype == 'object') or (vcol2.nunique() / len(vcol2) > 0.5)
632
+ Xe = ed.iloc[:, 2:] if vcol2_is_id else ed.iloc[:, 1:]
633
+
634
+ # Map validation labels using same mapping
635
+ ye = ye_raw.map(label_map)
636
+ if ye.isna().any():
637
+ log(f" ⚠ 验证集 {vi} 含有训练集中不存在的标签,已跳过")
638
+ continue
639
+
640
+ log(f"\n 📊 验证集 {vi}: {Xe.shape[0]} 样本, {os.path.basename(vp)}")
641
+
642
+ Xes = Xe[final_feats]; Xtf = X[final_feats]
643
+ fm = deepcopy(mcfg[best_mn]['model'])
644
+ bp3 = bpd[best_mn]
645
+ if isinstance(bp3, dict) and bp3: fm.set_params(**bp3)
646
+ fm.fit(Xtf.values, y_mapped)
647
+ yep = fm.predict_proba(Xes.values); yed = fm.predict(Xes.values)
648
+ ye_np = ye.values
649
+
650
+ metrics = compute_multiclass_metrics(ye_np, yed, yep, class_indices)
651
+ log(f" ✅ AUC={metrics['Macro_AUC']:.4f} Acc={metrics['Accuracy']:.4f} F1={metrics['Macro_F1']:.4f} Kappa={metrics['Kappa']:.4f}")
652
+
653
+ sfx = f'_ext{vi}' if len(val_files_list) > 1 else '_ext'
654
+ tag = f'Validation {vi}' if len(val_files_list) > 1 else 'External'
655
+
656
+ plot_multiclass_roc(ye_np, yep, class_indices, f'ROC — {tag} ({best_mn})', f'roc{sfx}', rf)
657
+ plot_multiclass_pr(ye_np, yep, class_indices, f'PR — {tag} ({best_mn})', f'pr{sfx}', rf)
658
+ plot_confusion_matrix(ye_np, yed, class_indices, f'CM — {tag} ({best_mn})', f'cm{sfx}', rf)
659
+
660
+ with pd.ExcelWriter(os.path.join(rf, f'validation{sfx}.xlsx'), engine='openpyxl') as w:
661
+ pd.DataFrame([{'Model': best_mn, 'N_Features': len(final_feats),
662
+ 'Macro_AUC': metrics['Macro_AUC'], 'Accuracy': metrics['Accuracy'],
663
+ 'Macro_F1': metrics['Macro_F1'], 'Weighted_F1': metrics['Weighted_F1'],
664
+ 'Kappa': metrics['Kappa']}]).to_excel(w, sheet_name='Metrics', index=False)
665
+ rpt = pd.DataFrame(metrics['report']).T
666
+ rpt.to_excel(w, sheet_name='Per_Class', index=True)
667
+ pd.DataFrame({'Feature': final_feats}).to_excel(w, sheet_name='Features', index=False)
668
+
669
+ # ── Save Results ──
670
+ progress(0.92, desc="💾 保存结果...")
671
+ log(f"\n 💾 保存结果...")
672
+
673
+ with pd.ExcelWriter(os.path.join(rf, 'model_evaluation.xlsx'), engine='openpyxl') as w:
674
+ for mn, r in amr.items():
675
+ r['fold_df'].to_excel(w, sheet_name=mn, index=False)
676
+ # Summary with retained status
677
+ sd = [{'Model': mn, 'Macro_AUC': r['mean_auc'], 'Accuracy': r['mean_acc'],
678
+ 'Macro_F1': r['mean_f1'], 'Retained': 'Yes' if mn in retained else 'No',
679
+ 'Best': 'Best' if mn == best_mn else ''}
680
+ for mn, r in amr.items()]
681
+ pd.DataFrame(sd).sort_values('Macro_AUC', ascending=False).to_excel(w, sheet_name='Summary', index=False)
682
+ # Bootstrap test results
683
+ if len(bootstrap_df) > 0:
684
+ bootstrap_df.to_excel(w, sheet_name='Bootstrap_Test', index=False)
685
+ # Per-class report for best model
686
+ best_report = classification_report(amr[best_mn]['all_yt'], amr[best_mn]['all_yp'],
687
+ labels=class_indices, output_dict=True, zero_division=0)
688
+ pd.DataFrame(best_report).T.to_excel(w, sheet_name=f'{best_mn}_PerClass', index=True)
689
+
690
+ if ablation_data:
691
+ with pd.ExcelWriter(os.path.join(rf, 'feature_ablation.xlsx'), engine='openpyxl') as w:
692
+ pd.DataFrame({'N': ablation_data['fcs'], 'AUC': ablation_data['aucs']}).to_excel(w, sheet_name='Ablation', index=False)
693
+ for mn, idf in shap_imp.items():
694
+ idf.to_excel(w, sheet_name=f'{mn}_Imp', index=False)
695
+
696
+ # Save params (English for SCI)
697
+ with open(os.path.join(rf, 'best_params.txt'), 'w', encoding='utf-8') as f:
698
+ f.write(f"Task: {task_type} Classification ({n_classes} classes)\n")
699
+ f.write(f"Classes: {classes}\n")
700
+ f.write(f"Label Mapping: {label_map}\n\n")
701
+ f.write(f"Statistical Test: Bootstrap AUC Test (n=2000, alpha=0.05)\n")
702
+ f.write(f"Retained Models: {', '.join(retained)} ({len(retained)}/{nm})\n\n")
703
+ for mn in mcfg:
704
+ status = "* Best" if mn == best_mn else ("Retained" if mn in retained else "Excluded")
705
+ f.write(f"Model: {mn} | AUC={amr[mn]['mean_auc']:.4f} | {status}\n")
706
+ bp = bpd[mn]
707
+ if isinstance(bp, dict):
708
+ for k, v in bp.items(): f.write(f" {k}: {v}\n")
709
+ else: f.write(f" {bp}\n")
710
+ f.write("\n")
711
+ if len(bootstrap_df) > 0:
712
+ f.write("\n" + "=" * 50 + "\n")
713
+ f.write("Bootstrap AUC Comparison Results\n")
714
+ f.write("=" * 50 + "\n")
715
+ for _, row in bootstrap_df.iterrows():
716
+ f.write(f" {row['Model_A']} vs {row['Model_B']}: ")
717
+ f.write(f"dAUC={row['AUC_Diff']:+.4f} 95%CI=[{row['CI_95_Low']:+.4f},{row['CI_95_High']:+.4f}] ")
718
+ f.write(f"P={row['P_value']:.4f} -> {row['Decision']}\n")
719
+ if ablation_data:
720
+ f.write(f"\nOptimal Features ({ablation_data['opt_n']}): {', '.join(ablation_data['opt_feats'])}\n")
721
+
722
+ # Save model
723
+ pickle.dump({
724
+ 'model_name': best_mn, 'model': tms[best_mn], 'best_params': bpd[best_mn],
725
+ 'classes': classes, 'n_classes': n_classes, 'label_map': label_map,
726
+ 'features': final_feats, 'task_type': task_type
727
+ }, open(os.path.join(rf, f'model_{best_mn}.pkl'), 'wb'))
728
+
729
+ # ── ZIP ──
730
+ progress(0.97, desc="📦 打包ZIP...")
731
+ zp = os.path.join(tempfile.gettempdir(), f"ml_results_{int(time.time())}_{os.getpid()}.zip")
732
+ with zipfile.ZipFile(zp, 'w', zipfile.ZIP_DEFLATED) as zf:
733
+ for root, _, files in os.walk(rf):
734
+ for fn in files: zf.write(os.path.join(root, fn), os.path.relpath(os.path.join(root, fn), rf))
735
+
736
+ nf = sum(len(f) for _, _, f in os.walk(rf))
737
+ shutil.rmtree(rf, ignore_errors=True); gc.collect()
738
+
739
+ log(f"\n{'━'*50}")
740
+ log(f" 🎉 分析完成!共 {nf} 个文件已打包")
741
+ log(f" 📋 Task: {task_type} | Best Model: {best_mn}")
742
+ log(f"{'━'*50}")
743
+ progress(1.0, desc="✅ 完成!")
744
+ return zp, "\n".join(L)
745
+
746
+ except Exception as e:
747
+ log(f"\n❌ 错误: {e}")
748
+ log(traceback.format_exc())
749
+ if os.path.exists(rf): shutil.rmtree(rf, ignore_errors=True)
750
+ gc.collect()
751
+ return None, "\n".join(L)
752
+
753
+
754
+ # ============================================================================
755
+ # Gradio UI
756
+ # ============================================================================
757
+ CUSTOM_CSS = """
758
+ .header-banner {
759
+ background: linear-gradient(135deg, #0a2463 0%, #1e3a7a 40%, #2554a8 100%);
760
+ border-radius: 16px; padding: 28px 36px; margin-bottom: 20px;
761
+ box-shadow: 0 8px 32px rgba(0,0,0,0.18); position: relative; overflow: hidden;
762
+ }
763
+ .header-banner::before {
764
+ content: ''; position: absolute; top: -50%; right: -20%;
765
+ width: 400px; height: 400px;
766
+ background: radial-gradient(circle, rgba(96,165,250,0.2) 0%, transparent 70%);
767
+ border-radius: 50%;
768
+ }
769
+ .header-banner img { max-height: 52px; border-radius: 6px; margin-bottom: 12px; }
770
+ .header-banner h1 { color: #e2e8f0 !important; font-size: 1.7em !important; margin: 4px 0 6px 0 !important; font-weight: 700 !important; }
771
+ .header-banner p { color: #94a3b8 !important; font-size: 0.92em !important; margin: 2px 0 !important; line-height: 1.6; }
772
+ .header-banner .credit { color: #64748b !important; font-size: 0.82em !important; margin-top: 10px !important; border-top: 1px solid rgba(148,163,184,0.15); padding-top: 10px; }
773
+ .section-title {
774
+ background: linear-gradient(90deg, #2563eb 0%, #3b82f6 100%);
775
+ color: white !important; padding: 8px 16px; border-radius: 8px;
776
+ font-size: 0.95em !important; font-weight: 600 !important; margin: 12px 0 8px 0;
777
+ }
778
+ .pipeline-box {
779
+ background: linear-gradient(135deg, #f0f9ff 0%, #e0f2fe 100%);
780
+ border: 1px solid #bae6fd; border-radius: 12px; padding: 14px 18px; margin: 8px 0; font-size: 0.88em;
781
+ }
782
+ .pipeline-box code { background: #2563eb; color: white; padding: 2px 8px; border-radius: 4px; font-size: 0.85em; margin: 0 2px; }
783
+ .log-area textarea {
784
+ font-family: 'Menlo','Consolas',monospace !important; font-size: 12.5px !important; line-height: 1.5 !important;
785
+ background: #0f172a !important; color: #e2e8f0 !important; border-radius: 10px !important; padding: 16px !important;
786
+ }
787
+ .gradio-container { max-width: 1280px !important; }
788
+ footer { display: none !important; }
789
+ """
790
+
791
+ with gr.Blocks(
792
+ title="ML 多分类模型平台 — 复旦大学附属眼耳鼻喉科医院",
793
+ theme=gr.themes.Soft(primary_hue="blue", secondary_hue="slate", neutral_hue="slate"),
794
+ css=CUSTOM_CSS,
795
+ ) as demo:
796
+
797
+ gr.HTML("""
798
+ <div class="header-banner">
799
+ <img src="https://huggingface.co/spaces/fudan-renjun/machine-learning-2/resolve/main/hospital_logo.png"
800
+ alt="Logo" onerror="this.style.display='none'"/>
801
+ <h1>🧬 ML 多分类模型训练与评估平台</h1>
802
+ <p>支持 2~8 分类 · 上传 CSV 即可完成全流程分析</p>
803
+ <p class="credit">复旦大学附属眼耳鼻喉科医院 · 检验科 · 任俊</p>
804
+ </div>
805
+ """)
806
+
807
+ gr.HTML("""
808
+ <div class="pipeline-box">
809
+ <strong>📋 流程:</strong>
810
+ <code>选择分类数</code> → <code>模型训练</code> → <code>交叉验证</code> →
811
+ <code>SHAP分析</code> → <code>特征消融</code> → <code>外部验证</code>
812
+ &nbsp;&nbsp;|&nbsp;&nbsp;
813
+ <strong>CSV格式:</strong> 第1列=标签(整数), 第2列=ID, 第3列起=特征
814
+ &nbsp;&nbsp;|&nbsp;&nbsp;
815
+ <strong>支持:</strong> 2分类 / 3分类 / ... / 8分类
816
+ </div>
817
+ """)
818
+
819
+ with gr.Row(equal_height=False):
820
+ with gr.Column(scale=5):
821
+ gr.HTML('<div class="section-title">📂 数据上传</div>')
822
+ train_file = gr.File(label="训练集 CSV(必需)", file_types=[".csv"])
823
+ gr.HTML('<p style="color:#64748b;font-size:0.85em;margin:4px 0 8px 0;">验证集可选,支持同时上传 1~3 个,分别生成独立报告</p>')
824
+ with gr.Row():
825
+ val_file1 = gr.File(label="验证集 1(可选)", file_types=[".csv"], scale=1)
826
+ val_file2 = gr.File(label="验证集 2(可选)", file_types=[".csv"], scale=1)
827
+ val_file3 = gr.File(label="验证集 3(可选)", file_types=[".csv"], scale=1)
828
+
829
+ gr.HTML('<div class="section-title">🏷️ 分类设置</div>')
830
+ n_classes_select = gr.Dropdown(
831
+ choices=["2 类(二分类)", "3 类", "4 类", "5 类", "6 类", "7 类", "8 类"],
832
+ value="2 类(二分类)",
833
+ label="选择分类数",
834
+ info="请根据数据标签列的类别数选择,系统将自动验证是否匹配",
835
+ )
836
+
837
+ gr.HTML('<div class="section-title">🤖 模型选择</div>')
838
+ model_selector = gr.Dropdown(
839
+ choices=ALL_MODEL_NAMES, value=ALL_MODEL_NAMES, multiselect=True,
840
+ label="选择模型(均支持多分类)",
841
+ info="RF=随机森林 DT=决策树 KNN=K近邻 XGB=XGBoost AdaBoost LR=逻辑回归 NB=朴素贝叶斯 SVM=支持向量机",
842
+ )
843
+ with gr.Row():
844
+ btn_all = gr.Button("🔘 全选", size="sm", variant="secondary")
845
+ btn_tree = gr.Button("🌲 树模型", size="sm", variant="secondary")
846
+ btn_linear = gr.Button("📐 线性模型", size="sm", variant="secondary")
847
+ btn_top4 = gr.Button("⚡ 经典四模型", size="sm", variant="secondary")
848
+ btn_all.click(lambda: ALL_MODEL_NAMES, outputs=model_selector)
849
+ btn_tree.click(lambda: ['RF','DT','XGB','AdaBoost'], outputs=model_selector)
850
+ btn_linear.click(lambda: ['LR','SVM','NB'], outputs=model_selector)
851
+ btn_top4.click(lambda: ['RF','XGB','LR','SVM'], outputs=model_selector)
852
+
853
+ gr.HTML('<div class="section-title">⚙️ 参数配置</div>')
854
+ enable_tuning = gr.Checkbox(value=False, label="启用超参数调优 (GridSearchCV) ⚠️ 开启后运行时间显著增加")
855
+ with gr.Row():
856
+ cv_folds = gr.Slider(3, 10, value=5, step=1, label="交叉验证折数")
857
+ top_n = gr.Slider(5, 50, value=20, step=1, label="SHAP 前 N 个特征")
858
+ shap_sz = gr.Slider(30, 200, value=80, step=10, label="SHAP 采样数量")
859
+
860
+ run_btn = gr.Button("🚀 开始分析", variant="primary", size="lg")
861
+
862
+ with gr.Column(scale=5):
863
+ gr.HTML('<div class="section-title">📋 运行日志</div>')
864
+ log_output = gr.Textbox(
865
+ label="", lines=24, max_lines=50, interactive=False,
866
+ placeholder="点击「开始分析」后,日志将在此显示...\n支持 2~8 分类。",
867
+ elem_classes="log-area",
868
+ )
869
+ gr.HTML('<div class="section-title">⬇️ 结果下载</div>')
870
+ zip_output = gr.File(label="分析结果 ZIP 压缩包")
871
+
872
+ run_btn.click(
873
+ fn=run_pipeline,
874
+ inputs=[train_file, val_file1, val_file2, val_file3, n_classes_select,
875
+ model_selector, enable_tuning, cv_folds, top_n, shap_sz],
876
+ outputs=[zip_output, log_output],
877
+ api_name="run",
878
+ )
879
+
880
+ # ============================================================================
881
+ # Authentication
882
+ # ============================================================================
883
+ from datetime import datetime
884
+
885
+ ACCOUNTS = {
886
+ "admin": {"password": "admin123", "expires": None},
887
+ "renjun": {"password": "fudan2025", "expires": "2026-12-31"},
888
+ "guest": {"password": "guest888", "expires": "2025-06-30"},
889
+ }
890
+
891
+ def auth_fn(username, password):
892
+ user = ACCOUNTS.get(username)
893
+ if not user or user["password"] != password: return False
894
+ if user["expires"]:
895
+ try:
896
+ if datetime.now() > datetime.strptime(user["expires"], "%Y-%m-%d"): return False
897
+ except: return False
898
+ return True
899
+
900
+ demo.queue()
901
+ demo.launch(server_name="0.0.0.0", server_port=7860, auth=auth_fn,
902
+ auth_message="🔐 复旦大学附属眼耳鼻喉科医院 · ML多分类分析平台\n请输入账号和密码登录",
903
+ ssr_mode=False)