File size: 3,270 Bytes
e45d058
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import re
from pathlib import Path

import torch
import math
from einops import rearrange

def load_checkpoint(path, device='cpu'):
    path = Path(path).expanduser()
    is_deepspeed = False
    if path.is_dir():  # DeepSpeed checkpoint
        is_deepspeed = True
        latest_path = path / 'latest'
        if latest_path.is_file():
            with open(latest_path, 'r') as fd:
                tag = fd.read().strip()
        else:
            raise ValueError(f"Unable to find 'latest' file at {latest_path}")
        path /= f'{tag}/mp_rank_00_model_states.pt'
    state_dict = torch.load(path, map_location=device)
    if is_deepspeed:
        state_dict = state_dict['module']

        # Replace the names of some of the submodules
        def key_mapping(key):
            return re.sub(r'^module.model.', '', key)

        state_dict = {key_mapping(k): v for k, v in state_dict.items()}
    return state_dict


def blockdiag_to_dense_mlp_bert(state_dict):
    from src.ops.blockdiag_multiply import blockdiag_weight_to_dense_weight
    names = {name for name in state_dict
             if re.match('bert.encoder.layer.(\d+).(mlp.fc(1|2)|(intermediate|output).dense).weight',
                         name)}
    for name in names:
        state_dict[name] = blockdiag_weight_to_dense_weight(state_dict[name])
    return state_dict

def interpolate_pos_embedding(state_dict, out_seqlen, pos_embedding_name='model.pos_encoder.pe', interleave=False):
    orig_emb = state_dict['state_dict'][pos_embedding_name]
    assert (out_seqlen % orig_emb.shape[1]) == 0, 'out_seqlen must be a multiple of the original sequence length'
    reps = [1 for i in orig_emb.shape]
    reps[1] = out_seqlen // orig_emb.shape[1]
    
    if interleave:
        assert math.isqrt(orig_emb.shape[1]) ** 2 == orig_emb.shape[1], 'interleave only works for square lengths'
        assert math.isqrt(out_seqlen) ** 2 == out_seqlen, 'interleave only works for square lengths'
        assert math.isqrt(reps[1]) ** 2 == reps[1], 'out_seqlen / seqlen must be a perfect square'

        emb_square = rearrange(orig_emb, 'b (h w) d -> b h w d', h = math.isqrt(orig_emb.shape[1]))
        emb_square_expanded = emb_square.repeat_interleave(math.isqrt(reps[1]), axis=1).repeat_interleave(math.isqrt(reps[1]), axis=2)
        new_emb = rearrange(emb_square_expanded, 'b h w d -> b (h w) d')
        state_dict['state_dict'][pos_embedding_name] = new_emb
    else:
        state_dict['state_dict'][pos_embedding_name] = orig_emb.repeat(*reps)

    ret = remove_model_prefix(state_dict)
    # # HACK: this is a hack for block-sparse flash attention
    ret = {
        k: v
        for k, v in ret.items()
        if not k.endswith('inner_attn.layout')
    }
    return ret

def remove_model_prefix(state_dict):
    # HACK: this is a hack to get the model to load properly, get rid of 'model.' prefix
    for key in list(state_dict['state_dict'].keys()):
        if key.startswith('model.'):
            new_key = key[len('model.'):]
            state_dict['state_dict'][new_key] = state_dict['state_dict'].pop(key)

    # HACK: something is wrong with the state dict being loaded...
    return state_dict['state_dict']