Spaces:
Running on Zero
Running on Zero
File size: 7,415 Bytes
d1f1097 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
"""Sample Generate"""
import os
import sys
import warnings
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../")))
import os
import sys
from argparse import Namespace
from contextlib import nullcontext
from megatron.core.inference.engines.abstract_engine import AbstractEngine
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
InferenceWrapperConfig,
)
from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.inference.text_generation_controllers.text_generation_controller import (
TextGenerationController,
)
import torch
from megatron.core.inference.engines import AbstractEngine, StaticInferenceEngine
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
InferenceWrapperConfig,
)
from megatron.training import get_model
from megatron.core.transformer.module import MegatronModule
from megatron.inference.text_generation import beam_search_and_post_process
from megatron.inference.text_generation.mcore_engine_server import (
ModelInferenceWrapperServer,
run_mcore_engine,
)
from megatron.inference.text_generation_server import MegatronServer
from megatron.training import print_rank_0
from megatron.core import mpu
from megatron.training import get_args, get_model, get_tokenizer
from megatron.training.checkpointing import load_checkpoint
from megatron.training.initialize import initialize_megatron
from megatron.post_training.arguments import add_modelopt_args
def get_inference_engine(args: Namespace, model: MegatronModule) -> AbstractEngine:
"""Get the relevant backend for running inference
This function will automatically choose the TRTLLMBackend when possible, and default to Mcore
backend if the user does not specify any backends. TRTLLMBackend is not implmented yet.
Args:
args (Namespace): The user arguments parsed from command line
model (MegatronModule): The megatron model.
Returns:
AbstractBackend: The chosen backend
"""
tokenizer = get_tokenizer()
inference_wrapper_config = InferenceWrapperConfig(
hidden_size=args.hidden_size,
inference_batch_times_seqlen_threshold=args.inference_batch_times_seqlen_threshold,
fp32_residual_connection=args.fp32_residual_connection,
params_dtype=args.params_dtype,
padded_vocab_size=args.padded_vocab_size,
inference_max_seq_length=args.inference_max_seq_length,
inference_max_requests=args.inference_max_batch_size,
nccl_all_reduce_for_prefill=args.nccl_all_reduce_for_prefill,
)
inference_wrapped_model = ModelInferenceWrapperServer(model, inference_wrapper_config)
text_generation_controller = TextGenerationController(
inference_wrapped_model=inference_wrapped_model, tokenizer=tokenizer
)
return StaticInferenceEngine(
text_generation_controller=text_generation_controller,
max_batch_size=args.inference_max_batch_size,
)
def add_text_generate_args(parser):
"""Adds text generation arguments to parser."""
group = parser.add_argument_group(title='text generation')
group.add_argument(
"--port", type=int, default=5000, help='port for text generation server to run on'
)
group.add_argument("--temperature", type=float, default=1.0, help='Sampling temperature.')
group.add_argument("--top_k", type=int, default=1, help='Top k sampling.')
group.add_argument("--top_p", type=float, default=0.0, help='Top p sampling.')
group.add_argument(
"--return-log-probs",
action='store_true',
default=True,
help='Return the log probabilities of the final output tokens',
)
group.add_argument(
"--num-tokens-to-generate",
type=int,
default=30,
help='Number of tokens to generate for each prompt',
)
group.add_argument(
"--prompts",
metavar='N',
type=str,
nargs='+',
help='Input prompts with each prompt within quotes and seperated by space',
)
group.add_argument(
"--max-batch-size",
type=int,
default=None,
help='Deprecated in favor of `--inference-max-batch-size`',
)
add_modelopt_args(parser)
return parser
@torch.inference_mode()
def main(model_provider: str = "gpt"):
"""Runs the text generation server with the specified model provider."""
initialize_megatron(
extra_args_provider=add_text_generate_args,
args_defaults={
'no_load_rng': True,
'no_load_optim': True,
'exit_on_missing_checkpoint': True,
},
)
args = get_args()
if args.num_layers_per_virtual_pipeline_stage is not None:
print("Interleaved pipeline schedule is not yet supported for text generation.")
exit()
print_rank_0("WARNING: Forcing exit_on_missing_checkpoint to True for text " "generation.")
args.exit_on_missing_checkpoint = True
# Set up model and load checkpoint
load_context = nullcontext()
if args.fp8:
from transformer_engine.pytorch.fp8 import fp8_model_init
load_context = fp8_model_init()
with load_context:
from megatron.post_training.model_provider import model_provider as modelopt_model_provider
if model_provider == "gpt":
model = get_model(modelopt_model_provider, wrap_with_ddp=False)
elif model_provider == "mamba":
pass
else:
raise ValueError(f"Invalid model provider {model_provider}")
if args.load is not None:
_ = load_checkpoint(model, None, None, strict=False)
assert len(model) == 1, "Above condition should have caught this"
model = model[0]
model.eval()
if args.max_batch_size is not None:
assert args.inference_max_batch_size is not None
args.inference_max_batch_size = max(args.inference_max_batch_size, args.max_batch_size)
warnings.warn(
"`--max-batch-size` has been deprecated in favor of `--inference-max-requests`, "
f"setting maximum batch size to {args.inference_max_batch_size}"
)
inference_engine = get_inference_engine(args, model)
if args.enable_cuda_graph:
print(f"Running warmup for CUDA graphs...")
inference_engine.generate(
prompts=["Test prompt"], sampling_params=SamplingParams(num_tokens_to_generate=10)
)
if (
mpu.is_pipeline_first_stage()
and mpu.get_tensor_model_parallel_rank() == 0
and mpu.get_expert_model_parallel_rank() == 0
):
server = MegatronServer(inference_engine, args)
server.run("0.0.0.0", port=args.port)
while True:
choice = torch.tensor(1, dtype=torch.long, device='cuda')
torch.distributed.broadcast(choice, 0)
if choice.item() == 0:
try:
run_mcore_engine(inference_engine)
except ValueError as ve:
pass
elif choice.item() == 1:
try:
beam_search_and_post_process(
inference_engine.text_generation_controller.inference_wrapped_model.model
)
except ValueError as ve:
pass
if __name__ == "__main__":
main(model_provider="gpt")
|