| """Test-time ensembling: K forward passes with Gumbel noise, averaged logits. |
| |
| Monkey-patches gumbel_hard_attention to always add Gumbel noise (even in eval), |
| then averages K independent passes. |
| """ |
| import argparse, math, os, json |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
|
|
| import model_v16 as _v16 |
| from model_v16 import _get_tau |
| from model_v57 import BitLMv57 |
|
|
|
|
| def _force_gumbel(scores, mask=None): |
| tau = _get_tau(scores.device) |
| if mask is not None: |
| scores = scores.masked_fill(mask, -1e9) |
| g = -torch.log(-torch.log(torch.rand_like(scores).clamp(min=1e-9)) + 1e-9) |
| y_soft = F.softmax((scores + g) / tau, dim=-1) |
| y_hard = torch.zeros_like(y_soft).scatter_(-1, y_soft.argmax(-1, keepdim=True), 1.0) |
| return y_hard |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument('--ckpt', required=True) |
| ap.add_argument('--data', default='/root/bitnet1/data/validation.bin') |
| ap.add_argument('--K', type=int, default=16) |
| ap.add_argument('--iters', type=int, default=40) |
| ap.add_argument('--batch-size', type=int, default=64) |
| ap.add_argument('--tau', type=float, default=0.5) |
| args = ap.parse_args() |
|
|
| |
| ck = torch.load(args.ckpt, map_location='cuda', weights_only=False) |
| cfg = ck['args'] |
| m = BitLMv57(vocab_size=128, d_model=cfg['d_model'], n_layers=cfg['n_layers'], |
| n_heads=cfg['n_heads'], d_ff=cfg['d_ff'], max_seq_len=cfg['seq_len']).cuda() |
| m.load_state_dict(ck['model']) |
| m.eval() |
| _v16.set_gumbel_tau(args.tau) |
|
|
| seq_len = cfg['seq_len'] |
| val = np.memmap(args.data, dtype=np.uint8, mode='r') |
|
|
| |
| losses_argmax = [] |
| for _ in range(args.iters): |
| ix = torch.randint(0, len(val) - seq_len - 1, (args.batch_size,)) |
| X = torch.stack([torch.from_numpy(val[s:s+seq_len].astype(np.int64)) for s in ix.tolist()]).cuda() |
| Y = torch.stack([torch.from_numpy(val[s+1:s+1+seq_len].astype(np.int64)) for s in ix.tolist()]).cuda() |
| with torch.no_grad(): |
| _, loss = m(X, Y) |
| losses_argmax.append(loss.item()) |
|
|
| argmax_loss = float(np.mean(losses_argmax)) |
| print(f'baseline (argmax) loss={argmax_loss:.4f} BPC={argmax_loss/math.log(2):.4f}') |
|
|
| |
| import model_v16 as _v16m |
| orig_gumbel = _v16m.gumbel_hard_attention |
| _v16m.gumbel_hard_attention = _force_gumbel |
|
|
| losses_ens = [] |
| for _ in range(args.iters): |
| ix = torch.randint(0, len(val) - seq_len - 1, (args.batch_size,)) |
| X = torch.stack([torch.from_numpy(val[s:s+seq_len].astype(np.int64)) for s in ix.tolist()]).cuda() |
| Y = torch.stack([torch.from_numpy(val[s+1:s+1+seq_len].astype(np.int64)) for s in ix.tolist()]).cuda() |
| log_ps = [] |
| with torch.no_grad(): |
| for _ in range(args.K): |
| logits, _ = m(X) |
| log_ps.append(F.log_softmax(logits, dim=-1)) |
| avg_log_p = torch.logsumexp(torch.stack(log_ps, dim=0), dim=0) - math.log(args.K) |
| log_p_correct = avg_log_p.gather(-1, Y.unsqueeze(-1)).squeeze(-1) |
| ens_loss = (-log_p_correct).mean() |
| losses_ens.append(ens_loss.item()) |
|
|
| _v16m.gumbel_hard_attention = orig_gumbel |
| ens_loss = float(np.mean(losses_ens)) |
| print(f'ensemble K={args.K} tau={args.tau} loss={ens_loss:.4f} BPC={ens_loss/math.log(2):.4f}') |
| print(f'delta BPC: {(ens_loss - argmax_loss)/math.log(2):+.4f}') |
|
|
| result = { |
| 'ckpt': args.ckpt, 'K': args.K, 'tau': args.tau, |
| 'argmax_bpc': argmax_loss/math.log(2), |
| 'ensemble_bpc': ens_loss/math.log(2), |
| } |
| print(json.dumps(result)) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|