| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Builds a .nemo file with average weights over multiple .ckpt files (assumes .ckpt files in same folder as .nemo file). |
| |
| Usage example for building *-averaged.nemo for a given .nemo file: |
| |
| NeMo/scripts/checkpoint_averaging/checkpoint_averaging.py my_model.nemo |
| |
| Usage example for building *-averaged.nemo files for all results in sub-directories under current path: |
| |
| find . -name '*.nemo' | grep -v -- "-averaged.nemo" | xargs NeMo/scripts/checkpoint_averaging/checkpoint_averaging.py |
| |
| |
| NOTE: if yout get the following error `AttributeError: Can't get attribute '???' on <module '__main__' from '???'>` |
| use --import_fname_list <FILE> with all files that contains missing classes. |
| """ |
|
|
| import argparse |
| import glob |
| import importlib |
| import os |
| import sys |
|
|
| import torch |
| from lightning.pytorch.trainer.trainer import Trainer |
| from omegaconf.omegaconf import OmegaConf, open_dict |
|
|
| from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector |
| from nemo.core import ModelPT |
| from nemo.utils import logging, model_utils |
|
|
|
|
| def main(): |
| """ |
| Main function |
| """ |
|
|
| logging.info("This script is deprecated and will be removed in the 25.01 release.") |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| 'model_fname_list', |
| metavar='N', |
| type=str, |
| nargs='+', |
| help='Input .nemo files (or folders who contains them) to parse', |
| ) |
| parser.add_argument( |
| '--import_fname_list', |
| type=str, |
| nargs='+', |
| default=[], |
| help='A list of Python file names to "from FILE import *"', |
| ) |
| parser.add_argument( |
| '--class_path', |
| type=str, |
| default='', |
| help='A path to class "module.submodule.class" (if given)', |
| ) |
| args = parser.parse_args() |
|
|
| logging.info( |
| f"\n\nIMPORTANT: Use --import_fname_list for all files that contain missing classes:\n\t" |
| "(AttributeError: Can't get attribute '???' on <module '__main__' from '???'>)\n\n" |
| ) |
|
|
| for fn in args.import_fname_list: |
| logging.info(f"Importing * from {fn}") |
| sys.path.insert(0, os.path.dirname(fn)) |
| globals().update(importlib.import_module(os.path.splitext(os.path.basename(fn))[0]).__dict__) |
|
|
| device = torch.device("cpu") |
|
|
| trainer = Trainer(strategy=NLPDDPStrategy(), devices=1, num_nodes=1, precision=16, accelerator='gpu') |
| |
| for model_fname_i, model_fname in enumerate(args.model_fname_list): |
| if not model_fname.endswith(".nemo"): |
| |
| nemo_files = list( |
| filter(lambda fn: not fn.endswith("-averaged.nemo"), glob.glob(os.path.join(model_fname, "*.nemo"))) |
| ) |
| if len(nemo_files) != 1: |
| raise RuntimeError(f"Expected only a single .nemo files but discovered {len(nemo_files)} .nemo files") |
|
|
| model_fname = nemo_files[0] |
|
|
| model_folder_path = os.path.dirname(model_fname) |
| fn, fe = os.path.splitext(model_fname) |
| avg_model_fname = f"{fn}-averaged{fe}" |
|
|
| logging.info(f"\n===> [{model_fname_i+1} / {len(args.model_fname_list)}] Parsing folder {model_folder_path}\n") |
|
|
| |
| model_cfg = ModelPT.restore_from( |
| restore_path=model_fname, |
| return_config=True, |
| save_restore_connector=NLPSaveRestoreConnector(), |
| trainer=trainer, |
| ) |
| if args.class_path: |
| classpath = args.class_path |
| else: |
| classpath = model_cfg.target |
|
|
| OmegaConf.set_struct(model_cfg, True) |
| with open_dict(model_cfg): |
| if model_cfg.get('megatron_amp_O2', False): |
| model_cfg.megatron_amp_O2 = False |
| imported_class = model_utils.import_class_by_path(classpath) |
| logging.info(f"Loading model {model_fname}") |
| nemo_model = imported_class.restore_from( |
| restore_path=model_fname, |
| map_location=device, |
| save_restore_connector=NLPSaveRestoreConnector(), |
| trainer=trainer, |
| override_config_path=model_cfg, |
| ) |
|
|
| |
| checkpoint_paths = [ |
| os.path.join(model_folder_path, x) |
| for x in os.listdir(model_folder_path) |
| if x.endswith('.ckpt') and not x.endswith('-last.ckpt') |
| ] |
| """ < Checkpoint Averaging Logic > """ |
| |
| n = len(checkpoint_paths) |
| avg_state = None |
|
|
| logging.info(f"Averaging {n} checkpoints ...") |
|
|
| for ix, path in enumerate(checkpoint_paths): |
| checkpoint = torch.load(path, map_location=device) |
| if 'state_dict' in checkpoint: |
| checkpoint = checkpoint['state_dict'] |
|
|
| if ix == 0: |
| |
| avg_state = checkpoint |
|
|
| logging.info(f"Initialized average state dict with checkpoint : {path}") |
| else: |
| |
| for k in avg_state: |
| avg_state[k] = avg_state[k] + checkpoint[k] |
|
|
| logging.info(f"Updated average state dict with state from checkpoint : {path}") |
|
|
| for k in avg_state: |
| if str(avg_state[k].dtype).startswith("torch.int"): |
| |
| |
| pass |
| else: |
| avg_state[k] = avg_state[k] / n |
|
|
| |
| nemo_model.load_state_dict(avg_state, strict=True) |
| |
| logging.info(f"Saving average mdel to: {avg_model_fname}") |
| nemo_model.save_to(avg_model_fname) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|