| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| import json |
| import os |
|
|
| try: |
| import decord |
| except Exception: |
| import logging |
|
|
| logging.warning("The package `decord` was not installed in this environment.") |
|
|
| import einops |
| import numpy as np |
| import soundfile as sf |
| import tensorrt as trt |
| import tensorrt_llm |
| import tensorrt_llm.profiler as profiler |
| import torch |
| import yaml |
| from PIL import Image |
| from tensorrt_llm import logger |
| from tensorrt_llm._utils import str_dtype_to_trt, torch_dtype_to_trt |
| from tensorrt_llm.runtime import ModelRunner, Session, TensorInfo |
| from torch.nn import functional as F |
| from torchvision import transforms |
| from transformers import AutoProcessor, CLIPImageProcessor |
|
|
| from nemo.export.utils.constants import TRTLLM_ENGINE_DIR |
|
|
|
|
| def trt_dtype_to_torch(dtype): |
| if dtype == trt.float16: |
| return torch.float16 |
| elif dtype == trt.float32: |
| return torch.float32 |
| elif dtype == trt.int32: |
| return torch.int32 |
| elif dtype == trt.bfloat16: |
| return torch.bfloat16 |
| else: |
| raise TypeError("%s is not supported" % dtype) |
|
|
|
|
| class MultimodalModelRunner: |
|
|
| def __init__(self, visual_engine_dir, llm_engine_dir, modality='vision'): |
| self.modality = modality |
| self.runtime_rank = tensorrt_llm.mpi_rank() |
| device_id = self.runtime_rank % torch.cuda.device_count() |
| torch.cuda.set_device(device_id) |
| self.device = "cuda:%d" % (device_id) |
|
|
| self.stream = torch.cuda.Stream(torch.cuda.current_device()) |
| torch.cuda.set_stream(self.stream) |
|
|
| |
| with open(os.path.join(visual_engine_dir, "config.json"), "r") as f: |
| config = json.load(f) |
| self.model_type = config['builder_config']['model_type'] |
| self.vision_precision = config['builder_config']['precision'] |
| self.modality_precision = config['builder_config']['precision'] |
|
|
| self.num_frames = config['builder_config'].get('num_frames', None) |
| self.image_size = config['builder_config'].get('image_size', None) |
|
|
| self.profiling_iterations = 20 |
|
|
| if modality == 'vision': |
| self.init_image_encoder(visual_engine_dir) |
| self.init_tokenizer(llm_engine_dir) |
| self.init_llm(os.path.join(llm_engine_dir, TRTLLM_ENGINE_DIR)) |
| if self.model_type == 'lita' or self.model_type == 'vila' or self.model_type == 'vita': |
| self.init_vision_preprocessor(visual_engine_dir) |
|
|
| def init_tokenizer(self, llm_engine_dir): |
| if os.path.exists(os.path.join(llm_engine_dir, "tokenizer_config.json")): |
| from transformers import AutoTokenizer |
|
|
| self.tokenizer = AutoTokenizer.from_pretrained(llm_engine_dir) |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
| if self.model_type == 'vita': |
| self.tokenizer.im_start_id = self.tokenizer.convert_tokens_to_ids("<extra_id_4>") |
| self.tokenizer.im_end_id = self.tokenizer.convert_tokens_to_ids("<extra_id_5>") |
| self.tokenizer.vid_start_id = self.tokenizer.convert_tokens_to_ids("<extra_id_8>") |
| self.tokenizer.vid_end_id = self.tokenizer.convert_tokens_to_ids("<extra_id_9>") |
| else: |
| from sentencepiece import SentencePieceProcessor |
|
|
| sp = SentencePieceProcessor(os.path.join(llm_engine_dir, 'tokenizer.model')) |
|
|
| class return_obj: |
|
|
| def __init__(self, input_ids): |
| self.input_ids = input_ids |
|
|
| def __getitem__(self, name): |
| if name in "input_ids": |
| return self.input_ids |
| else: |
| raise AttributeError(f"'return_obj' has no item '{name}'") |
|
|
| |
| class HFTokenizerInterface: |
|
|
| def encode(self, x, return_tensors=None, **kwargs): |
| out = sp.encode(x) |
| if return_tensors == "pt": |
| out = torch.tensor(out) |
| return return_obj(out) |
|
|
| def __call__(self, x, return_tensors=None, **kwargs): |
| return self.encode(x, return_tensors, **kwargs) |
|
|
| def decode(self, x, **kwargs): |
| return sp.decode(x.tolist()) |
|
|
| def batch_decode(self, x, **kwargs): |
| return self.decode(x, **kwargs) |
|
|
| self.tokenizer = HFTokenizerInterface() |
| self.tokenizer.eos_token_id = sp.eos_id() |
| self.tokenizer.bos_token_id = sp.bos_id() |
| self.tokenizer.pad_token_id = sp.pad_id() |
|
|
| self.tokenizer.padding_side = "right" |
|
|
| if self.model_type == 'lita': |
| self.tokenizer.im_start_id = sp.piece_to_id("<extra_id_4>") |
| self.tokenizer.im_end_id = sp.piece_to_id("<extra_id_5>") |
| self.tokenizer.vid_start_id = sp.piece_to_id("<extra_id_8>") |
| self.tokenizer.vid_end_id = sp.piece_to_id("<extra_id_9>") |
|
|
| def init_image_encoder(self, visual_engine_dir): |
| vision_encoder_path = os.path.join(visual_engine_dir, 'visual_encoder.engine') |
| logger.info(f'Loading engine from {vision_encoder_path}') |
| with open(vision_encoder_path, 'rb') as f: |
| engine_buffer = f.read() |
| logger.info(f'Creating session from engine {vision_encoder_path}') |
| self.visual_encoder_session = Session.from_serialized_engine(engine_buffer) |
|
|
| def init_vision_preprocessor(self, visual_encoder_dir): |
| with open(os.path.join(visual_encoder_dir, 'nemo_config.yaml'), 'r') as f: |
| self.nemo_config = yaml.safe_load(f) |
|
|
| vision_config = self.nemo_config["mm_cfg"]["vision_encoder"] |
|
|
| if self.model_type == 'lita': |
| self.image_processor = AutoProcessor.from_pretrained( |
| vision_config["from_pretrained"], torch_dtype=torch.bfloat16, trust_remote_code=True |
| ) |
| elif self.model_type == 'vila' or self.model_type == 'vita': |
| from transformers import SiglipImageProcessor |
|
|
| self.image_processor = SiglipImageProcessor.from_pretrained( |
| vision_config["from_pretrained"], torch_dtype=torch.bfloat16, trust_remote_code=True |
| ) |
| else: |
| raise ValueError(f"Invalid model type: {self.model_type}") |
|
|
| def init_llm(self, llm_engine_dir): |
| self.model = ModelRunner.from_dir( |
| llm_engine_dir, |
| rank=tensorrt_llm.mpi_rank(), |
| debug_mode=False, |
| stream=self.stream, |
| ) |
| self.model_config = self.model.session._model_config |
| self.runtime_mapping = self.model.session.mapping |
|
|
| def video_preprocess(self, video_path): |
| from decord import VideoReader |
|
|
| if isinstance(video_path, str): |
| vr = VideoReader(video_path) |
| num_frames = self.num_frames |
| if num_frames == -1: |
| frames = [Image.fromarray(frame.asnumpy()).convert('RGB') for frame in vr] |
| else: |
| |
| |
| num_frames = min(num_frames, len(vr)) |
| indices = np.linspace(0, len(vr) - 1, num=num_frames, dtype=int) |
| frames = [Image.fromarray(vr[idx].asnumpy()).convert('RGB') for idx in indices] |
| if len(frames) < num_frames: |
| frames += [frames[-1]] * (num_frames - len(frames)) |
| elif isinstance(video_path, np.ndarray): |
| num_frames = self.num_frames |
| if num_frames == -1: |
| frames = [Image.fromarray(frame).convert('RGB') for frame in video_path] |
| else: |
| |
| |
| num_frames = min(num_frames, video_path.shape[0]) |
| indices = np.linspace(0, video_path.shape[0] - 1, num=num_frames, dtype=int) |
| frames = [Image.fromarray(video_path[idx]).convert('RGB') for idx in indices] |
| if len(frames) < num_frames: |
| frames += [frames[-1]] * (num_frames - len(frames)) |
| else: |
| frames = self.video_path |
|
|
| processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.bfloat16) |
| frames = processor.preprocess(frames, return_tensors="pt")['pixel_values'] |
| |
| media_tensors = frames.to( |
| tensorrt_llm._utils.str_dtype_to_torch(self.vision_precision) |
| ) |
| return media_tensors.unsqueeze(0) |
|
|
| def insert_tokens_by_index(self, input_ids, num_frames): |
| im_start_id = self.tokenizer.im_start_id |
| im_end_id = self.tokenizer.im_end_id |
| vid_start_id = self.tokenizer.vid_start_id |
| vid_end_id = self.tokenizer.vid_end_id |
|
|
| image_token_indices = (input_ids == 0).nonzero(as_tuple=False).squeeze().tolist() |
| input_ids = input_ids.squeeze().tolist() |
| offset = 0 |
|
|
| |
| for i in range(num_frames): |
| idx = image_token_indices[1] + offset |
| input_ids.insert(idx + 1, im_end_id) |
| input_ids.insert(idx + 1, 0) |
| input_ids.insert(idx + 1, im_start_id) |
| offset += 3 |
|
|
| |
| vid_idx = image_token_indices[1] + offset |
| input_ids.insert(vid_idx + 1, vid_end_id) |
| input_ids.insert(vid_idx + 1, 0) |
| input_ids.insert(vid_idx + 1, vid_start_id) |
|
|
| input_ids.pop(image_token_indices[1]) |
| input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0) |
|
|
| return input_ids |
|
|
| def preprocess(self, warmup, pre_prompt, post_prompt, image, attention_mask, batch_size): |
| if not warmup: |
| profiler.start(self.modality.capitalize()) |
|
|
| if not warmup: |
| profiler.stop(self.modality.capitalize()) |
|
|
| if self.model_type == 'vila': |
| visual_features, visual_atts = self.get_visual_features(image, attention_mask) |
| input_ids = self.tokenizer_image_token(batch_size, pre_prompt[0] + post_prompt[0], self.tokenizer) |
| batch_split_prompts = self.split_prompt_by_images(input_ids) |
| first_batch_split_prompts = batch_split_prompts[0] |
| |
| length = sum([ids.shape[1] for ids in first_batch_split_prompts]) |
| if batch_size == 1 and len(image) > 1: |
| |
| length += visual_atts.shape[0] * visual_atts.shape[1] |
| else: |
| |
| length += visual_atts.shape[1] |
|
|
| input_lengths = torch.IntTensor([length] * batch_size).to(torch.int32) |
| input_ids, ptuning_args = self.setup_fake_prompts_vila( |
| batch_size, visual_features, first_batch_split_prompts, input_lengths |
| ) |
| return input_ids, input_lengths, ptuning_args, visual_features |
|
|
| elif self.model_type == 'lita' or self.model_type == 'vita': |
| visual_input = [] |
| for i, img in enumerate(image): |
| visual_features, visual_atts = self.get_visual_features(img, attention_mask) |
| visual_features = visual_features.unsqueeze(0) |
| im_tokens, vid_tokens, num_sample_frames = self.preprocess_lita_visual(visual_features, self.nemo_config) |
| visual_input.extend([im_tokens, vid_tokens]) |
|
|
| input_ids = self.tokenizer_image_token(batch_size, pre_prompt[0] + post_prompt[0], self.tokenizer) |
| input_ids = self.insert_tokens_by_index(input_ids, num_sample_frames) |
| batch_splits = self.split_prompt_by_images(input_ids) |
| first_batch_split_prompts = batch_splits[0] |
| length = sum([ids.shape[1] for ids in first_batch_split_prompts]) |
|
|
| |
| im_tokens = im_tokens.view(1, -1, im_tokens.shape[-1]) |
| visual_features = torch.cat([im_tokens, vid_tokens], dim=1) |
| visual_atts = torch.ones(visual_features.size()[:-1], dtype=torch.long).to(image.device) |
|
|
| if batch_size == 1: |
| length += visual_atts.shape[0] * visual_atts.shape[1] |
| else: |
| raise ValueError("Batch size greater than 1 is not supported for LITA and VITA models") |
|
|
| input_lengths = torch.IntTensor([length] * batch_size).to(torch.int32) |
| input_ids, ptuning_args = self.setup_fake_prompts_vila( |
| batch_size, visual_input, first_batch_split_prompts, input_lengths |
| ) |
| return input_ids, input_lengths, ptuning_args, visual_features |
| else: |
| visual_features, visual_atts = self.get_visual_features(image, attention_mask) |
| pre_input_ids = self.tokenizer(pre_prompt, return_tensors="pt", padding=True).input_ids |
| if post_prompt[0] is not None: |
| post_input_ids = self.tokenizer(post_prompt, return_tensors="pt", padding=True).input_ids |
| if self.model_type == 'video-neva': |
| length = ( |
| pre_input_ids.shape[1] + post_input_ids.shape[1] + visual_atts.shape[2] * visual_atts.shape[1] |
| ) |
| else: |
| length = pre_input_ids.shape[1] + post_input_ids.shape[1] + visual_atts.shape[1] |
| else: |
| post_input_ids = None |
| length = pre_input_ids.shape[1] + visual_atts.shape[1] |
|
|
| input_lengths = torch.IntTensor([length] * batch_size).to(torch.int32) |
|
|
| input_ids, ptuning_args = self.setup_fake_prompts( |
| visual_features, pre_input_ids, post_input_ids, input_lengths |
| ) |
|
|
| return input_ids, input_lengths, ptuning_args, visual_features |
|
|
| @staticmethod |
| def tokenizer_image_token(batch_size, prompt, tokenizer, image_token_index=-200): |
| prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")] |
|
|
| def insert_separator(X, sep): |
| return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1] |
|
|
| input_ids = [] |
| offset = 0 |
| if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: |
| offset = 1 |
| input_ids.append(prompt_chunks[0][0]) |
|
|
| for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): |
| input_ids.extend(x[offset:]) |
|
|
| input_ids = torch.tensor(input_ids, dtype=torch.long) |
| input_ids[input_ids == image_token_index] = 0 |
| input_ids = input_ids.unsqueeze(0).expand(batch_size, -1) |
|
|
| return input_ids |
|
|
| def split_prompt_by_images(self, tensor): |
| batch_splits = [] |
| for batch in tensor: |
| |
| zero_indices = (batch == 0).nonzero(as_tuple=False).squeeze(0) |
| |
| start_idx = 0 |
| splits = [] |
| for idx in zero_indices: |
| if start_idx != idx: |
| splits.append(batch[start_idx:idx].unsqueeze(0)) |
| start_idx = idx + 1 |
| if start_idx < len(batch): |
| splits.append(batch[start_idx:].unsqueeze(0)) |
| |
| splits = [split for split in splits if split.numel() > 0] |
| batch_splits.append(splits) |
|
|
| return batch_splits |
|
|
| def generate( |
| self, |
| pre_prompt, |
| post_prompt, |
| image, |
| decoder_input_ids, |
| max_new_tokens, |
| attention_mask, |
| warmup, |
| batch_size, |
| top_k, |
| top_p, |
| temperature, |
| repetition_penalty, |
| num_beams, |
| lora_uids=None, |
| ): |
| if not warmup: |
| profiler.start("Generate") |
|
|
| input_ids, input_lengths, ptuning_args, visual_features = self.preprocess( |
| warmup, pre_prompt, post_prompt, image, attention_mask, batch_size |
| ) |
|
|
| if warmup: |
| return None |
|
|
| profiler.start("LLM") |
| end_id = self.tokenizer.eos_token_id |
|
|
| ptuning_args[0] = torch.stack([ptuning_args[0]]) |
| output_ids = self.model.generate( |
| input_ids, |
| sampling_config=None, |
| prompt_table=ptuning_args[0], |
| max_new_tokens=max_new_tokens, |
| end_id=end_id, |
| pad_id=( |
| self.tokenizer.pad_token_id |
| if self.tokenizer.pad_token_id is not None |
| else self.tokenizer.all_special_ids[0] |
| ), |
| top_k=top_k, |
| top_p=top_p, |
| temperature=temperature, |
| repetition_penalty=repetition_penalty, |
| num_beams=num_beams, |
| output_sequence_lengths=False, |
| lora_uids=lora_uids, |
| return_dict=False, |
| ) |
|
|
| profiler.stop("LLM") |
|
|
| if tensorrt_llm.mpi_rank() == 0: |
| |
| output_beams_list = [ |
| self.tokenizer.batch_decode( |
| output_ids[batch_idx, :, input_lengths[batch_idx] :], skip_special_tokens=True |
| ) |
| for batch_idx in range(batch_size) |
| ] |
|
|
| stripped_text = [ |
| [output_beams_list[batch_idx][beam_idx].strip() for beam_idx in range(num_beams)] |
| for batch_idx in range(batch_size) |
| ] |
| profiler.stop("Generate") |
| return stripped_text |
| else: |
| profiler.stop("Generate") |
| return None |
|
|
| def get_visual_features(self, image, attention_mask): |
| visual_features = {'input': image.to(tensorrt_llm._utils.str_dtype_to_torch(self.vision_precision))} |
| if attention_mask is not None: |
| visual_features['attention_mask'] = attention_mask |
| tensor_info = [TensorInfo('input', str_dtype_to_trt(self.vision_precision), image.shape)] |
| if attention_mask is not None: |
| tensor_info.append(TensorInfo('attention_mask', trt.DataType.INT32, attention_mask.shape)) |
|
|
| visual_output_info = self.visual_encoder_session.infer_shapes(tensor_info) |
|
|
| visual_outputs = { |
| t.name: torch.empty(tuple(t.shape), dtype=trt_dtype_to_torch(t.dtype), device=image.device) |
| for t in visual_output_info |
| } |
|
|
| ok = self.visual_encoder_session.run(visual_features, visual_outputs, self.stream.cuda_stream) |
| assert ok, "Runtime execution failed for vision encoder session" |
| self.stream.synchronize() |
|
|
| image_embeds = visual_outputs['output'] |
| image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device) |
|
|
| return image_embeds, image_atts |
|
|
| def setup_fake_prompts(self, visual_features, pre_input_ids, post_input_ids, input_lengths): |
| |
| if hasattr(self, 'num_frames') and (visual_features.shape[1] == self.num_frames): |
| visual_features = visual_features.view(visual_features.shape[0], -1, visual_features.shape[-1]) |
|
|
| fake_prompt_id = torch.arange( |
| self.model_config.vocab_size, |
| self.model_config.vocab_size + visual_features.shape[0] * visual_features.shape[1], |
| ) |
| fake_prompt_id = fake_prompt_id.reshape(visual_features.shape[0], visual_features.shape[1]) |
|
|
| if post_input_ids is not None: |
| input_ids = [pre_input_ids, fake_prompt_id, post_input_ids] |
| else: |
| input_ids = [fake_prompt_id, pre_input_ids] |
| input_ids = torch.cat(input_ids, dim=1).contiguous().to(torch.int32) |
|
|
| ptuning_args = self.ptuning_setup(visual_features, input_ids, input_lengths) |
|
|
| return input_ids, ptuning_args |
|
|
| def setup_fake_prompts_vila(self, batch_size, visual_features, split_input_ids, input_lengths): |
|
|
| if self.model_type == 'lita' or self.model_type == 'vita': |
| squeeze_img_tokens = visual_features[0].squeeze(0) |
| reshape_img_tokens = [t.unsqueeze(0) for t in squeeze_img_tokens] |
| visual_features = reshape_img_tokens + [visual_features[1]] |
|
|
| fake_prompt_counter = self.model_config.vocab_size |
| if batch_size == 1: |
| |
| assert len(visual_features) <= len( |
| split_input_ids |
| ), "Unexpected number of visual features. Please check #<image> in prompt and the #image files." |
|
|
| input_ids = [] |
| if batch_size == 1: |
| input_ids = [split_input_ids[0]] |
|
|
| if self.model_type == 'vila': |
| |
| for idx, visual_feature in enumerate(visual_features): |
| fake_prompt_id = torch.arange(fake_prompt_counter, fake_prompt_counter + visual_feature.shape[0]) |
| fake_prompt_counter += visual_feature.shape[0] |
| fake_prompt_id = fake_prompt_id.unsqueeze(0) |
| input_ids.append(fake_prompt_id) |
|
|
| |
| if len(split_input_ids) > idx + 1: |
| input_ids.append(split_input_ids[idx + 1]) |
| elif self.model_type == 'lita' or self.model_type == 'vita': |
| for idx, visual_f in enumerate(visual_features): |
| fake_prompt_id = torch.arange(fake_prompt_counter, fake_prompt_counter + visual_f.shape[1]) |
| fake_prompt_id = fake_prompt_id.reshape(visual_f.shape[1]) |
| fake_prompt_counter += visual_f.shape[1] |
| fake_prompt_id = fake_prompt_id.unsqueeze(0) |
| input_ids.append(fake_prompt_id) |
|
|
| |
| if len(split_input_ids) > idx + 1: |
| input_ids.append(split_input_ids[idx + 1]) |
|
|
| elif batch_size > 1 and self.model_type == 'vila': |
| |
| for idx, visual_feature in enumerate(visual_features): |
| input_ids.append(split_input_ids[0]) |
| fake_prompt_id = torch.arange(fake_prompt_counter, fake_prompt_counter + visual_feature.shape[0]) |
| fake_prompt_counter += visual_feature.shape[0] |
| fake_prompt_id = fake_prompt_id.unsqueeze(0) |
| input_ids.append(fake_prompt_id) |
| if len(split_input_ids) > 1: |
| input_ids.append(split_input_ids[1]) |
|
|
| input_ids = torch.cat(input_ids, dim=1).contiguous().to(torch.int32) |
| input_ids = input_ids.reshape(batch_size, -1) |
| ptuning_args = self.ptuning_setup(visual_features, input_ids, input_lengths) |
| return input_ids, ptuning_args |
|
|
| def preprocess_lita_visual(self, visual_features, config): |
|
|
| b, t, s, d = visual_features.shape |
|
|
| num_frames = t |
| if ( |
| 'visual_token_format' in config['mm_cfg']['lita'] |
| and config['mm_cfg']['lita']['visual_token_format'] == 'im_vid_start_end' |
| ): |
| num_image_frames = min(num_frames, config['mm_cfg']['lita']['sample_frames']) |
| idx = np.round(np.linspace(0, num_frames - 1, num_image_frames)).astype(int) |
|
|
| |
| im_features = visual_features[:, idx, ...] |
|
|
| vid_features = einops.reduce(visual_features, 'b t s d -> b t d', 'mean') |
| return im_features, vid_features, num_image_frames |
|
|
| elif ( |
| 'lita_video_arch' in config['mm_cfg']['lita'] |
| and config['mm_cfg']['lita']['lita_video_arch'] == 'temporal_spatial_pool' |
| ): |
| pool_size = 2 |
| selected_frames = np.round(np.linspace(0, visual_features.shape[1] - 1, pool_size * pool_size)).astype(int) |
| s_tokens = visual_features[:, selected_frames, ...] |
| s_tokens = einops.rearrange(s_tokens, 'b t (h w) d -> (b t) d h w', h=16, w=16) |
| s_tokens = F.avg_pool2d(s_tokens, kernel_size=pool_size) |
| s_tokens = einops.rearrange(s_tokens, '(b t) d h w -> b (t h w) d', b=b) |
|
|
| t_tokens = einops.reduce(visual_features, 'b t s d -> b t d', 'mean') |
|
|
| return t_tokens, s_tokens, pool_size**2 |
|
|
| else: |
| raise ValueError(f'Invalid visual token format: {config["mm_cfg"]["lita"]["visual_token_format"]}') |
|
|
| def ptuning_setup(self, prompt_table, input_ids, input_lengths): |
| hidden_size = self.model_config.hidden_size * self.runtime_mapping.tp_size |
|
|
| if self.model_type == 'lita' or self.model_type == 'vita': |
| prompt_table = torch.cat(prompt_table, dim=1) |
| if prompt_table is not None: |
| task_vocab_size = torch.tensor( |
| [prompt_table.shape[1]], |
| dtype=torch.int32, |
| ).cuda() |
| prompt_table = prompt_table.view((prompt_table.shape[0] * prompt_table.shape[1], prompt_table.shape[2])) |
|
|
| assert prompt_table.shape[1] == hidden_size, "Prompt table dimensions do not match hidden size" |
|
|
| prompt_table = prompt_table.cuda().to( |
| dtype=tensorrt_llm._utils.str_dtype_to_torch(self.model_config.dtype) |
| ) |
| else: |
| prompt_table = torch.empty([1, hidden_size]).cuda() |
| task_vocab_size = torch.zeros([1]).cuda() |
|
|
| if self.model_config.remove_input_padding: |
| tasks = torch.zeros([torch.sum(input_lengths)], dtype=torch.int32).cuda() |
| else: |
| tasks = torch.zeros(input_ids.shape, dtype=torch.int32).cuda() |
|
|
| return [prompt_table, tasks, task_vocab_size] |
|
|
| def expand2square_pt(self, images, background_color): |
| height, width = images.shape[-2:] |
| b = len(images) |
| background_color = torch.Tensor(background_color) |
| if width == height: |
| return images |
| elif width > height: |
| result = einops.repeat(background_color, 'c -> b c h w', b=b, h=width, w=width).clone() |
| paste_start = (width - height) // 2 |
| paste_end = paste_start + height |
| result[:, :, paste_start:paste_end, :] = images |
| return result |
| else: |
| result = einops.repeat(background_color, 'c -> b c h w', b=b, h=height, w=height).clone() |
| paste_start = (height - width) // 2 |
| paste_end = paste_start + width |
| result[:, :, :, paste_start:paste_end] = images |
| return result |
|
|
| def load_video(self, config, video_path, processor, num_frames=None): |
| frames = None |
| if isinstance(video_path, str): |
| decord.bridge.set_bridge('torch') |
| video_reader = decord.VideoReader(uri=video_path) |
| if num_frames is not None: |
| idx = np.round(np.linspace(0, len(video_reader) - 1, num_frames)).astype(int) |
| frames = video_reader.get_batch(idx) |
| else: |
| frames = torch.cat([torch.tensor(f.asnumpy()) for f in video_reader]) |
| elif isinstance(video_path, np.ndarray): |
| frames = torch.tensor(video_path, dtype=torch.float32) |
|
|
| return self.preprocess_frames(frames, config, processor) |
|
|
| def preprocess_frames(self, frames, config, processor): |
| frames = einops.rearrange(frames, 't h w c -> t c h w') |
| if config['data']['image_aspect_ratio'] == 'pad': |
| frames = self.expand2square_pt(frames, tuple(int(x * 255) for x in processor.image_mean)) |
| processed_frames = processor.preprocess(frames, return_tensors='pt')['pixel_values'] |
| return processed_frames |
|
|
| def get_num_sample_frames(self, config, vid_len): |
| if ( |
| 'visual_token_format' in config['mm_cfg']['lita'] |
| and config['mm_cfg']['lita']['visual_token_format'] == 'im_vid_start_end' |
| ): |
| max_frames = config['data']['num_frames'] |
| if vid_len <= max_frames: |
| return vid_len |
| else: |
| subsample = int(np.ceil(float(vid_len) / max_frames)) |
| return int(np.round(float(vid_len) / subsample)) |
| else: |
| return config['mm_cfg']['lita']['sample_frames'] |
|
|
| def process_lita_video(self, nemo_config, video_path, image_processor): |
| image = None |
| if isinstance(video_path, str): |
| vid_len = len(decord.VideoReader(video_path)) |
| num_sample_frames = self.get_num_sample_frames(nemo_config, vid_len) |
| image = ( |
| self.load_video(nemo_config, video_path, image_processor, num_sample_frames) |
| .unsqueeze(0) |
| .to(self.device, dtype=torch.bfloat16) |
| ) |
| elif isinstance(video_path, np.ndarray): |
| image = ( |
| self.load_video(nemo_config, video_path, image_processor) |
| .unsqueeze(0) |
| .to(self.device, dtype=torch.bfloat16) |
| ) |
| return image |
|
|
| def process_image(self, image_file, image_processor, nemo_config, image_folder): |
| if isinstance(image_file, str): |
| if image_folder is not None: |
| image = Image.open(os.path.join(image_folder, image_file)).convert("RGB") |
| else: |
| image = Image.open(image_file).convert("RGB") |
| else: |
| |
| image = image_file |
|
|
| crop_size = nemo_config['mm_cfg']['vision_encoder']['crop_size'] |
| crop_size = tuple(crop_size) |
| image = image.resize(crop_size) |
| if nemo_config['data']['image_aspect_ratio'] == 'pad': |
| image = self.expand2square_pt(image, tuple(int(x * 255) for x in image_processor.image_mean)) |
| image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0] |
| else: |
| image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0] |
| return image |
|
|
| def process_vila_img(self, images): |
| new_images = [self.process_image(image, self.image_processor, self.nemo_config, None) for image in images] |
|
|
| if all(x.shape == new_images[0].shape for x in new_images): |
| new_images = torch.stack(new_images, dim=0) |
| return new_images |
|
|
| def setup_inputs(self, input_text, raw_image, batch_size): |
| attention_mask = None |
| image = None |
|
|
| if self.model_type == "neva": |
| image_size = self.image_size |
| dtype = torch.float32 |
| transform = transforms.Compose( |
| [ |
| transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC), |
| transforms.ToTensor(), |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), |
| ] |
| ) |
| image = transform(raw_image).to(dtype).unsqueeze(0) |
|
|
| if input_text is None: |
| input_text = "Hi! What is in this image?" |
|
|
| pre_prompt = "<extra_id_0>System\n\n<extra_id_1>User\n" |
| post_prompt = f"\n{input_text}\n<extra_id_1>Assistant\n" |
| elif self.model_type == "video-neva": |
| image = self.video_preprocess(raw_image) |
|
|
| if input_text is None: |
| input_text = "Hi! What is in this video?" |
|
|
| |
| pre_prompt = ( |
| "<extra_id_0>System\nA chat between a curious user and an artificial intelligence assistant. " |
| "The assistant gives helpful, detailed, and polite answers to the user's questions.\n\n" |
| "<extra_id_1>User" |
| ) |
| post_prompt = ( |
| f"\n{input_text}\n<extra_id_1>Assistant\n" |
| "<extra_id_2>quality:4,toxicity:0,humor:0,creativity:0,helpfulness:4," |
| "correctness:4,coherence:4,complexity:4,verbosity:4\n" |
| ) |
| elif self.model_type in ['vila', 'lita', 'vita']: |
| if self.model_type == "vila" or self.model_type == "lita": |
| pre_prompt = ( |
| "A chat between a curious user and an artificial intelligence assistant. " |
| "The assistant gives helpful, detailed, and polite answers to the user's questions. USER: " |
| ) |
| if input_text is None: |
| input_text = "<image>\n Please elaborate what you see in the images?" |
| post_prompt = input_text + " ASSISTANT:" |
|
|
| elif self.model_type == "vita": |
| |
| pre_prompt = ( |
| "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n" |
| "You are a helpful language and vision assistant. " |
| "You are able to understand the visual content that the user provides, " |
| "and assist the user with a variety of tasks using natural language. " |
| "<|start_header_id|>user<|end_header_id|>\n\n" |
| ) |
| if input_text is None: |
| input_text = "<image>\n Please elaborate what you see in the images?" |
| post_prompt = input_text + "<|start_header_id|>assistant<|end_header_id|>\n\n" |
|
|
| else: |
| raise RuntimeError(f"Invalid model type {self.model_type}") |
|
|
| if self.model_type == 'lita' or self.model_type == 'vita': |
| image = self.process_lita_video(self.nemo_config, raw_image, self.image_processor) |
|
|
| if self.model_type == 'vila': |
| raw_image = [raw_image] * batch_size |
| image = self.process_vila_img(raw_image) |
|
|
| |
| pre_prompt = [pre_prompt] * batch_size |
| post_prompt = [post_prompt] * batch_size |
| if self.model_type not in ['vila', 'lita', 'vita']: |
| if image.dim() == 5: |
| image = image.expand(batch_size, -1, -1, -1, -1).contiguous() |
| else: |
| image = image.expand(batch_size, -1, -1, -1).contiguous() |
| image = image.to(self.device) |
|
|
| decoder_input_ids = None |
|
|
| return input_text, pre_prompt, post_prompt, image, decoder_input_ids, attention_mask |
|
|
| def run( |
| self, |
| input_text, |
| input_image, |
| max_new_tokens, |
| batch_size, |
| top_k, |
| top_p, |
| temperature, |
| repetition_penalty, |
| num_beams, |
| lora_uids=None, |
| run_profiling=False, |
| check_accuracy=False, |
| ): |
| input_text, pre_prompt, post_prompt, processed_image, decoder_input_ids, attention_mask = self.setup_inputs( |
| input_text, input_image, batch_size |
| ) |
|
|
| self.generate( |
| pre_prompt, |
| post_prompt, |
| processed_image, |
| decoder_input_ids, |
| max_new_tokens, |
| attention_mask=attention_mask, |
| warmup=True, |
| batch_size=batch_size, |
| top_k=top_k, |
| top_p=top_p, |
| temperature=temperature, |
| repetition_penalty=repetition_penalty, |
| num_beams=num_beams, |
| lora_uids=lora_uids, |
| ) |
| num_iters = self.profiling_iterations if run_profiling else 1 |
| for _ in range(num_iters): |
| output_text = self.generate( |
| pre_prompt, |
| post_prompt, |
| processed_image, |
| decoder_input_ids, |
| max_new_tokens, |
| attention_mask=attention_mask, |
| warmup=False, |
| batch_size=batch_size, |
| top_k=top_k, |
| top_p=top_p, |
| temperature=temperature, |
| repetition_penalty=repetition_penalty, |
| num_beams=num_beams, |
| lora_uids=lora_uids, |
| ) |
| if self.runtime_rank == 0: |
| self.print_result(input_text, output_text, batch_size, num_beams, run_profiling, check_accuracy) |
| return output_text |
|
|
| def print_result(self, input_text, output_text, batch_size, num_beams, run_profiling, check_accuracy): |
| if not run_profiling and not check_accuracy: |
| return |
| logger.info("---------------------------------------------------------") |
| if self.model_type != 'nougat': |
| logger.info(f"\n[Q] {input_text}") |
| logger.info(f"\n[A] {output_text[0]}") |
|
|
| if num_beams == 1: |
| output_ids = self.tokenizer(output_text[0][0], add_special_tokens=False)['input_ids'] |
| logger.info(f"Generated {len(output_ids)} tokens") |
|
|
| if check_accuracy: |
| for i in range(batch_size - 1): |
| if not (output_text[i] == output_text[i + 1]): |
| logger.info(f"Output {i} and {i + 1} do not match") |
| assert False |
|
|
| assert 'robot' in output_text[0][0].lower() |
|
|
| if run_profiling: |
| msec_per_batch = lambda name: 1000 * profiler.elapsed_time_in_sec(name) / self.profiling_iterations |
| logger.info('Latencies per batch (msec)') |
| logger.info(f'TRT {self.modality} encoder: %.1f' % (msec_per_batch(self.modality.capitalize()))) |
| logger.info('TRTLLM LLM generate: %.1f' % (msec_per_batch('LLM'))) |
| logger.info('Multimodal generate: %.1f' % (msec_per_batch('Generate'))) |
|
|
| logger.info("---------------------------------------------------------") |
|
|
| def load_test_media(self, input_media): |
| media_model = ["video-neva", "lita", "vita"] |
| if self.model_type in media_model: |
| media = input_media |
| elif self.model_type == "neva" or self.model_type == "vila": |
| media = Image.open(input_media).convert('RGB') |
| else: |
| raise RuntimeError(f"Invalid model type {self.model_type}") |
|
|
| return media |
|
|
|
|
| class SpeechllmModelRunner(MultimodalModelRunner): |
| def __init__(self, perception_engine_dir, llm_engine_dir, modality): |
| """ |
| perception_engine_dir: path to the perception engine directory |
| it should contain: |
| config.json nemo_config.yaml |
| perception_encoder.engine : tensorrt engine |
| feature_extractor.ts : torchscript model |
| llm_engine_dir: path to the LLM engine directory |
| """ |
| super().__init__(perception_engine_dir, llm_engine_dir, modality) |
| assert self.model_type == 'salm' |
| |
| feature_extractor_path = os.path.join(perception_engine_dir, 'feature_extractor.ts') |
| self.feature_extractor = self.init_speech_preprocessor(feature_extractor_path) |
| self.init_modality_encoder(perception_engine_dir) |
|
|
| def init_modality_encoder(self, engine_dir): |
| """ |
| Initialize the modality encoder session from the prebuilt engine directory |
| Args: |
| engine_dir: str, path to the engine directory |
| """ |
| |
| engine_file = None |
| for file in os.listdir(engine_dir): |
| if file.endswith('.engine'): |
| engine_file = file |
| break |
| assert engine_file is not None, f"Engine file not found in {engine_dir}" |
| encoder_path = os.path.join(engine_dir, engine_file) |
| logger.info(f'Loading engine from {encoder_path}') |
| with open(encoder_path, 'rb') as f: |
| engine_buffer = f.read() |
| logger.info(f'Creating session from engine {encoder_path}') |
| self.modality_encoder_session = Session.from_serialized_engine(engine_buffer) |
|
|
| def init_speech_preprocessor(self, feature_extractor_path): |
| feature_extractor = torch.jit.load(feature_extractor_path) |
| feature_extractor.eval() |
| return feature_extractor |
|
|
| def process_audio(self, input_signal, input_signal_length): |
| """ |
| Args: |
| input_signal: audio signal in numpy array |
| input_signal_length: length of the audio signal in numpy array |
| |
| Returns: |
| processed_signal: torch.tensor [B, 80, T] |
| processed_signal_length [B] |
| """ |
| input_signal = torch.tensor(input_signal, dtype=torch.float32) |
| input_signal_length = torch.tensor(input_signal_length, dtype=torch.int32) |
| processed_signal, processed_signal_length = self.feature_extractor(input_signal, input_signal_length) |
| return processed_signal, processed_signal_length |
|
|
| def setup_inputs(self, input_text, input_media, batch_size): |
| """ |
| Args: |
| input_text: str or List[str] or None |
| input_media: Tuple[np.array, np.array] |
| input_signal: audio signal in numpy array [b, -1] |
| input_signal_length: length of the audio signal in numpy array [b] |
| batch_size: int |
| |
| """ |
| input_signal, input_signal_length = input_media |
| processed_signal, processed_signal_length = self.process_audio(input_signal, input_signal_length) |
| processed_signal = processed_signal.to(self.device) |
| processed_signal_length = processed_signal_length.to(self.device) |
| if input_text is None: |
| input_text = "Q: what's the transcription of the audio? A:" |
|
|
| if isinstance(input_text, str): |
| input_text = [input_text] * batch_size |
|
|
| assert len(input_text) == batch_size |
| pre_prompt = [''] * batch_size |
| post_prompt = input_text |
| decoder_input_ids = None |
| attention_mask = None |
| return ( |
| input_text, |
| pre_prompt, |
| post_prompt, |
| processed_signal, |
| processed_signal_length, |
| decoder_input_ids, |
| attention_mask, |
| ) |
|
|
| def load_test_media(self, input_media_path): |
| """ |
| Args: |
| input_media_path: str, path to the audio file |
| Returns: |
| input_signal: np.array [1, -1] |
| input_signal_length: np.array [1] |
| """ |
| waveform, sample_rate = sf.read(input_media_path, dtype=np.float32) |
| input_signal = np.array([waveform], dtype=np.float32) |
| input_signal_length = np.array([len(waveform)], dtype=np.int32) |
| return input_signal, input_signal_length |
|
|
| def get_modality_encoder_features(self, modality_features, attention_mask): |
| """ |
| Do inference on the modality encoder engine |
| Args: |
| modality_features: dict {'input1': torch.tensor, 'input2': torch.tensor, ..} |
| attention_mask: None |
| Returns: |
| """ |
|
|
| if attention_mask is not None: |
| modality_features['attention_mask'] = attention_mask |
|
|
| tensor_info = [] |
| for key, tensor in modality_features.items(): |
| tensor_info.append(TensorInfo(key, torch_dtype_to_trt(tensor.dtype), tensor.shape)) |
|
|
| output_info = self.modality_encoder_session.infer_shapes(tensor_info) |
|
|
| outputs = { |
| t.name: torch.empty(tuple(t.shape), dtype=trt_dtype_to_torch(t.dtype), device=self.device) |
| for t in output_info |
| } |
|
|
| ok = self.modality_encoder_session.run(modality_features, outputs, self.stream.cuda_stream) |
| assert ok, "Runtime execution failed for vision encoder session" |
| self.stream.synchronize() |
|
|
| return outputs |
|
|
| def preprocess(self, warmup, pre_prompt, post_prompt, processed_features, attention_mask, batch_size): |
| """ |
| Args: |
| warmup: bool |
| pre_prompt: List[str] |
| post_prompt: List[str] |
| processed_features: Tuple[torch.tensor, torch.tensor] |
| processed_signal: torch.tensor [B, 80, T] |
| processed_signal_length: torch.tensor [B] |
| attention_mask: None |
| batch_size: int |
| Returns: |
| input_ids: torch.tensor [B, L] |
| input_lengths: torch.tensor [B] |
| ptuning_args: List[torch.tensor] |
| encoded_features: torch.tensor [B, L, D] |
| """ |
| if not warmup: |
| profiler.start(self.modality.capitalize()) |
|
|
| if not warmup: |
| profiler.stop(self.modality.capitalize()) |
|
|
| assert self.model_type == 'salm', f"Invalid model type {self.model_type}" |
|
|
| processed_features = { |
| "processed_signal": processed_features[0], |
| "processed_signal_length": processed_features[1].to(torch.int32), |
| } |
| encoded_outputs = self.get_modality_encoder_features(processed_features, attention_mask) |
| encoded_features, encoded_length = encoded_outputs['encoded'], encoded_outputs['encoded_length'] |
| pre_input_ids = self.tokenizer(pre_prompt).input_ids |
| post_input_ids = self.tokenizer(post_prompt).input_ids |
| input_lengths = [] |
| input_ids = [] |
| encoded_length = encoded_length.cpu().numpy() |
| fake_id_start = self.model.vocab_size |
| for i in range(batch_size): |
| feat_len = encoded_length[i] |
| feat_fake_ids = np.arange(fake_id_start, fake_id_start + feat_len) |
| cur_input_ids = np.concatenate([pre_input_ids[i], feat_fake_ids, post_input_ids[i]]) |
| fake_id_start += feat_len |
| input_lengths.append(len(cur_input_ids)) |
| input_ids.append(cur_input_ids) |
|
|
| max_length = max(input_lengths) |
| |
| input_ids = [ |
| np.pad(ids, (0, max_length - len(ids)), 'constant', constant_values=self.tokenizer.pad_token_id) |
| for ids in input_ids |
| ] |
| input_ids = torch.tensor(input_ids, dtype=torch.int32) |
| input_lengths = torch.tensor(input_lengths, dtype=torch.int32) |
| ptuning_args = self.ptuning_setup(encoded_features, input_ids, input_lengths) |
|
|
| return input_ids, input_lengths, ptuning_args, encoded_features |
|
|
| def run( |
| self, |
| input_text, |
| input_media=None, |
| max_new_tokens: int = 30, |
| batch_size: int = 1, |
| top_k: int = 1, |
| top_p: float = 0.0, |
| temperature: float = 1.0, |
| repetition_penalty: float = 1.0, |
| num_beams: int = 1, |
| run_profiling=False, |
| check_accuracy=False, |
| input_signal=None, |
| input_signal_length=None, |
| lora_uids=None, |
| ): |
| """ |
| Args: |
| input_text: str or List[str] or None |
| input_media: Tuple[np.array, np.array] or None |
| input_signal: audio signal in numpy array [b, -1] |
| input_signal_length: length of the audio signal in numpy array [b] |
| max_new_tokens: int |
| batch_size: int |
| top_k: int |
| top_p: float |
| temperature: float |
| repetition_penalty: float |
| num_beams: int |
| run_profiling: bool |
| check_accuracy: bool |
| """ |
| if input_media is None: |
| assert input_signal is not None and input_signal_length is not None |
| input_media = (input_signal, input_signal_length) |
|
|
| ( |
| input_text, |
| pre_prompt, |
| post_prompt, |
| processed_signal, |
| processed_signal_length, |
| decoder_input_ids, |
| attention_mask, |
| ) = self.setup_inputs(input_text, input_media, batch_size) |
| processed_media = (processed_signal, processed_signal_length) |
|
|
| self.generate( |
| pre_prompt, |
| post_prompt, |
| processed_media, |
| decoder_input_ids, |
| max_new_tokens, |
| attention_mask=attention_mask, |
| warmup=True, |
| batch_size=batch_size, |
| top_k=top_k, |
| top_p=top_p, |
| temperature=temperature, |
| repetition_penalty=repetition_penalty, |
| num_beams=num_beams, |
| ) |
| num_iters = self.profiling_iterations if run_profiling else 1 |
| for _ in range(num_iters): |
| output_text = self.generate( |
| pre_prompt, |
| post_prompt, |
| processed_media, |
| decoder_input_ids, |
| max_new_tokens, |
| attention_mask=attention_mask, |
| warmup=False, |
| batch_size=batch_size, |
| top_k=top_k, |
| top_p=top_p, |
| temperature=temperature, |
| repetition_penalty=repetition_penalty, |
| num_beams=num_beams, |
| ) |
| if self.runtime_rank == 0: |
| self.print_result(input_text, output_text, batch_size, num_beams, run_profiling, check_accuracy) |
| return output_text |
|
|