NeMo_Canary / tests /deploy /test_deploy_query.py
Respair's picture
Upload folder using huggingface_hub
b386992 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from pytriton.decorators import batch
from pytriton.model_config import Tensor
from nemo.deploy import DeployPyTriton, ITritonDeployable
from nemo.deploy.nlp import NemoQueryLLM
from nemo.deploy.utils import cast_output, str_ndarray2list
class MockModel(ITritonDeployable):
@property
def get_triton_input(self):
inputs = (
Tensor(name="prompts", shape=(-1,), dtype=bytes),
Tensor(name="max_output_len", shape=(-1,), dtype=np.int_, optional=True),
Tensor(name="output_context_logits", shape=(-1,), dtype=np.bool_, optional=False),
Tensor(name="output_generation_logits", shape=(-1,), dtype=np.bool_, optional=False),
)
return inputs
@property
def get_triton_output(self):
outputs = (Tensor(name="outputs", shape=(-1,), dtype=bytes),)
return outputs
@batch
def triton_infer_fn(self, **inputs: np.ndarray):
infer_input = {"input_texts": str_ndarray2list(inputs.pop("prompts"))}
if "max_output_len" in inputs:
infer_input["max_output_len"] = inputs.pop("max_output_len")[0][0]
output_dict = dict()
output_dict["outputs"] = cast_output("I am good, how about you?", np.bytes_)
return output_dict
def test_nemo_deploy_query():
model_name = "mock_model"
model = MockModel()
nm = DeployPyTriton(
model=model,
triton_model_name=model_name,
max_batch_size=32,
http_port=9002,
grpc_port=8001,
address="0.0.0.0",
allow_grpc=True,
allow_http=True,
streaming=False,
)
nm.deploy()
nm.run()
nq = NemoQueryLLM(url="localhost:9002", model_name=model_name)
output_deployed = nq.query_llm(
prompts=["Hey, how is it going?"],
max_output_len=20,
)
nm.stop()
assert output_deployed is not None, "Output cannot be none."
assert output_deployed == "I am good, how about you?", "Output cannot be none."