| |
| |
| |
| |
| @@ -257,6 +257,18 @@ def to_py_obj(obj): |
| """ |
| Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a python list. |
| """ |
| + if isinstance(obj, (int, float)): |
| + return obj |
| + elif isinstance(obj, (dict, UserDict)): |
| + return {k: to_py_obj(v) for k, v in obj.items()} |
| + elif isinstance(obj, (list, tuple)): |
| + try: |
| + arr = np.array(obj) |
| + if np.issubdtype(arr.dtype, np.integer) or np.issubdtype(arr.dtype, np.floating): |
| + return arr.tolist() |
| + except Exception: |
| + pass |
| + return [to_py_obj(o) for o in obj] |
| |
| framework_to_py_obj = { |
| "pt": lambda obj: obj.detach().cpu().tolist(), |
| @@ -265,11 +277,6 @@ def to_py_obj(obj): |
| "np": lambda obj: obj.tolist(), |
| } |
| |
| - if isinstance(obj, (dict, UserDict)): |
| - return {k: to_py_obj(v) for k, v in obj.items()} |
| - elif isinstance(obj, (list, tuple)): |
| - return [to_py_obj(o) for o in obj] |
| - |
| # This gives us a smart order to test the frameworks with the corresponding tests. |
| framework_to_test_func = _get_frameworks_and_test_func(obj) |
| for framework, test_func in framework_to_test_func.items(): |
| |
| |
| |
| |
| @@ -28,6 +28,7 @@ |
| is_torch_available, |
| reshape, |
| squeeze, |
| + to_py_obj, |
| transpose, |
| ) |
| |
| @@ -201,6 +202,77 @@ def test_expand_dims_flax(self): |
| t = jnp.array(x) |
| self.assertTrue(np.allclose(expand_dims(x, axis=1), np.asarray(expand_dims(t, axis=1)))) |
| |
| + def test_to_py_obj_native(self): |
| + self.assertTrue(to_py_obj(1) == 1) |
| + self.assertTrue(to_py_obj([1, 2, 3]) == [1, 2, 3]) |
| + self.assertTrue(to_py_obj([((1.0, 1.1), 1.2), (2, 3)]) == [[[1.0, 1.1], 1.2], [2, 3]]) |
| + |
| + def test_to_py_obj_numpy(self): |
| + x1 = [[1, 2, 3], [4, 5, 6]] |
| + t1 = np.array(x1) |
| + self.assertTrue(to_py_obj(t1) == x1) |
| + |
| + x2 = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] |
| + t2 = np.array(x2) |
| + self.assertTrue(to_py_obj(t2) == x2) |
| + |
| + self.assertTrue(to_py_obj([t1, t2]) == [x1, x2]) |
| + |
| + @require_torch |
| + def test_to_py_obj_torch(self): |
| + x1 = [[1, 2, 3], [4, 5, 6]] |
| + t1 = torch.tensor(x1) |
| + self.assertTrue(to_py_obj(t1) == x1) |
| + |
| + x2 = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] |
| + t2 = torch.tensor(x2) |
| + self.assertTrue(to_py_obj(t2) == x2) |
| + |
| + self.assertTrue(to_py_obj([t1, t2]) == [x1, x2]) |
| + |
| + @require_tf |
| + def test_to_py_obj_tf(self): |
| + x1 = [[1, 2, 3], [4, 5, 6]] |
| + t1 = tf.constant(x1) |
| + self.assertTrue(to_py_obj(t1) == x1) |
| + |
| + x2 = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] |
| + t2 = tf.constant(x2) |
| + self.assertTrue(to_py_obj(t2) == x2) |
| + |
| + self.assertTrue(to_py_obj([t1, t2]) == [x1, x2]) |
| + |
| + @require_flax |
| + def test_to_py_obj_flax(self): |
| + x1 = [[1, 2, 3], [4, 5, 6]] |
| + t1 = jnp.array(x1) |
| + self.assertTrue(to_py_obj(t1) == x1) |
| + |
| + x2 = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] |
| + t2 = jnp.array(x2) |
| + self.assertTrue(to_py_obj(t2) == x2) |
| + |
| + self.assertTrue(to_py_obj([t1, t2]) == [x1, x2]) |
| + |
| + @require_torch |
| + @require_tf |
| + @require_flax |
| + def test_to_py_obj_mixed(self): |
| + x1 = [[1], [2]] |
| + t1 = np.array(x1) |
| + |
| + x2 = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] |
| + t2 = torch.tensor(x2) |
| + |
| + x3 = [1, 2, 3] |
| + t3 = tf.constant(x3) |
| + |
| + x4 = [[[1.0, 2.0]]] |
| + t4 = jnp.array(x4) |
| + |
| + mixed = [(t1, t2), (t3, t4)] |
| + self.assertTrue(to_py_obj(mixed) == [[x1, x2], [x3, x4]]) |
| + |
| |
| class ValidationDecoratorTester(unittest.TestCase): |
| def test_cases_no_warning(self): |
|
|