xiaohy commited on
Commit
afed715
·
verified ·
1 Parent(s): 687782e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -38
app.py CHANGED
@@ -1,6 +1,7 @@
1
  # ================================================================
2
- # 教育大模型MIA攻防研究 - Gradio演示系统 v7.0 学术巅峰版
3
- # 彻底消灭普通 e/s,全量启用 LaTeX 原生数学斜体 $\epsilon$ $\sigma$
 
4
  # ================================================================
5
 
6
  import os
@@ -31,6 +32,7 @@ def clean_text(text):
31
  text = re.sub(r'[\u200b-\u200f\u2028-\u202f\u2060-\u206f\ufeff]', '', text)
32
  return text.strip()
33
 
 
34
  try:
35
  member_data = load_json("data/member.json")
36
  non_member_data = load_json("data/non_member.json")
@@ -57,11 +59,12 @@ except FileNotFoundError:
57
  for s in [0.005, 0.01, 0.015, 0.02, 0.025, 0.03]:
58
  k = f"perturbation_{s}"
59
  perturb_results[k] = {m: v*0.85 for m, v in mia_results["baseline"].items()}
 
60
  perturb_results[k]["member_loss_std"] = np.sqrt(0.03**2 + s**2)
61
  perturb_results[k]["non_member_loss_std"] = np.sqrt(0.03**2 + s**2)
62
 
63
  # ================================================================
64
- # 全局图表配置
65
  # ================================================================
66
  COLORS = {
67
  'bg': '#FFFFFF',
@@ -78,6 +81,8 @@ COLORS = {
78
  'ls_colors': ['#A0C4FF', '#70A1FF', '#478EFF', '#007AFF'],
79
  'op_colors': ['#98F5E1', '#6EE7B7', '#34D399', '#10B981', '#059669', '#047857'],
80
  }
 
 
81
  CHART_W = 14
82
 
83
  def apply_light_style(fig, ax_or_axes):
@@ -96,15 +101,17 @@ def apply_light_style(fig, ax_or_axes):
96
  ax.grid(True, color=COLORS['grid'], alpha=0.6, linestyle='-', linewidth=0.8)
97
  ax.set_axisbelow(True)
98
 
99
- # 🌟🌟🌟 核心修改:专门为画图准备的 LaTeX 格式标签 🌟🌟🌟
 
 
100
  LS_KEYS = ["baseline", "smooth_eps_0.02", "smooth_eps_0.05", "smooth_eps_0.1", "smooth_eps_0.2"]
101
  LS_LABELS_PLOT = ["Baseline", r"LS($\epsilon$=0.02)", r"LS($\epsilon$=0.05)", r"LS($\epsilon$=0.1)", r"LS($\epsilon$=0.2)"]
102
- LS_LABELS_UI = ["基线(Baseline)", "LS(ε=0.02)", "LS(ε=0.05)", "LS(ε=0.1)", "LS(ε=0.2)"]
103
 
104
  OP_SIGMAS = [0.005, 0.01, 0.015, 0.02, 0.025, 0.03]
105
  OP_KEYS = [f"perturbation_{s}" for s in OP_SIGMAS]
106
- OP_LABELS_PLOT = [f"OP($\sigma$={s})" for s in OP_SIGMAS]
107
- OP_LABELS_UI = [f"OP(σ={s})" for s in OP_SIGMAS]
108
 
109
  ALL_KEYS = LS_KEYS + OP_KEYS
110
 
@@ -123,8 +130,12 @@ bl_acc = gu("baseline")
123
  bl_m_mean = gm("baseline", "member_loss_mean")
124
  bl_nm_mean = gm("baseline", "non_member_loss_mean")
125
 
126
- TYPE_CN = {'calculation': '基础计算', 'word_problem': '应用题', 'concept': '概念问答', 'error_correction': '错题订正'}
 
127
 
 
 
 
128
  np.random.seed(777)
129
  EVAL_POOL = []
130
  _types = ['calculation']*120 + ['word_problem']*90 + ['concept']*60 + ['error_correction']*30
@@ -138,7 +149,8 @@ for _i in range(300):
138
  else: _q,_ans=f"{_a} x {_b} = ?",str(_a*_b)
139
  elif _t == 'word_problem':
140
  _a,_b = int(np.random.randint(5,200)), int(np.random.randint(3,50))
141
- _tpls = [(f"{_a} apples, ate {_b}, left?",str(_a-_b)), (f"{_a} per group, {_b} groups, total?",str(_a*_b))]
 
142
  _q,_ans = _tpls[_i%len(_tpls)]
143
  elif _t == 'concept':
144
  _cs = [("area","Area = space occupied by a shape"),("perimeter","Perimeter = total boundary length")]
@@ -153,27 +165,41 @@ for _i in range(300):
153
  EVAL_POOL.append(item)
154
 
155
  # ================================================================
156
- # 图表绘制函数 (全部更换为 LaTeX 渲染)
157
  # ================================================================
158
  def fig_gauge(loss_val, m_mean, nm_mean, thr, m_std, nm_std):
159
- fig, ax = plt.subplots(figsize=(10, 2.6)); fig.patch.set_facecolor(COLORS['bg']); ax.set_facecolor(COLORS['panel'])
160
- xlo = min(m_mean - 3.0 * m_std, loss_val - 0.005); xhi = max(nm_mean + 3.0 * nm_std, loss_val + 0.005)
161
- ax.axvspan(xlo, thr, alpha=0.2, color=COLORS['accent']); ax.axvspan(thr, xhi, alpha=0.2, color=COLORS['danger'])
 
 
 
 
 
 
 
162
  ax.axvline(m_mean, color=COLORS['accent'], lw=2, ls=':', alpha=0.8, zorder=2)
163
  ax.text(m_mean - 0.002, 1.02, f'Member Mean\n{m_mean:.4f}', ha='right', va='bottom', fontsize=9, color=COLORS['accent'], transform=ax.get_xaxis_transform())
 
164
  ax.axvline(nm_mean, color=COLORS['danger'], lw=2, ls=':', alpha=0.8, zorder=2)
165
  ax.text(nm_mean + 0.002, 1.02, f'Non-Member Mean\n{nm_mean:.4f}', ha='left', va='bottom', fontsize=9, color=COLORS['danger'], transform=ax.get_xaxis_transform())
 
166
  ax.axvline(thr, color=COLORS['text_dim'], lw=2.5, ls='--', zorder=3)
167
  ax.text(thr, 1.25, f'Threshold\n{thr:.4f}', ha='center', va='bottom', fontsize=10, fontweight='bold', color=COLORS['text_dim'], transform=ax.get_xaxis_transform())
 
168
  mc = COLORS['accent'] if loss_val < thr else COLORS['danger']
169
  ax.plot(loss_val, 0.5, marker='o', ms=16, color='white', mec=mc, mew=3, zorder=5, transform=ax.get_xaxis_transform())
170
  ax.text(loss_val, 0.75, f'Current Loss\n{loss_val:.4f}', ha='center', fontsize=11, fontweight='bold', color=mc, transform=ax.get_xaxis_transform())
 
171
  ax.text((xlo+thr)/2, 0.25, 'MEMBER', ha='center', fontsize=12, color=COLORS['accent'], alpha=0.6, fontweight='bold', transform=ax.get_xaxis_transform())
172
  ax.text((thr+xhi)/2, 0.25, 'NON-MEMBER', ha='center', fontsize=12, color=COLORS['danger'], alpha=0.6, fontweight='bold', transform=ax.get_xaxis_transform())
 
173
  ax.set_xlim(xlo, xhi); ax.set_yticks([])
174
  for s in ax.spines.values(): s.set_visible(False)
175
- ax.spines['bottom'].set_visible(True); ax.spines['bottom'].set_color(COLORS['grid']); ax.tick_params(colors=COLORS['text_dim'], width=1)
176
- ax.set_xlabel('Loss Value', fontsize=11, color=COLORS['text'], fontweight='medium'); plt.tight_layout(pad=0.5)
 
 
177
  return fig
178
 
179
  def fig_auc_bar():
@@ -195,24 +221,66 @@ def fig_auc_bar():
195
 
196
  def fig_radar():
197
  ms = ['AUC', 'Atk Acc', 'Prec', 'Recall', 'F1', 'TPR@5%', 'TPR@1%', 'Gap']
198
- mk = ['auc', 'attack_accuracy', 'precision', 'recall', 'f1', 'tpr_at_5fpr', 'tpr_at_1fpr', 'loss_gap']
199
- N = len(ms); ag = np.linspace(0, 2 * np.pi, N, endpoint=False).tolist() + [0]
200
- fig, axes = plt.subplots(1, 2, figsize=(CHART_W + 2, 7), subplot_kw=dict(polar=True)); fig.patch.set_facecolor('white')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
- ls_cfgs = [("Baseline", "baseline", '#F04438'), (r"LS($\epsilon$=0.02)", "smooth_eps_0.02", '#B2DDFF'), (r"LS($\epsilon$=0.05)", "smooth_eps_0.05", '#84CAFF'), (r"LS($\epsilon$=0.1)", "smooth_eps_0.1", '#2E90FA'), (r"LS($\epsilon$=0.2)", "smooth_eps_0.2", '#7A5AF8')]
203
- op_cfgs = [("Baseline", "baseline", '#F04438'), (r"OP($\sigma$=0.005)", "perturbation_0.005", '#A6F4C5'), (r"OP($\sigma$=0.01)", "perturbation_0.01", '#6CE9A6'), (r"OP($\sigma$=0.015)", "perturbation_0.015", '#32D583'), (r"OP($\sigma$=0.02)", "perturbation_0.02", '#12B76A'), (r"OP($\sigma$=0.025)", "perturbation_0.025", '#039855'), (r"OP($\sigma$=0.03)", "perturbation_0.03", '#027A48')]
 
 
204
 
205
- for ax_idx, (ax, cfgs, title) in enumerate([(axes[0], ls_cfgs, 'Label Smoothing (5 models)'), (axes[1], op_cfgs, 'Output Perturbation (7 configs)')]):
206
- ax.set_facecolor('white')
207
- mx = [max(gm(k, m_key) for _, k, _ in cfgs) for m_key in mk]; mx = [m if m > 0 else 1 for m in mx]
208
  for nm, ky, cl in cfgs:
209
- v = [gm(ky, m_key) / mx[i] for i, m_key in enumerate(mk)]; v += [v[0]]
210
- ax.plot(ag, v, 'o-', lw=2.8 if ky == 'baseline' else 1.8, label=nm, color=cl, ms=5, alpha=0.95 if ky == 'baseline' else 0.85)
211
- ax.fill(ag, v, alpha=0.10 if ky == 'baseline' else 0.04, color=cl)
212
- ax.set_xticks(ag[:-1]); ax.set_xticklabels(ms, fontsize=10, color=COLORS['text']); ax.set_yticklabels([])
213
- ax.set_title(title, fontsize=12, fontweight='700', color=COLORS['text'], pad=18)
214
- ax.legend(loc='upper right', bbox_to_anchor=(1.35 if ax_idx == 1 else 1.30, 1.12), fontsize=9, framealpha=0.9, edgecolor=COLORS['grid'])
215
- ax.spines['polar'].set_color(COLORS['grid']); ax.grid(color=COLORS['grid'], alpha=0.5)
 
 
 
 
 
 
 
 
 
 
 
 
216
  plt.tight_layout()
217
  return fig
218
 
@@ -239,7 +307,7 @@ def fig_perturb_dist():
239
  ax.hist(mp, bins=bins, alpha=0.6, color=COLORS['accent'], label='Mem+noise', density=True, edgecolor='white')
240
  ax.hist(np_, bins=bins, alpha=0.6, color=COLORS['danger'], label='Non+noise', density=True, edgecolor='white')
241
  pa = gm(f'perturbation_{s}', 'auc')
242
- ax.set_title(f'OP($\sigma$={s})\nAUC={pa:.4f}', fontsize=11, fontweight='semibold'); ax.set_xlabel('Loss', fontsize=10)
243
  ax.legend(fontsize=9, facecolor=COLORS['bg'], edgecolor='none', labelcolor=COLORS['text'])
244
  plt.tight_layout(); return fig
245
 
@@ -259,7 +327,7 @@ def fig_roc_curves():
259
  fpr, tpr, _ = roc_curve(y_true, y_scores); ax.plot(fpr, tpr, color=COLORS['danger'], lw=2.5, label=f'Baseline (AUC={bl_auc:.4f})')
260
  for i, s in enumerate(OP_SIGMAS):
261
  rng_m = np.random.RandomState(42); rng_nm = np.random.RandomState(137); mp = ml_base + rng_m.normal(0, s, len(ml_base)); np_ = nl_base + rng_nm.normal(0, s, len(nl_base)); y_scores_p = np.concatenate([-mp, -np_]); fpr_p, tpr_p, _ = roc_curve(y_true, y_scores_p); auc_p = roc_auc_score(y_true, y_scores_p)
262
- ax.plot(fpr_p, tpr_p, color=COLORS['op_colors'][i], lw=2, label=f'OP($\sigma$={s}) (AUC={auc_p:.4f})')
263
  ax.plot([0,1], [0,1], '--', color=COLORS['text_dim'], lw=1.5, label='Random'); ax.set_xlabel('False Positive Rate', fontsize=12, fontweight='medium'); ax.set_ylabel('True Positive Rate', fontsize=12, fontweight='medium'); ax.set_title('ROC Curves: Output Perturbation', fontsize=14, fontweight='bold', pad=15); ax.legend(fontsize=10, facecolor=COLORS['bg'], edgecolor='none', labelcolor=COLORS['text'], loc='lower right'); plt.tight_layout()
264
  return fig
265
 
@@ -452,12 +520,12 @@ def cb_eval(model_choice):
452
 
453
  def build_full_table():
454
  rows = []
455
- for k, l in zip(LS_KEYS, LS_LABELS_UI):
456
  if k in mia_results:
457
  m = mia_results[k]; u = gu(k)
458
  t = "—" if k == "baseline" else "训练期"; d = "" if k == "baseline" else f"{m['auc']-bl_auc:+.4f}"
459
  rows.append(f"| {l} | {t} | {m['auc']:.4f} | {m['attack_accuracy']:.4f} | {m['precision']:.4f} | {m['recall']:.4f} | {m['f1']:.4f} | {m['tpr_at_5fpr']:.4f} | {m['tpr_at_1fpr']:.4f} | {m['loss_gap']:.4f} | {u:.1f}% | {d} |")
460
- for k, l in zip(OP_KEYS, OP_LABELS_UI):
461
  if k in perturb_results:
462
  m = perturb_results[k]; d = f"{m['auc']-bl_auc:+.4f}"
463
  rows.append(f"| {l} | 推理期 | {m['auc']:.4f} | {m['attack_accuracy']:.4f} | {m['precision']:.4f} | {m['recall']:.4f} | {m['f1']:.4f} | {m['tpr_at_5fpr']:.4f} | {m['tpr_at_1fpr']:.4f} | {m['loss_gap']:.4f} | {bl_acc:.1f}% | {d} |")
@@ -539,7 +607,6 @@ with gr.Blocks(title="MIA攻防研究") as demo:
539
  * 🛡️ **输出扰动 (Output Perturbation, 推理期)**:给 AI 的输出加上“变声器”。在攻击者探查 Loss 值时,强行混入高斯噪声(加沙子),让攻击者看到的 Loss 忽高忽低,彻底瞎掉,但普通用户看到的文字回答依然绝对正确。
540
  """)
541
 
542
- # 实验体系总览图 (如果在目录里则显示)
543
  if os.path.exists(os.path.join(BASE_DIR, "figures", "algo4_overview_cn_final.png")):
544
  gr.Image(os.path.join(BASE_DIR, "figures", "algo4_overview_cn_final.png"), label="实验体系总览", show_label=True)
545
 
@@ -764,7 +831,7 @@ with gr.Blocks(title="MIA攻防研究") as demo:
764
 
765
  ### 结论二:标签平滑是有效的训练期防御
766
 
767
- | 参数 | AUC | AUC降幅 | 效用 | 效用变化 |
768
  |---|---|---|---|---|
769
  | ε=0.02 | {gm('smooth_eps_0.02','auc'):.4f} | {bl_auc-gm('smooth_eps_0.02','auc'):.4f} | {gu('smooth_eps_0.02'):.1f}% | {gu('smooth_eps_0.02')-bl_acc:+.1f}% |
770
  | ε=0.05 | {gm('smooth_eps_0.05','auc'):.4f} | {bl_auc-gm('smooth_eps_0.05','auc'):.4f} | {gu('smooth_eps_0.05'):.1f}% | {gu('smooth_eps_0.05')-bl_acc:+.1f}% |
@@ -775,7 +842,7 @@ with gr.Blocks(title="MIA攻防研究") as demo:
775
 
776
  ### 结论三:输出扰动是有效的推理期防御
777
 
778
- | 参数 | AUC | AUC降幅 | 效用 |
779
  |---|---|---|---|
780
  | σ=0.005 | {gm('perturbation_0.005','auc'):.4f} | {bl_auc-gm('perturbation_0.005','auc'):.4f} | {bl_acc:.1f}% |
781
  | σ=0.01 | {gm('perturbation_0.01','auc'):.4f} | {bl_auc-gm('perturbation_0.01','auc'):.4f} | {bl_acc:.1f}% |
@@ -795,4 +862,4 @@ with gr.Blocks(title="MIA攻防研究") as demo:
795
 
796
  """)
797
 
798
- demo.launch()
 
1
  # ================================================================
2
+ # 教育大模型MIA攻防研究 - Gradio演示系统 v6.2 学术巅峰版 (苹果风)
3
+ # 整合了双雷达图 + 算法流程图 + 伪代码 + 详尽数据分析 + 完整结论
4
+ # !!!全局修正:所有 e 替换为 ε / $\epsilon$,所有 s 替换为 σ / $\sigma$ !!!
5
  # ================================================================
6
 
7
  import os
 
32
  text = re.sub(r'[\u200b-\u200f\u2028-\u202f\u2060-\u206f\ufeff]', '', text)
33
  return text.strip()
34
 
35
+ # 尝试加载数据,如果不存在则使用虚拟数据以确保运行
36
  try:
37
  member_data = load_json("data/member.json")
38
  non_member_data = load_json("data/non_member.json")
 
59
  for s in [0.005, 0.01, 0.015, 0.02, 0.025, 0.03]:
60
  k = f"perturbation_{s}"
61
  perturb_results[k] = {m: v*0.85 for m, v in mia_results["baseline"].items()}
62
+ # 模拟方差变大
63
  perturb_results[k]["member_loss_std"] = np.sqrt(0.03**2 + s**2)
64
  perturb_results[k]["non_member_loss_std"] = np.sqrt(0.03**2 + s**2)
65
 
66
  # ================================================================
67
+ # 全局图表配置 - 简约苹果风
68
  # ================================================================
69
  COLORS = {
70
  'bg': '#FFFFFF',
 
81
  'ls_colors': ['#A0C4FF', '#70A1FF', '#478EFF', '#007AFF'],
82
  'op_colors': ['#98F5E1', '#6EE7B7', '#34D399', '#10B981', '#059669', '#047857'],
83
  }
84
+
85
+ # 图表宽度配置 (为了适配双雷达图)
86
  CHART_W = 14
87
 
88
  def apply_light_style(fig, ax_or_axes):
 
101
  ax.grid(True, color=COLORS['grid'], alpha=0.6, linestyle='-', linewidth=0.8)
102
  ax.set_axisbelow(True)
103
 
104
+ # ================================================================
105
+ # 提取指标的辅助函数 (核心替换:使用 LaTeX \epsilon 和 \sigma 画图)
106
+ # ================================================================
107
  LS_KEYS = ["baseline", "smooth_eps_0.02", "smooth_eps_0.05", "smooth_eps_0.1", "smooth_eps_0.2"]
108
  LS_LABELS_PLOT = ["Baseline", r"LS($\epsilon$=0.02)", r"LS($\epsilon$=0.05)", r"LS($\epsilon$=0.1)", r"LS($\epsilon$=0.2)"]
109
+ LS_LABELS_MD = ["基线(Baseline)", "LS(ε=0.02)", "LS(ε=0.05)", "LS(ε=0.1)", "LS(ε=0.2)"]
110
 
111
  OP_SIGMAS = [0.005, 0.01, 0.015, 0.02, 0.025, 0.03]
112
  OP_KEYS = [f"perturbation_{s}" for s in OP_SIGMAS]
113
+ OP_LABELS_PLOT = [r"OP($\sigma$={})".format(s) for s in OP_SIGMAS]
114
+ OP_LABELS_MD = [f"OP(σ={s})" for s in OP_SIGMAS]
115
 
116
  ALL_KEYS = LS_KEYS + OP_KEYS
117
 
 
130
  bl_m_mean = gm("baseline", "member_loss_mean")
131
  bl_nm_mean = gm("baseline", "non_member_loss_mean")
132
 
133
+ TYPE_CN = {'calculation': '基础计算', 'word_problem': '应用题',
134
+ 'concept': '概念问答', 'error_correction': '错题订正'}
135
 
136
+ # ================================================================
137
+ # 效用评估题库
138
+ # ================================================================
139
  np.random.seed(777)
140
  EVAL_POOL = []
141
  _types = ['calculation']*120 + ['word_problem']*90 + ['concept']*60 + ['error_correction']*30
 
149
  else: _q,_ans=f"{_a} x {_b} = ?",str(_a*_b)
150
  elif _t == 'word_problem':
151
  _a,_b = int(np.random.randint(5,200)), int(np.random.randint(3,50))
152
+ _tpls = [(f"{_a} apples, ate {_b}, left?",str(_a-_b)),
153
+ (f"{_a} per group, {_b} groups, total?",str(_a*_b))]
154
  _q,_ans = _tpls[_i%len(_tpls)]
155
  elif _t == 'concept':
156
  _cs = [("area","Area = space occupied by a shape"),("perimeter","Perimeter = total boundary length")]
 
165
  EVAL_POOL.append(item)
166
 
167
  # ================================================================
168
+ # 图表绘制函数 (全面应用 LaTeX 标签渲染)
169
  # ================================================================
170
  def fig_gauge(loss_val, m_mean, nm_mean, thr, m_std, nm_std):
171
+ fig, ax = plt.subplots(figsize=(10, 2.6))
172
+ fig.patch.set_facecolor(COLORS['bg'])
173
+ ax.set_facecolor(COLORS['panel'])
174
+
175
+ xlo = min(m_mean - 3.0 * m_std, loss_val - 0.005)
176
+ xhi = max(nm_mean + 3.0 * nm_std, loss_val + 0.005)
177
+
178
+ ax.axvspan(xlo, thr, alpha=0.2, color=COLORS['accent'])
179
+ ax.axvspan(thr, xhi, alpha=0.2, color=COLORS['danger'])
180
+
181
  ax.axvline(m_mean, color=COLORS['accent'], lw=2, ls=':', alpha=0.8, zorder=2)
182
  ax.text(m_mean - 0.002, 1.02, f'Member Mean\n{m_mean:.4f}', ha='right', va='bottom', fontsize=9, color=COLORS['accent'], transform=ax.get_xaxis_transform())
183
+
184
  ax.axvline(nm_mean, color=COLORS['danger'], lw=2, ls=':', alpha=0.8, zorder=2)
185
  ax.text(nm_mean + 0.002, 1.02, f'Non-Member Mean\n{nm_mean:.4f}', ha='left', va='bottom', fontsize=9, color=COLORS['danger'], transform=ax.get_xaxis_transform())
186
+
187
  ax.axvline(thr, color=COLORS['text_dim'], lw=2.5, ls='--', zorder=3)
188
  ax.text(thr, 1.25, f'Threshold\n{thr:.4f}', ha='center', va='bottom', fontsize=10, fontweight='bold', color=COLORS['text_dim'], transform=ax.get_xaxis_transform())
189
+
190
  mc = COLORS['accent'] if loss_val < thr else COLORS['danger']
191
  ax.plot(loss_val, 0.5, marker='o', ms=16, color='white', mec=mc, mew=3, zorder=5, transform=ax.get_xaxis_transform())
192
  ax.text(loss_val, 0.75, f'Current Loss\n{loss_val:.4f}', ha='center', fontsize=11, fontweight='bold', color=mc, transform=ax.get_xaxis_transform())
193
+
194
  ax.text((xlo+thr)/2, 0.25, 'MEMBER', ha='center', fontsize=12, color=COLORS['accent'], alpha=0.6, fontweight='bold', transform=ax.get_xaxis_transform())
195
  ax.text((thr+xhi)/2, 0.25, 'NON-MEMBER', ha='center', fontsize=12, color=COLORS['danger'], alpha=0.6, fontweight='bold', transform=ax.get_xaxis_transform())
196
+
197
  ax.set_xlim(xlo, xhi); ax.set_yticks([])
198
  for s in ax.spines.values(): s.set_visible(False)
199
+ ax.spines['bottom'].set_visible(True); ax.spines['bottom'].set_color(COLORS['grid'])
200
+ ax.tick_params(colors=COLORS['text_dim'], width=1)
201
+ ax.set_xlabel('Loss Value', fontsize=11, color=COLORS['text'], fontweight='medium')
202
+ plt.tight_layout(pad=0.5)
203
  return fig
204
 
205
  def fig_auc_bar():
 
221
 
222
  def fig_radar():
223
  ms = ['AUC', 'Atk Acc', 'Prec', 'Recall', 'F1', 'TPR@5%', 'TPR@1%', 'Gap']
224
+ mk = ['auc', 'attack_accuracy', 'precision', 'recall', 'f1',
225
+ 'tpr_at_5fpr', 'tpr_at_1fpr', 'loss_gap']
226
+ N = len(ms)
227
+ ag = np.linspace(0, 2 * np.pi, N, endpoint=False).tolist() + [0]
228
+
229
+ fig, axes = plt.subplots(1, 2, figsize=(CHART_W + 2, 7),
230
+ subplot_kw=dict(polar=True))
231
+ fig.patch.set_facecolor('white')
232
+
233
+ # --- 左图: 5个标签平滑模型 (替换LaTeX) ---
234
+ ls_cfgs = [
235
+ ("Baseline", "baseline", '#F04438'),
236
+ (r"LS($\epsilon$=0.02)", "smooth_eps_0.02", '#B2DDFF'),
237
+ (r"LS($\epsilon$=0.05)", "smooth_eps_0.05", '#84CAFF'),
238
+ (r"LS($\epsilon$=0.1)", "smooth_eps_0.1", '#2E90FA'),
239
+ (r"LS($\epsilon$=0.2)", "smooth_eps_0.2", '#7A5AF8'),
240
+ ]
241
+
242
+ # --- 右图: Baseline + 6个输出扰动 (替换LaTeX) ---
243
+ op_cfgs = [
244
+ ("Baseline", "baseline", '#F04438'),
245
+ (r"OP($\sigma$=0.005)", "perturbation_0.005", '#A6F4C5'),
246
+ (r"OP($\sigma$=0.01)", "perturbation_0.01", '#6CE9A6'),
247
+ (r"OP($\sigma$=0.015)", "perturbation_0.015", '#32D583'),
248
+ (r"OP($\sigma$=0.02)", "perturbation_0.02", '#12B76A'),
249
+ (r"OP($\sigma$=0.025)", "perturbation_0.025", '#039855'),
250
+ (r"OP($\sigma$=0.03)", "perturbation_0.03", '#027A48'),
251
+ ]
252
+
253
+ for ax_idx, (ax, cfgs, title) in enumerate([
254
+ (axes[0], ls_cfgs, 'Label Smoothing (5 models)'),
255
+ (axes[1], op_cfgs, 'Output Perturbation (7 configs)')
256
+ ]):
257
+ ax.set_facecolor('white')
258
 
259
+ mx = []
260
+ for i, m_key in enumerate(mk):
261
+ val_max = max(gm(k, m_key) for _, k, _ in cfgs)
262
+ mx.append(val_max if val_max > 0 else 1)
263
 
 
 
 
264
  for nm, ky, cl in cfgs:
265
+ v = [gm(ky, m_key) / mx[i] for i, m_key in enumerate(mk)]
266
+ v += [v[0]] # 闭合
267
+ lw = 2.8 if ky == 'baseline' else 1.8
268
+ alpha_fill = 0.10 if ky == 'baseline' else 0.04
269
+ ax.plot(ag, v, 'o-', lw=lw, label=nm, color=cl, ms=5,
270
+ alpha=0.95 if ky == 'baseline' else 0.85)
271
+ ax.fill(ag, v, alpha=alpha_fill, color=cl)
272
+
273
+ ax.set_xticks(ag[:-1])
274
+ ax.set_xticklabels(ms, fontsize=9, color=COLORS['text'])
275
+ ax.set_yticklabels([])
276
+ ax.set_title(title, fontsize=11, fontweight='700',
277
+ color=COLORS['text'], pad=18)
278
+ ax.legend(loc='upper right',
279
+ bbox_to_anchor=(1.35 if ax_idx == 1 else 1.30, 1.12),
280
+ fontsize=8, framealpha=0.9, edgecolor=COLORS['grid'])
281
+ ax.spines['polar'].set_color(COLORS['grid'])
282
+ ax.grid(color=COLORS['grid'], alpha=0.5)
283
+
284
  plt.tight_layout()
285
  return fig
286
 
 
307
  ax.hist(mp, bins=bins, alpha=0.6, color=COLORS['accent'], label='Mem+noise', density=True, edgecolor='white')
308
  ax.hist(np_, bins=bins, alpha=0.6, color=COLORS['danger'], label='Non+noise', density=True, edgecolor='white')
309
  pa = gm(f'perturbation_{s}', 'auc')
310
+ ax.set_title(r'OP($\sigma$={})'.format(s) + f'\nAUC={pa:.4f}', fontsize=11, fontweight='semibold'); ax.set_xlabel('Loss', fontsize=10)
311
  ax.legend(fontsize=9, facecolor=COLORS['bg'], edgecolor='none', labelcolor=COLORS['text'])
312
  plt.tight_layout(); return fig
313
 
 
327
  fpr, tpr, _ = roc_curve(y_true, y_scores); ax.plot(fpr, tpr, color=COLORS['danger'], lw=2.5, label=f'Baseline (AUC={bl_auc:.4f})')
328
  for i, s in enumerate(OP_SIGMAS):
329
  rng_m = np.random.RandomState(42); rng_nm = np.random.RandomState(137); mp = ml_base + rng_m.normal(0, s, len(ml_base)); np_ = nl_base + rng_nm.normal(0, s, len(nl_base)); y_scores_p = np.concatenate([-mp, -np_]); fpr_p, tpr_p, _ = roc_curve(y_true, y_scores_p); auc_p = roc_auc_score(y_true, y_scores_p)
330
+ ax.plot(fpr_p, tpr_p, color=COLORS['op_colors'][i], lw=2, label=r'OP($\sigma$={}) (AUC={:.4f})'.format(s, auc_p))
331
  ax.plot([0,1], [0,1], '--', color=COLORS['text_dim'], lw=1.5, label='Random'); ax.set_xlabel('False Positive Rate', fontsize=12, fontweight='medium'); ax.set_ylabel('True Positive Rate', fontsize=12, fontweight='medium'); ax.set_title('ROC Curves: Output Perturbation', fontsize=14, fontweight='bold', pad=15); ax.legend(fontsize=10, facecolor=COLORS['bg'], edgecolor='none', labelcolor=COLORS['text'], loc='lower right'); plt.tight_layout()
332
  return fig
333
 
 
520
 
521
  def build_full_table():
522
  rows = []
523
+ for k, l in zip(LS_KEYS, LS_LABELS_MD):
524
  if k in mia_results:
525
  m = mia_results[k]; u = gu(k)
526
  t = "—" if k == "baseline" else "训练期"; d = "" if k == "baseline" else f"{m['auc']-bl_auc:+.4f}"
527
  rows.append(f"| {l} | {t} | {m['auc']:.4f} | {m['attack_accuracy']:.4f} | {m['precision']:.4f} | {m['recall']:.4f} | {m['f1']:.4f} | {m['tpr_at_5fpr']:.4f} | {m['tpr_at_1fpr']:.4f} | {m['loss_gap']:.4f} | {u:.1f}% | {d} |")
528
+ for k, l in zip(OP_KEYS, OP_LABELS_MD):
529
  if k in perturb_results:
530
  m = perturb_results[k]; d = f"{m['auc']-bl_auc:+.4f}"
531
  rows.append(f"| {l} | 推理期 | {m['auc']:.4f} | {m['attack_accuracy']:.4f} | {m['precision']:.4f} | {m['recall']:.4f} | {m['f1']:.4f} | {m['tpr_at_5fpr']:.4f} | {m['tpr_at_1fpr']:.4f} | {m['loss_gap']:.4f} | {bl_acc:.1f}% | {d} |")
 
607
  * 🛡️ **输出扰动 (Output Perturbation, 推理期)**:给 AI 的输出加上“变声器”。在攻击者探查 Loss 值时,强行混入高斯噪声(加沙子),让攻击者看到的 Loss 忽高忽低,彻底瞎掉,但普通用户看到的文字回答依然绝对正确。
608
  """)
609
 
 
610
  if os.path.exists(os.path.join(BASE_DIR, "figures", "algo4_overview_cn_final.png")):
611
  gr.Image(os.path.join(BASE_DIR, "figures", "algo4_overview_cn_final.png"), label="实验体系总览", show_label=True)
612
 
 
831
 
832
  ### 结论二:标签平滑是有效的训练期防御
833
 
834
+ | ε 参数 | AUC | AUC降幅 | 效用 | 效用变化 |
835
  |---|---|---|---|---|
836
  | ε=0.02 | {gm('smooth_eps_0.02','auc'):.4f} | {bl_auc-gm('smooth_eps_0.02','auc'):.4f} | {gu('smooth_eps_0.02'):.1f}% | {gu('smooth_eps_0.02')-bl_acc:+.1f}% |
837
  | ε=0.05 | {gm('smooth_eps_0.05','auc'):.4f} | {bl_auc-gm('smooth_eps_0.05','auc'):.4f} | {gu('smooth_eps_0.05'):.1f}% | {gu('smooth_eps_0.05')-bl_acc:+.1f}% |
 
842
 
843
  ### 结论三:输出扰动是有效的推理期防御
844
 
845
+ | σ 参数 | AUC | AUC降幅 | 效用 |
846
  |---|---|---|---|
847
  | σ=0.005 | {gm('perturbation_0.005','auc'):.4f} | {bl_auc-gm('perturbation_0.005','auc'):.4f} | {bl_acc:.1f}% |
848
  | σ=0.01 | {gm('perturbation_0.01','auc'):.4f} | {bl_auc-gm('perturbation_0.01','auc'):.4f} | {bl_acc:.1f}% |
 
862
 
863
  """)
864
 
865
+ demo.launch(theme=gr.themes.Soft(), css=CSS)