| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| import os |
| import torch.multiprocessing as mp |
| from omegaconf.omegaconf import OmegaConf |
| from nemo.collections.nlp.models.language_modeling.megatron_griffin_sft_model import MegatronGriffinSFTModel |
| from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronLMPPTrainerBuilder |
| from nemo.collections.nlp.parts.peft_config import PEFT_CONFIG_MAP |
| from nemo.core.config import hydra_runner |
| from nemo.utils import logging |
| from nemo.utils.model_utils import inject_model_parallel_rank |
|
|
|
|
| mp.set_start_method("spawn", force=True) |
|
|
|
|
| @hydra_runner(config_path="conf", config_name="megatron_griffin_generate_config") |
| def main(cfg) -> None: |
| logging.info("\n\n************** Experiment configuration ***********") |
| logging.info(f"\n{OmegaConf.to_yaml(cfg)}") |
| trainer = MegatronLMPPTrainerBuilder(cfg).create_trainer() |
|
|
| if cfg.model.peft.restore_from_path: |
| model_cfg = MegatronGriffinSFTModel.merge_inference_cfg(cfg.model.peft.restore_from_path, cfg) |
| else: |
| model_cfg = MegatronGriffinSFTModel.merge_inference_cfg(cfg.model.restore_from_path, cfg) |
|
|
| model = MegatronGriffinSFTModel.restore_from(cfg.model.restore_from_path, model_cfg, trainer=trainer) |
|
|
| if cfg.model.peft.restore_from_path: |
| model.load_adapters(cfg.model.peft.restore_from_path) |
| elif cfg.model.peft.restore_from_ckpt.checkpoint_dir and cfg.model.peft.restore_from_ckpt.checkpoint_name: |
| peft_cfg_cls = PEFT_CONFIG_MAP[cfg.model.peft.peft_scheme] |
| checkpoint_path = os.path.join( |
| cfg.model.peft.restore_from_ckpt.checkpoint_dir, cfg.model.peft.restore_from_ckpt.checkpoint_name |
| ) |
| |
| if not os.path.isdir(checkpoint_path): |
| |
| checkpoint_path = inject_model_parallel_rank( |
| os.path.join( |
| cfg.model.peft.restore_from_ckpt.checkpoint_dir, cfg.model.peft.restore_from_ckpt.checkpoint_name |
| ) |
| ) |
| model.load_adapters(checkpoint_path, peft_cfgs=peft_cfg_cls(model_cfg)) |
| else: |
| raise NotImplementedError("distributed checkpointing of PEFT weights is not supported") |
|
|
| model.freeze() |
| logging.info(f"Freezing parameters for PEFT eval:\n{model.summarize()}") |
|
|
| trainer.test(model) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|