fudan-renjun commited on
Commit
fc27abe
·
verified ·
1 Parent(s): 3116aa0

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +921 -0
app.py ADDED
@@ -0,0 +1,921 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ML Binary Classification Pipeline
3
+ Eye & ENT Hospital of Fudan University — Laboratory Medicine, Ren Jun
4
+ Gradio 5.12.0 + Python 3.11
5
+ """
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import matplotlib
10
+ matplotlib.use('Agg')
11
+ import matplotlib.pyplot as plt
12
+ from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier
13
+ from sklearn.tree import DecisionTreeClassifier
14
+ from sklearn.neighbors import KNeighborsClassifier
15
+ from sklearn.linear_model import LogisticRegression
16
+ from sklearn.naive_bayes import GaussianNB
17
+ from sklearn.svm import SVC
18
+ from xgboost import XGBClassifier
19
+ from sklearn.model_selection import StratifiedKFold, GridSearchCV
20
+ from sklearn.metrics import (
21
+ roc_auc_score, confusion_matrix, roc_curve,
22
+ auc as auc_score, precision_recall_curve
23
+ )
24
+ import seaborn as sns
25
+ import warnings
26
+ from scipy import stats
27
+ import os
28
+ import shap
29
+ import pickle
30
+ from copy import deepcopy
31
+ import zipfile
32
+ import tempfile
33
+ import traceback
34
+ import time
35
+ import shutil
36
+ import gc
37
+ import threading
38
+ import gradio as gr
39
+
40
+ warnings.filterwarnings('ignore')
41
+ plt.rcParams['font.sans-serif'] = ['DejaVu Sans']
42
+ plt.rcParams['axes.unicode_minus'] = False
43
+
44
+ # ============================================================================
45
+ # Cache Cleanup System
46
+ # ============================================================================
47
+ CLEANUP_MAX_AGE_MINUTES = 30 # 临时文件超过30分钟自动删除
48
+ CLEANUP_INTERVAL_SECONDS = 600 # 每10分钟检查一次
49
+ CLEANUP_MAX_DISK_MB = 1024 # /tmp 中 ml_ 文件夹超过 1GB 时强制清理
50
+
51
+ def cleanup_old_temp_files():
52
+ """清理超时的临时文件夹和ZIP"""
53
+ now = time.time()
54
+ max_age = CLEANUP_MAX_AGE_MINUTES * 60
55
+ cleaned_dirs = 0
56
+ cleaned_mb = 0.0
57
+ tmp_dir = tempfile.gettempdir()
58
+
59
+ try:
60
+ for item in os.listdir(tmp_dir):
61
+ item_path = os.path.join(tmp_dir, item)
62
+ # 清理 ml_ 开头的结果文件夹
63
+ if item.startswith("ml_") and os.path.isdir(item_path):
64
+ age = now - os.path.getmtime(item_path)
65
+ if age > max_age:
66
+ size = sum(os.path.getsize(os.path.join(r, f))
67
+ for r, _, fs in os.walk(item_path) for f in fs)
68
+ shutil.rmtree(item_path, ignore_errors=True)
69
+ cleaned_dirs += 1
70
+ cleaned_mb += size / (1024 * 1024)
71
+ # 清理旧的 ZIP 结果文件
72
+ if item.startswith("ml_") and item.endswith(".zip") and os.path.isfile(item_path):
73
+ age = now - os.path.getmtime(item_path)
74
+ if age > max_age:
75
+ size = os.path.getsize(item_path)
76
+ os.remove(item_path)
77
+ cleaned_mb += size / (1024 * 1024)
78
+ except Exception:
79
+ pass
80
+
81
+ # 强制回收 Python 内存
82
+ gc.collect()
83
+
84
+ if cleaned_dirs > 0:
85
+ print(f"[Cleanup] 清理 {cleaned_dirs} 个临时文件夹, 释放 {cleaned_mb:.1f} MB")
86
+
87
+
88
+ def check_disk_pressure():
89
+ """检查磁盘压力,超限时立即清理所有旧文件"""
90
+ tmp_dir = tempfile.gettempdir()
91
+ total_mb = 0
92
+ try:
93
+ for item in os.listdir(tmp_dir):
94
+ item_path = os.path.join(tmp_dir, item)
95
+ if item.startswith("ml_"):
96
+ if os.path.isdir(item_path):
97
+ total_mb += sum(os.path.getsize(os.path.join(r, f))
98
+ for r, _, fs in os.walk(item_path) for f in fs) / (1024*1024)
99
+ elif os.path.isfile(item_path):
100
+ total_mb += os.path.getsize(item_path) / (1024*1024)
101
+ except Exception:
102
+ pass
103
+
104
+ if total_mb > CLEANUP_MAX_DISK_MB:
105
+ print(f"[Cleanup] 磁盘占用 {total_mb:.0f}MB > {CLEANUP_MAX_DISK_MB}MB, 强制清理!")
106
+ for item in os.listdir(tmp_dir):
107
+ item_path = os.path.join(tmp_dir, item)
108
+ if item.startswith("ml_"):
109
+ try:
110
+ if os.path.isdir(item_path): shutil.rmtree(item_path, ignore_errors=True)
111
+ elif os.path.isfile(item_path): os.remove(item_path)
112
+ except: pass
113
+ gc.collect()
114
+
115
+
116
+ def periodic_cleanup():
117
+ """后台定时清理线程"""
118
+ while True:
119
+ time.sleep(CLEANUP_INTERVAL_SECONDS)
120
+ cleanup_old_temp_files()
121
+ check_disk_pressure()
122
+
123
+ # 启动后台清理线程
124
+ _cleanup_thread = threading.Thread(target=periodic_cleanup, daemon=True)
125
+ _cleanup_thread.start()
126
+ print("[Cleanup] 后台自动清理已启动 (每10分钟检查, 30分钟过期, 上限500MB)")
127
+
128
+ # ============================================================================
129
+ # Helper Functions
130
+ # ============================================================================
131
+ def compute_midrank(x):
132
+ J = np.argsort(x); Z = x[J]; N = len(x)
133
+ T = np.zeros(N, dtype=float); i = 0
134
+ while i < N:
135
+ j = i
136
+ while j < N and Z[j] == Z[i]: j += 1
137
+ T[i:j] = 0.5 * (i + j - 1); i = j
138
+ T2 = np.empty(N, dtype=float); T2[J] = T + 1
139
+ return T2
140
+
141
+ def fastDeLong(pst, m):
142
+ n = pst.shape[1] - m; k = pst.shape[0]
143
+ tx = np.empty([k, m]); ty = np.empty([k, n]); tz = np.empty([k, m + n])
144
+ for r in range(k):
145
+ tx[r] = compute_midrank(pst[r, :m]); ty[r] = compute_midrank(pst[r, m:])
146
+ tz[r] = compute_midrank(pst[r])
147
+ aucs = tz[:, :m].sum(1) / m / n - (m + 1.0) / 2.0 / n
148
+ v01 = (tz[:, :m] - tx) / n; v10 = 1.0 - (tz[:, m:] - ty) / m
149
+ return aucs, np.cov(v01) / m + np.cov(v10) / n
150
+
151
+ def delong_roc_test(gt, p1, p2):
152
+ order = (-gt).argsort(); m = int(gt.sum())
153
+ pst = np.vstack([p1, p2])[:, order]
154
+ aucs, cov = fastDeLong(pst, m)
155
+ l = np.array([[1, -1]])
156
+ z = np.abs(np.diff(aucs)) / np.sqrt(np.dot(np.dot(l, cov), l.T))
157
+ log10p = np.log10(2) + stats.norm.logsf(z, 0, 1) / np.log(10)
158
+ return 10 ** log10p[0][0], aucs[0], aucs[1]
159
+
160
+ def find_optimal_threshold(y_true, y_probs, method='youden'):
161
+ fpr, tpr, th = roc_curve(y_true, y_probs)
162
+ idx = np.argmax(tpr - fpr)
163
+ return th[idx], (tpr - fpr)[idx], idx
164
+
165
+ def calculate_net_benefit(y_true, y_probs, threshold):
166
+ yp = (y_probs >= threshold).astype(int)
167
+ tn, fp, fn, tp = confusion_matrix(y_true, yp).ravel()
168
+ n = len(y_true)
169
+ return (tp / n) - (fp / n) * (threshold / (1 - threshold))
170
+
171
+ def plot_dca(y_true, y_probs_dict, title, save_prefix, result_dir, final_model=None):
172
+ """绘制标准临床DCA曲线(类似R语言rmda包格式)"""
173
+ prevalence = np.mean(y_true)
174
+ max_thr = min(0.99, max(prevalence * 3, 0.6)) if prevalence < 0.5 else 0.9
175
+ thresholds = np.linspace(0.01, max_thr, 200)
176
+
177
+ plt.figure(figsize=(10, 7))
178
+
179
+ # Treat All
180
+ ta_nb = [prevalence - (1 - prevalence) * (pt / (1 - pt)) for pt in thresholds]
181
+ plt.plot(thresholds, ta_nb, 'k-', lw=1.5, label='Treat All')
182
+
183
+ # Treat None (y=0)
184
+ plt.axhline(y=0, color='#555555', lw=1.5, linestyle='-', label='Treat None')
185
+
186
+ # Model curves
187
+ DCA_COLORS = ['#e41a1c','#377eb8','#4daf4a','#984ea3','#ff7f00','#a65628','#f781bf','#999999']
188
+ for idx, (mn, yp) in enumerate(y_probs_dict.items()):
189
+ nbs = [calculate_net_benefit(y_true, yp, t) for t in thresholds]
190
+ lbl = f'{mn} (Final)' if mn == final_model else mn
191
+ plt.plot(thresholds, nbs, color=DCA_COLORS[idx % len(DCA_COLORS)], lw=2, label=lbl)
192
+
193
+ # Y-axis: clinical range
194
+ y_min = max(min(ta_nb), -0.05) - 0.01
195
+ y_max = max(prevalence * 1.5, 0.15)
196
+ plt.xlim([0, max_thr]); plt.ylim([y_min, y_max])
197
+ plt.xlabel('Threshold Probability', fontsize=13)
198
+ plt.ylabel('Net Benefit', fontsize=13)
199
+ plt.title(title, fontsize=15, fontweight='bold')
200
+ plt.legend(loc='upper right', fontsize=10, framealpha=0.9)
201
+ plt.grid(True, alpha=0.15); plt.tight_layout()
202
+ plt.savefig(os.path.join(result_dir, f'{save_prefix}.pdf'), format='pdf', bbox_inches='tight', dpi=300)
203
+ plt.savefig(os.path.join(result_dir, f'{save_prefix}.png'), format='png', bbox_inches='tight', dpi=150)
204
+ plt.close()
205
+
206
+ # ============================================================================
207
+ # Model configs
208
+ # ============================================================================
209
+ ALL_MODEL_NAMES = ['RF', 'DT', 'KNN', 'XGB', 'AdaBoost', 'LR', 'NB', 'SVM']
210
+
211
+ def get_models_config(selected, rs=42):
212
+ cfg = {
213
+ 'RF': {'model': RandomForestClassifier(random_state=rs, n_jobs=-1),
214
+ 'params': {'n_estimators': [100,200], 'max_depth': [20,50], 'min_samples_split': [2,5], 'max_features': ['sqrt']}},
215
+ 'DT': {'model': DecisionTreeClassifier(random_state=rs),
216
+ 'params': {'max_depth': [20,50], 'min_samples_split': [2,10], 'min_samples_leaf': [1,4], 'criterion': ['gini','entropy']}},
217
+ 'KNN': {'model': KNeighborsClassifier(n_jobs=-1),
218
+ 'params': {'n_neighbors': [3,5,7], 'weights': ['uniform','distance'], 'metric': ['euclidean','manhattan']}},
219
+ 'XGB': {'model': XGBClassifier(random_state=rs, eval_metric='logloss', n_jobs=-1),
220
+ 'params': {'n_estimators': [100,200], 'max_depth': [5,7], 'learning_rate': [0.05,0.1], 'subsample': [0.8,1.0], 'colsample_bytree': [0.8,1.0]}},
221
+ 'AdaBoost': {'model': AdaBoostClassifier(random_state=rs),
222
+ 'params': {'n_estimators': [50,100], 'learning_rate': [0.1,0.5,1.0]}},
223
+ 'LR': {'model': LogisticRegression(random_state=rs, n_jobs=-1, max_iter=2000),
224
+ 'params': {'C': [0.1,1,10], 'penalty': ['l2'], 'solver': ['lbfgs','liblinear']}},
225
+ 'NB': {'model': GaussianNB(),
226
+ 'params': {'var_smoothing': [1e-9,1e-7,1e-5]}},
227
+ 'SVM': {'model': SVC(probability=True, random_state=rs),
228
+ 'params': {'C': [1,10], 'kernel': ['rbf','linear'], 'gamma': ['scale','auto']}},
229
+ }
230
+ return {k: v for k, v in cfg.items() if k in selected}
231
+
232
+ # ============================================================================
233
+ # Main Pipeline with Progress
234
+ # ============================================================================
235
+ def run_pipeline(
236
+ train_file, val_file1, val_file2, val_file3, selected_models, enable_tuning,
237
+ cv_folds, alpha, top_n_features, shap_sample_size,
238
+ progress=gr.Progress(track_tqdm=True),
239
+ ):
240
+ if train_file is None:
241
+ return None, "❌ 请先上传训练集 CSV 文件"
242
+ sel = selected_models if isinstance(selected_models, list) else [s.strip() for s in str(selected_models).split(",") if s.strip()]
243
+ if not sel:
244
+ return None, "❌ 请至少选择一个模型"
245
+
246
+ RS = 42; CVF = int(cv_folds); ALP = float(alpha)
247
+ TOPN = int(top_n_features); SHAPSZ = int(shap_sample_size)
248
+ TUNING = bool(enable_tuning)
249
+
250
+ L = []
251
+ def log(m): L.append(str(m))
252
+
253
+ rf = tempfile.mkdtemp(prefix="ml_")
254
+
255
+ try:
256
+ # ── Load ──
257
+ progress(0.02, desc="📂 加载数据...")
258
+ log("━" * 50)
259
+ log(" 🧬 ML 二分类模型训练与评估系统")
260
+ log("━" * 50)
261
+
262
+ tp = train_file if isinstance(train_file, str) else getattr(train_file, 'name', str(train_file))
263
+ data = pd.read_csv(tp)
264
+ X = data.iloc[:, 2:]; y = data.iloc[:, 0]
265
+ fnames = X.columns.tolist()
266
+
267
+ # Auto 0/1
268
+ ul = sorted(y.unique())
269
+ if set(ul) != {0, 1}:
270
+ lm = {ul[0]: 0, ul[1]: 1}; y = y.map(lm)
271
+ log(f" ⚙ 标签已自动转换: {lm}")
272
+
273
+ log(f" 📊 训练集: {X.shape[0]} 样本 × {X.shape[1]} 特征")
274
+ log(f" 📊 标签: {dict(y.value_counts())}")
275
+ log(f" 🤖 模型: {', '.join(sel)}")
276
+ log(f" 🔧 调优: {'开启' if TUNING else '关闭'} | CV: {CVF}折")
277
+
278
+ mcfg = get_models_config(sel, RS)
279
+ skf = StratifiedKFold(n_splits=CVF, shuffle=True, random_state=RS)
280
+
281
+ # ── Train ──
282
+ bpd = {}; amr = {}; tms = {}
283
+ total = len(mcfg)
284
+ COLORS = ['#2563eb','#f59e0b','#10b981','#ef4444','#8b5cf6','#ec4899','#06b6d4','#6b7280']
285
+
286
+ for mi, (mn, cf) in enumerate(mcfg.items()):
287
+ pv = 0.05 + 0.40 * mi / total
288
+ progress(pv, desc=f"🏋️ [{mi+1}/{total}] 训练 {mn}...")
289
+ log(f"\n{'─'*40}")
290
+ log(f" 🔄 [{mi+1}/{total}] {mn}")
291
+
292
+ Xv = X.values
293
+ if TUNING:
294
+ log(f" ⏳ GridSearchCV (CV={CVF})...")
295
+ gs = GridSearchCV(cf['model'], cf['params'], cv=skf, scoring='roc_auc', n_jobs=-1, verbose=0)
296
+ gs.fit(Xv, y)
297
+ bp = gs.best_params_; bpd[mn] = bp
298
+ log(f" ✓ 最佳AUC: {gs.best_score_:.4f}")
299
+ else:
300
+ bp = {}; bpd[mn] = "默认参数"
301
+
302
+ mdl = deepcopy(cf['model'])
303
+ if bp: mdl.set_params(**bp)
304
+ mdl.fit(Xv, y)
305
+ tms[mn] = {'model': mdl, 'scaler': None}
306
+
307
+ # CV eval
308
+ folds = []; ayt = []; ayp = []; tprs = []
309
+ bfpr = np.linspace(0, 1, 101)
310
+ for fi, (tri, tei) in enumerate(skf.split(X, y), 1):
311
+ Xtr, Xte = X.iloc[tri].values, X.iloc[tei].values
312
+ ytr, yte = y.iloc[tri], y.iloc[tei]
313
+ mf = deepcopy(cf['model'])
314
+ if bp: mf.set_params(**bp)
315
+ mf.fit(Xtr, ytr)
316
+ ypp = mf.predict_proba(Xte)[:, 1]
317
+ ypd = (ypp > 0.5).astype(int)
318
+ tn, fp, fn, tp = confusion_matrix(yte, ypd).ravel()
319
+ se = tp/(tp+fn) if tp+fn else 0; sp = tn/(tn+fp) if tn+fp else 0
320
+ ac = (tp+tn)/(tp+tn+fp+fn); pr = tp/(tp+fp) if tp+fp else 0
321
+ f1 = 2*pr*se/(pr+se) if pr+se else 0
322
+ auc_v = roc_auc_score(yte, ypp)
323
+ folds.append({'Fold': fi, 'AUC': auc_v, 'Accuracy': ac, 'Sensitivity': se,
324
+ 'Specificity': sp, 'Precision': pr, 'F1': f1, 'TP': tp, 'TN': tn, 'FP': fp, 'FN': fn})
325
+ ayt.extend(yte); ayp.extend(ypp)
326
+ fa, ta, _ = roc_curve(yte, ypp)
327
+ ti = np.interp(bfpr, fa, ta); ti[0] = 0.0; tprs.append(ti)
328
+
329
+ rdf = pd.DataFrame(folds)
330
+ mr = {'Fold': 'Mean', 'AUC': rdf['AUC'].mean(), 'Accuracy': rdf['Accuracy'].mean(),
331
+ 'Sensitivity': rdf['Sensitivity'].mean(), 'Specificity': rdf['Specificity'].mean(),
332
+ 'Precision': rdf['Precision'].mean(), 'F1': rdf['F1'].mean(),
333
+ 'TP': rdf['TP'].sum(), 'TN': rdf['TN'].sum(), 'FP': rdf['FP'].sum(), 'FN': rdf['FN'].sum()}
334
+ rdf = pd.concat([rdf, pd.DataFrame([mr])], ignore_index=True)
335
+ ot, yv, _ = find_optimal_threshold(np.array(ayt), np.array(ayp))
336
+ amr[mn] = {'results_df': rdf, 'mean_auc': mr['AUC'], 'all_y_true': np.array(ayt),
337
+ 'all_y_probs': np.array(ayp), 'tprs': tprs, 'base_fpr': bfpr,
338
+ 'optimal_threshold': ot, 'youden_index': yv}
339
+ log(f" ✅ AUC={mr['AUC']:.4f} Acc={mr['Accuracy']:.4f} 阈值={ot:.4f}")
340
+
341
+ mnames = list(amr.keys()); nm = len(mnames)
342
+ log(f"\n{'━'*50}")
343
+ log(f" ✅ {nm} 个模型训练完成")
344
+
345
+ # ── ROC ──
346
+ progress(0.48, desc="📈 绘制ROC曲线...")
347
+ log(f"\n 📈 绘制图表...")
348
+ plt.figure(figsize=(12, 10))
349
+ for i, mn in enumerate(mnames):
350
+ r = amr[mn]; mt = np.mean(r['tprs'], axis=0); mt[-1] = 1.0
351
+ ma = auc_score(r['base_fpr'], mt); sa = r['results_df'].iloc[:-1]['AUC'].std()
352
+ st = np.std(r['tprs'], axis=0)
353
+ c = COLORS[i % 8]
354
+ plt.plot(r['base_fpr'], mt, color=c, lw=2.5, alpha=0.85, label=f'{mn} (AUC={ma:.3f}±{sa:.3f})')
355
+ plt.fill_between(r['base_fpr'], np.maximum(mt-st, 0), np.minimum(mt+st, 1), color=c, alpha=0.08)
356
+ plt.plot([0,1],[0,1],'--',lw=2,color='#9ca3af',alpha=0.5)
357
+ plt.xlim([-0.02,1.02]); plt.ylim([-0.02,1.02])
358
+ plt.xlabel('False Positive Rate',fontsize=13); plt.ylabel('True Positive Rate',fontsize=13)
359
+ plt.title('ROC Curves — Internal Cross-Validation',fontsize=15,fontweight='bold')
360
+ plt.legend(loc="lower right",fontsize=10); plt.grid(True,alpha=0.2); plt.tight_layout()
361
+ plt.savefig(os.path.join(rf,'roc_all.pdf'),format='pdf',bbox_inches='tight',dpi=300)
362
+ plt.savefig(os.path.join(rf,'roc_all.png'),format='png',bbox_inches='tight',dpi=150)
363
+ plt.close()
364
+
365
+ # ── PR ──
366
+ progress(0.52, desc="📈 绘制PR曲线...")
367
+ plt.figure(figsize=(12, 10))
368
+ for i, mn in enumerate(mnames):
369
+ r = amr[mn]; pra = []
370
+ for tri, tei in skf.split(X, y):
371
+ cf2 = mcfg[mn]; mpr = deepcopy(cf2['model'])
372
+ bp2 = bpd[mn]
373
+ if isinstance(bp2, dict) and bp2: mpr.set_params(**bp2)
374
+ mpr.fit(X.iloc[tri].values, y.iloc[tri])
375
+ yp2 = mpr.predict_proba(X.iloc[tei].values)[:,1]
376
+ pc, rc, _ = precision_recall_curve(y.iloc[tei], yp2)
377
+ pra.append(auc_score(rc, pc))
378
+ mpr_v = np.mean(pra); spr = np.std(pra)
379
+ pa, ra, _ = precision_recall_curve(r['all_y_true'], r['all_y_probs'])
380
+ plt.plot(ra, pa, color=COLORS[i%8], lw=2.5, alpha=0.85, label=f'{mn} (AUPRC={mpr_v:.3f}±{spr:.3f})')
381
+ plt.xlim([-0.02,1.02]); plt.ylim([-0.02,1.02])
382
+ plt.xlabel('Recall',fontsize=13); plt.ylabel('Precision',fontsize=13)
383
+ plt.title('Precision-Recall Curves — Internal CV',fontsize=15,fontweight='bold')
384
+ plt.legend(loc="lower left",fontsize=10); plt.grid(True,alpha=0.2); plt.tight_layout()
385
+ plt.savefig(os.path.join(rf,'pr_all.pdf'),format='pdf',bbox_inches='tight',dpi=300)
386
+ plt.savefig(os.path.join(rf,'pr_all.png'),format='png',bbox_inches='tight',dpi=150)
387
+ plt.close()
388
+
389
+ # ── CM ──
390
+ progress(0.55, desc="📊 绘制混淆矩阵...")
391
+ nc = min(4, nm); nr = (nm+nc-1)//nc
392
+ fig, axes = plt.subplots(nr, nc, figsize=(4.2*nc, 4.2*nr))
393
+ if nm == 1: axes = np.array([axes])
394
+ af = axes.flatten()
395
+ for i, mn in enumerate(mnames):
396
+ r = amr[mn]; ypc = (r['all_y_probs']>=r['optimal_threshold']).astype(int)
397
+ cm = confusion_matrix(r['all_y_true'], ypc)
398
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False,
399
+ xticklabels=['Neg','Pos'], yticklabels=['Neg','Pos'], ax=af[i], annot_kws={'fontsize':12})
400
+ af[i].set_xlabel('Predicted'); af[i].set_ylabel('True')
401
+ acc = (cm[0,0]+cm[1,1])/cm.sum()
402
+ af[i].set_title(f'{mn} (Acc={acc:.3f})',fontsize=12,fontweight='bold')
403
+ for i in range(nm, len(af)): af[i].set_visible(False)
404
+ plt.suptitle('Confusion Matrices',fontsize=15,fontweight='bold',y=1.0)
405
+ plt.tight_layout()
406
+ plt.savefig(os.path.join(rf,'confusion_matrices.pdf'),format='pdf',bbox_inches='tight',dpi=300)
407
+ plt.savefig(os.path.join(rf,'confusion_matrices.png'),format='png',bbox_inches='tight',dpi=150)
408
+ plt.close()
409
+
410
+ # ── DeLong ──
411
+ progress(0.58, desc="🔬 DeLong检验...")
412
+ bmn = max(amr, key=lambda x: amr[x]['mean_auc'])
413
+ bma = amr[bmn]['mean_auc']
414
+ log(f"\n 🏆 最佳模型: {bmn} (AUC={bma:.4f})")
415
+ dlr = []; retained = [bmn]
416
+ for om in mnames:
417
+ if om == bmn: continue
418
+ try: pv, a1, a2 = delong_roc_test(amr[bmn]['all_y_true'], amr[bmn]['all_y_probs'], amr[om]['all_y_probs'])
419
+ except: pv=1.0; a1=bma; a2=amr[om]['mean_auc']
420
+ if pv >= ALP: retained.append(om); dec = "保留"
421
+ else: dec = "排除"
422
+ dlr.append({'Model1': bmn, 'AUC1': a1, 'Model2': om, 'AUC2': a2, 'P': pv, 'Decision': dec})
423
+ log(f" {bmn} vs {om}: P={pv:.2e} → {dec}")
424
+ dldf = pd.DataFrame(dlr).sort_values('P', ascending=False) if dlr else pd.DataFrame()
425
+ log(f" ✅ 保留 {len(retained)} 个模型: {', '.join(retained)}")
426
+
427
+ # ── SHAP ──
428
+ progress(0.62, desc="🔥 SHAP分析...")
429
+ log(f"\n 🔥 SHAP特征分析...")
430
+ shap_imp = {}
431
+ for si, mn in enumerate(retained):
432
+ progress(0.62+0.10*si/len(retained), desc=f"🔥 SHAP: {mn}...")
433
+ mo = tms[mn]['model']; Xshap = X.values
434
+ ns = min(SHAPSZ, Xshap.shape[0])
435
+ np.random.seed(RS); sidx = np.random.choice(Xshap.shape[0], ns, replace=False)
436
+ Xs = Xshap[sidx]
437
+ try:
438
+ if mn in ['RF','XGB','DT','AdaBoost']:
439
+ exp = shap.TreeExplainer(mo); sv = exp.shap_values(Xs)
440
+ if isinstance(sv, list): sv = sv[1]
441
+ else:
442
+ bg = Xs[np.random.choice(ns, min(100,ns), replace=False)]
443
+ exp = shap.KernelExplainer(lambda x, m=mo: m.predict_proba(x)[:,1], bg)
444
+ sv = exp.shap_values(Xs)
445
+ if isinstance(sv, list): sv = sv[0]
446
+ sv = np.array(sv)
447
+ if sv.ndim > 2: sv = sv[0]
448
+ fi = np.abs(sv).mean(0)
449
+ if fi.ndim > 1: fi = fi.flatten()
450
+ if len(fi) > len(fnames): fi = fi[:len(fnames)]
451
+ elif len(fi) < len(fnames): fi = np.pad(fi, (0, len(fnames)-len(fi)))
452
+ idf = pd.DataFrame({'Feature': fnames, 'Importance': fi}).sort_values('Importance', ascending=False)
453
+ shap_imp[mn] = idf
454
+ Xdf = pd.DataFrame(Xs, columns=fnames)
455
+ if sv.shape[1] > Xdf.shape[1]: sv = sv[:,:Xdf.shape[1]]
456
+ elif sv.shape[1] < Xdf.shape[1]: sv = np.hstack([sv, np.zeros((sv.shape[0], Xdf.shape[1]-sv.shape[1]))])
457
+ plt.figure(figsize=(12,8))
458
+ shap.summary_plot(sv, Xdf, plot_type="dot", show=False, max_display=TOPN)
459
+ plt.title(f'SHAP — {mn} (Top {TOPN})',fontsize=14,fontweight='bold'); plt.tight_layout()
460
+ plt.savefig(os.path.join(rf,f'shap_{mn}.pdf'),format='pdf',bbox_inches='tight')
461
+ plt.savefig(os.path.join(rf,f'shap_{mn}.png'),format='png',bbox_inches='tight',dpi=150)
462
+ plt.close()
463
+ log(f" ✅ {mn} Top3: {', '.join(idf.head(3)['Feature'].tolist())}")
464
+ except Exception as e:
465
+ log(f" ⚠ {mn} SHAP失败: {e}")
466
+
467
+ # ── Ablation ──
468
+ progress(0.75, desc="🧪 特征消融...")
469
+ log(f"\n 🧪 特征消融研究...")
470
+ ablr = {}
471
+ for mn in retained:
472
+ if mn not in shap_imp: continue
473
+ tfs = shap_imp[mn].head(TOPN)['Feature'].tolist()
474
+ fcs = []; aucs_a = []; asp = {}
475
+ for nf in range(1, len(tfs)+1):
476
+ Xsub = X[tfs[:nf]]
477
+ fa = []; syt = []; syp = []
478
+ for tri, tei in skf.split(Xsub, y):
479
+ mf = deepcopy(mcfg[mn]['model'])
480
+ bp2 = bpd.get(mn, {})
481
+ if isinstance(bp2, dict) and bp2: mf.set_params(**bp2)
482
+ mf.fit(Xsub.iloc[tri].values, y.iloc[tri])
483
+ yp2 = mf.predict_proba(Xsub.iloc[tei].values)[:,1]
484
+ syt.extend(y.iloc[tei]); syp.extend(yp2)
485
+ fa.append(roc_auc_score(y.iloc[tei], yp2))
486
+ fcs.append(nf); aucs_a.append(np.mean(fa))
487
+ asp[nf] = {'yt': np.array(syt), 'yp': np.array(syp)}
488
+ fp = amr[mn]['all_y_probs']; fauc = amr[mn]['mean_auc']; optn = None; adl = []
489
+ for nf in range(1, len(tfs)+1):
490
+ sd = asp[nf]; sa = aucs_a[nf-1]
491
+ try:
492
+ pv = delong_roc_test(sd['yt'], fp, sd['yp'])[0] if len(fp)==len(sd['yp']) else (0.1 if abs(sa-fauc)<=0.05 else 0.01)
493
+ except: pv = 0.1 if abs(sa-fauc)<=0.05 else 0.01
494
+ sig = "Sig" if pv < ALP else "NS"
495
+ adl.append({'N': nf, 'AUC': sa, 'Full_AUC': fauc, 'P': pv, 'Sig': sig})
496
+ if optn is None and pv >= ALP: optn = nf
497
+ ablr[mn] = {'fcs': fcs, 'aucs': aucs_a, 'tfs': tfs, 'dl': pd.DataFrame(adl),
498
+ 'optn': optn or len(tfs), 'optf': tfs[:optn] if optn else tfs}
499
+ log(f" {mn}: 最优 {ablr[mn]['optn']} 个特征")
500
+
501
+ # Final model
502
+ fcands = {}
503
+ for mn in retained:
504
+ if mn in ablr:
505
+ ar = ablr[mn]
506
+ fcands[mn] = {'nf': ar['optn'], 'feats': ar['optf'], 'auc': ar['aucs'][ar['optn']-1]}
507
+ fmn = min(fcands, key=lambda x: fcands[x]['nf']) if fcands else None
508
+ fmi = fcands.get(fmn) if fmn else None
509
+ if fmn: log(f"\n ⭐ 最终模型: {fmn} ({fmi['nf']}特征, AUC={fmi['auc']:.4f})")
510
+
511
+ # Ablation plot
512
+ progress(0.80, desc="📈 消融曲线...")
513
+ plt.figure(figsize=(12,8))
514
+ for i, (mn, ar) in enumerate(ablr.items()):
515
+ c = COLORS[i%8]
516
+ plt.plot(ar['fcs'], ar['aucs'], marker='o', lw=2, ms=5, color=c, label=mn)
517
+ on = ar['optn']; oa = ar['aucs'][on-1]
518
+ plt.scatter([on],[oa], s=200, marker='*', color=c, edgecolors='black', lw=2, zorder=5)
519
+ plt.xlabel('Number of Features',fontsize=13); plt.ylabel('AUC',fontsize=13)
520
+ plt.title('Feature Ablation (★=Optimal)',fontsize=15,fontweight='bold')
521
+ plt.legend(fontsize=11); plt.grid(True,alpha=0.2); plt.tight_layout()
522
+ plt.savefig(os.path.join(rf,'ablation.pdf'),format='pdf',bbox_inches='tight')
523
+ plt.savefig(os.path.join(rf,'ablation.png'),format='png',bbox_inches='tight',dpi=150)
524
+ plt.close()
525
+
526
+ # DCA — Internal (标准临床格式)
527
+ progress(0.83, desc="📈 DCA曲线...")
528
+ dca_probs = {mn: amr[mn]['all_y_probs'] for mn in retained}
529
+ plot_dca(amr[retained[0]]['all_y_true'], dca_probs,
530
+ 'Decision Curve Analysis — Internal CV', 'dca', rf, final_model=fmn)
531
+
532
+ # ── External Validation (支持多个验证集) ──
533
+ val_files_list = []
534
+ for vf in [val_file1, val_file2, val_file3]:
535
+ if vf is not None:
536
+ val_files_list.append(vf)
537
+
538
+ if val_files_list and fmn:
539
+ progress(0.86, desc="🧪 外部验证...")
540
+ log(f"\n{'━'*50}")
541
+ log(f" 🧪 外部验证 ({len(val_files_list)} 个验证集)")
542
+
543
+ for vi, vf in enumerate(val_files_list, 1):
544
+ vp = vf if isinstance(vf, str) else getattr(vf, 'name', str(vf))
545
+ ed = pd.read_csv(vp); Xe = ed.iloc[:,2:]; ye = ed.iloc[:,0]
546
+ ule = sorted(ye.unique())
547
+ if set(ule)!={0,1}: lme={ule[0]:0,ule[1]:1}; ye=ye.map(lme)
548
+ log(f"\n 📊 验证集 {vi}: {Xe.shape[0]} 样本, {os.path.basename(vp)}")
549
+
550
+ Xes = Xe[fmi['feats']]; Xtf = X[fmi['feats']]
551
+ fm = deepcopy(mcfg[fmn]['model'])
552
+ bp3 = bpd[fmn]
553
+ if isinstance(bp3, dict) and bp3: fm.set_params(**bp3)
554
+ fm.fit(Xtf.values, y)
555
+ yep = fm.predict_proba(Xes.values)[:,1]; yed = (yep>0.5).astype(int)
556
+ tn,fp,fn,tp = confusion_matrix(ye,yed).ravel()
557
+ se=tp/(tp+fn) if tp+fn else 0; sp=tn/(tn+fp) if tn+fp else 0
558
+ ac=(tp+tn)/(tp+tn+fp+fn); pr=tp/(tp+fp) if tp+fp else 0
559
+ f1v=2*pr*se/(pr+se) if pr+se else 0; ea=roc_auc_score(ye,yep)
560
+ log(f" ✅ AUC={ea:.4f} Acc={ac:.4f} Sens={se:.4f} Spec={sp:.4f} F1={f1v:.4f}")
561
+
562
+ sfx = f'_ext{vi}' if len(val_files_list) > 1 else '_ext'
563
+ tag = f'Validation {vi}' if len(val_files_list) > 1 else 'External'
564
+
565
+ # ROC
566
+ fe,te,_ = roc_curve(ye,yep)
567
+ plt.figure(figsize=(10,8))
568
+ plt.plot(fe,te,'#2563eb',lw=2.5,label=f'{fmn} (AUC={ea:.3f})')
569
+ plt.plot([0,1],[0,1],'--',color='gray'); plt.xlabel('FPR'); plt.ylabel('TPR')
570
+ plt.title(f'ROC — {tag} ({fmn})',fontweight='bold'); plt.legend(); plt.grid(True,alpha=0.2); plt.tight_layout()
571
+ plt.savefig(os.path.join(rf,f'roc{sfx}.pdf'),format='pdf',bbox_inches='tight')
572
+ plt.savefig(os.path.join(rf,f'roc{sfx}.png'),format='png',bbox_inches='tight',dpi=150)
573
+ plt.close()
574
+
575
+ # PR
576
+ pe,re,_ = precision_recall_curve(ye,yep)
577
+ plt.figure(figsize=(10,8))
578
+ plt.plot(re,pe,'#2563eb',lw=2.5,label=fmn); plt.xlabel('Recall'); plt.ylabel('Precision')
579
+ plt.title(f'PR — {tag} ({fmn})',fontweight='bold'); plt.legend(); plt.grid(True,alpha=0.2); plt.tight_layout()
580
+ plt.savefig(os.path.join(rf,f'pr{sfx}.pdf'),format='pdf',bbox_inches='tight')
581
+ plt.savefig(os.path.join(rf,f'pr{sfx}.png'),format='png',bbox_inches='tight',dpi=150)
582
+ plt.close()
583
+
584
+ # CM
585
+ cme = confusion_matrix(ye,yed)
586
+ plt.figure(figsize=(8,6))
587
+ sns.heatmap(cme,annot=True,fmt='d',cmap='Blues',cbar=False,xticklabels=['Neg','Pos'],yticklabels=['Neg','Pos'])
588
+ plt.xlabel('Predicted'); plt.ylabel('True')
589
+ plt.title(f'CM — {tag} ({fmn})',fontweight='bold'); plt.tight_layout()
590
+ plt.savefig(os.path.join(rf,f'cm{sfx}.pdf'),format='pdf',bbox_inches='tight')
591
+ plt.savefig(os.path.join(rf,f'cm{sfx}.png'),format='png',bbox_inches='tight',dpi=150)
592
+ plt.close()
593
+
594
+ # DCA — 标准临床格式
595
+ plot_dca(ye, {fmn: yep}, f'DCA — {tag} ({fmn})', f'dca{sfx}', rf)
596
+
597
+ # Excel
598
+ with pd.ExcelWriter(os.path.join(rf,f'validation{sfx}.xlsx'),engine='openpyxl') as w:
599
+ pd.DataFrame([{'Model':fmn,'N_Features':fmi['nf'],'AUC':ea,'Accuracy':ac,
600
+ 'Sensitivity':se,'Specificity':sp,'Precision':pr,'F1':f1v}]).to_excel(w,sheet_name='Metrics',index=False)
601
+ pd.DataFrame({'Feature':fmi['feats']}).to_excel(w,sheet_name='Features',index=False)
602
+
603
+ # ── Save Excels ──
604
+ progress(0.92, desc="💾 保存结果...")
605
+ log(f"\n 💾 保存结果文件...")
606
+ with pd.ExcelWriter(os.path.join(rf,'model_evaluation.xlsx'),engine='openpyxl') as w:
607
+ for mn, r in amr.items(): r['results_df'].to_excel(w,sheet_name=mn,index=False)
608
+ sd = []
609
+ for mn, r in amr.items():
610
+ rw = r['results_df'].iloc[-1].to_dict()
611
+ rw.update({'Model':mn,'Retained':'Yes' if mn in retained else 'No','Final':'Yes' if mn==fmn else 'No'})
612
+ sd.append(rw)
613
+ sdf = pd.DataFrame(sd)
614
+ cols = ['Model','Retained','Final']+[c for c in sdf.columns if c not in ['Model','Fold','Retained','Final']]
615
+ sdf[cols].sort_values('AUC',ascending=False).to_excel(w,sheet_name='Summary',index=False)
616
+ if len(dldf)>0: dldf.to_excel(w,sheet_name='DeLong',index=False)
617
+
618
+ with pd.ExcelWriter(os.path.join(rf,'feature_ablation.xlsx'),engine='openpyxl') as w:
619
+ for mn, ar in ablr.items():
620
+ pd.DataFrame({'N':ar['fcs'],'AUC':ar['aucs']}).to_excel(w,sheet_name=mn,index=False)
621
+ if 'dl' in ar: ar['dl'].to_excel(w,sheet_name=f'{mn}_DL',index=False)
622
+ for mn, idf in shap_imp.items():
623
+ idf.to_excel(w,sheet_name=f'{mn}_Imp',index=False)
624
+
625
+ with open(os.path.join(rf,'best_params.txt'),'w',encoding='utf-8') as f:
626
+ f.write("模型最佳超参数\n"+"="*50+"\n\n")
627
+ for mn in mcfg:
628
+ f.write(f"模型: {mn}\n")
629
+ bp = bpd[mn]
630
+ if isinstance(bp,dict):
631
+ for k,v in bp.items(): f.write(f" {k}: {v}\n")
632
+ else: f.write(f" {bp}\n")
633
+ f.write(f" AUC: {amr[mn]['mean_auc']:.4f}\n 保留: {'是' if mn in retained else '否'}\n\n")
634
+ if fmn: f.write(f"\n最终模型: {fmn}\n特征({fmi['nf']}): {', '.join(fmi['feats'])}\n")
635
+
636
+ if fmn:
637
+ pickle.dump({'model_name':fmn,'model':tms[fmn]['model'],'best_params':bpd[fmn],
638
+ 'features':fmi['feats'],'n_features':fmi['nf'],'auc':fmi['auc'],
639
+ 'threshold':amr[fmn]['optimal_threshold']},
640
+ open(os.path.join(rf,f'model_{fmn}.pkl'),'wb'))
641
+
642
+ # ── ZIP ──
643
+ progress(0.97, desc="📦 打包ZIP...")
644
+ # 使用唯一文件名避免多用户冲突
645
+ zp = os.path.join(tempfile.gettempdir(), f"ml_results_{int(time.time())}_{os.getpid()}.zip")
646
+ with zipfile.ZipFile(zp,'w',zipfile.ZIP_DEFLATED) as zf:
647
+ for root,_,files in os.walk(rf):
648
+ for fn in files: zf.write(os.path.join(root,fn), os.path.relpath(os.path.join(root,fn),rf))
649
+
650
+ nf = sum(len(f) for _,_,f in os.walk(rf))
651
+
652
+ # 立即清理临时结果文件夹(ZIP已打包完毕)
653
+ shutil.rmtree(rf, ignore_errors=True)
654
+ gc.collect()
655
+
656
+ log(f"\n{'━'*50}")
657
+ log(f" 🎉 分析完成!共 {nf} 个文件已打包")
658
+ log(f" 💾 临时文件已自动清理")
659
+ log(f"{'━'*50}")
660
+ progress(1.0, desc="✅ 完成!")
661
+ return zp, "\n".join(L)
662
+
663
+ except Exception as e:
664
+ log(f"\n❌ 错误: {e}")
665
+ log(traceback.format_exc())
666
+ # 出错时也清理临时文件夹
667
+ if os.path.exists(rf):
668
+ shutil.rmtree(rf, ignore_errors=True)
669
+ gc.collect()
670
+ return None, "\n".join(L)
671
+
672
+
673
+ # ============================================================================
674
+ # Beautiful Gradio UI
675
+ # ============================================================================
676
+ CUSTOM_CSS = """
677
+ /* ── Header Banner ── */
678
+ .header-banner {
679
+ background: linear-gradient(135deg, #0a2463 0%, #1e3a7a 40%, #2554a8 100%);
680
+ border-radius: 16px;
681
+ padding: 28px 36px;
682
+ margin-bottom: 20px;
683
+ box-shadow: 0 8px 32px rgba(0,0,0,0.18);
684
+ position: relative;
685
+ overflow: hidden;
686
+ }
687
+ .header-banner::before {
688
+ content: '';
689
+ position: absolute;
690
+ top: -50%;
691
+ right: -20%;
692
+ width: 400px;
693
+ height: 400px;
694
+ background: radial-gradient(circle, rgba(96,165,250,0.2) 0%, transparent 70%);
695
+ border-radius: 50%;
696
+ }
697
+ .header-banner img {
698
+ max-height: 52px;
699
+ border-radius: 6px;
700
+ margin-bottom: 12px;
701
+ }
702
+ .header-banner h1 {
703
+ color: #e2e8f0 !important;
704
+ font-size: 1.7em !important;
705
+ margin: 4px 0 6px 0 !important;
706
+ font-weight: 700 !important;
707
+ letter-spacing: 0.5px;
708
+ }
709
+ .header-banner p {
710
+ color: #94a3b8 !important;
711
+ font-size: 0.92em !important;
712
+ margin: 2px 0 !important;
713
+ line-height: 1.6;
714
+ }
715
+ .header-banner .credit {
716
+ color: #64748b !important;
717
+ font-size: 0.82em !important;
718
+ margin-top: 10px !important;
719
+ border-top: 1px solid rgba(148,163,184,0.15);
720
+ padding-top: 10px;
721
+ }
722
+
723
+ /* ── Section Cards ── */
724
+ .section-title {
725
+ background: linear-gradient(90deg, #2563eb 0%, #3b82f6 100%);
726
+ color: white !important;
727
+ padding: 8px 16px;
728
+ border-radius: 8px;
729
+ font-size: 0.95em !important;
730
+ font-weight: 600 !important;
731
+ margin: 12px 0 8px 0;
732
+ letter-spacing: 0.3px;
733
+ }
734
+
735
+ /* ── Pipeline Steps ── */
736
+ .pipeline-box {
737
+ background: linear-gradient(135deg, #f0f9ff 0%, #e0f2fe 100%);
738
+ border: 1px solid #bae6fd;
739
+ border-radius: 12px;
740
+ padding: 14px 18px;
741
+ margin: 8px 0;
742
+ font-size: 0.88em;
743
+ }
744
+ .pipeline-box code {
745
+ background: #2563eb;
746
+ color: white;
747
+ padding: 2px 8px;
748
+ border-radius: 4px;
749
+ font-size: 0.85em;
750
+ margin: 0 2px;
751
+ }
752
+
753
+ /* ── Buttons ── */
754
+ .quick-btn {
755
+ border-radius: 8px !important;
756
+ font-weight: 500 !important;
757
+ transition: all 0.2s ease !important;
758
+ }
759
+ .quick-btn:hover {
760
+ transform: translateY(-1px) !important;
761
+ box-shadow: 0 4px 12px rgba(0,0,0,0.1) !important;
762
+ }
763
+
764
+ /* ── Log Area ── */
765
+ .log-area textarea {
766
+ font-family: 'Menlo', 'Consolas', 'Monaco', monospace !important;
767
+ font-size: 12.5px !important;
768
+ line-height: 1.5 !important;
769
+ background: #0f172a !important;
770
+ color: #e2e8f0 !important;
771
+ border-radius: 10px !important;
772
+ padding: 16px !important;
773
+ border: 1px solid #1e293b !important;
774
+ }
775
+
776
+ /* ── General Polish ── */
777
+ .gradio-container {
778
+ max-width: 1280px !important;
779
+ }
780
+ footer { display: none !important; }
781
+ """
782
+
783
+ with gr.Blocks(
784
+ title="ML 二分类模型平台 — 复旦大学附属眼耳鼻喉科医院",
785
+ theme=gr.themes.Soft(primary_hue="blue", secondary_hue="slate", neutral_hue="slate"),
786
+ css=CUSTOM_CSS,
787
+ ) as demo:
788
+
789
+ # ── Header ──
790
+ gr.HTML("""
791
+ <div class="header-banner">
792
+ <img src="https://huggingface.co/spaces/fudan-renjun/machine-learning-2/resolve/main/hospital_logo.png"
793
+ alt="Logo" onerror="this.style.display='none'"/>
794
+ <h1>🧬 ML 二分类模型训练与评估平台</h1>
795
+ <p>上传训练集与验证集 CSV,自动完成模型训练、交叉验证、统计检验、特征分析,结果打包下载</p>
796
+ <p class="credit">复旦大学附属眼耳鼻喉科医院 · 检验科 · 任俊</p>
797
+ </div>
798
+ """)
799
+
800
+ # ── Pipeline Info ──
801
+ gr.HTML("""
802
+ <div class="pipeline-box">
803
+ <strong>📋 分析流程:</strong>
804
+ <code>模型训练</code> → <code>交叉验证</code> → <code>DeLong检验</code> →
805
+ <code>SHAP分析</code> → <code>特征消融</code> → <code>外部验证</code>
806
+ &nbsp;&nbsp;|&nbsp;&nbsp;
807
+ <strong>CSV格式:</strong> 第1列=标签, 第2列=ID, 第3列起=特征
808
+ </div>
809
+ """)
810
+
811
+ with gr.Row(equal_height=False):
812
+ # ══ Left Panel ══
813
+ with gr.Column(scale=5):
814
+
815
+ gr.HTML('<div class="section-title">📂 数据上传</div>')
816
+ train_file = gr.File(label="训练集 CSV(必需)", file_types=[".csv"])
817
+ gr.HTML('<p style="color:#64748b;font-size:0.85em;margin:4px 0 8px 0;">验证集为可选项,支持同时上传 1~3 个验证集,分别生成独立的评估报告</p>')
818
+ with gr.Row():
819
+ val_file1 = gr.File(label="验证集 1(可选)", file_types=[".csv"], scale=1)
820
+ val_file2 = gr.File(label="验证集 2(可选)", file_types=[".csv"], scale=1)
821
+ val_file3 = gr.File(label="验证集 3(可选)", file_types=[".csv"], scale=1)
822
+
823
+ gr.HTML('<div class="section-title">🤖 模型选择</div>')
824
+ model_selector = gr.Dropdown(
825
+ choices=ALL_MODEL_NAMES,
826
+ value=ALL_MODEL_NAMES,
827
+ multiselect=True,
828
+ label="选择模型(可多选,默认全部)",
829
+ info="RF=随机森林 DT=决策树 KNN=K近邻 XGB=极限梯度提升 AdaBoost=自适应提升 LR=逻辑回归 NB=朴素贝叶斯 SVM=支持向量机",
830
+ )
831
+ with gr.Row():
832
+ btn_all = gr.Button("🔘 全选", size="sm", variant="secondary", elem_classes="quick-btn")
833
+ btn_tree = gr.Button("🌲 树模型", size="sm", variant="secondary", elem_classes="quick-btn")
834
+ btn_linear = gr.Button("📐 线性模型", size="sm", variant="secondary", elem_classes="quick-btn")
835
+ btn_top4 = gr.Button("⚡ 经典四模型", size="sm", variant="secondary", elem_classes="quick-btn")
836
+ btn_all.click(lambda: ALL_MODEL_NAMES, outputs=model_selector)
837
+ btn_tree.click(lambda: ['RF','DT','XGB','AdaBoost'], outputs=model_selector)
838
+ btn_linear.click(lambda: ['LR','SVM','NB'], outputs=model_selector)
839
+ btn_top4.click(lambda: ['RF','XGB','LR','SVM'], outputs=model_selector)
840
+
841
+ gr.HTML('<div class="section-title">⚙️ 参数配置</div>')
842
+ enable_tuning = gr.Checkbox(value=False, label="启用超参数调优 (GridSearchCV) ⚠️ 开启后运行时间显著增加")
843
+ with gr.Row():
844
+ cv_folds = gr.Slider(3, 10, value=5, step=1, label="交叉验证折数")
845
+ alpha_sl = gr.Slider(0.01, 0.10, value=0.05, step=0.01, label="DeLong 显著性水平 α")
846
+ with gr.Row():
847
+ top_n = gr.Slider(5, 50, value=20, step=1, label="SHAP 前 N 个特征")
848
+ shap_sz = gr.Slider(30, 200, value=80, step=10, label="SHAP 采样数量")
849
+
850
+ run_btn = gr.Button("🚀 开始分析", variant="primary", size="lg")
851
+
852
+ # ══ Right Panel ══
853
+ with gr.Column(scale=5):
854
+ gr.HTML('<div class="section-title">📋 运行日志</div>')
855
+ log_output = gr.Textbox(
856
+ label="", lines=22, max_lines=50, interactive=False,
857
+ placeholder="点击「开始分析」后,运行日志将在此实时显示...",
858
+ elem_classes="log-area",
859
+ )
860
+
861
+ gr.HTML('<div class="section-title">⬇️ 结果下载</div>')
862
+ zip_output = gr.File(label="分析结果 ZIP 压缩包")
863
+
864
+ # ── Connect ──
865
+ run_btn.click(
866
+ fn=run_pipeline,
867
+ inputs=[train_file, val_file1, val_file2, val_file3, model_selector, enable_tuning, cv_folds, alpha_sl, top_n, shap_sz],
868
+ outputs=[zip_output, log_output],
869
+ api_name="run",
870
+ )
871
+
872
+ # ============================================================================
873
+ # Authentication with Expiration
874
+ # ============================================================================
875
+ from datetime import datetime
876
+
877
+ # ┌─────────────────────────────────────────────────┐
878
+ # │ 账号配置 — 在这里修改账号、密码和有效期 │
879
+ # │ 格式: "用户名": {"password": "密码", │
880
+ # │ "expires": "YYYY-MM-DD"} │
881
+ # │ 如果不需要过期限制,设 "expires": None │
882
+ # └─────────────────────────────────────────────────┘
883
+ ACCOUNTS = {
884
+ "admin": {
885
+ "password": "admin123",
886
+ "expires": None, # 永不过期
887
+ },
888
+ "renjun": {
889
+ "password": "fudan2025",
890
+ "expires": "2026-12-31", # 2026年12月31日过期
891
+ },
892
+ "guest": {
893
+ "password": "guest888",
894
+ "expires": "2025-06-30", # 示例:已过期账号
895
+ },
896
+ }
897
+
898
+ def auth_fn(username, password):
899
+ """验证账号密码 + 检查有效期"""
900
+ user = ACCOUNTS.get(username)
901
+ if not user:
902
+ return False
903
+ if user["password"] != password:
904
+ return False
905
+ if user["expires"] is not None:
906
+ try:
907
+ exp_date = datetime.strptime(user["expires"], "%Y-%m-%d")
908
+ if datetime.now() > exp_date:
909
+ return False
910
+ except ValueError:
911
+ return False
912
+ return True
913
+
914
+ demo.queue()
915
+ demo.launch(
916
+ server_name="0.0.0.0",
917
+ server_port=7860,
918
+ auth=auth_fn,
919
+ auth_message="🔐 复旦大学附属眼耳鼻喉科医院 · ML分析平台\n请输入账号和密码登录",
920
+ ssr_mode=False,
921
+ )