|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import tarfile |
|
|
import tempfile |
|
|
import typing |
|
|
|
|
|
import numpy as np |
|
|
import pytest |
|
|
import torch |
|
|
from PIL import Image |
|
|
from pytriton.model_config import Tensor |
|
|
|
|
|
from nemo.deploy.utils import ( |
|
|
NEMO1, |
|
|
NEMO2, |
|
|
broadcast_list, |
|
|
cast_output, |
|
|
ndarray2img, |
|
|
nemo_checkpoint_version, |
|
|
str_list2numpy, |
|
|
str_ndarray2list, |
|
|
typedict2tensor, |
|
|
) |
|
|
|
|
|
|
|
|
class TestTypedict2Tensor: |
|
|
class SampleTypedict: |
|
|
int_field: int |
|
|
float_field: float |
|
|
bool_field: bool |
|
|
str_field: str |
|
|
int_list: typing.List[int] |
|
|
float_list: typing.List[float] |
|
|
bool_list: typing.List[bool] |
|
|
str_list: typing.List[str] |
|
|
|
|
|
def test_typedict2tensor_basic(self): |
|
|
tensors = typedict2tensor(self.SampleTypedict) |
|
|
assert len(tensors) == 8 |
|
|
assert all(isinstance(t, Tensor) for t in tensors) |
|
|
|
|
|
|
|
|
int_tensor = next(t for t in tensors if t.name == "int_field") |
|
|
assert int_tensor.dtype == np.int32 |
|
|
assert int_tensor.shape == (1,) |
|
|
|
|
|
|
|
|
float_tensor = next(t for t in tensors if t.name == "float_field") |
|
|
assert float_tensor.dtype == np.float32 |
|
|
assert float_tensor.shape == (1,) |
|
|
|
|
|
|
|
|
bool_tensor = next(t for t in tensors if t.name == "bool_field") |
|
|
assert bool_tensor.dtype == np.bool_ |
|
|
assert bool_tensor.shape == (1,) |
|
|
|
|
|
|
|
|
str_tensor = next(t for t in tensors if t.name == "str_field") |
|
|
assert str_tensor.dtype == bytes |
|
|
assert str_tensor.shape == (1,) |
|
|
|
|
|
def test_typedict2tensor_with_overwrite(self): |
|
|
overwrite_kwargs = {"optional": True} |
|
|
tensors = typedict2tensor(self.SampleTypedict, overwrite_kwargs=overwrite_kwargs) |
|
|
assert all(t.optional for t in tensors) |
|
|
|
|
|
def test_typedict2tensor_list_types(self): |
|
|
tensors = typedict2tensor(self.SampleTypedict) |
|
|
|
|
|
|
|
|
int_list_tensor = next(t for t in tensors if t.name == "int_list") |
|
|
assert int_list_tensor.dtype == np.int32 |
|
|
assert int_list_tensor.shape == (1,) |
|
|
|
|
|
|
|
|
float_list_tensor = next(t for t in tensors if t.name == "float_list") |
|
|
assert float_list_tensor.dtype == np.float32 |
|
|
assert float_list_tensor.shape == (1,) |
|
|
|
|
|
|
|
|
bool_list_tensor = next(t for t in tensors if t.name == "bool_list") |
|
|
assert bool_list_tensor.dtype == np.bool_ |
|
|
assert bool_list_tensor.shape == (1,) |
|
|
|
|
|
|
|
|
str_list_tensor = next(t for t in tensors if t.name == "str_list") |
|
|
assert str_list_tensor.dtype == bytes |
|
|
assert str_list_tensor.shape == (1,) |
|
|
|
|
|
|
|
|
class TestNemoCheckpointVersion: |
|
|
def test_nemo2_checkpoint_dir(self): |
|
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
|
|
|
|
os.makedirs(os.path.join(tmpdir, "context")) |
|
|
os.makedirs(os.path.join(tmpdir, "weights")) |
|
|
assert nemo_checkpoint_version(tmpdir) == NEMO2 |
|
|
|
|
|
def test_nemo1_checkpoint_dir(self): |
|
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
|
|
|
|
assert nemo_checkpoint_version(tmpdir) == NEMO1 |
|
|
|
|
|
def test_nemo2_checkpoint_tar(self): |
|
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
|
tar_path = os.path.join(tmpdir, "checkpoint.tar") |
|
|
with tarfile.open(tar_path, "w") as tar: |
|
|
|
|
|
context_info = tarfile.TarInfo("context") |
|
|
context_info.type = tarfile.DIRTYPE |
|
|
tar.addfile(context_info) |
|
|
|
|
|
weights_info = tarfile.TarInfo("weights") |
|
|
weights_info.type = tarfile.DIRTYPE |
|
|
tar.addfile(weights_info) |
|
|
|
|
|
assert nemo_checkpoint_version(tar_path) == NEMO2 |
|
|
|
|
|
def test_nemo1_checkpoint_tar(self): |
|
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
|
tar_path = os.path.join(tmpdir, "checkpoint.tar") |
|
|
with tarfile.open(tar_path, "w") as tar: |
|
|
|
|
|
pass |
|
|
|
|
|
assert nemo_checkpoint_version(tar_path) == NEMO1 |
|
|
|
|
|
|
|
|
class TestStringConversions: |
|
|
def test_str_list2numpy(self): |
|
|
input_list = ["hello", "world", "test"] |
|
|
result = str_list2numpy(input_list) |
|
|
assert isinstance(result, np.ndarray) |
|
|
assert result.shape == (3, 1) |
|
|
assert all(isinstance(x, bytes) for x in result.flatten()) |
|
|
|
|
|
def test_str_ndarray2list(self): |
|
|
input_array = np.array([b"hello", b"world", b"test"]).reshape(3, 1) |
|
|
result = str_ndarray2list(input_array) |
|
|
assert isinstance(result, list) |
|
|
assert result == ["hello", "world", "test"] |
|
|
|
|
|
def test_str_conversion_roundtrip(self): |
|
|
input_list = ["hello", "world", "test"] |
|
|
numpy_array = str_list2numpy(input_list) |
|
|
output_list = str_ndarray2list(numpy_array) |
|
|
assert input_list == output_list |
|
|
|
|
|
|
|
|
class TestImageConversions: |
|
|
def test_ndarray2img(self): |
|
|
|
|
|
img_array = np.random.randint(0, 255, size=(2, 100, 100, 3), dtype=np.uint8) |
|
|
result = ndarray2img(img_array) |
|
|
|
|
|
assert isinstance(result, list) |
|
|
assert len(result) == 2 |
|
|
assert all(isinstance(img, Image.Image) for img in result) |
|
|
assert all(img.size == (100, 100) for img in result) |
|
|
|
|
|
|
|
|
class TestCastOutput: |
|
|
def test_cast_tensor(self): |
|
|
input_tensor = torch.tensor([1, 2, 3]) |
|
|
result = cast_output(input_tensor, np.int32) |
|
|
assert isinstance(result, np.ndarray) |
|
|
assert result.dtype == np.int32 |
|
|
assert result.shape == (3, 1) |
|
|
|
|
|
def test_cast_numpy(self): |
|
|
input_array = np.array([1, 2, 3]) |
|
|
result = cast_output(input_array, np.float32) |
|
|
assert isinstance(result, np.ndarray) |
|
|
assert result.dtype == np.float32 |
|
|
assert result.shape == (3, 1) |
|
|
|
|
|
def test_cast_string(self): |
|
|
input_list = ["hello", "world"] |
|
|
result = cast_output(input_list, bytes) |
|
|
assert isinstance(result, np.ndarray) |
|
|
assert result.shape == (2, 1) |
|
|
|
|
|
def test_cast_1d_to_2d(self): |
|
|
input_array = np.array([1, 2, 3]) |
|
|
result = cast_output(input_array, np.int32) |
|
|
assert result.ndim == 2 |
|
|
assert result.shape == (3, 1) |
|
|
|
|
|
|
|
|
class TestBroadcastList: |
|
|
def test_broadcast_list_no_distributed(self): |
|
|
with pytest.raises(RuntimeError, match="Distributed environment is not initialized"): |
|
|
broadcast_list(["test"]) |
|
|
|
|
|
def test_broadcast_list_distributed(self, monkeypatch): |
|
|
|
|
|
monkeypatch.setattr(torch.distributed, "is_initialized", lambda: True) |
|
|
monkeypatch.setattr(torch.distributed, "get_rank", lambda: 0) |
|
|
|
|
|
|
|
|
def mock_broadcast_object_list(object_list, src, group=None): |
|
|
if src == 0: |
|
|
object_list[0] = ["test"] |
|
|
|
|
|
monkeypatch.setattr(torch.distributed, "broadcast_object_list", mock_broadcast_object_list) |
|
|
|
|
|
result = broadcast_list(["test"]) |
|
|
assert result == ["test"] |
|
|
|