|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
import pytest |
|
|
from nemo.deploy.triton_deployable import ITritonDeployable |
|
|
|
|
|
|
|
|
class MockTritonDeployable(ITritonDeployable): |
|
|
def __init__(self): |
|
|
self.input_shape = (1, 10) |
|
|
self.output_shape = (1, 5) |
|
|
|
|
|
def get_triton_input(self): |
|
|
return {"input": {"shape": self.input_shape, "dtype": np.float32}} |
|
|
|
|
|
def get_triton_output(self): |
|
|
return {"output": {"shape": self.output_shape, "dtype": np.float32}} |
|
|
|
|
|
def triton_infer_fn(self, **inputs: np.ndarray): |
|
|
input_data = inputs["input"] |
|
|
return {"output": np.ones(self.output_shape) * np.mean(input_data)} |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def mock_deployable(): |
|
|
return MockTritonDeployable() |
|
|
|
|
|
|
|
|
def test_get_triton_input(mock_deployable): |
|
|
"""Test that get_triton_input returns the correct input specification.""" |
|
|
input_spec = mock_deployable.get_triton_input() |
|
|
|
|
|
assert "input" in input_spec |
|
|
assert input_spec["input"]["shape"] == (1, 10) |
|
|
assert input_spec["input"]["dtype"] == np.float32 |
|
|
|
|
|
|
|
|
def test_get_triton_output(mock_deployable): |
|
|
"""Test that get_triton_output returns the correct output specification.""" |
|
|
output_spec = mock_deployable.get_triton_output() |
|
|
|
|
|
assert "output" in output_spec |
|
|
assert output_spec["output"]["shape"] == (1, 5) |
|
|
assert output_spec["output"]["dtype"] == np.float32 |
|
|
|
|
|
|
|
|
def test_triton_infer_fn(mock_deployable): |
|
|
"""Test that triton_infer_fn processes inputs correctly.""" |
|
|
|
|
|
test_input = np.random.rand(1, 10).astype(np.float32) |
|
|
input_mean = np.mean(test_input) |
|
|
|
|
|
|
|
|
result = mock_deployable.triton_infer_fn(input=test_input) |
|
|
|
|
|
|
|
|
assert "output" in result |
|
|
assert result["output"].shape == (1, 5) |
|
|
assert np.allclose(result["output"], input_mean) |
|
|
|
|
|
|
|
|
def test_abstract_class_instantiation(): |
|
|
"""Test that ITritonDeployable cannot be instantiated directly.""" |
|
|
with pytest.raises(TypeError): |
|
|
ITritonDeployable() |
|
|
|