|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
import sys |
|
|
|
|
|
from nemo.deploy.multimodal import NemoQueryMultimodal |
|
|
|
|
|
|
|
|
def get_args(argv): |
|
|
parser = argparse.ArgumentParser( |
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
|
|
description=f"Query Triton Multimodal server", |
|
|
) |
|
|
parser.add_argument("-u", "--url", default="0.0.0.0", type=str, help="url for the triton server") |
|
|
parser.add_argument("-mn", "--model_name", required=True, type=str, help="Name of the triton model") |
|
|
parser.add_argument("-mt", "--model_type", required=True, type=str, help="Type of the triton model") |
|
|
parser.add_argument("-int", "--input_text", required=True, type=str, help="Input text") |
|
|
parser.add_argument("-im", "--input_media", required=True, type=str, help="File path of input media") |
|
|
parser.add_argument("-bs", "--batch_size", default=1, type=int, help="Batch size") |
|
|
parser.add_argument("-mol", "--max_output_len", default=128, type=int, help="Max output token length") |
|
|
parser.add_argument("-tk", "--top_k", default=1, type=int, help="top_k") |
|
|
parser.add_argument("-tpp", "--top_p", default=0.0, type=float, help="top_p") |
|
|
parser.add_argument("-t", "--temperature", default=1.0, type=float, help="temperature") |
|
|
parser.add_argument("-rp", "--repetition_penalty", default=1.0, type=float, help="repetition_penalty") |
|
|
parser.add_argument("-nb", "--num_beams", default=1, type=int, help="num_beams") |
|
|
parser.add_argument("-it", "--init_timeout", default=60.0, type=float, help="init timeout for the triton server") |
|
|
parser.add_argument( |
|
|
"-lt", |
|
|
"--lora_task_uids", |
|
|
default=None, |
|
|
type=str, |
|
|
nargs="+", |
|
|
help="The list of LoRA task uids; use -1 to disable the LoRA module", |
|
|
) |
|
|
|
|
|
args = parser.parse_args(argv) |
|
|
return args |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
args = get_args(sys.argv[1:]) |
|
|
nq = NemoQueryMultimodal(url=args.url, model_name=args.model_name, model_type=args.model_type) |
|
|
output = nq.query( |
|
|
input_text=args.input_text, |
|
|
input_media=args.input_media, |
|
|
batch_size=args.batch_size, |
|
|
max_output_len=args.max_output_len, |
|
|
top_k=args.top_k, |
|
|
top_p=args.top_p, |
|
|
temperature=args.temperature, |
|
|
repetition_penalty=args.repetition_penalty, |
|
|
num_beams=args.num_beams, |
|
|
init_timeout=args.init_timeout, |
|
|
lora_uids=args.lora_task_uids, |
|
|
) |
|
|
print(output) |
|
|
|