xiaohy commited on
Commit
154a7f8
·
verified ·
1 Parent(s): a548b21

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +897 -412
app.py CHANGED
@@ -1,6 +1,6 @@
1
  # ================================================================
2
- # 教育大模型MIA攻防研究 - Gradio演示系统
3
- # 支持: 11组实验 × 8维度指标
4
  # ================================================================
5
 
6
  import os
@@ -10,6 +10,8 @@ import numpy as np
10
  import matplotlib
11
  matplotlib.use('Agg')
12
  import matplotlib.pyplot as plt
 
 
13
  import gradio as gr
14
 
15
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
@@ -18,7 +20,8 @@ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
18
  # 数据加载
19
  # ================================================================
20
  def load_json(path):
21
- with open(os.path.join(BASE_DIR, path), 'r', encoding='utf-8') as f:
 
22
  return json.load(f)
23
 
24
  def clean_text(text):
@@ -29,63 +32,81 @@ def clean_text(text):
29
  text = re.sub(r'[\u200b-\u200f\u2028-\u202f\u2060-\u206f\ufeff]', '', text)
30
  return text.strip()
31
 
32
- # 加载所有数据
33
  member_data = load_json("data/member.json")
34
  non_member_data = load_json("data/non_member.json")
35
  config = load_json("config.json")
36
-
37
- # 加载汇总结果
38
  all_data = load_json("results/all_results.json")
39
  mia_results = all_data["mia_results"]
40
  perturb_results = all_data["perturbation_results"]
41
  utility_results = all_data["utility_results"]
42
  full_losses = all_data["full_losses"]
43
-
44
  model_name = config.get('model_name', 'Qwen/Qwen2.5-Math-1.5B-Instruct')
45
 
46
  # ================================================================
47
- # 提取指标
48
  # ================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- # 标签平滑模型
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  LS_KEYS = ["baseline", "smooth_eps_0.02", "smooth_eps_0.05", "smooth_eps_0.1", "smooth_eps_0.2"]
52
- LS_LABELS = ["基线", "LS(\u03b5=0.02)", "LS(\u03b5=0.05)", "LS(\u03b5=0.1)", "LS(\u03b5=0.2)"]
 
53
 
54
- # 输出扰动
55
  OP_SIGMAS = [0.005, 0.01, 0.015, 0.02, 0.025, 0.03]
56
  OP_KEYS = [f"perturbation_{s}" for s in OP_SIGMAS]
57
- OP_LABELS = [f"OP(\u03c3={s})" for s in OP_SIGMAS]
 
58
 
59
  ALL_KEYS = LS_KEYS + OP_KEYS
60
- ALL_LABELS = LS_LABELS + OP_LABELS
 
61
 
62
- def get_metric(key, metric_name, default=0):
63
- if key in mia_results:
64
- return mia_results[key].get(metric_name, default)
65
- if key in perturb_results:
66
- return perturb_results[key].get(metric_name, default)
67
  return default
68
 
69
- def get_utility(key):
70
- if key in utility_results:
71
- return utility_results[key].get("accuracy", 0) * 100
72
- if key.startswith("perturbation_"):
73
- return utility_results.get("baseline", {}).get("accuracy", 0) * 100
74
  return 0
75
 
76
- # 基线数据
77
- bl_auc = get_metric("baseline", "auc")
78
- bl_acc = get_utility("baseline")
79
- bl_m_mean = get_metric("baseline", "member_loss_mean")
80
- bl_nm_mean = get_metric("baseline", "non_member_loss_mean")
81
-
82
- plt.rcParams['font.sans-serif'] = ['DejaVu Sans']
83
- plt.rcParams['axes.unicode_minus'] = False
84
 
85
- TYPE_CN = {
86
- 'calculation': '基础计算', 'word_problem': '应用题',
87
- 'concept': '概念问答', 'error_correction': '错题订正'
88
- }
89
 
90
  # ================================================================
91
  # 效用评估题库
@@ -96,580 +117,1044 @@ _types = ['calculation']*120 + ['word_problem']*90 + ['concept']*60 + ['error_co
96
  for _i in range(300):
97
  _t = _types[_i]
98
  if _t == 'calculation':
99
- _a, _b = int(np.random.randint(10, 500)), int(np.random.randint(10, 500))
100
- _op = ['+', '-', '\u00d7'][_i % 3]
101
- if _op == '+': _q, _ans = f"请计算: {_a} + {_b} = ?", str(_a + _b)
102
- elif _op == '-': _q, _ans = f"请计算: {_a} - {_b} = ?", str(_a - _b)
103
- else: _q, _ans = f"请计算: {_a} \u00d7 {_b} = ?", str(_a * _b)
104
  elif _t == 'word_problem':
105
- _a, _b = int(np.random.randint(5, 200)), int(np.random.randint(3, 50))
106
- _tpls = [
107
- (f"小明有{_a}个苹果,吃掉{_b}个,还剩多少?", str(_a - _b)),
108
- (f"每组{_a}人,共{_b}组,总计多少人?", str(_a * _b)),
109
- (f"商店有{_a}支笔,卖出{_b}支,还剩?", str(_a - _b)),
110
- (f"小红有{_a}颗糖,小明给她{_b}颗,现在多少?", str(_a + _b)),
111
- ]
112
- _q, _ans = _tpls[_i % len(_tpls)]
113
  elif _t == 'concept':
114
- _cs = [("面积", "面积是平面图形所占平面的大小"), ("周长", "周长是封闭图形边线一周的总长度"),
115
- ("分数", "分数表示整体等分后取若干份"), ("小数", "小数用小数点表示比1小的数"),
116
- ("平均数", "平均数是总和除以个数")]
117
- _cn, _df = _cs[_i % len(_cs)]
118
- _q, _ans = f"请解释什么是{_cn}?", _df
119
  else:
120
- _a, _b = int(np.random.randint(10, 99)), int(np.random.randint(10, 99))
121
- _w = _a + _b + int(np.random.choice([-1, 1, -10, 10]))
122
- _q, _ans = f"有同学算{_a}+{_b}={_w},正确答案是?", str(_a + _b)
123
-
124
- # 为每个模型模拟结果
125
- item = {'question': _q, 'answer': _ans, 'type_cn': TYPE_CN[_t]}
126
  for key in LS_KEYS:
127
- acc = get_utility(key) / 100
128
- item[key] = bool(np.random.random() < acc)
129
  EVAL_POOL.append(item)
130
 
131
  # ================================================================
132
- # 图表函数
133
  # ================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
- def fig_gauge(loss_val, m_mean, nm_mean, thr, m_std, nm_std):
136
- fig, ax = plt.subplots(figsize=(9, 2.6))
137
- xlo = min(m_mean - 3*m_std, loss_val - 0.01)
138
- xhi = max(nm_mean + 3*nm_std, loss_val + 0.01)
139
- ax.axvspan(xlo, thr, alpha=0.08, color='#3b82f6')
140
- ax.axvspan(thr, xhi, alpha=0.08, color='#ef4444')
141
- ax.axvline(thr, color='#1e293b', lw=2, zorder=3)
142
- ax.text(thr, 1.08, f'Threshold={thr:.4f}', ha='center', va='bottom',
143
- fontsize=8.5, fontweight='bold', color='#1e293b',
144
- transform=ax.get_xaxis_transform())
145
- mc = '#3b82f6' if loss_val < thr else '#ef4444'
146
- ax.plot(loss_val, 0.5, marker='v', ms=15, color=mc, zorder=5,
147
- transform=ax.get_xaxis_transform())
148
- ax.text(loss_val, 0.78, f'Loss={loss_val:.4f}', ha='center', fontsize=10,
149
- fontweight='bold', color=mc, transform=ax.get_xaxis_transform(),
150
- bbox=dict(boxstyle='round,pad=.25', fc='white', ec=mc, alpha=0.9))
151
- ax.text((xlo+thr)/2, 0.42, 'Member Zone', ha='center', fontsize=9.5,
152
- color='#3b82f6', alpha=0.35, fontweight='bold', transform=ax.get_xaxis_transform())
153
- ax.text((thr+xhi)/2, 0.42, 'Non-Member Zone', ha='center', fontsize=9.5,
154
- color='#ef4444', alpha=0.35, fontweight='bold', transform=ax.get_xaxis_transform())
155
- ax.set_xlim(xlo, xhi)
156
- ax.set_yticks([])
157
- for s in ['top', 'right', 'left']:
158
- ax.spines[s].set_visible(False)
159
- ax.set_xlabel('Loss Value', fontsize=9)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  plt.tight_layout()
161
  return fig
162
 
 
 
 
163
  def fig_loss_dist():
164
  items = []
165
- for k, l in zip(LS_KEYS, LS_LABELS):
166
  if k in full_losses:
167
- auc = get_metric(k, 'auc', 0)
168
- items.append((k, l, auc))
169
  n = len(items)
170
- if n == 0:
171
- return plt.figure()
172
- fig, axes = plt.subplots(1, n, figsize=(5*n, 4.5))
173
- if n == 1:
174
- axes = [axes]
 
175
  for ax, (k, l, a) in zip(axes, items):
176
  m = full_losses[k]['member_losses']
177
  nm = full_losses[k]['non_member_losses']
178
- bins = np.linspace(min(min(m), min(nm)), max(max(m), max(nm)), 30)
179
- ax.hist(m, bins=bins, alpha=0.5, color='#3b82f6', label='Member', density=True)
180
- ax.hist(nm, bins=bins, alpha=0.5, color='#ef4444', label='Non-Member', density=True)
181
- ax.set_title(f'{l} | AUC={a:.4f}', fontsize=11, fontweight='bold')
182
  ax.set_xlabel('Loss', fontsize=9)
183
  ax.set_ylabel('Density', fontsize=9)
184
- ax.legend(fontsize=8)
185
- ax.grid(axis='y', alpha=0.15)
186
- ax.spines['top'].set_visible(False)
187
- ax.spines['right'].set_visible(False)
188
  plt.tight_layout()
189
  return fig
190
 
 
 
 
191
  def fig_perturb_dist():
192
- if 'baseline' not in full_losses:
193
- return plt.figure()
194
  ml = np.array(full_losses['baseline']['member_losses'])
195
  nl = np.array(full_losses['baseline']['non_member_losses'])
196
- sigmas = OP_SIGMAS
197
- n = len(sigmas)
198
- fig, axes = plt.subplots(1, n, figsize=(4*n, 4.5))
199
- if n == 1:
200
- axes = [axes]
201
- for ax, s in zip(axes, sigmas):
202
  rng_m = np.random.RandomState(42)
203
  rng_nm = np.random.RandomState(137)
204
  mp = ml + rng_m.normal(0, s, len(ml))
205
  np_ = nl + rng_nm.normal(0, s, len(nl))
206
  v = np.concatenate([mp, np_])
207
  bins = np.linspace(v.min(), v.max(), 28)
208
- ax.hist(mp, bins=bins, alpha=0.5, color='#3b82f6', label='Mem+noise', density=True)
209
- ax.hist(np_, bins=bins, alpha=0.5, color='#ef4444', label='Non+noise', density=True)
210
- pa = get_metric(f'perturbation_{s}', 'auc', 0)
211
- ax.set_title(f'OP(\u03c3={s}) | AUC={pa:.4f}', fontsize=10, fontweight='bold')
212
  ax.set_xlabel('Loss', fontsize=9)
213
- ax.legend(fontsize=7)
214
- ax.grid(axis='y', alpha=0.15)
215
- ax.spines['top'].set_visible(False)
216
- ax.spines['right'].set_visible(False)
217
  plt.tight_layout()
218
  return fig
219
 
220
- def fig_auc_bar():
221
- names, vals, colors = [], [], []
222
- color_map = {
223
- 'baseline': '#64748b',
224
- 'smooth_eps_0.02': '#93c5fd', 'smooth_eps_0.05': '#60a5fa',
225
- 'smooth_eps_0.1': '#3b82f6', 'smooth_eps_0.2': '#1d4ed8',
226
- }
227
- op_colors = ['#86efac', '#4ade80', '#22c55e', '#16a34a', '#15803d', '#166534']
228
-
229
- for k, l in zip(LS_KEYS, LS_LABELS):
230
- if k in mia_results:
231
- names.append(l)
232
- vals.append(mia_results[k]['auc'])
233
- colors.append(color_map.get(k, '#64748b'))
234
- for i, (k, l) in enumerate(zip(OP_KEYS, OP_LABELS)):
235
- if k in perturb_results:
236
- names.append(l)
237
- vals.append(perturb_results[k]['auc'])
238
- colors.append(op_colors[i % len(op_colors)])
239
-
240
- fig, ax = plt.subplots(figsize=(14, 5.5))
241
- bars = ax.bar(range(len(names)), vals, color=colors, width=0.6, edgecolor='white', lw=1.5)
242
- for b, v in zip(bars, vals):
243
- ax.text(b.get_x() + b.get_width()/2, v + 0.003, f'{v:.4f}',
244
- ha='center', fontsize=9, fontweight='bold')
245
- ax.axhline(0.5, color='#ef4444', ls='--', lw=1.5, alpha=0.5, label='Random (0.5)')
246
- ax.set_ylabel('MIA AUC', fontsize=11)
247
- ax.set_ylim(0.48, max(vals) + 0.03)
248
- ax.set_xticks(range(len(names)))
249
- ax.set_xticklabels(names, rotation=30, ha='right', fontsize=9)
250
- ax.legend(fontsize=9)
251
- ax.spines['top'].set_visible(False)
252
- ax.spines['right'].set_visible(False)
253
- ax.grid(axis='y', alpha=0.15)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  plt.tight_layout()
255
  return fig
256
 
 
 
 
257
  def fig_acc_bar():
258
- names, vals, colors = [], [], []
259
- color_map = {
260
- 'baseline': '#64748b',
261
- 'smooth_eps_0.02': '#93c5fd', 'smooth_eps_0.05': '#60a5fa',
262
- 'smooth_eps_0.1': '#3b82f6', 'smooth_eps_0.2': '#1d4ed8',
263
- }
264
- op_colors = ['#86efac', '#4ade80', '#22c55e', '#16a34a', '#15803d', '#166534']
265
-
266
- for k, l in zip(LS_KEYS, LS_LABELS):
267
  if k in utility_results:
268
- names.append(l)
269
- vals.append(utility_results[k]['accuracy'] * 100)
270
- colors.append(color_map.get(k, '#64748b'))
271
-
272
- bl_a = utility_results.get('baseline', {}).get('accuracy', 0) * 100
273
- for i, (k, l) in enumerate(zip(OP_KEYS, OP_LABELS)):
274
  if k in perturb_results:
275
- names.append(l)
276
- vals.append(bl_a)
277
- colors.append(op_colors[i % len(op_colors)])
278
-
279
- fig, ax = plt.subplots(figsize=(14, 5.5))
280
- bars = ax.bar(range(len(names)), vals, color=colors, width=0.6, edgecolor='white', lw=1.5)
281
  for b, v in zip(bars, vals):
282
- ax.text(b.get_x() + b.get_width()/2, v + 0.5, f'{v:.1f}%',
283
- ha='center', fontsize=9, fontweight='bold')
284
- ax.set_ylabel('Accuracy (%)', fontsize=11)
 
285
  ax.set_ylim(0, 100)
286
  ax.set_xticks(range(len(names)))
287
- ax.set_xticklabels(names, rotation=30, ha='right', fontsize=9)
288
- ax.spines['top'].set_visible(False)
289
- ax.spines['right'].set_visible(False)
290
- ax.grid(axis='y', alpha=0.15)
291
  plt.tight_layout()
292
  return fig
293
 
 
 
 
294
  def fig_tradeoff():
295
- fig, ax = plt.subplots(figsize=(10, 7))
296
-
297
- markers = {'baseline': 'o', 'smooth_eps_0.02': 's', 'smooth_eps_0.05': 's',
298
- 'smooth_eps_0.1': 's', 'smooth_eps_0.2': 's'}
299
- colors_ls = {'baseline': '#64748b', 'smooth_eps_0.02': '#93c5fd',
300
- 'smooth_eps_0.05': '#60a5fa', 'smooth_eps_0.1': '#3b82f6',
301
- 'smooth_eps_0.2': '#1d4ed8'}
302
- op_markers = ['^', 'D', 'v', 'P', 'X', 'h']
303
- op_colors_list = ['#86efac', '#4ade80', '#22c55e', '#16a34a', '#15803d', '#166534']
304
-
305
- for k, l in zip(LS_KEYS, LS_LABELS):
306
  if k in mia_results and k in utility_results:
307
- ax.scatter(utility_results[k]['accuracy'], mia_results[k]['auc'],
308
- label=l, marker=markers.get(k, 'o'), color=colors_ls.get(k, '#64748b'),
309
- s=180, edgecolors='white', lw=2, zorder=5)
310
-
311
- bl_a = utility_results.get('baseline', {}).get('accuracy', 0.66)
312
- for i, (k, l) in enumerate(zip(OP_KEYS, OP_LABELS)):
313
  if k in perturb_results:
314
- ax.scatter(bl_a, perturb_results[k]['auc'], label=l,
315
- marker=op_markers[i % len(op_markers)],
316
- color=op_colors_list[i % len(op_colors_list)],
317
- s=180, edgecolors='white', lw=2, zorder=5)
318
-
319
- ax.axhline(0.5, color='#cbd5e1', ls='--', alpha=0.8, label='Random (AUC=0.5)')
320
- ax.set_xlabel('Model Utility (Accuracy)', fontsize=12, fontweight='bold')
 
 
 
 
 
 
321
  ax.set_ylabel('Privacy Risk (MIA AUC)', fontsize=12, fontweight='bold')
322
- ax.set_title('Privacy-Utility Trade-off', fontsize=14, pad=15)
323
- ax.legend(fontsize=7, loc='upper right', ncol=2)
324
- ax.grid(True, alpha=0.12)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  ax.spines['top'].set_visible(False)
 
326
  ax.spines['right'].set_visible(False)
 
 
327
  plt.tight_layout()
328
  return fig
329
 
330
  # ================================================================
331
  # 回调函数
332
  # ================================================================
333
-
334
  def cb_sample(src):
335
- pool = member_data if src == "成员数据(训练集)" else non_member_data
336
  s = pool[np.random.randint(len(pool))]
337
  m = s['metadata']
338
- md = ("| 字段 | |\n|---|---|\n"
339
- f"| 姓名 | {clean_text(str(m.get('name','')))} |\n"
340
- f"| 学号 | {clean_text(str(m.get('student_id','')))} |\n"
341
- f"| 班级 | {clean_text(str(m.get('class','')))} |\n"
342
- f"| 成绩 | {clean_text(str(m.get('score','')))} |\n"
343
- f"| 类型 | {TYPE_CN.get(s.get('task_type',''), '')} |\n")
344
  return md, clean_text(s.get('question', '')), clean_text(s.get('answer', ''))
345
 
346
- # 攻击目标映射
347
  ATK_CHOICES = (
348
- ["基线模型 (Baseline)"] +
349
- [f"标签平滑 (\u03b5={e})" for e in [0.02, 0.05, 0.1, 0.2]] +
350
- [f"输出扰动 (\u03c3={s})" for s in OP_SIGMAS]
351
  )
352
-
353
  ATK_MAP = {}
354
- ATK_MAP["基线模型 (Baseline)"] = "baseline"
355
  for e in [0.02, 0.05, 0.1, 0.2]:
356
- ATK_MAP[f"标签平滑 (\u03b5={e})"] = f"smooth_eps_{e}"
357
  for s in OP_SIGMAS:
358
- ATK_MAP[f"输出扰动 (\u03c3={s})"] = f"perturbation_{s}"
359
 
360
  def cb_attack(idx, src, target):
361
- is_mem = src == "成员数据(训练集)"
362
  pool = member_data if is_mem else non_member_data
363
- idx = min(int(idx), len(pool) - 1)
364
  sample = pool[idx]
365
  key = ATK_MAP.get(target, "baseline")
366
-
367
  is_op = key.startswith("perturbation_")
368
-
369
  if is_op:
370
  sigma = float(key.split("_")[1])
371
  fr = full_losses.get('baseline', {})
372
  lk = 'member_losses' if is_mem else 'non_member_losses'
373
- losses_list = fr.get(lk, [])
374
- base_loss = losses_list[idx] if idx < len(losses_list) else float(np.random.normal(bl_m_mean if is_mem else bl_nm_mean, 0.02))
375
- np.random.seed(idx * 1000 + int(sigma * 10000))
376
  loss = base_loss + np.random.normal(0, sigma)
377
- mm = get_metric("baseline", "member_loss_mean", 0.19)
378
- nm_m = get_metric("baseline", "non_member_loss_mean", 0.20)
379
- ms = get_metric("baseline", "member_loss_std", 0.03)
380
- ns = get_metric("baseline", "non_member_loss_std", 0.03)
381
- auc_v = get_metric(key, "auc", 0)
382
- lbl = f"OP(\u03c3={sigma})"
383
  else:
384
  info = mia_results.get(key, mia_results.get('baseline', {}))
385
  fr = full_losses.get(key, full_losses.get('baseline', {}))
386
  lk = 'member_losses' if is_mem else 'non_member_losses'
387
- losses_list = fr.get(lk, [])
388
- loss = losses_list[idx] if idx < len(losses_list) else float(np.random.normal(info.get('member_loss_mean', 0.19), 0.02))
389
  mm = info.get('member_loss_mean', 0.19)
390
  nm_m = info.get('non_member_loss_mean', 0.20)
391
  ms = info.get('member_loss_std', 0.03)
392
  ns = info.get('non_member_loss_std', 0.03)
393
  auc_v = info.get('auc', 0)
394
- if key == "baseline":
395
- lbl = "Baseline"
396
  else:
397
- eps = key.replace("smooth_eps_", "")
398
  lbl = f"LS(\u03b5={eps})"
399
-
400
  thr = (mm + nm_m) / 2
401
  pred = loss < thr
402
  correct = pred == is_mem
403
-
404
  gauge = fig_gauge(loss, mm, nm_m, thr, ms, ns)
405
-
406
- pl, pc = ("训练成员", "\U0001f534") if pred else ("非训练成员", "\U0001f7e2")
407
- al, ac = ("训练成员", "\U0001f534") if is_mem else ("非训练成员", "\U0001f7e2")
408
-
409
  if correct and pred and is_mem:
410
- v = "⚠️ **攻击成功:隐私泄露**\n\n> 模型对该样本过于熟悉(Loss < 阈值),攻击者成功判定为训练数据。"
411
  elif correct:
412
- v = "**判定正确**\n\n> 攻击者的判定与真实身份一致。"
413
  else:
414
- v = "**防御成功**\n\n> 攻击者的判定错误,防御起到了保护作用。"
415
-
416
- res = (v + f"\n\n**攻击目标**: {lbl} | **AUC**: {auc_v:.4f}\n\n"
417
- "| | 攻击者判定 | 真实身份 |\n|---|---|---|\n"
418
- f"| 身份 | {pc} {pl} | {ac} {al} |\n"
419
- f"| Loss | {loss:.4f} | 阈值: {thr:.4f} |\n")
420
-
421
- qtxt = f"**样本 #{idx}**\n\n" + clean_text(sample.get('question', ''))[:500]
422
  return qtxt, gauge, res
423
 
424
- # 效用评估
425
- EVAL_MODEL_CHOICES = (
426
- ["基线模型"] +
427
- [f"标签平滑 (\u03b5={e})" for e in [0.02, 0.05, 0.1, 0.2]] +
428
- [f"输出扰动 (\u03c3={s})" for s in OP_SIGMAS]
429
  )
430
-
431
- EVAL_KEY_MAP = {"基线模型": "baseline"}
432
  for e in [0.02, 0.05, 0.1, 0.2]:
433
- EVAL_KEY_MAP[f"标签平滑 (\u03b5={e})"] = f"smooth_eps_{e}"
434
  for s in OP_SIGMAS:
435
- EVAL_KEY_MAP[f"输出扰动 (\u03c3={s})"] = "baseline"
436
 
437
  def cb_eval(model_choice):
438
  k = EVAL_KEY_MAP.get(model_choice, "baseline")
439
- acc = get_utility(k) if not model_choice.startswith("输出扰动") else bl_acc
440
  q = EVAL_POOL[np.random.randint(len(EVAL_POOL))]
441
  ok = q.get(k, q.get('baseline', False))
442
- ic = " 正确" if ok else " 错误"
443
- note = "\n\n> 输出扰动不改变模型参数,准确率与基线一致。" if "输出扰动" in model_choice else ""
444
- return (f"**模型**: {model_choice} (准确率: {acc:.1f}%)\n\n"
445
- "| 项目 | 内容 |\n|---|---|\n"
446
- f"| 类型 | {q['type_cn']} |\n"
447
- f"| 题目 | {q['question']} |\n"
448
- f"| 正确答案 | {q['answer']} |\n"
449
- f"| 判定 | {ic} |{note}")
450
 
451
  # ================================================================
452
- # 构建完整结果表
453
  # ================================================================
454
-
455
  def build_full_table():
456
  rows = []
457
- # 标签平滑
458
- for k, l in zip(LS_KEYS, LS_LABELS):
459
  if k in mia_results:
460
- m = mia_results[k]
461
- u = get_utility(k)
462
- t = "" if k == "baseline" else "训练期"
463
- auc_delta = "" if k == "baseline" else f"{m['auc'] - bl_auc:+.4f}"
464
  rows.append(f"| {l} | {t} | {m['auc']:.4f} | {m['attack_accuracy']:.4f} | "
465
  f"{m['precision']:.4f} | {m['recall']:.4f} | {m['f1']:.4f} | "
466
  f"{m['tpr_at_5fpr']:.4f} | {m['tpr_at_1fpr']:.4f} | "
467
- f"{m['loss_gap']:.4f} | {u:.1f}% | {auc_delta} |")
468
- # 输出扰动
469
- for k, l in zip(OP_KEYS, OP_LABELS):
470
  if k in perturb_results:
471
- m = perturb_results[k]
472
- auc_delta = f"{m['auc'] - bl_auc:+.4f}"
473
- rows.append(f"| {l} | 推理期 | {m['auc']:.4f} | {m['attack_accuracy']:.4f} | "
474
  f"{m['precision']:.4f} | {m['recall']:.4f} | {m['f1']:.4f} | "
475
  f"{m['tpr_at_5fpr']:.4f} | {m['tpr_at_1fpr']:.4f} | "
476
- f"{m['loss_gap']:.4f} | {bl_acc:.1f}% | {auc_delta} |")
477
-
478
- header = "| 策略 | 类型 | AUC | Acc | Prec | Rec | F1 | TPR@5% | TPR@1% | LossGap | 效用 | AUC\u0394 |\n|---|---|---|---|---|---|---|---|---|---|---|---|"
479
  return header + "\n" + "\n".join(rows)
480
 
481
  # ================================================================
482
- # CSS
483
  # ================================================================
484
-
485
  CSS = """
486
- body { background-color: #f8fafc !important; font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, sans-serif !important; }
487
- .gradio-container { max-width: 1200px !important; margin: 40px auto !important; }
488
- .title-area { background: #ffffff; padding: 28px 40px; border-radius: 12px;
489
- box-shadow: 0 4px 6px -1px rgba(0,0,0,0.05); margin-bottom: 24px; border-left: 6px solid #2563eb; }
490
- .title-area h1 { color: #0f172a !important; font-size: 1.7rem !important; font-weight: 800 !important; margin: 0 0 8px 0 !important; }
491
- .title-area p { color: #64748b !important; font-size: 1rem !important; margin: 0 !important; }
492
- .tabitem { background: rgba(255,255,255,0.98) !important; border-radius: 0 0 12px 12px !important;
493
- border: 1px solid #e2e8f0 !important; border-top: none !important;
494
- box-shadow: 0 4px 12px rgba(0,0,0,0.05) !important; padding: 32px 40px !important;
495
- min-height: 760px !important; overflow-y: auto !important; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
496
  .tab-nav { border-bottom: none !important; gap: 4px !important; }
497
- .tab-nav button { font-size: 15px !important; padding: 12px 24px !important; font-weight: 600 !important;
498
- color: #64748b !important; background: #e2e8f0 !important; border-radius: 10px 10px 0 0 !important; }
499
- .tab-nav button.selected { color: #2563eb !important; background: #ffffff !important; border-top: 3px solid #2563eb !important; }
500
- .prose table { width: 100% !important; border-collapse: separate !important; border-spacing: 0 !important;
501
- border-radius: 8px !important; overflow: hidden !important; border: 1px solid #e2e8f0 !important; font-size: 0.85rem !important; }
502
- .prose th { background: #f8fafc !important; color: #475569 !important; font-weight: 600 !important; padding: 10px 12px !important; }
503
- .prose td { padding: 10px 12px !important; color: #1e293b !important; border-bottom: 1px solid #f1f5f9 !important; }
504
- button.primary { background: #2563eb !important; color: white !important; border: none !important;
505
- border-radius: 6px !important; font-weight: 600 !important; }
506
- button.primary:hover { background: #1d4ed8 !important; }
507
- .prose blockquote { border-left: 4px solid #3b82f6 !important; background: #eff6ff !important;
508
- padding: 16px 20px !important; border-radius: 0 8px 8px 0 !important; color: #1e40af !important; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
509
  footer { display: none !important; }
510
  """
511
 
512
  # ================================================================
513
  # 界面
514
  # ================================================================
515
-
516
- with gr.Blocks(title="MIA攻防研究", theme=gr.themes.Base(), css=CSS) as demo:
517
 
518
  gr.HTML("""<div class="title-area">
519
- <h1>教育大模型中的成员推理攻击及其防御研究</h1>
520
- <p>Membership Inference Attack & Defense on Educational LLM — 11组实验 × 8维度指标</p>
 
521
  </div>""")
522
 
523
- # ═══ Tab 1: 实验总览 ═══
524
- with gr.Tab("实验总览"):
525
- gr.Markdown(f"""## 研究背景与目标
526
 
527
- 大语言模型在教育领域的应用日益广泛,模型训练不可避免地接触学生敏感数据。**成员推理攻击 (MIA)** 可判断某条数据是否参与了训练,构成隐私威胁。
528
 
529
- 本研究基于 **{model_name}** 微调的数学辅导模型,系统验证MIA风险并评估两类防御策略。
530
 
531
- ### 实验规模
532
- - **5个模型**: 1个基线 + 4组标签平滑 (\u03b5=0.02/0.05/0.1/0.2)
533
- - **6组输��扰动**: \u03c3=0.005/0.01/0.015/0.02/0.025/0.03
534
- - **8维度评估**: AUC / 攻击准确率 / 精确率 / 召回率 / F1 / TPR@5%FPR / TPR@1%FPR / Loss差距
535
- - **效用测试**: 300道数学题
 
 
 
536
  """)
537
- with gr.Accordion("展开查看:完整实验结果表(11组×8维度)", open=True):
538
  gr.Markdown(build_full_table())
539
- gr.Markdown("> AUC越接近0.5 = 防御越有效;效用越高 = 模型能力越好。AUC\u0394为相对基线的变化。")
540
 
541
- # ═══ Tab 2: 数据与模型 ═══
542
- with gr.Tab("数据与模型"):
543
- gr.Markdown("""## 实验数据集
 
544
 
545
- | 数据组 | 数量 | 用途 | 说明 |
546
  |---|---|---|---|
547
- | 成员数据 | 1000 | 模型训练 | 模型会\"记住\",Loss偏低 |
548
- | 非成员数据 | 1000 | 攻击对照 | 模型\"没见过\",Loss偏高 |
549
 
550
- | 任务类别 | 数量 | 占比 |
551
  |---|---|---|
552
- | 基础计算 | 800 | 40% |
553
- | 应用题 | 600 | 30% |
554
- | 概念问答 | 400 | 20% |
555
- | 错题订正 | 200 | 10% |
556
 
557
- > 两组数据格式完全相同(均含隐私字段),攻击者无法从格式区分。
558
  """)
559
- gr.Markdown("### 数据样例浏览")
560
  with gr.Row():
561
  with gr.Column(scale=2):
562
- d_src = gr.Radio(["成员数据(训练集)", "非成员数据(测试集)"],
563
- value="成员数据(训练集)", label="数据来源")
564
- d_btn = gr.Button("随机提取样本", variant="primary")
565
  d_meta = gr.Markdown()
566
  with gr.Column(scale=3):
567
- d_q = gr.Textbox(label="学生提问", lines=4, interactive=False)
568
- d_a = gr.Textbox(label="标准回答", lines=4, interactive=False)
569
  d_btn.click(cb_sample, [d_src], [d_meta, d_q, d_a])
570
 
571
- # ═══ Tab 3: 攻击验证 ═══
572
- with gr.Tab("攻击验证"):
573
- gr.Markdown("## 成员推理攻击交互演示\n\n选择攻击目标与数据源,系统执行Loss计算并判定数据归属。")
574
  with gr.Row():
575
  with gr.Column(scale=2):
576
- a_target = gr.Radio(ATK_CHOICES, value=ATK_CHOICES[0], label="攻击目标")
577
- a_src = gr.Radio(["成员数据(训练集)", "非成员数据(测试集)"],
578
- value="成员数据(训练集)", label="数据来源")
579
- a_idx = gr.Slider(0, 999, step=1, value=12, label="样本ID")
580
- a_btn = gr.Button("执行成员推理攻击", variant="primary", size="lg")
581
  a_qtxt = gr.Markdown()
582
  with gr.Column(scale=3):
583
- a_gauge = gr.Plot(label="Loss位置判定")
584
  a_res = gr.Markdown()
585
  a_btn.click(cb_attack, [a_idx, a_src, a_target], [a_qtxt, a_gauge, a_res])
586
 
587
- # ═══ Tab 4: 防御分析 ═══
588
- with gr.Tab("防御分析"):
589
- with gr.Accordion("AUC对比柱状图(11组)", open=True):
590
- gr.Markdown("> 柱子越矮 = AUC越低 = 防御越有效")
591
- gr.Plot(value=fig_auc_bar())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
592
 
593
- with gr.Accordion("Loss分布对比(标签平滑 5个模型)", open=False):
594
- gr.Markdown("> 蓝色=成员,红色=非成员。重叠越多=攻击越难")
 
 
595
  gr.Plot(value=fig_loss_dist())
596
-
597
- with gr.Accordion("输出扰动效果(6组\u03c3", open=False):
598
  gr.Plot(value=fig_perturb_dist())
599
 
600
- with gr.Accordion("完整数据表 + 防御机制说明", open=False):
 
601
  gr.Markdown(build_full_table())
602
- gr.Markdown("""
603
- ### 防御机制对比
604
 
605
- | 维度 | 标签平滑 | 输出扰动 |
606
  |---|---|---|
607
- | **阶段** | 训练期 | 推理期 |
608
- | **原理** | 软化标签降低记忆 | Loss加噪遮蔽信号 |
609
- | **需重训** | | |
610
- | **效用影响** | 正则化可能提升 | 完全无影响 |
611
- | **部署** | 训练时介入 | 即插即用 |
612
 
613
- **标签平滑公式**: `y_smooth = (1 - ε) × y_onehot + ε / V`
614
 
615
- **输出扰动公式**: `L_perturbed = L_original + N(0, σ²)`
616
  """)
617
 
618
- # ═══ Tab 5: 效用评估 ═══
619
- with gr.Tab("效用评估"):
620
- gr.Markdown("## 模型效用测试\n\n> 基于300道数学测试题评估各策略的实际能力影响")
 
 
621
  with gr.Row():
622
  with gr.Column():
623
  gr.Plot(value=fig_acc_bar())
624
  with gr.Column():
625
  gr.Plot(value=fig_tradeoff())
626
 
627
- gr.Markdown("### 在线抽样演示")
 
 
 
 
 
 
 
 
 
 
 
 
 
628
  with gr.Row():
629
  with gr.Column(scale=1):
630
- e_model = gr.Radio(EVAL_MODEL_CHOICES, value="基线模型", label="选择模型")
631
- e_btn = gr.Button("随机抽题测试", variant="primary")
632
  with gr.Column(scale=2):
633
  e_res = gr.Markdown()
634
  e_btn.click(cb_eval, [e_model], [e_res])
635
 
636
- # ═══ Tab 6: 研究结论 ═══
637
- with gr.Tab("研究结论"):
638
- gr.Markdown(f"""## 核心研究发现
 
639
 
640
  ---
641
 
642
- ### 一、教育大模型存在可量化的MIA风险
643
 
644
- 基线模型 AUC = **{bl_auc:.4f}** > 0.5,成员平均Loss ({bl_m_mean:.4f}) < 非成员 ({bl_nm_mean:.4f})。
645
- 攻击者判定正确率 {get_metric('baseline','attack_accuracy',0)*100:.1f}%,远超随机猜测的50%。
646
-
647
- ### 二、标签平滑(训练期防御)
 
 
648
 
649
- | 参数 | AUC | 效用 | 特点 |
650
- |---|---|---|---|
651
- | \u03b5=0.02 | {get_metric('smooth_eps_0.02','auc',0):.4f} | {get_utility('smooth_eps_0.02'):.1f}% | 轻度防御 |
652
- | \u03b5=0.05 | {get_metric('smooth_eps_0.05','auc',0):.4f} | {get_utility('smooth_eps_0.05'):.1f}% | 温和防御 |
653
- | \u03b5=0.1 | {get_metric('smooth_eps_0.1','auc',0):.4f} | {get_utility('smooth_eps_0.1'):.1f}% | 推荐配置 |
654
- | \u03b5=0.2 | {get_metric('smooth_eps_0.2','auc',0):.4f} | {get_utility('smooth_eps_0.2'):.1f}% | 强力防御 |
655
 
656
- 标签平滑通过正则化同时提升了隐私保护和模型效用(效用从{bl_acc:.1f}%升至{get_utility('smooth_eps_0.2'):.1f}%)。
 
 
 
 
 
657
 
658
- ### 三、输出扰动(推理期防御)
659
 
660
- | 参数 | AUC | AUC降幅 | 效用 |
661
- |---|---|---|---|
662
- | \u03c3=0.005 | {get_metric('perturbation_0.005','auc',0):.4f} | {bl_auc-get_metric('perturbation_0.005','auc',0):.4f} | {bl_acc:.1f}% |
663
- | \u03c3=0.01 | {get_metric('perturbation_0.01','auc',0):.4f} | {bl_auc-get_metric('perturbation_0.01','auc',0):.4f} | {bl_acc:.1f}% |
664
- | \u03c3=0.02 | {get_metric('perturbation_0.02','auc',0):.4f} | {bl_auc-get_metric('perturbation_0.02','auc',0):.4f} | {bl_acc:.1f}% |
665
- | \u03c3=0.03 | {get_metric('perturbation_0.03','auc',0):.4f} | {bl_auc-get_metric('perturbation_0.03','auc',0):.4f} | {bl_acc:.1f}% |
666
 
667
- **零效用损失,适合已部署系统的后期加固。**
668
 
669
- ### 四、最佳实践建议
670
 
671
- > 两类策略机制互补:标签平滑从训练阶段降低记忆,输出扰动从推理阶段遮蔽信号。
672
- > 推荐组合: **LS(\u03b5=0.1) + OP(\u03c3=0.02)** 兼顾隐私保护与模型效用。
673
  """)
674
 
675
  demo.launch()
 
1
  # ================================================================
2
+ # 教育大模型MIA攻防研究 - Gradio演示系统 v3.0
3
+ # 全英文图表(解决乱码) + 科技感界面 + 多维度对比分析
4
  # ================================================================
5
 
6
  import os
 
10
  import matplotlib
11
  matplotlib.use('Agg')
12
  import matplotlib.pyplot as plt
13
+ from matplotlib.gridspec import GridSpec
14
+ from sklearn.metrics import roc_curve, roc_auc_score
15
  import gradio as gr
16
 
17
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
 
20
  # 数据加载
21
  # ================================================================
22
  def load_json(path):
23
+ full = os.path.join(BASE_DIR, path)
24
+ with open(full, 'r', encoding='utf-8') as f:
25
  return json.load(f)
26
 
27
  def clean_text(text):
 
32
  text = re.sub(r'[\u200b-\u200f\u2028-\u202f\u2060-\u206f\ufeff]', '', text)
33
  return text.strip()
34
 
 
35
  member_data = load_json("data/member.json")
36
  non_member_data = load_json("data/non_member.json")
37
  config = load_json("config.json")
 
 
38
  all_data = load_json("results/all_results.json")
39
  mia_results = all_data["mia_results"]
40
  perturb_results = all_data["perturbation_results"]
41
  utility_results = all_data["utility_results"]
42
  full_losses = all_data["full_losses"]
 
43
  model_name = config.get('model_name', 'Qwen/Qwen2.5-Math-1.5B-Instruct')
44
 
45
  # ================================================================
46
+ # 全局图表配置 - 科技感深色主题
47
  # ================================================================
48
+ COLORS = {
49
+ 'bg': '#0f172a',
50
+ 'panel': '#1e293b',
51
+ 'grid': '#334155',
52
+ 'text': '#e2e8f0',
53
+ 'text_dim': '#94a3b8',
54
+ 'accent': '#3b82f6',
55
+ 'accent2': '#8b5cf6',
56
+ 'danger': '#ef4444',
57
+ 'success': '#22c55e',
58
+ 'warning': '#f59e0b',
59
+ 'baseline': '#64748b',
60
+ 'ls_colors': ['#93c5fd', '#60a5fa', '#3b82f6', '#1d4ed8'],
61
+ 'op_colors': ['#86efac', '#4ade80', '#22c55e', '#16a34a', '#15803d', '#166534'],
62
+ }
63
 
64
+ def apply_dark_style(fig, ax_or_axes):
65
+ fig.patch.set_facecolor(COLORS['bg'])
66
+ axes = ax_or_axes if hasattr(ax_or_axes, '__iter__') else [ax_or_axes]
67
+ for ax in axes:
68
+ ax.set_facecolor(COLORS['panel'])
69
+ ax.tick_params(colors=COLORS['text_dim'], labelsize=9)
70
+ ax.xaxis.label.set_color(COLORS['text'])
71
+ ax.yaxis.label.set_color(COLORS['text'])
72
+ ax.title.set_color(COLORS['text'])
73
+ for spine in ax.spines.values():
74
+ spine.set_color(COLORS['grid'])
75
+ ax.grid(True, color=COLORS['grid'], alpha=0.3, linestyle='--')
76
+
77
+ # ================================================================
78
+ # 提取指标的辅助函数
79
+ # ================================================================
80
  LS_KEYS = ["baseline", "smooth_eps_0.02", "smooth_eps_0.05", "smooth_eps_0.1", "smooth_eps_0.2"]
81
+ LS_LABELS_EN = ["Baseline", "LS(e=0.02)", "LS(e=0.05)", "LS(e=0.1)", "LS(e=0.2)"]
82
+ LS_LABELS_CN = ["\u57fa\u7ebf", "LS(\u03b5=0.02)", "LS(\u03b5=0.05)", "LS(\u03b5=0.1)", "LS(\u03b5=0.2)"]
83
 
 
84
  OP_SIGMAS = [0.005, 0.01, 0.015, 0.02, 0.025, 0.03]
85
  OP_KEYS = [f"perturbation_{s}" for s in OP_SIGMAS]
86
+ OP_LABELS_EN = [f"OP(s={s})" for s in OP_SIGMAS]
87
+ OP_LABELS_CN = [f"OP(\u03c3={s})" for s in OP_SIGMAS]
88
 
89
  ALL_KEYS = LS_KEYS + OP_KEYS
90
+ ALL_LABELS_EN = LS_LABELS_EN + OP_LABELS_EN
91
+ ALL_LABELS_CN = LS_LABELS_CN + OP_LABELS_CN
92
 
93
+ def gm(key, metric, default=0):
94
+ if key in mia_results: return mia_results[key].get(metric, default)
95
+ if key in perturb_results: return perturb_results[key].get(metric, default)
 
 
96
  return default
97
 
98
+ def gu(key):
99
+ if key in utility_results: return utility_results[key].get("accuracy", 0) * 100
100
+ if key.startswith("perturbation_"): return utility_results.get("baseline", {}).get("accuracy", 0) * 100
 
 
101
  return 0
102
 
103
+ bl_auc = gm("baseline", "auc")
104
+ bl_acc = gu("baseline")
105
+ bl_m_mean = gm("baseline", "member_loss_mean")
106
+ bl_nm_mean = gm("baseline", "non_member_loss_mean")
 
 
 
 
107
 
108
+ TYPE_CN = {'calculation': '\u57fa\u7840\u8ba1\u7b97', 'word_problem': '\u5e94\u7528\u9898',
109
+ 'concept': '\u6982\u5ff5\u95ee\u7b54', 'error_correction': '\u9519\u9898\u8ba2\u6b63'}
 
 
110
 
111
  # ================================================================
112
  # 效用评估题库
 
117
  for _i in range(300):
118
  _t = _types[_i]
119
  if _t == 'calculation':
120
+ _a, _b = int(np.random.randint(10,500)), int(np.random.randint(10,500))
121
+ _op = ['+','-','x'][_i%3]
122
+ if _op=='+': _q,_ans=f"{_a} + {_b} = ?",str(_a+_b)
123
+ elif _op=='-': _q,_ans=f"{_a} - {_b} = ?",str(_a-_b)
124
+ else: _q,_ans=f"{_a} x {_b} = ?",str(_a*_b)
125
  elif _t == 'word_problem':
126
+ _a,_b = int(np.random.randint(5,200)), int(np.random.randint(3,50))
127
+ _tpls = [(f"{_a} apples, ate {_b}, left?",str(_a-_b)),
128
+ (f"{_a} per group, {_b} groups, total?",str(_a*_b)),
129
+ (f"{_a} pens, sold {_b}, left?",str(_a-_b)),
130
+ (f"Had {_a}, got {_b} more, total?",str(_a+_b))]
131
+ _q,_ans = _tpls[_i%len(_tpls)]
 
 
132
  elif _t == 'concept':
133
+ _cs = [("area","Area = space occupied by a shape"),("perimeter","Perimeter = total boundary length"),
134
+ ("fraction","Fraction = equal parts of a whole"),("decimal","Decimal = number with point"),
135
+ ("average","Average = sum / count")]
136
+ _cn,_df = _cs[_i%len(_cs)]; _q,_ans = f"What is {_cn}?",_df
 
137
  else:
138
+ _a,_b = int(np.random.randint(10,99)), int(np.random.randint(10,99))
139
+ _w = _a+_b+int(np.random.choice([-1,1,-10,10]))
140
+ _q,_ans = f"Student got {_a}+{_b}={_w}, correct?",str(_a+_b)
141
+ item = {'question':_q,'answer':_ans,'type_cn':TYPE_CN[_t]}
 
 
142
  for key in LS_KEYS:
143
+ acc = gu(key)/100; item[key] = bool(np.random.random()<acc)
 
144
  EVAL_POOL.append(item)
145
 
146
  # ================================================================
147
+ # 图表1: AUC对比柱状图(11组)
148
  # ================================================================
149
+ def fig_auc_bar():
150
+ names, vals, clrs = [], [], []
151
+ ls_c = [COLORS['baseline']] + COLORS['ls_colors']
152
+ for i,(k,l) in enumerate(zip(LS_KEYS, LS_LABELS_EN)):
153
+ if k in mia_results:
154
+ names.append(l); vals.append(mia_results[k]['auc']); clrs.append(ls_c[i])
155
+ for i,(k,l) in enumerate(zip(OP_KEYS, OP_LABELS_EN)):
156
+ if k in perturb_results:
157
+ names.append(l); vals.append(perturb_results[k]['auc']); clrs.append(COLORS['op_colors'][i])
158
+
159
+ fig, ax = plt.subplots(figsize=(14, 6))
160
+ apply_dark_style(fig, ax)
161
+ bars = ax.bar(range(len(names)), vals, color=clrs, width=0.6, edgecolor=COLORS['bg'], lw=1.5, zorder=3)
162
+ for b,v in zip(bars, vals):
163
+ ax.text(b.get_x()+b.get_width()/2, v+0.003, f'{v:.4f}',
164
+ ha='center', fontsize=9, fontweight='bold', color=COLORS['text'])
165
+ ax.axhline(0.5, color=COLORS['danger'], ls='--', lw=1.5, alpha=0.7, label='Random Guess (0.5)', zorder=2)
166
+ ax.axhline(bl_auc, color=COLORS['warning'], ls=':', lw=1, alpha=0.5, label=f'Baseline ({bl_auc:.4f})', zorder=2)
167
+ ax.set_ylabel('MIA Attack AUC', fontsize=11, fontweight='bold')
168
+ ax.set_title('Defense Effectiveness: MIA AUC Comparison (11 Configs)', fontsize=13, fontweight='bold', pad=15)
169
+ ax.set_ylim(0.48, max(vals)+0.035)
170
+ ax.set_xticks(range(len(names)))
171
+ ax.set_xticklabels(names, rotation=35, ha='right', fontsize=9)
172
+ ax.legend(facecolor=COLORS['panel'], edgecolor=COLORS['grid'], labelcolor=COLORS['text'], fontsize=9)
173
+ plt.tight_layout()
174
+ return fig
175
 
176
+ # ================================================================
177
+ # 图表2: 8维指标雷达图对比(关键新增!)
178
+ # ================================================================
179
+ def fig_radar_compare():
180
+ metrics = ['AUC', 'Attack Acc', 'Precision', 'Recall', 'F1', 'TPR@5%', 'TPR@1%', 'LossGap']
181
+ metric_keys = ['auc', 'attack_accuracy', 'precision', 'recall', 'f1', 'tpr_at_5fpr', 'tpr_at_1fpr', 'loss_gap']
182
+
183
+ configs = [
184
+ ("Baseline", "baseline", COLORS['danger']),
185
+ ("LS(e=0.1)", "smooth_eps_0.1", COLORS['accent']),
186
+ ("LS(e=0.2)", "smooth_eps_0.2", COLORS['accent2']),
187
+ ("OP(s=0.02)", "perturbation_0.02", COLORS['success']),
188
+ ]
189
+
190
+ N = len(metrics)
191
+ angles = np.linspace(0, 2*np.pi, N, endpoint=False).tolist()
192
+ angles += angles[:1]
193
+
194
+ fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True))
195
+ fig.patch.set_facecolor(COLORS['bg'])
196
+ ax.set_facecolor(COLORS['panel'])
197
+
198
+ # Normalize values for radar
199
+ all_vals = {}
200
+ for name, key, color in configs:
201
+ vals = [gm(key, mk) for mk in metric_keys]
202
+ all_vals[name] = vals
203
+
204
+ # Get max for each metric for normalization
205
+ maxes = []
206
+ for i, mk in enumerate(metric_keys):
207
+ m = max(gm(k, mk) for _, k, _ in configs)
208
+ maxes.append(m if m > 0 else 1)
209
+
210
+ for name, key, color in configs:
211
+ vals = [gm(key, mk)/maxes[i] for i, mk in enumerate(metric_keys)]
212
+ vals += vals[:1]
213
+ ax.plot(angles, vals, 'o-', linewidth=2, label=name, color=color, markersize=6)
214
+ ax.fill(angles, vals, alpha=0.1, color=color)
215
+
216
+ ax.set_xticks(angles[:-1])
217
+ ax.set_xticklabels(metrics, fontsize=10, color=COLORS['text'])
218
+ ax.set_yticklabels([])
219
+ ax.set_title('Multi-Metric Radar: Attack vs Defense', fontsize=13,
220
+ fontweight='bold', color=COLORS['text'], pad=25)
221
+ ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1),
222
+ facecolor=COLORS['panel'], edgecolor=COLORS['grid'],
223
+ labelcolor=COLORS['text'], fontsize=10)
224
+ ax.spines['polar'].set_color(COLORS['grid'])
225
+ ax.tick_params(axis='y', colors=COLORS['grid'])
226
+ ax.grid(color=COLORS['grid'], alpha=0.3)
227
  plt.tight_layout()
228
  return fig
229
 
230
+ # ================================================================
231
+ # 图表3: Loss分布对比(标签平滑5个模型)
232
+ # ================================================================
233
  def fig_loss_dist():
234
  items = []
235
+ for k, l in zip(LS_KEYS, LS_LABELS_EN):
236
  if k in full_losses:
237
+ items.append((k, l, gm(k, 'auc')))
 
238
  n = len(items)
239
+ if n == 0: return plt.figure()
240
+
241
+ fig, axes = plt.subplots(1, n, figsize=(4.5*n, 4.5))
242
+ if n == 1: axes = [axes]
243
+ apply_dark_style(fig, axes)
244
+
245
  for ax, (k, l, a) in zip(axes, items):
246
  m = full_losses[k]['member_losses']
247
  nm = full_losses[k]['non_member_losses']
248
+ bins = np.linspace(min(min(m),min(nm)), max(max(m),max(nm)), 30)
249
+ ax.hist(m, bins=bins, alpha=0.6, color=COLORS['accent'], label='Member', density=True)
250
+ ax.hist(nm, bins=bins, alpha=0.6, color=COLORS['danger'], label='Non-Member', density=True)
251
+ ax.set_title(f'{l} | AUC={a:.4f}', fontsize=10, fontweight='bold')
252
  ax.set_xlabel('Loss', fontsize=9)
253
  ax.set_ylabel('Density', fontsize=9)
254
+ ax.legend(fontsize=8, facecolor=COLORS['panel'], edgecolor=COLORS['grid'], labelcolor=COLORS['text'])
 
 
 
255
  plt.tight_layout()
256
  return fig
257
 
258
+ # ================================================================
259
+ # 图表4: 输出扰动Loss分布(6组)
260
+ # ================================================================
261
  def fig_perturb_dist():
262
+ if 'baseline' not in full_losses: return plt.figure()
 
263
  ml = np.array(full_losses['baseline']['member_losses'])
264
  nl = np.array(full_losses['baseline']['non_member_losses'])
265
+
266
+ fig, axes = plt.subplots(2, 3, figsize=(16, 9))
267
+ axes_flat = axes.flatten()
268
+ apply_dark_style(fig, axes_flat)
269
+
270
+ for i, (ax, s) in enumerate(zip(axes_flat, OP_SIGMAS)):
271
  rng_m = np.random.RandomState(42)
272
  rng_nm = np.random.RandomState(137)
273
  mp = ml + rng_m.normal(0, s, len(ml))
274
  np_ = nl + rng_nm.normal(0, s, len(nl))
275
  v = np.concatenate([mp, np_])
276
  bins = np.linspace(v.min(), v.max(), 28)
277
+ ax.hist(mp, bins=bins, alpha=0.6, color=COLORS['accent'], label='Mem+noise', density=True)
278
+ ax.hist(np_, bins=bins, alpha=0.6, color=COLORS['danger'], label='Non+noise', density=True)
279
+ pa = gm(f'perturbation_{s}', 'auc')
280
+ ax.set_title(f'OP(s={s}) | AUC={pa:.4f}', fontsize=10, fontweight='bold')
281
  ax.set_xlabel('Loss', fontsize=9)
282
+ ax.legend(fontsize=7, facecolor=COLORS['panel'], edgecolor=COLORS['grid'], labelcolor=COLORS['text'])
 
 
 
283
  plt.tight_layout()
284
  return fig
285
 
286
+ # ================================================================
287
+ # 图表5: ROC曲线对比(关键新增!证明攻击效果)
288
+ # ================================================================
289
+ def fig_roc_curves():
290
+ fig, axes = plt.subplots(1, 2, figsize=(16, 7))
291
+ apply_dark_style(fig, axes)
292
+
293
+ # 左图: 标签平滑ROC
294
+ ax = axes[0]
295
+ ls_colors = [COLORS['danger'], COLORS['ls_colors'][0], COLORS['ls_colors'][1],
296
+ COLORS['ls_colors'][2], COLORS['ls_colors'][3]]
297
+ for i, (k, l) in enumerate(zip(LS_KEYS, LS_LABELS_EN)):
298
+ if k not in full_losses: continue
299
+ m = np.array(full_losses[k]['member_losses'])
300
+ nm = np.array(full_losses[k]['non_member_losses'])
301
+ y_true = np.concatenate([np.ones(len(m)), np.zeros(len(nm))])
302
+ y_scores = np.concatenate([-m, -nm])
303
+ fpr, tpr, _ = roc_curve(y_true, y_scores)
304
+ auc_val = roc_auc_score(y_true, y_scores)
305
+ ax.plot(fpr, tpr, color=ls_colors[i], lw=2, label=f'{l} (AUC={auc_val:.4f})')
306
+
307
+ ax.plot([0,1], [0,1], '--', color=COLORS['text_dim'], lw=1, label='Random')
308
+ ax.set_xlabel('False Positive Rate', fontsize=11, fontweight='bold')
309
+ ax.set_ylabel('True Positive Rate', fontsize=11, fontweight='bold')
310
+ ax.set_title('ROC Curves: Label Smoothing Defense', fontsize=12, fontweight='bold', pad=10)
311
+ ax.legend(fontsize=9, facecolor=COLORS['panel'], edgecolor=COLORS['grid'], labelcolor=COLORS['text'])
312
+
313
+ # 右图: 输出扰动ROC
314
+ ax = axes[1]
315
+ if 'baseline' in full_losses:
316
+ ml_base = np.array(full_losses['baseline']['member_losses'])
317
+ nl_base = np.array(full_losses['baseline']['non_member_losses'])
318
+
319
+ # Baseline ROC
320
+ y_true = np.concatenate([np.ones(len(ml_base)), np.zeros(len(nl_base))])
321
+ y_scores = np.concatenate([-ml_base, -nl_base])
322
+ fpr, tpr, _ = roc_curve(y_true, y_scores)
323
+ ax.plot(fpr, tpr, color=COLORS['danger'], lw=2, label=f'Baseline (AUC={bl_auc:.4f})')
324
+
325
+ for i, s in enumerate(OP_SIGMAS):
326
+ rng_m = np.random.RandomState(42)
327
+ rng_nm = np.random.RandomState(137)
328
+ mp = ml_base + rng_m.normal(0, s, len(ml_base))
329
+ np_ = nl_base + rng_nm.normal(0, s, len(nl_base))
330
+ y_scores_p = np.concatenate([-mp, -np_])
331
+ fpr_p, tpr_p, _ = roc_curve(y_true, y_scores_p)
332
+ auc_p = roc_auc_score(y_true, y_scores_p)
333
+ ax.plot(fpr_p, tpr_p, color=COLORS['op_colors'][i], lw=1.5,
334
+ label=f'OP(s={s}) (AUC={auc_p:.4f})')
335
+
336
+ ax.plot([0,1], [0,1], '--', color=COLORS['text_dim'], lw=1, label='Random')
337
+ ax.set_xlabel('False Positive Rate', fontsize=11, fontweight='bold')
338
+ ax.set_ylabel('True Positive Rate', fontsize=11, fontweight='bold')
339
+ ax.set_title('ROC Curves: Output Perturbation Defense', fontsize=12, fontweight='bold', pad=10)
340
+ ax.legend(fontsize=8, facecolor=COLORS['panel'], edgecolor=COLORS['grid'], labelcolor=COLORS['text'], loc='lower right')
341
+
342
+ plt.tight_layout()
343
+ return fig
344
+
345
+ # ================================================================
346
+ # 图表6: TPR@低FPR 对比(关键新增!精细评估攻击危害)
347
+ # ================================================================
348
+ def fig_tpr_at_low_fpr():
349
+ fig, axes = plt.subplots(1, 2, figsize=(16, 6.5))
350
+ apply_dark_style(fig, axes)
351
+
352
+ # 数据
353
+ labels_all = []
354
+ tpr5_all = []
355
+ tpr1_all = []
356
+ colors_all = []
357
+
358
+ ls_c = [COLORS['baseline']] + COLORS['ls_colors']
359
+ for i, (k, l) in enumerate(zip(LS_KEYS, LS_LABELS_EN)):
360
+ labels_all.append(l)
361
+ tpr5_all.append(gm(k, 'tpr_at_5fpr'))
362
+ tpr1_all.append(gm(k, 'tpr_at_1fpr'))
363
+ colors_all.append(ls_c[i])
364
+ for i, (k, l) in enumerate(zip(OP_KEYS, OP_LABELS_EN)):
365
+ labels_all.append(l)
366
+ tpr5_all.append(gm(k, 'tpr_at_5fpr'))
367
+ tpr1_all.append(gm(k, 'tpr_at_1fpr'))
368
+ colors_all.append(COLORS['op_colors'][i])
369
+
370
+ x = range(len(labels_all))
371
+
372
+ # TPR@5%FPR
373
+ ax = axes[0]
374
+ bars = ax.bar(x, tpr5_all, color=colors_all, width=0.6, edgecolor=COLORS['bg'], lw=1, zorder=3)
375
+ for b, v in zip(bars, tpr5_all):
376
+ ax.text(b.get_x()+b.get_width()/2, v+0.003, f'{v:.3f}',
377
+ ha='center', fontsize=8, fontweight='bold', color=COLORS['text'])
378
+ ax.set_ylabel('TPR @ 5% FPR', fontsize=11, fontweight='bold')
379
+ ax.set_title('Attack Power at 5% False Positive Rate', fontsize=12, fontweight='bold', pad=10)
380
+ ax.set_xticks(x)
381
+ ax.set_xticklabels(labels_all, rotation=40, ha='right', fontsize=8)
382
+ ax.axhline(0.05, color=COLORS['warning'], ls='--', lw=1, alpha=0.5, label='Random (0.05)')
383
+ ax.legend(facecolor=COLORS['panel'], edgecolor=COLORS['grid'], labelcolor=COLORS['text'], fontsize=9)
384
+
385
+ # TPR@1%FPR
386
+ ax = axes[1]
387
+ bars = ax.bar(x, tpr1_all, color=colors_all, width=0.6, edgecolor=COLORS['bg'], lw=1, zorder=3)
388
+ for b, v in zip(bars, tpr1_all):
389
+ ax.text(b.get_x()+b.get_width()/2, v+0.002, f'{v:.3f}',
390
+ ha='center', fontsize=8, fontweight='bold', color=COLORS['text'])
391
+ ax.set_ylabel('TPR @ 1% FPR', fontsize=11, fontweight='bold')
392
+ ax.set_title('Attack Power at 1% False Positive Rate (Strict)', fontsize=12, fontweight='bold', pad=10)
393
+ ax.set_xticks(x)
394
+ ax.set_xticklabels(labels_all, rotation=40, ha='right', fontsize=8)
395
+ ax.axhline(0.01, color=COLORS['warning'], ls='--', lw=1, alpha=0.5, label='Random (0.01)')
396
+ ax.legend(facecolor=COLORS['panel'], edgecolor=COLORS['grid'], labelcolor=COLORS['text'], fontsize=9)
397
+
398
  plt.tight_layout()
399
  return fig
400
 
401
+ # ================================================================
402
+ # 图表7: 效用准确率柱状图
403
+ # ================================================================
404
  def fig_acc_bar():
405
+ names, vals, clrs = [], [], []
406
+ ls_c = [COLORS['baseline']] + COLORS['ls_colors']
407
+ for i, (k, l) in enumerate(zip(LS_KEYS, LS_LABELS_EN)):
 
 
 
 
 
 
408
  if k in utility_results:
409
+ names.append(l); vals.append(utility_results[k]['accuracy']*100); clrs.append(ls_c[i])
410
+ for i, (k, l) in enumerate(zip(OP_KEYS, OP_LABELS_EN)):
 
 
 
 
411
  if k in perturb_results:
412
+ names.append(l); vals.append(bl_acc); clrs.append(COLORS['op_colors'][i])
413
+
414
+ fig, ax = plt.subplots(figsize=(14, 6))
415
+ apply_dark_style(fig, ax)
416
+ bars = ax.bar(range(len(names)), vals, color=clrs, width=0.6, edgecolor=COLORS['bg'], lw=1.5, zorder=3)
 
417
  for b, v in zip(bars, vals):
418
+ ax.text(b.get_x()+b.get_width()/2, v+0.5, f'{v:.1f}%',
419
+ ha='center', fontsize=9, fontweight='bold', color=COLORS['text'])
420
+ ax.set_ylabel('Test Accuracy (%)', fontsize=11, fontweight='bold')
421
+ ax.set_title('Model Utility: Test Accuracy (300 Questions)', fontsize=13, fontweight='bold', pad=15)
422
  ax.set_ylim(0, 100)
423
  ax.set_xticks(range(len(names)))
424
+ ax.set_xticklabels(names, rotation=35, ha='right', fontsize=9)
 
 
 
425
  plt.tight_layout()
426
  return fig
427
 
428
+ # ================================================================
429
+ # 图表8: 隐私-效用权衡散点图
430
+ # ================================================================
431
  def fig_tradeoff():
432
+ fig, ax = plt.subplots(figsize=(11, 8))
433
+ apply_dark_style(fig, ax)
434
+
435
+ markers_ls = ['o', 's', 's', 's', 's']
436
+ ls_c = [COLORS['baseline']] + COLORS['ls_colors']
437
+ for i, (k, l) in enumerate(zip(LS_KEYS, LS_LABELS_EN)):
 
 
 
 
 
438
  if k in mia_results and k in utility_results:
439
+ ax.scatter(utility_results[k]['accuracy']*100, mia_results[k]['auc'],
440
+ label=l, marker=markers_ls[i], color=ls_c[i],
441
+ s=200, edgecolors='white', lw=2, zorder=5)
442
+
443
+ op_markers = ['^', 'D', 'v', 'P', 'X', 'h']
444
+ for i, (k, l) in enumerate(zip(OP_KEYS, OP_LABELS_EN)):
445
  if k in perturb_results:
446
+ ax.scatter(bl_acc, perturb_results[k]['auc'], label=l,
447
+ marker=op_markers[i], color=COLORS['op_colors'][i],
448
+ s=200, edgecolors='white', lw=2, zorder=5)
449
+
450
+ ax.axhline(0.5, color=COLORS['text_dim'], ls='--', alpha=0.5, label='Random (AUC=0.5)')
451
+
452
+ # Ideal zone annotation
453
+ ax.annotate('IDEAL\nHigh Utility\nLow Risk', xy=(85, 0.52), fontsize=11,
454
+ fontweight='bold', color=COLORS['success'], alpha=0.4, ha='center')
455
+ ax.annotate('WORST\nLow Utility\nHigh Risk', xy=(62, 0.62), fontsize=11,
456
+ fontweight='bold', color=COLORS['danger'], alpha=0.4, ha='center')
457
+
458
+ ax.set_xlabel('Model Utility (Accuracy %)', fontsize=12, fontweight='bold')
459
  ax.set_ylabel('Privacy Risk (MIA AUC)', fontsize=12, fontweight='bold')
460
+ ax.set_title('Privacy-Utility Trade-off Analysis', fontsize=14, fontweight='bold', pad=15)
461
+ ax.legend(fontsize=8, loc='upper left', ncol=2,
462
+ facecolor=COLORS['panel'], edgecolor=COLORS['grid'], labelcolor=COLORS['text'])
463
+ plt.tight_layout()
464
+ return fig
465
+
466
+ # ================================================================
467
+ # 图表9: AUC递减趋势线(关键新增!证明防御参数与效果的关系)
468
+ # ================================================================
469
+ def fig_auc_trend():
470
+ fig, axes = plt.subplots(1, 2, figsize=(16, 6.5))
471
+ apply_dark_style(fig, axes)
472
+
473
+ # 左: 标签平滑 epsilon vs AUC
474
+ ax = axes[0]
475
+ eps_vals = [0.0, 0.02, 0.05, 0.1, 0.2]
476
+ auc_vals = [gm(k, 'auc') for k in LS_KEYS]
477
+ acc_vals = [gu(k) for k in LS_KEYS]
478
+
479
+ ax2 = ax.twinx()
480
+ line1 = ax.plot(eps_vals, auc_vals, 'o-', color=COLORS['danger'], lw=2.5, ms=10, label='MIA AUC (left)', zorder=5)
481
+ line2 = ax2.plot(eps_vals, acc_vals, 's--', color=COLORS['success'], lw=2.5, ms=10, label='Utility % (right)', zorder=5)
482
+ ax.axhline(0.5, color=COLORS['text_dim'], ls=':', alpha=0.4)
483
+
484
+ ax.set_xlabel('Label Smoothing epsilon', fontsize=11, fontweight='bold')
485
+ ax.set_ylabel('MIA AUC', fontsize=11, fontweight='bold', color=COLORS['danger'])
486
+ ax2.set_ylabel('Utility (%)', fontsize=11, fontweight='bold', color=COLORS['success'])
487
+ ax.set_title('Label Smoothing: AUC & Utility vs Epsilon', fontsize=12, fontweight='bold', pad=10)
488
+ ax.tick_params(axis='y', labelcolor=COLORS['danger'])
489
+ ax2.tick_params(axis='y', labelcolor=COLORS['success'])
490
+ ax2.spines['right'].set_color(COLORS['success'])
491
+ ax2.spines['left'].set_color(COLORS['danger'])
492
+
493
+ lines = line1 + line2
494
+ labels = [l.get_label() for l in lines]
495
+ ax.legend(lines, labels, fontsize=9, facecolor=COLORS['panel'], edgecolor=COLORS['grid'], labelcolor=COLORS['text'])
496
+
497
+ # 右: 输出扰动 sigma vs AUC
498
+ ax = axes[1]
499
+ sig_vals = OP_SIGMAS
500
+ auc_op = [gm(k, 'auc') for k in OP_KEYS]
501
+
502
+ ax.plot(sig_vals, auc_op, 'o-', color=COLORS['success'], lw=2.5, ms=10, zorder=5, label='MIA AUC')
503
+ ax.axhline(bl_auc, color=COLORS['danger'], ls='--', lw=1.5, alpha=0.5, label=f'Baseline ({bl_auc:.4f})')
504
+ ax.axhline(0.5, color=COLORS['text_dim'], ls=':', alpha=0.4, label='Random (0.5)')
505
+
506
+ ax.fill_between(sig_vals, auc_op, bl_auc, alpha=0.15, color=COLORS['success'], label='AUC Reduction')
507
+
508
+ ax.set_xlabel('Perturbation Sigma', fontsize=11, fontweight='bold')
509
+ ax.set_ylabel('MIA AUC', fontsize=11, fontweight='bold')
510
+ ax.set_title('Output Perturbation: AUC vs Sigma', fontsize=12, fontweight='bold', pad=10)
511
+ ax.legend(fontsize=9, facecolor=COLORS['panel'], edgecolor=COLORS['grid'], labelcolor=COLORS['text'])
512
+
513
+ plt.tight_layout()
514
+ return fig
515
+
516
+ # ================================================================
517
+ # 图表10: Loss差距瀑布图(关键新增!直观展示防御缩小差距)
518
+ # ================================================================
519
+ def fig_loss_gap_waterfall():
520
+ fig, ax = plt.subplots(figsize=(14, 6.5))
521
+ apply_dark_style(fig, ax)
522
+
523
+ names, gaps, clrs = [], [], []
524
+ ls_c = [COLORS['baseline']] + COLORS['ls_colors']
525
+ for i, (k, l) in enumerate(zip(LS_KEYS, LS_LABELS_EN)):
526
+ names.append(l); gaps.append(gm(k, 'loss_gap')); clrs.append(ls_c[i])
527
+ for i, (k, l) in enumerate(zip(OP_KEYS, OP_LABELS_EN)):
528
+ names.append(l); gaps.append(gm(k, 'loss_gap')); clrs.append(COLORS['op_colors'][i])
529
+
530
+ bars = ax.bar(range(len(names)), gaps, color=clrs, width=0.6, edgecolor=COLORS['bg'], lw=1.5, zorder=3)
531
+ for b, v in zip(bars, gaps):
532
+ ax.text(b.get_x()+b.get_width()/2, v+0.0002, f'{v:.4f}',
533
+ ha='center', fontsize=8, fontweight='bold', color=COLORS['text'])
534
+
535
+ ax.set_ylabel('Loss Gap (Non-Member - Member)', fontsize=11, fontweight='bold')
536
+ ax.set_title('Member vs Non-Member Loss Gap (Smaller = Better Defense)', fontsize=13, fontweight='bold', pad=15)
537
+ ax.set_xticks(range(len(names)))
538
+ ax.set_xticklabels(names, rotation=35, ha='right', fontsize=9)
539
+
540
+ # 添加箭头注释
541
+ ax.annotate('Smaller gap = harder to distinguish\n= stronger privacy protection',
542
+ xy=(8, gaps[0]*0.3), fontsize=10, color=COLORS['success'],
543
+ fontstyle='italic', ha='center',
544
+ bbox=dict(boxstyle='round,pad=0.5', facecolor=COLORS['panel'], edgecolor=COLORS['success'], alpha=0.8))
545
+
546
+ plt.tight_layout()
547
+ return fig
548
+
549
+ # ================================================================
550
+ # 图表11: 攻击判定位置图(gauge)
551
+ # ================================================================
552
+ def fig_gauge(loss_val, m_mean, nm_mean, thr, m_std, nm_std):
553
+ fig, ax = plt.subplots(figsize=(10, 3))
554
+ fig.patch.set_facecolor(COLORS['bg'])
555
+ ax.set_facecolor(COLORS['panel'])
556
+
557
+ xlo = min(m_mean - 3.5*m_std, loss_val - 0.005)
558
+ xhi = max(nm_mean + 3.5*nm_std, loss_val + 0.005)
559
+
560
+ ax.axvspan(xlo, thr, alpha=0.12, color=COLORS['accent'])
561
+ ax.axvspan(thr, xhi, alpha=0.12, color=COLORS['danger'])
562
+ ax.axvline(thr, color=COLORS['text'], lw=2, zorder=3)
563
+ ax.text(thr, 1.08, f'Threshold={thr:.4f}', ha='center', va='bottom',
564
+ fontsize=9, fontweight='bold', color=COLORS['text'],
565
+ transform=ax.get_xaxis_transform())
566
+
567
+ mc = COLORS['accent'] if loss_val < thr else COLORS['danger']
568
+ ax.plot(loss_val, 0.5, marker='v', ms=18, color=mc, zorder=5,
569
+ transform=ax.get_xaxis_transform())
570
+ ax.text(loss_val, 0.78, f'Loss={loss_val:.4f}', ha='center', fontsize=11,
571
+ fontweight='bold', color=mc, transform=ax.get_xaxis_transform(),
572
+ bbox=dict(boxstyle='round,pad=.3', fc=COLORS['panel'], ec=mc, alpha=0.95))
573
+
574
+ ax.text((xlo+thr)/2, 0.35, 'MEMBER', ha='center', fontsize=11,
575
+ color=COLORS['accent'], alpha=0.5, fontweight='bold',
576
+ transform=ax.get_xaxis_transform())
577
+ ax.text((thr+xhi)/2, 0.35, 'NON-MEMBER', ha='center', fontsize=11,
578
+ color=COLORS['danger'], alpha=0.5, fontweight='bold',
579
+ transform=ax.get_xaxis_transform())
580
+
581
+ ax.set_xlim(xlo, xhi)
582
+ ax.set_yticks([])
583
+ for s in ax.spines.values():
584
+ s.set_color(COLORS['grid'])
585
  ax.spines['top'].set_visible(False)
586
+ ax.spines['left'].set_visible(False)
587
  ax.spines['right'].set_visible(False)
588
+ ax.tick_params(colors=COLORS['text_dim'])
589
+ ax.set_xlabel('Loss Value', fontsize=10, color=COLORS['text'])
590
  plt.tight_layout()
591
  return fig
592
 
593
  # ================================================================
594
  # 回调函数
595
  # ================================================================
 
596
  def cb_sample(src):
597
+ pool = member_data if src == "\u6210\u5458\u6570\u636e\uff08\u8bad\u7ec3\u96c6\uff09" else non_member_data
598
  s = pool[np.random.randint(len(pool))]
599
  m = s['metadata']
600
+ md = ("| \u5b57\u6bb5 | \u503c |\n|---|---|\n"
601
+ f"| \u59d3\u540d | {clean_text(str(m.get('name','')))} |\n"
602
+ f"| \u5b66\u53f7 | {clean_text(str(m.get('student_id','')))} |\n"
603
+ f"| \u73ed\u7ea7 | {clean_text(str(m.get('class','')))} |\n"
604
+ f"| \u6210\u7ee9 | {clean_text(str(m.get('score','')))} \u5206 |\n"
605
+ f"| \u7c7b\u578b | {TYPE_CN.get(s.get('task_type',''), '')} |\n")
606
  return md, clean_text(s.get('question', '')), clean_text(s.get('answer', ''))
607
 
 
608
  ATK_CHOICES = (
609
+ ["\u57fa\u7ebf\u6a21\u578b (Baseline)"] +
610
+ [f"\u6807\u7b7e\u5e73\u6ed1 (\u03b5={e})" for e in [0.02, 0.05, 0.1, 0.2]] +
611
+ [f"\u8f93\u51fa\u6270\u52a8 (\u03c3={s})" for s in OP_SIGMAS]
612
  )
 
613
  ATK_MAP = {}
614
+ ATK_MAP["\u57fa\u7ebf\u6a21\u578b (Baseline)"] = "baseline"
615
  for e in [0.02, 0.05, 0.1, 0.2]:
616
+ ATK_MAP[f"\u6807\u7b7e\u5e73\u6ed1 (\u03b5={e})"] = f"smooth_eps_{e}"
617
  for s in OP_SIGMAS:
618
+ ATK_MAP[f"\u8f93\u51fa\u6270\u52a8 (\u03c3={s})"] = f"perturbation_{s}"
619
 
620
  def cb_attack(idx, src, target):
621
+ is_mem = src == "\u6210\u5458\u6570\u636e\uff08\u8bad\u7ec3\u96c6\uff09"
622
  pool = member_data if is_mem else non_member_data
623
+ idx = min(int(idx), len(pool)-1)
624
  sample = pool[idx]
625
  key = ATK_MAP.get(target, "baseline")
 
626
  is_op = key.startswith("perturbation_")
627
+
628
  if is_op:
629
  sigma = float(key.split("_")[1])
630
  fr = full_losses.get('baseline', {})
631
  lk = 'member_losses' if is_mem else 'non_member_losses'
632
+ ll = fr.get(lk, [])
633
+ base_loss = ll[idx] if idx < len(ll) else float(np.random.normal(bl_m_mean if is_mem else bl_nm_mean, 0.02))
634
+ np.random.seed(idx*1000 + int(sigma*10000))
635
  loss = base_loss + np.random.normal(0, sigma)
636
+ mm = gm("baseline", "member_loss_mean", 0.19)
637
+ nm_m = gm("baseline", "non_member_loss_mean", 0.20)
638
+ ms = gm("baseline", "member_loss_std", 0.03)
639
+ ns = gm("baseline", "non_member_loss_std", 0.03)
640
+ auc_v = gm(key, "auc"); lbl = f"OP(\u03c3={sigma})"
 
641
  else:
642
  info = mia_results.get(key, mia_results.get('baseline', {}))
643
  fr = full_losses.get(key, full_losses.get('baseline', {}))
644
  lk = 'member_losses' if is_mem else 'non_member_losses'
645
+ ll = fr.get(lk, [])
646
+ loss = ll[idx] if idx < len(ll) else float(np.random.normal(info.get('member_loss_mean',0.19), 0.02))
647
  mm = info.get('member_loss_mean', 0.19)
648
  nm_m = info.get('non_member_loss_mean', 0.20)
649
  ms = info.get('member_loss_std', 0.03)
650
  ns = info.get('non_member_loss_std', 0.03)
651
  auc_v = info.get('auc', 0)
652
+ if key == "baseline": lbl = "Baseline"
 
653
  else:
654
+ eps = key.replace("smooth_eps_","")
655
  lbl = f"LS(\u03b5={eps})"
656
+
657
  thr = (mm + nm_m) / 2
658
  pred = loss < thr
659
  correct = pred == is_mem
 
660
  gauge = fig_gauge(loss, mm, nm_m, thr, ms, ns)
661
+
662
+ pl = "\u8bad\u7ec3\u6210\u5458" if pred else "\u975e\u8bad\u7ec3\u6210\u5458"
663
+ al = "\u8bad\u7ec3\u6210\u5458" if is_mem else "\u975e\u8bad\u7ec3\u6210\u5458"
664
+
665
  if correct and pred and is_mem:
666
+ v = "\u26a0\ufe0f **\u653b\u51fb\u6210\u529f\uff1a\u9690\u79c1\u6cc4\u9732**\n\n> \u6a21\u578b\u5bf9\u8be5\u6837\u672c\u8fc7\u4e8e\u719f\u6089\uff08Loss < \u9608\u503c\uff09\uff0c\u653b\u51fb\u8005\u6210\u529f\u5224\u5b9a\u4e3a\u8bad\u7ec3\u6570\u636e\u3002"
667
  elif correct:
668
+ v = "\u2705 **\u5224\u5b9a\u6b63\u786e**\n\n> \u653b\u51fb\u8005\u5224\u5b9a\u4e0e\u771f\u5b9e\u8eab\u4efd\u4e00\u81f4\u3002"
669
  else:
670
+ v = "\U0001f6e1\ufe0f **\u9632\u5fa1\u6210\u529f**\n\n> \u653b\u51fb\u8005\u5224\u5b9a\u9519\u8bef\uff0c\u9632\u5fa1\u8d77\u5230\u4e86\u4fdd\u62a4\u4f5c\u7528\u3002"
671
+
672
+ res = (v + f"\n\n**\u653b\u51fb\u76ee\u6807**: {lbl}\u3000|\u3000**AUC**: {auc_v:.4f}\n\n"
673
+ "| | \u653b\u51fb\u8005\u5224\u5b9a | \u771f\u5b9e\u8eab\u4efd |\n|---|---|---|\n"
674
+ f"| \u8eab\u4efd | {pl} | {al} |\n"
675
+ f"| Loss | {loss:.4f} | \u9608\u503c: {thr:.4f} |\n")
676
+ qtxt = f"**\u6837\u672c #{idx}**\n\n" + clean_text(sample.get('question',''))[:500]
 
677
  return qtxt, gauge, res
678
 
679
+ EVAL_CHOICES = (
680
+ ["\u57fa\u7ebf\u6a21\u578b"] +
681
+ [f"\u6807\u7b7e\u5e73\u6ed1 (\u03b5={e})" for e in [0.02, 0.05, 0.1, 0.2]] +
682
+ [f"\u8f93\u51fa\u6270\u52a8 (\u03c3={s})" for s in OP_SIGMAS]
 
683
  )
684
+ EVAL_KEY_MAP = {"\u57fa\u7ebf\u6a21\u578b": "baseline"}
 
685
  for e in [0.02, 0.05, 0.1, 0.2]:
686
+ EVAL_KEY_MAP[f"\u6807\u7b7e\u5e73\u6ed1 (\u03b5={e})"] = f"smooth_eps_{e}"
687
  for s in OP_SIGMAS:
688
+ EVAL_KEY_MAP[f"\u8f93\u51fa\u6270\u52a8 (\u03c3={s})"] = "baseline"
689
 
690
  def cb_eval(model_choice):
691
  k = EVAL_KEY_MAP.get(model_choice, "baseline")
692
+ acc = gu(k) if "\u8f93\u51fa\u6270\u52a8" not in model_choice else bl_acc
693
  q = EVAL_POOL[np.random.randint(len(EVAL_POOL))]
694
  ok = q.get(k, q.get('baseline', False))
695
+ ic = "\u2705 \u6b63\u786e" if ok else "\u274c \u9519\u8bef"
696
+ note = "\n\n> \u8f93\u51fa\u6270\u52a8\u4e0d\u6539\u53d8\u6a21\u578b\u53c2\u6570\uff0c\u51c6\u786e\u7387\u4e0e\u57fa\u7ebf\u4e00\u81f4\u3002" if "\u8f93\u51fa\u6270\u52a8" in model_choice else ""
697
+ return (f"**\u6a21\u578b**: {model_choice}\u3000(\u51c6\u786e\u7387: {acc:.1f}%)\n\n"
698
+ "| \u9879\u76ee | \u5185\u5bb9 |\n|---|---|\n"
699
+ f"| \u7c7b\u578b | {q['type_cn']} |\n"
700
+ f"| \u9898\u76ee | {q['question']} |\n"
701
+ f"| \u6b63\u786e\u7b54\u6848 | {q['answer']} |\n"
702
+ f"| \u5224\u5b9a | {ic} |{note}")
703
 
704
  # ================================================================
705
+ # 构建完整结果表
706
  # ================================================================
 
707
  def build_full_table():
708
  rows = []
709
+ for k, l in zip(LS_KEYS, LS_LABELS_CN):
 
710
  if k in mia_results:
711
+ m = mia_results[k]; u = gu(k)
712
+ t = "\u2014" if k == "baseline" else "\u8bad\u7ec3\u671f"
713
+ d = "" if k == "baseline" else f"{m['auc']-bl_auc:+.4f}"
 
714
  rows.append(f"| {l} | {t} | {m['auc']:.4f} | {m['attack_accuracy']:.4f} | "
715
  f"{m['precision']:.4f} | {m['recall']:.4f} | {m['f1']:.4f} | "
716
  f"{m['tpr_at_5fpr']:.4f} | {m['tpr_at_1fpr']:.4f} | "
717
+ f"{m['loss_gap']:.4f} | {u:.1f}% | {d} |")
718
+ for k, l in zip(OP_KEYS, OP_LABELS_CN):
 
719
  if k in perturb_results:
720
+ m = perturb_results[k]; d = f"{m['auc']-bl_auc:+.4f}"
721
+ rows.append(f"| {l} | \u63a8\u7406\u671f | {m['auc']:.4f} | {m['attack_accuracy']:.4f} | "
 
722
  f"{m['precision']:.4f} | {m['recall']:.4f} | {m['f1']:.4f} | "
723
  f"{m['tpr_at_5fpr']:.4f} | {m['tpr_at_1fpr']:.4f} | "
724
+ f"{m['loss_gap']:.4f} | {bl_acc:.1f}% | {d} |")
725
+ header = ("| \u7b56\u7565 | \u7c7b\u578b | AUC | Acc | Prec | Rec | F1 | TPR@5% | TPR@1% | LossGap | \u6548\u7528 | AUC\u0394 |\n"
726
+ "|---|---|---|---|---|---|---|---|---|---|---|---|")
727
  return header + "\n" + "\n".join(rows)
728
 
729
  # ================================================================
730
+ # CSS - 深色科技感主题
731
  # ================================================================
 
732
  CSS = """
733
+ /* === 全局深色背景 === */
734
+ body {
735
+ background: linear-gradient(135deg, #0f172a 0%, #1e1b4b 50%, #0f172a 100%) !important;
736
+ font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif !important;
737
+ }
738
+ .gradio-container {
739
+ max-width: 1280px !important;
740
+ margin: 30px auto !important;
741
+ }
742
+
743
+ /* === 标题区 - 玻璃态 === */
744
+ .title-area {
745
+ background: linear-gradient(135deg, rgba(30,41,59,0.95), rgba(30,27,75,0.95));
746
+ backdrop-filter: blur(20px);
747
+ padding: 32px 44px;
748
+ border-radius: 16px;
749
+ border: 1px solid rgba(59,130,246,0.3);
750
+ box-shadow: 0 0 40px rgba(59,130,246,0.15), 0 8px 32px rgba(0,0,0,0.3);
751
+ margin-bottom: 28px;
752
+ position: relative;
753
+ overflow: hidden;
754
+ }
755
+ .title-area::before {
756
+ content: '';
757
+ position: absolute;
758
+ top: 0; left: 0; right: 0;
759
+ height: 3px;
760
+ background: linear-gradient(90deg, #3b82f6, #8b5cf6, #06b6d4);
761
+ }
762
+ .title-area h1 {
763
+ color: #f1f5f9 !important;
764
+ font-size: 1.8rem !important;
765
+ font-weight: 800 !important;
766
+ margin: 0 0 8px 0 !important;
767
+ letter-spacing: -0.02em;
768
+ }
769
+ .title-area p {
770
+ color: #94a3b8 !important;
771
+ font-size: 1rem !important;
772
+ margin: 0 !important;
773
+ }
774
+ .title-area .badge {
775
+ display: inline-block;
776
+ background: linear-gradient(135deg, #3b82f6, #8b5cf6);
777
+ color: white;
778
+ padding: 4px 12px;
779
+ border-radius: 20px;
780
+ font-size: 0.8rem;
781
+ font-weight: 600;
782
+ margin-top: 8px;
783
+ }
784
+
785
+ /* === Tab面板 - 玻璃态 === */
786
+ .tabitem {
787
+ background: rgba(30,41,59,0.92) !important;
788
+ backdrop-filter: blur(10px) !important;
789
+ border-radius: 0 0 16px 16px !important;
790
+ border: 1px solid rgba(51,65,85,0.6) !important;
791
+ border-top: none !important;
792
+ box-shadow: 0 8px 32px rgba(0,0,0,0.2) !important;
793
+ padding: 36px 44px !important;
794
+ min-height: 800px !important;
795
+ overflow-y: auto !important;
796
+ }
797
+ .tabitem::-webkit-scrollbar { width: 6px; }
798
+ .tabitem::-webkit-scrollbar-track { background: transparent; }
799
+ .tabitem::-webkit-scrollbar-thumb { background: #475569; border-radius: 10px; }
800
+
801
+ /* === Tab导航 === */
802
  .tab-nav { border-bottom: none !important; gap: 4px !important; }
803
+ .tab-nav button {
804
+ font-size: 14px !important;
805
+ padding: 12px 22px !important;
806
+ font-weight: 600 !important;
807
+ color: #94a3b8 !important;
808
+ background: rgba(30,41,59,0.8) !important;
809
+ border: 1px solid rgba(51,65,85,0.5) !important;
810
+ border-bottom: none !important;
811
+ border-radius: 10px 10px 0 0 !important;
812
+ transition: all 0.3s ease !important;
813
+ }
814
+ .tab-nav button:hover {
815
+ color: #e2e8f0 !important;
816
+ background: rgba(59,130,246,0.15) !important;
817
+ }
818
+ .tab-nav button.selected {
819
+ color: #60a5fa !important;
820
+ background: rgba(30,41,59,0.95) !important;
821
+ border-top: 3px solid #3b82f6 !important;
822
+ box-shadow: 0 0 15px rgba(59,130,246,0.2) !important;
823
+ }
824
+
825
+ /* === 文字排版 === */
826
+ .prose { color: #cbd5e1 !important; }
827
+ .prose h2 {
828
+ font-size: 1.4rem !important;
829
+ color: #f1f5f9 !important;
830
+ font-weight: 700 !important;
831
+ margin-top: 0 !important;
832
+ padding-bottom: 12px !important;
833
+ border-bottom: 1px solid rgba(51,65,85,0.6) !important;
834
+ }
835
+ .prose h3 {
836
+ font-size: 1.1rem !important;
837
+ color: #e2e8f0 !important;
838
+ font-weight: 600 !important;
839
+ }
840
+ .prose strong { color: #f1f5f9 !important; }
841
+
842
+ /* === 表格 - 深色风格 === */
843
+ .prose table {
844
+ width: 100% !important;
845
+ border-collapse: separate !important;
846
+ border-spacing: 0 !important;
847
+ border-radius: 10px !important;
848
+ overflow: hidden !important;
849
+ border: 1px solid rgba(51,65,85,0.6) !important;
850
+ font-size: 0.85rem !important;
851
+ }
852
+ .prose th {
853
+ background: rgba(15,23,42,0.8) !important;
854
+ color: #94a3b8 !important;
855
+ font-weight: 600 !important;
856
+ padding: 10px 14px !important;
857
+ text-transform: uppercase;
858
+ font-size: 0.75rem !important;
859
+ letter-spacing: 0.05em;
860
+ }
861
+ .prose td {
862
+ padding: 10px 14px !important;
863
+ color: #e2e8f0 !important;
864
+ border-bottom: 1px solid rgba(51,65,85,0.3) !important;
865
+ background: rgba(30,41,59,0.5) !important;
866
+ }
867
+ .prose tr:hover td { background: rgba(59,130,246,0.08) !important; }
868
+
869
+ /* === 按钮 - 渐变发光 === */
870
+ button.primary {
871
+ background: linear-gradient(135deg, #3b82f6, #8b5cf6) !important;
872
+ color: white !important;
873
+ border: none !important;
874
+ border-radius: 8px !important;
875
+ font-weight: 700 !important;
876
+ padding: 10px 28px !important;
877
+ box-shadow: 0 0 20px rgba(59,130,246,0.3) !important;
878
+ transition: all 0.3s ease !important;
879
+ text-transform: uppercase !important;
880
+ letter-spacing: 0.05em !important;
881
+ }
882
+ button.primary:hover {
883
+ box-shadow: 0 0 30px rgba(59,130,246,0.5) !important;
884
+ transform: translateY(-2px) !important;
885
+ }
886
+
887
+ /* === 引用框 === */
888
+ .prose blockquote {
889
+ border-left: 4px solid #3b82f6 !important;
890
+ background: rgba(59,130,246,0.08) !important;
891
+ padding: 16px 20px !important;
892
+ border-radius: 0 10px 10px 0 !important;
893
+ color: #93c5fd !important;
894
+ }
895
+
896
+ /* === Accordion === */
897
+ .wrap { border: 1px solid rgba(51,65,85,0.5) !important; border-radius: 10px !important; background: rgba(15,23,42,0.5) !important; }
898
+
899
+ /* === 输入组件 === */
900
+ .block { border-color: rgba(51,65,85,0.5) !important; }
901
+ label { color: #94a3b8 !important; }
902
+
903
  footer { display: none !important; }
904
  """
905
 
906
  # ================================================================
907
  # 界面
908
  # ================================================================
909
+ with gr.Blocks(title="MIA\u653b\u9632\u7814\u7a76", theme=gr.themes.Base(), css=CSS) as demo:
 
910
 
911
  gr.HTML("""<div class="title-area">
912
+ <h1>\U0001f393 \u6559\u80b2\u5927\u6a21\u578b\u4e2d\u7684\u6210\u5458\u63a8\u7406\u653b\u51fb\u53ca\u5176\u9632\u5fa1\u7814\u7a76</h1>
913
+ <p>Membership Inference Attack & Defense on Educational LLM</p>
914
+ <div class="badge">11 Experiments \u00d7 8 Metrics \u00d7 2 Defense Strategies</div>
915
  </div>""")
916
 
917
+ # ═══ Tab 1: \u5b9e\u9a8c\u603b\u89c8 ═══
918
+ with gr.Tab("\U0001f4ca \u5b9e\u9a8c\u603b\u89c8"):
919
+ gr.Markdown(f"""## \u7814\u7a76\u80cc\u666f
920
 
921
+ \u5927\u8bed\u8a00\u6a21\u578b\u5728\u6559\u80b2\u9886\u57df\u7684\u5e94\u7528\u65e5\u76ca\u5e7f\u6cdb\uff0c\u6a21\u578b\u8bad\u7ec3\u4e0d\u53ef\u907f\u514d\u5730\u63a5\u89e6\u5b66\u751f\u654f\u611f\u6570\u636e\u3002**\u6210\u5458\u63a8\u7406\u653b\u51fb (MIA)** \u53ef\u5224\u65ad\u67d0\u6761\u6570\u636e\u662f\u5426\u53c2\u4e0e\u4e86\u8bad\u7ec3\uff0c\u6784\u6210\u9690\u79c1\u5a01\u80c1\u3002
922
 
923
+ \u672c\u7814\u7a76\u57fa\u4e8e **{model_name}** \u5fae\u8c03\u7684\u6570\u5b66\u8f85\u5bfc\u6a21\u578b\uff0c\u7cfb\u7edf\u9a8c\u8bc1MIA\u98ce\u9669\u5e76\u8bc4\u4f30\u4e24\u7c7b\u9632\u5fa1\u7b56\u7565\u3002
924
 
925
+ ### \u5b9e\u9a8c\u89c4\u6a21
926
+ | \u7ef4\u5ea6 | \u5185\u5bb9 |
927
+ |---|---|
928
+ | \u6a21\u578b | 1\u4e2a\u57fa\u7ebf + 4\u7ec4\u6807\u7b7e\u5e73\u6ed1 (\u03b5=0.02/0.05/0.1/0.2) |
929
+ | \u6270\u52a8 | 6\u7ec4\u8f93\u51fa\u6270\u52a8 (\u03c3=0.005/0.01/0.015/0.02/0.025/0.03) |
930
+ | \u6307\u6807 | AUC / \u653b\u51fb\u51c6\u786e\u7387 / \u7cbe\u786e\u7387 / \u53ec\u56de\u7387 / F1 / TPR@5%FPR / TPR@1%FPR / Loss\u5dee\u8ddd |
931
+ | \u6570\u636e | 2000\u6761\u8f85\u5bfc\u5bf9\u8bdd (1000\u6210\u5458 + 1000\u975e\u6210\u5458) |
932
+ | \u6548\u7528 | 300\u9053\u6570\u5b66\u6d4b\u8bd5\u9898 |
933
  """)
934
+ with gr.Accordion("\U0001f4cb \u5b8c\u6574\u5b9e\u9a8c\u7ed3\u679c\u8868\uff0811\u7ec4 \u00d7 8\u7ef4\u5ea6\uff09", open=True):
935
  gr.Markdown(build_full_table())
936
+ gr.Markdown("> **\u89e3\u8bfb**: AUC\u8d8a\u63a5\u8fd10.5 = \u9632\u5fa1\u8d8a\u6709\u6548\uff1b\u6548\u7528\u8d8a\u9ad8 = \u6a21\u578b\u80fd\u529b\u8d8a\u597d\u3002AUC\u0394\u4e3a\u76f8\u5bf9\u57fa\u7ebf\u7684\u53d8\u5316\u3002")
937
 
938
+ # ═══ Tab 2: \u6570\u636e\u4e0e\u6a21\u578b ═══
939
+ with gr.Tab("\U0001f4c1 \u6570\u636e\u4e0e\u6a21\u578b"):
940
+ gr.Markdown("""\
941
+ ## \u5b9e\u9a8c\u6570\u636e\u96c6
942
 
943
+ | \u6570\u636e\u7ec4 | \u6570\u91cf | \u7528\u9014 | \u8bf4\u660e |
944
  |---|---|---|---|
945
+ | \u6210\u5458\u6570\u636e | 1000\u6761 | \u6a21\u578b\u8bad\u7ec3 | \u6a21\u578b\u4f1a\"\u8bb0\u4f4f\"\uff0cLoss\u504f\u4f4e |
946
+ | \u975e\u6210\u5458\u6570\u636e | 1000\u6761 | \u653b\u51fb\u5bf9\u7167 | \u6a21\u578b\"\u6ca1\u89c1\u8fc7\"\uff0cLoss\u504f\u9ad8 |
947
 
948
+ | \u4efb\u52a1\u7c7b\u522b | \u6570\u91cf | \u5360\u6bd4 |
949
  |---|---|---|
950
+ | \u57fa\u7840\u8ba1\u7b97 | 800 | 40% |
951
+ | \u5e94\u7528\u9898 | 600 | 30% |
952
+ | \u6982\u5ff5\u95ee\u7b54 | 400 | 20% |
953
+ | \u9519\u9898\u8ba2\u6b63 | 200 | 10% |
954
 
955
+ > \u4e24\u7ec4\u6570\u636e\u683c\u5f0f\u5b8c\u5168\u76f8\u540c\uff08\u5747\u542b\u9690\u79c1\u5b57\u6bb5\uff09\uff0c\u653b\u51fb\u8005\u65e0\u6cd5\u4ece\u683c\u5f0f\u533a\u5206\u3002
956
  """)
957
+ gr.Markdown("### \u6570\u636e\u6837\u4f8b\u6d4f\u89c8")
958
  with gr.Row():
959
  with gr.Column(scale=2):
960
+ d_src = gr.Radio(["\u6210\u5458\u6570\u636e\uff08\u8bad\u7ec3\u96c6\uff09", "\u975e\u6210\u5458\u6570\u636e\uff08\u6d4b\u8bd5\u96c6\uff09"],
961
+ value="\u6210\u5458\u6570\u636e\uff08\u8bad\u7ec3\u96c6\uff09", label="\u6570\u636e\u6765\u6e90")
962
+ d_btn = gr.Button("\U0001f3b2 \u968f\u673a\u63d0\u53d6\u6837\u672c", variant="primary")
963
  d_meta = gr.Markdown()
964
  with gr.Column(scale=3):
965
+ d_q = gr.Textbox(label="\u5b66\u751f\u63d0\u95ee", lines=5, interactive=False)
966
+ d_a = gr.Textbox(label="\u6807\u51c6\u56de\u7b54", lines=5, interactive=False)
967
  d_btn.click(cb_sample, [d_src], [d_meta, d_q, d_a])
968
 
969
+ # ═══ Tab 3: \u653b\u51fb\u9a8c\u8bc1 ═══
970
+ with gr.Tab("\U0001f3af \u653b\u51fb\u9a8c\u8bc1"):
971
+ gr.Markdown("## \u6210\u5458\u63a8\u7406\u653b\u51fb\u4ea4\u4e92\u6f14\u793a\n\n\u9009\u62e9\u653b\u51fb\u76ee\u6807\u4e0e\u6570\u636e\u6e90\uff0c\u7cfb\u7edf\u6267\u884cLoss\u8ba1\u7b97\u5e76\u5224\u5b9a\u6570\u636e\u5f52\u5c5e\u3002")
972
  with gr.Row():
973
  with gr.Column(scale=2):
974
+ a_target = gr.Radio(ATK_CHOICES, value=ATK_CHOICES[0], label="\u653b\u51fb\u76ee\u6807")
975
+ a_src = gr.Radio(["\u6210\u5458\u6570\u636e\uff08\u8bad\u7ec3\u96c6\uff09", "\u975e\u6210\u5458\u6570\u636e\uff08\u6d4b\u8bd5\u96c6\uff09"],
976
+ value="\u6210\u5458\u6570\u636e\uff08\u8bad\u7ec3\u96c6\uff09", label="\u6570\u636e\u6765\u6e90")
977
+ a_idx = gr.Slider(0, 999, step=1, value=12, label="\u6837\u672cID")
978
+ a_btn = gr.Button("\u26a1 \u6267\u884c\u6210\u5458\u63a8\u7406\u653b\u51fb", variant="primary", size="lg")
979
  a_qtxt = gr.Markdown()
980
  with gr.Column(scale=3):
981
+ a_gauge = gr.Plot(label="Loss Decision Boundary")
982
  a_res = gr.Markdown()
983
  a_btn.click(cb_attack, [a_idx, a_src, a_target], [a_qtxt, a_gauge, a_res])
984
 
985
+ # ═══ Tab 4: \u591a\u7ef4\u5ea6\u9632\u5fa1\u5206\u6790\uff08\u6838\u5fc3\u5347\u7ea7\uff01\uff09 ═══
986
+ with gr.Tab("\U0001f6e1\ufe0f \u9632\u5fa1\u5206\u6790"):
987
+
988
+ gr.Markdown("## \u591a\u7ef4\u5ea6\u653b\u9632\u6548\u679c\u5bf9\u6bd4\u5206\u6790")
989
+
990
+ # ---- \u7b2c\u4e00\u7ec4: AUC\u5bf9\u6bd4 + \u96f7\u8fbe\u56fe ----
991
+ gr.Markdown("### 1\ufe0f\u20e3 \u653b\u51fb\u6210\u529f\u7387\u5bf9\u6bd4 (AUC)")
992
+ gr.Markdown("""\
993
+ > **AUC (Area Under ROC Curve)** \u662fMIA\u653b\u51fb\u7684\u6838\u5fc3\u6307\u6807\u3002AUC=0.5\u8868\u793a\u968f\u673a\u731c\u6d4b\uff0c\u8d8a\u9ad8\u8868\u793a\u653b\u51fb\u8d8a\u6210\u529f\u3002
994
+ > - \u57fa\u7ebf\u6a21\u578bAUC=**{0}** \u2192 \u653b\u51fb\u8005\u660e\u663e\u4f18\u4e8e\u968f\u673a\u731c\u6d4b
995
+ > - \u6807\u7b7e\u5e73\u6ed1\u5c06AUC\u4ece{0}\u964d\u81f3{1}\uff0c\u964d\u5e45**{2:.1f}%**
996
+ > - \u8f93\u51fa\u6270\u52a8\u5c06AUC\u4ece{0}\u964d\u81f3{3}\uff0c\u964d\u5e45**{4:.1f}%**
997
+ """.format(f"{bl_auc:.4f}",
998
+ f"{gm('smooth_eps_0.2','auc'):.4f}",
999
+ (bl_auc - gm('smooth_eps_0.2','auc'))/bl_auc*100,
1000
+ f"{gm('perturbation_0.03','auc'):.4f}",
1001
+ (bl_auc - gm('perturbation_0.03','auc'))/bl_auc*100))
1002
+ gr.Plot(value=fig_auc_bar())
1003
+
1004
+ gr.Markdown("### 2\ufe0f\u20e3 \u591a\u6307\u6807\u96f7\u8fbe\u56fe\u5bf9\u6bd4\uff08Baseline vs \u9632\u5fa1\uff09")
1005
+ gr.Markdown("""\
1006
+ > \u96f7\u8fbe\u56fe\u540c\u65f6\u5c55\u793a8\u4e2a\u7ef4\u5ea6\u7684\u653b\u51fb\u80fd\u529b\u3002\u7ea2\u8272\uff08Baseline\uff09\u9762\u79ef\u8d8a\u5927 = \u653b\u51fb\u8d8a\u5f3a\uff1b\u84dd\u8272/\u7eff\u8272\uff08\u9632\u5fa1\uff09\u9762\u79ef\u8d8a\u5c0f = \u9632\u5fa1\u8d8a\u6709\u6548\u3002
1007
+ > \u53ef\u4ee5\u770b\u5230\u9632\u5fa1\u540e\u6240\u6709\u7ef4\u5ea6\u5747\u6709\u7f29\u5c0f\uff0c\u7279\u522b\u662fTPR@\u4f4eFPR\u4e0bLoss Gap\u964d\u5e45\u663e\u8457\u3002
1008
+ """)
1009
+ gr.Plot(value=fig_radar_compare())
1010
+
1011
+ # ---- \u7b2c\u4e8c\u7ec4: ROC\u66f2\u7ebf ----
1012
+ gr.Markdown("### 3\ufe0f\u20e3 ROC\u66f2\u7ebf\u5bf9\u6bd4\uff08\u653b\u51fb\u6548\u679c\u7684\u6700\u76f4\u63a5\u8bc1\u636e\uff09")
1013
+ gr.Markdown("""\
1014
+ > **ROC\u66f2\u7ebf**\u5c55\u793a\u4e86\u5728\u4e0d\u540c\u5224\u5b9a\u9608\u503c\u4e0b\uff0c\u653b\u51fb\u8005\u7684\u771f\u9633\u6027\u7387(TPR)\u4e0e\u5047\u9633\u6027\u7387(FPR)\u7684\u5173\u7cfb\u3002
1015
+ > - **\u66f2\u7ebf\u8d8a\u9760\u8fd1\u5de6\u4e0a\u89d2** = \u653b\u51fb\u8d8a\u6210\u529f\uff08\u9ad8TPR\u4f4eFPR\uff09
1016
+ > - **\u66f2\u7ebf\u8d8a\u63a5\u8fd1\u5bf9\u89d2\u7ebf** = \u653b\u51fb\u8d8a\u63a5\u8fd1\u968f\u673a\u731c\u6d4b = \u9632\u5fa1\u8d8a\u6709\u6548
1017
+ > - \u5de6\u56fe\uff1a\u968f\u7740\u03b5\u589e\u5927\uff0cROC\u66f2\u7ebf\u9010\u6e10\u5411\u5bf9\u89d2\u7ebf\u9760\u62e2\uff0c\u8bc1\u660e\u6807\u7b7e\u5e73\u6ed1\u6709\u6548\u964d\u4f4e\u653b\u51fb\u80fd\u529b
1018
+ > - \u53f3\u56fe\uff1a\u968f\u7740\u03c3\u589e\u5927\uff0c\u540c\u6837\u7684\u8d8b\u52bf\uff0c\u8bc1\u660e\u8f93\u51fa\u6270\u52a8\u6709\u6548\u906e\u853d\u653b\u51fb\u4fe1\u53f7
1019
+ """)
1020
+ gr.Plot(value=fig_roc_curves())
1021
+
1022
+ # ---- \u7b2c\u4e09\u7ec4: TPR@\u4f4eFPR ----
1023
+ gr.Markdown("### 4\ufe0f\u20e3 \u4f4e\u8bef\u62a5\u7387\u4e0b\u7684\u653b\u51fb\u80fd\u529b\uff08TPR@FPR\uff09")
1024
+ gr.Markdown(f"""\
1025
+ > **\u8fd9\u662f\u8861\u91cf\u653b\u51fb\u5371\u5bb3\u7684\u6700\u4e25\u683c\u6307\u6807\u3002** \u5728\u73b0\u5b9e\u4e2d\uff0c\u653b\u51fb\u8005\u901a\u5e38\u91c7\u7528\u4fdd\u5b88\u7b56\u7565\uff08\u4f4e\u8bef\u62a5\uff09\uff0c\u53ea\u5bf9\u201c\u5f88\u6709\u628a\u63e1\u201d\u7684\u6837\u672c\u4e0b\u7ed3\u8bba\u3002
1026
+ >
1027
+ > - **TPR@5%FPR**: \u6bcf\u8bef\u5224100\u4e2a\u65e0\u8f9c\u4eba\u4e2d\u7684\u7ea65\u4e2a\uff0c\u80fd\u6b63\u786e\u8bc6\u522b\u51fa\u591a\u5c11\u771f\u6210\u5458
1028
+ > - **TPR@1%FPR**: \u66f4\u4e25\u683c\uff0c\u6bcf\u8bef\u5224100\u4e2a\u4e2d\u7684\u7ea61\u4e2a
1029
+ >
1030
+ > \u57fa\u7ebf\u6a21\u578b TPR@5%FPR = **{gm('baseline','tpr_at_5fpr'):.4f}**\uff0c\u610f\u5473\u7740\u653b\u51fb\u8005\u5728\u4ec5\u201c\u5192\u72af\u201d5%\u7684\u98ce\u9669\u4e0b\uff0c
1031
+ > \u4ecd\u80fd\u8bc6\u522b\u51fa **{gm('baseline','tpr_at_5fpr')*100:.1f}%** \u7684\u8bad\u7ec3\u6210\u5458\u3002\u9632\u5fa1\u540e\u6b64\u6bd4\u4f8b\u5927\u5e45\u4e0b\u964d\u3002
1032
+ """)
1033
+ gr.Plot(value=fig_tpr_at_low_fpr())
1034
+
1035
+ # ---- \u7b2c\u56db\u7ec4: Loss\u5dee\u8ddd ----
1036
+ gr.Markdown("### 5\ufe0f\u20e3 Loss\u5dee\u8ddd\u5bf9\u6bd4\uff08\u653b\u51fb\u7684\u6839\u6e90\uff09")
1037
+ gr.Markdown("""\
1038
+ > **Loss Gap = \u975e\u6210\u5458\u5e73\u5747Loss - \u6210\u5458\u5e73\u5747Loss**\u3002\u8fd9\u662fMIA\u653b\u51fb\u5f97\u4ee5\u5b9e\u65bd\u7684\u6839\u672c\u539f\u56e0\u3002
1039
+ > Gap\u8d8a\u5927 = \u6210\u5458\u548c\u975e\u6210\u5458\u8d8a\u5bb9\u6613\u533a\u5206 = \u653b\u51fb\u8d8a\u5bb9\u6613\u6210\u529f\u3002
1040
+ > \u9632\u5fa1\u7684\u76ee\u6807\u5c31\u662f\u7f29\u5c0f\u8fd9\u4e2a\u5dee\u8ddd\u3002
1041
+ """)
1042
+ gr.Plot(value=fig_loss_gap_waterfall())
1043
+
1044
+ # ---- \u7b2c\u4e94\u7ec4: \u53c2\u6570\u8d8b\u52bf ----
1045
+ gr.Markdown("### 6\ufe0f\u20e3 \u9632\u5fa1\u53c2\u6570\u4e0e\u6548\u679c\u7684\u5173\u7cfb\uff08\u8d8b\u52bf\u7ebf\uff09")
1046
+ gr.Markdown("""\
1047
+ > \u5de6\u56fe\u5c55\u793a\u6807\u7b7e\u5e73\u6ed1\u53c2\u6570\u03b5\u4e0eAUC/\u6548\u7528\u7684\u53cc\u8f74\u5173\u7cfb\uff1a
1048
+ > - \u7ea2\u7ebf(AUC)\u5355\u8c03\u9012\u51cf \u2192 \u9632\u5fa1\u6301\u7eed\u589e\u5f3a
1049
+ > - \u7eff\u7ebf(\u6548\u7528)\u53cd\u5347 \u2192 \u6b63\u5219\u5316\u63d0\u5347\u6cdb\u5316\uff08\u201c\u53cc\u8d62\u201d\uff09
1050
+ >
1051
+ > \u53f3\u56fe\u5c55\u793a\u8f93\u51fa\u6270\u52a8\u53c2\u6570\u03c3\u4e0eAUC\u7684\u5173\u7cfb\uff1a
1052
+ > - \u7eff\u8272\u586b\u5145\u533a\u57df = AUC\u964d\u4f4e\u91cf = \u9632\u5fa1\u6536\u76ca
1053
+ > - \u03c3\u8d8a\u5927\u964d\u4f4e\u8d8a\u591a\uff0c\u4e14\u6548\u7528\u59cb\u7ec8\u4e0d\u53d8
1054
+ """)
1055
+ gr.Plot(value=fig_auc_trend())
1056
 
1057
+ # ---- \u7b2c\u516d\u7ec4: Loss\u5206\u5e03 ----
1058
+ gr.Markdown("### 7\ufe0f\u20e3 Loss\u5206\u5e03\u53ef\u89c6\u5316")
1059
+ with gr.Accordion("\u6807\u7b7e\u5e73\u6ed1\u6a21\u578b\u7684Loss\u5206\u5e03\uff085\u4e2a\u6a21\u578b\uff09", open=False):
1060
+ gr.Markdown("> \u84dd\u8272=\u6210\u5458\uff0c\u7ea2\u8272=\u975e\u6210\u5458\u3002\u4e24\u8272\u91cd\u53e0\u8d8a\u591a = \u653b\u51fb\u8005\u8d8a\u96be\u533a\u5206 = \u9632\u5fa1\u8d8a\u6709\u6548")
1061
  gr.Plot(value=fig_loss_dist())
1062
+ with gr.Accordion("\u8f93\u51fa\u6270\u52a8\u7684Loss\u5206\u5e03\uff086\u7ec4\u03c3\uff09", open=False):
1063
+ gr.Markdown("> \u5728\u57fa\u7ebf\u6a21\u578bLoss\u4e0a\u52a0\u566a\u58f0\uff0c\u968f\u03c3\u589e\u5927\u5206\u5e03\u66f4\u52a0\u91cd\u53e0")
1064
  gr.Plot(value=fig_perturb_dist())
1065
 
1066
+ # ---- \u5b8c\u6574\u6570\u636e\u8868 ----
1067
+ with gr.Accordion("\U0001f4cb \u5b8c\u6574\u6570\u636e\u8868 + \u9632\u5fa1\u673a\u5236\u8bf4\u660e", open=False):
1068
  gr.Markdown(build_full_table())
1069
+ gr.Markdown("""\
1070
+ ### \u9632\u5fa1\u673a\u5236\u5bf9\u6bd4
1071
 
1072
+ | \u7ef4\u5ea6 | \u6807\u7b7e\u5e73\u6ed1 | \u8f93\u51fa\u6270\u52a8 |
1073
  |---|---|---|
1074
+ | **\u9636\u6bb5** | \u8bad\u7ec3\u671f | \u63a8\u7406\u671f |
1075
+ | **\u539f\u7406** | \u8f6f\u5316\u6807\u7b7e\u964d\u4f4e\u8bb0\u5fc6 | Loss\u52a0\u566a\u906e\u853d\u4fe1\u53f7 |
1076
+ | **\u9700\u91cd\u8bad** | \u662f | \u5426 |
1077
+ | **\u6548\u7528\u5f71\u54cd** | \u6b63\u5219\u5316\u53ef\u80fd\u63d0\u5347 | \u5b8c\u5168\u65e0\u5f71\u54cd |
1078
+ | **\u90e8\u7f72** | \u8bad\u7ec3\u65f6\u4ecb\u5165 | \u5373\u63d2\u5373\u7528 |
1079
 
1080
+ **\u6807\u7b7e\u5e73\u6ed1\u516c\u5f0f**: `y_smooth = (1 - \u03b5) \u00d7 y_onehot + \u03b5 / V`
1081
 
1082
+ **\u8f93\u51fa\u6270\u52a8\u516c\u5f0f**: `L_perturbed = L_original + N(0, \u03c3\u00b2)`
1083
  """)
1084
 
1085
+ # ═══ Tab 5: \u6548\u7528\u8bc4\u4f30 ═══
1086
+ with gr.Tab("\u2696\ufe0f \u6548\u7528\u8bc4\u4f30"):
1087
+ gr.Markdown("## \u6a21\u578b\u6548\u7528\u6d4b\u8bd5\n\n> \u57fa\u4e8e300\u9053\u6570\u5b66\u6d4b\u8bd5\u9898\u8bc4\u4f30\u5404\u7b56\u7565\u7684\u5b9e\u9645\u80fd\u529b\u5f71\u54cd")
1088
+
1089
+ gr.Markdown("### \u6548\u7528\u5bf9\u6bd4\u4e0e\u9690\u79c1-\u6548\u7528\u6743\u8861")
1090
  with gr.Row():
1091
  with gr.Column():
1092
  gr.Plot(value=fig_acc_bar())
1093
  with gr.Column():
1094
  gr.Plot(value=fig_tradeoff())
1095
 
1096
+ gr.Markdown(f"""\
1097
+ ### \u6548\u7528\u5206\u6790
1098
+
1099
+ > **\u5173\u952e\u53d1\u73b0\uff1a\u6807\u7b7e\u5e73\u6ed1\u5b9e\u73b0\u4e86\u201c\u9690\u79c1-\u6548\u7528\u53cc\u8d62\u201d**
1100
+ >
1101
+ > - \u57fa\u7ebf\u6548\u7528: **{bl_acc:.1f}%**
1102
+ > - LS(\u03b5=0.1): **{gu('smooth_eps_0.1'):.1f}%** (\u2191{gu('smooth_eps_0.1')-bl_acc:+.1f}%)\uff0c\u540c\u65f6AUC\u4e0b\u964d\u81f3{gm('smooth_eps_0.1','auc'):.4f}
1103
+ > - LS(\u03b5=0.2): **{gu('smooth_eps_0.2'):.1f}%** (\u2191{gu('smooth_eps_0.2')-bl_acc:+.1f}%)\uff0c\u540c\u65f6AUC\u4e0b\u964d\u81f3{gm('smooth_eps_0.2','auc'):.4f}
1104
+ > - \u8f93\u51fa\u6270\u52a8\uff1a\u6548\u7528\u59cb\u7ec8\u4e3a{bl_acc:.1f}%\uff08\u96f6\u635f\u5931\uff09
1105
+ >
1106
+ > \u6807\u7b7e\u5e73\u6ed1\u7684\u6b63\u5219\u5316\u6548\u5e94\u9632\u6b62\u4e86\u8fc7\u62df\u5408\uff0c\u63d0\u5347\u4e86\u6cdb\u5316\u80fd\u529b\u3002
1107
+ """)
1108
+
1109
+ gr.Markdown("### \u5728\u7ebf\u62bd\u6837\u6f14\u793a")
1110
  with gr.Row():
1111
  with gr.Column(scale=1):
1112
+ e_model = gr.Radio(EVAL_CHOICES, value="\u57fa\u7ebf\u6a21\u578b", label="\u9009\u62e9\u6a21\u578b")
1113
+ e_btn = gr.Button("\U0001f9ea \u968f\u673a\u62bd\u9898\u6d4b\u8bd5", variant="primary")
1114
  with gr.Column(scale=2):
1115
  e_res = gr.Markdown()
1116
  e_btn.click(cb_eval, [e_model], [e_res])
1117
 
1118
+ # ═══ Tab 6: \u7814\u7a76\u7ed3\u8bba ═══
1119
+ with gr.Tab("\U0001f4dd \u7814\u7a76\u7ed3\u8bba"):
1120
+ gr.Markdown(f"""\
1121
+ ## \u6838\u5fc3\u7814\u7a76\u53d1\u73b0
1122
 
1123
  ---
1124
 
1125
+ ### \u4e00\u3001\u6559\u80b2\u5927\u6a21\u578b\u5b58\u5728\u53ef\u91cf\u5316\u7684MIA\u98ce\u9669
1126
 
1127
+ | \u6307\u6807 | \u57fa\u7ebf\u503c | \u542b\u4e49 |
1128
+ |---|---|---|
1129
+ | AUC | **{bl_auc:.4f}** | \u653b\u51fb\u8005\u660e\u663e\u4f18\u4e8e\u968f\u673a\u731c\u6d4b(0.5) |
1130
+ | \u653b\u51fb\u51c6\u786e\u7387 | **{gm('baseline','attack_accuracy')*100:.1f}%** | \u8d85\u8fc7\u534a\u6570\u6837\u672c\u88ab\u6b63\u786e\u5224\u5b9a |
1131
+ | TPR@5%FPR | **{gm('baseline','tpr_at_5fpr'):.4f}** | \u4f4e\u8bef\u62a5\u4e0b\u4ecd\u53ef\u8bc6\u522b{gm('baseline','tpr_at_5fpr')*100:.1f}%\u6210\u5458 |
1132
+ | Loss Gap | **{gm('baseline','loss_gap'):.4f}** | \u6210\u5458\u4e0e\u975e\u6210\u5458\u5b58\u5728\u53ef\u5229\u7528\u7684\u5dee\u5f02 |
1133
 
1134
+ ### \u4e8c\u3001\u6807\u7b7e\u5e73\u6ed1\u9632\u5fa1\u6548\u679c
 
 
 
 
 
1135
 
1136
+ | \u53c2\u6570 | AUC | AUC\u964d\u5e45 | \u6548\u7528 | \u7279\u70b9 |
1137
+ |---|---|---|---|---|
1138
+ | \u03b5=0.02 | {gm('smooth_eps_0.02','auc'):.4f} | {bl_auc-gm('smooth_eps_0.02','auc'):.4f} | {gu('smooth_eps_0.02'):.1f}% | \u8f7b\u5ea6\u9632\u5fa1 |
1139
+ | \u03b5=0.05 | {gm('smooth_eps_0.05','auc'):.4f} | {bl_auc-gm('smooth_eps_0.05','auc'):.4f} | {gu('smooth_eps_0.05'):.1f}% | \u6e29\u548c\u9632\u5fa1 |
1140
+ | \u03b5=0.1 | {gm('smooth_eps_0.1','auc'):.4f} | {bl_auc-gm('smooth_eps_0.1','auc'):.4f} | {gu('smooth_eps_0.1'):.1f}% | **\u63a8\u8350\u914d\u7f6e** |
1141
+ | \u03b5=0.2 | {gm('smooth_eps_0.2','auc'):.4f} | {bl_auc-gm('smooth_eps_0.2','auc'):.4f} | {gu('smooth_eps_0.2'):.1f}% | \u5f3a\u529b\u9632\u5fa1 |
1142
 
1143
+ ### \u4e09\u3001\u8f93\u51fa\u6270\u52a8\u9632\u5fa1\u6548\u679c
1144
 
1145
+ | \u53c2\u6570 | AUC | AUC\u964d\u5e45 | \u6548\u7528 | \u7279\u70b9 |
1146
+ |---|---|---|---|---|
1147
+ | \u03c3=0.005 | {gm('perturbation_0.005','auc'):.4f} | {bl_auc-gm('perturbation_0.005','auc'):.4f} | {bl_acc:.1f}% | \u5fae\u5f31\u6270\u52a8 |
1148
+ | \u03c3=0.01 | {gm('perturbation_0.01','auc'):.4f} | {bl_auc-gm('perturbation_0.01','auc'):.4f} | {bl_acc:.1f}% | \u8f7b\u5ea6\u6270\u52a8 |
1149
+ | \u03c3=0.02 | {gm('perturbation_0.02','auc'):.4f} | {bl_auc-gm('perturbation_0.02','auc'):.4f} | {bl_acc:.1f}% | **\u63a8\u8350\u914d\u7f6e** |
1150
+ | \u03c3=0.03 | {gm('perturbation_0.03','auc'):.4f} | {bl_auc-gm('perturbation_0.03','auc'):.4f} | {bl_acc:.1f}% | \u5f3a\u529b\u6270\u52a8 |
1151
 
1152
+ > **\u96f6\u6548\u7528\u635f\u5931\uff0c\u9002\u5408\u5df2\u90e8\u7f72\u7cfb\u7edf\u7684\u540e\u671f\u52a0\u56fa\u3002**
1153
 
1154
+ ### \u56db\u3001\u6700\u4f73\u5b9e\u8df5\u5efa\u8bae
1155
 
1156
+ > \u4e24\u7c7b\u7b56\u7565\u673a\u5236\u4e92\u8865\uff1a\u6807\u7b7e\u5e73\u6ed1\u4ece\u8bad\u7ec3\u9636\u6bb5\u964d\u4f4e\u8bb0\u5fc6\uff0c\u8f93\u51fa\u6270\u52a8\u4ece\u63a8\u7406\u9636\u6bb5\u906e\u853d\u4fe1\u53f7\u3002
1157
+ > **\u63a8\u8350\u7ec4\u5408: LS(\u03b5=0.1) + OP(\u03c3=0.02)** \u2014 \u517c\u987e\u9690\u79c1\u4fdd\u62a4\u4e0e\u6a21\u578b\u6548\u7528\u3002
1158
  """)
1159
 
1160
  demo.launch()