| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| from pathlib import Path |
|
|
| |
| import torch._dynamo |
| import torch.multiprocessing as mp |
| from omegaconf.omegaconf import OmegaConf, open_dict |
|
|
| from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel |
| from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder |
| from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector |
| from nemo.core.config import hydra_runner |
| from nemo.utils import logging |
| from nemo.utils.exp_manager import exp_manager |
|
|
| torch._dynamo.config.suppress_errors = True |
|
|
| mp.set_start_method("spawn", force=True) |
|
|
|
|
| @hydra_runner(config_path="conf", config_name="megatron_gpt_config") |
| def main(cfg) -> None: |
| logging.info("\n\n************** Experiment configuration ***********") |
| logging.info(f'\n{OmegaConf.to_yaml(cfg)}') |
|
|
| trainer = MegatronTrainerBuilder(cfg).create_trainer() |
| exp_manager(trainer, cfg.exp_manager) |
|
|
| |
| if cfg.model.get("restore_from_path") is not None: |
| |
| logging.info(f"Continual training: loading weights from {cfg.model.restore_from_path}") |
| from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTSFTModel |
|
|
| model_cfg = MegatronGPTSFTModel.merge_cfg_with(cfg.model.restore_from_path, cfg) |
| model = MegatronGPTModel.restore_from( |
| restore_path=cfg.model.restore_from_path, |
| override_config_path=model_cfg, |
| trainer=trainer, |
| save_restore_connector=NLPSaveRestoreConnector(), |
| ) |
| elif cfg.model.get("restore_from_ckpt") is not None: |
| |
| logging.info(f"Continual training: loading weights and optimizer states from {cfg.model.restore_from_ckpt}") |
| trainer.ckpt_path = Path(cfg.model.restore_from_ckpt) |
| model = MegatronGPTModel(cfg.model, trainer) |
|
|
| |
| else: |
| model = MegatronGPTModel(cfg.model, trainer) |
|
|
| trainer.fit(model) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|