| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | import torch |
| | from omegaconf.omegaconf import open_dict |
| | from pytorch_lightning.trainer.trainer import Trainer |
| |
|
| | from nemo.collections.nlp.models.language_modeling.megatron_t5_prompt_learning_model import ( |
| | MegatronT5PromptLearningModel, |
| | ) |
| | from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel |
| | from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy |
| | from nemo.core.config import hydra_runner |
| | from nemo.utils.app_state import AppState |
| |
|
| | try: |
| | from apex.transformer import parallel_state |
| |
|
| | HAVE_APEX = True |
| | except (ImportError, ModuleNotFoundError): |
| | HAVE_APEX = False |
| |
|
| |
|
| | if not torch.cuda.is_available(): |
| | raise EnvironmentError("GPU is needed for the inference") |
| |
|
| |
|
| | @hydra_runner(config_path="conf", config_name="megatron_t5_prompt_learning_inference") |
| | def main(cfg) -> None: |
| |
|
| | |
| | trainer = Trainer(strategy=NLPDDPStrategy(), **cfg.trainer) |
| | assert ( |
| | cfg.trainer.devices * cfg.trainer.num_nodes |
| | == cfg.tensor_model_parallel_size * cfg.pipeline_model_parallel_size |
| | ), "devices * num_nodes should equal tensor_model_parallel_size * pipeline_model_parallel_size" |
| |
|
| | app_state = AppState() |
| | if cfg.tensor_model_parallel_size > 1 or cfg.pipeline_model_parallel_size > 1: |
| | app_state.model_parallel_size = cfg.tensor_model_parallel_size * cfg.pipeline_model_parallel_size |
| | ( |
| | app_state.tensor_model_parallel_rank, |
| | app_state.pipeline_model_parallel_rank, |
| | app_state.model_parallel_size, |
| | app_state.data_parallel_size, |
| | app_state.pipeline_model_parallel_split_rank, |
| | app_state.virtual_pipeline_model_parallel_rank, |
| | ) = fake_initialize_model_parallel( |
| | world_size=app_state.model_parallel_size, |
| | rank=trainer.global_rank, |
| | tensor_model_parallel_size_=cfg.tensor_model_parallel_size, |
| | pipeline_model_parallel_size_=cfg.pipeline_model_parallel_size, |
| | pipeline_model_parallel_split_rank_=cfg.pipeline_model_parallel_split_rank, |
| | ) |
| |
|
| | |
| | if cfg.get('virtual_prompt_model_file', None) is not None and cfg.get('language_model_path', None) is not None: |
| |
|
| | |
| | prompt_learning_cfg = MegatronT5PromptLearningModel.restore_from( |
| | cfg.virtual_prompt_model_file, trainer=trainer, return_config=True |
| | ) |
| | with open_dict(prompt_learning_cfg): |
| | if cfg.get("language_model_path"): |
| | |
| | if hasattr(prompt_learning_cfg, 'pretrained_language_model_path'): |
| | prompt_learning_cfg.pretrained_language_model_path = cfg.language_model_path |
| | else: |
| | prompt_learning_cfg.language_model_path = cfg.language_model_path |
| | prompt_learning_cfg.micro_batch_size = cfg.data.get('micro_batch_size', 4) |
| | prompt_learning_cfg.global_batch_size = cfg.data.get('global_batch_size', 4) |
| |
|
| | |
| | model = MegatronT5PromptLearningModel.restore_from( |
| | restore_path=cfg.virtual_prompt_model_file, trainer=trainer, override_config_path=prompt_learning_cfg |
| | ) |
| |
|
| | else: |
| | raise ValueError("virtual_prompt_model_file and pretrained_language_model_file must be provided in config") |
| |
|
| | |
| | if parallel_state.is_unitialized(): |
| |
|
| | def dummy(): |
| | return |
| |
|
| | if model.trainer.strategy.launcher is not None: |
| | model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer) |
| | model.trainer.strategy.setup_environment() |
| |
|
| | model.freeze() |
| |
|
| | _, test_dl = model.build_virtual_prompt_dataset( |
| | dataset_paths=cfg.data.test_ds, |
| | batch_size=cfg.data.global_batch_size, |
| | for_train=False, |
| | drop_last=False, |
| | shuffle=False, |
| | num_workers=cfg.data.num_workers, |
| | pin_memory=True, |
| | ) |
| |
|
| | outputs = trainer.predict(model, test_dl) |
| | with open(cfg.pred_file_path, "w", encoding="utf-8") as pred_file: |
| | for batch in outputs: |
| | preds = batch["preds_text"] |
| | for pred in preds: |
| | pred = pred.strip().replace("\n", " ") |
| | pred_file.write(pred + "\n") |
| | print('test finish---------------------------------') |
| |
|
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|