"""Deep diagnostic tests for understanding a trained binary LM. Beyond the first-pass analysis: F. Per-head attention pattern classification (recent / first-token / content / long) G. Position-wise BPC — how does BPC depend on position in the sequence? H. Context-length sweep — BPC as a function of how much context we give I. Layer-wise CKA similarity — which layers carry redundant information? J. Logit margin distribution — how confident is the model on right vs wrong? K. Per-head knockout — which heads are load-bearing? L. Effective parameter count — how many weights actually move the output? M. Character embedding clustering — do similar chars cluster in ±1 space? N. Bit-flip robustness — how much does one random flip cost? """ import argparse, json, math, os, time import numpy as np import torch import torch.nn.functional as F from model_v18 import BitLMv18 from model_fp32 import FP32LM from model_v16 import set_gumbel_tau def load_binary(path, device='cuda'): ck = torch.load(path, map_location=device, weights_only=False) cfg = ck['args'] m = BitLMv18(vocab_size=cfg['vocab_size'], d_model=cfg['d_model'], n_layers=cfg['n_layers'], n_heads=cfg['n_heads'], d_ff=cfg['d_ff'], max_seq_len=cfg['seq_len']).to(device) m.load_state_dict(ck['model']) m.eval() return m, ck def sample_batch(data, batch_size, seq_len, device='cuda'): ix = torch.randint(0, len(data) - seq_len - 1, (batch_size,)) x = torch.stack([torch.from_numpy(data[i:i+seq_len].astype(np.int64)) for i in ix]).to(device) y = torch.stack([torch.from_numpy(data[i+1:i+1+seq_len].astype(np.int64)) for i in ix]).to(device) return x, y # ---------------- F: Attention head pattern ---------------- @torch.no_grad() def head_attention_patterns(m, val, n_batches=5, bs=8, seq_len=256, device='cuda'): """Classify each (layer, head) by where it attends: recent = mean(|i-j|) small long-range = mean(|i-j|) large first-token = argmax often = 0 content-sensitive = variance of argmax across identical positions """ results = [] with torch.no_grad(): for li, blk in enumerate(m.blocks): attn = blk.attn H, Dh = attn.n_heads, attn.head_dim dists_per_head = [[] for _ in range(H)] first_tok_per_head = [[] for _ in range(H)] var_per_head = [[] for _ in range(H)] for _ in range(n_batches): x, _ = sample_batch(val, bs, seq_len, device) xe = m.embed(x) for k in range(li): xe = m.blocks[k](xe) B, T, D = xe.shape Q = attn.q_proj(xe).view(B, T, H, Dh).transpose(1, 2) K = attn.k_proj(xe).view(B, T, H, Dh).transpose(1, 2) scores = torch.matmul(Q, K.transpose(-2, -1)) pos = torch.arange(T, device=device).float() dist = (pos.unsqueeze(0) - pos.unsqueeze(1)).abs() alibi = attn.alibi_slopes_int.view(1, H, 1, 1).float() * dist.view(1, 1, T, T) scores = scores - alibi mask = torch.triu(torch.ones(T, T, device=device, dtype=torch.bool), diagonal=1) scores = scores.masked_fill(mask, -1e9) argmax_keys = scores.argmax(dim=-1) # (B, H, T) for h in range(H): ak = argmax_keys[:, h, :] # (B, T) # Average distance (only valid positions where i >= 0) pos_t = torch.arange(T, device=device).unsqueeze(0).expand(B, -1) d = (pos_t - ak).abs().float() # Only count positions where attention is meaningful (j != -inf masked) dists_per_head[h].append(d.mean().item()) first_tok_per_head[h].append((ak == 0).float().mean().item()) # Content variance: for the LAST position, how much does the choice # vary across different inputs? High variance = content-sensitive last_pos_ak = ak[:, T // 2] # mid position var_per_head[h].append(last_pos_ak.float().std().item()) for h in range(H): mean_dist = np.mean(dists_per_head[h]) first_frac = np.mean(first_tok_per_head[h]) content_var = np.mean(var_per_head[h]) # Classify if first_frac > 0.5: kind = 'first-token-sink' elif mean_dist < 3: kind = 'recent' elif mean_dist > seq_len / 4: kind = 'long-range' elif content_var > 5: kind = 'content-sensitive' else: kind = 'positional' results.append({'layer': li, 'head': h, 'mean_dist': float(mean_dist), 'first_tok_frac': float(first_frac), 'content_var': float(content_var), 'kind': kind, 'alibi_slope': int(attn.alibi_slopes_int[h].item())}) return results # ---------------- G: Position-wise BPC ---------------- @torch.no_grad() def position_bpc(m, val, n_batches=20, bs=32, seq_len=256, device='cuda'): """BPC per position in the sequence, averaged over batches.""" loss_sum = torch.zeros(seq_len, device=device) loss_cnt = torch.zeros(seq_len, device=device) for _ in range(n_batches): x, y = sample_batch(val, bs, seq_len, device) logits, _ = m(x, y) losses = F.cross_entropy(logits.permute(0, 2, 1), y, reduction='none') # (B, T) loss_sum += losses.sum(dim=0) loss_cnt += losses.shape[0] avg = (loss_sum / loss_cnt).cpu().numpy() / math.log(2) return {'bpc_per_position': avg.tolist(), 'bpc_quartile_starts': [float(avg[:seq_len//4].mean()), float(avg[seq_len//4:seq_len//2].mean()), float(avg[seq_len//2:3*seq_len//4].mean()), float(avg[3*seq_len//4:].mean())]} # ---------------- H: Context-length sweep ---------------- @torch.no_grad() def context_length_sweep(m, val, n_batches=20, bs=32, seq_len=256, device='cuda'): """For held-out data, BPC at different context lengths. Prediction position = last.""" results = [] ctx_lens = [1, 4, 16, 64, 128, 256] for cl in ctx_lens: if cl > seq_len: continue losses = [] for _ in range(n_batches): x, y = sample_batch(val, bs, seq_len, device) x_ctx = x[:, :cl] y_target = y[:, cl - 1:cl] logits, _ = m(x_ctx) # predict the last position pred_logits = logits[:, -1, :] loss = F.cross_entropy(pred_logits, y_target.squeeze(-1)) losses.append(loss.item()) avg = float(np.mean(losses)) / math.log(2) results.append({'context_len': cl, 'bpc_last_position': avg}) return results # ---------------- I: Layer-wise CKA similarity ---------------- @torch.no_grad() def layer_similarity(m, val, n_batches=5, bs=16, seq_len=256, device='cuda'): """Centered Kernel Alignment between hidden states at each pair of layers. High = redundant layers.""" n_layers = len(m.blocks) # Collect hidden states H_all = [[] for _ in range(n_layers)] for _ in range(n_batches): x, _ = sample_batch(val, bs, seq_len, device) xe = m.embed(x) for li, blk in enumerate(m.blocks): xe = blk(xe) H_all[li].append(xe.reshape(-1, xe.shape[-1]).float().cpu()) # For CKA, we need large matrices; compute cross-layer similarity via # simple agreement (both are ±1) for efficiency. agree = np.zeros((n_layers, n_layers)) for i in range(n_layers): hi = torch.cat(H_all[i], dim=0) for j in range(n_layers): hj = torch.cat(H_all[j], dim=0) # Cosine-ish: for ±1 vectors, row-averaged per-token agreement # Here we want COLUMN-wise (dimension-wise) correlation # Simpler: just mean element-wise agreement agree[i, j] = (hi == hj).float().mean().item() return {'similarity_matrix': agree.tolist()} # ---------------- J: Logit margin distribution ---------------- @torch.no_grad() def logit_margin_distribution(m, val, n_batches=20, bs=32, seq_len=256, device='cuda'): """For correct vs incorrect predictions, distribution of top1-top2 logit margin.""" correct_margins = [] wrong_margins = [] wrong_top2_correct = 0 # fraction of wrong predictions where correct is top-2 total_wrong = 0 for _ in range(n_batches): x, y = sample_batch(val, bs, seq_len, device) logits, _ = m(x, y) y_flat = y.view(-1) l_flat = logits.view(-1, logits.shape[-1]) pred = l_flat.argmax(dim=-1) correct_mask = (pred == y_flat) # top1 - top2 margin sorted_vals, sorted_idx = torch.topk(l_flat, 2, dim=-1) margin = (sorted_vals[:, 0] - sorted_vals[:, 1]).cpu().numpy() cm = margin[correct_mask.cpu().numpy()] wm = margin[~correct_mask.cpu().numpy()] correct_margins.append(cm) wrong_margins.append(wm) # For wrong preds, is correct in top 2? wrong_mask = ~correct_mask top2 = sorted_idx[:, 1] wrong_top2_correct += (top2[wrong_mask] == y_flat[wrong_mask]).float().sum().item() total_wrong += wrong_mask.sum().item() correct_margins = np.concatenate(correct_margins) wrong_margins = np.concatenate(wrong_margins) return { 'correct_count': int(correct_margins.size), 'wrong_count': int(wrong_margins.size), 'correct_margin_mean': float(correct_margins.mean()), 'correct_margin_median': float(np.median(correct_margins)), 'wrong_margin_mean': float(wrong_margins.mean()), 'wrong_margin_median': float(np.median(wrong_margins)), 'wrong_frac_correct_in_top2': wrong_top2_correct / max(1, total_wrong), } # ---------------- K: Per-head knockout ---------------- @torch.no_grad() def per_head_knockout(m, val, n_batches=10, bs=32, seq_len=256, device='cuda'): """Zero out each individual attention head, measure BPC delta.""" # Baseline base_losses = [] for _ in range(n_batches): x, y = sample_batch(val, bs, seq_len, device) _, loss = m(x, y) base_losses.append(loss.item()) base_bpc = float(np.mean(base_losses)) / math.log(2) results = [] for li, blk in enumerate(m.blocks): attn = blk.attn H = attn.n_heads Dh = attn.head_dim orig = attn.forward for h_idx in range(H): # Wrap attention to zero-out head h_idx def make_wrapped(blk_ref, head_to_zero): def wrapped(x_in): out = orig(x_in) # Head h_idx occupies bits [h*Dh : (h+1)*Dh] in d_model # Zero that slice in the ±1 output B, T, D = out.shape start = head_to_zero * Dh end = start + Dh out = out.clone() out[..., start:end] = 0 # 0 is "null" not ±1, breaks strictness but OK for analysis return out return wrapped attn.forward = make_wrapped(attn, h_idx) ko_losses = [] for _ in range(n_batches): x, y = sample_batch(val, bs, seq_len, device) _, loss = m(x, y) ko_losses.append(loss.item()) attn.forward = orig ko_bpc = float(np.mean(ko_losses)) / math.log(2) results.append({'layer': li, 'head': h_idx, 'baseline_bpc': base_bpc, 'knockout_bpc': ko_bpc, 'delta_bpc': ko_bpc - base_bpc}) return {'baseline_bpc': base_bpc, 'per_head': results} # ---------------- L: Effective parameter count via random bit flip ---------------- @torch.no_grad() def bit_flip_robustness(m, val, n_batches=10, bs=32, seq_len=256, device='cuda'): """Measure how much BPC degrades when we flip p% of latent weight signs.""" base_losses = [] for _ in range(n_batches): x, y = sample_batch(val, bs, seq_len, device) _, loss = m(x, y) base_losses.append(loss.item()) base_bpc = float(np.mean(base_losses)) / math.log(2) # Collect flippable weights (2D only) params = [(name, p) for name, p in m.named_parameters() if p.dim() >= 2] results = [] for p_flip in [0.001, 0.01, 0.05, 0.10]: # Save originals originals = [p.clone() for _, p in params] # Flip random fraction for _, p in params: flip_mask = torch.rand_like(p) < p_flip p.mul_(torch.where(flip_mask, -1.0, 1.0)) # Measure flip_losses = [] for _ in range(n_batches): x, y = sample_batch(val, bs, seq_len, device) _, loss = m(x, y) flip_losses.append(loss.item()) flip_bpc = float(np.mean(flip_losses)) / math.log(2) # Restore for (_, p), orig in zip(params, originals): p.copy_(orig) results.append({'flip_fraction': p_flip, 'bpc_after_flip': flip_bpc, 'delta_bpc': flip_bpc - base_bpc}) return {'baseline_bpc': base_bpc, 'flip_sweep': results} # ---------------- M: Character embedding clustering ---------------- @torch.no_grad() def char_embedding_geometry(m): """Compute pairwise Hamming similarity between character embedding codebooks.""" W = torch.sign(m.embed.weight) # (V, D) W[W == 0] = 1 V, D = W.shape # Similarity = Hamming agreement sim = (W @ W.t()) / D # value in [-1, 1] sim_np = sim.cpu().numpy() # Find clusters by looking at top-5 similar chars for a few test chars interest_chars = [ord(c) for c in 'aetoiAEnbz .,?!0'] neighbors = {} for c in interest_chars: if c < V: vals, idx = torch.topk(sim[c], 6) # itself + 5 neighbors ns = [(int(idx[k].item()), float(vals[k].item())) for k in range(6)] neighbors[repr(chr(c))] = ns return { 'mean_abs_similarity': float(sim_np[~np.eye(V, dtype=bool)].mean()), 'max_similarity_off_diag': float(sim_np[~np.eye(V, dtype=bool)].max()), 'neighbors_sample': {k: [(chr(c) if 32 <= c < 127 else f'<{c}>', float(s)) for c, s in v] for k, v in neighbors.items()} } # ---------------- Main ---------------- def main(): ap = argparse.ArgumentParser() ap.add_argument('--ckpt', required=True) ap.add_argument('--data', default='/root/bitnet1/data/validation.bin') ap.add_argument('--out', required=True) ap.add_argument('--tau', type=float, default=0.1) args = ap.parse_args() set_gumbel_tau(args.tau) val = np.memmap(args.data, dtype=np.uint8, mode='r') m, ck = load_binary(args.ckpt) cfg = ck['args'] out = { 'ckpt': args.ckpt, 'config': cfg, 'val_bpc': ck.get('val_bpc'), 'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'), } print("F. Attention head patterns...") out['head_patterns'] = head_attention_patterns(m, val) print(f" {len(out['head_patterns'])} heads classified") print("G. Position-wise BPC...") out['position_bpc'] = position_bpc(m, val) print(f" quartiles: {out['position_bpc']['bpc_quartile_starts']}") print("H. Context-length sweep...") out['context_sweep'] = context_length_sweep(m, val) print("I. Layer similarity matrix...") out['layer_similarity'] = layer_similarity(m, val) print("J. Logit margin distribution...") out['logit_margins'] = logit_margin_distribution(m, val) print("K. Per-head knockout...") out['head_knockout'] = per_head_knockout(m, val) print("L. Bit-flip robustness...") out['bit_flip'] = bit_flip_robustness(m, val) print("M. Character embedding geometry...") out['char_geometry'] = char_embedding_geometry(m) with open(args.out, 'w') as f: json.dump(out, f, indent=2, default=str) print(f"Wrote {args.out}") if __name__ == '__main__': main()