File size: 1,866 Bytes
211c22d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 | #!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import re
from collections import OrderedDict
import torch
from fairseq.file_io import PathManager
def is_update(param_name, module_name):
if module_name in param_name:
return True
return False
def load_checkpoint(src_cpt):
with PathManager.open(src_cpt, "rb") as f:
state_src = torch.load(
f,
map_location=(
lambda s, _: torch.serialization.default_restore_location(s, "cpu")
),
)
return state_src
def save_checkpoint(tgt_cpt, states):
with PathManager.open(tgt_cpt, "wb") as f:
torch.save(
states,
f,
)
# convert the pre-trained model into bart model
def main():
parser = argparse.ArgumentParser()
# fmt: off
parser.add_argument('--input-model', required=True,
help='Input checkpoint file path.')
parser.add_argument('--output-model', required=True,
help='output checkpoint file path.')
# fmt: on
args = parser.parse_args()
print(args)
states = load_checkpoint(args.input_model)
model = states["model"]
new_model = OrderedDict()
for key in model.keys():
if re.search("^encoder.text_encoder", key):
new_key = re.sub("encoder.text_encoder", "encoder", key)
new_model[new_key] = model[key]
elif re.search("^decoder.text_decoder", key):
new_key = re.sub("decoder.text_decoder", "decoder", key)
new_model[new_key] = model[key]
states["model"] = new_model
save_checkpoint(args.output_model, states)
if __name__ == "__main__":
main()
|