Spaces:
Sleeping
Sleeping
| """ | |
| ML Binary Classification Pipeline | |
| Eye & ENT Hospital of Fudan University — Laboratory Medicine, Ren Jun | |
| Gradio 5.12.0 + Python 3.11 | |
| """ | |
| import numpy as np | |
| import pandas as pd | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier | |
| from sklearn.tree import DecisionTreeClassifier | |
| from sklearn.neighbors import KNeighborsClassifier | |
| from sklearn.linear_model import LogisticRegression | |
| from sklearn.naive_bayes import GaussianNB | |
| from sklearn.svm import SVC | |
| from xgboost import XGBClassifier | |
| from sklearn.model_selection import StratifiedKFold, GridSearchCV | |
| from sklearn.metrics import ( | |
| roc_auc_score, confusion_matrix, roc_curve, | |
| auc as auc_score, precision_recall_curve | |
| ) | |
| import seaborn as sns | |
| import warnings | |
| from scipy import stats | |
| import os | |
| import shap | |
| import pickle | |
| from copy import deepcopy | |
| import zipfile | |
| import tempfile | |
| import traceback | |
| import time | |
| import shutil | |
| import gc | |
| import threading | |
| import gradio as gr | |
| warnings.filterwarnings('ignore') | |
| plt.rcParams['font.sans-serif'] = ['DejaVu Sans'] | |
| plt.rcParams['axes.unicode_minus'] = False | |
| # ============================================================================ | |
| # Cache Cleanup System | |
| # ============================================================================ | |
| CLEANUP_MAX_AGE_MINUTES = 30 # 临时文件超过30分钟自动删除 | |
| CLEANUP_INTERVAL_SECONDS = 600 # 每10分钟检查一次 | |
| CLEANUP_MAX_DISK_MB = 1024 # /tmp 中 ml_ 文件夹超过 1GB 时强制清理 | |
| def cleanup_old_temp_files(): | |
| """清理超时的临时文件夹和ZIP""" | |
| now = time.time() | |
| max_age = CLEANUP_MAX_AGE_MINUTES * 60 | |
| cleaned_dirs = 0 | |
| cleaned_mb = 0.0 | |
| tmp_dir = tempfile.gettempdir() | |
| try: | |
| for item in os.listdir(tmp_dir): | |
| item_path = os.path.join(tmp_dir, item) | |
| # 清理 ml_ 开头的结果文件夹 | |
| if item.startswith("ml_") and os.path.isdir(item_path): | |
| age = now - os.path.getmtime(item_path) | |
| if age > max_age: | |
| size = sum(os.path.getsize(os.path.join(r, f)) | |
| for r, _, fs in os.walk(item_path) for f in fs) | |
| shutil.rmtree(item_path, ignore_errors=True) | |
| cleaned_dirs += 1 | |
| cleaned_mb += size / (1024 * 1024) | |
| # 清理旧的 ZIP 结果文件 | |
| if item.startswith("ml_") and item.endswith(".zip") and os.path.isfile(item_path): | |
| age = now - os.path.getmtime(item_path) | |
| if age > max_age: | |
| size = os.path.getsize(item_path) | |
| os.remove(item_path) | |
| cleaned_mb += size / (1024 * 1024) | |
| except Exception: | |
| pass | |
| # 强制回收 Python 内存 | |
| gc.collect() | |
| if cleaned_dirs > 0: | |
| print(f"[Cleanup] 清理 {cleaned_dirs} 个临时文件夹, 释放 {cleaned_mb:.1f} MB") | |
| def check_disk_pressure(): | |
| """检查磁盘压力,超限时立即清理所有旧文件""" | |
| tmp_dir = tempfile.gettempdir() | |
| total_mb = 0 | |
| try: | |
| for item in os.listdir(tmp_dir): | |
| item_path = os.path.join(tmp_dir, item) | |
| if item.startswith("ml_"): | |
| if os.path.isdir(item_path): | |
| total_mb += sum(os.path.getsize(os.path.join(r, f)) | |
| for r, _, fs in os.walk(item_path) for f in fs) / (1024*1024) | |
| elif os.path.isfile(item_path): | |
| total_mb += os.path.getsize(item_path) / (1024*1024) | |
| except Exception: | |
| pass | |
| if total_mb > CLEANUP_MAX_DISK_MB: | |
| print(f"[Cleanup] 磁盘占用 {total_mb:.0f}MB > {CLEANUP_MAX_DISK_MB}MB, 强制清理!") | |
| for item in os.listdir(tmp_dir): | |
| item_path = os.path.join(tmp_dir, item) | |
| if item.startswith("ml_"): | |
| try: | |
| if os.path.isdir(item_path): shutil.rmtree(item_path, ignore_errors=True) | |
| elif os.path.isfile(item_path): os.remove(item_path) | |
| except: pass | |
| gc.collect() | |
| def periodic_cleanup(): | |
| """后台定时清理线程""" | |
| while True: | |
| time.sleep(CLEANUP_INTERVAL_SECONDS) | |
| cleanup_old_temp_files() | |
| check_disk_pressure() | |
| # 启动后台清理线程 | |
| _cleanup_thread = threading.Thread(target=periodic_cleanup, daemon=True) | |
| _cleanup_thread.start() | |
| print("[Cleanup] 后台自动清理已启动 (每10分钟检查, 30分钟过期, 上限500MB)") | |
| # ============================================================================ | |
| # Helper Functions | |
| # ============================================================================ | |
| def compute_midrank(x): | |
| J = np.argsort(x); Z = x[J]; N = len(x) | |
| T = np.zeros(N, dtype=float); i = 0 | |
| while i < N: | |
| j = i | |
| while j < N and Z[j] == Z[i]: j += 1 | |
| T[i:j] = 0.5 * (i + j - 1); i = j | |
| T2 = np.empty(N, dtype=float); T2[J] = T + 1 | |
| return T2 | |
| def fastDeLong(pst, m): | |
| n = pst.shape[1] - m; k = pst.shape[0] | |
| tx = np.empty([k, m]); ty = np.empty([k, n]); tz = np.empty([k, m + n]) | |
| for r in range(k): | |
| tx[r] = compute_midrank(pst[r, :m]); ty[r] = compute_midrank(pst[r, m:]) | |
| tz[r] = compute_midrank(pst[r]) | |
| aucs = tz[:, :m].sum(1) / m / n - (m + 1.0) / 2.0 / n | |
| v01 = (tz[:, :m] - tx) / n; v10 = 1.0 - (tz[:, m:] - ty) / m | |
| return aucs, np.cov(v01) / m + np.cov(v10) / n | |
| def delong_roc_test(gt, p1, p2): | |
| order = (-gt).argsort(); m = int(gt.sum()) | |
| pst = np.vstack([p1, p2])[:, order] | |
| aucs, cov = fastDeLong(pst, m) | |
| l = np.array([[1, -1]]) | |
| z = np.abs(np.diff(aucs)) / np.sqrt(np.dot(np.dot(l, cov), l.T)) | |
| log10p = np.log10(2) + stats.norm.logsf(z, 0, 1) / np.log(10) | |
| return 10 ** log10p[0][0], aucs[0], aucs[1] | |
| def find_optimal_threshold(y_true, y_probs, method='youden'): | |
| fpr, tpr, th = roc_curve(y_true, y_probs) | |
| idx = np.argmax(tpr - fpr) | |
| return th[idx], (tpr - fpr)[idx], idx | |
| def calculate_net_benefit(y_true, y_probs, threshold): | |
| yp = (y_probs >= threshold).astype(int) | |
| tn, fp, fn, tp = confusion_matrix(y_true, yp).ravel() | |
| n = len(y_true) | |
| return (tp / n) - (fp / n) * (threshold / (1 - threshold)) | |
| def plot_dca(y_true, y_probs_dict, title, save_prefix, result_dir, final_model=None): | |
| """绘制标准临床DCA曲线(类似R语言rmda包格式)""" | |
| prevalence = np.mean(y_true) | |
| max_thr = min(0.99, max(prevalence * 3, 0.6)) if prevalence < 0.5 else 0.9 | |
| thresholds = np.linspace(0.01, max_thr, 200) | |
| plt.figure(figsize=(10, 7)) | |
| # Treat All | |
| ta_nb = [prevalence - (1 - prevalence) * (pt / (1 - pt)) for pt in thresholds] | |
| plt.plot(thresholds, ta_nb, 'k-', lw=1.5, label='Treat All') | |
| # Treat None (y=0) | |
| plt.axhline(y=0, color='#555555', lw=1.5, linestyle='-', label='Treat None') | |
| # Model curves | |
| DCA_COLORS = ['#e41a1c','#377eb8','#4daf4a','#984ea3','#ff7f00','#a65628','#f781bf','#999999'] | |
| for idx, (mn, yp) in enumerate(y_probs_dict.items()): | |
| nbs = [calculate_net_benefit(y_true, yp, t) for t in thresholds] | |
| lbl = f'{mn} (Final)' if mn == final_model else mn | |
| plt.plot(thresholds, nbs, color=DCA_COLORS[idx % len(DCA_COLORS)], lw=2, label=lbl) | |
| # Y-axis: clinical range | |
| y_min = max(min(ta_nb), -0.05) - 0.01 | |
| y_max = max(prevalence * 1.5, 0.15) | |
| plt.xlim([0, max_thr]); plt.ylim([y_min, y_max]) | |
| plt.xlabel('Threshold Probability', fontsize=13) | |
| plt.ylabel('Net Benefit', fontsize=13) | |
| plt.title(title, fontsize=15, fontweight='bold') | |
| plt.legend(loc='upper right', fontsize=10, framealpha=0.9) | |
| plt.grid(True, alpha=0.15); plt.tight_layout() | |
| plt.savefig(os.path.join(result_dir, f'{save_prefix}.pdf'), format='pdf', bbox_inches='tight', dpi=300) | |
| plt.savefig(os.path.join(result_dir, f'{save_prefix}.png'), format='png', bbox_inches='tight', dpi=150) | |
| plt.close() | |
| # ============================================================================ | |
| # Model configs | |
| # ============================================================================ | |
| ALL_MODEL_NAMES = ['RF', 'DT', 'KNN', 'XGB', 'AdaBoost', 'LR', 'NB', 'SVM'] | |
| def get_models_config(selected, rs=42): | |
| cfg = { | |
| 'RF': {'model': RandomForestClassifier(random_state=rs, n_jobs=-1), | |
| 'params': {'n_estimators': [100,200], 'max_depth': [20,50], 'min_samples_split': [2,5], 'max_features': ['sqrt']}}, | |
| 'DT': {'model': DecisionTreeClassifier(random_state=rs), | |
| 'params': {'max_depth': [20,50], 'min_samples_split': [2,10], 'min_samples_leaf': [1,4], 'criterion': ['gini','entropy']}}, | |
| 'KNN': {'model': KNeighborsClassifier(n_jobs=-1), | |
| 'params': {'n_neighbors': [3,5,7], 'weights': ['uniform','distance'], 'metric': ['euclidean','manhattan']}}, | |
| 'XGB': {'model': XGBClassifier(random_state=rs, eval_metric='logloss', n_jobs=-1), | |
| 'params': {'n_estimators': [100,200], 'max_depth': [5,7], 'learning_rate': [0.05,0.1], 'subsample': [0.8,1.0], 'colsample_bytree': [0.8,1.0]}}, | |
| 'AdaBoost': {'model': AdaBoostClassifier(random_state=rs), | |
| 'params': {'n_estimators': [50,100], 'learning_rate': [0.1,0.5,1.0]}}, | |
| 'LR': {'model': LogisticRegression(random_state=rs, n_jobs=-1, max_iter=2000), | |
| 'params': {'C': [0.1,1,10], 'penalty': ['l2'], 'solver': ['lbfgs','liblinear']}}, | |
| 'NB': {'model': GaussianNB(), | |
| 'params': {'var_smoothing': [1e-9,1e-7,1e-5]}}, | |
| 'SVM': {'model': SVC(probability=True, random_state=rs), | |
| 'params': {'C': [1,10], 'kernel': ['rbf','linear'], 'gamma': ['scale','auto']}}, | |
| } | |
| return {k: v for k, v in cfg.items() if k in selected} | |
| # ============================================================================ | |
| # Main Pipeline with Progress | |
| # ============================================================================ | |
| def run_pipeline( | |
| train_file, val_file1, val_file2, val_file3, selected_models, enable_tuning, | |
| cv_folds, alpha, top_n_features, shap_sample_size, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| if train_file is None: | |
| return None, "❌ 请先上传训练集 CSV 文件" | |
| sel = selected_models if isinstance(selected_models, list) else [s.strip() for s in str(selected_models).split(",") if s.strip()] | |
| if not sel: | |
| return None, "❌ 请至少选择一个模型" | |
| RS = 42; CVF = int(cv_folds); ALP = float(alpha) | |
| TOPN = int(top_n_features); SHAPSZ = int(shap_sample_size) | |
| TUNING = bool(enable_tuning) | |
| L = [] | |
| def log(m): L.append(str(m)) | |
| rf = tempfile.mkdtemp(prefix="ml_") | |
| try: | |
| # ── Load ── | |
| progress(0.02, desc="📂 加载数据...") | |
| log("━" * 50) | |
| log(" 🧬 ML 二分类模型训练与评估系统") | |
| log("━" * 50) | |
| tp = train_file if isinstance(train_file, str) else getattr(train_file, 'name', str(train_file)) | |
| data = pd.read_csv(tp) | |
| X = data.iloc[:, 2:]; y = data.iloc[:, 0] | |
| fnames = X.columns.tolist() | |
| # Auto 0/1 | |
| ul = sorted(y.unique()) | |
| if set(ul) != {0, 1}: | |
| lm = {ul[0]: 0, ul[1]: 1}; y = y.map(lm) | |
| log(f" ⚙ 标签已自动转换: {lm}") | |
| log(f" 📊 训练集: {X.shape[0]} 样本 × {X.shape[1]} 特征") | |
| log(f" 📊 标签: {dict(y.value_counts())}") | |
| log(f" 🤖 模型: {', '.join(sel)}") | |
| log(f" 🔧 调优: {'开启' if TUNING else '关闭'} | CV: {CVF}折") | |
| mcfg = get_models_config(sel, RS) | |
| skf = StratifiedKFold(n_splits=CVF, shuffle=True, random_state=RS) | |
| # ── Train ── | |
| bpd = {}; amr = {}; tms = {} | |
| total = len(mcfg) | |
| COLORS = ['#2563eb','#f59e0b','#10b981','#ef4444','#8b5cf6','#ec4899','#06b6d4','#6b7280'] | |
| for mi, (mn, cf) in enumerate(mcfg.items()): | |
| pv = 0.05 + 0.40 * mi / total | |
| progress(pv, desc=f"🏋️ [{mi+1}/{total}] 训练 {mn}...") | |
| log(f"\n{'─'*40}") | |
| log(f" 🔄 [{mi+1}/{total}] {mn}") | |
| Xv = X.values | |
| if TUNING: | |
| log(f" ⏳ GridSearchCV (CV={CVF})...") | |
| gs = GridSearchCV(cf['model'], cf['params'], cv=skf, scoring='roc_auc', n_jobs=-1, verbose=0) | |
| gs.fit(Xv, y) | |
| bp = gs.best_params_; bpd[mn] = bp | |
| log(f" ✓ 最佳AUC: {gs.best_score_:.4f}") | |
| else: | |
| bp = {}; bpd[mn] = "默认参数" | |
| mdl = deepcopy(cf['model']) | |
| if bp: mdl.set_params(**bp) | |
| mdl.fit(Xv, y) | |
| tms[mn] = {'model': mdl, 'scaler': None} | |
| # CV eval | |
| folds = []; ayt = []; ayp = []; tprs = [] | |
| bfpr = np.linspace(0, 1, 101) | |
| for fi, (tri, tei) in enumerate(skf.split(X, y), 1): | |
| Xtr, Xte = X.iloc[tri].values, X.iloc[tei].values | |
| ytr, yte = y.iloc[tri], y.iloc[tei] | |
| mf = deepcopy(cf['model']) | |
| if bp: mf.set_params(**bp) | |
| mf.fit(Xtr, ytr) | |
| ypp = mf.predict_proba(Xte)[:, 1] | |
| ypd = (ypp > 0.5).astype(int) | |
| tn, fp, fn, tp = confusion_matrix(yte, ypd).ravel() | |
| se = tp/(tp+fn) if tp+fn else 0; sp = tn/(tn+fp) if tn+fp else 0 | |
| ac = (tp+tn)/(tp+tn+fp+fn); pr = tp/(tp+fp) if tp+fp else 0 | |
| f1 = 2*pr*se/(pr+se) if pr+se else 0 | |
| auc_v = roc_auc_score(yte, ypp) | |
| folds.append({'Fold': fi, 'AUC': auc_v, 'Accuracy': ac, 'Sensitivity': se, | |
| 'Specificity': sp, 'Precision': pr, 'F1': f1, 'TP': tp, 'TN': tn, 'FP': fp, 'FN': fn}) | |
| ayt.extend(yte); ayp.extend(ypp) | |
| fa, ta, _ = roc_curve(yte, ypp) | |
| ti = np.interp(bfpr, fa, ta); ti[0] = 0.0; tprs.append(ti) | |
| rdf = pd.DataFrame(folds) | |
| mr = {'Fold': 'Mean', 'AUC': rdf['AUC'].mean(), 'Accuracy': rdf['Accuracy'].mean(), | |
| 'Sensitivity': rdf['Sensitivity'].mean(), 'Specificity': rdf['Specificity'].mean(), | |
| 'Precision': rdf['Precision'].mean(), 'F1': rdf['F1'].mean(), | |
| 'TP': rdf['TP'].sum(), 'TN': rdf['TN'].sum(), 'FP': rdf['FP'].sum(), 'FN': rdf['FN'].sum()} | |
| rdf = pd.concat([rdf, pd.DataFrame([mr])], ignore_index=True) | |
| ot, yv, _ = find_optimal_threshold(np.array(ayt), np.array(ayp)) | |
| amr[mn] = {'results_df': rdf, 'mean_auc': mr['AUC'], 'all_y_true': np.array(ayt), | |
| 'all_y_probs': np.array(ayp), 'tprs': tprs, 'base_fpr': bfpr, | |
| 'optimal_threshold': ot, 'youden_index': yv} | |
| log(f" ✅ AUC={mr['AUC']:.4f} Acc={mr['Accuracy']:.4f} 阈值={ot:.4f}") | |
| mnames = list(amr.keys()); nm = len(mnames) | |
| log(f"\n{'━'*50}") | |
| log(f" ✅ {nm} 个模型训练完成") | |
| # ── ROC ── | |
| progress(0.48, desc="📈 绘制ROC曲线...") | |
| log(f"\n 📈 绘制图表...") | |
| plt.figure(figsize=(12, 10)) | |
| for i, mn in enumerate(mnames): | |
| r = amr[mn]; mt = np.mean(r['tprs'], axis=0); mt[-1] = 1.0 | |
| ma = auc_score(r['base_fpr'], mt); sa = r['results_df'].iloc[:-1]['AUC'].std() | |
| st = np.std(r['tprs'], axis=0) | |
| c = COLORS[i % 8] | |
| plt.plot(r['base_fpr'], mt, color=c, lw=2.5, alpha=0.85, label=f'{mn} (AUC={ma:.3f}±{sa:.3f})') | |
| plt.fill_between(r['base_fpr'], np.maximum(mt-st, 0), np.minimum(mt+st, 1), color=c, alpha=0.08) | |
| plt.plot([0,1],[0,1],'--',lw=2,color='#9ca3af',alpha=0.5) | |
| plt.xlim([-0.02,1.02]); plt.ylim([-0.02,1.02]) | |
| plt.xlabel('False Positive Rate',fontsize=13); plt.ylabel('True Positive Rate',fontsize=13) | |
| plt.title('ROC Curves — Internal Cross-Validation',fontsize=15,fontweight='bold') | |
| plt.legend(loc="lower right",fontsize=10); plt.grid(True,alpha=0.2); plt.tight_layout() | |
| plt.savefig(os.path.join(rf,'roc_all.pdf'),format='pdf',bbox_inches='tight',dpi=300) | |
| plt.savefig(os.path.join(rf,'roc_all.png'),format='png',bbox_inches='tight',dpi=150) | |
| plt.close() | |
| # ── PR ── | |
| progress(0.52, desc="📈 绘制PR曲线...") | |
| plt.figure(figsize=(12, 10)) | |
| for i, mn in enumerate(mnames): | |
| r = amr[mn]; pra = [] | |
| for tri, tei in skf.split(X, y): | |
| cf2 = mcfg[mn]; mpr = deepcopy(cf2['model']) | |
| bp2 = bpd[mn] | |
| if isinstance(bp2, dict) and bp2: mpr.set_params(**bp2) | |
| mpr.fit(X.iloc[tri].values, y.iloc[tri]) | |
| yp2 = mpr.predict_proba(X.iloc[tei].values)[:,1] | |
| pc, rc, _ = precision_recall_curve(y.iloc[tei], yp2) | |
| pra.append(auc_score(rc, pc)) | |
| mpr_v = np.mean(pra); spr = np.std(pra) | |
| pa, ra, _ = precision_recall_curve(r['all_y_true'], r['all_y_probs']) | |
| plt.plot(ra, pa, color=COLORS[i%8], lw=2.5, alpha=0.85, label=f'{mn} (AUPRC={mpr_v:.3f}±{spr:.3f})') | |
| plt.xlim([-0.02,1.02]); plt.ylim([-0.02,1.02]) | |
| plt.xlabel('Recall',fontsize=13); plt.ylabel('Precision',fontsize=13) | |
| plt.title('Precision-Recall Curves — Internal CV',fontsize=15,fontweight='bold') | |
| plt.legend(loc="lower left",fontsize=10); plt.grid(True,alpha=0.2); plt.tight_layout() | |
| plt.savefig(os.path.join(rf,'pr_all.pdf'),format='pdf',bbox_inches='tight',dpi=300) | |
| plt.savefig(os.path.join(rf,'pr_all.png'),format='png',bbox_inches='tight',dpi=150) | |
| plt.close() | |
| # ── CM ── | |
| progress(0.55, desc="📊 绘制混淆矩阵...") | |
| nc = min(4, nm); nr = (nm+nc-1)//nc | |
| fig, axes = plt.subplots(nr, nc, figsize=(4.2*nc, 4.2*nr)) | |
| if nm == 1: axes = np.array([axes]) | |
| af = axes.flatten() | |
| for i, mn in enumerate(mnames): | |
| r = amr[mn]; ypc = (r['all_y_probs']>=r['optimal_threshold']).astype(int) | |
| cm = confusion_matrix(r['all_y_true'], ypc) | |
| sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False, | |
| xticklabels=['Neg','Pos'], yticklabels=['Neg','Pos'], ax=af[i], annot_kws={'fontsize':12}) | |
| af[i].set_xlabel('Predicted'); af[i].set_ylabel('True') | |
| acc = (cm[0,0]+cm[1,1])/cm.sum() | |
| af[i].set_title(f'{mn} (Acc={acc:.3f})',fontsize=12,fontweight='bold') | |
| for i in range(nm, len(af)): af[i].set_visible(False) | |
| plt.suptitle('Confusion Matrices',fontsize=15,fontweight='bold',y=1.0) | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(rf,'confusion_matrices.pdf'),format='pdf',bbox_inches='tight',dpi=300) | |
| plt.savefig(os.path.join(rf,'confusion_matrices.png'),format='png',bbox_inches='tight',dpi=150) | |
| plt.close() | |
| # ── DeLong ── | |
| progress(0.58, desc="🔬 DeLong检验...") | |
| bmn = max(amr, key=lambda x: amr[x]['mean_auc']) | |
| bma = amr[bmn]['mean_auc'] | |
| log(f"\n 🏆 最佳模型: {bmn} (AUC={bma:.4f})") | |
| dlr = []; retained = [bmn] | |
| for om in mnames: | |
| if om == bmn: continue | |
| try: pv, a1, a2 = delong_roc_test(amr[bmn]['all_y_true'], amr[bmn]['all_y_probs'], amr[om]['all_y_probs']) | |
| except: pv=1.0; a1=bma; a2=amr[om]['mean_auc'] | |
| if pv >= ALP: retained.append(om); dec = "保留" | |
| else: dec = "排除" | |
| dlr.append({'Model1': bmn, 'AUC1': a1, 'Model2': om, 'AUC2': a2, 'P': pv, 'Decision': dec}) | |
| log(f" {bmn} vs {om}: P={pv:.2e} → {dec}") | |
| dldf = pd.DataFrame(dlr).sort_values('P', ascending=False) if dlr else pd.DataFrame() | |
| log(f" ✅ 保留 {len(retained)} 个模型: {', '.join(retained)}") | |
| # ── SHAP ── | |
| progress(0.62, desc="🔥 SHAP分析...") | |
| log(f"\n 🔥 SHAP特征分析...") | |
| shap_imp = {} | |
| for si, mn in enumerate(retained): | |
| progress(0.62+0.10*si/len(retained), desc=f"🔥 SHAP: {mn}...") | |
| mo = tms[mn]['model']; Xshap = X.values | |
| ns = min(SHAPSZ, Xshap.shape[0]) | |
| np.random.seed(RS); sidx = np.random.choice(Xshap.shape[0], ns, replace=False) | |
| Xs = Xshap[sidx] | |
| try: | |
| if mn in ['RF','XGB','DT','AdaBoost']: | |
| exp = shap.TreeExplainer(mo); sv = exp.shap_values(Xs) | |
| if isinstance(sv, list): sv = sv[1] | |
| else: | |
| bg = Xs[np.random.choice(ns, min(100,ns), replace=False)] | |
| exp = shap.KernelExplainer(lambda x, m=mo: m.predict_proba(x)[:,1], bg) | |
| sv = exp.shap_values(Xs) | |
| if isinstance(sv, list): sv = sv[0] | |
| sv = np.array(sv) | |
| if sv.ndim > 2: sv = sv[0] | |
| fi = np.abs(sv).mean(0) | |
| if fi.ndim > 1: fi = fi.flatten() | |
| if len(fi) > len(fnames): fi = fi[:len(fnames)] | |
| elif len(fi) < len(fnames): fi = np.pad(fi, (0, len(fnames)-len(fi))) | |
| idf = pd.DataFrame({'Feature': fnames, 'Importance': fi}).sort_values('Importance', ascending=False) | |
| shap_imp[mn] = idf | |
| Xdf = pd.DataFrame(Xs, columns=fnames) | |
| if sv.shape[1] > Xdf.shape[1]: sv = sv[:,:Xdf.shape[1]] | |
| elif sv.shape[1] < Xdf.shape[1]: sv = np.hstack([sv, np.zeros((sv.shape[0], Xdf.shape[1]-sv.shape[1]))]) | |
| plt.figure(figsize=(12,8)) | |
| shap.summary_plot(sv, Xdf, plot_type="dot", show=False, max_display=TOPN) | |
| plt.title(f'SHAP — {mn} (Top {TOPN})',fontsize=14,fontweight='bold'); plt.tight_layout() | |
| plt.savefig(os.path.join(rf,f'shap_{mn}.pdf'),format='pdf',bbox_inches='tight') | |
| plt.savefig(os.path.join(rf,f'shap_{mn}.png'),format='png',bbox_inches='tight',dpi=150) | |
| plt.close() | |
| log(f" ✅ {mn} Top3: {', '.join(idf.head(3)['Feature'].tolist())}") | |
| except Exception as e: | |
| log(f" ⚠ {mn} SHAP失败: {e}") | |
| # ── Ablation ── | |
| progress(0.75, desc="🧪 特征消融...") | |
| log(f"\n 🧪 特征消融研究...") | |
| ablr = {} | |
| for mn in retained: | |
| if mn not in shap_imp: continue | |
| tfs = shap_imp[mn].head(TOPN)['Feature'].tolist() | |
| fcs = []; aucs_a = []; asp = {} | |
| for nf in range(1, len(tfs)+1): | |
| Xsub = X[tfs[:nf]] | |
| fa = []; syt = []; syp = [] | |
| for tri, tei in skf.split(Xsub, y): | |
| mf = deepcopy(mcfg[mn]['model']) | |
| bp2 = bpd.get(mn, {}) | |
| if isinstance(bp2, dict) and bp2: mf.set_params(**bp2) | |
| mf.fit(Xsub.iloc[tri].values, y.iloc[tri]) | |
| yp2 = mf.predict_proba(Xsub.iloc[tei].values)[:,1] | |
| syt.extend(y.iloc[tei]); syp.extend(yp2) | |
| fa.append(roc_auc_score(y.iloc[tei], yp2)) | |
| fcs.append(nf); aucs_a.append(np.mean(fa)) | |
| asp[nf] = {'yt': np.array(syt), 'yp': np.array(syp)} | |
| fp = amr[mn]['all_y_probs']; fauc = amr[mn]['mean_auc']; optn = None; adl = [] | |
| for nf in range(1, len(tfs)+1): | |
| sd = asp[nf]; sa = aucs_a[nf-1] | |
| try: | |
| pv = delong_roc_test(sd['yt'], fp, sd['yp'])[0] if len(fp)==len(sd['yp']) else (0.1 if abs(sa-fauc)<=0.05 else 0.01) | |
| except: pv = 0.1 if abs(sa-fauc)<=0.05 else 0.01 | |
| sig = "Sig" if pv < ALP else "NS" | |
| adl.append({'N': nf, 'AUC': sa, 'Full_AUC': fauc, 'P': pv, 'Sig': sig}) | |
| if optn is None and pv >= ALP: optn = nf | |
| ablr[mn] = {'fcs': fcs, 'aucs': aucs_a, 'tfs': tfs, 'dl': pd.DataFrame(adl), | |
| 'optn': optn or len(tfs), 'optf': tfs[:optn] if optn else tfs} | |
| log(f" {mn}: 最优 {ablr[mn]['optn']} 个特征") | |
| # Final model | |
| fcands = {} | |
| for mn in retained: | |
| if mn in ablr: | |
| ar = ablr[mn] | |
| fcands[mn] = {'nf': ar['optn'], 'feats': ar['optf'], 'auc': ar['aucs'][ar['optn']-1]} | |
| fmn = min(fcands, key=lambda x: fcands[x]['nf']) if fcands else None | |
| fmi = fcands.get(fmn) if fmn else None | |
| if fmn: log(f"\n ⭐ 最终模型: {fmn} ({fmi['nf']}特征, AUC={fmi['auc']:.4f})") | |
| # Ablation plot | |
| progress(0.80, desc="📈 消融曲线...") | |
| plt.figure(figsize=(12,8)) | |
| for i, (mn, ar) in enumerate(ablr.items()): | |
| c = COLORS[i%8] | |
| plt.plot(ar['fcs'], ar['aucs'], marker='o', lw=2, ms=5, color=c, label=mn) | |
| on = ar['optn']; oa = ar['aucs'][on-1] | |
| plt.scatter([on],[oa], s=200, marker='*', color=c, edgecolors='black', lw=2, zorder=5) | |
| plt.xlabel('Number of Features',fontsize=13); plt.ylabel('AUC',fontsize=13) | |
| plt.title('Feature Ablation (★=Optimal)',fontsize=15,fontweight='bold') | |
| plt.legend(fontsize=11); plt.grid(True,alpha=0.2); plt.tight_layout() | |
| plt.savefig(os.path.join(rf,'ablation.pdf'),format='pdf',bbox_inches='tight') | |
| plt.savefig(os.path.join(rf,'ablation.png'),format='png',bbox_inches='tight',dpi=150) | |
| plt.close() | |
| # DCA — Internal (标准临床格式) | |
| progress(0.83, desc="📈 DCA曲线...") | |
| dca_probs = {mn: amr[mn]['all_y_probs'] for mn in retained} | |
| plot_dca(amr[retained[0]]['all_y_true'], dca_probs, | |
| 'Decision Curve Analysis — Internal CV', 'dca', rf, final_model=fmn) | |
| # ── External Validation (支持多个验证集) ── | |
| val_files_list = [] | |
| for vf in [val_file1, val_file2, val_file3]: | |
| if vf is not None: | |
| val_files_list.append(vf) | |
| if val_files_list and fmn: | |
| progress(0.86, desc="🧪 外部验证...") | |
| log(f"\n{'━'*50}") | |
| log(f" 🧪 外部验证 ({len(val_files_list)} 个验证集)") | |
| for vi, vf in enumerate(val_files_list, 1): | |
| vp = vf if isinstance(vf, str) else getattr(vf, 'name', str(vf)) | |
| ed = pd.read_csv(vp); Xe = ed.iloc[:,2:]; ye = ed.iloc[:,0] | |
| ule = sorted(ye.unique()) | |
| if set(ule)!={0,1}: lme={ule[0]:0,ule[1]:1}; ye=ye.map(lme) | |
| log(f"\n 📊 验证集 {vi}: {Xe.shape[0]} 样本, {os.path.basename(vp)}") | |
| Xes = Xe[fmi['feats']]; Xtf = X[fmi['feats']] | |
| fm = deepcopy(mcfg[fmn]['model']) | |
| bp3 = bpd[fmn] | |
| if isinstance(bp3, dict) and bp3: fm.set_params(**bp3) | |
| fm.fit(Xtf.values, y) | |
| yep = fm.predict_proba(Xes.values)[:,1]; yed = (yep>0.5).astype(int) | |
| tn,fp,fn,tp = confusion_matrix(ye,yed).ravel() | |
| se=tp/(tp+fn) if tp+fn else 0; sp=tn/(tn+fp) if tn+fp else 0 | |
| ac=(tp+tn)/(tp+tn+fp+fn); pr=tp/(tp+fp) if tp+fp else 0 | |
| f1v=2*pr*se/(pr+se) if pr+se else 0; ea=roc_auc_score(ye,yep) | |
| log(f" ✅ AUC={ea:.4f} Acc={ac:.4f} Sens={se:.4f} Spec={sp:.4f} F1={f1v:.4f}") | |
| sfx = f'_ext{vi}' if len(val_files_list) > 1 else '_ext' | |
| tag = f'Validation {vi}' if len(val_files_list) > 1 else 'External' | |
| # ROC | |
| fe,te,_ = roc_curve(ye,yep) | |
| plt.figure(figsize=(10,8)) | |
| plt.plot(fe,te,'#2563eb',lw=2.5,label=f'{fmn} (AUC={ea:.3f})') | |
| plt.plot([0,1],[0,1],'--',color='gray'); plt.xlabel('FPR'); plt.ylabel('TPR') | |
| plt.title(f'ROC — {tag} ({fmn})',fontweight='bold'); plt.legend(); plt.grid(True,alpha=0.2); plt.tight_layout() | |
| plt.savefig(os.path.join(rf,f'roc{sfx}.pdf'),format='pdf',bbox_inches='tight') | |
| plt.savefig(os.path.join(rf,f'roc{sfx}.png'),format='png',bbox_inches='tight',dpi=150) | |
| plt.close() | |
| # PR | |
| pe,re,_ = precision_recall_curve(ye,yep) | |
| plt.figure(figsize=(10,8)) | |
| plt.plot(re,pe,'#2563eb',lw=2.5,label=fmn); plt.xlabel('Recall'); plt.ylabel('Precision') | |
| plt.title(f'PR — {tag} ({fmn})',fontweight='bold'); plt.legend(); plt.grid(True,alpha=0.2); plt.tight_layout() | |
| plt.savefig(os.path.join(rf,f'pr{sfx}.pdf'),format='pdf',bbox_inches='tight') | |
| plt.savefig(os.path.join(rf,f'pr{sfx}.png'),format='png',bbox_inches='tight',dpi=150) | |
| plt.close() | |
| # CM | |
| cme = confusion_matrix(ye,yed) | |
| plt.figure(figsize=(8,6)) | |
| sns.heatmap(cme,annot=True,fmt='d',cmap='Blues',cbar=False,xticklabels=['Neg','Pos'],yticklabels=['Neg','Pos']) | |
| plt.xlabel('Predicted'); plt.ylabel('True') | |
| plt.title(f'CM — {tag} ({fmn})',fontweight='bold'); plt.tight_layout() | |
| plt.savefig(os.path.join(rf,f'cm{sfx}.pdf'),format='pdf',bbox_inches='tight') | |
| plt.savefig(os.path.join(rf,f'cm{sfx}.png'),format='png',bbox_inches='tight',dpi=150) | |
| plt.close() | |
| # DCA — 标准临床格式 | |
| plot_dca(ye, {fmn: yep}, f'DCA — {tag} ({fmn})', f'dca{sfx}', rf) | |
| # Excel | |
| with pd.ExcelWriter(os.path.join(rf,f'validation{sfx}.xlsx'),engine='openpyxl') as w: | |
| pd.DataFrame([{'Model':fmn,'N_Features':fmi['nf'],'AUC':ea,'Accuracy':ac, | |
| 'Sensitivity':se,'Specificity':sp,'Precision':pr,'F1':f1v}]).to_excel(w,sheet_name='Metrics',index=False) | |
| pd.DataFrame({'Feature':fmi['feats']}).to_excel(w,sheet_name='Features',index=False) | |
| # ── Save Excels ── | |
| progress(0.92, desc="💾 保存结果...") | |
| log(f"\n 💾 保存结果文件...") | |
| with pd.ExcelWriter(os.path.join(rf,'model_evaluation.xlsx'),engine='openpyxl') as w: | |
| for mn, r in amr.items(): r['results_df'].to_excel(w,sheet_name=mn,index=False) | |
| sd = [] | |
| for mn, r in amr.items(): | |
| rw = r['results_df'].iloc[-1].to_dict() | |
| rw.update({'Model':mn,'Retained':'Yes' if mn in retained else 'No','Final':'Yes' if mn==fmn else 'No'}) | |
| sd.append(rw) | |
| sdf = pd.DataFrame(sd) | |
| cols = ['Model','Retained','Final']+[c for c in sdf.columns if c not in ['Model','Fold','Retained','Final']] | |
| sdf[cols].sort_values('AUC',ascending=False).to_excel(w,sheet_name='Summary',index=False) | |
| if len(dldf)>0: dldf.to_excel(w,sheet_name='DeLong',index=False) | |
| with pd.ExcelWriter(os.path.join(rf,'feature_ablation.xlsx'),engine='openpyxl') as w: | |
| for mn, ar in ablr.items(): | |
| pd.DataFrame({'N':ar['fcs'],'AUC':ar['aucs']}).to_excel(w,sheet_name=mn,index=False) | |
| if 'dl' in ar: ar['dl'].to_excel(w,sheet_name=f'{mn}_DL',index=False) | |
| for mn, idf in shap_imp.items(): | |
| idf.to_excel(w,sheet_name=f'{mn}_Imp',index=False) | |
| with open(os.path.join(rf,'best_params.txt'),'w',encoding='utf-8') as f: | |
| f.write("模型最佳超参数\n"+"="*50+"\n\n") | |
| for mn in mcfg: | |
| f.write(f"模型: {mn}\n") | |
| bp = bpd[mn] | |
| if isinstance(bp,dict): | |
| for k,v in bp.items(): f.write(f" {k}: {v}\n") | |
| else: f.write(f" {bp}\n") | |
| f.write(f" AUC: {amr[mn]['mean_auc']:.4f}\n 保留: {'是' if mn in retained else '否'}\n\n") | |
| if fmn: f.write(f"\n最终模型: {fmn}\n特征({fmi['nf']}): {', '.join(fmi['feats'])}\n") | |
| if fmn: | |
| pickle.dump({'model_name':fmn,'model':tms[fmn]['model'],'best_params':bpd[fmn], | |
| 'features':fmi['feats'],'n_features':fmi['nf'],'auc':fmi['auc'], | |
| 'threshold':amr[fmn]['optimal_threshold']}, | |
| open(os.path.join(rf,f'model_{fmn}.pkl'),'wb')) | |
| # ── ZIP ── | |
| progress(0.97, desc="📦 打包ZIP...") | |
| # 使用唯一文件名避免多用户冲突 | |
| zp = os.path.join(tempfile.gettempdir(), f"ml_results_{int(time.time())}_{os.getpid()}.zip") | |
| with zipfile.ZipFile(zp,'w',zipfile.ZIP_DEFLATED) as zf: | |
| for root,_,files in os.walk(rf): | |
| for fn in files: zf.write(os.path.join(root,fn), os.path.relpath(os.path.join(root,fn),rf)) | |
| nf = sum(len(f) for _,_,f in os.walk(rf)) | |
| # 立即清理临时结果文件夹(ZIP已打包完毕) | |
| shutil.rmtree(rf, ignore_errors=True) | |
| gc.collect() | |
| log(f"\n{'━'*50}") | |
| log(f" 🎉 分析完成!共 {nf} 个文件已打包") | |
| log(f" 💾 临时文件已自动清理") | |
| log(f"{'━'*50}") | |
| progress(1.0, desc="✅ 完成!") | |
| return zp, "\n".join(L) | |
| except Exception as e: | |
| log(f"\n❌ 错误: {e}") | |
| log(traceback.format_exc()) | |
| # 出错时也清理临时文件夹 | |
| if os.path.exists(rf): | |
| shutil.rmtree(rf, ignore_errors=True) | |
| gc.collect() | |
| return None, "\n".join(L) | |
| # ============================================================================ | |
| # Beautiful Gradio UI | |
| # ============================================================================ | |
| CUSTOM_CSS = """ | |
| /* ── Header Banner ── */ | |
| .header-banner { | |
| background: linear-gradient(135deg, #0a2463 0%, #1e3a7a 40%, #2554a8 100%); | |
| border-radius: 16px; | |
| padding: 28px 36px; | |
| margin-bottom: 20px; | |
| box-shadow: 0 8px 32px rgba(0,0,0,0.18); | |
| position: relative; | |
| overflow: hidden; | |
| } | |
| .header-banner::before { | |
| content: ''; | |
| position: absolute; | |
| top: -50%; | |
| right: -20%; | |
| width: 400px; | |
| height: 400px; | |
| background: radial-gradient(circle, rgba(96,165,250,0.2) 0%, transparent 70%); | |
| border-radius: 50%; | |
| } | |
| .header-banner img { | |
| max-height: 52px; | |
| border-radius: 6px; | |
| margin-bottom: 12px; | |
| } | |
| .header-banner h1 { | |
| color: #e2e8f0 !important; | |
| font-size: 1.7em !important; | |
| margin: 4px 0 6px 0 !important; | |
| font-weight: 700 !important; | |
| letter-spacing: 0.5px; | |
| } | |
| .header-banner p { | |
| color: #94a3b8 !important; | |
| font-size: 0.92em !important; | |
| margin: 2px 0 !important; | |
| line-height: 1.6; | |
| } | |
| .header-banner .credit { | |
| color: #64748b !important; | |
| font-size: 0.82em !important; | |
| margin-top: 10px !important; | |
| border-top: 1px solid rgba(148,163,184,0.15); | |
| padding-top: 10px; | |
| } | |
| /* ── Section Cards ── */ | |
| .section-title { | |
| background: linear-gradient(90deg, #2563eb 0%, #3b82f6 100%); | |
| color: white !important; | |
| padding: 8px 16px; | |
| border-radius: 8px; | |
| font-size: 0.95em !important; | |
| font-weight: 600 !important; | |
| margin: 12px 0 8px 0; | |
| letter-spacing: 0.3px; | |
| } | |
| /* ── Pipeline Steps ── */ | |
| .pipeline-box { | |
| background: linear-gradient(135deg, #f0f9ff 0%, #e0f2fe 100%); | |
| border: 1px solid #bae6fd; | |
| border-radius: 12px; | |
| padding: 14px 18px; | |
| margin: 8px 0; | |
| font-size: 0.88em; | |
| } | |
| .pipeline-box code { | |
| background: #2563eb; | |
| color: white; | |
| padding: 2px 8px; | |
| border-radius: 4px; | |
| font-size: 0.85em; | |
| margin: 0 2px; | |
| } | |
| /* ── Buttons ── */ | |
| .quick-btn { | |
| border-radius: 8px !important; | |
| font-weight: 500 !important; | |
| transition: all 0.2s ease !important; | |
| } | |
| .quick-btn:hover { | |
| transform: translateY(-1px) !important; | |
| box-shadow: 0 4px 12px rgba(0,0,0,0.1) !important; | |
| } | |
| /* ── Log Area ── */ | |
| .log-area textarea { | |
| font-family: 'Menlo', 'Consolas', 'Monaco', monospace !important; | |
| font-size: 12.5px !important; | |
| line-height: 1.5 !important; | |
| background: #0f172a !important; | |
| color: #e2e8f0 !important; | |
| border-radius: 10px !important; | |
| padding: 16px !important; | |
| border: 1px solid #1e293b !important; | |
| } | |
| /* ── General Polish ── */ | |
| .gradio-container { | |
| max-width: 1280px !important; | |
| } | |
| footer { display: none !important; } | |
| """ | |
| with gr.Blocks( | |
| title="ML 二分类模型平台 — 复旦大学附属眼耳鼻喉科医院", | |
| theme=gr.themes.Soft(primary_hue="blue", secondary_hue="slate", neutral_hue="slate"), | |
| css=CUSTOM_CSS, | |
| ) as demo: | |
| # ── Header ── | |
| gr.HTML(""" | |
| <div class="header-banner"> | |
| <img src="https://huggingface.co/spaces/fudan-renjun/machine-learning-2/resolve/main/hospital_logo.png" | |
| alt="Logo" onerror="this.style.display='none'"/> | |
| <h1>🧬 ML 二分类模型训练与评估平台</h1> | |
| <p>上传训练集与验证集 CSV,自动完成模型训练、交叉验证、统计检验、特征分析,结果打包下载</p> | |
| <p class="credit">复旦大学附属眼耳鼻喉科医院 · 检验科 · 任俊</p> | |
| </div> | |
| """) | |
| # ── Pipeline Info ── | |
| gr.HTML(""" | |
| <div class="pipeline-box"> | |
| <strong>📋 分析流程:</strong> | |
| <code>模型训练</code> → <code>交叉验证</code> → <code>DeLong检验</code> → | |
| <code>SHAP分析</code> → <code>特征消融</code> → <code>外部验证</code> | |
| | | |
| <strong>CSV格式:</strong> 第1列=标签, 第2列=ID, 第3列起=特征 | |
| </div> | |
| """) | |
| with gr.Row(equal_height=False): | |
| # ══ Left Panel ══ | |
| with gr.Column(scale=5): | |
| gr.HTML('<div class="section-title">📂 数据上传</div>') | |
| train_file = gr.File(label="训练集 CSV(必需)", file_types=[".csv"]) | |
| gr.HTML('<p style="color:#64748b;font-size:0.85em;margin:4px 0 8px 0;">验证集为可选项,支持同时上传 1~3 个验证集,分别生成独立的评估报告</p>') | |
| with gr.Row(): | |
| val_file1 = gr.File(label="验证集 1(可选)", file_types=[".csv"], scale=1) | |
| val_file2 = gr.File(label="验证集 2(可选)", file_types=[".csv"], scale=1) | |
| val_file3 = gr.File(label="验证集 3(可选)", file_types=[".csv"], scale=1) | |
| gr.HTML('<div class="section-title">🤖 模型选择</div>') | |
| model_selector = gr.Dropdown( | |
| choices=ALL_MODEL_NAMES, | |
| value=ALL_MODEL_NAMES, | |
| multiselect=True, | |
| label="选择模型(可多选,默认全部)", | |
| info="RF=随机森林 DT=决策树 KNN=K近邻 XGB=极限梯度提升 AdaBoost=自适应提升 LR=逻辑回归 NB=朴素贝叶斯 SVM=支持向量机", | |
| ) | |
| with gr.Row(): | |
| btn_all = gr.Button("🔘 全选", size="sm", variant="secondary", elem_classes="quick-btn") | |
| btn_tree = gr.Button("🌲 树模型", size="sm", variant="secondary", elem_classes="quick-btn") | |
| btn_linear = gr.Button("📐 线性模型", size="sm", variant="secondary", elem_classes="quick-btn") | |
| btn_top4 = gr.Button("⚡ 经典四模型", size="sm", variant="secondary", elem_classes="quick-btn") | |
| btn_all.click(lambda: ALL_MODEL_NAMES, outputs=model_selector) | |
| btn_tree.click(lambda: ['RF','DT','XGB','AdaBoost'], outputs=model_selector) | |
| btn_linear.click(lambda: ['LR','SVM','NB'], outputs=model_selector) | |
| btn_top4.click(lambda: ['RF','XGB','LR','SVM'], outputs=model_selector) | |
| gr.HTML('<div class="section-title">⚙️ 参数配置</div>') | |
| enable_tuning = gr.Checkbox(value=False, label="启用超参数调优 (GridSearchCV) ⚠️ 开启后运行时间显著增加") | |
| with gr.Row(): | |
| cv_folds = gr.Slider(3, 10, value=5, step=1, label="交叉验证折数") | |
| alpha_sl = gr.Slider(0.01, 0.10, value=0.05, step=0.01, label="DeLong 显著性水平 α") | |
| with gr.Row(): | |
| top_n = gr.Slider(5, 50, value=20, step=1, label="SHAP 前 N 个特征") | |
| shap_sz = gr.Slider(30, 200, value=80, step=10, label="SHAP 采样数量") | |
| run_btn = gr.Button("🚀 开始分析", variant="primary", size="lg") | |
| # ══ Right Panel ══ | |
| with gr.Column(scale=5): | |
| gr.HTML('<div class="section-title">📋 运行日志</div>') | |
| log_output = gr.Textbox( | |
| label="", lines=22, max_lines=50, interactive=False, | |
| placeholder="点击「开始分析」后,运行日志将在此实时显示...", | |
| elem_classes="log-area", | |
| ) | |
| gr.HTML('<div class="section-title">⬇️ 结果下载</div>') | |
| zip_output = gr.File(label="分析结果 ZIP 压缩包") | |
| # ── Connect ── | |
| run_btn.click( | |
| fn=run_pipeline, | |
| inputs=[train_file, val_file1, val_file2, val_file3, model_selector, enable_tuning, cv_folds, alpha_sl, top_n, shap_sz], | |
| outputs=[zip_output, log_output], | |
| api_name="run", | |
| ) | |
| # ============================================================================ | |
| # Authentication with Expiration | |
| # ============================================================================ | |
| from datetime import datetime | |
| # ┌─────────────────────────────────────────────────┐ | |
| # │ 账号配置 — 在这里修改账号、密码和有效期 │ | |
| # │ 格式: "用户名": {"password": "密码", │ | |
| # │ "expires": "YYYY-MM-DD"} │ | |
| # │ 如果不需要过期限制,设 "expires": None │ | |
| # └─────────────────────────────────────────────────┘ | |
| ACCOUNTS = { | |
| "admin": { | |
| "password": "admin123", | |
| "expires": None, # 永不过期 | |
| }, | |
| "renjun": { | |
| "password": "fudan2025", | |
| "expires": "2026-12-31", # 2026年12月31日过期 | |
| }, | |
| "guest": { | |
| "password": "guest888", | |
| "expires": "2025-06-30", # 示例:已过期账号 | |
| }, | |
| } | |
| def auth_fn(username, password): | |
| """验证账号密码 + 检查有效期""" | |
| user = ACCOUNTS.get(username) | |
| if not user: | |
| return False | |
| if user["password"] != password: | |
| return False | |
| if user["expires"] is not None: | |
| try: | |
| exp_date = datetime.strptime(user["expires"], "%Y-%m-%d") | |
| if datetime.now() > exp_date: | |
| return False | |
| except ValueError: | |
| return False | |
| return True | |
| demo.queue() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| auth=auth_fn, | |
| auth_message="🔐 复旦大学附属眼耳鼻喉科医院 · ML分析平台\n请输入账号和密码登录", | |
| ssr_mode=False, | |
| ) | |