"""Initialize a BitLMv18 binary student's latent weights from a trained FP32LM teacher. Copies: - embedding codebook (FP32 embed → student.embed) - output codebook (FP32 head → student.out_codebook) - per-block attention projections (FP32 qkv split → student q/k/v, o → o) - per-block FFN (FP32 g/u/d → student gate/up/down) Skips (not present in v18): positional embedding, RMSNorm. ALiBi replaces pos. Teacher and student must have identical d_model, n_layers, n_heads, d_ff, vocab. """ import torch from model_v18 import BitLMv18 from model_fp32 import FP32LM @torch.no_grad() def init_binary_from_fp32(student: BitLMv18, teacher: FP32LM, mode='sign_rescale', latent_scale=0.02): """ mode='raw' — copy teacher weights verbatim (tried in v36; failed because teacher magnitudes are too large → weights frozen). mode='sign_rescale' — copy sign(teacher) and rescale to `latent_scale`. Student starts with teacher's sign pattern BUT has flippable latent magnitudes. Preferred. """ assert student.d_model == teacher.d_model assert len(student.blocks) == len(teacher.blocks) assert student.vocab_size == teacher.vocab_size def convert(w): if mode == 'raw': return w.clone() elif mode == 'sign_rescale': s = torch.sign(w) s[s == 0] = 1 return s * latent_scale raise ValueError(mode) student.embed.weight.copy_(convert(teacher.embed.weight)) student.out_codebook.copy_(convert(teacher.embed.weight)) for b_blk, t_blk in zip(student.blocks, teacher.blocks): qkv = t_blk.a.qkv.weight d = qkv.shape[1] b_blk.attn.q_proj.raw.weight.copy_(convert(qkv[:d])) b_blk.attn.k_proj.raw.weight.copy_(convert(qkv[d:2*d])) b_blk.attn.v_proj.raw.weight.copy_(convert(qkv[2*d:])) b_blk.attn.o_proj.raw.weight.copy_(convert(t_blk.a.o.weight)) b_blk.ffn.gate.raw.weight.copy_(convert(t_blk.f.g.weight)) b_blk.ffn.up.raw.weight.copy_(convert(t_blk.f.u.weight)) b_blk.ffn.down.raw.weight.copy_(convert(t_blk.f.d.weight)) def load_teacher(path, device='cuda'): ck = torch.load(path, map_location=device, weights_only=False) cfg = ck['args'] t = FP32LM( vocab_size=cfg['vocab_size'], d_model=cfg['d_model'], n_layers=cfg['n_layers'], n_heads=cfg['n_heads'], d_ff=cfg['d_ff'], max_seq_len=cfg['seq_len'], ).to(device) t.load_state_dict(ck['model']) t.eval() return t, ck