""" Activation Energy Analysis for HPS Model Generates a visualization showing: 1. Pre-GAP spatial activation energy heatmap (channels x spatial positions) 2. Per-channel energy before vs after GAP 3. Overall energy loss quantification """ import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' import numpy as np import tensorflow as tf import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec from matplotlib.colors import LogNorm import warnings warnings.filterwarnings('ignore') # Load the HPS model print("Loading HPS model...") model = tf.keras.models.load_model('/home/ubuntu/models/hps_model.h5', compile=False) # Find the GAP layer and the layer before it gap_layer = None pre_gap_layer = None for i, layer in enumerate(model.layers): if 'global_average_pooling' in layer.name.lower() or isinstance(layer, tf.keras.layers.GlobalAveragePooling1D): gap_layer = layer pre_gap_layer = model.layers[i - 1] print(f"Found GAP layer: {gap_layer.name} (index {i})") print(f"Pre-GAP layer: {pre_gap_layer.name} (index {i-1})") break if gap_layer is None: # Try to find by output shape for i, layer in enumerate(model.layers): name = layer.name.lower() if 'pool' in name or 'gap' in name: print(f" Candidate pooling layer: {layer.name}, output: {layer.output_shape}") # Search more broadly for i, layer in enumerate(model.layers): try: out_shape = layer.output_shape if isinstance(out_shape, tuple) and len(out_shape) == 2: prev = model.layers[i-1] prev_shape = prev.output_shape if isinstance(prev_shape, tuple) and len(prev_shape) == 3: print(f"Found dimension reduction: {prev.name} {prev_shape} -> {layer.name} {out_shape}") gap_layer = layer pre_gap_layer = prev break except: continue print(f"\nPre-GAP output shape: {pre_gap_layer.output.shape}") print(f"GAP output shape: {gap_layer.output.shape}") # Create a sub-model that outputs pre-GAP activations input_layer = model.input pre_gap_output = pre_gap_layer.output gap_output = gap_layer.output sub_model = tf.keras.Model(inputs=input_layer, outputs=[pre_gap_output, gap_output]) # Generate synthetic input data (random traces, same shape as ASCAD) # We use random data because we're analyzing the model's learned spatial structure, # not the data itself. The weights determine how energy is distributed. input_shape = model.input_shape[1:] print(f"Input shape: {input_shape}") # Use 100 random traces as stated in the report np.random.seed(42) batch_size = 100 X_batch = np.random.randn(batch_size, *input_shape).astype(np.float32) print("Running forward pass...") pre_gap_acts, gap_acts = sub_model.predict(X_batch, verbose=0) print(f"Pre-GAP activations shape: {pre_gap_acts.shape}") # (100, spatial, channels) print(f"Post-GAP activations shape: {gap_acts.shape}") # (100, channels) # Compute activation energy metrics # Energy = sum of squared activations (L2 energy) spatial_dim = pre_gap_acts.shape[1] channel_dim = pre_gap_acts.shape[2] # Per-position, per-channel energy (averaged over batch) # Shape: (spatial, channels) energy_map = np.mean(pre_gap_acts ** 2, axis=0) # (spatial, channels) print(f"Energy map shape: {energy_map.shape}") # Per-channel spatial variance (what GAP destroys) per_channel_variance = np.mean(np.var(pre_gap_acts, axis=1), axis=0) # (channels,) per_channel_mean_energy = np.mean(np.mean(pre_gap_acts ** 2, axis=1), axis=0) # (channels,) # Total energy before and after GAP total_energy_pre = np.mean(np.sum(pre_gap_acts ** 2, axis=(1, 2))) total_energy_post = np.mean(np.sum(gap_acts ** 2, axis=1)) * spatial_dim # scale back energy_retained = total_energy_post / total_energy_pre * 100 energy_lost = 100 - energy_retained # Variance-based analysis total_variance_pre = np.sum(per_channel_variance) mean_variance = np.mean(per_channel_variance) print(f"\n=== ENERGY ANALYSIS ===") print(f"Total pre-GAP energy: {total_energy_pre:.2f}") print(f"Total post-GAP energy (scaled): {total_energy_post:.2f}") print(f"Energy retained: {energy_retained:.1f}%") print(f"Energy lost: {energy_lost:.1f}%") print(f"Mean per-channel spatial variance: {mean_variance:.2f}") print(f"Total spatial variance: {total_variance_pre:.2f}") # Sort channels by variance for visualization channel_variance_order = np.argsort(per_channel_variance)[::-1] # ============================================================ # Generate the figure # ============================================================ fig = plt.figure(figsize=(16, 14)) gs = gridspec.GridSpec(3, 2, height_ratios=[1.2, 1, 1], hspace=0.35, wspace=0.3) # --- Panel (a): Pre-GAP Activation Energy Heatmap --- ax1 = fig.add_subplot(gs[0, :]) # Show top 64 channels by variance for readability top_channels = channel_variance_order[:64] energy_subset = energy_map[:, top_channels].T # (64, spatial) # Downsample spatial dimension for readability spatial_bins = 100 bin_size = spatial_dim // spatial_bins energy_binned = np.zeros((len(top_channels), spatial_bins)) for b in range(spatial_bins): start = b * bin_size end = min((b + 1) * bin_size, spatial_dim) energy_binned[:, b] = np.mean(energy_subset[:, start:end], axis=1) im = ax1.imshow(energy_binned, aspect='auto', cmap='hot', interpolation='nearest') ax1.set_xlabel('Spatial Position (binned into 100 segments)', fontsize=11) ax1.set_ylabel('Channel (sorted by variance)', fontsize=11) ax1.set_title('(a) Pre-GAP Activation Energy Map (Top 64 Channels by Spatial Variance)', fontsize=13, fontweight='bold') cbar = plt.colorbar(im, ax=ax1, shrink=0.8) cbar.set_label('Mean Squared Activation', fontsize=10) # Add annotation showing spatial structure exists ax1.text(0.02, 0.95, f'Spatial dim: {spatial_dim} positions × {channel_dim} channels', transform=ax1.transAxes, fontsize=9, color='white', va='top', bbox=dict(boxstyle='round', facecolor='black', alpha=0.7)) # --- Panel (b): Per-Channel Energy Before vs After GAP --- ax2 = fig.add_subplot(gs[1, 0]) # Show per-channel: total energy vs GAP-retained energy channels_sorted = channel_variance_order[:64] pre_energy_sorted = per_channel_mean_energy[channels_sorted] post_energy_sorted = np.mean(gap_acts ** 2, axis=0)[channels_sorted] x_pos = np.arange(len(channels_sorted)) width = 0.35 bars1 = ax2.bar(x_pos - width/2, pre_energy_sorted, width, label='Pre-GAP (per position)', color='#e74c3c', alpha=0.8) bars2 = ax2.bar(x_pos + width/2, post_energy_sorted, width, label='Post-GAP (averaged)', color='#3498db', alpha=0.8) ax2.set_xlabel('Channel Index (sorted by variance)', fontsize=10) ax2.set_ylabel('Mean Squared Activation', fontsize=10) ax2.set_title('(b) Per-Channel Energy: Before vs After GAP', fontsize=12, fontweight='bold') ax2.legend(fontsize=9) ax2.set_xticks(x_pos[::8]) ax2.set_xticklabels([str(c) for c in channels_sorted[::8]], fontsize=8) # --- Panel (c): Spatial Variance Distribution --- ax3 = fig.add_subplot(gs[1, 1]) variance_sorted = np.sort(per_channel_variance)[::-1] ax3.bar(range(len(variance_sorted)), variance_sorted, color='#e67e22', alpha=0.8, width=1.0) ax3.axhline(y=mean_variance, color='red', linestyle='--', linewidth=2, label=f'Mean variance: {mean_variance:.2f}') ax3.set_xlabel('Channel (sorted by variance)', fontsize=10) ax3.set_ylabel('Spatial Variance', fontsize=10) ax3.set_title('(c) Per-Channel Spatial Variance (Destroyed by GAP)', fontsize=12, fontweight='bold') ax3.legend(fontsize=10) # --- Panel (d): Before/After GAP Summary --- ax4 = fig.add_subplot(gs[2, 0]) # Show the averaging effect: pick one high-variance channel example_channel = channel_variance_order[0] example_activations = np.mean(pre_gap_acts[:, :, example_channel], axis=0) # (spatial,) gap_value = np.mean(gap_acts[:, example_channel]) ax4.plot(range(spatial_dim), example_activations, color='#e74c3c', alpha=0.6, linewidth=0.5, label='Pre-GAP spatial activations') ax4.axhline(y=gap_value, color='#3498db', linewidth=3, linestyle='-', label=f'Post-GAP value: {gap_value:.4f}') ax4.fill_between(range(spatial_dim), example_activations, gap_value, alpha=0.15, color='red') ax4.set_xlabel('Spatial Position', fontsize=10) ax4.set_ylabel('Activation Value', fontsize=10) ax4.set_title(f'(d) GAP Collapse: Channel {example_channel} (Highest Variance)', fontsize=12, fontweight='bold') ax4.legend(fontsize=9, loc='upper right') ax4.text(0.02, 0.05, 'Red shaded area = information destroyed by averaging', transform=ax4.transAxes, fontsize=9, style='italic', color='#c0392b') # --- Panel (e): Overall Energy Budget --- ax5 = fig.add_subplot(gs[2, 1]) labels = ['Energy\nRetained', 'Energy\nDestroyed'] sizes = [energy_retained, energy_lost] colors = ['#3498db', '#e74c3c'] explode = (0, 0.05) wedges, texts, autotexts = ax5.pie(sizes, explode=explode, labels=labels, colors=colors, autopct='%1.1f%%', shadow=True, startangle=90, textprops={'fontsize': 12}) for autotext in autotexts: autotext.set_fontsize(14) autotext.set_fontweight('bold') ax5.set_title('(e) Activation Energy Budget After GAP', fontsize=12, fontweight='bold') # Add text box with exact numbers textstr = (f'Pre-GAP: {spatial_dim} × {channel_dim} = {spatial_dim * channel_dim:,} values\n' f'Post-GAP: 1 × {channel_dim} = {channel_dim} values\n' f'Compression ratio: {spatial_dim}:1\n' f'Mean spatial variance: {mean_variance:.2f}') props = dict(boxstyle='round', facecolor='wheat', alpha=0.8) ax5.text(-0.1, -0.15, textstr, transform=ax5.transAxes, fontsize=9, verticalalignment='top', bbox=props) plt.suptitle('HPS Model: Activation Energy Analysis Before and After Global Average Pooling', fontsize=15, fontweight='bold', y=0.98) plt.savefig('/home/ubuntu/figures/fig7_activation_energy.png', dpi=200, bbox_inches='tight', facecolor='white', edgecolor='none') plt.close() print(f"\nFigure saved to /home/ubuntu/figures/fig7_activation_energy.png") print(f"\nKey numbers for report:") print(f" Spatial dimension: {spatial_dim}") print(f" Channel dimension: {channel_dim}") print(f" Energy retained: {energy_retained:.1f}%") print(f" Energy destroyed: {energy_lost:.1f}%") print(f" Mean per-channel variance: {mean_variance:.2f}") print(f" Compression ratio: {spatial_dim}:1")