| import argparse |
| import importlib |
| import torch.multiprocessing as mp |
| import os |
| import sys |
|
|
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def load_plugin(plugin_type, name): |
| module_name = f"checkpoint_{plugin_type}_{name}" |
| try: |
| plugin = importlib.import_module(module_name) |
| except ModuleNotFoundError: |
| module_name = name |
| try: |
| plugin = importlib.import_module(module_name) |
| except ModuleNotFoundError: |
| sys.exit(f"Unable to load {plugin_type} plugin {name}. Exiting.") |
|
|
| if not hasattr(plugin, 'add_arguments'): |
| sys.exit(f"{module_name} module is not a plugin. Exiting.") |
|
|
| print(f"Loaded {module_name} as the {plugin_type}.") |
| return plugin |
|
|
| def main(): |
| import argparse |
| parser = argparse.ArgumentParser(description="Megatron Checkpoint Utility Arguments", |
| allow_abbrev=False, conflict_handler='resolve') |
|
|
| parser.add_argument('--model-type', type=str, required=True, |
| choices=['GPT', 'BERT'], |
| help='Type of the model') |
| parser.add_argument('--loader', type=str, default='megatron', |
| help='Module name to load checkpoint, should be on python path') |
| parser.add_argument('--saver', type=str, default='megatron', |
| help='Module name to save checkpoint, shdoul be on python path') |
| parser.add_argument('--load-dir', type=str, required=True, |
| help='Directory to load model checkpoint from') |
| parser.add_argument('--save-dir', type=str, required=True, |
| help='Directory to save model checkpoint to') |
| parser.add_argument('--max-queue-size', type=int, default=50, |
| help='Maximum number of tensors in the queue') |
| parser.add_argument('--no-checking', action='store_false', |
| help='Do not perform checking on the name and ordering of weights', |
| dest='checking') |
|
|
| known_args, _ = parser.parse_known_args() |
| loader = load_plugin('loader', known_args.loader) |
| saver = load_plugin('saver', known_args.saver) |
|
|
| loader.add_arguments(parser) |
| saver.add_arguments(parser) |
|
|
| args = parser.parse_args() |
|
|
| queue = mp.Queue(maxsize=args.max_queue_size) |
|
|
| print("Starting saver...") |
| saver_proc = mp.Process(target=saver.save_checkpoint, args=(queue, args)) |
| saver_proc.start() |
|
|
| print("Starting loader...") |
| loader.load_checkpoint(queue, args) |
|
|
| print("Waiting for saver to complete...") |
| saver_proc.join() |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|