"""Benchmark FastBitLMv57 vs BitLMv57 end-to-end train-step wall time.""" import time import torch torch.set_float32_matmul_precision('high') import model_v16 as _v16 from model_v57 import BitLMv57 from model_v57_fast import BitLMv57Fast def bench(m, bs=64, T=256, iters=20, warmup=5): m = m.cuda() opt = torch.optim.AdamW(m.parameters(), lr=3e-4, betas=(0.9, 0.95)) mm = torch.compile(m) x = torch.randint(0, 128, (bs, T), device='cuda') y = torch.randint(0, 128, (bs, T), device='cuda') _v16.set_gumbel_tau(0.5) for _ in range(warmup): _, loss = mm(x, y) opt.zero_grad(set_to_none=True) loss.backward() torch.nn.utils.clip_grad_norm_(m.parameters(), 1.0) opt.step() torch.cuda.synchronize() t0 = time.time() for _ in range(iters): _, loss = mm(x, y) opt.zero_grad(set_to_none=True) loss.backward() torch.nn.utils.clip_grad_norm_(m.parameters(), 1.0) opt.step() torch.cuda.synchronize() return (time.time() - t0) / iters * 1000 def main(): # Match v73 config kw = dict(d_model=1024, n_layers=8, n_heads=32, d_ff=512) m_ref = BitLMv57(**kw) t_ref = bench(m_ref) del m_ref torch.cuda.empty_cache() m_fast = BitLMv57Fast(**kw) t_fast = bench(m_fast) print(f'reference BitLMv57 : {t_ref:.2f} ms / step') print(f'triton BitLMv57Fast : {t_fast:.2f} ms / step') print(f'speedup: {t_ref/t_fast:.2f}x') if __name__ == '__main__': main()