NeMo_Canary / nemo /deploy /multimodal /query_multimodal.py
Respair's picture
Upload folder using huggingface_hub
b386992 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from io import BytesIO
import numpy as np
import requests
import soundfile as sf
from PIL import Image
from nemo.deploy.utils import str_list2numpy
use_pytriton = True
try:
from pytriton.client import ModelClient
except Exception:
use_pytriton = False
try:
from decord import VideoReader
except Exception:
import logging
logging.warning("The package `decord` was not installed in this environment.")
class NemoQueryMultimodal:
"""
Sends a query to Triton for Multimodal inference
Example:
from nemo.deploy.multimodal import NemoQueryMultimodal
nq = NemoQueryMultimodal(url="localhost", model_name="neva", model_type="neva")
input_text = "Hi! What is in this image?"
output = nq.query(
input_text=input_text,
input_media="/path/to/image.jpg",
max_output_len=30,
top_k=1,
top_p=0.0,
temperature=1.0,
)
print("prompts: ", prompts)
"""
def __init__(self, url, model_name, model_type):
self.url = url
self.model_name = model_name
self.model_type = model_type
def setup_media(self, input_media):
"""Setup input media"""
if self.model_type == "video-neva":
vr = VideoReader(input_media)
frames = [f.asnumpy() for f in vr]
return np.array(frames)
elif self.model_type == "lita" or self.model_type == "vita":
vr = VideoReader(input_media)
frames = [f.asnumpy() for f in vr]
subsample_len = self.frame_len(frames)
sub_frames = self.get_subsampled_frames(frames, subsample_len)
return np.array(sub_frames)
elif self.model_type in ["neva", "vila", "mllama"]:
if input_media.startswith("http") or input_media.startswith("https"):
response = requests.get(input_media, timeout=5)
media = Image.open(BytesIO(response.content)).convert("RGB")
else:
media = Image.open(input_media).convert('RGB')
return np.expand_dims(np.array(media), axis=0)
elif self.model_type == "salm":
waveform, sample_rate = sf.read(input_media, dtype=np.float32)
input_signal = np.array([waveform], dtype=np.float32)
input_signal_length = np.array([[len(waveform)]], dtype=np.int32)
return {"input_signal": input_signal, "input_signal_length": input_signal_length}
else:
raise RuntimeError(f"Invalid model type {self.model_type}")
def frame_len(self, frames):
"""Get frame len"""
max_frames = 256
if len(frames) <= max_frames:
return len(frames)
else:
subsample = int(np.ceil(float(len(frames)) / max_frames))
return int(np.round(float(len(frames)) / subsample))
def get_subsampled_frames(self, frames, subsample_len):
"""Get subsampled frames"""
idx = np.round(np.linspace(0, len(frames) - 1, subsample_len)).astype(int)
sub_frames = [frames[i] for i in idx]
return sub_frames
def query(
self,
input_text,
input_media,
batch_size=1,
max_output_len=30,
top_k=1,
top_p=0.0,
temperature=1.0,
repetition_penalty=1.0,
num_beams=1,
init_timeout=60.0,
lora_uids=None,
):
"""Run query"""
prompts = str_list2numpy([input_text])
inputs = {"input_text": prompts}
media = self.setup_media(input_media)
if isinstance(media, dict):
inputs.update(media)
else:
inputs["input_media"] = np.repeat(media[np.newaxis, :, :, :, :], prompts.shape[0], axis=0)
if batch_size is not None:
inputs["batch_size"] = np.full(prompts.shape, batch_size, dtype=np.int_)
if max_output_len is not None:
inputs["max_output_len"] = np.full(prompts.shape, max_output_len, dtype=np.int_)
if top_k is not None:
inputs["top_k"] = np.full(prompts.shape, top_k, dtype=np.int_)
if top_p is not None:
inputs["top_p"] = np.full(prompts.shape, top_p, dtype=np.single)
if temperature is not None:
inputs["temperature"] = np.full(prompts.shape, temperature, dtype=np.single)
if repetition_penalty is not None:
inputs["repetition_penalty"] = np.full(prompts.shape, repetition_penalty, dtype=np.single)
if num_beams is not None:
inputs["num_beams"] = np.full(prompts.shape, num_beams, dtype=np.int_)
if lora_uids is not None:
lora_uids = np.char.encode(lora_uids, "utf-8")
inputs["lora_uids"] = np.full((prompts.shape[0], len(lora_uids)), lora_uids)
with ModelClient(self.url, self.model_name, init_timeout_s=init_timeout) as client:
result_dict = client.infer_batch(**inputs)
output_type = client.model_config.outputs[0].dtype
if output_type == np.bytes_:
sentences = np.char.decode(result_dict["outputs"].astype("bytes"), "utf-8")
return sentences
else:
return result_dict["outputs"]