xiaohy commited on
Commit
4c5f3ae
·
verified ·
1 Parent(s): 9ad7f6a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +567 -274
app.py CHANGED
@@ -9,23 +9,25 @@ import matplotlib.pyplot as plt
9
  import gradio as gr
10
 
11
  # ========================================
12
- # 1. 数据加载 (保持你原有的优秀逻辑不变)
13
  # ========================================
14
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
15
 
 
16
  def load_json(path):
17
- try:
18
- with open(os.path.join(BASE_DIR, path), 'r', encoding='utf-8') as f:
19
- return json.load(f)
20
- except Exception as e:
21
- print(f"Warning: Could not load {path}. Error: {e}")
22
- return {}
23
 
24
  def clean_text(text):
 
 
25
  text = re.sub(r'[\U00010000-\U0010ffff]', '', text)
 
26
  text = text.encode('utf-8', errors='ignore').decode('utf-8')
27
  return text
28
 
 
29
  member_data = load_json("data/member.json")
30
  non_member_data = load_json("data/non_member.json")
31
  mia_results = load_json("results/mia_results.json")
@@ -34,52 +36,57 @@ perturb_results = load_json("results/perturbation_results.json")
34
  full_results = load_json("results/mia_full_results.json")
35
  config = load_json("config.json")
36
 
37
- plt.rcParams['font.sans-serif'] = ['DejaVu Sans', 'SimHei', 'Arial']
38
  plt.rcParams['axes.unicode_minus'] = False
39
 
40
  # 预取数值
41
- bl_auc = mia_results.get('baseline', {}).get('auc', 0.6308)
42
- s002_auc = mia_results.get('smooth_0.02', {}).get('auc', 0.6223)
43
- s02_auc = mia_results.get('smooth_0.2', {}).get('auc', 0.5869)
44
- op001_auc = perturb_results.get('perturbation_0.01', {}).get('auc', 0.6120)
45
- op0015_auc = perturb_results.get('perturbation_0.015', {}).get('auc', 0.6025)
46
- op002_auc = perturb_results.get('perturbation_0.02', {}).get('auc', 0.5947)
47
-
48
- bl_acc = utility_results.get('baseline', {}).get('accuracy', 0.633) * 100
49
- s002_acc = utility_results.get('smooth_0.02', {}).get('accuracy', 0.747) * 100
50
- s02_acc = utility_results.get('smooth_0.2', {}).get('accuracy', 0.710) * 100
51
-
52
- bl_m_mean = mia_results.get('baseline', {}).get('member_loss_mean', 0.1992)
53
- bl_nm_mean = mia_results.get('baseline', {}).get('non_member_loss_mean', 0.2126)
54
- bl_m_std = mia_results.get('baseline', {}).get('member_loss_std', 0.0323)
55
- bl_nm_std = mia_results.get('baseline', {}).get('non_member_loss_std', 0.0371)
56
 
57
  model_name_str = config.get('model_name', 'Qwen/Qwen2.5-Math-1.5B-Instruct')
58
  gpu_name_str = config.get('gpu_name', 'T4')
59
  data_size_str = str(config.get('data_size', 2000))
60
  setup_date_str = config.get('setup_date', 'N/A')
61
 
62
- # 如果数据为空,提供假数据防止报错 (为了界面调试)
63
- if not member_data:
64
- member_data = [{"task_type": "calculation", "metadata": {"name": "张三", "student_id": "001", "class": "1班", "score": 90}, "question": "1+1=?", "answer": "2"}]
65
- if not non_member_data:
66
- non_member_data = [{"task_type": "concept", "metadata": {"name": "李四", "student_id": "002", "class": "2班", "score": 85}, "question": "什么是质数?", "answer": "只有1和它本身两个因数的自然数。"}]
67
-
68
 
69
  # ========================================
70
- # 2. 图表函数 (保持原样,绘图本身已经很不错)
71
  # ========================================
 
72
  def make_pie_chart():
73
  task_counts = {}
74
  for item in member_data + non_member_data:
75
  t = item.get('task_type', 'unknown')
76
  task_counts[t] = task_counts.get(t, 0) + 1
77
- name_map = {'calculation': 'Calculation', 'word_problem': 'Word Problem', 'concept': 'Concept Q&A', 'error_correction': 'Error Correction'}
 
 
 
 
 
78
  labels = [name_map.get(k, k) for k in task_counts]
79
  sizes = list(task_counts.values())
80
  colors = ['#5B8FF9', '#5AD8A6', '#F6BD16', '#E86452']
81
  fig, ax = plt.subplots(figsize=(7, 5.5))
82
- wedges, texts, autotexts = ax.pie(sizes, labels=labels, autopct='%1.1f%%', colors=colors[:len(labels)], startangle=90, textprops={'fontsize': 11}, wedgeprops={'edgecolor': 'white', 'linewidth': 2})
 
 
 
 
 
83
  for t in autotexts:
84
  t.set_fontsize(11)
85
  t.set_fontweight('bold')
@@ -87,6 +94,7 @@ def make_pie_chart():
87
  plt.tight_layout()
88
  return fig
89
 
 
90
  def make_loss_distribution():
91
  plot_items = []
92
  for k, t in [('baseline', 'Baseline'), ('smooth_0.02', 'LS (e=0.02)'), ('smooth_0.2', 'LS (e=0.2)')]:
@@ -95,12 +103,12 @@ def make_loss_distribution():
95
  plot_items.append((k, t + " | AUC=" + f"{auc:.4f}"))
96
  n = len(plot_items)
97
  if n == 0:
98
- fig, ax = plt.subplots(figsize=(6,4))
99
- ax.text(0.5, 0.5, 'Requires full results data', ha='center', color='gray')
100
- ax.axis('off')
101
  return fig
102
  fig, axes = plt.subplots(1, n, figsize=(5.5 * n, 4.5))
103
- if n == 1: axes = [axes]
 
104
  for ax, (k, title) in zip(axes, plot_items):
105
  m = full_results[k]['member_losses']
106
  nm = full_results[k]['non_member_losses']
@@ -117,26 +125,34 @@ def make_loss_distribution():
117
  plt.tight_layout()
118
  return fig
119
 
 
120
  def make_auc_bar():
121
  methods, aucs, colors = [], [], []
122
- items = [('baseline', 'Baseline', '#8C8C8C'), ('smooth_0.02', 'LS (e=0.02)', '#5B8FF9'), ('smooth_0.2', 'LS (e=0.2)', '#3D76DD')]
 
 
 
 
123
  for k, name, c in items:
124
  if k in mia_results:
125
- methods.append(name); aucs.append(mia_results[k]['auc']); colors.append(c)
126
- p_items = [('perturbation_0.01', 'OP (s=0.01)', '#5AD8A6'), ('perturbation_0.015', 'OP (s=0.015)', '#2EAD78'), ('perturbation_0.02', 'OP (s=0.02)', '#1A7F5A')]
 
 
 
 
 
 
127
  for k, name, c in p_items:
128
  if k in perturb_results:
129
- methods.append(name); aucs.append(perturb_results[k]['auc']); colors.append(c)
130
-
131
- if not methods: # Fallback dummy data if JSON not loaded
132
- methods = ['Baseline', 'LS (e=0.02)', 'LS (e=0.2)', 'OP (s=0.015)']
133
- aucs = [0.6308, 0.6223, 0.5869, 0.6025]
134
- colors = ['#8C8C8C', '#5B8FF9', '#3D76DD', '#2EAD78']
135
-
136
  fig, ax = plt.subplots(figsize=(10, 5.5))
137
  bars = ax.bar(methods, aucs, color=colors, width=0.52, edgecolor='white', linewidth=1.5)
138
  for bar, a in zip(bars, aucs):
139
- ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.003, f'{a:.4f}', ha='center', va='bottom', fontsize=11, fontweight='bold')
 
140
  ax.axhline(y=0.5, color='#E86452', linestyle='--', linewidth=1.5, alpha=0.7, label='Random Guess (0.5)')
141
  ax.set_ylabel('MIA AUC', fontsize=12)
142
  ax.set_title('Defense Mechanisms - AUC Comparison', fontsize=14, fontweight='bold')
@@ -149,23 +165,28 @@ def make_auc_bar():
149
  plt.tight_layout()
150
  return fig
151
 
 
152
  def make_tradeoff():
153
  fig, ax = plt.subplots(figsize=(9, 6.5))
154
  points = []
155
- for k, name, marker, color, sz in [('baseline', 'Baseline', 'o', '#8C8C8C', 200), ('smooth_0.02', 'LS (e=0.02)', 's', '#5B8FF9', 180), ('smooth_0.2', 'LS (e=0.2)', 's', '#3D76DD', 180)]:
 
 
 
156
  if k in mia_results and k in utility_results:
157
- points.append({'name': name, 'auc': mia_results[k]['auc'], 'acc': utility_results[k]['accuracy'], 'marker': marker, 'color': color, 'size': sz})
 
 
158
  base_acc = utility_results.get('baseline', {}).get('accuracy', 0.633)
159
- for k, name, marker, color, sz in [('perturbation_0.01', 'OP (s=0.01)', '^', '#5AD8A6', 190), ('perturbation_0.02', 'OP (s=0.02)', '^', '#1A7F5A', 190)]:
 
 
160
  if k in perturb_results:
161
- points.append({'name': name, 'auc': perturb_results[k]['auc'], 'acc': base_acc, 'marker': marker, 'color': color, 'size': sz})
162
-
163
- if not points: # Fallback dummy
164
- points = [{'name':'Baseline','auc':0.6308,'acc':0.633,'marker':'o','color':'#8C8C8C','size':200},
165
- {'name':'LS (e=0.02)','auc':0.6223,'acc':0.747,'marker':'s','color':'#5B8FF9','size':180}]
166
-
167
  for p in points:
168
- ax.scatter(p['acc'], p['auc'], label=p['name'], marker=p['marker'], color=p['color'], s=p['size'], edgecolors='white', linewidth=2, zorder=5)
 
169
  ax.axhline(y=0.5, color='#BFBFBF', linestyle='--', alpha=0.8, label='Random Guess')
170
  ax.set_xlabel('Model Utility (Accuracy)', fontsize=12, fontweight='bold')
171
  ax.set_ylabel('Privacy Risk (MIA AUC)', fontsize=12, fontweight='bold')
@@ -182,23 +203,27 @@ def make_tradeoff():
182
  plt.tight_layout()
183
  return fig
184
 
 
185
  def make_accuracy_bar():
186
  names, accs, colors = [], [], []
187
- for k, name, c in [('baseline', 'Baseline', '#8C8C8C'), ('smooth_0.02', 'LS (e=0.02)', '#5B8FF9'), ('smooth_0.2', 'LS (e=0.2)', '#3D76DD')]:
 
188
  if k in utility_results:
189
- names.append(name); accs.append(utility_results[k]['accuracy'] * 100); colors.append(c)
190
- base_pct = utility_results.get('baseline', {}).get('accuracy', 0.633) * 100
191
- for k, name, c in [('perturbation_0.01', 'OP (s=0.01)', '#5AD8A6'), ('perturbation_0.02', 'OP (s=0.02)', '#1A7F5A')]:
 
 
 
192
  if k in perturb_results:
193
- names.append(name); accs.append(base_pct); colors.append(c)
194
-
195
- if not names: # Fallback dummy
196
- names = ['Baseline', 'LS (e=0.02)', 'LS (e=0.2)']; accs = [63.3, 74.7, 71.0]; colors = ['#8C8C8C', '#5B8FF9', '#3D76DD']
197
-
198
  fig, ax = plt.subplots(figsize=(10, 5.5))
199
  bars = ax.bar(names, accs, color=colors, width=0.5, edgecolor='white', linewidth=1.5)
200
  for bar, acc in zip(bars, accs):
201
- ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.6, f'{acc:.1f}%', ha='center', va='bottom', fontsize=11, fontweight='bold')
 
202
  ax.set_ylabel('Accuracy (%)', fontsize=12)
203
  ax.set_title('Model Utility (300 Math Questions)', fontsize=14, fontweight='bold')
204
  ax.set_ylim(0, 100)
@@ -209,33 +234,52 @@ def make_accuracy_bar():
209
  plt.tight_layout()
210
  return fig
211
 
 
212
  def make_loss_gauge(loss_val, m_mean, nm_mean, threshold):
 
213
  fig, ax = plt.subplots(figsize=(8, 2.5))
 
 
214
  x_min = min(m_mean - 3 * bl_m_std, loss_val - 0.01)
215
  x_max = max(nm_mean + 3 * bl_nm_std, loss_val + 0.01)
216
 
 
217
  ax.axvspan(x_min, threshold, alpha=0.15, color='#5B8FF9')
 
218
  ax.axvspan(threshold, x_max, alpha=0.15, color='#E86452')
219
 
 
220
  ax.axvline(x=threshold, color='#595959', linewidth=2, linestyle='-', zorder=3)
221
- ax.text(threshold, 1.15, 'Threshold', ha='center', va='bottom', fontsize=9, fontweight='bold', color='#595959', transform=ax.get_xaxis_transform())
 
222
 
 
223
  ax.axvline(x=m_mean, color='#5B8FF9', linewidth=1.5, linestyle='--', alpha=0.7)
224
- ax.text(m_mean, -0.25, f'Member\n({m_mean:.4f})', ha='center', va='top', fontsize=8, color='#5B8FF9', transform=ax.get_xaxis_transform())
 
225
 
 
226
  ax.axvline(x=nm_mean, color='#E86452', linewidth=1.5, linestyle='--', alpha=0.7)
227
- ax.text(nm_mean, -0.25, f'Non-Member\n({nm_mean:.4f})', ha='center', va='top', fontsize=8, color='#E86452', transform=ax.get_xaxis_transform())
 
228
 
 
229
  is_member_zone = loss_val < threshold
230
  marker_color = '#5B8FF9' if is_member_zone else '#E86452'
231
- ax.plot(loss_val, 0.5, marker='v', markersize=18, color=marker_color, zorder=5, transform=ax.get_xaxis_transform())
 
232
  label_text = f'Loss={loss_val:.4f}'
233
- ax.text(loss_val, 0.75, label_text, ha='center', va='bottom', fontsize=10, fontweight='bold', color=marker_color, transform=ax.get_xaxis_transform(), bbox=dict(boxstyle='round,pad=0.3', facecolor='white', edgecolor=marker_color, alpha=0.9))
 
 
234
 
 
235
  member_center = (x_min + threshold) / 2
236
  nonmember_center = (threshold + x_max) / 2
237
- ax.text(member_center, 0.5, 'Member Zone', ha='center', va='center', fontsize=10, color='#5B8FF9', fontweight='bold', alpha=0.6, transform=ax.get_xaxis_transform())
238
- ax.text(nonmember_center, 0.5, 'Non-Member Zone', ha='center', va='center', fontsize=10, color='#E86452', fontweight='bold', alpha=0.6, transform=ax.get_xaxis_transform())
 
 
239
 
240
  ax.set_xlim(x_min, x_max)
241
  ax.set_yticks([])
@@ -246,33 +290,55 @@ def make_loss_gauge(loss_val, m_mean, nm_mean, threshold):
246
  plt.tight_layout()
247
  return fig
248
 
 
249
  def risk_badge(auc_val):
250
- if auc_val > 0.62: return "🔴 高风险"
251
- elif auc_val > 0.55: return "🟡 中风险"
252
- else: return "🟢 低风险"
 
 
 
 
253
 
254
  # ========================================
255
  # 3. 回调函数
256
  # ========================================
 
257
  def show_random_sample(data_type):
258
- data = member_data if data_type == "成员数据(训练集)" else non_member_data
 
 
 
259
  sample = data[np.random.randint(0, len(data))]
260
- meta = sample.get('metadata', {'name': '未知', 'student_id': '未知', 'class': '未知', 'score': '未知'})
261
- task_map = {'calculation': '基础计算', 'word_problem': '应用题', 'concept': '概念问答', 'error_correction': '错题订正'}
262
-
 
 
 
 
263
  info = (
264
- "### 🛡️ 截获的隐私元数据\n"
265
- f"- **姓名**: {meta.get('name', 'N/A')}\n"
266
- f"- **学号**: {meta.get('student_id', 'N/A')}\n"
267
- f"- **班级**: {meta.get('class', 'N/A')}\n"
268
- f"- **成绩**: {meta.get('score', 'N/A')} \n"
269
- f"- **类型**: {task_map.get(sample.get('task_type', ''), sample.get('task_type', ''))}\n"
 
 
 
270
  )
271
- return info, clean_text(sample.get('question', '')), clean_text(sample.get('answer', ''))
 
272
 
273
  def run_mia_demo(sample_index, data_type):
274
- is_member = (data_type == "成员数据(训练集)")
275
- data = member_data if is_member else non_member_data
 
 
 
 
 
276
  idx = min(int(sample_index), len(data) - 1)
277
  sample = data[idx]
278
 
@@ -282,241 +348,468 @@ def run_mia_demo(sample_index, data_type):
282
  elif not is_member and idx < len(bl.get('non_member_losses', [])):
283
  loss = bl['non_member_losses'][idx]
284
  else:
285
- loss = float(np.random.normal(bl_m_mean, 0.02)) if is_member else float(np.random.normal(bl_nm_mean, 0.02))
 
 
 
286
 
287
  threshold = (bl_m_mean + bl_nm_mean) / 2.0
288
  pred_member = (loss < threshold)
289
  actual_member = is_member
290
  attack_correct = (pred_member == actual_member)
291
 
 
292
  gauge_fig = make_loss_gauge(loss, bl_m_mean, bl_nm_mean, threshold)
293
 
294
- pred_html = f"<span style='color: {'#E86452' if pred_member else '#5AD8A6'}; font-weight: bold;'>{'训练成员 🔴' if pred_member else '非训练成员 🟢'}</span>"
295
- actual_html = f"<span style='color: {'#E86452' if actual_member else '#5AD8A6'}; font-weight: bold;'>{'训练成员 🔴' if actual_member else '非训练成员 🟢'}</span>"
296
-
 
 
 
 
 
 
 
 
 
 
 
297
  if attack_correct and pred_member and actual_member:
298
- res_style = "background: #ffebe9; border-left: 5px solid #E86452; color: #cf222e;"
299
- res_title = "⚠️ 攻击成功:发生了隐私泄露"
300
- res_desc = "模型对该样本过于熟悉(Loss低于阈值),攻击者成功判定其为训练集数据!"
301
  elif attack_correct:
302
- res_style = "background: #e6ffec; border-left: 5px solid #5AD8A6; color: #1a7f37;"
303
- res_title = "✅ 判定正确:未发生隐私泄露"
304
- res_desc = "样本确实未参与训练,且模型表现出正常的陌生感。"
 
 
 
 
 
 
 
 
 
305
  else:
306
- res_style = "background: #f6f8fa; border-left: 5px solid #8c8c8c; color: #57606a;"
307
- res_title = " 攻击失误:攻击者推断失败"
308
- res_desc = "攻击者的判断与事实不符,隐私得到了保护。"
309
-
310
- # 使用 HTML 输出精美的高亮结果板
311
- result_html = f"""
312
- <div style='{res_style} padding: 15px; border-radius: 8px; margin-bottom: 15px;'>
313
- <h3 style='margin-top: 0; margin-bottom: 8px; display: flex; align-items: center; font-size: 1.2em;'>{res_title}</h3>
314
- <p style='margin: 0; font-size: 0.95em;'>{res_desc}</p>
315
- </div>
316
- <div style='display: grid; grid-template-columns: 1fr 1fr; gap: 10px; font-size: 1em;'>
317
- <div style='background: #fff; padding: 12px; border-radius: 6px; border: 1px solid #eaecef;'>
318
- <div style='color: #57606a; font-size: 0.85em; margin-bottom: 4px;'>攻击者计算得出</div>
319
- <div>{pred_html}</div>
320
- <div style='color: #8c8c8c; font-size: 0.85em; margin-top: 4px;'>Sample Loss: {loss:.4f}</div>
321
- </div>
322
- <div style='background: #fff; padding: 12px; border-radius: 6px; border: 1px solid #eaecef;'>
323
- <div style='color: #57606a; font-size: 0.85em; margin-bottom: 4px;'>系统真实身份</div>
324
- <div>{actual_html}</div>
325
- <div style='color: #8c8c8c; font-size: 0.85em; margin-top: 4px;'>System Threshold: {threshold:.4f}</div>
326
- </div>
327
- </div>
328
- """
329
-
330
- question_display = f"**样本追踪号 [ {idx} ] :**\n\n> {clean_text(sample.get('question', '')[:600])}"
331
- return question_display, gauge_fig, result_html
332
 
333
 
334
  # ========================================
335
- # 4. 前端精装修 CSS & Gradio 构建
336
  # ========================================
337
 
338
  custom_css = """
339
- /* 全局设定 */
340
- body { background-color: #f7f9fc !important; }
341
  .gradio-container {
342
  max-width: 1200px !important;
343
- margin: 2rem auto !important;
344
  font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif !important;
345
  }
346
 
347
- /* 核心升级:精美卡片式布局 */
348
- .custom-card {
349
- background: #ffffff !important;
350
- border-radius: 12px !important;
351
- box-shadow: 0 4px 15px rgba(0, 0, 0, 0.05) !important;
352
- padding: 24px !important;
353
- border: 1px solid #edf2f9 !important;
354
- margin-bottom: 16px !important;
355
- }
356
-
357
- /* Tab 导航美化 */
358
- .tab-nav {
359
- border-bottom: 2px solid #edf2f9 !important;
360
- margin-bottom: 20px !important;
361
- }
362
  .tab-nav button {
363
- font-size: 15px !important;
364
- padding: 12px 20px !important;
365
  font-weight: 500 !important;
366
- color: #5e6e82 !important;
367
  border-radius: 8px 8px 0 0 !important;
368
- transition: all 0.2s ease !important;
369
  }
370
- .tab-nav button:hover { background: #f8faff !important; color: #5B8FF9 !important; }
371
  .tab-nav button.selected {
372
  font-weight: 700 !important;
373
- color: #2b52ff !important;
374
  border-bottom: 3px solid #5B8FF9 !important;
375
- background: transparent !important;
376
  }
377
 
378
- /* 标题样式重构 */
379
  .prose h1 {
380
- font-size: 2rem !important;
381
  color: #1a1a2e !important;
382
- border-bottom: 3px solid #5B8FF9 !important;
383
- padding-bottom: 12px !important;
384
- margin-bottom: 24px !important;
385
- font-weight: 800 !important;
386
  }
387
  .prose h2 {
388
- font-size: 1.4rem !important;
389
- color: #2d3748 !important;
390
- margin-top: 0.5em !important;
391
- border-left: 4px solid #5B8FF9 !important;
392
- padding-left: 10px !important;
393
  }
394
-
395
- /* 引用块设计感 */
396
- .prose blockquote {
397
- border-left: 4px solid #5AD8A6 !important;
398
- background: #f0fbf7 !important;
399
- padding: 16px 20px !important;
400
- margin: 16px 0 !important;
401
- border-radius: 0 8px 8px 0 !important;
402
- color: #2d3748 !important;
403
  }
404
 
405
- /* 按钮微调 */
406
- button.primary {
407
- background: linear-gradient(135deg, #5B8FF9 0%, #3D76DD 100%) !important;
408
- border: none !important;
409
- box-shadow: 0 4px 6px rgba(91, 143, 249, 0.2) !important;
410
  }
411
- button.primary:hover {
412
- transform: translateY(-1px);
413
- box-shadow: 0 6px 10px rgba(91, 143, 249, 0.3) !important;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
  }
415
 
416
- /* 隐藏 Footer */
417
  footer { display: none !important; }
418
  """
419
 
 
420
  with gr.Blocks(
421
  title="教育大模型隐私攻防实验",
422
- theme=gr.themes.Soft(primary_hue="blue", neutral_hue="slate"),
423
  css=custom_css
424
  ) as demo:
425
 
426
- gr.Markdown("# 🎓 教育大模型中的成员推理攻击及其防御研究\n\n> 探究教育大语言模型的隐私泄露风险,验证 **标签平滑** 与 **输出扰动** 两种防御策略的有效性及其对模型效用的影响。")
 
 
 
 
 
 
 
427
 
428
- with gr.Tabs():
429
- # --- Tab 1: 项目概览 ---
430
- with gr.Tab("项目概览"):
431
- with gr.Column(elem_classes="custom-card"):
432
- gr.Markdown(
433
- "## 📌 研究背景\n\n"
434
- "大语言模型在教育领域(智能辅导系统)应用广泛其训练数据往往包含学生敏感隐私。"
435
- "**成员推理攻击 (MIA)** 以判断某条数据是否被模型“��住”,从而引发隐私泄露威胁\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436
  )
437
- with gr.Row():
438
- with gr.Column(elem_classes="custom-card"):
439
- gr.Markdown(
440
- "## ⚙️ 实验配置\n"
441
- f"- **基座模型**: `{model_name_str}`\n"
442
- f"- **微调方法**: LoRA (r=8, alpha=16)\n"
443
- f"- **数据总量**: {data_size_str} 条 (1:1 成员与非成员)\n"
444
- f"- **计算硬件**: {gpu_name_str}\n"
445
- )
446
- with gr.Column(elem_classes="custom-card"):
447
- gr.Markdown(
448
- "## 🛡️ 防御架构\n"
449
- "- **训练期 - 标签平滑**: 软化目标标签,抑制模型过拟合。\n"
450
- "- **推理期 - 输出扰动**: 注入高斯噪声,物理隔绝攻击者对 Loss 的精确探测。\n"
451
- )
452
-
453
- # --- Tab 2: 数据展示 ---
454
- with gr.Tab("数据展示"):
455
- with gr.Row():
456
- with gr.Column(scale=1, elem_classes="custom-card"):
457
- gr.Markdown("## 📊 数据分布概况")
458
- gr.Plot(value=make_pie_chart())
459
-
460
- with gr.Column(scale=1, elem_classes="custom-card"):
461
- gr.Markdown("## 🎲 数据抽样与隐私探查")
462
- data_sel = gr.Radio(choices=["成员数据(训练集)", "非成员数据(测试集)"], value="成员数据(训练集)", label="选择靶向数据池", interactive=True)
463
- sample_btn = gr.Button("🔍 随机提取样本与隐私字典", variant="primary")
464
- sample_info = gr.Markdown()
465
-
466
- with gr.Column(elem_classes="custom-card"):
467
- gr.Markdown("### 📄 原始对话内容")
468
- with gr.Row():
469
- sample_q = gr.Textbox(label="🧑‍🎓 学生提问 (Prompt)", lines=5)
470
- sample_a = gr.Textbox(label="🤖 模型回答 (Ground Truth)", lines=5)
471
-
472
- sample_btn.click(fn=show_random_sample, inputs=[data_sel], outputs=[sample_info, sample_q, sample_a])
473
-
474
- # --- Tab 3: MIA攻击演示 ---
475
- with gr.Tab("MIA攻击演示"):
476
- with gr.Row():
477
- with gr.Column(scale=1, elem_classes="custom-card"):
478
- gr.Markdown("## 🥷 发起黑盒 API 攻击")
479
- gr.Markdown("调整下方滑块选择一条截获的数据,系统将计算该条数据的 Loss 值并实施判定。")
480
- atk_data_type = gr.Radio(choices=["成员数据(训练集)", "非成员数据(测试集)"], value="成员数据(训练集)", label="模拟真实数据来源")
481
- atk_index = gr.Slider(minimum=0, maximum=999, step=1, value=12, label="样本游标 ID (0-999)")
482
- atk_btn = gr.Button("⚡ 执行成员推理攻击", variant="primary", size="lg")
483
- atk_question = gr.Markdown(elem_classes="prose")
484
-
485
- with gr.Column(scale=1, elem_classes="custom-card"):
486
- gr.Markdown("## 📡 攻击侦测控制台")
487
- atk_gauge = gr.Plot(label="Loss 分布雷达")
488
- atk_result_html = gr.HTML() # 改用 HTML 渲染精美面板
489
-
490
- atk_btn.click(fn=run_mia_demo, inputs=[atk_index, atk_data_type], outputs=[atk_question, atk_gauge, atk_result_html])
491
-
492
- # --- Tab 4: 防御对比 ---
493
- with gr.Tab("防御对比"):
494
- with gr.Row():
495
- with gr.Column(elem_classes="custom-card"):
496
- gr.Markdown("## 📉 隐私风险 (AUC) 宏观对比")
497
- gr.Plot(value=make_auc_bar())
498
- with gr.Column(elem_classes="custom-card"):
499
- gr.Markdown("## 🔔 底层 Loss 分布位移")
500
- gr.Plot(value=make_loss_distribution())
501
-
502
- # --- Tab 5: 效用评估 ---
503
- with gr.Tab("效用评估"):
504
- with gr.Row():
505
- with gr.Column(elem_classes="custom-card"):
506
- gr.Markdown("## 🎯 模型数学能力基准 (Accuracy)")
507
- gr.Plot(value=make_accuracy_bar())
508
- with gr.Column(elem_classes="custom-card"):
509
- gr.Markdown("## ⚖️ 隐私-效用权衡空间 (Trade-off)")
510
- gr.Plot(value=make_tradeoff())
511
-
512
- # --- Tab 6: 研究结论 ---
513
- with gr.Tab("研究结论"):
514
- with gr.Column(elem_classes="custom-card"):
515
- gr.Markdown(
516
- "## 💡 核心学术贡献\n\n"
517
- "1. **确认了教育场景下的内生风险**:证明了基于大模型的教育应用在未经特殊处理时,极易向外部暴露学生的学情数据(基线 AUC=0.6308)。\n\n"
518
- "2. **论证了正则化与隐私防御的协同性**:适度的标签平滑(ε=0.02)不仅没有降低模型智商,反而由于抑制了过拟合,使数学准确率提升至 74.7%,是一种“双赢”的内建防御��制。\n\n"
519
- "3. **验证了推断期防御的高效性**:输出扰动(σ=0.02)作为一种零效用损耗的插件方案,能有效混淆黑盒探测者的统计雷达,尤其适合算力受限或模型已定型的生产环境。\n"
520
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
521
 
 
 
 
522
  demo.launch()
 
9
  import gradio as gr
10
 
11
  # ========================================
12
+ # 1. 数据加载
13
  # ========================================
14
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
15
 
16
+
17
  def load_json(path):
18
+ with open(os.path.join(BASE_DIR, path), 'r', encoding='utf-8') as f:
19
+ return json.load(f)
20
+
 
 
 
21
 
22
  def clean_text(text):
23
+ """清理文本中的特殊字符和emoji,防止乱码"""
24
+ # 移除emoji和特殊Unicode字符
25
  text = re.sub(r'[\U00010000-\U0010ffff]', '', text)
26
+ # 移除其他可能导致乱码的字符
27
  text = text.encode('utf-8', errors='ignore').decode('utf-8')
28
  return text
29
 
30
+
31
  member_data = load_json("data/member.json")
32
  non_member_data = load_json("data/non_member.json")
33
  mia_results = load_json("results/mia_results.json")
 
36
  full_results = load_json("results/mia_full_results.json")
37
  config = load_json("config.json")
38
 
39
+ plt.rcParams['font.sans-serif'] = ['DejaVu Sans']
40
  plt.rcParams['axes.unicode_minus'] = False
41
 
42
  # 预取数值
43
+ bl_auc = mia_results.get('baseline', {}).get('auc', 0)
44
+ s002_auc = mia_results.get('smooth_0.02', {}).get('auc', 0)
45
+ s02_auc = mia_results.get('smooth_0.2', {}).get('auc', 0)
46
+ op001_auc = perturb_results.get('perturbation_0.01', {}).get('auc', 0)
47
+ op0015_auc = perturb_results.get('perturbation_0.015', {}).get('auc', 0)
48
+ op002_auc = perturb_results.get('perturbation_0.02', {}).get('auc', 0)
49
+
50
+ bl_acc = utility_results.get('baseline', {}).get('accuracy', 0) * 100
51
+ s002_acc = utility_results.get('smooth_0.02', {}).get('accuracy', 0) * 100
52
+ s02_acc = utility_results.get('smooth_0.2', {}).get('accuracy', 0) * 100
53
+
54
+ bl_m_mean = mia_results.get('baseline', {}).get('member_loss_mean', 0.19)
55
+ bl_nm_mean = mia_results.get('baseline', {}).get('non_member_loss_mean', 0.23)
56
+ bl_m_std = mia_results.get('baseline', {}).get('member_loss_std', 0.03)
57
+ bl_nm_std = mia_results.get('baseline', {}).get('non_member_loss_std', 0.03)
58
 
59
  model_name_str = config.get('model_name', 'Qwen/Qwen2.5-Math-1.5B-Instruct')
60
  gpu_name_str = config.get('gpu_name', 'T4')
61
  data_size_str = str(config.get('data_size', 2000))
62
  setup_date_str = config.get('setup_date', 'N/A')
63
 
 
 
 
 
 
 
64
 
65
  # ========================================
66
+ # 2. 图表函数
67
  # ========================================
68
+
69
  def make_pie_chart():
70
  task_counts = {}
71
  for item in member_data + non_member_data:
72
  t = item.get('task_type', 'unknown')
73
  task_counts[t] = task_counts.get(t, 0) + 1
74
+ name_map = {
75
+ 'calculation': 'Calculation',
76
+ 'word_problem': 'Word Problem',
77
+ 'concept': 'Concept Q&A',
78
+ 'error_correction': 'Error Correction'
79
+ }
80
  labels = [name_map.get(k, k) for k in task_counts]
81
  sizes = list(task_counts.values())
82
  colors = ['#5B8FF9', '#5AD8A6', '#F6BD16', '#E86452']
83
  fig, ax = plt.subplots(figsize=(7, 5.5))
84
+ wedges, texts, autotexts = ax.pie(
85
+ sizes, labels=labels, autopct='%1.1f%%',
86
+ colors=colors[:len(labels)],
87
+ startangle=90, textprops={'fontsize': 11},
88
+ wedgeprops={'edgecolor': 'white', 'linewidth': 2}
89
+ )
90
  for t in autotexts:
91
  t.set_fontsize(11)
92
  t.set_fontweight('bold')
 
94
  plt.tight_layout()
95
  return fig
96
 
97
+
98
  def make_loss_distribution():
99
  plot_items = []
100
  for k, t in [('baseline', 'Baseline'), ('smooth_0.02', 'LS (e=0.02)'), ('smooth_0.2', 'LS (e=0.2)')]:
 
103
  plot_items.append((k, t + " | AUC=" + f"{auc:.4f}"))
104
  n = len(plot_items)
105
  if n == 0:
106
+ fig, ax = plt.subplots()
107
+ ax.text(0.5, 0.5, 'No data', ha='center')
 
108
  return fig
109
  fig, axes = plt.subplots(1, n, figsize=(5.5 * n, 4.5))
110
+ if n == 1:
111
+ axes = [axes]
112
  for ax, (k, title) in zip(axes, plot_items):
113
  m = full_results[k]['member_losses']
114
  nm = full_results[k]['non_member_losses']
 
125
  plt.tight_layout()
126
  return fig
127
 
128
+
129
  def make_auc_bar():
130
  methods, aucs, colors = [], [], []
131
+ items = [
132
+ ('baseline', 'Baseline', '#8C8C8C'),
133
+ ('smooth_0.02', 'LS (e=0.02)', '#5B8FF9'),
134
+ ('smooth_0.2', 'LS (e=0.2)', '#3D76DD'),
135
+ ]
136
  for k, name, c in items:
137
  if k in mia_results:
138
+ methods.append(name)
139
+ aucs.append(mia_results[k]['auc'])
140
+ colors.append(c)
141
+ p_items = [
142
+ ('perturbation_0.01', 'OP (s=0.01)', '#5AD8A6'),
143
+ ('perturbation_0.015', 'OP (s=0.015)', '#2EAD78'),
144
+ ('perturbation_0.02', 'OP (s=0.02)', '#1A7F5A'),
145
+ ]
146
  for k, name, c in p_items:
147
  if k in perturb_results:
148
+ methods.append(name)
149
+ aucs.append(perturb_results[k]['auc'])
150
+ colors.append(c)
 
 
 
 
151
  fig, ax = plt.subplots(figsize=(10, 5.5))
152
  bars = ax.bar(methods, aucs, color=colors, width=0.52, edgecolor='white', linewidth=1.5)
153
  for bar, a in zip(bars, aucs):
154
+ ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.003,
155
+ f'{a:.4f}', ha='center', va='bottom', fontsize=11, fontweight='bold')
156
  ax.axhline(y=0.5, color='#E86452', linestyle='--', linewidth=1.5, alpha=0.7, label='Random Guess (0.5)')
157
  ax.set_ylabel('MIA AUC', fontsize=12)
158
  ax.set_title('Defense Mechanisms - AUC Comparison', fontsize=14, fontweight='bold')
 
165
  plt.tight_layout()
166
  return fig
167
 
168
+
169
  def make_tradeoff():
170
  fig, ax = plt.subplots(figsize=(9, 6.5))
171
  points = []
172
+ for k, name, marker, color, sz in [
173
+ ('baseline', 'Baseline', 'o', '#8C8C8C', 200),
174
+ ('smooth_0.02', 'LS (e=0.02)', 's', '#5B8FF9', 180),
175
+ ('smooth_0.2', 'LS (e=0.2)', 's', '#3D76DD', 180)]:
176
  if k in mia_results and k in utility_results:
177
+ points.append({'name': name, 'auc': mia_results[k]['auc'],
178
+ 'acc': utility_results[k]['accuracy'],
179
+ 'marker': marker, 'color': color, 'size': sz})
180
  base_acc = utility_results.get('baseline', {}).get('accuracy', 0.633)
181
+ for k, name, marker, color, sz in [
182
+ ('perturbation_0.01', 'OP (s=0.01)', '^', '#5AD8A6', 190),
183
+ ('perturbation_0.02', 'OP (s=0.02)', '^', '#1A7F5A', 190)]:
184
  if k in perturb_results:
185
+ points.append({'name': name, 'auc': perturb_results[k]['auc'],
186
+ 'acc': base_acc, 'marker': marker, 'color': color, 'size': sz})
 
 
 
 
187
  for p in points:
188
+ ax.scatter(p['acc'], p['auc'], label=p['name'], marker=p['marker'],
189
+ color=p['color'], s=p['size'], edgecolors='white', linewidth=2, zorder=5)
190
  ax.axhline(y=0.5, color='#BFBFBF', linestyle='--', alpha=0.8, label='Random Guess')
191
  ax.set_xlabel('Model Utility (Accuracy)', fontsize=12, fontweight='bold')
192
  ax.set_ylabel('Privacy Risk (MIA AUC)', fontsize=12, fontweight='bold')
 
203
  plt.tight_layout()
204
  return fig
205
 
206
+
207
  def make_accuracy_bar():
208
  names, accs, colors = [], [], []
209
+ for k, name, c in [('baseline', 'Baseline', '#8C8C8C'), ('smooth_0.02', 'LS (e=0.02)', '#5B8FF9'),
210
+ ('smooth_0.2', 'LS (e=0.2)', '#3D76DD')]:
211
  if k in utility_results:
212
+ names.append(name)
213
+ accs.append(utility_results[k]['accuracy'] * 100)
214
+ colors.append(c)
215
+ base_pct = utility_results.get('baseline', {}).get('accuracy', 0) * 100
216
+ for k, name, c in [('perturbation_0.01', 'OP (s=0.01)', '#5AD8A6'),
217
+ ('perturbation_0.02', 'OP (s=0.02)', '#1A7F5A')]:
218
  if k in perturb_results:
219
+ names.append(name)
220
+ accs.append(base_pct)
221
+ colors.append(c)
 
 
222
  fig, ax = plt.subplots(figsize=(10, 5.5))
223
  bars = ax.bar(names, accs, color=colors, width=0.5, edgecolor='white', linewidth=1.5)
224
  for bar, acc in zip(bars, accs):
225
+ ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.6,
226
+ f'{acc:.1f}%', ha='center', va='bottom', fontsize=11, fontweight='bold')
227
  ax.set_ylabel('Accuracy (%)', fontsize=12)
228
  ax.set_title('Model Utility (300 Math Questions)', fontsize=14, fontweight='bold')
229
  ax.set_ylim(0, 100)
 
234
  plt.tight_layout()
235
  return fig
236
 
237
+
238
  def make_loss_gauge(loss_val, m_mean, nm_mean, threshold):
239
+ """生成精致的Loss位置可视化图表(替代粗糙的ASCII字符画)"""
240
  fig, ax = plt.subplots(figsize=(8, 2.5))
241
+
242
+ # 绘制底部色条
243
  x_min = min(m_mean - 3 * bl_m_std, loss_val - 0.01)
244
  x_max = max(nm_mean + 3 * bl_nm_std, loss_val + 0.01)
245
 
246
+ # 成员区域(蓝色渐变)
247
  ax.axvspan(x_min, threshold, alpha=0.15, color='#5B8FF9')
248
+ # 非成员区域(红色渐变)
249
  ax.axvspan(threshold, x_max, alpha=0.15, color='#E86452')
250
 
251
+ # 阈值线
252
  ax.axvline(x=threshold, color='#595959', linewidth=2, linestyle='-', zorder=3)
253
+ ax.text(threshold, 1.15, 'Threshold', ha='center', va='bottom', fontsize=9,
254
+ fontweight='bold', color='#595959', transform=ax.get_xaxis_transform())
255
 
256
+ # 成员均值标记
257
  ax.axvline(x=m_mean, color='#5B8FF9', linewidth=1.5, linestyle='--', alpha=0.7)
258
+ ax.text(m_mean, -0.25, f'Member\n({m_mean:.4f})', ha='center', va='top', fontsize=8,
259
+ color='#5B8FF9', transform=ax.get_xaxis_transform())
260
 
261
+ # 非成员均值标记
262
  ax.axvline(x=nm_mean, color='#E86452', linewidth=1.5, linestyle='--', alpha=0.7)
263
+ ax.text(nm_mean, -0.25, f'Non-Member\n({nm_mean:.4f})', ha='center', va='top', fontsize=8,
264
+ color='#E86452', transform=ax.get_xaxis_transform())
265
 
266
+ # 当前样本标记(大箭头)
267
  is_member_zone = loss_val < threshold
268
  marker_color = '#5B8FF9' if is_member_zone else '#E86452'
269
+ ax.plot(loss_val, 0.5, marker='v', markersize=18, color=marker_color, zorder=5,
270
+ transform=ax.get_xaxis_transform())
271
  label_text = f'Loss={loss_val:.4f}'
272
+ ax.text(loss_val, 0.75, label_text, ha='center', va='bottom', fontsize=10,
273
+ fontweight='bold', color=marker_color, transform=ax.get_xaxis_transform(),
274
+ bbox=dict(boxstyle='round,pad=0.3', facecolor='white', edgecolor=marker_color, alpha=0.9))
275
 
276
+ # 区域标签
277
  member_center = (x_min + threshold) / 2
278
  nonmember_center = (threshold + x_max) / 2
279
+ ax.text(member_center, 0.5, 'Member Zone', ha='center', va='center', fontsize=10,
280
+ color='#5B8FF9', fontweight='bold', alpha=0.6, transform=ax.get_xaxis_transform())
281
+ ax.text(nonmember_center, 0.5, 'Non-Member Zone', ha='center', va='center', fontsize=10,
282
+ color='#E86452', fontweight='bold', alpha=0.6, transform=ax.get_xaxis_transform())
283
 
284
  ax.set_xlim(x_min, x_max)
285
  ax.set_yticks([])
 
290
  plt.tight_layout()
291
  return fig
292
 
293
+
294
  def risk_badge(auc_val):
295
+ if auc_val > 0.62:
296
+ return "High"
297
+ elif auc_val > 0.55:
298
+ return "Medium"
299
+ else:
300
+ return "Low"
301
+
302
 
303
  # ========================================
304
  # 3. 回调函数
305
  # ========================================
306
+
307
  def show_random_sample(data_type):
308
+ if data_type == "成员数据(训练集)":
309
+ data = member_data
310
+ else:
311
+ data = non_member_data
312
  sample = data[np.random.randint(0, len(data))]
313
+ meta = sample['metadata']
314
+ task_map = {
315
+ 'calculation': '基础计算',
316
+ 'word_problem': '应用题',
317
+ 'concept': '概念问答',
318
+ 'error_correction': '错题订正'
319
+ }
320
  info = (
321
+ "### 样本元信息(隐私字段)\n\n"
322
+ "| 字段 | 值 |\n"
323
+ "|------|-----|\n"
324
+ "| **姓名** | " + str(meta['name']) + " |\n"
325
+ "| **学号** | " + str(meta['student_id']) + " |\n"
326
+ "| **班级** | " + str(meta['class']) + " |\n"
327
+ "| **成绩** | " + str(meta['score']) + " 分 |\n"
328
+ "| **任务类型** | " + task_map.get(sample['task_type'], sample['task_type']) + " |\n\n"
329
+ "> 以上即为攻击者试图推断的 **学生隐私信息**\n"
330
  )
331
+ return info, clean_text(sample['question']), clean_text(sample['answer'])
332
+
333
 
334
  def run_mia_demo(sample_index, data_type):
335
+ if data_type == "成员数据(训练集)":
336
+ is_member = True
337
+ data = member_data
338
+ else:
339
+ is_member = False
340
+ data = non_member_data
341
+
342
  idx = min(int(sample_index), len(data) - 1)
343
  sample = data[idx]
344
 
 
348
  elif not is_member and idx < len(bl.get('non_member_losses', [])):
349
  loss = bl['non_member_losses'][idx]
350
  else:
351
+ if is_member:
352
+ loss = float(np.random.normal(bl_m_mean, 0.02))
353
+ else:
354
+ loss = float(np.random.normal(bl_nm_mean, 0.02))
355
 
356
  threshold = (bl_m_mean + bl_nm_mean) / 2.0
357
  pred_member = (loss < threshold)
358
  actual_member = is_member
359
  attack_correct = (pred_member == actual_member)
360
 
361
+ # 生成精致的可视化图表
362
  gauge_fig = make_loss_gauge(loss, bl_m_mean, bl_nm_mean, threshold)
363
 
364
+ if pred_member:
365
+ pred_text = "训练成员(Loss < 阈值,模型过于熟悉)"
366
+ pred_icon = "🔴"
367
+ else:
368
+ pred_text = "非训练成员(Loss >= 阈值,模型不熟悉)"
369
+ pred_icon = "🟢"
370
+
371
+ if actual_member:
372
+ actual_text = "是训练成员(此数据参与了训练)"
373
+ actual_icon = "🔴"
374
+ else:
375
+ actual_text = "非训练成员(此数据未参与训练)"
376
+ actual_icon = "🟢"
377
+
378
  if attack_correct and pred_member and actual_member:
379
+ result_text = "**攻击成功 -- 隐私泄露**"
380
+ result_icon = "⚠️"
 
381
  elif attack_correct:
382
+ result_text = "**判断正确**"
383
+ result_icon = "✅"
384
+ else:
385
+ result_text = "**攻击失误**"
386
+ result_icon = "❌"
387
+
388
+ if pred_member:
389
+ warning = (
390
+ "> **隐私风险** : 此样本 Loss = " + f"{loss:.4f}"
391
+ + " 低于阈值 " + f"{threshold:.4f}"
392
+ + ",模型对它过于熟悉,学生隐私可能被推断。"
393
+ )
394
  else:
395
+ warning = (
396
+ "> **相对安全** : 此样本 Loss = " + f"{loss:.4f}"
397
+ + " 高于阈值 " + f"{threshold:.4f}"
398
+ + ",模型对其无特殊记忆。"
399
+ )
400
+
401
+ result_md = (
402
+ "### Loss 计算结果\n\n"
403
+ "| 指标 | |\n"
404
+ "|------|-----|\n"
405
+ "| 样本 Loss | " + f"{loss:.6f}" + " |\n"
406
+ "| 判定阈值 | " + f"{threshold:.6f}" + " |\n"
407
+ "| 成员平均 Loss | " + f"{bl_m_mean:.6f}" + " |\n"
408
+ "| 非成员平均 Loss | " + f"{bl_nm_mean:.6f}" + " |\n\n"
409
+ "### 攻击判定\n\n"
410
+ "| 项目 | 结果 |\n"
411
+ "|------|------|\n"
412
+ "| 攻击者预测 | " + pred_icon + " " + pred_text + " |\n"
413
+ "| 实际身份 | " + actual_icon + " " + actual_text + " |\n"
414
+ "| 攻击结果 | " + result_icon + " " + result_text + " |\n\n"
415
+ + warning + "\n"
416
+ )
417
+
418
+ question_display = "**第 " + str(idx) + " 号样本 :**\n\n" + clean_text(sample['question'][:600])
419
+ return question_display, gauge_fig, result_md
 
420
 
421
 
422
  # ========================================
423
+ # 4. 构建界面
424
  # ========================================
425
 
426
  custom_css = """
427
+ /* 整体容器 */
 
428
  .gradio-container {
429
  max-width: 1200px !important;
430
+ margin: auto !important;
431
  font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif !important;
432
  }
433
 
434
+ /* Tab 按钮 */
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435
  .tab-nav button {
436
+ font-size: 14px !important;
437
+ padding: 10px 16px !important;
438
  font-weight: 500 !important;
 
439
  border-radius: 8px 8px 0 0 !important;
 
440
  }
 
441
  .tab-nav button.selected {
442
  font-weight: 700 !important;
 
443
  border-bottom: 3px solid #5B8FF9 !important;
 
444
  }
445
 
446
+ /* 标题样式 */
447
  .prose h1 {
448
+ font-size: 1.8rem !important;
449
  color: #1a1a2e !important;
450
+ border-bottom: 2px solid #5B8FF9 !important;
451
+ padding-bottom: 8px !important;
 
 
452
  }
453
  .prose h2 {
454
+ font-size: 1.35rem !important;
455
+ color: #16213e !important;
456
+ margin-top: 1.2em !important;
 
 
457
  }
458
+ .prose h3 {
459
+ font-size: 1.1rem !important;
460
+ color: #0f3460 !important;
 
 
 
 
 
 
461
  }
462
 
463
+ /* 表格美化 */
464
+ .prose table {
465
+ border-collapse: collapse !important;
466
+ width: 100% !important;
467
+ font-size: 0.9rem !important;
468
  }
469
+ .prose th {
470
+ background: #f0f5ff !important;
471
+ color: #1a1a2e !important;
472
+ font-weight: 600 !important;
473
+ padding: 10px 14px !important;
474
+ }
475
+ .prose td {
476
+ padding: 8px 14px !important;
477
+ border-bottom: 1px solid #eee !important;
478
+ }
479
+
480
+ /* 引用块 */
481
+ .prose blockquote {
482
+ border-left: 4px solid #5B8FF9 !important;
483
+ background: #f7f9fc !important;
484
+ padding: 12px 16px !important;
485
+ margin: 12px 0 !important;
486
+ border-radius: 0 6px 6px 0 !important;
487
  }
488
 
489
+ /* 隐藏底部 */
490
  footer { display: none !important; }
491
  """
492
 
493
+
494
  with gr.Blocks(
495
  title="教育大模型隐私攻防实验",
496
+ theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky", neutral_hue="slate"),
497
  css=custom_css
498
  ) as demo:
499
 
500
+ # ============================
501
+ # 顶部
502
+ # ============================
503
+ gr.Markdown(
504
+ "# 教育大模型中的成员推理攻击及其防御研究\n\n"
505
+ "> 探究教育场景下大语言模型的隐私泄露风险,"
506
+ "验证 **标签平滑** 与 **输出扰动** 两种防御策略的有效性及其对模型效用的影响。\n"
507
+ )
508
 
509
+ # ============================
510
+ # Tab 1: 项目概览
511
+ # ============================
512
+ with gr.Tab("项目概览"):
513
+ gr.Markdown(
514
+ "## 研究背景\n\n"
515
+ "随着大语言模型在教育领域的广泛应用(智能辅导系统、个性化学习推荐等),"
516
+ "模型训练过程中不避免地接触到学生敏感数据。"
517
+ "**成员推理攻击 (Membership Inference Attack, MIA)** 可以判断某条数据是否参与了模型训练,"
518
+ "进而推断学生的隐私信息。\n\n"
519
+ "---\n\n"
520
+ "## 研究设计\n\n"
521
+ "| 阶段 | 内容 | 说明 |\n"
522
+ "|------|------|------|\n"
523
+ "| 数据准备 | 2000条小学数学辅导对话 | 含姓名、学号、班级、成绩等隐私字段 |\n"
524
+ "| 模型训练 | Qwen2.5-Math-1.5B + LoRA | 基线模型 + 标签平滑模型 (e=0.02, 0.2) |\n"
525
+ "| 攻击测试 | Loss-based MIA | 利用模型输出Loss判断成员身份 |\n"
526
+ "| 训练期防御 | 标签平滑 | 软化训练标签,降低模型对训练数据的记忆程度 |\n"
527
+ "| 推理期防御 | 输出扰动 | 在推理阶段对输出Loss添加高斯噪声 |\n"
528
+ "| 综合评估 | 隐私-效用权衡分析 | AUC(隐私风险)+ 准确率(模型效用)|\n\n"
529
+ "---\n\n"
530
+ "## 实验配置\n\n"
531
+ "| 配置项 | 值 |\n"
532
+ "|--------|-----|\n"
533
+ "| 基座模型 | " + model_name_str + " |\n"
534
+ "| 微调方法 | LoRA (r=8, alpha=16, target: q/k/v/o_proj) |\n"
535
+ "| 训练轮数 | 10 epochs |\n"
536
+ "| 数据总量 | " + data_size_str + " 条 (成员1000 + 非成员1000) |\n"
537
+ "| GPU | " + gpu_name_str + " |\n\n"
538
+ "---\n\n"
539
+ "## 技术路线\n\n"
540
+ "| 步骤 | 阶段 | 方法 | 输出 |\n"
541
+ "|------|------|------|------|\n"
542
+ "| 1 | 数据生成 | 模板化生成2000条对话 | member.json + non_member.json |\n"
543
+ "| 2 | 基线训练 | LoRA微调Qwen2.5-Math | baseline模型 |\n"
544
+ "| 3 | 防御训练 | 标签平滑 (e=0.02, e=0.2) | smooth模型 x2 |\n"
545
+ "| 4 | MIA攻击 | 计算全量样本Loss,AUC评估 | mia_results.json |\n"
546
+ "| 5 | 输出扰动 | 对baseline Loss加高斯噪声 (s=0.01~0.02) | perturbation_results.json |\n"
547
+ "| 6 | 效用评估 | 300道数学测试题 | utility_results.json |\n"
548
+ "| 7 | 综合分析 | 隐私-效用权衡图 | 研究结论 |\n"
549
+ )
550
+
551
+ # ============================
552
+ # Tab 2: 数据展示
553
+ # ============================
554
+ with gr.Tab("数据展示"):
555
+ gr.Markdown(
556
+ "## 数据集概况\n\n"
557
+ "- **成员数据(训练集)**: 1000条,用于模型微调训练\n"
558
+ "- **非成员数据(测试集)**: 1000条,不参与训练,作为MIA攻击的对照组\n"
559
+ "- 每条数据均包含学生隐私字段(姓名、学号、班级、成绩),模拟真实教育场景\n"
560
+ )
561
+
562
+ with gr.Row():
563
+ with gr.Column(scale=1):
564
+ gr.Markdown("### 任务类型分布")
565
+ gr.Plot(value=make_pie_chart())
566
+ with gr.Column(scale=1):
567
+ gr.Markdown("### 随机查看样本")
568
+ data_sel = gr.Radio(
569
+ choices=["成员数据(训练集)", "非成员数据(测试集)"],
570
+ value="成员数据(训练集)",
571
+ label="数据类型"
572
+ )
573
+ sample_btn = gr.Button("随机抽取样本", variant="primary")
574
+
575
+ sample_info = gr.Markdown()
576
+ with gr.Row():
577
+ sample_q = gr.Textbox(label="学生提问", lines=6, interactive=False)
578
+ sample_a = gr.Textbox(label="模型回答", lines=6, interactive=False)
579
+
580
+ sample_btn.click(
581
+ fn=show_random_sample,
582
+ inputs=[data_sel],
583
+ outputs=[sample_info, sample_q, sample_a]
584
+ )
585
+
586
+ # ============================
587
+ # Tab 3: MIA攻击演示
588
+ # ============================
589
+ with gr.Tab("MIA攻击演示"):
590
+ gr.Markdown(
591
+ "## 成员推理攻击演示\n\n"
592
+ "**原理**: 模型对训练过的数据产生更低的Loss,"
593
+ "攻击者利用Loss与阈值的比较判断样本是否为训练成员。\n\n"
594
+ "1. 选择数据来源 (成员 / 非成员)\n"
595
+ "2. 拖动滑块选择样本编号\n"
596
+ "3. 点击 **执行攻击**\n"
597
+ )
598
+
599
+ with gr.Row():
600
+ with gr.Column(scale=1):
601
+ atk_data_type = gr.Radio(
602
+ choices=["成员数据(训练集)", "非成员数据(测试集)"],
603
+ value="成员数据(训练集)",
604
+ label="数据来源"
605
  )
606
+ atk_index = gr.Slider(
607
+ minimum=0, maximum=999, step=1, value=0,
608
+ label="样本编号 (0-999)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
609
  )
610
+ atk_btn = gr.Button("执行MIA攻击", variant="primary", size="lg")
611
+ with gr.Column(scale=1):
612
+ atk_question = gr.Markdown()
613
+
614
+ atk_gauge = gr.Plot(label="Loss位置可视化")
615
+ atk_result = gr.Markdown()
616
+
617
+ atk_btn.click(
618
+ fn=run_mia_demo,
619
+ inputs=[atk_index, atk_data_type],
620
+ outputs=[atk_question, atk_gauge, atk_result]
621
+ )
622
+
623
+ # ============================
624
+ # Tab 4: 防御对比
625
+ # ============================
626
+ with gr.Tab("防御对比"):
627
+ gr.Markdown(
628
+ "## 防御策略效果对比\n\n"
629
+ "| 策略 | 类型 | 原理 | 优势 | 局限 |\n"
630
+ "|------|------|------|------|------|\n"
631
+ "| 标签平滑 | 训练期 | 软化one-hot标签,抑制过拟合 | 从根���降低模型记忆 | 可能影响模型效用 |\n"
632
+ "| 输出扰动 | 推理期 | 对输出Loss添加高斯噪声 | 零效用损失,即插即用 | 仅遮蔽统计信号 |\n"
633
+ )
634
+
635
+ with gr.Row():
636
+ with gr.Column():
637
+ gr.Markdown("### AUC对比(所有防御策略)")
638
+ gr.Plot(value=make_auc_bar())
639
+ with gr.Column():
640
+ gr.Markdown("### Loss分布对比")
641
+ gr.Plot(value=make_loss_distribution())
642
+
643
+ table = (
644
+ "### 实验结果汇总\n\n"
645
+ "| 策略 | 类型 | AUC | 风险等级 |\n"
646
+ "|------|------|-----|----------|\n"
647
+ )
648
+ for k, name, cat in [('baseline', '基线 (无防御)', '--'),
649
+ ('smooth_0.02', '标签平滑 (e=0.02)', '训练期'),
650
+ ('smooth_0.2', '标签平滑 (e=0.2)', '训练期')]:
651
+ if k in mia_results:
652
+ a = mia_results[k]['auc']
653
+ table += "| " + name + " | " + cat + " | " + f"{a:.4f}" + " | " + risk_badge(a) + " |\n"
654
+ for k, name in [('perturbation_0.01', '输出扰动 (s=0.01)'),
655
+ ('perturbation_0.015', '输出扰动 (s=0.015)'),
656
+ ('perturbation_0.02', '输出扰动 (s=0.02)')]:
657
+ if k in perturb_results:
658
+ a = perturb_results[k]['auc']
659
+ table += "| " + name + " | 推理期 | " + f"{a:.4f}" + " | " + risk_badge(a) + " |\n"
660
+ gr.Markdown(table)
661
+
662
+ # ============================
663
+ # Tab 5: 防御详解(标签平滑 + 输出扰动)
664
+ # ============================
665
+ with gr.Tab("防御详解"):
666
+ gr.Markdown(
667
+ "## 防御策略详解\n\n"
668
+ "---\n\n"
669
+ "### 一、标签平滑 (Label Smoothing)\n\n"
670
+ "**类型** : 训练期防御\n\n"
671
+ "**原理** : 将训练标签从硬标签 (one-hot) 转换为软标签,"
672
+ "降低模型对训练样本的过度拟合程度,从而缩小成员与非成员之间的Loss差异。\n\n"
673
+ "**公式** : y_smooth = (1 - e) * y_onehot + e / V\n\n"
674
+ "其中 e 为平滑系数,V 为词汇表大小。\n\n"
675
+ "| 参数 | AUC | 准确率 | 分析 |\n"
676
+ "|------|-----|--------|------|\n"
677
+ "| 基线 (e=0) | " + f"{bl_auc:.4f}" + " | " + f"{bl_acc:.1f}" + "% | 无防御,MIA风险较高 |\n"
678
+ "| e=0.02 | " + f"{s002_auc:.4f}" + " | " + f"{s002_acc:.1f}" + "% | 温和防御,效用保持良好 |\n"
679
+ "| e=0.2 | " + f"{s02_auc:.4f}" + " | " + f"{s02_acc:.1f}" + "% | 强力防御,AUC显著下降 |\n\n"
680
+ "---\n\n"
681
+ "### 二、输出扰动 (Output Perturbation)\n\n"
682
+ "**类型** : 推理期防御\n\n"
683
+ "**原理** : 在推理阶段对模型返回的Loss值添加高斯噪声,"
684
+ "模糊成员与非成员之间的统计差异,使攻击者难以准确判别。\n\n"
685
+ "**公式** : Loss_perturbed = Loss_original + N(0, s^2)\n\n"
686
+ "**核心优势** : 不修改模型参数,准确率完全不变。\n\n"
687
+ "| 参数 | AUC | AUC降幅 | 准确率 |\n"
688
+ "|------|-----|---------|--------|\n"
689
+ "| 基线 (s=0) | " + f"{bl_auc:.4f}" + " | -- | " + f"{bl_acc:.1f}" + "% |\n"
690
+ "| s=0.01 | " + f"{op001_auc:.4f}" + " | " + f"{bl_auc - op001_auc:.4f}" + " | " + f"{bl_acc:.1f}" + "% (不变) |\n"
691
+ "| s=0.015 | " + f"{op0015_auc:.4f}" + " | " + f"{bl_auc - op0015_auc:.4f}" + " | " + f"{bl_acc:.1f}" + "% (不变) |\n"
692
+ "| s=0.02 | " + f"{op002_auc:.4f}" + " | " + f"{bl_auc - op002_auc:.4f}" + " | " + f"{bl_acc:.1f}" + "% (不变) |\n\n"
693
+ "---\n\n"
694
+ "### 三、综合对比\n\n"
695
+ "| 维度 | 标签平滑 | 输出扰动 |\n"
696
+ "|------|---------|----------|\n"
697
+ "| 作用阶段 | 训练期 | 推理期 |\n"
698
+ "| 是否需要重训 | 是 | 否 |\n"
699
+ "| 对效用的影响 | 可能有影响 | 无影响 |\n"
700
+ "| 防御机制 | 降低过拟合 | 遮蔽统计信号 |\n"
701
+ "| 可叠加使用 | 是 | 是 |\n\n"
702
+ "> **推荐方案** : 标签平滑 (e=0.02) + 输出扰动 (s=0.02) 双重防护\n"
703
+ )
704
+
705
+ # ============================
706
+ # Tab 6: 效用评估
707
+ # ============================
708
+ with gr.Tab("效用评估"):
709
+ gr.Markdown(
710
+ "## 模型效用评估\n\n"
711
+ "> 测试集: 300道数学题,覆盖基础计算、应用题、概念问答三类任务。\n"
712
+ )
713
+
714
+ with gr.Row():
715
+ with gr.Column():
716
+ gr.Markdown("### 准确率对比")
717
+ gr.Plot(value=make_accuracy_bar())
718
+ with gr.Column():
719
+ gr.Markdown("### 隐私-效用权衡")
720
+ gr.Plot(value=make_tradeoff())
721
+
722
+ ut = (
723
+ "### 效用评估详情\n\n"
724
+ "| 策略 | 准确率 | AUC | 风险等级 | 效用影响 |\n"
725
+ "|------|--------|-----|---------|----------|\n"
726
+ )
727
+ for k, name in [('baseline', '基线'), ('smooth_0.02', '标签平滑 (e=0.02)'),
728
+ ('smooth_0.2', '标签平滑 (e=0.2)')]:
729
+ if k in utility_results and k in mia_results:
730
+ acc = utility_results[k]['accuracy'] * 100
731
+ auc = mia_results[k]['auc']
732
+ impact = "--" if k == 'baseline' else ("提升" if acc > bl_acc else "下降")
733
+ ut += "| " + name + " | " + f"{acc:.1f}" + "% | " + f"{auc:.4f}" + " | " + risk_badge(auc) + " | " + impact + " |\n"
734
+ for k, name in [('perturbation_0.01', '输出扰动 (s=0.01)'), ('perturbation_0.02', '输出扰动 (s=0.02)')]:
735
+ if k in perturb_results:
736
+ ut += "| " + name + " | " + f"{bl_acc:.1f}" + "% | " + f"{perturb_results[k]['auc']:.4f}" + " | " + risk_badge(perturb_results[k]['auc']) + " | 无影响 |\n"
737
+ gr.Markdown(ut)
738
+
739
+ # ============================
740
+ # Tab 7: 论文图表
741
+ # ============================
742
+ with gr.Tab("论文图表"):
743
+ gr.Markdown("## 学术图表 (300 DPI)")
744
+ for fn, cap in [("fig1_loss_distribution_comparison.png", "图1 : Loss分布对比"),
745
+ ("fig2_privacy_utility_tradeoff_fixed.png", "图2 : 隐私-效用权衡"),
746
+ ("fig3_defense_comparison_bar.png", "图3 : 防御效果柱状图")]:
747
+ path = os.path.join(BASE_DIR, "figures", fn)
748
+ if os.path.exists(path):
749
+ gr.Markdown("### " + cap)
750
+ gr.Image(value=path, show_label=False, height=420)
751
+ gr.Markdown("---")
752
+ else:
753
+ gr.Markdown("### " + cap + "\n\n> 文件未找到: " + fn)
754
+
755
+ # ============================
756
+ # Tab 8: 研究结论
757
+ # ============================
758
+ with gr.Tab("研究结论"):
759
+ gr.Markdown(
760
+ "## 研究结论\n\n"
761
+ "---\n\n"
762
+ "### 一、教育大模型面临显著的成员推理攻击风险\n\n"
763
+ "实验结果表明,基于Qwen2.5-Math-1.5B经LoRA微调的教育辅导模型,"
764
+ "在面对基于Loss的成员推理攻击时,AUC达到 **" + f"{bl_auc:.4f}" + "**,"
765
+ "显著高于随机猜测基准 (0.5)。这意味着攻击者仅通过观察模型对某一样本的输出置信度,"
766
+ "即可以高于随机的概率推断该样本是否被纳入训练集。"
767
+ "在教育场景中,训练数据通常包含学生的姓名、学号、学业成绩等敏感信息,"
768
+ "上述攻击能力构成了切实的隐私威胁。\n\n"
769
+ "---\n\n"
770
+ "### 二、标签平滑作为训练期防御策略的有效性与局限性\n\n"
771
+ "标签平滑通过软化训练标签分布,抑制模型对训练样本的过度拟合,"
772
+ "从而缩��成员与非成员之间的Loss分布差异。实验中:\n\n"
773
+ "- **e=0.02** (温和平滑): AUC从 " + f"{bl_auc:.4f}" + " 降至 " + f"{s002_auc:.4f}"
774
+ + ",准确率为 " + f"{s002_acc:.1f}" + "%,在隐私保护与效用保持之间取得了较好的平衡。\n"
775
+ "- **e=0.2** (强力平滑): AUC进一步降至 " + f"{s02_auc:.4f}"
776
+ + ",防御效果更为显著,准确率为 " + f"{s02_acc:.1f}" + "%。\n\n"
777
+ "该结果揭示了标签平滑系数的选取需在隐私保护强度与模型效用之间进行权衡。"
778
+ "过小的平滑系数防御效果有限,而过大的系数可能影响模型在下游任务上的表现。\n\n"
779
+ "---\n\n"
780
+ "### 三、输出扰动作为推理期防御策略的独特优势\n\n"
781
+ "输出扰动在推理阶段对模型输出的Loss值注入高斯噪声,"
782
+ "其核心优势在于**完全不改变模型参数**,因此对模型效用无任何影响。实验中:\n\n"
783
+ "- **s=0.02**: AUC从 " + f"{bl_auc:.4f}" + " 降至 " + f"{op002_auc:.4f}"
784
+ + ",而准确率保持 " + f"{bl_acc:.1f}" + "% 不变。\n\n"
785
+ "这表明输出扰动是一种**零效用成本**的防御手段,"
786
+ "特别适合已部署的模型系统进行后期隐私加固,具有良好的工程实用性。\n\n"
787
+ "---\n\n"
788
+ "### 四、隐私-效用权衡的定量分析\n\n"
789
+ "综合所有实验结果,本研究揭示了教育大模型隐私保护中的核心矛盾:\n\n"
790
+ "| 策略 | AUC (隐私风险) | 准确率 (效用) | 特点 |\n"
791
+ "|------|----------------|--------------|------|\n"
792
+ "| 基线 (无防御) | " + f"{bl_auc:.4f}" + " | " + f"{bl_acc:.1f}" + "% | 风险最高 |\n"
793
+ "| 标签平滑 e=0.02 | " + f"{s002_auc:.4f}" + " | " + f"{s002_acc:.1f}" + "% | 训练期防御,效用保持良好 |\n"
794
+ "| 标签平滑 e=0.2 | " + f"{s02_auc:.4f}" + " | " + f"{s02_acc:.1f}" + "% | 强力防御 |\n"
795
+ "| 输出扰动 s=0.02 | " + f"{op002_auc:.4f}" + " | " + f"{bl_acc:.1f}" + "% | 零效用损失 |\n\n"
796
+ "上述分析表明,将**训练期标签平滑** (e=0.02) 与**推理期输出扰动** (s=0.02) 组合使用,"
797
+ "可以在两个独立维度上削弱攻击者的推断能力,实现更为全面的隐私保护,"
798
+ "同时将效用损失控制在可接受范围内。\n"
799
+ )
800
+
801
+ # ============================
802
+ # 底部
803
+ # ============================
804
+ gr.Markdown(
805
+ "---\n\n"
806
+ "<center>\n\n"
807
+ "教育大模型中的成员推理攻击及其防御思路研究\n\n"
808
+ "Qwen2.5-Math-1.5B | LoRA | MIA | Label Smoothing | Output Perturbation\n\n"
809
+ "</center>\n"
810
+ )
811
 
812
+ # ========================================
813
+ # 5. 启动
814
+ # ========================================
815
  demo.launch()