CLIMB -> HF -> Megatron

#2
by sarahyurick - opened
NVIDIA org

Hi folks, if it is useful, consider the following conversion scripts (generated using Codex).

This maps the released checkpoint tensors into a standard LlamaForCausalLM layout. In particular, the released state dict stores the feed-forward block under moe.experts.0.{gate,up,down}_proj; since there is only experts.0, the script maps those tensors to the usual Llama mlp.{gate,up,down}_proj keys. It does not change the weights.

If you only want a Hugging Face checkpoint, climb_to_hf.py is enough. If you want to continue training with Megatron Bridge / Megatron-Core, run hf_to_megatron.py after that.

  1. climb_to_hf.py
from __future__ import annotations

import argparse
import enum
import re
import sys
import types
from pathlib import Path

import torch
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer


VARIANTS = {
    "62m": {
        "hidden_size": 384,
        "intermediate_size": 1024,
        "num_attention_heads": 6,
        "num_key_value_heads": 2,
    },
    "350m": {
        "hidden_size": 960,
        "intermediate_size": 2560,
        "num_attention_heads": 15,
        "num_key_value_heads": 5,
    },
}


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Convert Nemotron-CLIMB proxy model_optim_rng.pt to Hugging Face Llama format."
    )
    parser.add_argument("--checkpoint-path", required=True, type=Path)
    parser.add_argument("--tokenizer-model", required=True, type=Path)
    parser.add_argument("--output-dir", required=True, type=Path)
    parser.add_argument("--variant", choices=["auto", *VARIANTS.keys()], default="auto")
    parser.add_argument("--max-position-embeddings", type=int, default=1024)
    return parser.parse_args()


def install_megatron_pickle_stub() -> None:
    """Provide the one Megatron enum needed to unpickle released checkpoints."""
    try:
        import megatron.core.enums  # noqa: F401

        return
    except ImportError:
        pass

    megatron = types.ModuleType("megatron")
    core = types.ModuleType("megatron.core")
    enums = types.ModuleType("megatron.core.enums")

    class ModelType(enum.Enum):
        encoder_or_decoder = 1

    ModelType.__module__ = "megatron.core.enums"
    enums.ModelType = ModelType
    core.enums = enums
    megatron.core = core
    sys.modules["megatron"] = megatron
    sys.modules["megatron.core"] = core
    sys.modules["megatron.core.enums"] = enums


def load_model_state(checkpoint_path: Path) -> dict[str, torch.Tensor]:
    install_megatron_pickle_stub()
    checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
    state = checkpoint.get("model", checkpoint)
    return {key: value for key, value in state.items() if isinstance(value, torch.Tensor)}


def detect_variant(state: dict[str, torch.Tensor]) -> str:
    embed_key = next(key for key in state if key.endswith("embed_tokens.weight"))
    hidden_size = state[embed_key].shape[1]
    for variant, config in VARIANTS.items():
        if config["hidden_size"] == hidden_size:
            return variant
    raise ValueError(f"Could not infer variant from hidden size {hidden_size}; pass --variant explicitly.")


def make_config(
    state: dict[str, torch.Tensor],
    variant: str,
    max_position_embeddings: int,
) -> LlamaConfig:
    embed_key = next(key for key in state if key.endswith("embed_tokens.weight"))
    vocab_size, hidden_size = state[embed_key].shape
    layer_ids = {
        int(match.group(1))
        for key in state
        if (match := re.search(r"\.layers\.(\d+)\.", key))
    }
    variant_config = VARIANTS[variant]
    if hidden_size != variant_config["hidden_size"]:
        raise ValueError(
            f"Variant {variant} expects hidden size {variant_config['hidden_size']}, "
            f"but checkpoint has {hidden_size}."
        )

    return LlamaConfig(
        vocab_size=vocab_size,
        hidden_size=variant_config["hidden_size"],
        intermediate_size=variant_config["intermediate_size"],
        num_hidden_layers=max(layer_ids) + 1,
        num_attention_heads=variant_config["num_attention_heads"],
        num_key_value_heads=variant_config["num_key_value_heads"],
        hidden_act="silu",
        max_position_embeddings=max_position_embeddings,
        initializer_range=0.02,
        rms_norm_eps=1e-5,
        use_cache=True,
        pad_token_id=0,
        bos_token_id=1,
        eos_token_id=2,
        tie_word_embeddings=True,
        architectures=["LlamaForCausalLM"],
        torch_dtype="bfloat16",
    )


def convert_key(key: str) -> str | None:
    if key.endswith("embed_tokens.weight"):
        return "model.embed_tokens.weight"
    if key.endswith("lm_head.weight"):
        return "lm_head.weight"
    if key.endswith("final_layernorm.weight"):
        return "model.norm.weight"

    match = re.search(r"\.layers\.(\d+)\.", key)
    if not match:
        return None

    layer = match.group(1)
    layer_prefix = f"model.layers.{layer}"
    suffix_map = {
        "self_attn.q_proj.weight": f"{layer_prefix}.self_attn.q_proj.weight",
        "self_attn.k_proj.weight": f"{layer_prefix}.self_attn.k_proj.weight",
        "self_attn.v_proj.weight": f"{layer_prefix}.self_attn.v_proj.weight",
        "self_attn.o_proj.weight": f"{layer_prefix}.self_attn.o_proj.weight",
        "moe.experts.0.gate_proj.weight": f"{layer_prefix}.mlp.gate_proj.weight",
        "moe.experts.0.up_proj.weight": f"{layer_prefix}.mlp.up_proj.weight",
        "moe.experts.0.down_proj.weight": f"{layer_prefix}.mlp.down_proj.weight",
        "input_layernorm.weight": f"{layer_prefix}.input_layernorm.weight",
        "pre_moe_layernorm.weight": f"{layer_prefix}.post_attention_layernorm.weight",
    }
    for old_suffix, new_key in suffix_map.items():
        if key.endswith(old_suffix):
            return new_key
    return None


def convert_state_dict(state: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
    converted = {}
    skipped = []
    for old_key, tensor in state.items():
        new_key = convert_key(old_key)
        if new_key is None:
            skipped.append(old_key)
            continue
        if new_key in converted:
            raise ValueError(f"Two checkpoint keys map to {new_key}")
        converted[new_key] = tensor

    if skipped:
        raise ValueError("Unsupported checkpoint tensor keys:\n" + "\n".join(sorted(skipped)))
    return converted


def make_tokenizer(tokenizer_model: Path) -> LlamaTokenizer:
    tokenizer = LlamaTokenizer(
        vocab_file=str(tokenizer_model),
        legacy=False,
        bos_token="<s>",
        eos_token="</s>",
        unk_token="<unk>",
    )
    tokenizer.pad_token = tokenizer.unk_token
    return tokenizer


def main() -> None:
    args = parse_args()
    state = load_model_state(args.checkpoint_path)
    variant = detect_variant(state) if args.variant == "auto" else args.variant
    config = make_config(state, variant, args.max_position_embeddings)
    converted = convert_state_dict(state)

    model = LlamaForCausalLM(config).to(dtype=torch.bfloat16)
    missing, unexpected = model.load_state_dict(converted, strict=True)
    if missing or unexpected:
        raise RuntimeError(f"Missing keys: {missing}\nUnexpected keys: {unexpected}")

    tokenizer = make_tokenizer(args.tokenizer_model)
    print(
        f"Converted {args.checkpoint_path} as {variant}: "
        f"{config.num_hidden_layers} layers, hidden={config.hidden_size}, vocab={config.vocab_size}"
    )

    args.output_dir.mkdir(parents=True, exist_ok=True)
    model.save_pretrained(args.output_dir, safe_serialization=True, max_shard_size="10GB")
    tokenizer.save_pretrained(args.output_dir)
    print(f"Saved Hugging Face checkpoint to {args.output_dir}")


if __name__ == "__main__":
    main()
  1. hf_to_megatron.py (requires Megatron Bridge, not plain Transformers)
from __future__ import annotations

import argparse
from pathlib import Path

import torch
from megatron.bridge import AutoBridge


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Import a Hugging Face Llama checkpoint into a Megatron checkpoint."
    )
    parser.add_argument("--hf-model-dir", required=True, type=Path)
    parser.add_argument("--megatron-output-dir", required=True, type=Path)
    parser.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16")
    parser.add_argument("--trust-remote-code", action="store_true")
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    dtype = {
        "bf16": torch.bfloat16,
        "fp16": torch.float16,
        "fp32": torch.float32,
    }[args.dtype]

    AutoBridge.import_ckpt(
        args.hf_model_dir,
        args.megatron_output_dir,
        torch_dtype=dtype,
        trust_remote_code=args.trust_remote_code,
    )
    print(f"Imported {args.hf_model_dir} to {args.megatron_output_dir}")


if __name__ == "__main__":
    main()

I validated this path locally for both 62M and 350M. Please treat it as a practical helper script rather than an official conversion utility.

Example usage:

python climb_to_hf.py \
  --checkpoint-path nemotron_climb_proxy_model_62m/iter_2499000/mp_rank_00/model_optim_rng.pt \
  --tokenizer-model tokenizer.model \
  --output-dir nemotron_climb_proxy_model_62m_hf \
  --variant 62m

python hf_to_megatron.py \
  --hf-model-dir nemotron_climb_proxy_model_62m_hf \
  --megatron-output-dir nemotron_climb_proxy_model_62m_megatron_core

Happy training!

sarahyurick changed discussion title from Megatron -> HF -> Megatron Bridge to CLIMB -> HF -> Megatron

Sign up or log in to comment