| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """ |
| | Converts BERT NeMo0.* checkpoints to NeMo1.0 format. |
| | """ |
| |
|
| | from argparse import ArgumentParser |
| |
|
| | import torch |
| |
|
| | parser = ArgumentParser() |
| | parser.add_argument("--bert_encoder", required=True, help="path to BERT encoder, e.g. /../BERT-STEP-2285714.pt") |
| | parser.add_argument( |
| | "--bert_token_classifier", |
| | required=True, |
| | help="path to BERT token classifier, e.g. /../BertTokenClassifier-STEP-2285714.pt", |
| | ) |
| | parser.add_argument( |
| | "--bert_sequence_classifier", |
| | required=False, |
| | default=None, |
| | help="path to BERT sequence classifier, e.g /../SequenceClassifier-STEP-2285714.pt", |
| | ) |
| | parser.add_argument( |
| | "--output_path", required=False, default="converted_model.pt", help="output path to newly converted model" |
| | ) |
| | args = parser.parse_args() |
| |
|
| | bert_in = torch.load(args.bert_encoder) |
| | tok_in = torch.load(args.bert_token_classifier) |
| | if args.bert_sequence_classifier: |
| | seq_in = torch.load(args.bert_sequence_classifier) |
| |
|
| | new_dict = {} |
| | new_model = {"state_dict": new_dict} |
| | for k in bert_in: |
| | new_name = k.replace("bert.", "bert_model.") |
| | new_dict[new_name] = bert_in[k] |
| |
|
| | for k in tok_in: |
| | new_name = "mlm_classifier." + k |
| | new_dict[new_name] = tok_in[k] |
| |
|
| | if args.bert_sequence_classifier: |
| | for k in seq_in: |
| | new_name = "nsp_classifier." + k |
| | new_dict[new_name] = seq_in[k] |
| |
|
| | torch.save(new_model, args.output_path) |
| |
|