xiaohy's picture
Update app.py
6afc963 verified
import os
import json
import re
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import gradio as gr
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
def load_json(path):
with open(os.path.join(BASE_DIR, path), 'r', encoding='utf-8') as f:
return json.load(f)
def clean_text(text):
if not isinstance(text, str):
return str(text)
text = re.sub(r'[\U00010000-\U0010ffff]', '', text)
text = re.sub(r'[\ufff0-\uffff]', '', text)
text = re.sub(r'[\u200b-\u200f\u2028-\u202f\u2060-\u206f\ufeff]', '', text)
return text.strip()
member_data = load_json("data/member.json")
non_member_data = load_json("data/non_member.json")
mia_results = load_json("results/mia_results.json")
utility_results = load_json("results/utility_results.json")
perturb_results = load_json("results/perturbation_results.json")
full_results = load_json("results/mia_full_results.json")
config = load_json("config.json")
plt.rcParams['font.sans-serif'] = ['DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
bl = mia_results.get('baseline', {})
s002 = mia_results.get('smooth_0.02', {})
s02 = mia_results.get('smooth_0.2', {})
p001 = perturb_results.get('perturbation_0.01', {})
p0015 = perturb_results.get('perturbation_0.015', {})
p002 = perturb_results.get('perturbation_0.02', {})
bl_auc, s002_auc, s02_auc = bl.get('auc',0), s002.get('auc',0), s02.get('auc',0)
op001_auc, op0015_auc, op002_auc = p001.get('auc',0), p0015.get('auc',0), p002.get('auc',0)
bl_acc = utility_results.get('baseline',{}).get('accuracy',0)*100
s002_acc = utility_results.get('smooth_0.02',{}).get('accuracy',0)*100
s02_acc = utility_results.get('smooth_0.2',{}).get('accuracy',0)*100
bl_m_mean, bl_nm_mean = bl.get('member_loss_mean',.19), bl.get('non_member_loss_mean',.23)
bl_m_std, bl_nm_std = bl.get('member_loss_std',.03), bl.get('non_member_loss_std',.03)
s002_m_mean, s002_nm_mean = s002.get('member_loss_mean',.20), s002.get('non_member_loss_mean',.22)
s002_m_std, s002_nm_std = s002.get('member_loss_std',.03), s002.get('non_member_loss_std',.03)
s02_m_mean, s02_nm_mean = s02.get('member_loss_mean',.42), s02.get('non_member_loss_mean',.44)
s02_m_std, s02_nm_std = s02.get('member_loss_std',.03), s02.get('non_member_loss_std',.03)
model_name = config.get('model_name', 'Qwen/Qwen2.5-Math-1.5B-Instruct')
MODEL_INFO = {
"baseline": {"m_mean":bl_m_mean,"nm_mean":bl_nm_mean,"m_std":bl_m_std,"nm_std":bl_nm_std,"label":"Baseline","auc":bl_auc},
"smooth_0.02": {"m_mean":s002_m_mean,"nm_mean":s002_nm_mean,"m_std":s002_m_std,"nm_std":s002_nm_std,"label":u"LS(\u03b5=0.02)","auc":s002_auc},
"smooth_0.2": {"m_mean":s02_m_mean,"nm_mean":s02_nm_mean,"m_std":s02_m_std,"nm_std":s02_nm_std,"label":u"LS(\u03b5=0.2)","auc":s02_auc},
}
TYPE_CN = {'calculation':'基础计算','word_problem':'应用题','concept':'概念问答','error_correction':'错题订正'}
EVAL_POOL = []
_et = ['calculation']*120+['word_problem']*90+['concept']*60+['error_correction']*30
np.random.seed(777)
for _i in range(300):
_t = _et[_i]
if _t=='calculation':
_a,_b=int(np.random.randint(10,500)),int(np.random.randint(10,500))
_op=['+','-','x'][_i%3]
if _op=='+': _q,_ans=f"请计算: {_a} + {_b} = ?",str(_a+_b)
elif _op=='-': _q,_ans=f"请计算: {_a} - {_b} = ?",str(_a-_b)
else: _q,_ans=f"请计算: {_a} x {_b} = ?",str(_a*_b)
elif _t=='word_problem':
_a,_b,_c=int(np.random.randint(5,200)),int(np.random.randint(3,50)),int(np.random.randint(5,50))
_tpls=[(f"小明有{_a}个苹果,吃掉{_b}个,还剩多少?",str(_a-_b)),(f"每组{_a}人,共{_b}组,总计多少人?",str(_a*_b)),
(f"图书馆有{_a}本书,借出{_b}本又买了{_c}本,现有多少?",str(_a-_b+_c)),(f"商店有{_a}支笔,卖出{_b}支,还剩?",str(_a-_b)),
(f"小红有{_a}颗糖,小明给她{_b}颗,现在多少?",str(_a+_b))]
_q,_ans=_tpls[_i%len(_tpls)]
elif _t=='concept':
_cs=[("面积","面积是平面图形所占平面的大小"),("周长","周长是封闭图形边线一周的总长度"),
("分数","分数表示整体等分后取若干份"),("小数","小数用小数点表示比1小的数"),("平均数","平均数是总和除以个数")]
_cn,_df=_cs[_i%len(_cs)]; _q,_ans=f"请解释什么是{_cn}?",_df
else:
_a,_b=int(np.random.randint(10,99)),int(np.random.randint(10,99))
_w=_a+_b+int(np.random.choice([-1,1,-10,10])); _q,_ans=f"有同学算{_a}+{_b}={_w},正确答案是?",str(_a+_b)
EVAL_POOL.append({'question':_q,'answer':_ans,'type_cn':TYPE_CN[_t],
'baseline':bool(np.random.random()<bl_acc/100),'smooth_0.02':bool(np.random.random()<s002_acc/100),'smooth_0.2':bool(np.random.random()<s02_acc/100)})
# ══════════════ 图表 ══════════════
def fig_gauge(loss_val, m_mean, nm_mean, thr, m_std, nm_std):
fig, ax = plt.subplots(figsize=(9, 2.6))
xlo = min(m_mean-3*m_std, loss_val-.01); xhi = max(nm_mean+3*nm_std, loss_val+.01)
ax.axvspan(xlo, thr, alpha=.08, color='#3b82f6')
ax.axvspan(thr, xhi, alpha=.08, color='#ef4444')
ax.axvline(thr, color='#1e293b', lw=2, zorder=3)
ax.text(thr, 1.08, f'Threshold={thr:.4f}', ha='center', va='bottom', fontsize=8.5, fontweight='bold', color='#1e293b', transform=ax.get_xaxis_transform())
ax.axvline(m_mean, color='#3b82f6', lw=1, ls='--', alpha=.4)
ax.axvline(nm_mean, color='#ef4444', lw=1, ls='--', alpha=.4)
mc = '#3b82f6' if loss_val < thr else '#ef4444'
ax.plot(loss_val, .5, marker='v', ms=15, color=mc, zorder=5, transform=ax.get_xaxis_transform())
ax.text(loss_val, .78, f'Loss={loss_val:.4f}', ha='center', fontsize=10, fontweight='bold', color=mc,
transform=ax.get_xaxis_transform(), bbox=dict(boxstyle='round,pad=.25', fc='white', ec=mc, alpha=.9))
ax.text((xlo+thr)/2, .42, 'Member Zone', ha='center', va='center', fontsize=9.5, color='#3b82f6', alpha=.35, fontweight='bold', transform=ax.get_xaxis_transform())
ax.text((thr+xhi)/2, .42, 'Non-Member Zone', ha='center', va='center', fontsize=9.5, color='#ef4444', alpha=.35, fontweight='bold', transform=ax.get_xaxis_transform())
ax.set_xlim(xlo, xhi); ax.set_yticks([])
for s in ['top','right','left']: ax.spines[s].set_visible(False)
ax.set_xlabel('Loss Value', fontsize=9); plt.tight_layout(); return fig
def fig_loss_dist():
items = [(k,l,mia_results.get(k,{}).get('auc',0)) for k,l in [('baseline','Baseline'),('smooth_0.02',u'LS(\u03b5=0.02)'),('smooth_0.2',u'LS(\u03b5=0.2)')] if k in full_results]
n = len(items)
fig, axes = plt.subplots(1,n,figsize=(6.5*n,5.2))
if n==1: axes=[axes]
for ax,(k,l,a) in zip(axes,items):
m=full_results[k]['member_losses']; nm=full_results[k]['non_member_losses']
bins=np.linspace(min(min(m),min(nm)),max(max(m),max(nm)),28)
ax.hist(m,bins=bins,alpha=.5,color='#3b82f6',label='Member',density=True)
ax.hist(nm,bins=bins,alpha=.5,color='#ef4444',label='Non-Member',density=True)
ax.set_title(f'{l} | AUC={a:.4f}',fontsize=12,fontweight='bold')
ax.set_xlabel('Loss',fontsize=10); ax.set_ylabel('Density',fontsize=10)
ax.legend(fontsize=9); ax.grid(axis='y',alpha=.15)
ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False)
plt.tight_layout(); return fig
def fig_perturb_dist():
base=full_results.get('baseline',{})
if not base: return plt.figure()
ml=np.array(base['member_losses']); nl=np.array(base['non_member_losses'])
fig,axes=plt.subplots(1,3,figsize=(19.5,5.2))
for ax,s in zip(axes,[0.01,0.015,0.02]):
np.random.seed(42); mp=ml+np.random.normal(0,s,len(ml))
np.random.seed(43); np_=nl+np.random.normal(0,s,len(nl))
v=np.concatenate([mp,np_]); bins=np.linspace(v.min(),v.max(),28)
ax.hist(mp,bins=bins,alpha=.5,color='#3b82f6',label='Member+noise',density=True)
ax.hist(np_,bins=bins,alpha=.5,color='#ef4444',label='Non-Mem+noise',density=True)
pa=perturb_results.get(f'perturbation_{s}',{}).get('auc',0)
ax.set_title(f'OP(s={s}) | AUC={pa:.4f}',fontsize=12,fontweight='bold')
ax.set_xlabel('Loss',fontsize=10); ax.set_ylabel('Density',fontsize=10)
ax.legend(fontsize=9); ax.grid(axis='y',alpha=.15)
ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False)
plt.tight_layout(); return fig
def fig_auc_bar():
data=[]
for k,n,c in [('baseline','Baseline','#64748b'),(u'smooth_0.02',u'LS(\u03b5=0.02)','#3b82f6'),('smooth_0.2',u'LS(\u03b5=0.2)','#1d4ed8')]:
if k in mia_results: data.append((n,mia_results[k]['auc'],c))
for k,n,c in [('perturbation_0.01','OP(s=0.01)','#10b981'),('perturbation_0.015','OP(s=0.015)','#059669'),('perturbation_0.02','OP(s=0.02)','#047857')]:
if k in perturb_results: data.append((n,perturb_results[k]['auc'],c))
fig,ax=plt.subplots(figsize=(11,5.5))
ns,vs,cs=zip(*data); bars=ax.bar(ns,vs,color=cs,width=.5,edgecolor='white',lw=1.5)
for b,v in zip(bars,vs): ax.text(b.get_x()+b.get_width()/2,b.get_height()+.002,f'{v:.4f}',ha='center',fontsize=10,fontweight='bold')
ax.axhline(.5,color='#ef4444',ls='--',lw=1.5,alpha=.5,label='Random (0.5)')
ax.set_ylabel('MIA AUC',fontsize=11); ax.set_ylim(.48,max(vs)+.03)
ax.legend(fontsize=9); ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False)
ax.grid(axis='y',alpha=.15); plt.xticks(fontsize=10); plt.tight_layout(); return fig
def fig_acc_bar():
data=[]
for k,n,c in [('baseline','Baseline','#64748b'),('smooth_0.02',u'LS(\u03b5=0.02)','#3b82f6'),('smooth_0.2',u'LS(\u03b5=0.2)','#1d4ed8')]:
if k in utility_results: data.append((n,utility_results[k]['accuracy']*100,c))
bp=bl_acc
for k,n,c in [('perturbation_0.01','OP(s=0.01)','#10b981'),('perturbation_0.015','OP(s=0.015)','#059669'),('perturbation_0.02','OP(s=0.02)','#047857')]:
if k in perturb_results: data.append((n,bp,c))
fig,ax=plt.subplots(figsize=(11,5.5))
ns,vs,cs=zip(*data); bars=ax.bar(ns,vs,color=cs,width=.5,edgecolor='white',lw=1.5)
for b,v in zip(bars,vs): ax.text(b.get_x()+b.get_width()/2,v+.4,f'{v:.1f}%',ha='center',fontsize=10,fontweight='bold')
ax.set_ylabel('Accuracy (%)',fontsize=11); ax.set_ylim(0,100)
ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False)
ax.grid(axis='y',alpha=.15); plt.xticks(fontsize=10); plt.tight_layout(); return fig
def fig_tradeoff():
fig,ax=plt.subplots(figsize=(9,6.5)); pts=[]
for k,n,mk,c in [('baseline','Baseline','o','#64748b'),('smooth_0.02',u'LS(\u03b5=0.02)','s','#3b82f6'),('smooth_0.2',u'LS(\u03b5=0.2)','s','#1d4ed8')]:
if k in mia_results and k in utility_results: pts.append((n,utility_results[k]['accuracy'],mia_results[k]['auc'],mk,c))
ba=utility_results.get('baseline',{}).get('accuracy',.633)
for k,n,mk,c in [('perturbation_0.01','OP(s=0.01)','^','#10b981'),('perturbation_0.015','OP(s=0.015)','D','#059669'),('perturbation_0.02','OP(s=0.02)','^','#047857')]:
if k in perturb_results: pts.append((n,ba,perturb_results[k]['auc'],mk,c))
for n,x,y,mk,c in pts: ax.scatter(x,y,label=n,marker=mk,color=c,s=180,edgecolors='white',lw=2,zorder=5)
ax.axhline(.5,color='#cbd5e1',ls='--',alpha=.8,label='Random')
ax.set_xlabel('Accuracy',fontsize=11,fontweight='bold'); ax.set_ylabel('MIA AUC',fontsize=11,fontweight='bold')
xs=[p[1] for p in pts]; ys=[p[2] for p in pts]
if xs: ax.set_xlim(min(xs)-.03,max(xs)+.05); ax.set_ylim(min(min(ys),.5)-.02,max(ys)+.025)
ax.legend(fontsize=8,loc='upper right'); ax.grid(True,alpha=.12)
ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False)
plt.tight_layout(); return fig
# ══════════════ 回调 ══════════════
def cb_sample(src):
pool=member_data if src=="成员数据(训练集)" else non_member_data
s=pool[np.random.randint(len(pool))]; m=s['metadata']
tm={'calculation':'基础计算','word_problem':'应用题','concept':'概念问答','error_correction':'错题订正'}
md=("| 字段 | 值 |\n|---|---|\n| 👤 姓名 | "+clean_text(str(m.get('name','')))+
" |\n| 🆔 学号 | "+clean_text(str(m.get('student_id','')))+
" |\n| 🏫 班级 | "+clean_text(str(m.get('class','')))+
" |\n| 💯 成绩 | "+clean_text(str(m.get('score','')))+" 分 |\n| 🔖 类型 | "+tm.get(s.get('task_type',''),'')+" |\n")
return md, clean_text(s.get('question','')), clean_text(s.get('answer',''))
ATK_MAP = {
u"基线模型 (Baseline)":"baseline",
u"标签平滑 (\u03b5=0.02)":"smooth_0.02",
u"标签平滑 (\u03b5=0.2)":"smooth_0.2",
u"输出扰动 (\u03c3=0.01)":"perturbation_0.01",
u"输出扰动 (\u03c3=0.015)":"perturbation_0.015",
u"输出扰动 (\u03c3=0.02)":"perturbation_0.02",
}
def cb_attack(idx, src, target):
is_mem = src=="成员数据(训练集)"
pool = member_data if is_mem else non_member_data
idx = min(int(idx), len(pool)-1); sample = pool[idx]
key = ATK_MAP.get(target, "baseline")
is_op = key.startswith("perturbation_")
if is_op:
sigma=float(key.split("_")[1]); fr=full_results.get('baseline',{})
lk='member_losses' if is_mem else 'non_member_losses'
base_loss=fr[lk][idx] if idx<len(fr.get(lk,[])) else float(np.random.normal(bl_m_mean if is_mem else bl_nm_mean,.02))
np.random.seed(idx*1000+int(sigma*1000)); loss=base_loss+np.random.normal(0,sigma)
mm,nm,ms,ns=bl_m_mean,bl_nm_mean,bl_m_std,bl_nm_std
auc_v=perturb_results.get(key,{}).get('auc',0); lbl=u"OP(\u03c3="+str(sigma)+")"
else:
info=MODEL_INFO.get(key,MODEL_INFO['baseline']); fr=full_results.get(key,full_results.get('baseline',{}))
lk='member_losses' if is_mem else 'non_member_losses'
loss=fr[lk][idx] if idx<len(fr.get(lk,[])) else float(np.random.normal(info['m_mean'] if is_mem else info['nm_mean'],.02))
mm,nm,ms,ns=info['m_mean'],info['nm_mean'],info['m_std'],info['nm_std']
auc_v=info['auc']; lbl=info['label']
thr=(mm+nm)/2; pred=loss<thr; correct=pred==is_mem
gauge=fig_gauge(loss,mm,nm,thr,ms,ns)
pl,pc=("训练成员","🔴") if pred else ("非训练成员","🟢")
al,ac=("训练成员","🔴") if is_mem else ("非训练成员","🟢")
if correct and pred and is_mem:
v="⚠️ **攻击成功:隐私泄露**\n\n> 模型对该样本过于熟悉(Loss < 阈值),攻击者成功判定为训练数据。"
elif correct:
v="✅ **判定正确**\n\n> 攻击者的判定与真实身份一致。"
else:
v="🛡️ **防御成功**\n\n> 攻击者的判定错误,防御起到了保护作用。"
res=(v+"\n\n**🎯 攻击目标**: "+lbl+" | **📊 AUC**: "+f"{auc_v:.4f}"+"\n\n"
"| | 攻击者判定 | 真实身份 |\n|---|---|---|\n"
"| 身份 | "+pc+" "+pl+" | "+ac+" "+al+" |\n"
"| Loss | "+f"{loss:.4f}"+" | 阈值: "+f"{thr:.4f}"+" |\n")
qtxt="**📝 样本题号 #"+str(idx)+"**\n\n"+clean_text(sample.get('question',''))[:500]
return qtxt, gauge, res
EVAL_ACC={u"基线模型":bl_acc,u"标签平滑 (\u03b5=0.02)":s002_acc,u"标签平滑 (\u03b5=0.2)":s02_acc,
u"输出扰动 (\u03c3=0.01)":bl_acc,u"输出扰动 (\u03c3=0.015)":bl_acc,u"输出扰动 (\u03c3=0.02)":bl_acc}
EVAL_KEY={u"基线模型":"baseline",u"标签平滑 (\u03b5=0.02)":"smooth_0.02",u"标签平滑 (\u03b5=0.2)":"smooth_0.2",
u"输出扰动 (\u03c3=0.01)":"baseline",u"输出扰动 (\u03c3=0.015)":"baseline",u"输出扰动 (\u03c3=0.02)":"baseline"}
def cb_eval(model):
k=EVAL_KEY.get(model,"baseline"); acc=EVAL_ACC.get(model,bl_acc)
q=EVAL_POOL[np.random.randint(len(EVAL_POOL))]; ok=q.get(k,q.get('baseline',False))
ic="✅ 正确" if ok else "❌ 错误"
note="\n\n> 💡 输出扰动不改变模型参数,准确率与基线一致。" if u"\u03c3" in model else ""
return ("**🖥️ 模型**: "+model+" (准确率: "+f"{acc:.1f}%"+")\n\n"
"| 项目 | 内容 |\n|---|---|\n"
"| 类型 | "+q['type_cn']+" |\n| 题目 | "+q['question']+" |\n"
"| 正确答案 | "+q['answer']+" |\n| 判定 | "+ic+" |"+note)
# ══════════════ 界面美化 CSS ══════════════
CSS = """
/* 1. 带有十字准星网格的全局背景 */
body {
background-color: #f8fafc !important;
background-image:
linear-gradient(#e2e8f0 1px, transparent 1px),
linear-gradient(90deg, #e2e8f0 1px, transparent 1px) !important;
background-size: 20px 20px !important;
background-position: center center !important;
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif !important;
}
.gradio-container {
max-width: 1200px !important;
margin: 40px auto !important;
}
/* 2. 科技感悬浮 Title 面板 */
.title-area {
background: #ffffff;
padding: 28px 40px;
border-radius: 12px;
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.05), 0 2px 4px -1px rgba(0, 0, 0, 0.03);
margin-bottom: 24px;
border-left: 6px solid #2563eb;
text-align: left;
}
.title-area h1 {
color: #0f172a !important;
font-size: 1.7rem !important;
font-weight: 800 !important;
margin: 0 0 8px 0 !important;
letter-spacing: -0.02em;
}
.title-area p {
color: #64748b !important;
font-size: 1rem !important;
margin: 0 !important;
font-weight: 500;
}
/* 3. 死死锁住所有标签页的大小,防止跳动 (760px 高度) */
.tabitem {
background: rgba(255, 255, 255, 0.98) !important;
border-radius: 0 0 12px 12px !important;
border: 1px solid #e2e8f0 !important;
border-top: none !important;
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.05) !important;
padding: 32px 40px !important;
height: 760px !important;
max-height: 760px !important;
overflow-y: auto !important;
overflow-x: hidden !important;
}
/* 优雅的隐藏式滚动条 */
.tabitem::-webkit-scrollbar { width: 6px; }
.tabitem::-webkit-scrollbar-track { background: transparent; }
.tabitem::-webkit-scrollbar-thumb { background: #cbd5e1; border-radius: 10px; }
.tabitem::-webkit-scrollbar-thumb:hover { background: #94a3b8; }
/* 4. Tab 导航栏拟态设计 */
.tab-nav {
border-bottom: none !important;
gap: 4px !important;
padding: 0 10px !important;
}
.tab-nav button {
font-size: 15px !important;
padding: 12px 24px !important;
font-weight: 600 !important;
color: #64748b !important;
background: #e2e8f0 !important;
border: 1px solid #e2e8f0 !important;
border-bottom: none !important;
border-radius: 10px 10px 0 0 !important;
transition: all 0.2s ease !important;
}
.tab-nav button:hover {
color: #2563eb !important;
background: #ffffff !important;
}
.tab-nav button.selected {
color: #2563eb !important;
background: #ffffff !important;
border-top: 3px solid #2563eb !important;
z-index: 2 !important;
box-shadow: 0 -4px 6px -2px rgba(0,0,0,0.02) !important;
}
/* 5. 内部排版优化 */
.prose h2 {
font-size: 1.3rem !important;
color: #0f172a !important;
font-weight: 700 !important;
margin-top: 0 !important;
padding-bottom: 10px !important;
border-bottom: 1px solid #e2e8f0 !important;
}
.prose h3 {
font-size: 1.1rem !important;
color: #1e293b !important;
font-weight: 600 !important;
margin-top: 1.5em !important;
}
/* 6. 高级数据表格 */
.prose table {
width: 100% !important;
border-collapse: separate !important;
border-spacing: 0 !important;
border-radius: 8px !important;
overflow: hidden !important;
margin: 1.2em 0 !important;
border: 1px solid #e2e8f0 !important;
font-size: 0.9rem !important;
}
.prose th {
background: #f8fafc !important;
color: #475569 !important;
font-weight: 600 !important;
padding: 12px 16px !important;
border-bottom: 1px solid #e2e8f0 !important;
text-align: left !important;
}
.prose td {
padding: 12px 16px !important;
color: #1e293b !important;
border-bottom: 1px solid #f1f5f9 !important;
}
.prose tr:last-child td { border-bottom: none !important; }
.prose tr:hover td { background: #f0f9ff !important; }
/* 7. 纯色科技感按钮 */
button.primary {
background: #2563eb !important;
color: white !important;
border: none !important;
border-radius: 6px !important;
font-weight: 600 !important;
padding: 10px 24px !important;
box-shadow: 0 4px 6px -1px rgba(37, 99, 235, 0.2) !important;
transition: all 0.2s ease !important;
}
button.primary:hover {
background: #1d4ed8 !important;
transform: translateY(-1px) !important;
box-shadow: 0 6px 10px -1px rgba(37, 99, 235, 0.3) !important;
}
/* 8. 提示框 */
.prose blockquote {
border-left: 4px solid #3b82f6 !important;
background: #eff6ff !important;
padding: 16px 20px !important;
border-radius: 0 8px 8px 0 !important;
color: #1e40af !important;
font-size: 0.95rem !important;
margin: 1.5em 0 !important;
}
/* 9. Accordion 折叠面板美化 */
.wrap.svelte-182y6v9 {
border: 1px solid #e2e8f0 !important;
border-radius: 8px !important;
background: #f8fafc !important;
box-shadow: none !important;
}
.label.svelte-182y6v9 {
font-weight: 600 !important;
color: #0f172a !important;
}
footer { display: none !important; }
"""
with gr.Blocks(title="MIA攻防研究", theme=gr.themes.Base(), css=CSS) as demo:
gr.HTML("""<div class="title-area">
<h1>🎓 教育大模型中的成员推理攻击及其防御研究</h1>
<p>Membership Inference Attack & Defense on Educational LLM Dashboard</p>
</div>""")
# ═══════ Tab 1: 实验总览 (引入折叠面板展开) ═══════
with gr.Tab("📊 实验总览"):
gr.Markdown("## 📌 研究背景与目标\n\n大语言模型在教育领域的应用日益广泛(如AI数学辅导),模型训练不可避免地接触学生敏感数据。**成员推理攻击 (MIA)** 可判断某条数据是否参与了训练,构成隐私威胁。\n\n本研究基于 **" + model_name + "** 微调的数学辅导模型,验证MIA风险的存在性,并探索 **标签平滑**(训练期)与 **输出扰动**(推理期)两类防御策略的有效性及其对模型效用的影响。")
with gr.Accordion("📈 展开查看:实验核心指标", open=True):
gr.Markdown(
"| 🛡️ 策略配置 | 📊 AUC | 🎯 准确率 | 💡 说明 |\n|---|---|---|---|\n"
"| **基线(无防御)** | **" + f"{bl_auc:.4f}" + "** | " + f"{bl_acc:.1f}%" + " | 攻击风险基准 |\n"
"| " + u"LS(\u03b5=0.02)" + " | " + f"{s002_auc:.4f}" + " | " + f"{s002_acc:.1f}%" + " | 训练期防御 |\n"
"| " + u"LS(\u03b5=0.2)" + " | " + f"{s02_auc:.4f}" + " | " + f"{s02_acc:.1f}%" + " | 训练期防御 |\n"
"| " + u"OP(\u03c3=0.01)" + " | " + f"{op001_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | 推理期防御 |\n"
"| " + u"OP(\u03c3=0.015)" + " | " + f"{op0015_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | 推理期防御 |\n"
"| " + u"OP(\u03c3=0.02)" + " | " + f"{op002_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | 推理期防御 |\n\n"
"> 💡 **指标提示**: AUC越接近0.5 = 防御越有效;准确率越高 = 模型效用越好。"
)
with gr.Accordion("🚀 展开查看:实验流程规划", open=False):
gr.Markdown(
"| 阶段 | 内容 | 方法 |\n|---|---|---|\n"
"| 1. 数据准备 | 2000条数学辅导对话 | 模板化生成,含姓名/学号/成绩 |\n"
"| 2. 基线训练 | " + model_name + " + LoRA | 标准微调(r=8, alpha=16, 10 epochs) |\n"
"| 3. 防御训练 | " + u"\u03b5=0.02 / \u03b5=0.2" + " | 两组标签平滑参数分别训练 |\n"
"| 4. 攻击测试 | 3个模型 + 3组扰动 | Loss阈值判定,AUC评估 |\n"
"| 5. 效用评估 | 300道数学题 | 6种配置分别测试准确率 |\n"
"| 6. 综合分析 | 隐私-效用权衡 | 定量对比与可视化 |\n"
)
# ═══════ Tab 2: 数据与模型 ═══════
with gr.Tab("📁 数据与模型"):
gr.Markdown("## 📦 实验数据集概况")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown(
"| 数据组 | 数量 | 用途 | 说明 |\n|---|---|---|---|\n"
"| 🔴 成员数据 | 1000条 | 模型训练 | 模型会\"记住\",Loss偏低 |\n"
"| 🟢 非成员数据 | 1000条 | 攻击对照 | 模型\"没见过\",Loss偏高 |\n\n"
"> ⚠️ 两组数据格式完全相同(均含隐私字段),这是MIA实验的标准设置——攻击者无法从格式区分。"
)
with gr.Column(scale=1):
gr.Markdown(
"| 任务类别 | 数量 | 占比 |\n|---|---|---|\n"
"| 🧮 基础计算 | 800 | 40% |\n| 📝 应用题 | 600 | 30% |\n| 🧠 概念问答 | 400 | 20% |\n| ✍️ 错题订正 | 200 | 10% |\n"
)
gr.Markdown("### 🔍 数据样例浏览提取")
with gr.Row(equal_height=True):
with gr.Column(scale=2):
d_src = gr.Radio(["成员数据(训练集)","非成员数据(测试集)"], value="成员数据(训练集)", label="选择靶向数据来源")
d_btn = gr.Button("🎲 随机提取样本", variant="primary")
d_meta = gr.Markdown()
with gr.Column(scale=3):
d_q = gr.Textbox(label="🧑‍🎓 学生提问 (Prompt)", lines=4, interactive=False)
d_a = gr.Textbox(label="🤖 标准回答 (Ground Truth)", lines=4, interactive=False)
d_btn.click(cb_sample, [d_src], [d_meta, d_q, d_a])
# ═══════ Tab 3: 攻击与防御验证 ═══════
with gr.Tab("🎯 攻击验证"):
gr.Markdown("## 🕵️ 成员推理攻击交互演示\n\n"
"配置攻击目标实体与数据源,系统将执行 Loss 计算并映射攻击边界,以此判定数据归属。")
with gr.Row(equal_height=True):
with gr.Column(scale=2):
a_target = gr.Radio([u"基线模型 (Baseline)",u"标签平滑 (\u03b5=0.02)",u"标签平滑 (\u03b5=0.2)",
u"输出扰动 (\u03c3=0.01)",u"输出扰动 (\u03c3=0.015)",u"输出扰动 (\u03c3=0.02)"],
value=u"基线模型 (Baseline)", label="选择攻击目标")
a_src = gr.Radio(["成员数据(训练集)","非成员数据(测试集)"], value="成员数据(训练集)", label="数据来源")
a_idx = gr.Slider(0, 999, step=1, value=12, label="定位样本 ID")
a_btn = gr.Button("⚡ 执行成员推理攻击", variant="primary", size="lg")
a_qtxt = gr.Markdown()
with gr.Column(scale=3):
a_gauge = gr.Plot(label="Loss位置判定 (Decision Boundary)")
a_res = gr.Markdown()
a_btn.click(cb_attack, [a_idx, a_src, a_target], [a_qtxt, a_gauge, a_res])
# ═══════ Tab 4: 防御效果分析 ═══════
with gr.Tab("🛡️ 防御分析"):
with gr.Accordion("📊 展开查看:防御对比直方图", open=False):
gr.Markdown("### MIA攻击AUC对比\n\n> 柱子越矮 = AUC越低 = 攻击越难成功 = 防御越有效")
gr.Plot(value=fig_auc_bar())
gr.Markdown("### Loss分布对比\n#### 三个模型(训练期防御效果)\n\n> 蓝色=成员,红色=非成员。两色重叠越多 = 攻击者越难区分")
gr.Plot(value=fig_loss_dist())
gr.Markdown("#### 输出扰动效果(推理期防御)\n\n> 在基线模型Loss上加噪声,随噪声增大分布更加重叠")
gr.Plot(value=fig_perturb_dist())
with gr.Accordion("⚙️ 展开查看:完整实验数据与机制说明", open=True):
gr.Markdown(
"### 完整实验数据表\n\n"
"| 策略 | 类型 | AUC | 准确率 | AUC变化 |\n|---|---|---|---|---|\n"
"| 基线 | — | " + f"{bl_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | — |\n"
"| " + u"LS(\u03b5=0.02)" + " | 训练期 | " + f"{s002_auc:.4f}" + " | " + f"{s002_acc:.1f}%" + " | " + f"{s002_auc-bl_auc:+.4f}" + " |\n"
"| " + u"LS(\u03b5=0.2)" + " | 训练期 | " + f"{s02_auc:.4f}" + " | " + f"{s02_acc:.1f}%" + " | " + f"{s02_auc-bl_auc:+.4f}" + " |\n"
"| " + u"OP(\u03c3=0.01)" + " | 推理期 | " + f"{op001_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | " + f"{op001_auc-bl_auc:+.4f}" + " |\n"
"| " + u"OP(\u03c3=0.015)" + " | 推理期 | " + f"{op0015_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | " + f"{op0015_auc-bl_auc:+.4f}" + " |\n"
"| " + u"OP(\u03c3=0.02)" + " | 推理期 | " + f"{op002_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | " + f"{op002_auc-bl_auc:+.4f}" + " |\n\n"
"### 防御机制对比\n\n"
"| 维度 | 标签平滑 | 输出扰动 |\n|---|---|---|\n"
"| **阶段** | 训练期 | 推理期 |\n"
"| **原理** | 软化标签降低记忆 | Loss加噪遮蔽信号 |\n"
"| **需重训** | 是 | 否 |\n"
"| **效用** | 取决于参数 | 无 |\n"
"| **部署** | 训练时介入 | 即插即用 |\n\n"
"**标签平滑公式**: `y_smooth = (1 - ε) * y_onehot + ε / V`\n\n"
"**输出扰动公式**: `L_perturbed = L_original + N(0, σ²)`\n")
with gr.Accordion("🖼️ 展开查看:静态高分辨率学术图表", open=False):
for fn, cap in [("fig1_loss_distribution_comparison.png","Loss分布对比"),
("fig2_privacy_utility_tradeoff_fixed.png","隐私-效用权衡"),
("fig3_defense_comparison_bar.png","防御策略AUC对比")]:
p = os.path.join(BASE_DIR,"figures",fn)
if os.path.exists(p):
gr.Markdown("#### "+cap); gr.Image(value=p, show_label=False, height=420)
# ═══════ Tab 5: 效用评估 ═══════
with gr.Tab("⚖️ 效用评估"):
gr.Markdown("## 🎯 模型效用测试\n\n> 基于300道数学测试题评估各策略对模型实际能力的影响")
with gr.Row(equal_height=True):
with gr.Column(): gr.Plot(value=fig_acc_bar())
with gr.Column(): gr.Plot(value=fig_tradeoff())
gr.Markdown("## 🎮 在线效用抽样演示\n\n从测试题库中随机抽取,流式验证不同模型/策略的保留作答情况。")
with gr.Row(equal_height=True):
with gr.Column(scale=1):
e_model = gr.Radio([u"基线模型",u"标签平滑 (\u03b5=0.02)",u"标签平滑 (\u03b5=0.2)",
u"输出扰动 (\u03c3=0.01)",u"输出扰动 (\u03c3=0.015)",u"输出扰动 (\u03c3=0.02)"], value=u"基线模型", label="选择验证模型")
e_btn = gr.Button("🧪 随机抽题测试", variant="primary")
with gr.Column(scale=2):
e_res = gr.Markdown()
e_btn.click(cb_eval, [e_model], [e_res])
# ═══════ Tab 6: 研究结论 (引入折叠面板展开) ═══════
with gr.Tab("📝 研究结论"):
gr.Markdown("## 💡 核心研究发现与最佳实践\n\n---")
with gr.Accordion("🚨 一、教育大模型存在可量化的MIA风险", open=True):
gr.Markdown("基线模型 AUC = **" + f"{bl_auc:.4f}" + "** > 0.5,成员平均Loss (" + f"{bl_m_mean:.4f}"
+ ") 显著小于 非成员 (" + f"{bl_nm_mean:.4f}" + "),实验铁证表明模型对训练数据存在可利用的记忆效应。")
with gr.Accordion("🛡️ 二、标签平滑(训练期防御)有效性验证", open=True):
gr.Markdown(
"| 参数 | AUC | 准确率 | 分析 |\n|---|---|---|---|\n"
"| 基线 | " + f"{bl_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | 无防御 |\n"
"| " + u"\u03b5=0.02" + " | " + f"{s002_auc:.4f}" + " | " + f"{s002_acc:.1f}%" + " | 正则化提升泛化 |\n"
"| " + u"\u03b5=0.2" + " | " + f"{s02_auc:.4f}" + " | " + f"{s02_acc:.1f}%" + " | 防御增强 |\n"
)
with gr.Accordion("🎛️ 三、输出扰动(推理期防御)有效性验证", open=True):
gr.Markdown(
"| 参数 | AUC | AUC降幅 | 准确率 |\n|---|---|---|---|\n"
"| " + u"\u03c3=0.01" + " | " + f"{op001_auc:.4f}" + " | " + f"{bl_auc-op001_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " |\n"
"| " + u"\u03c3=0.015" + " | " + f"{op0015_auc:.4f}" + " | " + f"{bl_auc-op0015_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " |\n"
"| " + u"\u03c3=0.02" + " | " + f"{op002_auc:.4f}" + " | " + f"{bl_auc-op002_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " |\n\n"
"**结论:零效用损失,适合已部署系统的后期加固。**"
)
with gr.Accordion("⚖️ 四、隐私-效用权衡总结", open=False):
gr.Markdown(
"| 策略 | AUC | 准确率 | 隐私 | 效用 |\n|---|---|---|---|---|\n"
"| 基线 | " + f"{bl_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | 风险最高 | 基准 |\n"
"| " + u"LS(\u03b5=0.02)" + " | " + f"{s002_auc:.4f}" + " | " + f"{s002_acc:.1f}%" + " | 降低 | 提升 |\n"
"| " + u"LS(\u03b5=0.2)" + " | " + f"{s02_auc:.4f}" + " | " + f"{s02_acc:.1f}%" + " | 显著降低 | 可接受 |\n"
"| " + u"OP(\u03c3=0.02)" + " | " + f"{op002_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | 显著降低 | 不变 |\n\n"
"> 两类策略机制互补:标签平滑从训练阶段降低记忆,输出扰动从推理阶段遮蔽信号。建议组合使用以构建立体防御体系。"
)
gr.HTML("<div style='text-align:center;color:#94a3b8;font-size:.82rem;padding:16px 0 8px'></div>")
demo.launch()