File size: 3,773 Bytes
4754707
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""Test-time ensembling: K forward passes with Gumbel noise, averaged logits.

Monkey-patches gumbel_hard_attention to always add Gumbel noise (even in eval),
then averages K independent passes.
"""
import argparse, math, os, json
import numpy as np
import torch
import torch.nn.functional as F

import model_v16 as _v16
from model_v16 import _get_tau
from model_v57 import BitLMv57


def _force_gumbel(scores, mask=None):
    tau = _get_tau(scores.device)
    if mask is not None:
        scores = scores.masked_fill(mask, -1e9)
    g = -torch.log(-torch.log(torch.rand_like(scores).clamp(min=1e-9)) + 1e-9)
    y_soft = F.softmax((scores + g) / tau, dim=-1)
    y_hard = torch.zeros_like(y_soft).scatter_(-1, y_soft.argmax(-1, keepdim=True), 1.0)
    return y_hard


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument('--ckpt', required=True)
    ap.add_argument('--data', default='/root/bitnet1/data/validation.bin')
    ap.add_argument('--K', type=int, default=16)
    ap.add_argument('--iters', type=int, default=40)
    ap.add_argument('--batch-size', type=int, default=64)
    ap.add_argument('--tau', type=float, default=0.5)
    args = ap.parse_args()

    # Load
    ck = torch.load(args.ckpt, map_location='cuda', weights_only=False)
    cfg = ck['args']
    m = BitLMv57(vocab_size=128, d_model=cfg['d_model'], n_layers=cfg['n_layers'],
                 n_heads=cfg['n_heads'], d_ff=cfg['d_ff'], max_seq_len=cfg['seq_len']).cuda()
    m.load_state_dict(ck['model'])
    m.eval()
    _v16.set_gumbel_tau(args.tau)

    seq_len = cfg['seq_len']
    val = np.memmap(args.data, dtype=np.uint8, mode='r')

    # BASELINE: argmax eval (no Gumbel) — same as training eval.
    losses_argmax = []
    for _ in range(args.iters):
        ix = torch.randint(0, len(val) - seq_len - 1, (args.batch_size,))
        X = torch.stack([torch.from_numpy(val[s:s+seq_len].astype(np.int64)) for s in ix.tolist()]).cuda()
        Y = torch.stack([torch.from_numpy(val[s+1:s+1+seq_len].astype(np.int64)) for s in ix.tolist()]).cuda()
        with torch.no_grad():
            _, loss = m(X, Y)
        losses_argmax.append(loss.item())

    argmax_loss = float(np.mean(losses_argmax))
    print(f'baseline (argmax) loss={argmax_loss:.4f}  BPC={argmax_loss/math.log(2):.4f}')

    # ENSEMBLE: monkey-patch gumbel to always inject noise, then average K.
    import model_v16 as _v16m
    orig_gumbel = _v16m.gumbel_hard_attention
    _v16m.gumbel_hard_attention = _force_gumbel

    losses_ens = []
    for _ in range(args.iters):
        ix = torch.randint(0, len(val) - seq_len - 1, (args.batch_size,))
        X = torch.stack([torch.from_numpy(val[s:s+seq_len].astype(np.int64)) for s in ix.tolist()]).cuda()
        Y = torch.stack([torch.from_numpy(val[s+1:s+1+seq_len].astype(np.int64)) for s in ix.tolist()]).cuda()
        log_ps = []
        with torch.no_grad():
            for _ in range(args.K):
                logits, _ = m(X)
                log_ps.append(F.log_softmax(logits, dim=-1))
            avg_log_p = torch.logsumexp(torch.stack(log_ps, dim=0), dim=0) - math.log(args.K)
            log_p_correct = avg_log_p.gather(-1, Y.unsqueeze(-1)).squeeze(-1)
            ens_loss = (-log_p_correct).mean()
        losses_ens.append(ens_loss.item())

    _v16m.gumbel_hard_attention = orig_gumbel
    ens_loss = float(np.mean(losses_ens))
    print(f'ensemble K={args.K} tau={args.tau} loss={ens_loss:.4f}  BPC={ens_loss/math.log(2):.4f}')
    print(f'delta BPC: {(ens_loss - argmax_loss)/math.log(2):+.4f}')

    result = {
        'ckpt': args.ckpt, 'K': args.K, 'tau': args.tau,
        'argmax_bpc': argmax_loss/math.log(2),
        'ensemble_bpc': ens_loss/math.log(2),
    }
    print(json.dumps(result))


if __name__ == '__main__':
    main()