|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
|
|
|
from lightning.pytorch import Trainer |
|
|
from omegaconf import OmegaConf, open_dict |
|
|
|
|
|
from nemo.collections.nlp.models.language_modeling.megatron_bart_model import MegatronBARTModel |
|
|
from nemo.collections.nlp.models.language_modeling.megatron_bert_model import MegatronBertModel |
|
|
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel |
|
|
from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel |
|
|
from nemo.collections.nlp.models.language_modeling.megatron_t5_model import MegatronT5Model |
|
|
|
|
|
try: |
|
|
from nemo.collections.nlp.models.machine_translation.megatron_nmt_model import MegatronNMTModel |
|
|
except ModuleNotFoundError: |
|
|
from abc import ABC |
|
|
|
|
|
MegatronNMTModel = ABC |
|
|
from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel |
|
|
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector |
|
|
from nemo.core import ModelPT |
|
|
from nemo.core.config import hydra_runner |
|
|
from nemo.utils import logging |
|
|
from nemo.utils.app_state import AppState |
|
|
from nemo.utils.model_utils import inject_model_parallel_rank |
|
|
|
|
|
|
|
|
def get_model_class(cfg): |
|
|
if cfg.model_type == 'gpt': |
|
|
return MegatronGPTModel |
|
|
elif cfg.model_type == 'bert': |
|
|
return MegatronBertModel |
|
|
elif cfg.model_type == 't5': |
|
|
return MegatronT5Model |
|
|
elif cfg.model_type == 'bart': |
|
|
return MegatronBARTModel |
|
|
elif cfg.model_type == 'nmt': |
|
|
return MegatronNMTModel |
|
|
elif cfg.model_type == 'retro': |
|
|
return MegatronRetrievalModel |
|
|
else: |
|
|
raise ValueError("Invalid Model Type") |
|
|
|
|
|
|
|
|
@hydra_runner(config_path="conf", config_name="megatron_gpt_export") |
|
|
def nemo_export(cfg): |
|
|
"""Convert a nemo model into .onnx ONNX format.""" |
|
|
nemo_in = None |
|
|
if cfg.gpt_model_file: |
|
|
nemo_in = cfg.gpt_model_file |
|
|
elif cfg.checkpoint_dir: |
|
|
nemo_in = os.path.join(cfg.checkpoint_dir, cfg.checkpoint_name) |
|
|
assert nemo_in is not None, "NeMo model not provided. Please provide the path to the .nemo or .ckpt file" |
|
|
|
|
|
onnx_out = cfg.onnx_model_file |
|
|
|
|
|
trainer = Trainer(strategy=NLPDDPStrategy(), **cfg.trainer) |
|
|
assert ( |
|
|
cfg.trainer.devices * cfg.trainer.num_nodes |
|
|
== cfg.tensor_model_parallel_size * cfg.pipeline_model_parallel_size |
|
|
), "devices * num_nodes should equal tensor_model_parallel_size * pipeline_model_parallel_size" |
|
|
|
|
|
logging.info("Restoring NeMo model from '{}'".format(nemo_in)) |
|
|
try: |
|
|
if cfg.gpt_model_file: |
|
|
save_restore_connector = NLPSaveRestoreConnector() |
|
|
if os.path.isdir(cfg.gpt_model_file): |
|
|
save_restore_connector.model_extracted_dir = cfg.gpt_model_file |
|
|
|
|
|
pretrained_cfg = ModelPT.restore_from( |
|
|
restore_path=cfg.gpt_model_file, |
|
|
trainer=trainer, |
|
|
return_config=True, |
|
|
save_restore_connector=save_restore_connector, |
|
|
) |
|
|
OmegaConf.set_struct(pretrained_cfg, True) |
|
|
with open_dict(pretrained_cfg): |
|
|
pretrained_cfg.sequence_parallel = False |
|
|
pretrained_cfg.activations_checkpoint_granularity = None |
|
|
pretrained_cfg.activations_checkpoint_method = None |
|
|
pretrained_cfg.precision = trainer.precision |
|
|
if trainer.precision == "16": |
|
|
pretrained_cfg.megatron_amp_O2 = False |
|
|
model = ModelPT.restore_from( |
|
|
restore_path=cfg.gpt_model_file, |
|
|
trainer=trainer, |
|
|
override_config_path=pretrained_cfg, |
|
|
save_restore_connector=save_restore_connector, |
|
|
) |
|
|
elif cfg.checkpoint_dir: |
|
|
app_state = AppState() |
|
|
if cfg.tensor_model_parallel_size > 1 or cfg.pipeline_model_parallel_size > 1: |
|
|
app_state.model_parallel_size = cfg.tensor_model_parallel_size * cfg.pipeline_model_parallel_size |
|
|
app_state.tensor_model_parallel_size = cfg.tensor_model_parallel_size |
|
|
app_state.pipeline_model_parallel_size = cfg.pipeline_model_parallel_size |
|
|
( |
|
|
app_state.tensor_model_parallel_rank, |
|
|
app_state.pipeline_model_parallel_rank, |
|
|
app_state.model_parallel_size, |
|
|
app_state.data_parallel_size, |
|
|
app_state.pipeline_model_parallel_split_rank, |
|
|
app_state.virtual_pipeline_model_parallel_rank, |
|
|
) = fake_initialize_model_parallel( |
|
|
world_size=app_state.model_parallel_size, |
|
|
rank=trainer.global_rank, |
|
|
tensor_model_parallel_size_=cfg.tensor_model_parallel_size, |
|
|
pipeline_model_parallel_size_=cfg.pipeline_model_parallel_size, |
|
|
pipeline_model_parallel_split_rank_=cfg.pipeline_model_parallel_split_rank, |
|
|
) |
|
|
checkpoint_path = inject_model_parallel_rank(os.path.join(cfg.checkpoint_dir, cfg.checkpoint_name)) |
|
|
model_cls = get_model_class(cfg) |
|
|
model = model_cls.load_from_checkpoint(checkpoint_path, hparams_file=cfg.hparams_file, trainer=trainer) |
|
|
else: |
|
|
raise ValueError("need at least a nemo file or checkpoint dir") |
|
|
except Exception as e: |
|
|
logging.error( |
|
|
"Failed to restore model from NeMo file : {}. Please make sure you have the latest NeMo package installed with [all] dependencies.".format( |
|
|
nemo_in |
|
|
) |
|
|
) |
|
|
raise e |
|
|
|
|
|
logging.info("Model {} restored from '{}'".format(model.__class__.__name__, nemo_in)) |
|
|
|
|
|
|
|
|
check_trace = cfg.export_options.runtime_check |
|
|
|
|
|
try: |
|
|
model.to(device=cfg.export_options.device).freeze() |
|
|
model.eval() |
|
|
model.export( |
|
|
onnx_out, |
|
|
onnx_opset_version=cfg.export_options.onnx_opset, |
|
|
do_constant_folding=cfg.export_options.do_constant_folding, |
|
|
dynamic_axes={ |
|
|
'input_ids': {0: "sequence", 1: "batch"}, |
|
|
'position_ids': {0: "sequence", 1: "batch"}, |
|
|
'logits': {0: "sequence", 1: "batch"}, |
|
|
}, |
|
|
check_trace=check_trace, |
|
|
check_tolerance=cfg.export_options.check_tolerance, |
|
|
verbose=cfg.export_options.verbose, |
|
|
) |
|
|
except Exception as e: |
|
|
logging.error( |
|
|
"Export failed. Please make sure your NeMo model class ({}) has working export() and that you have the latest NeMo package installed with [all] dependencies.".format( |
|
|
model.__class__ |
|
|
) |
|
|
) |
|
|
raise e |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
nemo_export() |
|
|
|