File size: 16,516 Bytes
4754707
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
"""Deep diagnostic tests for understanding a trained binary LM.

Beyond the first-pass analysis:
  F. Per-head attention pattern classification (recent / first-token / content / long)
  G. Position-wise BPC — how does BPC depend on position in the sequence?
  H. Context-length sweep — BPC as a function of how much context we give
  I. Layer-wise CKA similarity — which layers carry redundant information?
  J. Logit margin distribution — how confident is the model on right vs wrong?
  K. Per-head knockout — which heads are load-bearing?
  L. Effective parameter count — how many weights actually move the output?
  M. Character embedding clustering — do similar chars cluster in ±1 space?
  N. Bit-flip robustness — how much does one random flip cost?
"""
import argparse, json, math, os, time
import numpy as np
import torch
import torch.nn.functional as F

from model_v18 import BitLMv18
from model_fp32 import FP32LM
from model_v16 import set_gumbel_tau


def load_binary(path, device='cuda'):
    ck = torch.load(path, map_location=device, weights_only=False)
    cfg = ck['args']
    m = BitLMv18(vocab_size=cfg['vocab_size'], d_model=cfg['d_model'],
                 n_layers=cfg['n_layers'], n_heads=cfg['n_heads'],
                 d_ff=cfg['d_ff'], max_seq_len=cfg['seq_len']).to(device)
    m.load_state_dict(ck['model'])
    m.eval()
    return m, ck


def sample_batch(data, batch_size, seq_len, device='cuda'):
    ix = torch.randint(0, len(data) - seq_len - 1, (batch_size,))
    x = torch.stack([torch.from_numpy(data[i:i+seq_len].astype(np.int64)) for i in ix]).to(device)
    y = torch.stack([torch.from_numpy(data[i+1:i+1+seq_len].astype(np.int64)) for i in ix]).to(device)
    return x, y


# ---------------- F: Attention head pattern ----------------
@torch.no_grad()
def head_attention_patterns(m, val, n_batches=5, bs=8, seq_len=256, device='cuda'):
    """Classify each (layer, head) by where it attends:
       recent = mean(|i-j|) small
       long-range = mean(|i-j|) large
       first-token = argmax often = 0
       content-sensitive = variance of argmax across identical positions
    """
    results = []
    with torch.no_grad():
        for li, blk in enumerate(m.blocks):
            attn = blk.attn
            H, Dh = attn.n_heads, attn.head_dim
            dists_per_head = [[] for _ in range(H)]
            first_tok_per_head = [[] for _ in range(H)]
            var_per_head = [[] for _ in range(H)]
            for _ in range(n_batches):
                x, _ = sample_batch(val, bs, seq_len, device)
                xe = m.embed(x)
                for k in range(li):
                    xe = m.blocks[k](xe)
                B, T, D = xe.shape
                Q = attn.q_proj(xe).view(B, T, H, Dh).transpose(1, 2)
                K = attn.k_proj(xe).view(B, T, H, Dh).transpose(1, 2)
                scores = torch.matmul(Q, K.transpose(-2, -1))
                pos = torch.arange(T, device=device).float()
                dist = (pos.unsqueeze(0) - pos.unsqueeze(1)).abs()
                alibi = attn.alibi_slopes_int.view(1, H, 1, 1).float() * dist.view(1, 1, T, T)
                scores = scores - alibi
                mask = torch.triu(torch.ones(T, T, device=device, dtype=torch.bool), diagonal=1)
                scores = scores.masked_fill(mask, -1e9)
                argmax_keys = scores.argmax(dim=-1)  # (B, H, T)
                for h in range(H):
                    ak = argmax_keys[:, h, :]  # (B, T)
                    # Average distance (only valid positions where i >= 0)
                    pos_t = torch.arange(T, device=device).unsqueeze(0).expand(B, -1)
                    d = (pos_t - ak).abs().float()
                    # Only count positions where attention is meaningful (j != -inf masked)
                    dists_per_head[h].append(d.mean().item())
                    first_tok_per_head[h].append((ak == 0).float().mean().item())
                    # Content variance: for the LAST position, how much does the choice
                    # vary across different inputs? High variance = content-sensitive
                    last_pos_ak = ak[:, T // 2]  # mid position
                    var_per_head[h].append(last_pos_ak.float().std().item())

            for h in range(H):
                mean_dist = np.mean(dists_per_head[h])
                first_frac = np.mean(first_tok_per_head[h])
                content_var = np.mean(var_per_head[h])
                # Classify
                if first_frac > 0.5:
                    kind = 'first-token-sink'
                elif mean_dist < 3:
                    kind = 'recent'
                elif mean_dist > seq_len / 4:
                    kind = 'long-range'
                elif content_var > 5:
                    kind = 'content-sensitive'
                else:
                    kind = 'positional'
                results.append({'layer': li, 'head': h,
                                'mean_dist': float(mean_dist),
                                'first_tok_frac': float(first_frac),
                                'content_var': float(content_var),
                                'kind': kind,
                                'alibi_slope': int(attn.alibi_slopes_int[h].item())})
    return results


# ---------------- G: Position-wise BPC ----------------
@torch.no_grad()
def position_bpc(m, val, n_batches=20, bs=32, seq_len=256, device='cuda'):
    """BPC per position in the sequence, averaged over batches."""
    loss_sum = torch.zeros(seq_len, device=device)
    loss_cnt = torch.zeros(seq_len, device=device)
    for _ in range(n_batches):
        x, y = sample_batch(val, bs, seq_len, device)
        logits, _ = m(x, y)
        losses = F.cross_entropy(logits.permute(0, 2, 1), y, reduction='none')  # (B, T)
        loss_sum += losses.sum(dim=0)
        loss_cnt += losses.shape[0]
    avg = (loss_sum / loss_cnt).cpu().numpy() / math.log(2)
    return {'bpc_per_position': avg.tolist(),
            'bpc_quartile_starts': [float(avg[:seq_len//4].mean()),
                                    float(avg[seq_len//4:seq_len//2].mean()),
                                    float(avg[seq_len//2:3*seq_len//4].mean()),
                                    float(avg[3*seq_len//4:].mean())]}


# ---------------- H: Context-length sweep ----------------
@torch.no_grad()
def context_length_sweep(m, val, n_batches=20, bs=32, seq_len=256, device='cuda'):
    """For held-out data, BPC at different context lengths. Prediction position = last."""
    results = []
    ctx_lens = [1, 4, 16, 64, 128, 256]
    for cl in ctx_lens:
        if cl > seq_len: continue
        losses = []
        for _ in range(n_batches):
            x, y = sample_batch(val, bs, seq_len, device)
            x_ctx = x[:, :cl]
            y_target = y[:, cl - 1:cl]
            logits, _ = m(x_ctx)
            # predict the last position
            pred_logits = logits[:, -1, :]
            loss = F.cross_entropy(pred_logits, y_target.squeeze(-1))
            losses.append(loss.item())
        avg = float(np.mean(losses)) / math.log(2)
        results.append({'context_len': cl, 'bpc_last_position': avg})
    return results


# ---------------- I: Layer-wise CKA similarity ----------------
@torch.no_grad()
def layer_similarity(m, val, n_batches=5, bs=16, seq_len=256, device='cuda'):
    """Centered Kernel Alignment between hidden states at each pair of layers.
    High = redundant layers."""
    n_layers = len(m.blocks)
    # Collect hidden states
    H_all = [[] for _ in range(n_layers)]
    for _ in range(n_batches):
        x, _ = sample_batch(val, bs, seq_len, device)
        xe = m.embed(x)
        for li, blk in enumerate(m.blocks):
            xe = blk(xe)
            H_all[li].append(xe.reshape(-1, xe.shape[-1]).float().cpu())
    # For CKA, we need large matrices; compute cross-layer similarity via
    # simple agreement (both are ±1) for efficiency.
    agree = np.zeros((n_layers, n_layers))
    for i in range(n_layers):
        hi = torch.cat(H_all[i], dim=0)
        for j in range(n_layers):
            hj = torch.cat(H_all[j], dim=0)
            # Cosine-ish: for ±1 vectors, row-averaged per-token agreement
            # Here we want COLUMN-wise (dimension-wise) correlation
            # Simpler: just mean element-wise agreement
            agree[i, j] = (hi == hj).float().mean().item()
    return {'similarity_matrix': agree.tolist()}


# ---------------- J: Logit margin distribution ----------------
@torch.no_grad()
def logit_margin_distribution(m, val, n_batches=20, bs=32, seq_len=256, device='cuda'):
    """For correct vs incorrect predictions, distribution of top1-top2 logit margin."""
    correct_margins = []
    wrong_margins = []
    wrong_top2_correct = 0  # fraction of wrong predictions where correct is top-2
    total_wrong = 0
    for _ in range(n_batches):
        x, y = sample_batch(val, bs, seq_len, device)
        logits, _ = m(x, y)
        y_flat = y.view(-1)
        l_flat = logits.view(-1, logits.shape[-1])
        pred = l_flat.argmax(dim=-1)
        correct_mask = (pred == y_flat)
        # top1 - top2 margin
        sorted_vals, sorted_idx = torch.topk(l_flat, 2, dim=-1)
        margin = (sorted_vals[:, 0] - sorted_vals[:, 1]).cpu().numpy()
        cm = margin[correct_mask.cpu().numpy()]
        wm = margin[~correct_mask.cpu().numpy()]
        correct_margins.append(cm)
        wrong_margins.append(wm)
        # For wrong preds, is correct in top 2?
        wrong_mask = ~correct_mask
        top2 = sorted_idx[:, 1]
        wrong_top2_correct += (top2[wrong_mask] == y_flat[wrong_mask]).float().sum().item()
        total_wrong += wrong_mask.sum().item()
    correct_margins = np.concatenate(correct_margins)
    wrong_margins = np.concatenate(wrong_margins)
    return {
        'correct_count': int(correct_margins.size),
        'wrong_count': int(wrong_margins.size),
        'correct_margin_mean': float(correct_margins.mean()),
        'correct_margin_median': float(np.median(correct_margins)),
        'wrong_margin_mean': float(wrong_margins.mean()),
        'wrong_margin_median': float(np.median(wrong_margins)),
        'wrong_frac_correct_in_top2': wrong_top2_correct / max(1, total_wrong),
    }


# ---------------- K: Per-head knockout ----------------
@torch.no_grad()
def per_head_knockout(m, val, n_batches=10, bs=32, seq_len=256, device='cuda'):
    """Zero out each individual attention head, measure BPC delta."""
    # Baseline
    base_losses = []
    for _ in range(n_batches):
        x, y = sample_batch(val, bs, seq_len, device)
        _, loss = m(x, y)
        base_losses.append(loss.item())
    base_bpc = float(np.mean(base_losses)) / math.log(2)

    results = []
    for li, blk in enumerate(m.blocks):
        attn = blk.attn
        H = attn.n_heads
        Dh = attn.head_dim
        orig = attn.forward
        for h_idx in range(H):
            # Wrap attention to zero-out head h_idx
            def make_wrapped(blk_ref, head_to_zero):
                def wrapped(x_in):
                    out = orig(x_in)
                    # Head h_idx occupies bits [h*Dh : (h+1)*Dh] in d_model
                    # Zero that slice in the ±1 output
                    B, T, D = out.shape
                    start = head_to_zero * Dh
                    end = start + Dh
                    out = out.clone()
                    out[..., start:end] = 0  # 0 is "null" not ±1, breaks strictness but OK for analysis
                    return out
                return wrapped
            attn.forward = make_wrapped(attn, h_idx)
            ko_losses = []
            for _ in range(n_batches):
                x, y = sample_batch(val, bs, seq_len, device)
                _, loss = m(x, y)
                ko_losses.append(loss.item())
            attn.forward = orig
            ko_bpc = float(np.mean(ko_losses)) / math.log(2)
            results.append({'layer': li, 'head': h_idx,
                            'baseline_bpc': base_bpc,
                            'knockout_bpc': ko_bpc,
                            'delta_bpc': ko_bpc - base_bpc})
    return {'baseline_bpc': base_bpc, 'per_head': results}


# ---------------- L: Effective parameter count via random bit flip ----------------
@torch.no_grad()
def bit_flip_robustness(m, val, n_batches=10, bs=32, seq_len=256, device='cuda'):
    """Measure how much BPC degrades when we flip p% of latent weight signs."""
    base_losses = []
    for _ in range(n_batches):
        x, y = sample_batch(val, bs, seq_len, device)
        _, loss = m(x, y)
        base_losses.append(loss.item())
    base_bpc = float(np.mean(base_losses)) / math.log(2)

    # Collect flippable weights (2D only)
    params = [(name, p) for name, p in m.named_parameters() if p.dim() >= 2]
    results = []
    for p_flip in [0.001, 0.01, 0.05, 0.10]:
        # Save originals
        originals = [p.clone() for _, p in params]
        # Flip random fraction
        for _, p in params:
            flip_mask = torch.rand_like(p) < p_flip
            p.mul_(torch.where(flip_mask, -1.0, 1.0))
        # Measure
        flip_losses = []
        for _ in range(n_batches):
            x, y = sample_batch(val, bs, seq_len, device)
            _, loss = m(x, y)
            flip_losses.append(loss.item())
        flip_bpc = float(np.mean(flip_losses)) / math.log(2)
        # Restore
        for (_, p), orig in zip(params, originals):
            p.copy_(orig)
        results.append({'flip_fraction': p_flip,
                        'bpc_after_flip': flip_bpc,
                        'delta_bpc': flip_bpc - base_bpc})
    return {'baseline_bpc': base_bpc, 'flip_sweep': results}


# ---------------- M: Character embedding clustering ----------------
@torch.no_grad()
def char_embedding_geometry(m):
    """Compute pairwise Hamming similarity between character embedding codebooks."""
    W = torch.sign(m.embed.weight)  # (V, D)
    W[W == 0] = 1
    V, D = W.shape
    # Similarity = Hamming agreement
    sim = (W @ W.t()) / D  # value in [-1, 1]
    sim_np = sim.cpu().numpy()

    # Find clusters by looking at top-5 similar chars for a few test chars
    interest_chars = [ord(c) for c in 'aetoiAEnbz .,?!0']
    neighbors = {}
    for c in interest_chars:
        if c < V:
            vals, idx = torch.topk(sim[c], 6)  # itself + 5 neighbors
            ns = [(int(idx[k].item()), float(vals[k].item())) for k in range(6)]
            neighbors[repr(chr(c))] = ns
    return {
        'mean_abs_similarity': float(sim_np[~np.eye(V, dtype=bool)].mean()),
        'max_similarity_off_diag': float(sim_np[~np.eye(V, dtype=bool)].max()),
        'neighbors_sample': {k: [(chr(c) if 32 <= c < 127 else f'<{c}>', float(s))
                                 for c, s in v] for k, v in neighbors.items()}
    }


# ---------------- Main ----------------
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument('--ckpt', required=True)
    ap.add_argument('--data', default='/root/bitnet1/data/validation.bin')
    ap.add_argument('--out', required=True)
    ap.add_argument('--tau', type=float, default=0.1)
    args = ap.parse_args()

    set_gumbel_tau(args.tau)
    val = np.memmap(args.data, dtype=np.uint8, mode='r')
    m, ck = load_binary(args.ckpt)
    cfg = ck['args']

    out = {
        'ckpt': args.ckpt, 'config': cfg, 'val_bpc': ck.get('val_bpc'),
        'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
    }

    print("F. Attention head patterns...")
    out['head_patterns'] = head_attention_patterns(m, val)
    print(f"   {len(out['head_patterns'])} heads classified")

    print("G. Position-wise BPC...")
    out['position_bpc'] = position_bpc(m, val)
    print(f"   quartiles: {out['position_bpc']['bpc_quartile_starts']}")

    print("H. Context-length sweep...")
    out['context_sweep'] = context_length_sweep(m, val)

    print("I. Layer similarity matrix...")
    out['layer_similarity'] = layer_similarity(m, val)

    print("J. Logit margin distribution...")
    out['logit_margins'] = logit_margin_distribution(m, val)

    print("K. Per-head knockout...")
    out['head_knockout'] = per_head_knockout(m, val)

    print("L. Bit-flip robustness...")
    out['bit_flip'] = bit_flip_robustness(m, val)

    print("M. Character embedding geometry...")
    out['char_geometry'] = char_embedding_geometry(m)

    with open(args.out, 'w') as f:
        json.dump(out, f, indent=2, default=str)
    print(f"Wrote {args.out}")


if __name__ == '__main__':
    main()