"""Test-time ensembling: K forward passes with Gumbel noise, averaged logits. Monkey-patches gumbel_hard_attention to always add Gumbel noise (even in eval), then averages K independent passes. """ import argparse, math, os, json import numpy as np import torch import torch.nn.functional as F import model_v16 as _v16 from model_v16 import _get_tau from model_v57 import BitLMv57 def _force_gumbel(scores, mask=None): tau = _get_tau(scores.device) if mask is not None: scores = scores.masked_fill(mask, -1e9) g = -torch.log(-torch.log(torch.rand_like(scores).clamp(min=1e-9)) + 1e-9) y_soft = F.softmax((scores + g) / tau, dim=-1) y_hard = torch.zeros_like(y_soft).scatter_(-1, y_soft.argmax(-1, keepdim=True), 1.0) return y_hard def main(): ap = argparse.ArgumentParser() ap.add_argument('--ckpt', required=True) ap.add_argument('--data', default='/root/bitnet1/data/validation.bin') ap.add_argument('--K', type=int, default=16) ap.add_argument('--iters', type=int, default=40) ap.add_argument('--batch-size', type=int, default=64) ap.add_argument('--tau', type=float, default=0.5) args = ap.parse_args() # Load ck = torch.load(args.ckpt, map_location='cuda', weights_only=False) cfg = ck['args'] m = BitLMv57(vocab_size=128, d_model=cfg['d_model'], n_layers=cfg['n_layers'], n_heads=cfg['n_heads'], d_ff=cfg['d_ff'], max_seq_len=cfg['seq_len']).cuda() m.load_state_dict(ck['model']) m.eval() _v16.set_gumbel_tau(args.tau) seq_len = cfg['seq_len'] val = np.memmap(args.data, dtype=np.uint8, mode='r') # BASELINE: argmax eval (no Gumbel) — same as training eval. losses_argmax = [] for _ in range(args.iters): ix = torch.randint(0, len(val) - seq_len - 1, (args.batch_size,)) X = torch.stack([torch.from_numpy(val[s:s+seq_len].astype(np.int64)) for s in ix.tolist()]).cuda() Y = torch.stack([torch.from_numpy(val[s+1:s+1+seq_len].astype(np.int64)) for s in ix.tolist()]).cuda() with torch.no_grad(): _, loss = m(X, Y) losses_argmax.append(loss.item()) argmax_loss = float(np.mean(losses_argmax)) print(f'baseline (argmax) loss={argmax_loss:.4f} BPC={argmax_loss/math.log(2):.4f}') # ENSEMBLE: monkey-patch gumbel to always inject noise, then average K. import model_v16 as _v16m orig_gumbel = _v16m.gumbel_hard_attention _v16m.gumbel_hard_attention = _force_gumbel losses_ens = [] for _ in range(args.iters): ix = torch.randint(0, len(val) - seq_len - 1, (args.batch_size,)) X = torch.stack([torch.from_numpy(val[s:s+seq_len].astype(np.int64)) for s in ix.tolist()]).cuda() Y = torch.stack([torch.from_numpy(val[s+1:s+1+seq_len].astype(np.int64)) for s in ix.tolist()]).cuda() log_ps = [] with torch.no_grad(): for _ in range(args.K): logits, _ = m(X) log_ps.append(F.log_softmax(logits, dim=-1)) avg_log_p = torch.logsumexp(torch.stack(log_ps, dim=0), dim=0) - math.log(args.K) log_p_correct = avg_log_p.gather(-1, Y.unsqueeze(-1)).squeeze(-1) ens_loss = (-log_p_correct).mean() losses_ens.append(ens_loss.item()) _v16m.gumbel_hard_attention = orig_gumbel ens_loss = float(np.mean(losses_ens)) print(f'ensemble K={args.K} tau={args.tau} loss={ens_loss:.4f} BPC={ens_loss/math.log(2):.4f}') print(f'delta BPC: {(ens_loss - argmax_loss)/math.log(2):+.4f}') result = { 'ckpt': args.ckpt, 'K': args.K, 'tau': args.tau, 'argmax_bpc': argmax_loss/math.log(2), 'ensemble_bpc': ens_loss/math.log(2), } print(json.dumps(result)) if __name__ == '__main__': main()