Update convert_nllb_moe_sharded_original_checkpoint_to_pytorch.py
Browse files
convert_nllb_moe_sharded_original_checkpoint_to_pytorch.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
# Copyright
|
| 2 |
#
|
| 3 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
# you may not use this file except in compliance with the License.
|
|
@@ -23,9 +23,6 @@ from transformers.modeling_utils import dtype_byte_size
|
|
| 23 |
from transformers.utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME
|
| 24 |
|
| 25 |
|
| 26 |
-
# 'encoder.layers.7.moe_layer.experts.0.fc2.bias', 'encoder.layers.11.moe_layer.experts.0.fc1.weight',
|
| 27 |
-
|
| 28 |
-
|
| 29 |
def remove_ignore_keys_(state_dict):
|
| 30 |
ignore_keys = [
|
| 31 |
"encoder.version",
|
|
@@ -48,30 +45,30 @@ def make_linear_from_emb(emb):
|
|
| 48 |
return lin_layer
|
| 49 |
|
| 50 |
|
| 51 |
-
def rename_fairseq_keys(state_dict, expert_idx
|
| 52 |
-
# 'encoder.layers.7.moe_layer.experts.0.fc2.bias' ->'encoder.layers.7.ffn.mlp.experts.0.fc2.bias'
|
| 53 |
-
# 'encoder.layers.7.fc2.bias' -> 'encoder.layers.7.ffn.mlp.fc2.bias'
|
| 54 |
-
# encoder.layers.7.wg -> encoder.layers.7.ffn.mlp.router.classifier
|
| 55 |
new_dict = {}
|
| 56 |
for old_key in state_dict.keys():
|
| 57 |
key = old_key
|
| 58 |
if "experts" in key:
|
| 59 |
-
key = key.replace("moe_layer.experts.0", f"ffn.mlp.experts.{expert_idx}")
|
| 60 |
-
elif "fc2" :
|
| 61 |
-
key = key.replace(".fc2.", ".ffn.mlp.fc2")
|
| 62 |
-
elif "fc1" :
|
| 63 |
-
key = key.replace(".fc1.", ".ffn.mlp.fc1")
|
| 64 |
elif "gate" in key:
|
| 65 |
key = key.replace(".moe_layer.gate.wg", ".ffn.mlp.router.classifier")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
new_dict[key] = state_dict[old_key]
|
| 67 |
return new_dict
|
| 68 |
|
| 69 |
|
| 70 |
-
def shard_on_the_fly(
|
| 71 |
-
switch_checkpoint_path, dump_path, num_experts, dtype, weights_name: str = WEIGHTS_NAME
|
| 72 |
-
):
|
| 73 |
sharded_state_dicts = []
|
| 74 |
-
current_block = {}
|
| 75 |
total_size = 0
|
| 76 |
os.makedirs(dump_path, exist_ok=True)
|
| 77 |
|
|
@@ -105,7 +102,6 @@ def shard_on_the_fly(
|
|
| 105 |
|
| 106 |
# Otherwise, let's build the index
|
| 107 |
weight_map = {}
|
| 108 |
-
shards = {}
|
| 109 |
for idx, shard in enumerate(sharded_state_dicts):
|
| 110 |
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin")
|
| 111 |
temp_filename = os.path.join(dump_path, weights_name.replace(".bin", f"-{idx+1:05d}-of-???.bin"))
|
|
@@ -143,23 +139,17 @@ if __name__ == "__main__":
|
|
| 143 |
help="Path to the output pytorch model.",
|
| 144 |
)
|
| 145 |
args = parser.parse_args()
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
|
| 154 |
config = NllbMoeConfig.from_pretrained(
|
| 155 |
-
"facebook/nllb-200-3.3B",
|
| 156 |
-
num_sparse_encoder_layers=4,
|
| 157 |
-
num_sparse_decoder_layers=4,
|
| 158 |
)
|
| 159 |
config.save_pretrained(args.pytorch_dump_folder_path)
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
model = NllbMoeModel(config)
|
| 163 |
model.save_pretrained(args.pytorch_dump_folder_path)
|
| 164 |
-
# model.push_to_hub("ArthurZ/nllb-moe-54b", use_auth_token="")
|
| 165 |
-
# model.save_pretrained(args.pytorch_dump_folder_path)
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
#
|
| 3 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
# you may not use this file except in compliance with the License.
|
|
|
|
| 23 |
from transformers.utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME
|
| 24 |
|
| 25 |
|
|
|
|
|
|
|
|
|
|
| 26 |
def remove_ignore_keys_(state_dict):
|
| 27 |
ignore_keys = [
|
| 28 |
"encoder.version",
|
|
|
|
| 45 |
return lin_layer
|
| 46 |
|
| 47 |
|
| 48 |
+
def rename_fairseq_keys(state_dict, expert_idx=None):
|
|
|
|
|
|
|
|
|
|
| 49 |
new_dict = {}
|
| 50 |
for old_key in state_dict.keys():
|
| 51 |
key = old_key
|
| 52 |
if "experts" in key:
|
| 53 |
+
key = key.replace("moe_layer.experts.0", f"ffn.mlp.experts.expert_{expert_idx}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
elif "gate" in key:
|
| 55 |
key = key.replace(".moe_layer.gate.wg", ".ffn.mlp.router.classifier")
|
| 56 |
+
if "fc2" and "experts" not in key:
|
| 57 |
+
key = key.replace(".fc2.", ".ffn.mlp.fc2.")
|
| 58 |
+
if "fc1" and "experts" not in key:
|
| 59 |
+
key = key.replace(".fc1.", ".ffn.mlp.fc1.")
|
| 60 |
+
if ".encoder_attn." in key:
|
| 61 |
+
key = key.replace(".encoder_attn.", ".cross_attention.")
|
| 62 |
+
if "encoder_attn_layer_norm" in key:
|
| 63 |
+
key = key.replace("encoder_attn_layer_norm", "cross_attention_layer_norm")
|
| 64 |
+
if "final_layer_norm" in key:
|
| 65 |
+
key = key.replace("final_layer_norm", "ffn.layer_norm")
|
| 66 |
new_dict[key] = state_dict[old_key]
|
| 67 |
return new_dict
|
| 68 |
|
| 69 |
|
| 70 |
+
def shard_on_the_fly(switch_checkpoint_path, dump_path, num_experts, dtype, weights_name: str = WEIGHTS_NAME):
|
|
|
|
|
|
|
| 71 |
sharded_state_dicts = []
|
|
|
|
| 72 |
total_size = 0
|
| 73 |
os.makedirs(dump_path, exist_ok=True)
|
| 74 |
|
|
|
|
| 102 |
|
| 103 |
# Otherwise, let's build the index
|
| 104 |
weight_map = {}
|
|
|
|
| 105 |
for idx, shard in enumerate(sharded_state_dicts):
|
| 106 |
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin")
|
| 107 |
temp_filename = os.path.join(dump_path, weights_name.replace(".bin", f"-{idx+1:05d}-of-???.bin"))
|
|
|
|
| 139 |
help="Path to the output pytorch model.",
|
| 140 |
)
|
| 141 |
args = parser.parse_args()
|
| 142 |
+
metadata, index = shard_on_the_fly(
|
| 143 |
+
args.nllb_moe_checkpoint_path,
|
| 144 |
+
args.pytorch_dump_folder_path,
|
| 145 |
+
128,
|
| 146 |
+
args.dtype,
|
| 147 |
+
)
|
|
|
|
| 148 |
|
| 149 |
config = NllbMoeConfig.from_pretrained(
|
| 150 |
+
"facebook/nllb-200-3.3B", encoder_sparse_step=4, decoder_sparse_step=4, num_experts=128
|
|
|
|
|
|
|
| 151 |
)
|
| 152 |
config.save_pretrained(args.pytorch_dump_folder_path)
|
| 153 |
+
model = NllbMoeModel.from_pretrained(args.pytorch_dump_folder_path)
|
| 154 |
+
print("Done")
|
|
|
|
| 155 |
model.save_pretrained(args.pytorch_dump_folder_path)
|
|
|
|
|
|