File size: 7,794 Bytes
4868ec9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 | """
Gradient and Activation Analysis: HPS GAP Information Collapse - V4
Key fix: Use sum of logits (not max) and handle softmax saturation.
Also include weight-based analysis.
"""
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
import numpy as np
import tensorflow as tf
import json
from numpy.linalg import norm
print("Loading HPS model...")
hps = tf.keras.models.load_model('/home/ubuntu/models/hps_model.h5', compile=False)
np.random.seed(42)
dummy_input = np.random.randn(32, 32272, 1).astype(np.float32)
pre_gap_layer = hps.get_layer('hps_drop4')
gap_layer = hps.get_layer('hps_gap')
# ============================================================
# Analysis 1: Structural proof - identical input to all heads
# ============================================================
print("="*80)
print("ANALYSIS 1: Structural Proof of Information Bottleneck")
print("="*80)
pre_gap_model = tf.keras.Model(inputs=hps.input, outputs=pre_gap_layer.output)
post_gap_model = tf.keras.Model(inputs=hps.input, outputs=gap_layer.output)
pre_gap_out = pre_gap_model.predict(dummy_input, verbose=0)
post_gap_out = post_gap_model.predict(dummy_input, verbose=0)
print(f"\nPre-GAP feature map: {pre_gap_out.shape}")
print(f" = {pre_gap_out.shape[1]} spatial positions x {pre_gap_out.shape[2]} channels")
print(f"Post-GAP vector: {post_gap_out.shape}")
print(f" = 1 vector of {post_gap_out.shape[1]} values (spatial info destroyed)")
manual_gap = pre_gap_out.mean(axis=1)
gap_error = np.abs(manual_gap - post_gap_out).max()
print(f"\nVerification: max|manual_avg - GAP_output| = {gap_error:.2e}")
# ============================================================
# Analysis 2: Pre-GAP spatial activation analysis
# ============================================================
print("\n" + "="*80)
print("ANALYSIS 2: Pre-GAP Spatial Activation Analysis")
print("="*80)
spatial_variance = np.var(pre_gap_out, axis=1) # (32, 512)
print(f"Mean spatial variance per channel: {spatial_variance.mean():.4f}")
print(f"Max spatial variance per channel: {spatial_variance.max():.4f}")
# Compute how much information is lost by averaging
# Information loss = variance of the spatial dimension (what GAP discards)
total_energy = np.sum(pre_gap_out**2, axis=1).mean() # mean over batch
gap_energy = np.sum(post_gap_out**2, axis=1).mean()
print(f"\nTotal pre-GAP energy (mean L2): {total_energy:.4f}")
print(f"Post-GAP energy (mean L2): {gap_energy:.4f}")
print(f"Energy ratio retained: {gap_energy/total_energy:.4f} ({100*gap_energy/total_energy:.1f}%)")
print(f"Energy lost by averaging: {100*(1-gap_energy/total_energy):.1f}%")
# Spatial entropy
channel_means = np.abs(pre_gap_out.mean(axis=0)) # (1008, 512)
entropies = []
for ch in range(512):
acts = channel_means[:, ch]
total = acts.sum()
if total > 0:
probs = acts / total
entropy = -np.sum(probs * np.log(probs + 1e-10))
max_entropy = np.log(len(probs))
entropies.append(entropy / max_entropy)
entropies = np.array(entropies)
print(f"\nNormalized spatial entropy: {entropies.mean():.4f}")
print(f" (1.0 = uniform spread, 0.0 = concentrated)")
print(f" High entropy confirms activations spread across all {pre_gap_out.shape[1]} positions")
# ============================================================
# Analysis 3: Weight-based proof of identical representations
# ============================================================
print("\n" + "="*80)
print("ANALYSIS 3: Byte Head Weight Analysis")
print("="*80)
byte_layer_names = sorted(
[l.name for l in hps.layers
if l.name.startswith('byte_') and l.name.count('_') == 1],
key=lambda x: int(x.replace('byte_', ''))
)
byte_weights = {}
for name in byte_layer_names:
layer = hps.get_layer(name)
W, b = layer.get_weights()
byte_weights[name] = {'W': W, 'b': b}
# Cosine similarity between byte head weight matrices
n = 16
weight_cos_sim = np.zeros((n, n))
for i in range(n):
Wi = byte_weights[f'byte_{i}']['W'].flatten()
for j in range(n):
Wj = byte_weights[f'byte_{j}']['W'].flatten()
weight_cos_sim[i, j] = np.dot(Wi, Wj) / (norm(Wi) * norm(Wj) + 1e-10)
upper_tri = weight_cos_sim[np.triu_indices(n, k=1)]
print(f"\nWeight cosine similarity across all 120 byte pairs:")
print(f" Mean: {upper_tri.mean():.6f}")
print(f" Min: {upper_tri.min():.6f}")
print(f" Max: {upper_tri.max():.6f}")
print(f" Std: {upper_tri.std():.6f}")
# Full matrix
print(f"\n{'':>8s}", end='')
for j in range(16):
print(f" b{j:>2d}", end='')
print()
for i in range(16):
print(f"byte_{i:>2d}", end='')
for j in range(16):
print(f" {weight_cos_sim[i,j]:5.2f}", end='')
print()
# Byte 0,1 vs others
print(f"\n--- Succeeded (bytes 0,1) vs Failed (bytes 2-15) ---")
b01_sim = weight_cos_sim[0, 1]
b0_vs_failed = [weight_cos_sim[0, j] for j in range(2, 16)]
b1_vs_failed = [weight_cos_sim[1, j] for j in range(2, 16)]
failed_pairs = [weight_cos_sim[i, j] for i in range(2, 16) for j in range(i+1, 16)]
print(f"Byte 0 vs Byte 1 (both succeeded): {b01_sim:.4f}")
print(f"Byte 0 vs failed bytes (2-15): mean={np.mean(b0_vs_failed):.4f}")
print(f"Byte 1 vs failed bytes (2-15): mean={np.mean(b1_vs_failed):.4f}")
print(f"Failed vs failed (2-15 pairs): mean={np.mean(failed_pairs):.4f}")
# Weight magnitude
print(f"\n--- Weight magnitude per byte head ---")
print(f"{'Byte':>8s} {'W_norm':>10s} {'W_mean':>10s} {'W_std':>10s} {'b_mean':>10s} {'Rank':>6s}")
for i, name in enumerate(byte_layer_names):
W = byte_weights[name]['W']
b = byte_weights[name]['b']
rank_str = "0" if i < 2 else "FAIL"
print(f"{name:>8s} {norm(W):10.4f} {W.mean():10.6f} {W.std():10.6f} {b.mean():10.6f} {rank_str:>6s}")
# ============================================================
# Analysis 4: Output predictions - are failed bytes near-uniform?
# ============================================================
print("\n" + "="*80)
print("ANALYSIS 4: Output Prediction Distributions")
print("="*80)
all_preds = hps.predict(dummy_input, verbose=0)
# all_preds is a list of 16 arrays, each (32, 256)
print(f"\nPer-byte prediction statistics (batch of {dummy_input.shape[0]}):")
print(f"{'Byte':>8s} {'Max_prob':>10s} {'Entropy':>10s} {'Max_ent':>10s} {'Ent_ratio':>10s} {'Rank':>6s}")
max_entropy = np.log(256)
for i in range(16):
preds = all_preds[i] # (32, 256)
max_probs = preds.max(axis=1).mean()
# Compute entropy
ent = -np.sum(preds * np.log(preds + 1e-10), axis=1).mean()
rank_str = "0" if i < 2 else "FAIL"
print(f"byte_{i:>2d} {max_probs:10.6f} {ent:10.4f} {max_entropy:10.4f} {ent/max_entropy:10.4f} {rank_str:>6s}")
# ============================================================
# Save all results
# ============================================================
results = {
'pre_gap_shape': [int(x) for x in pre_gap_out.shape],
'post_gap_shape': [int(x) for x in post_gap_out.shape],
'gap_verification_error': float(gap_error),
'spatial_variance_mean': float(spatial_variance.mean()),
'spatial_entropy_mean': float(entropies.mean()),
'energy_retained_pct': float(100*gap_energy/total_energy),
'weight_cosine_similarity': {
'mean': float(upper_tri.mean()),
'min': float(upper_tri.min()),
'max': float(upper_tri.max()),
'std': float(upper_tri.std()),
},
'byte0_vs_byte1_weight_sim': float(b01_sim),
'byte0_vs_failed_mean': float(np.mean(b0_vs_failed)),
'failed_vs_failed_mean': float(np.mean(failed_pairs)),
}
with open('/home/ubuntu/gap_analysis_results.json', 'w') as f:
json.dump(results, f, indent=2)
print("\nResults saved to gap_analysis_results.json")
print("ANALYSIS COMPLETE")
|