bitnet-1bitllm / vm_backup /code /export_v18.py
hidude562's picture
1bitllm code (checkpoints to follow)
4754707 verified
"""Export a v18 checkpoint to a flat packed-binary file the C inference reads.
File format (little-endian; all integers fit in their declared widths):
HEADER (40 bytes):
magic uint32 = 'BIT1' = 0x31544942
version uint32 = 1
vocab_size uint32
d_model uint32
n_layers uint32
n_heads uint32
d_ff uint32
max_seq_len uint32
logit_scale_M int64 = 65536 (integer logit scale multiplier)
EMBEDDING (vocab_size rows Γ— d_model bits, row-packed in uint64 words):
[row for each char] each row = ceil(d_model/64) uint64 words
For each of n_layers:
ATTENTION:
alibi_slopes int32[n_heads]
For each of Q,K,V,O (all are BitLinear[d_model -> d_model]):
weight_bits d_model rows Γ— (d_model bits packed as uint64 per row)
int_threshold int32[d_model] (= round(threshold_float * sqrt(d_in)))
FFN:
gate BitLinear[d_model -> d_ff]
up BitLinear[d_model -> d_ff]
down BitLinear[d_ff -> d_model]
OUTPUT HEAD:
out_codebook vocab_size rows Γ— d_model bits
int_out_bias int64[vocab_size] (= round(out_bias * M / logit_scale))
All weight bits use the convention: bit = 1 means +1, bit = 0 means βˆ’1.
Rows are stored tightly; each row's bits are packed least-significant-first into
successive uint64 words. If d_in % 64 != 0, the last word is padded (unused bits
stored as 0 β€” the C code masks out the unused range when popcount-ing).
"""
import argparse
import math
import os
import struct
import numpy as np
import torch
from model_v18 import BitLMv18
MAGIC = 0x31544942 # 'BIT1' little-endian
VERSION = 1
M_SCALE = 1 << 16
def pack_bits_rows(sign_tensor: torch.Tensor) -> np.ndarray:
"""sign_tensor shape (rows, cols). Each value ∈ {βˆ’1, +1} β†’ bit ∈ {0, 1}.
Packs each row to ceil(cols/64) uint64 words, LSB-first.
Returns uint64 np.array of shape (rows, words_per_row)."""
assert sign_tensor.ndim == 2
rows, cols = sign_tensor.shape
words = (cols + 63) // 64
# Map Β±1 β†’ {0,1}: (sign+1)/2 β†’ 1 for +1, 0 for βˆ’1.
bits = ((sign_tensor > 0).to(torch.uint8)).cpu().numpy()
out = np.zeros((rows, words), dtype=np.uint64)
for r in range(rows):
row = bits[r]
for k in range(cols):
if row[k]:
out[r, k // 64] |= np.uint64(1) << np.uint64(k % 64)
return out
def export_bitlinear(f, bitlinear):
"""Write weight bits then int32 thresholds for a BitLinear module."""
w = bitlinear.raw.weight.detach() # (out, in) float latent
W_sign = torch.where(w >= 0, torch.ones_like(w), -torch.ones_like(w))
packed = pack_bits_rows(W_sign)
f.write(packed.tobytes())
in_features = w.shape[1]
# For integer popcount y, Python compares y*scale >= threshold (float).
# Equivalent integer form: y >= ceil(threshold/scale) = ceil(threshold*sqrt(in)).
# ceil (not round) is required so that for threshold/scale = 0.5, we get 1,
# not 0 β€” otherwise y=0 would falsely pass the comparison.
threshold = bitlinear.threshold.detach().cpu().numpy()
int_thr = np.ceil(threshold * math.sqrt(in_features)).astype(np.int32)
f.write(int_thr.tobytes())
def export(ckpt_path: str, out_path: str):
ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)
cfg = ckpt['args']
model = BitLMv18(
vocab_size=cfg['vocab_size'], d_model=cfg['d_model'], n_layers=cfg['n_layers'],
n_heads=cfg['n_heads'], d_ff=cfg['d_ff'], max_seq_len=cfg['seq_len'],
)
model.load_state_dict(ckpt['model'])
model.eval()
V = cfg['vocab_size']
D = cfg['d_model']
L = cfg['n_layers']
H = cfg['n_heads']
F = cfg['d_ff']
T = cfg['seq_len']
print(f"exporting v18 ckpt step={ckpt['step']} val_bpc={ckpt['val_bpc']:.4f}")
print(f" vocab={V} d_model={D} n_layers={L} n_heads={H} d_ff={F} seq_len={T}")
with open(out_path, 'wb') as f:
# HEADER
f.write(struct.pack('<IIIIIIII', MAGIC, VERSION, V, D, L, H, F, T))
f.write(struct.pack('<q', M_SCALE))
# EMBEDDING
E = model.embed.weight.detach()
E_sign = torch.where(E >= 0, torch.ones_like(E), -torch.ones_like(E))
f.write(pack_bits_rows(E_sign).tobytes())
# LAYERS
for li, blk in enumerate(model.blocks):
# Attention
slopes = blk.attn.alibi_slopes_int.detach().cpu().numpy().astype(np.int32)
f.write(slopes.tobytes())
for proj in [blk.attn.q_proj, blk.attn.k_proj, blk.attn.v_proj, blk.attn.o_proj]:
export_bitlinear(f, proj)
# FFN
export_bitlinear(f, blk.ffn.gate)
export_bitlinear(f, blk.ffn.up)
export_bitlinear(f, blk.ffn.down)
# OUTPUT HEAD
W_out = model.out_codebook.detach()
W_out_sign = torch.where(W_out >= 0, torch.ones_like(W_out), -torch.ones_like(W_out))
f.write(pack_bits_rows(W_out_sign).tobytes())
int_out_bias = np.round(
model.out_bias.detach().cpu().numpy() * M_SCALE / model.logit_scale.detach().item()
).astype(np.int64)
f.write(int_out_bias.tobytes())
sz = os.path.getsize(out_path)
print(f"wrote {sz:,} bytes to {out_path} ({sz / 2**20:.2f} MiB)")
if __name__ == '__main__':
ap = argparse.ArgumentParser()
ap.add_argument('--ckpt', required=True)
ap.add_argument('--out', required=True)
args = ap.parse_args()
export(args.ckpt, args.out)