|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from argparse import ArgumentParser |
|
|
|
|
|
import torch |
|
|
import torch.distributed |
|
|
from megatron.core.inference.common_inference_params import CommonInferenceParams |
|
|
|
|
|
import nemo.lightning as nl |
|
|
from nemo.collections.llm import api |
|
|
|
|
|
""" |
|
|
torchrun --nproc-per-node=8 /opt/NeMo/scripts/llm/generate.py \ |
|
|
--model_path=<PATH_TO_NEMO2_MODEL> \ |
|
|
--tp=8 \ |
|
|
--devices=8 \ |
|
|
--num_tokens_to_generate=40 \ |
|
|
--temperature=0.001 \ |
|
|
--top_p=0.0 \ |
|
|
--top_k=1 \ |
|
|
--fp8 |
|
|
""" |
|
|
|
|
|
|
|
|
def get_args(): |
|
|
""" |
|
|
Parse the command line arguments. |
|
|
""" |
|
|
parser = ArgumentParser(description="""Run generation on a few sample prompts given the checkpoint path.""") |
|
|
parser.add_argument( |
|
|
"--prompts", |
|
|
type=str, |
|
|
nargs="+", |
|
|
default=[ |
|
|
"Q: How are you?", |
|
|
"Q: How big is the universe?", |
|
|
"Q: How is the weather?", |
|
|
"Q: How many stars are there?", |
|
|
"Paris is know for its ", |
|
|
"In a hot sunny day, you should ", |
|
|
"Q: How many planets are in the solar system?", |
|
|
"Q: How old are you?", |
|
|
], |
|
|
help="List of prompt strings", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--model_path", |
|
|
type=str, |
|
|
required=True, |
|
|
help="""Path to NeMo 2 checkpoint""", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--tp", |
|
|
type=int, |
|
|
default=1, |
|
|
help="""Tensor parallel size""", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--pp", |
|
|
type=int, |
|
|
default=1, |
|
|
help="""Pipeline parallel size""", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--ep", |
|
|
type=int, |
|
|
default=1, |
|
|
help="""Expert parallel size""", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--devices", |
|
|
type=int, |
|
|
default=1, |
|
|
help="""Number of GPUs to use on a single node""", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--nodes", |
|
|
type=int, |
|
|
default=1, |
|
|
help="""Number of nodes to use""", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--temperature", |
|
|
type=float, |
|
|
default=1.0, |
|
|
help="""Temperature to be used in megatron.core.inference.common_inference_params.CommonInferenceParams""", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--top_p", |
|
|
type=float, |
|
|
default=0.95, |
|
|
help="""top_p to be used in megatron.core.inference.common_inference_params.CommonInferenceParams""", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--top_k", |
|
|
type=int, |
|
|
default=0, |
|
|
help="""top_k to be used in megatron.core.inference.common_inference_params.CommonInferenceParams""", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--add_BOS", |
|
|
action="store_true", |
|
|
help="""Whether to add BOS token to the prompt""", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--num_tokens_to_generate", |
|
|
type=int, |
|
|
default=25, |
|
|
help="""Number of tokens to generate per prompt""", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--fp8", |
|
|
action="store_true", |
|
|
help="""Whether to run inference in FP8 precision""", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--fp8_recipe", |
|
|
type=str, |
|
|
default="tensorwise", |
|
|
help="""fp8 recipe, can be 'tensorwise', 'delayed', or 'mxfp8'""", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max_batch_size", |
|
|
type=int, |
|
|
default=8, |
|
|
help="""Maximum batch size for inference""", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--random_seed", |
|
|
type=int, |
|
|
default=1234, |
|
|
help="""Random seed for generation""", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--legacy_ckpt", |
|
|
action="store_true", |
|
|
help="""Load ckpt saved with TE < 1.14""", |
|
|
) |
|
|
args = parser.parse_args() |
|
|
return args |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
args = get_args() |
|
|
|
|
|
if args.fp8: |
|
|
assert len(args.prompts) % 8 == 0, "Batch size should be divisible by 8 for FP8 inference" |
|
|
|
|
|
strategy = nl.MegatronStrategy( |
|
|
tensor_model_parallel_size=args.tp, |
|
|
pipeline_model_parallel_size=args.pp, |
|
|
expert_model_parallel_size=args.ep, |
|
|
expert_tensor_parallel_size=1 if args.ep > 1 else None, |
|
|
context_parallel_size=1, |
|
|
sequence_parallel=False, |
|
|
setup_optimizers=False, |
|
|
store_optimizer_states=False, |
|
|
) |
|
|
|
|
|
trainer = nl.Trainer( |
|
|
accelerator="gpu", |
|
|
devices=args.devices, |
|
|
num_nodes=args.nodes, |
|
|
strategy=strategy, |
|
|
plugins=nl.MegatronMixedPrecision( |
|
|
precision="bf16-mixed", |
|
|
params_dtype=torch.bfloat16, |
|
|
pipeline_dtype=torch.bfloat16, |
|
|
autocast_enabled=False, |
|
|
grad_reduce_in_fp32=False, |
|
|
fp8="hybrid" if args.fp8 else None, |
|
|
fp8_recipe=args.fp8_recipe if args.fp8 else None, |
|
|
fp8_amax_history_len=1, |
|
|
fp8_amax_compute_algo="max" if args.fp8 else "most_recent", |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
if args.legacy_ckpt: |
|
|
trainer.strategy.ckpt_load_strictness = False |
|
|
|
|
|
prompts = args.prompts |
|
|
|
|
|
results = api.generate( |
|
|
path=args.model_path, |
|
|
prompts=prompts, |
|
|
trainer=trainer, |
|
|
add_BOS=args.add_BOS, |
|
|
inference_params=CommonInferenceParams( |
|
|
temperature=args.temperature, |
|
|
top_p=args.top_p, |
|
|
top_k=args.top_k, |
|
|
num_tokens_to_generate=args.num_tokens_to_generate, |
|
|
), |
|
|
text_only=True, |
|
|
max_batch_size=args.max_batch_size, |
|
|
random_seed=args.random_seed, |
|
|
) |
|
|
if torch.distributed.get_rank() == 0: |
|
|
for i, r in enumerate(results): |
|
|
print(prompts[i]) |
|
|
print("*" * 50) |
|
|
print(r) |
|
|
print("\n\n") |
|
|
|