| |
|
|
| """Sample Generate""" |
| import os |
| import sys |
| import warnings |
|
|
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))) |
| import os |
| import sys |
| from argparse import Namespace |
| from contextlib import nullcontext |
|
|
| from megatron.core.inference.engines.abstract_engine import AbstractEngine |
| from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( |
| InferenceWrapperConfig, |
| ) |
| from megatron.core.inference.sampling_params import SamplingParams |
| from megatron.core.inference.text_generation_controllers.text_generation_controller import ( |
| TextGenerationController, |
| ) |
| import torch |
|
|
| from megatron.core.inference.engines import AbstractEngine, StaticInferenceEngine |
| from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( |
| InferenceWrapperConfig, |
| ) |
| from megatron.training import get_model |
| from megatron.core.transformer.module import MegatronModule |
| from megatron.inference.text_generation import beam_search_and_post_process |
| from megatron.inference.text_generation.mcore_engine_server import ( |
| ModelInferenceWrapperServer, |
| run_mcore_engine, |
| ) |
| from megatron.inference.text_generation_server import MegatronServer |
| from megatron.training import print_rank_0 |
|
|
| from megatron.core import mpu |
| from megatron.training import get_args, get_model, get_tokenizer |
| from megatron.training.checkpointing import load_checkpoint |
| from megatron.training.initialize import initialize_megatron |
|
|
|
|
| from megatron.post_training.arguments import add_modelopt_args |
|
|
|
|
| def get_inference_engine(args: Namespace, model: MegatronModule) -> AbstractEngine: |
| """Get the relevant backend for running inference |
| |
| This function will automatically choose the TRTLLMBackend when possible, and default to Mcore |
| backend if the user does not specify any backends. TRTLLMBackend is not implmented yet. |
| |
| Args: |
| args (Namespace): The user arguments parsed from command line |
| model (MegatronModule): The megatron model. |
| |
| Returns: |
| AbstractBackend: The chosen backend |
| """ |
| tokenizer = get_tokenizer() |
|
|
| inference_wrapper_config = InferenceWrapperConfig( |
| hidden_size=args.hidden_size, |
| inference_batch_times_seqlen_threshold=args.inference_batch_times_seqlen_threshold, |
| fp32_residual_connection=args.fp32_residual_connection, |
| params_dtype=args.params_dtype, |
| padded_vocab_size=args.padded_vocab_size, |
| inference_max_seq_length=args.inference_max_seq_length, |
| inference_max_requests=args.inference_max_batch_size, |
| nccl_all_reduce_for_prefill=args.nccl_all_reduce_for_prefill, |
| ) |
|
|
| inference_wrapped_model = ModelInferenceWrapperServer(model, inference_wrapper_config) |
| text_generation_controller = TextGenerationController( |
| inference_wrapped_model=inference_wrapped_model, tokenizer=tokenizer |
| ) |
| return StaticInferenceEngine( |
| text_generation_controller=text_generation_controller, |
| max_batch_size=args.inference_max_batch_size, |
| ) |
|
|
|
|
| def add_text_generate_args(parser): |
| """Adds text generation arguments to parser.""" |
| group = parser.add_argument_group(title='text generation') |
| group.add_argument( |
| "--port", type=int, default=5000, help='port for text generation server to run on' |
| ) |
| group.add_argument("--temperature", type=float, default=1.0, help='Sampling temperature.') |
| group.add_argument("--top_k", type=int, default=1, help='Top k sampling.') |
| group.add_argument("--top_p", type=float, default=0.0, help='Top p sampling.') |
| group.add_argument( |
| "--return-log-probs", |
| action='store_true', |
| default=True, |
| help='Return the log probabilities of the final output tokens', |
| ) |
| group.add_argument( |
| "--num-tokens-to-generate", |
| type=int, |
| default=30, |
| help='Number of tokens to generate for each prompt', |
| ) |
| group.add_argument( |
| "--prompts", |
| metavar='N', |
| type=str, |
| nargs='+', |
| help='Input prompts with each prompt within quotes and seperated by space', |
| ) |
| group.add_argument( |
| "--max-batch-size", |
| type=int, |
| default=None, |
| help='Deprecated in favor of `--inference-max-batch-size`', |
| ) |
| add_modelopt_args(parser) |
| return parser |
|
|
|
|
| @torch.inference_mode() |
| def main(model_provider: str = "gpt"): |
| """Runs the text generation server with the specified model provider.""" |
| initialize_megatron( |
| extra_args_provider=add_text_generate_args, |
| args_defaults={ |
| 'no_load_rng': True, |
| 'no_load_optim': True, |
| 'exit_on_missing_checkpoint': True, |
| }, |
| ) |
| args = get_args() |
| if args.num_layers_per_virtual_pipeline_stage is not None: |
| print("Interleaved pipeline schedule is not yet supported for text generation.") |
| exit() |
| print_rank_0("WARNING: Forcing exit_on_missing_checkpoint to True for text " "generation.") |
| args.exit_on_missing_checkpoint = True |
|
|
| |
| load_context = nullcontext() |
| if args.fp8: |
| from transformer_engine.pytorch.fp8 import fp8_model_init |
|
|
| load_context = fp8_model_init() |
| with load_context: |
|
|
| from megatron.post_training.model_provider import model_provider as modelopt_model_provider |
| if model_provider == "gpt": |
| model = get_model(modelopt_model_provider, wrap_with_ddp=False) |
| elif model_provider == "mamba": |
| pass |
| else: |
| raise ValueError(f"Invalid model provider {model_provider}") |
|
|
| if args.load is not None: |
| _ = load_checkpoint(model, None, None, strict=False) |
|
|
| assert len(model) == 1, "Above condition should have caught this" |
| model = model[0] |
| model.eval() |
|
|
| if args.max_batch_size is not None: |
| assert args.inference_max_batch_size is not None |
| args.inference_max_batch_size = max(args.inference_max_batch_size, args.max_batch_size) |
| warnings.warn( |
| "`--max-batch-size` has been deprecated in favor of `--inference-max-requests`, " |
| f"setting maximum batch size to {args.inference_max_batch_size}" |
| ) |
|
|
| inference_engine = get_inference_engine(args, model) |
|
|
| if args.enable_cuda_graph: |
| print(f"Running warmup for CUDA graphs...") |
| inference_engine.generate( |
| prompts=["Test prompt"], sampling_params=SamplingParams(num_tokens_to_generate=10) |
| ) |
|
|
| if ( |
| mpu.is_pipeline_first_stage() |
| and mpu.get_tensor_model_parallel_rank() == 0 |
| and mpu.get_expert_model_parallel_rank() == 0 |
| ): |
| server = MegatronServer(inference_engine, args) |
| server.run("0.0.0.0", port=args.port) |
|
|
| while True: |
| choice = torch.tensor(1, dtype=torch.long, device='cuda') |
| torch.distributed.broadcast(choice, 0) |
| if choice.item() == 0: |
| try: |
| run_mcore_engine(inference_engine) |
| except ValueError as ve: |
| pass |
| elif choice.item() == 1: |
| try: |
| beam_search_and_post_process( |
| inference_engine.text_generation_controller.inference_wrapped_model.model |
| ) |
| except ValueError as ve: |
| pass |
|
|
|
|
| if __name__ == "__main__": |
| main(model_provider="gpt") |
|
|