File size: 16,516 Bytes
4754707 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 | """Deep diagnostic tests for understanding a trained binary LM.
Beyond the first-pass analysis:
F. Per-head attention pattern classification (recent / first-token / content / long)
G. Position-wise BPC — how does BPC depend on position in the sequence?
H. Context-length sweep — BPC as a function of how much context we give
I. Layer-wise CKA similarity — which layers carry redundant information?
J. Logit margin distribution — how confident is the model on right vs wrong?
K. Per-head knockout — which heads are load-bearing?
L. Effective parameter count — how many weights actually move the output?
M. Character embedding clustering — do similar chars cluster in ±1 space?
N. Bit-flip robustness — how much does one random flip cost?
"""
import argparse, json, math, os, time
import numpy as np
import torch
import torch.nn.functional as F
from model_v18 import BitLMv18
from model_fp32 import FP32LM
from model_v16 import set_gumbel_tau
def load_binary(path, device='cuda'):
ck = torch.load(path, map_location=device, weights_only=False)
cfg = ck['args']
m = BitLMv18(vocab_size=cfg['vocab_size'], d_model=cfg['d_model'],
n_layers=cfg['n_layers'], n_heads=cfg['n_heads'],
d_ff=cfg['d_ff'], max_seq_len=cfg['seq_len']).to(device)
m.load_state_dict(ck['model'])
m.eval()
return m, ck
def sample_batch(data, batch_size, seq_len, device='cuda'):
ix = torch.randint(0, len(data) - seq_len - 1, (batch_size,))
x = torch.stack([torch.from_numpy(data[i:i+seq_len].astype(np.int64)) for i in ix]).to(device)
y = torch.stack([torch.from_numpy(data[i+1:i+1+seq_len].astype(np.int64)) for i in ix]).to(device)
return x, y
# ---------------- F: Attention head pattern ----------------
@torch.no_grad()
def head_attention_patterns(m, val, n_batches=5, bs=8, seq_len=256, device='cuda'):
"""Classify each (layer, head) by where it attends:
recent = mean(|i-j|) small
long-range = mean(|i-j|) large
first-token = argmax often = 0
content-sensitive = variance of argmax across identical positions
"""
results = []
with torch.no_grad():
for li, blk in enumerate(m.blocks):
attn = blk.attn
H, Dh = attn.n_heads, attn.head_dim
dists_per_head = [[] for _ in range(H)]
first_tok_per_head = [[] for _ in range(H)]
var_per_head = [[] for _ in range(H)]
for _ in range(n_batches):
x, _ = sample_batch(val, bs, seq_len, device)
xe = m.embed(x)
for k in range(li):
xe = m.blocks[k](xe)
B, T, D = xe.shape
Q = attn.q_proj(xe).view(B, T, H, Dh).transpose(1, 2)
K = attn.k_proj(xe).view(B, T, H, Dh).transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1))
pos = torch.arange(T, device=device).float()
dist = (pos.unsqueeze(0) - pos.unsqueeze(1)).abs()
alibi = attn.alibi_slopes_int.view(1, H, 1, 1).float() * dist.view(1, 1, T, T)
scores = scores - alibi
mask = torch.triu(torch.ones(T, T, device=device, dtype=torch.bool), diagonal=1)
scores = scores.masked_fill(mask, -1e9)
argmax_keys = scores.argmax(dim=-1) # (B, H, T)
for h in range(H):
ak = argmax_keys[:, h, :] # (B, T)
# Average distance (only valid positions where i >= 0)
pos_t = torch.arange(T, device=device).unsqueeze(0).expand(B, -1)
d = (pos_t - ak).abs().float()
# Only count positions where attention is meaningful (j != -inf masked)
dists_per_head[h].append(d.mean().item())
first_tok_per_head[h].append((ak == 0).float().mean().item())
# Content variance: for the LAST position, how much does the choice
# vary across different inputs? High variance = content-sensitive
last_pos_ak = ak[:, T // 2] # mid position
var_per_head[h].append(last_pos_ak.float().std().item())
for h in range(H):
mean_dist = np.mean(dists_per_head[h])
first_frac = np.mean(first_tok_per_head[h])
content_var = np.mean(var_per_head[h])
# Classify
if first_frac > 0.5:
kind = 'first-token-sink'
elif mean_dist < 3:
kind = 'recent'
elif mean_dist > seq_len / 4:
kind = 'long-range'
elif content_var > 5:
kind = 'content-sensitive'
else:
kind = 'positional'
results.append({'layer': li, 'head': h,
'mean_dist': float(mean_dist),
'first_tok_frac': float(first_frac),
'content_var': float(content_var),
'kind': kind,
'alibi_slope': int(attn.alibi_slopes_int[h].item())})
return results
# ---------------- G: Position-wise BPC ----------------
@torch.no_grad()
def position_bpc(m, val, n_batches=20, bs=32, seq_len=256, device='cuda'):
"""BPC per position in the sequence, averaged over batches."""
loss_sum = torch.zeros(seq_len, device=device)
loss_cnt = torch.zeros(seq_len, device=device)
for _ in range(n_batches):
x, y = sample_batch(val, bs, seq_len, device)
logits, _ = m(x, y)
losses = F.cross_entropy(logits.permute(0, 2, 1), y, reduction='none') # (B, T)
loss_sum += losses.sum(dim=0)
loss_cnt += losses.shape[0]
avg = (loss_sum / loss_cnt).cpu().numpy() / math.log(2)
return {'bpc_per_position': avg.tolist(),
'bpc_quartile_starts': [float(avg[:seq_len//4].mean()),
float(avg[seq_len//4:seq_len//2].mean()),
float(avg[seq_len//2:3*seq_len//4].mean()),
float(avg[3*seq_len//4:].mean())]}
# ---------------- H: Context-length sweep ----------------
@torch.no_grad()
def context_length_sweep(m, val, n_batches=20, bs=32, seq_len=256, device='cuda'):
"""For held-out data, BPC at different context lengths. Prediction position = last."""
results = []
ctx_lens = [1, 4, 16, 64, 128, 256]
for cl in ctx_lens:
if cl > seq_len: continue
losses = []
for _ in range(n_batches):
x, y = sample_batch(val, bs, seq_len, device)
x_ctx = x[:, :cl]
y_target = y[:, cl - 1:cl]
logits, _ = m(x_ctx)
# predict the last position
pred_logits = logits[:, -1, :]
loss = F.cross_entropy(pred_logits, y_target.squeeze(-1))
losses.append(loss.item())
avg = float(np.mean(losses)) / math.log(2)
results.append({'context_len': cl, 'bpc_last_position': avg})
return results
# ---------------- I: Layer-wise CKA similarity ----------------
@torch.no_grad()
def layer_similarity(m, val, n_batches=5, bs=16, seq_len=256, device='cuda'):
"""Centered Kernel Alignment between hidden states at each pair of layers.
High = redundant layers."""
n_layers = len(m.blocks)
# Collect hidden states
H_all = [[] for _ in range(n_layers)]
for _ in range(n_batches):
x, _ = sample_batch(val, bs, seq_len, device)
xe = m.embed(x)
for li, blk in enumerate(m.blocks):
xe = blk(xe)
H_all[li].append(xe.reshape(-1, xe.shape[-1]).float().cpu())
# For CKA, we need large matrices; compute cross-layer similarity via
# simple agreement (both are ±1) for efficiency.
agree = np.zeros((n_layers, n_layers))
for i in range(n_layers):
hi = torch.cat(H_all[i], dim=0)
for j in range(n_layers):
hj = torch.cat(H_all[j], dim=0)
# Cosine-ish: for ±1 vectors, row-averaged per-token agreement
# Here we want COLUMN-wise (dimension-wise) correlation
# Simpler: just mean element-wise agreement
agree[i, j] = (hi == hj).float().mean().item()
return {'similarity_matrix': agree.tolist()}
# ---------------- J: Logit margin distribution ----------------
@torch.no_grad()
def logit_margin_distribution(m, val, n_batches=20, bs=32, seq_len=256, device='cuda'):
"""For correct vs incorrect predictions, distribution of top1-top2 logit margin."""
correct_margins = []
wrong_margins = []
wrong_top2_correct = 0 # fraction of wrong predictions where correct is top-2
total_wrong = 0
for _ in range(n_batches):
x, y = sample_batch(val, bs, seq_len, device)
logits, _ = m(x, y)
y_flat = y.view(-1)
l_flat = logits.view(-1, logits.shape[-1])
pred = l_flat.argmax(dim=-1)
correct_mask = (pred == y_flat)
# top1 - top2 margin
sorted_vals, sorted_idx = torch.topk(l_flat, 2, dim=-1)
margin = (sorted_vals[:, 0] - sorted_vals[:, 1]).cpu().numpy()
cm = margin[correct_mask.cpu().numpy()]
wm = margin[~correct_mask.cpu().numpy()]
correct_margins.append(cm)
wrong_margins.append(wm)
# For wrong preds, is correct in top 2?
wrong_mask = ~correct_mask
top2 = sorted_idx[:, 1]
wrong_top2_correct += (top2[wrong_mask] == y_flat[wrong_mask]).float().sum().item()
total_wrong += wrong_mask.sum().item()
correct_margins = np.concatenate(correct_margins)
wrong_margins = np.concatenate(wrong_margins)
return {
'correct_count': int(correct_margins.size),
'wrong_count': int(wrong_margins.size),
'correct_margin_mean': float(correct_margins.mean()),
'correct_margin_median': float(np.median(correct_margins)),
'wrong_margin_mean': float(wrong_margins.mean()),
'wrong_margin_median': float(np.median(wrong_margins)),
'wrong_frac_correct_in_top2': wrong_top2_correct / max(1, total_wrong),
}
# ---------------- K: Per-head knockout ----------------
@torch.no_grad()
def per_head_knockout(m, val, n_batches=10, bs=32, seq_len=256, device='cuda'):
"""Zero out each individual attention head, measure BPC delta."""
# Baseline
base_losses = []
for _ in range(n_batches):
x, y = sample_batch(val, bs, seq_len, device)
_, loss = m(x, y)
base_losses.append(loss.item())
base_bpc = float(np.mean(base_losses)) / math.log(2)
results = []
for li, blk in enumerate(m.blocks):
attn = blk.attn
H = attn.n_heads
Dh = attn.head_dim
orig = attn.forward
for h_idx in range(H):
# Wrap attention to zero-out head h_idx
def make_wrapped(blk_ref, head_to_zero):
def wrapped(x_in):
out = orig(x_in)
# Head h_idx occupies bits [h*Dh : (h+1)*Dh] in d_model
# Zero that slice in the ±1 output
B, T, D = out.shape
start = head_to_zero * Dh
end = start + Dh
out = out.clone()
out[..., start:end] = 0 # 0 is "null" not ±1, breaks strictness but OK for analysis
return out
return wrapped
attn.forward = make_wrapped(attn, h_idx)
ko_losses = []
for _ in range(n_batches):
x, y = sample_batch(val, bs, seq_len, device)
_, loss = m(x, y)
ko_losses.append(loss.item())
attn.forward = orig
ko_bpc = float(np.mean(ko_losses)) / math.log(2)
results.append({'layer': li, 'head': h_idx,
'baseline_bpc': base_bpc,
'knockout_bpc': ko_bpc,
'delta_bpc': ko_bpc - base_bpc})
return {'baseline_bpc': base_bpc, 'per_head': results}
# ---------------- L: Effective parameter count via random bit flip ----------------
@torch.no_grad()
def bit_flip_robustness(m, val, n_batches=10, bs=32, seq_len=256, device='cuda'):
"""Measure how much BPC degrades when we flip p% of latent weight signs."""
base_losses = []
for _ in range(n_batches):
x, y = sample_batch(val, bs, seq_len, device)
_, loss = m(x, y)
base_losses.append(loss.item())
base_bpc = float(np.mean(base_losses)) / math.log(2)
# Collect flippable weights (2D only)
params = [(name, p) for name, p in m.named_parameters() if p.dim() >= 2]
results = []
for p_flip in [0.001, 0.01, 0.05, 0.10]:
# Save originals
originals = [p.clone() for _, p in params]
# Flip random fraction
for _, p in params:
flip_mask = torch.rand_like(p) < p_flip
p.mul_(torch.where(flip_mask, -1.0, 1.0))
# Measure
flip_losses = []
for _ in range(n_batches):
x, y = sample_batch(val, bs, seq_len, device)
_, loss = m(x, y)
flip_losses.append(loss.item())
flip_bpc = float(np.mean(flip_losses)) / math.log(2)
# Restore
for (_, p), orig in zip(params, originals):
p.copy_(orig)
results.append({'flip_fraction': p_flip,
'bpc_after_flip': flip_bpc,
'delta_bpc': flip_bpc - base_bpc})
return {'baseline_bpc': base_bpc, 'flip_sweep': results}
# ---------------- M: Character embedding clustering ----------------
@torch.no_grad()
def char_embedding_geometry(m):
"""Compute pairwise Hamming similarity between character embedding codebooks."""
W = torch.sign(m.embed.weight) # (V, D)
W[W == 0] = 1
V, D = W.shape
# Similarity = Hamming agreement
sim = (W @ W.t()) / D # value in [-1, 1]
sim_np = sim.cpu().numpy()
# Find clusters by looking at top-5 similar chars for a few test chars
interest_chars = [ord(c) for c in 'aetoiAEnbz .,?!0']
neighbors = {}
for c in interest_chars:
if c < V:
vals, idx = torch.topk(sim[c], 6) # itself + 5 neighbors
ns = [(int(idx[k].item()), float(vals[k].item())) for k in range(6)]
neighbors[repr(chr(c))] = ns
return {
'mean_abs_similarity': float(sim_np[~np.eye(V, dtype=bool)].mean()),
'max_similarity_off_diag': float(sim_np[~np.eye(V, dtype=bool)].max()),
'neighbors_sample': {k: [(chr(c) if 32 <= c < 127 else f'<{c}>', float(s))
for c, s in v] for k, v in neighbors.items()}
}
# ---------------- Main ----------------
def main():
ap = argparse.ArgumentParser()
ap.add_argument('--ckpt', required=True)
ap.add_argument('--data', default='/root/bitnet1/data/validation.bin')
ap.add_argument('--out', required=True)
ap.add_argument('--tau', type=float, default=0.1)
args = ap.parse_args()
set_gumbel_tau(args.tau)
val = np.memmap(args.data, dtype=np.uint8, mode='r')
m, ck = load_binary(args.ckpt)
cfg = ck['args']
out = {
'ckpt': args.ckpt, 'config': cfg, 'val_bpc': ck.get('val_bpc'),
'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
}
print("F. Attention head patterns...")
out['head_patterns'] = head_attention_patterns(m, val)
print(f" {len(out['head_patterns'])} heads classified")
print("G. Position-wise BPC...")
out['position_bpc'] = position_bpc(m, val)
print(f" quartiles: {out['position_bpc']['bpc_quartile_starts']}")
print("H. Context-length sweep...")
out['context_sweep'] = context_length_sweep(m, val)
print("I. Layer similarity matrix...")
out['layer_similarity'] = layer_similarity(m, val)
print("J. Logit margin distribution...")
out['logit_margins'] = logit_margin_distribution(m, val)
print("K. Per-head knockout...")
out['head_knockout'] = per_head_knockout(m, val)
print("L. Bit-flip robustness...")
out['bit_flip'] = bit_flip_robustness(m, val)
print("M. Character embedding geometry...")
out['char_geometry'] = char_embedding_geometry(m)
with open(args.out, 'w') as f:
json.dump(out, f, indent=2, default=str)
print(f"Wrote {args.out}")
if __name__ == '__main__':
main()
|