| """Initialize a BitLMv18 binary student's latent weights from a trained FP32LM teacher. |
| |
| Copies: |
| - embedding codebook (FP32 embed β student.embed) |
| - output codebook (FP32 head β student.out_codebook) |
| - per-block attention projections (FP32 qkv split β student q/k/v, o β o) |
| - per-block FFN (FP32 g/u/d β student gate/up/down) |
| |
| Skips (not present in v18): positional embedding, RMSNorm. ALiBi replaces pos. |
| Teacher and student must have identical d_model, n_layers, n_heads, d_ff, vocab. |
| """ |
| import torch |
| from model_v18 import BitLMv18 |
| from model_fp32 import FP32LM |
|
|
|
|
| @torch.no_grad() |
| def init_binary_from_fp32(student: BitLMv18, teacher: FP32LM, mode='sign_rescale', |
| latent_scale=0.02): |
| """ |
| mode='raw' β copy teacher weights verbatim (tried in v36; failed because |
| teacher magnitudes are too large β weights frozen). |
| mode='sign_rescale' β copy sign(teacher) and rescale to `latent_scale`. Student |
| starts with teacher's sign pattern BUT has flippable |
| latent magnitudes. Preferred. |
| """ |
| assert student.d_model == teacher.d_model |
| assert len(student.blocks) == len(teacher.blocks) |
| assert student.vocab_size == teacher.vocab_size |
|
|
| def convert(w): |
| if mode == 'raw': |
| return w.clone() |
| elif mode == 'sign_rescale': |
| s = torch.sign(w) |
| s[s == 0] = 1 |
| return s * latent_scale |
| raise ValueError(mode) |
|
|
| student.embed.weight.copy_(convert(teacher.embed.weight)) |
| student.out_codebook.copy_(convert(teacher.embed.weight)) |
|
|
| for b_blk, t_blk in zip(student.blocks, teacher.blocks): |
| qkv = t_blk.a.qkv.weight |
| d = qkv.shape[1] |
| b_blk.attn.q_proj.raw.weight.copy_(convert(qkv[:d])) |
| b_blk.attn.k_proj.raw.weight.copy_(convert(qkv[d:2*d])) |
| b_blk.attn.v_proj.raw.weight.copy_(convert(qkv[2*d:])) |
| b_blk.attn.o_proj.raw.weight.copy_(convert(t_blk.a.o.weight)) |
|
|
| b_blk.ffn.gate.raw.weight.copy_(convert(t_blk.f.g.weight)) |
| b_blk.ffn.up.raw.weight.copy_(convert(t_blk.f.u.weight)) |
| b_blk.ffn.down.raw.weight.copy_(convert(t_blk.f.d.weight)) |
|
|
|
|
| def load_teacher(path, device='cuda'): |
| ck = torch.load(path, map_location=device, weights_only=False) |
| cfg = ck['args'] |
| t = FP32LM( |
| vocab_size=cfg['vocab_size'], d_model=cfg['d_model'], n_layers=cfg['n_layers'], |
| n_heads=cfg['n_heads'], d_ff=cfg['d_ff'], max_seq_len=cfg['seq_len'], |
| ).to(device) |
| t.load_state_dict(ck['model']) |
| t.eval() |
| return t, ck |
|
|