xiaohy commited on
Commit
bac3c00
·
verified ·
1 Parent(s): 72d2d35

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +403 -673
app.py CHANGED
@@ -1,10 +1,6 @@
1
  # ================================================================
2
- # 🎓 教育��模型中的成员推理攻击及其防御研究
3
- # Membership Inference Attack & Defense in Educational LLMs
4
- # ================================================================
5
- # 部署平台:Hugging Face Spaces (永久免费)
6
- # SDK:Gradio
7
- # 硬件:CPU basic (Free) — 不需要 GPU
8
  # ================================================================
9
 
10
  import os
@@ -22,179 +18,129 @@ import gradio as gr
22
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
23
 
24
 
25
- def load_json(relative_path):
26
- """安全加载 JSON 文件"""
27
- path = os.path.join(BASE_DIR, relative_path)
28
- if not os.path.exists(path):
29
- raise FileNotFoundError(f"文件不存在: {path}")
30
- with open(path, 'r', encoding='utf-8') as f:
31
  return json.load(f)
32
 
33
 
34
- # 训练/测试数据
35
  member_data = load_json("data/member.json")
36
  non_member_data = load_json("data/non_member.json")
37
-
38
- # 实验结果
39
  mia_results = load_json("results/mia_results.json")
40
  utility_results = load_json("results/utility_results.json")
41
  perturb_results = load_json("results/perturbation_results.json")
42
  full_results = load_json("results/mia_full_results.json")
43
-
44
- # 项目配置
45
  config = load_json("config.json")
46
 
47
- # 字体设置(兼容 Spaces 环境)
48
- plt.rcParams['font.sans-serif'] = ['DejaVu Sans', 'Arial', 'sans-serif']
49
  plt.rcParams['axes.unicode_minus'] = False
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  # ========================================
53
- # 2. 图表生成函数
54
  # ========================================
55
 
56
  def make_pie_chart():
57
- """数据集任务分布饼图"""
58
  task_counts = {}
59
  for item in member_data + non_member_data:
60
  t = item.get('task_type', 'unknown')
61
  task_counts[t] = task_counts.get(t, 0) + 1
62
-
63
  name_map = {
64
- 'calculation': 'Calculation (40%)',
65
- 'word_problem': 'Word Problem (30%)',
66
- 'concept': 'Concept Q&A (20%)',
67
- 'error_correction': 'Error Correction (10%)'
68
  }
69
-
70
  labels = [name_map.get(k, k) for k in task_counts]
71
  sizes = list(task_counts.values())
72
  colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4']
73
-
74
  fig, ax = plt.subplots(figsize=(8, 6))
75
- wedges, texts, autotexts = ax.pie(
76
- sizes,
77
- labels=labels,
78
- autopct='%1.1f%%',
79
- colors=colors[:len(labels)],
80
- explode=[0.04] * len(labels),
81
- shadow=True,
82
- startangle=90,
83
- textprops={'fontsize': 11}
84
- )
85
- for t in autotexts:
86
- t.set_fontsize(12)
87
- t.set_fontweight('bold')
88
-
89
- ax.set_title(
90
- 'Dataset Task Distribution (2000 samples)',
91
- fontsize=15, fontweight='bold', pad=15
92
  )
 
93
  plt.tight_layout()
94
  return fig
95
 
96
 
97
  def make_loss_distribution():
98
- """Loss 分布直方图(使用真实 loss 数据)"""
99
  plot_items = []
100
- for k, t in [('baseline', 'Baseline'),
101
- ('smooth_0.02', 'Label Smoothing e=0.02'),
102
- ('smooth_0.2', 'Label Smoothing e=0.2')]:
103
  if k in full_results:
104
  auc = mia_results.get(k, {}).get('auc', 0)
105
- plot_items.append((k, f"{t}\nAUC = {auc:.4f}"))
106
-
107
  n = len(plot_items)
108
  if n == 0:
109
  fig, ax = plt.subplots()
110
- ax.text(0.5, 0.5, 'No data available', ha='center', va='center')
111
  return fig
112
-
113
  fig, axes = plt.subplots(1, n, figsize=(6 * n, 5))
114
  if n == 1:
115
  axes = [axes]
116
-
117
  for ax, (k, title) in zip(axes, plot_items):
118
- m_losses = full_results[k]['member_losses']
119
- nm_losses = full_results[k]['non_member_losses']
120
-
121
- all_losses = m_losses + nm_losses
122
- bins = np.linspace(min(all_losses), max(all_losses), 40)
123
-
124
- ax.hist(m_losses, bins=bins, alpha=0.55, color='#4A90D9',
125
- label=f'Members (u={np.mean(m_losses):.3f})', density=True)
126
- ax.hist(nm_losses, bins=bins, alpha=0.55, color='#E74C3C',
127
- label=f'Non-Members (u={np.mean(nm_losses):.3f})', density=True)
128
-
129
- ax.set_title(title, fontsize=13, fontweight='bold')
130
- ax.set_xlabel('Loss Value', fontsize=11)
131
- ax.set_ylabel('Density', fontsize=11)
132
  ax.legend(fontsize=9)
133
  ax.grid(True, linestyle='--', alpha=0.4)
134
-
135
- plt.suptitle(
136
- 'Member vs Non-Member Loss Distribution',
137
- fontsize=16, fontweight='bold', y=1.02
138
- )
139
  plt.tight_layout()
140
  return fig
141
 
142
 
143
  def make_auc_bar():
144
- """所有防御策略 AUC 柱状图"""
145
- methods = []
146
- aucs = []
147
- colors = []
148
-
149
- # MIA 模型结果
150
- for k, name, c in [
151
- ('baseline', 'Baseline', '#95A5A6'),
152
- ('smooth_0.02', 'LS e=0.02', '#5B9BD5'),
153
- ('smooth_0.2', 'LS e=0.2', '#2E5FA1'),
154
- ]:
155
  if k in mia_results:
156
  methods.append(name)
157
  aucs.append(mia_results[k]['auc'])
158
  colors.append(c)
159
-
160
- # 输出扰动结果
161
- for k, name, c in [
162
- ('perturbation_0.01', 'OP s=0.01', '#27AE60'),
163
- ('perturbation_0.015', 'OP s=0.015', '#1E8449'),
164
- ('perturbation_0.02', 'OP s=0.02', '#145A32'),
165
- ]:
166
  if k in perturb_results:
167
  methods.append(name)
168
  aucs.append(perturb_results[k]['auc'])
169
  colors.append(c)
170
-
171
  fig, ax = plt.subplots(figsize=(11, 6))
172
- bars = ax.bar(
173
- methods, aucs, color=colors, width=0.55,
174
- edgecolor='white', linewidth=1.5
175
- )
176
-
177
- # 数值标签
178
- for bar, auc_val in zip(bars, aucs):
179
- ax.text(
180
- bar.get_x() + bar.get_width() / 2,
181
- bar.get_height() + 0.004,
182
- f'{auc_val:.3f}',
183
- ha='center', va='bottom', fontsize=13, fontweight='bold'
184
- )
185
-
186
- # 参考线
187
- baseline_auc = mia_results.get('baseline', {}).get('auc', 0.63)
188
- ax.axhline(y=0.5, color='red', linestyle='--', linewidth=2,
189
- label='Random Guess (AUC=0.5)')
190
- ax.axhline(y=baseline_auc, color='black', linestyle=':',
191
- linewidth=1.5, label='Baseline Risk')
192
-
193
- ax.set_ylabel('MIA Attack AUC', fontsize=13)
194
- ax.set_title(
195
- 'Comparison of All Defense Mechanisms',
196
- fontsize=15, fontweight='bold'
197
- )
198
  ax.set_ylim(0.45, max(aucs) + 0.06 if aucs else 1.0)
199
  ax.legend(fontsize=11)
200
  ax.grid(axis='y', linestyle='--', alpha=0.4)
@@ -204,71 +150,35 @@ def make_auc_bar():
204
 
205
 
206
  def make_tradeoff():
207
- """隐私-效用权衡散点图"""
208
  fig, ax = plt.subplots(figsize=(10, 7))
209
  points = []
210
-
211
- # MIA 模型
212
  for k, name, marker, color, sz in [
213
- ('baseline', 'Baseline (No Defense)', 'o', 'black', 180),
214
- ('smooth_0.02', 'Label Smoothing e=0.02', 's', '#5B9BD5', 160),
215
- ('smooth_0.2', 'Label Smoothing e=0.2', 's', '#2E5FA1', 160),
216
- ]:
217
  if k in mia_results and k in utility_results:
218
- points.append({
219
- 'name': name,
220
- 'auc': mia_results[k]['auc'],
221
- 'acc': utility_results[k]['accuracy'],
222
- 'marker': marker, 'color': color, 'size': sz
223
- })
224
-
225
- # 输出扰动(准确率 = 基线准确率)
226
  base_acc = utility_results.get('baseline', {}).get('accuracy', 0.633)
227
  for k, name, marker, color, sz in [
228
- ('perturbation_0.01', 'Output Perturb s=0.01', '^', '#27AE60', 170),
229
- ('perturbation_0.02', 'Output Perturb s=0.02', '^', '#145A32', 170),
230
- ]:
231
  if k in perturb_results:
232
- points.append({
233
- 'name': name,
234
- 'auc': perturb_results[k]['auc'],
235
- 'acc': base_acc,
236
- 'marker': marker, 'color': color, 'size': sz
237
- })
238
-
239
  for p in points:
240
- ax.scatter(
241
- p['acc'], p['auc'],
242
- label=p['name'], marker=p['marker'], color=p['color'],
243
- s=p['size'], edgecolors='white', linewidth=1.5, zorder=5
244
- )
245
-
246
- ax.axhline(y=0.5, color='gray', linestyle='--', alpha=0.7,
247
- label='Random Guess (AUC=0.5)')
248
-
249
- ax.set_xlabel('Model Utility (Test Accuracy)', fontsize=13, fontweight='bold')
250
- ax.set_ylabel('Privacy Risk (MIA AUC)', fontsize=13, fontweight='bold')
251
- ax.set_title('Privacy-Utility Trade-off Analysis', fontsize=15, fontweight='bold')
252
-
253
- # 自动坐标范围
254
  all_acc = [p['acc'] for p in points]
255
  all_auc = [p['auc'] for p in points]
256
  if all_acc and all_auc:
257
  ax.set_xlim(min(all_acc) - 0.03, max(all_acc) + 0.05)
258
  ax.set_ylim(min(min(all_auc), 0.5) - 0.02, max(all_auc) + 0.02)
259
-
260
- # 区域标注
261
- ax.text(
262
- min(all_acc) - 0.02, max(all_auc) + 0.01,
263
- 'High Risk / Low Utility', fontsize=10, color='red',
264
- bbox=dict(facecolor='red', alpha=0.1)
265
- )
266
- ax.text(
267
- max(all_acc) + 0.03, min(min(all_auc), 0.5) + 0.005,
268
- 'Ideal Zone', fontsize=10, color='green',
269
- bbox=dict(facecolor='green', alpha=0.1)
270
- )
271
-
272
  ax.legend(loc='upper right', frameon=True, shadow=True, fontsize=10)
273
  ax.grid(True, alpha=0.3)
274
  plt.tight_layout()
@@ -276,48 +186,27 @@ def make_tradeoff():
276
 
277
 
278
  def make_accuracy_bar():
279
- """准确率对比柱状图"""
280
- names = []
281
- accs = []
282
- colors = []
283
-
284
- for k, name, c in [
285
- ('baseline', 'Baseline', '#95A5A6'),
286
- ('smooth_0.02', 'LS e=0.02', '#5B9BD5'),
287
- ('smooth_0.2', 'LS e=0.2', '#2E5FA1'),
288
- ]:
289
  if k in utility_results:
290
  names.append(name)
291
  accs.append(utility_results[k]['accuracy'] * 100)
292
  colors.append(c)
293
-
294
- # 输出扰动准确率 = 基线准确率
295
  base_pct = utility_results.get('baseline', {}).get('accuracy', 0) * 100
296
- for k, name, c in [
297
- ('perturbation_0.01', 'OP s=0.01', '#27AE60'),
298
- ('perturbation_0.02', 'OP s=0.02', '#145A32'),
299
- ]:
300
  if k in perturb_results:
301
  names.append(name)
302
  accs.append(base_pct)
303
  colors.append(c)
304
-
305
  fig, ax = plt.subplots(figsize=(11, 6))
306
- bars = ax.bar(
307
- names, accs, color=colors, width=0.5,
308
- edgecolor='white', linewidth=1.5
309
- )
310
-
311
  for bar, acc in zip(bars, accs):
312
- ax.text(
313
- bar.get_x() + bar.get_width() / 2,
314
- bar.get_height() + 0.8,
315
- f'{acc:.1f}%',
316
- ha='center', va='bottom', fontsize=13, fontweight='bold'
317
- )
318
-
319
  ax.set_ylabel('Accuracy (%)', fontsize=13)
320
- ax.set_title('Model Utility Comparison (300 Math Questions)', fontsize=15, fontweight='bold')
321
  ax.set_ylim(0, 100)
322
  ax.grid(axis='y', alpha=0.3)
323
  plt.xticks(rotation=10)
@@ -325,390 +214,302 @@ def make_accuracy_bar():
325
  return fig
326
 
327
 
 
 
 
 
 
 
 
 
 
328
  # ========================================
329
- # 3. 界面回调函数
330
  # ========================================
331
 
332
  def show_random_sample(data_type):
333
- """随机展示一条数据样本"""
334
- if "member" in data_type.lower() or "成员" in data_type:
335
  data = member_data
336
  else:
337
  data = non_member_data
338
-
339
  sample = data[np.random.randint(0, len(data))]
340
  meta = sample['metadata']
341
-
342
- task_name_map = {
343
- 'calculation': 'Calculation (基础计算)',
344
- 'word_problem': 'Word Problem (应用题)',
345
- 'concept': 'Concept Q&A (概念问答)',
346
- 'error_correction': 'Error Correction (错题订正)'
347
  }
348
-
349
- info = f"""### 📋 Sample Metadata (Privacy Fields)
350
-
351
- | Field | Value |
352
- |-------|-------|
353
- | **Name (姓名)** | {meta['name']} |
354
- | **Student ID (学号)** | {meta['student_id']} |
355
- | **Class (班级)** | {meta['class']} |
356
- | **Score (成绩)** | {meta['score']} |
357
- | **Task Type** | {task_name_map.get(sample['task_type'], sample['task_type'])} |
358
-
359
- > ⚠️ The above are **student privacy fields** that attackers attempt to infer!
360
- """
361
  return info, sample['question'], sample['answer']
362
 
363
 
364
  def run_mia_demo(sample_index, data_type):
365
- """MIA 攻击演示(使用实验中保存的真实 loss 数据)"""
 
 
 
 
 
 
366
 
367
- is_member = ("Member" in data_type or "成员" in data_type)
368
- idx = min(int(sample_index), 999)
369
- data = member_data if is_member else non_member_data
370
  sample = data[idx]
371
 
372
- # 从保存的完整 loss 数据中出对应样本的真实 loss
373
  bl = full_results.get('baseline', {})
374
  if is_member and idx < len(bl.get('member_losses', [])):
375
  loss = bl['member_losses'][idx]
376
  elif not is_member and idx < len(bl.get('non_member_losses', [])):
377
  loss = bl['non_member_losses'][idx]
378
  else:
379
- # 兜底:用统计信息模拟
380
- m_mean_fb = mia_results.get('baseline', {}).get('member_loss_mean', 0.19)
381
- nm_mean_fb = mia_results.get('baseline', {}).get('non_member_loss_mean', 0.23)
382
  if is_member:
383
- loss = float(np.random.normal(m_mean_fb, 0.02))
384
  else:
385
- loss = float(np.random.normal(nm_mean_fb, 0.02))
386
-
387
- # 计算阈值
388
- m_mean = mia_results.get('baseline', {}).get('member_loss_mean', 0.19)
389
- nm_mean = mia_results.get('baseline', {}).get('non_member_loss_mean', 0.23)
390
- threshold = (m_mean + nm_mean) / 2.0
391
 
 
392
  pred_member = (loss < threshold)
393
  actual_member = is_member
394
  attack_correct = (pred_member == actual_member)
395
 
396
- # ===== Loss 位置可视化 =====
397
  bar_total = 40
398
- if nm_mean > m_mean:
399
- ratio = (loss - m_mean) / (nm_mean - m_mean)
400
  else:
401
  ratio = 0.5
402
  ratio = max(0.0, min(1.0, ratio))
403
  pos = int(bar_total * ratio)
 
404
 
405
- bar_left = "=" * pos
406
- bar_right = "=" * (bar_total - pos)
407
- bar_visual = bar_left + "V" + bar_right
408
-
409
- # 阈值位置标记
410
- threshold_pos = int(bar_total * 0.5)
411
- threshold_bar = " " * threshold_pos + "|"
412
 
413
- # 判定文字
414
  if pred_member:
415
- pred_text = "🔴 **MEMBER** (Loss < Threshold → Model is too familiar)"
416
  else:
417
- pred_text = "🟢 **NON-MEMBER** (Loss >= Threshold → Model is not familiar)"
418
 
419
  if actual_member:
420
- actual_text = "🔴 **MEMBER** (This data WAS used in training)"
421
  else:
422
- actual_text = "🟢 **NON-MEMBER** (This data was NOT used in training)"
423
 
424
  if attack_correct and pred_member and actual_member:
425
- result_text = "✅ **ATTACK SUCCESS Privacy Leaked!**"
426
- result_emoji = "⚠️"
427
  elif attack_correct:
428
- result_text = "✅ **Correct Judgment**"
429
- result_emoji = "✅"
430
  else:
431
- result_text = "❌ **Attack Failed**"
432
- result_emoji = "❌"
433
-
434
- result_md = f"""## 🔍 MIA Attack Result
435
-
436
- # ===== Build the visualization block as a separate string =====
437
- viz_block = (
438
- " Member Zone (Low Loss) Non-Member Zone (High Loss)\n"
439
- " <--------------------|----------------------->\n"
440
- " Threshold\n"
441
- "\n"
442
- f" [{bar_visual}]\n"
443
- " | | |\n"
444
- " Member Mean Threshold Non-Member Mean\n"
445
- f" {m_mean:.4f} {threshold:.4f} {nm_mean:.4f}\n"
446
- "\n"
447
- f" Current Loss = {loss:.4f}\n"
448
- f" Position: {position_text}\n"
449
- )
450
 
451
- # ===== Build the warning/safe message =====
452
  if pred_member:
453
- warning_msg = (
454
- f"⚠️ **Privacy Risk!** This sample Loss = {loss:.4f} "
455
- f"is BELOW the threshold ({threshold:.4f}). "
456
- "The model 'remembers' this data — student privacy may be compromised!"
457
  )
458
  else:
459
- warning_msg = (
460
- f"✅ This sample Loss = {loss:.4f} "
461
- f"is ABOVE the threshold ({threshold:.4f}). "
462
- "The model shows no special memorization of this data."
463
  )
464
 
465
- # ===== Assemble the final Markdown =====
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
  result_md = (
467
- "## 🔍 MIA Attack Result\n\n"
468
- "### 📊 Loss Computation\n\n"
469
- "| Metric | Value |\n"
470
- "|--------|-------|\n"
471
- f"| **Sample Loss** | `{loss:.6f}` |\n"
472
- f"| **Decision Threshold** | `{threshold:.6f}` |\n"
473
- f"| **Member Mean Loss** | `{m_mean:.6f}` |\n"
474
- f"| **Non-Member Mean Loss** | `{nm_mean:.6f}` |\n\n"
475
- "### 📏 Loss Position Visualization\n\n"
476
  "```\n"
477
- f"{viz_block}"
478
- "```\n\n"
479
- "### 🎯 Attack Judgment\n\n"
480
- "| Item | Result |\n"
481
- "|------|--------|\n"
482
- f"| **Attacker Prediction** | {result_icon} {pred_text} |\n"
483
- f"| **Actual Identity** | {actual_text} |\n"
484
- f"| **Attack Outcome** | {result_icon} **{result_text}** |\n\n"
485
- "### 💡 How It Works\n\n"
486
- "The model produces **lower Loss** on data it was **trained on** "
487
- "(it's more \"confident\"). The attacker exploits this statistical difference:\n\n"
488
- f"- Loss **below** threshold `{threshold:.4f}` Predicted as **training member** → ⚠️ Privacy risk\n"
489
- f"- Loss **above** threshold `{threshold:.4f}` Predicted as **non-member** → ✅ Relatively safe\n\n"
490
- f"{warning_msg}\n\n"
491
- "> 📌 *This demo uses real Loss values saved from the experiment (not real-time inference).*\n"
492
  )
493
 
494
- question_display = f"**📝 Sample #{idx}:**\n\n{sample['question'][:600]}"
495
- return question_display, result_md
496
-
497
- ### 🎯 Attack Judgment
498
-
499
- | Item | Result |
500
- |------|--------|
501
- | **Attacker Prediction** | {pred_text} |
502
- | **Actual Identity** | {actual_text} |
503
- | **Attack Outcome** | {result_emoji} {result_text} |
504
-
505
- ### 💡 How It Works
506
-
507
- The model produces **lower Loss** on data it was **trained on** (it's more "confident").
508
- The attacker exploits this statistical difference:
509
-
510
- - Loss **below** threshold `{threshold:.4f}` → Predicted as **training member** → ⚠️ Privacy risk
511
- - Loss **above** threshold `{threshold:.4f}` → Predicted as **non-member** → ✅ Relatively safe
512
-
513
- {"⚠️ **Privacy Risk!** This sample's Loss = " + f"{loss:.4f}" + " is BELOW the threshold. The model 'remembers' this data — student privacy may be compromised!" if pred_member else "✅ This sample's Loss = " + f"{loss:.4f}" + " is ABOVE the threshold. The model shows no special memorization of this data."}
514
-
515
- > 📌 *This demo uses real Loss values saved from the experiment (not real-time inference).*
516
- """
517
-
518
- question_display = f"**📝 Sample #{idx}:**\n\n{sample['question'][:600]}"
519
  return question_display, result_md
520
 
521
 
522
  # ========================================
523
- # 4. 构建完整 Gradio 界面
524
  # ========================================
525
 
526
- custom_css = """
527
- .gradio-container {
528
- max-width: 1280px !important;
529
- margin: auto !important;
530
- }
531
- .tab-nav button {
532
- font-size: 15px !important;
533
- padding: 10px 18px !important;
534
- font-weight: 600 !important;
535
- }
536
- footer {
537
- display: none !important;
538
- }
539
- """
540
-
541
- # 预先取出常用数值(避免在 Markdown 中报错)
542
- bl_auc = mia_results.get('baseline', {}).get('auc', 0)
543
- s002_auc = mia_results.get('smooth_0.02', {}).get('auc', 0)
544
- s02_auc = mia_results.get('smooth_0.2', {}).get('auc', 0)
545
- op001_auc = perturb_results.get('perturbation_0.01', {}).get('auc', 0)
546
- op0015_auc = perturb_results.get('perturbation_0.015', {}).get('auc', 0)
547
- op002_auc = perturb_results.get('perturbation_0.02', {}).get('auc', 0)
548
-
549
- bl_acc = utility_results.get('baseline', {}).get('accuracy', 0) * 100
550
- s002_acc = utility_results.get('smooth_0.02', {}).get('accuracy', 0) * 100
551
- s02_acc = utility_results.get('smooth_0.2', {}).get('accuracy', 0) * 100
552
-
553
- model_name_str = config.get('model_name', 'Qwen/Qwen2.5-Math-1.5B-Instruct')
554
- gpu_name_str = config.get('gpu_name', 'T4')
555
- data_size_str = config.get('data_size', 2000)
556
- setup_date_str = config.get('setup_date', 'N/A')
557
-
558
 
559
  with gr.Blocks(
560
- title="Education LLM Privacy Attack & Defense",
561
- theme=gr.themes.Soft(
562
- primary_hue="blue",
563
- secondary_hue="sky",
564
- neutral_hue="slate"
565
- ),
566
  css=custom_css
567
  ) as demo:
568
 
569
  # ============================
570
- # Header
571
  # ============================
572
- gr.Markdown(f"""
573
- # 🎓 Membership Inference Attack & Defense in Educational LLMs
574
- ### 教育大模型中的成员推理攻击及其防御研究
575
-
576
- ---
577
-
578
- > **Goal**: Investigate privacy leakage risks in educational LLMs and evaluate **Label Smoothing** + **Output Perturbation** as defense strategies.
579
-
580
- > **Tech Stack**: `Qwen2.5-Math-1.5B` · `LoRA Fine-tuning` · `Loss-based MIA` · `Label Smoothing` · `Output Perturbation`
581
- """)
582
 
583
  # ============================
584
- # Tab 1: Project Overview
585
  # ============================
586
- with gr.Tab("🏠 Project Overview"):
587
-
588
- overview_md = (
589
- "## 📖 Research Background\n\n"
590
- "As LLMs are increasingly deployed in education (tutoring systems, personalized learning),\n"
591
- "they inevitably process student **private data** (names, IDs, grades).\n\n"
592
- "**Membership Inference Attack (MIA)** can determine whether a data sample was used to train\n"
593
- "the model, potentially exposing student privacy.\n\n"
594
  "---\n\n"
595
- "## 🔬 Research Design\n\n"
596
- "| Phase | Content | Details |\n"
597
- "|-------|---------|--------|\n"
598
- "| 📂 Data | 2000 math tutoring dialogues | Contains names, student IDs, grades |\n"
599
- "| 🧠 Training | Qwen2.5-Math + LoRA | Baseline + 2 label smoothing models |\n"
600
- "| ⚔️ Attack | Loss-based MIA | Classify members by output loss |\n"
601
- "| 🛡️ Train-time Defense | Label Smoothing (e=0.02, 0.2) | Regularization during training |\n"
602
- "| 🛡️ Inference-time Defense | Output Perturbation (s=0.01~0.02) | Add noise at inference |\n"
603
- "| 📊 Evaluation | Privacy-Utility Trade-off | AUC + Accuracy |\n\n"
604
  "---\n\n"
605
- "## ⚙️ Experiment Configuration\n\n"
606
- "| Item | Value |\n"
607
- "|------|-------|\n"
608
- f"| **Base Model** | {model_name_str} |\n"
609
- "| **Fine-tuning** | LoRA (r=8, alpha=16) |\n"
610
- "| **Training Epochs** | 10 |\n"
611
- f"| **Dataset Size** | {data_size_str} (1000 member + 1000 non-member) |\n"
612
- f"| **GPU** | {gpu_name_str} |\n"
613
- f"| **Date** | {setup_date_str} |\n\n"
614
  "---\n\n"
615
- "## 📐 Technical Pipeline\n\n"
616
  "```\n"
617
- "+-------------+ +-------------------+ +-----------+ +-------------------+ +------------+\n"
618
- "| Data Gen | --> | Baseline Training | --> | MIA Attack| --> | Defense Deploy | --> | Evaluation |\n"
619
- "| (2000) | | (LoRA fine-tune) | | (Loss) | | (LS + OP) | | (AUC+Acc) |\n"
620
- "+-------------+ +--------+----------+ +-----------+ +-------------------+ +------------+\n"
621
- " | |\n"
622
- " +-- Label Smoothing Training --------------+\n"
623
- " (e=0.02, e=0.2)\n"
624
  "```\n"
625
  )
626
 
627
- gr.Markdown(overview_md)
628
- # ============================
629
- # Tab 2: Data Explorer
630
  # ============================
631
- with gr.Tab("📊 Data Explorer"):
632
- gr.Markdown("""
633
- ## 📂 Dataset Overview
634
-
635
- - **Member data (Training set)**: 1000 samples — used to train the model
636
- - **Non-member data (Test set)**: 1000 samples — NOT used in training
637
- - Each sample contains **student privacy**: name, student ID, class, score
638
- """)
 
639
 
640
  with gr.Row():
641
  with gr.Column(scale=1):
642
- gr.Markdown("### 📊 Task Distribution")
643
  gr.Plot(value=make_pie_chart())
644
-
645
  with gr.Column(scale=1):
646
- gr.Markdown("### 🔍 Random Sample Viewer")
647
- data_type_selector = gr.Radio(
648
- choices=[
649
- "Member Data (Training Set / 成员数据)",
650
- "Non-Member Data (Test Set / 非成员数据)"
651
- ],
652
- value="Member Data (Training Set / 成员数据)",
653
- label="Select Data Type"
654
- )
655
- sample_btn = gr.Button(
656
- "🎲 Random Sample", variant="primary"
657
  )
 
658
 
659
  sample_info = gr.Markdown()
660
  with gr.Row():
661
- sample_q = gr.Textbox(
662
- label="📝 Student Question", lines=7, interactive=False
663
- )
664
- sample_a = gr.Textbox(
665
- label="💡 Model Answer", lines=7, interactive=False
666
- )
667
 
668
  sample_btn.click(
669
  fn=show_random_sample,
670
- inputs=[data_type_selector],
671
  outputs=[sample_info, sample_q, sample_a]
672
  )
673
 
674
  # ============================
675
- # Tab 3: MIA Attack Demo
676
  # ============================
677
- with gr.Tab("⚔️ MIA Attack Demo"):
678
- gr.Markdown("""
679
- ## ⚔️ Live Membership Inference Attack
680
-
681
- **How it works**: The model produces **lower Loss** on training data (it's more "confident").
682
- The attacker uses a **threshold** on Loss to predict membership.
683
-
684
- ### 📌 Steps:
685
- 1️⃣ Select data source (Member / Non-Member)
686
- 2️⃣ Choose a sample index (0-999)
687
- 3️⃣ Click **"Run Attack"** to see the result
688
- """)
689
 
690
  with gr.Row():
691
  with gr.Column(scale=1):
692
  atk_data_type = gr.Radio(
693
- choices=[
694
- "Member Data (成员数据)",
695
- "Non-Member Data (非成员数据)"
696
- ],
697
- value="Member Data (成员数据)",
698
- label="📂 Data Source"
699
  )
700
  atk_index = gr.Slider(
701
  minimum=0, maximum=999, step=1, value=0,
702
- label="📌 Sample Index (0-999)"
703
- )
704
- atk_btn = gr.Button(
705
- "⚔️ Run MIA Attack",
706
- variant="primary",
707
- size="lg"
708
  )
709
-
710
  with gr.Column(scale=1):
711
- atk_question = gr.Markdown(label="Sample Content")
712
 
713
  atk_result = gr.Markdown()
714
 
@@ -719,245 +520,174 @@ The attacker uses a **threshold** on Loss to predict membership.
719
  )
720
 
721
  # ============================
722
- # Tab 4: Defense Comparison
723
  # ============================
724
- with gr.Tab("🛡️ Defense Comparison"):
725
- gr.Markdown("""
726
- ## 🛡️ Defense Strategy Comparison
727
-
728
- | Strategy | Type | Mechanism | Pros | Cons |
729
- |----------|------|-----------|------|------|
730
- | **Label Smoothing** | Train-time | Soften labels to prevent overfitting | Reduces memorization | May hurt utility |
731
- | **Output Perturbation** | Inference-time | Add Gaussian noise to Loss | Zero utility loss | Only masks signal |
732
- """)
733
 
734
  with gr.Row():
735
  with gr.Column():
736
- gr.Markdown("### 📊 AUC Comparison (All Defenses)")
737
  gr.Plot(value=make_auc_bar())
738
-
739
  with gr.Column():
740
- gr.Markdown("### 📈 Loss Distribution (Baseline vs LS)")
741
  gr.Plot(value=make_loss_distribution())
742
 
743
- # Results table
744
- gr.Markdown("### 📋 Detailed Results")
745
-
746
- def risk_badge(auc_val):
747
- if auc_val > 0.62:
748
- return "🔴 High"
749
- elif auc_val > 0.55:
750
- return "🟡 Medium"
751
- else:
752
- return "🟢 Low"
753
-
754
- table = "| Strategy | Type | AUC | Privacy Risk |\n"
755
- table += "|----------|------|-----|-------------|\n"
756
-
757
- for k, name, cat in [
758
- ('baseline', 'Baseline (No Defense)', '—'),
759
- ('smooth_0.02', 'Label Smoothing e=0.02', 'Train-time'),
760
- ('smooth_0.2', 'Label Smoothing e=0.2', 'Train-time'),
761
- ]:
762
  if k in mia_results:
763
  a = mia_results[k]['auc']
764
- table += f"| {name} | {cat} | **{a:.4f}** | {risk_badge(a)} |\n"
765
-
766
- for k, name in [
767
- ('perturbation_0.01', 'Output Perturbation s=0.01'),
768
- ('perturbation_0.015', 'Output Perturbation s=0.015'),
769
- ('perturbation_0.02', 'Output Perturbation s=0.02'),
770
- ]:
771
  if k in perturb_results:
772
  a = perturb_results[k]['auc']
773
- table += f"| {name} | Inference-time | **{a:.4f}** | {risk_badge(a)} |\n"
774
-
775
  gr.Markdown(table)
776
 
777
  # ============================
778
- # Tab 5: Output Perturbation
779
  # ============================
780
- with gr.Tab("🔊 Output Perturbation"):
781
- gr.Markdown(f"""
782
- ## 🔊 Output Perturbation Defense
783
-
784
- ### 📌 Core Idea
785
-
786
- At **inference time**, add **Gaussian noise** to the model's output Loss:
787
-
788
- **Loss_perturbed = Loss_original + N(0, sigma^2)**
789
-
790
- ### Key Advantage
791
- - **No retraining needed** (zero deployment cost)
792
- - **No utility loss** (accuracy stays exactly the same)
793
- - Noise level sigma can be tuned dynamically
794
-
795
- ### 📊 Experiment Results
796
-
797
- | sigma | AUC | AUC Reduction | Accuracy | Note |
798
- |-------|-----|--------------|----------|------|
799
- | 0 (Baseline) | **{bl_auc:.4f}** | — | {bl_acc:.1f}% | No defense |
800
- | 0.01 | **{op001_auc:.4f}** | ↓{bl_auc - op001_auc:.4f} | {bl_acc:.1f}% (unchanged) | Mild |
801
- | 0.015 | **{op0015_auc:.4f}** | ↓{bl_auc - op0015_auc:.4f} | {bl_acc:.1f}% (unchanged) | Moderate |
802
- | 0.02 | **{op002_auc:.4f}** | ↓{bl_auc - op002_auc:.4f} | {bl_acc:.1f}% (unchanged) | **Recommended** |
803
-
804
- ### 💡 Key Finding
805
-
806
- > Output Perturbation (s=0.02) reduces AUC from {bl_auc:.4f} to **{op002_auc:.4f}**
807
- > while keeping accuracy at **{bl_acc:.1f}%** — truly a **zero-cost defense**!
808
- """)
809
 
810
  # ============================
811
- # Tab 6: Utility Evaluation
812
  # ============================
813
- with gr.Tab("📝 Utility Evaluation"):
814
- gr.Markdown("""
815
- ## 📐 Model Utility Evaluation
816
-
817
- > Defense must not sacrifice too much utility.
818
- > Test set: **300 math questions** covering calculation, word problems, and concept Q&A.
819
- """)
820
 
821
  with gr.Row():
822
  with gr.Column():
823
- gr.Markdown("### 📊 Accuracy Comparison")
824
  gr.Plot(value=make_accuracy_bar())
825
  with gr.Column():
826
- gr.Markdown("### ⚖️ Privacy-Utility Trade-off")
827
  gr.Plot(value=make_tradeoff())
828
 
829
- # Utility table
830
- ut = "| Strategy | Accuracy | AUC | Risk | Utility Impact |\n"
831
- ut += "|----------|----------|-----|------|---------------|\n"
832
-
833
- for k, name in [
834
- ('baseline', 'Baseline'),
835
- ('smooth_0.02', 'LS e=0.02'),
836
- ('smooth_0.2', 'LS e=0.2'),
837
- ]:
838
  if k in utility_results and k in mia_results:
839
  acc = utility_results[k]['accuracy'] * 100
840
  auc = mia_results[k]['auc']
841
- impact = "—" if k == 'baseline' else (
842
- " Improved" if acc > bl_acc else "⚠️ Decreased"
843
- )
844
- ut += f"| {name} | **{acc:.1f}%** | {auc:.4f} | {risk_badge(auc)} | {impact} |\n"
845
-
846
- for k, name in [
847
- ('perturbation_0.01', 'OP s=0.01'),
848
- ('perturbation_0.02', 'OP s=0.02'),
849
- ]:
850
  if k in perturb_results:
851
- ut += (f"| {name} | **{bl_acc:.1f}%** | "
852
- f"{perturb_results[k]['auc']:.4f} | "
853
- f"{risk_badge(perturb_results[k]['auc'])} | ✅ No change |\n")
854
-
855
  gr.Markdown(ut)
856
 
857
  # ============================
858
- # Tab 7: Paper Figures
859
  # ============================
860
- with gr.Tab("📄 Paper Figures"):
861
- gr.Markdown("## 📄 Publication-Quality Figures (300 DPI)")
862
-
863
- figure_items = [
864
- ("fig1_loss_distribution_comparison.png",
865
- "Figure 1: Loss Distribution — Baseline vs Label Smoothing"),
866
- ("fig2_privacy_utility_tradeoff_fixed.png",
867
- "Figure 2: Privacy-Utility Trade-off Analysis"),
868
- ("fig3_defense_comparison_bar.png",
869
- "Figure 3: Defense Mechanism AUC Comparison"),
870
- ]
871
-
872
- for filename, caption in figure_items:
873
- path = os.path.join(BASE_DIR, "figures", filename)
874
  if os.path.exists(path):
875
- gr.Markdown(f"### {caption}")
876
  gr.Image(value=path, show_label=False, height=420)
877
  gr.Markdown("---")
878
  else:
879
- gr.Markdown(
880
- f"### {caption}\n\n"
881
- f"> ⚠️ File not found: `figures/{filename}` — "
882
- f"this figure is optional."
883
- )
884
 
885
  # ============================
886
- # Tab 8: Conclusions
887
  # ============================
888
- with gr.Tab("🎓 Conclusions"):
889
- gr.Markdown(f"""
890
- ## 📝 Research Conclusions
891
-
892
- ---
893
-
894
- ### 🔬 Core Findings
895
-
896
- #### Finding 1: MIA poses a real threat to educational LLMs
897
- Baseline AUC = **{bl_auc:.4f}** (significantly above random guess of 0.5).
898
- Attackers can infer student data membership with high probability.
899
-
900
- #### Finding 2: Label Smoothing effectively reduces risk
901
-
902
- | Strategy | AUC | Accuracy | Verdict |
903
- |----------|-----|----------|---------|
904
- | Baseline | {bl_auc:.4f} | {bl_acc:.1f}% | High privacy risk |
905
- | LS e=0.02 | {s002_auc:.4f} | {s002_acc:.1f}% | ✅ **Recommended** |
906
- | LS e=0.2 | {s02_auc:.4f} | {s02_acc:.1f}% | ⚠️ Strong defense, possible utility impact |
907
-
908
- #### Finding 3: Output Perturbation is a zero-cost defense
909
- sigma=0.02 reduces AUC from {bl_auc:.4f} to **{op002_auc:.4f}** with **zero accuracy loss**.
910
-
911
- #### Finding 4: Best practice — combine both defenses
912
- > **Recommended**: LS e=0.02 (training) + OP s=0.02 (inference) = **Dual Protection**
913
-
914
- ---
915
-
916
- ### 🎤 Defense Presentation Script
917
-
918
- > "This study uses a math tutoring scenario with Qwen2.5-Math-1.5B + LoRA fine-tuning.
919
- > The baseline model shows AUC={bl_auc:.4f}, indicating significant privacy leakage risk.
920
- > We evaluate two complementary defenses: **Label Smoothing** (train-time, e=0.02)
921
- > and **Output Perturbation** (inference-time, s=0.02).
922
- > Output perturbation achieves **zero utility loss**, making it ideal for practical deployment.
923
- > The study reveals the fundamental privacy-utility trade-off in educational AI."
924
-
925
- ---
926
-
927
- ### 📚 Innovation Points
928
-
929
- 1. **Novel scenario** — Focus on educational LLM privacy (not general NLP)
930
- 2. **Dual defense** — Both train-time and inference-time strategies
931
- 3. **Practical** — Label smoothing = 1 line of code; Output perturbation = 1 line of code
932
- 4. **Comprehensive** — Attack + Defense + Utility + Trade-off analysis
933
-
934
- ---
935
-
936
- ### 🔮 Future Work
937
-
938
- - Explore **Differential Privacy (DP-SGD)** for stronger guarantees
939
- - Test **Shadow Model Attack** and other advanced MIA variants
940
- - Validate on real educational datasets
941
- - Investigate **Federated Learning** for educational model privacy
942
- """)
943
 
944
  # ============================
945
- # Footer
946
  # ============================
947
- gr.Markdown("""
948
- ---
949
- <center>
950
-
951
- 🎓 **Membership Inference Attack & Defense in Educational LLMs**
952
-
953
- `Qwen2.5-Math-1.5B` · `LoRA` · `MIA` · `Label Smoothing` · `Output Perturbation` · `Gradio`
954
-
955
- </center>
956
- """)
957
-
958
 
959
  # ========================================
960
- # 5. Launch
961
  # ========================================
962
- demo.launch()
963
-
 
1
  # ================================================================
2
+ # 🎓 教育模型中的成员推理攻击及其防御研究
3
+ # 完整演示界面 - Hugging Face Spaces 永久部署版
 
 
 
 
4
  # ================================================================
5
 
6
  import os
 
18
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
19
 
20
 
21
+ def load_json(path):
22
+ full = os.path.join(BASE_DIR, path)
23
+ with open(full, 'r', encoding='utf-8') as f:
 
 
 
24
  return json.load(f)
25
 
26
 
 
27
  member_data = load_json("data/member.json")
28
  non_member_data = load_json("data/non_member.json")
 
 
29
  mia_results = load_json("results/mia_results.json")
30
  utility_results = load_json("results/utility_results.json")
31
  perturb_results = load_json("results/perturbation_results.json")
32
  full_results = load_json("results/mia_full_results.json")
 
 
33
  config = load_json("config.json")
34
 
35
+ # 字体
36
+ plt.rcParams['font.sans-serif'] = ['DejaVu Sans']
37
  plt.rcParams['axes.unicode_minus'] = False
38
 
39
+ # 预取数值
40
+ bl_auc = mia_results.get('baseline', {}).get('auc', 0)
41
+ s002_auc = mia_results.get('smooth_0.02', {}).get('auc', 0)
42
+ s02_auc = mia_results.get('smooth_0.2', {}).get('auc', 0)
43
+ op001_auc = perturb_results.get('perturbation_0.01', {}).get('auc', 0)
44
+ op0015_auc = perturb_results.get('perturbation_0.015', {}).get('auc', 0)
45
+ op002_auc = perturb_results.get('perturbation_0.02', {}).get('auc', 0)
46
+
47
+ bl_acc = utility_results.get('baseline', {}).get('accuracy', 0) * 100
48
+ s002_acc = utility_results.get('smooth_0.02', {}).get('accuracy', 0) * 100
49
+ s02_acc = utility_results.get('smooth_0.2', {}).get('accuracy', 0) * 100
50
+
51
+ bl_m_mean = mia_results.get('baseline', {}).get('member_loss_mean', 0.19)
52
+ bl_nm_mean = mia_results.get('baseline', {}).get('non_member_loss_mean', 0.23)
53
+
54
+ model_name_str = config.get('model_name', 'Qwen/Qwen2.5-Math-1.5B-Instruct')
55
+ gpu_name_str = config.get('gpu_name', 'T4')
56
+ data_size_str = config.get('data_size', 2000)
57
+ setup_date_str = config.get('setup_date', 'N/A')
58
+
59
 
60
  # ========================================
61
+ # 2. 图表函数
62
  # ========================================
63
 
64
  def make_pie_chart():
 
65
  task_counts = {}
66
  for item in member_data + non_member_data:
67
  t = item.get('task_type', 'unknown')
68
  task_counts[t] = task_counts.get(t, 0) + 1
 
69
  name_map = {
70
+ 'calculation': 'Calculation',
71
+ 'word_problem': 'Word Problem',
72
+ 'concept': 'Concept Q&A',
73
+ 'error_correction': 'Error Correction'
74
  }
 
75
  labels = [name_map.get(k, k) for k in task_counts]
76
  sizes = list(task_counts.values())
77
  colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4']
 
78
  fig, ax = plt.subplots(figsize=(8, 6))
79
+ ax.pie(
80
+ sizes, labels=labels, autopct='%1.1f%%',
81
+ colors=colors[:len(labels)], explode=[0.04] * len(labels),
82
+ shadow=True, startangle=90, textprops={'fontsize': 11}
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  )
84
+ ax.set_title('Task Distribution (2000 samples)', fontsize=14, fontweight='bold', pad=15)
85
  plt.tight_layout()
86
  return fig
87
 
88
 
89
  def make_loss_distribution():
 
90
  plot_items = []
91
+ for k, t in [('baseline', 'Baseline'), ('smooth_0.02', 'LS e=0.02'), ('smooth_0.2', 'LS e=0.2')]:
 
 
92
  if k in full_results:
93
  auc = mia_results.get(k, {}).get('auc', 0)
94
+ plot_items.append((k, t + " (AUC=" + f"{auc:.4f}" + ")"))
 
95
  n = len(plot_items)
96
  if n == 0:
97
  fig, ax = plt.subplots()
98
+ ax.text(0.5, 0.5, 'No data', ha='center')
99
  return fig
 
100
  fig, axes = plt.subplots(1, n, figsize=(6 * n, 5))
101
  if n == 1:
102
  axes = [axes]
 
103
  for ax, (k, title) in zip(axes, plot_items):
104
+ m = full_results[k]['member_losses']
105
+ nm = full_results[k]['non_member_losses']
106
+ bins = np.linspace(min(min(m), min(nm)), max(max(m), max(nm)), 40)
107
+ ax.hist(m, bins=bins, alpha=0.55, color='#4A90D9',
108
+ label='Members (u=' + f"{np.mean(m):.3f}" + ')', density=True)
109
+ ax.hist(nm, bins=bins, alpha=0.55, color='#E74C3C',
110
+ label='Non-Members (u=' + f"{np.mean(nm):.3f}" + ')', density=True)
111
+ ax.set_title(title, fontsize=12, fontweight='bold')
112
+ ax.set_xlabel('Loss')
113
+ ax.set_ylabel('Density')
 
 
 
 
114
  ax.legend(fontsize=9)
115
  ax.grid(True, linestyle='--', alpha=0.4)
 
 
 
 
 
116
  plt.tight_layout()
117
  return fig
118
 
119
 
120
  def make_auc_bar():
121
+ methods, aucs, colors = [], [], []
122
+ for k, name, c in [('baseline', 'Baseline', '#95A5A6'), ('smooth_0.02', 'LS e=0.02', '#5B9BD5'),
123
+ ('smooth_0.2', 'LS e=0.2', '#2E5FA1')]:
 
 
 
 
 
 
 
 
124
  if k in mia_results:
125
  methods.append(name)
126
  aucs.append(mia_results[k]['auc'])
127
  colors.append(c)
128
+ for k, name, c in [('perturbation_0.01', 'OP s=0.01', '#27AE60'),
129
+ ('perturbation_0.015', 'OP s=0.015', '#1E8449'),
130
+ ('perturbation_0.02', 'OP s=0.02', '#145A32')]:
 
 
 
 
131
  if k in perturb_results:
132
  methods.append(name)
133
  aucs.append(perturb_results[k]['auc'])
134
  colors.append(c)
 
135
  fig, ax = plt.subplots(figsize=(11, 6))
136
+ bars = ax.bar(methods, aucs, color=colors, width=0.55, edgecolor='white', linewidth=1.5)
137
+ for bar, a in zip(bars, aucs):
138
+ ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.004,
139
+ f'{a:.3f}', ha='center', va='bottom', fontsize=13, fontweight='bold')
140
+ ax.axhline(y=0.5, color='red', linestyle='--', linewidth=2, label='Random Guess (0.5)')
141
+ ax.axhline(y=bl_auc, color='black', linestyle=':', linewidth=1.5, label='Baseline')
142
+ ax.set_ylabel('MIA AUC', fontsize=13)
143
+ ax.set_title('All Defense Mechanisms - AUC', fontsize=14, fontweight='bold')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  ax.set_ylim(0.45, max(aucs) + 0.06 if aucs else 1.0)
145
  ax.legend(fontsize=11)
146
  ax.grid(axis='y', linestyle='--', alpha=0.4)
 
150
 
151
 
152
  def make_tradeoff():
 
153
  fig, ax = plt.subplots(figsize=(10, 7))
154
  points = []
 
 
155
  for k, name, marker, color, sz in [
156
+ ('baseline', 'Baseline', 'o', 'black', 180),
157
+ ('smooth_0.02', 'LS e=0.02', 's', '#5B9BD5', 160),
158
+ ('smooth_0.2', 'LS e=0.2', 's', '#2E5FA1', 160)]:
 
159
  if k in mia_results and k in utility_results:
160
+ points.append({'name': name, 'auc': mia_results[k]['auc'],
161
+ 'acc': utility_results[k]['accuracy'],
162
+ 'marker': marker, 'color': color, 'size': sz})
 
 
 
 
 
163
  base_acc = utility_results.get('baseline', {}).get('accuracy', 0.633)
164
  for k, name, marker, color, sz in [
165
+ ('perturbation_0.01', 'OP s=0.01', '^', '#27AE60', 170),
166
+ ('perturbation_0.02', 'OP s=0.02', '^', '#145A32', 170)]:
 
167
  if k in perturb_results:
168
+ points.append({'name': name, 'auc': perturb_results[k]['auc'],
169
+ 'acc': base_acc, 'marker': marker, 'color': color, 'size': sz})
 
 
 
 
 
170
  for p in points:
171
+ ax.scatter(p['acc'], p['auc'], label=p['name'], marker=p['marker'],
172
+ color=p['color'], s=p['size'], edgecolors='white', linewidth=1.5, zorder=5)
173
+ ax.axhline(y=0.5, color='gray', linestyle='--', alpha=0.7, label='Random Guess (0.5)')
174
+ ax.set_xlabel('Accuracy', fontsize=13, fontweight='bold')
175
+ ax.set_ylabel('MIA AUC (Privacy Risk)', fontsize=13, fontweight='bold')
176
+ ax.set_title('Privacy-Utility Trade-off', fontsize=14, fontweight='bold')
 
 
 
 
 
 
 
 
177
  all_acc = [p['acc'] for p in points]
178
  all_auc = [p['auc'] for p in points]
179
  if all_acc and all_auc:
180
  ax.set_xlim(min(all_acc) - 0.03, max(all_acc) + 0.05)
181
  ax.set_ylim(min(min(all_auc), 0.5) - 0.02, max(all_auc) + 0.02)
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  ax.legend(loc='upper right', frameon=True, shadow=True, fontsize=10)
183
  ax.grid(True, alpha=0.3)
184
  plt.tight_layout()
 
186
 
187
 
188
  def make_accuracy_bar():
189
+ names, accs, colors = [], [], []
190
+ for k, name, c in [('baseline', 'Baseline', '#95A5A6'), ('smooth_0.02', 'LS e=0.02', '#5B9BD5'),
191
+ ('smooth_0.2', 'LS e=0.2', '#2E5FA1')]:
 
 
 
 
 
 
 
192
  if k in utility_results:
193
  names.append(name)
194
  accs.append(utility_results[k]['accuracy'] * 100)
195
  colors.append(c)
 
 
196
  base_pct = utility_results.get('baseline', {}).get('accuracy', 0) * 100
197
+ for k, name, c in [('perturbation_0.01', 'OP s=0.01', '#27AE60'),
198
+ ('perturbation_0.02', 'OP s=0.02', '#145A32')]:
 
 
199
  if k in perturb_results:
200
  names.append(name)
201
  accs.append(base_pct)
202
  colors.append(c)
 
203
  fig, ax = plt.subplots(figsize=(11, 6))
204
+ bars = ax.bar(names, accs, color=colors, width=0.5, edgecolor='white', linewidth=1.5)
 
 
 
 
205
  for bar, acc in zip(bars, accs):
206
+ ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.8,
207
+ f'{acc:.1f}%', ha='center', va='bottom', fontsize=13, fontweight='bold')
 
 
 
 
 
208
  ax.set_ylabel('Accuracy (%)', fontsize=13)
209
+ ax.set_title('Model Utility (300 Math Questions)', fontsize=14, fontweight='bold')
210
  ax.set_ylim(0, 100)
211
  ax.grid(axis='y', alpha=0.3)
212
  plt.xticks(rotation=10)
 
214
  return fig
215
 
216
 
217
+ def risk_badge(auc_val):
218
+ if auc_val > 0.62:
219
+ return "🔴 高"
220
+ elif auc_val > 0.55:
221
+ return "🟡 中"
222
+ else:
223
+ return "🟢 低"
224
+
225
+
226
  # ========================================
227
+ # 3. 回调函数
228
  # ========================================
229
 
230
  def show_random_sample(data_type):
231
+ if data_type == "成员数据(训练集)":
 
232
  data = member_data
233
  else:
234
  data = non_member_data
 
235
  sample = data[np.random.randint(0, len(data))]
236
  meta = sample['metadata']
237
+ task_map = {
238
+ 'calculation': '基础计算',
239
+ 'word_problem': '应用题',
240
+ 'concept': '概念问答',
241
+ 'error_correction': '错题订正'
 
242
  }
243
+ info = (
244
+ "### 📋 样本元信息(隐私字段)\n\n"
245
+ "| 字段 | 值 |\n"
246
+ "|------|-----|\n"
247
+ "| **姓名** | " + str(meta['name']) + " |\n"
248
+ "| **学号** | " + str(meta['student_id']) + " |\n"
249
+ "| **班级** | " + str(meta['class']) + " |\n"
250
+ "| **成绩** | " + str(meta['score']) + " 分 |\n"
251
+ "| **任务类型** | " + task_map.get(sample['task_type'], sample['task_type']) + " |\n\n"
252
+ "> ⚠️ 以上就是攻击者试图推断的**学��隐私信息**!\n"
253
+ )
 
 
254
  return info, sample['question'], sample['answer']
255
 
256
 
257
  def run_mia_demo(sample_index, data_type):
258
+ # 判断成员/非成员
259
+ if data_type == "成员数据(训练集)":
260
+ is_member = True
261
+ data = member_data
262
+ else:
263
+ is_member = False
264
+ data = non_member_data
265
 
266
+ idx = min(int(sample_index), len(data) - 1)
 
 
267
  sample = data[idx]
268
 
269
+ # 取真实 loss
270
  bl = full_results.get('baseline', {})
271
  if is_member and idx < len(bl.get('member_losses', [])):
272
  loss = bl['member_losses'][idx]
273
  elif not is_member and idx < len(bl.get('non_member_losses', [])):
274
  loss = bl['non_member_losses'][idx]
275
  else:
 
 
 
276
  if is_member:
277
+ loss = float(np.random.normal(bl_m_mean, 0.02))
278
  else:
279
+ loss = float(np.random.normal(bl_nm_mean, 0.02))
 
 
 
 
 
280
 
281
+ threshold = (bl_m_mean + bl_nm_mean) / 2.0
282
  pred_member = (loss < threshold)
283
  actual_member = is_member
284
  attack_correct = (pred_member == actual_member)
285
 
286
+ # 可视化进度条
287
  bar_total = 40
288
+ if bl_nm_mean > bl_m_mean:
289
+ ratio = (loss - bl_m_mean) / (bl_nm_mean - bl_m_mean)
290
  else:
291
  ratio = 0.5
292
  ratio = max(0.0, min(1.0, ratio))
293
  pos = int(bar_total * ratio)
294
+ bar_visual = "=" * pos + "V" + "=" * (bar_total - pos)
295
 
296
+ if pred_member:
297
+ position_text = "成员区(左侧)⚠️ 隐私风险"
298
+ else:
299
+ position_text = "非成员区(右侧)✅ 相对安全"
 
 
 
300
 
 
301
  if pred_member:
302
+ pred_text = "🔴 是训练成员(Loss < 阈值,模型过于熟悉)"
303
  else:
304
+ pred_text = "🟢 不是训练成员(Loss >= 阈值,模型不熟悉)"
305
 
306
  if actual_member:
307
+ actual_text = "🔴 是训练成员(此数据参与了训练)"
308
  else:
309
+ actual_text = "🟢 不是训练成员(此数据未参与训练)"
310
 
311
  if attack_correct and pred_member and actual_member:
312
+ result_text = "✅ **攻击成功隐私泄露!**"
 
313
  elif attack_correct:
314
+ result_text = "✅ **判断正确**"
 
315
  else:
316
+ result_text = "❌ **攻击失误**"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
 
 
318
  if pred_member:
319
+ warning = (
320
+ "⚠️ **隐私风险!** 此样本 Loss = " + f"{loss:.4f}"
321
+ + " 低于阈值(" + f"{threshold:.4f}"
322
+ + "),模型对它过于'熟悉',学生隐私可能被推断!"
323
  )
324
  else:
325
+ warning = (
326
+ "✅ 此样本 Loss = " + f"{loss:.4f}"
327
+ + " 高于阈值(" + f"{threshold:.4f}"
328
+ + "),模型对其无特殊记忆,隐私相对安全。"
329
  )
330
 
331
+ viz = (
332
+ " 成员区(低Loss) 非成员区(高Loss)\n"
333
+ " <-----------------------|------------------------->\n"
334
+ " 阈值\n"
335
+ "\n"
336
+ " [" + bar_visual + "]\n"
337
+ " | | |\n"
338
+ " 成员均值 阈值 非成员均值\n"
339
+ " " + f"{bl_m_mean:.4f}" + " "
340
+ + f"{threshold:.4f}" + " "
341
+ + f"{bl_nm_mean:.4f}" + "\n"
342
+ "\n"
343
+ " 当前 Loss = " + f"{loss:.4f}" + "\n"
344
+ " 位置: " + position_text + "\n"
345
+ )
346
+
347
  result_md = (
348
+ "## 🔍 MIA 攻击结果\n\n"
349
+ "### 📊 Loss 计算\n\n"
350
+ "| 指标 | |\n"
351
+ "|------|-----|\n"
352
+ "| **样本 Loss** | `" + f"{loss:.6f}" + "` |\n"
353
+ "| **判定阈值** | `" + f"{threshold:.6f}" + "` |\n"
354
+ "| **成员平均 Loss** | `" + f"{bl_m_mean:.6f}" + "` |\n"
355
+ "| **非成员平均 Loss** | `" + f"{bl_nm_mean:.6f}" + "` |\n\n"
356
+ "### 📏 Loss ��置可视化\n\n"
357
  "```\n"
358
+ + viz
359
+ + "```\n\n"
360
+ "### 🎯 攻击判定\n\n"
361
+ "| 项目 | 结果 |\n"
362
+ "|------|------|\n"
363
+ "| **攻击者预测** | " + pred_text + " |\n"
364
+ "| **实际身份** | " + actual_text + " |\n"
365
+ "| **攻击结果** | " + result_text + " |\n\n"
366
+ "### 💡 原理说明\n\n"
367
+ "模型对**训练过的数据**产生**更低的 Loss**(更\"自信\"),"
368
+ "攻击者利用这一统计差异推断成员身份:\n\n"
369
+ "- Loss **低于** 阈值 " + f"{threshold:.4f}" + " 判定为**训练成员** → ⚠️ 隐私风险\n"
370
+ "- Loss **高于** 阈值 " + f"{threshold:.4f}" + " 判定为**非成员** → ✅ 相对安全\n\n"
371
+ + warning + "\n\n"
372
+ "> 📌 本演示使用实验中保存的真实 Loss 数据。\n"
373
  )
374
 
375
+ question_display = "**📝 " + str(idx) + " 号样本:**\n\n" + sample['question'][:600]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  return question_display, result_md
377
 
378
 
379
  # ========================================
380
+ # 4. 构建界面
381
  # ========================================
382
 
383
+ custom_css = (
384
+ ".gradio-container { max-width: 1280px !important; margin: auto !important; }\n"
385
+ ".tab-nav button { font-size: 15px !important; padding: 10px 18px !important; font-weight: 600 !important; }\n"
386
+ "footer { display: none !important; }\n"
387
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
 
389
  with gr.Blocks(
390
+ title="教育大模型隐私攻防实验",
391
+ theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky", neutral_hue="slate"),
 
 
 
 
392
  css=custom_css
393
  ) as demo:
394
 
395
  # ============================
396
+ # 顶部标题
397
  # ============================
398
+ gr.Markdown(
399
+ "# 🎓 教育大模型中的成员推理攻击及其防御研究\n"
400
+ "### Membership Inference Attack & Defense in Educational LLMs\n\n"
401
+ "---\n\n"
402
+ "> **研究目标**:探究教育场景下大语言模型的隐私泄露风险,"
403
+ "验证**标签平滑**和**输出扰动**两种防御策略的效果与局限。\n\n"
404
+ "> **技术栈**:`Qwen2.5-Math-1.5B` · `LoRA微调` · `Loss-based MIA` · "
405
+ "`标签平滑` · `输出扰动`\n"
406
+ )
 
407
 
408
  # ============================
409
+ # Tab 1: 项目概览
410
  # ============================
411
+ with gr.Tab("🏠 项目概览"):
412
+ gr.Markdown(
413
+ "## 📖 研究背景\n\n"
414
+ "随着大语言模型在教育领域广泛应用(智能辅导、个性化学习等),"
415
+ "模型训练不可避免地接触到学生隐私数据(姓名、学号、成绩等)。\n\n"
416
+ "**成员推理攻击(MIA)** 可以判断某条数据是否用于模型训练,从而推断学生隐私。\n\n"
 
 
417
  "---\n\n"
418
+ "## 🔬 研究设计\n\n"
419
+ "| 阶段 | 内容 | 说明 |\n"
420
+ "|------|------|------|\n"
421
+ "| 📂 数据准备 | 2000条小学数学辅导对话 | 含姓名、学号、成绩等隐私 |\n"
422
+ "| 🧠 模型训练 | Qwen2.5-Math + LoRA | 基线 + 两个标签平滑模型 |\n"
423
+ "| ⚔️ 攻击测试 | Loss-based MIA | 基于Loss判断成员身份 |\n"
424
+ "| 🛡️ 训练期防御 | 标签平滑 (e=0.02, 0.2) | 训练时正则化 |\n"
425
+ "| 🛡️ 推理期防御 | 输出扰动 (s=0.01~0.02) | 推理时加噪声 |\n"
426
+ "| 📊 综合评估 | 隐私-效用权衡 | AUC + 准确率 |\n\n"
427
  "---\n\n"
428
+ "## ⚙️ 实验配置\n\n"
429
+ "| 配置项 | |\n"
430
+ "|--------|-----|\n"
431
+ "| **基座模型** | " + model_name_str + " |\n"
432
+ "| **微调方法** | LoRA (r=8, alpha=16) |\n"
433
+ "| **训练轮数** | 10 epochs |\n"
434
+ "| **数据总量** | " + str(data_size_str) + " 条(成员1000 + 非成员1000)|\n"
435
+ "| **GPU** | " + gpu_name_str + " |\n"
436
+ "| **实验日期** | " + setup_date_str + " |\n\n"
437
  "---\n\n"
438
+ "## 📐 技术路线\n\n"
439
  "```\n"
440
+ "+----------+ +-----------+ +----------+ +----------+ +----------+\n"
441
+ "| 数据生成 |--->| 基线训练 |--->| MIA攻击 |--->| 防御部署 |--->| 综合评估 |\n"
442
+ "| (2000) | | (LoRA) | | (Loss) | | (LS+OP) | | (AUC+Acc)|\n"
443
+ "+----------+ +-----+-----+ +----------+ +----------+ +----------+\n"
444
+ " | |\n"
445
+ " +--- 标签平滑模型训练 -----------+\n"
446
+ " (e=0.02, e=0.2)\n"
447
  "```\n"
448
  )
449
 
 
 
 
450
  # ============================
451
+ # Tab 2: 数据展示
452
+ # ============================
453
+ with gr.Tab("📊 数据展示"):
454
+ gr.Markdown(
455
+ "## 📂 数据集概况\n\n"
456
+ "- **成员数据(训练集)**1000条,用于训练模型\n"
457
+ "- **非成员数据(测试集)**:1000条,不参与训练\n"
458
+ "- 每条数据包含**学生隐私信息**(姓名、学号、班级、成绩)\n"
459
+ )
460
 
461
  with gr.Row():
462
  with gr.Column(scale=1):
463
+ gr.Markdown("### 📊 任务类型分布")
464
  gr.Plot(value=make_pie_chart())
 
465
  with gr.Column(scale=1):
466
+ gr.Markdown("### 🔍 随机查看样本")
467
+ data_sel = gr.Radio(
468
+ choices=["成员数据(训练集)", "非成员数据(测试集)"],
469
+ value="成员数据(训练集)",
470
+ label="选择数据类型"
 
 
 
 
 
 
471
  )
472
+ sample_btn = gr.Button("🎲 随机抽取样本", variant="primary")
473
 
474
  sample_info = gr.Markdown()
475
  with gr.Row():
476
+ sample_q = gr.Textbox(label="📝 学生提问", lines=7, interactive=False)
477
+ sample_a = gr.Textbox(label="💡 模型回答", lines=7, interactive=False)
 
 
 
 
478
 
479
  sample_btn.click(
480
  fn=show_random_sample,
481
+ inputs=[data_sel],
482
  outputs=[sample_info, sample_q, sample_a]
483
  )
484
 
485
  # ============================
486
+ # Tab 3: MIA 攻击演示
487
  # ============================
488
+ with gr.Tab("⚔️ MIA攻击演示"):
489
+ gr.Markdown(
490
+ "## ⚔️ 实时成员推理攻击\n\n"
491
+ "**攻击原理**:模型对训练过的数据产生更低的Loss(更\"自信\"),"
492
+ "攻击者利用Loss阈值判断成员身份。\n\n"
493
+ "### 📌 操作步骤\n"
494
+ "1️⃣ 选择数据来源(成员/非成员)\n"
495
+ "2️⃣ 拖动滑块选择样本编号\n"
496
+ "3️⃣ 点击 **\"执行攻击\"** 查看结果\n"
497
+ )
 
 
498
 
499
  with gr.Row():
500
  with gr.Column(scale=1):
501
  atk_data_type = gr.Radio(
502
+ choices=["成员数据(训练集)", "非成员数据(测试集)"],
503
+ value="成员数据(训练集)",
504
+ label="📂 数据来源"
 
 
 
505
  )
506
  atk_index = gr.Slider(
507
  minimum=0, maximum=999, step=1, value=0,
508
+ label="📌 样本编号 (0-999)"
 
 
 
 
 
509
  )
510
+ atk_btn = gr.Button("⚔️ 执行MIA攻击", variant="primary", size="lg")
511
  with gr.Column(scale=1):
512
+ atk_question = gr.Markdown()
513
 
514
  atk_result = gr.Markdown()
515
 
 
520
  )
521
 
522
  # ============================
523
+ # Tab 4: 防御对比
524
  # ============================
525
+ with gr.Tab("🛡️ 防御对比"):
526
+ gr.Markdown(
527
+ "## 🛡️ 防御策略效果对比\n\n"
528
+ "| 策略 | 类型 | 原理 | 优点 | 缺点 |\n"
529
+ "|------|------|------|------|------|\n"
530
+ "| **标签平滑** | 训练期 | 软化标签防止过拟合 | 从根源降低记忆 | 可能损失效用 |\n"
531
+ "| **输出扰动** | 推理期 | Loss加高斯噪声 | 零效用损失 | 只遮蔽统计信号 |\n"
532
+ )
 
533
 
534
  with gr.Row():
535
  with gr.Column():
536
+ gr.Markdown("### 📊 所有防御策略AUC对比")
537
  gr.Plot(value=make_auc_bar())
 
538
  with gr.Column():
539
+ gr.Markdown("### 📈 Loss分布对比")
540
  gr.Plot(value=make_loss_distribution())
541
 
542
+ # 结果表格
543
+ table = (
544
+ "### 📋 完整实验结果\n\n"
545
+ "| 策略 | 类型 | AUC | 隐私风险 |\n"
546
+ "|------|------|-----|----------|\n"
547
+ )
548
+ for k, name, cat in [('baseline', '基线(无防御)', '—'),
549
+ ('smooth_0.02', '标签平滑 e=0.02', '训练期'),
550
+ ('smooth_0.2', '标签平滑 e=0.2', '训练期')]:
 
 
 
 
 
 
 
 
 
 
551
  if k in mia_results:
552
  a = mia_results[k]['auc']
553
+ table += "| " + name + " | " + cat + " | **" + f"{a:.4f}" + "** | " + risk_badge(a) + " |\n"
554
+ for k, name in [('perturbation_0.01', '输出扰动 s=0.01'),
555
+ ('perturbation_0.015', '输出扰动 s=0.015'),
556
+ ('perturbation_0.02', '输出扰动 s=0.02')]:
 
 
 
557
  if k in perturb_results:
558
  a = perturb_results[k]['auc']
559
+ table += "| " + name + " | 推理期 | **" + f"{a:.4f}" + "** | " + risk_badge(a) + " |\n"
 
560
  gr.Markdown(table)
561
 
562
  # ============================
563
+ # Tab 5: 输出扰动
564
  # ============================
565
+ with gr.Tab("🔊 输出扰动"):
566
+ gr.Markdown(
567
+ "## 🔊 输出扰动防御详解\n\n"
568
+ "### 📌 核心思想\n\n"
569
+ "在**推理阶段**,对模型返回的Loss值添加**高斯噪声**:\n\n"
570
+ "**Loss_new = Loss_original + N(0, sigma^2)**\n\n"
571
+ "### 最大优势\n"
572
+ "- **不需要重新训练模型**(部署成本为零)\n"
573
+ "- **不影响模型效用**(准确率完全不变)\n"
574
+ "- 噪声强度sigma可以动态调节\n\n"
575
+ "### 📊 实验结果\n\n"
576
+ "| sigma | AUC | 相比基线降低 | 准确率 | 说明 |\n"
577
+ "|-------|-----|-------------|--------|------|\n"
578
+ "| 0(基线)| **" + f"{bl_auc:.4f}" + "** | — | " + f"{bl_acc:.1f}" + "% | 无防御 |\n"
579
+ "| 0.01 | **" + f"{op001_auc:.4f}" + "** | ↓" + f"{bl_auc - op001_auc:.4f}" + " | " + f"{bl_acc:.1f}" + "%(不变)| 温和 |\n"
580
+ "| 0.015 | **" + f"{op0015_auc:.4f}" + "** | ↓" + f"{bl_auc - op0015_auc:.4f}" + " | " + f"{bl_acc:.1f}" + "%(不变)| 适中 |\n"
581
+ "| 0.02 | **" + f"{op002_auc:.4f}" + "** | ↓" + f"{bl_auc - op002_auc:.4f}" + " | " + f"{bl_acc:.1f}" + "%(不变)| **推荐** |\n\n"
582
+ "### 💡 核心发现\n\n"
583
+ "> 输出扰动 (s=0.02) 将AUC从 " + f"{bl_auc:.4f}" + " 降至 **" + f"{op002_auc:.4f}" + "**,"
584
+ "准确率 **" + f"{bl_acc:.1f}" + "%** 完全不变 — 真正的**零成本防御**!\n"
585
+ )
 
 
 
 
 
 
 
 
586
 
587
  # ============================
588
+ # Tab 6: 效用评估
589
  # ============================
590
+ with gr.Tab("📝 效用评估"):
591
+ gr.Markdown(
592
+ "## 📐 模型效用评估\n\n"
593
+ "> 防御不能\"只管隐私不管效果\"。本节评估各模型在 **300道数学题** 上的准确率。\n"
594
+ )
 
 
595
 
596
  with gr.Row():
597
  with gr.Column():
598
+ gr.Markdown("### 📊 准确率对比")
599
  gr.Plot(value=make_accuracy_bar())
600
  with gr.Column():
601
+ gr.Markdown("### ⚖️ 隐私-效用权衡")
602
  gr.Plot(value=make_tradeoff())
603
 
604
+ ut = (
605
+ "### 📋 效用评估详情\n\n"
606
+ "| 策略 | 准确率 | AUC | 风险 | 效用影响 |\n"
607
+ "|------|--------|-----|------|----------|\n"
608
+ )
609
+ for k, name in [('baseline', '基线'), ('smooth_0.02', '标签平滑 e=0.02'),
610
+ ('smooth_0.2', '标签平滑 e=0.2')]:
 
 
611
  if k in utility_results and k in mia_results:
612
  acc = utility_results[k]['accuracy'] * 100
613
  auc = mia_results[k]['auc']
614
+ impact = "—" if k == 'baseline' else ("✅ 提升" if acc > bl_acc else "⚠️ 下降")
615
+ ut += "| " + name + " | **" + f"{acc:.1f}" + "%** | " + f"{auc:.4f}" + " | " + risk_badge(auc) + " | " + impact + " |\n"
616
+ for k, name in [('perturbation_0.01', '输出扰动 s=0.01'), ('perturbation_0.02', '输出扰动 s=0.02')]:
 
 
 
 
 
 
617
  if k in perturb_results:
618
+ ut += "| " + name + " | **" + f"{bl_acc:.1f}" + "%** | " + f"{perturb_results[k]['auc']:.4f}" + " | " + risk_badge(perturb_results[k]['auc']) + " | ✅ 无影响 |\n"
 
 
 
619
  gr.Markdown(ut)
620
 
621
  # ============================
622
+ # Tab 7: 论文图表
623
  # ============================
624
+ with gr.Tab("📄 论文图表"):
625
+ gr.Markdown("## 📄 学术级论文图表(300 DPI")
626
+
627
+ for fn, cap in [("fig1_loss_distribution_comparison.png", "图1:Loss分布对比"),
628
+ ("fig2_privacy_utility_tradeoff_fixed.png", "图2:隐私-效用权衡"),
629
+ ("fig3_defense_comparison_bar.png", "图3:防御效果柱状图")]:
630
+ path = os.path.join(BASE_DIR, "figures", fn)
 
 
 
 
 
 
 
631
  if os.path.exists(path):
632
+ gr.Markdown("### " + cap)
633
  gr.Image(value=path, show_label=False, height=420)
634
  gr.Markdown("---")
635
  else:
636
+ gr.Markdown("### " + cap + "\n\n> ⚠️ 文件未找到:" + fn + "(不影响核心功能)")
 
 
 
 
637
 
638
  # ============================
639
+ # Tab 8: 研究结论
640
  # ============================
641
+ with gr.Tab("🎓 研究结论"):
642
+ gr.Markdown(
643
+ "## 📝 核心结论\n\n"
644
+ "---\n\n"
645
+ "### 发现一:MIA对教育大模型构成现实威胁\n\n"
646
+ "基线模型AUC = **" + f"{bl_auc:.4f}" + "**,远高于随机猜测(0.5),"
647
+ "攻击者可以较高概率推断学生隐私。\n\n"
648
+ "### 发现二:标签平滑是有效的训练期防御\n\n"
649
+ "| 策略 | AUC | 准确率 | 评价 |\n"
650
+ "|------|-----|--------|------|\n"
651
+ "| 基线(无防御)| " + f"{bl_auc:.4f}" + " | " + f"{bl_acc:.1f}" + "% | 隐私风险高 |\n"
652
+ "| 标签平滑 e=0.02 | " + f"{s002_auc:.4f}" + " | " + f"{s002_acc:.1f}" + "% | ✅ **推荐** |\n"
653
+ "| 标签平滑 e=0.2 | " + f"{s02_auc:.4f}" + " | " + f"{s02_acc:.1f}" + "% | ⚠️ 防御强但效用受影响 |\n\n"
654
+ "### 发现三:输出扰动是零成本的推理期防御\n\n"
655
+ "s=0.02 将AUC从 " + f"{bl_auc:.4f}" + " 降至 **" + f"{op002_auc:.4f}" + "**,准确率**不变**。\n\n"
656
+ "### 发现四:双重防御可叠加使用\n\n"
657
+ "> **推荐方案**:标签平滑 e=0.02(训练期)+ 输出扰动 s=0.02(推理期)= **双重防护**\n\n"
658
+ "---\n\n"
659
+ "### 🎤 答辩话术\n\n"
660
+ "> \"本研究以小学数学智能辅导系统为场景,使用Qwen2.5-Math-1.5B + LoRA微调。\n"
661
+ "> 基线模型AUC=" + f"{bl_auc:.4f}" + ",存在显著隐私泄露风险。\n"
662
+ "> 通过**训练期标签平滑**(e=0.02)和**推理期输出扰动**(s=0.02)两种防御,\n"
663
+ "> 有效降低了攻击成功率,其中输出扰动实现了**零效用损失**。\n"
664
+ "> 研究揭示了教育AI领域隐私保护与模型效用之间的权衡关系。\"\n\n"
665
+ "---\n\n"
666
+ "### 📚 创新点\n\n"
667
+ "1. **场景新颖** — 聚焦教育领域LLM隐私(而非通用NLP)\n"
668
+ "2. **双重防御** — 同时研究训练期 + 推理期防御策略\n"
669
+ "3. **工程可行** 标签平滑一行代码,输出扰动一行代码\n"
670
+ "4. **实验完整** — 攻击 + 防御 + 效用评估 + 权衡分析\n\n"
671
+ "---\n\n"
672
+ "### 🔮 未来工作\n\n"
673
+ "- ��索**差分隐私 (DP-SGD)** 等更强防御\n"
674
+ "- 测试 **Shadow Model Attack** 等更强攻击\n"
675
+ "- 在真实教育数据集上验证\n"
676
+ "- 研究**联邦学习**框架下的教育模型隐私\n"
677
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
678
 
679
  # ============================
680
+ # 底部
681
  # ============================
682
+ gr.Markdown(
683
+ "---\n"
684
+ "<center>\n\n"
685
+ "🎓 **教育大模型中的成员推理攻击及其防御思路研究**\n\n"
686
+ "`Qwen2.5-Math-1.5B` · `LoRA` · `MIA` · `标签平滑` · `输出扰动` · `Gradio`\n\n"
687
+ "</center>\n"
688
+ )
 
 
 
 
689
 
690
  # ========================================
691
+ # 5. 启动
692
  # ========================================
693
+ demo.launch()