Spaces:
Runtime error
Runtime error
| VOCAB_SIZE = 32000 | |
| import torch | |
| from scipy.optimize import linear_sum_assignment as LAP | |
| from ..utils import pdists | |
| from .model import permute_model, permute_transformer_block | |
| def match_wmats(wmat0, wmat1): | |
| dists = pdists(wmat0, wmat1).type(torch.float64) | |
| perm = LAP(dists)[1] | |
| return perm # wmat1[perm] should match wmat0 | |
| def match_mlp(base_model, ft_model, i=0): | |
| base_wmat = base_model.state_dict()["model.layers." + str(i) + ".mlp.gate_proj.weight"] | |
| ft_wmat = ft_model.state_dict()["model.layers." + str(i) + ".mlp.gate_proj.weight"] | |
| perm = match_wmats(base_wmat, ft_wmat) | |
| return perm | |
| def match_emb(base_model, ft_model, i="inp"): | |
| if i == "inp": | |
| weight_id = "model.embed_tokens.weight" | |
| if i == "out": | |
| weight_id = "lm_head.weight" | |
| base_wmat = base_model.state_dict()[weight_id][:VOCAB_SIZE].T | |
| ft_wmat = ft_model.state_dict()[weight_id][:VOCAB_SIZE].T | |
| perm = match_wmats(base_wmat, ft_wmat) | |
| return perm | |
| def align_model(base_model, ft_model, tmp_model, n_blocks=32): | |
| emb_perm = match_emb(base_model, ft_model) | |
| permute_model(ft_model, tmp_model, torch.arange(11008), emb_perm) | |
| for i in range(n_blocks): | |
| mlp_perm = match_mlp(base_model, tmp_model, i=i) | |
| permute_transformer_block(tmp_model, i, tmp_model, mlp_perm, torch.arange(4096)) | |