"""Diagnostic tests on trained binary and FP32 models. Outputs structured JSON that analyze_report.py compiles into a readable report. Each test tries to reveal *mechanism*, not just measure BPC. """ import argparse, json, math, os, time import numpy as np import torch import torch.nn.functional as F from model_v18 import BitLMv18 from model_fp32 import FP32LM from model_v16 import set_gumbel_tau def load_binary_ckpt(path, device='cuda'): ck = torch.load(path, map_location=device, weights_only=False) cfg = ck['args'] m = BitLMv18( vocab_size=cfg['vocab_size'], d_model=cfg['d_model'], n_layers=cfg['n_layers'], n_heads=cfg['n_heads'], d_ff=cfg['d_ff'], max_seq_len=cfg['seq_len'], ).to(device) m.load_state_dict(ck['model']) m.eval() return m, ck def load_fp32_ckpt(path, device='cuda'): ck = torch.load(path, map_location=device, weights_only=False) cfg = ck['args'] m = FP32LM( vocab_size=cfg['vocab_size'], d_model=cfg['d_model'], n_layers=cfg['n_layers'], n_heads=cfg['n_heads'], d_ff=cfg['d_ff'], max_seq_len=cfg['seq_len'], ).to(device) m.load_state_dict(ck['model']) m.eval() return m, ck def sample_eval_batch(data, batch_size, seq_len, device='cuda'): ix = torch.randint(0, len(data) - seq_len - 1, (batch_size,)) x = torch.stack([torch.from_numpy(data[i:i+seq_len].astype(np.int64)) for i in ix]).to(device) y = torch.stack([torch.from_numpy(data[i+1:i+1+seq_len].astype(np.int64)) for i in ix]).to(device) return x, y # ---------------- Test A: Layer ablation ---------------- def layer_ablation_bpc(m, val_data, n_batches=20, bs=32, seq_len=256, device='cuda'): """Zero out each layer's contribution (residual-only), measure BPC delta.""" # Baseline m.eval() base_losses = [] with torch.no_grad(): for _ in range(n_batches): x, y = sample_eval_batch(val_data, bs, seq_len, device) _, loss = m(x, y) base_losses.append(loss.item()) base_bpc = float(np.mean(base_losses)) / math.log(2) # For each layer, replace its forward with identity (skip connection only) results = [] for li in range(len(m.blocks)): original = m.blocks[li].forward # Wrap forward to return x unchanged m.blocks[li].forward = lambda x: x with torch.no_grad(): abl_losses = [] for _ in range(n_batches): x, y = sample_eval_batch(val_data, bs, seq_len, device) _, loss = m(x, y) abl_losses.append(loss.item()) m.blocks[li].forward = original abl_bpc = float(np.mean(abl_losses)) / math.log(2) results.append({'layer': li, 'baseline_bpc': base_bpc, 'ablated_bpc': abl_bpc, 'delta_bpc': abl_bpc - base_bpc}) return {'baseline_bpc': base_bpc, 'per_layer': results} # ---------------- Test B: Weight saturation / flip-flop potential ---------------- def weight_saturation(m): """For each 2D weight tensor, compute the distribution of |latent|. High |latent| = 'locked sign' (won't flip easily). Near zero = 'flippable'. Returns per-parameter distribution stats. """ stats = [] for name, p in m.named_parameters(): if p.dim() < 2: continue with torch.no_grad(): abs_vals = p.abs().flatten() stats.append({ 'name': name, 'shape': list(p.shape), 'n': abs_vals.numel(), 'mean': abs_vals.mean().item(), 'median': abs_vals.median().item(), 'q10': abs_vals.quantile(0.1).item(), 'q90': abs_vals.quantile(0.9).item(), 'q99': abs_vals.quantile(0.99).item(), 'frac_below_0.01': (abs_vals < 0.01).float().mean().item(), 'frac_below_0.05': (abs_vals < 0.05).float().mean().item(), 'frac_above_0.5': (abs_vals > 0.5).float().mean().item(), 'max': abs_vals.max().item(), }) return stats # ---------------- Test C: Attention entropy per head/layer ---------------- def attention_entropy(m, val_data, n_batches=5, bs=8, seq_len=256, device='cuda'): """For each layer and head, compute the entropy of attention-weight distribution averaged over queries. Entropy should be log(T) for uniform, 0 for argmax. For our Gumbel hard-attention, score distribution is what matters. We compute entropy of the *softmax* of raw integer scores (sharpness proxy).""" from model_v16 import _get_tau results = [] with torch.no_grad(): for li, blk in enumerate(m.blocks): attn = blk.attn per_head_entropies = [] per_head_max_score = [] for _ in range(n_batches): x, _ = sample_eval_batch(val_data, bs, seq_len, device) # Mirror attention forward but capture scores xe = m.embed(x) for k in range(li): xe = m.blocks[k](xe) B, T, D = xe.shape H, Dh = attn.n_heads, attn.head_dim Q = attn.q_proj(xe).view(B, T, H, Dh).transpose(1, 2) K = attn.k_proj(xe).view(B, T, H, Dh).transpose(1, 2) scores = torch.matmul(Q, K.transpose(-2, -1)) pos = torch.arange(T, device=device).float() dist = (pos.unsqueeze(0) - pos.unsqueeze(1)).abs() alibi = attn.alibi_slopes_int.view(1, H, 1, 1).float() * dist.view(1, 1, T, T) scores = scores - alibi mask = torch.triu(torch.ones(T, T, device=device, dtype=torch.bool), diagonal=1) scores = scores.masked_fill(mask, -1e9) # Per head: take argmax concentration = max softmax prob averaged over queries probs = F.softmax(scores, dim=-1) # (B, H, T, T) # For each (q, h), compute max prob and entropy max_p = probs.max(dim=-1).values # (B, H, T) entropies = -(probs * probs.clamp(min=1e-9).log()).sum(dim=-1) # (B, H, T) per_head_entropies.append(entropies.mean(dim=(0, 2)).cpu().numpy()) per_head_max_score.append(max_p.mean(dim=(0, 2)).cpu().numpy()) ph_ent = np.stack(per_head_entropies).mean(axis=0) ph_maxp = np.stack(per_head_max_score).mean(axis=0) results.append({'layer': li, 'entropy_per_head': ph_ent.tolist(), 'max_prob_per_head': ph_maxp.tolist(), 'mean_entropy': float(ph_ent.mean()), 'mean_max_prob': float(ph_maxp.mean())}) return results # ---------------- Test D: Student-teacher representation similarity ---------------- def student_teacher_similarity(student_m, teacher_m, val_data, n_batches=5, bs=16, seq_len=256, device='cuda'): """Per-layer: how well does the student's ±1 hidden state match sign(teacher hidden)?""" student_m.eval(); teacher_m.eval() n_layers_s = len(student_m.blocks) n_layers_t = len(teacher_m.blocks) # We assume aligned architectures (student layers == teacher layers) sims = [[] for _ in range(min(n_layers_s, n_layers_t))] with torch.no_grad(): for _ in range(n_batches): x, _ = sample_eval_batch(val_data, bs, seq_len, device) # Student path with hidden snapshots s = student_m.embed(x) s_hiddens = [] for blk in student_m.blocks: s = blk(s) s_hiddens.append(s.clone()) # ±1 valued # Teacher path T_ids = x.shape[1] t_pos = torch.arange(T_ids, device=device) t = teacher_m.embed(x) + teacher_m.pos(t_pos) t_hiddens = [] for blk in teacher_m.blocks: t = blk(t) t_hiddens.append(t.clone()) # Compare: student vs sign(teacher) for i in range(min(n_layers_s, n_layers_t)): tg = torch.sign(t_hiddens[i]) tg[tg == 0] = 1 s_flat = s_hiddens[i].reshape(-1, s_hiddens[i].shape[-1]) t_flat = tg.reshape(-1, tg.shape[-1]) # Cosine similarity: (a · b) / (|a| |b|); for ±1 it simplifies to # agreement fraction × 2 - 1 agree = (s_flat == t_flat).float().mean().item() sims[i].append(agree) per_layer = [{'layer': i, 'sign_agreement': float(np.mean(sims[i]))} for i in range(len(sims))] return per_layer # ---------------- Test E: Prediction error breakdown ---------------- def error_breakdown(m, val_data, n_batches=20, bs=32, seq_len=256, device='cuda'): """Classify errors by character class.""" m.eval() per_char_correct = np.zeros(128) per_char_total = np.zeros(128) class_groups = { 'space': {32}, 'newline': {10}, 'lowercase': set(range(97, 123)), 'uppercase': set(range(65, 91)), 'digit': set(range(48, 58)), 'punct': {46, 44, 33, 63, 39, 34, 58, 59, 40, 41, 45}, } with torch.no_grad(): for _ in range(n_batches): x, y = sample_eval_batch(val_data, bs, seq_len, device) logits, _ = m(x, y) pred = logits.argmax(dim=-1) for i in range(y.numel()): t = y.flatten()[i].item() p = pred.flatten()[i].item() if t < 128: per_char_total[t] += 1 if p == t: per_char_correct[t] += 1 per_class_acc = {} for name, chars in class_groups.items(): tot = sum(per_char_total[c] for c in chars) cor = sum(per_char_correct[c] for c in chars) per_class_acc[name] = {'accuracy': cor / max(tot, 1), 'n': int(tot)} overall_tot = per_char_total.sum() overall_cor = per_char_correct.sum() return {'overall_accuracy': float(overall_cor / max(overall_tot, 1)), 'per_class': per_class_acc} # ---------------- Main ---------------- def main(): ap = argparse.ArgumentParser() ap.add_argument('--student-ckpt', required=True) ap.add_argument('--teacher-ckpt', default=None) ap.add_argument('--data', default='/root/bitnet1/data/validation.bin') ap.add_argument('--out', required=True) ap.add_argument('--tau-eval', type=float, default=0.1, help='Gumbel tau used for eval-mode forwards.') args = ap.parse_args() set_gumbel_tau(args.tau_eval) val = np.memmap(args.data, dtype=np.uint8, mode='r') print(f"Loading student {args.student_ckpt}") student, s_ck = load_binary_ckpt(args.student_ckpt) s_cfg = s_ck['args'] out = { 'student_ckpt': args.student_ckpt, 'student_config': s_cfg, 'student_step': s_ck.get('step'), 'student_val_bpc': s_ck.get('val_bpc'), 'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'), } print("A. Layer ablation BPC...") out['layer_ablation'] = layer_ablation_bpc(student, val) print(f" baseline {out['layer_ablation']['baseline_bpc']:.4f}, {len(out['layer_ablation']['per_layer'])} layers") print("B. Weight saturation...") out['weight_saturation'] = weight_saturation(student) print(f" {len(out['weight_saturation'])} weight tensors analyzed") print("C. Attention entropy...") out['attention_entropy'] = attention_entropy(student, val) print(f" {len(out['attention_entropy'])} layers analyzed") print("E. Error breakdown...") out['error_breakdown'] = error_breakdown(student, val) print(f" overall acc {out['error_breakdown']['overall_accuracy']:.4f}") if args.teacher_ckpt: print(f"Loading teacher {args.teacher_ckpt}") teacher, t_ck = load_fp32_ckpt(args.teacher_ckpt) out['teacher_ckpt'] = args.teacher_ckpt out['teacher_val_bpc'] = t_ck.get('val_bpc') print("D. Student-teacher similarity...") out['student_teacher_similarity'] = student_teacher_similarity(student, teacher, val) print(f" done") with open(args.out, 'w') as f: json.dump(out, f, indent=2, default=str) print(f"Wrote {args.out}") if __name__ == '__main__': main()