File size: 5,573 Bytes
4754707 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | """Export a v18 checkpoint to a flat packed-binary file the C inference reads.
File format (little-endian; all integers fit in their declared widths):
HEADER (40 bytes):
magic uint32 = 'BIT1' = 0x31544942
version uint32 = 1
vocab_size uint32
d_model uint32
n_layers uint32
n_heads uint32
d_ff uint32
max_seq_len uint32
logit_scale_M int64 = 65536 (integer logit scale multiplier)
EMBEDDING (vocab_size rows Γ d_model bits, row-packed in uint64 words):
[row for each char] each row = ceil(d_model/64) uint64 words
For each of n_layers:
ATTENTION:
alibi_slopes int32[n_heads]
For each of Q,K,V,O (all are BitLinear[d_model -> d_model]):
weight_bits d_model rows Γ (d_model bits packed as uint64 per row)
int_threshold int32[d_model] (= round(threshold_float * sqrt(d_in)))
FFN:
gate BitLinear[d_model -> d_ff]
up BitLinear[d_model -> d_ff]
down BitLinear[d_ff -> d_model]
OUTPUT HEAD:
out_codebook vocab_size rows Γ d_model bits
int_out_bias int64[vocab_size] (= round(out_bias * M / logit_scale))
All weight bits use the convention: bit = 1 means +1, bit = 0 means β1.
Rows are stored tightly; each row's bits are packed least-significant-first into
successive uint64 words. If d_in % 64 != 0, the last word is padded (unused bits
stored as 0 β the C code masks out the unused range when popcount-ing).
"""
import argparse
import math
import os
import struct
import numpy as np
import torch
from model_v18 import BitLMv18
MAGIC = 0x31544942 # 'BIT1' little-endian
VERSION = 1
M_SCALE = 1 << 16
def pack_bits_rows(sign_tensor: torch.Tensor) -> np.ndarray:
"""sign_tensor shape (rows, cols). Each value β {β1, +1} β bit β {0, 1}.
Packs each row to ceil(cols/64) uint64 words, LSB-first.
Returns uint64 np.array of shape (rows, words_per_row)."""
assert sign_tensor.ndim == 2
rows, cols = sign_tensor.shape
words = (cols + 63) // 64
# Map Β±1 β {0,1}: (sign+1)/2 β 1 for +1, 0 for β1.
bits = ((sign_tensor > 0).to(torch.uint8)).cpu().numpy()
out = np.zeros((rows, words), dtype=np.uint64)
for r in range(rows):
row = bits[r]
for k in range(cols):
if row[k]:
out[r, k // 64] |= np.uint64(1) << np.uint64(k % 64)
return out
def export_bitlinear(f, bitlinear):
"""Write weight bits then int32 thresholds for a BitLinear module."""
w = bitlinear.raw.weight.detach() # (out, in) float latent
W_sign = torch.where(w >= 0, torch.ones_like(w), -torch.ones_like(w))
packed = pack_bits_rows(W_sign)
f.write(packed.tobytes())
in_features = w.shape[1]
# For integer popcount y, Python compares y*scale >= threshold (float).
# Equivalent integer form: y >= ceil(threshold/scale) = ceil(threshold*sqrt(in)).
# ceil (not round) is required so that for threshold/scale = 0.5, we get 1,
# not 0 β otherwise y=0 would falsely pass the comparison.
threshold = bitlinear.threshold.detach().cpu().numpy()
int_thr = np.ceil(threshold * math.sqrt(in_features)).astype(np.int32)
f.write(int_thr.tobytes())
def export(ckpt_path: str, out_path: str):
ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)
cfg = ckpt['args']
model = BitLMv18(
vocab_size=cfg['vocab_size'], d_model=cfg['d_model'], n_layers=cfg['n_layers'],
n_heads=cfg['n_heads'], d_ff=cfg['d_ff'], max_seq_len=cfg['seq_len'],
)
model.load_state_dict(ckpt['model'])
model.eval()
V = cfg['vocab_size']
D = cfg['d_model']
L = cfg['n_layers']
H = cfg['n_heads']
F = cfg['d_ff']
T = cfg['seq_len']
print(f"exporting v18 ckpt step={ckpt['step']} val_bpc={ckpt['val_bpc']:.4f}")
print(f" vocab={V} d_model={D} n_layers={L} n_heads={H} d_ff={F} seq_len={T}")
with open(out_path, 'wb') as f:
# HEADER
f.write(struct.pack('<IIIIIIII', MAGIC, VERSION, V, D, L, H, F, T))
f.write(struct.pack('<q', M_SCALE))
# EMBEDDING
E = model.embed.weight.detach()
E_sign = torch.where(E >= 0, torch.ones_like(E), -torch.ones_like(E))
f.write(pack_bits_rows(E_sign).tobytes())
# LAYERS
for li, blk in enumerate(model.blocks):
# Attention
slopes = blk.attn.alibi_slopes_int.detach().cpu().numpy().astype(np.int32)
f.write(slopes.tobytes())
for proj in [blk.attn.q_proj, blk.attn.k_proj, blk.attn.v_proj, blk.attn.o_proj]:
export_bitlinear(f, proj)
# FFN
export_bitlinear(f, blk.ffn.gate)
export_bitlinear(f, blk.ffn.up)
export_bitlinear(f, blk.ffn.down)
# OUTPUT HEAD
W_out = model.out_codebook.detach()
W_out_sign = torch.where(W_out >= 0, torch.ones_like(W_out), -torch.ones_like(W_out))
f.write(pack_bits_rows(W_out_sign).tobytes())
int_out_bias = np.round(
model.out_bias.detach().cpu().numpy() * M_SCALE / model.logit_scale.detach().item()
).astype(np.int64)
f.write(int_out_bias.tobytes())
sz = os.path.getsize(out_path)
print(f"wrote {sz:,} bytes to {out_path} ({sz / 2**20:.2f} MiB)")
if __name__ == '__main__':
ap = argparse.ArgumentParser()
ap.add_argument('--ckpt', required=True)
ap.add_argument('--out', required=True)
args = ap.parse_args()
export(args.ckpt, args.out)
|