bitnet-1bitllm / vm_backup /code /fp32_to_binary_init.py
hidude562's picture
1bitllm code (checkpoints to follow)
4754707 verified
"""Initialize a BitLMv18 binary student's latent weights from a trained FP32LM teacher.
Copies:
- embedding codebook (FP32 embed β†’ student.embed)
- output codebook (FP32 head β†’ student.out_codebook)
- per-block attention projections (FP32 qkv split β†’ student q/k/v, o β†’ o)
- per-block FFN (FP32 g/u/d β†’ student gate/up/down)
Skips (not present in v18): positional embedding, RMSNorm. ALiBi replaces pos.
Teacher and student must have identical d_model, n_layers, n_heads, d_ff, vocab.
"""
import torch
from model_v18 import BitLMv18
from model_fp32 import FP32LM
@torch.no_grad()
def init_binary_from_fp32(student: BitLMv18, teacher: FP32LM, mode='sign_rescale',
latent_scale=0.02):
"""
mode='raw' β€” copy teacher weights verbatim (tried in v36; failed because
teacher magnitudes are too large β†’ weights frozen).
mode='sign_rescale' β€” copy sign(teacher) and rescale to `latent_scale`. Student
starts with teacher's sign pattern BUT has flippable
latent magnitudes. Preferred.
"""
assert student.d_model == teacher.d_model
assert len(student.blocks) == len(teacher.blocks)
assert student.vocab_size == teacher.vocab_size
def convert(w):
if mode == 'raw':
return w.clone()
elif mode == 'sign_rescale':
s = torch.sign(w)
s[s == 0] = 1
return s * latent_scale
raise ValueError(mode)
student.embed.weight.copy_(convert(teacher.embed.weight))
student.out_codebook.copy_(convert(teacher.embed.weight))
for b_blk, t_blk in zip(student.blocks, teacher.blocks):
qkv = t_blk.a.qkv.weight
d = qkv.shape[1]
b_blk.attn.q_proj.raw.weight.copy_(convert(qkv[:d]))
b_blk.attn.k_proj.raw.weight.copy_(convert(qkv[d:2*d]))
b_blk.attn.v_proj.raw.weight.copy_(convert(qkv[2*d:]))
b_blk.attn.o_proj.raw.weight.copy_(convert(t_blk.a.o.weight))
b_blk.ffn.gate.raw.weight.copy_(convert(t_blk.f.g.weight))
b_blk.ffn.up.raw.weight.copy_(convert(t_blk.f.u.weight))
b_blk.ffn.down.raw.weight.copy_(convert(t_blk.f.d.weight))
def load_teacher(path, device='cuda'):
ck = torch.load(path, map_location=device, weights_only=False)
cfg = ck['args']
t = FP32LM(
vocab_size=cfg['vocab_size'], d_model=cfg['d_model'], n_layers=cfg['n_layers'],
n_heads=cfg['n_heads'], d_ff=cfg['d_ff'], max_seq_len=cfg['seq_len'],
).to(device)
t.load_state_dict(ck['model'])
t.eval()
return t, ck