Respair's picture
Upload folder using huggingface_hub
b386992 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
try:
import decord
except Exception:
import logging
logging.warning("The package `decord` was not installed in this environment.")
import einops
import numpy as np
import soundfile as sf
import tensorrt as trt
import tensorrt_llm
import tensorrt_llm.profiler as profiler
import torch
import yaml
from PIL import Image
from tensorrt_llm import logger
from tensorrt_llm._utils import str_dtype_to_trt, torch_dtype_to_trt
from tensorrt_llm.runtime import ModelRunner, Session, TensorInfo
from torch.nn import functional as F
from torchvision import transforms
from transformers import AutoProcessor, CLIPImageProcessor
from nemo.export.utils.constants import TRTLLM_ENGINE_DIR
def trt_dtype_to_torch(dtype):
if dtype == trt.float16:
return torch.float16
elif dtype == trt.float32:
return torch.float32
elif dtype == trt.int32:
return torch.int32
elif dtype == trt.bfloat16:
return torch.bfloat16
else:
raise TypeError("%s is not supported" % dtype)
class MultimodalModelRunner:
def __init__(self, visual_engine_dir, llm_engine_dir, modality='vision'):
self.modality = modality
self.runtime_rank = tensorrt_llm.mpi_rank()
device_id = self.runtime_rank % torch.cuda.device_count()
torch.cuda.set_device(device_id)
self.device = "cuda:%d" % (device_id)
self.stream = torch.cuda.Stream(torch.cuda.current_device())
torch.cuda.set_stream(self.stream)
# parse model type from visual engine config
with open(os.path.join(visual_engine_dir, "config.json"), "r") as f:
config = json.load(f)
self.model_type = config['builder_config']['model_type']
self.vision_precision = config['builder_config']['precision']
self.modality_precision = config['builder_config']['precision']
self.num_frames = config['builder_config'].get('num_frames', None)
self.image_size = config['builder_config'].get('image_size', None)
self.profiling_iterations = 20
if modality == 'vision':
self.init_image_encoder(visual_engine_dir)
self.init_tokenizer(llm_engine_dir)
self.init_llm(os.path.join(llm_engine_dir, TRTLLM_ENGINE_DIR)) # Engine is stored in subdirectory
if self.model_type == 'lita' or self.model_type == 'vila' or self.model_type == 'vita':
self.init_vision_preprocessor(visual_engine_dir)
def init_tokenizer(self, llm_engine_dir):
if os.path.exists(os.path.join(llm_engine_dir, "tokenizer_config.json")):
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(llm_engine_dir)
self.tokenizer.pad_token = self.tokenizer.eos_token
if self.model_type == 'vita':
self.tokenizer.im_start_id = self.tokenizer.convert_tokens_to_ids("<extra_id_4>")
self.tokenizer.im_end_id = self.tokenizer.convert_tokens_to_ids("<extra_id_5>")
self.tokenizer.vid_start_id = self.tokenizer.convert_tokens_to_ids("<extra_id_8>")
self.tokenizer.vid_end_id = self.tokenizer.convert_tokens_to_ids("<extra_id_9>")
else:
from sentencepiece import SentencePieceProcessor
sp = SentencePieceProcessor(os.path.join(llm_engine_dir, 'tokenizer.model'))
class return_obj:
def __init__(self, input_ids):
self.input_ids = input_ids
def __getitem__(self, name):
if name in "input_ids":
return self.input_ids
else:
raise AttributeError(f"'return_obj' has no item '{name}'")
# sentencepiece does not follow the same interface as HF
class HFTokenizerInterface:
def encode(self, x, return_tensors=None, **kwargs):
out = sp.encode(x)
if return_tensors == "pt":
out = torch.tensor(out)
return return_obj(out)
def __call__(self, x, return_tensors=None, **kwargs):
return self.encode(x, return_tensors, **kwargs)
def decode(self, x, **kwargs):
return sp.decode(x.tolist())
def batch_decode(self, x, **kwargs):
return self.decode(x, **kwargs)
self.tokenizer = HFTokenizerInterface()
self.tokenizer.eos_token_id = sp.eos_id()
self.tokenizer.bos_token_id = sp.bos_id()
self.tokenizer.pad_token_id = sp.pad_id()
self.tokenizer.padding_side = "right"
if self.model_type == 'lita':
self.tokenizer.im_start_id = sp.piece_to_id("<extra_id_4>")
self.tokenizer.im_end_id = sp.piece_to_id("<extra_id_5>")
self.tokenizer.vid_start_id = sp.piece_to_id("<extra_id_8>")
self.tokenizer.vid_end_id = sp.piece_to_id("<extra_id_9>")
def init_image_encoder(self, visual_engine_dir):
vision_encoder_path = os.path.join(visual_engine_dir, 'visual_encoder.engine')
logger.info(f'Loading engine from {vision_encoder_path}')
with open(vision_encoder_path, 'rb') as f:
engine_buffer = f.read()
logger.info(f'Creating session from engine {vision_encoder_path}')
self.visual_encoder_session = Session.from_serialized_engine(engine_buffer)
def init_vision_preprocessor(self, visual_encoder_dir):
with open(os.path.join(visual_encoder_dir, 'nemo_config.yaml'), 'r') as f:
self.nemo_config = yaml.safe_load(f)
vision_config = self.nemo_config["mm_cfg"]["vision_encoder"]
if self.model_type == 'lita':
self.image_processor = AutoProcessor.from_pretrained(
vision_config["from_pretrained"], torch_dtype=torch.bfloat16, trust_remote_code=True
)
elif self.model_type == 'vila' or self.model_type == 'vita':
from transformers import SiglipImageProcessor
self.image_processor = SiglipImageProcessor.from_pretrained(
vision_config["from_pretrained"], torch_dtype=torch.bfloat16, trust_remote_code=True
)
else:
raise ValueError(f"Invalid model type: {self.model_type}")
def init_llm(self, llm_engine_dir):
self.model = ModelRunner.from_dir(
llm_engine_dir,
rank=tensorrt_llm.mpi_rank(),
debug_mode=False,
stream=self.stream,
)
self.model_config = self.model.session._model_config
self.runtime_mapping = self.model.session.mapping
def video_preprocess(self, video_path):
from decord import VideoReader
if isinstance(video_path, str):
vr = VideoReader(video_path)
num_frames = self.num_frames
if num_frames == -1:
frames = [Image.fromarray(frame.asnumpy()).convert('RGB') for frame in vr]
else:
# equally sliced frames into self.num_frames frames
# if self.num_frames is greater than the number of frames in the video, we will repeat the last frame
num_frames = min(num_frames, len(vr))
indices = np.linspace(0, len(vr) - 1, num=num_frames, dtype=int)
frames = [Image.fromarray(vr[idx].asnumpy()).convert('RGB') for idx in indices]
if len(frames) < num_frames:
frames += [frames[-1]] * (num_frames - len(frames))
elif isinstance(video_path, np.ndarray):
num_frames = self.num_frames
if num_frames == -1:
frames = [Image.fromarray(frame).convert('RGB') for frame in video_path]
else:
# equally sliced frames into self.num_frames frames
# if self.num_frames is greater than the number of frames in the video, we will repeat the last frame
num_frames = min(num_frames, video_path.shape[0])
indices = np.linspace(0, video_path.shape[0] - 1, num=num_frames, dtype=int)
frames = [Image.fromarray(video_path[idx]).convert('RGB') for idx in indices]
if len(frames) < num_frames:
frames += [frames[-1]] * (num_frames - len(frames))
else:
frames = self.video_path
processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.bfloat16)
frames = processor.preprocess(frames, return_tensors="pt")['pixel_values']
# make dtype consistent with vision encoder
media_tensors = frames.to(
tensorrt_llm._utils.str_dtype_to_torch(self.vision_precision)
) # [num_frames, 3, H, W]
return media_tensors.unsqueeze(0) # [1, num_frames, 3, H, W]
def insert_tokens_by_index(self, input_ids, num_frames):
im_start_id = self.tokenizer.im_start_id
im_end_id = self.tokenizer.im_end_id
vid_start_id = self.tokenizer.vid_start_id
vid_end_id = self.tokenizer.vid_end_id
image_token_indices = (input_ids == 0).nonzero(as_tuple=False).squeeze().tolist()
input_ids = input_ids.squeeze().tolist()
offset = 0
# Insert the image tokens and corresponding start/end tokens
for i in range(num_frames):
idx = image_token_indices[1] + offset
input_ids.insert(idx + 1, im_end_id)
input_ids.insert(idx + 1, 0)
input_ids.insert(idx + 1, im_start_id)
offset += 3
# Insert the video start and end tokens around the video token
vid_idx = image_token_indices[1] + offset
input_ids.insert(vid_idx + 1, vid_end_id)
input_ids.insert(vid_idx + 1, 0)
input_ids.insert(vid_idx + 1, vid_start_id)
input_ids.pop(image_token_indices[1])
input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0)
return input_ids
def preprocess(self, warmup, pre_prompt, post_prompt, image, attention_mask, batch_size):
if not warmup:
profiler.start(self.modality.capitalize())
if not warmup:
profiler.stop(self.modality.capitalize())
if self.model_type == 'vila':
visual_features, visual_atts = self.get_visual_features(image, attention_mask)
input_ids = self.tokenizer_image_token(batch_size, pre_prompt[0] + post_prompt[0], self.tokenizer)
batch_split_prompts = self.split_prompt_by_images(input_ids)
first_batch_split_prompts = batch_split_prompts[0]
# compute prompt length + visual length
length = sum([ids.shape[1] for ids in first_batch_split_prompts])
if batch_size == 1 and len(image) > 1:
# mode 1: multiple image as a whole, flatten visual dims
length += visual_atts.shape[0] * visual_atts.shape[1]
else:
# mode 2: multiple images individually (replicate prompt for each image)
length += visual_atts.shape[1]
input_lengths = torch.IntTensor([length] * batch_size).to(torch.int32)
input_ids, ptuning_args = self.setup_fake_prompts_vila(
batch_size, visual_features, first_batch_split_prompts, input_lengths
)
return input_ids, input_lengths, ptuning_args, visual_features
elif self.model_type == 'lita' or self.model_type == 'vita':
visual_input = []
for i, img in enumerate(image):
visual_features, visual_atts = self.get_visual_features(img, attention_mask)
visual_features = visual_features.unsqueeze(0)
im_tokens, vid_tokens, num_sample_frames = self.preprocess_lita_visual(visual_features, self.nemo_config)
visual_input.extend([im_tokens, vid_tokens])
input_ids = self.tokenizer_image_token(batch_size, pre_prompt[0] + post_prompt[0], self.tokenizer)
input_ids = self.insert_tokens_by_index(input_ids, num_sample_frames)
batch_splits = self.split_prompt_by_images(input_ids)
first_batch_split_prompts = batch_splits[0]
length = sum([ids.shape[1] for ids in first_batch_split_prompts])
# Update visual atts shape to match im_tokens shape and vid_tokens shape
im_tokens = im_tokens.view(1, -1, im_tokens.shape[-1])
visual_features = torch.cat([im_tokens, vid_tokens], dim=1)
visual_atts = torch.ones(visual_features.size()[:-1], dtype=torch.long).to(image.device)
if batch_size == 1:
length += visual_atts.shape[0] * visual_atts.shape[1]
else:
raise ValueError("Batch size greater than 1 is not supported for LITA and VITA models")
input_lengths = torch.IntTensor([length] * batch_size).to(torch.int32)
input_ids, ptuning_args = self.setup_fake_prompts_vila(
batch_size, visual_input, first_batch_split_prompts, input_lengths
)
return input_ids, input_lengths, ptuning_args, visual_features
else:
visual_features, visual_atts = self.get_visual_features(image, attention_mask)
pre_input_ids = self.tokenizer(pre_prompt, return_tensors="pt", padding=True).input_ids
if post_prompt[0] is not None:
post_input_ids = self.tokenizer(post_prompt, return_tensors="pt", padding=True).input_ids
if self.model_type == 'video-neva':
length = (
pre_input_ids.shape[1] + post_input_ids.shape[1] + visual_atts.shape[2] * visual_atts.shape[1]
)
else:
length = pre_input_ids.shape[1] + post_input_ids.shape[1] + visual_atts.shape[1]
else:
post_input_ids = None
length = pre_input_ids.shape[1] + visual_atts.shape[1]
input_lengths = torch.IntTensor([length] * batch_size).to(torch.int32)
input_ids, ptuning_args = self.setup_fake_prompts(
visual_features, pre_input_ids, post_input_ids, input_lengths
)
return input_ids, input_lengths, ptuning_args, visual_features
@staticmethod
def tokenizer_image_token(batch_size, prompt, tokenizer, image_token_index=-200):
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
def insert_separator(X, sep):
return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
input_ids = []
offset = 0
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
offset = 1
input_ids.append(prompt_chunks[0][0])
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
input_ids.extend(x[offset:])
input_ids = torch.tensor(input_ids, dtype=torch.long)
input_ids[input_ids == image_token_index] = 0
input_ids = input_ids.unsqueeze(0).expand(batch_size, -1)
return input_ids
def split_prompt_by_images(self, tensor):
batch_splits = []
for batch in tensor:
# Find indices where value is zero (<image>)
zero_indices = (batch == 0).nonzero(as_tuple=False).squeeze(0)
# Add starting point for slicing
start_idx = 0
splits = []
for idx in zero_indices:
if start_idx != idx: # Ensure not slicing zero-length tensors
splits.append(batch[start_idx:idx].unsqueeze(0))
start_idx = idx + 1 # Move start index past the zero
if start_idx < len(batch): # Handle last segment if it's not zero-ending
splits.append(batch[start_idx:].unsqueeze(0))
# Remove empty tensors resulting from consecutive zeros
splits = [split for split in splits if split.numel() > 0]
batch_splits.append(splits)
return batch_splits
def generate(
self,
pre_prompt,
post_prompt,
image,
decoder_input_ids,
max_new_tokens,
attention_mask,
warmup,
batch_size,
top_k,
top_p,
temperature,
repetition_penalty,
num_beams,
lora_uids=None,
):
if not warmup:
profiler.start("Generate")
input_ids, input_lengths, ptuning_args, visual_features = self.preprocess(
warmup, pre_prompt, post_prompt, image, attention_mask, batch_size
)
if warmup:
return None
profiler.start("LLM")
end_id = self.tokenizer.eos_token_id
ptuning_args[0] = torch.stack([ptuning_args[0]])
output_ids = self.model.generate(
input_ids,
sampling_config=None,
prompt_table=ptuning_args[0],
max_new_tokens=max_new_tokens,
end_id=end_id,
pad_id=(
self.tokenizer.pad_token_id
if self.tokenizer.pad_token_id is not None
else self.tokenizer.all_special_ids[0]
),
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
num_beams=num_beams,
output_sequence_lengths=False,
lora_uids=lora_uids,
return_dict=False,
)
profiler.stop("LLM")
if tensorrt_llm.mpi_rank() == 0:
# Extract a list of tensors of shape beam_width x output_ids.
output_beams_list = [
self.tokenizer.batch_decode(
output_ids[batch_idx, :, input_lengths[batch_idx] :], skip_special_tokens=True
)
for batch_idx in range(batch_size)
]
stripped_text = [
[output_beams_list[batch_idx][beam_idx].strip() for beam_idx in range(num_beams)]
for batch_idx in range(batch_size)
]
profiler.stop("Generate")
return stripped_text
else:
profiler.stop("Generate")
return None
def get_visual_features(self, image, attention_mask):
visual_features = {'input': image.to(tensorrt_llm._utils.str_dtype_to_torch(self.vision_precision))}
if attention_mask is not None:
visual_features['attention_mask'] = attention_mask
tensor_info = [TensorInfo('input', str_dtype_to_trt(self.vision_precision), image.shape)]
if attention_mask is not None:
tensor_info.append(TensorInfo('attention_mask', trt.DataType.INT32, attention_mask.shape))
visual_output_info = self.visual_encoder_session.infer_shapes(tensor_info)
visual_outputs = {
t.name: torch.empty(tuple(t.shape), dtype=trt_dtype_to_torch(t.dtype), device=image.device)
for t in visual_output_info
}
ok = self.visual_encoder_session.run(visual_features, visual_outputs, self.stream.cuda_stream)
assert ok, "Runtime execution failed for vision encoder session"
self.stream.synchronize()
image_embeds = visual_outputs['output']
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
return image_embeds, image_atts
def setup_fake_prompts(self, visual_features, pre_input_ids, post_input_ids, input_lengths):
# Assemble fake prompts which points to image embedding actually
if hasattr(self, 'num_frames') and (visual_features.shape[1] == self.num_frames):
visual_features = visual_features.view(visual_features.shape[0], -1, visual_features.shape[-1])
fake_prompt_id = torch.arange(
self.model_config.vocab_size,
self.model_config.vocab_size + visual_features.shape[0] * visual_features.shape[1],
)
fake_prompt_id = fake_prompt_id.reshape(visual_features.shape[0], visual_features.shape[1])
if post_input_ids is not None:
input_ids = [pre_input_ids, fake_prompt_id, post_input_ids]
else:
input_ids = [fake_prompt_id, pre_input_ids]
input_ids = torch.cat(input_ids, dim=1).contiguous().to(torch.int32)
ptuning_args = self.ptuning_setup(visual_features, input_ids, input_lengths)
return input_ids, ptuning_args
def setup_fake_prompts_vila(self, batch_size, visual_features, split_input_ids, input_lengths):
if self.model_type == 'lita' or self.model_type == 'vita':
squeeze_img_tokens = visual_features[0].squeeze(0)
reshape_img_tokens = [t.unsqueeze(0) for t in squeeze_img_tokens]
visual_features = reshape_img_tokens + [visual_features[1]]
fake_prompt_counter = self.model_config.vocab_size
if batch_size == 1:
# only check for multi-image inference (mode 1)
assert len(visual_features) <= len(
split_input_ids
), "Unexpected number of visual features. Please check #<image> in prompt and the #image files."
input_ids = []
if batch_size == 1:
input_ids = [split_input_ids[0]]
if self.model_type == 'vila':
# mode 1: multiple image as a whole, concat all prompts together, <pre><image1><inter><image2>...<post>
for idx, visual_feature in enumerate(visual_features):
fake_prompt_id = torch.arange(fake_prompt_counter, fake_prompt_counter + visual_feature.shape[0])
fake_prompt_counter += visual_feature.shape[0]
fake_prompt_id = fake_prompt_id.unsqueeze(0)
input_ids.append(fake_prompt_id)
# in case no post prompt
if len(split_input_ids) > idx + 1:
input_ids.append(split_input_ids[idx + 1])
elif self.model_type == 'lita' or self.model_type == 'vita':
for idx, visual_f in enumerate(visual_features):
fake_prompt_id = torch.arange(fake_prompt_counter, fake_prompt_counter + visual_f.shape[1])
fake_prompt_id = fake_prompt_id.reshape(visual_f.shape[1])
fake_prompt_counter += visual_f.shape[1]
fake_prompt_id = fake_prompt_id.unsqueeze(0)
input_ids.append(fake_prompt_id)
# in case no post prompt
if len(split_input_ids) > idx + 1:
input_ids.append(split_input_ids[idx + 1])
elif batch_size > 1 and self.model_type == 'vila':
# mode 2: each image have individual prompt, <pre><image><post>
for idx, visual_feature in enumerate(visual_features):
input_ids.append(split_input_ids[0])
fake_prompt_id = torch.arange(fake_prompt_counter, fake_prompt_counter + visual_feature.shape[0])
fake_prompt_counter += visual_feature.shape[0]
fake_prompt_id = fake_prompt_id.unsqueeze(0)
input_ids.append(fake_prompt_id)
if len(split_input_ids) > 1:
input_ids.append(split_input_ids[1])
input_ids = torch.cat(input_ids, dim=1).contiguous().to(torch.int32)
input_ids = input_ids.reshape(batch_size, -1)
ptuning_args = self.ptuning_setup(visual_features, input_ids, input_lengths)
return input_ids, ptuning_args
def preprocess_lita_visual(self, visual_features, config):
b, t, s, d = visual_features.shape
num_frames = t
if (
'visual_token_format' in config['mm_cfg']['lita']
and config['mm_cfg']['lita']['visual_token_format'] == 'im_vid_start_end'
):
num_image_frames = min(num_frames, config['mm_cfg']['lita']['sample_frames'])
idx = np.round(np.linspace(0, num_frames - 1, num_image_frames)).astype(int)
# Image and video features
im_features = visual_features[:, idx, ...]
vid_features = einops.reduce(visual_features, 'b t s d -> b t d', 'mean')
return im_features, vid_features, num_image_frames
elif (
'lita_video_arch' in config['mm_cfg']['lita']
and config['mm_cfg']['lita']['lita_video_arch'] == 'temporal_spatial_pool'
):
pool_size = 2
selected_frames = np.round(np.linspace(0, visual_features.shape[1] - 1, pool_size * pool_size)).astype(int)
s_tokens = visual_features[:, selected_frames, ...]
s_tokens = einops.rearrange(s_tokens, 'b t (h w) d -> (b t) d h w', h=16, w=16)
s_tokens = F.avg_pool2d(s_tokens, kernel_size=pool_size)
s_tokens = einops.rearrange(s_tokens, '(b t) d h w -> b (t h w) d', b=b)
t_tokens = einops.reduce(visual_features, 'b t s d -> b t d', 'mean')
return t_tokens, s_tokens, pool_size**2
else:
raise ValueError(f'Invalid visual token format: {config["mm_cfg"]["lita"]["visual_token_format"]}')
def ptuning_setup(self, prompt_table, input_ids, input_lengths):
hidden_size = self.model_config.hidden_size * self.runtime_mapping.tp_size
if self.model_type == 'lita' or self.model_type == 'vita':
prompt_table = torch.cat(prompt_table, dim=1)
if prompt_table is not None:
task_vocab_size = torch.tensor(
[prompt_table.shape[1]],
dtype=torch.int32,
).cuda()
prompt_table = prompt_table.view((prompt_table.shape[0] * prompt_table.shape[1], prompt_table.shape[2]))
assert prompt_table.shape[1] == hidden_size, "Prompt table dimensions do not match hidden size"
prompt_table = prompt_table.cuda().to(
dtype=tensorrt_llm._utils.str_dtype_to_torch(self.model_config.dtype)
)
else:
prompt_table = torch.empty([1, hidden_size]).cuda()
task_vocab_size = torch.zeros([1]).cuda()
if self.model_config.remove_input_padding:
tasks = torch.zeros([torch.sum(input_lengths)], dtype=torch.int32).cuda()
else:
tasks = torch.zeros(input_ids.shape, dtype=torch.int32).cuda()
return [prompt_table, tasks, task_vocab_size]
def expand2square_pt(self, images, background_color):
height, width = images.shape[-2:]
b = len(images)
background_color = torch.Tensor(background_color)
if width == height:
return images
elif width > height:
result = einops.repeat(background_color, 'c -> b c h w', b=b, h=width, w=width).clone()
paste_start = (width - height) // 2
paste_end = paste_start + height
result[:, :, paste_start:paste_end, :] = images
return result
else:
result = einops.repeat(background_color, 'c -> b c h w', b=b, h=height, w=height).clone()
paste_start = (height - width) // 2
paste_end = paste_start + width
result[:, :, :, paste_start:paste_end] = images
return result
def load_video(self, config, video_path, processor, num_frames=None):
frames = None
if isinstance(video_path, str):
decord.bridge.set_bridge('torch')
video_reader = decord.VideoReader(uri=video_path)
if num_frames is not None:
idx = np.round(np.linspace(0, len(video_reader) - 1, num_frames)).astype(int)
frames = video_reader.get_batch(idx)
else:
frames = torch.cat([torch.tensor(f.asnumpy()) for f in video_reader])
elif isinstance(video_path, np.ndarray):
frames = torch.tensor(video_path, dtype=torch.float32)
return self.preprocess_frames(frames, config, processor)
def preprocess_frames(self, frames, config, processor):
frames = einops.rearrange(frames, 't h w c -> t c h w')
if config['data']['image_aspect_ratio'] == 'pad':
frames = self.expand2square_pt(frames, tuple(int(x * 255) for x in processor.image_mean))
processed_frames = processor.preprocess(frames, return_tensors='pt')['pixel_values']
return processed_frames
def get_num_sample_frames(self, config, vid_len):
if (
'visual_token_format' in config['mm_cfg']['lita']
and config['mm_cfg']['lita']['visual_token_format'] == 'im_vid_start_end'
):
max_frames = config['data']['num_frames']
if vid_len <= max_frames:
return vid_len
else:
subsample = int(np.ceil(float(vid_len) / max_frames))
return int(np.round(float(vid_len) / subsample))
else:
return config['mm_cfg']['lita']['sample_frames']
def process_lita_video(self, nemo_config, video_path, image_processor):
image = None
if isinstance(video_path, str):
vid_len = len(decord.VideoReader(video_path))
num_sample_frames = self.get_num_sample_frames(nemo_config, vid_len)
image = (
self.load_video(nemo_config, video_path, image_processor, num_sample_frames)
.unsqueeze(0)
.to(self.device, dtype=torch.bfloat16)
)
elif isinstance(video_path, np.ndarray):
image = (
self.load_video(nemo_config, video_path, image_processor)
.unsqueeze(0)
.to(self.device, dtype=torch.bfloat16)
)
return image
def process_image(self, image_file, image_processor, nemo_config, image_folder):
if isinstance(image_file, str):
if image_folder is not None:
image = Image.open(os.path.join(image_folder, image_file)).convert("RGB")
else:
image = Image.open(image_file).convert("RGB")
else:
# image is stored in bytearray
image = image_file
crop_size = nemo_config['mm_cfg']['vision_encoder']['crop_size']
crop_size = tuple(crop_size)
image = image.resize(crop_size)
if nemo_config['data']['image_aspect_ratio'] == 'pad':
image = self.expand2square_pt(image, tuple(int(x * 255) for x in image_processor.image_mean))
image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
else:
image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
return image
def process_vila_img(self, images):
new_images = [self.process_image(image, self.image_processor, self.nemo_config, None) for image in images]
if all(x.shape == new_images[0].shape for x in new_images):
new_images = torch.stack(new_images, dim=0)
return new_images
def setup_inputs(self, input_text, raw_image, batch_size):
attention_mask = None
image = None
if self.model_type == "neva":
image_size = self.image_size
dtype = torch.float32
transform = transforms.Compose(
[
transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
image = transform(raw_image).to(dtype).unsqueeze(0)
if input_text is None:
input_text = "Hi! What is in this image?"
pre_prompt = "<extra_id_0>System\n\n<extra_id_1>User\n"
post_prompt = f"\n{input_text}\n<extra_id_1>Assistant\n"
elif self.model_type == "video-neva":
image = self.video_preprocess(raw_image) # shape (1, num_frames, 3, H, W)
if input_text is None:
input_text = "Hi! What is in this video?"
# SteerLM prompt template
pre_prompt = (
"<extra_id_0>System\nA chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.\n\n"
"<extra_id_1>User"
)
post_prompt = (
f"\n{input_text}\n<extra_id_1>Assistant\n"
"<extra_id_2>quality:4,toxicity:0,humor:0,creativity:0,helpfulness:4,"
"correctness:4,coherence:4,complexity:4,verbosity:4\n"
)
elif self.model_type in ['vila', 'lita', 'vita']:
if self.model_type == "vila" or self.model_type == "lita":
pre_prompt = (
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions. USER: "
)
if input_text is None:
input_text = "<image>\n Please elaborate what you see in the images?"
post_prompt = input_text + " ASSISTANT:"
elif self.model_type == "vita":
# llama3 prompt template
pre_prompt = (
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
"You are a helpful language and vision assistant. "
"You are able to understand the visual content that the user provides, "
"and assist the user with a variety of tasks using natural language. "
"<|start_header_id|>user<|end_header_id|>\n\n"
)
if input_text is None:
input_text = "<image>\n Please elaborate what you see in the images?"
post_prompt = input_text + "<|start_header_id|>assistant<|end_header_id|>\n\n"
else:
raise RuntimeError(f"Invalid model type {self.model_type}")
if self.model_type == 'lita' or self.model_type == 'vita':
image = self.process_lita_video(self.nemo_config, raw_image, self.image_processor)
if self.model_type == 'vila':
raw_image = [raw_image] * batch_size
image = self.process_vila_img(raw_image)
# Repeat inputs to match batch size
pre_prompt = [pre_prompt] * batch_size
post_prompt = [post_prompt] * batch_size
if self.model_type not in ['vila', 'lita', 'vita']:
if image.dim() == 5:
image = image.expand(batch_size, -1, -1, -1, -1).contiguous()
else:
image = image.expand(batch_size, -1, -1, -1).contiguous()
image = image.to(self.device)
decoder_input_ids = None
return input_text, pre_prompt, post_prompt, image, decoder_input_ids, attention_mask
def run(
self,
input_text,
input_image,
max_new_tokens,
batch_size,
top_k,
top_p,
temperature,
repetition_penalty,
num_beams,
lora_uids=None,
run_profiling=False,
check_accuracy=False,
):
input_text, pre_prompt, post_prompt, processed_image, decoder_input_ids, attention_mask = self.setup_inputs(
input_text, input_image, batch_size
)
self.generate(
pre_prompt,
post_prompt,
processed_image,
decoder_input_ids,
max_new_tokens,
attention_mask=attention_mask,
warmup=True,
batch_size=batch_size,
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
num_beams=num_beams,
lora_uids=lora_uids,
)
num_iters = self.profiling_iterations if run_profiling else 1
for _ in range(num_iters):
output_text = self.generate(
pre_prompt,
post_prompt,
processed_image,
decoder_input_ids,
max_new_tokens,
attention_mask=attention_mask,
warmup=False,
batch_size=batch_size,
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
num_beams=num_beams,
lora_uids=lora_uids,
)
if self.runtime_rank == 0:
self.print_result(input_text, output_text, batch_size, num_beams, run_profiling, check_accuracy)
return output_text
def print_result(self, input_text, output_text, batch_size, num_beams, run_profiling, check_accuracy):
if not run_profiling and not check_accuracy:
return
logger.info("---------------------------------------------------------")
if self.model_type != 'nougat':
logger.info(f"\n[Q] {input_text}")
logger.info(f"\n[A] {output_text[0]}")
if num_beams == 1:
output_ids = self.tokenizer(output_text[0][0], add_special_tokens=False)['input_ids']
logger.info(f"Generated {len(output_ids)} tokens")
if check_accuracy:
for i in range(batch_size - 1):
if not (output_text[i] == output_text[i + 1]):
logger.info(f"Output {i} and {i + 1} do not match")
assert False
assert 'robot' in output_text[0][0].lower()
if run_profiling:
msec_per_batch = lambda name: 1000 * profiler.elapsed_time_in_sec(name) / self.profiling_iterations
logger.info('Latencies per batch (msec)')
logger.info(f'TRT {self.modality} encoder: %.1f' % (msec_per_batch(self.modality.capitalize())))
logger.info('TRTLLM LLM generate: %.1f' % (msec_per_batch('LLM')))
logger.info('Multimodal generate: %.1f' % (msec_per_batch('Generate')))
logger.info("---------------------------------------------------------")
def load_test_media(self, input_media):
media_model = ["video-neva", "lita", "vita"]
if self.model_type in media_model:
media = input_media
elif self.model_type == "neva" or self.model_type == "vila":
media = Image.open(input_media).convert('RGB')
else:
raise RuntimeError(f"Invalid model type {self.model_type}")
return media
class SpeechllmModelRunner(MultimodalModelRunner):
def __init__(self, perception_engine_dir, llm_engine_dir, modality):
"""
perception_engine_dir: path to the perception engine directory
it should contain:
config.json nemo_config.yaml
perception_encoder.engine : tensorrt engine
feature_extractor.ts : torchscript model
llm_engine_dir: path to the LLM engine directory
"""
super().__init__(perception_engine_dir, llm_engine_dir, modality)
assert self.model_type == 'salm'
# init preprocessor
feature_extractor_path = os.path.join(perception_engine_dir, 'feature_extractor.ts')
self.feature_extractor = self.init_speech_preprocessor(feature_extractor_path)
self.init_modality_encoder(perception_engine_dir)
def init_modality_encoder(self, engine_dir):
"""
Initialize the modality encoder session from the prebuilt engine directory
Args:
engine_dir: str, path to the engine directory
"""
# find file with .engine extension
engine_file = None
for file in os.listdir(engine_dir):
if file.endswith('.engine'):
engine_file = file
break
assert engine_file is not None, f"Engine file not found in {engine_dir}"
encoder_path = os.path.join(engine_dir, engine_file)
logger.info(f'Loading engine from {encoder_path}')
with open(encoder_path, 'rb') as f:
engine_buffer = f.read()
logger.info(f'Creating session from engine {encoder_path}')
self.modality_encoder_session = Session.from_serialized_engine(engine_buffer)
def init_speech_preprocessor(self, feature_extractor_path):
feature_extractor = torch.jit.load(feature_extractor_path)
feature_extractor.eval()
return feature_extractor
def process_audio(self, input_signal, input_signal_length):
"""
Args:
input_signal: audio signal in numpy array
input_signal_length: length of the audio signal in numpy array
Returns:
processed_signal: torch.tensor [B, 80, T]
processed_signal_length [B]
"""
input_signal = torch.tensor(input_signal, dtype=torch.float32)
input_signal_length = torch.tensor(input_signal_length, dtype=torch.int32)
processed_signal, processed_signal_length = self.feature_extractor(input_signal, input_signal_length)
return processed_signal, processed_signal_length
def setup_inputs(self, input_text, input_media, batch_size):
"""
Args:
input_text: str or List[str] or None
input_media: Tuple[np.array, np.array]
input_signal: audio signal in numpy array [b, -1]
input_signal_length: length of the audio signal in numpy array [b]
batch_size: int
"""
input_signal, input_signal_length = input_media
processed_signal, processed_signal_length = self.process_audio(input_signal, input_signal_length)
processed_signal = processed_signal.to(self.device)
processed_signal_length = processed_signal_length.to(self.device)
if input_text is None:
input_text = "Q: what's the transcription of the audio? A:"
if isinstance(input_text, str):
input_text = [input_text] * batch_size
assert len(input_text) == batch_size
pre_prompt = [''] * batch_size
post_prompt = input_text
decoder_input_ids = None
attention_mask = None
return (
input_text,
pre_prompt,
post_prompt,
processed_signal,
processed_signal_length,
decoder_input_ids,
attention_mask,
)
def load_test_media(self, input_media_path):
"""
Args:
input_media_path: str, path to the audio file
Returns:
input_signal: np.array [1, -1]
input_signal_length: np.array [1]
"""
waveform, sample_rate = sf.read(input_media_path, dtype=np.float32)
input_signal = np.array([waveform], dtype=np.float32)
input_signal_length = np.array([len(waveform)], dtype=np.int32)
return input_signal, input_signal_length
def get_modality_encoder_features(self, modality_features, attention_mask):
"""
Do inference on the modality encoder engine
Args:
modality_features: dict {'input1': torch.tensor, 'input2': torch.tensor, ..}
attention_mask: None
Returns:
"""
if attention_mask is not None:
modality_features['attention_mask'] = attention_mask
tensor_info = []
for key, tensor in modality_features.items():
tensor_info.append(TensorInfo(key, torch_dtype_to_trt(tensor.dtype), tensor.shape))
output_info = self.modality_encoder_session.infer_shapes(tensor_info)
outputs = {
t.name: torch.empty(tuple(t.shape), dtype=trt_dtype_to_torch(t.dtype), device=self.device)
for t in output_info
}
ok = self.modality_encoder_session.run(modality_features, outputs, self.stream.cuda_stream)
assert ok, "Runtime execution failed for vision encoder session"
self.stream.synchronize()
return outputs
def preprocess(self, warmup, pre_prompt, post_prompt, processed_features, attention_mask, batch_size):
"""
Args:
warmup: bool
pre_prompt: List[str]
post_prompt: List[str]
processed_features: Tuple[torch.tensor, torch.tensor]
processed_signal: torch.tensor [B, 80, T]
processed_signal_length: torch.tensor [B]
attention_mask: None
batch_size: int
Returns:
input_ids: torch.tensor [B, L]
input_lengths: torch.tensor [B]
ptuning_args: List[torch.tensor]
encoded_features: torch.tensor [B, L, D]
"""
if not warmup:
profiler.start(self.modality.capitalize())
if not warmup:
profiler.stop(self.modality.capitalize())
assert self.model_type == 'salm', f"Invalid model type {self.model_type}"
processed_features = {
"processed_signal": processed_features[0],
"processed_signal_length": processed_features[1].to(torch.int32),
}
encoded_outputs = self.get_modality_encoder_features(processed_features, attention_mask)
encoded_features, encoded_length = encoded_outputs['encoded'], encoded_outputs['encoded_length']
pre_input_ids = self.tokenizer(pre_prompt).input_ids
post_input_ids = self.tokenizer(post_prompt).input_ids
input_lengths = []
input_ids = []
encoded_length = encoded_length.cpu().numpy()
fake_id_start = self.model.vocab_size
for i in range(batch_size):
feat_len = encoded_length[i]
feat_fake_ids = np.arange(fake_id_start, fake_id_start + feat_len)
cur_input_ids = np.concatenate([pre_input_ids[i], feat_fake_ids, post_input_ids[i]])
fake_id_start += feat_len
input_lengths.append(len(cur_input_ids))
input_ids.append(cur_input_ids)
max_length = max(input_lengths)
# convert input_ids to torch tensor with padding
input_ids = [
np.pad(ids, (0, max_length - len(ids)), 'constant', constant_values=self.tokenizer.pad_token_id)
for ids in input_ids
]
input_ids = torch.tensor(input_ids, dtype=torch.int32)
input_lengths = torch.tensor(input_lengths, dtype=torch.int32)
ptuning_args = self.ptuning_setup(encoded_features, input_ids, input_lengths)
return input_ids, input_lengths, ptuning_args, encoded_features
def run(
self,
input_text,
input_media=None,
max_new_tokens: int = 30,
batch_size: int = 1,
top_k: int = 1,
top_p: float = 0.0,
temperature: float = 1.0,
repetition_penalty: float = 1.0,
num_beams: int = 1,
run_profiling=False,
check_accuracy=False,
input_signal=None,
input_signal_length=None,
lora_uids=None,
):
"""
Args:
input_text: str or List[str] or None
input_media: Tuple[np.array, np.array] or None
input_signal: audio signal in numpy array [b, -1]
input_signal_length: length of the audio signal in numpy array [b]
max_new_tokens: int
batch_size: int
top_k: int
top_p: float
temperature: float
repetition_penalty: float
num_beams: int
run_profiling: bool
check_accuracy: bool
"""
if input_media is None:
assert input_signal is not None and input_signal_length is not None
input_media = (input_signal, input_signal_length)
(
input_text,
pre_prompt,
post_prompt,
processed_signal,
processed_signal_length,
decoder_input_ids,
attention_mask,
) = self.setup_inputs(input_text, input_media, batch_size)
processed_media = (processed_signal, processed_signal_length)
self.generate(
pre_prompt,
post_prompt,
processed_media,
decoder_input_ids,
max_new_tokens,
attention_mask=attention_mask,
warmup=True,
batch_size=batch_size,
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
num_beams=num_beams,
)
num_iters = self.profiling_iterations if run_profiling else 1
for _ in range(num_iters):
output_text = self.generate(
pre_prompt,
post_prompt,
processed_media,
decoder_input_ids,
max_new_tokens,
attention_mask=attention_mask,
warmup=False,
batch_size=batch_size,
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
num_beams=num_beams,
)
if self.runtime_rank == 0:
self.print_result(input_text, output_text, batch_size, num_beams, run_profiling, check_accuracy)
return output_text