xiaohy commited on
Commit
42d2d6f
·
verified ·
1 Parent(s): 8c47de9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -96
app.py CHANGED
@@ -3,6 +3,8 @@
3
  # 1. 严格基于原版底座,UI界面和回调逻辑一字不改。
4
  # 2. 仅对所有的 fig_xxx 绘图函数进行了学术黑白化+花纹底纹改造。
5
  # 3. 修复了参数传递和括号语法问题,保证100%零报错运行。
 
 
6
  # ================================================================
7
 
8
  import os
@@ -15,6 +17,12 @@ import matplotlib.pyplot as plt
15
  from sklearn.metrics import roc_curve, roc_auc_score
16
  import gradio as gr
17
 
 
 
 
 
 
 
18
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
19
 
20
  # ================================================================
@@ -64,7 +72,7 @@ except FileNotFoundError:
64
  perturb_results[k]["non_member_loss_std"] = np.sqrt(0.03**2 + s**2)
65
 
66
  # ================================================================
67
- # 全局UI配置 (完全保留您的原版色彩,不影响网页和HTML元素的颜色)
68
  # ================================================================
69
  COLORS = {
70
  'bg': '#FFFFFF',
@@ -82,7 +90,7 @@ COLORS = {
82
  'op_colors': ['#98F5E1', '#6EE7B7', '#34D399', '#10B981', '#059669', '#047857'],
83
  }
84
 
85
- # 🌟 专门为图表新增的学术黑白配置集 (Hatch与线型)
86
  CHART_C = {
87
  'bg': '#FFFFFF',
88
  'panel': '#FFFFFF',
@@ -104,10 +112,7 @@ HATCH_NONMEMBER = '..'
104
  LS_LINESTYLES = ['-', '--', '-.', ':', (0, (3, 1, 1, 1))]
105
  OP_LINESTYLES = ['-', '--', '-.', ':', (0, (5, 1)), (0, (3, 1, 1, 1, 1, 1))]
106
 
107
- CHART_W = 14
108
-
109
- # 让黑白底纹在论文和PDF中更清晰
110
- plt.rcParams['hatch.linewidth'] = 1.1
111
 
112
  def apply_academic_style(fig, ax_or_axes):
113
  fig.patch.set_facecolor(CHART_C['bg'])
@@ -117,10 +122,13 @@ def apply_academic_style(fig, ax_or_axes):
117
  for spine in ax.spines.values():
118
  spine.set_color('#000000')
119
  spine.set_linewidth(1.0)
120
- ax.tick_params(colors='#000000', labelsize=10, width=1.0)
121
  ax.xaxis.label.set_color('#000000')
 
122
  ax.yaxis.label.set_color('#000000')
 
123
  ax.title.set_color('#000000')
 
124
  ax.title.set_fontweight('bold')
125
  ax.grid(True, color=CHART_C['grid'], alpha=0.8, linestyle='--', linewidth=0.5)
126
  ax.set_axisbelow(True)
@@ -184,20 +192,19 @@ for _i in range(300):
184
  EVAL_POOL.append(item)
185
 
186
  # ================================================================
187
- # 图表绘制函数 (全面转换为学术黑白+底纹格式)
188
  # ================================================================
189
  def fig_gauge(loss_val, m_mean, nm_mean, thr, m_std, nm_std):
190
- fig, ax = plt.subplots(figsize=(10, 2.6)); apply_academic_style(fig, ax)
191
  xlo = min(m_mean - 3.0 * m_std, loss_val - 0.005); xhi = max(nm_mean + 3.0 * nm_std, loss_val + 0.005)
192
- # 使用底纹区分判断区域
193
  ax.axvspan(xlo, thr, alpha=0.3, color=CHART_C['mem'], hatch=HATCH_MEMBER, edgecolor='black')
194
  ax.axvspan(thr, xhi, alpha=0.3, color=CHART_C['nmem'], hatch=HATCH_NONMEMBER, edgecolor='black')
195
  ax.axvline(m_mean, color='black', lw=2, ls=':', zorder=2)
196
- ax.text(m_mean - 0.002, 1.02, f'Member Mean\n{m_mean:.4f}', ha='right', va='bottom', fontsize=9, color='black', transform=ax.get_xaxis_transform())
197
  ax.axvline(nm_mean, color='black', lw=2, ls=':', zorder=2)
198
- ax.text(nm_mean + 0.002, 1.02, f'Non-Member Mean\n{nm_mean:.4f}', ha='left', va='bottom', fontsize=9, color='black', transform=ax.get_xaxis_transform())
199
  ax.axvline(thr, color='black', lw=2.5, ls='--', zorder=3)
200
- ax.text(thr, 1.25, f'Threshold\n{thr:.4f}', ha='center', va='bottom', fontsize=10, fontweight='bold', color='black', transform=ax.get_xaxis_transform())
201
  ax.plot(loss_val, 0.5, marker='o', ms=16, color='black', mec='black', mew=3, zorder=5, transform=ax.get_xaxis_transform())
202
  ax.text(loss_val, 0.75, f'Current Loss\n{loss_val:.4f}', ha='center', fontsize=11, fontweight='bold', color='black', transform=ax.get_xaxis_transform())
203
  ax.text((xlo+thr)/2, 0.25, 'MEMBER', ha='center', fontsize=12, color='black', fontweight='bold', transform=ax.get_xaxis_transform(), bbox=dict(facecolor='white', alpha=0.8, edgecolor='none'))
@@ -223,17 +230,17 @@ def fig_auc_bar():
223
  names.append(l); vals.append(perturb_results[k]['auc'])
224
  clrs.append(CHART_C['op_colors'][i]); hatches.append(HATCH_OP[i])
225
 
226
- fig, ax = plt.subplots(figsize=(14, 6)); apply_academic_style(fig, ax)
227
  bars = ax.bar(range(len(names)), vals, color=clrs, width=0.65, edgecolor='black', linewidth=1.5, zorder=3)
228
  for bar, h in zip(bars, hatches):
229
  if h: bar.set_hatch(h)
230
 
231
- for b,v in zip(bars, vals): ax.text(b.get_x()+b.get_width()/2, v+0.01, f'{v:.4f}', ha='center', fontsize=10, fontweight='semibold', color='black')
232
  ax.axhline(0.5, color='black', ls='--', lw=1.5, label='Random Guess (0.5)', zorder=2)
233
  ax.axhline(bl_auc, color='black', ls=':', lw=2, label=f'Baseline ({bl_auc:.4f})', zorder=2)
234
- ax.set_ylabel('MIA Attack AUC', fontsize=12, fontweight='medium'); ax.set_title('Defense Effectiveness: MIA AUC Comparison', fontsize=14, fontweight='bold', pad=20)
235
- ax.set_ylim(0.45, max(vals)+0.05); ax.set_xticks(range(len(names))); ax.set_xticklabels(names, rotation=30, ha='right', fontsize=11)
236
- ax.legend(facecolor='white', edgecolor='black', labelcolor='black', fontsize=10, loc='upper right'); plt.tight_layout()
237
  return fig
238
 
239
  def fig_radar():
@@ -242,7 +249,7 @@ def fig_radar():
242
  N = len(ms)
243
  ag = np.linspace(0, 2 * np.pi, N, endpoint=False).tolist() + [0]
244
 
245
- fig, axes = plt.subplots(1, 2, figsize=(CHART_W + 2, 7), subplot_kw=dict(polar=True))
246
  fig.patch.set_facecolor('white')
247
 
248
  ls_cfgs = [
@@ -262,9 +269,6 @@ def fig_radar():
262
  ("OP(σ=0.03)", "perturbation_0.03")
263
  ]
264
 
265
- # 关键修正:两张雷达图共用同一组归一化分母。
266
- # 原代码分别在左图和右图内部计算最大值,导致同一个 Baseline 在两张图中的形状不一致。
267
- # 这里改为基于所有 LS 与 OP 配置计算全局最大值,使 Baseline 在左右两图中完全一致。
268
  all_cfgs_for_norm = ls_cfgs + op_cfgs
269
  global_max = []
270
  for m_key in mk:
@@ -294,14 +298,14 @@ def fig_radar():
294
  ax.fill(ag, v, alpha=0.08 if ky == 'baseline' else 0.0, color='black')
295
 
296
  ax.set_xticks(ag[:-1])
297
- ax.set_xticklabels(ms, fontsize=10, color='black')
298
  ax.set_yticklabels([])
299
  ax.set_ylim(0, 1.05)
300
  ax.set_title(title, fontsize=12, fontweight='700', color='black', pad=18)
301
  ax.legend(
302
  loc='upper right',
303
  bbox_to_anchor=(1.35 if ax_idx == 1 else 1.30, 1.12),
304
- fontsize=9,
305
  framealpha=0.9,
306
  edgecolor='black'
307
  )
@@ -317,7 +321,7 @@ def fig_d3_dist_compare():
317
  ("Label Smoothing (ε=0.2)", "smooth_eps_0.2", None),
318
  ("Output Perturbation (σ=0.03)", "baseline", 0.03),
319
  ]
320
- fig, axes = plt.subplots(1, 3, figsize=(18, 5.5))
321
  apply_academic_style(fig, axes)
322
 
323
  for idx, (title, key, sigma) in enumerate(configs):
@@ -332,7 +336,6 @@ def fig_d3_dist_compare():
332
  all_v = np.concatenate([m_losses, nm_losses])
333
  bins = np.linspace(all_v.min(), all_v.max(), 35)
334
 
335
- # 使用高对比度的底纹
336
  ax.hist(m_losses, bins=bins, alpha=0.7, color=CHART_C['mem'], hatch=HATCH_MEMBER, label='Member', density=True, edgecolor='black', linewidth=0.8)
337
  ax.hist(nm_losses, bins=bins, alpha=0.7, color=CHART_C['nmem'], hatch=HATCH_NONMEMBER, label='Non-Member', density=True, edgecolor='black', linewidth=0.8)
338
 
@@ -341,19 +344,18 @@ def fig_d3_dist_compare():
341
  ax.axvline(m_mean, color='black', ls='--', lw=2)
342
  ax.axvline(nm_mean, color='black', ls='-', lw=2)
343
  ax.annotate(f'Gap={gap:.4f}', xy=((m_mean+nm_mean)/2, ax.get_ylim()[1]*0.85 if ax.get_ylim()[1]>0 else 5),
344
- fontsize=11, fontweight='bold', color='black', ha='center',
345
  bbox=dict(boxstyle='round,pad=0.4', fc='white', ec='black', alpha=1.0))
346
 
347
- ax.set_title(title, fontsize=13, fontweight='bold', color='black', pad=15)
348
- ax.set_xlabel('Loss', fontsize=12)
349
- if idx == 0: ax.set_ylabel('Density', fontsize=12)
350
- ax.legend(fontsize=10, facecolor='white', edgecolor='black')
351
 
352
- fig.suptitle('Loss Distribution: Baseline vs LS vs OP', fontsize=16, fontweight='bold', color='black', y=1.05)
353
  plt.tight_layout(); return fig
354
 
355
  def fig_loss_dist():
356
- # 仅展示4组标签平滑模型,按“每行2张”排版,避免一行过密
357
  items = [
358
  (k, l, gm(k, 'auc'))
359
  for k, l in zip(LS_KEYS[1:], LS_LABELS_PLOT[1:])
@@ -365,7 +367,7 @@ def fig_loss_dist():
365
 
366
  ncols = 2
367
  nrows = int(np.ceil(n / ncols))
368
- fig, axes = plt.subplots(nrows, ncols, figsize=(10, 4.6 * nrows))
369
  axes_flat = np.array(axes).reshape(-1)
370
  apply_academic_style(fig, axes_flat)
371
 
@@ -374,22 +376,13 @@ def fig_loss_dist():
374
  nm = np.array(full_losses[k]['non_member_losses'])
375
  bins = np.linspace(min(m.min(), nm.min()), max(m.max(), nm.max()), 30)
376
 
377
- # Member Non-Member 使用显著不同的黑白底纹:
378
- # Member 为斜线,Non-Member 为点状,图例与柱状图保持一致。
379
- ax.hist(
380
- m, bins=bins, alpha=0.78, color=CHART_C['mem'], hatch=HATCH_MEMBER,
381
- label='Member', density=True, edgecolor='black', linewidth=0.8
382
- )
383
- ax.hist(
384
- nm, bins=bins, alpha=0.78, color=CHART_C['nmem'], hatch=HATCH_NONMEMBER,
385
- label='Non-Member', density=True, edgecolor='black', linewidth=0.8
386
- )
387
- ax.set_title(f'{l}\nAUC={a:.4f}', fontsize=11, fontweight='semibold')
388
- ax.set_xlabel('Loss', fontsize=10)
389
- ax.set_ylabel('Density', fontsize=10)
390
- ax.legend(fontsize=9, facecolor='white', edgecolor='black', labelcolor='black')
391
 
392
- # 如果子图数量不足,隐藏空白坐标轴
393
  for ax in axes_flat[n:]:
394
  ax.axis('off')
395
 
@@ -402,9 +395,8 @@ def fig_perturb_dist():
402
  ml = np.array(full_losses['baseline']['member_losses'])
403
  nl = np.array(full_losses['baseline']['non_member_losses'])
404
 
405
- # 6组输出扰动结果改为3行2列,保证每行两个子图,便于论文排版阅读
406
  nrows, ncols = 3, 2
407
- fig, axes = plt.subplots(nrows, ncols, figsize=(10, 13.5))
408
  axes_flat = axes.flatten()
409
  apply_academic_style(fig, axes_flat)
410
 
@@ -416,27 +408,19 @@ def fig_perturb_dist():
416
  v = np.concatenate([mp, np_])
417
  bins = np.linspace(v.min(), v.max(), 28)
418
 
419
- # 与标签平滑分布图保持一致:Member/ Mem+noise 用斜线,Non/ Non+noise 用点状
420
- ax.hist(
421
- mp, bins=bins, alpha=0.78, color=CHART_C['mem'], hatch=HATCH_MEMBER,
422
- label='Mem+noise', density=True, edgecolor='black', linewidth=0.8
423
- )
424
- ax.hist(
425
- np_, bins=bins, alpha=0.78, color=CHART_C['nmem'], hatch=HATCH_NONMEMBER,
426
- label='Non+noise', density=True, edgecolor='black', linewidth=0.8
427
- )
428
-
429
  pa = gm(f'perturbation_{s}', 'auc')
430
- ax.set_title(f'OP(σ={s})\nAUC={pa:.4f}', fontsize=11, fontweight='semibold')
431
- ax.set_xlabel('Loss', fontsize=10)
432
- ax.set_ylabel('Density', fontsize=10)
433
- ax.legend(fontsize=9, facecolor='white', edgecolor='black', labelcolor='black')
434
 
435
  plt.tight_layout()
436
  return fig
437
 
438
  def fig_roc_curves():
439
- fig, axes = plt.subplots(1, 2, figsize=(16, 7)); apply_academic_style(fig, axes)
440
 
441
  # LS ROC
442
  ax = axes[0]
@@ -449,7 +433,7 @@ def fig_roc_curves():
449
  lw = 3.0 if k == 'baseline' else 2.0
450
  ax.plot(fpr, tpr, color='black', ls=ls_linestyle_cfgs[i], lw=lw, label=f'{l} (AUC={auc_val:.4f})')
451
  ax.plot([0,1], [0,1], '-', color='gray', lw=1.5, label='Random')
452
- ax.set_xlabel('False Positive Rate', fontsize=12, fontweight='medium'); ax.set_ylabel('True Positive Rate', fontsize=12, fontweight='medium'); ax.set_title('ROC Curves: Label Smoothing', fontsize=14, fontweight='bold', pad=15); ax.legend(fontsize=10, facecolor='white', edgecolor='black', labelcolor='black')
453
 
454
  # OP ROC
455
  ax = axes[1]
@@ -460,11 +444,11 @@ def fig_roc_curves():
460
  rng_m = np.random.RandomState(42); rng_nm = np.random.RandomState(137); mp = ml_base + rng_m.normal(0, s, len(ml_base)); np_ = nl_base + rng_nm.normal(0, s, len(nl_base)); y_scores_p = np.concatenate([-mp, -np_]); fpr_p, tpr_p, _ = roc_curve(y_true, y_scores_p); auc_p = roc_auc_score(y_true, y_scores_p)
461
  ax.plot(fpr_p, tpr_p, color='black', ls=OP_LINESTYLES[i % len(OP_LINESTYLES)], lw=1.5, label=f'OP(σ={s}) (AUC={auc_p:.4f})')
462
  ax.plot([0,1], [0,1], '-', color='gray', lw=1.5, label='Random')
463
- ax.set_xlabel('False Positive Rate', fontsize=12, fontweight='medium'); ax.set_ylabel('True Positive Rate', fontsize=12, fontweight='medium'); ax.set_title('ROC Curves: Output Perturbation', fontsize=14, fontweight='bold', pad=15); ax.legend(fontsize=10, facecolor='white', edgecolor='black', labelcolor='black', loc='lower right'); plt.tight_layout()
464
  return fig
465
 
466
  def fig_tpr_at_low_fpr():
467
- fig, axes = plt.subplots(1, 2, figsize=(16, 6.5)); apply_academic_style(fig, axes); labels_all, tpr5_all, tpr1_all, clrs_all, hatches_all = [], [], [], [], []
468
  ls_h_list = [HATCH_BASELINE] + HATCH_LS
469
  ls_c_list = [CHART_C['baseline']] + CHART_C['ls_colors']
470
 
@@ -479,19 +463,19 @@ def fig_tpr_at_low_fpr():
479
  bars = ax.bar(x, tpr5_all, color=clrs_all, width=0.65, edgecolor='black', linewidth=1.5, zorder=3)
480
  for bar, h in zip(bars, hatches_all):
481
  if h: bar.set_hatch(h)
482
- for b, v in zip(bars, tpr5_all): ax.text(b.get_x()+b.get_width()/2, v+0.005, f'{v:.3f}', ha='center', fontsize=9, fontweight='semibold', color='black')
483
- ax.set_ylabel('TPR @ 5% FPR', fontsize=12, fontweight='medium'); ax.set_title('Attack Power at 5% FPR', fontsize=14, fontweight='bold', pad=15); ax.set_xticks(x); ax.set_xticklabels(labels_all, rotation=35, ha='right', fontsize=11); ax.axhline(0.05, color='gray', ls='--', lw=2, label='Random (0.05)'); ax.legend(facecolor='white', edgecolor='black', labelcolor='black', fontsize=10)
484
 
485
  ax = axes[1];
486
  bars = ax.bar(x, tpr1_all, color=clrs_all, width=0.65, edgecolor='black', linewidth=1.5, zorder=3)
487
  for bar, h in zip(bars, hatches_all):
488
  if h: bar.set_hatch(h)
489
- for b, v in zip(bars, tpr1_all): ax.text(b.get_x()+b.get_width()/2, v+0.003, f'{v:.3f}', ha='center', fontsize=9, fontweight='semibold', color='black')
490
- ax.set_ylabel('TPR @ 1% FPR', fontsize=12, fontweight='medium'); ax.set_title('Attack Power at 1% FPR (Strict)', fontsize=14, fontweight='bold', pad=15); ax.set_xticks(x); ax.set_xticklabels(labels_all, rotation=35, ha='right', fontsize=11); ax.axhline(0.01, color='gray', ls='--', lw=2, label='Random (0.01)'); ax.legend(facecolor='white', edgecolor='black', labelcolor='black', fontsize=10); plt.tight_layout()
491
  return fig
492
 
493
  def fig_loss_gap_waterfall():
494
- fig, ax = plt.subplots(figsize=(14, 6)); apply_academic_style(fig, ax); names, gaps, clrs, hatches = [], [], [], []
495
  ls_h_list = [HATCH_BASELINE] + HATCH_LS
496
  ls_c_list = [CHART_C['baseline']] + CHART_C['ls_colors']
497
  for i, (k, l) in enumerate(zip(LS_KEYS, LS_LABELS_PLOT)):
@@ -504,8 +488,8 @@ def fig_loss_gap_waterfall():
504
  bars = ax.bar(range(len(names)), gaps, color=clrs, width=0.65, edgecolor='black', linewidth=1.5, zorder=3)
505
  for bar, h in zip(bars, hatches):
506
  if h: bar.set_hatch(h)
507
- for b, v in zip(bars, gaps): ax.text(b.get_x()+b.get_width()/2, v+0.0005, f'{v:.4f}', ha='center', fontsize=10, fontweight='semibold', color='black')
508
- ax.set_ylabel('Loss Gap', fontsize=12, fontweight='medium'); ax.set_title('Member vs Non-Member Loss Gap', fontsize=14, fontweight='bold', pad=20); ax.set_xticks(range(len(names))); ax.set_xticklabels(names, rotation=30, ha='right', fontsize=11); ax.annotate('Smaller gap = Better Privacy', xy=(8, gaps[0]*0.4), fontsize=11, color='black', fontstyle='italic', ha='center', backgroundcolor='white', bbox=dict(boxstyle='round,pad=0.4', facecolor='white', edgecolor='black', alpha=1.0)); plt.tight_layout()
509
  return fig
510
 
511
  def fig_acc_bar():
@@ -521,52 +505,52 @@ def fig_acc_bar():
521
  names.append(l); vals.append(bl_acc)
522
  clrs.append(CHART_C['op_colors'][i]); hatches.append(HATCH_OP[i])
523
 
524
- fig, ax = plt.subplots(figsize=(12, 7)); apply_academic_style(fig, ax)
525
  bars = ax.bar(range(len(names)), vals, color=clrs, width=0.65, edgecolor='black', linewidth=1.5, zorder=3)
526
  for bar, h in zip(bars, hatches):
527
  if h: bar.set_hatch(h)
528
- for b, v in zip(bars, vals): ax.text(b.get_x()+b.get_width()/2, v+1, f'{v:.1f}%', ha='center', fontsize=11, fontweight='bold', color='black')
529
- ax.set_ylabel('Test Accuracy (%)', fontsize=12, fontweight='medium'); ax.set_title('Model Utility: Test Accuracy', fontsize=15, fontweight='bold', pad=20)
530
- ax.set_ylim(0, 105); ax.set_xticks(range(len(names))); ax.set_xticklabels(names, rotation=35, ha='right', fontsize=12); plt.tight_layout()
531
  return fig
532
 
533
  def fig_tradeoff():
534
- fig, ax = plt.subplots(figsize=(12, 7)); apply_academic_style(fig, ax);
535
  markers_ls = ['o', 's', 'p', '*', 'h']
536
  for i, (k, l) in enumerate(zip(LS_KEYS, LS_LABELS_PLOT)):
537
  if k in mia_results and k in utility_results:
538
- ax.scatter(utility_results[k]['accuracy']*100, mia_results[k]['auc'], label=l, marker=markers_ls[i], color='white', s=250, edgecolors='black', lw=2.0, zorder=5)
539
  op_markers = ['^', 'D', 'v', 'P', 'X', '>']
540
  for i, (k, l) in enumerate(zip(OP_KEYS, OP_LABELS_PLOT)):
541
  if k in perturb_results:
542
- ax.scatter(bl_acc, perturb_results[k]['auc'], label=l, marker=op_markers[i], color='#AAAAAA', s=200, edgecolors='black', lw=1.5, zorder=6)
543
 
544
  ax.axhline(0.5, color='black', ls='--', lw=1.5, alpha=0.6, label='Random (AUC=0.5)')
545
- ax.annotate('IDEAL ZONE\nHigh Utility, Low Risk', xy=(85, 0.51), fontsize=11, fontweight='bold', color='black', ha='center', bbox=dict(boxstyle='round,pad=0.5', fc='white', ec='black'))
546
- ax.annotate('HIGH RISK ZONE\nLow Utility, High Risk', xy=(62, 0.61), fontsize=11, fontweight='bold', color='black', ha='center', bbox=dict(boxstyle='round,pad=0.5', fc='white', ec='black'))
547
- ax.set_xlabel('Model Utility (Accuracy %)', fontsize=12, fontweight='medium'); ax.set_ylabel('Privacy Risk (MIA AUC)', fontsize=12, fontweight='medium')
548
- ax.set_title('Privacy-Utility Trade-off Analysis', fontsize=15, fontweight='bold', pad=20)
549
- ax.legend(fontsize=11, loc='lower left', ncol=2, facecolor='white', edgecolor='black', labelcolor='black'); plt.tight_layout()
550
  return fig
551
 
552
  def fig_auc_trend():
553
- fig, axes = plt.subplots(1, 2, figsize=(16, 6.5)); apply_academic_style(fig, axes); ax = axes[0]; eps_vals = [0.0, 0.02, 0.05, 0.1, 0.2]; auc_vals = [gm(k, 'auc') for k in LS_KEYS]; acc_vals = [gu(k) for k in LS_KEYS]
554
  ax2 = ax.twinx();
555
  line1 = ax.plot(eps_vals, auc_vals, marker='o', ls='-', color='black', lw=3, ms=9, label='MIA AUC (Risk)', zorder=5);
556
  line2 = ax2.plot(eps_vals, acc_vals, marker='s', ls='--', color='black', lw=3, ms=9, label='Utility % (right)', zorder=5);
557
  ax.axhline(0.5, color='gray', ls=':')
558
  ax.fill_between(eps_vals, auc_vals, 0.5, alpha=0.2, color='gray', hatch='//')
559
- ax.set_xlabel('Label Smoothing ε', fontsize=12, fontweight='medium'); ax.set_ylabel('MIA AUC', fontsize=12, fontweight='medium', color='black'); ax2.set_ylabel('Utility (%)', fontsize=12, fontweight='medium', color='black'); ax.set_title('Label Smoothing Trends', fontsize=14, fontweight='bold', pad=15); ax.tick_params(axis='y', labelcolor='black'); ax2.tick_params(axis='y', labelcolor='black'); ax2.spines['right'].set_color('black'); ax2.spines['left'].set_color('black'); lines = line1 + line2; labels = [l.get_label() for l in lines]
560
- ax.legend(lines, labels, fontsize=10, facecolor='white', edgecolor='black', loc='lower right')
561
 
562
  ax = axes[1]; sig_vals = OP_SIGMAS; auc_op = [gm(k, 'auc') for k in OP_KEYS];
563
  ax.plot(sig_vals, auc_op, marker='^', ls='-', color='black', lw=3, ms=9, zorder=5, label='MIA AUC');
564
  ax.axhline(bl_auc, color='black', ls='--', lw=2, label=f'Baseline ({bl_auc:.4f})');
565
  ax.axhline(0.5, color='gray', ls=':', label='Random (0.5)');
566
  ax.fill_between(sig_vals, auc_op, bl_auc, alpha=0.2, color='gray', hatch='\\\\', label='AUC Reduction')
567
- ax2r = ax.twinx(); ax2r.axhline(bl_acc, color='black', ls='-', lw=2.5); ax2r.set_ylabel(f'Utility = {bl_acc:.1f}% (unchanged)', fontsize=12, fontweight='medium', color='black'); ax2r.set_ylim(0,100); ax2r.tick_params(axis='y', labelcolor='black'); ax2r.spines['right'].set_color('black')
568
- ax.set_xlabel('Perturbation σ', fontsize=12, fontweight='medium'); ax.set_ylabel('MIA AUC', fontsize=12, fontweight='medium'); ax.set_title('Output Perturbation Trends', fontsize=14, fontweight='bold', pad=15)
569
- ax.legend(fontsize=10, facecolor='white', edgecolor='black', loc='lower left'); plt.tight_layout()
570
  return fig
571
 
572
  # ================================================================
@@ -765,7 +749,6 @@ footer { display: none !important; }
765
  # ================================================================
766
  # UI 布局构建 (完全不碰原版Blocks构建)
767
  # ================================================================
768
- # 移除了警告的 theme 和 css 参数,确保兼容 Gradio 6.0
769
  with gr.Blocks(title="MIA攻防研究") as demo:
770
 
771
  gr.HTML("""<div class="title-area">
@@ -1045,5 +1028,4 @@ with gr.Blocks(title="MIA攻防研究") as demo:
1045
 
1046
  """)
1047
 
1048
- # 添加了 theme 和 css 并修复了括号问题
1049
  demo.launch(theme=gr.themes.Soft(), css=CSS)
 
3
  # 1. 严格基于原版底座,UI界面和回调逻辑一字不改。
4
  # 2. 仅对所有的 fig_xxx 绘图函数进行了学术黑白化+花纹底纹改造。
5
  # 3. 修复了参数传递和括号语法问题,保证100%零报错运行。
6
+ # 4. [新增] 全局设定字体为 Times New Roman + 宋体,字号统一设为五号(10.5pt)
7
+ # 5. [新增] 等比例放大所有图表的 figsize,提升清晰度。
8
  # ================================================================
9
 
10
  import os
 
17
  from sklearn.metrics import roc_curve, roc_auc_score
18
  import gradio as gr
19
 
20
+ # 🌟 全局学术字体与字号配置 (英文 Times New Roman, 中文宋体, 五号字10.5pt)
21
+ plt.rcParams['font.sans-serif'] = ['Times New Roman', 'SimSun', 'Arial']
22
+ plt.rcParams['axes.unicode_minus'] = False
23
+ plt.rcParams['font.size'] = 10.5
24
+ plt.rcParams['hatch.linewidth'] = 1.2 # 加粗底纹线条使其更清晰
25
+
26
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
27
 
28
  # ================================================================
 
72
  perturb_results[k]["non_member_loss_std"] = np.sqrt(0.03**2 + s**2)
73
 
74
  # ================================================================
75
+ # 全局UI配置
76
  # ================================================================
77
  COLORS = {
78
  'bg': '#FFFFFF',
 
90
  'op_colors': ['#98F5E1', '#6EE7B7', '#34D399', '#10B981', '#059669', '#047857'],
91
  }
92
 
93
+ # 专门为图表新增的学术黑白配置集
94
  CHART_C = {
95
  'bg': '#FFFFFF',
96
  'panel': '#FFFFFF',
 
112
  LS_LINESTYLES = ['-', '--', '-.', ':', (0, (3, 1, 1, 1))]
113
  OP_LINESTYLES = ['-', '--', '-.', ':', (0, (5, 1)), (0, (3, 1, 1, 1, 1, 1))]
114
 
115
+ CHART_W = 16
 
 
 
116
 
117
  def apply_academic_style(fig, ax_or_axes):
118
  fig.patch.set_facecolor(CHART_C['bg'])
 
122
  for spine in ax.spines.values():
123
  spine.set_color('#000000')
124
  spine.set_linewidth(1.0)
125
+ ax.tick_params(colors='#000000', labelsize=10.5, width=1.0)
126
  ax.xaxis.label.set_color('#000000')
127
+ ax.xaxis.label.set_fontsize(11)
128
  ax.yaxis.label.set_color('#000000')
129
+ ax.yaxis.label.set_fontsize(11)
130
  ax.title.set_color('#000000')
131
+ ax.title.set_fontsize(12)
132
  ax.title.set_fontweight('bold')
133
  ax.grid(True, color=CHART_C['grid'], alpha=0.8, linestyle='--', linewidth=0.5)
134
  ax.set_axisbelow(True)
 
192
  EVAL_POOL.append(item)
193
 
194
  # ================================================================
195
+ # 图表绘制函数 (全面转换为学术黑白+底纹格式,并等比放大figsize)
196
  # ================================================================
197
  def fig_gauge(loss_val, m_mean, nm_mean, thr, m_std, nm_std):
198
+ fig, ax = plt.subplots(figsize=(12, 3.5)); apply_academic_style(fig, ax)
199
  xlo = min(m_mean - 3.0 * m_std, loss_val - 0.005); xhi = max(nm_mean + 3.0 * nm_std, loss_val + 0.005)
 
200
  ax.axvspan(xlo, thr, alpha=0.3, color=CHART_C['mem'], hatch=HATCH_MEMBER, edgecolor='black')
201
  ax.axvspan(thr, xhi, alpha=0.3, color=CHART_C['nmem'], hatch=HATCH_NONMEMBER, edgecolor='black')
202
  ax.axvline(m_mean, color='black', lw=2, ls=':', zorder=2)
203
+ ax.text(m_mean - 0.002, 1.02, f'Member Mean\n{m_mean:.4f}', ha='right', va='bottom', fontsize=10.5, color='black', transform=ax.get_xaxis_transform())
204
  ax.axvline(nm_mean, color='black', lw=2, ls=':', zorder=2)
205
+ ax.text(nm_mean + 0.002, 1.02, f'Non-Member Mean\n{nm_mean:.4f}', ha='left', va='bottom', fontsize=10.5, color='black', transform=ax.get_xaxis_transform())
206
  ax.axvline(thr, color='black', lw=2.5, ls='--', zorder=3)
207
+ ax.text(thr, 1.25, f'Threshold\n{thr:.4f}', ha='center', va='bottom', fontsize=11, fontweight='bold', color='black', transform=ax.get_xaxis_transform())
208
  ax.plot(loss_val, 0.5, marker='o', ms=16, color='black', mec='black', mew=3, zorder=5, transform=ax.get_xaxis_transform())
209
  ax.text(loss_val, 0.75, f'Current Loss\n{loss_val:.4f}', ha='center', fontsize=11, fontweight='bold', color='black', transform=ax.get_xaxis_transform())
210
  ax.text((xlo+thr)/2, 0.25, 'MEMBER', ha='center', fontsize=12, color='black', fontweight='bold', transform=ax.get_xaxis_transform(), bbox=dict(facecolor='white', alpha=0.8, edgecolor='none'))
 
230
  names.append(l); vals.append(perturb_results[k]['auc'])
231
  clrs.append(CHART_C['op_colors'][i]); hatches.append(HATCH_OP[i])
232
 
233
+ fig, ax = plt.subplots(figsize=(16, 7.5)); apply_academic_style(fig, ax)
234
  bars = ax.bar(range(len(names)), vals, color=clrs, width=0.65, edgecolor='black', linewidth=1.5, zorder=3)
235
  for bar, h in zip(bars, hatches):
236
  if h: bar.set_hatch(h)
237
 
238
+ for b,v in zip(bars, vals): ax.text(b.get_x()+b.get_width()/2, v+0.01, f'{v:.4f}', ha='center', fontsize=10.5, fontweight='semibold', color='black')
239
  ax.axhline(0.5, color='black', ls='--', lw=1.5, label='Random Guess (0.5)', zorder=2)
240
  ax.axhline(bl_auc, color='black', ls=':', lw=2, label=f'Baseline ({bl_auc:.4f})', zorder=2)
241
+ ax.set_ylabel('MIA Attack AUC'); ax.set_title('Defense Effectiveness: MIA AUC Comparison', pad=20)
242
+ ax.set_ylim(0.45, max(vals)+0.05); ax.set_xticks(range(len(names))); ax.set_xticklabels(names, rotation=30, ha='right', fontsize=10.5)
243
+ ax.legend(facecolor='white', edgecolor='black', labelcolor='black', fontsize=10.5, loc='upper right'); plt.tight_layout()
244
  return fig
245
 
246
  def fig_radar():
 
249
  N = len(ms)
250
  ag = np.linspace(0, 2 * np.pi, N, endpoint=False).tolist() + [0]
251
 
252
+ fig, axes = plt.subplots(1, 2, figsize=(18, 8), subplot_kw=dict(polar=True))
253
  fig.patch.set_facecolor('white')
254
 
255
  ls_cfgs = [
 
269
  ("OP(σ=0.03)", "perturbation_0.03")
270
  ]
271
 
 
 
 
272
  all_cfgs_for_norm = ls_cfgs + op_cfgs
273
  global_max = []
274
  for m_key in mk:
 
298
  ax.fill(ag, v, alpha=0.08 if ky == 'baseline' else 0.0, color='black')
299
 
300
  ax.set_xticks(ag[:-1])
301
+ ax.set_xticklabels(ms, fontsize=10.5, color='black')
302
  ax.set_yticklabels([])
303
  ax.set_ylim(0, 1.05)
304
  ax.set_title(title, fontsize=12, fontweight='700', color='black', pad=18)
305
  ax.legend(
306
  loc='upper right',
307
  bbox_to_anchor=(1.35 if ax_idx == 1 else 1.30, 1.12),
308
+ fontsize=10.5,
309
  framealpha=0.9,
310
  edgecolor='black'
311
  )
 
321
  ("Label Smoothing (ε=0.2)", "smooth_eps_0.2", None),
322
  ("Output Perturbation (σ=0.03)", "baseline", 0.03),
323
  ]
324
+ fig, axes = plt.subplots(1, 3, figsize=(22, 6.5))
325
  apply_academic_style(fig, axes)
326
 
327
  for idx, (title, key, sigma) in enumerate(configs):
 
336
  all_v = np.concatenate([m_losses, nm_losses])
337
  bins = np.linspace(all_v.min(), all_v.max(), 35)
338
 
 
339
  ax.hist(m_losses, bins=bins, alpha=0.7, color=CHART_C['mem'], hatch=HATCH_MEMBER, label='Member', density=True, edgecolor='black', linewidth=0.8)
340
  ax.hist(nm_losses, bins=bins, alpha=0.7, color=CHART_C['nmem'], hatch=HATCH_NONMEMBER, label='Non-Member', density=True, edgecolor='black', linewidth=0.8)
341
 
 
344
  ax.axvline(m_mean, color='black', ls='--', lw=2)
345
  ax.axvline(nm_mean, color='black', ls='-', lw=2)
346
  ax.annotate(f'Gap={gap:.4f}', xy=((m_mean+nm_mean)/2, ax.get_ylim()[1]*0.85 if ax.get_ylim()[1]>0 else 5),
347
+ fontsize=10.5, fontweight='bold', color='black', ha='center',
348
  bbox=dict(boxstyle='round,pad=0.4', fc='white', ec='black', alpha=1.0))
349
 
350
+ ax.set_title(title, pad=15)
351
+ ax.set_xlabel('Loss')
352
+ if idx == 0: ax.set_ylabel('Density')
353
+ ax.legend(fontsize=10.5, facecolor='white', edgecolor='black')
354
 
355
+ fig.suptitle('Loss Distribution: Baseline vs LS vs OP', fontsize=14, fontweight='bold', color='black', y=1.02)
356
  plt.tight_layout(); return fig
357
 
358
  def fig_loss_dist():
 
359
  items = [
360
  (k, l, gm(k, 'auc'))
361
  for k, l in zip(LS_KEYS[1:], LS_LABELS_PLOT[1:])
 
367
 
368
  ncols = 2
369
  nrows = int(np.ceil(n / ncols))
370
+ fig, axes = plt.subplots(nrows, ncols, figsize=(14, 5.5 * nrows))
371
  axes_flat = np.array(axes).reshape(-1)
372
  apply_academic_style(fig, axes_flat)
373
 
 
376
  nm = np.array(full_losses[k]['non_member_losses'])
377
  bins = np.linspace(min(m.min(), nm.min()), max(m.max(), nm.max()), 30)
378
 
379
+ ax.hist(m, bins=bins, alpha=0.78, color=CHART_C['mem'], hatch=HATCH_MEMBER, label='Member', density=True, edgecolor='black', linewidth=0.8)
380
+ ax.hist(nm, bins=bins, alpha=0.78, color=CHART_C['nmem'], hatch=HATCH_NONMEMBER, label='Non-Member', density=True, edgecolor='black', linewidth=0.8)
381
+ ax.set_title(f'{l}\nAUC={a:.4f}')
382
+ ax.set_xlabel('Loss')
383
+ ax.set_ylabel('Density')
384
+ ax.legend(fontsize=10.5, facecolor='white', edgecolor='black', labelcolor='black')
 
 
 
 
 
 
 
 
385
 
 
386
  for ax in axes_flat[n:]:
387
  ax.axis('off')
388
 
 
395
  ml = np.array(full_losses['baseline']['member_losses'])
396
  nl = np.array(full_losses['baseline']['non_member_losses'])
397
 
 
398
  nrows, ncols = 3, 2
399
+ fig, axes = plt.subplots(nrows, ncols, figsize=(14, 16))
400
  axes_flat = axes.flatten()
401
  apply_academic_style(fig, axes_flat)
402
 
 
408
  v = np.concatenate([mp, np_])
409
  bins = np.linspace(v.min(), v.max(), 28)
410
 
411
+ ax.hist(mp, bins=bins, alpha=0.78, color=CHART_C['mem'], hatch=HATCH_MEMBER, label='Mem+noise', density=True, edgecolor='black', linewidth=0.8)
412
+ ax.hist(np_, bins=bins, alpha=0.78, color=CHART_C['nmem'], hatch=HATCH_NONMEMBER, label='Non+noise', density=True, edgecolor='black', linewidth=0.8)
 
 
 
 
 
 
 
 
413
  pa = gm(f'perturbation_{s}', 'auc')
414
+ ax.set_title(f'OP(σ={s})\nAUC={pa:.4f}')
415
+ ax.set_xlabel('Loss')
416
+ ax.set_ylabel('Density')
417
+ ax.legend(fontsize=10.5, facecolor='white', edgecolor='black', labelcolor='black')
418
 
419
  plt.tight_layout()
420
  return fig
421
 
422
  def fig_roc_curves():
423
+ fig, axes = plt.subplots(1, 2, figsize=(18, 8)); apply_academic_style(fig, axes)
424
 
425
  # LS ROC
426
  ax = axes[0]
 
433
  lw = 3.0 if k == 'baseline' else 2.0
434
  ax.plot(fpr, tpr, color='black', ls=ls_linestyle_cfgs[i], lw=lw, label=f'{l} (AUC={auc_val:.4f})')
435
  ax.plot([0,1], [0,1], '-', color='gray', lw=1.5, label='Random')
436
+ ax.set_xlabel('False Positive Rate'); ax.set_ylabel('True Positive Rate'); ax.set_title('ROC Curves: Label Smoothing', pad=15); ax.legend(fontsize=10.5, facecolor='white', edgecolor='black', labelcolor='black')
437
 
438
  # OP ROC
439
  ax = axes[1]
 
444
  rng_m = np.random.RandomState(42); rng_nm = np.random.RandomState(137); mp = ml_base + rng_m.normal(0, s, len(ml_base)); np_ = nl_base + rng_nm.normal(0, s, len(nl_base)); y_scores_p = np.concatenate([-mp, -np_]); fpr_p, tpr_p, _ = roc_curve(y_true, y_scores_p); auc_p = roc_auc_score(y_true, y_scores_p)
445
  ax.plot(fpr_p, tpr_p, color='black', ls=OP_LINESTYLES[i % len(OP_LINESTYLES)], lw=1.5, label=f'OP(σ={s}) (AUC={auc_p:.4f})')
446
  ax.plot([0,1], [0,1], '-', color='gray', lw=1.5, label='Random')
447
+ ax.set_xlabel('False Positive Rate'); ax.set_ylabel('True Positive Rate'); ax.set_title('ROC Curves: Output Perturbation', pad=15); ax.legend(fontsize=10.5, facecolor='white', edgecolor='black', labelcolor='black', loc='lower right'); plt.tight_layout()
448
  return fig
449
 
450
  def fig_tpr_at_low_fpr():
451
+ fig, axes = plt.subplots(1, 2, figsize=(18, 7.5)); apply_academic_style(fig, axes); labels_all, tpr5_all, tpr1_all, clrs_all, hatches_all = [], [], [], [], []
452
  ls_h_list = [HATCH_BASELINE] + HATCH_LS
453
  ls_c_list = [CHART_C['baseline']] + CHART_C['ls_colors']
454
 
 
463
  bars = ax.bar(x, tpr5_all, color=clrs_all, width=0.65, edgecolor='black', linewidth=1.5, zorder=3)
464
  for bar, h in zip(bars, hatches_all):
465
  if h: bar.set_hatch(h)
466
+ for b, v in zip(bars, tpr5_all): ax.text(b.get_x()+b.get_width()/2, v+0.005, f'{v:.3f}', ha='center', fontsize=10.5, fontweight='semibold', color='black')
467
+ ax.set_ylabel('TPR @ 5% FPR'); ax.set_title('Attack Power at 5% FPR', pad=15); ax.set_xticks(x); ax.set_xticklabels(labels_all, rotation=35, ha='right', fontsize=10.5); ax.axhline(0.05, color='gray', ls='--', lw=2, label='Random (0.05)'); ax.legend(facecolor='white', edgecolor='black', labelcolor='black', fontsize=10.5)
468
 
469
  ax = axes[1];
470
  bars = ax.bar(x, tpr1_all, color=clrs_all, width=0.65, edgecolor='black', linewidth=1.5, zorder=3)
471
  for bar, h in zip(bars, hatches_all):
472
  if h: bar.set_hatch(h)
473
+ for b, v in zip(bars, tpr1_all): ax.text(b.get_x()+b.get_width()/2, v+0.003, f'{v:.3f}', ha='center', fontsize=10.5, fontweight='semibold', color='black')
474
+ ax.set_ylabel('TPR @ 1% FPR'); ax.set_title('Attack Power at 1% FPR (Strict)', pad=15); ax.set_xticks(x); ax.set_xticklabels(labels_all, rotation=35, ha='right', fontsize=10.5); ax.axhline(0.01, color='gray', ls='--', lw=2, label='Random (0.01)'); ax.legend(facecolor='white', edgecolor='black', labelcolor='black', fontsize=10.5); plt.tight_layout()
475
  return fig
476
 
477
  def fig_loss_gap_waterfall():
478
+ fig, ax = plt.subplots(figsize=(16, 7.5)); apply_academic_style(fig, ax); names, gaps, clrs, hatches = [], [], [], []
479
  ls_h_list = [HATCH_BASELINE] + HATCH_LS
480
  ls_c_list = [CHART_C['baseline']] + CHART_C['ls_colors']
481
  for i, (k, l) in enumerate(zip(LS_KEYS, LS_LABELS_PLOT)):
 
488
  bars = ax.bar(range(len(names)), gaps, color=clrs, width=0.65, edgecolor='black', linewidth=1.5, zorder=3)
489
  for bar, h in zip(bars, hatches):
490
  if h: bar.set_hatch(h)
491
+ for b, v in zip(bars, gaps): ax.text(b.get_x()+b.get_width()/2, v+0.0005, f'{v:.4f}', ha='center', fontsize=10.5, fontweight='semibold', color='black')
492
+ ax.set_ylabel('Loss Gap'); ax.set_title('Member vs Non-Member Loss Gap', pad=20); ax.set_xticks(range(len(names))); ax.set_xticklabels(names, rotation=30, ha='right', fontsize=10.5); ax.annotate('Smaller gap = Better Privacy', xy=(8, gaps[0]*0.4), fontsize=10.5, color='black', fontstyle='italic', ha='center', backgroundcolor='white', bbox=dict(boxstyle='round,pad=0.4', facecolor='white', edgecolor='black', alpha=1.0)); plt.tight_layout()
493
  return fig
494
 
495
  def fig_acc_bar():
 
505
  names.append(l); vals.append(bl_acc)
506
  clrs.append(CHART_C['op_colors'][i]); hatches.append(HATCH_OP[i])
507
 
508
+ fig, ax = plt.subplots(figsize=(14, 8)); apply_academic_style(fig, ax)
509
  bars = ax.bar(range(len(names)), vals, color=clrs, width=0.65, edgecolor='black', linewidth=1.5, zorder=3)
510
  for bar, h in zip(bars, hatches):
511
  if h: bar.set_hatch(h)
512
+ for b, v in zip(bars, vals): ax.text(b.get_x()+b.get_width()/2, v+1, f'{v:.1f}%', ha='center', fontsize=10.5, fontweight='bold', color='black')
513
+ ax.set_ylabel('Test Accuracy (%)'); ax.set_title('Model Utility: Test Accuracy', pad=20)
514
+ ax.set_ylim(0, 105); ax.set_xticks(range(len(names))); ax.set_xticklabels(names, rotation=35, ha='right', fontsize=10.5); plt.tight_layout()
515
  return fig
516
 
517
  def fig_tradeoff():
518
+ fig, ax = plt.subplots(figsize=(14, 8)); apply_academic_style(fig, ax);
519
  markers_ls = ['o', 's', 'p', '*', 'h']
520
  for i, (k, l) in enumerate(zip(LS_KEYS, LS_LABELS_PLOT)):
521
  if k in mia_results and k in utility_results:
522
+ ax.scatter(utility_results[k]['accuracy']*100, mia_results[k]['auc'], label=l, marker=markers_ls[i], color='white', s=280, edgecolors='black', lw=2.0, zorder=5)
523
  op_markers = ['^', 'D', 'v', 'P', 'X', '>']
524
  for i, (k, l) in enumerate(zip(OP_KEYS, OP_LABELS_PLOT)):
525
  if k in perturb_results:
526
+ ax.scatter(bl_acc, perturb_results[k]['auc'], label=l, marker=op_markers[i], color='#AAAAAA', s=230, edgecolors='black', lw=1.5, zorder=6)
527
 
528
  ax.axhline(0.5, color='black', ls='--', lw=1.5, alpha=0.6, label='Random (AUC=0.5)')
529
+ ax.annotate('IDEAL ZONE\nHigh Utility, Low Risk', xy=(85, 0.51), fontsize=10.5, fontweight='bold', color='black', ha='center', bbox=dict(boxstyle='round,pad=0.5', fc='white', ec='black'))
530
+ ax.annotate('HIGH RISK ZONE\nLow Utility, High Risk', xy=(62, 0.61), fontsize=10.5, fontweight='bold', color='black', ha='center', bbox=dict(boxstyle='round,pad=0.5', fc='white', ec='black'))
531
+ ax.set_xlabel('Model Utility (Accuracy %)'); ax.set_ylabel('Privacy Risk (MIA AUC)')
532
+ ax.set_title('Privacy-Utility Trade-off Analysis', pad=20)
533
+ ax.legend(fontsize=10.5, loc='lower left', ncol=2, facecolor='white', edgecolor='black', labelcolor='black'); plt.tight_layout()
534
  return fig
535
 
536
  def fig_auc_trend():
537
+ fig, axes = plt.subplots(1, 2, figsize=(18, 7.5)); apply_academic_style(fig, axes); ax = axes[0]; eps_vals = [0.0, 0.02, 0.05, 0.1, 0.2]; auc_vals = [gm(k, 'auc') for k in LS_KEYS]; acc_vals = [gu(k) for k in LS_KEYS]
538
  ax2 = ax.twinx();
539
  line1 = ax.plot(eps_vals, auc_vals, marker='o', ls='-', color='black', lw=3, ms=9, label='MIA AUC (Risk)', zorder=5);
540
  line2 = ax2.plot(eps_vals, acc_vals, marker='s', ls='--', color='black', lw=3, ms=9, label='Utility % (right)', zorder=5);
541
  ax.axhline(0.5, color='gray', ls=':')
542
  ax.fill_between(eps_vals, auc_vals, 0.5, alpha=0.2, color='gray', hatch='//')
543
+ ax.set_xlabel('Label Smoothing ε'); ax.set_ylabel('MIA AUC', color='black'); ax2.set_ylabel('Utility (%)', color='black'); ax.set_title('Label Smoothing Trends', pad=15); ax.tick_params(axis='y', labelcolor='black'); ax2.tick_params(axis='y', labelcolor='black'); ax2.spines['right'].set_color('black'); ax2.spines['left'].set_color('black'); lines = line1 + line2; labels = [l.get_label() for l in lines]
544
+ ax.legend(lines, labels, fontsize=10.5, facecolor='white', edgecolor='black', loc='lower right')
545
 
546
  ax = axes[1]; sig_vals = OP_SIGMAS; auc_op = [gm(k, 'auc') for k in OP_KEYS];
547
  ax.plot(sig_vals, auc_op, marker='^', ls='-', color='black', lw=3, ms=9, zorder=5, label='MIA AUC');
548
  ax.axhline(bl_auc, color='black', ls='--', lw=2, label=f'Baseline ({bl_auc:.4f})');
549
  ax.axhline(0.5, color='gray', ls=':', label='Random (0.5)');
550
  ax.fill_between(sig_vals, auc_op, bl_auc, alpha=0.2, color='gray', hatch='\\\\', label='AUC Reduction')
551
+ ax2r = ax.twinx(); ax2r.axhline(bl_acc, color='black', ls='-', lw=2.5); ax2r.set_ylabel(f'Utility = {bl_acc:.1f}% (unchanged)', color='black'); ax2r.set_ylim(0,100); ax2r.tick_params(axis='y', labelcolor='black'); ax2r.spines['right'].set_color('black')
552
+ ax.set_xlabel('Perturbation σ'); ax.set_ylabel('MIA AUC'); ax.set_title('Output Perturbation Trends', pad=15)
553
+ ax.legend(fontsize=10.5, facecolor='white', edgecolor='black', loc='lower left'); plt.tight_layout()
554
  return fig
555
 
556
  # ================================================================
 
749
  # ================================================================
750
  # UI 布局构建 (完全不碰原版Blocks构建)
751
  # ================================================================
 
752
  with gr.Blocks(title="MIA攻防研究") as demo:
753
 
754
  gr.HTML("""<div class="title-area">
 
1028
 
1029
  """)
1030
 
 
1031
  demo.launch(theme=gr.themes.Soft(), css=CSS)