# Copyright (c) Alibaba, Inc. and its affiliates. import math import torch from megatron.training.checkpointing import load_checkpoint from megatron.training.checkpointing import save_checkpoint as mg_save_checkpoint from megatron.training.initialize import initialize_megatron from megatron.training.utils import get_ltor_masks_and_position_ids from swift.llm import ExportArguments, get_model_tokenizer, get_template, save_checkpoint from swift.utils import get_logger, get_n_params_grads from ..argument import MegatronArguments from ..model import get_megatron_model_meta from .patcher import patch_megatron_tokenizer, patch_torch_dist_shard logger = get_logger() def test_convert_precision(hf_model, mg_model, processor): torch_dtype = hf_model.dtype template = get_template(hf_model.model_meta.template, processor) input_ids = template.encode({'messages': [{'role': 'user', 'content': 'who are you?'}]})['input_ids'] input_ids = torch.tensor(input_ids)[None].to('cuda') hf_model.to('cuda') hf_model.to(torch.float32) with torch.inference_mode(): hf_logits = hf_model(input_ids).logits hf_model.to(torch_dtype) hf_model.to('cpu') attention_mask, _, position_ids = get_ltor_masks_and_position_ids(input_ids, -100, True, True, True) mg_model.to('cuda') mg_model.to(torch.float32) with torch.inference_mode(): mg_logits = mg_model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids) mg_model.to(torch_dtype) mg_model.to('cpu') mean_diff = (mg_logits - hf_logits).abs().mean().item() max_diff = (mg_logits - hf_logits).abs().max().item() print(f'mean_diff: {mean_diff}, max_diff: {max_diff}') hf_tokens = hf_logits.argmax(-1) mg_tokens = mg_logits.argmax(-1) print(f'hf_tokens: {hf_tokens[0].tolist()}\nmg_tokens: {mg_tokens[0].tolist()}') assert mean_diff < 0.1 assert (hf_tokens == mg_tokens).all() convert_kwargs = { 'use_cpu_initialization': True, 'no_save_optim': True, 'no_save_rng': True, 'no_load_optim': True, 'no_load_rng': True, 'no_masked_softmax_fusion': True, 'no_bias_dropout_fusion': True, 'no_bias_swiglu_fusion': True, 'no_rope_fusion': True } def convert_hf2mcore(args: ExportArguments) -> None: kwargs = args.get_model_kwargs() hf_model, processor = get_model_tokenizer(**kwargs) if args.thread_count is None: checkpoint_size = sum(get_n_params_grads(hf_model)[0]) * torch.finfo(args.torch_dtype).bits // 8e9 args.thread_count = max(math.ceil(checkpoint_size / 10), 2) # 10GB patch_torch_dist_shard(args.thread_count) megatron_model_meta = get_megatron_model_meta(args.model_type) assert megatron_model_meta is not None, f'Model: {args.model} is not supported.' kwargs = megatron_model_meta.convert_hf_config(processor.model_info.config) megatron_args = MegatronArguments(**kwargs, **convert_kwargs, save=args.output_dir, torch_dtype=args.torch_dtype) patch_megatron_tokenizer(processor) extra_args = megatron_args.parse_to_megatron() initialize_megatron(args_defaults=extra_args) mg_model = megatron_model_meta.model_provider() logger.info('Megatron model created successfully.') megatron_model_meta.convert_hf2mcore(hf_model, mg_model) if args.test_convert_precision: test_convert_precision(hf_model, mg_model, processor) logger.info('Successfully transferred HF model weights to MG model.') mg_save_checkpoint(1, [mg_model], None, None, 0) args.save_args() logger.info(f'Successfully saved Megatron model weights in `{args.output_dir}`.') def convert_mcore2hf(args: ExportArguments) -> None: kwargs = args.get_model_kwargs() hf_model, processor = get_model_tokenizer(**kwargs) if args.thread_count is None: checkpoint_size = sum(get_n_params_grads(hf_model)[0]) * torch.finfo(args.torch_dtype).bits // 8e9 args.thread_count = max(math.ceil(checkpoint_size / 10), 2) # 10GB patch_torch_dist_shard(args.thread_count) megatron_model_meta = get_megatron_model_meta(args.model_type) assert megatron_model_meta is not None, f'Model: {args.model} is not supported.' kwargs = megatron_model_meta.convert_hf_config(processor.model_info.config) megatron_args = MegatronArguments(**kwargs, **convert_kwargs, load=args.mcore_model, torch_dtype=args.torch_dtype) patch_megatron_tokenizer(processor) extra_args = megatron_args.parse_to_megatron() initialize_megatron(args_defaults=extra_args) mg_model = megatron_model_meta.model_provider() load_checkpoint([mg_model], None, None, strict=True) logger.info('Megatron model created successfully.') megatron_model_meta.convert_mcore2hf(hf_model, mg_model) if args.test_convert_precision: test_convert_precision(hf_model, mg_model, processor) logger.info('Successfully transferred MG model weights to HF model.') save_checkpoint( hf_model, processor, args.output_dir, safe_serialization=args.safe_serialization, model_dirs=[args.mcore_model, args.model_dir], max_shard_size=args.max_shard_size, additional_saved_files=hf_model.model_meta.additional_saved_files) args.save_args() logger.info(f'Successfully saved HF model weights in `{args.output_dir}`.')