File size: 3,302 Bytes
b386992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import logging
import sys
import time

from nemo.deploy.nlp import NemoQueryLLMPyTorch

LOGGER = logging.getLogger("NeMo")


def get_args(argv):
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description=f"Queries Triton server running an in-framework Nemo model",
    )
    parser.add_argument("-u", "--url", default="0.0.0.0", type=str, help="url for the triton server")
    parser.add_argument("-mn", "--model_name", required=True, type=str, help="Name of the triton model")
    prompt_group = parser.add_mutually_exclusive_group(required=True)
    prompt_group.add_argument("-p", "--prompt", required=False, type=str, help="Prompt")
    prompt_group.add_argument("-pf", "--prompt_file", required=False, type=str, help="File to read the prompt from")
    parser.add_argument("-mol", "--max_output_len", default=128, type=int, help="Max output token length")
    parser.add_argument("-tk", "--top_k", default=1, type=int, help="top_k")
    parser.add_argument("-tpp", "--top_p", default=0.0, type=float, help="top_p")
    parser.add_argument("-t", "--temperature", default=1.0, type=float, help="temperature")
    parser.add_argument("-it", "--init_timeout", default=60.0, type=float, help="init timeout for the triton server")
    parser.add_argument("-clp", "--compute_logprob", default=None, action='store_true', help="Returns log_probs")

    args = parser.parse_args(argv)
    return args


def query_llm(
    url,
    model_name,
    prompts,
    max_output_len=128,
    top_k=1,
    top_p=0.0,
    temperature=1.0,
    compute_logprob=None,
    init_timeout=60.0,
):
    start_time = time.time()
    nemo_query = NemoQueryLLMPyTorch(url, model_name)
    result = nemo_query.query_llm(
        prompts=prompts,
        max_length=max_output_len,
        top_k=top_k,
        top_p=top_p,
        temperature=temperature,
        compute_logprob=compute_logprob,
        init_timeout=init_timeout,
    )
    end_time = time.time()
    LOGGER.info(f"Query execution time: {end_time - start_time:.2f} seconds")
    return result


def query(argv):
    args = get_args(argv)

    if args.prompt_file is not None:
        with open(args.prompt_file, "r") as f:
            args.prompt = f.read()

    outputs = query_llm(
        url=args.url,
        model_name=args.model_name,
        prompts=[args.prompt],
        max_output_len=args.max_output_len,
        top_k=args.top_k,
        top_p=args.top_p,
        temperature=args.temperature,
        compute_logprob=args.compute_logprob,
        init_timeout=args.init_timeout,
    )
    print(outputs)


if __name__ == '__main__':
    query(sys.argv[1:])