ascad-v1-models / analysis /scripts /gradient_analysis_v4.py
lemousehunter's picture
Upload analysis/scripts/gradient_analysis_v4.py with huggingface_hub
4868ec9 verified
"""
Gradient and Activation Analysis: HPS GAP Information Collapse - V4
Key fix: Use sum of logits (not max) and handle softmax saturation.
Also include weight-based analysis.
"""
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
import numpy as np
import tensorflow as tf
import json
from numpy.linalg import norm
print("Loading HPS model...")
hps = tf.keras.models.load_model('/home/ubuntu/models/hps_model.h5', compile=False)
np.random.seed(42)
dummy_input = np.random.randn(32, 32272, 1).astype(np.float32)
pre_gap_layer = hps.get_layer('hps_drop4')
gap_layer = hps.get_layer('hps_gap')
# ============================================================
# Analysis 1: Structural proof - identical input to all heads
# ============================================================
print("="*80)
print("ANALYSIS 1: Structural Proof of Information Bottleneck")
print("="*80)
pre_gap_model = tf.keras.Model(inputs=hps.input, outputs=pre_gap_layer.output)
post_gap_model = tf.keras.Model(inputs=hps.input, outputs=gap_layer.output)
pre_gap_out = pre_gap_model.predict(dummy_input, verbose=0)
post_gap_out = post_gap_model.predict(dummy_input, verbose=0)
print(f"\nPre-GAP feature map: {pre_gap_out.shape}")
print(f" = {pre_gap_out.shape[1]} spatial positions x {pre_gap_out.shape[2]} channels")
print(f"Post-GAP vector: {post_gap_out.shape}")
print(f" = 1 vector of {post_gap_out.shape[1]} values (spatial info destroyed)")
manual_gap = pre_gap_out.mean(axis=1)
gap_error = np.abs(manual_gap - post_gap_out).max()
print(f"\nVerification: max|manual_avg - GAP_output| = {gap_error:.2e}")
# ============================================================
# Analysis 2: Pre-GAP spatial activation analysis
# ============================================================
print("\n" + "="*80)
print("ANALYSIS 2: Pre-GAP Spatial Activation Analysis")
print("="*80)
spatial_variance = np.var(pre_gap_out, axis=1) # (32, 512)
print(f"Mean spatial variance per channel: {spatial_variance.mean():.4f}")
print(f"Max spatial variance per channel: {spatial_variance.max():.4f}")
# Compute how much information is lost by averaging
# Information loss = variance of the spatial dimension (what GAP discards)
total_energy = np.sum(pre_gap_out**2, axis=1).mean() # mean over batch
gap_energy = np.sum(post_gap_out**2, axis=1).mean()
print(f"\nTotal pre-GAP energy (mean L2): {total_energy:.4f}")
print(f"Post-GAP energy (mean L2): {gap_energy:.4f}")
print(f"Energy ratio retained: {gap_energy/total_energy:.4f} ({100*gap_energy/total_energy:.1f}%)")
print(f"Energy lost by averaging: {100*(1-gap_energy/total_energy):.1f}%")
# Spatial entropy
channel_means = np.abs(pre_gap_out.mean(axis=0)) # (1008, 512)
entropies = []
for ch in range(512):
acts = channel_means[:, ch]
total = acts.sum()
if total > 0:
probs = acts / total
entropy = -np.sum(probs * np.log(probs + 1e-10))
max_entropy = np.log(len(probs))
entropies.append(entropy / max_entropy)
entropies = np.array(entropies)
print(f"\nNormalized spatial entropy: {entropies.mean():.4f}")
print(f" (1.0 = uniform spread, 0.0 = concentrated)")
print(f" High entropy confirms activations spread across all {pre_gap_out.shape[1]} positions")
# ============================================================
# Analysis 3: Weight-based proof of identical representations
# ============================================================
print("\n" + "="*80)
print("ANALYSIS 3: Byte Head Weight Analysis")
print("="*80)
byte_layer_names = sorted(
[l.name for l in hps.layers
if l.name.startswith('byte_') and l.name.count('_') == 1],
key=lambda x: int(x.replace('byte_', ''))
)
byte_weights = {}
for name in byte_layer_names:
layer = hps.get_layer(name)
W, b = layer.get_weights()
byte_weights[name] = {'W': W, 'b': b}
# Cosine similarity between byte head weight matrices
n = 16
weight_cos_sim = np.zeros((n, n))
for i in range(n):
Wi = byte_weights[f'byte_{i}']['W'].flatten()
for j in range(n):
Wj = byte_weights[f'byte_{j}']['W'].flatten()
weight_cos_sim[i, j] = np.dot(Wi, Wj) / (norm(Wi) * norm(Wj) + 1e-10)
upper_tri = weight_cos_sim[np.triu_indices(n, k=1)]
print(f"\nWeight cosine similarity across all 120 byte pairs:")
print(f" Mean: {upper_tri.mean():.6f}")
print(f" Min: {upper_tri.min():.6f}")
print(f" Max: {upper_tri.max():.6f}")
print(f" Std: {upper_tri.std():.6f}")
# Full matrix
print(f"\n{'':>8s}", end='')
for j in range(16):
print(f" b{j:>2d}", end='')
print()
for i in range(16):
print(f"byte_{i:>2d}", end='')
for j in range(16):
print(f" {weight_cos_sim[i,j]:5.2f}", end='')
print()
# Byte 0,1 vs others
print(f"\n--- Succeeded (bytes 0,1) vs Failed (bytes 2-15) ---")
b01_sim = weight_cos_sim[0, 1]
b0_vs_failed = [weight_cos_sim[0, j] for j in range(2, 16)]
b1_vs_failed = [weight_cos_sim[1, j] for j in range(2, 16)]
failed_pairs = [weight_cos_sim[i, j] for i in range(2, 16) for j in range(i+1, 16)]
print(f"Byte 0 vs Byte 1 (both succeeded): {b01_sim:.4f}")
print(f"Byte 0 vs failed bytes (2-15): mean={np.mean(b0_vs_failed):.4f}")
print(f"Byte 1 vs failed bytes (2-15): mean={np.mean(b1_vs_failed):.4f}")
print(f"Failed vs failed (2-15 pairs): mean={np.mean(failed_pairs):.4f}")
# Weight magnitude
print(f"\n--- Weight magnitude per byte head ---")
print(f"{'Byte':>8s} {'W_norm':>10s} {'W_mean':>10s} {'W_std':>10s} {'b_mean':>10s} {'Rank':>6s}")
for i, name in enumerate(byte_layer_names):
W = byte_weights[name]['W']
b = byte_weights[name]['b']
rank_str = "0" if i < 2 else "FAIL"
print(f"{name:>8s} {norm(W):10.4f} {W.mean():10.6f} {W.std():10.6f} {b.mean():10.6f} {rank_str:>6s}")
# ============================================================
# Analysis 4: Output predictions - are failed bytes near-uniform?
# ============================================================
print("\n" + "="*80)
print("ANALYSIS 4: Output Prediction Distributions")
print("="*80)
all_preds = hps.predict(dummy_input, verbose=0)
# all_preds is a list of 16 arrays, each (32, 256)
print(f"\nPer-byte prediction statistics (batch of {dummy_input.shape[0]}):")
print(f"{'Byte':>8s} {'Max_prob':>10s} {'Entropy':>10s} {'Max_ent':>10s} {'Ent_ratio':>10s} {'Rank':>6s}")
max_entropy = np.log(256)
for i in range(16):
preds = all_preds[i] # (32, 256)
max_probs = preds.max(axis=1).mean()
# Compute entropy
ent = -np.sum(preds * np.log(preds + 1e-10), axis=1).mean()
rank_str = "0" if i < 2 else "FAIL"
print(f"byte_{i:>2d} {max_probs:10.6f} {ent:10.4f} {max_entropy:10.4f} {ent/max_entropy:10.4f} {rank_str:>6s}")
# ============================================================
# Save all results
# ============================================================
results = {
'pre_gap_shape': [int(x) for x in pre_gap_out.shape],
'post_gap_shape': [int(x) for x in post_gap_out.shape],
'gap_verification_error': float(gap_error),
'spatial_variance_mean': float(spatial_variance.mean()),
'spatial_entropy_mean': float(entropies.mean()),
'energy_retained_pct': float(100*gap_energy/total_energy),
'weight_cosine_similarity': {
'mean': float(upper_tri.mean()),
'min': float(upper_tri.min()),
'max': float(upper_tri.max()),
'std': float(upper_tri.std()),
},
'byte0_vs_byte1_weight_sim': float(b01_sim),
'byte0_vs_failed_mean': float(np.mean(b0_vs_failed)),
'failed_vs_failed_mean': float(np.mean(failed_pairs)),
}
with open('/home/ubuntu/gap_analysis_results.json', 'w') as f:
json.dump(results, f, indent=2)
print("\nResults saved to gap_analysis_results.json")
print("ANALYSIS COMPLETE")