Spaces:
Running on Zero
Running on Zero
| # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | |
| import json | |
| import os | |
| import sys | |
| import torch | |
| import types | |
| from functools import partial | |
| from schema_core import get_model_schema | |
| from loader_base import MegatronCheckpointLoaderBase | |
| def add_arguments(parser): | |
| """Add command-line arguments relevant to Megatron model loading.""" | |
| group = parser.add_argument_group(title='Megatron loader') | |
| group.add_argument('--true-vocab-size', type=int, default=None, | |
| help='Original size of vocab; if specified, trims padding from embedding table.') | |
| group.add_argument('--vocab-file', type=str, default=None, | |
| help='Path to a vocab file. If specified, determines vocab size to trim padding.') | |
| group.add_argument('--megatron-path', type=str, default=None, | |
| help='Base directory of Megatron repository') | |
| group.add_argument('--position-embedding-type', | |
| type=str, | |
| default='learned_absolute', | |
| choices=['learned_absolute', 'rope'], | |
| help='Type of position embedding.') | |
| group.add_argument('--loader-transformer-impl', default='transformer_engine', | |
| choices=['local', 'transformer_engine'], | |
| help='Which Transformer implementation to use.') | |
| class MegatronCheckpointLoaderLLM(MegatronCheckpointLoaderBase): | |
| """ | |
| Orchestrates loading a Megatron checkpoint and sending | |
| model parameters over a given multiprocessing queue. | |
| Args: | |
| args: argparse Namespace with Megatron checkpoint configurations. | |
| queue: A multiprocessing.Queue (or similar) used to send out loaded tensors. | |
| """ | |
| def build_sys_argv(self): | |
| """ | |
| Construct a sys.argv list for Megatron's argument parser. | |
| This centralizes the hack of overwriting sys.argv. | |
| """ | |
| return [ | |
| *super().build_sys_argv(), | |
| '--position-embedding-type', self.args.position_embedding_type, | |
| ] | |
| def import_model_provider(self): | |
| """Return the correct model_provider function depending on GPT vs. BERT.""" | |
| if self.args.model_type == 'GPT': | |
| from model_provider import model_provider | |
| from gpt_builders import gpt_builder | |
| self.model_provider = partial(model_provider, gpt_builder) | |
| return model_provider | |
| elif self.args.model_type == 'BERT': | |
| from pretrain_bert import model_provider | |
| return model_provider | |
| else: | |
| raise Exception(f"Unrecognized model type: {self.args.model_type}") | |
| def send_model_over_queue(self): | |
| self.send_metadata_over_queue() | |
| # Model schema. | |
| schema = get_model_schema( | |
| self.md.model_type, | |
| self.margs.transformer_impl, | |
| self.margs.num_experts, | |
| self.margs.expert_model_parallel_size, | |
| ) | |
| self.send_llm_over_queue(schema) | |
| self.queue.put("done") | |
| def load_checkpoint(queue, args): | |
| """ | |
| Required top-level function that creates the loader, | |
| calls its .load(), and handles exceptions by signaling 'exit'. | |
| """ | |
| loader = MegatronCheckpointLoaderLLM(args, queue) | |
| try: | |
| loader.load() | |
| except Exception as e: | |
| queue.put("exit") | |
| raise e | |