# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. import json import os import sys import types import torch from utils import _ConverterFakeProcessGroup, print_memory_usage class MegatronCheckpointLoaderBase: """Orchestrates loading a Megatron checkpoint and sending model parameters over a given multiprocessing queue. Args: args: argparse Namespace with Megatron checkpoint configurations. queue: A multiprocessing.Queue (or similar) used to send out loaded tensors. """ def __init__(self, args, queue, build_tokenizer=False): self.args = args self.queue = queue self.build_tokenizer = build_tokenizer self.margs = None # Will hold Megatron's main args self.checkpoint_args = None # Will hold additional checkpoint args self.all_models = None # Model sharded over different parallelism self.md = None # Metadata sent to the saver self.consumed_train_samples = None self.consumed_valid_samples = None def _maybe_parse_additional_megatron_args(self, margs, checkpoint_args): """ Method used to optionally add arguments from the checkpoint to the main args. For instance, using margs.some_arg = checkpoint_args.some_arg """ return margs def parse_megatron_args(self): """ Parse Megatron arguments by forcibly overwriting sys.argv. Populates self.margs and self.checkpoint_args. """ # Ensure we can import Megatron sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) if self.args.megatron_path is not None: sys.path.insert(0, self.args.megatron_path) try: from megatron.training.arguments import parse_args, validate_args from megatron.training.checkpointing import load_args_from_checkpoint except ModuleNotFoundError: print("Unable to import Megatron. Please specify --megatron-path. Exiting.") self.queue.put("exit") sys.exit(1) # Overwrite sys.argv sys.argv = self.build_sys_argv() margs = parse_args() margs, checkpoint_args = load_args_from_checkpoint(margs) # Adjust world size so validation doesn't fail margs.world_size = margs.tensor_model_parallel_size * margs.pipeline_model_parallel_size # Copy data types from checkpoint margs.fp16 = checkpoint_args.fp16 margs.bf16 = checkpoint_args.bf16 # Expert parallelism requires sequence parallelism if margs.expert_model_parallel_size > 1: margs.sequence_parallel = True margs = self._maybe_parse_additional_megatron_args(margs, checkpoint_args) # Validate final arguments try: from megatron.training.arguments import validate_args margs = validate_args(margs) except Exception as e: print(f"Error validating Megatron arguments: {e}") self.queue.put("exit") sys.exit(1) margs.use_legacy_models = False margs.transformer_impl = self.args.loader_transformer_impl if self.args.loader_transformer_impl == "local" and margs.normalization == "RMSNorm": margs.no_persist_layer_norm = True self.margs = margs self.checkpoint_args = checkpoint_args def _maybe_ensure_additional_required_arguments(self): """ Can be used to ensure some expected args are present. For instance, use self.check_for_arg('some_arg') """ pass def check_for_arg(self, arg_name, default=None): if getattr(self.margs, arg_name, None) is None: if default is not None: setattr(self.margs, arg_name, default) else: print(f"Checkpoint does not specify argument {arg_name}. Exiting.") print(f"Arguments: {self.margs}") self.queue.put("exit") sys.exit(1) def ensure_required_arguments(self): """ Ensure that certain Megatron arguments (from checkpoint) are present. If missing, either set defaults or exit. """ self.check_for_arg('tensor_model_parallel_size') self.check_for_arg('pipeline_model_parallel_size') self.check_for_arg('num_layers') self.check_for_arg('hidden_size') self.check_for_arg('seq_length') self.check_for_arg('num_attention_heads') self.check_for_arg('max_position_embeddings') self.check_for_arg('position_embedding_type') self.check_for_arg('tokenizer_type') self.check_for_arg('iteration') self.check_for_arg('bert_binary_head') self.check_for_arg('disable_bias_linear', False) self.check_for_arg('params_dtype') self.check_for_arg('swiglu', False) self._maybe_ensure_additional_required_arguments() def initialize_megatron_env(self): """ Initialize Megatron global variables and fused kernels. """ try: from megatron.training.global_vars import set_global_variables from megatron.core import mpu from megatron.legacy import fused_kernels except ModuleNotFoundError as e: print(f"Unable to import required Megatron modules: {e}") self.queue.put("exit") sys.exit(1) set_global_variables(self.margs, build_tokenizer=self.build_tokenizer) mpu.set_tensor_model_parallel_world_size(self.margs.tensor_model_parallel_size) mpu.set_pipeline_model_parallel_world_size(self.margs.pipeline_model_parallel_size) mpu.set_virtual_pipeline_model_parallel_world_size(self.margs.virtual_pipeline_model_parallel_size) mpu.set_expert_model_parallel_world_size(self.margs.expert_model_parallel_size) # For backward compatibility during local parallel states refactoring fake_tp_group = _ConverterFakeProcessGroup(size=self.margs.tensor_model_parallel_size) fake_ep_group = _ConverterFakeProcessGroup(size=self.margs.expert_model_parallel_size) mpu._TENSOR_MODEL_PARALLEL_GROUP = fake_tp_group mpu._EXPERT_MODEL_PARALLEL_GROUP = fake_ep_group fused_kernels.load(self.margs) def compute_true_vocab_size(self): """Determine the 'true' (non-padded) vocab size.""" if self.args.true_vocab_size is not None: return self.args.true_vocab_size elif self.args.vocab_file is not None: vocab = json.load(open(self.args.vocab_file)) return len(vocab) else: return None def verify_vocabs_match(self, true_vocab_size): """ If both --true-vocab-size and --vocab-file are specified, verify they match. Return False (and exit) if they don't match; True otherwise. """ if self.args.true_vocab_size is not None and self.args.vocab_file is not None: vocab = json.load(open(self.args.vocab_file)) if len(vocab) != self.args.true_vocab_size: print("Both --true-vocab-size and --vocab-file specified but vocab sizes do not match. Aborting.") return False return True def load_model_shards(self, model_provider, dtype): """ Build and load model shards for each tensor-parallel rank, returning: - A nested list of loaded models by [pipeline_rank][virtual_pipeline_rank]. - consumed_train_samples, consumed_valid_samples """ from megatron.core import mpu from megatron.training.checkpointing import load_checkpoint consumed_train_samples = None consumed_valid_samples = None tp_size = self.margs.tensor_model_parallel_size pp_size = self.margs.pipeline_model_parallel_size vp_size = self.margs.virtual_pipeline_model_parallel_size or 1 all_models = [] # all_models[pp_rank][vp_rank] = [list of models across TP ranks] def get_models_for_pipeline_stage(count, dtype): local_models_for_stage = [[] for _ in range(vp_size)] for tp_rank in range(count): fake_tp_group = mpu.get_tensor_model_parallel_group() fake_tp_group.set_rank(tp_rank) mpu.set_tensor_model_parallel_rank(tp_rank) model_list = [] for i in range(vp_size): mpu.set_virtual_pipeline_model_parallel_rank(i) pre_process = mpu.is_pipeline_first_stage() post_process = mpu.is_pipeline_last_stage() this_model = model_provider(pre_process=pre_process, post_process=post_process).to(dtype) model_list.append(this_model) # Each time we load, we set counters to 0, pass None for optimizer/ LR self.margs.consumed_train_samples = 0 self.margs.consumed_valid_samples = 0 self.margs.exit_on_missing_checkpoint = True load_checkpoint(model_list, None, None) # Validate that train/valid samples match across ranks nonlocal consumed_train_samples, consumed_valid_samples if consumed_train_samples is not None: assert self.margs.consumed_train_samples == consumed_train_samples else: consumed_train_samples = self.margs.consumed_train_samples if consumed_valid_samples is not None: assert self.margs.consumed_valid_samples == consumed_valid_samples else: consumed_valid_samples = self.margs.consumed_valid_samples for vp_rank in range(vp_size): local_models_for_stage[vp_rank].append(model_list[vp_rank]) # Print memory usage print_memory_usage("loader", tp_rank, count) return local_models_for_stage # Load shards for each pipeline rank mpu.set_virtual_pipeline_model_parallel_rank(0) for pp_rank in range(pp_size): mpu.set_pipeline_model_parallel_rank(pp_rank) all_models.append(get_models_for_pipeline_stage(tp_size, dtype)) return all_models, consumed_train_samples, consumed_valid_samples def send_metadata_over_queue(self): # Let the consumer know the overall metadata: self.md.consumed_train_samples = self.consumed_train_samples self.md.consumed_valid_samples = self.consumed_valid_samples self.queue.put(self.md) def queue_put(self, name, msg): print(f"sending {name}") msg["name"] = name self.queue.put(msg) def send_llm_over_queue(self, schema): """ Using self.all_models, extract model parameters and send them over the queue. """ # 2) Transformer layers tp_size = self.margs.tensor_model_parallel_size pp_size = self.margs.pipeline_model_parallel_size vp_size = self.margs.virtual_pipeline_model_parallel_size or 1 # all_models[pp_rank][vp_rank] is a list across TP ranks # We'll start with pipeline=0, vp=0 for embeddings/final norm first_pipeline_models = self.all_models[0][0] # 1) Embeddings embeddings = [schema.get("embeddings", m) for m in first_pipeline_models] message = { "word embeddings": torch.cat([e["word"] for e in embeddings], dim=0) } if self.md.position_embedding_type == 'learned_absolute': # Only send one set from rank 0 message["position embeddings"] = embeddings[0]["pos"] else: assert embeddings[0]["pos"] is None self.queue_put("embeddings", message) total_layer_num = 0 for vp_rank in range(vp_size): for pp_rank in range(pp_size): models = self.all_models[pp_rank][vp_rank] num_layers = schema.get_num_layers(models[0]) for layer_idx in range(num_layers): message = {} layer = schema.get_layer(models[0], layer_idx) # Non-parallel params message["input norm weight"] = layer["self_attn_norm_weight"] message["post norm weight"] = layer["mlp_norm_weight"] if self.md.norm_has_bias: message["input norm bias"] = layer["self_attn_norm_bias"] message["post norm bias"] = layer["mlp_norm_bias"] if self.md.linear_bias: message["dense bias"] = layer["self_attn_proj_bias"] message["mlp l1 bias"] = layer["mlp_fc2_bias"] # Collect parallel parameters qkv_weight, qkv_bias = [], [] dense_weight = [] mlp_l0_weight, mlp_l0_bias = [], [] mlp_l1_weight = [] for model_tp in models: layer_p = schema.get_layer(model_tp, layer_idx) qkv_weight.append(layer_p["self_attn_qkv_weight"]) dense_weight.append(layer_p["self_attn_proj_weight"]) mlp_l0_weight.append(layer_p["mlp_fc1_weight"]) mlp_l1_weight.append(layer_p["mlp_fc2_weight"]) if self.md.qkv_bias: qkv_bias.append(layer_p["self_attn_qkv_bias"]) if self.md.linear_bias: mlp_l0_bias.append(layer_p["mlp_fc1_bias"]) # If we are using SwiGLU, chunk each mlp_l0_weight if self.md.swiglu: for i in range(tp_size): mlp_l0_weight[i] = torch.chunk(mlp_l0_weight[i], 2, dim=0) message["mlp l0 weight W"] = torch.cat([w[0] for w in mlp_l0_weight], dim=0) message["mlp l0 weight V"] = torch.cat([w[1] for w in mlp_l0_weight], dim=0) else: message["mlp l0 weight"] = torch.cat(mlp_l0_weight, dim=0) # Standard concatenations message["qkv weight"] = torch.cat(qkv_weight, dim=0) message["dense weight"] = torch.cat(dense_weight, dim=1) message["mlp l1 weight"] = torch.cat(mlp_l1_weight, dim=1) if self.md.qkv_bias: message["qkv bias"] = torch.cat(qkv_bias, dim=0) if self.md.linear_bias: if self.md.swiglu: for i in range(tp_size): mlp_l0_bias[i] = torch.chunk(mlp_l0_bias[i], 2, dim=0) message["mlp l0 bias W"] = torch.cat([b[0] for b in mlp_l0_bias], dim=0) message["mlp l0 bias V"] = torch.cat([b[1] for b in mlp_l0_bias], dim=0) else: message["mlp l0 bias"] = torch.cat(mlp_l0_bias, dim=0) self.queue_put(f"transformer layer {total_layer_num}", message) total_layer_num += 1 # 3) Final norm final_norm = schema.get("final_norm", models[0]) message = {"weight": final_norm["weight"]} if self.md.norm_has_bias: message["bias"] = final_norm["bias"] self.queue_put("final norm", message) # 4) Output layer if self.md.output_layer: output_layers = [schema.get("output_layer", m) for m in models] message = { "weight": torch.cat([layer["weight"] for layer in output_layers], dim=0), } self.queue_put("output layer", message) # 5) BERT-specific parameters if self.md.model_type == 'BERT': # Pooler pooler = schema.get("pooler", models[0]) message = { "weight": pooler["weight"], "bias": pooler["bias"], } self.queue_put("pooler", message) # LM head lm_head = schema.get("lm_head", models[0]) message = { "dense weight": lm_head["dense_weight"], "dense bias": lm_head["dense_bias"], "norm weight": lm_head["norm_weight"], } if self.md.norm_has_bias: message["norm bias"] = lm_head["norm_bias"] self.queue_put("lm head", message) # Binary head if self.md.bert_binary_head: binary_head = schema.get("binary_head", models[0]) message = { "weight": binary_head["weight"], "bias": binary_head["bias"], } self.queue_put("binary head", message) # Done self.queue.put("done") def load(self): """ Orchestrate the entire flow of loading the Megatron checkpoint. """ # 1) Parse Megatron arguments self.parse_megatron_args() # 2) Ensure required arguments are present self.ensure_required_arguments() # 3) Import the correct model provider (GPT or BERT) model_provider = self.import_model_provider() # 4) Initialize the Megatron environment self.initialize_megatron_env() # 5) Determine the true vocab size and verify if both sources match true_vocab_size = self.compute_true_vocab_size() if not self.verify_vocabs_match(true_vocab_size): self.queue.put("exit") sys.exit(1) # 6) Build metadata self.md = self.build_checkpoint_metadata(true_vocab_size) # 7) Load all model shards self.all_models, self.consumed_train_samples, self.consumed_valid_samples = self.load_model_shards( model_provider, self.md.params_dtype ) # 8) Send model over the queue self.send_model_over_queue() def build_checkpoint_metadata(self, true_vocab_size): """ Construct a simple namespace for all relevant model metadata. """ norm_has_bias = True if hasattr(self.checkpoint_args, 'normalization'): # For older models, normalization was always "LayerNorm". norm_has_bias = (self.checkpoint_args.normalization == "LayerNorm") md = types.SimpleNamespace() md.model_type = self.args.model_type md.num_layers = self.margs.num_layers md.hidden_size = self.margs.hidden_size md.seq_length = self.margs.seq_length md.num_attention_heads = self.margs.num_attention_heads md.max_position_embeddings = self.margs.max_position_embeddings md.tokenizer_type = self.margs.tokenizer_type md.iteration = self.margs.iteration md.params_dtype = self.margs.params_dtype md.bert_binary_head = self.margs.bert_binary_head md.output_layer = self.margs.untie_embeddings_and_output_weights md.position_embedding_type = self.margs.position_embedding_type md.linear_bias = self.margs.add_bias_linear md.qkv_bias = self.margs.add_qkv_bias md.norm_has_bias = norm_has_bias md.swiglu = self.margs.swiglu md.previous_tensor_parallel_size = self.margs.tensor_model_parallel_size md.previous_pipeline_parallel_size = self.margs.pipeline_model_parallel_size md.true_vocab_size = true_vocab_size md.make_vocab_size_divisible_by = self.margs.make_vocab_size_divisible_by md.checkpoint_args = self.checkpoint_args md.use_legacy_models = self.margs.use_legacy_models return md def build_sys_argv(self): """ Construct a sys.argv list for Megatron's argument parser. This centralizes the hack of overwriting sys.argv. """ return [ 'script.py', '--no-masked-softmax-fusion', '--no-bias-gelu-fusion', '--no-bias-dropout-fusion', '--no-async-tensor-model-parallel-allreduce', '--use-cpu-initialization', '--micro-batch-size', '1', '--no-load-optim', '--no-load-rng', '--no-save-optim', '--no-save-rng', '--no-initialization', '--mock-data', # To pass the "blend data checks" in arguments.py '--load', self.args.load_dir, '--exit-on-missing-checkpoint', '--use-mp-args-from-checkpoint-args', '--no-one-logger', ] def import_model_provider(self): """Return the correct model_provider function depending on GPT vs. BERT.""" raise NotImplementedError def send_model_over_queue(self): """Creates model schema and sends the model over the queue""" raise NotImplementedError