xiaohy commited on
Commit
490cead
·
verified ·
1 Parent(s): d4c815e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -4
app.py CHANGED
@@ -207,15 +207,20 @@ def fig_radar():
207
  ls_cfgs = [("Baseline", "baseline", '#F04438'), ("LS(ε=0.02)", "smooth_eps_0.02", '#B2DDFF'), ("LS(ε=0.05)", "smooth_eps_0.05", '#84CAFF'), ("LS(ε=0.1)", "smooth_eps_0.1", '#2E90FA'), ("LS(ε=0.2)", "smooth_eps_0.2", '#7A5AF8')]
208
  op_cfgs = [("Baseline", "baseline", '#F04438'), ("OP(σ=0.005)", "perturbation_0.005", '#A6F4C5'), ("OP(σ=0.01)", "perturbation_0.01", '#6CE9A6'), ("OP(σ=0.015)", "perturbation_0.015", '#32D583'), ("OP(σ=0.02)", "perturbation_0.02", '#12B76A'), ("OP(σ=0.025)", "perturbation_0.025", '#039855'), ("OP(σ=0.03)", "perturbation_0.03", '#027A48')]
209
 
210
- for ax_idx, (ax, cfgs, title) in enumerate([(axes[0], ls_cfgs, 'Label Smoothing (5 models)'), (axes[1], op_cfgs, 'Output Perturbation (7 configs)')]):
 
 
 
 
 
211
  ax.set_facecolor('white')
212
- mx = [max(gm(k, m_key) for _, k, _ in cfgs) for m_key in mk]; mx = [m if m > 0 else 1 for m in mx]
213
  for nm, ky, cl in cfgs:
214
- v = [gm(ky, m_key) / mx[i] for i, m_key in enumerate(mk)]; v += [v[0]]
 
215
  ax.plot(ag, v, 'o-', lw=2.8 if ky == 'baseline' else 1.8, label=nm, color=cl, ms=5, alpha=0.95 if ky == 'baseline' else 0.85)
216
  ax.fill(ag, v, alpha=0.10 if ky == 'baseline' else 0.04, color=cl)
217
  ax.set_xticks(ag[:-1]); ax.set_xticklabels(ms, fontsize=10, color=COLORS['text']); ax.set_yticklabels([])
218
- ax.set_title(title, fontsize=12, fontweight='700', color=COLORS['text'], pad=18)
219
  ax.legend(loc='upper right', bbox_to_anchor=(1.35 if ax_idx == 1 else 1.30, 1.12), fontsize=9, framealpha=0.9, edgecolor=COLORS['grid'])
220
  ax.spines['polar'].set_color(COLORS['grid']); ax.grid(color=COLORS['grid'], alpha=0.5)
221
  plt.tight_layout()
 
207
  ls_cfgs = [("Baseline", "baseline", '#F04438'), ("LS(ε=0.02)", "smooth_eps_0.02", '#B2DDFF'), ("LS(ε=0.05)", "smooth_eps_0.05", '#84CAFF'), ("LS(ε=0.1)", "smooth_eps_0.1", '#2E90FA'), ("LS(ε=0.2)", "smooth_eps_0.2", '#7A5AF8')]
208
  op_cfgs = [("Baseline", "baseline", '#F04438'), ("OP(σ=0.005)", "perturbation_0.005", '#A6F4C5'), ("OP(σ=0.01)", "perturbation_0.01", '#6CE9A6'), ("OP(σ=0.015)", "perturbation_0.015", '#32D583'), ("OP(σ=0.02)", "perturbation_0.02", '#12B76A'), ("OP(σ=0.025)", "perturbation_0.025", '#039855'), ("OP(σ=0.03)", "perturbation_0.03", '#027A48')]
209
 
210
+ # 🌟 核心修改:全局归一化,锁定统一的天花板
211
+ all_configs_keys = ["baseline"] + LS_KEYS[1:] + OP_KEYS
212
+ mx_global = [max(gm(k, m_key) for k in all_configs_keys) for m_key in mk]
213
+ mx_global = [m if m > 0 else 1 for m in mx_global]
214
+
215
+ for ax_idx, (ax, cfgs, title) in enumerate([(axes[0], ls_cfgs, 'Radar Analysis: Label Smoothing'), (axes[1], op_cfgs, 'Radar Analysis: Output Perturbation')]):
216
  ax.set_facecolor('white')
 
217
  for nm, ky, cl in cfgs:
218
+ # 🌟 这里统一除以 mx_global,确保红线形状完全一致
219
+ v = [gm(ky, m_key) / mx_global[i] for i, m_key in enumerate(mk)]; v += [v[0]]
220
  ax.plot(ag, v, 'o-', lw=2.8 if ky == 'baseline' else 1.8, label=nm, color=cl, ms=5, alpha=0.95 if ky == 'baseline' else 0.85)
221
  ax.fill(ag, v, alpha=0.10 if ky == 'baseline' else 0.04, color=cl)
222
  ax.set_xticks(ag[:-1]); ax.set_xticklabels(ms, fontsize=10, color=COLORS['text']); ax.set_yticklabels([])
223
+ ax.set_title(title, fontsize=14, fontweight='bold', color=COLORS['text'], pad=25)
224
  ax.legend(loc='upper right', bbox_to_anchor=(1.35 if ax_idx == 1 else 1.30, 1.12), fontsize=9, framealpha=0.9, edgecolor=COLORS['grid'])
225
  ax.spines['polar'].set_color(COLORS['grid']); ax.grid(color=COLORS['grid'], alpha=0.5)
226
  plt.tight_layout()