NeMo_Canary / nemo /deploy /deploy_pytriton.py
Respair's picture
Upload folder using huggingface_hub
b386992 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
use_pytriton = True
try:
from pytriton.model_config import ModelConfig
from pytriton.triton import Triton, TritonConfig
except Exception:
use_pytriton = False
from nemo.deploy.deploy_base import DeployBase
class DeployPyTriton(DeployBase):
"""
Deploys any models to Triton Inference Server that implements ITritonDeployable interface in nemo.deploy.
Example:
from nemo.deploy import DeployPyTriton, NemoQueryLLM
from nemo.export.tensorrt_llm import TensorRTLLM
trt_llm_exporter = TensorRTLLM(model_dir="/path/for/model/files")
trt_llm_exporter.export(
nemo_checkpoint_path="/path/for/nemo/checkpoint",
model_type="llama",
tensor_parallelism_size=1,
)
nm = DeployPyTriton(model=trt_llm_exporter, triton_model_name="model_name", http_port=8000)
nm.deploy()
nm.run()
nq = NemoQueryLLM(url="localhost", model_name="model_name")
prompts = ["hello, testing GPT inference", "another GPT inference test?"]
output = nq.query_llm(prompts=prompts, max_output_len=100)
print("prompts: ", prompts)
print("")
print("output: ", output)
print("")
prompts = ["Give me some info about Paris", "Do you think Londan is a good city to visit?", "What do you think about Rome?"]
output = nq.query_llm(prompts=prompts, max_output_len=250)
print("prompts: ", prompts)
print("")
print("output: ", output)
print("")
"""
def __init__(
self,
triton_model_name: str,
triton_model_version: int = 1,
checkpoint_path: str = None,
model=None,
max_batch_size: int = 128,
http_port: int = 8000,
grpc_port: int = 8001,
address="0.0.0.0",
allow_grpc=True,
allow_http=True,
streaming=False,
pytriton_log_verbose=0,
):
"""
A nemo checkpoint or model is expected for serving on Triton Inference Server.
Args:
triton_model_name (str): Name for the service
triton_model_version(int): Version for the service
checkpoint_path (str): path of the nemo file
model (ITritonDeployable): A model that implements the ITritonDeployable from nemo.deploy import ITritonDeployable
max_batch_size (int): max batch size
port (int) : port for the Triton server
address (str): http address for Triton server to bind.
"""
super().__init__(
triton_model_name=triton_model_name,
triton_model_version=triton_model_version,
checkpoint_path=checkpoint_path,
model=model,
max_batch_size=max_batch_size,
http_port=http_port,
grpc_port=grpc_port,
address=address,
allow_grpc=allow_grpc,
allow_http=allow_http,
streaming=streaming,
pytriton_log_verbose=pytriton_log_verbose,
)
def deploy(self):
"""
Deploys any models to Triton Inference Server.
"""
self._init_nemo_model()
try:
if self.streaming:
# TODO: can't set allow_http=True due to a bug in pytriton, will fix in latest pytriton
triton_config = TritonConfig(
log_verbose=self.pytriton_log_verbose,
allow_grpc=self.allow_grpc,
allow_http=self.allow_http,
grpc_address=self.address,
)
self.triton = Triton(config=triton_config)
self.triton.bind(
model_name=self.triton_model_name,
model_version=self.triton_model_version,
infer_func=self.model.triton_infer_fn_streaming,
inputs=self.model.get_triton_input,
outputs=self.model.get_triton_output,
config=ModelConfig(decoupled=True),
)
else:
triton_config = TritonConfig(
http_address=self.address,
http_port=self.http_port,
grpc_address=self.address,
grpc_port=self.grpc_port,
allow_grpc=self.allow_grpc,
allow_http=self.allow_http,
)
self.triton = Triton(config=triton_config)
self.triton.bind(
model_name=self.triton_model_name,
model_version=self.triton_model_version,
infer_func=self.model.triton_infer_fn,
inputs=self.model.get_triton_input,
outputs=self.model.get_triton_output,
config=ModelConfig(max_batch_size=self.max_batch_size),
)
except Exception as e:
self.triton = None
print(e)
def serve(self):
"""
Starts serving the model and waits for the requests
"""
if self.triton is None:
raise Exception("deploy should be called first.")
try:
self.triton.serve()
except Exception as e:
self.triton = None
print(e)
def run(self):
"""
Starts serving the model asynchronously.
"""
if self.triton is None:
raise Exception("deploy should be called first.")
self.triton.run()
def stop(self):
"""
Stops serving the model.
"""
if self.triton is None:
raise Exception("deploy should be called first.")
self.triton.stop()