"""Export a v18 checkpoint to a flat packed-binary file the C inference reads. File format (little-endian; all integers fit in their declared widths): HEADER (40 bytes): magic uint32 = 'BIT1' = 0x31544942 version uint32 = 1 vocab_size uint32 d_model uint32 n_layers uint32 n_heads uint32 d_ff uint32 max_seq_len uint32 logit_scale_M int64 = 65536 (integer logit scale multiplier) EMBEDDING (vocab_size rows × d_model bits, row-packed in uint64 words): [row for each char] each row = ceil(d_model/64) uint64 words For each of n_layers: ATTENTION: alibi_slopes int32[n_heads] For each of Q,K,V,O (all are BitLinear[d_model -> d_model]): weight_bits d_model rows × (d_model bits packed as uint64 per row) int_threshold int32[d_model] (= round(threshold_float * sqrt(d_in))) FFN: gate BitLinear[d_model -> d_ff] up BitLinear[d_model -> d_ff] down BitLinear[d_ff -> d_model] OUTPUT HEAD: out_codebook vocab_size rows × d_model bits int_out_bias int64[vocab_size] (= round(out_bias * M / logit_scale)) All weight bits use the convention: bit = 1 means +1, bit = 0 means −1. Rows are stored tightly; each row's bits are packed least-significant-first into successive uint64 words. If d_in % 64 != 0, the last word is padded (unused bits stored as 0 — the C code masks out the unused range when popcount-ing). """ import argparse import math import os import struct import numpy as np import torch from model_v18 import BitLMv18 MAGIC = 0x31544942 # 'BIT1' little-endian VERSION = 1 M_SCALE = 1 << 16 def pack_bits_rows(sign_tensor: torch.Tensor) -> np.ndarray: """sign_tensor shape (rows, cols). Each value ∈ {−1, +1} → bit ∈ {0, 1}. Packs each row to ceil(cols/64) uint64 words, LSB-first. Returns uint64 np.array of shape (rows, words_per_row).""" assert sign_tensor.ndim == 2 rows, cols = sign_tensor.shape words = (cols + 63) // 64 # Map ±1 → {0,1}: (sign+1)/2 → 1 for +1, 0 for −1. bits = ((sign_tensor > 0).to(torch.uint8)).cpu().numpy() out = np.zeros((rows, words), dtype=np.uint64) for r in range(rows): row = bits[r] for k in range(cols): if row[k]: out[r, k // 64] |= np.uint64(1) << np.uint64(k % 64) return out def export_bitlinear(f, bitlinear): """Write weight bits then int32 thresholds for a BitLinear module.""" w = bitlinear.raw.weight.detach() # (out, in) float latent W_sign = torch.where(w >= 0, torch.ones_like(w), -torch.ones_like(w)) packed = pack_bits_rows(W_sign) f.write(packed.tobytes()) in_features = w.shape[1] # For integer popcount y, Python compares y*scale >= threshold (float). # Equivalent integer form: y >= ceil(threshold/scale) = ceil(threshold*sqrt(in)). # ceil (not round) is required so that for threshold/scale = 0.5, we get 1, # not 0 — otherwise y=0 would falsely pass the comparison. threshold = bitlinear.threshold.detach().cpu().numpy() int_thr = np.ceil(threshold * math.sqrt(in_features)).astype(np.int32) f.write(int_thr.tobytes()) def export(ckpt_path: str, out_path: str): ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False) cfg = ckpt['args'] model = BitLMv18( vocab_size=cfg['vocab_size'], d_model=cfg['d_model'], n_layers=cfg['n_layers'], n_heads=cfg['n_heads'], d_ff=cfg['d_ff'], max_seq_len=cfg['seq_len'], ) model.load_state_dict(ckpt['model']) model.eval() V = cfg['vocab_size'] D = cfg['d_model'] L = cfg['n_layers'] H = cfg['n_heads'] F = cfg['d_ff'] T = cfg['seq_len'] print(f"exporting v18 ckpt step={ckpt['step']} val_bpc={ckpt['val_bpc']:.4f}") print(f" vocab={V} d_model={D} n_layers={L} n_heads={H} d_ff={F} seq_len={T}") with open(out_path, 'wb') as f: # HEADER f.write(struct.pack('= 0, torch.ones_like(E), -torch.ones_like(E)) f.write(pack_bits_rows(E_sign).tobytes()) # LAYERS for li, blk in enumerate(model.blocks): # Attention slopes = blk.attn.alibi_slopes_int.detach().cpu().numpy().astype(np.int32) f.write(slopes.tobytes()) for proj in [blk.attn.q_proj, blk.attn.k_proj, blk.attn.v_proj, blk.attn.o_proj]: export_bitlinear(f, proj) # FFN export_bitlinear(f, blk.ffn.gate) export_bitlinear(f, blk.ffn.up) export_bitlinear(f, blk.ffn.down) # OUTPUT HEAD W_out = model.out_codebook.detach() W_out_sign = torch.where(W_out >= 0, torch.ones_like(W_out), -torch.ones_like(W_out)) f.write(pack_bits_rows(W_out_sign).tobytes()) int_out_bias = np.round( model.out_bias.detach().cpu().numpy() * M_SCALE / model.logit_scale.detach().item() ).astype(np.int64) f.write(int_out_bias.tobytes()) sz = os.path.getsize(out_path) print(f"wrote {sz:,} bytes to {out_path} ({sz / 2**20:.2f} MiB)") if __name__ == '__main__': ap = argparse.ArgumentParser() ap.add_argument('--ckpt', required=True) ap.add_argument('--out', required=True) args = ap.parse_args() export(args.ckpt, args.out)