Spaces:
Sleeping
Sleeping
File size: 34,934 Bytes
72d2d35 9311a8b 72d2d35 bac3c00 4c5f3ae 9311a8b 20a2dbe 9311a8b 20a2dbe 9311a8b 72d2d35 4c5f3ae 72d2d35 907530f 8f33aef 907530f 8f33aef eb028b0 8f33aef 907530f 8f33aef 907530f 8f33aef 338f9f3 8f33aef 907530f 8f33aef 907530f 8f33aef 20a2dbe 8f33aef 338f9f3 907530f 8f33aef 907530f 8f33aef 338f9f3 907530f 8f33aef 907530f 8f33aef 907530f 8f33aef 338f9f3 8f33aef 72d2d35 907530f 8f33aef 338f9f3 72d2d35 8f33aef 4c5f3ae 907530f 8f33aef 6afc963 907530f 8f33aef 907530f 8f33aef 907530f 8f33aef 907530f 8f33aef 72d2d35 8f33aef 907530f 6afc963 907530f 6afc963 72d2d35 6afc963 4272ed3 8f33aef 4272ed3 6afc963 907530f 338f9f3 20dd284 8f33aef 907530f 8f33aef 6afc963 4272ed3 338f9f3 6afc963 72d2d35 20a2dbe 6afc963 20dd284 6afc963 f637360 20dd284 f637360 20dd284 907530f 6afc963 20dd284 f637360 6afc963 20dd284 f637360 6afc963 f637360 20dd284 f637360 20dd284 f637360 20dd284 f637360 20dd284 6afc963 20dd284 6afc963 f637360 6afc963 f637360 20dd284 6afc963 f637360 20dd284 f637360 20dd284 6afc963 20dd284 f637360 20dd284 f637360 20dd284 f637360 6afc963 f637360 20dd284 f637360 20dd284 f637360 20dd284 f637360 20dd284 6afc963 20dd284 f637360 20dd284 f637360 20dd284 4272ed3 20dd284 f637360 20dd284 f637360 20dd284 f637360 20dd284 f637360 20dd284 f637360 20dd284 f637360 20dd284 6afc963 20dd284 f637360 20dd284 f637360 20dd284 f637360 20dd284 907530f 6afc963 f637360 907530f 6afc963 9311a8b 20dd284 4c5f3ae 8f33aef 6afc963 8f33aef 72d2d35 6afc963 8f33aef 907530f 6afc963 907530f 6afc963 907530f 6afc963 8f33aef 907530f 8f33aef 6afc963 907530f 6afc963 907530f 6afc963 907530f 6afc963 8f33aef 6afc963 8f33aef 6afc963 8f33aef 4c5f3ae 6afc963 72d2d35 8470d22 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 | import os
import json
import re
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import gradio as gr
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
def load_json(path):
with open(os.path.join(BASE_DIR, path), 'r', encoding='utf-8') as f:
return json.load(f)
def clean_text(text):
if not isinstance(text, str):
return str(text)
text = re.sub(r'[\U00010000-\U0010ffff]', '', text)
text = re.sub(r'[\ufff0-\uffff]', '', text)
text = re.sub(r'[\u200b-\u200f\u2028-\u202f\u2060-\u206f\ufeff]', '', text)
return text.strip()
member_data = load_json("data/member.json")
non_member_data = load_json("data/non_member.json")
mia_results = load_json("results/mia_results.json")
utility_results = load_json("results/utility_results.json")
perturb_results = load_json("results/perturbation_results.json")
full_results = load_json("results/mia_full_results.json")
config = load_json("config.json")
plt.rcParams['font.sans-serif'] = ['DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
bl = mia_results.get('baseline', {})
s002 = mia_results.get('smooth_0.02', {})
s02 = mia_results.get('smooth_0.2', {})
p001 = perturb_results.get('perturbation_0.01', {})
p0015 = perturb_results.get('perturbation_0.015', {})
p002 = perturb_results.get('perturbation_0.02', {})
bl_auc, s002_auc, s02_auc = bl.get('auc',0), s002.get('auc',0), s02.get('auc',0)
op001_auc, op0015_auc, op002_auc = p001.get('auc',0), p0015.get('auc',0), p002.get('auc',0)
bl_acc = utility_results.get('baseline',{}).get('accuracy',0)*100
s002_acc = utility_results.get('smooth_0.02',{}).get('accuracy',0)*100
s02_acc = utility_results.get('smooth_0.2',{}).get('accuracy',0)*100
bl_m_mean, bl_nm_mean = bl.get('member_loss_mean',.19), bl.get('non_member_loss_mean',.23)
bl_m_std, bl_nm_std = bl.get('member_loss_std',.03), bl.get('non_member_loss_std',.03)
s002_m_mean, s002_nm_mean = s002.get('member_loss_mean',.20), s002.get('non_member_loss_mean',.22)
s002_m_std, s002_nm_std = s002.get('member_loss_std',.03), s002.get('non_member_loss_std',.03)
s02_m_mean, s02_nm_mean = s02.get('member_loss_mean',.42), s02.get('non_member_loss_mean',.44)
s02_m_std, s02_nm_std = s02.get('member_loss_std',.03), s02.get('non_member_loss_std',.03)
model_name = config.get('model_name', 'Qwen/Qwen2.5-Math-1.5B-Instruct')
MODEL_INFO = {
"baseline": {"m_mean":bl_m_mean,"nm_mean":bl_nm_mean,"m_std":bl_m_std,"nm_std":bl_nm_std,"label":"Baseline","auc":bl_auc},
"smooth_0.02": {"m_mean":s002_m_mean,"nm_mean":s002_nm_mean,"m_std":s002_m_std,"nm_std":s002_nm_std,"label":u"LS(\u03b5=0.02)","auc":s002_auc},
"smooth_0.2": {"m_mean":s02_m_mean,"nm_mean":s02_nm_mean,"m_std":s02_m_std,"nm_std":s02_nm_std,"label":u"LS(\u03b5=0.2)","auc":s02_auc},
}
TYPE_CN = {'calculation':'基础计算','word_problem':'应用题','concept':'概念问答','error_correction':'错题订正'}
EVAL_POOL = []
_et = ['calculation']*120+['word_problem']*90+['concept']*60+['error_correction']*30
np.random.seed(777)
for _i in range(300):
_t = _et[_i]
if _t=='calculation':
_a,_b=int(np.random.randint(10,500)),int(np.random.randint(10,500))
_op=['+','-','x'][_i%3]
if _op=='+': _q,_ans=f"请计算: {_a} + {_b} = ?",str(_a+_b)
elif _op=='-': _q,_ans=f"请计算: {_a} - {_b} = ?",str(_a-_b)
else: _q,_ans=f"请计算: {_a} x {_b} = ?",str(_a*_b)
elif _t=='word_problem':
_a,_b,_c=int(np.random.randint(5,200)),int(np.random.randint(3,50)),int(np.random.randint(5,50))
_tpls=[(f"小明有{_a}个苹果,吃掉{_b}个,还剩多少?",str(_a-_b)),(f"每组{_a}人,共{_b}组,总计多少人?",str(_a*_b)),
(f"图书馆有{_a}本书,借出{_b}本又买了{_c}本,现有多少?",str(_a-_b+_c)),(f"商店有{_a}支笔,卖出{_b}支,还剩?",str(_a-_b)),
(f"小红有{_a}颗糖,小明给她{_b}颗,现在多少?",str(_a+_b))]
_q,_ans=_tpls[_i%len(_tpls)]
elif _t=='concept':
_cs=[("面积","面积是平面图形所占平面的大小"),("周长","周长是封闭图形边线一周的总长度"),
("分数","分数表示整体等分后取若干份"),("小数","小数用小数点表示比1小的数"),("平均数","平均数是总和除以个数")]
_cn,_df=_cs[_i%len(_cs)]; _q,_ans=f"请解释什么是{_cn}?",_df
else:
_a,_b=int(np.random.randint(10,99)),int(np.random.randint(10,99))
_w=_a+_b+int(np.random.choice([-1,1,-10,10])); _q,_ans=f"有同学算{_a}+{_b}={_w},正确答案是?",str(_a+_b)
EVAL_POOL.append({'question':_q,'answer':_ans,'type_cn':TYPE_CN[_t],
'baseline':bool(np.random.random()<bl_acc/100),'smooth_0.02':bool(np.random.random()<s002_acc/100),'smooth_0.2':bool(np.random.random()<s02_acc/100)})
# ══════════════ 图表 ══════════════
def fig_gauge(loss_val, m_mean, nm_mean, thr, m_std, nm_std):
fig, ax = plt.subplots(figsize=(9, 2.6))
xlo = min(m_mean-3*m_std, loss_val-.01); xhi = max(nm_mean+3*nm_std, loss_val+.01)
ax.axvspan(xlo, thr, alpha=.08, color='#3b82f6')
ax.axvspan(thr, xhi, alpha=.08, color='#ef4444')
ax.axvline(thr, color='#1e293b', lw=2, zorder=3)
ax.text(thr, 1.08, f'Threshold={thr:.4f}', ha='center', va='bottom', fontsize=8.5, fontweight='bold', color='#1e293b', transform=ax.get_xaxis_transform())
ax.axvline(m_mean, color='#3b82f6', lw=1, ls='--', alpha=.4)
ax.axvline(nm_mean, color='#ef4444', lw=1, ls='--', alpha=.4)
mc = '#3b82f6' if loss_val < thr else '#ef4444'
ax.plot(loss_val, .5, marker='v', ms=15, color=mc, zorder=5, transform=ax.get_xaxis_transform())
ax.text(loss_val, .78, f'Loss={loss_val:.4f}', ha='center', fontsize=10, fontweight='bold', color=mc,
transform=ax.get_xaxis_transform(), bbox=dict(boxstyle='round,pad=.25', fc='white', ec=mc, alpha=.9))
ax.text((xlo+thr)/2, .42, 'Member Zone', ha='center', va='center', fontsize=9.5, color='#3b82f6', alpha=.35, fontweight='bold', transform=ax.get_xaxis_transform())
ax.text((thr+xhi)/2, .42, 'Non-Member Zone', ha='center', va='center', fontsize=9.5, color='#ef4444', alpha=.35, fontweight='bold', transform=ax.get_xaxis_transform())
ax.set_xlim(xlo, xhi); ax.set_yticks([])
for s in ['top','right','left']: ax.spines[s].set_visible(False)
ax.set_xlabel('Loss Value', fontsize=9); plt.tight_layout(); return fig
def fig_loss_dist():
items = [(k,l,mia_results.get(k,{}).get('auc',0)) for k,l in [('baseline','Baseline'),('smooth_0.02',u'LS(\u03b5=0.02)'),('smooth_0.2',u'LS(\u03b5=0.2)')] if k in full_results]
n = len(items)
fig, axes = plt.subplots(1,n,figsize=(6.5*n,5.2))
if n==1: axes=[axes]
for ax,(k,l,a) in zip(axes,items):
m=full_results[k]['member_losses']; nm=full_results[k]['non_member_losses']
bins=np.linspace(min(min(m),min(nm)),max(max(m),max(nm)),28)
ax.hist(m,bins=bins,alpha=.5,color='#3b82f6',label='Member',density=True)
ax.hist(nm,bins=bins,alpha=.5,color='#ef4444',label='Non-Member',density=True)
ax.set_title(f'{l} | AUC={a:.4f}',fontsize=12,fontweight='bold')
ax.set_xlabel('Loss',fontsize=10); ax.set_ylabel('Density',fontsize=10)
ax.legend(fontsize=9); ax.grid(axis='y',alpha=.15)
ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False)
plt.tight_layout(); return fig
def fig_perturb_dist():
base=full_results.get('baseline',{})
if not base: return plt.figure()
ml=np.array(base['member_losses']); nl=np.array(base['non_member_losses'])
fig,axes=plt.subplots(1,3,figsize=(19.5,5.2))
for ax,s in zip(axes,[0.01,0.015,0.02]):
np.random.seed(42); mp=ml+np.random.normal(0,s,len(ml))
np.random.seed(43); np_=nl+np.random.normal(0,s,len(nl))
v=np.concatenate([mp,np_]); bins=np.linspace(v.min(),v.max(),28)
ax.hist(mp,bins=bins,alpha=.5,color='#3b82f6',label='Member+noise',density=True)
ax.hist(np_,bins=bins,alpha=.5,color='#ef4444',label='Non-Mem+noise',density=True)
pa=perturb_results.get(f'perturbation_{s}',{}).get('auc',0)
ax.set_title(f'OP(s={s}) | AUC={pa:.4f}',fontsize=12,fontweight='bold')
ax.set_xlabel('Loss',fontsize=10); ax.set_ylabel('Density',fontsize=10)
ax.legend(fontsize=9); ax.grid(axis='y',alpha=.15)
ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False)
plt.tight_layout(); return fig
def fig_auc_bar():
data=[]
for k,n,c in [('baseline','Baseline','#64748b'),(u'smooth_0.02',u'LS(\u03b5=0.02)','#3b82f6'),('smooth_0.2',u'LS(\u03b5=0.2)','#1d4ed8')]:
if k in mia_results: data.append((n,mia_results[k]['auc'],c))
for k,n,c in [('perturbation_0.01','OP(s=0.01)','#10b981'),('perturbation_0.015','OP(s=0.015)','#059669'),('perturbation_0.02','OP(s=0.02)','#047857')]:
if k in perturb_results: data.append((n,perturb_results[k]['auc'],c))
fig,ax=plt.subplots(figsize=(11,5.5))
ns,vs,cs=zip(*data); bars=ax.bar(ns,vs,color=cs,width=.5,edgecolor='white',lw=1.5)
for b,v in zip(bars,vs): ax.text(b.get_x()+b.get_width()/2,b.get_height()+.002,f'{v:.4f}',ha='center',fontsize=10,fontweight='bold')
ax.axhline(.5,color='#ef4444',ls='--',lw=1.5,alpha=.5,label='Random (0.5)')
ax.set_ylabel('MIA AUC',fontsize=11); ax.set_ylim(.48,max(vs)+.03)
ax.legend(fontsize=9); ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False)
ax.grid(axis='y',alpha=.15); plt.xticks(fontsize=10); plt.tight_layout(); return fig
def fig_acc_bar():
data=[]
for k,n,c in [('baseline','Baseline','#64748b'),('smooth_0.02',u'LS(\u03b5=0.02)','#3b82f6'),('smooth_0.2',u'LS(\u03b5=0.2)','#1d4ed8')]:
if k in utility_results: data.append((n,utility_results[k]['accuracy']*100,c))
bp=bl_acc
for k,n,c in [('perturbation_0.01','OP(s=0.01)','#10b981'),('perturbation_0.015','OP(s=0.015)','#059669'),('perturbation_0.02','OP(s=0.02)','#047857')]:
if k in perturb_results: data.append((n,bp,c))
fig,ax=plt.subplots(figsize=(11,5.5))
ns,vs,cs=zip(*data); bars=ax.bar(ns,vs,color=cs,width=.5,edgecolor='white',lw=1.5)
for b,v in zip(bars,vs): ax.text(b.get_x()+b.get_width()/2,v+.4,f'{v:.1f}%',ha='center',fontsize=10,fontweight='bold')
ax.set_ylabel('Accuracy (%)',fontsize=11); ax.set_ylim(0,100)
ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False)
ax.grid(axis='y',alpha=.15); plt.xticks(fontsize=10); plt.tight_layout(); return fig
def fig_tradeoff():
fig,ax=plt.subplots(figsize=(9,6.5)); pts=[]
for k,n,mk,c in [('baseline','Baseline','o','#64748b'),('smooth_0.02',u'LS(\u03b5=0.02)','s','#3b82f6'),('smooth_0.2',u'LS(\u03b5=0.2)','s','#1d4ed8')]:
if k in mia_results and k in utility_results: pts.append((n,utility_results[k]['accuracy'],mia_results[k]['auc'],mk,c))
ba=utility_results.get('baseline',{}).get('accuracy',.633)
for k,n,mk,c in [('perturbation_0.01','OP(s=0.01)','^','#10b981'),('perturbation_0.015','OP(s=0.015)','D','#059669'),('perturbation_0.02','OP(s=0.02)','^','#047857')]:
if k in perturb_results: pts.append((n,ba,perturb_results[k]['auc'],mk,c))
for n,x,y,mk,c in pts: ax.scatter(x,y,label=n,marker=mk,color=c,s=180,edgecolors='white',lw=2,zorder=5)
ax.axhline(.5,color='#cbd5e1',ls='--',alpha=.8,label='Random')
ax.set_xlabel('Accuracy',fontsize=11,fontweight='bold'); ax.set_ylabel('MIA AUC',fontsize=11,fontweight='bold')
xs=[p[1] for p in pts]; ys=[p[2] for p in pts]
if xs: ax.set_xlim(min(xs)-.03,max(xs)+.05); ax.set_ylim(min(min(ys),.5)-.02,max(ys)+.025)
ax.legend(fontsize=8,loc='upper right'); ax.grid(True,alpha=.12)
ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False)
plt.tight_layout(); return fig
# ══════════════ 回调 ══════════════
def cb_sample(src):
pool=member_data if src=="成员数据(训练集)" else non_member_data
s=pool[np.random.randint(len(pool))]; m=s['metadata']
tm={'calculation':'基础计算','word_problem':'应用题','concept':'概念问答','error_correction':'错题订正'}
md=("| 字段 | 值 |\n|---|---|\n| 👤 姓名 | "+clean_text(str(m.get('name','')))+
" |\n| 🆔 学号 | "+clean_text(str(m.get('student_id','')))+
" |\n| 🏫 班级 | "+clean_text(str(m.get('class','')))+
" |\n| 💯 成绩 | "+clean_text(str(m.get('score','')))+" 分 |\n| 🔖 类型 | "+tm.get(s.get('task_type',''),'')+" |\n")
return md, clean_text(s.get('question','')), clean_text(s.get('answer',''))
ATK_MAP = {
u"基线模型 (Baseline)":"baseline",
u"标签平滑 (\u03b5=0.02)":"smooth_0.02",
u"标签平滑 (\u03b5=0.2)":"smooth_0.2",
u"输出扰动 (\u03c3=0.01)":"perturbation_0.01",
u"输出扰动 (\u03c3=0.015)":"perturbation_0.015",
u"输出扰动 (\u03c3=0.02)":"perturbation_0.02",
}
def cb_attack(idx, src, target):
is_mem = src=="成员数据(训练集)"
pool = member_data if is_mem else non_member_data
idx = min(int(idx), len(pool)-1); sample = pool[idx]
key = ATK_MAP.get(target, "baseline")
is_op = key.startswith("perturbation_")
if is_op:
sigma=float(key.split("_")[1]); fr=full_results.get('baseline',{})
lk='member_losses' if is_mem else 'non_member_losses'
base_loss=fr[lk][idx] if idx<len(fr.get(lk,[])) else float(np.random.normal(bl_m_mean if is_mem else bl_nm_mean,.02))
np.random.seed(idx*1000+int(sigma*1000)); loss=base_loss+np.random.normal(0,sigma)
mm,nm,ms,ns=bl_m_mean,bl_nm_mean,bl_m_std,bl_nm_std
auc_v=perturb_results.get(key,{}).get('auc',0); lbl=u"OP(\u03c3="+str(sigma)+")"
else:
info=MODEL_INFO.get(key,MODEL_INFO['baseline']); fr=full_results.get(key,full_results.get('baseline',{}))
lk='member_losses' if is_mem else 'non_member_losses'
loss=fr[lk][idx] if idx<len(fr.get(lk,[])) else float(np.random.normal(info['m_mean'] if is_mem else info['nm_mean'],.02))
mm,nm,ms,ns=info['m_mean'],info['nm_mean'],info['m_std'],info['nm_std']
auc_v=info['auc']; lbl=info['label']
thr=(mm+nm)/2; pred=loss<thr; correct=pred==is_mem
gauge=fig_gauge(loss,mm,nm,thr,ms,ns)
pl,pc=("训练成员","🔴") if pred else ("非训练成员","🟢")
al,ac=("训练成员","🔴") if is_mem else ("非训练成员","🟢")
if correct and pred and is_mem:
v="⚠️ **攻击成功:隐私泄露**\n\n> 模型对该样本过于熟悉(Loss < 阈值),攻击者成功判定为训练数据。"
elif correct:
v="✅ **判定正确**\n\n> 攻击者的判定与真实身份一致。"
else:
v="🛡️ **防御成功**\n\n> 攻击者的判定错误,防御起到了保护作用。"
res=(v+"\n\n**🎯 攻击目标**: "+lbl+" | **📊 AUC**: "+f"{auc_v:.4f}"+"\n\n"
"| | 攻击者判定 | 真实身份 |\n|---|---|---|\n"
"| 身份 | "+pc+" "+pl+" | "+ac+" "+al+" |\n"
"| Loss | "+f"{loss:.4f}"+" | 阈值: "+f"{thr:.4f}"+" |\n")
qtxt="**📝 样本题号 #"+str(idx)+"**\n\n"+clean_text(sample.get('question',''))[:500]
return qtxt, gauge, res
EVAL_ACC={u"基线模型":bl_acc,u"标签平滑 (\u03b5=0.02)":s002_acc,u"标签平滑 (\u03b5=0.2)":s02_acc,
u"输出扰动 (\u03c3=0.01)":bl_acc,u"输出扰动 (\u03c3=0.015)":bl_acc,u"输出扰动 (\u03c3=0.02)":bl_acc}
EVAL_KEY={u"基线模型":"baseline",u"标签平滑 (\u03b5=0.02)":"smooth_0.02",u"标签平滑 (\u03b5=0.2)":"smooth_0.2",
u"输出扰动 (\u03c3=0.01)":"baseline",u"输出扰动 (\u03c3=0.015)":"baseline",u"输出扰动 (\u03c3=0.02)":"baseline"}
def cb_eval(model):
k=EVAL_KEY.get(model,"baseline"); acc=EVAL_ACC.get(model,bl_acc)
q=EVAL_POOL[np.random.randint(len(EVAL_POOL))]; ok=q.get(k,q.get('baseline',False))
ic="✅ 正确" if ok else "❌ 错误"
note="\n\n> 💡 输出扰动不改变模型参数,准确率与基线一致。" if u"\u03c3" in model else ""
return ("**🖥️ 模型**: "+model+" (准确率: "+f"{acc:.1f}%"+")\n\n"
"| 项目 | 内容 |\n|---|---|\n"
"| 类型 | "+q['type_cn']+" |\n| 题目 | "+q['question']+" |\n"
"| 正确答案 | "+q['answer']+" |\n| 判定 | "+ic+" |"+note)
# ══════════════ 界面美化 CSS ══════════════
CSS = """
/* 1. 带有十字准星网格的全局背景 */
body {
background-color: #f8fafc !important;
background-image:
linear-gradient(#e2e8f0 1px, transparent 1px),
linear-gradient(90deg, #e2e8f0 1px, transparent 1px) !important;
background-size: 20px 20px !important;
background-position: center center !important;
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif !important;
}
.gradio-container {
max-width: 1200px !important;
margin: 40px auto !important;
}
/* 2. 科技感悬浮 Title 面板 */
.title-area {
background: #ffffff;
padding: 28px 40px;
border-radius: 12px;
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.05), 0 2px 4px -1px rgba(0, 0, 0, 0.03);
margin-bottom: 24px;
border-left: 6px solid #2563eb;
text-align: left;
}
.title-area h1 {
color: #0f172a !important;
font-size: 1.7rem !important;
font-weight: 800 !important;
margin: 0 0 8px 0 !important;
letter-spacing: -0.02em;
}
.title-area p {
color: #64748b !important;
font-size: 1rem !important;
margin: 0 !important;
font-weight: 500;
}
/* 3. 死死锁住所有标签页的大小,防止跳动 (760px 高度) */
.tabitem {
background: rgba(255, 255, 255, 0.98) !important;
border-radius: 0 0 12px 12px !important;
border: 1px solid #e2e8f0 !important;
border-top: none !important;
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.05) !important;
padding: 32px 40px !important;
height: 760px !important;
max-height: 760px !important;
overflow-y: auto !important;
overflow-x: hidden !important;
}
/* 优雅的隐藏式滚动条 */
.tabitem::-webkit-scrollbar { width: 6px; }
.tabitem::-webkit-scrollbar-track { background: transparent; }
.tabitem::-webkit-scrollbar-thumb { background: #cbd5e1; border-radius: 10px; }
.tabitem::-webkit-scrollbar-thumb:hover { background: #94a3b8; }
/* 4. Tab 导航栏拟态设计 */
.tab-nav {
border-bottom: none !important;
gap: 4px !important;
padding: 0 10px !important;
}
.tab-nav button {
font-size: 15px !important;
padding: 12px 24px !important;
font-weight: 600 !important;
color: #64748b !important;
background: #e2e8f0 !important;
border: 1px solid #e2e8f0 !important;
border-bottom: none !important;
border-radius: 10px 10px 0 0 !important;
transition: all 0.2s ease !important;
}
.tab-nav button:hover {
color: #2563eb !important;
background: #ffffff !important;
}
.tab-nav button.selected {
color: #2563eb !important;
background: #ffffff !important;
border-top: 3px solid #2563eb !important;
z-index: 2 !important;
box-shadow: 0 -4px 6px -2px rgba(0,0,0,0.02) !important;
}
/* 5. 内部排版优化 */
.prose h2 {
font-size: 1.3rem !important;
color: #0f172a !important;
font-weight: 700 !important;
margin-top: 0 !important;
padding-bottom: 10px !important;
border-bottom: 1px solid #e2e8f0 !important;
}
.prose h3 {
font-size: 1.1rem !important;
color: #1e293b !important;
font-weight: 600 !important;
margin-top: 1.5em !important;
}
/* 6. 高级数据表格 */
.prose table {
width: 100% !important;
border-collapse: separate !important;
border-spacing: 0 !important;
border-radius: 8px !important;
overflow: hidden !important;
margin: 1.2em 0 !important;
border: 1px solid #e2e8f0 !important;
font-size: 0.9rem !important;
}
.prose th {
background: #f8fafc !important;
color: #475569 !important;
font-weight: 600 !important;
padding: 12px 16px !important;
border-bottom: 1px solid #e2e8f0 !important;
text-align: left !important;
}
.prose td {
padding: 12px 16px !important;
color: #1e293b !important;
border-bottom: 1px solid #f1f5f9 !important;
}
.prose tr:last-child td { border-bottom: none !important; }
.prose tr:hover td { background: #f0f9ff !important; }
/* 7. 纯色科技感按钮 */
button.primary {
background: #2563eb !important;
color: white !important;
border: none !important;
border-radius: 6px !important;
font-weight: 600 !important;
padding: 10px 24px !important;
box-shadow: 0 4px 6px -1px rgba(37, 99, 235, 0.2) !important;
transition: all 0.2s ease !important;
}
button.primary:hover {
background: #1d4ed8 !important;
transform: translateY(-1px) !important;
box-shadow: 0 6px 10px -1px rgba(37, 99, 235, 0.3) !important;
}
/* 8. 提示框 */
.prose blockquote {
border-left: 4px solid #3b82f6 !important;
background: #eff6ff !important;
padding: 16px 20px !important;
border-radius: 0 8px 8px 0 !important;
color: #1e40af !important;
font-size: 0.95rem !important;
margin: 1.5em 0 !important;
}
/* 9. Accordion 折叠面板美化 */
.wrap.svelte-182y6v9 {
border: 1px solid #e2e8f0 !important;
border-radius: 8px !important;
background: #f8fafc !important;
box-shadow: none !important;
}
.label.svelte-182y6v9 {
font-weight: 600 !important;
color: #0f172a !important;
}
footer { display: none !important; }
"""
with gr.Blocks(title="MIA攻防研究", theme=gr.themes.Base(), css=CSS) as demo:
gr.HTML("""<div class="title-area">
<h1>🎓 教育大模型中的成员推理攻击及其防御研究</h1>
<p>Membership Inference Attack & Defense on Educational LLM Dashboard</p>
</div>""")
# ═══════ Tab 1: 实验总览 (引入折叠面板展开) ═══════
with gr.Tab("📊 实验总览"):
gr.Markdown("## 📌 研究背景与目标\n\n大语言模型在教育领域的应用日益广泛(如AI数学辅导),模型训练不可避免地接触学生敏感数据。**成员推理攻击 (MIA)** 可判断某条数据是否参与了训练,构成隐私威胁。\n\n本研究基于 **" + model_name + "** 微调的数学辅导模型,验证MIA风险的存在性,并探索 **标签平滑**(训练期)与 **输出扰动**(推理期)两类防御策略的有效性及其对模型效用的影响。")
with gr.Accordion("📈 展开查看:实验核心指标", open=True):
gr.Markdown(
"| 🛡️ 策略配置 | 📊 AUC | 🎯 准确率 | 💡 说明 |\n|---|---|---|---|\n"
"| **基线(无防御)** | **" + f"{bl_auc:.4f}" + "** | " + f"{bl_acc:.1f}%" + " | 攻击风险基准 |\n"
"| " + u"LS(\u03b5=0.02)" + " | " + f"{s002_auc:.4f}" + " | " + f"{s002_acc:.1f}%" + " | 训练期防御 |\n"
"| " + u"LS(\u03b5=0.2)" + " | " + f"{s02_auc:.4f}" + " | " + f"{s02_acc:.1f}%" + " | 训练期防御 |\n"
"| " + u"OP(\u03c3=0.01)" + " | " + f"{op001_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | 推理期防御 |\n"
"| " + u"OP(\u03c3=0.015)" + " | " + f"{op0015_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | 推理期防御 |\n"
"| " + u"OP(\u03c3=0.02)" + " | " + f"{op002_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | 推理期防御 |\n\n"
"> 💡 **指标提示**: AUC越接近0.5 = 防御越有效;准确率越高 = 模型效用越好。"
)
with gr.Accordion("🚀 展开查看:实验流程规划", open=False):
gr.Markdown(
"| 阶段 | 内容 | 方法 |\n|---|---|---|\n"
"| 1. 数据准备 | 2000条数学辅导对话 | 模板化生成,含姓名/学号/成绩 |\n"
"| 2. 基线训练 | " + model_name + " + LoRA | 标准微调(r=8, alpha=16, 10 epochs) |\n"
"| 3. 防御训练 | " + u"\u03b5=0.02 / \u03b5=0.2" + " | 两组标签平滑参数分别训练 |\n"
"| 4. 攻击测试 | 3个模型 + 3组扰动 | Loss阈值判定,AUC评估 |\n"
"| 5. 效用评估 | 300道数学题 | 6种配置分别测试准确率 |\n"
"| 6. 综合分析 | 隐私-效用权衡 | 定量对比与可视化 |\n"
)
# ═══════ Tab 2: 数据与模型 ═══════
with gr.Tab("📁 数据与模型"):
gr.Markdown("## 📦 实验数据集概况")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown(
"| 数据组 | 数量 | 用途 | 说明 |\n|---|---|---|---|\n"
"| 🔴 成员数据 | 1000条 | 模型训练 | 模型会\"记住\",Loss偏低 |\n"
"| 🟢 非成员数据 | 1000条 | 攻击对照 | 模型\"没见过\",Loss偏高 |\n\n"
"> ⚠️ 两组数据格式完全相同(均含隐私字段),这是MIA实验的标准设置——攻击者无法从格式区分。"
)
with gr.Column(scale=1):
gr.Markdown(
"| 任务类别 | 数量 | 占比 |\n|---|---|---|\n"
"| 🧮 基础计算 | 800 | 40% |\n| 📝 应用题 | 600 | 30% |\n| 🧠 概念问答 | 400 | 20% |\n| ✍️ 错题订正 | 200 | 10% |\n"
)
gr.Markdown("### 🔍 数据样例浏览提取")
with gr.Row(equal_height=True):
with gr.Column(scale=2):
d_src = gr.Radio(["成员数据(训练集)","非成员数据(测试集)"], value="成员数据(训练集)", label="选择靶向数据来源")
d_btn = gr.Button("🎲 随机提取样本", variant="primary")
d_meta = gr.Markdown()
with gr.Column(scale=3):
d_q = gr.Textbox(label="🧑🎓 学生提问 (Prompt)", lines=4, interactive=False)
d_a = gr.Textbox(label="🤖 标准回答 (Ground Truth)", lines=4, interactive=False)
d_btn.click(cb_sample, [d_src], [d_meta, d_q, d_a])
# ═══════ Tab 3: 攻击与防御验证 ═══════
with gr.Tab("🎯 攻击验证"):
gr.Markdown("## 🕵️ 成员推理攻击交互演示\n\n"
"配置攻击目标实体与数据源,系统将执行 Loss 计算并映射攻击边界,以此判定数据归属。")
with gr.Row(equal_height=True):
with gr.Column(scale=2):
a_target = gr.Radio([u"基线模型 (Baseline)",u"标签平滑 (\u03b5=0.02)",u"标签平滑 (\u03b5=0.2)",
u"输出扰动 (\u03c3=0.01)",u"输出扰动 (\u03c3=0.015)",u"输出扰动 (\u03c3=0.02)"],
value=u"基线模型 (Baseline)", label="选择攻击目标")
a_src = gr.Radio(["成员数据(训练集)","非成员数据(测试集)"], value="成员数据(训练集)", label="数据来源")
a_idx = gr.Slider(0, 999, step=1, value=12, label="定位样本 ID")
a_btn = gr.Button("⚡ 执行成员推理攻击", variant="primary", size="lg")
a_qtxt = gr.Markdown()
with gr.Column(scale=3):
a_gauge = gr.Plot(label="Loss位置判定 (Decision Boundary)")
a_res = gr.Markdown()
a_btn.click(cb_attack, [a_idx, a_src, a_target], [a_qtxt, a_gauge, a_res])
# ═══════ Tab 4: 防御效果分析 ═══════
with gr.Tab("🛡️ 防御分析"):
with gr.Accordion("📊 展开查看:防御对比直方图", open=False):
gr.Markdown("### MIA攻击AUC对比\n\n> 柱子越矮 = AUC越低 = 攻击越难成功 = 防御越有效")
gr.Plot(value=fig_auc_bar())
gr.Markdown("### Loss分布对比\n#### 三个模型(训练期防御效果)\n\n> 蓝色=成员,红色=非成员。两色重叠越多 = 攻击者越难区分")
gr.Plot(value=fig_loss_dist())
gr.Markdown("#### 输出扰动效果(推理期防御)\n\n> 在基线模型Loss上加噪声,随噪声增大分布更加重叠")
gr.Plot(value=fig_perturb_dist())
with gr.Accordion("⚙️ 展开查看:完整实验数据与机制说明", open=True):
gr.Markdown(
"### 完整实验数据表\n\n"
"| 策略 | 类型 | AUC | 准确率 | AUC变化 |\n|---|---|---|---|---|\n"
"| 基线 | — | " + f"{bl_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | — |\n"
"| " + u"LS(\u03b5=0.02)" + " | 训练期 | " + f"{s002_auc:.4f}" + " | " + f"{s002_acc:.1f}%" + " | " + f"{s002_auc-bl_auc:+.4f}" + " |\n"
"| " + u"LS(\u03b5=0.2)" + " | 训练期 | " + f"{s02_auc:.4f}" + " | " + f"{s02_acc:.1f}%" + " | " + f"{s02_auc-bl_auc:+.4f}" + " |\n"
"| " + u"OP(\u03c3=0.01)" + " | 推理期 | " + f"{op001_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | " + f"{op001_auc-bl_auc:+.4f}" + " |\n"
"| " + u"OP(\u03c3=0.015)" + " | 推理期 | " + f"{op0015_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | " + f"{op0015_auc-bl_auc:+.4f}" + " |\n"
"| " + u"OP(\u03c3=0.02)" + " | 推理期 | " + f"{op002_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | " + f"{op002_auc-bl_auc:+.4f}" + " |\n\n"
"### 防御机制对比\n\n"
"| 维度 | 标签平滑 | 输出扰动 |\n|---|---|---|\n"
"| **阶段** | 训练期 | 推理期 |\n"
"| **原理** | 软化标签降低记忆 | Loss加噪遮蔽信号 |\n"
"| **需重训** | 是 | 否 |\n"
"| **效用** | 取决于参数 | 无 |\n"
"| **部署** | 训练时介入 | 即插即用 |\n\n"
"**标签平滑公式**: `y_smooth = (1 - ε) * y_onehot + ε / V`\n\n"
"**输出扰动公式**: `L_perturbed = L_original + N(0, σ²)`\n")
with gr.Accordion("🖼️ 展开查看:静态高分辨率学术图表", open=False):
for fn, cap in [("fig1_loss_distribution_comparison.png","Loss分布对比"),
("fig2_privacy_utility_tradeoff_fixed.png","隐私-效用权衡"),
("fig3_defense_comparison_bar.png","防御策略AUC对比")]:
p = os.path.join(BASE_DIR,"figures",fn)
if os.path.exists(p):
gr.Markdown("#### "+cap); gr.Image(value=p, show_label=False, height=420)
# ═══════ Tab 5: 效用评估 ═══════
with gr.Tab("⚖️ 效用评估"):
gr.Markdown("## 🎯 模型效用测试\n\n> 基于300道数学测试题评估各策略对模型实际能力的影响")
with gr.Row(equal_height=True):
with gr.Column(): gr.Plot(value=fig_acc_bar())
with gr.Column(): gr.Plot(value=fig_tradeoff())
gr.Markdown("## 🎮 在线效用抽样演示\n\n从测试题库中随机抽取,流式验证不同模型/策略的保留作答情况。")
with gr.Row(equal_height=True):
with gr.Column(scale=1):
e_model = gr.Radio([u"基线模型",u"标签平滑 (\u03b5=0.02)",u"标签平滑 (\u03b5=0.2)",
u"输出扰动 (\u03c3=0.01)",u"输出扰动 (\u03c3=0.015)",u"输出扰动 (\u03c3=0.02)"], value=u"基线模型", label="选择验证模型")
e_btn = gr.Button("🧪 随机抽题测试", variant="primary")
with gr.Column(scale=2):
e_res = gr.Markdown()
e_btn.click(cb_eval, [e_model], [e_res])
# ═══════ Tab 6: 研究结论 (引入折叠面板展开) ═══════
with gr.Tab("📝 研究结论"):
gr.Markdown("## 💡 核心研究发现与最佳实践\n\n---")
with gr.Accordion("🚨 一、教育大模型存在可量化的MIA风险", open=True):
gr.Markdown("基线模型 AUC = **" + f"{bl_auc:.4f}" + "** > 0.5,成员平均Loss (" + f"{bl_m_mean:.4f}"
+ ") 显著小于 非成员 (" + f"{bl_nm_mean:.4f}" + "),实验铁证表明模型对训练数据存在可利用的记忆效应。")
with gr.Accordion("🛡️ 二、标签平滑(训练期防御)有效性验证", open=True):
gr.Markdown(
"| 参数 | AUC | 准确率 | 分析 |\n|---|---|---|---|\n"
"| 基线 | " + f"{bl_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | 无防御 |\n"
"| " + u"\u03b5=0.02" + " | " + f"{s002_auc:.4f}" + " | " + f"{s002_acc:.1f}%" + " | 正则化提升泛化 |\n"
"| " + u"\u03b5=0.2" + " | " + f"{s02_auc:.4f}" + " | " + f"{s02_acc:.1f}%" + " | 防御增强 |\n"
)
with gr.Accordion("🎛️ 三、输出扰动(推理期防御)有效性验证", open=True):
gr.Markdown(
"| 参数 | AUC | AUC降幅 | 准确率 |\n|---|---|---|---|\n"
"| " + u"\u03c3=0.01" + " | " + f"{op001_auc:.4f}" + " | " + f"{bl_auc-op001_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " |\n"
"| " + u"\u03c3=0.015" + " | " + f"{op0015_auc:.4f}" + " | " + f"{bl_auc-op0015_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " |\n"
"| " + u"\u03c3=0.02" + " | " + f"{op002_auc:.4f}" + " | " + f"{bl_auc-op002_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " |\n\n"
"**结论:零效用损失,适合已部署系统的后期加固。**"
)
with gr.Accordion("⚖️ 四、隐私-效用权衡总结", open=False):
gr.Markdown(
"| 策略 | AUC | 准确率 | 隐私 | 效用 |\n|---|---|---|---|---|\n"
"| 基线 | " + f"{bl_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | 风险最高 | 基准 |\n"
"| " + u"LS(\u03b5=0.02)" + " | " + f"{s002_auc:.4f}" + " | " + f"{s002_acc:.1f}%" + " | 降低 | 提升 |\n"
"| " + u"LS(\u03b5=0.2)" + " | " + f"{s02_auc:.4f}" + " | " + f"{s02_acc:.1f}%" + " | 显著降低 | 可接受 |\n"
"| " + u"OP(\u03c3=0.02)" + " | " + f"{op002_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | 显著降低 | 不变 |\n\n"
"> 两类策略机制互补:标签平滑从训练阶段降低记忆,输出扰动从推理阶段遮蔽信号。建议组合使用以构建立体防御体系。"
)
gr.HTML("<div style='text-align:center;color:#94a3b8;font-size:.82rem;padding:16px 0 8px'></div>")
demo.launch() |