| """Deep diagnostic tests for understanding a trained binary LM. |
| |
| Beyond the first-pass analysis: |
| F. Per-head attention pattern classification (recent / first-token / content / long) |
| G. Position-wise BPC — how does BPC depend on position in the sequence? |
| H. Context-length sweep — BPC as a function of how much context we give |
| I. Layer-wise CKA similarity — which layers carry redundant information? |
| J. Logit margin distribution — how confident is the model on right vs wrong? |
| K. Per-head knockout — which heads are load-bearing? |
| L. Effective parameter count — how many weights actually move the output? |
| M. Character embedding clustering — do similar chars cluster in ±1 space? |
| N. Bit-flip robustness — how much does one random flip cost? |
| """ |
| import argparse, json, math, os, time |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
|
|
| from model_v18 import BitLMv18 |
| from model_fp32 import FP32LM |
| from model_v16 import set_gumbel_tau |
|
|
|
|
| def load_binary(path, device='cuda'): |
| ck = torch.load(path, map_location=device, weights_only=False) |
| cfg = ck['args'] |
| m = BitLMv18(vocab_size=cfg['vocab_size'], d_model=cfg['d_model'], |
| n_layers=cfg['n_layers'], n_heads=cfg['n_heads'], |
| d_ff=cfg['d_ff'], max_seq_len=cfg['seq_len']).to(device) |
| m.load_state_dict(ck['model']) |
| m.eval() |
| return m, ck |
|
|
|
|
| def sample_batch(data, batch_size, seq_len, device='cuda'): |
| ix = torch.randint(0, len(data) - seq_len - 1, (batch_size,)) |
| x = torch.stack([torch.from_numpy(data[i:i+seq_len].astype(np.int64)) for i in ix]).to(device) |
| y = torch.stack([torch.from_numpy(data[i+1:i+1+seq_len].astype(np.int64)) for i in ix]).to(device) |
| return x, y |
|
|
|
|
| |
| @torch.no_grad() |
| def head_attention_patterns(m, val, n_batches=5, bs=8, seq_len=256, device='cuda'): |
| """Classify each (layer, head) by where it attends: |
| recent = mean(|i-j|) small |
| long-range = mean(|i-j|) large |
| first-token = argmax often = 0 |
| content-sensitive = variance of argmax across identical positions |
| """ |
| results = [] |
| with torch.no_grad(): |
| for li, blk in enumerate(m.blocks): |
| attn = blk.attn |
| H, Dh = attn.n_heads, attn.head_dim |
| dists_per_head = [[] for _ in range(H)] |
| first_tok_per_head = [[] for _ in range(H)] |
| var_per_head = [[] for _ in range(H)] |
| for _ in range(n_batches): |
| x, _ = sample_batch(val, bs, seq_len, device) |
| xe = m.embed(x) |
| for k in range(li): |
| xe = m.blocks[k](xe) |
| B, T, D = xe.shape |
| Q = attn.q_proj(xe).view(B, T, H, Dh).transpose(1, 2) |
| K = attn.k_proj(xe).view(B, T, H, Dh).transpose(1, 2) |
| scores = torch.matmul(Q, K.transpose(-2, -1)) |
| pos = torch.arange(T, device=device).float() |
| dist = (pos.unsqueeze(0) - pos.unsqueeze(1)).abs() |
| alibi = attn.alibi_slopes_int.view(1, H, 1, 1).float() * dist.view(1, 1, T, T) |
| scores = scores - alibi |
| mask = torch.triu(torch.ones(T, T, device=device, dtype=torch.bool), diagonal=1) |
| scores = scores.masked_fill(mask, -1e9) |
| argmax_keys = scores.argmax(dim=-1) |
| for h in range(H): |
| ak = argmax_keys[:, h, :] |
| |
| pos_t = torch.arange(T, device=device).unsqueeze(0).expand(B, -1) |
| d = (pos_t - ak).abs().float() |
| |
| dists_per_head[h].append(d.mean().item()) |
| first_tok_per_head[h].append((ak == 0).float().mean().item()) |
| |
| |
| last_pos_ak = ak[:, T // 2] |
| var_per_head[h].append(last_pos_ak.float().std().item()) |
|
|
| for h in range(H): |
| mean_dist = np.mean(dists_per_head[h]) |
| first_frac = np.mean(first_tok_per_head[h]) |
| content_var = np.mean(var_per_head[h]) |
| |
| if first_frac > 0.5: |
| kind = 'first-token-sink' |
| elif mean_dist < 3: |
| kind = 'recent' |
| elif mean_dist > seq_len / 4: |
| kind = 'long-range' |
| elif content_var > 5: |
| kind = 'content-sensitive' |
| else: |
| kind = 'positional' |
| results.append({'layer': li, 'head': h, |
| 'mean_dist': float(mean_dist), |
| 'first_tok_frac': float(first_frac), |
| 'content_var': float(content_var), |
| 'kind': kind, |
| 'alibi_slope': int(attn.alibi_slopes_int[h].item())}) |
| return results |
|
|
|
|
| |
| @torch.no_grad() |
| def position_bpc(m, val, n_batches=20, bs=32, seq_len=256, device='cuda'): |
| """BPC per position in the sequence, averaged over batches.""" |
| loss_sum = torch.zeros(seq_len, device=device) |
| loss_cnt = torch.zeros(seq_len, device=device) |
| for _ in range(n_batches): |
| x, y = sample_batch(val, bs, seq_len, device) |
| logits, _ = m(x, y) |
| losses = F.cross_entropy(logits.permute(0, 2, 1), y, reduction='none') |
| loss_sum += losses.sum(dim=0) |
| loss_cnt += losses.shape[0] |
| avg = (loss_sum / loss_cnt).cpu().numpy() / math.log(2) |
| return {'bpc_per_position': avg.tolist(), |
| 'bpc_quartile_starts': [float(avg[:seq_len//4].mean()), |
| float(avg[seq_len//4:seq_len//2].mean()), |
| float(avg[seq_len//2:3*seq_len//4].mean()), |
| float(avg[3*seq_len//4:].mean())]} |
|
|
|
|
| |
| @torch.no_grad() |
| def context_length_sweep(m, val, n_batches=20, bs=32, seq_len=256, device='cuda'): |
| """For held-out data, BPC at different context lengths. Prediction position = last.""" |
| results = [] |
| ctx_lens = [1, 4, 16, 64, 128, 256] |
| for cl in ctx_lens: |
| if cl > seq_len: continue |
| losses = [] |
| for _ in range(n_batches): |
| x, y = sample_batch(val, bs, seq_len, device) |
| x_ctx = x[:, :cl] |
| y_target = y[:, cl - 1:cl] |
| logits, _ = m(x_ctx) |
| |
| pred_logits = logits[:, -1, :] |
| loss = F.cross_entropy(pred_logits, y_target.squeeze(-1)) |
| losses.append(loss.item()) |
| avg = float(np.mean(losses)) / math.log(2) |
| results.append({'context_len': cl, 'bpc_last_position': avg}) |
| return results |
|
|
|
|
| |
| @torch.no_grad() |
| def layer_similarity(m, val, n_batches=5, bs=16, seq_len=256, device='cuda'): |
| """Centered Kernel Alignment between hidden states at each pair of layers. |
| High = redundant layers.""" |
| n_layers = len(m.blocks) |
| |
| H_all = [[] for _ in range(n_layers)] |
| for _ in range(n_batches): |
| x, _ = sample_batch(val, bs, seq_len, device) |
| xe = m.embed(x) |
| for li, blk in enumerate(m.blocks): |
| xe = blk(xe) |
| H_all[li].append(xe.reshape(-1, xe.shape[-1]).float().cpu()) |
| |
| |
| agree = np.zeros((n_layers, n_layers)) |
| for i in range(n_layers): |
| hi = torch.cat(H_all[i], dim=0) |
| for j in range(n_layers): |
| hj = torch.cat(H_all[j], dim=0) |
| |
| |
| |
| agree[i, j] = (hi == hj).float().mean().item() |
| return {'similarity_matrix': agree.tolist()} |
|
|
|
|
| |
| @torch.no_grad() |
| def logit_margin_distribution(m, val, n_batches=20, bs=32, seq_len=256, device='cuda'): |
| """For correct vs incorrect predictions, distribution of top1-top2 logit margin.""" |
| correct_margins = [] |
| wrong_margins = [] |
| wrong_top2_correct = 0 |
| total_wrong = 0 |
| for _ in range(n_batches): |
| x, y = sample_batch(val, bs, seq_len, device) |
| logits, _ = m(x, y) |
| y_flat = y.view(-1) |
| l_flat = logits.view(-1, logits.shape[-1]) |
| pred = l_flat.argmax(dim=-1) |
| correct_mask = (pred == y_flat) |
| |
| sorted_vals, sorted_idx = torch.topk(l_flat, 2, dim=-1) |
| margin = (sorted_vals[:, 0] - sorted_vals[:, 1]).cpu().numpy() |
| cm = margin[correct_mask.cpu().numpy()] |
| wm = margin[~correct_mask.cpu().numpy()] |
| correct_margins.append(cm) |
| wrong_margins.append(wm) |
| |
| wrong_mask = ~correct_mask |
| top2 = sorted_idx[:, 1] |
| wrong_top2_correct += (top2[wrong_mask] == y_flat[wrong_mask]).float().sum().item() |
| total_wrong += wrong_mask.sum().item() |
| correct_margins = np.concatenate(correct_margins) |
| wrong_margins = np.concatenate(wrong_margins) |
| return { |
| 'correct_count': int(correct_margins.size), |
| 'wrong_count': int(wrong_margins.size), |
| 'correct_margin_mean': float(correct_margins.mean()), |
| 'correct_margin_median': float(np.median(correct_margins)), |
| 'wrong_margin_mean': float(wrong_margins.mean()), |
| 'wrong_margin_median': float(np.median(wrong_margins)), |
| 'wrong_frac_correct_in_top2': wrong_top2_correct / max(1, total_wrong), |
| } |
|
|
|
|
| |
| @torch.no_grad() |
| def per_head_knockout(m, val, n_batches=10, bs=32, seq_len=256, device='cuda'): |
| """Zero out each individual attention head, measure BPC delta.""" |
| |
| base_losses = [] |
| for _ in range(n_batches): |
| x, y = sample_batch(val, bs, seq_len, device) |
| _, loss = m(x, y) |
| base_losses.append(loss.item()) |
| base_bpc = float(np.mean(base_losses)) / math.log(2) |
|
|
| results = [] |
| for li, blk in enumerate(m.blocks): |
| attn = blk.attn |
| H = attn.n_heads |
| Dh = attn.head_dim |
| orig = attn.forward |
| for h_idx in range(H): |
| |
| def make_wrapped(blk_ref, head_to_zero): |
| def wrapped(x_in): |
| out = orig(x_in) |
| |
| |
| B, T, D = out.shape |
| start = head_to_zero * Dh |
| end = start + Dh |
| out = out.clone() |
| out[..., start:end] = 0 |
| return out |
| return wrapped |
| attn.forward = make_wrapped(attn, h_idx) |
| ko_losses = [] |
| for _ in range(n_batches): |
| x, y = sample_batch(val, bs, seq_len, device) |
| _, loss = m(x, y) |
| ko_losses.append(loss.item()) |
| attn.forward = orig |
| ko_bpc = float(np.mean(ko_losses)) / math.log(2) |
| results.append({'layer': li, 'head': h_idx, |
| 'baseline_bpc': base_bpc, |
| 'knockout_bpc': ko_bpc, |
| 'delta_bpc': ko_bpc - base_bpc}) |
| return {'baseline_bpc': base_bpc, 'per_head': results} |
|
|
|
|
| |
| @torch.no_grad() |
| def bit_flip_robustness(m, val, n_batches=10, bs=32, seq_len=256, device='cuda'): |
| """Measure how much BPC degrades when we flip p% of latent weight signs.""" |
| base_losses = [] |
| for _ in range(n_batches): |
| x, y = sample_batch(val, bs, seq_len, device) |
| _, loss = m(x, y) |
| base_losses.append(loss.item()) |
| base_bpc = float(np.mean(base_losses)) / math.log(2) |
|
|
| |
| params = [(name, p) for name, p in m.named_parameters() if p.dim() >= 2] |
| results = [] |
| for p_flip in [0.001, 0.01, 0.05, 0.10]: |
| |
| originals = [p.clone() for _, p in params] |
| |
| for _, p in params: |
| flip_mask = torch.rand_like(p) < p_flip |
| p.mul_(torch.where(flip_mask, -1.0, 1.0)) |
| |
| flip_losses = [] |
| for _ in range(n_batches): |
| x, y = sample_batch(val, bs, seq_len, device) |
| _, loss = m(x, y) |
| flip_losses.append(loss.item()) |
| flip_bpc = float(np.mean(flip_losses)) / math.log(2) |
| |
| for (_, p), orig in zip(params, originals): |
| p.copy_(orig) |
| results.append({'flip_fraction': p_flip, |
| 'bpc_after_flip': flip_bpc, |
| 'delta_bpc': flip_bpc - base_bpc}) |
| return {'baseline_bpc': base_bpc, 'flip_sweep': results} |
|
|
|
|
| |
| @torch.no_grad() |
| def char_embedding_geometry(m): |
| """Compute pairwise Hamming similarity between character embedding codebooks.""" |
| W = torch.sign(m.embed.weight) |
| W[W == 0] = 1 |
| V, D = W.shape |
| |
| sim = (W @ W.t()) / D |
| sim_np = sim.cpu().numpy() |
|
|
| |
| interest_chars = [ord(c) for c in 'aetoiAEnbz .,?!0'] |
| neighbors = {} |
| for c in interest_chars: |
| if c < V: |
| vals, idx = torch.topk(sim[c], 6) |
| ns = [(int(idx[k].item()), float(vals[k].item())) for k in range(6)] |
| neighbors[repr(chr(c))] = ns |
| return { |
| 'mean_abs_similarity': float(sim_np[~np.eye(V, dtype=bool)].mean()), |
| 'max_similarity_off_diag': float(sim_np[~np.eye(V, dtype=bool)].max()), |
| 'neighbors_sample': {k: [(chr(c) if 32 <= c < 127 else f'<{c}>', float(s)) |
| for c, s in v] for k, v in neighbors.items()} |
| } |
|
|
|
|
| |
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument('--ckpt', required=True) |
| ap.add_argument('--data', default='/root/bitnet1/data/validation.bin') |
| ap.add_argument('--out', required=True) |
| ap.add_argument('--tau', type=float, default=0.1) |
| args = ap.parse_args() |
|
|
| set_gumbel_tau(args.tau) |
| val = np.memmap(args.data, dtype=np.uint8, mode='r') |
| m, ck = load_binary(args.ckpt) |
| cfg = ck['args'] |
|
|
| out = { |
| 'ckpt': args.ckpt, 'config': cfg, 'val_bpc': ck.get('val_bpc'), |
| 'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'), |
| } |
|
|
| print("F. Attention head patterns...") |
| out['head_patterns'] = head_attention_patterns(m, val) |
| print(f" {len(out['head_patterns'])} heads classified") |
|
|
| print("G. Position-wise BPC...") |
| out['position_bpc'] = position_bpc(m, val) |
| print(f" quartiles: {out['position_bpc']['bpc_quartile_starts']}") |
|
|
| print("H. Context-length sweep...") |
| out['context_sweep'] = context_length_sweep(m, val) |
|
|
| print("I. Layer similarity matrix...") |
| out['layer_similarity'] = layer_similarity(m, val) |
|
|
| print("J. Logit margin distribution...") |
| out['logit_margins'] = logit_margin_distribution(m, val) |
|
|
| print("K. Per-head knockout...") |
| out['head_knockout'] = per_head_knockout(m, val) |
|
|
| print("L. Bit-flip robustness...") |
| out['bit_flip'] = bit_flip_robustness(m, val) |
|
|
| print("M. Character embedding geometry...") |
| out['char_geometry'] = char_embedding_geometry(m) |
|
|
| with open(args.out, 'w') as f: |
| json.dump(out, f, indent=2, default=str) |
| print(f"Wrote {args.out}") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|