| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from typing import Dict |
| |
|
| | import torch |
| |
|
| |
|
| | class AttnProcsLayers(torch.nn.Module): |
| | def __init__(self, state_dict: Dict[str, torch.Tensor]): |
| | super().__init__() |
| | self.layers = torch.nn.ModuleList(state_dict.values()) |
| | self.mapping = dict(enumerate(state_dict.keys())) |
| | self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} |
| |
|
| | |
| | self.split_keys = [".processor", ".self_attn"] |
| |
|
| | |
| | |
| | def map_to(module, state_dict, *args, **kwargs): |
| | new_state_dict = {} |
| | for key, value in state_dict.items(): |
| | num = int(key.split(".")[1]) |
| | new_key = key.replace(f"layers.{num}", module.mapping[num]) |
| | new_state_dict[new_key] = value |
| |
|
| | return new_state_dict |
| |
|
| | def remap_key(key, state_dict): |
| | for k in self.split_keys: |
| | if k in key: |
| | return key.split(k)[0] + k |
| |
|
| | raise ValueError( |
| | f"There seems to be a problem with the state_dict: {set(state_dict.keys())}. {key} has to have one of {self.split_keys}." |
| | ) |
| |
|
| | def map_from(module, state_dict, *args, **kwargs): |
| | all_keys = list(state_dict.keys()) |
| | for key in all_keys: |
| | replace_key = remap_key(key, state_dict) |
| | new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}") |
| | state_dict[new_key] = state_dict[key] |
| | del state_dict[key] |
| |
|
| | self._register_state_dict_hook(map_to) |
| | self._register_load_state_dict_pre_hook(map_from, with_module=True) |
| |
|