| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from io import BytesIO |
| |
|
| | import numpy as np |
| | import requests |
| | import soundfile as sf |
| | from PIL import Image |
| |
|
| | from nemo.deploy.utils import str_list2numpy |
| |
|
| | use_pytriton = True |
| | try: |
| | from pytriton.client import ModelClient |
| | except Exception: |
| | use_pytriton = False |
| |
|
| | try: |
| | from decord import VideoReader |
| | except Exception: |
| | import logging |
| |
|
| | logging.warning("The package `decord` was not installed in this environment.") |
| |
|
| |
|
| | class NemoQueryMultimodal: |
| | """ |
| | Sends a query to Triton for Multimodal inference |
| | |
| | Example: |
| | from nemo.deploy.multimodal import NemoQueryMultimodal |
| | |
| | nq = NemoQueryMultimodal(url="localhost", model_name="neva", model_type="neva") |
| | |
| | input_text = "Hi! What is in this image?" |
| | output = nq.query( |
| | input_text=input_text, |
| | input_media="/path/to/image.jpg", |
| | max_output_len=30, |
| | top_k=1, |
| | top_p=0.0, |
| | temperature=1.0, |
| | ) |
| | print("prompts: ", prompts) |
| | """ |
| |
|
| | def __init__(self, url, model_name, model_type): |
| | self.url = url |
| | self.model_name = model_name |
| | self.model_type = model_type |
| |
|
| | def setup_media(self, input_media): |
| | """Setup input media""" |
| | if self.model_type == "video-neva": |
| | vr = VideoReader(input_media) |
| | frames = [f.asnumpy() for f in vr] |
| | return np.array(frames) |
| | elif self.model_type == "lita" or self.model_type == "vita": |
| | vr = VideoReader(input_media) |
| | frames = [f.asnumpy() for f in vr] |
| | subsample_len = self.frame_len(frames) |
| | sub_frames = self.get_subsampled_frames(frames, subsample_len) |
| | return np.array(sub_frames) |
| | elif self.model_type in ["neva", "vila", "mllama"]: |
| | if input_media.startswith("http") or input_media.startswith("https"): |
| | response = requests.get(input_media, timeout=5) |
| | media = Image.open(BytesIO(response.content)).convert("RGB") |
| | else: |
| | media = Image.open(input_media).convert('RGB') |
| | return np.expand_dims(np.array(media), axis=0) |
| | elif self.model_type == "salm": |
| | waveform, sample_rate = sf.read(input_media, dtype=np.float32) |
| | input_signal = np.array([waveform], dtype=np.float32) |
| | input_signal_length = np.array([[len(waveform)]], dtype=np.int32) |
| | return {"input_signal": input_signal, "input_signal_length": input_signal_length} |
| | else: |
| | raise RuntimeError(f"Invalid model type {self.model_type}") |
| |
|
| | def frame_len(self, frames): |
| | """Get frame len""" |
| | max_frames = 256 |
| | if len(frames) <= max_frames: |
| | return len(frames) |
| | else: |
| | subsample = int(np.ceil(float(len(frames)) / max_frames)) |
| | return int(np.round(float(len(frames)) / subsample)) |
| |
|
| | def get_subsampled_frames(self, frames, subsample_len): |
| | """Get subsampled frames""" |
| | idx = np.round(np.linspace(0, len(frames) - 1, subsample_len)).astype(int) |
| | sub_frames = [frames[i] for i in idx] |
| | return sub_frames |
| |
|
| | def query( |
| | self, |
| | input_text, |
| | input_media, |
| | batch_size=1, |
| | max_output_len=30, |
| | top_k=1, |
| | top_p=0.0, |
| | temperature=1.0, |
| | repetition_penalty=1.0, |
| | num_beams=1, |
| | init_timeout=60.0, |
| | lora_uids=None, |
| | ): |
| | """Run query""" |
| |
|
| | prompts = str_list2numpy([input_text]) |
| | inputs = {"input_text": prompts} |
| |
|
| | media = self.setup_media(input_media) |
| | if isinstance(media, dict): |
| | inputs.update(media) |
| | else: |
| | inputs["input_media"] = np.repeat(media[np.newaxis, :, :, :, :], prompts.shape[0], axis=0) |
| |
|
| | if batch_size is not None: |
| | inputs["batch_size"] = np.full(prompts.shape, batch_size, dtype=np.int_) |
| |
|
| | if max_output_len is not None: |
| | inputs["max_output_len"] = np.full(prompts.shape, max_output_len, dtype=np.int_) |
| |
|
| | if top_k is not None: |
| | inputs["top_k"] = np.full(prompts.shape, top_k, dtype=np.int_) |
| |
|
| | if top_p is not None: |
| | inputs["top_p"] = np.full(prompts.shape, top_p, dtype=np.single) |
| |
|
| | if temperature is not None: |
| | inputs["temperature"] = np.full(prompts.shape, temperature, dtype=np.single) |
| |
|
| | if repetition_penalty is not None: |
| | inputs["repetition_penalty"] = np.full(prompts.shape, repetition_penalty, dtype=np.single) |
| |
|
| | if num_beams is not None: |
| | inputs["num_beams"] = np.full(prompts.shape, num_beams, dtype=np.int_) |
| |
|
| | if lora_uids is not None: |
| | lora_uids = np.char.encode(lora_uids, "utf-8") |
| | inputs["lora_uids"] = np.full((prompts.shape[0], len(lora_uids)), lora_uids) |
| |
|
| | with ModelClient(self.url, self.model_name, init_timeout_s=init_timeout) as client: |
| | result_dict = client.infer_batch(**inputs) |
| | output_type = client.model_config.outputs[0].dtype |
| |
|
| | if output_type == np.bytes_: |
| | sentences = np.char.decode(result_dict["outputs"].astype("bytes"), "utf-8") |
| | return sentences |
| | else: |
| | return result_dict["outputs"] |
| |
|