| |
| import json |
| import os |
| import sys |
| import types |
| import torch |
|
|
| from utils import _ConverterFakeProcessGroup, print_memory_usage |
|
|
| class MegatronCheckpointLoaderBase: |
| """Orchestrates loading a Megatron checkpoint and sending |
| model parameters over a given multiprocessing queue. |
| |
| Args: |
| args: argparse Namespace with Megatron checkpoint configurations. |
| queue: A multiprocessing.Queue (or similar) used to send out loaded tensors. |
| """ |
|
|
| def __init__(self, args, queue, build_tokenizer=False): |
| self.args = args |
| self.queue = queue |
| self.build_tokenizer = build_tokenizer |
| self.margs = None |
| self.checkpoint_args = None |
| self.all_models = None |
| self.md = None |
| self.consumed_train_samples = None |
| self.consumed_valid_samples = None |
|
|
| def _maybe_parse_additional_megatron_args(self, margs, checkpoint_args): |
| """ |
| Method used to optionally add arguments from the checkpoint to the main args. |
| For instance, using margs.some_arg = checkpoint_args.some_arg |
| """ |
| return margs |
|
|
| def parse_megatron_args(self): |
| """ |
| Parse Megatron arguments by forcibly overwriting sys.argv. |
| Populates self.margs and self.checkpoint_args. |
| """ |
| |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) |
| if self.args.megatron_path is not None: |
| sys.path.insert(0, self.args.megatron_path) |
|
|
| try: |
| from megatron.training.arguments import parse_args, validate_args |
| from megatron.training.checkpointing import load_args_from_checkpoint |
| except ModuleNotFoundError: |
| print("Unable to import Megatron. Please specify --megatron-path. Exiting.") |
| self.queue.put("exit") |
| sys.exit(1) |
|
|
| |
| sys.argv = self.build_sys_argv() |
|
|
| margs = parse_args() |
| margs, checkpoint_args = load_args_from_checkpoint(margs) |
|
|
| |
| margs.world_size = margs.tensor_model_parallel_size * margs.pipeline_model_parallel_size |
|
|
| |
| margs.fp16 = checkpoint_args.fp16 |
| margs.bf16 = checkpoint_args.bf16 |
|
|
| |
| if margs.expert_model_parallel_size > 1: |
| margs.sequence_parallel = True |
| |
| margs = self._maybe_parse_additional_megatron_args(margs, checkpoint_args) |
|
|
| |
| try: |
| from megatron.training.arguments import validate_args |
| margs = validate_args(margs) |
| except Exception as e: |
| print(f"Error validating Megatron arguments: {e}") |
| self.queue.put("exit") |
| sys.exit(1) |
|
|
| margs.use_legacy_models = False |
| margs.transformer_impl = self.args.loader_transformer_impl |
| if self.args.loader_transformer_impl == "local" and margs.normalization == "RMSNorm": |
| margs.no_persist_layer_norm = True |
|
|
| self.margs = margs |
| self.checkpoint_args = checkpoint_args |
|
|
| def _maybe_ensure_additional_required_arguments(self): |
| """ |
| Can be used to ensure some expected args are present. |
| For instance, use self.check_for_arg('some_arg') |
| """ |
| pass |
|
|
| def check_for_arg(self, arg_name, default=None): |
| if getattr(self.margs, arg_name, None) is None: |
| if default is not None: |
| setattr(self.margs, arg_name, default) |
| else: |
| print(f"Checkpoint does not specify argument {arg_name}. Exiting.") |
| print(f"Arguments: {self.margs}") |
| self.queue.put("exit") |
| sys.exit(1) |
|
|
| def ensure_required_arguments(self): |
| """ |
| Ensure that certain Megatron arguments (from checkpoint) are present. |
| If missing, either set defaults or exit. |
| """ |
|
|
| self.check_for_arg('tensor_model_parallel_size') |
| self.check_for_arg('pipeline_model_parallel_size') |
| self.check_for_arg('num_layers') |
| self.check_for_arg('hidden_size') |
| self.check_for_arg('seq_length') |
| self.check_for_arg('num_attention_heads') |
| self.check_for_arg('max_position_embeddings') |
| self.check_for_arg('position_embedding_type') |
| self.check_for_arg('tokenizer_type') |
| self.check_for_arg('iteration') |
| self.check_for_arg('bert_binary_head') |
| self.check_for_arg('disable_bias_linear', False) |
| self.check_for_arg('params_dtype') |
| self.check_for_arg('swiglu', False) |
|
|
| self._maybe_ensure_additional_required_arguments() |
|
|
| def initialize_megatron_env(self): |
| """ |
| Initialize Megatron global variables and fused kernels. |
| """ |
| try: |
| from megatron.training.global_vars import set_global_variables |
| from megatron.core import mpu |
| from megatron.legacy import fused_kernels |
| except ModuleNotFoundError as e: |
| print(f"Unable to import required Megatron modules: {e}") |
| self.queue.put("exit") |
| sys.exit(1) |
|
|
| set_global_variables(self.margs, build_tokenizer=self.build_tokenizer) |
| mpu.set_tensor_model_parallel_world_size(self.margs.tensor_model_parallel_size) |
| mpu.set_pipeline_model_parallel_world_size(self.margs.pipeline_model_parallel_size) |
| mpu.set_virtual_pipeline_model_parallel_world_size(self.margs.virtual_pipeline_model_parallel_size) |
| mpu.set_expert_model_parallel_world_size(self.margs.expert_model_parallel_size) |
| |
| |
| fake_tp_group = _ConverterFakeProcessGroup(size=self.margs.tensor_model_parallel_size) |
| fake_ep_group = _ConverterFakeProcessGroup(size=self.margs.expert_model_parallel_size) |
| mpu._TENSOR_MODEL_PARALLEL_GROUP = fake_tp_group |
| mpu._EXPERT_MODEL_PARALLEL_GROUP = fake_ep_group |
| fused_kernels.load(self.margs) |
|
|
| def compute_true_vocab_size(self): |
| """Determine the 'true' (non-padded) vocab size.""" |
| if self.args.true_vocab_size is not None: |
| return self.args.true_vocab_size |
| elif self.args.vocab_file is not None: |
| vocab = json.load(open(self.args.vocab_file)) |
| return len(vocab) |
| else: |
| return None |
|
|
| def verify_vocabs_match(self, true_vocab_size): |
| """ |
| If both --true-vocab-size and --vocab-file are specified, verify they match. |
| Return False (and exit) if they don't match; True otherwise. |
| """ |
| if self.args.true_vocab_size is not None and self.args.vocab_file is not None: |
| vocab = json.load(open(self.args.vocab_file)) |
| if len(vocab) != self.args.true_vocab_size: |
| print("Both --true-vocab-size and --vocab-file specified but vocab sizes do not match. Aborting.") |
| return False |
| return True |
|
|
| def load_model_shards(self, model_provider, dtype): |
| """ |
| Build and load model shards for each tensor-parallel rank, returning: |
| - A nested list of loaded models by [pipeline_rank][virtual_pipeline_rank]. |
| - consumed_train_samples, consumed_valid_samples |
| """ |
| from megatron.core import mpu |
| from megatron.training.checkpointing import load_checkpoint |
|
|
| consumed_train_samples = None |
| consumed_valid_samples = None |
| tp_size = self.margs.tensor_model_parallel_size |
| pp_size = self.margs.pipeline_model_parallel_size |
| vp_size = self.margs.virtual_pipeline_model_parallel_size or 1 |
|
|
| all_models = [] |
|
|
| def get_models_for_pipeline_stage(count, dtype): |
| local_models_for_stage = [[] for _ in range(vp_size)] |
| for tp_rank in range(count): |
| fake_tp_group = mpu.get_tensor_model_parallel_group() |
| fake_tp_group.set_rank(tp_rank) |
| mpu.set_tensor_model_parallel_rank(tp_rank) |
| model_list = [] |
|
|
| for i in range(vp_size): |
| mpu.set_virtual_pipeline_model_parallel_rank(i) |
| pre_process = mpu.is_pipeline_first_stage() |
| post_process = mpu.is_pipeline_last_stage() |
| this_model = model_provider(pre_process=pre_process, |
| post_process=post_process).to(dtype) |
| model_list.append(this_model) |
|
|
| |
| self.margs.consumed_train_samples = 0 |
| self.margs.consumed_valid_samples = 0 |
| self.margs.exit_on_missing_checkpoint = True |
| load_checkpoint(model_list, None, None) |
|
|
| |
| nonlocal consumed_train_samples, consumed_valid_samples |
| if consumed_train_samples is not None: |
| assert self.margs.consumed_train_samples == consumed_train_samples |
| else: |
| consumed_train_samples = self.margs.consumed_train_samples |
|
|
| if consumed_valid_samples is not None: |
| assert self.margs.consumed_valid_samples == consumed_valid_samples |
| else: |
| consumed_valid_samples = self.margs.consumed_valid_samples |
|
|
| for vp_rank in range(vp_size): |
| local_models_for_stage[vp_rank].append(model_list[vp_rank]) |
|
|
| |
| print_memory_usage("loader", tp_rank, count) |
|
|
| return local_models_for_stage |
|
|
| |
| mpu.set_virtual_pipeline_model_parallel_rank(0) |
| for pp_rank in range(pp_size): |
| mpu.set_pipeline_model_parallel_rank(pp_rank) |
| all_models.append(get_models_for_pipeline_stage(tp_size, dtype)) |
|
|
| return all_models, consumed_train_samples, consumed_valid_samples |
| |
| def send_metadata_over_queue(self): |
| |
| self.md.consumed_train_samples = self.consumed_train_samples |
| self.md.consumed_valid_samples = self.consumed_valid_samples |
| self.queue.put(self.md) |
|
|
| def queue_put(self, name, msg): |
| print(f"sending {name}") |
| msg["name"] = name |
| self.queue.put(msg) |
|
|
| def send_llm_over_queue(self, schema): |
| """ |
| Using self.all_models, extract model parameters and send them over the queue. |
| """ |
| |
| tp_size = self.margs.tensor_model_parallel_size |
| pp_size = self.margs.pipeline_model_parallel_size |
| vp_size = self.margs.virtual_pipeline_model_parallel_size or 1 |
|
|
| |
| |
| first_pipeline_models = self.all_models[0][0] |
|
|
| |
| embeddings = [schema.get("embeddings", m) for m in first_pipeline_models] |
| message = { |
| "word embeddings": torch.cat([e["word"] for e in embeddings], dim=0) |
| } |
| if self.md.position_embedding_type == 'learned_absolute': |
| |
| message["position embeddings"] = embeddings[0]["pos"] |
| else: |
| assert embeddings[0]["pos"] is None |
| self.queue_put("embeddings", message) |
|
|
| total_layer_num = 0 |
| for vp_rank in range(vp_size): |
| for pp_rank in range(pp_size): |
| models = self.all_models[pp_rank][vp_rank] |
| num_layers = schema.get_num_layers(models[0]) |
| for layer_idx in range(num_layers): |
| message = {} |
| layer = schema.get_layer(models[0], layer_idx) |
|
|
| |
| message["input norm weight"] = layer["self_attn_norm_weight"] |
| message["post norm weight"] = layer["mlp_norm_weight"] |
| if self.md.norm_has_bias: |
| message["input norm bias"] = layer["self_attn_norm_bias"] |
| message["post norm bias"] = layer["mlp_norm_bias"] |
| if self.md.linear_bias: |
| message["dense bias"] = layer["self_attn_proj_bias"] |
| message["mlp l1 bias"] = layer["mlp_fc2_bias"] |
|
|
| |
| qkv_weight, qkv_bias = [], [] |
| dense_weight = [] |
| mlp_l0_weight, mlp_l0_bias = [], [] |
| mlp_l1_weight = [] |
|
|
| for model_tp in models: |
| layer_p = schema.get_layer(model_tp, layer_idx) |
| qkv_weight.append(layer_p["self_attn_qkv_weight"]) |
| dense_weight.append(layer_p["self_attn_proj_weight"]) |
| mlp_l0_weight.append(layer_p["mlp_fc1_weight"]) |
| mlp_l1_weight.append(layer_p["mlp_fc2_weight"]) |
| if self.md.qkv_bias: |
| qkv_bias.append(layer_p["self_attn_qkv_bias"]) |
| if self.md.linear_bias: |
| mlp_l0_bias.append(layer_p["mlp_fc1_bias"]) |
|
|
| |
| if self.md.swiglu: |
| for i in range(tp_size): |
| mlp_l0_weight[i] = torch.chunk(mlp_l0_weight[i], 2, dim=0) |
| message["mlp l0 weight W"] = torch.cat([w[0] for w in mlp_l0_weight], dim=0) |
| message["mlp l0 weight V"] = torch.cat([w[1] for w in mlp_l0_weight], dim=0) |
| else: |
| message["mlp l0 weight"] = torch.cat(mlp_l0_weight, dim=0) |
|
|
| |
| message["qkv weight"] = torch.cat(qkv_weight, dim=0) |
| message["dense weight"] = torch.cat(dense_weight, dim=1) |
| message["mlp l1 weight"] = torch.cat(mlp_l1_weight, dim=1) |
|
|
| if self.md.qkv_bias: |
| message["qkv bias"] = torch.cat(qkv_bias, dim=0) |
| if self.md.linear_bias: |
| if self.md.swiglu: |
| for i in range(tp_size): |
| mlp_l0_bias[i] = torch.chunk(mlp_l0_bias[i], 2, dim=0) |
| message["mlp l0 bias W"] = torch.cat([b[0] for b in mlp_l0_bias], dim=0) |
| message["mlp l0 bias V"] = torch.cat([b[1] for b in mlp_l0_bias], dim=0) |
| else: |
| message["mlp l0 bias"] = torch.cat(mlp_l0_bias, dim=0) |
|
|
| self.queue_put(f"transformer layer {total_layer_num}", message) |
| total_layer_num += 1 |
|
|
| |
| final_norm = schema.get("final_norm", models[0]) |
| message = {"weight": final_norm["weight"]} |
| if self.md.norm_has_bias: |
| message["bias"] = final_norm["bias"] |
| self.queue_put("final norm", message) |
|
|
| |
| if self.md.output_layer: |
| output_layers = [schema.get("output_layer", m) for m in models] |
| message = { |
| "weight": torch.cat([layer["weight"] for layer in output_layers], dim=0), |
| } |
| self.queue_put("output layer", message) |
|
|
| |
| if self.md.model_type == 'BERT': |
| |
| pooler = schema.get("pooler", models[0]) |
| message = { |
| "weight": pooler["weight"], |
| "bias": pooler["bias"], |
| } |
| self.queue_put("pooler", message) |
|
|
| |
| lm_head = schema.get("lm_head", models[0]) |
| message = { |
| "dense weight": lm_head["dense_weight"], |
| "dense bias": lm_head["dense_bias"], |
| "norm weight": lm_head["norm_weight"], |
| } |
| if self.md.norm_has_bias: |
| message["norm bias"] = lm_head["norm_bias"] |
| self.queue_put("lm head", message) |
|
|
| |
| if self.md.bert_binary_head: |
| binary_head = schema.get("binary_head", models[0]) |
| message = { |
| "weight": binary_head["weight"], |
| "bias": binary_head["bias"], |
| } |
| self.queue_put("binary head", message) |
|
|
| |
| self.queue.put("done") |
|
|
| def load(self): |
| """ |
| Orchestrate the entire flow of loading the Megatron checkpoint. |
| """ |
| |
| self.parse_megatron_args() |
|
|
| |
| self.ensure_required_arguments() |
|
|
| |
| model_provider = self.import_model_provider() |
|
|
| |
| self.initialize_megatron_env() |
|
|
| |
| true_vocab_size = self.compute_true_vocab_size() |
| if not self.verify_vocabs_match(true_vocab_size): |
| self.queue.put("exit") |
| sys.exit(1) |
|
|
| |
| self.md = self.build_checkpoint_metadata(true_vocab_size) |
|
|
| |
| self.all_models, self.consumed_train_samples, self.consumed_valid_samples = self.load_model_shards( |
| model_provider, |
| self.md.params_dtype |
| ) |
|
|
| |
| self.send_model_over_queue() |
|
|
| def build_checkpoint_metadata(self, true_vocab_size): |
| """ |
| Construct a simple namespace for all relevant model metadata. |
| """ |
| norm_has_bias = True |
| if hasattr(self.checkpoint_args, 'normalization'): |
| |
| norm_has_bias = (self.checkpoint_args.normalization == "LayerNorm") |
|
|
| md = types.SimpleNamespace() |
| md.model_type = self.args.model_type |
| md.num_layers = self.margs.num_layers |
| md.hidden_size = self.margs.hidden_size |
| md.seq_length = self.margs.seq_length |
| md.num_attention_heads = self.margs.num_attention_heads |
| md.max_position_embeddings = self.margs.max_position_embeddings |
| md.tokenizer_type = self.margs.tokenizer_type |
| md.iteration = self.margs.iteration |
| md.params_dtype = self.margs.params_dtype |
| md.bert_binary_head = self.margs.bert_binary_head |
| md.output_layer = self.margs.untie_embeddings_and_output_weights |
| md.position_embedding_type = self.margs.position_embedding_type |
| md.linear_bias = self.margs.add_bias_linear |
| md.qkv_bias = self.margs.add_qkv_bias |
| md.norm_has_bias = norm_has_bias |
| md.swiglu = self.margs.swiglu |
| md.previous_tensor_parallel_size = self.margs.tensor_model_parallel_size |
| md.previous_pipeline_parallel_size = self.margs.pipeline_model_parallel_size |
| md.true_vocab_size = true_vocab_size |
| md.make_vocab_size_divisible_by = self.margs.make_vocab_size_divisible_by |
| md.checkpoint_args = self.checkpoint_args |
| md.use_legacy_models = self.margs.use_legacy_models |
| return md |
|
|
| def build_sys_argv(self): |
| """ |
| Construct a sys.argv list for Megatron's argument parser. |
| This centralizes the hack of overwriting sys.argv. |
| """ |
|
|
| return [ |
| 'script.py', |
| '--no-masked-softmax-fusion', |
| '--no-bias-gelu-fusion', |
| '--no-bias-dropout-fusion', |
| '--no-async-tensor-model-parallel-allreduce', |
| '--use-cpu-initialization', |
| '--micro-batch-size', '1', |
| '--no-load-optim', |
| '--no-load-rng', |
| '--no-save-optim', |
| '--no-save-rng', |
| '--no-initialization', |
| '--mock-data', |
| '--load', self.args.load_dir, |
| '--exit-on-missing-checkpoint', |
| '--use-mp-args-from-checkpoint-args', |
| '--no-one-logger', |
| ] |
|
|
| def import_model_provider(self): |
| """Return the correct model_provider function depending on GPT vs. BERT.""" |
| raise NotImplementedError |
|
|
| def send_model_over_queue(self): |
| """Creates model schema and sends the model over the queue""" |
| raise NotImplementedError |
|
|
|
|