File size: 5,573 Bytes
4754707
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
"""Export a v18 checkpoint to a flat packed-binary file the C inference reads.

File format (little-endian; all integers fit in their declared widths):

  HEADER (40 bytes):
    magic             uint32  = 'BIT1' = 0x31544942
    version           uint32  = 1
    vocab_size        uint32
    d_model           uint32
    n_layers          uint32
    n_heads           uint32
    d_ff              uint32
    max_seq_len       uint32
    logit_scale_M     int64   = 65536  (integer logit scale multiplier)

  EMBEDDING (vocab_size rows Γ— d_model bits, row-packed in uint64 words):
    [row for each char] each row = ceil(d_model/64) uint64 words

  For each of n_layers:
    ATTENTION:
      alibi_slopes  int32[n_heads]
      For each of Q,K,V,O (all are BitLinear[d_model -> d_model]):
        weight_bits   d_model rows Γ— (d_model bits packed as uint64 per row)
        int_threshold int32[d_model]   (= round(threshold_float * sqrt(d_in)))

    FFN:
      gate  BitLinear[d_model -> d_ff]
      up    BitLinear[d_model -> d_ff]
      down  BitLinear[d_ff    -> d_model]

  OUTPUT HEAD:
    out_codebook  vocab_size rows Γ— d_model bits
    int_out_bias  int64[vocab_size]   (= round(out_bias * M / logit_scale))

All weight bits use the convention: bit = 1 means +1, bit = 0 means βˆ’1.
Rows are stored tightly; each row's bits are packed least-significant-first into
successive uint64 words. If d_in % 64 != 0, the last word is padded (unused bits
stored as 0 β€” the C code masks out the unused range when popcount-ing).
"""
import argparse
import math
import os
import struct

import numpy as np
import torch

from model_v18 import BitLMv18


MAGIC = 0x31544942  # 'BIT1' little-endian
VERSION = 1
M_SCALE = 1 << 16


def pack_bits_rows(sign_tensor: torch.Tensor) -> np.ndarray:
    """sign_tensor shape (rows, cols). Each value ∈ {βˆ’1, +1} β†’ bit ∈ {0, 1}.
    Packs each row to ceil(cols/64) uint64 words, LSB-first.
    Returns uint64 np.array of shape (rows, words_per_row)."""
    assert sign_tensor.ndim == 2
    rows, cols = sign_tensor.shape
    words = (cols + 63) // 64
    # Map Β±1 β†’ {0,1}: (sign+1)/2 β†’ 1 for +1, 0 for βˆ’1.
    bits = ((sign_tensor > 0).to(torch.uint8)).cpu().numpy()
    out = np.zeros((rows, words), dtype=np.uint64)
    for r in range(rows):
        row = bits[r]
        for k in range(cols):
            if row[k]:
                out[r, k // 64] |= np.uint64(1) << np.uint64(k % 64)
    return out


def export_bitlinear(f, bitlinear):
    """Write weight bits then int32 thresholds for a BitLinear module."""
    w = bitlinear.raw.weight.detach()   # (out, in) float latent
    W_sign = torch.where(w >= 0, torch.ones_like(w), -torch.ones_like(w))
    packed = pack_bits_rows(W_sign)
    f.write(packed.tobytes())

    in_features = w.shape[1]
    # For integer popcount y, Python compares y*scale >= threshold (float).
    # Equivalent integer form: y >= ceil(threshold/scale) = ceil(threshold*sqrt(in)).
    # ceil (not round) is required so that for threshold/scale = 0.5, we get 1,
    # not 0 β€” otherwise y=0 would falsely pass the comparison.
    threshold = bitlinear.threshold.detach().cpu().numpy()
    int_thr = np.ceil(threshold * math.sqrt(in_features)).astype(np.int32)
    f.write(int_thr.tobytes())


def export(ckpt_path: str, out_path: str):
    ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)
    cfg = ckpt['args']

    model = BitLMv18(
        vocab_size=cfg['vocab_size'], d_model=cfg['d_model'], n_layers=cfg['n_layers'],
        n_heads=cfg['n_heads'], d_ff=cfg['d_ff'], max_seq_len=cfg['seq_len'],
    )
    model.load_state_dict(ckpt['model'])
    model.eval()

    V = cfg['vocab_size']
    D = cfg['d_model']
    L = cfg['n_layers']
    H = cfg['n_heads']
    F = cfg['d_ff']
    T = cfg['seq_len']

    print(f"exporting v18 ckpt step={ckpt['step']} val_bpc={ckpt['val_bpc']:.4f}")
    print(f"  vocab={V} d_model={D} n_layers={L} n_heads={H} d_ff={F} seq_len={T}")

    with open(out_path, 'wb') as f:
        # HEADER
        f.write(struct.pack('<IIIIIIII', MAGIC, VERSION, V, D, L, H, F, T))
        f.write(struct.pack('<q', M_SCALE))

        # EMBEDDING
        E = model.embed.weight.detach()
        E_sign = torch.where(E >= 0, torch.ones_like(E), -torch.ones_like(E))
        f.write(pack_bits_rows(E_sign).tobytes())

        # LAYERS
        for li, blk in enumerate(model.blocks):
            # Attention
            slopes = blk.attn.alibi_slopes_int.detach().cpu().numpy().astype(np.int32)
            f.write(slopes.tobytes())
            for proj in [blk.attn.q_proj, blk.attn.k_proj, blk.attn.v_proj, blk.attn.o_proj]:
                export_bitlinear(f, proj)
            # FFN
            export_bitlinear(f, blk.ffn.gate)
            export_bitlinear(f, blk.ffn.up)
            export_bitlinear(f, blk.ffn.down)

        # OUTPUT HEAD
        W_out = model.out_codebook.detach()
        W_out_sign = torch.where(W_out >= 0, torch.ones_like(W_out), -torch.ones_like(W_out))
        f.write(pack_bits_rows(W_out_sign).tobytes())

        int_out_bias = np.round(
            model.out_bias.detach().cpu().numpy() * M_SCALE / model.logit_scale.detach().item()
        ).astype(np.int64)
        f.write(int_out_bias.tobytes())

    sz = os.path.getsize(out_path)
    print(f"wrote {sz:,} bytes to {out_path} ({sz / 2**20:.2f} MiB)")


if __name__ == '__main__':
    ap = argparse.ArgumentParser()
    ap.add_argument('--ckpt', required=True)
    ap.add_argument('--out', required=True)
    args = ap.parse_args()
    export(args.ckpt, args.out)