NeMo_Canary / tests /deploy /test_triton_deployable.py
Respair's picture
Upload folder using huggingface_hub
b386992 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytest
from nemo.deploy.triton_deployable import ITritonDeployable
class MockTritonDeployable(ITritonDeployable):
def __init__(self):
self.input_shape = (1, 10)
self.output_shape = (1, 5)
def get_triton_input(self):
return {"input": {"shape": self.input_shape, "dtype": np.float32}}
def get_triton_output(self):
return {"output": {"shape": self.output_shape, "dtype": np.float32}}
def triton_infer_fn(self, **inputs: np.ndarray):
input_data = inputs["input"]
return {"output": np.ones(self.output_shape) * np.mean(input_data)}
@pytest.fixture
def mock_deployable():
return MockTritonDeployable()
def test_get_triton_input(mock_deployable):
"""Test that get_triton_input returns the correct input specification."""
input_spec = mock_deployable.get_triton_input()
assert "input" in input_spec
assert input_spec["input"]["shape"] == (1, 10)
assert input_spec["input"]["dtype"] == np.float32
def test_get_triton_output(mock_deployable):
"""Test that get_triton_output returns the correct output specification."""
output_spec = mock_deployable.get_triton_output()
assert "output" in output_spec
assert output_spec["output"]["shape"] == (1, 5)
assert output_spec["output"]["dtype"] == np.float32
def test_triton_infer_fn(mock_deployable):
"""Test that triton_infer_fn processes inputs correctly."""
# Create test input
test_input = np.random.rand(1, 10).astype(np.float32)
input_mean = np.mean(test_input)
# Run inference
result = mock_deployable.triton_infer_fn(input=test_input)
# Check output
assert "output" in result
assert result["output"].shape == (1, 5)
assert np.allclose(result["output"], input_mean)
def test_abstract_class_instantiation():
"""Test that ITritonDeployable cannot be instantiated directly."""
with pytest.raises(TypeError):
ITritonDeployable()