PyTorch
ssl-aasist
custom_code
ash56's picture
Add files using upload-large-folder tool
211c22d verified
raw
history blame
1.87 kB
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import re
from collections import OrderedDict
import torch
from fairseq.file_io import PathManager
def is_update(param_name, module_name):
if module_name in param_name:
return True
return False
def load_checkpoint(src_cpt):
with PathManager.open(src_cpt, "rb") as f:
state_src = torch.load(
f,
map_location=(
lambda s, _: torch.serialization.default_restore_location(s, "cpu")
),
)
return state_src
def save_checkpoint(tgt_cpt, states):
with PathManager.open(tgt_cpt, "wb") as f:
torch.save(
states,
f,
)
# convert the pre-trained model into bart model
def main():
parser = argparse.ArgumentParser()
# fmt: off
parser.add_argument('--input-model', required=True,
help='Input checkpoint file path.')
parser.add_argument('--output-model', required=True,
help='output checkpoint file path.')
# fmt: on
args = parser.parse_args()
print(args)
states = load_checkpoint(args.input_model)
model = states["model"]
new_model = OrderedDict()
for key in model.keys():
if re.search("^encoder.text_encoder", key):
new_key = re.sub("encoder.text_encoder", "encoder", key)
new_model[new_key] = model[key]
elif re.search("^decoder.text_decoder", key):
new_key = re.sub("decoder.text_decoder", "decoder", key)
new_model[new_key] = model[key]
states["model"] = new_model
save_checkpoint(args.output_model, states)
if __name__ == "__main__":
main()