| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| r""" |
| Conversion script to convert Huggingface Baichuan2 checkpoints into nemo checkpoint. |
| Example to run this conversion script: |
| python convert_baichuan2_hf_to_nemo.py \ |
| --input_name_or_path <path_to_hf_checkpoints_folder> \ |
| --output_path <path_to_output_nemo_file> |
| """ |
|
|
| import os |
| from argparse import ArgumentParser |
| from collections import OrderedDict |
|
|
| import torch |
| from lightning.pytorch.core.saving import _load_state as ptl_load_state |
| from lightning.pytorch.trainer.trainer import Trainer |
| from omegaconf import OmegaConf |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel |
| from nemo.collections.nlp.parts.nlp_overrides import ( |
| GradScaler, |
| MegatronHalfPrecisionPlugin, |
| NLPDDPStrategy, |
| NLPSaveRestoreConnector, |
| PipelineMixedPrecisionPlugin, |
| ) |
| from nemo.collections.nlp.parts.utils_funcs import load_state_dict_helper, torch_dtype_from_precision |
| from nemo.utils import logging |
|
|
|
|
| def get_args(): |
| parser = ArgumentParser() |
| parser.add_argument( |
| "--input_name_or_path", |
| type=str, |
| default=None, |
| required=True, |
| help="Path to Huggingface baichuan2 checkpoints", |
| ) |
| parser.add_argument("--output_path", type=str, default=None, required=True, help="Path to output .nemo file.") |
| parser.add_argument( |
| "--hparams_file", |
| type=str, |
| default=os.path.join( |
| os.path.dirname(__file__), '../../examples/nlp/language_modeling/conf/megatron_baichuan2_config.yaml' |
| ), |
| required=False, |
| help="Path config for restoring. It's created during training and may need to be modified during restore if restore environment is different than training. Ex: /raid/nemo_experiments/megatron_gpt/hparams.yaml", |
| ) |
| parser.add_argument("--precision", type=str, default="32", help="Model precision") |
| args = parser.parse_args() |
| return args |
|
|
|
|
| def load_model(cls, checkpoint, strict, **kwargs): |
| try: |
| if 'cfg' in kwargs: |
| model = ptl_load_state(cls, checkpoint, strict=strict, **kwargs) |
| else: |
| model = cls(cfg=checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY], **kwargs) |
| for name, module in model.named_parameters(): |
| if name in checkpoint['state_dict']: |
| module.data = checkpoint['state_dict'][name] |
| checkpoint['state_dict'].pop(name) |
| else: |
| print(f"Unexpected key: {name} not in checkpoint but in model.") |
|
|
| for name, buffer in model.named_buffers(): |
| if name in checkpoint['state_dict']: |
| buffer.data = checkpoint['state_dict'][name] |
| checkpoint['state_dict'].pop(name) |
|
|
| if len(checkpoint['state_dict'].keys()) != 0: |
| raise RuntimeError( |
| f"Additional keys: {checkpoint['state_dict'].keys()} in checkpoint but not in model." |
| ) |
|
|
| |
| cfg = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] |
| if cfg.tokenizer.model is not None: |
| model.register_artifact("tokenizer.tokenizer_model", cfg.tokenizer.model) |
| if cfg.tokenizer.vocab_file is not None: |
| model.register_artifact("tokenizer.vocab_file", cfg.tokenizer.vocab_file) |
| if cfg.tokenizer.merge_file is not None: |
| model.register_artifact("tokenizer.merge_file", cfg.tokenizer.merge_file) |
| finally: |
| cls._set_model_restore_state(is_being_restored=False) |
| return model |
|
|
|
|
| def load_config(args, baichuan2_config): |
| nemo_config = OmegaConf.load(args.hparams_file).model |
| if 'max_position_embeddings' in baichuan2_config: |
| nemo_config.encoder_seq_length = baichuan2_config['max_position_embeddings'] |
| nemo_config.num_layers = baichuan2_config['num_hidden_layers'] |
| nemo_config.hidden_size = baichuan2_config['hidden_size'] |
| nemo_config.ffn_hidden_size = baichuan2_config['intermediate_size'] |
| nemo_config.num_attention_heads = baichuan2_config['num_attention_heads'] |
| nemo_config.num_query_groups = baichuan2_config['num_attention_heads'] |
| nemo_config.init_method_std = baichuan2_config['initializer_range'] |
| nemo_config.layernorm_epsilon = baichuan2_config['rms_norm_eps'] |
| nemo_config.use_cpu_initialization = True |
| nemo_config.activation = 'fast-swiglu' |
| nemo_config.tokenizer.model = baichuan2_config['tokenizer_model'] |
| if nemo_config.num_layers == 32: |
| nemo_config.position_embedding_type = 'rope' |
| elif nemo_config.num_layers == 40: |
| nemo_config.position_embedding_type = 'alibi' |
|
|
| base = 128 |
| while baichuan2_config['vocab_size'] % base != 0: |
| base //= 2 |
| nemo_config.make_vocab_size_divisible_by = base |
|
|
| return nemo_config |
|
|
|
|
| def convert(args): |
| logging.info(f"loading checkpoint {args.input_name_or_path}") |
| model = AutoModelForCausalLM.from_pretrained(args.input_name_or_path, trust_remote_code=True) |
| tokenizer = AutoTokenizer.from_pretrained(args.input_name_or_path, trust_remote_code=True) |
| hf_config = vars(model.config) |
| hf_config['tokenizer_model'] = str(tokenizer.vocab_file) |
| print(f"hf_config: {hf_config}") |
| print("named parameters:") |
| for name, param in model.named_parameters(): |
| print(f"- {name}") |
|
|
| nemo_config = load_config(args, hf_config) |
|
|
| if args.precision in ["32", "16"]: |
| precision = int(float(args.precision)) |
| elif args.precision in ["bf16", "bf16-mixed"]: |
| if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): |
| precision = args.precision |
| else: |
| logging.warning("BF16 is not supported on this device. Using FP16 instead.") |
| precision = args.precision[2:] |
| else: |
| precision = args.precision |
|
|
| plugins = [] |
| if precision in [16, '16', 'bf16', '16-mixed', 'bf16-mixed']: |
| scaler = None |
| if precision in [16, '16', '16-mixed']: |
| scaler = GradScaler( |
| init_scale=nemo_config.get('native_amp_init_scale', 2**32), |
| growth_interval=nemo_config.get('native_amp_growth_interval', 1000), |
| hysteresis=nemo_config.get('hysteresis', 2), |
| ) |
| |
| plugin_precision = '16-mixed' |
| else: |
| plugin_precision = 'bf16-mixed' |
|
|
| if nemo_config.get('megatron_amp_O2', False): |
| plugins.append(MegatronHalfPrecisionPlugin(precision=plugin_precision, device='cuda', scaler=scaler)) |
| else: |
| plugins.append(PipelineMixedPrecisionPlugin(precision=plugin_precision, device='cuda', scaler=scaler)) |
|
|
| nemo_config.precision = precision |
| print(f"nemo_config: {nemo_config}") |
|
|
| trainer = Trainer(plugins=plugins, accelerator='cpu', strategy=NLPDDPStrategy()) |
|
|
| hidden_size = hf_config["hidden_size"] |
| head_num = hf_config["num_attention_heads"] |
| head_size = hidden_size // head_num |
| num_layers = hf_config["num_hidden_layers"] |
|
|
| mcore_gpt = nemo_config.mcore_gpt |
|
|
| assert mcore_gpt == nemo_config.get( |
| 'transformer_engine', False |
| ), "mcore_gpt transformer_engine must be enabled (or disabled) together." |
|
|
| param_to_weights = lambda param: param.float() |
|
|
| checkpoint = OrderedDict() |
| checkpoint['state_dict'] = OrderedDict() |
|
|
| embed_weight = model.state_dict()[f'model.embed_tokens.weight'] |
| if mcore_gpt: |
| embed_weights_base_name = f'model.embedding.word_embeddings.weight' |
| else: |
| embed_weights_base_name = f'model.language_model.embedding.word_embeddings.weight' |
| checkpoint['state_dict'][embed_weights_base_name] = param_to_weights(embed_weight) |
|
|
| |
| if f'model.layers.0.self_attn.rotary_emb.inv_freq' in model.state_dict(): |
| rotary_embed_weight = model.state_dict()[f'model.layers.0.self_attn.rotary_emb.inv_freq'] |
| if mcore_gpt: |
| rotary_embed_weight_base_name = f'model.rotary_pos_emb.inv_freq' |
| else: |
| rotary_embed_weight_base_name = f'model.language_model.rotary_pos_emb.inv_freq' |
| checkpoint['state_dict'][rotary_embed_weight_base_name] = param_to_weights(rotary_embed_weight) |
|
|
| if nemo_config.num_query_groups is None or nemo_config.num_query_groups == head_num: |
| num_query_groups = head_num |
| else: |
| num_query_groups = nemo_config.num_query_groups |
| assert head_num % num_query_groups == 0, 'head_num must be divisible by num_query_groups' |
| if mcore_gpt: |
| assert nemo_config.activation.startswith('fast-'), 'mcore only supports fast version of gated linear unit.' |
|
|
| for l in range(int(num_layers)): |
| print(f"converting layer {l}") |
| qkv_weights = model.state_dict()[f'model.layers.{l}.self_attn.W_pack.weight'] |
| qkv_weights = qkv_weights.unflatten(0, (3, hidden_size)) |
| old_tensor_shape = qkv_weights[0].squeeze().size() |
| new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:] |
| new_kv_tensor_shape = (num_query_groups, head_size) + old_tensor_shape[1:] |
| q = qkv_weights[0].squeeze().view(*new_q_tensor_shape) |
| k = qkv_weights[1].squeeze().view(*new_kv_tensor_shape) |
| v = qkv_weights[2].squeeze().view(*new_kv_tensor_shape) |
| qkv_weights = torch.empty((0, head_size) + old_tensor_shape[1:]) |
| heads_per_group = head_num // num_query_groups |
| for i in range(num_query_groups): |
| qkv_weights = torch.cat((qkv_weights, q[i * heads_per_group : (i + 1) * heads_per_group, :, :])) |
| qkv_weights = torch.cat((qkv_weights, k[i : i + 1, :, :])) |
| qkv_weights = torch.cat((qkv_weights, v[i : i + 1, :, :])) |
|
|
| qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size]) |
|
|
| if mcore_gpt: |
| qkv_weights_base_name = f'model.decoder.layers.{l}.self_attention.linear_qkv.weight' |
| else: |
| qkv_weights_base_name = f'model.language_model.encoder.layers.{l}.self_attention.query_key_value.weight' |
| checkpoint['state_dict'][qkv_weights_base_name] = param_to_weights(qkv_weights) |
|
|
| |
| o_weight = model.state_dict()[f'model.layers.{l}.self_attn.o_proj.weight'] |
| if mcore_gpt: |
| o_weight_base_name = f'model.decoder.layers.{l}.self_attention.linear_proj.weight' |
| else: |
| o_weight_base_name = f'model.language_model.encoder.layers.{l}.self_attention.dense.weight' |
| checkpoint['state_dict'][o_weight_base_name] = param_to_weights(o_weight) |
|
|
| |
| mlp_down_weight = model.state_dict()[f'model.layers.{l}.mlp.gate_proj.weight'] |
| mlp_gate_weight = model.state_dict()[f'model.layers.{l}.mlp.up_proj.weight'] |
| if mcore_gpt: |
| mlp_down_base_name = f'model.decoder.layers.{l}.mlp.linear_fc1.weight' |
| else: |
| mlp_down_base_name = f'model.language_model.encoder.layers.{l}.mlp.dense_h_to_4h.weight' |
| mlp_down_weight = torch.cat((mlp_down_weight, mlp_gate_weight), axis=0) |
| checkpoint['state_dict'][mlp_down_base_name] = param_to_weights(mlp_down_weight) |
|
|
| mlp_up_weight = model.state_dict()[f'model.layers.{l}.mlp.down_proj.weight'] |
| if mcore_gpt: |
| mlp_up_base_name = f'model.decoder.layers.{l}.mlp.linear_fc2.weight' |
| else: |
| mlp_up_base_name = f'model.language_model.encoder.layers.{l}.mlp.dense_4h_to_h.weight' |
| checkpoint['state_dict'][mlp_up_base_name] = param_to_weights(mlp_up_weight) |
|
|
| |
| input_ln_weight = model.state_dict()[f'model.layers.{l}.input_layernorm.weight'] |
| if mcore_gpt: |
| input_ln_base_name = f'model.decoder.layers.{l}.self_attention.linear_qkv.layer_norm_weight' |
| else: |
| input_ln_base_name = f'model.language_model.encoder.layers.{l}.input_layernorm.weight' |
| checkpoint['state_dict'][input_ln_base_name] = param_to_weights(input_ln_weight) |
|
|
| post_attn_ln_weight = model.state_dict()[f'model.layers.{l}.post_attention_layernorm.weight'] |
| if mcore_gpt: |
| post_attn_ln_base_name = f'model.decoder.layers.{l}.mlp.linear_fc1.layer_norm_weight' |
| else: |
| post_attn_ln_base_name = f'model.language_model.encoder.layers.{l}.post_attention_layernorm.weight' |
| checkpoint['state_dict'][post_attn_ln_base_name] = param_to_weights(post_attn_ln_weight) |
|
|
| print(f"done layer {l}") |
|
|
| final_ln_weight = model.state_dict()[f'model.norm.weight'] |
| if mcore_gpt: |
| final_ln_base_name = f'model.decoder.final_layernorm.weight' |
| else: |
| final_ln_base_name = f'model.language_model.encoder.final_layernorm.weight' |
| checkpoint['state_dict'][final_ln_base_name] = param_to_weights(final_ln_weight) |
|
|
| output_layer_weight = model.state_dict()[f'lm_head.weight'] |
| if mcore_gpt: |
| output_layer_base_name = f'model.output_layer.weight' |
| else: |
| output_layer_base_name = f'model.language_model.output_layer.weight' |
| checkpoint['state_dict'][output_layer_base_name] = param_to_weights(output_layer_weight) |
|
|
| checkpoint[MegatronGPTModel.CHECKPOINT_HYPER_PARAMS_KEY] = nemo_config |
|
|
| del model |
|
|
| if nemo_config.get('megatron_amp_O2', False): |
| keys = list(checkpoint['state_dict'].keys()) |
| for key in keys: |
| checkpoint['state_dict'][key.replace('model.', 'model.module.', 1)] = checkpoint['state_dict'].pop(key) |
|
|
| model = load_state_dict_helper(MegatronGPTModel, nemo_config, trainer, checkpoint['state_dict']) |
|
|
| model._save_restore_connector = NLPSaveRestoreConnector() |
|
|
| |
| dtype = torch_dtype_from_precision(precision) |
| model = model.to(dtype=dtype) |
| model.cfg.use_cpu_initialization = False |
|
|
| model.save_to(args.output_path) |
| logging.info(f'NeMo model saved to: {args.output_path}') |
|
|
|
|
| if __name__ == '__main__': |
| args = get_args() |
| convert(args) |
|
|