File size: 4,065 Bytes
dfefe0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py
index 4af6f2d5b954..f17410a20daf 100644
--- a/src/transformers/utils/generic.py
+++ b/src/transformers/utils/generic.py
@@ -257,6 +257,18 @@ def to_py_obj(obj):
     """
     Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a python list.
     """
+    if isinstance(obj, (int, float)):
+        return obj
+    elif isinstance(obj, (dict, UserDict)):
+        return {k: to_py_obj(v) for k, v in obj.items()}
+    elif isinstance(obj, (list, tuple)):
+        try:
+            arr = np.array(obj)
+            if np.issubdtype(arr.dtype, np.integer) or np.issubdtype(arr.dtype, np.floating):
+                return arr.tolist()
+        except Exception:
+            pass
+        return [to_py_obj(o) for o in obj]
 
     framework_to_py_obj = {
         "pt": lambda obj: obj.detach().cpu().tolist(),
@@ -265,11 +277,6 @@ def to_py_obj(obj):
         "np": lambda obj: obj.tolist(),
     }
 
-    if isinstance(obj, (dict, UserDict)):
-        return {k: to_py_obj(v) for k, v in obj.items()}
-    elif isinstance(obj, (list, tuple)):
-        return [to_py_obj(o) for o in obj]
-
     # This gives us a smart order to test the frameworks with the corresponding tests.
     framework_to_test_func = _get_frameworks_and_test_func(obj)
     for framework, test_func in framework_to_test_func.items():
diff --git a/tests/utils/test_generic.py b/tests/utils/test_generic.py
index 287887038ab4..3eed30c16e3a 100644
--- a/tests/utils/test_generic.py
+++ b/tests/utils/test_generic.py
@@ -28,6 +28,7 @@
     is_torch_available,
     reshape,
     squeeze,
+    to_py_obj,
     transpose,
 )
 
@@ -201,6 +202,77 @@ def test_expand_dims_flax(self):
         t = jnp.array(x)
         self.assertTrue(np.allclose(expand_dims(x, axis=1), np.asarray(expand_dims(t, axis=1))))
 
+    def test_to_py_obj_native(self):
+        self.assertTrue(to_py_obj(1) == 1)
+        self.assertTrue(to_py_obj([1, 2, 3]) == [1, 2, 3])
+        self.assertTrue(to_py_obj([((1.0, 1.1), 1.2), (2, 3)]) == [[[1.0, 1.1], 1.2], [2, 3]])
+
+    def test_to_py_obj_numpy(self):
+        x1 = [[1, 2, 3], [4, 5, 6]]
+        t1 = np.array(x1)
+        self.assertTrue(to_py_obj(t1) == x1)
+
+        x2 = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
+        t2 = np.array(x2)
+        self.assertTrue(to_py_obj(t2) == x2)
+
+        self.assertTrue(to_py_obj([t1, t2]) == [x1, x2])
+
+    @require_torch
+    def test_to_py_obj_torch(self):
+        x1 = [[1, 2, 3], [4, 5, 6]]
+        t1 = torch.tensor(x1)
+        self.assertTrue(to_py_obj(t1) == x1)
+
+        x2 = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
+        t2 = torch.tensor(x2)
+        self.assertTrue(to_py_obj(t2) == x2)
+
+        self.assertTrue(to_py_obj([t1, t2]) == [x1, x2])
+
+    @require_tf
+    def test_to_py_obj_tf(self):
+        x1 = [[1, 2, 3], [4, 5, 6]]
+        t1 = tf.constant(x1)
+        self.assertTrue(to_py_obj(t1) == x1)
+
+        x2 = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
+        t2 = tf.constant(x2)
+        self.assertTrue(to_py_obj(t2) == x2)
+
+        self.assertTrue(to_py_obj([t1, t2]) == [x1, x2])
+
+    @require_flax
+    def test_to_py_obj_flax(self):
+        x1 = [[1, 2, 3], [4, 5, 6]]
+        t1 = jnp.array(x1)
+        self.assertTrue(to_py_obj(t1) == x1)
+
+        x2 = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
+        t2 = jnp.array(x2)
+        self.assertTrue(to_py_obj(t2) == x2)
+
+        self.assertTrue(to_py_obj([t1, t2]) == [x1, x2])
+
+    @require_torch
+    @require_tf
+    @require_flax
+    def test_to_py_obj_mixed(self):
+        x1 = [[1], [2]]
+        t1 = np.array(x1)
+
+        x2 = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
+        t2 = torch.tensor(x2)
+
+        x3 = [1, 2, 3]
+        t3 = tf.constant(x3)
+
+        x4 = [[[1.0, 2.0]]]
+        t4 = jnp.array(x4)
+
+        mixed = [(t1, t2), (t3, t4)]
+        self.assertTrue(to_py_obj(mixed) == [[x1, x2], [x3, x4]])
+
 
 class ValidationDecoratorTester(unittest.TestCase):
     def test_cases_no_warning(self):