NeMo_Canary / nemo /export /tensorrt_mm_exporter.py
Respair's picture
Upload folder using huggingface_hub
b386992 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import shutil
import tempfile
from pathlib import Path
from typing import List
import numpy as np
import wrapt
from tensorrt_llm.runtime import MultimodalModelRunner as TRTLLMRunner
from nemo.deploy import ITritonDeployable
from nemo.export.multimodal.build import (
build_mllama_engine,
build_perception_engine,
build_trtllm_engine,
build_visual_engine,
extract_lora_ckpt,
)
from nemo.export.multimodal.run import MultimodalModelRunner, SpeechllmModelRunner
use_deploy = True
try:
from nemo.deploy.utils import cast_output, ndarray2img, str_ndarray2list
except Exception:
use_deploy = False
@wrapt.decorator
def noop_decorator(func):
"""No op decorator"""
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
use_pytriton = True
batch = noop_decorator
try:
from pytriton.decorators import batch, first_value
from pytriton.model_config import Tensor
except Exception:
use_pytriton = False
LOGGER = logging.getLogger("NeMo")
class TensorRTMMExporter(ITritonDeployable):
"""
Exports nemo checkpoints to TensorRT and run fast inference.
Example:
from nemo.export import TensorRTMMExporter
exporter = TensorRTMMExporter(model_dir="/path/for/model/files")
exporter.export(
visual_checkpoint_path="/path/for/nemo/checkpoint",
model_type="neva",
tensor_parallel_size=1,
)
output = exporter.forward("Hi! What is in this image?", "/path/for/input_media")
print("output: ", output)
"""
def __init__(
self,
model_dir: str,
load_model: bool = True,
modality: str = "vision",
):
self.model_dir = model_dir
self.runner = None
# vision modality is for image and video
assert modality in ["vision", "audio"]
self.modality = modality
if load_model:
self._load()
def export(
self,
visual_checkpoint_path: str,
llm_checkpoint_path: str = None,
model_type: str = "neva",
llm_model_type: str = "llama",
tensor_parallel_size: int = 1,
max_input_len: int = 4096,
max_output_len: int = 256,
max_batch_size: int = 1,
vision_max_batch_size: int = 1,
max_multimodal_len: int = 3072,
dtype: str = "bfloat16",
delete_existing_files: bool = True,
load_model: bool = True,
use_lora_plugin: str = None,
lora_target_modules: List[str] = None,
lora_checkpoint_path: str = None,
max_lora_rank: int = 64,
):
"""Export multimodal models to TRTLLM"""
if Path(self.model_dir).exists():
if delete_existing_files and len(os.listdir(self.model_dir)) > 0:
for files in os.listdir(self.model_dir):
path = os.path.join(self.model_dir, files)
try:
shutil.rmtree(path)
except OSError:
os.remove(path)
if len(os.listdir(self.model_dir)) > 0:
raise Exception("Couldn't delete all files.")
elif len(os.listdir(self.model_dir)) > 0:
raise Exception("There are files in this folder. Try setting delete_existing_files=True.")
else:
Path(self.model_dir).mkdir(parents=True, exist_ok=True)
if model_type == "mllama":
build_mllama_engine(
model_dir=self.model_dir,
checkpoint_path=visual_checkpoint_path,
tensor_parallelism_size=tensor_parallel_size,
max_input_len=max_input_len,
max_output_len=max_output_len,
max_batch_size=max_batch_size,
vision_max_batch_size=vision_max_batch_size,
max_multimodal_len=max_multimodal_len,
dtype=dtype,
)
else:
if lora_checkpoint_path is not None:
tmp_dir = tempfile.TemporaryDirectory()
if os.path.isdir(lora_checkpoint_path):
lora_dir = lora_checkpoint_path
else:
raise ValueError("lora_checkpoint_path in nemo1 is not supported. It must be a directory")
llm_lora_path = [extract_lora_ckpt(lora_dir, tmp_dir.name)]
else:
tmp_dir = None
llm_lora_path = None
lora_dir = None
llm_dir = os.path.join(self.model_dir, "llm_engine")
build_trtllm_engine(
model_dir=llm_dir,
visual_checkpoint_path=visual_checkpoint_path,
llm_checkpoint_path=llm_checkpoint_path,
model_type=model_type,
llm_model_type=llm_model_type,
tensor_parallelism_size=tensor_parallel_size,
max_input_len=max_input_len,
max_output_len=max_output_len,
max_batch_size=max_batch_size,
max_multimodal_len=max_multimodal_len,
dtype=dtype,
use_lora_plugin=use_lora_plugin,
lora_target_modules=lora_target_modules,
max_lora_rank=max_lora_rank,
lora_ckpt_list=llm_lora_path,
)
if model_type == "salm":
perception_dir = os.path.join(self.model_dir, "perception_engine")
build_perception_engine(perception_dir, visual_checkpoint_path, model_type, vision_max_batch_size)
else:
visual_dir = os.path.join(self.model_dir, "visual_engine")
build_visual_engine(
visual_dir,
visual_checkpoint_path if lora_dir is None else lora_dir,
model_type,
vision_max_batch_size,
)
if tmp_dir is not None:
tmp_dir.cleanup()
if load_model:
self._load()
def forward(
self,
input_text: str,
input_media: str,
batch_size: int = 1,
max_output_len: int = 30,
top_k: int = 1,
top_p: float = 0.0,
temperature: float = 1.0,
repetition_penalty: float = 1.0,
num_beams: int = 1,
lora_uids: List[str] = None,
):
"""Run forward with loaded TRTLLM engine"""
if self.runner is None:
raise Exception(
"A nemo checkpoint should be exported and " "then it should be loaded first to run inference."
)
if isinstance(self.runner, TRTLLMRunner):
self.runner.args.image_path = input_media
self.runner.args.batch_size = batch_size
self.runner.args.top_k = top_k
self.runner.args.top_p = top_p
self.runner.args.temperature = temperature
self.runner.args.repetition_penalty = repetition_penalty
self.runner.args.num_beams = num_beams
raw_image = self.runner.load_test_data(input_media)
return self.runner.run(
input_text,
raw_image,
max_output_len,
)[1]
else:
input_media = self.runner.load_test_media(input_media)
return self.runner.run(
input_text,
input_media,
max_output_len,
batch_size,
top_k,
top_p,
temperature,
repetition_penalty,
num_beams,
lora_uids,
)
def get_input_media_tensors(self):
"""Get input media tensors"""
if self.modality == "vision":
return [Tensor(name="input_media", shape=(-1, -1, -1, 3), dtype=np.uint8)]
elif self.modality == "audio":
return [
Tensor(name="input_signal", shape=(-1,), dtype=np.single),
Tensor(name="input_signal_length", shape=(1,), dtype=np.intc),
]
return []
@property
def get_triton_input(self):
inputs = (
[Tensor(name="input_text", shape=(-1,), dtype=bytes)]
+ self.get_input_media_tensors()
+ [
Tensor(name="batch_size", shape=(-1,), dtype=np.int_, optional=True),
Tensor(name="max_output_len", shape=(-1,), dtype=np.int_, optional=True),
Tensor(name="top_k", shape=(-1,), dtype=np.int_, optional=True),
Tensor(name="top_p", shape=(-1,), dtype=np.single, optional=True),
Tensor(name="temperature", shape=(-1,), dtype=np.single, optional=True),
Tensor(name="repetition_penalty", shape=(-1,), dtype=np.single, optional=True),
Tensor(name="num_beams", shape=(-1,), dtype=np.int_, optional=True),
Tensor(name="lora_uids", shape=(-1,), dtype=bytes, optional=True),
]
)
inputs = tuple(inputs)
return inputs
@property
def get_triton_output(self):
outputs = (Tensor(name="outputs", shape=(-1,), dtype=bytes),)
return outputs
@batch
@first_value("batch_size", "max_output_len", "top_k", "top_p", "temperature", "repetition_penalty", "num_beams")
def triton_infer_fn(self, **inputs: np.ndarray):
try:
if self.runner is None:
raise Exception(
"A nemo checkpoint should be exported and then it should be loaded first to run inference."
)
infer_input = {"input_text": str_ndarray2list(inputs.pop("input_text")[0])}
video_model_list = ["video-neva", "lita", "vita"]
if self.runner.model_type in ["neva", "vila", "mllama"]:
infer_input["input_image"] = ndarray2img(inputs.pop("input_media")[0])[0]
elif self.runner.model_type in video_model_list:
infer_input["input_image"] = inputs.pop("input_media")[0]
elif self.runner.model_type == "salm":
infer_input["input_signal"] = inputs.pop("input_signal")
infer_input["input_signal_length"] = inputs.pop("input_signal_length")[:, 0]
if "batch_size" in inputs:
infer_input["batch_size"] = inputs.pop("batch_size")
if "max_output_len" in inputs:
infer_input["max_new_tokens"] = inputs.pop("max_output_len")
if "top_k" in inputs:
infer_input["top_k"] = inputs.pop("top_k")
if "top_p" in inputs:
infer_input["top_p"] = inputs.pop("top_p")
if "temperature" in inputs:
infer_input["temperature"] = inputs.pop("temperature")
if "repetition_penalty" in inputs:
infer_input["repetition_penalty"] = inputs.pop("repetition_penalty")
if "num_beams" in inputs:
infer_input["num_beams"] = inputs.pop("num_beams")
if "lora_uids" in inputs:
lora_uids = np.char.decode(inputs.pop("lora_uids").astype("bytes"), encoding="utf-8")
infer_input["lora_uids"] = lora_uids[0].tolist()
if isinstance(self.runner, TRTLLMRunner):
self.runner.args.batch_size = infer_input.pop("batch_size")
self.runner.args.top_k = infer_input.pop("top_k")
self.runner.args.top_p = infer_input.pop("top_p")
self.runner.args.temperature = infer_input.pop("temperature")
self.runner.args.repetition_penalty = infer_input.pop("repetition_penalty")
self.runner.args.num_beams = infer_input.pop("num_beams")
output_texts = self.runner.run(**infer_input)[1]
else:
output_texts = self.runner.run(**infer_input)
output = cast_output(output_texts, np.bytes_)
except Exception as error:
err_msg = "An error occurred: {0}".format(str(error))
output = cast_output([err_msg], np.bytes_)
return {"outputs": output}
def _load(self):
llm_dir = os.path.join(self.model_dir, "llm_engine")
if not os.path.exists(llm_dir):
return
if self.modality == "vision":
import json
visual_dir = os.path.join(self.model_dir, "visual_engine")
with open(os.path.join(visual_dir, "config.json"), "r") as f:
config = json.load(f)
if config["builder_config"]["model_type"] == "mllama":
from types import SimpleNamespace
args = SimpleNamespace(
visual_engine_dir=visual_dir,
visual_engine_name="visual_encoder.engine",
llm_engine_dir=llm_dir,
hf_model_dir='meta-llama/Llama-3.2-11B-Vision-Instruct',
use_py_session=True,
cross_kv_cache_fraction=0.5,
enable_context_fmha_fp32_acc=None,
enable_chunked_context=False,
kv_cache_free_gpu_memory_fraction=0.9,
multi_block_mode=True,
)
self.runner = TRTLLMRunner(args)
else:
self.runner = MultimodalModelRunner(visual_dir, llm_dir, self.modality)
elif self.modality == "audio":
perception_dir = os.path.join(self.model_dir, "perception_engine")
self.runner = SpeechllmModelRunner(perception_dir, llm_dir, self.modality)