Spaces:
Sleeping
Sleeping
| import re | |
| from pathlib import Path | |
| import torch | |
| import math | |
| from einops import rearrange | |
| def load_checkpoint(path, device='cpu'): | |
| path = Path(path).expanduser() | |
| is_deepspeed = False | |
| if path.is_dir(): # DeepSpeed checkpoint | |
| is_deepspeed = True | |
| latest_path = path / 'latest' | |
| if latest_path.is_file(): | |
| with open(latest_path, 'r') as fd: | |
| tag = fd.read().strip() | |
| else: | |
| raise ValueError(f"Unable to find 'latest' file at {latest_path}") | |
| path /= f'{tag}/mp_rank_00_model_states.pt' | |
| state_dict = torch.load(path, map_location=device) | |
| if is_deepspeed: | |
| state_dict = state_dict['module'] | |
| # Replace the names of some of the submodules | |
| def key_mapping(key): | |
| return re.sub(r'^module.model.', '', key) | |
| state_dict = {key_mapping(k): v for k, v in state_dict.items()} | |
| return state_dict | |
| def blockdiag_to_dense_mlp_bert(state_dict): | |
| from src.ops.blockdiag_multiply import blockdiag_weight_to_dense_weight | |
| names = {name for name in state_dict | |
| if re.match('bert.encoder.layer.(\d+).(mlp.fc(1|2)|(intermediate|output).dense).weight', | |
| name)} | |
| for name in names: | |
| state_dict[name] = blockdiag_weight_to_dense_weight(state_dict[name]) | |
| return state_dict | |
| def interpolate_pos_embedding(state_dict, out_seqlen, pos_embedding_name='model.pos_encoder.pe', interleave=False): | |
| orig_emb = state_dict['state_dict'][pos_embedding_name] | |
| assert (out_seqlen % orig_emb.shape[1]) == 0, 'out_seqlen must be a multiple of the original sequence length' | |
| reps = [1 for i in orig_emb.shape] | |
| reps[1] = out_seqlen // orig_emb.shape[1] | |
| if interleave: | |
| assert math.isqrt(orig_emb.shape[1]) ** 2 == orig_emb.shape[1], 'interleave only works for square lengths' | |
| assert math.isqrt(out_seqlen) ** 2 == out_seqlen, 'interleave only works for square lengths' | |
| assert math.isqrt(reps[1]) ** 2 == reps[1], 'out_seqlen / seqlen must be a perfect square' | |
| emb_square = rearrange(orig_emb, 'b (h w) d -> b h w d', h = math.isqrt(orig_emb.shape[1])) | |
| emb_square_expanded = emb_square.repeat_interleave(math.isqrt(reps[1]), axis=1).repeat_interleave(math.isqrt(reps[1]), axis=2) | |
| new_emb = rearrange(emb_square_expanded, 'b h w d -> b (h w) d') | |
| state_dict['state_dict'][pos_embedding_name] = new_emb | |
| else: | |
| state_dict['state_dict'][pos_embedding_name] = orig_emb.repeat(*reps) | |
| ret = remove_model_prefix(state_dict) | |
| # # HACK: this is a hack for block-sparse flash attention | |
| ret = { | |
| k: v | |
| for k, v in ret.items() | |
| if not k.endswith('inner_attn.layout') | |
| } | |
| return ret | |
| def remove_model_prefix(state_dict): | |
| # HACK: this is a hack to get the model to load properly, get rid of 'model.' prefix | |
| for key in list(state_dict['state_dict'].keys()): | |
| if key.startswith('model.'): | |
| new_key = key[len('model.'):] | |
| state_dict['state_dict'][new_key] = state_dict['state_dict'].pop(key) | |
| # HACK: something is wrong with the state dict being loaded... | |
| return state_dict['state_dict'] | |