|
|
import numpy as np |
|
|
import pytest |
|
|
import tree |
|
|
|
|
|
from openpi_client import msgpack_numpy |
|
|
|
|
|
|
|
|
def _check(expected, actual): |
|
|
if isinstance(expected, np.ndarray): |
|
|
assert expected.shape == actual.shape |
|
|
assert expected.dtype == actual.dtype |
|
|
assert np.array_equal(expected, actual, equal_nan=expected.dtype.kind == "f") |
|
|
else: |
|
|
assert expected == actual |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
"data", |
|
|
[ |
|
|
1, |
|
|
1.0, |
|
|
"hello", |
|
|
np.bool_(True), |
|
|
np.array([1, 2, 3])[0], |
|
|
np.str_("asdf"), |
|
|
[1, 2, 3], |
|
|
{"key": "value"}, |
|
|
{"key": [1, 2, 3]}, |
|
|
np.array(1.0), |
|
|
np.array([1, 2, 3], dtype=np.int32), |
|
|
np.array(["asdf", "qwer"]), |
|
|
np.array([True, False]), |
|
|
np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32), |
|
|
np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.int16), |
|
|
np.array([np.nan, np.inf, -np.inf]), |
|
|
{"arr": np.array([1, 2, 3]), "nested": {"arr": np.array([4, 5, 6])}}, |
|
|
[np.array([1, 2]), np.array([3, 4])], |
|
|
np.zeros((3, 4, 5), dtype=np.float32), |
|
|
np.ones((2, 3), dtype=np.float64), |
|
|
], |
|
|
) |
|
|
def test_pack_unpack(data): |
|
|
packed = msgpack_numpy.packb(data) |
|
|
unpacked = msgpack_numpy.unpackb(packed) |
|
|
tree.map_structure(_check, data, unpacked) |
|
|
|