Spaces:
Sleeping
Sleeping
Commit
·
962c25c
1
Parent(s):
4d5aec3
Add test for mod mapping in torch
Browse files- test/test_torch.py +17 -0
test/test_torch.py
CHANGED
|
@@ -51,3 +51,20 @@ class TestTorch(unittest.TestCase):
|
|
| 51 |
np.square(np.cos(X[:, 1])), # Selection 1st feature
|
| 52 |
decimal=4,
|
| 53 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
np.square(np.cos(X[:, 1])), # Selection 1st feature
|
| 52 |
decimal=4,
|
| 53 |
)
|
| 54 |
+
|
| 55 |
+
def test_mod_mapping(self):
|
| 56 |
+
x, y, z = sympy.symbols("x y z")
|
| 57 |
+
expression = x ** 2 + sympy.atanh(sympy.Mod(y + 1, 2) - 1) * 3.2 * z
|
| 58 |
+
|
| 59 |
+
module = sympy2torch(expression, [x, y, z])
|
| 60 |
+
|
| 61 |
+
X = torch.rand(100, 3).float() * 10
|
| 62 |
+
|
| 63 |
+
true_out = (
|
| 64 |
+
X[:, 0] ** 2 + torch.atanh(torch.fmod(X[:, 1] + 1, 2) - 1) * 3.2 * X[:, 2]
|
| 65 |
+
)
|
| 66 |
+
torch_out = module(X)
|
| 67 |
+
|
| 68 |
+
np.testing.assert_array_almost_equal(
|
| 69 |
+
true_out.detach(), torch_out.detach(), decimal=4
|
| 70 |
+
)
|