xiaohy commited on
Commit
72d2d35
·
verified ·
1 Parent(s): e6788b5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +963 -0
app.py ADDED
@@ -0,0 +1,963 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ================================================================
2
+ # 🎓 教育��模型中的成员推理攻击及其防御研究
3
+ # Membership Inference Attack & Defense in Educational LLMs
4
+ # ================================================================
5
+ # 部署平台:Hugging Face Spaces (永久免费)
6
+ # SDK:Gradio
7
+ # 硬件:CPU basic (Free) — 不需要 GPU
8
+ # ================================================================
9
+
10
+ import os
11
+ import json
12
+ import numpy as np
13
+ import matplotlib
14
+ matplotlib.use('Agg')
15
+ import matplotlib.pyplot as plt
16
+ import gradio as gr
17
+
18
+ # ========================================
19
+ # 1. 加载所有数据
20
+ # ========================================
21
+
22
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
23
+
24
+
25
+ def load_json(relative_path):
26
+ """安全加载 JSON 文件"""
27
+ path = os.path.join(BASE_DIR, relative_path)
28
+ if not os.path.exists(path):
29
+ raise FileNotFoundError(f"文件不存在: {path}")
30
+ with open(path, 'r', encoding='utf-8') as f:
31
+ return json.load(f)
32
+
33
+
34
+ # 训练/测试数据
35
+ member_data = load_json("data/member.json")
36
+ non_member_data = load_json("data/non_member.json")
37
+
38
+ # 实验结果
39
+ mia_results = load_json("results/mia_results.json")
40
+ utility_results = load_json("results/utility_results.json")
41
+ perturb_results = load_json("results/perturbation_results.json")
42
+ full_results = load_json("results/mia_full_results.json")
43
+
44
+ # 项目配置
45
+ config = load_json("config.json")
46
+
47
+ # 字体设置(兼容 Spaces 环境)
48
+ plt.rcParams['font.sans-serif'] = ['DejaVu Sans', 'Arial', 'sans-serif']
49
+ plt.rcParams['axes.unicode_minus'] = False
50
+
51
+
52
+ # ========================================
53
+ # 2. 图表生成函数
54
+ # ========================================
55
+
56
+ def make_pie_chart():
57
+ """数据集任务分布饼图"""
58
+ task_counts = {}
59
+ for item in member_data + non_member_data:
60
+ t = item.get('task_type', 'unknown')
61
+ task_counts[t] = task_counts.get(t, 0) + 1
62
+
63
+ name_map = {
64
+ 'calculation': 'Calculation (40%)',
65
+ 'word_problem': 'Word Problem (30%)',
66
+ 'concept': 'Concept Q&A (20%)',
67
+ 'error_correction': 'Error Correction (10%)'
68
+ }
69
+
70
+ labels = [name_map.get(k, k) for k in task_counts]
71
+ sizes = list(task_counts.values())
72
+ colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4']
73
+
74
+ fig, ax = plt.subplots(figsize=(8, 6))
75
+ wedges, texts, autotexts = ax.pie(
76
+ sizes,
77
+ labels=labels,
78
+ autopct='%1.1f%%',
79
+ colors=colors[:len(labels)],
80
+ explode=[0.04] * len(labels),
81
+ shadow=True,
82
+ startangle=90,
83
+ textprops={'fontsize': 11}
84
+ )
85
+ for t in autotexts:
86
+ t.set_fontsize(12)
87
+ t.set_fontweight('bold')
88
+
89
+ ax.set_title(
90
+ 'Dataset Task Distribution (2000 samples)',
91
+ fontsize=15, fontweight='bold', pad=15
92
+ )
93
+ plt.tight_layout()
94
+ return fig
95
+
96
+
97
+ def make_loss_distribution():
98
+ """Loss 分布直方图(使用真实 loss 数据)"""
99
+ plot_items = []
100
+ for k, t in [('baseline', 'Baseline'),
101
+ ('smooth_0.02', 'Label Smoothing e=0.02'),
102
+ ('smooth_0.2', 'Label Smoothing e=0.2')]:
103
+ if k in full_results:
104
+ auc = mia_results.get(k, {}).get('auc', 0)
105
+ plot_items.append((k, f"{t}\nAUC = {auc:.4f}"))
106
+
107
+ n = len(plot_items)
108
+ if n == 0:
109
+ fig, ax = plt.subplots()
110
+ ax.text(0.5, 0.5, 'No data available', ha='center', va='center')
111
+ return fig
112
+
113
+ fig, axes = plt.subplots(1, n, figsize=(6 * n, 5))
114
+ if n == 1:
115
+ axes = [axes]
116
+
117
+ for ax, (k, title) in zip(axes, plot_items):
118
+ m_losses = full_results[k]['member_losses']
119
+ nm_losses = full_results[k]['non_member_losses']
120
+
121
+ all_losses = m_losses + nm_losses
122
+ bins = np.linspace(min(all_losses), max(all_losses), 40)
123
+
124
+ ax.hist(m_losses, bins=bins, alpha=0.55, color='#4A90D9',
125
+ label=f'Members (u={np.mean(m_losses):.3f})', density=True)
126
+ ax.hist(nm_losses, bins=bins, alpha=0.55, color='#E74C3C',
127
+ label=f'Non-Members (u={np.mean(nm_losses):.3f})', density=True)
128
+
129
+ ax.set_title(title, fontsize=13, fontweight='bold')
130
+ ax.set_xlabel('Loss Value', fontsize=11)
131
+ ax.set_ylabel('Density', fontsize=11)
132
+ ax.legend(fontsize=9)
133
+ ax.grid(True, linestyle='--', alpha=0.4)
134
+
135
+ plt.suptitle(
136
+ 'Member vs Non-Member Loss Distribution',
137
+ fontsize=16, fontweight='bold', y=1.02
138
+ )
139
+ plt.tight_layout()
140
+ return fig
141
+
142
+
143
+ def make_auc_bar():
144
+ """所有防御策略 AUC 柱状图"""
145
+ methods = []
146
+ aucs = []
147
+ colors = []
148
+
149
+ # MIA 模型结果
150
+ for k, name, c in [
151
+ ('baseline', 'Baseline', '#95A5A6'),
152
+ ('smooth_0.02', 'LS e=0.02', '#5B9BD5'),
153
+ ('smooth_0.2', 'LS e=0.2', '#2E5FA1'),
154
+ ]:
155
+ if k in mia_results:
156
+ methods.append(name)
157
+ aucs.append(mia_results[k]['auc'])
158
+ colors.append(c)
159
+
160
+ # 输出扰动结果
161
+ for k, name, c in [
162
+ ('perturbation_0.01', 'OP s=0.01', '#27AE60'),
163
+ ('perturbation_0.015', 'OP s=0.015', '#1E8449'),
164
+ ('perturbation_0.02', 'OP s=0.02', '#145A32'),
165
+ ]:
166
+ if k in perturb_results:
167
+ methods.append(name)
168
+ aucs.append(perturb_results[k]['auc'])
169
+ colors.append(c)
170
+
171
+ fig, ax = plt.subplots(figsize=(11, 6))
172
+ bars = ax.bar(
173
+ methods, aucs, color=colors, width=0.55,
174
+ edgecolor='white', linewidth=1.5
175
+ )
176
+
177
+ # 数值标签
178
+ for bar, auc_val in zip(bars, aucs):
179
+ ax.text(
180
+ bar.get_x() + bar.get_width() / 2,
181
+ bar.get_height() + 0.004,
182
+ f'{auc_val:.3f}',
183
+ ha='center', va='bottom', fontsize=13, fontweight='bold'
184
+ )
185
+
186
+ # 参考线
187
+ baseline_auc = mia_results.get('baseline', {}).get('auc', 0.63)
188
+ ax.axhline(y=0.5, color='red', linestyle='--', linewidth=2,
189
+ label='Random Guess (AUC=0.5)')
190
+ ax.axhline(y=baseline_auc, color='black', linestyle=':',
191
+ linewidth=1.5, label='Baseline Risk')
192
+
193
+ ax.set_ylabel('MIA Attack AUC', fontsize=13)
194
+ ax.set_title(
195
+ 'Comparison of All Defense Mechanisms',
196
+ fontsize=15, fontweight='bold'
197
+ )
198
+ ax.set_ylim(0.45, max(aucs) + 0.06 if aucs else 1.0)
199
+ ax.legend(fontsize=11)
200
+ ax.grid(axis='y', linestyle='--', alpha=0.4)
201
+ plt.xticks(rotation=10)
202
+ plt.tight_layout()
203
+ return fig
204
+
205
+
206
+ def make_tradeoff():
207
+ """隐私-效用权衡散点图"""
208
+ fig, ax = plt.subplots(figsize=(10, 7))
209
+ points = []
210
+
211
+ # MIA 模型
212
+ for k, name, marker, color, sz in [
213
+ ('baseline', 'Baseline (No Defense)', 'o', 'black', 180),
214
+ ('smooth_0.02', 'Label Smoothing e=0.02', 's', '#5B9BD5', 160),
215
+ ('smooth_0.2', 'Label Smoothing e=0.2', 's', '#2E5FA1', 160),
216
+ ]:
217
+ if k in mia_results and k in utility_results:
218
+ points.append({
219
+ 'name': name,
220
+ 'auc': mia_results[k]['auc'],
221
+ 'acc': utility_results[k]['accuracy'],
222
+ 'marker': marker, 'color': color, 'size': sz
223
+ })
224
+
225
+ # 输出扰动(准确率 = 基线准确率)
226
+ base_acc = utility_results.get('baseline', {}).get('accuracy', 0.633)
227
+ for k, name, marker, color, sz in [
228
+ ('perturbation_0.01', 'Output Perturb s=0.01', '^', '#27AE60', 170),
229
+ ('perturbation_0.02', 'Output Perturb s=0.02', '^', '#145A32', 170),
230
+ ]:
231
+ if k in perturb_results:
232
+ points.append({
233
+ 'name': name,
234
+ 'auc': perturb_results[k]['auc'],
235
+ 'acc': base_acc,
236
+ 'marker': marker, 'color': color, 'size': sz
237
+ })
238
+
239
+ for p in points:
240
+ ax.scatter(
241
+ p['acc'], p['auc'],
242
+ label=p['name'], marker=p['marker'], color=p['color'],
243
+ s=p['size'], edgecolors='white', linewidth=1.5, zorder=5
244
+ )
245
+
246
+ ax.axhline(y=0.5, color='gray', linestyle='--', alpha=0.7,
247
+ label='Random Guess (AUC=0.5)')
248
+
249
+ ax.set_xlabel('Model Utility (Test Accuracy)', fontsize=13, fontweight='bold')
250
+ ax.set_ylabel('Privacy Risk (MIA AUC)', fontsize=13, fontweight='bold')
251
+ ax.set_title('Privacy-Utility Trade-off Analysis', fontsize=15, fontweight='bold')
252
+
253
+ # 自动坐标范围
254
+ all_acc = [p['acc'] for p in points]
255
+ all_auc = [p['auc'] for p in points]
256
+ if all_acc and all_auc:
257
+ ax.set_xlim(min(all_acc) - 0.03, max(all_acc) + 0.05)
258
+ ax.set_ylim(min(min(all_auc), 0.5) - 0.02, max(all_auc) + 0.02)
259
+
260
+ # 区域标注
261
+ ax.text(
262
+ min(all_acc) - 0.02, max(all_auc) + 0.01,
263
+ 'High Risk / Low Utility', fontsize=10, color='red',
264
+ bbox=dict(facecolor='red', alpha=0.1)
265
+ )
266
+ ax.text(
267
+ max(all_acc) + 0.03, min(min(all_auc), 0.5) + 0.005,
268
+ 'Ideal Zone', fontsize=10, color='green',
269
+ bbox=dict(facecolor='green', alpha=0.1)
270
+ )
271
+
272
+ ax.legend(loc='upper right', frameon=True, shadow=True, fontsize=10)
273
+ ax.grid(True, alpha=0.3)
274
+ plt.tight_layout()
275
+ return fig
276
+
277
+
278
+ def make_accuracy_bar():
279
+ """准确率对比柱状图"""
280
+ names = []
281
+ accs = []
282
+ colors = []
283
+
284
+ for k, name, c in [
285
+ ('baseline', 'Baseline', '#95A5A6'),
286
+ ('smooth_0.02', 'LS e=0.02', '#5B9BD5'),
287
+ ('smooth_0.2', 'LS e=0.2', '#2E5FA1'),
288
+ ]:
289
+ if k in utility_results:
290
+ names.append(name)
291
+ accs.append(utility_results[k]['accuracy'] * 100)
292
+ colors.append(c)
293
+
294
+ # 输出扰动准确率 = 基线准确率
295
+ base_pct = utility_results.get('baseline', {}).get('accuracy', 0) * 100
296
+ for k, name, c in [
297
+ ('perturbation_0.01', 'OP s=0.01', '#27AE60'),
298
+ ('perturbation_0.02', 'OP s=0.02', '#145A32'),
299
+ ]:
300
+ if k in perturb_results:
301
+ names.append(name)
302
+ accs.append(base_pct)
303
+ colors.append(c)
304
+
305
+ fig, ax = plt.subplots(figsize=(11, 6))
306
+ bars = ax.bar(
307
+ names, accs, color=colors, width=0.5,
308
+ edgecolor='white', linewidth=1.5
309
+ )
310
+
311
+ for bar, acc in zip(bars, accs):
312
+ ax.text(
313
+ bar.get_x() + bar.get_width() / 2,
314
+ bar.get_height() + 0.8,
315
+ f'{acc:.1f}%',
316
+ ha='center', va='bottom', fontsize=13, fontweight='bold'
317
+ )
318
+
319
+ ax.set_ylabel('Accuracy (%)', fontsize=13)
320
+ ax.set_title('Model Utility Comparison (300 Math Questions)', fontsize=15, fontweight='bold')
321
+ ax.set_ylim(0, 100)
322
+ ax.grid(axis='y', alpha=0.3)
323
+ plt.xticks(rotation=10)
324
+ plt.tight_layout()
325
+ return fig
326
+
327
+
328
+ # ========================================
329
+ # 3. 界面回调函数
330
+ # ========================================
331
+
332
+ def show_random_sample(data_type):
333
+ """随机展示一条数据样本"""
334
+ if "member" in data_type.lower() or "成员" in data_type:
335
+ data = member_data
336
+ else:
337
+ data = non_member_data
338
+
339
+ sample = data[np.random.randint(0, len(data))]
340
+ meta = sample['metadata']
341
+
342
+ task_name_map = {
343
+ 'calculation': 'Calculation (基础计算)',
344
+ 'word_problem': 'Word Problem (应用题)',
345
+ 'concept': 'Concept Q&A (概念问答)',
346
+ 'error_correction': 'Error Correction (错题订正)'
347
+ }
348
+
349
+ info = f"""### 📋 Sample Metadata (Privacy Fields)
350
+
351
+ | Field | Value |
352
+ |-------|-------|
353
+ | **Name (姓名)** | {meta['name']} |
354
+ | **Student ID (学号)** | {meta['student_id']} |
355
+ | **Class (班级)** | {meta['class']} |
356
+ | **Score (成绩)** | {meta['score']} |
357
+ | **Task Type** | {task_name_map.get(sample['task_type'], sample['task_type'])} |
358
+
359
+ > ⚠️ The above are **student privacy fields** that attackers attempt to infer!
360
+ """
361
+ return info, sample['question'], sample['answer']
362
+
363
+
364
+ def run_mia_demo(sample_index, data_type):
365
+ """MIA 攻击演示(使用实验中保存的真实 loss 数据)"""
366
+
367
+ is_member = ("Member" in data_type or "成员" in data_type)
368
+ idx = min(int(sample_index), 999)
369
+ data = member_data if is_member else non_member_data
370
+ sample = data[idx]
371
+
372
+ # 从保存的完整 loss 数据中取出对应样本的真实 loss
373
+ bl = full_results.get('baseline', {})
374
+ if is_member and idx < len(bl.get('member_losses', [])):
375
+ loss = bl['member_losses'][idx]
376
+ elif not is_member and idx < len(bl.get('non_member_losses', [])):
377
+ loss = bl['non_member_losses'][idx]
378
+ else:
379
+ # 兜底:用统计信息模拟
380
+ m_mean_fb = mia_results.get('baseline', {}).get('member_loss_mean', 0.19)
381
+ nm_mean_fb = mia_results.get('baseline', {}).get('non_member_loss_mean', 0.23)
382
+ if is_member:
383
+ loss = float(np.random.normal(m_mean_fb, 0.02))
384
+ else:
385
+ loss = float(np.random.normal(nm_mean_fb, 0.02))
386
+
387
+ # 计算阈值
388
+ m_mean = mia_results.get('baseline', {}).get('member_loss_mean', 0.19)
389
+ nm_mean = mia_results.get('baseline', {}).get('non_member_loss_mean', 0.23)
390
+ threshold = (m_mean + nm_mean) / 2.0
391
+
392
+ pred_member = (loss < threshold)
393
+ actual_member = is_member
394
+ attack_correct = (pred_member == actual_member)
395
+
396
+ # ===== Loss 位置可视化 =====
397
+ bar_total = 40
398
+ if nm_mean > m_mean:
399
+ ratio = (loss - m_mean) / (nm_mean - m_mean)
400
+ else:
401
+ ratio = 0.5
402
+ ratio = max(0.0, min(1.0, ratio))
403
+ pos = int(bar_total * ratio)
404
+
405
+ bar_left = "=" * pos
406
+ bar_right = "=" * (bar_total - pos)
407
+ bar_visual = bar_left + "V" + bar_right
408
+
409
+ # 阈值位置标记
410
+ threshold_pos = int(bar_total * 0.5)
411
+ threshold_bar = " " * threshold_pos + "|"
412
+
413
+ # 判定文字
414
+ if pred_member:
415
+ pred_text = "🔴 **MEMBER** (Loss < Threshold → Model is too familiar)"
416
+ else:
417
+ pred_text = "🟢 **NON-MEMBER** (Loss >= Threshold → Model is not familiar)"
418
+
419
+ if actual_member:
420
+ actual_text = "🔴 **MEMBER** (This data WAS used in training)"
421
+ else:
422
+ actual_text = "🟢 **NON-MEMBER** (This data was NOT used in training)"
423
+
424
+ if attack_correct and pred_member and actual_member:
425
+ result_text = "✅ **ATTACK SUCCESS — Privacy Leaked!**"
426
+ result_emoji = "⚠️"
427
+ elif attack_correct:
428
+ result_text = "✅ **Correct Judgment**"
429
+ result_emoji = "✅"
430
+ else:
431
+ result_text = "❌ **Attack Failed**"
432
+ result_emoji = "❌"
433
+
434
+ result_md = f"""## 🔍 MIA Attack Result
435
+
436
+ # ===== Build the visualization block as a separate string =====
437
+ viz_block = (
438
+ " Member Zone (Low Loss) Non-Member Zone (High Loss)\n"
439
+ " <--------------------|----------------------->\n"
440
+ " Threshold\n"
441
+ "\n"
442
+ f" [{bar_visual}]\n"
443
+ " | | |\n"
444
+ " Member Mean Threshold Non-Member Mean\n"
445
+ f" {m_mean:.4f} {threshold:.4f} {nm_mean:.4f}\n"
446
+ "\n"
447
+ f" Current Loss = {loss:.4f}\n"
448
+ f" Position: {position_text}\n"
449
+ )
450
+
451
+ # ===== Build the warning/safe message =====
452
+ if pred_member:
453
+ warning_msg = (
454
+ f"⚠️ **Privacy Risk!** This sample Loss = {loss:.4f} "
455
+ f"is BELOW the threshold ({threshold:.4f}). "
456
+ "The model 'remembers' this data — student privacy may be compromised!"
457
+ )
458
+ else:
459
+ warning_msg = (
460
+ f"✅ This sample Loss = {loss:.4f} "
461
+ f"is ABOVE the threshold ({threshold:.4f}). "
462
+ "The model shows no special memorization of this data."
463
+ )
464
+
465
+ # ===== Assemble the final Markdown =====
466
+ result_md = (
467
+ "## 🔍 MIA Attack Result\n\n"
468
+ "### 📊 Loss Computation\n\n"
469
+ "| Metric | Value |\n"
470
+ "|--------|-------|\n"
471
+ f"| **Sample Loss** | `{loss:.6f}` |\n"
472
+ f"| **Decision Threshold** | `{threshold:.6f}` |\n"
473
+ f"| **Member Mean Loss** | `{m_mean:.6f}` |\n"
474
+ f"| **Non-Member Mean Loss** | `{nm_mean:.6f}` |\n\n"
475
+ "### 📏 Loss Position Visualization\n\n"
476
+ "```\n"
477
+ f"{viz_block}"
478
+ "```\n\n"
479
+ "### 🎯 Attack Judgment\n\n"
480
+ "| Item | Result |\n"
481
+ "|------|--------|\n"
482
+ f"| **Attacker Prediction** | {result_icon} {pred_text} |\n"
483
+ f"| **Actual Identity** | {actual_text} |\n"
484
+ f"| **Attack Outcome** | {result_icon} **{result_text}** |\n\n"
485
+ "### 💡 How It Works\n\n"
486
+ "The model produces **lower Loss** on data it was **trained on** "
487
+ "(it's more \"confident\"). The attacker exploits this statistical difference:\n\n"
488
+ f"- Loss **below** threshold `{threshold:.4f}` → Predicted as **training member** → ⚠️ Privacy risk\n"
489
+ f"- Loss **above** threshold `{threshold:.4f}` → Predicted as **non-member** → ✅ Relatively safe\n\n"
490
+ f"{warning_msg}\n\n"
491
+ "> 📌 *This demo uses real Loss values saved from the experiment (not real-time inference).*\n"
492
+ )
493
+
494
+ question_display = f"**📝 Sample #{idx}:**\n\n{sample['question'][:600]}"
495
+ return question_display, result_md
496
+
497
+ ### 🎯 Attack Judgment
498
+
499
+ | Item | Result |
500
+ |------|--------|
501
+ | **Attacker Prediction** | {pred_text} |
502
+ | **Actual Identity** | {actual_text} |
503
+ | **Attack Outcome** | {result_emoji} {result_text} |
504
+
505
+ ### 💡 How It Works
506
+
507
+ The model produces **lower Loss** on data it was **trained on** (it's more "confident").
508
+ The attacker exploits this statistical difference:
509
+
510
+ - Loss **below** threshold `{threshold:.4f}` → Predicted as **training member** → ⚠️ Privacy risk
511
+ - Loss **above** threshold `{threshold:.4f}` → Predicted as **non-member** → ✅ Relatively safe
512
+
513
+ {"⚠️ **Privacy Risk!** This sample's Loss = " + f"{loss:.4f}" + " is BELOW the threshold. The model 'remembers' this data — student privacy may be compromised!" if pred_member else "✅ This sample's Loss = " + f"{loss:.4f}" + " is ABOVE the threshold. The model shows no special memorization of this data."}
514
+
515
+ > 📌 *This demo uses real Loss values saved from the experiment (not real-time inference).*
516
+ """
517
+
518
+ question_display = f"**📝 Sample #{idx}:**\n\n{sample['question'][:600]}"
519
+ return question_display, result_md
520
+
521
+
522
+ # ========================================
523
+ # 4. 构建完整 Gradio 界面
524
+ # ========================================
525
+
526
+ custom_css = """
527
+ .gradio-container {
528
+ max-width: 1280px !important;
529
+ margin: auto !important;
530
+ }
531
+ .tab-nav button {
532
+ font-size: 15px !important;
533
+ padding: 10px 18px !important;
534
+ font-weight: 600 !important;
535
+ }
536
+ footer {
537
+ display: none !important;
538
+ }
539
+ """
540
+
541
+ # 预先取出常用数值(避免在 Markdown 中报错)
542
+ bl_auc = mia_results.get('baseline', {}).get('auc', 0)
543
+ s002_auc = mia_results.get('smooth_0.02', {}).get('auc', 0)
544
+ s02_auc = mia_results.get('smooth_0.2', {}).get('auc', 0)
545
+ op001_auc = perturb_results.get('perturbation_0.01', {}).get('auc', 0)
546
+ op0015_auc = perturb_results.get('perturbation_0.015', {}).get('auc', 0)
547
+ op002_auc = perturb_results.get('perturbation_0.02', {}).get('auc', 0)
548
+
549
+ bl_acc = utility_results.get('baseline', {}).get('accuracy', 0) * 100
550
+ s002_acc = utility_results.get('smooth_0.02', {}).get('accuracy', 0) * 100
551
+ s02_acc = utility_results.get('smooth_0.2', {}).get('accuracy', 0) * 100
552
+
553
+ model_name_str = config.get('model_name', 'Qwen/Qwen2.5-Math-1.5B-Instruct')
554
+ gpu_name_str = config.get('gpu_name', 'T4')
555
+ data_size_str = config.get('data_size', 2000)
556
+ setup_date_str = config.get('setup_date', 'N/A')
557
+
558
+
559
+ with gr.Blocks(
560
+ title="Education LLM Privacy Attack & Defense",
561
+ theme=gr.themes.Soft(
562
+ primary_hue="blue",
563
+ secondary_hue="sky",
564
+ neutral_hue="slate"
565
+ ),
566
+ css=custom_css
567
+ ) as demo:
568
+
569
+ # ============================
570
+ # Header
571
+ # ============================
572
+ gr.Markdown(f"""
573
+ # 🎓 Membership Inference Attack & Defense in Educational LLMs
574
+ ### 教育大模型中的成员推理攻击及其防御研究
575
+
576
+ ---
577
+
578
+ > **Goal**: Investigate privacy leakage risks in educational LLMs and evaluate **Label Smoothing** + **Output Perturbation** as defense strategies.
579
+
580
+ > **Tech Stack**: `Qwen2.5-Math-1.5B` · `LoRA Fine-tuning` · `Loss-based MIA` · `Label Smoothing` · `Output Perturbation`
581
+ """)
582
+
583
+ # ============================
584
+ # Tab 1: Project Overview
585
+ # ============================
586
+ with gr.Tab("🏠 Project Overview"):
587
+
588
+ overview_md = (
589
+ "## 📖 Research Background\n\n"
590
+ "As LLMs are increasingly deployed in education (tutoring systems, personalized learning),\n"
591
+ "they inevitably process student **private data** (names, IDs, grades).\n\n"
592
+ "**Membership Inference Attack (MIA)** can determine whether a data sample was used to train\n"
593
+ "the model, potentially exposing student privacy.\n\n"
594
+ "---\n\n"
595
+ "## 🔬 Research Design\n\n"
596
+ "| Phase | Content | Details |\n"
597
+ "|-------|---------|--------|\n"
598
+ "| 📂 Data | 2000 math tutoring dialogues | Contains names, student IDs, grades |\n"
599
+ "| 🧠 Training | Qwen2.5-Math + LoRA | Baseline + 2 label smoothing models |\n"
600
+ "| ⚔️ Attack | Loss-based MIA | Classify members by output loss |\n"
601
+ "| 🛡️ Train-time Defense | Label Smoothing (e=0.02, 0.2) | Regularization during training |\n"
602
+ "| 🛡️ Inference-time Defense | Output Perturbation (s=0.01~0.02) | Add noise at inference |\n"
603
+ "| 📊 Evaluation | Privacy-Utility Trade-off | AUC + Accuracy |\n\n"
604
+ "---\n\n"
605
+ "## ⚙️ Experiment Configuration\n\n"
606
+ "| Item | Value |\n"
607
+ "|------|-------|\n"
608
+ f"| **Base Model** | {model_name_str} |\n"
609
+ "| **Fine-tuning** | LoRA (r=8, alpha=16) |\n"
610
+ "| **Training Epochs** | 10 |\n"
611
+ f"| **Dataset Size** | {data_size_str} (1000 member + 1000 non-member) |\n"
612
+ f"| **GPU** | {gpu_name_str} |\n"
613
+ f"| **Date** | {setup_date_str} |\n\n"
614
+ "---\n\n"
615
+ "## 📐 Technical Pipeline\n\n"
616
+ "```\n"
617
+ "+-------------+ +-------------------+ +-----------+ +-------------------+ +------------+\n"
618
+ "| Data Gen | --> | Baseline Training | --> | MIA Attack| --> | Defense Deploy | --> | Evaluation |\n"
619
+ "| (2000) | | (LoRA fine-tune) | | (Loss) | | (LS + OP) | | (AUC+Acc) |\n"
620
+ "+-------------+ +--------+----------+ +-----------+ +-------------------+ +------------+\n"
621
+ " | |\n"
622
+ " +-- Label Smoothing Training --------------+\n"
623
+ " (e=0.02, e=0.2)\n"
624
+ "```\n"
625
+ )
626
+
627
+ gr.Markdown(overview_md)
628
+ # ============================
629
+ # Tab 2: Data Explorer
630
+ # ============================
631
+ with gr.Tab("📊 Data Explorer"):
632
+ gr.Markdown("""
633
+ ## 📂 Dataset Overview
634
+
635
+ - **Member data (Training set)**: 1000 samples — used to train the model
636
+ - **Non-member data (Test set)**: 1000 samples — NOT used in training
637
+ - Each sample contains **student privacy**: name, student ID, class, score
638
+ """)
639
+
640
+ with gr.Row():
641
+ with gr.Column(scale=1):
642
+ gr.Markdown("### 📊 Task Distribution")
643
+ gr.Plot(value=make_pie_chart())
644
+
645
+ with gr.Column(scale=1):
646
+ gr.Markdown("### 🔍 Random Sample Viewer")
647
+ data_type_selector = gr.Radio(
648
+ choices=[
649
+ "Member Data (Training Set / 成员数据)",
650
+ "Non-Member Data (Test Set / 非成员数据)"
651
+ ],
652
+ value="Member Data (Training Set / 成员数据)",
653
+ label="Select Data Type"
654
+ )
655
+ sample_btn = gr.Button(
656
+ "🎲 Random Sample", variant="primary"
657
+ )
658
+
659
+ sample_info = gr.Markdown()
660
+ with gr.Row():
661
+ sample_q = gr.Textbox(
662
+ label="📝 Student Question", lines=7, interactive=False
663
+ )
664
+ sample_a = gr.Textbox(
665
+ label="💡 Model Answer", lines=7, interactive=False
666
+ )
667
+
668
+ sample_btn.click(
669
+ fn=show_random_sample,
670
+ inputs=[data_type_selector],
671
+ outputs=[sample_info, sample_q, sample_a]
672
+ )
673
+
674
+ # ============================
675
+ # Tab 3: MIA Attack Demo
676
+ # ============================
677
+ with gr.Tab("⚔️ MIA Attack Demo"):
678
+ gr.Markdown("""
679
+ ## ⚔️ Live Membership Inference Attack
680
+
681
+ **How it works**: The model produces **lower Loss** on training data (it's more "confident").
682
+ The attacker uses a **threshold** on Loss to predict membership.
683
+
684
+ ### 📌 Steps:
685
+ 1️⃣ Select data source (Member / Non-Member)
686
+ 2️⃣ Choose a sample index (0-999)
687
+ 3️⃣ Click **"Run Attack"** to see the result
688
+ """)
689
+
690
+ with gr.Row():
691
+ with gr.Column(scale=1):
692
+ atk_data_type = gr.Radio(
693
+ choices=[
694
+ "Member Data (成员数据)",
695
+ "Non-Member Data (非成员数据)"
696
+ ],
697
+ value="Member Data (成员数据)",
698
+ label="📂 Data Source"
699
+ )
700
+ atk_index = gr.Slider(
701
+ minimum=0, maximum=999, step=1, value=0,
702
+ label="📌 Sample Index (0-999)"
703
+ )
704
+ atk_btn = gr.Button(
705
+ "⚔️ Run MIA Attack",
706
+ variant="primary",
707
+ size="lg"
708
+ )
709
+
710
+ with gr.Column(scale=1):
711
+ atk_question = gr.Markdown(label="Sample Content")
712
+
713
+ atk_result = gr.Markdown()
714
+
715
+ atk_btn.click(
716
+ fn=run_mia_demo,
717
+ inputs=[atk_index, atk_data_type],
718
+ outputs=[atk_question, atk_result]
719
+ )
720
+
721
+ # ============================
722
+ # Tab 4: Defense Comparison
723
+ # ============================
724
+ with gr.Tab("🛡️ Defense Comparison"):
725
+ gr.Markdown("""
726
+ ## 🛡️ Defense Strategy Comparison
727
+
728
+ | Strategy | Type | Mechanism | Pros | Cons |
729
+ |----------|------|-----------|------|------|
730
+ | **Label Smoothing** | Train-time | Soften labels to prevent overfitting | Reduces memorization | May hurt utility |
731
+ | **Output Perturbation** | Inference-time | Add Gaussian noise to Loss | Zero utility loss | Only masks signal |
732
+ """)
733
+
734
+ with gr.Row():
735
+ with gr.Column():
736
+ gr.Markdown("### 📊 AUC Comparison (All Defenses)")
737
+ gr.Plot(value=make_auc_bar())
738
+
739
+ with gr.Column():
740
+ gr.Markdown("### 📈 Loss Distribution (Baseline vs LS)")
741
+ gr.Plot(value=make_loss_distribution())
742
+
743
+ # Results table
744
+ gr.Markdown("### 📋 Detailed Results")
745
+
746
+ def risk_badge(auc_val):
747
+ if auc_val > 0.62:
748
+ return "🔴 High"
749
+ elif auc_val > 0.55:
750
+ return "🟡 Medium"
751
+ else:
752
+ return "🟢 Low"
753
+
754
+ table = "| Strategy | Type | AUC | Privacy Risk |\n"
755
+ table += "|----------|------|-----|-------------|\n"
756
+
757
+ for k, name, cat in [
758
+ ('baseline', 'Baseline (No Defense)', '—'),
759
+ ('smooth_0.02', 'Label Smoothing e=0.02', 'Train-time'),
760
+ ('smooth_0.2', 'Label Smoothing e=0.2', 'Train-time'),
761
+ ]:
762
+ if k in mia_results:
763
+ a = mia_results[k]['auc']
764
+ table += f"| {name} | {cat} | **{a:.4f}** | {risk_badge(a)} |\n"
765
+
766
+ for k, name in [
767
+ ('perturbation_0.01', 'Output Perturbation s=0.01'),
768
+ ('perturbation_0.015', 'Output Perturbation s=0.015'),
769
+ ('perturbation_0.02', 'Output Perturbation s=0.02'),
770
+ ]:
771
+ if k in perturb_results:
772
+ a = perturb_results[k]['auc']
773
+ table += f"| {name} | Inference-time | **{a:.4f}** | {risk_badge(a)} |\n"
774
+
775
+ gr.Markdown(table)
776
+
777
+ # ============================
778
+ # Tab 5: Output Perturbation
779
+ # ============================
780
+ with gr.Tab("🔊 Output Perturbation"):
781
+ gr.Markdown(f"""
782
+ ## 🔊 Output Perturbation Defense
783
+
784
+ ### 📌 Core Idea
785
+
786
+ At **inference time**, add **Gaussian noise** to the model's output Loss:
787
+
788
+ **Loss_perturbed = Loss_original + N(0, sigma^2)**
789
+
790
+ ### ✅ Key Advantage
791
+ - **No retraining needed** (zero deployment cost)
792
+ - **No utility loss** (accuracy stays exactly the same)
793
+ - Noise level sigma can be tuned dynamically
794
+
795
+ ### 📊 Experiment Results
796
+
797
+ | sigma | AUC | AUC Reduction | Accuracy | Note |
798
+ |-------|-----|--------------|----------|------|
799
+ | 0 (Baseline) | **{bl_auc:.4f}** | — | {bl_acc:.1f}% | No defense |
800
+ | 0.01 | **{op001_auc:.4f}** | ↓{bl_auc - op001_auc:.4f} | {bl_acc:.1f}% (unchanged) | Mild |
801
+ | 0.015 | **{op0015_auc:.4f}** | ↓{bl_auc - op0015_auc:.4f} | {bl_acc:.1f}% (unchanged) | Moderate |
802
+ | 0.02 | **{op002_auc:.4f}** | ↓{bl_auc - op002_auc:.4f} | {bl_acc:.1f}% (unchanged) | **Recommended** |
803
+
804
+ ### 💡 Key Finding
805
+
806
+ > Output Perturbation (s=0.02) reduces AUC from {bl_auc:.4f} to **{op002_auc:.4f}**
807
+ > while keeping accuracy at **{bl_acc:.1f}%** — truly a **zero-cost defense**!
808
+ """)
809
+
810
+ # ============================
811
+ # Tab 6: Utility Evaluation
812
+ # ============================
813
+ with gr.Tab("📝 Utility Evaluation"):
814
+ gr.Markdown("""
815
+ ## 📐 Model Utility Evaluation
816
+
817
+ > Defense must not sacrifice too much utility.
818
+ > Test set: **300 math questions** covering calculation, word problems, and concept Q&A.
819
+ """)
820
+
821
+ with gr.Row():
822
+ with gr.Column():
823
+ gr.Markdown("### 📊 Accuracy Comparison")
824
+ gr.Plot(value=make_accuracy_bar())
825
+ with gr.Column():
826
+ gr.Markdown("### ⚖️ Privacy-Utility Trade-off")
827
+ gr.Plot(value=make_tradeoff())
828
+
829
+ # Utility table
830
+ ut = "| Strategy | Accuracy | AUC | Risk | Utility Impact |\n"
831
+ ut += "|----------|----------|-----|------|---------------|\n"
832
+
833
+ for k, name in [
834
+ ('baseline', 'Baseline'),
835
+ ('smooth_0.02', 'LS e=0.02'),
836
+ ('smooth_0.2', 'LS e=0.2'),
837
+ ]:
838
+ if k in utility_results and k in mia_results:
839
+ acc = utility_results[k]['accuracy'] * 100
840
+ auc = mia_results[k]['auc']
841
+ impact = "—" if k == 'baseline' else (
842
+ "✅ Improved" if acc > bl_acc else "⚠️ Decreased"
843
+ )
844
+ ut += f"| {name} | **{acc:.1f}%** | {auc:.4f} | {risk_badge(auc)} | {impact} |\n"
845
+
846
+ for k, name in [
847
+ ('perturbation_0.01', 'OP s=0.01'),
848
+ ('perturbation_0.02', 'OP s=0.02'),
849
+ ]:
850
+ if k in perturb_results:
851
+ ut += (f"| {name} | **{bl_acc:.1f}%** | "
852
+ f"{perturb_results[k]['auc']:.4f} | "
853
+ f"{risk_badge(perturb_results[k]['auc'])} | ✅ No change |\n")
854
+
855
+ gr.Markdown(ut)
856
+
857
+ # ============================
858
+ # Tab 7: Paper Figures
859
+ # ============================
860
+ with gr.Tab("📄 Paper Figures"):
861
+ gr.Markdown("## 📄 Publication-Quality Figures (300 DPI)")
862
+
863
+ figure_items = [
864
+ ("fig1_loss_distribution_comparison.png",
865
+ "Figure 1: Loss Distribution — Baseline vs Label Smoothing"),
866
+ ("fig2_privacy_utility_tradeoff_fixed.png",
867
+ "Figure 2: Privacy-Utility Trade-off Analysis"),
868
+ ("fig3_defense_comparison_bar.png",
869
+ "Figure 3: Defense Mechanism AUC Comparison"),
870
+ ]
871
+
872
+ for filename, caption in figure_items:
873
+ path = os.path.join(BASE_DIR, "figures", filename)
874
+ if os.path.exists(path):
875
+ gr.Markdown(f"### {caption}")
876
+ gr.Image(value=path, show_label=False, height=420)
877
+ gr.Markdown("---")
878
+ else:
879
+ gr.Markdown(
880
+ f"### {caption}\n\n"
881
+ f"> ⚠️ File not found: `figures/{filename}` — "
882
+ f"this figure is optional."
883
+ )
884
+
885
+ # ============================
886
+ # Tab 8: Conclusions
887
+ # ============================
888
+ with gr.Tab("🎓 Conclusions"):
889
+ gr.Markdown(f"""
890
+ ## 📝 Research Conclusions
891
+
892
+ ---
893
+
894
+ ### 🔬 Core Findings
895
+
896
+ #### Finding 1: MIA poses a real threat to educational LLMs
897
+ Baseline AUC = **{bl_auc:.4f}** (significantly above random guess of 0.5).
898
+ Attackers can infer student data membership with high probability.
899
+
900
+ #### Finding 2: Label Smoothing effectively reduces risk
901
+
902
+ | Strategy | AUC | Accuracy | Verdict |
903
+ |----------|-----|----------|---------|
904
+ | Baseline | {bl_auc:.4f} | {bl_acc:.1f}% | High privacy risk |
905
+ | LS e=0.02 | {s002_auc:.4f} | {s002_acc:.1f}% | ✅ **Recommended** |
906
+ | LS e=0.2 | {s02_auc:.4f} | {s02_acc:.1f}% | ⚠️ Strong defense, possible utility impact |
907
+
908
+ #### Finding 3: Output Perturbation is a zero-cost defense
909
+ sigma=0.02 reduces AUC from {bl_auc:.4f} to **{op002_auc:.4f}** with **zero accuracy loss**.
910
+
911
+ #### Finding 4: Best practice — combine both defenses
912
+ > **Recommended**: LS e=0.02 (training) + OP s=0.02 (inference) = **Dual Protection**
913
+
914
+ ---
915
+
916
+ ### 🎤 Defense Presentation Script
917
+
918
+ > "This study uses a math tutoring scenario with Qwen2.5-Math-1.5B + LoRA fine-tuning.
919
+ > The baseline model shows AUC={bl_auc:.4f}, indicating significant privacy leakage risk.
920
+ > We evaluate two complementary defenses: **Label Smoothing** (train-time, e=0.02)
921
+ > and **Output Perturbation** (inference-time, s=0.02).
922
+ > Output perturbation achieves **zero utility loss**, making it ideal for practical deployment.
923
+ > The study reveals the fundamental privacy-utility trade-off in educational AI."
924
+
925
+ ---
926
+
927
+ ### 📚 Innovation Points
928
+
929
+ 1. **Novel scenario** — Focus on educational LLM privacy (not general NLP)
930
+ 2. **Dual defense** — Both train-time and inference-time strategies
931
+ 3. **Practical** — Label smoothing = 1 line of code; Output perturbation = 1 line of code
932
+ 4. **Comprehensive** — Attack + Defense + Utility + Trade-off analysis
933
+
934
+ ---
935
+
936
+ ### 🔮 Future Work
937
+
938
+ - Explore **Differential Privacy (DP-SGD)** for stronger guarantees
939
+ - Test **Shadow Model Attack** and other advanced MIA variants
940
+ - Validate on real educational datasets
941
+ - Investigate **Federated Learning** for educational model privacy
942
+ """)
943
+
944
+ # ============================
945
+ # Footer
946
+ # ============================
947
+ gr.Markdown("""
948
+ ---
949
+ <center>
950
+
951
+ 🎓 **Membership Inference Attack & Defense in Educational LLMs**
952
+
953
+ `Qwen2.5-Math-1.5B` · `LoRA` · `MIA` · `Label Smoothing` · `Output Perturbation` · `Gradio`
954
+
955
+ </center>
956
+ """)
957
+
958
+
959
+ # ========================================
960
+ # 5. Launch
961
+ # ========================================
962
+ demo.launch()
963
+