Respair's picture
Upload folder using huggingface_hub
b386992 verified
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from lightning.pytorch import Trainer
from omegaconf import OmegaConf, open_dict
from nemo.collections.nlp.models.language_modeling.megatron_bart_model import MegatronBARTModel
from nemo.collections.nlp.models.language_modeling.megatron_bert_model import MegatronBertModel
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel
from nemo.collections.nlp.models.language_modeling.megatron_t5_model import MegatronT5Model
try:
from nemo.collections.nlp.models.machine_translation.megatron_nmt_model import MegatronNMTModel
except ModuleNotFoundError:
from abc import ABC
MegatronNMTModel = ABC
from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector
from nemo.core import ModelPT
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.app_state import AppState
from nemo.utils.model_utils import inject_model_parallel_rank
def get_model_class(cfg):
if cfg.model_type == 'gpt':
return MegatronGPTModel
elif cfg.model_type == 'bert':
return MegatronBertModel
elif cfg.model_type == 't5':
return MegatronT5Model
elif cfg.model_type == 'bart':
return MegatronBARTModel
elif cfg.model_type == 'nmt':
return MegatronNMTModel
elif cfg.model_type == 'retro':
return MegatronRetrievalModel
else:
raise ValueError("Invalid Model Type")
@hydra_runner(config_path="conf", config_name="megatron_gpt_export")
def nemo_export(cfg):
"""Convert a nemo model into .onnx ONNX format."""
nemo_in = None
if cfg.gpt_model_file:
nemo_in = cfg.gpt_model_file
elif cfg.checkpoint_dir:
nemo_in = os.path.join(cfg.checkpoint_dir, cfg.checkpoint_name)
assert nemo_in is not None, "NeMo model not provided. Please provide the path to the .nemo or .ckpt file"
onnx_out = cfg.onnx_model_file
trainer = Trainer(strategy=NLPDDPStrategy(), **cfg.trainer)
assert (
cfg.trainer.devices * cfg.trainer.num_nodes
== cfg.tensor_model_parallel_size * cfg.pipeline_model_parallel_size
), "devices * num_nodes should equal tensor_model_parallel_size * pipeline_model_parallel_size"
logging.info("Restoring NeMo model from '{}'".format(nemo_in))
try:
if cfg.gpt_model_file:
save_restore_connector = NLPSaveRestoreConnector()
if os.path.isdir(cfg.gpt_model_file):
save_restore_connector.model_extracted_dir = cfg.gpt_model_file
pretrained_cfg = ModelPT.restore_from(
restore_path=cfg.gpt_model_file,
trainer=trainer,
return_config=True,
save_restore_connector=save_restore_connector,
)
OmegaConf.set_struct(pretrained_cfg, True)
with open_dict(pretrained_cfg):
pretrained_cfg.sequence_parallel = False
pretrained_cfg.activations_checkpoint_granularity = None
pretrained_cfg.activations_checkpoint_method = None
pretrained_cfg.precision = trainer.precision
if trainer.precision == "16":
pretrained_cfg.megatron_amp_O2 = False
model = ModelPT.restore_from(
restore_path=cfg.gpt_model_file,
trainer=trainer,
override_config_path=pretrained_cfg,
save_restore_connector=save_restore_connector,
)
elif cfg.checkpoint_dir:
app_state = AppState()
if cfg.tensor_model_parallel_size > 1 or cfg.pipeline_model_parallel_size > 1:
app_state.model_parallel_size = cfg.tensor_model_parallel_size * cfg.pipeline_model_parallel_size
app_state.tensor_model_parallel_size = cfg.tensor_model_parallel_size
app_state.pipeline_model_parallel_size = cfg.pipeline_model_parallel_size
(
app_state.tensor_model_parallel_rank,
app_state.pipeline_model_parallel_rank,
app_state.model_parallel_size,
app_state.data_parallel_size,
app_state.pipeline_model_parallel_split_rank,
app_state.virtual_pipeline_model_parallel_rank,
) = fake_initialize_model_parallel(
world_size=app_state.model_parallel_size,
rank=trainer.global_rank,
tensor_model_parallel_size_=cfg.tensor_model_parallel_size,
pipeline_model_parallel_size_=cfg.pipeline_model_parallel_size,
pipeline_model_parallel_split_rank_=cfg.pipeline_model_parallel_split_rank,
)
checkpoint_path = inject_model_parallel_rank(os.path.join(cfg.checkpoint_dir, cfg.checkpoint_name))
model_cls = get_model_class(cfg)
model = model_cls.load_from_checkpoint(checkpoint_path, hparams_file=cfg.hparams_file, trainer=trainer)
else:
raise ValueError("need at least a nemo file or checkpoint dir")
except Exception as e:
logging.error(
"Failed to restore model from NeMo file : {}. Please make sure you have the latest NeMo package installed with [all] dependencies.".format(
nemo_in
)
)
raise e
logging.info("Model {} restored from '{}'".format(model.__class__.__name__, nemo_in))
# Export
check_trace = cfg.export_options.runtime_check
try:
model.to(device=cfg.export_options.device).freeze()
model.eval()
model.export(
onnx_out,
onnx_opset_version=cfg.export_options.onnx_opset,
do_constant_folding=cfg.export_options.do_constant_folding,
dynamic_axes={
'input_ids': {0: "sequence", 1: "batch"},
'position_ids': {0: "sequence", 1: "batch"},
'logits': {0: "sequence", 1: "batch"},
},
check_trace=check_trace,
check_tolerance=cfg.export_options.check_tolerance,
verbose=cfg.export_options.verbose,
)
except Exception as e:
logging.error(
"Export failed. Please make sure your NeMo model class ({}) has working export() and that you have the latest NeMo package installed with [all] dependencies.".format(
model.__class__
)
)
raise e
if __name__ == '__main__':
nemo_export()