File size: 2,585 Bytes
b386992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import numpy as np
import pytest
from nemo.deploy.triton_deployable import ITritonDeployable


class MockTritonDeployable(ITritonDeployable):
    def __init__(self):
        self.input_shape = (1, 10)
        self.output_shape = (1, 5)

    def get_triton_input(self):
        return {"input": {"shape": self.input_shape, "dtype": np.float32}}

    def get_triton_output(self):
        return {"output": {"shape": self.output_shape, "dtype": np.float32}}

    def triton_infer_fn(self, **inputs: np.ndarray):
        input_data = inputs["input"]
        return {"output": np.ones(self.output_shape) * np.mean(input_data)}


@pytest.fixture
def mock_deployable():
    return MockTritonDeployable()


def test_get_triton_input(mock_deployable):
    """Test that get_triton_input returns the correct input specification."""
    input_spec = mock_deployable.get_triton_input()

    assert "input" in input_spec
    assert input_spec["input"]["shape"] == (1, 10)
    assert input_spec["input"]["dtype"] == np.float32


def test_get_triton_output(mock_deployable):
    """Test that get_triton_output returns the correct output specification."""
    output_spec = mock_deployable.get_triton_output()

    assert "output" in output_spec
    assert output_spec["output"]["shape"] == (1, 5)
    assert output_spec["output"]["dtype"] == np.float32


def test_triton_infer_fn(mock_deployable):
    """Test that triton_infer_fn processes inputs correctly."""
    # Create test input
    test_input = np.random.rand(1, 10).astype(np.float32)
    input_mean = np.mean(test_input)

    # Run inference
    result = mock_deployable.triton_infer_fn(input=test_input)

    # Check output
    assert "output" in result
    assert result["output"].shape == (1, 5)
    assert np.allclose(result["output"], input_mean)


def test_abstract_class_instantiation():
    """Test that ITritonDeployable cannot be instantiated directly."""
    with pytest.raises(TypeError):
        ITritonDeployable()