| """Benchmark FastBitLMv57 vs BitLMv57 end-to-end train-step wall time.""" |
| import time |
| import torch |
| torch.set_float32_matmul_precision('high') |
|
|
| import model_v16 as _v16 |
| from model_v57 import BitLMv57 |
| from model_v57_fast import BitLMv57Fast |
|
|
|
|
| def bench(m, bs=64, T=256, iters=20, warmup=5): |
| m = m.cuda() |
| opt = torch.optim.AdamW(m.parameters(), lr=3e-4, betas=(0.9, 0.95)) |
| mm = torch.compile(m) |
| x = torch.randint(0, 128, (bs, T), device='cuda') |
| y = torch.randint(0, 128, (bs, T), device='cuda') |
| _v16.set_gumbel_tau(0.5) |
|
|
| for _ in range(warmup): |
| _, loss = mm(x, y) |
| opt.zero_grad(set_to_none=True) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(m.parameters(), 1.0) |
| opt.step() |
|
|
| torch.cuda.synchronize() |
| t0 = time.time() |
| for _ in range(iters): |
| _, loss = mm(x, y) |
| opt.zero_grad(set_to_none=True) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(m.parameters(), 1.0) |
| opt.step() |
| torch.cuda.synchronize() |
| return (time.time() - t0) / iters * 1000 |
|
|
|
|
| def main(): |
| |
| kw = dict(d_model=1024, n_layers=8, n_heads=32, d_ff=512) |
| m_ref = BitLMv57(**kw) |
| t_ref = bench(m_ref) |
| del m_ref |
| torch.cuda.empty_cache() |
|
|
| m_fast = BitLMv57Fast(**kw) |
| t_fast = bench(m_fast) |
|
|
| print(f'reference BitLMv57 : {t_ref:.2f} ms / step') |
| print(f'triton BitLMv57Fast : {t_fast:.2f} ms / step') |
| print(f'speedup: {t_ref/t_fast:.2f}x') |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|