bitnet-1bitllm / vm_backup /code /analyze_deep.py
hidude562's picture
1bitllm code (checkpoints to follow)
4754707 verified
"""Deep diagnostic tests for understanding a trained binary LM.
Beyond the first-pass analysis:
F. Per-head attention pattern classification (recent / first-token / content / long)
G. Position-wise BPC — how does BPC depend on position in the sequence?
H. Context-length sweep — BPC as a function of how much context we give
I. Layer-wise CKA similarity — which layers carry redundant information?
J. Logit margin distribution — how confident is the model on right vs wrong?
K. Per-head knockout — which heads are load-bearing?
L. Effective parameter count — how many weights actually move the output?
M. Character embedding clustering — do similar chars cluster in ±1 space?
N. Bit-flip robustness — how much does one random flip cost?
"""
import argparse, json, math, os, time
import numpy as np
import torch
import torch.nn.functional as F
from model_v18 import BitLMv18
from model_fp32 import FP32LM
from model_v16 import set_gumbel_tau
def load_binary(path, device='cuda'):
ck = torch.load(path, map_location=device, weights_only=False)
cfg = ck['args']
m = BitLMv18(vocab_size=cfg['vocab_size'], d_model=cfg['d_model'],
n_layers=cfg['n_layers'], n_heads=cfg['n_heads'],
d_ff=cfg['d_ff'], max_seq_len=cfg['seq_len']).to(device)
m.load_state_dict(ck['model'])
m.eval()
return m, ck
def sample_batch(data, batch_size, seq_len, device='cuda'):
ix = torch.randint(0, len(data) - seq_len - 1, (batch_size,))
x = torch.stack([torch.from_numpy(data[i:i+seq_len].astype(np.int64)) for i in ix]).to(device)
y = torch.stack([torch.from_numpy(data[i+1:i+1+seq_len].astype(np.int64)) for i in ix]).to(device)
return x, y
# ---------------- F: Attention head pattern ----------------
@torch.no_grad()
def head_attention_patterns(m, val, n_batches=5, bs=8, seq_len=256, device='cuda'):
"""Classify each (layer, head) by where it attends:
recent = mean(|i-j|) small
long-range = mean(|i-j|) large
first-token = argmax often = 0
content-sensitive = variance of argmax across identical positions
"""
results = []
with torch.no_grad():
for li, blk in enumerate(m.blocks):
attn = blk.attn
H, Dh = attn.n_heads, attn.head_dim
dists_per_head = [[] for _ in range(H)]
first_tok_per_head = [[] for _ in range(H)]
var_per_head = [[] for _ in range(H)]
for _ in range(n_batches):
x, _ = sample_batch(val, bs, seq_len, device)
xe = m.embed(x)
for k in range(li):
xe = m.blocks[k](xe)
B, T, D = xe.shape
Q = attn.q_proj(xe).view(B, T, H, Dh).transpose(1, 2)
K = attn.k_proj(xe).view(B, T, H, Dh).transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1))
pos = torch.arange(T, device=device).float()
dist = (pos.unsqueeze(0) - pos.unsqueeze(1)).abs()
alibi = attn.alibi_slopes_int.view(1, H, 1, 1).float() * dist.view(1, 1, T, T)
scores = scores - alibi
mask = torch.triu(torch.ones(T, T, device=device, dtype=torch.bool), diagonal=1)
scores = scores.masked_fill(mask, -1e9)
argmax_keys = scores.argmax(dim=-1) # (B, H, T)
for h in range(H):
ak = argmax_keys[:, h, :] # (B, T)
# Average distance (only valid positions where i >= 0)
pos_t = torch.arange(T, device=device).unsqueeze(0).expand(B, -1)
d = (pos_t - ak).abs().float()
# Only count positions where attention is meaningful (j != -inf masked)
dists_per_head[h].append(d.mean().item())
first_tok_per_head[h].append((ak == 0).float().mean().item())
# Content variance: for the LAST position, how much does the choice
# vary across different inputs? High variance = content-sensitive
last_pos_ak = ak[:, T // 2] # mid position
var_per_head[h].append(last_pos_ak.float().std().item())
for h in range(H):
mean_dist = np.mean(dists_per_head[h])
first_frac = np.mean(first_tok_per_head[h])
content_var = np.mean(var_per_head[h])
# Classify
if first_frac > 0.5:
kind = 'first-token-sink'
elif mean_dist < 3:
kind = 'recent'
elif mean_dist > seq_len / 4:
kind = 'long-range'
elif content_var > 5:
kind = 'content-sensitive'
else:
kind = 'positional'
results.append({'layer': li, 'head': h,
'mean_dist': float(mean_dist),
'first_tok_frac': float(first_frac),
'content_var': float(content_var),
'kind': kind,
'alibi_slope': int(attn.alibi_slopes_int[h].item())})
return results
# ---------------- G: Position-wise BPC ----------------
@torch.no_grad()
def position_bpc(m, val, n_batches=20, bs=32, seq_len=256, device='cuda'):
"""BPC per position in the sequence, averaged over batches."""
loss_sum = torch.zeros(seq_len, device=device)
loss_cnt = torch.zeros(seq_len, device=device)
for _ in range(n_batches):
x, y = sample_batch(val, bs, seq_len, device)
logits, _ = m(x, y)
losses = F.cross_entropy(logits.permute(0, 2, 1), y, reduction='none') # (B, T)
loss_sum += losses.sum(dim=0)
loss_cnt += losses.shape[0]
avg = (loss_sum / loss_cnt).cpu().numpy() / math.log(2)
return {'bpc_per_position': avg.tolist(),
'bpc_quartile_starts': [float(avg[:seq_len//4].mean()),
float(avg[seq_len//4:seq_len//2].mean()),
float(avg[seq_len//2:3*seq_len//4].mean()),
float(avg[3*seq_len//4:].mean())]}
# ---------------- H: Context-length sweep ----------------
@torch.no_grad()
def context_length_sweep(m, val, n_batches=20, bs=32, seq_len=256, device='cuda'):
"""For held-out data, BPC at different context lengths. Prediction position = last."""
results = []
ctx_lens = [1, 4, 16, 64, 128, 256]
for cl in ctx_lens:
if cl > seq_len: continue
losses = []
for _ in range(n_batches):
x, y = sample_batch(val, bs, seq_len, device)
x_ctx = x[:, :cl]
y_target = y[:, cl - 1:cl]
logits, _ = m(x_ctx)
# predict the last position
pred_logits = logits[:, -1, :]
loss = F.cross_entropy(pred_logits, y_target.squeeze(-1))
losses.append(loss.item())
avg = float(np.mean(losses)) / math.log(2)
results.append({'context_len': cl, 'bpc_last_position': avg})
return results
# ---------------- I: Layer-wise CKA similarity ----------------
@torch.no_grad()
def layer_similarity(m, val, n_batches=5, bs=16, seq_len=256, device='cuda'):
"""Centered Kernel Alignment between hidden states at each pair of layers.
High = redundant layers."""
n_layers = len(m.blocks)
# Collect hidden states
H_all = [[] for _ in range(n_layers)]
for _ in range(n_batches):
x, _ = sample_batch(val, bs, seq_len, device)
xe = m.embed(x)
for li, blk in enumerate(m.blocks):
xe = blk(xe)
H_all[li].append(xe.reshape(-1, xe.shape[-1]).float().cpu())
# For CKA, we need large matrices; compute cross-layer similarity via
# simple agreement (both are ±1) for efficiency.
agree = np.zeros((n_layers, n_layers))
for i in range(n_layers):
hi = torch.cat(H_all[i], dim=0)
for j in range(n_layers):
hj = torch.cat(H_all[j], dim=0)
# Cosine-ish: for ±1 vectors, row-averaged per-token agreement
# Here we want COLUMN-wise (dimension-wise) correlation
# Simpler: just mean element-wise agreement
agree[i, j] = (hi == hj).float().mean().item()
return {'similarity_matrix': agree.tolist()}
# ---------------- J: Logit margin distribution ----------------
@torch.no_grad()
def logit_margin_distribution(m, val, n_batches=20, bs=32, seq_len=256, device='cuda'):
"""For correct vs incorrect predictions, distribution of top1-top2 logit margin."""
correct_margins = []
wrong_margins = []
wrong_top2_correct = 0 # fraction of wrong predictions where correct is top-2
total_wrong = 0
for _ in range(n_batches):
x, y = sample_batch(val, bs, seq_len, device)
logits, _ = m(x, y)
y_flat = y.view(-1)
l_flat = logits.view(-1, logits.shape[-1])
pred = l_flat.argmax(dim=-1)
correct_mask = (pred == y_flat)
# top1 - top2 margin
sorted_vals, sorted_idx = torch.topk(l_flat, 2, dim=-1)
margin = (sorted_vals[:, 0] - sorted_vals[:, 1]).cpu().numpy()
cm = margin[correct_mask.cpu().numpy()]
wm = margin[~correct_mask.cpu().numpy()]
correct_margins.append(cm)
wrong_margins.append(wm)
# For wrong preds, is correct in top 2?
wrong_mask = ~correct_mask
top2 = sorted_idx[:, 1]
wrong_top2_correct += (top2[wrong_mask] == y_flat[wrong_mask]).float().sum().item()
total_wrong += wrong_mask.sum().item()
correct_margins = np.concatenate(correct_margins)
wrong_margins = np.concatenate(wrong_margins)
return {
'correct_count': int(correct_margins.size),
'wrong_count': int(wrong_margins.size),
'correct_margin_mean': float(correct_margins.mean()),
'correct_margin_median': float(np.median(correct_margins)),
'wrong_margin_mean': float(wrong_margins.mean()),
'wrong_margin_median': float(np.median(wrong_margins)),
'wrong_frac_correct_in_top2': wrong_top2_correct / max(1, total_wrong),
}
# ---------------- K: Per-head knockout ----------------
@torch.no_grad()
def per_head_knockout(m, val, n_batches=10, bs=32, seq_len=256, device='cuda'):
"""Zero out each individual attention head, measure BPC delta."""
# Baseline
base_losses = []
for _ in range(n_batches):
x, y = sample_batch(val, bs, seq_len, device)
_, loss = m(x, y)
base_losses.append(loss.item())
base_bpc = float(np.mean(base_losses)) / math.log(2)
results = []
for li, blk in enumerate(m.blocks):
attn = blk.attn
H = attn.n_heads
Dh = attn.head_dim
orig = attn.forward
for h_idx in range(H):
# Wrap attention to zero-out head h_idx
def make_wrapped(blk_ref, head_to_zero):
def wrapped(x_in):
out = orig(x_in)
# Head h_idx occupies bits [h*Dh : (h+1)*Dh] in d_model
# Zero that slice in the ±1 output
B, T, D = out.shape
start = head_to_zero * Dh
end = start + Dh
out = out.clone()
out[..., start:end] = 0 # 0 is "null" not ±1, breaks strictness but OK for analysis
return out
return wrapped
attn.forward = make_wrapped(attn, h_idx)
ko_losses = []
for _ in range(n_batches):
x, y = sample_batch(val, bs, seq_len, device)
_, loss = m(x, y)
ko_losses.append(loss.item())
attn.forward = orig
ko_bpc = float(np.mean(ko_losses)) / math.log(2)
results.append({'layer': li, 'head': h_idx,
'baseline_bpc': base_bpc,
'knockout_bpc': ko_bpc,
'delta_bpc': ko_bpc - base_bpc})
return {'baseline_bpc': base_bpc, 'per_head': results}
# ---------------- L: Effective parameter count via random bit flip ----------------
@torch.no_grad()
def bit_flip_robustness(m, val, n_batches=10, bs=32, seq_len=256, device='cuda'):
"""Measure how much BPC degrades when we flip p% of latent weight signs."""
base_losses = []
for _ in range(n_batches):
x, y = sample_batch(val, bs, seq_len, device)
_, loss = m(x, y)
base_losses.append(loss.item())
base_bpc = float(np.mean(base_losses)) / math.log(2)
# Collect flippable weights (2D only)
params = [(name, p) for name, p in m.named_parameters() if p.dim() >= 2]
results = []
for p_flip in [0.001, 0.01, 0.05, 0.10]:
# Save originals
originals = [p.clone() for _, p in params]
# Flip random fraction
for _, p in params:
flip_mask = torch.rand_like(p) < p_flip
p.mul_(torch.where(flip_mask, -1.0, 1.0))
# Measure
flip_losses = []
for _ in range(n_batches):
x, y = sample_batch(val, bs, seq_len, device)
_, loss = m(x, y)
flip_losses.append(loss.item())
flip_bpc = float(np.mean(flip_losses)) / math.log(2)
# Restore
for (_, p), orig in zip(params, originals):
p.copy_(orig)
results.append({'flip_fraction': p_flip,
'bpc_after_flip': flip_bpc,
'delta_bpc': flip_bpc - base_bpc})
return {'baseline_bpc': base_bpc, 'flip_sweep': results}
# ---------------- M: Character embedding clustering ----------------
@torch.no_grad()
def char_embedding_geometry(m):
"""Compute pairwise Hamming similarity between character embedding codebooks."""
W = torch.sign(m.embed.weight) # (V, D)
W[W == 0] = 1
V, D = W.shape
# Similarity = Hamming agreement
sim = (W @ W.t()) / D # value in [-1, 1]
sim_np = sim.cpu().numpy()
# Find clusters by looking at top-5 similar chars for a few test chars
interest_chars = [ord(c) for c in 'aetoiAEnbz .,?!0']
neighbors = {}
for c in interest_chars:
if c < V:
vals, idx = torch.topk(sim[c], 6) # itself + 5 neighbors
ns = [(int(idx[k].item()), float(vals[k].item())) for k in range(6)]
neighbors[repr(chr(c))] = ns
return {
'mean_abs_similarity': float(sim_np[~np.eye(V, dtype=bool)].mean()),
'max_similarity_off_diag': float(sim_np[~np.eye(V, dtype=bool)].max()),
'neighbors_sample': {k: [(chr(c) if 32 <= c < 127 else f'<{c}>', float(s))
for c, s in v] for k, v in neighbors.items()}
}
# ---------------- Main ----------------
def main():
ap = argparse.ArgumentParser()
ap.add_argument('--ckpt', required=True)
ap.add_argument('--data', default='/root/bitnet1/data/validation.bin')
ap.add_argument('--out', required=True)
ap.add_argument('--tau', type=float, default=0.1)
args = ap.parse_args()
set_gumbel_tau(args.tau)
val = np.memmap(args.data, dtype=np.uint8, mode='r')
m, ck = load_binary(args.ckpt)
cfg = ck['args']
out = {
'ckpt': args.ckpt, 'config': cfg, 'val_bpc': ck.get('val_bpc'),
'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
}
print("F. Attention head patterns...")
out['head_patterns'] = head_attention_patterns(m, val)
print(f" {len(out['head_patterns'])} heads classified")
print("G. Position-wise BPC...")
out['position_bpc'] = position_bpc(m, val)
print(f" quartiles: {out['position_bpc']['bpc_quartile_starts']}")
print("H. Context-length sweep...")
out['context_sweep'] = context_length_sweep(m, val)
print("I. Layer similarity matrix...")
out['layer_similarity'] = layer_similarity(m, val)
print("J. Logit margin distribution...")
out['logit_margins'] = logit_margin_distribution(m, val)
print("K. Per-head knockout...")
out['head_knockout'] = per_head_knockout(m, val)
print("L. Bit-flip robustness...")
out['bit_flip'] = bit_flip_robustness(m, val)
print("M. Character embedding geometry...")
out['char_geometry'] = char_embedding_geometry(m)
with open(args.out, 'w') as f:
json.dump(out, f, indent=2, default=str)
print(f"Wrote {args.out}")
if __name__ == '__main__':
main()