| """Export a v18 checkpoint to a flat packed-binary file the C inference reads. |
| |
| File format (little-endian; all integers fit in their declared widths): |
| |
| HEADER (40 bytes): |
| magic uint32 = 'BIT1' = 0x31544942 |
| version uint32 = 1 |
| vocab_size uint32 |
| d_model uint32 |
| n_layers uint32 |
| n_heads uint32 |
| d_ff uint32 |
| max_seq_len uint32 |
| logit_scale_M int64 = 65536 (integer logit scale multiplier) |
| |
| EMBEDDING (vocab_size rows Γ d_model bits, row-packed in uint64 words): |
| [row for each char] each row = ceil(d_model/64) uint64 words |
| |
| For each of n_layers: |
| ATTENTION: |
| alibi_slopes int32[n_heads] |
| For each of Q,K,V,O (all are BitLinear[d_model -> d_model]): |
| weight_bits d_model rows Γ (d_model bits packed as uint64 per row) |
| int_threshold int32[d_model] (= round(threshold_float * sqrt(d_in))) |
| |
| FFN: |
| gate BitLinear[d_model -> d_ff] |
| up BitLinear[d_model -> d_ff] |
| down BitLinear[d_ff -> d_model] |
| |
| OUTPUT HEAD: |
| out_codebook vocab_size rows Γ d_model bits |
| int_out_bias int64[vocab_size] (= round(out_bias * M / logit_scale)) |
| |
| All weight bits use the convention: bit = 1 means +1, bit = 0 means β1. |
| Rows are stored tightly; each row's bits are packed least-significant-first into |
| successive uint64 words. If d_in % 64 != 0, the last word is padded (unused bits |
| stored as 0 β the C code masks out the unused range when popcount-ing). |
| """ |
| import argparse |
| import math |
| import os |
| import struct |
|
|
| import numpy as np |
| import torch |
|
|
| from model_v18 import BitLMv18 |
|
|
|
|
| MAGIC = 0x31544942 |
| VERSION = 1 |
| M_SCALE = 1 << 16 |
|
|
|
|
| def pack_bits_rows(sign_tensor: torch.Tensor) -> np.ndarray: |
| """sign_tensor shape (rows, cols). Each value β {β1, +1} β bit β {0, 1}. |
| Packs each row to ceil(cols/64) uint64 words, LSB-first. |
| Returns uint64 np.array of shape (rows, words_per_row).""" |
| assert sign_tensor.ndim == 2 |
| rows, cols = sign_tensor.shape |
| words = (cols + 63) // 64 |
| |
| bits = ((sign_tensor > 0).to(torch.uint8)).cpu().numpy() |
| out = np.zeros((rows, words), dtype=np.uint64) |
| for r in range(rows): |
| row = bits[r] |
| for k in range(cols): |
| if row[k]: |
| out[r, k // 64] |= np.uint64(1) << np.uint64(k % 64) |
| return out |
|
|
|
|
| def export_bitlinear(f, bitlinear): |
| """Write weight bits then int32 thresholds for a BitLinear module.""" |
| w = bitlinear.raw.weight.detach() |
| W_sign = torch.where(w >= 0, torch.ones_like(w), -torch.ones_like(w)) |
| packed = pack_bits_rows(W_sign) |
| f.write(packed.tobytes()) |
|
|
| in_features = w.shape[1] |
| |
| |
| |
| |
| threshold = bitlinear.threshold.detach().cpu().numpy() |
| int_thr = np.ceil(threshold * math.sqrt(in_features)).astype(np.int32) |
| f.write(int_thr.tobytes()) |
|
|
|
|
| def export(ckpt_path: str, out_path: str): |
| ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False) |
| cfg = ckpt['args'] |
|
|
| model = BitLMv18( |
| vocab_size=cfg['vocab_size'], d_model=cfg['d_model'], n_layers=cfg['n_layers'], |
| n_heads=cfg['n_heads'], d_ff=cfg['d_ff'], max_seq_len=cfg['seq_len'], |
| ) |
| model.load_state_dict(ckpt['model']) |
| model.eval() |
|
|
| V = cfg['vocab_size'] |
| D = cfg['d_model'] |
| L = cfg['n_layers'] |
| H = cfg['n_heads'] |
| F = cfg['d_ff'] |
| T = cfg['seq_len'] |
|
|
| print(f"exporting v18 ckpt step={ckpt['step']} val_bpc={ckpt['val_bpc']:.4f}") |
| print(f" vocab={V} d_model={D} n_layers={L} n_heads={H} d_ff={F} seq_len={T}") |
|
|
| with open(out_path, 'wb') as f: |
| |
| f.write(struct.pack('<IIIIIIII', MAGIC, VERSION, V, D, L, H, F, T)) |
| f.write(struct.pack('<q', M_SCALE)) |
|
|
| |
| E = model.embed.weight.detach() |
| E_sign = torch.where(E >= 0, torch.ones_like(E), -torch.ones_like(E)) |
| f.write(pack_bits_rows(E_sign).tobytes()) |
|
|
| |
| for li, blk in enumerate(model.blocks): |
| |
| slopes = blk.attn.alibi_slopes_int.detach().cpu().numpy().astype(np.int32) |
| f.write(slopes.tobytes()) |
| for proj in [blk.attn.q_proj, blk.attn.k_proj, blk.attn.v_proj, blk.attn.o_proj]: |
| export_bitlinear(f, proj) |
| |
| export_bitlinear(f, blk.ffn.gate) |
| export_bitlinear(f, blk.ffn.up) |
| export_bitlinear(f, blk.ffn.down) |
|
|
| |
| W_out = model.out_codebook.detach() |
| W_out_sign = torch.where(W_out >= 0, torch.ones_like(W_out), -torch.ones_like(W_out)) |
| f.write(pack_bits_rows(W_out_sign).tobytes()) |
|
|
| int_out_bias = np.round( |
| model.out_bias.detach().cpu().numpy() * M_SCALE / model.logit_scale.detach().item() |
| ).astype(np.int64) |
| f.write(int_out_bias.tobytes()) |
|
|
| sz = os.path.getsize(out_path) |
| print(f"wrote {sz:,} bytes to {out_path} ({sz / 2**20:.2f} MiB)") |
|
|
|
|
| if __name__ == '__main__': |
| ap = argparse.ArgumentParser() |
| ap.add_argument('--ckpt', required=True) |
| ap.add_argument('--out', required=True) |
| args = ap.parse_args() |
| export(args.ckpt, args.out) |
|
|