File size: 2,625 Bytes
4754707
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""Initialize a BitLMv18 binary student's latent weights from a trained FP32LM teacher.

Copies:
  - embedding codebook (FP32 embed → student.embed)
  - output codebook (FP32 head → student.out_codebook)
  - per-block attention projections (FP32 qkv split → student q/k/v, o → o)
  - per-block FFN (FP32 g/u/d → student gate/up/down)

Skips (not present in v18): positional embedding, RMSNorm. ALiBi replaces pos.
Teacher and student must have identical d_model, n_layers, n_heads, d_ff, vocab.
"""
import torch
from model_v18 import BitLMv18
from model_fp32 import FP32LM


@torch.no_grad()
def init_binary_from_fp32(student: BitLMv18, teacher: FP32LM, mode='sign_rescale',
                          latent_scale=0.02):
    """
    mode='raw'           — copy teacher weights verbatim (tried in v36; failed because
                           teacher magnitudes are too large → weights frozen).
    mode='sign_rescale'  — copy sign(teacher) and rescale to `latent_scale`. Student
                           starts with teacher's sign pattern BUT has flippable
                           latent magnitudes. Preferred.
    """
    assert student.d_model == teacher.d_model
    assert len(student.blocks) == len(teacher.blocks)
    assert student.vocab_size == teacher.vocab_size

    def convert(w):
        if mode == 'raw':
            return w.clone()
        elif mode == 'sign_rescale':
            s = torch.sign(w)
            s[s == 0] = 1
            return s * latent_scale
        raise ValueError(mode)

    student.embed.weight.copy_(convert(teacher.embed.weight))
    student.out_codebook.copy_(convert(teacher.embed.weight))

    for b_blk, t_blk in zip(student.blocks, teacher.blocks):
        qkv = t_blk.a.qkv.weight
        d = qkv.shape[1]
        b_blk.attn.q_proj.raw.weight.copy_(convert(qkv[:d]))
        b_blk.attn.k_proj.raw.weight.copy_(convert(qkv[d:2*d]))
        b_blk.attn.v_proj.raw.weight.copy_(convert(qkv[2*d:]))
        b_blk.attn.o_proj.raw.weight.copy_(convert(t_blk.a.o.weight))

        b_blk.ffn.gate.raw.weight.copy_(convert(t_blk.f.g.weight))
        b_blk.ffn.up.raw.weight.copy_(convert(t_blk.f.u.weight))
        b_blk.ffn.down.raw.weight.copy_(convert(t_blk.f.d.weight))


def load_teacher(path, device='cuda'):
    ck = torch.load(path, map_location=device, weights_only=False)
    cfg = ck['args']
    t = FP32LM(
        vocab_size=cfg['vocab_size'], d_model=cfg['d_model'], n_layers=cfg['n_layers'],
        n_heads=cfg['n_heads'], d_ff=cfg['d_ff'], max_seq_len=cfg['seq_len'],
    ).to(device)
    t.load_state_dict(ck['model'])
    t.eval()
    return t, ck