xiaohy commited on
Commit
8f33aef
·
verified ·
1 Parent(s): 79abf5a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +311 -333
app.py CHANGED
@@ -35,7 +35,6 @@ config = load_json("config.json")
35
  plt.rcParams['font.sans-serif'] = ['DejaVu Sans']
36
  plt.rcParams['axes.unicode_minus'] = False
37
 
38
- # ── 预取指标 ──
39
  bl = mia_results.get('baseline', {})
40
  s002 = mia_results.get('smooth_0.02', {})
41
  s02 = mia_results.get('smooth_0.2', {})
@@ -43,372 +42,351 @@ p001 = perturb_results.get('perturbation_0.01', {})
43
  p0015 = perturb_results.get('perturbation_0.015', {})
44
  p002 = perturb_results.get('perturbation_0.02', {})
45
 
46
- bl_auc = bl.get('auc', 0)
47
- s002_auc = s002.get('auc', 0)
48
- s02_auc = s02.get('auc', 0)
49
- op001_auc = p001.get('auc', 0)
50
- op0015_auc = p0015.get('auc', 0)
51
- op002_auc = p002.get('auc', 0)
52
-
53
- bl_acc = utility_results.get('baseline', {}).get('accuracy', 0) * 100
54
- s002_acc = utility_results.get('smooth_0.02', {}).get('accuracy', 0) * 100
55
- s02_acc = utility_results.get('smooth_0.2', {}).get('accuracy', 0) * 100
56
-
57
- bl_m_mean = bl.get('member_loss_mean', 0.19)
58
- bl_nm_mean = bl.get('non_member_loss_mean', 0.23)
59
- bl_m_std = bl.get('member_loss_std', 0.03)
60
- bl_nm_std = bl.get('non_member_loss_std', 0.03)
61
- s002_m_mean = s002.get('member_loss_mean', 0.20)
62
- s002_nm_mean = s002.get('non_member_loss_mean', 0.22)
63
- s002_m_std = s002.get('member_loss_std', 0.03)
64
- s002_nm_std = s002.get('non_member_loss_std', 0.03)
65
- s02_m_mean = s02.get('member_loss_mean', 0.42)
66
- s02_nm_mean = s02.get('non_member_loss_mean', 0.44)
67
- s02_m_std = s02.get('member_loss_std', 0.03)
68
- s02_nm_std = s02.get('non_member_loss_std', 0.03)
69
-
70
  model_name = config.get('model_name', 'Qwen/Qwen2.5-Math-1.5B-Instruct')
71
 
72
  MODEL_INFO = {
73
- "baseline": {"m_mean": bl_m_mean, "nm_mean": bl_nm_mean, "m_std": bl_m_std, "nm_std": bl_nm_std, "label": "Baseline", "auc": bl_auc},
74
- "smooth_0.02": {"m_mean": s002_m_mean, "nm_mean": s002_nm_mean, "m_std": s002_m_std, "nm_std": s002_nm_std, "label": "LS(e=0.02)", "auc": s002_auc},
75
- "smooth_0.2": {"m_mean": s02_m_mean, "nm_mean": s02_nm_mean, "m_std": s02_m_std, "nm_std": s02_nm_std, "label": "LS(e=0.2)", "auc": s02_auc},
76
  }
77
 
78
- # ── 用题库 ──
79
  EVAL_POOL = []
80
- TYPE_CN = {'calculation': '基础计算', 'word_problem': '应用题', 'concept': '概念问答', 'error_correction': '错题订正'}
81
- _et = ['calculation'] * 120 + ['word_problem'] * 90 + ['concept'] * 60 + ['error_correction'] * 30
82
  np.random.seed(777)
83
  for _i in range(300):
84
  _t = _et[_i]
85
- if _t == 'calculation':
86
- _a, _b = int(np.random.randint(10, 500)), int(np.random.randint(10, 500))
87
- _op = ['+', '-', 'x'][_i % 3]
88
- if _op == '+': _q, _ans = f"请计算: {_a} + {_b} = ?", str(_a + _b)
89
- elif _op == '-': _q, _ans = f"请计算: {_a} - {_b} = ?", str(_a - _b)
90
- else: _q, _ans = f"请计算: {_a} x {_b} = ?", str(_a * _b)
91
- elif _t == 'word_problem':
92
- _a, _b, _c = int(np.random.randint(5, 200)), int(np.random.randint(3, 50)), int(np.random.randint(5, 50))
93
- _tpls = [(f"小明有{_a}个苹果,吃掉{_b}个,还剩多少?", str(_a-_b)),
94
- (f"每组{_a}{_b}总计多少?", str(_a*_b)),
95
- (f"图书馆有{_a}本书借出{_b}本后又买了{_c}本,现多少?", str(_a-_b+_c)),
96
- (f"商店有{_a}支铅笔,卖出{_b}支,还剩多少?", str(_a-_b)),
97
- (f"小红有{_a}颗糖,小明给了她{_b}颗,现在多少?", str(_a+_b))]
98
- _q, _ans = _tpls[_i % len(_tpls)]
99
- elif _t == 'concept':
100
- _cs = [("面积","面积是平面图形所占平面的大小"),("周长","���长封闭图形边线一周的总长度"),
101
- ("分数","分数表示整体等分后取若干份"),("小数","小数用小数点表示比1小的数"),("平均数","平均数是总和除以个数")]
102
- _cn, _df = _cs[_i % len(_cs)]
103
- _q, _ans = f"请解释什么是{_cn}?", _df
104
  else:
105
- _a, _b = int(np.random.randint(10, 99)), int(np.random.randint(10, 99))
106
- _w = _a + _b + int(np.random.choice([-1, 1, -10, 10]))
107
- _q, _ans = f"有同学算 {_a}+{_b}={_w},正确答案是?", str(_a + _b)
108
- EVAL_POOL.append({'question': _q, 'answer': _ans, 'type_cn': TYPE_CN[_t],
109
- 'baseline': bool(np.random.random() < bl_acc / 100),
110
- 'smooth_0.02': bool(np.random.random() < s002_acc / 100),
111
- 'smooth_0.2': bool(np.random.random() < s02_acc / 100)})
112
-
113
-
114
- # ========================================
115
- # 图表函数
116
- # ========================================
117
-
118
- def fig_loss_gauge(loss_val, m_mean, nm_mean, threshold, m_std, nm_std):
119
- fig, ax = plt.subplots(figsize=(8, 2.5))
120
- xlo = min(m_mean - 3 * m_std, loss_val - 0.01)
121
- xhi = max(nm_mean + 3 * nm_std, loss_val + 0.01)
122
- ax.axvspan(xlo, threshold, alpha=0.10, color='#3b82f6')
123
- ax.axvspan(threshold, xhi, alpha=0.10, color='#ef4444')
124
- ax.axvline(threshold, color='#1e293b', lw=2, zorder=3)
125
- ax.text(threshold, 1.08, 'Threshold', ha='center', va='bottom', fontsize=9, fontweight='bold', color='#1e293b', transform=ax.get_xaxis_transform())
126
- ax.axvline(m_mean, color='#3b82f6', lw=1, ls='--', alpha=.5)
127
- ax.axvline(nm_mean, color='#ef4444', lw=1, ls='--', alpha=.5)
128
- mc = '#3b82f6' if loss_val < threshold else '#ef4444'
129
- ax.plot(loss_val, 0.5, marker='v', ms=14, color=mc, zorder=5, transform=ax.get_xaxis_transform())
130
- ax.text(loss_val, 0.76, f'Loss={loss_val:.4f}', ha='center', fontsize=10, fontweight='bold', color=mc,
131
- transform=ax.get_xaxis_transform(), bbox=dict(boxstyle='round,pad=0.25', fc='white', ec=mc, alpha=.9))
132
- ax.text((xlo + threshold) / 2, 0.45, 'Member\nZone', ha='center', va='center', fontsize=9, color='#3b82f6', alpha=.4, fontweight='bold', transform=ax.get_xaxis_transform())
133
- ax.text((threshold + xhi) / 2, 0.45, 'Non-Member\nZone', ha='center', va='center', fontsize=9, color='#ef4444', alpha=.4, fontweight='bold', transform=ax.get_xaxis_transform())
134
  ax.set_xlim(xlo, xhi); ax.set_yticks([])
135
- for s in ['top', 'right', 'left']: ax.spines[s].set_visible(False)
136
  ax.set_xlabel('Loss Value', fontsize=9); plt.tight_layout(); return fig
137
 
138
 
139
  def fig_loss_dist():
140
- items = [(k, l, mia_results.get(k, {}).get('auc', 0)) for k, l in [('baseline', 'Baseline'), ('smooth_0.02', 'LS(e=0.02)'), ('smooth_0.2', 'LS(e=0.2)')] if k in full_results]
141
  n = len(items)
142
- fig, axes = plt.subplots(1, n, figsize=(6 * n, 5))
143
- if n == 1: axes = [axes]
144
- for ax, (k, l, a) in zip(axes, items):
145
- m = full_results[k]['member_losses']; nm = full_results[k]['non_member_losses']
146
- bins = np.linspace(min(min(m), min(nm)), max(max(m), max(nm)), 28)
147
- ax.hist(m, bins=bins, alpha=.5, color='#3b82f6', label='Member', density=True)
148
- ax.hist(nm, bins=bins, alpha=.5, color='#ef4444', label='Non-Member', density=True)
149
- ax.set_title(f'{l} | AUC={a:.4f}', fontsize=12, fontweight='bold')
150
- ax.set_xlabel('Loss'); ax.set_ylabel('Density'); ax.legend(fontsize=9)
 
151
  ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False)
152
- ax.grid(axis='y', alpha=.2)
153
  plt.tight_layout(); return fig
154
 
155
 
156
  def fig_perturb_dist():
157
- base = full_results.get('baseline', {})
158
  if not base: return plt.figure()
159
- ml = np.array(base['member_losses']); nl = np.array(base['non_member_losses'])
160
- fig, axes = plt.subplots(1, 3, figsize=(18, 5))
161
- for ax, s in zip(axes, [0.01, 0.015, 0.02]):
162
- np.random.seed(42); mp = ml + np.random.normal(0, s, len(ml))
163
- np.random.seed(43); np_ = nl + np.random.normal(0, s, len(nl))
164
- v = np.concatenate([mp, np_]); bins = np.linspace(v.min(), v.max(), 28)
165
- ax.hist(mp, bins=bins, alpha=.5, color='#3b82f6', label='Member+noise', density=True)
166
- ax.hist(np_, bins=bins, alpha=.5, color='#ef4444', label='Non-Mem+noise', density=True)
167
- pa = perturb_results.get(f'perturbation_{s}', {}).get('auc', 0)
168
- ax.set_title(f'OP(s={s}) | AUC={pa:.4f}', fontsize=12, fontweight='bold')
169
- ax.set_xlabel('Loss'); ax.set_ylabel('Density'); ax.legend(fontsize=9)
 
170
  ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False)
171
- ax.grid(axis='y', alpha=.2)
172
  plt.tight_layout(); return fig
173
 
174
 
175
  def fig_auc_bar():
176
- data = []
177
- for k, n, c in [('baseline','Baseline','#64748b'),('smooth_0.02','LS(e=0.02)','#3b82f6'),('smooth_0.2','LS(e=0.2)','#1d4ed8')]:
178
- if k in mia_results: data.append((n, mia_results[k]['auc'], c))
179
- for k, n, c in [('perturbation_0.01','OP(s=0.01)','#10b981'),('perturbation_0.015','OP(s=0.015)','#059669'),('perturbation_0.02','OP(s=0.02)','#047857')]:
180
- if k in perturb_results: data.append((n, perturb_results[k]['auc'], c))
181
- fig, ax = plt.subplots(figsize=(11, 5.5))
182
- ns, vs, cs = zip(*data)
183
- bars = ax.bar(ns, vs, color=cs, width=.5, edgecolor='white', lw=1.5)
184
- for b, v in zip(bars, vs): ax.text(b.get_x()+b.get_width()/2, b.get_height()+.002, f'{v:.4f}', ha='center', fontsize=10, fontweight='bold')
185
- ax.axhline(.5, color='#ef4444', ls='--', lw=1.5, alpha=.5, label='Random (0.5)')
186
- ax.set_ylabel('MIA AUC', fontsize=11); ax.set_ylim(.48, max(vs)+.03)
187
  ax.legend(fontsize=9); ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False)
188
- ax.grid(axis='y', alpha=.2); plt.xticks(fontsize=10); plt.tight_layout(); return fig
189
 
190
 
191
  def fig_acc_bar():
192
- data = []
193
- for k, n, c in [('baseline','Baseline','#64748b'),('smooth_0.02','LS(e=0.02)','#3b82f6'),('smooth_0.2','LS(e=0.2)','#1d4ed8')]:
194
- if k in utility_results: data.append((n, utility_results[k]['accuracy']*100, c))
195
- bp = bl_acc
196
- for k, n, c in [('perturbation_0.01','OP(s=0.01)','#10b981'),('perturbation_0.015','OP(s=0.015)','#059669'),('perturbation_0.02','OP(s=0.02)','#047857')]:
197
- if k in perturb_results: data.append((n, bp, c))
198
- fig, ax = plt.subplots(figsize=(11, 5.5))
199
- ns, vs, cs = zip(*data)
200
- bars = ax.bar(ns, vs, color=cs, width=.5, edgecolor='white', lw=1.5)
201
- for b, v in zip(bars, vs): ax.text(b.get_x()+b.get_width()/2, v+.4, f'{v:.1f}%', ha='center', fontsize=10, fontweight='bold')
202
- ax.set_ylabel('Accuracy (%)', fontsize=11); ax.set_ylim(0, 100)
203
  ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False)
204
- ax.grid(axis='y', alpha=.2); plt.xticks(fontsize=10); plt.tight_layout(); return fig
205
 
206
 
207
  def fig_tradeoff():
208
- fig, ax = plt.subplots(figsize=(9, 6.5))
209
- pts = []
210
- for k, n, mk, c in [('baseline','Baseline','o','#64748b'),('smooth_0.02','LS(e=0.02)','s','#3b82f6'),('smooth_0.2','LS(e=0.2)','s','#1d4ed8')]:
211
- if k in mia_results and k in utility_results: pts.append((n, utility_results[k]['accuracy'], mia_results[k]['auc'], mk, c))
212
- ba = utility_results.get('baseline',{}).get('accuracy',.633)
213
- for k, n, mk, c in [('perturbation_0.01','OP(s=0.01)','^','#10b981'),('perturbation_0.015','OP(s=0.015)','D','#059669'),('perturbation_0.02','OP(s=0.02)','^','#047857')]:
214
- if k in perturb_results: pts.append((n, ba, perturb_results[k]['auc'], mk, c))
215
- for n, x, y, mk, c in pts: ax.scatter(x, y, label=n, marker=mk, color=c, s=180, edgecolors='white', lw=2, zorder=5)
216
- ax.axhline(.5, color='#cbd5e1', ls='--', alpha=.8, label='Random')
217
- ax.set_xlabel('Accuracy', fontsize=11, fontweight='bold'); ax.set_ylabel('MIA AUC (Privacy Risk)', fontsize=11, fontweight='bold')
218
- xs = [p[1] for p in pts]; ys = [p[2] for p in pts]
219
- if xs: ax.set_xlim(min(xs)-.03, max(xs)+.05); ax.set_ylim(min(min(ys),.5)-.02, max(ys)+.02)
220
- ax.legend(fontsize=8, loc='upper right'); ax.grid(True, alpha=.15)
221
  ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False)
222
  plt.tight_layout(); return fig
223
 
224
 
225
- # ========================================
226
- # 回调函数
227
- # ========================================
228
 
229
  def cb_sample(src):
230
- pool = member_data if src == "成员数据(训练集)" else non_member_data
231
- s = pool[np.random.randint(len(pool))]
232
- m = s['metadata']
233
- tm = {'calculation':'基础计算','word_problem':'应用题','concept':'概念问答','error_correction':'错题订正'}
234
- md = ("| 字段 | 值 |\n|---|---|\n"
235
- "| 姓名 | " + clean_text(str(m.get('name',''))) + " |\n"
236
- "| 学号 | " + clean_text(str(m.get('student_id',''))) + " |\n"
237
- "| 班级 | " + clean_text(str(m.get('class',''))) + " |\n"
238
- "| 成绩 | " + clean_text(str(m.get('score',''))) + " 分 |\n"
239
- "| 类型 | " + tm.get(s.get('task_type',''),'') + " |\n")
240
  return md, clean_text(s.get('question','')), clean_text(s.get('answer',''))
241
 
242
 
243
- ATK_MAP = {"基线模型 (Baseline)":"baseline","标签平滑 (e=0.02)":"smooth_0.02","标签平滑 (e=0.2)":"smooth_0.2",
244
- "输出扰动 (s=0.01)":"perturbation_0.01","输出扰动 (s=0.015)":"perturbation_0.015","输出扰动 (s=0.02)":"perturbation_0.02"}
 
 
 
 
 
 
245
 
246
 
247
  def cb_attack(idx, src, target):
248
- is_mem = src == "成员数据(训练集)"
249
  pool = member_data if is_mem else non_member_data
250
- idx = min(int(idx), len(pool)-1)
251
- sample = pool[idx]
252
  key = ATK_MAP.get(target, "baseline")
253
  is_op = key.startswith("perturbation_")
254
-
255
  if is_op:
256
- sigma = float(key.split("_")[1])
257
- fr = full_results.get('baseline', {})
258
- lk = 'member_losses' if is_mem else 'non_member_losses'
259
- base_loss = fr[lk][idx] if idx < len(fr.get(lk, [])) else float(np.random.normal(bl_m_mean if is_mem else bl_nm_mean, .02))
260
- np.random.seed(idx * 1000 + int(sigma * 1000))
261
- loss = base_loss + np.random.normal(0, sigma)
262
- mm, nm, ms, ns = bl_m_mean, bl_nm_mean, bl_m_std, bl_nm_std
263
- auc_v = perturb_results.get(key, {}).get('auc', 0)
264
- lbl = f"OP(s={sigma})"
265
  else:
266
- info = MODEL_INFO.get(key, MODEL_INFO['baseline'])
267
- fr = full_results.get(key, full_results.get('baseline', {}))
268
- lk = 'member_losses' if is_mem else 'non_member_losses'
269
- loss = fr[lk][idx] if idx < len(fr.get(lk, [])) else float(np.random.normal(info['m_mean'] if is_mem else info['nm_mean'], .02))
270
- mm, nm, ms, ns = info['m_mean'], info['nm_mean'], info['m_std'], info['nm_std']
271
- auc_v = info['auc']
272
- lbl = info['label']
273
-
274
- thr = (mm + nm) / 2
275
- pred = loss < thr
276
- correct = pred == is_mem
277
- gauge = fig_loss_gauge(loss, mm, nm, thr, ms, ns)
278
-
279
- pl, pc = ("训练成员","🔴") if pred else ("非训练成员","🟢")
280
- al, ac = ("训练成员","🔴") if is_mem else ("非训练成员","🟢")
281
-
282
  if correct and pred and is_mem:
283
- v = "⚠️ **攻击成功:隐私泄露**\n\n模型对该样本过于熟悉(Loss < 阈值),攻击者成功判定为训练数据。"
284
  elif correct:
285
- v = "✅ **判定正确**\n\n攻击者的判定与真实身份一致。"
286
  else:
287
- v = "🛡️ **防御成功:攻击失误**\n\n攻击者的判定与真实身份不符,防御起到了作用。"
288
-
289
- res = (v + "\n\n**攻击目标**: " + lbl + " | **AUC**: " + f"{auc_v:.4f}" + "\n\n"
290
- "| | 攻击者判定 | 真实身份 |\n|---|---|---|\n"
291
- "| 身份 | " + pc + " " + pl + " | " + ac + " " + al + " |\n"
292
- "| Loss | " + f"{loss:.4f}" + " | 阈值: " + f"{thr:.4f}" + " |\n")
293
-
294
- qtxt = "**样本 #" + str(idx) + "**\n\n" + clean_text(sample.get('question', ''))[:500]
295
  return qtxt, gauge, res
296
 
297
 
298
- EVAL_ACC = {"基线型":bl_acc,"标签平滑 (e=0.02)":s002_acc,"标签平滑 (e=0.2)":s02_acc,
299
- "输出扰动 (s=0.01)":bl_acc,"输出扰动 (s=0.015)":bl_acc,"输出扰动 (s=0.02)":bl_acc}
300
- EVAL_KEY = {"基线模型":"baseline","标签平滑 (e=0.02)":"smooth_0.02","标签平滑 (e=0.2)":"smooth_0.2",
301
- "输出扰动 (s=0.01)":"baseline","输出扰动 (s=0.015)":"baseline","输出扰动 (s=0.02)":"baseline"}
302
 
303
 
304
  def cb_eval(model):
305
- k = EVAL_KEY.get(model, "baseline")
306
- acc = EVAL_ACC.get(model, bl_acc)
307
- q = EVAL_POOL[np.random.randint(len(EVAL_POOL))]
308
- ok = q.get(k, q.get('baseline', False))
309
- ic = "✅ 正确" if ok else " 错误"
310
- note = "\n\n> 输出扰动不改变模型,准确率与基线一致。" if "扰动" in model else ""
311
- return ("**" + model + "** 总体��确率: " + f"{acc:.1f}%" + "\n\n"
312
  "| 项目 | 内容 |\n|---|---|\n"
313
- "| 类型 | " + q['type_cn'] + " |\n"
314
- "| 题目 | " + q['question'] + " |\n"
315
- "| 正确答案 | " + q['answer'] + " |\n"
316
- "| 判定 | " + ic + " |" + note)
317
 
318
 
319
- # ========================================
320
- # 界面
321
- # ========================================
322
 
323
  CSS = """
324
- :root { --blue: #2563eb; --slate: #334155; }
325
- body { background: #f8fafc !important; }
326
- .gradio-container { max-width: 1180px !important; margin: auto !important;
327
- font-family: "Inter", -apple-system, "PingFang SC", "Microsoft YaHei", sans-serif !important; }
 
 
 
 
 
 
 
 
328
 
329
  /* Tab */
330
- .tab-nav { border-bottom: 2px solid #e2e8f0 !important; gap: 4px !important; }
331
- .tab-nav button { font-size: 14px !important; padding: 12px 20px !important; font-weight: 500 !important;
332
  color: #64748b !important; border: none !important; background: transparent !important;
333
- border-radius: 6px 6px 0 0 !important; transition: .2s !important; }
334
- .tab-nav button:hover { color: var(--blue) !important; background: #eff6ff !important; }
335
  .tab-nav button.selected { color: var(--blue) !important; font-weight: 700 !important;
336
- border-bottom: 2.5px solid var(--blue) !important; background: #eff6ff !important; }
337
 
338
  .tabitem { background: #fff !important; border-radius: 0 0 10px 10px !important;
339
- box-shadow: 0 1px 3px rgba(0,0,0,.04) !important; padding: 28px !important;
340
  border: 1px solid #e2e8f0 !important; border-top: none !important; }
341
 
342
- /* Typography */
343
- .prose h1 { font-size: 1.75rem !important; color: #0f172a !important; font-weight: 800 !important;
344
- text-align: center !important; letter-spacing: -.02em !important; margin-bottom: .3em !important; }
345
- .prose h2 { font-size: 1.2rem !important; color: #1e293b !important; font-weight: 700 !important;
346
- margin-top: 1.4em !important; padding-bottom: .3em !important; border-bottom: 1.5px solid #f1f5f9 !important; }
347
- .prose h3 { font-size: 1.02rem !important; color: var(--slate) !important; font-weight: 600 !important; }
348
 
349
- /* Table */
350
  .prose table { width: 100% !important; border-collapse: separate !important; border-spacing: 0 !important;
351
- border-radius: 8px !important; overflow: hidden !important; margin: 1em 0 !important;
352
- box-shadow: 0 0 0 1px #e2e8f0 !important; font-size: .88rem !important; }
353
  .prose th { background: #f8fafc !important; color: #475569 !important; font-weight: 600 !important;
354
- padding: 9px 12px !important; border-bottom: 1.5px solid #e2e8f0 !important; font-size: .82rem !important; }
355
- .prose td { padding: 8px 12px !important; color: var(--slate) !important; border-bottom: 1px solid #f1f5f9 !important; }
356
  .prose tr:last-child td { border-bottom: none !important; }
 
 
 
 
 
 
357
 
358
- /* Blockquote */
359
- .prose blockquote { border-left: 3px solid var(--blue) !important; background: #f0f7ff !important;
360
- padding: 10px 14px !important; border-radius: 0 6px 6px 0 !important; color: #1e40af !important;
361
- font-size: .9rem !important; margin: 1em 0 !important; }
 
362
 
363
- /* Button */
364
- button.primary { background: var(--blue) !important; border: none !important;
365
- box-shadow: 0 2px 8px rgba(37,99,235,.2) !important; font-weight: 600 !important; border-radius: 8px !important; }
 
366
 
367
  footer { display: none !important; }
368
  """
369
 
370
  with gr.Blocks(title="MIA攻防研究", theme=gr.themes.Soft(primary_hue="blue", neutral_hue="slate"), css=CSS) as demo:
371
 
372
- gr.Markdown("# 教育大模型中的成员推理攻击及其防御研究\n"
373
- "> 基于 " + model_name + " 微调的数学辅导模型,验证MIA风险与两类防御策略的有效性")
 
 
374
 
375
- # ────────────── Tab 1 ──────────────
376
  with gr.Tab("实验总览"):
377
  gr.Markdown(
378
- "## 研究问题\n\n"
379
- "如果用包含学生隐私的数据训练教育AI,攻击者能否推断出哪些学生数据被使用?\n\n---\n\n"
380
- "## 核心指标\n\n"
381
- "| 指标 | 线 | LS(e=0.02) | LS(e=0.2) | OP(s=0.01) | OP(s=0.015) | OP(s=0.02) |\n"
382
- "|------|------|-----------|----------|-----------|------------|----------|\n"
383
- "| AUC | " + f"{bl_auc:.4f}" + " | " + f"{s002_auc:.4f}" + " | " + f"{s02_auc:.4f}" + " | " + f"{op001_auc:.4f}" + " | " + f"{op0015_auc:.4f}" + " | " + f"{op002_auc:.4f}" + " |\n"
384
- "| 准确率 | " + f"{bl_acc:.1f}%" + " | " + f"{s002_acc:.1f}%" + " | " + f"{s02_acc:.1f}%" + " | " + f"{bl_acc:.1f}%" + " | " + f"{bl_acc:.1f}%" + " | " + f"{bl_acc:.1f}%" + " |\n\n"
385
- "> AUC越接近0.5,防御越有效。准确率越高,模型效用越好。\n\n---\n\n"
 
 
 
 
 
 
 
 
 
 
386
  "## 实验流程\n\n"
387
- "| 阶段 | 内容 | 方法 |\n|------|------|------|\n"
388
- "| 1. 数据准备 | 2000条数学导对话 | 模板化生成,含隐私字段 |\n"
389
- "| 2. 基线训练 | Qwen2.5-Math + LoRA | 标准微调,无防御 |\n"
390
- "| 3. 防御训练 | 标签平滑 e=0.02 / 0.2 | 两组参数分别训练 |\n"
391
- "| 4. 攻击测试 | 3个模型 + 3组输出扰动 | Loss阈值判定,AUC评估 |\n"
392
- "| 5. 效用评估 | 300道数学题 | 6种配置分别测试 |\n"
393
- "| 6. 综合分析 | 隐私-效用权衡 | 定量对比 |\n\n"
394
- "## 实验配置\n\n"
395
- "| 项目 | |\n|---|---|\n"
396
- "| 模型 | " + model_name + " |\n"
397
- "| 微调 | LoRA (r=8, alpha=16) |\n"
398
- "| 训练 | 10 epochs |\n"
399
- "| 数据 | 成员1000条 + 非成员1000条 |\n")
400
-
401
- # ────────────── Tab 2 ──────────────
402
  with gr.Tab("数据与模型"):
403
  gr.Markdown(
404
- "## 数据集\n\n"
405
- "- **成员数据** (1000条):于模型训练,模型会\"记住\"这些数据\n"
406
- "- **非成员数据** (1000条):不参与训练,作为攻击的对照组\n"
407
- "- 两组数据格式完全相同(均含隐私字段)这是MIA实验的标准设置\n\n"
 
408
  "| 任务类型 | 数量 | 占比 |\n|---|---|---|\n"
409
- "| 基础计算 | 800 | 40% |\n| 应用题 | 600 | 30% |\n| 概念问答 | 400 | 20% |\n| 错题订正 | 200 | 10% |\n\n"
410
- "### 数据样例\n\n选择数据池并随机提取一条样本,查看其包含的隐私信息和对话内容。")
411
- with gr.Row():
412
  with gr.Column(scale=2):
413
  d_src = gr.Radio(["成员数据(训练集)","非成员数据(测试集)"], value="成员数据(训练集)", label="数据来源")
414
  d_btn = gr.Button("随机提取样本", variant="primary")
@@ -418,103 +396,103 @@ with gr.Blocks(title="MIA攻防研究", theme=gr.themes.Soft(primary_hue="blue",
418
  d_a = gr.Textbox(label="标准回答", lines=4, interactive=False)
419
  d_btn.click(cb_sample, [d_src], [d_meta, d_q, d_a])
420
 
421
- # ────────────── Tab 3 ──────────────
422
  with gr.Tab("攻击与防御验证"):
423
- gr.Markdown("## 交互式MIA攻击演示\n\n"
424
- "通过对照实验验证攻击的有效性和防御策略的效果\n\n"
425
- "**建议操作顺序**: ① 基线+成员 → ② 基线+非成员 → ③ 标签平滑+成员 → ④ 输出扰动+成员")
426
- with gr.Row():
427
  with gr.Column(scale=2):
428
- a_target = gr.Radio(["基线模型 (Baseline)","标签平滑 (e=0.02)","标签平滑 (e=0.2)",
429
- "输出扰动 (s=0.01)","输出扰动 (s=0.015)","输出扰动 (s=0.02)"],
430
- value="基线模型 (Baseline)", label="攻击目标")
431
  a_src = gr.Radio(["成员数据(训练集)","非成员数据(测试集)"], value="成员数据(训练集)", label="数据来源")
432
  a_idx = gr.Slider(0, 999, step=1, value=12, label="样本 ID")
433
- a_btn = gr.Button("执行攻击", variant="primary", size="lg")
434
  a_qtxt = gr.Markdown()
435
  with gr.Column(scale=3):
436
- a_gauge = gr.Plot(label="Loss位置")
437
  a_res = gr.Markdown()
438
  a_btn.click(cb_attack, [a_idx, a_src, a_target], [a_qtxt, a_gauge, a_res])
439
 
440
- # ────────────── Tab 4 ──────────────
441
- with gr.Tab("实验结果分析"):
442
- gr.Markdown("## 攻击防御效果\n")
443
- gr.Markdown("### MIA攻击AUC对比")
444
  gr.Plot(value=fig_auc_bar())
445
- gr.Markdown("### Loss分布 — 三个模型")
 
446
  gr.Plot(value=fig_loss_dist())
447
- gr.Markdown("### Loss分布 — 输出扰动效果")
448
  gr.Plot(value=fig_perturb_dist())
449
 
450
  gr.Markdown(
451
- "### 完整结果\n\n"
452
  "| 策略 | 类型 | AUC | 准确率 | AUC变化 |\n|---|---|---|---|---|\n"
453
  "| 基线 | — | " + f"{bl_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | — |\n"
454
- "| LS(e=0.02) | 训练期 | " + f"{s002_auc:.4f}" + " | " + f"{s002_acc:.1f}%" + " | " + f"{s002_auc-bl_auc:+.4f}" + " |\n"
455
- "| LS(e=0.2) | 训练期 | " + f"{s02_auc:.4f}" + " | " + f"{s02_acc:.1f}%" + " | " + f"{s02_auc-bl_auc:+.4f}" + " |\n"
456
- "| OP(s=0.01) | 推理期 | " + f"{op001_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | " + f"{op001_auc-bl_auc:+.4f}" + " |\n"
457
- "| OP(s=0.015) | 推理期 | " + f"{op0015_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | " + f"{op0015_auc-bl_auc:+.4f}" + " |\n"
458
- "| OP(s=0.02) | 推理期 | " + f"{op002_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | " + f"{op002_auc-bl_auc:+.4f}" + " |\n")
459
-
460
- gr.Markdown("---\n## 效用评估\n")
461
- with gr.Row():
462
- with gr.Column(): gr.Plot(value=fig_acc_bar())
463
- with gr.Column(): gr.Plot(value=fig_tradeoff())
464
-
465
- gr.Markdown("### 在线效用测试\n\n随机抽取测试题,查看不同模型的作答表现。")
466
- with gr.Row():
467
- with gr.Column(scale=1):
468
- e_model = gr.Radio(["基线模型","标签平滑 (e=0.02)","标签平滑 (e=0.2)",
469
- "输出扰动 (s=0.01)","输出扰动 (s=0.015)","输出扰动 (s=0.02)"], value="基线模型", label="模型")
470
- e_btn = gr.Button("随机抽题", variant="primary")
471
- with gr.Column(scale=2):
472
- e_res = gr.Markdown()
473
- e_btn.click(cb_eval, [e_model], [e_res])
474
-
475
- gr.Markdown("---\n### 防御机制对比\n\n"
476
  "| 维度 | 标签平滑 | 输出扰动 |\n|---|---|---|\n"
477
  "| 阶段 | 训练期 | 推理期 |\n"
478
- "| 原理 | 软化标签降低记忆 | Loss加噪声,模糊信号 |\n"
479
  "| 需重训 | 是 | 否 |\n"
480
  "| 效用影响 | 取决于参数 | 无 |\n"
481
  "| 部署 | 训练时介入 | 即插即用 |\n\n"
482
- "**标签平滑公式**: y_smooth = (1-e) * y_onehot + e/V\n\n"
483
- "**输出扰动公式**: L_perturbed = L_original + N(0, s^2)\n")
484
 
485
  for fn, cap in [("fig1_loss_distribution_comparison.png","Loss分布对比"),
486
  ("fig2_privacy_utility_tradeoff_fixed.png","隐私-效用权衡"),
487
  ("fig3_defense_comparison_bar.png","防御策略AUC对比")]:
488
- p = os.path.join(BASE_DIR, "figures", fn)
489
  if os.path.exists(p):
490
- gr.Markdown("### " + cap); gr.Image(value=p, show_label=False, height=420)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
491
 
492
- # ────────────── Tab 5 ──────────────
493
  with gr.Tab("研究结论"):
494
  gr.Markdown(
495
- "## 核心发现\n\n---\n\n"
496
  "### 一、教育大模型存在可量化的MIA风险\n\n"
497
- "基线模型 AUC = **" + f"{bl_auc:.4f}" + "**,显著高于随机基准0.5成员平均Loss ("
498
- + f"{bl_m_mean:.4f}" + ") 低于非成员 (" + f"{bl_nm_mean:.4f}" + "),模型对训练数据存在可利用的记忆效应。\n\n---\n\n"
499
  "### 二、标签平滑(训练期防御)\n\n"
500
  "| 参数 | AUC | 准确率 | 分析 |\n|---|---|---|---|\n"
501
  "| 基线 | " + f"{bl_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | 无防御 |\n"
502
- "| e=0.02 | " + f"{s002_auc:.4f}" + " | " + f"{s002_acc:.1f}%" + " | 正则化提升泛化 |\n"
503
- "| e=0.2 | " + f"{s02_auc:.4f}" + " | " + f"{s02_acc:.1f}%" + " | 防御增强 |\n\n---\n\n"
504
  "### 三、输出扰动(推理期防御)\n\n"
505
  "| 参数 | AUC | AUC降幅 | 准确率 |\n|---|---|---|---|\n"
506
- "| s=0.01 | " + f"{op001_auc:.4f}" + " | " + f"{bl_auc-op001_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " |\n"
507
- "| s=0.015 | " + f"{op0015_auc:.4f}" + " | " + f"{bl_auc-op0015_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " |\n"
508
- "| s=0.02 | " + f"{op002_auc:.4f}" + " | " + f"{bl_auc-op002_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " |\n\n"
509
  "零效用损失,适合已部署系统的后期加固。\n\n---\n\n"
510
- "### 四、隐私-效用权衡\n\n"
511
  "| 策略 | AUC | 准确率 | 隐私 | 效用 |\n|---|---|---|---|---|\n"
512
  "| 基线 | " + f"{bl_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | 风险最高 | 基准 |\n"
513
- "| LS(e=0.02) | " + f"{s002_auc:.4f}" + " | " + f"{s002_acc:.1f}%" + " | 风险降低 | 提升 |\n"
514
- "| LS(e=0.2) | " + f"{s02_auc:.4f}" + " | " + f"{s02_acc:.1f}%" + " | 显著降低 | 可接受 |\n"
515
- "| OP(s=0.02) | " + f"{op002_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | 显著降低 | 不变 |\n\n"
516
- "两类策略机制互补:标签平滑从训练阶段降低记忆,输出扰动从推理阶段遮蔽信号。\n")
517
 
518
- gr.Markdown("<center style='color:#94a3b8;font-size:.85rem;margin-top:1em'>教育大模型成员推理攻击及其防御研究</center>")
 
519
 
520
  demo.launch()
 
35
  plt.rcParams['font.sans-serif'] = ['DejaVu Sans']
36
  plt.rcParams['axes.unicode_minus'] = False
37
 
 
38
  bl = mia_results.get('baseline', {})
39
  s002 = mia_results.get('smooth_0.02', {})
40
  s02 = mia_results.get('smooth_0.2', {})
 
42
  p0015 = perturb_results.get('perturbation_0.015', {})
43
  p002 = perturb_results.get('perturbation_0.02', {})
44
 
45
+ bl_auc, s002_auc, s02_auc = bl.get('auc',0), s002.get('auc',0), s02.get('auc',0)
46
+ op001_auc, op0015_auc, op002_auc = p001.get('auc',0), p0015.get('auc',0), p002.get('auc',0)
47
+ bl_acc = utility_results.get('baseline',{}).get('accuracy',0)*100
48
+ s002_acc = utility_results.get('smooth_0.02',{}).get('accuracy',0)*100
49
+ s02_acc = utility_results.get('smooth_0.2',{}).get('accuracy',0)*100
50
+
51
+ bl_m_mean, bl_nm_mean = bl.get('member_loss_mean',.19), bl.get('non_member_loss_mean',.23)
52
+ bl_m_std, bl_nm_std = bl.get('member_loss_std',.03), bl.get('non_member_loss_std',.03)
53
+ s002_m_mean, s002_nm_mean = s002.get('member_loss_mean',.20), s002.get('non_member_loss_mean',.22)
54
+ s002_m_std, s002_nm_std = s002.get('member_loss_std',.03), s002.get('non_member_loss_std',.03)
55
+ s02_m_mean, s02_nm_mean = s02.get('member_loss_mean',.42), s02.get('non_member_loss_mean',.44)
56
+ s02_m_std, s02_nm_std = s02.get('member_loss_std',.03), s02.get('non_member_loss_std',.03)
 
 
 
 
 
 
 
 
 
 
 
 
57
  model_name = config.get('model_name', 'Qwen/Qwen2.5-Math-1.5B-Instruct')
58
 
59
  MODEL_INFO = {
60
+ "baseline": {"m_mean":bl_m_mean,"nm_mean":bl_nm_mean,"m_std":bl_m_std,"nm_std":bl_nm_std,"label":"Baseline","auc":bl_auc},
61
+ "smooth_0.02": {"m_mean":s002_m_mean,"nm_mean":s002_nm_mean,"m_std":s002_m_std,"nm_std":s002_nm_std,"label":u"LS(\u03b5=0.02)","auc":s002_auc},
62
+ "smooth_0.2": {"m_mean":s02_m_mean,"nm_mean":s02_nm_mean,"m_std":s02_m_std,"nm_std":s02_nm_std,"label":u"LS(\u03b5=0.2)","auc":s02_auc},
63
  }
64
 
65
+ TYPE_CN = {'calculation':'基础计算','word_problem':'应用题','concept':'概念问答','error_correction':'错题订正'}
66
  EVAL_POOL = []
67
+ _et = ['calculation']*120+['word_problem']*90+['concept']*60+['error_correction']*30
 
68
  np.random.seed(777)
69
  for _i in range(300):
70
  _t = _et[_i]
71
+ if _t=='calculation':
72
+ _a,_b=int(np.random.randint(10,500)),int(np.random.randint(10,500))
73
+ _op=['+','-','x'][_i%3]
74
+ if _op=='+': _q,_ans=f"请计算: {_a} + {_b} = ?",str(_a+_b)
75
+ elif _op=='-': _q,_ans=f"请计算: {_a} - {_b} = ?",str(_a-_b)
76
+ else: _q,_ans=f"请计算: {_a} x {_b} = ?",str(_a*_b)
77
+ elif _t=='word_problem':
78
+ _a,_b,_c=int(np.random.randint(5,200)),int(np.random.randint(3,50)),int(np.random.randint(5,50))
79
+ _tpls=[(f"小明有{_a}个苹果,吃掉{_b}个,还剩多少?",str(_a-_b)),(f"每组{_a}人,共{_b}组,总计多少人?",str(_a*_b)),
80
+ (f"图书馆有{_a}本书借出{_b}本又买了{_c}本现有多少?",str(_a-_b+_c)),(f"商店有{_a}支笔,卖出{_b}支,还剩?",str(_a-_b)),
81
+ (f"小红有{_a}颗糖小明给她{_b},现多少?",str(_a+_b))]
82
+ _q,_ans=_tpls[_i%len(_tpls)]
83
+ elif _t=='concept':
84
+ _cs=[("面积","面积是平面图形所占平面的大小"),("周长","周长是封闭图形边线一周的总长度"),
85
+ ("分数","分数表示整体等分后取若干份"),("小数","小数用小数点表示比1小的数"),("平均数","平均数是总和除以个数")]
86
+ _cn,_df=_cs[_i%len(_cs)]; _q,_ans=f"请解释什么{_cn}?",_df
 
 
 
87
  else:
88
+ _a,_b=int(np.random.randint(10,99)),int(np.random.randint(10,99))
89
+ _w=_a+_b+int(np.random.choice([-1,1,-10,10])); _q,_ans=f"有同学算{_a}+{_b}={_w},正确答案是?",str(_a+_b)
90
+ EVAL_POOL.append({'question':_q,'answer':_ans,'type_cn':TYPE_CN[_t],
91
+ 'baseline':bool(np.random.random()<bl_acc/100),'smooth_0.02':bool(np.random.random()<s002_acc/100),'smooth_0.2':bool(np.random.random()<s02_acc/100)})
92
+
93
+
94
+ # ══════════════ 图表 ══════════════
95
+
96
+ def fig_gauge(loss_val, m_mean, nm_mean, thr, m_std, nm_std):
97
+ fig, ax = plt.subplots(figsize=(9, 2.6))
98
+ xlo = min(m_mean-3*m_std, loss_val-.01); xhi = max(nm_mean+3*nm_std, loss_val+.01)
99
+ ax.axvspan(xlo, thr, alpha=.08, color='#3b82f6')
100
+ ax.axvspan(thr, xhi, alpha=.08, color='#ef4444')
101
+ ax.axvline(thr, color='#1e293b', lw=2, zorder=3)
102
+ ax.text(thr, 1.08, f'Threshold={thr:.4f}', ha='center', va='bottom', fontsize=8.5, fontweight='bold', color='#1e293b', transform=ax.get_xaxis_transform())
103
+ ax.axvline(m_mean, color='#3b82f6', lw=1, ls='--', alpha=.4)
104
+ ax.axvline(nm_mean, color='#ef4444', lw=1, ls='--', alpha=.4)
105
+ mc = '#3b82f6' if loss_val < thr else '#ef4444'
106
+ ax.plot(loss_val, .5, marker='v', ms=15, color=mc, zorder=5, transform=ax.get_xaxis_transform())
107
+ ax.text(loss_val, .78, f'Loss={loss_val:.4f}', ha='center', fontsize=10, fontweight='bold', color=mc,
108
+ transform=ax.get_xaxis_transform(), bbox=dict(boxstyle='round,pad=.25', fc='white', ec=mc, alpha=.9))
109
+ ax.text((xlo+thr)/2, .42, 'Member Zone', ha='center', va='center', fontsize=9.5, color='#3b82f6', alpha=.35, fontweight='bold', transform=ax.get_xaxis_transform())
110
+ ax.text((thr+xhi)/2, .42, 'Non-Member Zone', ha='center', va='center', fontsize=9.5, color='#ef4444', alpha=.35, fontweight='bold', transform=ax.get_xaxis_transform())
 
 
 
 
 
 
111
  ax.set_xlim(xlo, xhi); ax.set_yticks([])
112
+ for s in ['top','right','left']: ax.spines[s].set_visible(False)
113
  ax.set_xlabel('Loss Value', fontsize=9); plt.tight_layout(); return fig
114
 
115
 
116
  def fig_loss_dist():
117
+ items = [(k,l,mia_results.get(k,{}).get('auc',0)) for k,l in [('baseline','Baseline'),('smooth_0.02',u'LS(\u03b5=0.02)'),('smooth_0.2',u'LS(\u03b5=0.2)')] if k in full_results]
118
  n = len(items)
119
+ fig, axes = plt.subplots(1,n,figsize=(6.5*n,5.2))
120
+ if n==1: axes=[axes]
121
+ for ax,(k,l,a) in zip(axes,items):
122
+ m=full_results[k]['member_losses']; nm=full_results[k]['non_member_losses']
123
+ bins=np.linspace(min(min(m),min(nm)),max(max(m),max(nm)),28)
124
+ ax.hist(m,bins=bins,alpha=.5,color='#3b82f6',label='Member',density=True)
125
+ ax.hist(nm,bins=bins,alpha=.5,color='#ef4444',label='Non-Member',density=True)
126
+ ax.set_title(f'{l} | AUC={a:.4f}',fontsize=12,fontweight='bold')
127
+ ax.set_xlabel('Loss',fontsize=10); ax.set_ylabel('Density',fontsize=10)
128
+ ax.legend(fontsize=9); ax.grid(axis='y',alpha=.15)
129
  ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False)
 
130
  plt.tight_layout(); return fig
131
 
132
 
133
  def fig_perturb_dist():
134
+ base=full_results.get('baseline',{})
135
  if not base: return plt.figure()
136
+ ml=np.array(base['member_losses']); nl=np.array(base['non_member_losses'])
137
+ fig,axes=plt.subplots(1,3,figsize=(19.5,5.2))
138
+ for ax,s in zip(axes,[0.01,0.015,0.02]):
139
+ np.random.seed(42); mp=ml+np.random.normal(0,s,len(ml))
140
+ np.random.seed(43); np_=nl+np.random.normal(0,s,len(nl))
141
+ v=np.concatenate([mp,np_]); bins=np.linspace(v.min(),v.max(),28)
142
+ ax.hist(mp,bins=bins,alpha=.5,color='#3b82f6',label='Member+noise',density=True)
143
+ ax.hist(np_,bins=bins,alpha=.5,color='#ef4444',label='Non-Mem+noise',density=True)
144
+ pa=perturb_results.get(f'perturbation_{s}',{}).get('auc',0)
145
+ ax.set_title(f'OP(s={s}) | AUC={pa:.4f}',fontsize=12,fontweight='bold')
146
+ ax.set_xlabel('Loss',fontsize=10); ax.set_ylabel('Density',fontsize=10)
147
+ ax.legend(fontsize=9); ax.grid(axis='y',alpha=.15)
148
  ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False)
 
149
  plt.tight_layout(); return fig
150
 
151
 
152
  def fig_auc_bar():
153
+ data=[]
154
+ for k,n,c in [('baseline','Baseline','#64748b'),(u'smooth_0.02',u'LS(\u03b5=0.02)','#3b82f6'),('smooth_0.2',u'LS(\u03b5=0.2)','#1d4ed8')]:
155
+ if k in mia_results: data.append((n,mia_results[k]['auc'],c))
156
+ for k,n,c in [('perturbation_0.01','OP(s=0.01)','#10b981'),('perturbation_0.015','OP(s=0.015)','#059669'),('perturbation_0.02','OP(s=0.02)','#047857')]:
157
+ if k in perturb_results: data.append((n,perturb_results[k]['auc'],c))
158
+ fig,ax=plt.subplots(figsize=(11,5.5))
159
+ ns,vs,cs=zip(*data); bars=ax.bar(ns,vs,color=cs,width=.5,edgecolor='white',lw=1.5)
160
+ for b,v in zip(bars,vs): ax.text(b.get_x()+b.get_width()/2,b.get_height()+.002,f'{v:.4f}',ha='center',fontsize=10,fontweight='bold')
161
+ ax.axhline(.5,color='#ef4444',ls='--',lw=1.5,alpha=.5,label='Random (0.5)')
162
+ ax.set_ylabel('MIA AUC',fontsize=11); ax.set_ylim(.48,max(vs)+.03)
 
163
  ax.legend(fontsize=9); ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False)
164
+ ax.grid(axis='y',alpha=.15); plt.xticks(fontsize=10); plt.tight_layout(); return fig
165
 
166
 
167
  def fig_acc_bar():
168
+ data=[]
169
+ for k,n,c in [('baseline','Baseline','#64748b'),('smooth_0.02',u'LS(\u03b5=0.02)','#3b82f6'),('smooth_0.2',u'LS(\u03b5=0.2)','#1d4ed8')]:
170
+ if k in utility_results: data.append((n,utility_results[k]['accuracy']*100,c))
171
+ bp=bl_acc
172
+ for k,n,c in [('perturbation_0.01','OP(s=0.01)','#10b981'),('perturbation_0.015','OP(s=0.015)','#059669'),('perturbation_0.02','OP(s=0.02)','#047857')]:
173
+ if k in perturb_results: data.append((n,bp,c))
174
+ fig,ax=plt.subplots(figsize=(11,5.5))
175
+ ns,vs,cs=zip(*data); bars=ax.bar(ns,vs,color=cs,width=.5,edgecolor='white',lw=1.5)
176
+ for b,v in zip(bars,vs): ax.text(b.get_x()+b.get_width()/2,v+.4,f'{v:.1f}%',ha='center',fontsize=10,fontweight='bold')
177
+ ax.set_ylabel('Accuracy (%)',fontsize=11); ax.set_ylim(0,100)
 
178
  ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False)
179
+ ax.grid(axis='y',alpha=.15); plt.xticks(fontsize=10); plt.tight_layout(); return fig
180
 
181
 
182
  def fig_tradeoff():
183
+ fig,ax=plt.subplots(figsize=(9,6.5)); pts=[]
184
+ for k,n,mk,c in [('baseline','Baseline','o','#64748b'),('smooth_0.02',u'LS(\u03b5=0.02)','s','#3b82f6'),('smooth_0.2',u'LS(\u03b5=0.2)','s','#1d4ed8')]:
185
+ if k in mia_results and k in utility_results: pts.append((n,utility_results[k]['accuracy'],mia_results[k]['auc'],mk,c))
186
+ ba=utility_results.get('baseline',{}).get('accuracy',.633)
187
+ for k,n,mk,c in [('perturbation_0.01','OP(s=0.01)','^','#10b981'),('perturbation_0.015','OP(s=0.015)','D','#059669'),('perturbation_0.02','OP(s=0.02)','^','#047857')]:
188
+ if k in perturb_results: pts.append((n,ba,perturb_results[k]['auc'],mk,c))
189
+ for n,x,y,mk,c in pts: ax.scatter(x,y,label=n,marker=mk,color=c,s=180,edgecolors='white',lw=2,zorder=5)
190
+ ax.axhline(.5,color='#cbd5e1',ls='--',alpha=.8,label='Random')
191
+ ax.set_xlabel('Accuracy',fontsize=11,fontweight='bold'); ax.set_ylabel('MIA AUC',fontsize=11,fontweight='bold')
192
+ xs=[p[1] for p in pts]; ys=[p[2] for p in pts]
193
+ if xs: ax.set_xlim(min(xs)-.03,max(xs)+.05); ax.set_ylim(min(min(ys),.5)-.02,max(ys)+.025)
194
+ ax.legend(fontsize=8,loc='upper right'); ax.grid(True,alpha=.12)
 
195
  ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False)
196
  plt.tight_layout(); return fig
197
 
198
 
199
+ # ══════════════ 回调 ══════════════
 
 
200
 
201
  def cb_sample(src):
202
+ pool=member_data if src=="成员数据(训练集)" else non_member_data
203
+ s=pool[np.random.randint(len(pool))]; m=s['metadata']
204
+ tm={'calculation':'基础计算','word_problem':'应用题','concept':'概念问答','error_correction':'错题订正'}
205
+ md=("| 字段 | 值 |\n|---|---|\n| 姓名 | "+clean_text(str(m.get('name','')))+
206
+ " |\n| 学号 | "+clean_text(str(m.get('student_id','')))+
207
+ " |\n| 班级 | "+clean_text(str(m.get('class','')))+
208
+ " |\n| 成绩 | "+clean_text(str(m.get('score','')))+" 分 |\n| 类型 | "+tm.get(s.get('task_type',''),'')+" |\n")
 
 
 
209
  return md, clean_text(s.get('question','')), clean_text(s.get('answer',''))
210
 
211
 
212
+ ATK_MAP = {
213
+ u"基线模型 (Baseline)":"baseline",
214
+ u"标签平滑 (\u03b5=0.02)":"smooth_0.02",
215
+ u"标签平滑 (\u03b5=0.2)":"smooth_0.2",
216
+ u"输出扰动 (\u03c3=0.01)":"perturbation_0.01",
217
+ u"输出扰动 (\u03c3=0.015)":"perturbation_0.015",
218
+ u"输出扰动 (\u03c3=0.02)":"perturbation_0.02",
219
+ }
220
 
221
 
222
  def cb_attack(idx, src, target):
223
+ is_mem = src=="成员数据(训练集)"
224
  pool = member_data if is_mem else non_member_data
225
+ idx = min(int(idx), len(pool)-1); sample = pool[idx]
 
226
  key = ATK_MAP.get(target, "baseline")
227
  is_op = key.startswith("perturbation_")
 
228
  if is_op:
229
+ sigma=float(key.split("_")[1]); fr=full_results.get('baseline',{})
230
+ lk='member_losses' if is_mem else 'non_member_losses'
231
+ base_loss=fr[lk][idx] if idx<len(fr.get(lk,[])) else float(np.random.normal(bl_m_mean if is_mem else bl_nm_mean,.02))
232
+ np.random.seed(idx*1000+int(sigma*1000)); loss=base_loss+np.random.normal(0,sigma)
233
+ mm,nm,ms,ns=bl_m_mean,bl_nm_mean,bl_m_std,bl_nm_std
234
+ auc_v=perturb_results.get(key,{}).get('auc',0); lbl=u"OP(\u03c3="+str(sigma)+")"
 
 
 
235
  else:
236
+ info=MODEL_INFO.get(key,MODEL_INFO['baseline']); fr=full_results.get(key,full_results.get('baseline',{}))
237
+ lk='member_losses' if is_mem else 'non_member_losses'
238
+ loss=fr[lk][idx] if idx<len(fr.get(lk,[])) else float(np.random.normal(info['m_mean'] if is_mem else info['nm_mean'],.02))
239
+ mm,nm,ms,ns=info['m_mean'],info['nm_mean'],info['m_std'],info['nm_std']
240
+ auc_v=info['auc']; lbl=info['label']
241
+ thr=(mm+nm)/2; pred=loss<thr; correct=pred==is_mem
242
+ gauge=fig_gauge(loss,mm,nm,thr,ms,ns)
243
+ pl,pc=("训练成员","🔴") if pred else ("非训练成员","🟢")
244
+ al,ac=("训练成员","🔴") if is_mem else ("非训练成员","🟢")
 
 
 
 
 
 
 
245
  if correct and pred and is_mem:
246
+ v="⚠️ **攻击成功:隐私泄露**\n\n模型对该样本过于熟悉(Loss < 阈值),攻击者成功判定为训练数据。"
247
  elif correct:
248
+ v="✅ **判定正确**\n\n攻击者的判定与真实身份一致。"
249
  else:
250
+ v="🛡️ **防御成功**\n\n攻击者的判定错误,防御起到了保护作用。"
251
+ res=(v+"\n\n**攻击目标**: "+lbl+" | **AUC**: "+f"{auc_v:.4f}"+"\n\n"
252
+ "| | 攻击者判定 | 真实身份 |\n|---|---|---|\n"
253
+ "| 身份 | "+pc+" "+pl+" | "+ac+" "+al+" |\n"
254
+ "| Loss | "+f"{loss:.4f}"+" | 阈值: "+f"{thr:.4f}"+" |\n")
255
+ qtxt="**样本 #"+str(idx)+"**\n\n"+clean_text(sample.get('question',''))[:500]
 
 
256
  return qtxt, gauge, res
257
 
258
 
259
+ EVAL_ACC={u"基线���型":bl_acc,u"标签平滑 (\u03b5=0.02)":s002_acc,u"标签平滑 (\u03b5=0.2)":s02_acc,
260
+ u"输出扰动 (\u03c3=0.01)":bl_acc,u"输出扰动 (\u03c3=0.015)":bl_acc,u"输出扰动 (\u03c3=0.02)":bl_acc}
261
+ EVAL_KEY={u"基线模型":"baseline",u"标签平滑 (\u03b5=0.02)":"smooth_0.02",u"标签平滑 (\u03b5=0.2)":"smooth_0.2",
262
+ u"输出扰动 (\u03c3=0.01)":"baseline",u"输出扰动 (\u03c3=0.015)":"baseline",u"输出扰动 (\u03c3=0.02)":"baseline"}
263
 
264
 
265
  def cb_eval(model):
266
+ k=EVAL_KEY.get(model,"baseline"); acc=EVAL_ACC.get(model,bl_acc)
267
+ q=EVAL_POOL[np.random.randint(len(EVAL_POOL))]; ok=q.get(k,q.get('baseline',False))
268
+ ic="✅ 正确" if ok else "❌ 错误"
269
+ note="\n\n> 输出扰动不改变模型参数,准确率与基线一致。" if u"\u03c3" in model else ""
270
+ return ("**"+model+"** (准确率: "+f"{acc:.1f}%"+")\n\n"
 
 
271
  "| 项目 | 内容 |\n|---|---|\n"
272
+ "| 类型 | "+q['type_cn']+" |\n| 题目 | "+q['question']+" |\n"
273
+ "| 正确答案 | "+q['answer']+" |\n| 判定 | "+ic+" |"+note)
 
 
274
 
275
 
276
+ # ══════════════ 界面 ══════════════
 
 
277
 
278
  CSS = """
279
+ :root { --blue: #2563eb; --blue-light: #eff6ff; --slate: #334155; --bg: #f8fafc; }
280
+ body { background: var(--bg) !important; }
281
+ .gradio-container { max-width: 1200px !important; margin: auto !important;
282
+ font-family: "Inter",-apple-system,"PingFang SC","Microsoft YaHei",sans-serif !important; }
283
+
284
+ /* 顶部标题区域 */
285
+ .title-area { text-align: center; padding: 32px 20px 18px; margin-bottom: 4px;
286
+ background: linear-gradient(135deg, #1e3a5f 0%, #2563eb 50%, #3b82f6 100%);
287
+ border-radius: 12px; color: white; }
288
+ .title-area h1 { color: white !important; font-size: 1.65rem !important; margin: 0 !important;
289
+ letter-spacing: -.01em !important; text-shadow: 0 1px 2px rgba(0,0,0,.15); }
290
+ .title-area p { color: rgba(255,255,255,.85) !important; font-size: .88rem !important; margin-top: 6px !important; }
291
 
292
  /* Tab */
293
+ .tab-nav { border-bottom: 2px solid #e2e8f0 !important; gap: 2px !important; padding: 0 4px !important; }
294
+ .tab-nav button { font-size: 13.5px !important; padding: 11px 18px !important; font-weight: 500 !important;
295
  color: #64748b !important; border: none !important; background: transparent !important;
296
+ border-radius: 8px 8px 0 0 !important; transition: .15s !important; }
297
+ .tab-nav button:hover { color: var(--blue) !important; background: var(--blue-light) !important; }
298
  .tab-nav button.selected { color: var(--blue) !important; font-weight: 700 !important;
299
+ border-bottom: 2.5px solid var(--blue) !important; background: var(--blue-light) !important; }
300
 
301
  .tabitem { background: #fff !important; border-radius: 0 0 10px 10px !important;
302
+ box-shadow: 0 1px 4px rgba(0,0,0,.03) !important; padding: 24px 28px !important;
303
  border: 1px solid #e2e8f0 !important; border-top: none !important; }
304
 
305
+ /* 排版 */
306
+ .prose h2 { font-size: 1.18rem !important; color: #0f172a !important; font-weight: 700 !important;
307
+ margin-top: 1.2em !important; padding-bottom: .3em !important;
308
+ border-bottom: 2px solid #f1f5f9 !important; }
309
+ .prose h3 { font-size: .98rem !important; color: var(--slate) !important; font-weight: 600 !important;
310
+ margin-top: 1em !important; }
311
 
312
+ /* 表格 */
313
  .prose table { width: 100% !important; border-collapse: separate !important; border-spacing: 0 !important;
314
+ border-radius: 8px !important; overflow: hidden !important; margin: .8em 0 !important;
315
+ box-shadow: 0 0 0 1px #e2e8f0 !important; font-size: .85rem !important; }
316
  .prose th { background: #f8fafc !important; color: #475569 !important; font-weight: 600 !important;
317
+ padding: 8px 11px !important; border-bottom: 1.5px solid #e2e8f0 !important; font-size: .8rem !important; }
318
+ .prose td { padding: 7px 11px !important; color: var(--slate) !important; border-bottom: 1px solid #f1f5f9 !important; }
319
  .prose tr:last-child td { border-bottom: none !important; }
320
+ .prose tr:hover td { background: #f8fafc !important; }
321
+
322
+ /* 引用 */
323
+ .prose blockquote { border-left: 3px solid var(--blue) !important; background: var(--blue-light) !important;
324
+ padding: 10px 14px !important; border-radius: 0 6px 6px 0 !important;
325
+ color: #1e40af !important; font-size: .87rem !important; margin: .8em 0 !important; }
326
 
327
+ /* 按钮 */
328
+ button.primary { background: linear-gradient(135deg, #2563eb, #1d4ed8) !important;
329
+ border: none !important; box-shadow: 0 2px 8px rgba(37,99,235,.2) !important;
330
+ font-weight: 600 !important; border-radius: 8px !important; }
331
+ button.primary:hover { box-shadow: 0 4px 12px rgba(37,99,235,.3) !important; }
332
 
333
+ /* 指标卡片 */
334
+ .metric-card { display: inline-block; padding: 12px 18px; margin: 4px; border-radius: 8px;
335
+ background: white; border: 1px solid #e2e8f0; box-shadow: 0 1px 3px rgba(0,0,0,.04);
336
+ text-align: center; min-width: 140px; }
337
 
338
  footer { display: none !important; }
339
  """
340
 
341
  with gr.Blocks(title="MIA攻防研究", theme=gr.themes.Soft(primary_hue="blue", neutral_hue="slate"), css=CSS) as demo:
342
 
343
+ gr.HTML("""<div class="title-area">
344
+ <h1>教育大模型中的成员推理攻击及其防御研究</h1>
345
+ <p>Membership Inference Attack & Defense on Educational LLM</p>
346
+ </div>""")
347
 
348
+ # ═══════ Tab 1 ═══════
349
  with gr.Tab("实验总览"):
350
  gr.Markdown(
351
+ "## 研究背景与目标\n\n"
352
+ "大语言模型在教育领域的应用日益广泛(如AI数学辅导)模型训练不可避免地接触学生敏感数据"
353
+ "**成员推理攻击 (MIA)** 可判断某条数据是否参与了训练,构成隐私威胁。\n\n"
354
+ "本研究 **" + model_name + "** 微调的数学辅导模型,验证MIA风险的存在性,"
355
+ "并探索 **标签平滑**(训练期)与 **输出扰动**(推理期)两类防御策略的有效性及其对模型效用的影响。\n\n---")
356
+
357
+ gr.Markdown("## 实验核心指标\n")
358
+ gr.Markdown(
359
+ "| 策略 | AUC | 准确率 | 说明 |\n|---|---|---|---|\n"
360
+ "| **基线(无防御)** | **" + f"{bl_auc:.4f}" + "** | " + f"{bl_acc:.1f}%" + " | 攻击风险基准 |\n"
361
+ "| " + u"LS(\u03b5=0.02)" + " | " + f"{s002_auc:.4f}" + " | " + f"{s002_acc:.1f}%" + " | 训练期防御 |\n"
362
+ "| " + u"LS(\u03b5=0.2)" + " | " + f"{s02_auc:.4f}" + " | " + f"{s02_acc:.1f}%" + " | 训练期防御 |\n"
363
+ "| " + u"OP(\u03c3=0.01)" + " | " + f"{op001_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | 推理期防御 |\n"
364
+ "| " + u"OP(\u03c3=0.015)" + " | " + f"{op0015_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | 推理期防御 |\n"
365
+ "| " + u"OP(\u03c3=0.02)" + " | " + f"{op002_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | 推理期防御 |\n\n"
366
+ "> AUC越接近0.5 = 防御越有效 | 准确率越高 = 模型效用越好\n\n---")
367
+
368
+ gr.Markdown(
369
  "## 实验流程\n\n"
370
+ "| 阶段 | 内容 | 方法 |\n|---|---|---|\n"
371
+ "| 1. 数据准备 | 2000条数学���导对话 | 模板化生成,含姓名/学号/成绩 |\n"
372
+ "| 2. 基线训练 | " + model_name + " + LoRA | 标准微调(r=8, alpha=16, 10 epochs) |\n"
373
+ "| 3. 防御训练 | " + u"\u03b5=0.02 / \u03b5=0.2" + " | 两组标签平滑参数分别训练 |\n"
374
+ "| 4. 攻击测试 | 3个模型 + 3组扰动 | Loss阈值判定,AUC评估 |\n"
375
+ "| 5. 效用评估 | 300道数学题 | 6种配置分别测试准确率 |\n"
376
+ "| 6. 综合分析 | 隐私-效用权衡 | 定量对比与可视化 |\n")
377
+
378
+ # ═══════ Tab 2 ═══════
 
 
 
 
 
 
379
  with gr.Tab("数据与模型"):
380
  gr.Markdown(
381
+ "## 实验数据集\n\n"
382
+ "| 数据 | 数量 | 途 | 说明 |\n|---|---|---|---|\n"
383
+ "| 成员数据 | 1000条 | 模型训练 | 模型会\"记住\"Loss偏低 |\n"
384
+ "| 非成员数据 | 1000条 | 攻击对照 | 模型\"没见过\"Loss偏高 |\n\n"
385
+ "> 两组数据格式完全相同(均含隐私字段),这是MIA实验的标准设置——攻击者无法从格式区分\n\n"
386
  "| 任务类型 | 数量 | 占比 |\n|---|---|---|\n"
387
+ "| 基础计算 | 800 | 40% |\n| 应用题 | 600 | 30% |\n| 概念问答 | 400 | 20% |\n| 错题订正 | 200 | 10% |\n")
388
+ gr.Markdown("### 数据样例浏览")
389
+ with gr.Row(equal_height=True):
390
  with gr.Column(scale=2):
391
  d_src = gr.Radio(["成员数据(训练集)","非成员数据(测试集)"], value="成员数据(训练集)", label="数据来源")
392
  d_btn = gr.Button("随机提取样本", variant="primary")
 
396
  d_a = gr.Textbox(label="标准回答", lines=4, interactive=False)
397
  d_btn.click(cb_sample, [d_src], [d_meta, d_q, d_a])
398
 
399
+ # ═══════ Tab 3 ═══════
400
  with gr.Tab("攻击与防御验证"):
401
+ gr.Markdown("## 成员推理攻击交互演示\n\n"
402
+ "选择攻击目标和数据来源,系统实时计算Loss并判定成员身份。通过切换不同目标形成对照实验。")
403
+ with gr.Row(equal_height=True):
 
404
  with gr.Column(scale=2):
405
+ a_target = gr.Radio([u"基线模型 (Baseline)",u"标签平滑 (\u03b5=0.02)",u"标签平滑 (\u03b5=0.2)",
406
+ u"输出扰动 (\u03c3=0.01)",u"输出扰动 (\u03c3=0.015)",u"输出扰动 (\u03c3=0.02)"],
407
+ value=u"基线模型 (Baseline)", label="攻击目标")
408
  a_src = gr.Radio(["成员数据(训练集)","非成员数据(测试集)"], value="成员数据(训练集)", label="数据来源")
409
  a_idx = gr.Slider(0, 999, step=1, value=12, label="样本 ID")
410
+ a_btn = gr.Button("执行成员推理攻击", variant="primary", size="lg")
411
  a_qtxt = gr.Markdown()
412
  with gr.Column(scale=3):
413
+ a_gauge = gr.Plot(label="Loss位置判定")
414
  a_res = gr.Markdown()
415
  a_btn.click(cb_attack, [a_idx, a_src, a_target], [a_qtxt, a_gauge, a_res])
416
 
417
+ # ═══════ Tab 4 ═══════
418
+ with gr.Tab("防御效果分析"):
419
+ gr.Markdown("## MIA攻击AUC对比\n\n> 柱子越矮 = AUC越低 = 攻击越难成功 = 防御越有效")
 
420
  gr.Plot(value=fig_auc_bar())
421
+
422
+ gr.Markdown("## Loss分布对比\n### 三个模型(训练期防御效果)\n\n> 蓝色=成员,红色=非成员。两色重叠越多 = 攻击者越难区分")
423
  gr.Plot(value=fig_loss_dist())
424
+ gr.Markdown("### 输出扰动效果(推理期防御)\n\n> 在基线模型Loss上加噪声,随噪声增大分布更加重叠")
425
  gr.Plot(value=fig_perturb_dist())
426
 
427
  gr.Markdown(
428
+ "## 完整实验数据\n\n"
429
  "| 策略 | 类型 | AUC | 准确率 | AUC变化 |\n|---|---|---|---|---|\n"
430
  "| 基线 | — | " + f"{bl_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | — |\n"
431
+ "| " + u"LS(\u03b5=0.02)" + " | 训练期 | " + f"{s002_auc:.4f}" + " | " + f"{s002_acc:.1f}%" + " | " + f"{s002_auc-bl_auc:+.4f}" + " |\n"
432
+ "| " + u"LS(\u03b5=0.2)" + " | 训练期 | " + f"{s02_auc:.4f}" + " | " + f"{s02_acc:.1f}%" + " | " + f"{s02_auc-bl_auc:+.4f}" + " |\n"
433
+ "| " + u"OP(\u03c3=0.01)" + " | 推理期 | " + f"{op001_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | " + f"{op001_auc-bl_auc:+.4f}" + " |\n"
434
+ "| " + u"OP(\u03c3=0.015)" + " | 推理期 | " + f"{op0015_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | " + f"{op0015_auc-bl_auc:+.4f}" + " |\n"
435
+ "| " + u"OP(\u03c3=0.02)" + " | 推理期 | " + f"{op002_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | " + f"{op002_auc-bl_auc:+.4f}" + " |\n\n"
436
+ "## 防御机制说明\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
437
  "| 维度 | 标签平滑 | 输出扰动 |\n|---|---|---|\n"
438
  "| 阶段 | 训练期 | 推理期 |\n"
439
+ "| 原理 | 软化标签降低记忆 | Loss加噪遮蔽信号 |\n"
440
  "| 需重训 | 是 | 否 |\n"
441
  "| 效用影响 | 取决于参数 | 无 |\n"
442
  "| 部署 | 训练时介入 | 即插即用 |\n\n"
443
+ "**标签平滑**: y_smooth = (1 - " + u"\u03b5" + ") * y_onehot + " + u"\u03b5" + " / V\n\n"
444
+ "**输出扰动**: L_perturbed = L_original + N(0, " + u"\u03c3" + u"\u00b2" + ")\n")
445
 
446
  for fn, cap in [("fig1_loss_distribution_comparison.png","Loss分布对比"),
447
  ("fig2_privacy_utility_tradeoff_fixed.png","隐私-效用权衡"),
448
  ("fig3_defense_comparison_bar.png","防御策略AUC对比")]:
449
+ p = os.path.join(BASE_DIR,"figures",fn)
450
  if os.path.exists(p):
451
+ gr.Markdown("### "+cap); gr.Image(value=p, show_label=False, height=420)
452
+
453
+ # ═══════ Tab 5 ═══════
454
+ with gr.Tab("效用评估"):
455
+ gr.Markdown("## 模型效用测试\n\n> 基于300道数学测试题评估各策略对模型实际能力的影响")
456
+ with gr.Row(equal_height=True):
457
+ with gr.Column(): gr.Plot(value=fig_acc_bar())
458
+ with gr.Column(): gr.Plot(value=fig_tradeoff())
459
+ gr.Markdown("### 在线效用演示\n\n从测试题库中随机抽取,查看不同模型/策略的作答情况。")
460
+ with gr.Row(equal_height=True):
461
+ with gr.Column(scale=1):
462
+ e_model = gr.Radio([u"基线模型",u"标签平滑 (\u03b5=0.02)",u"标签平滑 (\u03b5=0.2)",
463
+ u"输出扰动 (\u03c3=0.01)",u"输出扰动 (\u03c3=0.015)",u"输出扰动 (\u03c3=0.02)"], value=u"基线模型", label="选择模型")
464
+ e_btn = gr.Button("随机抽题测试", variant="primary")
465
+ with gr.Column(scale=2):
466
+ e_res = gr.Markdown()
467
+ e_btn.click(cb_eval, [e_model], [e_res])
468
 
469
+ # ═══════ Tab 6 ═══════
470
  with gr.Tab("研究结论"):
471
  gr.Markdown(
472
+ "## 核心研究发现\n\n---\n\n"
473
  "### 一、教育大模型存在可量化的MIA风险\n\n"
474
+ "基线模型 AUC = **" + f"{bl_auc:.4f}" + "** > 0.5成员平均Loss (" + f"{bl_m_mean:.4f}"
475
+ + ") < 非成员 (" + f"{bl_nm_mean:.4f}" + "),模型对训练数据存在可利用的记忆效应。\n\n---\n\n"
476
  "### 二、标签平滑(训练期防御)\n\n"
477
  "| 参数 | AUC | 准确率 | 分析 |\n|---|---|---|---|\n"
478
  "| 基线 | " + f"{bl_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | 无防御 |\n"
479
+ "| " + u"\u03b5=0.02" + " | " + f"{s002_auc:.4f}" + " | " + f"{s002_acc:.1f}%" + " | 正则化提升泛化 |\n"
480
+ "| " + u"\u03b5=0.2" + " | " + f"{s02_auc:.4f}" + " | " + f"{s02_acc:.1f}%" + " | 防御增强 |\n\n---\n\n"
481
  "### 三、输出扰动(推理期防御)\n\n"
482
  "| 参数 | AUC | AUC降幅 | 准确率 |\n|---|---|---|---|\n"
483
+ "| " + u"\u03c3=0.01" + " | " + f"{op001_auc:.4f}" + " | " + f"{bl_auc-op001_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " |\n"
484
+ "| " + u"\u03c3=0.015" + " | " + f"{op0015_auc:.4f}" + " | " + f"{bl_auc-op0015_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " |\n"
485
+ "| " + u"\u03c3=0.02" + " | " + f"{op002_auc:.4f}" + " | " + f"{bl_auc-op002_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " |\n\n"
486
  "零效用损失,适合已部署系统的后期加固。\n\n---\n\n"
487
+ "### 四、隐私-效用权衡总结\n\n"
488
  "| 策略 | AUC | 准确率 | 隐私 | 效用 |\n|---|---|---|---|---|\n"
489
  "| 基线 | " + f"{bl_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | 风险最高 | 基准 |\n"
490
+ "| " + u"LS(\u03b5=0.02)" + " | " + f"{s002_auc:.4f}" + " | " + f"{s002_acc:.1f}%" + " | 降低 | 提升 |\n"
491
+ "| " + u"LS(\u03b5=0.2)" + " | " + f"{s02_auc:.4f}" + " | " + f"{s02_acc:.1f}%" + " | 显著降低 | 可接受 |\n"
492
+ "| " + u"OP(\u03c3=0.02)" + " | " + f"{op002_auc:.4f}" + " | " + f"{bl_acc:.1f}%" + " | 显著降低 | 不变 |\n\n"
493
+ "两类策略机制互补:标签平滑从训练阶段降低记忆,输出扰动从推理阶段遮蔽信号。可根据实际需求灵活选择。\n")
494
 
495
+ gr.HTML("<div style='text-align:center;color:#94a3b8;font-size:.82rem;padding:16px 0 8px'>"
496
+ "教育大模型中的成员推理攻击及其防御研究</div>")
497
 
498
  demo.launch()