| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from omegaconf.omegaconf import OmegaConf, open_dict |
| | from pytorch_lightning import Trainer |
| | from pytorch_lightning.plugins.environments import TorchElasticEnvironment |
| | from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin |
| |
|
| | from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel |
| | from nemo.collections.nlp.modules.common.megatron.mup.shape import make_base_shapes |
| | from nemo.collections.nlp.parts.nlp_overrides import GradScaler, MegatronHalfPrecisionPlugin, NLPDDPStrategy |
| | from nemo.core.config import hydra_runner |
| | from nemo.utils import logging |
| |
|
| |
|
| | @hydra_runner(config_path="conf", config_name="megatron_retro_mutransfer") |
| | def main(cfg) -> None: |
| | logging.info("\n\n************** Experiment configuration ***********") |
| | logging.info(f'\n{OmegaConf.to_yaml(cfg)}') |
| |
|
| | megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False) |
| | plugins = [] |
| | strategy = NLPDDPStrategy( |
| | no_ddp_communication_hook=True if megatron_amp_o2 else False, |
| | gradient_as_bucket_view=cfg.model.gradient_as_bucket_view, |
| | find_unused_parameters=False, |
| | ) |
| |
|
| | if cfg.trainer.precision in [16, 'bf16']: |
| | scaler = None |
| | if cfg.trainer.precision == 16: |
| | scaler = GradScaler( |
| | init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32), |
| | growth_interval=cfg.model.get('native_amp_growth_interval', 1000), |
| | hysteresis=cfg.model.get('hysteresis', 2), |
| | ) |
| | if megatron_amp_o2: |
| | plugins.append(MegatronHalfPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler)) |
| | else: |
| | plugins.append(NativeMixedPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler)) |
| |
|
| | if cfg.get('cluster_type', None) == 'BCP': |
| | plugins.append(TorchElasticEnvironment()) |
| |
|
| | trainer = Trainer(plugins=plugins, strategy=strategy, **cfg.trainer) |
| |
|
| | |
| | with open_dict(cfg): |
| | cfg.base_model.precision = cfg.trainer.precision |
| | cfg.delta_model.precision = cfg.trainer.precision |
| |
|
| | base_model = MegatronRetrievalModel(cfg.base_model, trainer) |
| | delta_model = MegatronRetrievalModel(cfg.delta_model, trainer) |
| | make_base_shapes(base_model, delta_model, savefile=cfg.model.shape_file) |
| |
|
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|