bitnet-1bitllm / vm_backup /code /eval_ensemble.py
hidude562's picture
1bitllm code (checkpoints to follow)
4754707 verified
"""Test-time ensembling: K forward passes with Gumbel noise, averaged logits.
Monkey-patches gumbel_hard_attention to always add Gumbel noise (even in eval),
then averages K independent passes.
"""
import argparse, math, os, json
import numpy as np
import torch
import torch.nn.functional as F
import model_v16 as _v16
from model_v16 import _get_tau
from model_v57 import BitLMv57
def _force_gumbel(scores, mask=None):
tau = _get_tau(scores.device)
if mask is not None:
scores = scores.masked_fill(mask, -1e9)
g = -torch.log(-torch.log(torch.rand_like(scores).clamp(min=1e-9)) + 1e-9)
y_soft = F.softmax((scores + g) / tau, dim=-1)
y_hard = torch.zeros_like(y_soft).scatter_(-1, y_soft.argmax(-1, keepdim=True), 1.0)
return y_hard
def main():
ap = argparse.ArgumentParser()
ap.add_argument('--ckpt', required=True)
ap.add_argument('--data', default='/root/bitnet1/data/validation.bin')
ap.add_argument('--K', type=int, default=16)
ap.add_argument('--iters', type=int, default=40)
ap.add_argument('--batch-size', type=int, default=64)
ap.add_argument('--tau', type=float, default=0.5)
args = ap.parse_args()
# Load
ck = torch.load(args.ckpt, map_location='cuda', weights_only=False)
cfg = ck['args']
m = BitLMv57(vocab_size=128, d_model=cfg['d_model'], n_layers=cfg['n_layers'],
n_heads=cfg['n_heads'], d_ff=cfg['d_ff'], max_seq_len=cfg['seq_len']).cuda()
m.load_state_dict(ck['model'])
m.eval()
_v16.set_gumbel_tau(args.tau)
seq_len = cfg['seq_len']
val = np.memmap(args.data, dtype=np.uint8, mode='r')
# BASELINE: argmax eval (no Gumbel) — same as training eval.
losses_argmax = []
for _ in range(args.iters):
ix = torch.randint(0, len(val) - seq_len - 1, (args.batch_size,))
X = torch.stack([torch.from_numpy(val[s:s+seq_len].astype(np.int64)) for s in ix.tolist()]).cuda()
Y = torch.stack([torch.from_numpy(val[s+1:s+1+seq_len].astype(np.int64)) for s in ix.tolist()]).cuda()
with torch.no_grad():
_, loss = m(X, Y)
losses_argmax.append(loss.item())
argmax_loss = float(np.mean(losses_argmax))
print(f'baseline (argmax) loss={argmax_loss:.4f} BPC={argmax_loss/math.log(2):.4f}')
# ENSEMBLE: monkey-patch gumbel to always inject noise, then average K.
import model_v16 as _v16m
orig_gumbel = _v16m.gumbel_hard_attention
_v16m.gumbel_hard_attention = _force_gumbel
losses_ens = []
for _ in range(args.iters):
ix = torch.randint(0, len(val) - seq_len - 1, (args.batch_size,))
X = torch.stack([torch.from_numpy(val[s:s+seq_len].astype(np.int64)) for s in ix.tolist()]).cuda()
Y = torch.stack([torch.from_numpy(val[s+1:s+1+seq_len].astype(np.int64)) for s in ix.tolist()]).cuda()
log_ps = []
with torch.no_grad():
for _ in range(args.K):
logits, _ = m(X)
log_ps.append(F.log_softmax(logits, dim=-1))
avg_log_p = torch.logsumexp(torch.stack(log_ps, dim=0), dim=0) - math.log(args.K)
log_p_correct = avg_log_p.gather(-1, Y.unsqueeze(-1)).squeeze(-1)
ens_loss = (-log_p_correct).mean()
losses_ens.append(ens_loss.item())
_v16m.gumbel_hard_attention = orig_gumbel
ens_loss = float(np.mean(losses_ens))
print(f'ensemble K={args.K} tau={args.tau} loss={ens_loss:.4f} BPC={ens_loss/math.log(2):.4f}')
print(f'delta BPC: {(ens_loss - argmax_loss)/math.log(2):+.4f}')
result = {
'ckpt': args.ckpt, 'K': args.K, 'tau': args.tau,
'argmax_bpc': argmax_loss/math.log(2),
'ensemble_bpc': ens_loss/math.log(2),
}
print(json.dumps(result))
if __name__ == '__main__':
main()