| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import json |
| import logging |
| from pathlib import Path |
| from typing import List, Optional |
|
|
| import numpy as np |
| import torch |
| import torch.distributed |
| import wrapt |
| from jinja2 import Template |
| from megatron.core.dist_checkpointing.validation import StrictHandling |
| from megatron.core.inference.common_inference_params import CommonInferenceParams |
| from megatron.core.inference.inference_request import InferenceRequest |
|
|
| import nemo.lightning as nl |
| from nemo.collections.llm import inference |
| from nemo.deploy import ITritonDeployable |
| from nemo.deploy.utils import NEMO2, broadcast_list, cast_output, nemo_checkpoint_version, str_ndarray2list |
|
|
|
|
| @wrapt.decorator |
| def noop_decorator(func): |
| """A no-op decorator that returns the original function unchanged. |
| |
| Used as a fallback when pytriton's batch decorator is not available. |
| |
| Args: |
| func: The function to decorate |
| |
| Returns: |
| The original function without any modifications |
| """ |
|
|
| def wrapper(*args, **kwargs): |
| """ |
| Wrapper method returning the func. |
| """ |
| return func(*args, **kwargs) |
|
|
| return wrapper |
|
|
|
|
| use_pytriton = True |
| batch = noop_decorator |
| try: |
| from pytriton.decorators import batch, first_value |
| from pytriton.model_config import Tensor |
| except Exception: |
| use_pytriton = False |
|
|
| LOGGER = logging.getLogger("NeMo") |
|
|
|
|
| class MegatronLLMDeploy: |
| """ |
| A factory class for creating deployable instances of Megatron LLM models. |
| This class provides a method to get the appropriate deployable instance |
| based on the version of the NeMo checkpoint model used. |
| """ |
|
|
| @staticmethod |
| def get_deployable( |
| nemo_checkpoint_filepath: str, |
| num_devices: int = 1, |
| num_nodes: int = 1, |
| tensor_model_parallel_size: int = 1, |
| pipeline_model_parallel_size: int = 1, |
| context_parallel_size: int = 1, |
| max_batch_size: int = 32, |
| random_seed: Optional[int] = None, |
| enable_flash_decode: bool = False, |
| legacy_ckpt: bool = False, |
| ): |
| """ |
| Returns the appropriate deployable instance for the given NeMo checkpoint. |
| |
| Args: |
| nemo_checkpoint_filepath (str): Path to the .nemo checkpoint file. |
| num_devices (int): Number of devices to use for deployment. |
| num_nodes (int): Number of nodes to use for deployment. |
| tensor_model_parallel_size (int): Size of the tensor model parallelism. |
| pipeline_model_parallel_size (int): Size of the pipeline model parallelism. |
| context_parallel_size (int): Size of the context parallelism. |
| enable_flash_decode (bool): Whether to enable flash decode for inference. |
| |
| Returns: |
| ITritonDeployable: An instance of a deployable class compatible with Triton inference server. |
| """ |
| if nemo_checkpoint_version(nemo_checkpoint_filepath) == NEMO2: |
| return MegatronLLMDeployableNemo2( |
| nemo_checkpoint_filepath=nemo_checkpoint_filepath, |
| num_devices=num_devices, |
| num_nodes=num_nodes, |
| tensor_model_parallel_size=tensor_model_parallel_size, |
| pipeline_model_parallel_size=pipeline_model_parallel_size, |
| context_parallel_size=context_parallel_size, |
| max_batch_size=max_batch_size, |
| random_seed=random_seed, |
| enable_flash_decode=enable_flash_decode, |
| legacy_ckpt=legacy_ckpt, |
| ) |
| else: |
| raise Exception("Only NeMo 2.0 checkpoint is supported.") |
|
|
|
|
| def dict_to_str(messages): |
| """ |
| Serializes dict to str |
| """ |
| return json.dumps(messages) |
|
|
|
|
| class MegatronLLMDeployableNemo2(ITritonDeployable): |
| """ |
| Triton inference server compatible deploy class for a .nemo model file |
| |
| Args: |
| nemo_checkpoint_filepath (str): path for the nemo checkpoint. |
| num_devices (int): number of GPUs. |
| num_nodes (int): number of nodes. |
| tensor_model_parallel_size (int): tensor parallelism. |
| pipeline_parallelism_size (int): pipeline parallelism. |
| context_parallel_size (int): context parallelism. |
| params_dtype (torch.dtype): max input length. |
| inference_batch_times_seqlen_threshold (int): squence threshold. |
| inference_max_seq_length (int): max_seq_length for inference. Required by MCoreEngine (>=0.12). Defaults to |
| 4096. |
| max_batch_size (int): max batch size for inference. Defaults to 32. |
| random_seed (Optional[int]): random seed for inference. Defaults to None. |
| enable_flash_decode (bool): enable flash decode for inference. Defaults to False. |
| """ |
|
|
| def __init__( |
| self, |
| nemo_checkpoint_filepath: str = None, |
| num_devices: int = 1, |
| num_nodes: int = 1, |
| tensor_model_parallel_size: int = 1, |
| pipeline_model_parallel_size: int = 1, |
| context_parallel_size: int = 1, |
| expert_model_parallel_size: int = 1, |
| expert_tensor_parallel_size: int = 1, |
| params_dtype: torch.dtype = torch.bfloat16, |
| inference_batch_times_seqlen_threshold: int = 1000, |
| inference_max_seq_length: int = 4096, |
| max_batch_size: int = 32, |
| random_seed: Optional[int] = None, |
| enable_flash_decode: bool = True, |
| legacy_ckpt: bool = False, |
| ): |
| self.nemo_checkpoint_filepath = nemo_checkpoint_filepath |
|
|
| strategy = nl.MegatronStrategy( |
| tensor_model_parallel_size=tensor_model_parallel_size, |
| pipeline_model_parallel_size=pipeline_model_parallel_size, |
| context_parallel_size=context_parallel_size, |
| expert_model_parallel_size=expert_model_parallel_size, |
| expert_tensor_parallel_size=expert_tensor_parallel_size, |
| sequence_parallel=False, |
| setup_optimizers=False, |
| store_optimizer_states=False, |
| ckpt_load_strictness=StrictHandling.LOG_ALL if legacy_ckpt else None, |
| ) |
|
|
| trainer = nl.Trainer( |
| accelerator="gpu", |
| devices=num_devices, |
| num_nodes=num_nodes, |
| strategy=strategy, |
| plugins=nl.MegatronMixedPrecision( |
| precision="bf16-mixed", |
| params_dtype=torch.bfloat16, |
| pipeline_dtype=torch.bfloat16, |
| autocast_enabled=False, |
| grad_reduce_in_fp32=False, |
| ), |
| ) |
|
|
| self.mcore_engine, self.inference_wrapped_model, self.mcore_tokenizer = inference.setup_mcore_engine( |
| path=Path(nemo_checkpoint_filepath), |
| trainer=trainer, |
| params_dtype=params_dtype, |
| inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold, |
| inference_max_seq_length=inference_max_seq_length, |
| max_batch_size=max_batch_size, |
| random_seed=random_seed, |
| enable_flash_decode=enable_flash_decode, |
| ) |
|
|
| def generate( |
| self, prompts: List[str], inference_params: Optional[CommonInferenceParams] = None |
| ) -> List[InferenceRequest]: |
| """ |
| Generates text based on the provided input prompts. |
| |
| Args: |
| prompts (List[str]): A list of input strings. |
| inference_params (Optional[CommonInferenceParams]): Parameters for controlling the inference process. |
| Returns: |
| List[InferenceRequest]: A list containing the generated results. |
| """ |
|
|
| inference_params = inference_params or CommonInferenceParams() |
| results = self.mcore_engine.generate( |
| prompts=prompts, |
| add_BOS=False, |
| common_inference_params=inference_params, |
| ) |
| return list(results) |
|
|
| def generate_other_ranks(self): |
| """ |
| Generate function for ranks other than the rank 0. |
| """ |
|
|
| while True: |
| message = torch.empty(1, dtype=torch.long, device="cuda") |
| torch.distributed.broadcast(message, src=0) |
| if message == 0: |
| prompts = broadcast_list(data=[None], src=0) |
| temperature, top_k, top_p, num_tokens_to_generate, log_probs = broadcast_list(data=[None], src=0) |
|
|
| inference_params = CommonInferenceParams( |
| temperature=temperature, |
| top_k=int(top_k), |
| top_p=float(top_p), |
| num_tokens_to_generate=num_tokens_to_generate, |
| return_log_probs=log_probs, |
| ) |
|
|
| self.generate(prompts, inference_params) |
| else: |
| return |
|
|
| def apply_chat_template(self, messages, add_generation_prompt=True): |
| """ |
| Load the chat template. |
| Works when model's tokenizer has chat template (typically chat models). |
| """ |
| try: |
| tokenizer_chat_template = self.mcore_tokenizer.tokenizer.tokenizer.chat_template |
| bos_token = self.mcore_tokenizer.tokenizer.tokenizer.bos_token |
| template = Template(tokenizer_chat_template) |
| except AttributeError: |
| |
| raise ValueError( |
| "The tokenizer does not have chat template, if you would like to evaluate chat model \ |
| ensure your model's tokenizer has a chat template" |
| ) |
| |
| rendered_output = template.render( |
| messages=messages, bos_token=bos_token, add_generation_prompt=add_generation_prompt |
| ) |
|
|
| return rendered_output |
|
|
| def remove_eos_token(self, text): |
| """ |
| Removes eos token if it exists in the output, otherwise does nothing |
| """ |
| eos_token = self.mcore_tokenizer.tokenizer.tokenizer.eos_token |
| output = [] |
| for t in text: |
| if eos_token in t: |
| output.append(t.rsplit(eos_token, 1)[0]) |
| else: |
| output.append(t) |
| return output |
|
|
| def str_to_dict(self, json_str): |
| """ |
| Convert str to dict. |
| """ |
| return json.loads(json_str) |
|
|
| @property |
| def get_triton_input(self): |
| inputs = ( |
| Tensor(name="prompts", shape=(-1,), dtype=bytes), |
| Tensor(name="max_length", shape=(-1,), dtype=np.int_, optional=True), |
| Tensor(name="max_batch_size", shape=(-1,), dtype=np.int_, optional=True), |
| Tensor(name="top_k", shape=(-1,), dtype=np.int_, optional=True), |
| Tensor(name="top_p", shape=(-1,), dtype=np.single, optional=True), |
| Tensor(name="temperature", shape=(-1,), dtype=np.single, optional=True), |
| Tensor(name="random_seed", shape=(-1,), dtype=np.int_, optional=True), |
| Tensor(name="compute_logprob", shape=(-1,), dtype=np.bool_, optional=True), |
| Tensor(name="apply_chat_template", shape=(-1,), dtype=np.bool_, optional=True), |
| Tensor(name="n_top_logprobs", shape=(-1,), dtype=np.int_, optional=True), |
| Tensor(name="echo", shape=(-1,), dtype=np.bool_, optional=True), |
| ) |
| return inputs |
|
|
| @property |
| def get_triton_output(self): |
| return ( |
| Tensor(name="sentences", shape=(-1,), dtype=bytes), |
| Tensor(name="log_probs", shape=(-1,), dtype=np.single), |
| Tensor(name="top_logprobs", shape=(-1,), dtype=bytes), |
| ) |
|
|
| @batch |
| @first_value( |
| "max_length", |
| "max_batch_size", |
| "top_k", |
| "top_p", |
| "temperature", |
| "random_seed", |
| "compute_logprob", |
| "apply_chat_template", |
| "n_top_logprobs", |
| "echo", |
| ) |
| def triton_infer_fn(self, **inputs: np.ndarray): |
| output_infer = {} |
| prompts = str_ndarray2list(inputs.pop("prompts")) |
| temperature = inputs.pop("temperature", 1.0) |
| top_k = inputs.pop("top_k", 1) |
| top_p = inputs.pop("top_p", 0.0) |
| num_tokens_to_generate = inputs.pop("max_length", 256) |
| log_probs = inputs.pop("compute_logprob", False) |
| apply_chat_template = inputs.pop("apply_chat_template", False) |
| top_logprobs = inputs.pop("n_top_logprobs", 0) |
| echo = inputs.pop("echo", False) |
| text_only = True |
|
|
| if apply_chat_template: |
| |
| prompts = [self.str_to_dict(prompt) for prompt in prompts] |
| prompts = [self.apply_chat_template(prompt) for prompt in prompts] |
| |
| |
| if torch.distributed.is_initialized(): |
| if torch.distributed.get_world_size() > 1: |
| torch.distributed.broadcast(torch.tensor([0], dtype=torch.long, device="cuda"), src=0) |
| broadcast_list(prompts, src=0) |
| broadcast_list( |
| data=[ |
| temperature, |
| top_k, |
| top_p, |
| num_tokens_to_generate, |
| log_probs, |
| ], |
| src=0, |
| ) |
|
|
| inference_params = CommonInferenceParams( |
| temperature=temperature, |
| top_k=int(top_k), |
| top_p=float(top_p), |
| num_tokens_to_generate=num_tokens_to_generate, |
| return_log_probs=log_probs, |
| top_n_logprobs=top_logprobs, |
| ) |
|
|
| results = self.generate(prompts, inference_params) |
| if echo: |
| output_texts = [r.prompt + r.generated_text if text_only else r for r in results] |
| else: |
| output_texts = [r.generated_text if text_only else r for r in results] |
| output_texts = self.remove_eos_token(output_texts) |
| output_infer = {"sentences": cast_output(output_texts, np.bytes_)} |
| if log_probs: |
| output_log_probs = [] |
| for r in results: |
| |
| |
| if echo: |
| lp = torch.tensor(r.prompt_log_probs + r.generated_log_probs).cpu().detach().numpy() |
| else: |
| lp = torch.tensor(r.generated_log_probs).cpu().detach().numpy() |
| if len(lp) == 0: |
| output_log_probs.append([0]) |
| else: |
| output_log_probs.append(lp) |
| if echo: |
| |
| |
| |
| max_len = max(len(arr) for arr in output_log_probs) |
| |
| padded = np.array( |
| [np.pad(arr, (0, max_len - len(arr)), constant_values=0) for arr in output_log_probs] |
| ) |
|
|
| output_infer["log_probs"] = padded |
| else: |
| output_infer["log_probs"] = np.array(output_log_probs) |
| if top_logprobs: |
| output_top_n_log_probs = [] |
| for r in results: |
| |
| |
| |
| top_n_lp = dict_to_str(r.generated_top_n_logprobs) |
| output_top_n_log_probs.append(top_n_lp) |
| output_infer["top_logprobs"] = cast_output(output_top_n_log_probs, np.bytes_) |
|
|
| return output_infer |
|
|