# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from io import BytesIO import numpy as np import requests import soundfile as sf from PIL import Image from nemo.deploy.utils import str_list2numpy use_pytriton = True try: from pytriton.client import ModelClient except Exception: use_pytriton = False try: from decord import VideoReader except Exception: import logging logging.warning("The package `decord` was not installed in this environment.") class NemoQueryMultimodal: """ Sends a query to Triton for Multimodal inference Example: from nemo.deploy.multimodal import NemoQueryMultimodal nq = NemoQueryMultimodal(url="localhost", model_name="neva", model_type="neva") input_text = "Hi! What is in this image?" output = nq.query( input_text=input_text, input_media="/path/to/image.jpg", max_output_len=30, top_k=1, top_p=0.0, temperature=1.0, ) print("prompts: ", prompts) """ def __init__(self, url, model_name, model_type): self.url = url self.model_name = model_name self.model_type = model_type def setup_media(self, input_media): """Setup input media""" if self.model_type == "video-neva": vr = VideoReader(input_media) frames = [f.asnumpy() for f in vr] return np.array(frames) elif self.model_type == "lita" or self.model_type == "vita": vr = VideoReader(input_media) frames = [f.asnumpy() for f in vr] subsample_len = self.frame_len(frames) sub_frames = self.get_subsampled_frames(frames, subsample_len) return np.array(sub_frames) elif self.model_type in ["neva", "vila", "mllama"]: if input_media.startswith("http") or input_media.startswith("https"): response = requests.get(input_media, timeout=5) media = Image.open(BytesIO(response.content)).convert("RGB") else: media = Image.open(input_media).convert('RGB') return np.expand_dims(np.array(media), axis=0) elif self.model_type == "salm": waveform, sample_rate = sf.read(input_media, dtype=np.float32) input_signal = np.array([waveform], dtype=np.float32) input_signal_length = np.array([[len(waveform)]], dtype=np.int32) return {"input_signal": input_signal, "input_signal_length": input_signal_length} else: raise RuntimeError(f"Invalid model type {self.model_type}") def frame_len(self, frames): """Get frame len""" max_frames = 256 if len(frames) <= max_frames: return len(frames) else: subsample = int(np.ceil(float(len(frames)) / max_frames)) return int(np.round(float(len(frames)) / subsample)) def get_subsampled_frames(self, frames, subsample_len): """Get subsampled frames""" idx = np.round(np.linspace(0, len(frames) - 1, subsample_len)).astype(int) sub_frames = [frames[i] for i in idx] return sub_frames def query( self, input_text, input_media, batch_size=1, max_output_len=30, top_k=1, top_p=0.0, temperature=1.0, repetition_penalty=1.0, num_beams=1, init_timeout=60.0, lora_uids=None, ): """Run query""" prompts = str_list2numpy([input_text]) inputs = {"input_text": prompts} media = self.setup_media(input_media) if isinstance(media, dict): inputs.update(media) else: inputs["input_media"] = np.repeat(media[np.newaxis, :, :, :, :], prompts.shape[0], axis=0) if batch_size is not None: inputs["batch_size"] = np.full(prompts.shape, batch_size, dtype=np.int_) if max_output_len is not None: inputs["max_output_len"] = np.full(prompts.shape, max_output_len, dtype=np.int_) if top_k is not None: inputs["top_k"] = np.full(prompts.shape, top_k, dtype=np.int_) if top_p is not None: inputs["top_p"] = np.full(prompts.shape, top_p, dtype=np.single) if temperature is not None: inputs["temperature"] = np.full(prompts.shape, temperature, dtype=np.single) if repetition_penalty is not None: inputs["repetition_penalty"] = np.full(prompts.shape, repetition_penalty, dtype=np.single) if num_beams is not None: inputs["num_beams"] = np.full(prompts.shape, num_beams, dtype=np.int_) if lora_uids is not None: lora_uids = np.char.encode(lora_uids, "utf-8") inputs["lora_uids"] = np.full((prompts.shape[0], len(lora_uids)), lora_uids) with ModelClient(self.url, self.model_name, init_timeout_s=init_timeout) as client: result_dict = client.infer_batch(**inputs) output_type = client.model_config.outputs[0].dtype if output_type == np.bytes_: sentences = np.char.decode(result_dict["outputs"].astype("bytes"), "utf-8") return sentences else: return result_dict["outputs"]