| |
| """ |
| Generate comparative gradient saliency visualizations for thesis claims. |
| |
| This script produces publication-quality figures from the gradient maps |
| generated by generate_gradient_maps.py. Each figure directly supports |
| specific claims in the thesis chapters. |
| |
| Figures generated: |
| 1. LMIC-TSBN vs CNN vs MLP saliency comparison (byte 2, desync=0) |
| 2. CNN degradation under desync (byte 0: desync=0 vs 50 vs 100) |
| 3. LMIC-TSBN 16-byte grid showing uniform attention (desync=0) |
| 4. HPS/MTAN-Lite representation competition (all bytes, desync=50/100) |
| 5. Ablation: binary encoding effect on gradient quality |
| 6. Ablation: TSBN effect on cross-byte interference |
| 7. MLP diffuse attention vs CNN localized attention |
| 8. Desync robustness: LMIC-TSBN gradient stability across desync levels |
| """ |
|
|
| import os |
| import json |
| import numpy as np |
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
| from matplotlib.patches import Rectangle |
| from pathlib import Path |
|
|
| |
| |
| |
|
|
| GRADIENT_MAPS_DIR = "/home/ubuntu/gradient_maps_data/gradient_maps" |
| OUTPUT_DIR = "/home/ubuntu/gradient_figures" |
| os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
| |
| plt.rcParams.update({ |
| 'font.family': 'serif', |
| 'font.size': 9, |
| 'axes.labelsize': 10, |
| 'axes.titlesize': 11, |
| 'xtick.labelsize': 8, |
| 'ytick.labelsize': 8, |
| 'legend.fontsize': 8, |
| 'figure.dpi': 300, |
| 'savefig.dpi': 300, |
| 'savefig.bbox': 'tight', |
| 'axes.grid': True, |
| 'grid.alpha': 0.3, |
| }) |
|
|
| |
| BYTE_POI_WINDOWS = { |
| 0: (45800, 46500), 1: (46100, 46800), 2: (45400, 46100), |
| 3: (46400, 47100), 4: (47000, 47700), 5: (47300, 48000), |
| 6: (47600, 48300), 7: (47900, 48600), 8: (48200, 48900), |
| 9: (48500, 49200), 10: (48800, 49500), 11: (49100, 49800), |
| 12: (49400, 50100), 13: (49700, 50400), 14: (50000, 50700), |
| 15: (50300, 51000), |
| } |
|
|
| WINDOW_SIZE = 700 |
|
|
| |
| BYTE_PEAK_SNR = { |
| 0: 0.1523, 1: 0.1843, 2: 0.2103, 3: 0.1402, |
| 4: 0.1652, 5: 0.1891, 6: 0.1340, 7: 0.1578, |
| 8: 0.1815, 9: 0.1253, 10: 0.1489, 11: 0.1725, |
| 12: 0.1161, 13: 0.1396, 14: 0.1631, 15: 0.1065, |
| } |
|
|
|
|
| def load_gradient_map(run_name, byte_idx=None): |
| """Load gradient map(s) for a run.""" |
| run_dir = os.path.join(GRADIENT_MAPS_DIR, run_name) |
| if not os.path.exists(run_dir): |
| return None |
| |
| if byte_idx is not None: |
| |
| path = os.path.join(run_dir, f"gradient_map_byte{byte_idx}.npy") |
| if os.path.exists(path): |
| return np.load(path) |
| |
| path = os.path.join(run_dir, "gradient_map.npy") |
| if os.path.exists(path): |
| return np.load(path) |
| else: |
| |
| path = os.path.join(run_dir, "gradient_map.npy") |
| if os.path.exists(path): |
| return np.load(path) |
| return None |
|
|
|
|
| def load_all_bytes(run_name): |
| """Load all 16 byte gradient maps for a multi-task model.""" |
| maps = {} |
| run_dir = os.path.join(GRADIENT_MAPS_DIR, run_name) |
| if not os.path.exists(run_dir): |
| return maps |
| for b in range(16): |
| path = os.path.join(run_dir, f"gradient_map_byte{b}.npy") |
| if os.path.exists(path): |
| maps[b] = np.load(path) |
| return maps |
|
|
|
|
| def normalize_saliency(sal): |
| """Normalize saliency to [0, 1] for visualization.""" |
| if sal.max() == 0: |
| return sal |
| return sal / sal.max() |
|
|
|
|
| |
| |
| |
| |
| |
|
|
| def figure1_architecture_comparison(): |
| """Compare gradient saliency across architectures for byte 2 at desync=0.""" |
| fig, axes = plt.subplots(3, 1, figsize=(8, 6), sharex=True) |
| |
| |
| lmic = load_gradient_map("RERUN-LMIC-TSBN-V7b-multibit-desync0", byte_idx=2) |
| cnn = load_gradient_map("RERUN-CNN-byte2-desync0") |
| mlp = load_gradient_map("RERUN-MLP-byte2-desync0") or load_gradient_map("CLEAN-MLP-byte2-desync0") |
| |
| x = np.arange(WINDOW_SIZE) |
| |
| models = [ |
| ("LMIC-TSBN (Multi-Task)", lmic, 'tab:blue'), |
| ("CNN (Single-Byte)", cnn, 'tab:orange'), |
| ("MLP (Single-Byte)", mlp, 'tab:green'), |
| ] |
| |
| for ax, (name, sal, color) in zip(axes, models): |
| if sal is not None: |
| sal_norm = normalize_saliency(sal) |
| ax.fill_between(x, sal_norm, alpha=0.3, color=color) |
| ax.plot(x, sal_norm, linewidth=0.8, color=color) |
| ax.set_ylabel("Norm. Saliency") |
| ax.set_title(name, fontsize=10, fontweight='bold') |
| ax.set_ylim(0, 1.05) |
| |
| peak_pos = int(BYTE_PEAK_SNR[2] * WINDOW_SIZE) |
| ax.axvline(peak_pos, color='red', linestyle='--', alpha=0.7, linewidth=1, |
| label=f'SNR Peak (sample {peak_pos})') |
| ax.legend(loc='upper right') |
| else: |
| ax.text(0.5, 0.5, f"{name}: Data not available", |
| transform=ax.transAxes, ha='center', va='center') |
| |
| axes[-1].set_xlabel("Time Sample (within 700-sample window)") |
| fig.suptitle("Gradient Saliency Comparison: Byte 2, Desync=0", fontsize=12, fontweight='bold') |
| plt.tight_layout() |
| plt.savefig(os.path.join(OUTPUT_DIR, "fig1_architecture_comparison_byte2.png")) |
| plt.close() |
| print(" Figure 1: Architecture comparison saved") |
|
|
|
|
| |
| |
| |
| |
| |
|
|
| def figure2_cnn_desync_degradation(): |
| """Show CNN gradient maps degrading under increasing desync.""" |
| fig, axes = plt.subplots(3, 1, figsize=(8, 6), sharex=True) |
| |
| desyncs = [0, 50, 100] |
| x = np.arange(WINDOW_SIZE) |
| |
| for ax, desync in zip(axes, desyncs): |
| sal = load_gradient_map(f"RERUN-CNN-byte0-desync{desync}") |
| if sal is not None: |
| sal_norm = normalize_saliency(sal) |
| ax.fill_between(x, sal_norm, alpha=0.3, color='tab:orange') |
| ax.plot(x, sal_norm, linewidth=0.8, color='tab:orange') |
| ax.set_ylabel("Norm. Saliency") |
| ax.set_title(f"CNN Byte 0, Desync={desync}", fontsize=10, fontweight='bold') |
| ax.set_ylim(0, 1.05) |
| |
| peak_pos = int(BYTE_PEAK_SNR[0] * WINDOW_SIZE) |
| ax.axvline(peak_pos, color='red', linestyle='--', alpha=0.7, linewidth=1) |
| |
| ax.text(0.98, 0.85, f"max={sal.max():.4f}\nmean={sal.mean():.4f}", |
| transform=ax.transAxes, ha='right', va='top', fontsize=7, |
| bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) |
| else: |
| ax.text(0.5, 0.5, f"Data not available", transform=ax.transAxes, ha='center') |
| |
| axes[-1].set_xlabel("Time Sample (within 700-sample window)") |
| fig.suptitle("CNN Gradient Saliency Degradation Under Desynchronization (Byte 0)", |
| fontsize=12, fontweight='bold') |
| plt.tight_layout() |
| plt.savefig(os.path.join(OUTPUT_DIR, "fig2_cnn_desync_degradation.png")) |
| plt.close() |
| print(" Figure 2: CNN desync degradation saved") |
|
|
|
|
| |
| |
| |
| |
| |
|
|
| def figure3_lmic_tsbn_16byte_grid(): |
| """Show LMIC-TSBN gradient maps for all 16 bytes in a 4x4 grid.""" |
| fig, axes = plt.subplots(4, 4, figsize=(12, 8)) |
| |
| |
| run_name = "RERUN-LMIC-TSBN-V7b-multibit-desync0" |
| maps = load_all_bytes(run_name) |
| if not maps: |
| run_name = "LMIC-TSBN-V8a-bitDTP-desync0" |
| maps = load_all_bytes(run_name) |
| |
| x = np.arange(WINDOW_SIZE) |
| |
| for byte_idx in range(16): |
| row, col = byte_idx // 4, byte_idx % 4 |
| ax = axes[row, col] |
| |
| if byte_idx in maps: |
| sal = maps[byte_idx] |
| sal_norm = normalize_saliency(sal) |
| ax.fill_between(x, sal_norm, alpha=0.3, color='tab:blue') |
| ax.plot(x, sal_norm, linewidth=0.5, color='tab:blue') |
| |
| peak_pos = int(BYTE_PEAK_SNR[byte_idx] * WINDOW_SIZE) |
| ax.axvline(peak_pos, color='red', linestyle='--', alpha=0.7, linewidth=0.8) |
| ax.set_title(f"Byte {byte_idx}", fontsize=8) |
| ax.set_ylim(0, 1.05) |
| else: |
| ax.text(0.5, 0.5, "N/A", transform=ax.transAxes, ha='center') |
| |
| ax.set_xticks([]) |
| ax.set_yticks([]) |
| if col == 0: |
| ax.set_ylabel("Saliency", fontsize=7) |
| if row == 3: |
| ax.set_xlabel("Time", fontsize=7) |
| |
| fig.suptitle(f"LMIC-TSBN Gradient Saliency: All 16 Bytes (Desync=0)\n" |
| f"Red dashed = SNR peak position", fontsize=12, fontweight='bold') |
| plt.tight_layout() |
| plt.savefig(os.path.join(OUTPUT_DIR, "fig3_lmic_tsbn_16byte_grid.png")) |
| plt.close() |
| print(" Figure 3: LMIC-TSBN 16-byte grid saved") |
|
|
|
|
| |
| |
| |
| |
| |
|
|
| def figure4_hps_representation_competition(): |
| """Show HPS gradient maps revealing representation competition.""" |
| fig, axes = plt.subplots(4, 4, figsize=(12, 8)) |
| |
| |
| run_name = "HPS-baseline-desync100" |
| maps = load_all_bytes(run_name) |
| if not maps: |
| run_name = "HPS-baseline-desync50" |
| maps = load_all_bytes(run_name) |
| |
| if not maps: |
| print(" Figure 4: SKIPPED - No HPS gradient data available") |
| return |
| |
| |
| global_max = max(m.max() for m in maps.values()) if maps else 1.0 |
| |
| x = np.arange(len(list(maps.values())[0])) if maps else np.arange(700) |
| |
| for byte_idx in range(16): |
| row, col = byte_idx // 4, byte_idx % 4 |
| ax = axes[row, col] |
| |
| if byte_idx in maps: |
| sal = maps[byte_idx] |
| |
| sal_global = sal / global_max if global_max > 0 else sal |
| ax.fill_between(range(len(sal)), sal_global, alpha=0.3, color='tab:red') |
| ax.plot(sal_global, linewidth=0.5, color='tab:red') |
| ax.set_title(f"Byte {byte_idx} (max={sal.max():.2e})", fontsize=7) |
| ax.set_ylim(0, 1.05) |
| else: |
| ax.text(0.5, 0.5, "N/A", transform=ax.transAxes, ha='center') |
| |
| ax.set_xticks([]) |
| ax.set_yticks([]) |
| |
| fig.suptitle(f"HPS Gradient Saliency: Representation Competition (Desync=100)\n" |
| f"Note: Globally normalized - bytes 0/1 dominate, others near-zero", |
| fontsize=11, fontweight='bold') |
| plt.tight_layout() |
| plt.savefig(os.path.join(OUTPUT_DIR, "fig4_hps_representation_competition.png")) |
| plt.close() |
| print(" Figure 4: HPS representation competition saved") |
|
|
|
|
| |
| |
| |
| |
| |
|
|
| def figure5_binary_encoding_effect(): |
| """Compare gradient quality with and without binary encoding.""" |
| fig, axes = plt.subplots(2, 4, figsize=(12, 5)) |
| |
| |
| multibit_maps = load_all_bytes("RERUN-LMIC-TSBN-V7b-multibit-desync0") |
| no_multibit_maps = load_all_bytes("ABLATION-A1-no-multibit-desync0") |
| |
| if not multibit_maps and not no_multibit_maps: |
| |
| multibit_maps = load_all_bytes("LMIC-TSBN-V8a-bitDTP-desync0") |
| |
| |
| show_bytes = [0, 2, 8, 14] |
| |
| for col, byte_idx in enumerate(show_bytes): |
| |
| ax = axes[0, col] |
| if byte_idx in multibit_maps: |
| sal = multibit_maps[byte_idx] |
| sal_norm = normalize_saliency(sal) |
| ax.fill_between(range(len(sal)), sal_norm, alpha=0.3, color='tab:blue') |
| ax.plot(sal_norm, linewidth=0.5, color='tab:blue') |
| ax.set_title(f"Byte {byte_idx}", fontsize=9) |
| ax.set_ylim(0, 1.05) |
| ax.set_xticks([]) |
| if col == 0: |
| ax.set_ylabel("Binary Enc.", fontsize=9, fontweight='bold') |
| |
| |
| ax = axes[1, col] |
| if byte_idx in no_multibit_maps: |
| sal = no_multibit_maps[byte_idx] |
| sal_norm = normalize_saliency(sal) |
| ax.fill_between(range(len(sal)), sal_norm, alpha=0.3, color='tab:purple') |
| ax.plot(sal_norm, linewidth=0.5, color='tab:purple') |
| ax.set_ylim(0, 1.05) |
| ax.set_xticks([]) |
| if col == 0: |
| ax.set_ylabel("Identity Enc.", fontsize=9, fontweight='bold') |
| |
| fig.suptitle("Effect of Binary Encoding on Gradient Saliency (Desync=0)\n" |
| "Binary encoding produces sharper, more focused gradients", |
| fontsize=11, fontweight='bold') |
| plt.tight_layout() |
| plt.savefig(os.path.join(OUTPUT_DIR, "fig5_binary_encoding_effect.png")) |
| plt.close() |
| print(" Figure 5: Binary encoding effect saved") |
|
|
|
|
| |
| |
| |
| |
| |
|
|
| def figure6_tsbn_effect(): |
| """Compare gradient maps with and without TSBN.""" |
| fig, axes = plt.subplots(2, 4, figsize=(12, 5)) |
| |
| |
| with_tsbn = load_all_bytes("RERUN-LMIC-TSBN-V7b-multibit-desync0") |
| without_tsbn = load_all_bytes("ABLATION-LMIC-no-TSBN-desync0") |
| |
| if not with_tsbn: |
| with_tsbn = load_all_bytes("LMIC-TSBN-V8a-bitDTP-desync0") |
| |
| show_bytes = [0, 2, 8, 14] |
| |
| for col, byte_idx in enumerate(show_bytes): |
| |
| ax = axes[0, col] |
| if byte_idx in with_tsbn: |
| sal = with_tsbn[byte_idx] |
| sal_norm = normalize_saliency(sal) |
| ax.fill_between(range(len(sal)), sal_norm, alpha=0.3, color='tab:blue') |
| ax.plot(sal_norm, linewidth=0.5, color='tab:blue') |
| ax.set_title(f"Byte {byte_idx}", fontsize=9) |
| ax.set_ylim(0, 1.05) |
| ax.set_xticks([]) |
| if col == 0: |
| ax.set_ylabel("With TSBN", fontsize=9, fontweight='bold') |
| |
| |
| ax = axes[1, col] |
| if byte_idx in without_tsbn: |
| sal = without_tsbn[byte_idx] |
| sal_norm = normalize_saliency(sal) |
| ax.fill_between(range(len(sal)), sal_norm, alpha=0.3, color='tab:red') |
| ax.plot(sal_norm, linewidth=0.5, color='tab:red') |
| ax.set_ylim(0, 1.05) |
| ax.set_xticks([]) |
| if col == 0: |
| ax.set_ylabel("Without TSBN", fontsize=9, fontweight='bold') |
| |
| fig.suptitle("Effect of Task-Specific Batch Normalization on Gradient Saliency (Desync=0)\n" |
| "TSBN prevents cross-byte interference in normalization statistics", |
| fontsize=11, fontweight='bold') |
| plt.tight_layout() |
| plt.savefig(os.path.join(OUTPUT_DIR, "fig6_tsbn_effect.png")) |
| plt.close() |
| print(" Figure 6: TSBN effect saved") |
|
|
|
|
| |
| |
| |
| |
| |
|
|
| def figure7_mlp_vs_cnn_attention(): |
| """Compare MLP (diffuse) vs CNN (localized) gradient patterns.""" |
| fig, axes = plt.subplots(2, 3, figsize=(10, 5)) |
| |
| bytes_to_show = [0, 5, 11] |
| |
| for col, byte_idx in enumerate(bytes_to_show): |
| |
| ax = axes[0, col] |
| sal = load_gradient_map(f"RERUN-CNN-byte{byte_idx}-desync0") |
| if sal is not None: |
| sal_norm = normalize_saliency(sal) |
| ax.fill_between(range(len(sal)), sal_norm, alpha=0.3, color='tab:orange') |
| ax.plot(sal_norm, linewidth=0.5, color='tab:orange') |
| ax.set_title(f"Byte {byte_idx}", fontsize=9) |
| ax.set_ylim(0, 1.05) |
| |
| peak_pos = int(BYTE_PEAK_SNR[byte_idx] * WINDOW_SIZE) |
| ax.axvline(peak_pos, color='red', linestyle='--', alpha=0.7, linewidth=0.8) |
| ax.set_xticks([]) |
| if col == 0: |
| ax.set_ylabel("CNN", fontsize=9, fontweight='bold') |
| |
| |
| ax = axes[1, col] |
| sal = load_gradient_map(f"RERUN-MLP-byte{byte_idx}-desync0") |
| if sal is None: |
| sal = load_gradient_map(f"CLEAN-MLP-byte{byte_idx}-desync0") |
| if sal is not None: |
| sal_norm = normalize_saliency(sal) |
| ax.fill_between(range(len(sal)), sal_norm, alpha=0.3, color='tab:green') |
| ax.plot(sal_norm, linewidth=0.5, color='tab:green') |
| ax.set_ylim(0, 1.05) |
| peak_pos = int(BYTE_PEAK_SNR[byte_idx] * WINDOW_SIZE) |
| ax.axvline(peak_pos, color='red', linestyle='--', alpha=0.7, linewidth=0.8) |
| ax.set_xticks([]) |
| if col == 0: |
| ax.set_ylabel("MLP", fontsize=9, fontweight='bold') |
| |
| fig.suptitle("CNN vs MLP Gradient Saliency Patterns (Desync=0)\n" |
| "CNN: Localized peaks at leakage points | MLP: Diffuse, no spatial structure", |
| fontsize=11, fontweight='bold') |
| plt.tight_layout() |
| plt.savefig(os.path.join(OUTPUT_DIR, "fig7_mlp_vs_cnn_attention.png")) |
| plt.close() |
| print(" Figure 7: MLP vs CNN attention saved") |
|
|
|
|
| |
| |
| |
| |
| |
|
|
| def figure8_lmic_desync_robustness(): |
| """Show LMIC-TSBN gradient stability across desync levels.""" |
| fig, axes = plt.subplots(3, 4, figsize=(12, 7)) |
| |
| desyncs = [0, 50, 100] |
| show_bytes = [0, 2, 8, 14] |
| |
| run_names = { |
| 0: "RERUN-LMIC-TSBN-V7b-multibit-desync0", |
| 50: "RERUN-LMIC-TSBN-V7b-multibit-desync50", |
| 100: "RERUN-LMIC-TSBN-V7b-multibit-desync100", |
| } |
| |
| |
| fallback_names = { |
| 0: "LMIC-TSBN-V8a-bitDTP-desync0", |
| 50: "LMIC-TSBN-V8a-bitDTP-desync50", |
| 100: "LMIC-TSBN-V8a-bitDTP-desync100", |
| } |
| |
| for row, desync in enumerate(desyncs): |
| maps = load_all_bytes(run_names[desync]) |
| if not maps: |
| maps = load_all_bytes(fallback_names[desync]) |
| |
| for col, byte_idx in enumerate(show_bytes): |
| ax = axes[row, col] |
| if byte_idx in maps: |
| sal = maps[byte_idx] |
| sal_norm = normalize_saliency(sal) |
| ax.fill_between(range(len(sal)), sal_norm, alpha=0.3, color='tab:blue') |
| ax.plot(sal_norm, linewidth=0.5, color='tab:blue') |
| ax.set_ylim(0, 1.05) |
| peak_pos = int(BYTE_PEAK_SNR[byte_idx] * WINDOW_SIZE) |
| ax.axvline(peak_pos, color='red', linestyle='--', alpha=0.7, linewidth=0.8) |
| ax.set_xticks([]) |
| ax.set_yticks([]) |
| if col == 0: |
| ax.set_ylabel(f"Desync={desync}", fontsize=9, fontweight='bold') |
| if row == 0: |
| ax.set_title(f"Byte {byte_idx}", fontsize=9) |
| |
| fig.suptitle("LMIC-TSBN Gradient Saliency Stability Across Desynchronization Levels\n" |
| "Gradient focus remains on correct leakage points regardless of temporal shift", |
| fontsize=11, fontweight='bold') |
| plt.tight_layout() |
| plt.savefig(os.path.join(OUTPUT_DIR, "fig8_lmic_desync_robustness.png")) |
| plt.close() |
| print(" Figure 8: LMIC-TSBN desync robustness saved") |
|
|
|
|
| |
| |
| |
| |
|
|
| def figure9_quantitative_summary(): |
| """Generate quantitative comparison of saliency magnitudes.""" |
| |
| summary_path = os.path.join(GRADIENT_MAPS_DIR, "summary.json") |
| with open(summary_path) as f: |
| summary = json.load(f) |
| |
| |
| arch_data = {} |
| for entry in summary: |
| if entry.get('status') != 'success': |
| continue |
| name = entry['name'] |
| model_type = entry.get('model_type', '') |
| desync = entry.get('desync', 0) |
| |
| |
| if 'CNN' in name and 'RERUN' in name: |
| key = f"CNN-desync{desync}" |
| if key not in arch_data: |
| arch_data[key] = [] |
| arch_data[key].append(entry.get('saliency_mean', 0)) |
| elif 'MLP' in name and 'RERUN' in name: |
| key = f"MLP-desync{desync}" |
| if key not in arch_data: |
| arch_data[key] = [] |
| arch_data[key].append(entry.get('saliency_mean', 0)) |
| elif 'V7b-multibit' in name or 'V8a-bitDTP' in name: |
| key = f"LMIC-TSBN-desync{desync}" |
| if key not in arch_data: |
| arch_data[key] = [] |
| |
| stats = entry.get('saliency_stats', {}) |
| if stats: |
| byte_means = [v['mean'] for v in stats.values()] |
| arch_data[key].append(np.mean(byte_means)) |
| elif 'HPS' in name: |
| key = f"HPS-desync{desync}" |
| if key not in arch_data: |
| arch_data[key] = [] |
| stats = entry.get('saliency_stats', {}) |
| if stats: |
| byte_means = [v['mean'] for v in stats.values()] |
| arch_data[key].append(np.mean(byte_means)) |
| |
| |
| fig, ax = plt.subplots(figsize=(10, 5)) |
| |
| architectures = ['MLP', 'CNN', 'HPS', 'LMIC-TSBN'] |
| desyncs = [0, 50, 100] |
| colors = ['tab:green', 'tab:orange', 'tab:red', 'tab:blue'] |
| |
| x = np.arange(len(desyncs)) |
| width = 0.2 |
| |
| for i, (arch, color) in enumerate(zip(architectures, colors)): |
| means = [] |
| for d in desyncs: |
| key = f"{arch}-desync{d}" |
| vals = arch_data.get(key, []) |
| means.append(np.mean(vals) if vals else 0) |
| |
| bars = ax.bar(x + i * width, means, width, label=arch, color=color, alpha=0.8) |
| |
| ax.set_xlabel("Desynchronization Level") |
| ax.set_ylabel("Mean Gradient Saliency Magnitude") |
| ax.set_title("Mean Gradient Saliency by Architecture and Desynchronization Level", |
| fontweight='bold') |
| ax.set_xticks(x + 1.5 * width) |
| ax.set_xticklabels([f"Desync={d}" for d in desyncs]) |
| ax.legend() |
| ax.set_yscale('log') |
| |
| plt.tight_layout() |
| plt.savefig(os.path.join(OUTPUT_DIR, "fig9_quantitative_summary.png")) |
| plt.close() |
| print(" Figure 9: Quantitative summary saved") |
|
|
|
|
| |
| |
| |
| |
|
|
| def figure10_seed_sensitivity(): |
| """Compare gradient maps across different random seeds.""" |
| fig, axes = plt.subplots(2, 4, figsize=(12, 5)) |
| |
| seeds = ['seed0', 'seed1'] |
| show_bytes = [0, 2, 8, 14] |
| |
| for row, seed in enumerate(seeds): |
| maps = load_all_bytes(f"SEED-sensitivity-{seed}-desync0") |
| |
| for col, byte_idx in enumerate(show_bytes): |
| ax = axes[row, col] |
| if byte_idx in maps: |
| sal = maps[byte_idx] |
| sal_norm = normalize_saliency(sal) |
| ax.fill_between(range(len(sal)), sal_norm, alpha=0.3, color='tab:blue') |
| ax.plot(sal_norm, linewidth=0.5, color='tab:blue') |
| ax.set_ylim(0, 1.05) |
| peak_pos = int(BYTE_PEAK_SNR[byte_idx] * WINDOW_SIZE) |
| ax.axvline(peak_pos, color='red', linestyle='--', alpha=0.7, linewidth=0.8) |
| ax.set_xticks([]) |
| ax.set_yticks([]) |
| if col == 0: |
| ax.set_ylabel(f"Seed {row}", fontsize=9, fontweight='bold') |
| if row == 0: |
| ax.set_title(f"Byte {byte_idx}", fontsize=9) |
| |
| fig.suptitle("LMIC-TSBN Gradient Saliency Consistency Across Random Seeds (Desync=0)\n" |
| "Gradient patterns are highly reproducible across different initializations", |
| fontsize=11, fontweight='bold') |
| plt.tight_layout() |
| plt.savefig(os.path.join(OUTPUT_DIR, "fig10_seed_sensitivity.png")) |
| plt.close() |
| print(" Figure 10: Seed sensitivity saved") |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| print("Generating gradient saliency visualizations...") |
| print(f"Input: {GRADIENT_MAPS_DIR}") |
| print(f"Output: {OUTPUT_DIR}") |
| print() |
| |
| figure1_architecture_comparison() |
| figure2_cnn_desync_degradation() |
| figure3_lmic_tsbn_16byte_grid() |
| figure4_hps_representation_competition() |
| figure5_binary_encoding_effect() |
| figure6_tsbn_effect() |
| figure7_mlp_vs_cnn_attention() |
| figure8_lmic_desync_robustness() |
| figure9_quantitative_summary() |
| figure10_seed_sensitivity() |
| |
| print() |
| print(f"All figures saved to: {OUTPUT_DIR}") |
| print("Done!") |
|
|