File size: 8,798 Bytes
b386992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import logging
import os
import sys
from pathlib import Path

from nemo.deploy import DeployPyTriton

LOGGER = logging.getLogger("NeMo")

multimodal_supported = True
try:
    from nemo.export.tensorrt_mm_exporter import TensorRTMMExporter
except Exception as e:
    LOGGER.warning(f"Cannot import the TensorRTMMExporter exporter, it will not be available. {type(e).__name__}: {e}")
    multimodal_supported = False


def get_args(argv):
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description=f"Deploy nemo models to Triton",
    )
    # default modality is vision, can be changed to audio
    parser.add_argument(
        "-mod",
        "--modality",
        type=str,
        required=False,
        default="vision",
        choices=["vision", "audio"],
        help="Modality of the model",
    )
    parser.add_argument("-vc", "--visual_checkpoint", type=str, help="Source .nemo file for visual model")
    parser.add_argument(
        "-lc",
        "--llm_checkpoint",
        type=str,
        required=False,
        help="Source .nemo file for llm",
    )
    parser.add_argument(
        "-mt",
        "--model_type",
        type=str,
        required=True,
        choices=["neva", "video-neva", "lita", "vila", "vita", "salm", "mllama"],
        help="Type of the model that is supported.",
    )
    parser.add_argument(
        "-lmt",
        "--llm_model_type",
        type=str,
        required=True,
        choices=["gptnext", "gpt", "llama", "falcon", "starcoder", "mixtral", "gemma", "mllama"],
        help="Type of LLM. gptnext, gpt, llama, falcon, and starcoder are only supported."
        " gptnext and gpt are the same and keeping it for backward compatibility",
    )
    parser.add_argument("-tmn", "--triton_model_name", required=True, type=str, help="Name for the service")
    parser.add_argument("-tmv", "--triton_model_version", default=1, type=int, help="Version for the service")
    parser.add_argument(
        "-trp", "--triton_port", default=8000, type=int, help="Port for the Triton server to listen for requests"
    )
    parser.add_argument(
        "-tha", "--triton_http_address", default="0.0.0.0", type=str, help="HTTP address for the Triton server"
    )
    parser.add_argument(
        "-tmr", "--triton_model_repository", default=None, type=str, help="Folder for the trt-llm conversion"
    )
    parser.add_argument("-ng", "--num_gpus", default=1, type=int, help="Number of GPUs for the deployment")
    parser.add_argument(
        "-dt",
        "--dtype",
        choices=["bfloat16", "float16"],
        default="bfloat16",
        type=str,
        help="dtype of the model on TensorRT",
    )
    parser.add_argument("-mil", "--max_input_len", default=4096, type=int, help="Max input length of the model")
    parser.add_argument("-mol", "--max_output_len", default=256, type=int, help="Max output length of the model")
    parser.add_argument("-mbs", "--max_batch_size", default=1, type=int, help="Max batch size of the llm model")
    parser.add_argument("-mml", "--max_multimodal_len", default=3072, type=int, help="Max length of multimodal input")
    parser.add_argument(
        "-vmb",
        "--vision_max_batch_size",
        default=1,
        type=int,
        help="Max batch size of the visual inputs, for lita/vita model with video inference, this should be set to 256",
    )
    parser.add_argument(
        '--use_lora_plugin',
        nargs='?',
        const=None,
        choices=['float16', 'float32', 'bfloat16'],
        help="Activates the lora plugin which enables embedding sharing.",
    )
    parser.add_argument(
        '--lora_target_modules',
        nargs='+',
        default=None,
        choices=[
            "attn_qkv",
            "attn_q",
            "attn_k",
            "attn_v",
            "attn_dense",
            "mlp_h_to_4h",
            "mlp_gate",
            "mlp_4h_to_h",
        ],
        help="Add lora in which modules. Only be activated when use_lora_plugin is enabled.",
    )
    parser.add_argument(
        '--max_lora_rank',
        type=int,
        default=64,
        help='maximum lora rank for different lora modules. '
        'It is used to compute the workspace size of lora plugin.',
    )
    parser.add_argument("--lora_checkpoint_path", default=None, type=str, help="The checkpoint path of LoRA weights")
    args = parser.parse_args(argv)
    return args


def get_trt_deployable(args):
    if args.triton_model_repository is None:
        trt_path = "/tmp/trt_model_dir/"
        LOGGER.info(
            "/tmp/trt_model_dir/ path will be used as the TensorRT folder. "
            "Please set the --triton_model_repository parameter if you'd like to use a path that already "
            "includes the TensorRT model files."
        )
        Path(trt_path).mkdir(parents=True, exist_ok=True)
    else:
        trt_path = args.triton_model_repository

    if args.visual_checkpoint is None and args.triton_model_repository is None:
        raise ValueError(
            "The provided model repository is not a valid TensorRT model "
            "directory. Please provide a --visual_checkpoint."
        )

    if args.visual_checkpoint is None and not os.path.isdir(args.triton_model_repository):
        raise ValueError(
            "The provided model repository is not a valid TensorRT model "
            "directory. Please provide a --visual_checkpoint."
        )

    if args.visual_checkpoint is not None and args.model_type is None:
        raise ValueError("Model type is required to be defined if a nemo checkpoint is provided.")

    exporter = TensorRTMMExporter(
        model_dir=trt_path,
        load_model=(args.visual_checkpoint is None),
        modality=args.modality,
    )

    if args.visual_checkpoint is not None:
        try:
            LOGGER.info("Export operation will be started to export the nemo checkpoint to TensorRT.")
            exporter.export(
                visual_checkpoint_path=args.visual_checkpoint,
                llm_checkpoint_path=args.llm_checkpoint,
                model_type=args.model_type,
                llm_model_type=args.llm_model_type,
                tensor_parallel_size=args.num_gpus,
                max_input_len=args.max_input_len,
                max_output_len=args.max_output_len,
                vision_max_batch_size=args.vision_max_batch_size,
                max_batch_size=args.max_batch_size,
                max_multimodal_len=args.max_multimodal_len,
                dtype=args.dtype,
                use_lora_plugin=args.use_lora_plugin,
                lora_target_modules=args.lora_target_modules,
                max_lora_rank=args.max_lora_rank,
                lora_checkpoint_path=args.lora_checkpoint_path,
            )
        except Exception as error:
            raise RuntimeError("An error has occurred during the model export. Error message: " + str(error))

    return exporter


def nemo_deploy(argv):
    args = get_args(argv)

    loglevel = logging.INFO

    LOGGER.setLevel(loglevel)
    LOGGER.info("Logging level set to {}".format(loglevel))
    LOGGER.info(args)

    triton_deployable = get_trt_deployable(args)

    try:
        nm = DeployPyTriton(
            model=triton_deployable,
            triton_model_name=args.triton_model_name,
            triton_model_version=args.triton_model_version,
            max_batch_size=args.max_batch_size,
            http_port=args.triton_port,
            address=args.triton_http_address,
        )

        LOGGER.info("Triton deploy function will be called.")
        nm.deploy()
    except Exception as error:
        LOGGER.error("Error message has occurred during deploy function. Error message: " + str(error))
        return

    try:
        LOGGER.info("Model serving on Triton is will be started.")
        nm.serve()
    except Exception as error:
        LOGGER.error("Error message has occurred during deploy function. Error message: " + str(error))
        return

    LOGGER.info("Model serving will be stopped.")
    nm.stop()


if __name__ == '__main__':
    nemo_deploy(sys.argv[1:])