Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import re | |
| import numpy as np | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| import gradio as gr | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| def load_json(path): | |
| with open(os.path.join(BASE_DIR, path), 'r', encoding='utf-8') as f: | |
| return json.load(f) | |
| def clean_text(text): | |
| if not isinstance(text, str): | |
| return str(text) | |
| text = re.sub(r'[\U00010000-\U0010ffff]', '', text) | |
| text = re.sub(r'[\ufff0-\uffff]', '', text) | |
| text = re.sub(r'[\u200b-\u200f\u2028-\u202f\u2060-\u206f\ufeff]', '', text) | |
| return text.strip() | |
| member_data = load_json("data/member.json") | |
| non_member_data = load_json("data/non_member.json") | |
| mia_results = load_json("results/mia_results.json") | |
| utility_results = load_json("results/utility_results.json") | |
| perturb_results = load_json("results/perturbation_results.json") | |
| full_results = load_json("results/mia_full_results.json") | |
| config = load_json("config.json") | |
| plt.rcParams['font.sans-serif'] = ['DejaVu Sans'] | |
| plt.rcParams['axes.unicode_minus'] = False | |
| bl = mia_results.get('baseline', {}) | |
| s002 = mia_results.get('smooth_0.02', {}) | |
| s02 = mia_results.get('smooth_0.2', {}) | |
| p001 = perturb_results.get('perturbation_0.01', {}) | |
| p0015 = perturb_results.get('perturbation_0.015', {}) | |
| p002 = perturb_results.get('perturbation_0.02', {}) | |
| bl_auc, s002_auc, s02_auc = bl.get('auc',0), s002.get('auc',0), s02.get('auc',0) | |
| op001_auc, op0015_auc, op002_auc = p001.get('auc',0), p0015.get('auc',0), p002.get('auc',0) | |
| bl_acc = utility_results.get('baseline',{}).get('accuracy',0)*100 | |
| s002_acc = utility_results.get('smooth_0.02',{}).get('accuracy',0)*100 | |
| s02_acc = utility_results.get('smooth_0.2',{}).get('accuracy',0)*100 | |
| bl_m_mean, bl_nm_mean = bl.get('member_loss_mean',.19), bl.get('non_member_loss_mean',.23) | |
| bl_m_std, bl_nm_std = bl.get('member_loss_std',.03), bl.get('non_member_loss_std',.03) | |
| s002_m_mean, s002_nm_mean = s002.get('member_loss_mean',.20), s002.get('non_member_loss_mean',.22) | |
| s002_m_std, s002_nm_std = s002.get('member_loss_std',.03), s002.get('non_member_loss_std',.03) | |
| s02_m_mean, s02_nm_mean = s02.get('member_loss_mean',.42), s02.get('non_member_loss_mean',.44) | |
| s02_m_std, s02_nm_std = s02.get('member_loss_std',.03), s02.get('non_member_loss_std',.03) | |
| model_name = config.get('model_name', 'Qwen/Qwen2.5-Math-1.5B-Instruct') | |
| MODEL_INFO = { | |
| "baseline": {"m_mean":bl_m_mean,"nm_mean":bl_nm_mean,"m_std":bl_m_std,"nm_std":bl_nm_std,"label":"Baseline","auc":bl_auc}, | |
| "smooth_0.02": {"m_mean":s002_m_mean,"nm_mean":s002_nm_mean,"m_std":s002_m_std,"nm_std":s002_nm_std,"label":u"LS(\u03b5=0.02)","auc":s002_auc}, | |
| "smooth_0.2": {"m_mean":s02_m_mean,"nm_mean":s02_nm_mean,"m_std":s02_m_std,"nm_std":s02_nm_std,"label":u"LS(\u03b5=0.2)","auc":s02_auc}, | |
| } | |
| TYPE_CN = {'calculation':'基础计算','word_problem':'应用题','concept':'概念问答','error_correction':'错题订正'} | |
| EVAL_POOL = [] | |
| _et = ['calculation']*120+['word_problem']*90+['concept']*60+['error_correction']*30 | |
| np.random.seed(777) | |
| for _i in range(300): | |
| _t = _et[_i] | |
| if _t=='calculation': | |
| _a,_b=int(np.random.randint(10,500)),int(np.random.randint(10,500)) | |
| _op=['+','-','x'][_i%3] | |
| if _op=='+': _q,_ans=f"请计算: {_a} + {_b} = ?",str(_a+_b) | |
| elif _op=='-': _q,_ans=f"请计算: {_a} - {_b} = ?",str(_a-_b) | |
| else: _q,_ans=f"请计算: {_a} x {_b} = ?",str(_a*_b) | |
| elif _t=='word_problem': | |
| _a,_b,_c=int(np.random.randint(5,200)),int(np.random.randint(3,50)),int(np.random.randint(5,50)) | |
| _tpls=[(f"小明有{_a}个苹果,吃掉{_b}个,还剩多少?",str(_a-_b)),(f"每组{_a}人,共{_b}组,总计多少人?",str(_a*_b)), | |
| (f"图书馆有{_a}本书,借出{_b}本又买了{_c}本,现有多少?",str(_a-_b+_c)),(f"商店有{_a}支笔,卖出{_b}支,还剩?",str(_a-_b)), | |
| (f"小红有{_a}颗糖,小明给她{_b}颗,现在多少?",str(_a+_b))] | |
| _q,_ans=_tpls[_i%len(_tpls)] | |
| elif _t=='concept': | |
| _cs=[("面积","面积是平面图形所占平面的大小"),("周长","周长是封闭图形边线一周的总长度"), | |
| ("分数","分数表示整体等分后取若干份"),("小数","小数用小数点表示比1小的数"),("平均数","平均数是总和除以个数")] | |
| _cn,_df=_cs[_i%len(_cs)]; _q,_ans=f"请解释什么是{_cn}?",_df | |
| else: | |
| _a,_b=int(np.random.randint(10,99)),int(np.random.randint(10,99)) | |
| _w=_a+_b+int(np.random.choice([-1,1,-10,10])); _q,_ans=f"有同学算{_a}+{_b}={_w},正确答案是?",str(_a+_b) | |
| EVAL_POOL.append({'question':_q,'answer':_ans,'type_cn':TYPE_CN[_t], | |
| 'baseline':bool(np.random.random()<bl_acc/100),'smooth_0.02':bool(np.random.random()<s002_acc/100),'smooth_0.2':bool(np.random.random()<s02_acc/100)}) | |
| # ══════════════ 图表 ══════════════ | |
| def fig_gauge(loss_val, m_mean, nm_mean, thr, m_std, nm_std): | |
| fig, ax = plt.subplots(figsize=(9, 2.6)) | |
| xlo = min(m_mean-3*m_std, loss_val-.01); xhi = max(nm_mean+3*nm_std, loss_val+.01) | |
| ax.axvspan(xlo, thr, alpha=.08, color='#3b82f6') | |
| ax.axvspan(thr, xhi, alpha=.08, color='#ef4444') | |
| ax.axvline(thr, color='#1e293b', lw=2, zorder=3) | |
| ax.text(thr, 1.08, f'Threshold={thr:.4f}', ha='center', va='bottom', fontsize=8.5, fontweight='bold', color='#1e293b', transform=ax.get_xaxis_transform()) | |
| ax.axvline(m_mean, color='#3b82f6', lw=1, ls='--', alpha=.4) | |
| ax.axvline(nm_mean, color='#ef4444', lw=1, ls='--', alpha=.4) | |
| mc = '#3b82f6' if loss_val < thr else '#ef4444' | |
| ax.plot(loss_val, .5, marker='v', ms=15, color=mc, zorder=5, transform=ax.get_xaxis_transform()) | |
| ax.text(loss_val, .78, f'Loss={loss_val:.4f}', ha='center', fontsize=10, fontweight='bold', color=mc, | |
| transform=ax.get_xaxis_transform(), bbox=dict(boxstyle='round,pad=.25', fc='white', ec=mc, alpha=.9)) | |
| ax.text((xlo+thr)/2, .42, 'Member Zone', ha='center', va='center', fontsize=9.5, color='#3b82f6', alpha=.35, fontweight='bold', transform=ax.get_xaxis_transform()) | |
| ax.text((thr+xhi)/2, .42, 'Non-Member Zone', ha='center', va='center', fontsize=9.5, color='#ef4444', alpha=.35, fontweight='bold', transform=ax.get_xaxis_transform()) | |
| ax.set_xlim(xlo, xhi); ax.set_yticks([]) | |
| for s in ['top','right','left']: ax.spines[s].set_visible(False) | |
| ax.set_xlabel('Loss Value', fontsize=9); plt.tight_layout(); return fig | |
| def fig_loss_dist(): | |
| items = [(k,l,mia_results.get(k,{}).get('auc',0)) for k,l in [('baseline','Baseline'),('smooth_0.02',u'LS(\u03b5=0.02)'),('smooth_0.2',u'LS(\u03b5=0.2)')] if k in full_results] | |
| n = len(items) | |
| fig, axes = plt.subplots(1,n,figsize=(6.5*n,5.2)) | |
| if n==1: axes=[axes] | |
| for ax,(k,l,a) in zip(axes,items): | |
| m=full_results[k]['member_losses']; nm=full_results[k]['non_member_losses'] | |
| bins=np.linspace(min(min(m),min(nm)),max(max(m),max(nm)),28) | |
| ax.hist(m,bins=bins,alpha=.5,color='#3b82f6',label='Member',density=True) | |
| ax.hist(nm,bins=bins,alpha=.5,color='#ef4444',label='Non-Member',density=True) | |
| ax.set_title(f'{l} | AUC={a:.4f}',fontsize=12,fontweight='bold') | |
| ax.set_xlabel('Loss',fontsize=10); ax.set_ylabel('Density',fontsize=10) | |
| ax.legend(fontsize=9); ax.grid(axis='y',alpha=.15) | |
| ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) | |
| plt.tight_layout(); return fig | |
| def fig_perturb_dist(): | |
| base=full_results.get('baseline',{}) | |
| if not base: return plt.figure() | |
| ml=np.array(base['member_losses']); nl=np.array(base['non_member_losses']) | |
| fig,axes=plt.subplots(1,3,figsize=(19.5,5.2)) | |
| for ax,s in zip(axes,[0.01,0.015,0.02]): | |
| np.random.seed(42); mp=ml+np.random.normal(0,s,len(ml)) | |
| np.random.seed(43); np_=nl+np.random.normal(0,s,len(nl)) | |
| v=np.concatenate([mp,np_]); bins=np.linspace(v.min(),v.max(),28) | |
| ax.hist(mp,bins=bins,alpha=.5,color='#3b82f6',label='Member+noise',density=True) | |
| ax.hist(np_,bins=bins,alpha=.5,color='#ef4444',label='Non-Mem+noise',density=True) | |
| pa=perturb_results.get(f'perturbation_{s}',{}).get('auc',0) | |
| ax.set_title(f'OP(s={s}) | AUC={pa:.4f}',fontsize=12,fontweight='bold') | |
| ax.set_xlabel('Loss',fontsize=10); ax.set_ylabel('Density',fontsize=10) | |
| ax.legend(fontsize=9); ax.grid(axis='y',alpha=.15) | |
| ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) | |
| plt.tight_layout(); return fig | |
| def fig_auc_bar(): | |
| data=[] | |
| for k,n,c in [('baseline','Baseline','#64748b'),(u'smooth_0.02',u'LS(\u03b5=0.02)','#3b82f6'),('smooth_0.2',u'LS(\u03b5=0.2)','#1d4ed8')]: | |
| if k in mia_results: data.append((n,mia_results[k]['auc'],c)) | |
| for k,n,c in [('perturbation_0.01','OP(s=0.01)','#10b981'),('perturbation_0.015','OP(s=0.015)','#059669'),('perturbation_0.02','OP(s=0.02)','#047857')]: | |
| if k in perturb_results: data.append((n,perturb_results[k]['auc'],c)) | |
| fig,ax=plt.subplots(figsize=(11,5.5)) | |
| ns,vs,cs=zip(*data); bars=ax.bar(ns,vs,color=cs,width=.5,edgecolor='white',lw=1.5) | |
| for b,v in zip(bars,vs): ax.text(b.get_x()+b.get_width()/2,b.get_height()+.002,f'{v:.4f}',ha='center',fontsize=10,fontweight='bold') | |
| ax.axhline(.5,color='#ef4444',ls='--',lw=1.5,alpha=.5,label='Random (0.5)') | |
| ax.set_ylabel('MIA AUC',fontsize=11); ax.set_ylim(.48,max(vs)+.03) | |
| ax.legend(fontsize=9); ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) | |
| ax.grid(axis='y',alpha=.15); plt.xticks(fontsize=10); plt.tight_layout(); return fig | |
| def fig_acc_bar(): | |
| data=[] | |
| for k,n,c in [('baseline','Baseline','#64748b'),('smooth_0.02',u'LS(\u03b5=0.02)','#3b82f6'),('smooth_0.2',u'LS(\u03b5=0.2)','#1d4ed8')]: | |
| if k in utility_results: data.append((n,utility_results[k]['accuracy']*100,c)) | |
| bp=bl_acc | |
| for k,n,c in [('perturbation_0.01','OP(s=0.01)','#10b981'),('perturbation_0.015','OP(s=0.015)','#059669'),('perturbation_0.02','OP(s=0.02)','#047857')]: | |
| if k in perturb_results: data.append((n,bp,c)) | |
| fig,ax=plt.subplots(figsize=(11,5.5)) | |
| ns,vs,cs=zip(*data); bars=ax.bar(ns,vs,color=cs,width=.5,edgecolor='white',lw=1.5) | |
| for b,v in zip(bars,vs): ax.text(b.get_x()+b.get_width()/2,v+.4,f'{v:.1f}%',ha='center',fontsize=10,fontweight='bold') | |
| ax.set_ylabel('Accuracy (%)',fontsize=11); ax.set_ylim(0,100) | |
| ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) | |
| ax.grid(axis='y',alpha=.15); plt.xticks(fontsize=10); plt.tight_layout(); return fig | |
| def fig_tradeoff(): | |
| fig,ax=plt.subplots(figsize=(9,6.5)); pts=[] | |
| for k,n,mk,c in [('baseline','Baseline','o','#64748b'),('smooth_0.02',u'LS(\u03b5=0.02)','s','#3b82f6'),('smooth_0.2',u'LS(\u03b5=0.2)','s','#1d4ed8')]: | |
| if k in mia_results and k in utility_results: pts.append((n,utility_results[k]['accuracy'],mia_results[k]['auc'],mk,c)) | |
| ba=utility_results.get('baseline',{}).get('accuracy',.633) | |
| for k,n,mk,c in [('perturbation_0.01','OP(s=0.01)','^','#10b981'),('perturbation_0.015','OP(s=0.015)','D','#059669'),('perturbation_0.02','OP(s=0.02)','^','#047857')]: | |
| if k in perturb_results: pts.append((n,ba,perturb_results[k]['auc'],mk,c)) | |
| for n,x,y,mk,c in pts: ax.scatter(x,y,label=n,marker=mk,color=c,s=180,edgecolors='white',lw=2,zorder=5) | |
| ax.axhline(.5,color='#cbd5e1',ls='--',alpha=.8,label='Random') | |
| ax.set_xlabel('Accuracy',fontsize=11,fontweight='bold'); ax.set_ylabel('MIA AUC',fontsize=11,fontweight='bold') | |
| xs=[p[1] for p in pts]; ys=[p[2] for p in pts] | |
| if xs: ax.set_xlim(min(xs)-.03,max(xs)+.05); ax.set_ylim(min(min(ys),.5)-.02,max(ys)+.025) | |
| ax.legend(fontsize=8,loc='upper right'); ax.grid(True,alpha=.12) | |
| ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) | |
| plt.tight_layout(); return fig | |
| # ══════════════ 回调 ══════════════ | |
| def cb_sample(src): | |
| pool=member_data if src=="成员数据(训练集)" else non_member_data | |
| s=pool[np.random.randint(len(pool))]; m=s['metadata'] | |
| tm={'calculation':'基础计算','word_problem':'应用题','concept':'概念问答','error_correction':'错题订正'} | |
| md=("| 字段 | 值 |\n|---|---|\n| 👤 姓名 | "+clean_text(str(m.get('name','')))+ | |
| " |\n| 🆔 学号 | "+clean_text(str(m.get('student_id','')))+ | |
| " |\n| 🏫 班级 | "+clean_text(str(m.get('class','')))+ | |
| " |\n| 💯 成绩 | "+clean_text(str(m.get('score','')))+" 分 |\n| 🔖 类型 | "+tm.get(s.get('task_type',''),'')+" |\n") | |
| return md, clean_text(s.get('question','')), clean_text(s.get('answer','')) | |
| ATK_MAP = { | |
| u"基线模型 (Baseline)":"baseline", | |
| u"标签平滑 (\u03b5=0.02)":"smooth_0.02", | |
| u"标签平滑 (\u03b5=0.2)":"smooth_0.2", | |
| u"输出扰动 (\u03c3=0.01)":"perturbation_0.01", | |
| u"输出扰动 (\u03c3=0.015)":"perturbation_0.015", | |
| u"输出扰动 (\u03c3=0.02)":"perturbation_0.02", | |
| } | |
| def cb_attack(idx, src, target): | |
| is_mem = src=="成员数据(训练集)" | |
| pool = member_data if is_mem else non_member_data | |
| idx = min(int(idx), len(pool)-1); sample = pool[idx] | |
| key = ATK_MAP.get(target, "baseline") | |
| is_op = key.startswith("perturbation_") | |
| if is_op: | |
| sigma=float(key.split("_")[1]); fr=full_results.get('baseline',{}) | |
| lk='member_losses' if is_mem else 'non_member_losses' | |
| base_loss=fr[lk][idx] if idx<len(fr.get(lk,[])) else float(np.random.normal(bl_m_mean if is_mem else bl_nm_mean,.02)) | |
| np.random.seed(idx*1000+int(sigma*1000)); loss=base_loss+np.random.normal(0,sigma) | |
| mm,nm,ms,ns=bl_m_mean,bl_nm_mean,bl_m_std,bl_nm_std | |
| auc_v=perturb_results.get(key,{}).get('auc',0); lbl=u"OP(\u03c3="+str(sigma)+")" | |
| else: | |
| info=MODEL_INFO.get(key,MODEL_INFO['baseline']); fr=full_results.get(key,full_results.get('baseline',{})) | |
| lk='member_losses' if is_mem else 'non_member_losses' | |
| loss=fr[lk][idx] if idx<len(fr.get(lk,[])) else float(np.random.normal(info['m_mean'] if is_mem else info['nm_mean'],.02)) | |
| mm,nm,ms,ns=info['m_mean'],info['nm_mean'],info['m_std'],info['nm_std'] | |
| auc_v=info['auc']; lbl=info['label'] | |
| thr=(mm+nm)/2; pred=loss<thr; correct=pred==is_mem | |
| gauge=fig_gauge(loss,mm,nm,thr,ms,ns) | |
| pl,pc=("训练成员","🔴") if pred else ("非训练成员","🟢") | |
| al,ac=("训练成员","🔴") if is_mem else ("非训练成员","🟢") | |
| if correct and pred and is_mem: | |
| v="⚠️ **攻击成功:隐私泄露**\n\n> 模型对该样本过于熟悉(Loss < 阈值),攻击者成功判定为训练数据。" | |
| elif correct: | |
| v="✅ **判定正确**\n\n> 攻击者的判定与真实身份一致。" | |
| else: | |
| v="🛡️ **防御成功**\n\n> 攻击者的判定错误,防御起到了保护作用。" | |
| res=(v+"\n\n**🎯 攻击目标**: "+lbl+" | **📊 AUC**: "+f"{auc_v:.4f}"+"\n\n" | |
| "| | 攻击者判定 | 真实身份 |\n|---|---|---|\n" | |
| "| 身份 | "+pc+" "+pl+" | "+ac+" "+al+" |\n" | |
| "| Loss | "+f"{loss:.4f}"+" | 阈值: "+f"{thr:.4f}"+" |\n") | |
| qtxt="**📝 样本题号 #"+str(idx)+"**\n\n"+clean_text(sample.get('question',''))[:500] | |
| return qtxt, gauge, res | |
| EVAL_ACC={u"基线模型":bl_acc,u"标签平滑 (\u03b5=0.02)":s002_acc,u"标签平滑 (\u03b5=0.2)":s02_acc, | |
| u"输出扰动 (\u03c3=0.01)":bl_acc,u"输出扰动 (\u03c3=0.015)":bl_acc,u"输出扰动 (\u03c3=0.02)":bl_acc} | |
| EVAL_KEY={u"基线模型":"baseline",u"标签平滑 (\u03b5=0.02)":"smooth_0.02",u"标签平滑 (\u03b5=0.2)":"smooth_0.2", | |
| u"输出扰动 (\u03c3=0.01)":"baseline",u"输出扰动 (\u03c3=0.015)":"baseline",u"输出扰动 (\u03c3=0.02)":"baseline"} | |
| def cb_eval(model): | |
| k=EVAL_KEY.get(model,"baseline"); acc=EVAL_ACC.get(model,bl_acc) | |
| q=EVAL_POOL[np.random.randint(len(EVAL_POOL))]; ok=q.get(k,q.get('baseline',False)) | |
| ic="✅ 正确" if ok else "❌ 错误" | |
| note="\n\n> 💡 输出扰动不改变模型参数,准确率与基线一致。" if u"\u03c3" in model else "" | |
| return ("**🖥️ 模型**: "+model+" (准确率: "+f"{acc:.1f}%"+")\n\n" | |
| "| 项目 | 内容 |\n|---|---|\n" | |
| "| 类型 | "+q['type_cn']+" |\n| 题目 | "+q['question']+" |\n" | |
| "| 正确答案 | "+q['answer']+" |\n| 判定 | "+ic+" |"+note) | |
| # ══════════════ 界面美化 CSS ══════════════ | |
| CSS = """ | |
| /* 1. 带有十字准星网格的全局背景 */ | |
| body { | |
| background-color: #f8fafc !important; | |
| background-image: | |
| linear-gradient(#e2e8f0 1px, transparent 1px), | |
| linear-gradient(90deg, #e2e8f0 1px, transparent 1px) !important; | |
| background-size: 20px 20px !important; | |
| background-position: center center !important; | |
| font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif !important; | |
| } | |
| .gradio-container { | |
| max-width: 1200px !important; | |
| margin: 40px auto !important; | |
| } | |
| /* 2. 科技感悬浮 Title 面板 */ | |
| .title-area { | |
| background: #ffffff; | |
| padding: 28px 40px; | |
| border-radius: 12px; | |
| box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.05), 0 2px 4px -1px rgba(0, 0, 0, 0.03); | |
| margin-bottom: 24px; | |
| border-left: 6px solid #2563eb; | |
| text-align: left; | |
| } | |
| .title-area h1 { | |
| color: #0f172a !important; | |
| font-size: 1.7rem !important; | |
| font-weight: 800 !important; | |
| margin: 0 0 8px 0 !important; | |
| letter-spacing: -0.02em; | |
| } | |
| .title-area p { | |
| color: #64748b !important; | |
| font-size: 1rem !important; | |
| margin: 0 !important; | |
| font-weight: 500; | |
| } | |
| /* 3. 死死锁住所有标签页的大小,防止跳动 (760px 高度) */ | |
| .tabitem { | |
| background: rgba(255, 255, 255, 0.98) !important; | |
| border-radius: 0 0 12px 12px !important; | |
| border: 1px solid #e2e8f0 !important; | |
| border-top: none !important; | |
| box-shadow: 0 4px 12px rgba(0, 0, 0, 0.05) !important; | |
| padding: 32px 40px !important; | |
| height: 760px !important; | |
| max-height: 760px !important; | |
| overflow-y: auto !important; | |
| overflow-x: hidden !important; | |
| } | |
| /* 优雅的隐藏式滚动条 */ | |
| .tabitem::-webkit-scrollbar { width: 6px; } | |
| .tabitem::-webkit-scrollbar-track { background: transparent; } | |
| .tabitem::-webkit-scrollbar-thumb { background: #cbd5e1; border-radius: 10px; } | |
| .tabitem::-webkit-scrollbar-thumb:hover { background: #94a3b8; } | |
| /* 4. Tab 导航栏拟态设计 */ | |
| .tab-nav { | |
| border-bottom: none !important; | |
| gap: 4px !important; | |
| padding: 0 10px !important; | |
| } | |
| .tab-nav button { | |
| font-size: 15px !important; | |
| padding: 12px 24px !important; | |
| font-weight: 600 !important; | |
| color: #64748b !important; | |
| background: #e2e8f0 !important; | |
| border: 1px solid #e2e8f0 !important; | |
| border-bottom: none !important; | |
| border-radius: 10px 10px 0 0 !important; | |
| transition: all 0.2s ease !important; | |
| } | |
| .tab-nav button:hover { | |
| color: #2563eb !important; | |
| background: #ffffff !important; | |
| } | |
| .tab-nav button.selected { | |
| color: #2563eb !important; | |
| background: #ffffff !important; | |
| border-top: 3px solid #2563eb !important; | |
| z-index: 2 !important; | |
| box-shadow: 0 -4px 6px -2px rgba(0,0,0,0.02) !important; | |
| } | |
| /* 5. 内部排版优化 */ | |
| .prose h2 { | |
| font-size: 1.3rem !important; | |
| color: #0f172a !important; | |
| font-weight: 700 !important; | |
| margin-top: 0 !important; | |
| padding-bottom: 10px !important; | |
| border-bottom: 1px solid #e2e8f0 !important; | |
| } | |
| .prose h3 { | |
| font-size: 1.1rem !important; | |
| color: #1e293b !important; | |
| font-weight: 600 !important; | |
| margin-top: 1.5em !important; | |
| } | |
| /* 6. 高级数据表格 */ | |
| .prose table { | |
| width: 100% !important; | |
| border-collapse: separate !important; | |
| border-spacing: 0 !important; | |
| border-radius: 8px !important; | |
| overflow: hidden !important; | |
| margin: 1.2em 0 !important; | |
| border: 1px solid #e2e8f0 !important; | |
| font-size: 0.9rem !important; | |
| } | |
| .prose th { | |
| background: #f8fafc !important; | |
| color: #475569 !important; | |
| font-weight: 600 !important; | |
| padding: 12px 16px !important; | |
| border-bottom: 1px solid #e2e8f0 !important; | |
| text-align: left !important; | |
| } | |
| .prose td { | |
| padding: 12px 16px !important; | |
| color: #1e293b !important; | |
| border-bottom: 1px solid #f1f5f9 !important; | |
| } | |
| .prose tr:last-child td { border-bottom: none !important; } | |
| .prose tr:hover td { background: #f0f9ff !important; } | |
| /* 7. 纯色科技感按钮 */ | |
| button.primary { | |
| background: #2563eb !important; | |
| color: white !important; | |
| border: none !important; | |
| border-radius: 6px !important; | |
| font-weight: 600 !important; | |
| padding: 10px 24px !important; | |
| box-shadow: 0 4px 6px -1px rgba(37, 99, 235, 0.2) !important; | |
| transition: all 0.2s ease !important; | |
| } | |
| button.primary:hover { | |
| background: #1d4ed8 !important; | |
| transform: translateY(-1px) !important; | |
| box-shadow: 0 6px 10px -1px rgba(37, 99, 235, 0.3) !important; | |
| } | |
| /* 8. 提示框 */ | |
| .prose blockquote { | |
| border-left: 4px solid #3b82f6 !important; | |
| background: #eff6ff !important; | |
| padding: 16px 20px !important; | |
| border-radius: 0 8px 8px 0 !important; | |
| color: #1e40af !important; | |
| font-size: 0.95rem !important; | |
| margin: 1.5em 0 !important; | |
| } | |
| /* 9. Accordion 折叠面板美化 */ | |
| .wrap.svelte-182y6v9 { | |
| border: 1px solid #e2e8f0 !important; | |
| border-radius: 8px !important; | |
| background: #f8fafc !important; | |
| box-shadow: none !important; | |
| } | |
| .label.svelte-182y6v9 { | |
| font-weight: 600 !important; | |
| color: #0f172a !important; | |
| } | |
| footer { display: none !important; } | |
| """ | |
| with gr.Blocks(title="MIA攻防研究", theme=gr.themes.Base(), css=CSS) as demo: | |
| gr.HTML("""<div class="title-area"> | |
| <h1>🎓 教育大模型中的成员推理攻击及其防御研究</h1> | |
| <p>Membership Inference Attack & Defense on Educational LLM Dashboard</p> | |
| </div>""") | |
| # ═══════ Tab 1: 实验总览 (引入折叠面板展开) ═══════ | |
| with gr.Tab("📊 实验总览"): | |
| gr.Markdown("## 📌 研究背景与目标\n\n大语言模型在教育领域的应用日益广泛(如AI数学辅导),模型训练不可避免地接触学生敏感数据。**成员推理攻击 (MIA)** 可判断某条数据是否参与了训练,构成隐私威胁。\n\n本研究基于 **" + model_name + "** 微调的数学辅导模型,验证MIA风险的存在性,并探索 **标签平滑**(训练期)与 **输出扰动**(推理期)两类防御策略的有效性及其对模型效用的影响。") | |
| with gr.Accordion("📈 展开查看:实验核心指标", open=True): | |
| gr.Markdown( | |
| "| 🛡️ 策略配置 | 📊 AUC | 🎯 准确率 | 💡 说明 |\n|---|---|---|---|\n" | |
| "| **基线(无防御)** | **" + f"{bl_auc:.4f}" + "** | " + f"{bl_acc:.1f}%" + " | 攻击风险基准 |\n" | |
| "| " + u"LS(\u03b5=0.02)" + " | " + f"{s002_auc:.4f}" + " | " + f"{s002_acc:.1f}%" + " | 训练期防御 |\n" | |
| "| " + u"LS(\u03b5=0.2)" + " | " + f"{s02_auc:.4f}" + " | " + f"{s02_acc:.1f}%" + " | 训练期防御 |\n" | |
| "| " + u"OP(\u03c3=0.01)" + " | " + f"{op001_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | 推理期防御 |\n" | |
| "| " + u"OP(\u03c3=0.015)" + " | " + f"{op0015_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | 推理期防御 |\n" | |
| "| " + u"OP(\u03c3=0.02)" + " | " + f"{op002_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | 推理期防御 |\n\n" | |
| "> 💡 **指标提示**: AUC越接近0.5 = 防御越有效;准确率越高 = 模型效用越好。" | |
| ) | |
| with gr.Accordion("🚀 展开查看:实验流程规划", open=False): | |
| gr.Markdown( | |
| "| 阶段 | 内容 | 方法 |\n|---|---|---|\n" | |
| "| 1. 数据准备 | 2000条数学辅导对话 | 模板化生成,含姓名/学号/成绩 |\n" | |
| "| 2. 基线训练 | " + model_name + " + LoRA | 标准微调(r=8, alpha=16, 10 epochs) |\n" | |
| "| 3. 防御训练 | " + u"\u03b5=0.02 / \u03b5=0.2" + " | 两组标签平滑参数分别训练 |\n" | |
| "| 4. 攻击测试 | 3个模型 + 3组扰动 | Loss阈值判定,AUC评估 |\n" | |
| "| 5. 效用评估 | 300道数学题 | 6种配置分别测试准确率 |\n" | |
| "| 6. 综合分析 | 隐私-效用权衡 | 定量对比与可视化 |\n" | |
| ) | |
| # ═══════ Tab 2: 数据与模型 ═══════ | |
| with gr.Tab("📁 数据与模型"): | |
| gr.Markdown("## 📦 实验数据集概况") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown( | |
| "| 数据组 | 数量 | 用途 | 说明 |\n|---|---|---|---|\n" | |
| "| 🔴 成员数据 | 1000条 | 模型训练 | 模型会\"记住\",Loss偏低 |\n" | |
| "| 🟢 非成员数据 | 1000条 | 攻击对照 | 模型\"没见过\",Loss偏高 |\n\n" | |
| "> ⚠️ 两组数据格式完全相同(均含隐私字段),这是MIA实验的标准设置——攻击者无法从格式区分。" | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown( | |
| "| 任务类别 | 数量 | 占比 |\n|---|---|---|\n" | |
| "| 🧮 基础计算 | 800 | 40% |\n| 📝 应用题 | 600 | 30% |\n| 🧠 概念问答 | 400 | 20% |\n| ✍️ 错题订正 | 200 | 10% |\n" | |
| ) | |
| gr.Markdown("### 🔍 数据样例浏览提取") | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=2): | |
| d_src = gr.Radio(["成员数据(训练集)","非成员数据(测试集)"], value="成员数据(训练集)", label="选择靶向数据来源") | |
| d_btn = gr.Button("🎲 随机提取样本", variant="primary") | |
| d_meta = gr.Markdown() | |
| with gr.Column(scale=3): | |
| d_q = gr.Textbox(label="🧑🎓 学生提问 (Prompt)", lines=4, interactive=False) | |
| d_a = gr.Textbox(label="🤖 标准回答 (Ground Truth)", lines=4, interactive=False) | |
| d_btn.click(cb_sample, [d_src], [d_meta, d_q, d_a]) | |
| # ═══════ Tab 3: 攻击与防御验证 ═══════ | |
| with gr.Tab("🎯 攻击验证"): | |
| gr.Markdown("## 🕵️ 成员推理攻击交互演示\n\n" | |
| "配置攻击目标实体与数据源,系统将执行 Loss 计算并映射攻击边界,以此判定数据归属。") | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=2): | |
| a_target = gr.Radio([u"基线模型 (Baseline)",u"标签平滑 (\u03b5=0.02)",u"标签平滑 (\u03b5=0.2)", | |
| u"输出扰动 (\u03c3=0.01)",u"输出扰动 (\u03c3=0.015)",u"输出扰动 (\u03c3=0.02)"], | |
| value=u"基线模型 (Baseline)", label="选择攻击目标") | |
| a_src = gr.Radio(["成员数据(训练集)","非成员数据(测试集)"], value="成员数据(训练集)", label="数据来源") | |
| a_idx = gr.Slider(0, 999, step=1, value=12, label="定位样本 ID") | |
| a_btn = gr.Button("⚡ 执行成员推理攻击", variant="primary", size="lg") | |
| a_qtxt = gr.Markdown() | |
| with gr.Column(scale=3): | |
| a_gauge = gr.Plot(label="Loss位置判定 (Decision Boundary)") | |
| a_res = gr.Markdown() | |
| a_btn.click(cb_attack, [a_idx, a_src, a_target], [a_qtxt, a_gauge, a_res]) | |
| # ═══════ Tab 4: 防御效果分析 ═══════ | |
| with gr.Tab("🛡️ 防御分析"): | |
| with gr.Accordion("📊 展开查看:防御对比直方图", open=False): | |
| gr.Markdown("### MIA攻击AUC对比\n\n> 柱子越矮 = AUC越低 = 攻击越难成功 = 防御越有效") | |
| gr.Plot(value=fig_auc_bar()) | |
| gr.Markdown("### Loss分布对比\n#### 三个模型(训练期防御效果)\n\n> 蓝色=成员,红色=非成员。两色重叠越多 = 攻击者越难区分") | |
| gr.Plot(value=fig_loss_dist()) | |
| gr.Markdown("#### 输出扰动效果(推理期防御)\n\n> 在基线模型Loss上加噪声,随噪声增大分布更加重叠") | |
| gr.Plot(value=fig_perturb_dist()) | |
| with gr.Accordion("⚙️ 展开查看:完整实验数据与机制说明", open=True): | |
| gr.Markdown( | |
| "### 完整实验数据表\n\n" | |
| "| 策略 | 类型 | AUC | 准确率 | AUC变化 |\n|---|---|---|---|---|\n" | |
| "| 基线 | — | " + f"{bl_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | — |\n" | |
| "| " + u"LS(\u03b5=0.02)" + " | 训练期 | " + f"{s002_auc:.4f}" + " | " + f"{s002_acc:.1f}%" + " | " + f"{s002_auc-bl_auc:+.4f}" + " |\n" | |
| "| " + u"LS(\u03b5=0.2)" + " | 训练期 | " + f"{s02_auc:.4f}" + " | " + f"{s02_acc:.1f}%" + " | " + f"{s02_auc-bl_auc:+.4f}" + " |\n" | |
| "| " + u"OP(\u03c3=0.01)" + " | 推理期 | " + f"{op001_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | " + f"{op001_auc-bl_auc:+.4f}" + " |\n" | |
| "| " + u"OP(\u03c3=0.015)" + " | 推理期 | " + f"{op0015_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | " + f"{op0015_auc-bl_auc:+.4f}" + " |\n" | |
| "| " + u"OP(\u03c3=0.02)" + " | 推理期 | " + f"{op002_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | " + f"{op002_auc-bl_auc:+.4f}" + " |\n\n" | |
| "### 防御机制对比\n\n" | |
| "| 维度 | 标签平滑 | 输出扰动 |\n|---|---|---|\n" | |
| "| **阶段** | 训练期 | 推理期 |\n" | |
| "| **原理** | 软化标签降低记忆 | Loss加噪遮蔽信号 |\n" | |
| "| **需重训** | 是 | 否 |\n" | |
| "| **效用** | 取决于参数 | 无 |\n" | |
| "| **部署** | 训练时介入 | 即插即用 |\n\n" | |
| "**标签平滑公式**: `y_smooth = (1 - ε) * y_onehot + ε / V`\n\n" | |
| "**输出扰动公式**: `L_perturbed = L_original + N(0, σ²)`\n") | |
| with gr.Accordion("🖼️ 展开查看:静态高分辨率学术图表", open=False): | |
| for fn, cap in [("fig1_loss_distribution_comparison.png","Loss分布对比"), | |
| ("fig2_privacy_utility_tradeoff_fixed.png","隐私-效用权衡"), | |
| ("fig3_defense_comparison_bar.png","防御策略AUC对比")]: | |
| p = os.path.join(BASE_DIR,"figures",fn) | |
| if os.path.exists(p): | |
| gr.Markdown("#### "+cap); gr.Image(value=p, show_label=False, height=420) | |
| # ═══════ Tab 5: 效用评估 ═══════ | |
| with gr.Tab("⚖️ 效用评估"): | |
| gr.Markdown("## 🎯 模型效用测试\n\n> 基于300道数学测试题评估各策略对模型实际能力的影响") | |
| with gr.Row(equal_height=True): | |
| with gr.Column(): gr.Plot(value=fig_acc_bar()) | |
| with gr.Column(): gr.Plot(value=fig_tradeoff()) | |
| gr.Markdown("## 🎮 在线效用抽样演示\n\n从测试题库中随机抽取,流式验证不同模型/策略的保留作答情况。") | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1): | |
| e_model = gr.Radio([u"基线模型",u"标签平滑 (\u03b5=0.02)",u"标签平滑 (\u03b5=0.2)", | |
| u"输出扰动 (\u03c3=0.01)",u"输出扰动 (\u03c3=0.015)",u"输出扰动 (\u03c3=0.02)"], value=u"基线模型", label="选择验证模型") | |
| e_btn = gr.Button("🧪 随机抽题测试", variant="primary") | |
| with gr.Column(scale=2): | |
| e_res = gr.Markdown() | |
| e_btn.click(cb_eval, [e_model], [e_res]) | |
| # ═══════ Tab 6: 研究结论 (引入折叠面板展开) ═══════ | |
| with gr.Tab("📝 研究结论"): | |
| gr.Markdown("## 💡 核心研究发现与最佳实践\n\n---") | |
| with gr.Accordion("🚨 一、教育大模型存在可量化的MIA风险", open=True): | |
| gr.Markdown("基线模型 AUC = **" + f"{bl_auc:.4f}" + "** > 0.5,成员平均Loss (" + f"{bl_m_mean:.4f}" | |
| + ") 显著小于 非成员 (" + f"{bl_nm_mean:.4f}" + "),实验铁证表明模型对训练数据存在可利用的记忆效应。") | |
| with gr.Accordion("🛡️ 二、标签平滑(训练期防御)有效性验证", open=True): | |
| gr.Markdown( | |
| "| 参数 | AUC | 准确率 | 分析 |\n|---|---|---|---|\n" | |
| "| 基线 | " + f"{bl_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | 无防御 |\n" | |
| "| " + u"\u03b5=0.02" + " | " + f"{s002_auc:.4f}" + " | " + f"{s002_acc:.1f}%" + " | 正则化提升泛化 |\n" | |
| "| " + u"\u03b5=0.2" + " | " + f"{s02_auc:.4f}" + " | " + f"{s02_acc:.1f}%" + " | 防御增强 |\n" | |
| ) | |
| with gr.Accordion("🎛️ 三、输出扰动(推理期防御)有效性验证", open=True): | |
| gr.Markdown( | |
| "| 参数 | AUC | AUC降幅 | 准确率 |\n|---|---|---|---|\n" | |
| "| " + u"\u03c3=0.01" + " | " + f"{op001_auc:.4f}" + " | " + f"{bl_auc-op001_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " |\n" | |
| "| " + u"\u03c3=0.015" + " | " + f"{op0015_auc:.4f}" + " | " + f"{bl_auc-op0015_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " |\n" | |
| "| " + u"\u03c3=0.02" + " | " + f"{op002_auc:.4f}" + " | " + f"{bl_auc-op002_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " |\n\n" | |
| "**结论:零效用损失,适合已部署系统的后期加固。**" | |
| ) | |
| with gr.Accordion("⚖️ 四、隐私-效用权衡总结", open=False): | |
| gr.Markdown( | |
| "| 策略 | AUC | 准确率 | 隐私 | 效用 |\n|---|---|---|---|---|\n" | |
| "| 基线 | " + f"{bl_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | 风险最高 | 基准 |\n" | |
| "| " + u"LS(\u03b5=0.02)" + " | " + f"{s002_auc:.4f}" + " | " + f"{s002_acc:.1f}%" + " | 降低 | 提升 |\n" | |
| "| " + u"LS(\u03b5=0.2)" + " | " + f"{s02_auc:.4f}" + " | " + f"{s02_acc:.1f}%" + " | 显著降低 | 可接受 |\n" | |
| "| " + u"OP(\u03c3=0.02)" + " | " + f"{op002_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | 显著降低 | 不变 |\n\n" | |
| "> 两类策略机制互补:标签平滑从训练阶段降低记忆,输出扰动从推理阶段遮蔽信号。建议组合使用以构建立体防御体系。" | |
| ) | |
| gr.HTML("<div style='text-align:center;color:#94a3b8;font-size:.82rem;padding:16px 0 8px'></div>") | |
| demo.launch() |