File size: 12,311 Bytes
4754707 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 | """Diagnostic tests on trained binary and FP32 models.
Outputs structured JSON that analyze_report.py compiles into a readable report.
Each test tries to reveal *mechanism*, not just measure BPC.
"""
import argparse, json, math, os, time
import numpy as np
import torch
import torch.nn.functional as F
from model_v18 import BitLMv18
from model_fp32 import FP32LM
from model_v16 import set_gumbel_tau
def load_binary_ckpt(path, device='cuda'):
ck = torch.load(path, map_location=device, weights_only=False)
cfg = ck['args']
m = BitLMv18(
vocab_size=cfg['vocab_size'], d_model=cfg['d_model'], n_layers=cfg['n_layers'],
n_heads=cfg['n_heads'], d_ff=cfg['d_ff'], max_seq_len=cfg['seq_len'],
).to(device)
m.load_state_dict(ck['model'])
m.eval()
return m, ck
def load_fp32_ckpt(path, device='cuda'):
ck = torch.load(path, map_location=device, weights_only=False)
cfg = ck['args']
m = FP32LM(
vocab_size=cfg['vocab_size'], d_model=cfg['d_model'], n_layers=cfg['n_layers'],
n_heads=cfg['n_heads'], d_ff=cfg['d_ff'], max_seq_len=cfg['seq_len'],
).to(device)
m.load_state_dict(ck['model'])
m.eval()
return m, ck
def sample_eval_batch(data, batch_size, seq_len, device='cuda'):
ix = torch.randint(0, len(data) - seq_len - 1, (batch_size,))
x = torch.stack([torch.from_numpy(data[i:i+seq_len].astype(np.int64)) for i in ix]).to(device)
y = torch.stack([torch.from_numpy(data[i+1:i+1+seq_len].astype(np.int64)) for i in ix]).to(device)
return x, y
# ---------------- Test A: Layer ablation ----------------
def layer_ablation_bpc(m, val_data, n_batches=20, bs=32, seq_len=256, device='cuda'):
"""Zero out each layer's contribution (residual-only), measure BPC delta."""
# Baseline
m.eval()
base_losses = []
with torch.no_grad():
for _ in range(n_batches):
x, y = sample_eval_batch(val_data, bs, seq_len, device)
_, loss = m(x, y)
base_losses.append(loss.item())
base_bpc = float(np.mean(base_losses)) / math.log(2)
# For each layer, replace its forward with identity (skip connection only)
results = []
for li in range(len(m.blocks)):
original = m.blocks[li].forward
# Wrap forward to return x unchanged
m.blocks[li].forward = lambda x: x
with torch.no_grad():
abl_losses = []
for _ in range(n_batches):
x, y = sample_eval_batch(val_data, bs, seq_len, device)
_, loss = m(x, y)
abl_losses.append(loss.item())
m.blocks[li].forward = original
abl_bpc = float(np.mean(abl_losses)) / math.log(2)
results.append({'layer': li, 'baseline_bpc': base_bpc, 'ablated_bpc': abl_bpc,
'delta_bpc': abl_bpc - base_bpc})
return {'baseline_bpc': base_bpc, 'per_layer': results}
# ---------------- Test B: Weight saturation / flip-flop potential ----------------
def weight_saturation(m):
"""For each 2D weight tensor, compute the distribution of |latent|.
High |latent| = 'locked sign' (won't flip easily). Near zero = 'flippable'.
Returns per-parameter distribution stats.
"""
stats = []
for name, p in m.named_parameters():
if p.dim() < 2: continue
with torch.no_grad():
abs_vals = p.abs().flatten()
stats.append({
'name': name, 'shape': list(p.shape), 'n': abs_vals.numel(),
'mean': abs_vals.mean().item(),
'median': abs_vals.median().item(),
'q10': abs_vals.quantile(0.1).item(),
'q90': abs_vals.quantile(0.9).item(),
'q99': abs_vals.quantile(0.99).item(),
'frac_below_0.01': (abs_vals < 0.01).float().mean().item(),
'frac_below_0.05': (abs_vals < 0.05).float().mean().item(),
'frac_above_0.5': (abs_vals > 0.5).float().mean().item(),
'max': abs_vals.max().item(),
})
return stats
# ---------------- Test C: Attention entropy per head/layer ----------------
def attention_entropy(m, val_data, n_batches=5, bs=8, seq_len=256, device='cuda'):
"""For each layer and head, compute the entropy of attention-weight distribution
averaged over queries. Entropy should be log(T) for uniform, 0 for argmax.
For our Gumbel hard-attention, score distribution is what matters. We compute
entropy of the *softmax* of raw integer scores (sharpness proxy)."""
from model_v16 import _get_tau
results = []
with torch.no_grad():
for li, blk in enumerate(m.blocks):
attn = blk.attn
per_head_entropies = []
per_head_max_score = []
for _ in range(n_batches):
x, _ = sample_eval_batch(val_data, bs, seq_len, device)
# Mirror attention forward but capture scores
xe = m.embed(x)
for k in range(li):
xe = m.blocks[k](xe)
B, T, D = xe.shape
H, Dh = attn.n_heads, attn.head_dim
Q = attn.q_proj(xe).view(B, T, H, Dh).transpose(1, 2)
K = attn.k_proj(xe).view(B, T, H, Dh).transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1))
pos = torch.arange(T, device=device).float()
dist = (pos.unsqueeze(0) - pos.unsqueeze(1)).abs()
alibi = attn.alibi_slopes_int.view(1, H, 1, 1).float() * dist.view(1, 1, T, T)
scores = scores - alibi
mask = torch.triu(torch.ones(T, T, device=device, dtype=torch.bool), diagonal=1)
scores = scores.masked_fill(mask, -1e9)
# Per head: take argmax concentration = max softmax prob averaged over queries
probs = F.softmax(scores, dim=-1) # (B, H, T, T)
# For each (q, h), compute max prob and entropy
max_p = probs.max(dim=-1).values # (B, H, T)
entropies = -(probs * probs.clamp(min=1e-9).log()).sum(dim=-1) # (B, H, T)
per_head_entropies.append(entropies.mean(dim=(0, 2)).cpu().numpy())
per_head_max_score.append(max_p.mean(dim=(0, 2)).cpu().numpy())
ph_ent = np.stack(per_head_entropies).mean(axis=0)
ph_maxp = np.stack(per_head_max_score).mean(axis=0)
results.append({'layer': li,
'entropy_per_head': ph_ent.tolist(),
'max_prob_per_head': ph_maxp.tolist(),
'mean_entropy': float(ph_ent.mean()),
'mean_max_prob': float(ph_maxp.mean())})
return results
# ---------------- Test D: Student-teacher representation similarity ----------------
def student_teacher_similarity(student_m, teacher_m, val_data, n_batches=5, bs=16, seq_len=256, device='cuda'):
"""Per-layer: how well does the student's ±1 hidden state match sign(teacher hidden)?"""
student_m.eval(); teacher_m.eval()
n_layers_s = len(student_m.blocks)
n_layers_t = len(teacher_m.blocks)
# We assume aligned architectures (student layers == teacher layers)
sims = [[] for _ in range(min(n_layers_s, n_layers_t))]
with torch.no_grad():
for _ in range(n_batches):
x, _ = sample_eval_batch(val_data, bs, seq_len, device)
# Student path with hidden snapshots
s = student_m.embed(x)
s_hiddens = []
for blk in student_m.blocks:
s = blk(s)
s_hiddens.append(s.clone()) # ±1 valued
# Teacher path
T_ids = x.shape[1]
t_pos = torch.arange(T_ids, device=device)
t = teacher_m.embed(x) + teacher_m.pos(t_pos)
t_hiddens = []
for blk in teacher_m.blocks:
t = blk(t)
t_hiddens.append(t.clone())
# Compare: student vs sign(teacher)
for i in range(min(n_layers_s, n_layers_t)):
tg = torch.sign(t_hiddens[i])
tg[tg == 0] = 1
s_flat = s_hiddens[i].reshape(-1, s_hiddens[i].shape[-1])
t_flat = tg.reshape(-1, tg.shape[-1])
# Cosine similarity: (a · b) / (|a| |b|); for ±1 it simplifies to
# agreement fraction × 2 - 1
agree = (s_flat == t_flat).float().mean().item()
sims[i].append(agree)
per_layer = [{'layer': i, 'sign_agreement': float(np.mean(sims[i]))}
for i in range(len(sims))]
return per_layer
# ---------------- Test E: Prediction error breakdown ----------------
def error_breakdown(m, val_data, n_batches=20, bs=32, seq_len=256, device='cuda'):
"""Classify errors by character class."""
m.eval()
per_char_correct = np.zeros(128)
per_char_total = np.zeros(128)
class_groups = {
'space': {32},
'newline': {10},
'lowercase': set(range(97, 123)),
'uppercase': set(range(65, 91)),
'digit': set(range(48, 58)),
'punct': {46, 44, 33, 63, 39, 34, 58, 59, 40, 41, 45},
}
with torch.no_grad():
for _ in range(n_batches):
x, y = sample_eval_batch(val_data, bs, seq_len, device)
logits, _ = m(x, y)
pred = logits.argmax(dim=-1)
for i in range(y.numel()):
t = y.flatten()[i].item()
p = pred.flatten()[i].item()
if t < 128:
per_char_total[t] += 1
if p == t: per_char_correct[t] += 1
per_class_acc = {}
for name, chars in class_groups.items():
tot = sum(per_char_total[c] for c in chars)
cor = sum(per_char_correct[c] for c in chars)
per_class_acc[name] = {'accuracy': cor / max(tot, 1), 'n': int(tot)}
overall_tot = per_char_total.sum()
overall_cor = per_char_correct.sum()
return {'overall_accuracy': float(overall_cor / max(overall_tot, 1)),
'per_class': per_class_acc}
# ---------------- Main ----------------
def main():
ap = argparse.ArgumentParser()
ap.add_argument('--student-ckpt', required=True)
ap.add_argument('--teacher-ckpt', default=None)
ap.add_argument('--data', default='/root/bitnet1/data/validation.bin')
ap.add_argument('--out', required=True)
ap.add_argument('--tau-eval', type=float, default=0.1,
help='Gumbel tau used for eval-mode forwards.')
args = ap.parse_args()
set_gumbel_tau(args.tau_eval)
val = np.memmap(args.data, dtype=np.uint8, mode='r')
print(f"Loading student {args.student_ckpt}")
student, s_ck = load_binary_ckpt(args.student_ckpt)
s_cfg = s_ck['args']
out = {
'student_ckpt': args.student_ckpt,
'student_config': s_cfg,
'student_step': s_ck.get('step'),
'student_val_bpc': s_ck.get('val_bpc'),
'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
}
print("A. Layer ablation BPC...")
out['layer_ablation'] = layer_ablation_bpc(student, val)
print(f" baseline {out['layer_ablation']['baseline_bpc']:.4f}, {len(out['layer_ablation']['per_layer'])} layers")
print("B. Weight saturation...")
out['weight_saturation'] = weight_saturation(student)
print(f" {len(out['weight_saturation'])} weight tensors analyzed")
print("C. Attention entropy...")
out['attention_entropy'] = attention_entropy(student, val)
print(f" {len(out['attention_entropy'])} layers analyzed")
print("E. Error breakdown...")
out['error_breakdown'] = error_breakdown(student, val)
print(f" overall acc {out['error_breakdown']['overall_accuracy']:.4f}")
if args.teacher_ckpt:
print(f"Loading teacher {args.teacher_ckpt}")
teacher, t_ck = load_fp32_ckpt(args.teacher_ckpt)
out['teacher_ckpt'] = args.teacher_ckpt
out['teacher_val_bpc'] = t_ck.get('val_bpc')
print("D. Student-teacher similarity...")
out['student_teacher_similarity'] = student_teacher_similarity(student, teacher, val)
print(f" done")
with open(args.out, 'w') as f:
json.dump(out, f, indent=2, default=str)
print(f"Wrote {args.out}")
if __name__ == '__main__':
main()
|