File size: 3,025 Bytes
b386992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import sys

from nemo.deploy.multimodal import NemoQueryMultimodal


def get_args(argv):
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description=f"Query Triton Multimodal server",
    )
    parser.add_argument("-u", "--url", default="0.0.0.0", type=str, help="url for the triton server")
    parser.add_argument("-mn", "--model_name", required=True, type=str, help="Name of the triton model")
    parser.add_argument("-mt", "--model_type", required=True, type=str, help="Type of the triton model")
    parser.add_argument("-int", "--input_text", required=True, type=str, help="Input text")
    parser.add_argument("-im", "--input_media", required=True, type=str, help="File path of input media")
    parser.add_argument("-bs", "--batch_size", default=1, type=int, help="Batch size")
    parser.add_argument("-mol", "--max_output_len", default=128, type=int, help="Max output token length")
    parser.add_argument("-tk", "--top_k", default=1, type=int, help="top_k")
    parser.add_argument("-tpp", "--top_p", default=0.0, type=float, help="top_p")
    parser.add_argument("-t", "--temperature", default=1.0, type=float, help="temperature")
    parser.add_argument("-rp", "--repetition_penalty", default=1.0, type=float, help="repetition_penalty")
    parser.add_argument("-nb", "--num_beams", default=1, type=int, help="num_beams")
    parser.add_argument("-it", "--init_timeout", default=60.0, type=float, help="init timeout for the triton server")
    parser.add_argument(
        "-lt",
        "--lora_task_uids",
        default=None,
        type=str,
        nargs="+",
        help="The list of LoRA task uids; use -1 to disable the LoRA module",
    )

    args = parser.parse_args(argv)
    return args


if __name__ == '__main__':
    args = get_args(sys.argv[1:])
    nq = NemoQueryMultimodal(url=args.url, model_name=args.model_name, model_type=args.model_type)
    output = nq.query(
        input_text=args.input_text,
        input_media=args.input_media,
        batch_size=args.batch_size,
        max_output_len=args.max_output_len,
        top_k=args.top_k,
        top_p=args.top_p,
        temperature=args.temperature,
        repetition_penalty=args.repetition_penalty,
        num_beams=args.num_beams,
        init_timeout=args.init_timeout,
        lora_uids=args.lora_task_uids,
    )
    print(output)