ascad-v1-models / analysis /scripts /activation_energy_analysis.py
lemousehunter's picture
Upload analysis/scripts/activation_energy_analysis.py with huggingface_hub
58d9bc3 verified
"""
Activation Energy Analysis for HPS Model
Generates a visualization showing:
1. Pre-GAP spatial activation energy heatmap (channels x spatial positions)
2. Per-channel energy before vs after GAP
3. Overall energy loss quantification
"""
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.colors import LogNorm
import warnings
warnings.filterwarnings('ignore')
# Load the HPS model
print("Loading HPS model...")
model = tf.keras.models.load_model('/home/ubuntu/models/hps_model.h5', compile=False)
# Find the GAP layer and the layer before it
gap_layer = None
pre_gap_layer = None
for i, layer in enumerate(model.layers):
if 'global_average_pooling' in layer.name.lower() or isinstance(layer, tf.keras.layers.GlobalAveragePooling1D):
gap_layer = layer
pre_gap_layer = model.layers[i - 1]
print(f"Found GAP layer: {gap_layer.name} (index {i})")
print(f"Pre-GAP layer: {pre_gap_layer.name} (index {i-1})")
break
if gap_layer is None:
# Try to find by output shape
for i, layer in enumerate(model.layers):
name = layer.name.lower()
if 'pool' in name or 'gap' in name:
print(f" Candidate pooling layer: {layer.name}, output: {layer.output_shape}")
# Search more broadly
for i, layer in enumerate(model.layers):
try:
out_shape = layer.output_shape
if isinstance(out_shape, tuple) and len(out_shape) == 2:
prev = model.layers[i-1]
prev_shape = prev.output_shape
if isinstance(prev_shape, tuple) and len(prev_shape) == 3:
print(f"Found dimension reduction: {prev.name} {prev_shape} -> {layer.name} {out_shape}")
gap_layer = layer
pre_gap_layer = prev
break
except:
continue
print(f"\nPre-GAP output shape: {pre_gap_layer.output.shape}")
print(f"GAP output shape: {gap_layer.output.shape}")
# Create a sub-model that outputs pre-GAP activations
input_layer = model.input
pre_gap_output = pre_gap_layer.output
gap_output = gap_layer.output
sub_model = tf.keras.Model(inputs=input_layer, outputs=[pre_gap_output, gap_output])
# Generate synthetic input data (random traces, same shape as ASCAD)
# We use random data because we're analyzing the model's learned spatial structure,
# not the data itself. The weights determine how energy is distributed.
input_shape = model.input_shape[1:]
print(f"Input shape: {input_shape}")
# Use 100 random traces as stated in the report
np.random.seed(42)
batch_size = 100
X_batch = np.random.randn(batch_size, *input_shape).astype(np.float32)
print("Running forward pass...")
pre_gap_acts, gap_acts = sub_model.predict(X_batch, verbose=0)
print(f"Pre-GAP activations shape: {pre_gap_acts.shape}") # (100, spatial, channels)
print(f"Post-GAP activations shape: {gap_acts.shape}") # (100, channels)
# Compute activation energy metrics
# Energy = sum of squared activations (L2 energy)
spatial_dim = pre_gap_acts.shape[1]
channel_dim = pre_gap_acts.shape[2]
# Per-position, per-channel energy (averaged over batch)
# Shape: (spatial, channels)
energy_map = np.mean(pre_gap_acts ** 2, axis=0) # (spatial, channels)
print(f"Energy map shape: {energy_map.shape}")
# Per-channel spatial variance (what GAP destroys)
per_channel_variance = np.mean(np.var(pre_gap_acts, axis=1), axis=0) # (channels,)
per_channel_mean_energy = np.mean(np.mean(pre_gap_acts ** 2, axis=1), axis=0) # (channels,)
# Total energy before and after GAP
total_energy_pre = np.mean(np.sum(pre_gap_acts ** 2, axis=(1, 2)))
total_energy_post = np.mean(np.sum(gap_acts ** 2, axis=1)) * spatial_dim # scale back
energy_retained = total_energy_post / total_energy_pre * 100
energy_lost = 100 - energy_retained
# Variance-based analysis
total_variance_pre = np.sum(per_channel_variance)
mean_variance = np.mean(per_channel_variance)
print(f"\n=== ENERGY ANALYSIS ===")
print(f"Total pre-GAP energy: {total_energy_pre:.2f}")
print(f"Total post-GAP energy (scaled): {total_energy_post:.2f}")
print(f"Energy retained: {energy_retained:.1f}%")
print(f"Energy lost: {energy_lost:.1f}%")
print(f"Mean per-channel spatial variance: {mean_variance:.2f}")
print(f"Total spatial variance: {total_variance_pre:.2f}")
# Sort channels by variance for visualization
channel_variance_order = np.argsort(per_channel_variance)[::-1]
# ============================================================
# Generate the figure
# ============================================================
fig = plt.figure(figsize=(16, 14))
gs = gridspec.GridSpec(3, 2, height_ratios=[1.2, 1, 1], hspace=0.35, wspace=0.3)
# --- Panel (a): Pre-GAP Activation Energy Heatmap ---
ax1 = fig.add_subplot(gs[0, :])
# Show top 64 channels by variance for readability
top_channels = channel_variance_order[:64]
energy_subset = energy_map[:, top_channels].T # (64, spatial)
# Downsample spatial dimension for readability
spatial_bins = 100
bin_size = spatial_dim // spatial_bins
energy_binned = np.zeros((len(top_channels), spatial_bins))
for b in range(spatial_bins):
start = b * bin_size
end = min((b + 1) * bin_size, spatial_dim)
energy_binned[:, b] = np.mean(energy_subset[:, start:end], axis=1)
im = ax1.imshow(energy_binned, aspect='auto', cmap='hot', interpolation='nearest')
ax1.set_xlabel('Spatial Position (binned into 100 segments)', fontsize=11)
ax1.set_ylabel('Channel (sorted by variance)', fontsize=11)
ax1.set_title('(a) Pre-GAP Activation Energy Map (Top 64 Channels by Spatial Variance)', fontsize=13, fontweight='bold')
cbar = plt.colorbar(im, ax=ax1, shrink=0.8)
cbar.set_label('Mean Squared Activation', fontsize=10)
# Add annotation showing spatial structure exists
ax1.text(0.02, 0.95, f'Spatial dim: {spatial_dim} positions × {channel_dim} channels',
transform=ax1.transAxes, fontsize=9, color='white', va='top',
bbox=dict(boxstyle='round', facecolor='black', alpha=0.7))
# --- Panel (b): Per-Channel Energy Before vs After GAP ---
ax2 = fig.add_subplot(gs[1, 0])
# Show per-channel: total energy vs GAP-retained energy
channels_sorted = channel_variance_order[:64]
pre_energy_sorted = per_channel_mean_energy[channels_sorted]
post_energy_sorted = np.mean(gap_acts ** 2, axis=0)[channels_sorted]
x_pos = np.arange(len(channels_sorted))
width = 0.35
bars1 = ax2.bar(x_pos - width/2, pre_energy_sorted, width, label='Pre-GAP (per position)', color='#e74c3c', alpha=0.8)
bars2 = ax2.bar(x_pos + width/2, post_energy_sorted, width, label='Post-GAP (averaged)', color='#3498db', alpha=0.8)
ax2.set_xlabel('Channel Index (sorted by variance)', fontsize=10)
ax2.set_ylabel('Mean Squared Activation', fontsize=10)
ax2.set_title('(b) Per-Channel Energy: Before vs After GAP', fontsize=12, fontweight='bold')
ax2.legend(fontsize=9)
ax2.set_xticks(x_pos[::8])
ax2.set_xticklabels([str(c) for c in channels_sorted[::8]], fontsize=8)
# --- Panel (c): Spatial Variance Distribution ---
ax3 = fig.add_subplot(gs[1, 1])
variance_sorted = np.sort(per_channel_variance)[::-1]
ax3.bar(range(len(variance_sorted)), variance_sorted, color='#e67e22', alpha=0.8, width=1.0)
ax3.axhline(y=mean_variance, color='red', linestyle='--', linewidth=2, label=f'Mean variance: {mean_variance:.2f}')
ax3.set_xlabel('Channel (sorted by variance)', fontsize=10)
ax3.set_ylabel('Spatial Variance', fontsize=10)
ax3.set_title('(c) Per-Channel Spatial Variance (Destroyed by GAP)', fontsize=12, fontweight='bold')
ax3.legend(fontsize=10)
# --- Panel (d): Before/After GAP Summary ---
ax4 = fig.add_subplot(gs[2, 0])
# Show the averaging effect: pick one high-variance channel
example_channel = channel_variance_order[0]
example_activations = np.mean(pre_gap_acts[:, :, example_channel], axis=0) # (spatial,)
gap_value = np.mean(gap_acts[:, example_channel])
ax4.plot(range(spatial_dim), example_activations, color='#e74c3c', alpha=0.6, linewidth=0.5, label='Pre-GAP spatial activations')
ax4.axhline(y=gap_value, color='#3498db', linewidth=3, linestyle='-', label=f'Post-GAP value: {gap_value:.4f}')
ax4.fill_between(range(spatial_dim), example_activations, gap_value, alpha=0.15, color='red')
ax4.set_xlabel('Spatial Position', fontsize=10)
ax4.set_ylabel('Activation Value', fontsize=10)
ax4.set_title(f'(d) GAP Collapse: Channel {example_channel} (Highest Variance)', fontsize=12, fontweight='bold')
ax4.legend(fontsize=9, loc='upper right')
ax4.text(0.02, 0.05, 'Red shaded area = information destroyed by averaging',
transform=ax4.transAxes, fontsize=9, style='italic', color='#c0392b')
# --- Panel (e): Overall Energy Budget ---
ax5 = fig.add_subplot(gs[2, 1])
labels = ['Energy\nRetained', 'Energy\nDestroyed']
sizes = [energy_retained, energy_lost]
colors = ['#3498db', '#e74c3c']
explode = (0, 0.05)
wedges, texts, autotexts = ax5.pie(sizes, explode=explode, labels=labels, colors=colors,
autopct='%1.1f%%', shadow=True, startangle=90,
textprops={'fontsize': 12})
for autotext in autotexts:
autotext.set_fontsize(14)
autotext.set_fontweight('bold')
ax5.set_title('(e) Activation Energy Budget After GAP', fontsize=12, fontweight='bold')
# Add text box with exact numbers
textstr = (f'Pre-GAP: {spatial_dim} × {channel_dim} = {spatial_dim * channel_dim:,} values\n'
f'Post-GAP: 1 × {channel_dim} = {channel_dim} values\n'
f'Compression ratio: {spatial_dim}:1\n'
f'Mean spatial variance: {mean_variance:.2f}')
props = dict(boxstyle='round', facecolor='wheat', alpha=0.8)
ax5.text(-0.1, -0.15, textstr, transform=ax5.transAxes, fontsize=9,
verticalalignment='top', bbox=props)
plt.suptitle('HPS Model: Activation Energy Analysis Before and After Global Average Pooling',
fontsize=15, fontweight='bold', y=0.98)
plt.savefig('/home/ubuntu/figures/fig7_activation_energy.png', dpi=200, bbox_inches='tight',
facecolor='white', edgecolor='none')
plt.close()
print(f"\nFigure saved to /home/ubuntu/figures/fig7_activation_energy.png")
print(f"\nKey numbers for report:")
print(f" Spatial dimension: {spatial_dim}")
print(f" Channel dimension: {channel_dim}")
print(f" Energy retained: {energy_retained:.1f}%")
print(f" Energy destroyed: {energy_lost:.1f}%")
print(f" Mean per-channel variance: {mean_variance:.2f}")
print(f" Compression ratio: {spatial_dim}:1")