#!/usr/bin/env python3 """ Generate comparative gradient saliency visualizations for thesis claims. This script produces publication-quality figures from the gradient maps generated by generate_gradient_maps.py. Each figure directly supports specific claims in the thesis chapters. Figures generated: 1. LMIC-TSBN vs CNN vs MLP saliency comparison (byte 2, desync=0) 2. CNN degradation under desync (byte 0: desync=0 vs 50 vs 100) 3. LMIC-TSBN 16-byte grid showing uniform attention (desync=0) 4. HPS/MTAN-Lite representation competition (all bytes, desync=50/100) 5. Ablation: binary encoding effect on gradient quality 6. Ablation: TSBN effect on cross-byte interference 7. MLP diffuse attention vs CNN localized attention 8. Desync robustness: LMIC-TSBN gradient stability across desync levels """ import os import json import numpy as np import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt from matplotlib.patches import Rectangle from pathlib import Path # ============================================================================ # Configuration # ============================================================================ GRADIENT_MAPS_DIR = "/home/ubuntu/gradient_maps_data/gradient_maps" OUTPUT_DIR = "/home/ubuntu/gradient_figures" os.makedirs(OUTPUT_DIR, exist_ok=True) # Publication style plt.rcParams.update({ 'font.family': 'serif', 'font.size': 9, 'axes.labelsize': 10, 'axes.titlesize': 11, 'xtick.labelsize': 8, 'ytick.labelsize': 8, 'legend.fontsize': 8, 'figure.dpi': 300, 'savefig.dpi': 300, 'savefig.bbox': 'tight', 'axes.grid': True, 'grid.alpha': 0.3, }) # POI windows from the pipeline constants (start positions relative to full trace) BYTE_POI_WINDOWS = { 0: (45800, 46500), 1: (46100, 46800), 2: (45400, 46100), 3: (46400, 47100), 4: (47000, 47700), 5: (47300, 48000), 6: (47600, 48300), 7: (47900, 48600), 8: (48200, 48900), 9: (48500, 49200), 10: (48800, 49500), 11: (49100, 49800), 12: (49400, 50100), 13: (49700, 50400), 14: (50000, 50700), 15: (50300, 51000), } WINDOW_SIZE = 700 # SNR peak positions (relative to window start) from pipeline BYTE_PEAK_SNR = { 0: 0.1523, 1: 0.1843, 2: 0.2103, 3: 0.1402, 4: 0.1652, 5: 0.1891, 6: 0.1340, 7: 0.1578, 8: 0.1815, 9: 0.1253, 10: 0.1489, 11: 0.1725, 12: 0.1161, 13: 0.1396, 14: 0.1631, 15: 0.1065, } def load_gradient_map(run_name, byte_idx=None): """Load gradient map(s) for a run.""" run_dir = os.path.join(GRADIENT_MAPS_DIR, run_name) if not os.path.exists(run_dir): return None if byte_idx is not None: # Multi-task model: load specific byte path = os.path.join(run_dir, f"gradient_map_byte{byte_idx}.npy") if os.path.exists(path): return np.load(path) # Single-byte model path = os.path.join(run_dir, "gradient_map.npy") if os.path.exists(path): return np.load(path) else: # Load single gradient map path = os.path.join(run_dir, "gradient_map.npy") if os.path.exists(path): return np.load(path) return None def load_all_bytes(run_name): """Load all 16 byte gradient maps for a multi-task model.""" maps = {} run_dir = os.path.join(GRADIENT_MAPS_DIR, run_name) if not os.path.exists(run_dir): return maps for b in range(16): path = os.path.join(run_dir, f"gradient_map_byte{b}.npy") if os.path.exists(path): maps[b] = np.load(path) return maps def normalize_saliency(sal): """Normalize saliency to [0, 1] for visualization.""" if sal.max() == 0: return sal return sal / sal.max() # ============================================================================ # Figure 1: LMIC-TSBN vs CNN vs MLP (byte 2, desync=0) # Shows that LMIC-TSBN focuses sharply on correct leakage point # Supports Claims 14, 19, 44 (Ch4) # ============================================================================ def figure1_architecture_comparison(): """Compare gradient saliency across architectures for byte 2 at desync=0.""" fig, axes = plt.subplots(3, 1, figsize=(8, 6), sharex=True) # Load data lmic = load_gradient_map("RERUN-LMIC-TSBN-V7b-multibit-desync0", byte_idx=2) cnn = load_gradient_map("RERUN-CNN-byte2-desync0") mlp = load_gradient_map("RERUN-MLP-byte2-desync0") or load_gradient_map("CLEAN-MLP-byte2-desync0") x = np.arange(WINDOW_SIZE) models = [ ("LMIC-TSBN (Multi-Task)", lmic, 'tab:blue'), ("CNN (Single-Byte)", cnn, 'tab:orange'), ("MLP (Single-Byte)", mlp, 'tab:green'), ] for ax, (name, sal, color) in zip(axes, models): if sal is not None: sal_norm = normalize_saliency(sal) ax.fill_between(x, sal_norm, alpha=0.3, color=color) ax.plot(x, sal_norm, linewidth=0.8, color=color) ax.set_ylabel("Norm. Saliency") ax.set_title(name, fontsize=10, fontweight='bold') ax.set_ylim(0, 1.05) # Mark expected SNR peak position peak_pos = int(BYTE_PEAK_SNR[2] * WINDOW_SIZE) ax.axvline(peak_pos, color='red', linestyle='--', alpha=0.7, linewidth=1, label=f'SNR Peak (sample {peak_pos})') ax.legend(loc='upper right') else: ax.text(0.5, 0.5, f"{name}: Data not available", transform=ax.transAxes, ha='center', va='center') axes[-1].set_xlabel("Time Sample (within 700-sample window)") fig.suptitle("Gradient Saliency Comparison: Byte 2, Desync=0", fontsize=12, fontweight='bold') plt.tight_layout() plt.savefig(os.path.join(OUTPUT_DIR, "fig1_architecture_comparison_byte2.png")) plt.close() print(" Figure 1: Architecture comparison saved") # ============================================================================ # Figure 2: CNN degradation under desync (byte 0) # Shows CNN losing focus on correct leakage point as desync increases # Supports Claims 17, 18, 19 (Ch4) # ============================================================================ def figure2_cnn_desync_degradation(): """Show CNN gradient maps degrading under increasing desync.""" fig, axes = plt.subplots(3, 1, figsize=(8, 6), sharex=True) desyncs = [0, 50, 100] x = np.arange(WINDOW_SIZE) for ax, desync in zip(axes, desyncs): sal = load_gradient_map(f"RERUN-CNN-byte0-desync{desync}") if sal is not None: sal_norm = normalize_saliency(sal) ax.fill_between(x, sal_norm, alpha=0.3, color='tab:orange') ax.plot(x, sal_norm, linewidth=0.8, color='tab:orange') ax.set_ylabel("Norm. Saliency") ax.set_title(f"CNN Byte 0, Desync={desync}", fontsize=10, fontweight='bold') ax.set_ylim(0, 1.05) # Mark SNR peak peak_pos = int(BYTE_PEAK_SNR[0] * WINDOW_SIZE) ax.axvline(peak_pos, color='red', linestyle='--', alpha=0.7, linewidth=1) # Add max saliency annotation ax.text(0.98, 0.85, f"max={sal.max():.4f}\nmean={sal.mean():.4f}", transform=ax.transAxes, ha='right', va='top', fontsize=7, bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) else: ax.text(0.5, 0.5, f"Data not available", transform=ax.transAxes, ha='center') axes[-1].set_xlabel("Time Sample (within 700-sample window)") fig.suptitle("CNN Gradient Saliency Degradation Under Desynchronization (Byte 0)", fontsize=12, fontweight='bold') plt.tight_layout() plt.savefig(os.path.join(OUTPUT_DIR, "fig2_cnn_desync_degradation.png")) plt.close() print(" Figure 2: CNN desync degradation saved") # ============================================================================ # Figure 3: LMIC-TSBN 16-byte grid (desync=0) # Shows uniform attention across all bytes # Supports Claims 44, 45 (Ch4) # ============================================================================ def figure3_lmic_tsbn_16byte_grid(): """Show LMIC-TSBN gradient maps for all 16 bytes in a 4x4 grid.""" fig, axes = plt.subplots(4, 4, figsize=(12, 8)) # Try V7b first, then V8a run_name = "RERUN-LMIC-TSBN-V7b-multibit-desync0" maps = load_all_bytes(run_name) if not maps: run_name = "LMIC-TSBN-V8a-bitDTP-desync0" maps = load_all_bytes(run_name) x = np.arange(WINDOW_SIZE) for byte_idx in range(16): row, col = byte_idx // 4, byte_idx % 4 ax = axes[row, col] if byte_idx in maps: sal = maps[byte_idx] sal_norm = normalize_saliency(sal) ax.fill_between(x, sal_norm, alpha=0.3, color='tab:blue') ax.plot(x, sal_norm, linewidth=0.5, color='tab:blue') # Mark SNR peak peak_pos = int(BYTE_PEAK_SNR[byte_idx] * WINDOW_SIZE) ax.axvline(peak_pos, color='red', linestyle='--', alpha=0.7, linewidth=0.8) ax.set_title(f"Byte {byte_idx}", fontsize=8) ax.set_ylim(0, 1.05) else: ax.text(0.5, 0.5, "N/A", transform=ax.transAxes, ha='center') ax.set_xticks([]) ax.set_yticks([]) if col == 0: ax.set_ylabel("Saliency", fontsize=7) if row == 3: ax.set_xlabel("Time", fontsize=7) fig.suptitle(f"LMIC-TSBN Gradient Saliency: All 16 Bytes (Desync=0)\n" f"Red dashed = SNR peak position", fontsize=12, fontweight='bold') plt.tight_layout() plt.savefig(os.path.join(OUTPUT_DIR, "fig3_lmic_tsbn_16byte_grid.png")) plt.close() print(" Figure 3: LMIC-TSBN 16-byte grid saved") # ============================================================================ # Figure 4: HPS representation competition # Shows HPS gradient concentrated on bytes 0/1, negligible for others # Supports Claims 29, 30, 33, 34 (Ch4) # ============================================================================ def figure4_hps_representation_competition(): """Show HPS gradient maps revealing representation competition.""" fig, axes = plt.subplots(4, 4, figsize=(12, 8)) # Try desync100 first (clearest failure) run_name = "HPS-baseline-desync100" maps = load_all_bytes(run_name) if not maps: run_name = "HPS-baseline-desync50" maps = load_all_bytes(run_name) if not maps: print(" Figure 4: SKIPPED - No HPS gradient data available") return # Find global max for consistent scale global_max = max(m.max() for m in maps.values()) if maps else 1.0 x = np.arange(len(list(maps.values())[0])) if maps else np.arange(700) for byte_idx in range(16): row, col = byte_idx // 4, byte_idx % 4 ax = axes[row, col] if byte_idx in maps: sal = maps[byte_idx] # Use GLOBAL normalization to show relative magnitudes sal_global = sal / global_max if global_max > 0 else sal ax.fill_between(range(len(sal)), sal_global, alpha=0.3, color='tab:red') ax.plot(sal_global, linewidth=0.5, color='tab:red') ax.set_title(f"Byte {byte_idx} (max={sal.max():.2e})", fontsize=7) ax.set_ylim(0, 1.05) else: ax.text(0.5, 0.5, "N/A", transform=ax.transAxes, ha='center') ax.set_xticks([]) ax.set_yticks([]) fig.suptitle(f"HPS Gradient Saliency: Representation Competition (Desync=100)\n" f"Note: Globally normalized - bytes 0/1 dominate, others near-zero", fontsize=11, fontweight='bold') plt.tight_layout() plt.savefig(os.path.join(OUTPUT_DIR, "fig4_hps_representation_competition.png")) plt.close() print(" Figure 4: HPS representation competition saved") # ============================================================================ # Figure 5: Binary encoding effect on gradient quality # Compares LMIC-TSBN (multibit) vs ABLATION-A1 (no-multibit) # Supports Claims 53, 55, 56, 57 (Ch4) # ============================================================================ def figure5_binary_encoding_effect(): """Compare gradient quality with and without binary encoding.""" fig, axes = plt.subplots(2, 4, figsize=(12, 5)) # Load both models at desync=0 multibit_maps = load_all_bytes("RERUN-LMIC-TSBN-V7b-multibit-desync0") no_multibit_maps = load_all_bytes("ABLATION-A1-no-multibit-desync0") if not multibit_maps and not no_multibit_maps: # Try alternative names multibit_maps = load_all_bytes("LMIC-TSBN-V8a-bitDTP-desync0") # Show 4 representative bytes (0, 2, 8, 14) show_bytes = [0, 2, 8, 14] for col, byte_idx in enumerate(show_bytes): # Top row: with binary encoding ax = axes[0, col] if byte_idx in multibit_maps: sal = multibit_maps[byte_idx] sal_norm = normalize_saliency(sal) ax.fill_between(range(len(sal)), sal_norm, alpha=0.3, color='tab:blue') ax.plot(sal_norm, linewidth=0.5, color='tab:blue') ax.set_title(f"Byte {byte_idx}", fontsize=9) ax.set_ylim(0, 1.05) ax.set_xticks([]) if col == 0: ax.set_ylabel("Binary Enc.", fontsize=9, fontweight='bold') # Bottom row: without binary encoding (identity) ax = axes[1, col] if byte_idx in no_multibit_maps: sal = no_multibit_maps[byte_idx] sal_norm = normalize_saliency(sal) ax.fill_between(range(len(sal)), sal_norm, alpha=0.3, color='tab:purple') ax.plot(sal_norm, linewidth=0.5, color='tab:purple') ax.set_ylim(0, 1.05) ax.set_xticks([]) if col == 0: ax.set_ylabel("Identity Enc.", fontsize=9, fontweight='bold') fig.suptitle("Effect of Binary Encoding on Gradient Saliency (Desync=0)\n" "Binary encoding produces sharper, more focused gradients", fontsize=11, fontweight='bold') plt.tight_layout() plt.savefig(os.path.join(OUTPUT_DIR, "fig5_binary_encoding_effect.png")) plt.close() print(" Figure 5: Binary encoding effect saved") # ============================================================================ # Figure 6: TSBN effect on cross-byte interference # Compares LMIC-TSBN vs LMIC (no TSBN) at desync=0 # Supports Claims 59, 60, 61 (Ch4) # ============================================================================ def figure6_tsbn_effect(): """Compare gradient maps with and without TSBN.""" fig, axes = plt.subplots(2, 4, figsize=(12, 5)) # Load both models with_tsbn = load_all_bytes("RERUN-LMIC-TSBN-V7b-multibit-desync0") without_tsbn = load_all_bytes("ABLATION-LMIC-no-TSBN-desync0") if not with_tsbn: with_tsbn = load_all_bytes("LMIC-TSBN-V8a-bitDTP-desync0") show_bytes = [0, 2, 8, 14] for col, byte_idx in enumerate(show_bytes): # Top row: with TSBN ax = axes[0, col] if byte_idx in with_tsbn: sal = with_tsbn[byte_idx] sal_norm = normalize_saliency(sal) ax.fill_between(range(len(sal)), sal_norm, alpha=0.3, color='tab:blue') ax.plot(sal_norm, linewidth=0.5, color='tab:blue') ax.set_title(f"Byte {byte_idx}", fontsize=9) ax.set_ylim(0, 1.05) ax.set_xticks([]) if col == 0: ax.set_ylabel("With TSBN", fontsize=9, fontweight='bold') # Bottom row: without TSBN ax = axes[1, col] if byte_idx in without_tsbn: sal = without_tsbn[byte_idx] sal_norm = normalize_saliency(sal) ax.fill_between(range(len(sal)), sal_norm, alpha=0.3, color='tab:red') ax.plot(sal_norm, linewidth=0.5, color='tab:red') ax.set_ylim(0, 1.05) ax.set_xticks([]) if col == 0: ax.set_ylabel("Without TSBN", fontsize=9, fontweight='bold') fig.suptitle("Effect of Task-Specific Batch Normalization on Gradient Saliency (Desync=0)\n" "TSBN prevents cross-byte interference in normalization statistics", fontsize=11, fontweight='bold') plt.tight_layout() plt.savefig(os.path.join(OUTPUT_DIR, "fig6_tsbn_effect.png")) plt.close() print(" Figure 6: TSBN effect saved") # ============================================================================ # Figure 7: MLP diffuse vs CNN localized attention # Shows MLP has no spatial structure in gradients # Supports Claim 14 (Ch4) # ============================================================================ def figure7_mlp_vs_cnn_attention(): """Compare MLP (diffuse) vs CNN (localized) gradient patterns.""" fig, axes = plt.subplots(2, 3, figsize=(10, 5)) bytes_to_show = [0, 5, 11] for col, byte_idx in enumerate(bytes_to_show): # Top: CNN ax = axes[0, col] sal = load_gradient_map(f"RERUN-CNN-byte{byte_idx}-desync0") if sal is not None: sal_norm = normalize_saliency(sal) ax.fill_between(range(len(sal)), sal_norm, alpha=0.3, color='tab:orange') ax.plot(sal_norm, linewidth=0.5, color='tab:orange') ax.set_title(f"Byte {byte_idx}", fontsize=9) ax.set_ylim(0, 1.05) # Mark peak peak_pos = int(BYTE_PEAK_SNR[byte_idx] * WINDOW_SIZE) ax.axvline(peak_pos, color='red', linestyle='--', alpha=0.7, linewidth=0.8) ax.set_xticks([]) if col == 0: ax.set_ylabel("CNN", fontsize=9, fontweight='bold') # Bottom: MLP ax = axes[1, col] sal = load_gradient_map(f"RERUN-MLP-byte{byte_idx}-desync0") if sal is None: sal = load_gradient_map(f"CLEAN-MLP-byte{byte_idx}-desync0") if sal is not None: sal_norm = normalize_saliency(sal) ax.fill_between(range(len(sal)), sal_norm, alpha=0.3, color='tab:green') ax.plot(sal_norm, linewidth=0.5, color='tab:green') ax.set_ylim(0, 1.05) peak_pos = int(BYTE_PEAK_SNR[byte_idx] * WINDOW_SIZE) ax.axvline(peak_pos, color='red', linestyle='--', alpha=0.7, linewidth=0.8) ax.set_xticks([]) if col == 0: ax.set_ylabel("MLP", fontsize=9, fontweight='bold') fig.suptitle("CNN vs MLP Gradient Saliency Patterns (Desync=0)\n" "CNN: Localized peaks at leakage points | MLP: Diffuse, no spatial structure", fontsize=11, fontweight='bold') plt.tight_layout() plt.savefig(os.path.join(OUTPUT_DIR, "fig7_mlp_vs_cnn_attention.png")) plt.close() print(" Figure 7: MLP vs CNN attention saved") # ============================================================================ # Figure 8: LMIC-TSBN desync robustness # Shows gradient maps remain stable across desync levels # Supports Claims 68, 69, 70 (Ch4) # ============================================================================ def figure8_lmic_desync_robustness(): """Show LMIC-TSBN gradient stability across desync levels.""" fig, axes = plt.subplots(3, 4, figsize=(12, 7)) desyncs = [0, 50, 100] show_bytes = [0, 2, 8, 14] run_names = { 0: "RERUN-LMIC-TSBN-V7b-multibit-desync0", 50: "RERUN-LMIC-TSBN-V7b-multibit-desync50", 100: "RERUN-LMIC-TSBN-V7b-multibit-desync100", } # Fallback names fallback_names = { 0: "LMIC-TSBN-V8a-bitDTP-desync0", 50: "LMIC-TSBN-V8a-bitDTP-desync50", 100: "LMIC-TSBN-V8a-bitDTP-desync100", } for row, desync in enumerate(desyncs): maps = load_all_bytes(run_names[desync]) if not maps: maps = load_all_bytes(fallback_names[desync]) for col, byte_idx in enumerate(show_bytes): ax = axes[row, col] if byte_idx in maps: sal = maps[byte_idx] sal_norm = normalize_saliency(sal) ax.fill_between(range(len(sal)), sal_norm, alpha=0.3, color='tab:blue') ax.plot(sal_norm, linewidth=0.5, color='tab:blue') ax.set_ylim(0, 1.05) peak_pos = int(BYTE_PEAK_SNR[byte_idx] * WINDOW_SIZE) ax.axvline(peak_pos, color='red', linestyle='--', alpha=0.7, linewidth=0.8) ax.set_xticks([]) ax.set_yticks([]) if col == 0: ax.set_ylabel(f"Desync={desync}", fontsize=9, fontweight='bold') if row == 0: ax.set_title(f"Byte {byte_idx}", fontsize=9) fig.suptitle("LMIC-TSBN Gradient Saliency Stability Across Desynchronization Levels\n" "Gradient focus remains on correct leakage points regardless of temporal shift", fontsize=11, fontweight='bold') plt.tight_layout() plt.savefig(os.path.join(OUTPUT_DIR, "fig8_lmic_desync_robustness.png")) plt.close() print(" Figure 8: LMIC-TSBN desync robustness saved") # ============================================================================ # Figure 9: Quantitative summary - mean saliency magnitude comparison # Bar chart comparing mean gradient magnitudes across architectures # ============================================================================ def figure9_quantitative_summary(): """Generate quantitative comparison of saliency magnitudes.""" # Load summary data summary_path = os.path.join(GRADIENT_MAPS_DIR, "summary.json") with open(summary_path) as f: summary = json.load(f) # Organize by architecture and desync arch_data = {} for entry in summary: if entry.get('status') != 'success': continue name = entry['name'] model_type = entry.get('model_type', '') desync = entry.get('desync', 0) # Categorize if 'CNN' in name and 'RERUN' in name: key = f"CNN-desync{desync}" if key not in arch_data: arch_data[key] = [] arch_data[key].append(entry.get('saliency_mean', 0)) elif 'MLP' in name and 'RERUN' in name: key = f"MLP-desync{desync}" if key not in arch_data: arch_data[key] = [] arch_data[key].append(entry.get('saliency_mean', 0)) elif 'V7b-multibit' in name or 'V8a-bitDTP' in name: key = f"LMIC-TSBN-desync{desync}" if key not in arch_data: arch_data[key] = [] # Multi-task: average across bytes stats = entry.get('saliency_stats', {}) if stats: byte_means = [v['mean'] for v in stats.values()] arch_data[key].append(np.mean(byte_means)) elif 'HPS' in name: key = f"HPS-desync{desync}" if key not in arch_data: arch_data[key] = [] stats = entry.get('saliency_stats', {}) if stats: byte_means = [v['mean'] for v in stats.values()] arch_data[key].append(np.mean(byte_means)) # Create grouped bar chart fig, ax = plt.subplots(figsize=(10, 5)) architectures = ['MLP', 'CNN', 'HPS', 'LMIC-TSBN'] desyncs = [0, 50, 100] colors = ['tab:green', 'tab:orange', 'tab:red', 'tab:blue'] x = np.arange(len(desyncs)) width = 0.2 for i, (arch, color) in enumerate(zip(architectures, colors)): means = [] for d in desyncs: key = f"{arch}-desync{d}" vals = arch_data.get(key, []) means.append(np.mean(vals) if vals else 0) bars = ax.bar(x + i * width, means, width, label=arch, color=color, alpha=0.8) ax.set_xlabel("Desynchronization Level") ax.set_ylabel("Mean Gradient Saliency Magnitude") ax.set_title("Mean Gradient Saliency by Architecture and Desynchronization Level", fontweight='bold') ax.set_xticks(x + 1.5 * width) ax.set_xticklabels([f"Desync={d}" for d in desyncs]) ax.legend() ax.set_yscale('log') plt.tight_layout() plt.savefig(os.path.join(OUTPUT_DIR, "fig9_quantitative_summary.png")) plt.close() print(" Figure 9: Quantitative summary saved") # ============================================================================ # Figure 10: Seed sensitivity - gradient consistency across seeds # Supports Claims 65, 66, 67 (Ch4) # ============================================================================ def figure10_seed_sensitivity(): """Compare gradient maps across different random seeds.""" fig, axes = plt.subplots(2, 4, figsize=(12, 5)) seeds = ['seed0', 'seed1'] show_bytes = [0, 2, 8, 14] for row, seed in enumerate(seeds): maps = load_all_bytes(f"SEED-sensitivity-{seed}-desync0") for col, byte_idx in enumerate(show_bytes): ax = axes[row, col] if byte_idx in maps: sal = maps[byte_idx] sal_norm = normalize_saliency(sal) ax.fill_between(range(len(sal)), sal_norm, alpha=0.3, color='tab:blue') ax.plot(sal_norm, linewidth=0.5, color='tab:blue') ax.set_ylim(0, 1.05) peak_pos = int(BYTE_PEAK_SNR[byte_idx] * WINDOW_SIZE) ax.axvline(peak_pos, color='red', linestyle='--', alpha=0.7, linewidth=0.8) ax.set_xticks([]) ax.set_yticks([]) if col == 0: ax.set_ylabel(f"Seed {row}", fontsize=9, fontweight='bold') if row == 0: ax.set_title(f"Byte {byte_idx}", fontsize=9) fig.suptitle("LMIC-TSBN Gradient Saliency Consistency Across Random Seeds (Desync=0)\n" "Gradient patterns are highly reproducible across different initializations", fontsize=11, fontweight='bold') plt.tight_layout() plt.savefig(os.path.join(OUTPUT_DIR, "fig10_seed_sensitivity.png")) plt.close() print(" Figure 10: Seed sensitivity saved") # ============================================================================ # Main # ============================================================================ if __name__ == "__main__": print("Generating gradient saliency visualizations...") print(f"Input: {GRADIENT_MAPS_DIR}") print(f"Output: {OUTPUT_DIR}") print() figure1_architecture_comparison() figure2_cnn_desync_degradation() figure3_lmic_tsbn_16byte_grid() figure4_hps_representation_competition() figure5_binary_encoding_effect() figure6_tsbn_effect() figure7_mlp_vs_cnn_attention() figure8_lmic_desync_robustness() figure9_quantitative_summary() figure10_seed_sensitivity() print() print(f"All figures saved to: {OUTPUT_DIR}") print("Done!")