ascad-v1-models / analysis /scripts /structural_bias_analysis.py
lemousehunter's picture
Upload analysis/scripts/structural_bias_analysis.py with huggingface_hub
b414070 verified
"""
Analyze WHY bytes 0 and 1 dominate the shared representation.
Look at POI positions relative to the global window, CNN receptive field,
and how GAP interacts with spatial features.
"""
import numpy as np
# From constants.py
GLOBAL_WINDOW_START = 18000
GLOBAL_WINDOW_END = 50272
GLOBAL_WINDOW_SIZE = GLOBAL_WINDOW_END - GLOBAL_WINDOW_START # 32,272
BYTE_POI_WINDOWS = {
0: (30838, 31538),
1: (24525, 25225),
2: (45400, 46100),
3: (32824, 33524),
4: (47508, 48208),
5: (41258, 41958),
6: (37094, 37794),
7: (35018, 35718),
8: (26631, 27331),
9: (39145, 39845),
10: (28766, 29466),
11: (43333, 44033),
12: (20418, 21118),
13: (22499, 23199),
14: (49571, 50271),
15: (18363, 19063),
}
BYTE_PEAK_SNR = {
0: 1.72, 1: 1.44, 2: 2.99, 3: 1.82,
4: 2.24, 5: 2.22, 6: 2.10, 7: 1.95,
8: 1.28, 9: 1.68, 10: 1.15, 11: 2.48,
12: 0.78, 13: 0.94, 14: 2.55, 15: 0.83,
}
# HPS final ranks
HPS_RANKS = {
0: 0, 1: 0, 2: 2, 3: 66, 4: 57, 5: 86, 6: 27, 7: 19,
8: 41, 9: 91, 10: 129, 11: 14, 12: 39, 13: 88, 14: 109, 15: 28,
}
print("=" * 90)
print("BYTE POI POSITION ANALYSIS")
print("=" * 90)
print(f"{'Byte':>4} | {'POI Start':>10} | {'POI Center':>10} | {'Rel Position':>12} | {'SNR':>5} | {'HPS Rank':>8} | {'Learned?':>8}")
print("-" * 90)
for i in range(16):
start, end = BYTE_POI_WINDOWS[i]
center = (start + end) / 2
# Position relative to global window (0.0 = start, 1.0 = end)
rel_pos = (center - GLOBAL_WINDOW_START) / GLOBAL_WINDOW_SIZE
rank = HPS_RANKS[i]
snr = BYTE_PEAK_SNR[i]
learned = "YES" if rank <= 2 else ("close" if rank <= 30 else "NO")
print(f"{i:>4} | {start:>10} | {center:>10.0f} | {rel_pos:>11.3f} | {snr:>5.2f} | {rank:>8} | {learned:>8}")
# Sort by relative position
print("\n" + "=" * 90)
print("BYTES SORTED BY POI POSITION (left to right in global window)")
print("=" * 90)
sorted_bytes = sorted(range(16), key=lambda i: BYTE_POI_WINDOWS[i][0])
for i in sorted_bytes:
start, end = BYTE_POI_WINDOWS[i]
center = (start + end) / 2
rel_pos = (center - GLOBAL_WINDOW_START) / GLOBAL_WINDOW_SIZE
rank = HPS_RANKS[i]
learned = "YES" if rank <= 2 else ("close" if rank <= 30 else "NO")
bar = "#" * int(rel_pos * 50)
print(f" byte {i:>2} [{rel_pos:.3f}] rank={rank:>3} {bar}")
# CNN architecture analysis
print("\n" + "=" * 90)
print("CNN SPATIAL DIMENSION AFTER EACH POOLING LAYER")
print("=" * 90)
dim = GLOBAL_WINDOW_SIZE # 32,272
poi_width = 700
print(f"Input: {dim} samples, POI width: {poi_width} samples ({poi_width/dim*100:.1f}% of window)")
for layer in range(1, 6):
dim = dim // 2 # AvgPool(2)
poi_width = poi_width // 2
print(f"After pool {layer}: {dim} samples, POI width: ~{poi_width} samples ({poi_width/dim*100:.1f}% of window)")
print(f"\nAfter 5 pools: each POI region is ~{700//32} samples out of {32272//32} total")
print(f"After GAP: ALL spatial information is collapsed to a single vector")
print(f"The model CANNOT distinguish which spatial region a feature came from after GAP")
# Correlation analysis
print("\n" + "=" * 90)
print("CORRELATION ANALYSIS")
print("=" * 90)
from scipy import stats
positions = [((BYTE_POI_WINDOWS[i][0] + BYTE_POI_WINDOWS[i][1]) / 2 - GLOBAL_WINDOW_START) / GLOBAL_WINDOW_SIZE for i in range(16)]
ranks = [HPS_RANKS[i] for i in range(16)]
snrs = [BYTE_PEAK_SNR[i] for i in range(16)]
# Distance from center of global window
center_distances = [abs(p - 0.5) for p in positions]
r_pos, p_pos = stats.pearsonr(positions, ranks)
r_snr, p_snr = stats.pearsonr(snrs, ranks)
r_center, p_center = stats.pearsonr(center_distances, ranks)
print(f"Correlation (POI position vs rank): r={r_pos:.3f}, p={p_pos:.3f}")
print(f"Correlation (SNR vs rank): r={r_snr:.3f}, p={p_snr:.3f}")
print(f"Correlation (distance from center vs rank): r={r_center:.3f}, p={p_center:.3f}")
# Check if bytes 0,1 have anything special about their POI positions
print(f"\nByte 0 POI center: {(30838+31538)/2:.0f}, relative position: {((30838+31538)/2 - 18000)/32272:.3f}")
print(f"Byte 1 POI center: {(24525+25225)/2:.0f}, relative position: {((24525+25225)/2 - 18000)/32272:.3f}")
print(f"Window center: {18000 + 32272/2:.0f}, relative position: 0.500")
# Check byte ordering in AES execution
print("\n" + "=" * 90)
print("BYTE POI ORDERING vs AES EXECUTION ORDER")
print("=" * 90)
print("AES SubBytes processes bytes 0-15 sequentially.")
print("The POI positions should reflect this execution order.")
print()
sorted_by_poi = sorted(range(16), key=lambda i: BYTE_POI_WINDOWS[i][0])
print(f"Execution order (by POI position): {sorted_by_poi}")
print(f"This is NOT sequential 0-15, suggesting the compiler/hardware reorders operations.")
print()
print("Bytes that learned (rank <= 2):", [i for i in range(16) if HPS_RANKS[i] <= 2])
print("Their POI positions:", [(i, BYTE_POI_WINDOWS[i]) for i in range(16) if HPS_RANKS[i] <= 2])
print()
print("Bytes that were close (rank <= 30):", [i for i in range(16) if 2 < HPS_RANKS[i] <= 30])
print("Their POI positions:", [(i, BYTE_POI_WINDOWS[i]) for i in range(16) if 2 < HPS_RANKS[i] <= 30])