| """ |
| Activation Energy Analysis for HPS Model |
| Generates a visualization showing: |
| 1. Pre-GAP spatial activation energy heatmap (channels x spatial positions) |
| 2. Per-channel energy before vs after GAP |
| 3. Overall energy loss quantification |
| """ |
| import os |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' |
| os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' |
|
|
| import numpy as np |
| import tensorflow as tf |
| import matplotlib.pyplot as plt |
| import matplotlib.gridspec as gridspec |
| from matplotlib.colors import LogNorm |
| import warnings |
| warnings.filterwarnings('ignore') |
|
|
| |
| print("Loading HPS model...") |
| model = tf.keras.models.load_model('/home/ubuntu/models/hps_model.h5', compile=False) |
|
|
| |
| gap_layer = None |
| pre_gap_layer = None |
| for i, layer in enumerate(model.layers): |
| if 'global_average_pooling' in layer.name.lower() or isinstance(layer, tf.keras.layers.GlobalAveragePooling1D): |
| gap_layer = layer |
| pre_gap_layer = model.layers[i - 1] |
| print(f"Found GAP layer: {gap_layer.name} (index {i})") |
| print(f"Pre-GAP layer: {pre_gap_layer.name} (index {i-1})") |
| break |
|
|
| if gap_layer is None: |
| |
| for i, layer in enumerate(model.layers): |
| name = layer.name.lower() |
| if 'pool' in name or 'gap' in name: |
| print(f" Candidate pooling layer: {layer.name}, output: {layer.output_shape}") |
| |
| |
| for i, layer in enumerate(model.layers): |
| try: |
| out_shape = layer.output_shape |
| if isinstance(out_shape, tuple) and len(out_shape) == 2: |
| prev = model.layers[i-1] |
| prev_shape = prev.output_shape |
| if isinstance(prev_shape, tuple) and len(prev_shape) == 3: |
| print(f"Found dimension reduction: {prev.name} {prev_shape} -> {layer.name} {out_shape}") |
| gap_layer = layer |
| pre_gap_layer = prev |
| break |
| except: |
| continue |
|
|
| print(f"\nPre-GAP output shape: {pre_gap_layer.output.shape}") |
| print(f"GAP output shape: {gap_layer.output.shape}") |
|
|
| |
| input_layer = model.input |
| pre_gap_output = pre_gap_layer.output |
| gap_output = gap_layer.output |
|
|
| sub_model = tf.keras.Model(inputs=input_layer, outputs=[pre_gap_output, gap_output]) |
|
|
| |
| |
| |
| input_shape = model.input_shape[1:] |
| print(f"Input shape: {input_shape}") |
|
|
| |
| np.random.seed(42) |
| batch_size = 100 |
| X_batch = np.random.randn(batch_size, *input_shape).astype(np.float32) |
|
|
| print("Running forward pass...") |
| pre_gap_acts, gap_acts = sub_model.predict(X_batch, verbose=0) |
| print(f"Pre-GAP activations shape: {pre_gap_acts.shape}") |
| print(f"Post-GAP activations shape: {gap_acts.shape}") |
|
|
| |
| |
| spatial_dim = pre_gap_acts.shape[1] |
| channel_dim = pre_gap_acts.shape[2] |
|
|
| |
| |
| energy_map = np.mean(pre_gap_acts ** 2, axis=0) |
| print(f"Energy map shape: {energy_map.shape}") |
|
|
| |
| per_channel_variance = np.mean(np.var(pre_gap_acts, axis=1), axis=0) |
| per_channel_mean_energy = np.mean(np.mean(pre_gap_acts ** 2, axis=1), axis=0) |
|
|
| |
| total_energy_pre = np.mean(np.sum(pre_gap_acts ** 2, axis=(1, 2))) |
| total_energy_post = np.mean(np.sum(gap_acts ** 2, axis=1)) * spatial_dim |
| energy_retained = total_energy_post / total_energy_pre * 100 |
| energy_lost = 100 - energy_retained |
|
|
| |
| total_variance_pre = np.sum(per_channel_variance) |
| mean_variance = np.mean(per_channel_variance) |
|
|
| print(f"\n=== ENERGY ANALYSIS ===") |
| print(f"Total pre-GAP energy: {total_energy_pre:.2f}") |
| print(f"Total post-GAP energy (scaled): {total_energy_post:.2f}") |
| print(f"Energy retained: {energy_retained:.1f}%") |
| print(f"Energy lost: {energy_lost:.1f}%") |
| print(f"Mean per-channel spatial variance: {mean_variance:.2f}") |
| print(f"Total spatial variance: {total_variance_pre:.2f}") |
|
|
| |
| channel_variance_order = np.argsort(per_channel_variance)[::-1] |
|
|
| |
| |
| |
| fig = plt.figure(figsize=(16, 14)) |
| gs = gridspec.GridSpec(3, 2, height_ratios=[1.2, 1, 1], hspace=0.35, wspace=0.3) |
|
|
| |
| ax1 = fig.add_subplot(gs[0, :]) |
| |
| top_channels = channel_variance_order[:64] |
| energy_subset = energy_map[:, top_channels].T |
|
|
| |
| spatial_bins = 100 |
| bin_size = spatial_dim // spatial_bins |
| energy_binned = np.zeros((len(top_channels), spatial_bins)) |
| for b in range(spatial_bins): |
| start = b * bin_size |
| end = min((b + 1) * bin_size, spatial_dim) |
| energy_binned[:, b] = np.mean(energy_subset[:, start:end], axis=1) |
|
|
| im = ax1.imshow(energy_binned, aspect='auto', cmap='hot', interpolation='nearest') |
| ax1.set_xlabel('Spatial Position (binned into 100 segments)', fontsize=11) |
| ax1.set_ylabel('Channel (sorted by variance)', fontsize=11) |
| ax1.set_title('(a) Pre-GAP Activation Energy Map (Top 64 Channels by Spatial Variance)', fontsize=13, fontweight='bold') |
| cbar = plt.colorbar(im, ax=ax1, shrink=0.8) |
| cbar.set_label('Mean Squared Activation', fontsize=10) |
|
|
| |
| ax1.text(0.02, 0.95, f'Spatial dim: {spatial_dim} positions × {channel_dim} channels', |
| transform=ax1.transAxes, fontsize=9, color='white', va='top', |
| bbox=dict(boxstyle='round', facecolor='black', alpha=0.7)) |
|
|
| |
| ax2 = fig.add_subplot(gs[1, 0]) |
| |
| channels_sorted = channel_variance_order[:64] |
| pre_energy_sorted = per_channel_mean_energy[channels_sorted] |
| post_energy_sorted = np.mean(gap_acts ** 2, axis=0)[channels_sorted] |
|
|
| x_pos = np.arange(len(channels_sorted)) |
| width = 0.35 |
| bars1 = ax2.bar(x_pos - width/2, pre_energy_sorted, width, label='Pre-GAP (per position)', color='#e74c3c', alpha=0.8) |
| bars2 = ax2.bar(x_pos + width/2, post_energy_sorted, width, label='Post-GAP (averaged)', color='#3498db', alpha=0.8) |
| ax2.set_xlabel('Channel Index (sorted by variance)', fontsize=10) |
| ax2.set_ylabel('Mean Squared Activation', fontsize=10) |
| ax2.set_title('(b) Per-Channel Energy: Before vs After GAP', fontsize=12, fontweight='bold') |
| ax2.legend(fontsize=9) |
| ax2.set_xticks(x_pos[::8]) |
| ax2.set_xticklabels([str(c) for c in channels_sorted[::8]], fontsize=8) |
|
|
| |
| ax3 = fig.add_subplot(gs[1, 1]) |
| variance_sorted = np.sort(per_channel_variance)[::-1] |
| ax3.bar(range(len(variance_sorted)), variance_sorted, color='#e67e22', alpha=0.8, width=1.0) |
| ax3.axhline(y=mean_variance, color='red', linestyle='--', linewidth=2, label=f'Mean variance: {mean_variance:.2f}') |
| ax3.set_xlabel('Channel (sorted by variance)', fontsize=10) |
| ax3.set_ylabel('Spatial Variance', fontsize=10) |
| ax3.set_title('(c) Per-Channel Spatial Variance (Destroyed by GAP)', fontsize=12, fontweight='bold') |
| ax3.legend(fontsize=10) |
|
|
| |
| ax4 = fig.add_subplot(gs[2, 0]) |
| |
| example_channel = channel_variance_order[0] |
| example_activations = np.mean(pre_gap_acts[:, :, example_channel], axis=0) |
| gap_value = np.mean(gap_acts[:, example_channel]) |
|
|
| ax4.plot(range(spatial_dim), example_activations, color='#e74c3c', alpha=0.6, linewidth=0.5, label='Pre-GAP spatial activations') |
| ax4.axhline(y=gap_value, color='#3498db', linewidth=3, linestyle='-', label=f'Post-GAP value: {gap_value:.4f}') |
| ax4.fill_between(range(spatial_dim), example_activations, gap_value, alpha=0.15, color='red') |
| ax4.set_xlabel('Spatial Position', fontsize=10) |
| ax4.set_ylabel('Activation Value', fontsize=10) |
| ax4.set_title(f'(d) GAP Collapse: Channel {example_channel} (Highest Variance)', fontsize=12, fontweight='bold') |
| ax4.legend(fontsize=9, loc='upper right') |
| ax4.text(0.02, 0.05, 'Red shaded area = information destroyed by averaging', |
| transform=ax4.transAxes, fontsize=9, style='italic', color='#c0392b') |
|
|
| |
| ax5 = fig.add_subplot(gs[2, 1]) |
| labels = ['Energy\nRetained', 'Energy\nDestroyed'] |
| sizes = [energy_retained, energy_lost] |
| colors = ['#3498db', '#e74c3c'] |
| explode = (0, 0.05) |
|
|
| wedges, texts, autotexts = ax5.pie(sizes, explode=explode, labels=labels, colors=colors, |
| autopct='%1.1f%%', shadow=True, startangle=90, |
| textprops={'fontsize': 12}) |
| for autotext in autotexts: |
| autotext.set_fontsize(14) |
| autotext.set_fontweight('bold') |
| ax5.set_title('(e) Activation Energy Budget After GAP', fontsize=12, fontweight='bold') |
|
|
| |
| textstr = (f'Pre-GAP: {spatial_dim} × {channel_dim} = {spatial_dim * channel_dim:,} values\n' |
| f'Post-GAP: 1 × {channel_dim} = {channel_dim} values\n' |
| f'Compression ratio: {spatial_dim}:1\n' |
| f'Mean spatial variance: {mean_variance:.2f}') |
| props = dict(boxstyle='round', facecolor='wheat', alpha=0.8) |
| ax5.text(-0.1, -0.15, textstr, transform=ax5.transAxes, fontsize=9, |
| verticalalignment='top', bbox=props) |
|
|
| plt.suptitle('HPS Model: Activation Energy Analysis Before and After Global Average Pooling', |
| fontsize=15, fontweight='bold', y=0.98) |
|
|
| plt.savefig('/home/ubuntu/figures/fig7_activation_energy.png', dpi=200, bbox_inches='tight', |
| facecolor='white', edgecolor='none') |
| plt.close() |
|
|
| print(f"\nFigure saved to /home/ubuntu/figures/fig7_activation_energy.png") |
| print(f"\nKey numbers for report:") |
| print(f" Spatial dimension: {spatial_dim}") |
| print(f" Channel dimension: {channel_dim}") |
| print(f" Energy retained: {energy_retained:.1f}%") |
| print(f" Energy destroyed: {energy_lost:.1f}%") |
| print(f" Mean per-channel variance: {mean_variance:.2f}") |
| print(f" Compression ratio: {spatial_dim}:1") |
|
|