# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. """Export a GPTModel.""" import functools import os import sys import warnings sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))) import modelopt.torch.export as mtex import torch from megatron.post_training.arguments import add_modelopt_args from megatron.post_training.checkpointing import load_modelopt_checkpoint from megatron.post_training.model_provider import model_provider from megatron.training import get_args, get_model from megatron.training.initialize import initialize_megatron from megatron.training.utils import unwrap_model warnings.filterwarnings('ignore') def add_modelopt_export_args(parser): """Add additional arguments for ModelOpt hf-like export.""" group = parser.add_argument_group(title='ModelOpt hf-like export') group.add_argument( "--export-extra-modules", action="store_true", help="Export extra modules such as Medusa, EAGLE, or MTP.", ) group.add_argument( "--pretrained-model-name", type=str, help="A pretrained model hosted inside a model repo on huggingface.co.", ) group.add_argument("--export-dir", type=str, help="The target export path.") add_modelopt_args(parser) return parser if __name__ == "__main__": initialize_megatron( extra_args_provider=add_modelopt_export_args, args_defaults={ 'tokenizer_type': 'HuggingFaceTokenizer', 'no_load_rng': True, 'no_load_optim': True, }, ) args = get_args() # Meta device initialization for ParallelLinear only works if using cpu initialization. # Meta device initialization is used such that models can be materialized in low-precision # directly when ModelOpt real quant is used. Otherwise, the model is first initialized # as BF16 in memory which may result in OOM and defeat the purpose of real quant. args.use_cpu_initialization = True if not args.init_model_with_meta_device: warnings.warn( "--init-model-with-meta-device is not set. If you would like to resume the " "model in low-bit directly (low-memory initialization and skipping 16-bit), " "--init-model-with-meta-device must be set.", UserWarning, ) model = get_model(functools.partial(model_provider, parallel_output=True), wrap_with_ddp=False) # Materialize the model from meta device to cpu before loading the checkpoint. unwrapped_model = unwrap_model(model)[0] unwrapped_model.to_empty(device="cpu") if args.load is not None: _ = load_modelopt_checkpoint(model) # Decide whether we are exporting only the extra_modules (e.g. EAGLE3). # Only the last pp stage may have extra_modules, hence broadcast from the last rank. export_extra_modules = hasattr(unwrapped_model, "eagle_module") or hasattr(unwrapped_model, "medusa_heads") torch.distributed.broadcast_object_list( [export_extra_modules], src=torch.distributed.get_world_size() - 1, ) mtex.export_mcore_gpt_to_hf( unwrapped_model, args.pretrained_model_name, export_extra_modules=export_extra_modules, dtype=torch.bfloat16, export_dir=args.export_dir, )