| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from omegaconf.omegaconf import OmegaConf |
| | from pytorch_lightning import Trainer |
| |
|
| | from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel |
| | from nemo.collections.nlp.modules.common.megatron.megatron_utils import compute_model_parallel_rank |
| | from nemo.collections.nlp.parts.nlp_overrides import ( |
| | NLPDDPStrategy, |
| | NLPNativeMixedPrecisionPlugin, |
| | NLPPrecisionPlugin, |
| | NLPSaveRestoreConnector, |
| | ) |
| | from nemo.core.config import hydra_runner |
| | from nemo.utils import logging |
| | from nemo.utils.app_state import AppState |
| |
|
| |
|
| | @hydra_runner(config_path="conf", config_name="megatron_gpt_config") |
| | def main(cfg) -> None: |
| | logging.info("\n\n************** Experiment configuration ***********") |
| | logging.info(f'\n{OmegaConf.to_yaml(cfg)}') |
| |
|
| | trainer = None |
| | if cfg.trainer.precision == 16: |
| | trainer = Trainer( |
| | plugins=[ |
| | NLPNativeMixedPrecisionPlugin( |
| | init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32), |
| | growth_interval=cfg.model.get('native_amp_growth_interval', 1000), |
| | ), |
| | ], |
| | strategy=NLPDDPStrategy(), |
| | **cfg.trainer, |
| | ) |
| | elif cfg.trainer.precision == 'bf16': |
| | trainer = Trainer(plugins=[NLPNativeBfloat16PrecisionPlugin(),], strategy=NLPDDPStrategy(), **cfg.trainer,) |
| | else: |
| | trainer = Trainer(plugins=[NLPPrecisionPlugin()], strategy=NLPDDPStrategy(), **cfg.trainer) |
| |
|
| | app_state = AppState() |
| | app_state.model_parallel_size = cfg.model.tensor_model_parallel_size |
| | app_state.model_parallel_rank = compute_model_parallel_rank(trainer.local_rank, app_state.model_parallel_size) |
| |
|
| | model = MegatronGPTModel.restore_from( |
| | cfg.restore_from_path, trainer=trainer, save_restore_connector=NLPSaveRestoreConnector(), |
| | ) |
| |
|
| | |
| | |
| | model.cfg.data.splits_string = cfg.model.data.splits_string |
| |
|
| | trainer.test(model) |
| |
|
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|