""" Gradient and Activation Analysis: HPS GAP Information Collapse - V4 Key fix: Use sum of logits (not max) and handle softmax saturation. Also include weight-based analysis. """ import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' import numpy as np import tensorflow as tf import json from numpy.linalg import norm print("Loading HPS model...") hps = tf.keras.models.load_model('/home/ubuntu/models/hps_model.h5', compile=False) np.random.seed(42) dummy_input = np.random.randn(32, 32272, 1).astype(np.float32) pre_gap_layer = hps.get_layer('hps_drop4') gap_layer = hps.get_layer('hps_gap') # ============================================================ # Analysis 1: Structural proof - identical input to all heads # ============================================================ print("="*80) print("ANALYSIS 1: Structural Proof of Information Bottleneck") print("="*80) pre_gap_model = tf.keras.Model(inputs=hps.input, outputs=pre_gap_layer.output) post_gap_model = tf.keras.Model(inputs=hps.input, outputs=gap_layer.output) pre_gap_out = pre_gap_model.predict(dummy_input, verbose=0) post_gap_out = post_gap_model.predict(dummy_input, verbose=0) print(f"\nPre-GAP feature map: {pre_gap_out.shape}") print(f" = {pre_gap_out.shape[1]} spatial positions x {pre_gap_out.shape[2]} channels") print(f"Post-GAP vector: {post_gap_out.shape}") print(f" = 1 vector of {post_gap_out.shape[1]} values (spatial info destroyed)") manual_gap = pre_gap_out.mean(axis=1) gap_error = np.abs(manual_gap - post_gap_out).max() print(f"\nVerification: max|manual_avg - GAP_output| = {gap_error:.2e}") # ============================================================ # Analysis 2: Pre-GAP spatial activation analysis # ============================================================ print("\n" + "="*80) print("ANALYSIS 2: Pre-GAP Spatial Activation Analysis") print("="*80) spatial_variance = np.var(pre_gap_out, axis=1) # (32, 512) print(f"Mean spatial variance per channel: {spatial_variance.mean():.4f}") print(f"Max spatial variance per channel: {spatial_variance.max():.4f}") # Compute how much information is lost by averaging # Information loss = variance of the spatial dimension (what GAP discards) total_energy = np.sum(pre_gap_out**2, axis=1).mean() # mean over batch gap_energy = np.sum(post_gap_out**2, axis=1).mean() print(f"\nTotal pre-GAP energy (mean L2): {total_energy:.4f}") print(f"Post-GAP energy (mean L2): {gap_energy:.4f}") print(f"Energy ratio retained: {gap_energy/total_energy:.4f} ({100*gap_energy/total_energy:.1f}%)") print(f"Energy lost by averaging: {100*(1-gap_energy/total_energy):.1f}%") # Spatial entropy channel_means = np.abs(pre_gap_out.mean(axis=0)) # (1008, 512) entropies = [] for ch in range(512): acts = channel_means[:, ch] total = acts.sum() if total > 0: probs = acts / total entropy = -np.sum(probs * np.log(probs + 1e-10)) max_entropy = np.log(len(probs)) entropies.append(entropy / max_entropy) entropies = np.array(entropies) print(f"\nNormalized spatial entropy: {entropies.mean():.4f}") print(f" (1.0 = uniform spread, 0.0 = concentrated)") print(f" High entropy confirms activations spread across all {pre_gap_out.shape[1]} positions") # ============================================================ # Analysis 3: Weight-based proof of identical representations # ============================================================ print("\n" + "="*80) print("ANALYSIS 3: Byte Head Weight Analysis") print("="*80) byte_layer_names = sorted( [l.name for l in hps.layers if l.name.startswith('byte_') and l.name.count('_') == 1], key=lambda x: int(x.replace('byte_', '')) ) byte_weights = {} for name in byte_layer_names: layer = hps.get_layer(name) W, b = layer.get_weights() byte_weights[name] = {'W': W, 'b': b} # Cosine similarity between byte head weight matrices n = 16 weight_cos_sim = np.zeros((n, n)) for i in range(n): Wi = byte_weights[f'byte_{i}']['W'].flatten() for j in range(n): Wj = byte_weights[f'byte_{j}']['W'].flatten() weight_cos_sim[i, j] = np.dot(Wi, Wj) / (norm(Wi) * norm(Wj) + 1e-10) upper_tri = weight_cos_sim[np.triu_indices(n, k=1)] print(f"\nWeight cosine similarity across all 120 byte pairs:") print(f" Mean: {upper_tri.mean():.6f}") print(f" Min: {upper_tri.min():.6f}") print(f" Max: {upper_tri.max():.6f}") print(f" Std: {upper_tri.std():.6f}") # Full matrix print(f"\n{'':>8s}", end='') for j in range(16): print(f" b{j:>2d}", end='') print() for i in range(16): print(f"byte_{i:>2d}", end='') for j in range(16): print(f" {weight_cos_sim[i,j]:5.2f}", end='') print() # Byte 0,1 vs others print(f"\n--- Succeeded (bytes 0,1) vs Failed (bytes 2-15) ---") b01_sim = weight_cos_sim[0, 1] b0_vs_failed = [weight_cos_sim[0, j] for j in range(2, 16)] b1_vs_failed = [weight_cos_sim[1, j] for j in range(2, 16)] failed_pairs = [weight_cos_sim[i, j] for i in range(2, 16) for j in range(i+1, 16)] print(f"Byte 0 vs Byte 1 (both succeeded): {b01_sim:.4f}") print(f"Byte 0 vs failed bytes (2-15): mean={np.mean(b0_vs_failed):.4f}") print(f"Byte 1 vs failed bytes (2-15): mean={np.mean(b1_vs_failed):.4f}") print(f"Failed vs failed (2-15 pairs): mean={np.mean(failed_pairs):.4f}") # Weight magnitude print(f"\n--- Weight magnitude per byte head ---") print(f"{'Byte':>8s} {'W_norm':>10s} {'W_mean':>10s} {'W_std':>10s} {'b_mean':>10s} {'Rank':>6s}") for i, name in enumerate(byte_layer_names): W = byte_weights[name]['W'] b = byte_weights[name]['b'] rank_str = "0" if i < 2 else "FAIL" print(f"{name:>8s} {norm(W):10.4f} {W.mean():10.6f} {W.std():10.6f} {b.mean():10.6f} {rank_str:>6s}") # ============================================================ # Analysis 4: Output predictions - are failed bytes near-uniform? # ============================================================ print("\n" + "="*80) print("ANALYSIS 4: Output Prediction Distributions") print("="*80) all_preds = hps.predict(dummy_input, verbose=0) # all_preds is a list of 16 arrays, each (32, 256) print(f"\nPer-byte prediction statistics (batch of {dummy_input.shape[0]}):") print(f"{'Byte':>8s} {'Max_prob':>10s} {'Entropy':>10s} {'Max_ent':>10s} {'Ent_ratio':>10s} {'Rank':>6s}") max_entropy = np.log(256) for i in range(16): preds = all_preds[i] # (32, 256) max_probs = preds.max(axis=1).mean() # Compute entropy ent = -np.sum(preds * np.log(preds + 1e-10), axis=1).mean() rank_str = "0" if i < 2 else "FAIL" print(f"byte_{i:>2d} {max_probs:10.6f} {ent:10.4f} {max_entropy:10.4f} {ent/max_entropy:10.4f} {rank_str:>6s}") # ============================================================ # Save all results # ============================================================ results = { 'pre_gap_shape': [int(x) for x in pre_gap_out.shape], 'post_gap_shape': [int(x) for x in post_gap_out.shape], 'gap_verification_error': float(gap_error), 'spatial_variance_mean': float(spatial_variance.mean()), 'spatial_entropy_mean': float(entropies.mean()), 'energy_retained_pct': float(100*gap_energy/total_energy), 'weight_cosine_similarity': { 'mean': float(upper_tri.mean()), 'min': float(upper_tri.min()), 'max': float(upper_tri.max()), 'std': float(upper_tri.std()), }, 'byte0_vs_byte1_weight_sim': float(b01_sim), 'byte0_vs_failed_mean': float(np.mean(b0_vs_failed)), 'failed_vs_failed_mean': float(np.mean(failed_pairs)), } with open('/home/ubuntu/gap_analysis_results.json', 'w') as f: json.dump(results, f, indent=2) print("\nResults saved to gap_analysis_results.json") print("ANALYSIS COMPLETE")