| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
|
|
| from examples.nlp.language_modeling.megatron_gpt_eval import RequestDataSet |
| from lightning.pytorch import Trainer |
| from omegaconf.omegaconf import OmegaConf, open_dict |
| from torch.utils.data import DataLoader |
|
|
| from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel |
| from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam |
| from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector |
| from nemo.core.config import hydra_runner |
|
|
| try: |
| from megatron.core import parallel_state |
|
|
| HAVE_MEGATRON_CORE = True |
|
|
| except (ImportError, ModuleNotFoundError): |
|
|
| HAVE_MEGATRON_CORE = False |
|
|
| """ |
| This is the script to run RETRO Model text generation. |
| (This inferencing script for native NeMo RETRO will be soon deprecated. For new inferencing script for mcore RETRO, see ./megatron_retro_eval.py) |
| |
| Usage: |
| Assume the model has TP=1, PP=1 |
| run greedy inference from a nemo file: |
| python megatron_retro_eval.py \ |
| trainer.devices=1 \ |
| trainer.num_nodes=1 \ |
| trainer.accelerator=gpu \ |
| trainer.precision=16 \ |
| inference.tokens_to_generate=128 \ |
| inference.greedy=True \ |
| retro_model_file=path_to_retro_nemo_file \ |
| tensor_model_parallel_size=-1 \ |
| pipeline_model_parallel_size=-1 \ |
| retrieval_service.faiss_devices='0' \ |
| retrieval_service.faiss_index=path_to_faiss_index \ |
| retrieval_service.retrieval_index=path_to_retrieval_dataset \ |
| retrieval_service.neighbors=20 |
| """ |
|
|
|
|
| @hydra_runner(config_path="conf", config_name="megatron_retro_inference_legacy") |
| def main(cfg) -> None: |
| trainer = Trainer(strategy=NLPDDPStrategy(), **cfg.trainer) |
|
|
| model_path = cfg.retro_model_file |
|
|
| save_restore_connector = NLPSaveRestoreConnector() |
|
|
| if os.path.isdir(model_path): |
| save_restore_connector.model_extracted_dir = model_path |
|
|
| model_cfg = MegatronRetrievalModel.restore_from( |
| model_path, |
| trainer=trainer, |
| return_config=True, |
| save_restore_connector=save_restore_connector, |
| ) |
|
|
| with open_dict(model_cfg): |
| model_cfg.precision = trainer.precision |
| model_cfg.sequence_parallel = False |
| model_cfg.activations_checkpoint_granularity = None |
| model_cfg.activations_checkpoint_method = None |
|
|
| if ( |
| cfg.tensor_model_parallel_size < 0 |
| or cfg.pipeline_model_parallel_size < 0 |
| or cfg.get('pipeline_model_parallel_split_rank', -1) < 0 |
| ): |
| with open_dict(cfg): |
| cfg.tensor_model_parallel_size = model_cfg.get('tensor_model_parallel_size', 1) |
| cfg.pipeline_model_parallel_size = model_cfg.get('pipeline_model_parallel_size', 1) |
| cfg.pipeline_model_parallel_split_rank = model_cfg.get('pipeline_model_parallel_split_rank', 0) |
|
|
| model = MegatronRetrievalModel.restore_from( |
| model_path, |
| trainer=trainer, |
| save_restore_connector=save_restore_connector, |
| override_config_path=model_cfg, |
| ) |
|
|
| length_params: LengthParam = { |
| "max_length": cfg.inference.tokens_to_generate, |
| "min_length": cfg.inference.min_tokens_to_generate, |
| } |
|
|
| sampling_params: SamplingParam = { |
| "use_greedy": cfg.inference.greedy, |
| "temperature": cfg.inference.temperature, |
| "top_k": cfg.inference.top_k, |
| "top_p": cfg.inference.top_p, |
| "repetition_penalty": cfg.inference.repetition_penalty, |
| "add_BOS": cfg.inference.add_BOS, |
| "all_probs": cfg.inference.all_probs, |
| "compute_logprob": cfg.inference.compute_logprob, |
| } |
|
|
| |
| if parallel_state.is_unitialized(): |
|
|
| def dummy(): |
| return |
|
|
| if model.trainer.strategy.launcher is not None: |
| model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer) |
| model.trainer.strategy.setup_environment() |
|
|
| config = OmegaConf.to_container(cfg.inference) |
| retrieval_service = OmegaConf.to_container(cfg.retrieval_service) |
| model.set_inference_config(config, retrieval_service) |
|
|
| if not cfg.use_predict_method: |
| |
| response = model.generate( |
| inputs=OmegaConf.to_container(cfg.prompts), |
| length_params=length_params, |
| sampling_params=sampling_params, |
| strategy=model.inference_strategy, |
| ) |
| else: |
| |
| ds = RequestDataSet(OmegaConf.to_container(cfg.prompts)) |
| request_dl = DataLoader(dataset=ds, batch_size=cfg.inference_batch_size) |
| response = trainer.predict(model, request_dl) |
|
|
| print("***************************") |
| print(response) |
| print("***************************") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|