| import numpy as np |
| import pytest |
| import tree |
|
|
| from openpi_client import msgpack_numpy |
|
|
|
|
| def _check(expected, actual): |
| if isinstance(expected, np.ndarray): |
| assert expected.shape == actual.shape |
| assert expected.dtype == actual.dtype |
| assert np.array_equal(expected, actual, equal_nan=expected.dtype.kind == "f") |
| else: |
| assert expected == actual |
|
|
|
|
| @pytest.mark.parametrize( |
| "data", |
| [ |
| 1, |
| 1.0, |
| "hello", |
| np.bool_(True), |
| np.array([1, 2, 3])[0], |
| np.str_("asdf"), |
| [1, 2, 3], |
| {"key": "value"}, |
| {"key": [1, 2, 3]}, |
| np.array(1.0), |
| np.array([1, 2, 3], dtype=np.int32), |
| np.array(["asdf", "qwer"]), |
| np.array([True, False]), |
| np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32), |
| np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.int16), |
| np.array([np.nan, np.inf, -np.inf]), |
| {"arr": np.array([1, 2, 3]), "nested": {"arr": np.array([4, 5, 6])}}, |
| [np.array([1, 2]), np.array([3, 4])], |
| np.zeros((3, 4, 5), dtype=np.float32), |
| np.ones((2, 3), dtype=np.float64), |
| ], |
| ) |
| def test_pack_unpack(data): |
| packed = msgpack_numpy.packb(data) |
| unpacked = msgpack_numpy.unpackb(packed) |
| tree.map_structure(_check, data, unpacked) |
|
|